diff --git a/crates/goth-ast/src/expr.rs b/crates/goth-ast/src/expr.rs index 5688f2b..f8aa909 100644 --- a/crates/goth-ast/src/expr.rs +++ b/crates/goth-ast/src/expr.rs @@ -529,6 +529,9 @@ impl std::fmt::Display for Expr { Literal::True => write!(f, "⊀"), Literal::False => write!(f, "βŠ₯"), Literal::Unit => write!(f, "⟨⟩"), + Literal::ImagI(x) => write!(f, "{}π•š", x), + Literal::ImagJ(x) => write!(f, "{}𝕛", x), + Literal::ImagK(x) => write!(f, "{}π•œ", x), }, Expr::Prim(name) => write!(f, "βŠ₯{}", name), Expr::App(func, arg) => write!(f, "({} {})", func, arg), diff --git a/crates/goth-ast/src/literal.rs b/crates/goth-ast/src/literal.rs index 431b4b4..58c5e58 100644 --- a/crates/goth-ast/src/literal.rs +++ b/crates/goth-ast/src/literal.rs @@ -25,6 +25,15 @@ pub enum Literal { /// Unit value (empty tuple) Unit, + + /// Imaginary-i literal (coefficient stored as f64) + ImagI(f64), + + /// Imaginary-j literal (quaternion j component) + ImagJ(f64), + + /// Imaginary-k literal (quaternion k component) + ImagK(f64), } impl Literal { @@ -40,6 +49,10 @@ impl Literal { Literal::String(s.into()) } + pub fn imag_i(f: f64) -> Self { Literal::ImagI(f) } + pub fn imag_j(f: f64) -> Self { Literal::ImagJ(f) } + pub fn imag_k(f: f64) -> Self { Literal::ImagK(f) } + pub fn bool(b: bool) -> Self { if b { Literal::True } else { Literal::False } } @@ -60,7 +73,7 @@ impl Literal { /// Check if this is a numeric literal pub fn is_numeric(&self) -> bool { - matches!(self, Literal::Int(_) | Literal::Float(_)) + matches!(self, Literal::Int(_) | Literal::Float(_) | Literal::ImagI(_) | Literal::ImagJ(_) | Literal::ImagK(_)) } } diff --git a/crates/goth-ast/src/pretty.rs b/crates/goth-ast/src/pretty.rs index 73f0757..cf8f211 100644 --- a/crates/goth-ast/src/pretty.rs +++ b/crates/goth-ast/src/pretty.rs @@ -364,6 +364,9 @@ impl Pretty { self.write("'"); } Literal::Unit => self.write("()"), + Literal::ImagI(x) => { self.write(&x.to_string()); self.write("π•š"); } + Literal::ImagJ(x) => { self.write(&x.to_string()); self.write("𝕛"); } + Literal::ImagK(x) => { self.write(&x.to_string()); self.write("π•œ"); } } Expr::Name(name) => self.write(name), @@ -541,6 +544,9 @@ impl Pretty { self.write("'"); } Literal::Unit => self.write("()"), + Literal::ImagI(x) => { self.write(&x.to_string()); self.write("π•š"); } + Literal::ImagJ(x) => { self.write(&x.to_string()); self.write("𝕛"); } + Literal::ImagK(x) => { self.write(&x.to_string()); self.write("π•œ"); } } } Pattern::Tuple(pats) => { diff --git a/crates/goth-ast/src/types.rs b/crates/goth-ast/src/types.rs index 343a8fe..ddb7b31 100644 --- a/crates/goth-ast/src/types.rs +++ b/crates/goth-ast/src/types.rs @@ -43,6 +43,10 @@ pub enum PrimType { // Arbitrary precision (for compile-time computation) Nat, // β„• - natural numbers Int, // β„€ - integers + + // Complex number types + Complex, // β„‚ - complex (f64, f64) + Quaternion, // ℍ - quaternion (f64, f64, f64, f64) } /// Type representation @@ -173,6 +177,8 @@ impl Type { pub fn bool() -> Self { Type::Prim(PrimType::Bool) } pub fn char() -> Self { Type::Prim(PrimType::Char) } pub fn nat() -> Self { Type::Prim(PrimType::Nat) } + pub fn complex() -> Self { Type::Prim(PrimType::Complex) } + pub fn quaternion() -> Self { Type::Prim(PrimType::Quaternion) } // Tensor pub fn tensor(shape: Shape, elem: Type) -> Self { @@ -282,7 +288,7 @@ impl PrimType { } pub fn is_float(&self) -> bool { - matches!(self, PrimType::F64 | PrimType::F32) + matches!(self, PrimType::F64 | PrimType::F32 | PrimType::Complex | PrimType::Quaternion) } pub fn is_int(&self) -> bool { @@ -306,7 +312,9 @@ impl PrimType { PrimType::F32 | PrimType::I32 | PrimType::U32 | PrimType::Char => Some(32), PrimType::I16 | PrimType::U16 => Some(16), PrimType::I8 | PrimType::U8 | PrimType::Byte | PrimType::Bool => Some(8), - PrimType::Nat | PrimType::Int | PrimType::String => None, // Variable/arbitrary size + PrimType::Complex => Some(128), // 2 Γ— f64 + PrimType::Quaternion => None, // 4 Γ— f64 = 256, doesn't fit u8 + PrimType::Nat | PrimType::Int | PrimType::String => None, } } } @@ -332,6 +340,8 @@ impl std::fmt::Display for PrimType { PrimType::String => write!(f, "String"), PrimType::Nat => write!(f, "β„•"), PrimType::Int => write!(f, "β„€"), + PrimType::Complex => write!(f, "β„‚"), + PrimType::Quaternion => write!(f, "ℍ"), } } } diff --git a/crates/goth-check/src/infer.rs b/crates/goth-check/src/infer.rs index 4f7caed..828e89a 100644 --- a/crates/goth-check/src/infer.rs +++ b/crates/goth-check/src/infer.rs @@ -528,6 +528,8 @@ fn literal_type(lit: &Literal) -> Type { Box::new(Type::Prim(PrimType::Char)), ), Literal::Unit => Type::unit(), + Literal::ImagI(_) => Type::Prim(PrimType::Complex), + Literal::ImagJ(_) | Literal::ImagK(_) => Type::Prim(PrimType::Quaternion), } } diff --git a/crates/goth-eval/src/eval.rs b/crates/goth-eval/src/eval.rs index c5b415d..4ffc122 100644 --- a/crates/goth-eval/src/eval.rs +++ b/crates/goth-eval/src/eval.rs @@ -83,6 +83,13 @@ impl Evaluator { // String comparison ("strEq", PrimFn::StrEq), ("startsWith", PrimFn::StartsWith), ("endsWith", PrimFn::EndsWith), ("contains", PrimFn::Contains), + // Complex/quaternion decomposition + ("re", PrimFn::Re), ("im", PrimFn::Im), ("conj", PrimFn::Conj), ("arg", PrimFn::Arg), + // Matrix utilities + ("trace", PrimFn::Trace), ("tr", PrimFn::Trace), + ("det", PrimFn::Det), ("inv", PrimFn::Inv), + ("diag", PrimFn::Diag), ("eye", PrimFn::Eye), + ("solve", PrimFn::Solve), ("solveWith", PrimFn::SolveWith), ]; for (name, prim) in prims { self.globals.borrow_mut().insert(name.to_string(), Value::Primitive(*prim)); } // Register stream constants @@ -152,7 +159,7 @@ impl Evaluator { } fn eval_literal(&self, lit: &Literal) -> Value { - match lit { Literal::Int(n) => Value::Int(*n), Literal::Float(f) => Value::float(*f), Literal::Char(c) => Value::Char(*c), Literal::String(s) => Value::string(s), Literal::True => Value::Bool(true), Literal::False => Value::Bool(false), Literal::Unit => Value::Unit } + match lit { Literal::Int(n) => Value::Int(*n), Literal::Float(f) => Value::float(*f), Literal::Char(c) => Value::Char(*c), Literal::String(s) => Value::string(s), Literal::True => Value::Bool(true), Literal::False => Value::Bool(false), Literal::Unit => Value::Unit, Literal::ImagI(f) => Value::Complex(0.0, *f), Literal::ImagJ(f) => Value::Quaternion(0.0, 0.0, *f, 0.0), Literal::ImagK(f) => Value::Quaternion(0.0, 0.0, 0.0, *f) } } fn eval_binop(&mut self, op: &BinOp, left: &Expr, right: &Expr, env: &Env) -> EvalResult { @@ -419,7 +426,42 @@ impl Evaluator { else if all_float { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Float(values.iter().map(|v| ordered_float::OrderedFloat(v.coerce_float().unwrap())).collect()) })) } else if all_bool { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Bool(values.iter().map(|v| v.as_bool().unwrap()).collect()) })) } else if all_char { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Char(values.iter().map(|v| v.as_char().unwrap()).collect()) })) } - else { Value::Tensor(Rc::new(Tensor::from_values(shape, values))) } + else { + // Auto-flatten: if all values are tensors with the same shape and compatible data, + // combine into a higher-dimensional tensor (e.g., [[1,2],[3,4]] β†’ shape [2,2]) + let all_tensor_same_shape = if let Some(Value::Tensor(first)) = values.first() { + let s = &first.shape; + values.iter().all(|v| matches!(v, Value::Tensor(t) if &t.shape == s)) + } else { false }; + if all_tensor_same_shape { + let sub_tensors: Vec<&Tensor> = values.iter().map(|v| match v { Value::Tensor(t) => t.as_ref(), _ => unreachable!() }).collect(); + let sub_shape = &sub_tensors[0].shape; + let mut new_shape = shape.clone(); + new_shape.extend_from_slice(sub_shape); + // Try to flatten into typed tensor + let all_int = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Int(_))); + let all_float = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Int(_) | crate::value::TensorData::Float(_))); + let all_bool = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Bool(_))); + if all_int { + let data: Vec = sub_tensors.iter().flat_map(|t| match &t.data { crate::value::TensorData::Int(v) => v.iter().copied(), _ => unreachable!() }).collect(); + Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Int(data) })) + } else if all_float { + let data: Vec> = sub_tensors.iter().flat_map(|t| match &t.data { + crate::value::TensorData::Float(v) => v.clone(), + crate::value::TensorData::Int(v) => v.iter().map(|&i| ordered_float::OrderedFloat(i as f64)).collect(), + _ => unreachable!() + }).collect(); + Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Float(data) })) + } else if all_bool { + let data: Vec = sub_tensors.iter().flat_map(|t| match &t.data { crate::value::TensorData::Bool(v) => v.iter().copied(), _ => unreachable!() }).collect(); + Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Bool(data) })) + } else { + Value::Tensor(Rc::new(Tensor::from_values(shape, values))) + } + } else { + Value::Tensor(Rc::new(Tensor::from_values(shape, values))) + } + } } fn access_field(&self, val: Value, access: &FieldAccess) -> EvalResult { @@ -605,6 +647,10 @@ fn prim_arity(prim: PrimFn) -> usize { PrimFn::Print | PrimFn::Write | PrimFn::ReadLine | PrimFn::ReadKey | PrimFn::ReadFile | PrimFn::Sleep => 1, PrimFn::Flush | PrimFn::RawModeEnter | PrimFn::RawModeExit => 1, // Terminal control (take unit) PrimFn::Lines | PrimFn::Words | PrimFn::Bytes => 1, // String splitting (unary) + PrimFn::Re | PrimFn::Im | PrimFn::Conj | PrimFn::Arg => 1, // Complex decomposition + PrimFn::Trace | PrimFn::Det | PrimFn::Inv | PrimFn::Diag | PrimFn::Eye => 1, // Matrix utilities + PrimFn::Solve => 2, // Linear solve (default LU) + PrimFn::SolveWith => 3, // Linear solve with method string PrimFn::WriteFile | PrimFn::ReadBytes | PrimFn::WriteBytes => 2, // Binary I/O takes 2 args PrimFn::Fold => 3, // fold f acc arr PrimFn::StrEq | PrimFn::StartsWith | PrimFn::EndsWith | PrimFn::Contains => 2, // String comparison (binary) diff --git a/crates/goth-eval/src/lib.rs b/crates/goth-eval/src/lib.rs index 5dd9030..46584ef 100644 --- a/crates/goth-eval/src/lib.rs +++ b/crates/goth-eval/src/lib.rs @@ -1206,4 +1206,634 @@ mod tests { _ => panic!("Expected tensors"), } } + + // ============ Complex / Quaternion Foundation Tests ============ + + #[test] + fn test_complex_literal_i() { + let expr = Expr::Lit(Literal::ImagI(3.0)); + assert_eq!(eval(&expr).unwrap(), Value::Complex(0.0, 3.0)); + } + + #[test] + fn test_complex_literal_j() { + let expr = Expr::Lit(Literal::ImagJ(2.0)); + assert_eq!(eval(&expr).unwrap(), Value::Quaternion(0.0, 0.0, 2.0, 0.0)); + } + + #[test] + fn test_complex_literal_k() { + let expr = Expr::Lit(Literal::ImagK(5.0)); + assert_eq!(eval(&expr).unwrap(), Value::Quaternion(0.0, 0.0, 0.0, 5.0)); + } + + #[test] + fn test_complex_display() { + assert_eq!(format!("{}", Value::Complex(3.0, 4.0)), "3 + 4π•š"); + assert_eq!(format!("{}", Value::Complex(3.0, -4.0)), "3 - 4π•š"); + assert_eq!(format!("{}", Value::Complex(0.0, 1.0)), "1π•š"); + assert_eq!(format!("{}", Value::Complex(5.0, 0.0)), "5"); + assert_eq!(format!("{}", Value::Complex(0.0, 0.0)), "0"); + } + + #[test] + fn test_complex_type_name() { + assert_eq!(Value::Complex(1.0, 2.0).type_name(), "Complex"); + assert_eq!(Value::Quaternion(1.0, 0.0, 0.0, 0.0).type_name(), "Quaternion"); + } + + #[test] + fn test_complex_is_numeric() { + assert!(Value::Complex(1.0, 2.0).is_numeric()); + assert!(Value::Quaternion(1.0, 0.0, 0.0, 0.0).is_numeric()); + } + + #[test] + fn test_complex_deep_eq() { + assert!(Value::Complex(1.0, 2.0).deep_eq(&Value::Complex(1.0, 2.0))); + assert!(!Value::Complex(1.0, 2.0).deep_eq(&Value::Complex(1.0, 3.0))); + assert!(Value::Quaternion(1.0, 2.0, 3.0, 4.0).deep_eq(&Value::Quaternion(1.0, 2.0, 3.0, 4.0))); + assert!(!Value::Quaternion(1.0, 2.0, 3.0, 4.0).deep_eq(&Value::Quaternion(1.0, 2.0, 3.0, 5.0))); + } + + #[test] + fn test_complex_coerce_float() { + assert_eq!(Value::Complex(3.0, 0.0).coerce_float(), Some(3.0)); + assert_eq!(Value::Complex(3.0, 1.0).coerce_float(), None); + } + + #[test] + fn test_complex_coerce_complex() { + assert_eq!(Value::Int(5).coerce_complex(), Some((5.0, 0.0))); + assert_eq!(Value::float(3.14).coerce_complex(), Some((3.14, 0.0))); + assert_eq!(Value::Complex(1.0, 2.0).coerce_complex(), Some((1.0, 2.0))); + } + + #[test] + fn test_quaternion_coerce() { + assert_eq!(Value::Int(5).coerce_quaternion(), Some((5.0, 0.0, 0.0, 0.0))); + assert_eq!(Value::Complex(1.0, 2.0).coerce_quaternion(), Some((1.0, 2.0, 0.0, 0.0))); + assert_eq!(Value::Quaternion(1.0, 2.0, 3.0, 4.0).coerce_quaternion(), Some((1.0, 2.0, 3.0, 4.0))); + } + + // ── Phase 3: Complex + Quaternion arithmetic ── + + fn assert_complex_approx(result: &Value, re: f64, im: f64, tol: f64, label: &str) { + match result { + Value::Complex(r, i) => { + assert!((*r - re).abs() < tol, "{}: re = {}, expected {}", label, r, re); + assert!((*i - im).abs() < tol, "{}: im = {}, expected {}", label, i, im); + } + other => panic!("{}: expected Complex, got {:?}", label, other), + } + } + + fn assert_quat_approx(result: &Value, w: f64, i: f64, j: f64, k: f64, tol: f64, label: &str) { + match result { + Value::Quaternion(rw, ri, rj, rk) => { + assert!((*rw - w).abs() < tol, "{}: w = {}, expected {}", label, rw, w); + assert!((*ri - i).abs() < tol, "{}: i = {}, expected {}", label, ri, i); + assert!((*rj - j).abs() < tol, "{}: j = {}, expected {}", label, rj, j); + assert!((*rk - k).abs() < tol, "{}: k = {}, expected {}", label, rk, k); + } + other => panic!("{}: expected Quaternion, got {:?}", label, other), + } + } + + #[test] + fn test_complex_add() { + // (3+4i) + (1+2i) = 4+6i + let expr = Expr::add( + Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0))), + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 4.0, 6.0, 1e-10, "complex add"); + } + + #[test] + fn test_complex_sub() { + // (3+4i) - (1+2i) = 2+2i + let expr = Expr::sub( + Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0))), + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 2.0, 2.0, 1e-10, "complex sub"); + } + + #[test] + fn test_complex_mul() { + // (3+4i)(1+2i) = 3*1 - 4*2 + (3*2 + 4*1)i = -5+10i + let expr = Expr::mul( + Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0))), + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), -5.0, 10.0, 1e-10, "complex mul"); + } + + #[test] + fn test_complex_mul_i_squared() { + // i * i = -1 + let expr = Expr::mul( + Expr::lit(Literal::ImagI(1.0)), + Expr::lit(Literal::ImagI(1.0)), + ); + let result = eval(&expr).unwrap(); + assert_complex_approx(&result, -1.0, 0.0, 1e-10, "i*i"); + } + + #[test] + fn test_complex_div() { + // (3+4i)/(1+2i) = (3+8+4-6i)/(1+4) = (11-2i)/5 = 2.2-0.4i + let expr = Expr::div( + Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0))), + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 2.2, -0.4, 1e-10, "complex div"); + } + + #[test] + fn test_complex_abs() { + // |3+4i| = 5.0 + let expr = Expr::UnaryOp( + UnaryOp::Abs, + Box::new(Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0)))), + ); + assert_eq!(eval(&expr).unwrap(), Value::float(5.0)); + } + + #[test] + fn test_complex_negate() { + // -(3+4i) = -3-4i + let expr = Expr::UnaryOp( + UnaryOp::Neg, + Box::new(Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0)))), + ); + assert_complex_approx(&eval(&expr).unwrap(), -3.0, -4.0, 1e-10, "complex negate"); + } + + #[test] + fn test_complex_auto_promote() { + // 5 + 3i = Complex(5, 3) + let expr = Expr::add(Expr::int(5), Expr::lit(Literal::ImagI(3.0))); + assert_complex_approx(&eval(&expr).unwrap(), 5.0, 3.0, 1e-10, "auto-promote"); + } + + #[test] + fn test_complex_exp_euler() { + // exp(Ο€i) β‰ˆ -1 + 0i + let pi = std::f64::consts::PI; + let expr = Expr::UnaryOp( + UnaryOp::Exp, + Box::new(Expr::lit(Literal::ImagI(pi))), + ); + assert_complex_approx(&eval(&expr).unwrap(), -1.0, 0.0, 1e-10, "euler identity"); + } + + #[test] + fn test_complex_sin() { + // sin(i) = iΒ·sinh(1) + let expr = Expr::UnaryOp( + UnaryOp::Sin, + Box::new(Expr::lit(Literal::ImagI(1.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 0.0, 1.0_f64.sinh(), 1e-10, "sin(i)"); + } + + #[test] + fn test_complex_cos() { + // cos(i) = cosh(1) + let expr = Expr::UnaryOp( + UnaryOp::Cos, + Box::new(Expr::lit(Literal::ImagI(1.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 1.0_f64.cosh(), 0.0, 1e-10, "cos(i)"); + } + + #[test] + fn test_complex_ln() { + // ln(i) = Ο€i/2 + let expr = Expr::UnaryOp( + UnaryOp::Ln, + Box::new(Expr::lit(Literal::ImagI(1.0))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 0.0, std::f64::consts::FRAC_PI_2, 1e-10, "ln(i)"); + } + + #[test] + fn test_complex_sqrt_negative() { + // sqrt(-4) = 2i + let expr = Expr::UnaryOp( + UnaryOp::Sqrt, + Box::new(Expr::UnaryOp(UnaryOp::Neg, Box::new(Expr::int(4)))), + ); + assert_complex_approx(&eval(&expr).unwrap(), 0.0, 2.0, 1e-10, "sqrt(-4)"); + } + + #[test] + fn test_quaternion_ij_eq_k() { + // i Γ— j = k + let expr = Expr::mul( + Expr::lit(Literal::ImagI(1.0)), + Expr::lit(Literal::ImagJ(1.0)), + ); + assert_quat_approx(&eval(&expr).unwrap(), 0.0, 0.0, 0.0, 1.0, 1e-10, "i*j=k"); + } + + #[test] + fn test_quaternion_ji_eq_neg_k() { + // j Γ— i = -k + let expr = Expr::mul( + Expr::lit(Literal::ImagJ(1.0)), + Expr::lit(Literal::ImagI(1.0)), + ); + assert_quat_approx(&eval(&expr).unwrap(), 0.0, 0.0, 0.0, -1.0, 1e-10, "j*i=-k"); + } + + #[test] + fn test_quaternion_jk_eq_i() { + // j Γ— k = i + let expr = Expr::mul( + Expr::lit(Literal::ImagJ(1.0)), + Expr::lit(Literal::ImagK(1.0)), + ); + assert_quat_approx(&eval(&expr).unwrap(), 0.0, 1.0, 0.0, 0.0, 1e-10, "j*k=i"); + } + + #[test] + fn test_quaternion_ijk_eq_neg1() { + // i Γ— j Γ— k = -1 + let expr = Expr::mul( + Expr::mul( + Expr::lit(Literal::ImagI(1.0)), + Expr::lit(Literal::ImagJ(1.0)), + ), + Expr::lit(Literal::ImagK(1.0)), + ); + assert_quat_approx(&eval(&expr).unwrap(), -1.0, 0.0, 0.0, 0.0, 1e-10, "ijk=-1"); + } + + #[test] + fn test_quaternion_add() { + // (1+2i+3j+4k) + (5+6i+7j+8k) = 6+8i+10j+12k + let q1 = Expr::add( + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + Expr::add(Expr::lit(Literal::ImagJ(3.0)), Expr::lit(Literal::ImagK(4.0))), + ); + let q2 = Expr::add( + Expr::add(Expr::int(5), Expr::lit(Literal::ImagI(6.0))), + Expr::add(Expr::lit(Literal::ImagJ(7.0)), Expr::lit(Literal::ImagK(8.0))), + ); + let expr = Expr::add(q1, q2); + assert_quat_approx(&eval(&expr).unwrap(), 6.0, 8.0, 10.0, 12.0, 1e-10, "quat add"); + } + + #[test] + fn test_quaternion_norm() { + // |1+2i+3j+4k| = √(1+4+9+16) = √30 + let q = Expr::add( + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + Expr::add(Expr::lit(Literal::ImagJ(3.0)), Expr::lit(Literal::ImagK(4.0))), + ); + let expr = Expr::UnaryOp(UnaryOp::Abs, Box::new(q)); + assert_eq!(eval(&expr).unwrap(), Value::float(30.0_f64.sqrt())); + } + + // ── Phase 4: re, im, conj, arg primitives ── + + #[test] + fn test_re_complex() { + let mut e = Evaluator::new(); + let expr = Expr::app(Expr::name("re"), Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0)))); + assert_eq!(e.eval(&expr).unwrap(), Value::float(3.0)); + } + + #[test] + fn test_im_complex() { + let mut e = Evaluator::new(); + let expr = Expr::app(Expr::name("im"), Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0)))); + assert_eq!(e.eval(&expr).unwrap(), Value::float(4.0)); + } + + #[test] + fn test_conj_complex() { + let mut e = Evaluator::new(); + let expr = Expr::app(Expr::name("conj"), Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0)))); + assert_complex_approx(&e.eval(&expr).unwrap(), 3.0, -4.0, 1e-10, "conj(3+4i)"); + } + + #[test] + fn test_arg_complex() { + let mut e = Evaluator::new(); + // arg(i) = Ο€/2 + let expr = Expr::app(Expr::name("arg"), Expr::lit(Literal::ImagI(1.0))); + assert_approx( + match e.eval(&expr).unwrap() { Value::Float(f) => f.0, v => panic!("expected Float, got {:?}", v) }, + std::f64::consts::FRAC_PI_2, 1e-10, "arg(i)" + ); + } + + #[test] + fn test_conj_quaternion() { + let mut e = Evaluator::new(); + let q = Expr::add( + Expr::add(Expr::int(1), Expr::lit(Literal::ImagI(2.0))), + Expr::add(Expr::lit(Literal::ImagJ(3.0)), Expr::lit(Literal::ImagK(4.0))), + ); + let expr = Expr::app(Expr::name("conj"), q); + assert_quat_approx(&e.eval(&expr).unwrap(), 1.0, -2.0, -3.0, -4.0, 1e-10, "conj(quat)"); + } + + #[test] + fn test_re_of_real() { + let mut e = Evaluator::new(); + assert_eq!(e.eval(&Expr::app(Expr::name("re"), Expr::float(5.0))).unwrap(), Value::float(5.0)); + } + + #[test] + fn test_conj_of_real() { + let mut e = Evaluator::new(); + assert_eq!(e.eval(&Expr::app(Expr::name("conj"), Expr::float(5.0))).unwrap(), Value::float(5.0)); + } + + #[test] + fn test_z_times_conj_z() { + // z Γ— conj(z) = |z|Β² (real) + let mut e = Evaluator::new(); + let z = Expr::add(Expr::int(3), Expr::lit(Literal::ImagI(4.0))); + let conj_z = Expr::app(Expr::name("conj"), z.clone()); + let expr = Expr::mul(z, conj_z); + // 3Β² + 4Β² = 25, should be Complex(25, 0) + assert_complex_approx(&e.eval(&expr).unwrap(), 25.0, 0.0, 1e-10, "z*conj(z)"); + } + + // ── Matrix utility tests ── + + fn mat2x2(a: f64, b: f64, c: f64, d: f64) -> Expr { + Expr::array(vec![ + Expr::array(vec![Expr::float(a), Expr::float(b)]), + Expr::array(vec![Expr::float(c), Expr::float(d)]), + ]) + } + + fn mat3x3(vals: [f64; 9]) -> Expr { + Expr::array(vec![ + Expr::array(vec![Expr::float(vals[0]), Expr::float(vals[1]), Expr::float(vals[2])]), + Expr::array(vec![Expr::float(vals[3]), Expr::float(vals[4]), Expr::float(vals[5])]), + Expr::array(vec![Expr::float(vals[6]), Expr::float(vals[7]), Expr::float(vals[8])]), + ]) + } + + fn vec_expr(vals: &[f64]) -> Expr { + Expr::array(vals.iter().map(|&v| Expr::float(v)).collect()) + } + + fn assert_tensor_float(result: &Value, idx: &[usize], expected: f64, tol: f64, label: &str) { + match result { + Value::Tensor(t) => { + let v = t.get(idx).unwrap_or_else(|| panic!("{}: index {:?} out of bounds", label, idx)); + let f = v.coerce_float().unwrap_or_else(|| panic!("{}: not numeric at {:?}", label, idx)); + assert!((f - expected).abs() < tol, "{}: at {:?} got {}, expected {}", label, idx, f, expected); + } + other => panic!("{}: expected Tensor, got {:?}", label, other), + } + } + + // Phase 1: trace + eye + + #[test] + fn test_trace_general() { + let mut e = Evaluator::new(); + let expr = Expr::app(Expr::name("trace"), mat2x2(1.0, 2.0, 3.0, 4.0)); + assert_eq!(e.eval(&expr).unwrap(), Value::float(5.0)); + } + + #[test] + fn test_trace_identity_3x3() { + let mut e = Evaluator::new(); + let expr = Expr::app(Expr::name("trace"), Expr::app(Expr::name("eye"), Expr::int(3))); + assert_eq!(e.eval(&expr).unwrap(), Value::float(3.0)); + } + + #[test] + fn test_trace_non_square_error() { + let mut e = Evaluator::new(); + let mat = Expr::array(vec![ + Expr::array(vec![Expr::float(1.0), Expr::float(2.0), Expr::float(3.0)]), + Expr::array(vec![Expr::float(4.0), Expr::float(5.0), Expr::float(6.0)]), + ]); + assert!(e.eval(&Expr::app(Expr::name("trace"), mat)).is_err()); + } + + #[test] + fn test_eye_3() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("eye"), Expr::int(3))).unwrap(); + if let Value::Tensor(t) = &result { + assert_eq!(t.shape, vec![3, 3]); + for i in 0..3 { for j in 0..3 { + let expected = if i == j { 1.0 } else { 0.0 }; + assert_eq!(t.get(&[i, j]).unwrap().coerce_float().unwrap(), expected); + }} + } else { panic!("Expected tensor"); } + } + + #[test] + fn test_eye_1() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("eye"), Expr::int(1))).unwrap(); + assert_tensor_float(&result, &[0, 0], 1.0, 1e-15, "eye(1)"); + } + + // Phase 2: diag + + #[test] + fn test_diag_vec_to_matrix() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("diag"), vec_expr(&[1.0, 2.0, 3.0]))).unwrap(); + if let Value::Tensor(t) = &result { + assert_eq!(t.shape, vec![3, 3]); + } else { panic!("Expected tensor"); } + assert_tensor_float(&result, &[0, 0], 1.0, 1e-15, "diag[0,0]"); + assert_tensor_float(&result, &[1, 1], 2.0, 1e-15, "diag[1,1]"); + assert_tensor_float(&result, &[2, 2], 3.0, 1e-15, "diag[2,2]"); + assert_tensor_float(&result, &[0, 1], 0.0, 1e-15, "diag[0,1]"); + } + + #[test] + fn test_diag_matrix_to_vec() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("diag"), mat2x2(1.0, 2.0, 3.0, 4.0))).unwrap(); + if let Value::Tensor(t) = &result { + assert_eq!(t.shape, vec![2]); + assert_eq!(t.get(&[0]).unwrap().coerce_float().unwrap(), 1.0); + assert_eq!(t.get(&[1]).unwrap().coerce_float().unwrap(), 4.0); + } else { panic!("Expected tensor"); } + } + + // Phase 3: det + + #[test] + fn test_det_2x2() { + let mut e = Evaluator::new(); + let f = e.eval(&Expr::app(Expr::name("det"), mat2x2(1.0, 2.0, 3.0, 4.0))).unwrap() + .coerce_float().unwrap(); + assert!((f - (-2.0)).abs() < 1e-10, "det = {}, expected -2", f); + } + + #[test] + fn test_det_3x3() { + let mut e = Evaluator::new(); + let f = e.eval(&Expr::app(Expr::name("det"), mat3x3([6.0,1.0,1.0, 4.0,-2.0,5.0, 2.0,8.0,7.0]))).unwrap() + .coerce_float().unwrap(); + assert!((f - (-306.0)).abs() < 1e-8, "det = {}, expected -306", f); + } + + #[test] + fn test_det_identity() { + let mut e = Evaluator::new(); + let f = e.eval(&Expr::app(Expr::name("det"), Expr::app(Expr::name("eye"), Expr::int(3)))).unwrap() + .coerce_float().unwrap(); + assert!((f - 1.0).abs() < 1e-12, "det(I) = {}, expected 1", f); + } + + #[test] + fn test_det_singular() { + let mut e = Evaluator::new(); + let f = e.eval(&Expr::app(Expr::name("det"), mat2x2(1.0, 2.0, 2.0, 4.0))).unwrap() + .coerce_float().unwrap(); + assert!(f.abs() < 1e-10, "det(singular) = {}, expected 0", f); + } + + // Phase 4: inv + + #[test] + fn test_inv_2x2() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("inv"), mat2x2(1.0, 2.0, 3.0, 4.0))).unwrap(); + let tol = 1e-10; + assert_tensor_float(&result, &[0, 0], -2.0, tol, "inv[0,0]"); + assert_tensor_float(&result, &[0, 1], 1.0, tol, "inv[0,1]"); + assert_tensor_float(&result, &[1, 0], 1.5, tol, "inv[1,0]"); + assert_tensor_float(&result, &[1, 1], -0.5, tol, "inv[1,1]"); + } + + #[test] + fn test_inv_identity() { + let mut e = Evaluator::new(); + let result = e.eval(&Expr::app(Expr::name("inv"), Expr::app(Expr::name("eye"), Expr::int(3)))).unwrap(); + for i in 0..3 { for j in 0..3 { + let expected = if i == j { 1.0 } else { 0.0 }; + assert_tensor_float(&result, &[i, j], expected, 1e-10, &format!("inv(I)[{},{}]", i, j)); + }} + } + + #[test] + fn test_inv_singular_error() { + let mut e = Evaluator::new(); + assert!(e.eval(&Expr::app(Expr::name("inv"), mat2x2(1.0, 2.0, 2.0, 4.0))).is_err()); + } + + #[test] + fn test_inv_roundtrip() { + let mut e = Evaluator::new(); + let a = mat2x2(2.0, 1.0, 5.0, 3.0); + let inv_a = Expr::app(Expr::name("inv"), a.clone()); + let product = Expr::app(Expr::app(Expr::name("matmul"), a), inv_a); + let result = e.eval(&product).unwrap(); + for i in 0..2 { for j in 0..2 { + let expected = if i == j { 1.0 } else { 0.0 }; + assert_tensor_float(&result, &[i, j], expected, 1e-10, &format!("A*inv(A)[{},{}]", i, j)); + }} + } + + // Phase 5: solve + + #[test] + fn test_solve_2x2() { + let mut e = Evaluator::new(); + let a = mat2x2(2.0, 1.0, 5.0, 3.0); + let b = vec_expr(&[4.0, 7.0]); + let result = e.eval(&Expr::app(Expr::app(Expr::name("solve"), a), b)).unwrap(); + assert_tensor_float(&result, &[0], 5.0, 1e-10, "x[0]"); + assert_tensor_float(&result, &[1], -6.0, 1e-10, "x[1]"); + } + + #[test] + fn test_solve_3x3() { + let mut e = Evaluator::new(); + let a = mat3x3([1.0,1.0,1.0, 0.0,2.0,5.0, 2.0,5.0,-1.0]); + let b = vec_expr(&[6.0, -4.0, 27.0]); + let result = e.eval(&Expr::app(Expr::app(Expr::name("solve"), a), b)).unwrap(); + assert_tensor_float(&result, &[0], 5.0, 1e-10, "x[0]"); + assert_tensor_float(&result, &[1], 3.0, 1e-10, "x[1]"); + assert_tensor_float(&result, &[2], -2.0, 1e-10, "x[2]"); + } + + #[test] + fn test_solve_singular_error() { + let mut e = Evaluator::new(); + let a = mat2x2(1.0, 2.0, 2.0, 4.0); + let b = vec_expr(&[1.0, 2.0]); + assert!(e.eval(&Expr::app(Expr::app(Expr::name("solve"), a), b)).is_err()); + } + + #[test] + fn test_solve_dimension_mismatch() { + let mut e = Evaluator::new(); + let a = mat2x2(1.0, 2.0, 3.0, 4.0); + let b = vec_expr(&[1.0, 2.0, 3.0]); + assert!(e.eval(&Expr::app(Expr::app(Expr::name("solve"), a), b)).is_err()); + } + + // Phase 6: solveWith + QR + + #[test] + fn test_solve_with_lu_explicit() { + let mut e = Evaluator::new(); + let a = mat2x2(2.0, 1.0, 5.0, 3.0); + let b = vec_expr(&[4.0, 7.0]); + let method = Expr::Lit(Literal::string("lu")); + let expr = Expr::app(Expr::app(Expr::app(Expr::name("solveWith"), a), b), method); + let result = e.eval(&expr).unwrap(); + assert_tensor_float(&result, &[0], 5.0, 1e-10, "lu x[0]"); + assert_tensor_float(&result, &[1], -6.0, 1e-10, "lu x[1]"); + } + + #[test] + fn test_solve_with_qr() { + let mut e = Evaluator::new(); + let a = mat2x2(2.0, 1.0, 5.0, 3.0); + let b = vec_expr(&[4.0, 7.0]); + let method = Expr::Lit(Literal::string("qr")); + let expr = Expr::app(Expr::app(Expr::app(Expr::name("solveWith"), a), b), method); + let result = e.eval(&expr).unwrap(); + assert_tensor_float(&result, &[0], 5.0, 1e-8, "qr x[0]"); + assert_tensor_float(&result, &[1], -6.0, 1e-8, "qr x[1]"); + } + + #[test] + fn test_solve_with_qr_overdetermined() { + let mut e = Evaluator::new(); + let a = Expr::array(vec![ + Expr::array(vec![Expr::float(1.0), Expr::float(1.0)]), + Expr::array(vec![Expr::float(1.0), Expr::float(2.0)]), + Expr::array(vec![Expr::float(1.0), Expr::float(3.0)]), + ]); + let b = vec_expr(&[1.0, 2.0, 2.0]); + let method = Expr::Lit(Literal::string("qr")); + let expr = Expr::app(Expr::app(Expr::app(Expr::name("solveWith"), a), b), method); + let result = e.eval(&expr).unwrap(); + assert_tensor_float(&result, &[0], 2.0 / 3.0, 1e-8, "lstsq x[0]"); + assert_tensor_float(&result, &[1], 0.5, 1e-8, "lstsq x[1]"); + } + + #[test] + fn test_solve_with_unknown_method() { + let mut e = Evaluator::new(); + let a = mat2x2(1.0, 0.0, 0.0, 1.0); + let b = vec_expr(&[1.0, 2.0]); + let method = Expr::Lit(Literal::string("nonsense")); + let expr = Expr::app(Expr::app(Expr::app(Expr::name("solveWith"), a), b), method); + assert!(e.eval(&expr).is_err()); + } } diff --git a/crates/goth-eval/src/prim.rs b/crates/goth-eval/src/prim.rs index 71cf80c..f759f8a 100644 --- a/crates/goth-eval/src/prim.rs +++ b/crates/goth-eval/src/prim.rs @@ -1,7 +1,7 @@ //! Primitive operations for Goth use std::rc::Rc; -use crate::value::{Value, Tensor, PrimFn}; +use crate::value::{Value, Tensor, TensorData, PrimFn}; use crate::error::{EvalError, EvalResult}; use ordered_float::OrderedFloat; @@ -72,6 +72,65 @@ fn multiplicative_unc(a: f64, b: f64, da: f64, db: f64) -> f64 { } } +// ============ Complex arithmetic helpers ============ + +fn complex_mul(r1: f64, i1: f64, r2: f64, i2: f64) -> (f64, f64) { + (r1 * r2 - i1 * i2, r1 * i2 + i1 * r2) +} + +fn complex_div(r1: f64, i1: f64, r2: f64, i2: f64) -> (f64, f64) { + let denom = r2 * r2 + i2 * i2; + ((r1 * r2 + i1 * i2) / denom, (i1 * r2 - r1 * i2) / denom) +} + +fn complex_abs(re: f64, im: f64) -> f64 { (re * re + im * im).sqrt() } +fn complex_arg(re: f64, im: f64) -> f64 { im.atan2(re) } + +fn complex_exp(re: f64, im: f64) -> (f64, f64) { + let r = re.exp(); + (r * im.cos(), r * im.sin()) +} + +fn complex_ln(re: f64, im: f64) -> (f64, f64) { + (complex_abs(re, im).ln(), complex_arg(re, im)) +} + +fn complex_sqrt(re: f64, im: f64) -> (f64, f64) { + let r = complex_abs(re, im); + let re_out = ((r + re) / 2.0).sqrt(); + let im_out = ((r - re) / 2.0).sqrt() * if im >= 0.0 { 1.0 } else { -1.0 }; + (re_out, im_out) +} + +fn complex_sin(re: f64, im: f64) -> (f64, f64) { + (re.sin() * im.cosh(), re.cos() * im.sinh()) +} + +fn complex_cos(re: f64, im: f64) -> (f64, f64) { + (re.cos() * im.cosh(), -(re.sin() * im.sinh())) +} + +fn complex_pow(r1: f64, i1: f64, r2: f64, i2: f64) -> (f64, f64) { + let (ln_r, ln_i) = complex_ln(r1, i1); + let (mul_r, mul_i) = complex_mul(r2, i2, ln_r, ln_i); + complex_exp(mul_r, mul_i) +} + +// ============ Quaternion arithmetic helpers ============ + +fn quat_mul(a: (f64, f64, f64, f64), b: (f64, f64, f64, f64)) -> (f64, f64, f64, f64) { + ( + a.0*b.0 - a.1*b.1 - a.2*b.2 - a.3*b.3, + a.0*b.1 + a.1*b.0 + a.2*b.3 - a.3*b.2, + a.0*b.2 - a.1*b.3 + a.2*b.0 + a.3*b.1, + a.0*b.3 + a.1*b.2 - a.2*b.1 + a.3*b.0, + ) +} + +fn quat_norm(q: (f64, f64, f64, f64)) -> f64 { + (q.0*q.0 + q.1*q.1 + q.2*q.2 + q.3*q.3).sqrt() +} + pub fn apply_binop(op: &goth_ast::op::BinOp, left: Value, right: Value) -> EvalResult { use goth_ast::op::BinOp::*; match op { @@ -334,6 +393,17 @@ pub fn apply_prim(prim: PrimFn, args: Vec) -> EvalResult { PrimFn::BitXor => binary_args(&args, bitxor), PrimFn::Shl => binary_args(&args, shl), PrimFn::Shr => binary_args(&args, shr), + PrimFn::Re => unary_args(&args, prim_re), + PrimFn::Im => unary_args(&args, prim_im), + PrimFn::Conj => unary_args(&args, prim_conj), + PrimFn::Arg => unary_args(&args, prim_arg), + PrimFn::Trace => unary_args(&args, mat_trace), + PrimFn::Det => unary_args(&args, mat_det), + PrimFn::Inv => unary_args(&args, mat_inv), + PrimFn::Diag => unary_args(&args, mat_diag), + PrimFn::Eye => unary_args(&args, mat_eye), + PrimFn::Solve => binary_args(&args, mat_solve), + PrimFn::SolveWith => ternary_args(&args, mat_solve_with), _ => Err(EvalError::not_implemented(format!("primitive: {:?}", prim))), } } @@ -348,6 +418,11 @@ fn binary_args(args: &[Value], f: F) -> EvalResult where F: FnOnce(Val f(args[0].clone(), args[1].clone()) } +fn ternary_args(args: &[Value], f: F) -> EvalResult where F: FnOnce(Value, Value, Value) -> EvalResult { + if args.len() != 3 { return Err(EvalError::ArityMismatch { expected: 3, got: args.len() }); } + f(args[0].clone(), args[1].clone(), args[2].clone()) +} + fn add(left: Value, right: Value) -> EvalResult { match (&left, &right) { // Uncertain + Uncertain: additive propagation @@ -367,6 +442,24 @@ fn add(left: Value, right: Value) -> EvalResult { let (b, db) = uncertain_parts(&right).unwrap(); Ok(make_uncertain(a + b, db)) } + // Quaternion + anything (widest first) + (Value::Quaternion(w1, i1, j1, k1), _) if right.coerce_quaternion().is_some() => { + let (w2, i2, j2, k2) = right.coerce_quaternion().unwrap(); + Ok(Value::Quaternion(w1 + w2, i1 + i2, j1 + j2, k1 + k2)) + } + (_, Value::Quaternion(w2, i2, j2, k2)) if left.coerce_quaternion().is_some() => { + let (w1, i1, j1, k1) = left.coerce_quaternion().unwrap(); + Ok(Value::Quaternion(w1 + w2, i1 + i2, j1 + j2, k1 + k2)) + } + // Complex + anything + (Value::Complex(r1, i1), _) if right.coerce_complex().is_some() => { + let (r2, i2) = right.coerce_complex().unwrap(); + Ok(Value::Complex(r1 + r2, i1 + i2)) + } + (_, Value::Complex(r2, i2)) if left.coerce_complex().is_some() => { + let (r1, i1) = left.coerce_complex().unwrap(); + Ok(Value::Complex(r1 + r2, i1 + i2)) + } (Value::Int(a), Value::Int(b)) => a.checked_add(*b).map(Value::Int).ok_or_else(|| EvalError::Overflow(format!("{} + {} overflows", a, b))), (Value::Float(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(a.0 + b.0))), (Value::Int(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(*a as f64 + b.0))), @@ -399,6 +492,24 @@ fn sub(left: Value, right: Value) -> EvalResult { let (b, db) = uncertain_parts(&right).unwrap(); Ok(make_uncertain(a - b, db)) } + // Quaternion - anything (widest first) + (Value::Quaternion(w1, i1, j1, k1), _) if right.coerce_quaternion().is_some() => { + let (w2, i2, j2, k2) = right.coerce_quaternion().unwrap(); + Ok(Value::Quaternion(w1 - w2, i1 - i2, j1 - j2, k1 - k2)) + } + (_, Value::Quaternion(w2, i2, j2, k2)) if left.coerce_quaternion().is_some() => { + let (w1, i1, j1, k1) = left.coerce_quaternion().unwrap(); + Ok(Value::Quaternion(w1 - w2, i1 - i2, j1 - j2, k1 - k2)) + } + // Complex - anything + (Value::Complex(r1, i1), _) if right.coerce_complex().is_some() => { + let (r2, i2) = right.coerce_complex().unwrap(); + Ok(Value::Complex(r1 - r2, i1 - i2)) + } + (_, Value::Complex(r2, i2)) if left.coerce_complex().is_some() => { + let (r1, i1) = left.coerce_complex().unwrap(); + Ok(Value::Complex(r1 - r2, i1 - i2)) + } (Value::Int(a), Value::Int(b)) => a.checked_sub(*b).map(Value::Int).ok_or_else(|| EvalError::Overflow(format!("{} - {} overflows", a, b))), (Value::Float(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(a.0 - b.0))), (Value::Int(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(*a as f64 - b.0))), @@ -429,6 +540,32 @@ fn mul(left: Value, right: Value) -> EvalResult { let (b, db) = uncertain_parts(&right).unwrap(); Ok(make_uncertain(a * b, (a * db).abs())) } + // Quaternion Γ— anything (widest first, non-commutative!) + (Value::Quaternion(..), _) if right.coerce_quaternion().is_some() => { + let a = left.coerce_quaternion().unwrap(); + let b = right.coerce_quaternion().unwrap(); + let (w, i, j, k) = quat_mul(a, b); + Ok(Value::Quaternion(w, i, j, k)) + } + (_, Value::Quaternion(..)) if left.coerce_quaternion().is_some() => { + let a = left.coerce_quaternion().unwrap(); + let b = right.coerce_quaternion().unwrap(); + let (w, i, j, k) = quat_mul(a, b); + Ok(Value::Quaternion(w, i, j, k)) + } + // Complex Γ— anything + (Value::Complex(..), _) if right.coerce_complex().is_some() => { + let (r1, i1) = left.coerce_complex().unwrap(); + let (r2, i2) = right.coerce_complex().unwrap(); + let (r, i) = complex_mul(r1, i1, r2, i2); + Ok(Value::Complex(r, i)) + } + (_, Value::Complex(..)) if left.coerce_complex().is_some() => { + let (r1, i1) = left.coerce_complex().unwrap(); + let (r2, i2) = right.coerce_complex().unwrap(); + let (r, i) = complex_mul(r1, i1, r2, i2); + Ok(Value::Complex(r, i)) + } (Value::Int(a), Value::Int(b)) => a.checked_mul(*b).map(Value::Int).ok_or_else(|| EvalError::Overflow(format!("{} Γ— {} overflows", a, b))), (Value::Float(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(a.0 * b.0))), (Value::Int(a), Value::Float(b)) => Ok(Value::Float(OrderedFloat(*a as f64 * b.0))), @@ -468,6 +605,26 @@ fn div(left: Value, right: Value) -> EvalResult { let result = a / b; Ok(make_uncertain(result, result.abs() * (db / b).abs())) } + // Quaternion / anything: a Γ— conj(b) / |b|Β² + (Value::Quaternion(..), _) | (_, Value::Quaternion(..)) + if left.coerce_quaternion().is_some() && right.coerce_quaternion().is_some() => + { + let a = left.coerce_quaternion().unwrap(); + let b = right.coerce_quaternion().unwrap(); + let norm_sq = b.0*b.0 + b.1*b.1 + b.2*b.2 + b.3*b.3; + let conj_b = (b.0, -b.1, -b.2, -b.3); + let (w, i, j, k) = quat_mul(a, conj_b); + Ok(Value::Quaternion(w / norm_sq, i / norm_sq, j / norm_sq, k / norm_sq)) + } + // Complex / anything + (Value::Complex(..), _) | (_, Value::Complex(..)) + if left.coerce_complex().is_some() && right.coerce_complex().is_some() => + { + let (r1, i1) = left.coerce_complex().unwrap(); + let (r2, i2) = right.coerce_complex().unwrap(); + let (r, i) = complex_div(r1, i1, r2, i2); + Ok(Value::Complex(r, i)) + } (Value::Int(a), Value::Int(b)) => if *b == 0 { Err(EvalError::DivisionByZero) } else { Ok(Value::Int(a / b)) }, (Value::Float(a), Value::Float(b)) => if b.0 == 0.0 { Err(EvalError::DivisionByZero) } else { Ok(Value::Float(OrderedFloat(a.0 / b.0))) }, (Value::Int(a), Value::Float(b)) => if b.0 == 0.0 { Err(EvalError::DivisionByZero) } else { Ok(Value::Float(OrderedFloat(*a as f64 / b.0))) }, @@ -559,6 +716,15 @@ fn pow(left: Value, right: Value) -> EvalResult { let unc = if a > 0.0 { (result * a.ln() * db).abs() } else { 0.0 }; Ok(make_uncertain(result, unc)) } + // Complex ^ anything: exp(z2 * ln(z1)) + (Value::Complex(..), _) | (_, Value::Complex(..)) + if left.coerce_complex().is_some() && right.coerce_complex().is_some() => + { + let (r1, i1) = left.coerce_complex().unwrap(); + let (r2, i2) = right.coerce_complex().unwrap(); + let (r, i) = complex_pow(r1, i1, r2, i2); + Ok(Value::Complex(r, i)) + } (Value::Int(a), Value::Int(b)) => { if *b < 0 { Ok(Value::Float(OrderedFloat((*a as f64).powi(*b as i32)))) @@ -589,6 +755,8 @@ fn negate(value: Value) -> EvalResult { let (a, da) = uncertain_parts(&value).unwrap(); Ok(make_uncertain(-a, da)) } + Value::Quaternion(w, i, j, k) => Ok(Value::Quaternion(-w, -i, -j, -k)), + Value::Complex(r, i) => Ok(Value::Complex(-r, -i)), Value::Int(n) => n.checked_neg().map(Value::Int).ok_or_else(|| EvalError::Overflow(format!("negation of {} overflows", n))), Value::Float(f) => Ok(Value::Float(OrderedFloat(-f.0))), Value::Tensor(t) => Ok(Value::Tensor(Rc::new(t.map(|x| negate(x).unwrap_or(Value::Error("negate failed".into())))))), @@ -602,6 +770,8 @@ fn abs(value: Value) -> EvalResult { let (a, da) = uncertain_parts(&value).unwrap(); Ok(make_uncertain(a.abs(), da)) } + Value::Quaternion(w, i, j, k) => Ok(Value::Float(OrderedFloat(quat_norm((w, i, j, k))))), + Value::Complex(r, i) => Ok(Value::Float(OrderedFloat(complex_abs(r, i)))), Value::Int(n) => Ok(Value::Int(n.abs())), Value::Float(f) => Ok(Value::Float(OrderedFloat(f.0.abs()))), _ => Err(EvalError::type_error("numeric", &value)), @@ -610,10 +780,12 @@ fn abs(value: Value) -> EvalResult { fn exp(value: Value) -> EvalResult { if let Some((a, da)) = uncertain_parts(&value) { let r = a.exp(); return Ok(make_uncertain(r, r * da)); } + if let Value::Complex(re, im) = &value { let (r, i) = complex_exp(*re, *im); return Ok(Value::Complex(r, i)); } let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; Ok(Value::Float(OrderedFloat(f.exp()))) } fn ln(value: Value) -> EvalResult { if let Some((a, da)) = uncertain_parts(&value) { if a <= 0.0 { return Err(EvalError::type_error_msg("ln requires positive argument")); } return Ok(make_uncertain(a.ln(), da / a.abs())); } + if let Value::Complex(re, im) = &value { let (r, i) = complex_ln(*re, *im); return Ok(Value::Complex(r, i)); } let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; if f <= 0.0 { Err(EvalError::type_error_msg("ln requires positive argument")) } else { Ok(Value::Float(OrderedFloat(f.ln()))) } } fn log10(value: Value) -> EvalResult { @@ -626,14 +798,24 @@ fn log2(value: Value) -> EvalResult { } fn sqrt(value: Value) -> EvalResult { if let Some((a, da)) = uncertain_parts(&value) { if a < 0.0 { return Err(EvalError::type_error_msg("sqrt requires non-negative argument")); } let r = a.sqrt(); return Ok(make_uncertain(r, da / (2.0 * r))); } - let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; if f < 0.0 { Err(EvalError::type_error_msg("sqrt requires non-negative argument")) } else { Ok(Value::Float(OrderedFloat(f.sqrt()))) } + if let Value::Complex(re, im) = &value { let (r, i) = complex_sqrt(*re, *im); return Ok(Value::Complex(r, i)); } + let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; + if f < 0.0 { + // sqrt of negative real β†’ complex result + let (r, i) = complex_sqrt(f, 0.0); + Ok(Value::Complex(r, i)) + } else { + Ok(Value::Float(OrderedFloat(f.sqrt()))) + } } fn sin(value: Value) -> EvalResult { if let Some((a, da)) = uncertain_parts(&value) { return Ok(make_uncertain(a.sin(), a.cos().abs() * da)); } + if let Value::Complex(re, im) = &value { let (r, i) = complex_sin(*re, *im); return Ok(Value::Complex(r, i)); } let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; Ok(Value::Float(OrderedFloat(f.sin()))) } fn cos(value: Value) -> EvalResult { if let Some((a, da)) = uncertain_parts(&value) { return Ok(make_uncertain(a.cos(), a.sin().abs() * da)); } + if let Value::Complex(re, im) = &value { let (r, i) = complex_cos(*re, *im); return Ok(Value::Complex(r, i)); } let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; Ok(Value::Float(OrderedFloat(f.cos()))) } fn tan(value: Value) -> EvalResult { @@ -682,6 +864,54 @@ fn sign(value: Value) -> EvalResult { let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; Ok(Value::Float(OrderedFloat(if f > 0.0 { 1.0 } else if f < 0.0 { -1.0 } else { 0.0 }))) } +// Complex/quaternion decomposition primitives +fn prim_re(value: Value) -> EvalResult { + match &value { + Value::Quaternion(w, _, _, _) => Ok(Value::Float(OrderedFloat(*w))), + Value::Complex(re, _) => Ok(Value::Float(OrderedFloat(*re))), + _ => { + let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; + Ok(Value::Float(OrderedFloat(f))) + } + } +} + +fn prim_im(value: Value) -> EvalResult { + match &value { + Value::Quaternion(_, i, j, k) => Ok(Value::Tuple(vec![ + Value::Float(OrderedFloat(*i)), + Value::Float(OrderedFloat(*j)), + Value::Float(OrderedFloat(*k)), + ])), + Value::Complex(_, im) => Ok(Value::Float(OrderedFloat(*im))), + _ => { + let _ = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; + Ok(Value::Float(OrderedFloat(0.0))) + } + } +} + +fn prim_conj(value: Value) -> EvalResult { + match &value { + Value::Quaternion(w, i, j, k) => Ok(Value::Quaternion(*w, -i, -j, -k)), + Value::Complex(re, im) => Ok(Value::Complex(*re, -im)), + _ => { + let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; + Ok(Value::Float(OrderedFloat(f))) + } + } +} + +fn prim_arg(value: Value) -> EvalResult { + match &value { + Value::Complex(re, im) => Ok(Value::Float(OrderedFloat(complex_arg(*re, *im)))), + _ => { + let f = value.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &value))?; + Ok(Value::Float(OrderedFloat(if f >= 0.0 { 0.0 } else { std::f64::consts::PI }))) + } + } +} + // Digamma (ψ) function via finite differences for uncertainty propagation. fn digamma_approx(x: f64) -> f64 { let h = 1e-7; @@ -1275,3 +1505,320 @@ fn str_contains(str_val: Value, substr_val: Value) -> EvalResult { ))), } } + +// ── Matrix utility helpers ── + +fn tensor_to_floats(t: &Tensor) -> EvalResult> { + let n = t.shape.iter().product::(); + let mut out = Vec::with_capacity(n); + for i in 0..n { + let v = t.get_flat(i).ok_or_else(|| EvalError::shape_mismatch("index out of bounds"))?; + out.push(v.coerce_float().ok_or_else(|| EvalError::type_error("numeric", &v))?); + } + Ok(out) +} + +fn build_float_matrix(m: usize, n: usize, data: Vec) -> Tensor { + Tensor { shape: vec![m, n], data: TensorData::Float(data.into_iter().map(OrderedFloat).collect()) } +} + +fn build_float_vector(data: Vec) -> Tensor { + let len = data.len(); + Tensor { shape: vec![len], data: TensorData::Float(data.into_iter().map(OrderedFloat).collect()) } +} + +// ── LU decomposition with partial pivoting ── + +/// In-place LU decomposition. Returns (pivot_indices, det_sign). +/// Does NOT error on singular matrices β€” leaves near-zero pivots for caller to check. +fn lu_decompose(a: &mut [f64], n: usize) -> (Vec, f64) { + let mut piv: Vec = (0..n).collect(); + let mut sign = 1.0; + for k in 0..n { + // Find pivot + let mut max_val = a[k * n + k].abs(); + let mut max_row = k; + for i in (k + 1)..n { + let v = a[i * n + k].abs(); + if v > max_val { max_val = v; max_row = i; } + } + if max_row != k { + // Swap rows k and max_row + for j in 0..n { a.swap(k * n + j, max_row * n + j); } + piv.swap(k, max_row); + sign = -sign; + } + let pivot = a[k * n + k]; + if pivot.abs() < 1e-15 { continue; } + // Eliminate below + for i in (k + 1)..n { + let factor = a[i * n + k] / pivot; + a[i * n + k] = factor; // Store L factor + for j in (k + 1)..n { + a[i * n + j] -= factor * a[k * n + j]; + } + } + } + (piv, sign) +} + +/// Solve Ax = b given LU factorization with pivoting. +fn lu_solve(lu: &[f64], piv: &[usize], b: &[f64], n: usize) -> Vec { + // Apply permutation + let mut x: Vec = piv.iter().map(|&i| b[i]).collect(); + // Forward substitution (L * y = Pb, L has 1s on diagonal) + for i in 1..n { + for j in 0..i { + x[i] -= lu[i * n + j] * x[j]; + } + } + // Back substitution (U * x = y) + for i in (0..n).rev() { + for j in (i + 1)..n { + x[i] -= lu[i * n + j] * x[j]; + } + x[i] /= lu[i * n + i]; + } + x +} + +// ── QR decomposition (Householder reflections) ── + +/// QR decomposition of mΓ—n matrix (m >= n). Returns (Q, R) as flat arrays. +fn householder_qr(a: &[f64], m: usize, n: usize) -> (Vec, Vec) { + let mut r = a.to_vec(); + // Q starts as identity + let mut q = vec![0.0; m * m]; + for i in 0..m { q[i * m + i] = 1.0; } + + let k_max = if m > n { n } else { m }; + for k in 0..k_max { + // Extract column k below diagonal + let mut col = vec![0.0; m - k]; + for i in k..m { col[i - k] = r[i * n + k]; } + let col_norm = col.iter().map(|x| x * x).sum::().sqrt(); + if col_norm < 1e-15 { continue; } + + // Householder vector + let sign = if col[0] >= 0.0 { 1.0 } else { -1.0 }; + col[0] += sign * col_norm; + let v_norm_sq = col.iter().map(|x| x * x).sum::(); + if v_norm_sq < 1e-30 { continue; } + + // Apply H = I - 2vv^T/v^Tv to R (columns k..n) + for j in k..n { + let mut dot = 0.0; + for i in k..m { dot += col[i - k] * r[i * n + j]; } + let factor = 2.0 * dot / v_norm_sq; + for i in k..m { r[i * n + j] -= factor * col[i - k]; } + } + // Apply H to Q (all columns) + for j in 0..m { + let mut dot = 0.0; + for i in k..m { dot += col[i - k] * q[i * m + j]; } + let factor = 2.0 * dot / v_norm_sq; + for i in k..m { q[i * m + j] -= factor * col[i - k]; } + } + } + // Transpose Q (we accumulated H*Q^T, need Q = (H*Q^T)^T) + let mut qt = vec![0.0; m * m]; + for i in 0..m { for j in 0..m { qt[i * m + j] = q[j * m + i]; } } + (qt, r) +} + +/// Solve Ax=b via QR. Handles overdetermined (least squares) systems. +fn qr_solve_impl(a: &[f64], b: &[f64], m: usize, n: usize) -> Vec { + let (q, r) = householder_qr(a, m, n); + // Compute Q^T * b + let mut qtb = vec![0.0; m]; + for i in 0..m { + for j in 0..m { qtb[i] += q[j * m + i] * b[j]; } + } + // Back substitution with R (use first n rows/cols) + let mut x = vec![0.0; n]; + for i in (0..n).rev() { + x[i] = qtb[i]; + for j in (i + 1)..n { x[i] -= r[i * n + j] * x[j]; } + if r[i * n + i].abs() > 1e-15 { x[i] /= r[i * n + i]; } + } + x +} + +// ── Matrix primitive implementations ── + +fn mat_trace(value: Value) -> EvalResult { + match &value { + Value::Tensor(t) => { + if t.shape.len() != 2 || t.shape[0] != t.shape[1] { + return Err(EvalError::shape_mismatch("trace requires square 2D tensor")); + } + let n = t.shape[0]; + let mut sum = 0.0f64; + for i in 0..n { + sum += t.get(&[i, i]).and_then(|v| v.coerce_float()).unwrap_or(0.0); + } + Ok(Value::Float(OrderedFloat(sum))) + } + _ => Err(EvalError::type_error("Tensor", &value)), + } +} + +fn mat_eye(value: Value) -> EvalResult { + let n = value.as_index().ok_or_else(|| EvalError::type_error("Int", &value))?; + if n == 0 { return Err(EvalError::shape_mismatch("eye requires positive size")); } + let mut data = vec![0.0f64; n * n]; + for i in 0..n { data[i * n + i] = 1.0; } + Ok(Value::Tensor(Rc::new(build_float_matrix(n, n, data)))) +} + +fn mat_diag(value: Value) -> EvalResult { + match &value { + Value::Tensor(t) => { + if t.shape.len() == 1 { + // Vector β†’ diagonal matrix + let n = t.shape[0]; + let mut data = vec![0.0f64; n * n]; + for i in 0..n { + data[i * n + i] = t.get(&[i]).and_then(|v| v.coerce_float()).unwrap_or(0.0); + } + Ok(Value::Tensor(Rc::new(build_float_matrix(n, n, data)))) + } else if t.shape.len() == 2 { + // Matrix β†’ diagonal vector + let n = t.shape[0].min(t.shape[1]); + let mut data = Vec::with_capacity(n); + for i in 0..n { + data.push(t.get(&[i, i]).and_then(|v| v.coerce_float()).unwrap_or(0.0)); + } + Ok(Value::Tensor(Rc::new(build_float_vector(data)))) + } else { + Err(EvalError::shape_mismatch("diag requires 1D or 2D tensor")) + } + } + _ => Err(EvalError::type_error("Tensor", &value)), + } +} + +fn mat_det(value: Value) -> EvalResult { + match &value { + Value::Tensor(t) => { + if t.shape.len() != 2 || t.shape[0] != t.shape[1] { + return Err(EvalError::shape_mismatch("det requires square 2D tensor")); + } + let n = t.shape[0]; + let mut a = tensor_to_floats(t)?; + let (_piv, sign) = lu_decompose(&mut a, n); + let product: f64 = (0..n).map(|i| a[i * n + i]).product(); + Ok(Value::Float(OrderedFloat(sign * product))) + } + _ => Err(EvalError::type_error("Tensor", &value)), + } +} + +fn mat_inv(value: Value) -> EvalResult { + match &value { + Value::Tensor(t) => { + if t.shape.len() != 2 || t.shape[0] != t.shape[1] { + return Err(EvalError::shape_mismatch("inv requires square 2D tensor")); + } + let n = t.shape[0]; + let mut a = tensor_to_floats(t)?; + let (piv, _sign) = lu_decompose(&mut a, n); + // Check singularity + for i in 0..n { + if a[i * n + i].abs() < 1e-12 { + return Err(EvalError::shape_mismatch("singular matrix: cannot compute inverse")); + } + } + // Solve A * X = I column by column + let mut result = vec![0.0f64; n * n]; + for col in 0..n { + let mut e_col = vec![0.0; n]; + e_col[col] = 1.0; + let x = lu_solve(&a, &piv, &e_col, n); + for row in 0..n { result[row * n + col] = x[row]; } + } + Ok(Value::Tensor(Rc::new(build_float_matrix(n, n, result)))) + } + _ => Err(EvalError::type_error("Tensor", &value)), + } +} + +fn solve_lu_impl(a_val: &Value, b_val: &Value) -> EvalResult { + match (a_val, b_val) { + (Value::Tensor(at), Value::Tensor(bt)) => { + if at.shape.len() != 2 || at.shape[0] != at.shape[1] { + return Err(EvalError::shape_mismatch("solve requires square 2D matrix")); + } + let n = at.shape[0]; + if bt.shape.len() != 1 || bt.shape[0] != n { + return Err(EvalError::shape_mismatch(format!( + "solve dimension mismatch: {}Γ—{} matrix vs length-{} vector", + n, n, bt.shape[0] + ))); + } + let mut a = tensor_to_floats(at)?; + let b = tensor_to_floats(bt)?; + let (piv, _sign) = lu_decompose(&mut a, n); + for i in 0..n { + if a[i * n + i].abs() < 1e-12 { + return Err(EvalError::shape_mismatch("singular matrix: cannot solve")); + } + } + let x = lu_solve(&a, &piv, &b, n); + Ok(Value::Tensor(Rc::new(build_float_vector(x)))) + } + _ => Err(EvalError::type_error_msg(format!( + "solve requires (Tensor, Tensor), got ({}, {})", + a_val.type_name(), b_val.type_name() + ))), + } +} + +fn solve_qr_impl(a_val: &Value, b_val: &Value) -> EvalResult { + match (a_val, b_val) { + (Value::Tensor(at), Value::Tensor(bt)) => { + if at.shape.len() != 2 { + return Err(EvalError::shape_mismatch("solveWith(qr) requires 2D matrix")); + } + let m = at.shape[0]; + let n = at.shape[1]; + if m < n { + return Err(EvalError::shape_mismatch("solveWith(qr) requires m >= n (not underdetermined)")); + } + if bt.shape.len() != 1 || bt.shape[0] != m { + return Err(EvalError::shape_mismatch(format!( + "solveWith(qr) dimension mismatch: {}Γ—{} matrix vs length-{} vector", + m, n, bt.shape[0] + ))); + } + let a = tensor_to_floats(at)?; + let b = tensor_to_floats(bt)?; + let x = qr_solve_impl(&a, &b, m, n); + Ok(Value::Tensor(Rc::new(build_float_vector(x)))) + } + _ => Err(EvalError::type_error_msg(format!( + "solveWith requires (Tensor, Tensor, String), got ({}, {})", + a_val.type_name(), b_val.type_name() + ))), + } +} + +fn mat_solve(a_val: Value, b_val: Value) -> EvalResult { + solve_lu_impl(&a_val, &b_val) +} + +fn mat_solve_with(a_val: Value, b_val: Value, method_val: Value) -> EvalResult { + let method_str = match &method_val { + Value::Tensor(t) => t.to_string_value() + .ok_or_else(|| EvalError::type_error_msg("solveWith: third argument must be a method string"))?, + _ => return Err(EvalError::type_error_msg(format!( + "solveWith: third argument must be a string, got {}", + method_val.type_name() + ))), + }; + match method_str.as_str() { + "lu" => solve_lu_impl(&a_val, &b_val), + "qr" => solve_qr_impl(&a_val, &b_val), + other => Err(EvalError::not_implemented(format!("solve method '{}' (available: lu, qr)", other))), + } +} diff --git a/crates/goth-eval/src/value.rs b/crates/goth-eval/src/value.rs index f61ebaa..c9fa7d4 100644 --- a/crates/goth-eval/src/value.rs +++ b/crates/goth-eval/src/value.rs @@ -10,6 +10,8 @@ use ordered_float::OrderedFloat; pub enum Value { Int(i128), Float(OrderedFloat), + Complex(f64, f64), + Quaternion(f64, f64, f64, f64), Bool(bool), Char(char), Unit, @@ -76,6 +78,9 @@ pub enum PrimFn { Lines, Words, Bytes, // String splitting for wc StrEq, StartsWith, EndsWith, Contains, // String comparison BitAnd, BitOr, BitXor, Shl, Shr, // Bitwise operations + Re, Im, Conj, Arg, // Complex/quaternion decomposition + Trace, Det, Inv, Diag, Eye, // Matrix utilities + Solve, SolveWith, // Linear system solvers } #[derive(Debug, Clone)] @@ -94,6 +99,8 @@ pub struct Env { impl Value { pub fn int(n: impl Into) -> Self { Value::Int(n.into()) } pub fn float(f: f64) -> Self { Value::Float(OrderedFloat(f)) } + pub fn complex(re: f64, im: f64) -> Self { Value::Complex(re, im) } + pub fn quaternion(w: f64, i: f64, j: f64, k: f64) -> Self { Value::Quaternion(w, i, j, k) } pub fn bool(b: bool) -> Self { Value::Bool(b) } pub fn char(c: char) -> Self { Value::Char(c) } pub fn unit() -> Self { Value::Unit } @@ -115,7 +122,7 @@ impl Value { pub fn is_int(&self) -> bool { matches!(self, Value::Int(_)) } pub fn is_float(&self) -> bool { matches!(self, Value::Float(_)) } - pub fn is_numeric(&self) -> bool { matches!(self, Value::Int(_) | Value::Float(_)) } + pub fn is_numeric(&self) -> bool { matches!(self, Value::Int(_) | Value::Float(_) | Value::Complex(_, _) | Value::Quaternion(_, _, _, _)) } pub fn is_bool(&self) -> bool { matches!(self, Value::Bool(_)) } pub fn is_tensor(&self) -> bool { matches!(self, Value::Tensor(_)) } pub fn is_callable(&self) -> bool { matches!(self, Value::Closure(_) | Value::Primitive(_) | Value::Partial { .. }) } @@ -127,13 +134,17 @@ impl Value { pub fn as_char(&self) -> Option { match self { Value::Char(c) => Some(*c), _ => None } } pub fn as_tensor(&self) -> Option<&Tensor> { match self { Value::Tensor(t) => Some(t), _ => None } } pub fn as_tuple(&self) -> Option<&[Value]> { match self { Value::Tuple(vs) => Some(vs), Value::Unit => Some(&[]), _ => None } } - pub fn coerce_float(&self) -> Option { match self { Value::Float(f) => Some(f.0), Value::Int(n) => Some(*n as f64), _ => None } } + pub fn coerce_float(&self) -> Option { match self { Value::Float(f) => Some(f.0), Value::Int(n) => Some(*n as f64), Value::Complex(re, im) if *im == 0.0 => Some(*re), _ => None } } + pub fn coerce_complex(&self) -> Option<(f64, f64)> { match self { Value::Complex(re, im) => Some((*re, *im)), Value::Float(f) => Some((f.0, 0.0)), Value::Int(n) => Some((*n as f64, 0.0)), _ => None } } + pub fn coerce_quaternion(&self) -> Option<(f64, f64, f64, f64)> { match self { Value::Quaternion(w, i, j, k) => Some((*w, *i, *j, *k)), Value::Complex(re, im) => Some((*re, *im, 0.0, 0.0)), Value::Float(f) => Some((f.0, 0.0, 0.0, 0.0)), Value::Int(n) => Some((*n as f64, 0.0, 0.0, 0.0)), _ => None } } /// Coerce to a usize index: accepts Int directly, or Float if it's a whole number. pub fn as_index(&self) -> Option { match self { Value::Int(n) => Some(*n as usize), Value::Float(f) => { let v = f.0; if v.fract() == 0.0 && v >= 0.0 { Some(v as usize) } else { None } }, _ => None } } pub fn type_name(&self) -> &'static str { match self { - Value::Int(_) => "Int", Value::Float(_) => "Float", Value::Bool(_) => "Bool", + Value::Int(_) => "Int", Value::Float(_) => "Float", + Value::Complex(_, _) => "Complex", Value::Quaternion(_, _, _, _) => "Quaternion", + Value::Bool(_) => "Bool", Value::Char(_) => "Char", Value::Unit => "Unit", Value::Tensor(_) => "Tensor", Value::Tuple(_) => "Tuple", Value::Record(_) => "Record", Value::Variant { .. } => "Variant", Value::Closure(_) => "Closure", Value::Primitive(_) => "Primitive", @@ -147,6 +158,9 @@ impl Value { match (self, other) { (Value::Int(a), Value::Int(b)) => a == b, (Value::Float(a), Value::Float(b)) => a == b, + (Value::Complex(r1, i1), Value::Complex(r2, i2)) => r1 == r2 && i1 == i2, + (Value::Quaternion(w1, i1, j1, k1), Value::Quaternion(w2, i2, j2, k2)) => + w1 == w2 && i1 == i2 && j1 == j2 && k1 == k2, (Value::Bool(a), Value::Bool(b)) => a == b, (Value::Char(a), Value::Char(b)) => a == b, (Value::Unit, Value::Unit) => true, @@ -299,6 +313,20 @@ impl std::fmt::Display for Value { match self { Value::Int(n) => write!(f, "{}", n), Value::Float(x) => write!(f, "{}", x.0), + Value::Complex(re, im) => { + if *re == 0.0 && *im == 0.0 { write!(f, "0") } + else if *re == 0.0 { write!(f, "{}π•š", im) } + else if *im == 0.0 { write!(f, "{}", re) } + else if *im > 0.0 { write!(f, "{} + {}π•š", re, im) } + else { write!(f, "{} - {}π•š", re, -im) } + } + Value::Quaternion(w, i, j, k) => { + write!(f, "{}", w)?; + if *i >= 0.0 { write!(f, " + {}π•š", i)?; } else { write!(f, " - {}π•š", -i)?; } + if *j >= 0.0 { write!(f, " + {}𝕛", j)?; } else { write!(f, " - {}𝕛", -j)?; } + if *k >= 0.0 { write!(f, " + {}π•œ", k)?; } else { write!(f, " - {}π•œ", -k)?; } + Ok(()) + } Value::Bool(true) => write!(f, "⊀"), Value::Bool(false) => write!(f, "βŠ₯"), Value::Char(c) => write!(f, "'{}'", c), diff --git a/crates/goth-parse/src/lexer.rs b/crates/goth-parse/src/lexer.rs index 9695f78..5ca58c1 100644 --- a/crates/goth-parse/src/lexer.rs +++ b/crates/goth-parse/src/lexer.rs @@ -353,6 +353,48 @@ pub enum Token { #[regex(r"β‚€|₁|β‚‚|₃|β‚„|β‚…|₆|₇|β‚ˆ|₉|_[0-9]+", priority = 3, callback = |lex| parse_index(lex.slice()))] Index(u32), + // ============ Imaginary Literals ============ + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?\u{1D55A}", priority = 3, callback = |lex| { + let s = lex.slice(); + let num = &s[..s.len() - "\u{1D55A}".len()]; + num.parse::().ok() + })] + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?i", priority = 3, callback = |lex| { + let s = lex.slice(); + s[..s.len() - 1].parse::().ok() + })] + ImagLitI(f64), + + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?\u{1D55B}", priority = 3, callback = |lex| { + let s = lex.slice(); + let num = &s[..s.len() - "\u{1D55B}".len()]; + num.parse::().ok() + })] + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?j", priority = 3, callback = |lex| { + let s = lex.slice(); + s[..s.len() - 1].parse::().ok() + })] + ImagLitJ(f64), + + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?\u{1D55C}", priority = 3, callback = |lex| { + let s = lex.slice(); + let num = &s[..s.len() - "\u{1D55C}".len()]; + num.parse::().ok() + })] + #[regex(r"[0-9]+(\.[0-9]+([eE][+-]?[0-9]+)?)?k", priority = 3, callback = |lex| { + let s = lex.slice(); + s[..s.len() - 1].parse::().ok() + })] + ImagLitK(f64), + + // Standalone Unicode imaginary units + #[token("\u{1D55A}", priority = 5)] + ImagI, + #[token("\u{1D55B}", priority = 5)] + ImagJ, + #[token("\u{1D55C}", priority = 5)] + ImagK, + // ============ Literals ============ #[regex(r"[0-9]+", priority = 2, callback = |lex| lex.slice().parse::().ok())] Int(i128), @@ -417,6 +459,12 @@ pub enum Token { TyInt, #[token("Unit")] TyUnit, + #[token("\u{2102}")] + #[token("Complex")] + TyComplex, + #[token("\u{210D}")] + #[token("Quaternion")] + TyQuaternion, /// Invalid/unrecognized token (not generated by logos; constructed by the lexer) Error(String), @@ -526,6 +574,8 @@ impl Token { pub fn can_start_expr(&self) -> bool { matches!(self, Token::Int(_) | Token::Float(_) | Token::String(_) | Token::Char(_) | + Token::ImagLitI(_) | Token::ImagLitJ(_) | Token::ImagLitK(_) | + Token::ImagI | Token::ImagJ | Token::ImagK | Token::True | Token::False | Token::Pi | Token::Euler | Token::Ident(_) | Token::TyVar(_) | Token::AplIdent(_) | Token::Lambda | Token::LParen | Token::LBracket | Token::LAngle | @@ -587,6 +637,12 @@ impl fmt::Display for Token { Token::Bind => write!(f, " "), Token::Sum => write!(f, "Ξ£"), Token::Prod => write!(f, "Ξ "), + Token::ImagLitI(x) => write!(f, "{}π•š", x), + Token::ImagLitJ(x) => write!(f, "{}𝕛", x), + Token::ImagLitK(x) => write!(f, "{}π•œ", x), + Token::ImagI => write!(f, "π•š"), + Token::ImagJ => write!(f, "𝕛"), + Token::ImagK => write!(f, "π•œ"), Token::Error(s) => write!(f, "", s), _ => write!(f, "{:?}", self), } @@ -907,4 +963,56 @@ mod tests { assert_eq!(lex.next(), Some(Token::Pure)); assert_eq!(lex.next(), Some(Token::Diamond)); } + + // ============ Imaginary Literal Tests ============ + + #[test] + fn test_imaginary_int_unicode() { + let mut lex = Lexer::new("4\u{1D55A}"); + assert_eq!(lex.next(), Some(Token::ImagLitI(4.0))); + } + + #[test] + fn test_imaginary_float_unicode() { + let mut lex = Lexer::new("3.14\u{1D55A}"); + assert_eq!(lex.next(), Some(Token::ImagLitI(3.14))); + } + + #[test] + fn test_imaginary_ascii_suffix() { + let mut lex = Lexer::new("4i 2.5j 3k"); + assert_eq!(lex.next(), Some(Token::ImagLitI(4.0))); + assert_eq!(lex.next(), Some(Token::ImagLitJ(2.5))); + assert_eq!(lex.next(), Some(Token::ImagLitK(3.0))); + } + + #[test] + fn test_standalone_imaginary_unicode() { + let mut lex = Lexer::new("\u{1D55A} \u{1D55B} \u{1D55C}"); + assert_eq!(lex.next(), Some(Token::ImagI)); + assert_eq!(lex.next(), Some(Token::ImagJ)); + assert_eq!(lex.next(), Some(Token::ImagK)); + } + + #[test] + fn test_complex_type_tokens() { + let mut lex = Lexer::new("\u{2102} \u{210D}"); + assert_eq!(lex.next(), Some(Token::TyComplex)); + assert_eq!(lex.next(), Some(Token::TyQuaternion)); + } + + #[test] + fn test_complex_type_ascii() { + let mut lex = Lexer::new("Complex Quaternion"); + assert_eq!(lex.next(), Some(Token::TyComplex)); + assert_eq!(lex.next(), Some(Token::TyQuaternion)); + } + + #[test] + fn test_complex_expression_tokens() { + let mut lex = Lexer::new("3 + 4\u{1D55A}"); + assert_eq!(lex.next(), Some(Token::Int(3))); + assert_eq!(lex.next(), Some(Token::Plus)); + assert_eq!(lex.next(), Some(Token::ImagLitI(4.0))); + } } diff --git a/crates/goth-parse/src/parser.rs b/crates/goth-parse/src/parser.rs index 9812923..51182c2 100644 --- a/crates/goth-parse/src/parser.rs +++ b/crates/goth-parse/src/parser.rs @@ -312,6 +312,12 @@ impl<'a> Parser<'a> { Some(Token::False) => { self.next(); Expr::Lit(Literal::False) } Some(Token::Pi) => { self.next(); Expr::Lit(Literal::Float(std::f64::consts::PI)) } Some(Token::Euler) => { self.next(); Expr::Lit(Literal::Float(std::f64::consts::E)) } + Some(Token::ImagLitI(f)) => { self.next(); Expr::Lit(Literal::ImagI(f)) } + Some(Token::ImagLitJ(f)) => { self.next(); Expr::Lit(Literal::ImagJ(f)) } + Some(Token::ImagLitK(f)) => { self.next(); Expr::Lit(Literal::ImagK(f)) } + Some(Token::ImagI) => { self.next(); Expr::Lit(Literal::ImagI(1.0)) } + Some(Token::ImagJ) => { self.next(); Expr::Lit(Literal::ImagJ(1.0)) } + Some(Token::ImagK) => { self.next(); Expr::Lit(Literal::ImagK(1.0)) } // De Bruijn index Some(Token::Index(i)) => { self.next(); Expr::Idx(i) } @@ -940,6 +946,8 @@ impl<'a> Parser<'a> { Some(Token::TyNat) => { self.next(); Ok(Type::Prim(PrimType::Nat)) } Some(Token::TyInt) => { self.next(); Ok(Type::Prim(PrimType::Int)) } Some(Token::TyUnit) => { self.next(); Ok(Type::Tuple(vec![])) } + Some(Token::TyComplex) => { self.next(); Ok(Type::Prim(PrimType::Complex)) } + Some(Token::TyQuaternion) => { self.next(); Ok(Type::Prim(PrimType::Quaternion)) } // Type variable Some(Token::TyVar(v)) => { self.next(); Ok(Type::Var(v.into())) } diff --git a/examples/complex/complex_arithmetic.goth b/examples/complex/complex_arithmetic.goth new file mode 100644 index 0000000..a90c844 --- /dev/null +++ b/examples/complex/complex_arithmetic.goth @@ -0,0 +1,6 @@ +# Complex arithmetic: (3+4π•š) Γ— (1+2π•š) = -5+10π•š +# Postcondition verifies real and imaginary parts + +╭─ main : I64 β†’ β„‚ +β”‚ ⊨ re(β‚€) = 0.0 - 5.0 ∧ im(β‚€) = 10.0 +╰─ (3 + 4π•š) * (1 + 2π•š) diff --git a/examples/complex/euler_identity.goth b/examples/complex/euler_identity.goth new file mode 100644 index 0000000..15ca179 --- /dev/null +++ b/examples/complex/euler_identity.goth @@ -0,0 +1,6 @@ +# Euler's identity: e^(Ο€π•š) + 1 β‰ˆ 0 +# Postcondition verifies the result is near zero + +╭─ main : I64 β†’ β„‚ +β”‚ ⊨ abs(β‚€) < 0.0001 +╰─ exp(Ο€ * π•š) + 1 diff --git a/examples/complex/quaternion_rotation.goth b/examples/complex/quaternion_rotation.goth new file mode 100644 index 0000000..f51df03 --- /dev/null +++ b/examples/complex/quaternion_rotation.goth @@ -0,0 +1,6 @@ +# Hamilton's identity: π•š Γ— 𝕛 Γ— π•œ = -1 +# Postcondition verifies the result is exactly -1 + +╭─ main : I64 β†’ ℍ +β”‚ ⊨ re(β‚€) = 0.0 - 1.0 ∧ abs(β‚€ + 1) < 0.0001 +╰─ π•š * 𝕛 * π•œ diff --git a/examples/complex/sqrt_negative.goth b/examples/complex/sqrt_negative.goth new file mode 100644 index 0000000..b23b131 --- /dev/null +++ b/examples/complex/sqrt_negative.goth @@ -0,0 +1,6 @@ +# Square root of negative number: √(-4) = 2π•š +# Postcondition: squaring the result gives back -4 + +╭─ main : I64 β†’ β„‚ +β”‚ ⊨ abs(β‚€ * β‚€ + 4) < 0.0001 +╰─ √(0 - 4) diff --git a/examples/linalg/determinant.goth b/examples/linalg/determinant.goth new file mode 100644 index 0000000..47582e8 --- /dev/null +++ b/examples/linalg/determinant.goth @@ -0,0 +1,6 @@ +# Determinant of a 3x3 matrix +# Postcondition verifies the known value + +╭─ main : I64 β†’ F64 +β”‚ ⊨ abs(β‚€ + 306.0) < 0.001 +╰─ det [[6,1,1],[4,0-2,5],[2,8,7]] diff --git a/examples/linalg/matrix_inverse.goth b/examples/linalg/matrix_inverse.goth new file mode 100644 index 0000000..f5ec0d1 --- /dev/null +++ b/examples/linalg/matrix_inverse.goth @@ -0,0 +1,6 @@ +# Compute the inverse of a 2x2 matrix +# Postcondition: A Γ— A⁻¹ has trace 2 (identity) + +╭─ main : I64 β†’ [m][n]F64 +β”‚ ⊨ abs(tr(matmul [[1,2],[3,4]] β‚€) - 2.0) < 0.001 +╰─ inv [[1,2],[3,4]] diff --git a/examples/linalg/solve_linear.goth b/examples/linalg/solve_linear.goth new file mode 100644 index 0000000..7bd7aa7 --- /dev/null +++ b/examples/linalg/solve_linear.goth @@ -0,0 +1,7 @@ +# Solve a 3x3 linear system Ax = b +# A = [[1,1,1],[0,2,5],[2,5,-1]], b = [6,-4,27] +# Postcondition: solution components match expected [5, 3, -2] + +╭─ main : I64 β†’ [n]F64 +β”‚ ⊨ abs(β‚€[0] - 5.0) < 0.001 ∧ abs(β‚€[1] - 3.0) < 0.001 ∧ abs(β‚€[2] + 2.0) < 0.001 +╰─ solve [[1,1,1],[0,2,5],[2,5,-1]] [6,-4,27]