Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 18 additions & 16 deletions crates/goth-eval/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ impl Evaluator {
Expr::Lit(lit) => Ok(self.eval_literal(lit)),
Expr::Prim(name) => env.get_global(name).ok_or_else(|| EvalError::not_implemented(format!("primitive: {}", name))),
Expr::App(func, arg) => { let func_val = self.eval_with_env(func, env)?; let arg_val = self.eval_with_env(arg, env)?; self.apply(func_val, arg_val) }
Expr::Lam(body) => Ok(Value::Closure(Closure { arity: 1, body: (**body).clone(), env: env.capture(), preconditions: vec![], postconditions: vec![] })),
Expr::LamN(n, body) => Ok(Value::Closure(Closure { arity: *n, body: (**body).clone(), env: env.capture(), preconditions: vec![], postconditions: vec![] })),
Expr::Lam(body) => Ok(Value::Closure(Rc::new(Closure { arity: 1, body: (**body).clone(), env: env.capture(), preconditions: vec![], postconditions: vec![] }))),
Expr::LamN(n, body) => Ok(Value::Closure(Rc::new(Closure { arity: *n, body: (**body).clone(), env: env.capture(), preconditions: vec![], postconditions: vec![] }))),
Expr::Let { pattern, type_: _, value, body } => { let val = self.eval_with_env(value, env)?; let mut new_env = env.clone(); self.bind_pattern(pattern, val, &mut new_env)?; self.eval_with_env(body, &new_env) }
Expr::LetRec { bindings, body } => {
let mut new_env = env.clone();
Expand All @@ -135,7 +135,7 @@ impl Evaluator {
Expr::Tuple(exprs) => { let values: Vec<Value> = exprs.iter().map(|e| self.eval_with_env(e, env)).collect::<Result<_, _>>()?; Ok(Value::tuple(values)) }
Expr::Record(fields) => { let map: HashMap<String, Value> = fields.iter().map(|(name, expr)| { let val = self.eval_with_env(expr, env)?; Ok((name.to_string(), val)) }).collect::<EvalResult<_>>()?; Ok(Value::Record(Rc::new(map))) }
Expr::Array(exprs) => { let values: Vec<Value> = exprs.iter().map(|e| self.eval_with_env(e, env)).collect::<Result<_, _>>()?; Ok(self.values_to_tensor(values)) }
Expr::ArrayFill { shape, value } => { let shape_vals: Vec<usize> = shape.iter().map(|e| { let v = self.eval_with_env(e, env)?; v.as_int().map(|n| n as usize).ok_or_else(|| EvalError::type_error("Int", &v)) }).collect::<Result<_, _>>()?; let fill_val = self.eval_with_env(value, env)?; let size: usize = shape_vals.iter().product(); let data = vec![fill_val; size]; Ok(Value::Tensor(Tensor::from_values(shape_vals, data))) }
Expr::ArrayFill { shape, value } => { let shape_vals: Vec<usize> = shape.iter().map(|e| { let v = self.eval_with_env(e, env)?; v.as_int().map(|n| n as usize).ok_or_else(|| EvalError::type_error("Int", &v)) }).collect::<Result<_, _>>()?; let fill_val = self.eval_with_env(value, env)?; let size: usize = shape_vals.iter().product(); let data = vec![fill_val; size]; Ok(Value::Tensor(Rc::new(Tensor::from_values(shape_vals, data)))) }
Expr::Variant { constructor, payload } => { let payload_val = match payload { Some(p) => Some(self.eval_with_env(p, env)?), None => None }; Ok(Value::variant(constructor.to_string(), payload_val)) }
Expr::Field(base, access) => { let val = self.eval_with_env(base, env)?; self.access_field(val, access) }
Expr::Index(base, indices) => { let arr = self.eval_with_env(base, env)?; let idx_vals: Vec<usize> = indices.iter().map(|e| { let v = self.eval_with_env(e, env)?; v.as_index().ok_or_else(|| EvalError::type_error("Int", &v)) }).collect::<Result<_, _>>()?; self.index_value(arr, &idx_vals) }
Expand Down Expand Up @@ -197,8 +197,9 @@ impl Evaluator {
/// Single step of application that may return a tail call for trampolining.
fn apply_once(&mut self, func: Value, arg: Value) -> EvalResult<TcoResult> {
match func {
Value::Closure(closure) => {
if closure.arity == 1 {
Value::Closure(rc_closure) => {
if rc_closure.arity == 1 {
let closure = Rc::unwrap_or_clone(rc_closure);
let mut new_env = closure.env.clone();
new_env.push(arg.clone());

Expand All @@ -215,15 +216,16 @@ impl Evaluator {
Ok(TcoResult::Done(result))
}
} else {
let remaining = (closure.arity - 1) as usize;
Ok(TcoResult::Done(Value::Partial { func: Box::new(Value::Closure(closure)), args: vec![arg], remaining }))
let remaining = (rc_closure.arity - 1) as usize;
Ok(TcoResult::Done(Value::Partial { func: Box::new(Value::Closure(rc_closure)), args: vec![arg], remaining }))
}
}
Value::Partial { func, mut args, remaining } => {
args.push(arg);
if remaining == 1 {
match *func {
Value::Closure(closure) => {
Value::Closure(rc_closure) => {
let closure = Rc::unwrap_or_clone(rc_closure);
let mut new_env = closure.env.clone();
for a in &args {
new_env.push(a.clone());
Expand Down Expand Up @@ -392,7 +394,7 @@ impl Evaluator {
Pattern::Var(_) => { env.push(val.clone()); Ok(true) }
Pattern::Lit(lit) => { let lit_val = self.eval_literal(lit); Ok(val.deep_eq(&lit_val)) }
Pattern::Array(pats) => { match val { Value::Tensor(t) => { if t.rank() != 1 || t.len() != pats.len() { return Ok(false); } for (i, pat) in pats.iter().enumerate() { let elem = t.get_flat(i).unwrap(); if !self.match_pattern(pat, &elem, env)? { return Ok(false); } } Ok(true) } _ => Ok(false) } }
Pattern::ArraySplit { head, tail } => { match val { Value::Tensor(t) => { if t.rank() != 1 || t.len() < head.len() { return Ok(false); } for (i, pat) in head.iter().enumerate() { let elem = t.get_flat(i).unwrap(); if !self.match_pattern(pat, &elem, env)? { return Ok(false); } } let tail_data: Vec<Value> = (head.len()..t.len()).map(|i| t.get_flat(i).unwrap()).collect(); let tail_tensor = Tensor::from_values(vec![tail_data.len()], tail_data); self.match_pattern(tail, &Value::Tensor(tail_tensor), env) } _ => Ok(false) } }
Pattern::ArraySplit { head, tail } => { match val { Value::Tensor(t) => { if t.rank() != 1 || t.len() < head.len() { return Ok(false); } for (i, pat) in head.iter().enumerate() { let elem = t.get_flat(i).unwrap(); if !self.match_pattern(pat, &elem, env)? { return Ok(false); } } let tail_data: Vec<Value> = (head.len()..t.len()).map(|i| t.get_flat(i).unwrap()).collect(); let tail_tensor = Rc::new(Tensor::from_values(vec![tail_data.len()], tail_data)); self.match_pattern(tail, &Value::Tensor(tail_tensor), env) } _ => Ok(false) } }
Pattern::Tuple(pats) => { match val { Value::Tuple(vals) => { if vals.len() != pats.len() { return Ok(false); } for (pat, v) in pats.iter().zip(vals) { if !self.match_pattern(pat, v, env)? { return Ok(false); } } Ok(true) } Value::Unit if pats.is_empty() => Ok(true), _ => Ok(false) } }
Pattern::Variant { constructor, payload } => { match val { Value::Variant { tag, payload: val_payload } => { if tag.as_str() != constructor.as_ref() { return Ok(false); } match (payload, val_payload) { (None, None) => Ok(true), (Some(pat), Some(v)) => self.match_pattern(pat, v, env), _ => Ok(false) } } _ => Ok(false) } }
Pattern::Typed(inner, _ty) => self.match_pattern(inner, val, env),
Expand All @@ -408,16 +410,16 @@ impl Evaluator {
}

fn values_to_tensor_shaped(&self, shape: Vec<usize>, values: Vec<Value>) -> Value {
if values.is_empty() { return Value::Tensor(Tensor::from_ints(vec![])); }
if values.is_empty() { return Value::Tensor(Rc::new(Tensor::from_ints(vec![]))); }
let all_int = values.iter().all(|v| matches!(v, Value::Int(_)));
let all_float = values.iter().all(|v| matches!(v, Value::Float(_) | Value::Int(_)));
let all_bool = values.iter().all(|v| matches!(v, Value::Bool(_)));
let all_char = values.iter().all(|v| matches!(v, Value::Char(_)));
if all_int { Value::Tensor(Tensor { shape, data: crate::value::TensorData::Int(values.iter().map(|v| v.as_int().unwrap()).collect()) }) }
else if all_float { Value::Tensor(Tensor { shape, data: crate::value::TensorData::Float(values.iter().map(|v| ordered_float::OrderedFloat(v.coerce_float().unwrap())).collect()) }) }
else if all_bool { Value::Tensor(Tensor { shape, data: crate::value::TensorData::Bool(values.iter().map(|v| v.as_bool().unwrap()).collect()) }) }
else if all_char { Value::Tensor(Tensor { shape, data: crate::value::TensorData::Char(values.iter().map(|v| v.as_char().unwrap()).collect()) }) }
else { Value::Tensor(Tensor::from_values(shape, values)) }
if all_int { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Int(values.iter().map(|v| v.as_int().unwrap()).collect()) })) }
else if all_float { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Float(values.iter().map(|v| ordered_float::OrderedFloat(v.coerce_float().unwrap())).collect()) })) }
else if all_bool { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Bool(values.iter().map(|v| v.as_bool().unwrap()).collect()) })) }
else if all_char { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Char(values.iter().map(|v| v.as_char().unwrap()).collect()) })) }
else { Value::Tensor(Rc::new(Tensor::from_values(shape, values))) }
}

fn access_field(&self, val: Value, access: &FieldAccess) -> EvalResult<Value> {
Expand Down Expand Up @@ -513,7 +515,7 @@ impl Evaluator {
let body = Expr::App(Box::new(Expr::Idx(2)), Box::new(Expr::App(Box::new(Expr::Idx(1)), Box::new(Expr::Idx(0)))));
let mut env = Env::with_globals(Rc::clone(&self.globals));
env.push(f); env.push(g); // Push f first, then g, so g is at Idx(1) and f is at Idx(2)
Ok(Value::Closure(Closure { arity: 1, body, env, preconditions: vec![], postconditions: vec![] }))
Ok(Value::Closure(Rc::new(Closure { arity: 1, body, env, preconditions: vec![], postconditions: vec![] })))
}

fn eval_write(&mut self, content: Value, target: Value) -> EvalResult<Value> {
Expand Down
148 changes: 148 additions & 0 deletions crates/goth-eval/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod prelude {
mod tests {
use super::prelude::*;
use goth_ast::prelude::*;
use std::rc::Rc;

#[test] fn test_int_literal() { assert_eq!(eval(&Expr::int(42)).unwrap(), Value::Int(42)); }
#[test] fn test_float_literal() { assert_eq!(eval(&Expr::float(3.14)).unwrap(), Value::float(3.14)); }
Expand Down Expand Up @@ -1058,4 +1059,151 @@ mod tests {
}
let _ = fs::remove_file(temp_path);
}

// ============ Rc-wrapping invariant tests ============
// These must pass BEFORE and AFTER the Rc-wrapping refactor.

#[test]
fn invariant_tensor_string_roundtrip() {
let v = Value::string("hello");
match &v {
Value::Tensor(t) => assert_eq!(t.to_string_value(), Some("hello".to_string())),
_ => panic!("Expected Tensor"),
}
}

#[test]
fn invariant_tensor_clone_independence() {
let v1 = Value::Tensor(Rc::new(Tensor::from_ints(vec![1, 2, 3])));
let v2 = v1.clone();
assert_eq!(v1, v2);
}

#[test]
fn invariant_tensor_deep_eq() {
let a = Value::Tensor(Rc::new(Tensor::from_ints(vec![1, 2, 3])));
let b = Value::Tensor(Rc::new(Tensor::from_ints(vec![1, 2, 3])));
assert!(a.deep_eq(&b));
}

#[test]
fn invariant_tensor_as_tensor() {
let v = Value::Tensor(Rc::new(Tensor::from_floats(vec![1.0, 2.0])));
let t = v.as_tensor().unwrap();
assert_eq!(t.shape, vec![2]);
assert_eq!(t.len(), 2);
}

#[test]
fn invariant_tensor_display() {
let v = Value::Tensor(Rc::new(Tensor::from_ints(vec![1, 2, 3])));
let s = format!("{}", v);
assert!(s.contains("1") && s.contains("2") && s.contains("3"));
}

#[test]
fn invariant_concat_strings() {
let mut e = Evaluator::new();
let expr = Expr::app(
Expr::app(Expr::name("concat"), Expr::Lit(Literal::String("ab".into()))),
Expr::Lit(Literal::String("cd".into())),
);
let result = e.eval(&expr).unwrap();
match &result {
Value::Tensor(t) => assert_eq!(t.to_string_value(), Some("abcd".to_string())),
_ => panic!("Expected string tensor"),
}
}

#[test]
fn invariant_concat_int_arrays() {
let mut e = Evaluator::new();
let expr = Expr::app(
Expr::app(
Expr::name("concat"),
Expr::array(vec![Expr::int(1), Expr::int(2)]),
),
Expr::array(vec![Expr::int(3), Expr::int(4)]),
);
let result = e.eval(&expr).unwrap();
match &result {
Value::Tensor(t) => {
assert_eq!(t.len(), 4);
assert_eq!(t.get_flat(0), Some(Value::Int(1)));
assert_eq!(t.get_flat(3), Some(Value::Int(4)));
}
_ => panic!("Expected tensor"),
}
}

#[test]
fn invariant_closure_apply() {
let result = eval(&Expr::app(
Expr::lam(Expr::mul(Expr::idx(0), Expr::int(2))),
Expr::int(21),
)).unwrap();
assert_eq!(result, Value::Int(42));
}

#[test]
fn invariant_closure_partial_application() {
let add_fn = Expr::lam(Expr::lam(Expr::add(Expr::idx(1), Expr::idx(0))));
let add5 = Expr::app(add_fn, Expr::int(5));
let result = eval(&Expr::app(add5, Expr::int(3))).unwrap();
assert_eq!(result, Value::Int(8));
}

#[test]
fn invariant_closure_captures_env() {
let expr = Expr::let_(
Pattern::var("x"), Expr::int(10),
Expr::app(Expr::lam(Expr::add(Expr::idx(0), Expr::idx(1))), Expr::int(5)),
);
assert_eq!(eval(&expr).unwrap(), Value::Int(15));
}

#[test]
fn invariant_closure_eq() {
let c1 = Value::closure(1, Expr::idx(0), Env::new());
let c2 = Value::closure(1, Expr::idx(0), Env::new());
assert_eq!(c1, c2);
}

#[test]
fn invariant_closure_is_callable() {
let v = Value::closure(1, Expr::idx(0), Env::new());
assert!(v.is_callable());
}

// ============ Rc sharing tests ============

#[test]
fn sharing_closure_clone_shares_rc() {
let v1 = Value::closure(1, Expr::idx(0), Env::new());
let v2 = v1.clone();
match (&v1, &v2) {
(Value::Closure(a), Value::Closure(b)) => assert!(Rc::ptr_eq(a, b)),
_ => panic!("Expected closures"),
}
}

#[test]
fn sharing_tensor_clone_shares_rc() {
let v1 = Value::Tensor(Rc::new(Tensor::from_ints(vec![1, 2, 3])));
let v2 = v1.clone();
match (&v1, &v2) {
(Value::Tensor(a), Value::Tensor(b)) => assert!(Rc::ptr_eq(a, b)),
_ => panic!("Expected tensors"),
}
}

#[test]
fn sharing_string_value_is_rc_tensor() {
let v1 = Value::string("hello");
let v2 = v1.clone();
match (&v1, &v2) {
(Value::Tensor(a), Value::Tensor(b)) => assert!(Rc::ptr_eq(a, b)),
_ => panic!("Expected tensors"),
}
}
}
Loading
Loading