diff --git a/src/compiler.rs b/src/compiler.rs new file mode 100644 index 0000000..589931b --- /dev/null +++ b/src/compiler.rs @@ -0,0 +1,83 @@ +use crate::ast::{ExprAst, Prim}; +use crate::dsl::{Composition, Expr, Function, Junction}; +use crate::primitives::{add_pair, cos_f, id, sin_f, triple}; +use crate::support::Arc; + +/// Runtime type for a compiled AST node. +#[derive(Clone)] +pub enum CompiledExpr { + /// f64 -> f64 + F64(Arc + Send + Sync + 'static>), + /// f64 -> (f64, f64) + Pair(Arc + Send + Sync + 'static>), + /// (f64, f64) -> f64 + PairToF64(Arc + Send + Sync + 'static>), +} + +impl CompiledExpr { +} + +pub fn compile_ast(ast: &ExprAst) -> Result { + match ast { + ExprAst::Atom(prim) => compile_atom(prim), + ExprAst::Composition(first, second) => compile_composition(first, second), + ExprAst::Junction(left, right) => compile_junction(left, right), + } +} + +pub fn compile_to_scalar( + ast: &ExprAst, +) -> Result + Send + Sync + 'static>, String> { + match compile_ast(ast)? { + CompiledExpr::F64(expr) => Ok(expr), + _ => Err("expected AST that evaluates to f64".to_string()), + } +} + +fn compile_atom(prim: &Prim) -> Result { + let compiled = match prim { + Prim::Sin => CompiledExpr::F64(Arc::new(Function { f: sin_f })), + Prim::Cos => CompiledExpr::F64(Arc::new(Function { f: cos_f })), + Prim::Triple => CompiledExpr::F64(Arc::new(Function { f: triple })), + Prim::AddPair => CompiledExpr::PairToF64(Arc::new(Function { f: add_pair })), + Prim::Input => CompiledExpr::F64(Arc::new(Function { f: id })), + }; + + Ok(compiled) +} + +fn compile_composition( + first: &ExprAst, + second: &ExprAst, +) -> Result { + let compiled_first = compile_ast(first)?; + let compiled_second = compile_ast(second)?; + + match (compiled_first, compiled_second) { + // f64 -> f64 -> f64 + (CompiledExpr::F64(f1), CompiledExpr::F64(f2)) => Ok(CompiledExpr::F64(Arc::new( + Composition { first: f1, second: f2 }, + ))), + // f64 -> f64 -> (f64,f64) + (CompiledExpr::F64(f1), CompiledExpr::Pair(f2)) => Ok(CompiledExpr::Pair(Arc::new( + Composition { first: f1, second: f2 }, + ))), + // f64 -> (f64,f64) -> f64 + (CompiledExpr::Pair(f1), CompiledExpr::PairToF64(f2)) => Ok(CompiledExpr::F64( + Arc::new(Composition { first: f1, second: f2 }), + )), + _ => Err("type mismatch in composition".to_string()), + } +} + +fn compile_junction(left: &ExprAst, right: &ExprAst) -> Result { + let compiled_left = compile_ast(left)?; + let compiled_right = compile_ast(right)?; + + match (compiled_left, compiled_right) { + (CompiledExpr::F64(l), CompiledExpr::F64(r)) => Ok(CompiledExpr::Pair(Arc::new( + Junction { left: l, right: r }, + ))), + _ => Err("junction requires two f64-returning branches".to_string()), + } +} diff --git a/src/primitives.rs b/src/primitives.rs new file mode 100644 index 0000000..9b7652a --- /dev/null +++ b/src/primitives.rs @@ -0,0 +1,22 @@ +// Reusable primitive functions used by both the DSL and the AST compiler. +pub fn sin_f(x: f64) -> f64 { + x.sin() +} + +pub fn cos_f(x: f64) -> f64 { + x.cos() +} + +pub fn triple(x: f64) -> f64 { + 3.0 * x +} + +/// (a,b) -> a + b +pub fn add_pair(p: (f64, f64)) -> f64 { + p.0 + p.1 +} + +/// Identity for the AST `Input` primitive. +pub fn id(x: f64) -> f64 { + x +} diff --git a/src/store.rs b/src/store.rs new file mode 100644 index 0000000..f14e0f5 --- /dev/null +++ b/src/store.rs @@ -0,0 +1,69 @@ +use std::collections::HashMap; + +use crate::ast::{hash_ast, ExprAst}; +use crate::compiler::{compile_ast, compile_to_scalar, CompiledExpr}; +use crate::support::Arc; + +/// In-memory content-addressed store keyed by AST hash. +#[derive(Default)] +pub struct ExprStore { + compiled: HashMap>, + asts: HashMap, +} + +impl ExprStore { + pub fn new() -> Self { + Self::default() + } + + /// Insert an AST, compile it once, and return its hash. + /// If the hash already exists, the cached entry is reused. + pub fn insert(&mut self, ast: ExprAst) -> Result { + let hash = hash_ast(&ast); + if !self.compiled.contains_key(&hash) { + let compiled = compile_ast(&ast)?; + self.compiled.insert(hash, Arc::new(compiled)); + self.asts.insert(hash, ast); + } + Ok(hash) + } + + /// Fetch a compiled expression by hash. + pub fn get(&self, hash: u64) -> Option> { + self.compiled.get(&hash).cloned() + } + + /// Fetch the original AST by hash. + pub fn get_ast(&self, hash: u64) -> Option<&ExprAst> { + self.asts.get(&hash) + } + + /// Evaluate a stored scalar expression for the provided input. + pub fn eval_scalar(&self, hash: u64, input: f64) -> Result { + let compiled = self + .get(hash) + .ok_or_else(|| format!("hash {:016x} not found", hash))?; + + match compiled.as_ref() { + CompiledExpr::F64(expr) => Ok(expr.eval(input)), + CompiledExpr::Pair(_) => Err("expression returns a pair, not f64".into()), + CompiledExpr::PairToF64(_) => Err("expression expects a pair input, not f64".into()), + } + } + + /// Insert an AST that must evaluate to f64 -> f64, returning its hash and compiled form. + pub fn insert_scalar( + &mut self, + ast: ExprAst, + ) -> Result< + ( + u64, + Arc + Send + Sync + 'static>, + ), + String, + > { + let hash = self.insert(ast.clone())?; + let expr = compile_to_scalar(&ast)?; + Ok((hash, expr)) + } +} diff --git a/src/support.rs b/src/support.rs new file mode 100644 index 0000000..ccec14b --- /dev/null +++ b/src/support.rs @@ -0,0 +1,39 @@ +// Platform helpers to ease future no_std portability. +// The crate currently depends on `std`, but concentrating platform-specific +// pieces here makes migration simpler. + +#[cfg(feature = "std")] +pub use std::sync::Arc; + +#[cfg(feature = "std")] +pub fn run_parallel(left: L, right: R) -> (OutL, OutR) +where + L: FnOnce() -> OutL + Send + 'static, + R: FnOnce() -> OutR + Send + 'static, + OutL: Send + 'static, + OutR: Send + 'static, +{ + let h1 = std::thread::spawn(left); + let h2 = std::thread::spawn(right); + (h1.join().unwrap(), h2.join().unwrap()) +} + +#[cfg(not(feature = "std"))] +mod no_std_support { + extern crate alloc; + pub use alloc::sync::Arc; + + pub fn run_parallel(left: L, right: R) -> (OutL, OutR) + where + L: FnOnce() -> OutL + Send + 'static, + R: FnOnce() -> OutR + Send + 'static, + OutL: Send + 'static, + OutR: Send + 'static, + { + // No threads available: fall back to sequential execution. + (left(), right()) + } +} + +#[cfg(not(feature = "std"))] +pub use no_std_support::*;