compiler and basic analyzing ast
This commit is contained in:
parent
f393c583b8
commit
96a783ae5a
4 changed files with 151 additions and 28 deletions
|
|
@ -4,3 +4,7 @@ version = "0.1.0"
|
||||||
edition = "2024"
|
edition = "2024"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = ["std"]
|
||||||
|
std = []
|
||||||
|
|
|
||||||
82
src/ast.rs
82
src/ast.rs
|
|
@ -10,7 +10,13 @@ pub enum Prim {
|
||||||
Input, // переменная x
|
Input, // переменная x
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
impl Prim {
|
||||||
|
pub const fn is_commutative(&self) -> bool {
|
||||||
|
matches!(self, Prim::AddPair)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub enum ExprAst {
|
pub enum ExprAst {
|
||||||
Atom(Prim),
|
Atom(Prim),
|
||||||
Composition(Box<ExprAst>, Box<ExprAst>), // g ∘ f
|
Composition(Box<ExprAst>, Box<ExprAst>), // g ∘ f
|
||||||
|
|
@ -26,3 +32,77 @@ pub fn hash_ast(e: &ExprAst) -> u64 {
|
||||||
e.hash(&mut h);
|
e.hash(&mut h);
|
||||||
h.finish()
|
h.finish()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl ExprAst {
|
||||||
|
fn fingerprint(&self) -> u64 {
|
||||||
|
let mut h = DefaultHasher::new();
|
||||||
|
self.hash(&mut h);
|
||||||
|
h.finish()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PartialEq for ExprAst {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
match (self, other) {
|
||||||
|
(ExprAst::Atom(p1), ExprAst::Atom(p2)) => p1 == p2,
|
||||||
|
(ExprAst::Composition(a1, b1), ExprAst::Composition(a2, b2)) => {
|
||||||
|
match (b1.as_ref(), b2.as_ref()) {
|
||||||
|
(ExprAst::Atom(p1), ExprAst::Atom(p2))
|
||||||
|
if p1.is_commutative() && p2.is_commutative() =>
|
||||||
|
{
|
||||||
|
if p1 != p2 {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
match (a1.as_ref(), a2.as_ref()) {
|
||||||
|
(ExprAst::Junction(l1, r1), ExprAst::Junction(l2, r2)) => {
|
||||||
|
(l1 == l2 && r1 == r2) || (l1 == r2 && r1 == l2)
|
||||||
|
}
|
||||||
|
_ => a1 == a2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => a1 == a2 && b1 == b2,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
(ExprAst::Junction(l1, r1), ExprAst::Junction(l2, r2)) => l1 == l2 && r1 == r2,
|
||||||
|
_ => false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Eq for ExprAst {}
|
||||||
|
|
||||||
|
impl Hash for ExprAst {
|
||||||
|
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||||
|
match self {
|
||||||
|
ExprAst::Atom(p) => {
|
||||||
|
0u8.hash(state);
|
||||||
|
p.hash(state);
|
||||||
|
}
|
||||||
|
ExprAst::Composition(a, b) => {
|
||||||
|
if let ExprAst::Atom(p) = b.as_ref() {
|
||||||
|
if p.is_commutative() {
|
||||||
|
// Canonicalize inputs for commutative operations.
|
||||||
|
1u8.hash(state);
|
||||||
|
p.hash(state);
|
||||||
|
if let ExprAst::Junction(l, r) = a.as_ref() {
|
||||||
|
let hl = l.fingerprint();
|
||||||
|
let hr = r.fingerprint();
|
||||||
|
let (lo, hi) = if hl <= hr { (hl, hr) } else { (hr, hl) };
|
||||||
|
lo.hash(state);
|
||||||
|
hi.hash(state);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
1u8.hash(state);
|
||||||
|
a.hash(state);
|
||||||
|
b.hash(state);
|
||||||
|
}
|
||||||
|
ExprAst::Junction(l, r) => {
|
||||||
|
2u8.hash(state);
|
||||||
|
l.hash(state);
|
||||||
|
r.hash(state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
26
src/dsl.rs
26
src/dsl.rs
|
|
@ -1,5 +1,4 @@
|
||||||
use std::sync::Arc;
|
use crate::support::{run_parallel, Arc};
|
||||||
use std::thread;
|
|
||||||
|
|
||||||
// Обобщённое выражение: I -> Out
|
// Обобщённое выражение: I -> Out
|
||||||
pub trait Expr<I> {
|
pub trait Expr<I> {
|
||||||
|
|
@ -7,6 +6,18 @@ pub trait Expr<I> {
|
||||||
fn eval(&self, x: I) -> Self::Out;
|
fn eval(&self, x: I) -> Self::Out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Позволяет использовать Arc<T> там, где требуется Expr.
|
||||||
|
impl<I, T> Expr<I> for Arc<T>
|
||||||
|
where
|
||||||
|
T: Expr<I> + ?Sized,
|
||||||
|
{
|
||||||
|
type Out = T::Out;
|
||||||
|
|
||||||
|
fn eval(&self, x: I) -> Self::Out {
|
||||||
|
(**self).eval(x)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// атомарная функция
|
// атомарная функция
|
||||||
pub struct Function<I, O> {
|
pub struct Function<I, O> {
|
||||||
pub f: fn(I) -> O,
|
pub f: fn(I) -> O,
|
||||||
|
|
@ -39,7 +50,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
// junction — две ветки параллельно
|
// junction — две ветки параллельно
|
||||||
pub struct Junction<F1, F2> {
|
pub struct Junction<F1: ?Sized, F2: ?Sized> {
|
||||||
pub left: Arc<F1>,
|
pub left: Arc<F1>,
|
||||||
pub right: Arc<F2>,
|
pub right: Arc<F2>,
|
||||||
}
|
}
|
||||||
|
|
@ -47,8 +58,8 @@ pub struct Junction<F1, F2> {
|
||||||
impl<I, F1, F2> Expr<I> for Junction<F1, F2>
|
impl<I, F1, F2> Expr<I> for Junction<F1, F2>
|
||||||
where
|
where
|
||||||
I: Copy + Send + 'static,
|
I: Copy + Send + 'static,
|
||||||
F1: Expr<I> + Send + Sync + 'static,
|
F1: Expr<I> + Send + Sync + 'static + ?Sized,
|
||||||
F2: Expr<I> + Send + Sync + 'static,
|
F2: Expr<I> + Send + Sync + 'static + ?Sized,
|
||||||
F1::Out: Send + 'static,
|
F1::Out: Send + 'static,
|
||||||
F2::Out: Send + 'static,
|
F2::Out: Send + 'static,
|
||||||
{
|
{
|
||||||
|
|
@ -61,9 +72,6 @@ where
|
||||||
let x1 = x;
|
let x1 = x;
|
||||||
let x2 = x;
|
let x2 = x;
|
||||||
|
|
||||||
let h1 = thread::spawn(move || l.eval(x1));
|
run_parallel(move || l.eval(x1), move || r.eval(x2))
|
||||||
let h2 = thread::spawn(move || r.eval(x2));
|
|
||||||
|
|
||||||
(h1.join().unwrap(), h2.join().unwrap())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
67
src/main.rs
67
src/main.rs
|
|
@ -1,33 +1,23 @@
|
||||||
mod ast;
|
mod ast;
|
||||||
|
mod compiler;
|
||||||
mod dsl;
|
mod dsl;
|
||||||
|
mod primitives;
|
||||||
|
mod store;
|
||||||
|
mod support;
|
||||||
|
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use ast::{
|
use ast::{
|
||||||
ExprAst,
|
ExprAst,
|
||||||
ExprAst::{Atom, Composition as C, Junction as J},
|
ExprAst::{Atom, Composition as C, Junction as J},
|
||||||
Prim::{AddPair, Cos, Input, Sin, Triple},
|
Prim::{AddPair, Cos, Sin, Triple},
|
||||||
ast_input, hash_ast,
|
ast_input, hash_ast,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use compiler::compile_to_scalar;
|
||||||
use dsl::{Composition, Expr, Function, Junction};
|
use dsl::{Composition, Expr, Function, Junction};
|
||||||
|
use primitives::{add_pair, cos_f, sin_f, triple};
|
||||||
// ========= тестовые функции =========
|
use store::ExprStore;
|
||||||
|
|
||||||
fn sin_f(x: f64) -> f64 {
|
|
||||||
x.sin()
|
|
||||||
}
|
|
||||||
fn cos_f(x: f64) -> f64 {
|
|
||||||
x.cos()
|
|
||||||
}
|
|
||||||
fn triple(x: f64) -> f64 {
|
|
||||||
3.0 * x
|
|
||||||
}
|
|
||||||
|
|
||||||
/// (a,b) -> a + b
|
|
||||||
fn add_pair(p: (f64, f64)) -> f64 {
|
|
||||||
p.0 + p.1
|
|
||||||
}
|
|
||||||
|
|
||||||
// AST для sin(cos(x) + 3*x)
|
// AST для sin(cos(x) + 3*x)
|
||||||
fn pupa() -> ExprAst {
|
fn pupa() -> ExprAst {
|
||||||
|
|
@ -95,4 +85,45 @@ fn main() {
|
||||||
println!("equal? {}", ast1 == ast2);
|
println!("equal? {}", ast1 == ast2);
|
||||||
println!("hash1 = {:016x}", hash_ast(&ast1));
|
println!("hash1 = {:016x}", hash_ast(&ast1));
|
||||||
println!("hash2 = {:016x}", hash_ast(&ast2));
|
println!("hash2 = {:016x}", hash_ast(&ast2));
|
||||||
|
|
||||||
|
// --- компиляция AST -> DSL и вычисление ---
|
||||||
|
|
||||||
|
let compiled = compile_to_scalar(&ast).expect("AST should compile to f64 -> f64");
|
||||||
|
let compiled_y = compiled.eval(x);
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"compiled eval: sin(cos({}) + 3*{}) = {}",
|
||||||
|
x, x, compiled_y
|
||||||
|
);
|
||||||
|
|
||||||
|
// --- контент-адресуемое хранилище выражений ---
|
||||||
|
|
||||||
|
let mut store = ExprStore::new();
|
||||||
|
let (stored_hash, stored_expr) = store
|
||||||
|
.insert_scalar(ast.clone())
|
||||||
|
.expect("store insert should compile");
|
||||||
|
let store_eval = stored_expr.eval(x);
|
||||||
|
let stored_ast = store
|
||||||
|
.get_ast(stored_hash)
|
||||||
|
.expect("AST should be retrievable by hash");
|
||||||
|
let store_eval_via_lookup = store
|
||||||
|
.eval_scalar(stored_hash, x)
|
||||||
|
.expect("store should return scalar result");
|
||||||
|
let compiled_from_store = store
|
||||||
|
.get(stored_hash)
|
||||||
|
.expect("compiled expression should be cached");
|
||||||
|
|
||||||
|
println!(
|
||||||
|
"store eval (hash {:016x}): sin(cos({}) + 3*{}) = {}",
|
||||||
|
stored_hash, x, x, store_eval
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"store eval via lookup (hash {:016x}): sin(cos({}) + 3*{}) = {}",
|
||||||
|
stored_hash, x, x, store_eval_via_lookup
|
||||||
|
);
|
||||||
|
println!(
|
||||||
|
"compiled cached? {}",
|
||||||
|
matches!(compiled_from_store.as_ref(), compiler::CompiledExpr::F64(_))
|
||||||
|
);
|
||||||
|
println!("store keeps AST equal to original? {}", stored_ast == &ast);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue