new file: compiler.rs
new file: primitives.rs new file: store.rs new file: support.rs
This commit is contained in:
parent
96a783ae5a
commit
c86ad6e5ac
4 changed files with 213 additions and 0 deletions
83
src/compiler.rs
Normal file
83
src/compiler.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
||||||
|
use crate::ast::{ExprAst, Prim};
|
||||||
|
use crate::dsl::{Composition, Expr, Function, Junction};
|
||||||
|
use crate::primitives::{add_pair, cos_f, id, sin_f, triple};
|
||||||
|
use crate::support::Arc;
|
||||||
|
|
||||||
|
/// Runtime type for a compiled AST node.
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub enum CompiledExpr {
|
||||||
|
/// f64 -> f64
|
||||||
|
F64(Arc<dyn Expr<f64, Out = f64> + Send + Sync + 'static>),
|
||||||
|
/// f64 -> (f64, f64)
|
||||||
|
Pair(Arc<dyn Expr<f64, Out = (f64, f64)> + Send + Sync + 'static>),
|
||||||
|
/// (f64, f64) -> f64
|
||||||
|
PairToF64(Arc<dyn Expr<(f64, f64), Out = f64> + Send + Sync + 'static>),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CompiledExpr {
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compile_ast(ast: &ExprAst) -> Result<CompiledExpr, String> {
|
||||||
|
match ast {
|
||||||
|
ExprAst::Atom(prim) => compile_atom(prim),
|
||||||
|
ExprAst::Composition(first, second) => compile_composition(first, second),
|
||||||
|
ExprAst::Junction(left, right) => compile_junction(left, right),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn compile_to_scalar(
|
||||||
|
ast: &ExprAst,
|
||||||
|
) -> Result<Arc<dyn Expr<f64, Out = f64> + Send + Sync + 'static>, String> {
|
||||||
|
match compile_ast(ast)? {
|
||||||
|
CompiledExpr::F64(expr) => Ok(expr),
|
||||||
|
_ => Err("expected AST that evaluates to f64".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_atom(prim: &Prim) -> Result<CompiledExpr, String> {
|
||||||
|
let compiled = match prim {
|
||||||
|
Prim::Sin => CompiledExpr::F64(Arc::new(Function { f: sin_f })),
|
||||||
|
Prim::Cos => CompiledExpr::F64(Arc::new(Function { f: cos_f })),
|
||||||
|
Prim::Triple => CompiledExpr::F64(Arc::new(Function { f: triple })),
|
||||||
|
Prim::AddPair => CompiledExpr::PairToF64(Arc::new(Function { f: add_pair })),
|
||||||
|
Prim::Input => CompiledExpr::F64(Arc::new(Function { f: id })),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(compiled)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_composition(
|
||||||
|
first: &ExprAst,
|
||||||
|
second: &ExprAst,
|
||||||
|
) -> Result<CompiledExpr, String> {
|
||||||
|
let compiled_first = compile_ast(first)?;
|
||||||
|
let compiled_second = compile_ast(second)?;
|
||||||
|
|
||||||
|
match (compiled_first, compiled_second) {
|
||||||
|
// f64 -> f64 -> f64
|
||||||
|
(CompiledExpr::F64(f1), CompiledExpr::F64(f2)) => Ok(CompiledExpr::F64(Arc::new(
|
||||||
|
Composition { first: f1, second: f2 },
|
||||||
|
))),
|
||||||
|
// f64 -> f64 -> (f64,f64)
|
||||||
|
(CompiledExpr::F64(f1), CompiledExpr::Pair(f2)) => Ok(CompiledExpr::Pair(Arc::new(
|
||||||
|
Composition { first: f1, second: f2 },
|
||||||
|
))),
|
||||||
|
// f64 -> (f64,f64) -> f64
|
||||||
|
(CompiledExpr::Pair(f1), CompiledExpr::PairToF64(f2)) => Ok(CompiledExpr::F64(
|
||||||
|
Arc::new(Composition { first: f1, second: f2 }),
|
||||||
|
)),
|
||||||
|
_ => Err("type mismatch in composition".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn compile_junction(left: &ExprAst, right: &ExprAst) -> Result<CompiledExpr, String> {
|
||||||
|
let compiled_left = compile_ast(left)?;
|
||||||
|
let compiled_right = compile_ast(right)?;
|
||||||
|
|
||||||
|
match (compiled_left, compiled_right) {
|
||||||
|
(CompiledExpr::F64(l), CompiledExpr::F64(r)) => Ok(CompiledExpr::Pair(Arc::new(
|
||||||
|
Junction { left: l, right: r },
|
||||||
|
))),
|
||||||
|
_ => Err("junction requires two f64-returning branches".to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
22
src/primitives.rs
Normal file
22
src/primitives.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
||||||
|
// Reusable primitive functions used by both the DSL and the AST compiler.
|
||||||
|
pub fn sin_f(x: f64) -> f64 {
|
||||||
|
x.sin()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn cos_f(x: f64) -> f64 {
|
||||||
|
x.cos()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn triple(x: f64) -> f64 {
|
||||||
|
3.0 * x
|
||||||
|
}
|
||||||
|
|
||||||
|
/// (a,b) -> a + b
|
||||||
|
pub fn add_pair(p: (f64, f64)) -> f64 {
|
||||||
|
p.0 + p.1
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Identity for the AST `Input` primitive.
|
||||||
|
pub fn id(x: f64) -> f64 {
|
||||||
|
x
|
||||||
|
}
|
||||||
69
src/store.rs
Normal file
69
src/store.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
use crate::ast::{hash_ast, ExprAst};
|
||||||
|
use crate::compiler::{compile_ast, compile_to_scalar, CompiledExpr};
|
||||||
|
use crate::support::Arc;
|
||||||
|
|
||||||
|
/// In-memory content-addressed store keyed by AST hash.
|
||||||
|
#[derive(Default)]
|
||||||
|
pub struct ExprStore {
|
||||||
|
compiled: HashMap<u64, Arc<CompiledExpr>>,
|
||||||
|
asts: HashMap<u64, ExprAst>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ExprStore {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an AST, compile it once, and return its hash.
|
||||||
|
/// If the hash already exists, the cached entry is reused.
|
||||||
|
pub fn insert(&mut self, ast: ExprAst) -> Result<u64, String> {
|
||||||
|
let hash = hash_ast(&ast);
|
||||||
|
if !self.compiled.contains_key(&hash) {
|
||||||
|
let compiled = compile_ast(&ast)?;
|
||||||
|
self.compiled.insert(hash, Arc::new(compiled));
|
||||||
|
self.asts.insert(hash, ast);
|
||||||
|
}
|
||||||
|
Ok(hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch a compiled expression by hash.
|
||||||
|
pub fn get(&self, hash: u64) -> Option<Arc<CompiledExpr>> {
|
||||||
|
self.compiled.get(&hash).cloned()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Fetch the original AST by hash.
|
||||||
|
pub fn get_ast(&self, hash: u64) -> Option<&ExprAst> {
|
||||||
|
self.asts.get(&hash)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Evaluate a stored scalar expression for the provided input.
|
||||||
|
pub fn eval_scalar(&self, hash: u64, input: f64) -> Result<f64, String> {
|
||||||
|
let compiled = self
|
||||||
|
.get(hash)
|
||||||
|
.ok_or_else(|| format!("hash {:016x} not found", hash))?;
|
||||||
|
|
||||||
|
match compiled.as_ref() {
|
||||||
|
CompiledExpr::F64(expr) => Ok(expr.eval(input)),
|
||||||
|
CompiledExpr::Pair(_) => Err("expression returns a pair, not f64".into()),
|
||||||
|
CompiledExpr::PairToF64(_) => Err("expression expects a pair input, not f64".into()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Insert an AST that must evaluate to f64 -> f64, returning its hash and compiled form.
|
||||||
|
pub fn insert_scalar(
|
||||||
|
&mut self,
|
||||||
|
ast: ExprAst,
|
||||||
|
) -> Result<
|
||||||
|
(
|
||||||
|
u64,
|
||||||
|
Arc<dyn crate::dsl::Expr<f64, Out = f64> + Send + Sync + 'static>,
|
||||||
|
),
|
||||||
|
String,
|
||||||
|
> {
|
||||||
|
let hash = self.insert(ast.clone())?;
|
||||||
|
let expr = compile_to_scalar(&ast)?;
|
||||||
|
Ok((hash, expr))
|
||||||
|
}
|
||||||
|
}
|
||||||
39
src/support.rs
Normal file
39
src/support.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
||||||
|
// Platform helpers to ease future no_std portability.
|
||||||
|
// The crate currently depends on `std`, but concentrating platform-specific
|
||||||
|
// pieces here makes migration simpler.
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use std::sync::Arc;
|
||||||
|
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub fn run_parallel<L, R, OutL, OutR>(left: L, right: R) -> (OutL, OutR)
|
||||||
|
where
|
||||||
|
L: FnOnce() -> OutL + Send + 'static,
|
||||||
|
R: FnOnce() -> OutR + Send + 'static,
|
||||||
|
OutL: Send + 'static,
|
||||||
|
OutR: Send + 'static,
|
||||||
|
{
|
||||||
|
let h1 = std::thread::spawn(left);
|
||||||
|
let h2 = std::thread::spawn(right);
|
||||||
|
(h1.join().unwrap(), h2.join().unwrap())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
mod no_std_support {
|
||||||
|
extern crate alloc;
|
||||||
|
pub use alloc::sync::Arc;
|
||||||
|
|
||||||
|
pub fn run_parallel<L, R, OutL, OutR>(left: L, right: R) -> (OutL, OutR)
|
||||||
|
where
|
||||||
|
L: FnOnce() -> OutL + Send + 'static,
|
||||||
|
R: FnOnce() -> OutR + Send + 'static,
|
||||||
|
OutL: Send + 'static,
|
||||||
|
OutR: Send + 'static,
|
||||||
|
{
|
||||||
|
// No threads available: fall back to sequential execution.
|
||||||
|
(left(), right())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(feature = "std"))]
|
||||||
|
pub use no_std_support::*;
|
||||||
Loading…
Add table
Add a link
Reference in a new issue