Compare commits
2 commits
f393c583b8
...
c86ad6e5ac
| Author | SHA1 | Date | |
|---|---|---|---|
| c86ad6e5ac | |||
| 96a783ae5a |
8 changed files with 364 additions and 28 deletions
|
|
@ -4,3 +4,7 @@ version = "0.1.0"
|
|||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
|
||||
[features]
|
||||
default = ["std"]
|
||||
std = []
|
||||
|
|
|
|||
82
src/ast.rs
82
src/ast.rs
|
|
@ -10,7 +10,13 @@ pub enum Prim {
|
|||
Input, // переменная x
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
impl Prim {
|
||||
pub const fn is_commutative(&self) -> bool {
|
||||
matches!(self, Prim::AddPair)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ExprAst {
|
||||
Atom(Prim),
|
||||
Composition(Box<ExprAst>, Box<ExprAst>), // g ∘ f
|
||||
|
|
@ -26,3 +32,77 @@ pub fn hash_ast(e: &ExprAst) -> u64 {
|
|||
e.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
|
||||
impl ExprAst {
|
||||
fn fingerprint(&self) -> u64 {
|
||||
let mut h = DefaultHasher::new();
|
||||
self.hash(&mut h);
|
||||
h.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialEq for ExprAst {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(ExprAst::Atom(p1), ExprAst::Atom(p2)) => p1 == p2,
|
||||
(ExprAst::Composition(a1, b1), ExprAst::Composition(a2, b2)) => {
|
||||
match (b1.as_ref(), b2.as_ref()) {
|
||||
(ExprAst::Atom(p1), ExprAst::Atom(p2))
|
||||
if p1.is_commutative() && p2.is_commutative() =>
|
||||
{
|
||||
if p1 != p2 {
|
||||
return false;
|
||||
}
|
||||
match (a1.as_ref(), a2.as_ref()) {
|
||||
(ExprAst::Junction(l1, r1), ExprAst::Junction(l2, r2)) => {
|
||||
(l1 == l2 && r1 == r2) || (l1 == r2 && r1 == l2)
|
||||
}
|
||||
_ => a1 == a2,
|
||||
}
|
||||
}
|
||||
_ => a1 == a2 && b1 == b2,
|
||||
}
|
||||
}
|
||||
(ExprAst::Junction(l1, r1), ExprAst::Junction(l2, r2)) => l1 == l2 && r1 == r2,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for ExprAst {}
|
||||
|
||||
impl Hash for ExprAst {
|
||||
fn hash<H: Hasher>(&self, state: &mut H) {
|
||||
match self {
|
||||
ExprAst::Atom(p) => {
|
||||
0u8.hash(state);
|
||||
p.hash(state);
|
||||
}
|
||||
ExprAst::Composition(a, b) => {
|
||||
if let ExprAst::Atom(p) = b.as_ref() {
|
||||
if p.is_commutative() {
|
||||
// Canonicalize inputs for commutative operations.
|
||||
1u8.hash(state);
|
||||
p.hash(state);
|
||||
if let ExprAst::Junction(l, r) = a.as_ref() {
|
||||
let hl = l.fingerprint();
|
||||
let hr = r.fingerprint();
|
||||
let (lo, hi) = if hl <= hr { (hl, hr) } else { (hr, hl) };
|
||||
lo.hash(state);
|
||||
hi.hash(state);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
1u8.hash(state);
|
||||
a.hash(state);
|
||||
b.hash(state);
|
||||
}
|
||||
ExprAst::Junction(l, r) => {
|
||||
2u8.hash(state);
|
||||
l.hash(state);
|
||||
r.hash(state);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
83
src/compiler.rs
Normal file
83
src/compiler.rs
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
use crate::ast::{ExprAst, Prim};
|
||||
use crate::dsl::{Composition, Expr, Function, Junction};
|
||||
use crate::primitives::{add_pair, cos_f, id, sin_f, triple};
|
||||
use crate::support::Arc;
|
||||
|
||||
/// Runtime type for a compiled AST node.
|
||||
#[derive(Clone)]
|
||||
pub enum CompiledExpr {
|
||||
/// f64 -> f64
|
||||
F64(Arc<dyn Expr<f64, Out = f64> + Send + Sync + 'static>),
|
||||
/// f64 -> (f64, f64)
|
||||
Pair(Arc<dyn Expr<f64, Out = (f64, f64)> + Send + Sync + 'static>),
|
||||
/// (f64, f64) -> f64
|
||||
PairToF64(Arc<dyn Expr<(f64, f64), Out = f64> + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl CompiledExpr {
|
||||
}
|
||||
|
||||
pub fn compile_ast(ast: &ExprAst) -> Result<CompiledExpr, String> {
|
||||
match ast {
|
||||
ExprAst::Atom(prim) => compile_atom(prim),
|
||||
ExprAst::Composition(first, second) => compile_composition(first, second),
|
||||
ExprAst::Junction(left, right) => compile_junction(left, right),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn compile_to_scalar(
|
||||
ast: &ExprAst,
|
||||
) -> Result<Arc<dyn Expr<f64, Out = f64> + Send + Sync + 'static>, String> {
|
||||
match compile_ast(ast)? {
|
||||
CompiledExpr::F64(expr) => Ok(expr),
|
||||
_ => Err("expected AST that evaluates to f64".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_atom(prim: &Prim) -> Result<CompiledExpr, String> {
|
||||
let compiled = match prim {
|
||||
Prim::Sin => CompiledExpr::F64(Arc::new(Function { f: sin_f })),
|
||||
Prim::Cos => CompiledExpr::F64(Arc::new(Function { f: cos_f })),
|
||||
Prim::Triple => CompiledExpr::F64(Arc::new(Function { f: triple })),
|
||||
Prim::AddPair => CompiledExpr::PairToF64(Arc::new(Function { f: add_pair })),
|
||||
Prim::Input => CompiledExpr::F64(Arc::new(Function { f: id })),
|
||||
};
|
||||
|
||||
Ok(compiled)
|
||||
}
|
||||
|
||||
fn compile_composition(
|
||||
first: &ExprAst,
|
||||
second: &ExprAst,
|
||||
) -> Result<CompiledExpr, String> {
|
||||
let compiled_first = compile_ast(first)?;
|
||||
let compiled_second = compile_ast(second)?;
|
||||
|
||||
match (compiled_first, compiled_second) {
|
||||
// f64 -> f64 -> f64
|
||||
(CompiledExpr::F64(f1), CompiledExpr::F64(f2)) => Ok(CompiledExpr::F64(Arc::new(
|
||||
Composition { first: f1, second: f2 },
|
||||
))),
|
||||
// f64 -> f64 -> (f64,f64)
|
||||
(CompiledExpr::F64(f1), CompiledExpr::Pair(f2)) => Ok(CompiledExpr::Pair(Arc::new(
|
||||
Composition { first: f1, second: f2 },
|
||||
))),
|
||||
// f64 -> (f64,f64) -> f64
|
||||
(CompiledExpr::Pair(f1), CompiledExpr::PairToF64(f2)) => Ok(CompiledExpr::F64(
|
||||
Arc::new(Composition { first: f1, second: f2 }),
|
||||
)),
|
||||
_ => Err("type mismatch in composition".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_junction(left: &ExprAst, right: &ExprAst) -> Result<CompiledExpr, String> {
|
||||
let compiled_left = compile_ast(left)?;
|
||||
let compiled_right = compile_ast(right)?;
|
||||
|
||||
match (compiled_left, compiled_right) {
|
||||
(CompiledExpr::F64(l), CompiledExpr::F64(r)) => Ok(CompiledExpr::Pair(Arc::new(
|
||||
Junction { left: l, right: r },
|
||||
))),
|
||||
_ => Err("junction requires two f64-returning branches".to_string()),
|
||||
}
|
||||
}
|
||||
26
src/dsl.rs
26
src/dsl.rs
|
|
@ -1,5 +1,4 @@
|
|||
use std::sync::Arc;
|
||||
use std::thread;
|
||||
use crate::support::{run_parallel, Arc};
|
||||
|
||||
// Обобщённое выражение: I -> Out
|
||||
pub trait Expr<I> {
|
||||
|
|
@ -7,6 +6,18 @@ pub trait Expr<I> {
|
|||
fn eval(&self, x: I) -> Self::Out;
|
||||
}
|
||||
|
||||
// Позволяет использовать Arc<T> там, где требуется Expr.
|
||||
impl<I, T> Expr<I> for Arc<T>
|
||||
where
|
||||
T: Expr<I> + ?Sized,
|
||||
{
|
||||
type Out = T::Out;
|
||||
|
||||
fn eval(&self, x: I) -> Self::Out {
|
||||
(**self).eval(x)
|
||||
}
|
||||
}
|
||||
|
||||
// атомарная функция
|
||||
pub struct Function<I, O> {
|
||||
pub f: fn(I) -> O,
|
||||
|
|
@ -39,7 +50,7 @@ where
|
|||
}
|
||||
|
||||
// junction — две ветки параллельно
|
||||
pub struct Junction<F1, F2> {
|
||||
pub struct Junction<F1: ?Sized, F2: ?Sized> {
|
||||
pub left: Arc<F1>,
|
||||
pub right: Arc<F2>,
|
||||
}
|
||||
|
|
@ -47,8 +58,8 @@ pub struct Junction<F1, F2> {
|
|||
impl<I, F1, F2> Expr<I> for Junction<F1, F2>
|
||||
where
|
||||
I: Copy + Send + 'static,
|
||||
F1: Expr<I> + Send + Sync + 'static,
|
||||
F2: Expr<I> + Send + Sync + 'static,
|
||||
F1: Expr<I> + Send + Sync + 'static + ?Sized,
|
||||
F2: Expr<I> + Send + Sync + 'static + ?Sized,
|
||||
F1::Out: Send + 'static,
|
||||
F2::Out: Send + 'static,
|
||||
{
|
||||
|
|
@ -61,9 +72,6 @@ where
|
|||
let x1 = x;
|
||||
let x2 = x;
|
||||
|
||||
let h1 = thread::spawn(move || l.eval(x1));
|
||||
let h2 = thread::spawn(move || r.eval(x2));
|
||||
|
||||
(h1.join().unwrap(), h2.join().unwrap())
|
||||
run_parallel(move || l.eval(x1), move || r.eval(x2))
|
||||
}
|
||||
}
|
||||
|
|
|
|||
67
src/main.rs
67
src/main.rs
|
|
@ -1,33 +1,23 @@
|
|||
mod ast;
|
||||
mod compiler;
|
||||
mod dsl;
|
||||
mod primitives;
|
||||
mod store;
|
||||
mod support;
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use ast::{
|
||||
ExprAst,
|
||||
ExprAst::{Atom, Composition as C, Junction as J},
|
||||
Prim::{AddPair, Cos, Input, Sin, Triple},
|
||||
Prim::{AddPair, Cos, Sin, Triple},
|
||||
ast_input, hash_ast,
|
||||
};
|
||||
|
||||
use compiler::compile_to_scalar;
|
||||
use dsl::{Composition, Expr, Function, Junction};
|
||||
|
||||
// ========= тестовые функции =========
|
||||
|
||||
fn sin_f(x: f64) -> f64 {
|
||||
x.sin()
|
||||
}
|
||||
fn cos_f(x: f64) -> f64 {
|
||||
x.cos()
|
||||
}
|
||||
fn triple(x: f64) -> f64 {
|
||||
3.0 * x
|
||||
}
|
||||
|
||||
/// (a,b) -> a + b
|
||||
fn add_pair(p: (f64, f64)) -> f64 {
|
||||
p.0 + p.1
|
||||
}
|
||||
use primitives::{add_pair, cos_f, sin_f, triple};
|
||||
use store::ExprStore;
|
||||
|
||||
// AST для sin(cos(x) + 3*x)
|
||||
fn pupa() -> ExprAst {
|
||||
|
|
@ -95,4 +85,45 @@ fn main() {
|
|||
println!("equal? {}", ast1 == ast2);
|
||||
println!("hash1 = {:016x}", hash_ast(&ast1));
|
||||
println!("hash2 = {:016x}", hash_ast(&ast2));
|
||||
|
||||
// --- компиляция AST -> DSL и вычисление ---
|
||||
|
||||
let compiled = compile_to_scalar(&ast).expect("AST should compile to f64 -> f64");
|
||||
let compiled_y = compiled.eval(x);
|
||||
|
||||
println!(
|
||||
"compiled eval: sin(cos({}) + 3*{}) = {}",
|
||||
x, x, compiled_y
|
||||
);
|
||||
|
||||
// --- контент-адресуемое хранилище выражений ---
|
||||
|
||||
let mut store = ExprStore::new();
|
||||
let (stored_hash, stored_expr) = store
|
||||
.insert_scalar(ast.clone())
|
||||
.expect("store insert should compile");
|
||||
let store_eval = stored_expr.eval(x);
|
||||
let stored_ast = store
|
||||
.get_ast(stored_hash)
|
||||
.expect("AST should be retrievable by hash");
|
||||
let store_eval_via_lookup = store
|
||||
.eval_scalar(stored_hash, x)
|
||||
.expect("store should return scalar result");
|
||||
let compiled_from_store = store
|
||||
.get(stored_hash)
|
||||
.expect("compiled expression should be cached");
|
||||
|
||||
println!(
|
||||
"store eval (hash {:016x}): sin(cos({}) + 3*{}) = {}",
|
||||
stored_hash, x, x, store_eval
|
||||
);
|
||||
println!(
|
||||
"store eval via lookup (hash {:016x}): sin(cos({}) + 3*{}) = {}",
|
||||
stored_hash, x, x, store_eval_via_lookup
|
||||
);
|
||||
println!(
|
||||
"compiled cached? {}",
|
||||
matches!(compiled_from_store.as_ref(), compiler::CompiledExpr::F64(_))
|
||||
);
|
||||
println!("store keeps AST equal to original? {}", stored_ast == &ast);
|
||||
}
|
||||
|
|
|
|||
22
src/primitives.rs
Normal file
22
src/primitives.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
// Reusable primitive functions used by both the DSL and the AST compiler.
|
||||
pub fn sin_f(x: f64) -> f64 {
|
||||
x.sin()
|
||||
}
|
||||
|
||||
pub fn cos_f(x: f64) -> f64 {
|
||||
x.cos()
|
||||
}
|
||||
|
||||
pub fn triple(x: f64) -> f64 {
|
||||
3.0 * x
|
||||
}
|
||||
|
||||
/// (a,b) -> a + b
|
||||
pub fn add_pair(p: (f64, f64)) -> f64 {
|
||||
p.0 + p.1
|
||||
}
|
||||
|
||||
/// Identity for the AST `Input` primitive.
|
||||
pub fn id(x: f64) -> f64 {
|
||||
x
|
||||
}
|
||||
69
src/store.rs
Normal file
69
src/store.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::ast::{hash_ast, ExprAst};
|
||||
use crate::compiler::{compile_ast, compile_to_scalar, CompiledExpr};
|
||||
use crate::support::Arc;
|
||||
|
||||
/// In-memory content-addressed store keyed by AST hash.
|
||||
#[derive(Default)]
|
||||
pub struct ExprStore {
|
||||
compiled: HashMap<u64, Arc<CompiledExpr>>,
|
||||
asts: HashMap<u64, ExprAst>,
|
||||
}
|
||||
|
||||
impl ExprStore {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Insert an AST, compile it once, and return its hash.
|
||||
/// If the hash already exists, the cached entry is reused.
|
||||
pub fn insert(&mut self, ast: ExprAst) -> Result<u64, String> {
|
||||
let hash = hash_ast(&ast);
|
||||
if !self.compiled.contains_key(&hash) {
|
||||
let compiled = compile_ast(&ast)?;
|
||||
self.compiled.insert(hash, Arc::new(compiled));
|
||||
self.asts.insert(hash, ast);
|
||||
}
|
||||
Ok(hash)
|
||||
}
|
||||
|
||||
/// Fetch a compiled expression by hash.
|
||||
pub fn get(&self, hash: u64) -> Option<Arc<CompiledExpr>> {
|
||||
self.compiled.get(&hash).cloned()
|
||||
}
|
||||
|
||||
/// Fetch the original AST by hash.
|
||||
pub fn get_ast(&self, hash: u64) -> Option<&ExprAst> {
|
||||
self.asts.get(&hash)
|
||||
}
|
||||
|
||||
/// Evaluate a stored scalar expression for the provided input.
|
||||
pub fn eval_scalar(&self, hash: u64, input: f64) -> Result<f64, String> {
|
||||
let compiled = self
|
||||
.get(hash)
|
||||
.ok_or_else(|| format!("hash {:016x} not found", hash))?;
|
||||
|
||||
match compiled.as_ref() {
|
||||
CompiledExpr::F64(expr) => Ok(expr.eval(input)),
|
||||
CompiledExpr::Pair(_) => Err("expression returns a pair, not f64".into()),
|
||||
CompiledExpr::PairToF64(_) => Err("expression expects a pair input, not f64".into()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Insert an AST that must evaluate to f64 -> f64, returning its hash and compiled form.
|
||||
pub fn insert_scalar(
|
||||
&mut self,
|
||||
ast: ExprAst,
|
||||
) -> Result<
|
||||
(
|
||||
u64,
|
||||
Arc<dyn crate::dsl::Expr<f64, Out = f64> + Send + Sync + 'static>,
|
||||
),
|
||||
String,
|
||||
> {
|
||||
let hash = self.insert(ast.clone())?;
|
||||
let expr = compile_to_scalar(&ast)?;
|
||||
Ok((hash, expr))
|
||||
}
|
||||
}
|
||||
39
src/support.rs
Normal file
39
src/support.rs
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
// Platform helpers to ease future no_std portability.
|
||||
// The crate currently depends on `std`, but concentrating platform-specific
|
||||
// pieces here makes migration simpler.
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub use std::sync::Arc;
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
pub fn run_parallel<L, R, OutL, OutR>(left: L, right: R) -> (OutL, OutR)
|
||||
where
|
||||
L: FnOnce() -> OutL + Send + 'static,
|
||||
R: FnOnce() -> OutR + Send + 'static,
|
||||
OutL: Send + 'static,
|
||||
OutR: Send + 'static,
|
||||
{
|
||||
let h1 = std::thread::spawn(left);
|
||||
let h2 = std::thread::spawn(right);
|
||||
(h1.join().unwrap(), h2.join().unwrap())
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
mod no_std_support {
|
||||
extern crate alloc;
|
||||
pub use alloc::sync::Arc;
|
||||
|
||||
pub fn run_parallel<L, R, OutL, OutR>(left: L, right: R) -> (OutL, OutR)
|
||||
where
|
||||
L: FnOnce() -> OutL + Send + 'static,
|
||||
R: FnOnce() -> OutR + Send + 'static,
|
||||
OutL: Send + 'static,
|
||||
OutR: Send + 'static,
|
||||
{
|
||||
// No threads available: fall back to sequential execution.
|
||||
(left(), right())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "std"))]
|
||||
pub use no_std_support::*;
|
||||
Loading…
Add table
Add a link
Reference in a new issue