Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions crates/goth-ast/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,9 @@ impl std::fmt::Display for Expr {
Literal::True => write!(f, "⊤"),
Literal::False => write!(f, "⊥"),
Literal::Unit => write!(f, "⟨⟩"),
Literal::ImagI(x) => write!(f, "{}𝕚", x),
Literal::ImagJ(x) => write!(f, "{}𝕛", x),
Literal::ImagK(x) => write!(f, "{}𝕜", x),
},
Expr::Prim(name) => write!(f, "⊥{}", name),
Expr::App(func, arg) => write!(f, "({} {})", func, arg),
Expand Down
15 changes: 14 additions & 1 deletion crates/goth-ast/src/literal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,15 @@ pub enum Literal {

/// Unit value (empty tuple)
Unit,

/// Imaginary-i literal (coefficient stored as f64)
ImagI(f64),

/// Imaginary-j literal (quaternion j component)
ImagJ(f64),

/// Imaginary-k literal (quaternion k component)
ImagK(f64),
}

impl Literal {
Expand All @@ -40,6 +49,10 @@ impl Literal {
Literal::String(s.into())
}

pub fn imag_i(f: f64) -> Self { Literal::ImagI(f) }
pub fn imag_j(f: f64) -> Self { Literal::ImagJ(f) }
pub fn imag_k(f: f64) -> Self { Literal::ImagK(f) }

pub fn bool(b: bool) -> Self {
if b { Literal::True } else { Literal::False }
}
Expand All @@ -60,7 +73,7 @@ impl Literal {

/// Check if this is a numeric literal
pub fn is_numeric(&self) -> bool {
matches!(self, Literal::Int(_) | Literal::Float(_))
matches!(self, Literal::Int(_) | Literal::Float(_) | Literal::ImagI(_) | Literal::ImagJ(_) | Literal::ImagK(_))
}
}

Expand Down
6 changes: 6 additions & 0 deletions crates/goth-ast/src/pretty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,9 @@ impl Pretty {
self.write("'");
}
Literal::Unit => self.write("()"),
Literal::ImagI(x) => { self.write(&x.to_string()); self.write("𝕚"); }
Literal::ImagJ(x) => { self.write(&x.to_string()); self.write("𝕛"); }
Literal::ImagK(x) => { self.write(&x.to_string()); self.write("𝕜"); }
}

Expr::Name(name) => self.write(name),
Expand Down Expand Up @@ -541,6 +544,9 @@ impl Pretty {
self.write("'");
}
Literal::Unit => self.write("()"),
Literal::ImagI(x) => { self.write(&x.to_string()); self.write("𝕚"); }
Literal::ImagJ(x) => { self.write(&x.to_string()); self.write("𝕛"); }
Literal::ImagK(x) => { self.write(&x.to_string()); self.write("𝕜"); }
}
}
Pattern::Tuple(pats) => {
Expand Down
14 changes: 12 additions & 2 deletions crates/goth-ast/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,10 @@ pub enum PrimType {
// Arbitrary precision (for compile-time computation)
Nat, // ℕ - natural numbers
Int, // ℤ - integers

// Complex number types
Complex, // ℂ - complex (f64, f64)
Quaternion, // ℍ - quaternion (f64, f64, f64, f64)
}

/// Type representation
Expand Down Expand Up @@ -173,6 +177,8 @@ impl Type {
pub fn bool() -> Self { Type::Prim(PrimType::Bool) }
pub fn char() -> Self { Type::Prim(PrimType::Char) }
pub fn nat() -> Self { Type::Prim(PrimType::Nat) }
pub fn complex() -> Self { Type::Prim(PrimType::Complex) }
pub fn quaternion() -> Self { Type::Prim(PrimType::Quaternion) }

// Tensor
pub fn tensor(shape: Shape, elem: Type) -> Self {
Expand Down Expand Up @@ -282,7 +288,7 @@ impl PrimType {
}

pub fn is_float(&self) -> bool {
matches!(self, PrimType::F64 | PrimType::F32)
matches!(self, PrimType::F64 | PrimType::F32 | PrimType::Complex | PrimType::Quaternion)
}

pub fn is_int(&self) -> bool {
Expand All @@ -306,7 +312,9 @@ impl PrimType {
PrimType::F32 | PrimType::I32 | PrimType::U32 | PrimType::Char => Some(32),
PrimType::I16 | PrimType::U16 => Some(16),
PrimType::I8 | PrimType::U8 | PrimType::Byte | PrimType::Bool => Some(8),
PrimType::Nat | PrimType::Int | PrimType::String => None, // Variable/arbitrary size
PrimType::Complex => Some(128), // 2 × f64
PrimType::Quaternion => None, // 4 × f64 = 256, doesn't fit u8
PrimType::Nat | PrimType::Int | PrimType::String => None,
}
}
}
Expand All @@ -332,6 +340,8 @@ impl std::fmt::Display for PrimType {
PrimType::String => write!(f, "String"),
PrimType::Nat => write!(f, "ℕ"),
PrimType::Int => write!(f, "ℤ"),
PrimType::Complex => write!(f, "ℂ"),
PrimType::Quaternion => write!(f, "ℍ"),
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions crates/goth-check/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,8 @@ fn literal_type(lit: &Literal) -> Type {
Box::new(Type::Prim(PrimType::Char)),
),
Literal::Unit => Type::unit(),
Literal::ImagI(_) => Type::Prim(PrimType::Complex),
Literal::ImagJ(_) | Literal::ImagK(_) => Type::Prim(PrimType::Quaternion),
}
}

Expand Down
50 changes: 48 additions & 2 deletions crates/goth-eval/src/eval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,13 @@ impl Evaluator {
// String comparison
("strEq", PrimFn::StrEq), ("startsWith", PrimFn::StartsWith),
("endsWith", PrimFn::EndsWith), ("contains", PrimFn::Contains),
// Complex/quaternion decomposition
("re", PrimFn::Re), ("im", PrimFn::Im), ("conj", PrimFn::Conj), ("arg", PrimFn::Arg),
// Matrix utilities
("trace", PrimFn::Trace), ("tr", PrimFn::Trace),
("det", PrimFn::Det), ("inv", PrimFn::Inv),
("diag", PrimFn::Diag), ("eye", PrimFn::Eye),
("solve", PrimFn::Solve), ("solveWith", PrimFn::SolveWith),
];
for (name, prim) in prims { self.globals.borrow_mut().insert(name.to_string(), Value::Primitive(*prim)); }
// Register stream constants
Expand Down Expand Up @@ -152,7 +159,7 @@ impl Evaluator {
}

fn eval_literal(&self, lit: &Literal) -> Value {
match lit { Literal::Int(n) => Value::Int(*n), Literal::Float(f) => Value::float(*f), Literal::Char(c) => Value::Char(*c), Literal::String(s) => Value::string(s), Literal::True => Value::Bool(true), Literal::False => Value::Bool(false), Literal::Unit => Value::Unit }
match lit { Literal::Int(n) => Value::Int(*n), Literal::Float(f) => Value::float(*f), Literal::Char(c) => Value::Char(*c), Literal::String(s) => Value::string(s), Literal::True => Value::Bool(true), Literal::False => Value::Bool(false), Literal::Unit => Value::Unit, Literal::ImagI(f) => Value::Complex(0.0, *f), Literal::ImagJ(f) => Value::Quaternion(0.0, 0.0, *f, 0.0), Literal::ImagK(f) => Value::Quaternion(0.0, 0.0, 0.0, *f) }
}

fn eval_binop(&mut self, op: &BinOp, left: &Expr, right: &Expr, env: &Env) -> EvalResult<Value> {
Expand Down Expand Up @@ -419,7 +426,42 @@ impl Evaluator {
else if all_float { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Float(values.iter().map(|v| ordered_float::OrderedFloat(v.coerce_float().unwrap())).collect()) })) }
else if all_bool { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Bool(values.iter().map(|v| v.as_bool().unwrap()).collect()) })) }
else if all_char { Value::Tensor(Rc::new(Tensor { shape, data: crate::value::TensorData::Char(values.iter().map(|v| v.as_char().unwrap()).collect()) })) }
else { Value::Tensor(Rc::new(Tensor::from_values(shape, values))) }
else {
// Auto-flatten: if all values are tensors with the same shape and compatible data,
// combine into a higher-dimensional tensor (e.g., [[1,2],[3,4]] → shape [2,2])
let all_tensor_same_shape = if let Some(Value::Tensor(first)) = values.first() {
let s = &first.shape;
values.iter().all(|v| matches!(v, Value::Tensor(t) if &t.shape == s))
} else { false };
if all_tensor_same_shape {
let sub_tensors: Vec<&Tensor> = values.iter().map(|v| match v { Value::Tensor(t) => t.as_ref(), _ => unreachable!() }).collect();
let sub_shape = &sub_tensors[0].shape;
let mut new_shape = shape.clone();
new_shape.extend_from_slice(sub_shape);
// Try to flatten into typed tensor
let all_int = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Int(_)));
let all_float = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Int(_) | crate::value::TensorData::Float(_)));
let all_bool = sub_tensors.iter().all(|t| matches!(t.data, crate::value::TensorData::Bool(_)));
if all_int {
let data: Vec<i128> = sub_tensors.iter().flat_map(|t| match &t.data { crate::value::TensorData::Int(v) => v.iter().copied(), _ => unreachable!() }).collect();
Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Int(data) }))
} else if all_float {
let data: Vec<ordered_float::OrderedFloat<f64>> = sub_tensors.iter().flat_map(|t| match &t.data {
crate::value::TensorData::Float(v) => v.clone(),
crate::value::TensorData::Int(v) => v.iter().map(|&i| ordered_float::OrderedFloat(i as f64)).collect(),
_ => unreachable!()
}).collect();
Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Float(data) }))
} else if all_bool {
let data: Vec<bool> = sub_tensors.iter().flat_map(|t| match &t.data { crate::value::TensorData::Bool(v) => v.iter().copied(), _ => unreachable!() }).collect();
Value::Tensor(Rc::new(Tensor { shape: new_shape, data: crate::value::TensorData::Bool(data) }))
} else {
Value::Tensor(Rc::new(Tensor::from_values(shape, values)))
}
} else {
Value::Tensor(Rc::new(Tensor::from_values(shape, values)))
}
}
}

fn access_field(&self, val: Value, access: &FieldAccess) -> EvalResult<Value> {
Expand Down Expand Up @@ -605,6 +647,10 @@ fn prim_arity(prim: PrimFn) -> usize {
PrimFn::Print | PrimFn::Write | PrimFn::ReadLine | PrimFn::ReadKey | PrimFn::ReadFile | PrimFn::Sleep => 1,
PrimFn::Flush | PrimFn::RawModeEnter | PrimFn::RawModeExit => 1, // Terminal control (take unit)
PrimFn::Lines | PrimFn::Words | PrimFn::Bytes => 1, // String splitting (unary)
PrimFn::Re | PrimFn::Im | PrimFn::Conj | PrimFn::Arg => 1, // Complex decomposition
PrimFn::Trace | PrimFn::Det | PrimFn::Inv | PrimFn::Diag | PrimFn::Eye => 1, // Matrix utilities
PrimFn::Solve => 2, // Linear solve (default LU)
PrimFn::SolveWith => 3, // Linear solve with method string
PrimFn::WriteFile | PrimFn::ReadBytes | PrimFn::WriteBytes => 2, // Binary I/O takes 2 args
PrimFn::Fold => 3, // fold f acc arr
PrimFn::StrEq | PrimFn::StartsWith | PrimFn::EndsWith | PrimFn::Contains => 2, // String comparison (binary)
Expand Down
Loading
Loading