diff --git a/Cargo.toml b/Cargo.toml index 1278e9126..027b40a77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "air", "codegen/winterfell", "codegen/ace", + "types", ] resolver = "2" @@ -22,3 +23,4 @@ rust-version = "1.89" anyhow = "1.0" miden-diagnostics = "0.1" thiserror = "2.0" +air-types = { version = "1.0", path = "types" } diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index ed15f61b6..815cca86e 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -252,7 +252,7 @@ impl AirBuilder<'_> { )); }; - let ConstantValue::Felt(rhs_value) = constant_value else { + let ConstantValue::Scalar(rhs_value) = constant_value else { return Err(CompileError::SemanticAnalysis( SemanticAnalysisError::InvalidExpr( ast::InvalidExprError::NonConstantExponent(rhs.span()), @@ -267,7 +267,7 @@ impl AirBuilder<'_> { let value = match mir_value { MirValue::Constant(constant_value) => { - if let ConstantValue::Felt(felt) = constant_value { + if let ConstantValue::Scalar(felt) = constant_value { crate::ir::Value::Constant(*felt) } else { unreachable!() @@ -323,7 +323,7 @@ impl AirBuilder<'_> { let value = match mir_value { MirValue::Constant(constant_value) => { - if let ConstantValue::Felt(felt) = constant_value { + if let ConstantValue::Scalar(felt) = constant_value { crate::ir::Value::Constant(*felt) } else { unreachable!() @@ -610,11 +610,7 @@ impl AirBuilder<'_> { .emit(); return Err(CompileError::Failed); } - MirTraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - } + MirTraceAccess::new(trace_access_binding.segment, trace_access_binding.offset, 0) }, SpannedMirValue { value: MirValue::BusAccess(bus_access), .. diff --git a/air/src/tests/trace.rs b/air/src/tests/trace.rs index 8bb69e9c1..065a3ea6a 100644 --- a/air/src/tests/trace.rs +++ b/air/src/tests/trace.rs @@ -151,5 +151,5 @@ fn err_ic_trace_cols_group_used_as_scalar() { enf a[0]' = a + clk; }"; - expect_diagnostic(source, "type mismatch"); + expect_diagnostic(source, "invalid binary expression"); } diff --git a/mir/Cargo.toml b/mir/Cargo.toml index f1bce99c5..c1d8e0c43 100644 --- a/mir/Cargo.toml +++ b/mir/Cargo.toml @@ -14,6 +14,7 @@ edition.workspace = true [dependencies] air-parser = { package = "air-parser", path = "../parser", version = "0.5" } air-pass = { package = "air-pass", path = "../pass", version = "0.5" } +air-types.workspace = true anyhow = { workspace = true } derive-ir = { package = "air-derive-ir", path = "./derive-ir", version = "0.5" } miden-core = { package = "miden-core", version = "0.13", default-features = false } @@ -21,4 +22,4 @@ miden-diagnostics = { workspace = true } pretty_assertions = "1.4" rand = "0.9" thiserror = { workspace = true } -winter-math = { package = "winter-math", version = "0.12", default-features = false } \ No newline at end of file +winter-math = { package = "winter-math", version = "0.12", default-features = false } diff --git a/mir/derive-ir/src/builder.rs b/mir/derive-ir/src/builder.rs index 92857403b..fd3b8398b 100644 --- a/mir/derive-ir/src/builder.rs +++ b/mir/derive-ir/src/builder.rs @@ -564,20 +564,24 @@ fn make_build_method( match enum_wrapper { EnumWrapper::Op => quote! { pub fn build(&self) -> crate::ir::Link { - Op::#name( + let mut op = Op::#name( #name { #(#fields),* } - ).into() + ); + op.finalize_hook(); + op.into() } }, EnumWrapper::Root => quote! { pub fn build(&self) -> crate::ir::Link { - Root::#name( + let mut root = Root::#name( #name { #(#fields),* } - ).into() + ); + root.finalize_hook(); + root.into() } }, } diff --git a/mir/src/ir/link.rs b/mir/src/ir/link.rs index db1cab648..1cbdc4cd5 100644 --- a/mir/src/ir/link.rs +++ b/mir/src/ir/link.rs @@ -5,6 +5,8 @@ use std::{ rc::{Rc, Weak}, }; +use air_types::Typing; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; /// A wrapper around a `Rc>` to allow custom trait implementations. @@ -110,6 +112,24 @@ where } } +impl Typing for Link { + fn ty(&self) -> Option { + self.borrow().ty() + } +} + +impl ScalarTypeMut for Link { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Link { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_ty_unchecked(new_ty); + } +} + /// A wrapper around a `Option>>` to allow custom trait implementations. /// Used instead of `Link` where a `Link` would create a cyclIc reference. pub struct BackLink { diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index a23623eb2..05b4ba2ff 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -7,8 +7,11 @@ mod nodes; mod owner; mod quad_eval; mod utils; +pub extern crate air_types; pub extern crate derive_ir; +#[allow(unused_imports)] +pub use air_types::*; pub use bus::Bus; pub use derive_ir::Builder; pub use graph::Graph; @@ -113,3 +116,7 @@ pub trait Builder { /// Create a new empty builder that exposes all fields fn builder() -> Self::Empty; } + +pub trait BuilderHook { + fn finalize_hook(&mut self) {} +} diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 9bf47a696..6974893cc 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -1,8 +1,8 @@ use std::ops::Deref; -use miden_diagnostics::{SourceSpan, Spanned}; +use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root, Stale}; /// All the nodes that can be in the MIR Graph /// Combines all [Root] and [Op] variants @@ -33,7 +33,8 @@ pub enum Node { BusOp(BackLink), Parameter(BackLink), Value(BackLink), - None(SourceSpan), + Cast(BackLink), + None(Stale), } impl Default for Node { @@ -64,6 +65,7 @@ impl PartialEq for Node { (Node::BusOp(lhs), Node::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Parameter(lhs), Node::Parameter(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Value(lhs), Node::Value(rhs)) => lhs.to_link() == rhs.to_link(), + (Node::Cast(lhs), Node::Cast(rhs)) => lhs.to_link() == rhs.to_link(), (Node::None(_), Node::None(_)) => true, _ => false, } @@ -92,6 +94,7 @@ impl std::hash::Hash for Node { Node::BusOp(b) => b.to_link().hash(state), Node::Parameter(p) => p.to_link().hash(state), Node::Value(v) => v.to_link().hash(state), + Node::Cast(c) => c.to_link().hash(state), Node::None(_) => (), } } @@ -119,6 +122,7 @@ impl Parent for Node { Node::BusOp(b) => b.children(), Node::Parameter(_p) => Link::default(), Node::Value(_v) => Link::default(), + Node::Cast(c) => c.children(), Node::None(_) => Link::default(), } } @@ -146,6 +150,7 @@ impl Child for Node { Node::BusOp(b) => b.get_parents(), Node::Parameter(p) => p.get_parents(), Node::Value(v) => v.get_parents(), + Node::Cast(c) => c.get_parents(), Node::None(_) => Vec::default(), } } @@ -169,6 +174,7 @@ impl Child for Node { Node::BusOp(b) => b.add_parent(parent), Node::Parameter(p) => p.add_parent(parent), Node::Value(v) => v.add_parent(parent), + Node::Cast(c) => c.add_parent(parent), Node::None(_) => (), } } @@ -192,6 +198,7 @@ impl Child for Node { Node::BusOp(b) => b.remove_parent(parent), Node::Parameter(p) => p.remove_parent(parent), Node::Value(v) => v.remove_parent(parent), + Node::Cast(c) => c.remove_parent(parent), Node::None(_) => (), } } @@ -220,17 +227,18 @@ impl Link { Op::BusOp(_) => Node::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => Node::Parameter(BackLink::from(op_inner_val)), Op::Value(_) => Node::Value(BackLink::from(op_inner_val)), - Op::None(span) => Node::None(*span), + Op::Cast(_) => Node::Cast(BackLink::from(op_inner_val)), + Op::None(none) => Node::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { to_update = match root_inner_val.clone().borrow().deref() { Root::Function(_) => Node::Function(BackLink::from(root_inner_val)), Root::Evaluator(_) => Node::Evaluator(BackLink::from(root_inner_val)), - Root::None(span) => Node::None(*span), + Root::None(none) => Node::None(none.clone()), }; } else { // If the [Node] is stale, we set it to None - to_update = Node::None(self.span()); + to_update = Node::None(Stale { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; @@ -276,6 +284,7 @@ impl Link { Node::BusOp(_) => None, Node::Parameter(_) => None, Node::Value(_) => None, + Node::Cast(_) => None, Node::None(_) => None, } } @@ -301,6 +310,7 @@ impl Link { Node::BusOp(inner) => inner.to_link(), Node::Parameter(inner) => inner.to_link(), Node::Value(inner) => inner.to_link(), + Node::Cast(inner) => inner.to_link(), Node::None(_) => None, } } diff --git a/mir/src/ir/nodes/mod.rs b/mir/src/ir/nodes/mod.rs index 5ddf29733..2ca8f45bb 100644 --- a/mir/src/ir/nodes/mod.rs +++ b/mir/src/ir/nodes/mod.rs @@ -1,3 +1,4 @@ +pub mod stale; mod op; mod ops; mod root; @@ -5,6 +6,7 @@ mod roots; use std::cell::{Ref, RefMut}; +pub use stale::*; pub use op::Op; pub use ops::*; pub use root::Root; diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index 4d903dc5b..deaebf081 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -3,12 +3,13 @@ use std::{ ops::{Deref, DerefMut}, }; -use miden_diagnostics::{SourceSpan, Spanned}; +use air_types::*; +use miden_diagnostics::Spanned; use crate::ir::{ - Accessor, Add, BackLink, Boundary, BusOp, Call, Child, ConstantValue, Enf, Exp, Fold, For, If, - Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, SpannedMirValue, Sub, - Value, Vector, get_inner, get_inner_mut, + Accessor, Add, BackLink, Boundary, BuilderHook, BusOp, Call, Cast, Child, ConstantValue, Enf, + Exp, Fold, For, If, Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, + SpannedMirValue, Stale, Sub, Value, Vector, get_inner, get_inner_mut, }; /// The combined [Op]s and leaves of the MIR Graph. @@ -33,7 +34,33 @@ pub enum Op { BusOp(BusOp), Parameter(Parameter), Value(Value), - None(SourceSpan), + Cast(Cast), + None(Stale), +} + +impl BuilderHook for Op { + fn finalize_hook(&mut self) { + match self { + Op::Enf(e) => e.finalize_hook(), + Op::Boundary(b) => b.finalize_hook(), + Op::Add(a) => a.finalize_hook(), + Op::Sub(s) => s.finalize_hook(), + Op::Mul(m) => m.finalize_hook(), + Op::Exp(e) => e.finalize_hook(), + Op::If(i) => i.finalize_hook(), + Op::For(f) => f.finalize_hook(), + Op::Call(c) => c.finalize_hook(), + Op::Fold(f) => f.finalize_hook(), + Op::Vector(v) => v.finalize_hook(), + Op::Matrix(m) => m.finalize_hook(), + Op::Accessor(a) => a.finalize_hook(), + Op::BusOp(b) => b.finalize_hook(), + Op::Parameter(p) => p.finalize_hook(), + Op::Value(v) => v.finalize_hook(), + Op::Cast(c) => c.finalize_hook(), + Op::None(_) => {}, + } + } } impl Default for Op { @@ -62,6 +89,7 @@ impl Parent for Op { Op::BusOp(b) => b.children(), Op::Parameter(_) => Link::default(), Op::Value(_) => Link::default(), + Op::Cast(c) => c.children(), Op::None(_) => Link::default(), } } @@ -87,6 +115,7 @@ impl Child for Op { Op::BusOp(b) => b.get_parents(), Op::Parameter(p) => p.get_parents(), Op::Value(v) => v.get_parents(), + Op::Cast(c) => c.get_parents(), Op::None(_) => Default::default(), } } @@ -108,6 +137,7 @@ impl Child for Op { Op::BusOp(b) => b.add_parent(parent), Op::Parameter(p) => p.add_parent(parent), Op::Value(v) => v.add_parent(parent), + Op::Cast(c) => c.add_parent(parent), Op::None(_) => {}, } } @@ -129,11 +159,87 @@ impl Child for Op { Op::BusOp(b) => b.remove_parent(parent), Op::Parameter(p) => p.remove_parent(parent), Op::Value(v) => v.remove_parent(parent), + Op::Cast(c) => c.remove_parent(parent), Op::None(_) => {}, } } } +impl ScalarTypeMut for Op { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + match self { + Op::Enf(e) => e.update_scalar_ty_unchecked(new_ty), + Op::Boundary(b) => b.update_scalar_ty_unchecked(new_ty), + Op::Add(a) => a.update_scalar_ty_unchecked(new_ty), + Op::Sub(s) => s.update_scalar_ty_unchecked(new_ty), + Op::Mul(m) => m.update_scalar_ty_unchecked(new_ty), + Op::Exp(e) => e.update_scalar_ty_unchecked(new_ty), + Op::If(i) => i.update_scalar_ty_unchecked(new_ty), + Op::For(f) => f.update_scalar_ty_unchecked(new_ty), + Op::Call(c) => c.update_scalar_ty_unchecked(new_ty), + Op::Fold(f) => f.update_scalar_ty_unchecked(new_ty), + Op::Vector(v) => v.update_scalar_ty_unchecked(new_ty), + Op::Matrix(m) => m.update_scalar_ty_unchecked(new_ty), + Op::Accessor(a) => a.update_scalar_ty_unchecked(new_ty), + Op::BusOp(_) => {}, + Op::Parameter(p) => p.update_scalar_ty_unchecked(new_ty), + Op::Value(v) => v.update_scalar_ty_unchecked(new_ty), + Op::Cast(_) => {}, + Op::None(n) => n.update_scalar_ty_unchecked(new_ty), + } + } +} + +impl TypeMut for Op { + fn update_ty_unchecked(&mut self, new_ty: Option) { + match self { + Op::Enf(e) => e.update_ty_unchecked(new_ty), + Op::Boundary(b) => b.update_ty_unchecked(new_ty), + Op::Add(a) => a.update_ty_unchecked(new_ty), + Op::Sub(s) => s.update_ty_unchecked(new_ty), + Op::Mul(m) => m.update_ty_unchecked(new_ty), + Op::Exp(e) => e.update_ty_unchecked(new_ty), + Op::If(i) => i.update_ty_unchecked(new_ty), + Op::For(f) => f.update_ty_unchecked(new_ty), + Op::Call(c) => c.update_ty_unchecked(new_ty), + Op::Fold(f) => f.update_ty_unchecked(new_ty), + Op::Vector(v) => v.update_ty_unchecked(new_ty), + Op::Matrix(m) => m.update_ty_unchecked(new_ty), + Op::Accessor(a) => a.update_ty_unchecked(new_ty), + Op::BusOp(_) => {}, + Op::Parameter(p) => p.update_ty_unchecked(new_ty), + Op::Value(v) => v.update_ty_unchecked(new_ty), + Op::Cast(_) => {}, + Op::None(n) => n.update_ty_unchecked(new_ty), + } + } +} + +impl Typing for Op { + fn ty(&self) -> Option { + match self { + Op::Enf(e) => e.ty(), + Op::Boundary(b) => b.ty(), + Op::Add(a) => a.ty(), + Op::Sub(s) => s.ty(), + Op::Mul(m) => m.ty(), + Op::Exp(e) => e.ty(), + Op::If(i) => i.ty(), + Op::For(f) => f.ty(), + Op::Call(c) => c.ty(), + Op::Fold(f) => f.ty(), + Op::Vector(v) => v.ty(), + Op::Matrix(m) => m.ty(), + Op::Accessor(a) => a.ty(), + Op::BusOp(_) => ty!(?), + Op::Parameter(p) => p.ty(), + Op::Value(v) => v.ty(), + Op::Cast(c) => c.ty(), + Op::None(n) => n.ty(), + } + } +} + impl Link { /// Debug the current [Op], showing [std::cell::RefCell]'s `@{pointer}` and inner struct. /// This is useful to debug shared mutability issues. @@ -155,6 +261,7 @@ impl Link { Op::BusOp(b) => format!("Op::BusOp@{}({:#?})", self.get_ptr(), b), Op::Parameter(p) => format!("Op::Parameter@{}({:#?})", self.get_ptr(), p), Op::Value(v) => format!("Op::Value@{}({:#?})", self.get_ptr(), v), + Op::Cast(c) => format!("Op::Cast@{}({:#?})", self.get_ptr(), c), Op::None(_) => "Op::None".to_string(), } } @@ -230,6 +337,9 @@ impl Link { Op::Value(value) => { value._node = Singleton::from(node.clone()); }, + Op::Cast(cast) => { + cast._node = Singleton::from(node.clone()); + }, Op::None(_) => {}, } } @@ -280,6 +390,9 @@ impl Link { }, Op::Parameter(_parameter) => {}, Op::Value(_value) => {}, + Op::Cast(cast) => { + cast._owner = Singleton::from(owner.clone()); + }, Op::None(_) => {}, } } @@ -385,7 +498,13 @@ impl Link { value._node = Singleton::from(node.clone()); node }, - Op::None(span) => Node::None(*span).into(), + Op::Cast(Cast { _node: Singleton(Some(link)), .. }) => link.clone(), + Op::Cast(cast) => { + let node: Link = Node::Cast(back).into(); + cast._node = Singleton::from(node.clone()); + node + }, + Op::None(none) => Node::None(none.clone()).into(), } } /// Try getting the current [Op]'s [Owner] variant, @@ -479,6 +598,12 @@ impl Link { }, Op::Parameter(_) => None, Op::Value(_) => None, + Op::Cast(Cast { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), + Op::Cast(cast) => { + let owner: Link = Owner::Cast(back).into(); + cast._owner = Singleton::from(owner.clone()); + cast._owner.0.clone() + }, Op::None(_) => None, } } @@ -738,13 +863,29 @@ impl Link { _ => None, }) } + /// Try getting the current [Op]'s inner [Cast]. + /// Returns None if the current [Op] is not a [Cast] or the Rc count is zero. + pub fn as_cast(&self) -> Option> { + get_inner(self.borrow(), |op| match op { + Op::Cast(inner) => Some(inner), + _ => None, + }) + } + /// Try getting the current [Op]'s inner [Cast], borrowing mutably. + /// Returns None if the current [Op] is not a [Cast] or the Rc count is zero. + pub fn as_cast_mut(&self) -> Option> { + get_inner_mut(self.borrow_mut(), |op| match op { + Op::Cast(inner) => Some(inner), + _ => None, + }) + } } impl From for Link { fn from(value: i64) -> Self { Op::Value(Value { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(value as u64)), + value: MirValue::Constant(ConstantValue::Scalar(value as u64)), ..Default::default() }, ..Default::default() diff --git a/mir/src/ir/nodes/ops/accessor.rs b/mir/src/ir/nodes/ops/accessor.rs index aea462cf4..6f51c921b 100644 --- a/mir/src/ir/nodes/ops/accessor.rs +++ b/mir/src/ir/nodes/ops/accessor.rs @@ -1,9 +1,10 @@ use std::hash::Hash; -use air_parser::ast::AccessType; +use air_parser::ast::{Access, AccessType}; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent accessing a given op, `indexable`, in two different ways: /// - access_type: AccessType, which describes for example how to access a given index for a Vector @@ -20,6 +21,35 @@ pub struct Accessor { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Accessor { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self._ty.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Accessor { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._ty = new_ty; + } +} + +impl Typing for Accessor { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Accessor { + fn finalize_hook(&mut self) { + self._ty = self + .indexable + .borrow() + .ty() + .map(|ty| ty.access(self.access_type.clone()).unwrap()); + } } impl Accessor { @@ -29,14 +59,15 @@ impl Accessor { offset: usize, span: SourceSpan, ) -> Link { - Op::Accessor(Self { - access_type, + let mut accessor = Self { indexable, + access_type, offset, span, ..Default::default() - }) - .into() + }; + accessor.finalize_hook(); + Op::Accessor(accessor).into() } } diff --git a/mir/src/ir/nodes/ops/add.rs b/mir/src/ir/nodes/ops/add.rs index e5bfe1ebf..ecca45aa1 100644 --- a/mir/src/ir/nodes/ops/add.rs +++ b/mir/src/ir/nodes/ops/add.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the addition of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,46 @@ pub struct Add { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Add { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self._bin_ty.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Add { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._bin_ty.update_ty_unchecked(new_ty); + } +} + +impl Typing for Add { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Add { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Add(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_add().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Add { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Add(Self { lhs, rhs, span, ..Default::default() }).into() + let mut add = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + add.finalize_hook(); + Op::Add(add).into() } } diff --git a/mir/src/ir/nodes/ops/boundary.rs b/mir/src/ir/nodes/ops/boundary.rs index 8785fa018..d3582ab7b 100644 --- a/mir/src/ir/nodes/ops/boundary.rs +++ b/mir/src/ir/nodes/ops/boundary.rs @@ -1,9 +1,10 @@ use std::hash::Hash; use air_parser::ast::Boundary as BoundaryKind; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent bounding a given op, `expr`, to access either the first or last row /// @@ -18,6 +19,31 @@ pub struct Boundary { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Boundary { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self._ty.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Boundary { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._ty = new_ty; + } +} + +impl Typing for Boundary { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Boundary { + fn finalize_hook(&mut self) { + self._ty = self.expr.borrow().ty(); + } } impl Hash for Boundary { @@ -32,7 +58,9 @@ impl Hash for Boundary { impl Boundary { pub fn create(expr: Link, kind: BoundaryKind, span: SourceSpan) -> Link { - Op::Boundary(Self { expr, kind, span, ..Default::default() }).into() + let mut boundary = Self { expr, kind, span, ..Default::default() }; + boundary.finalize_hook(); + Op::Boundary(boundary).into() } } diff --git a/mir/src/ir/nodes/ops/bus_op.rs b/mir/src/ir/nodes/ops/bus_op.rs index a868275ca..fd986d37a 100644 --- a/mir/src/ir/nodes/ops/bus_op.rs +++ b/mir/src/ir/nodes/ops/bus_op.rs @@ -2,7 +2,9 @@ use std::hash::Hash; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{ + BackLink, Builder, BuilderHook, Bus, Child, Link, Node, Op, Owner, Parent, Singleton, +}; #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] pub enum BusOpKind { @@ -25,6 +27,8 @@ pub struct BusOp { pub span: SourceSpan, } +impl BuilderHook for BusOp {} + impl Hash for BusOp { fn hash(&self, state: &mut H) { self.bus.get_name().hash(state); diff --git a/mir/src/ir/nodes/ops/call.rs b/mir/src/ir/nodes/ops/call.rs index efa169328..6c2b26b44 100644 --- a/mir/src/ir/nodes/ops/call.rs +++ b/mir/src/ir/nodes/ops/call.rs @@ -1,6 +1,9 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{ + BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Root, Singleton, +}; /// A MIR operation to represent a call to a given function, a `Root` that represents either a /// `Function` or an `Evaluator` @@ -21,17 +24,43 @@ pub struct Call { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Call { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._ty.update_scalar_ty(new_sty); + } +} + +impl TypeMut for Call { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._ty = new_ty; + } +} + +impl Typing for Call { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Call { + fn finalize_hook(&mut self) { + self._ty = self.function.borrow().ty(); + } } impl Call { pub fn create(function: Link, arguments: Vec>, span: SourceSpan) -> Link { - Op::Call(Self { + let mut call = Self { function, arguments: Link::new(arguments), span, ..Default::default() - }) - .into() + }; + call.finalize_hook(); + Op::Call(call).into() } } diff --git a/mir/src/ir/nodes/ops/cast.rs b/mir/src/ir/nodes/ops/cast.rs new file mode 100644 index 000000000..8663638d8 --- /dev/null +++ b/mir/src/ir/nodes/ops/cast.rs @@ -0,0 +1,51 @@ +use air_types::*; +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; + +#[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] +#[enum_wrapper(Op)] +pub struct Cast { + pub parents: Vec>, + /// The value being cast + pub value: Link, + ty: Option, + pub _node: Singleton, + pub _owner: Singleton, + #[span] + pub span: SourceSpan, +} + +impl BuilderHook for Cast {} + +impl Cast { + pub fn create(value: Link, ty: Option, span: SourceSpan) -> Link { + let cast = Self { value, ty, span, ..Default::default() }; + Link::new(Op::Cast(cast)) + } +} +impl Typing for Cast { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl Parent for Cast { + type Child = Op; + fn children(&self) -> Link>> { + Link::new(vec![self.value.clone()]) + } +} + +impl Child for Cast { + type Parent = Owner; + fn get_parents(&self) -> Vec> { + self.parents.clone() + } + fn add_parent(&mut self, parent: Link) { + self.parents.push(parent.into()); + } + fn remove_parent(&mut self, parent: Link) { + self.parents.retain(|p| *p != parent.clone().into()); + } +} diff --git a/mir/src/ir/nodes/ops/enf.rs b/mir/src/ir/nodes/ops/enf.rs index 0206a4023..9bac06685 100644 --- a/mir/src/ir/nodes/ops/enf.rs +++ b/mir/src/ir/nodes/ops/enf.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to enforce that a given MIR op, `expr` equals zero #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -12,11 +13,38 @@ pub struct Enf { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Enf { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Enf { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._ty = new_ty; + } +} + +impl Typing for Enf { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Enf { + fn finalize_hook(&mut self) { + self._ty = self.expr.borrow().ty(); + } } impl Enf { pub fn create(expr: Link, span: SourceSpan) -> Link { - Op::Enf(Self { expr, span, ..Default::default() }).into() + let mut enf = Self { expr, span, ..Default::default() }; + enf.finalize_hook(); + Op::Enf(enf).into() } } diff --git a/mir/src/ir/nodes/ops/exp.rs b/mir/src/ir/nodes/ops/exp.rs index c888d1cf6..12615c068 100644 --- a/mir/src/ir/nodes/ops/exp.rs +++ b/mir/src/ir/nodes/ops/exp.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the exponentiation of a MIR op, `lhs` by another, `rhs` /// @@ -15,11 +16,46 @@ pub struct Exp { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Exp { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._bin_ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Exp { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._bin_ty.update_ty_unchecked(new_ty); + } +} + +impl Typing for Exp { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Exp { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Exp(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_exp().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Exp { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Exp(Self { lhs, rhs, span, ..Default::default() }).into() + let mut exp = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + exp.finalize_hook(); + Op::Exp(exp).into() } } diff --git a/mir/src/ir/nodes/ops/fold.rs b/mir/src/ir/nodes/ops/fold.rs index 0ca0a7c9d..343370c64 100644 --- a/mir/src/ir/nodes/ops/fold.rs +++ b/mir/src/ir/nodes/ops/fold.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent folding a given Vector operator according to a given operator and /// initial value @@ -21,8 +22,29 @@ pub struct Fold { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for Fold { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Fold { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for Fold { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for Fold {} + #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] pub enum FoldOperator { Add, diff --git a/mir/src/ir/nodes/ops/for_op.rs b/mir/src/ir/nodes/ops/for_op.rs index 9ab9a200d..25d39c798 100644 --- a/mir/src/ir/nodes/ops/for_op.rs +++ b/mir/src/ir/nodes/ops/for_op.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent list comprehensions. /// @@ -20,8 +21,29 @@ pub struct For { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for For { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for For { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for For { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for For {} + impl For { pub fn create( iterators: Link>>, diff --git a/mir/src/ir/nodes/ops/if_op.rs b/mir/src/ir/nodes/ops/if_op.rs index 023f65634..b32dfdc1a 100644 --- a/mir/src/ir/nodes/ops/if_op.rs +++ b/mir/src/ir/nodes/ops/if_op.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent conditional constraints /// @@ -18,8 +19,29 @@ pub struct If { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for If { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for If { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for If { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for If {} + #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] pub struct MatchArm { pub condition: Link, diff --git a/mir/src/ir/nodes/ops/matrix.rs b/mir/src/ir/nodes/ops/matrix.rs index 277783ca4..f6e69edac 100644 --- a/mir/src/ir/nodes/ops/matrix.rs +++ b/mir/src/ir/nodes/ops/matrix.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent a matrix of MIR ops of a given size #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -14,18 +15,44 @@ pub struct Matrix { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _kind: Option, +} + +impl ScalarTypeMut for Matrix { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._kind.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Matrix { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._kind.update_ty_unchecked(new_ty); + } +} + +impl Typing for Matrix { + fn ty(&self) -> Option { + self._kind.ty() + } +} + +impl BuilderHook for Matrix { + fn finalize_hook(&mut self) { + self._kind = self.elements.borrow().kind(); + } } impl Matrix { pub fn create(elements: Vec>, span: SourceSpan) -> Link { let size = elements.len(); - Op::Matrix(Self { + let mut mat = Self { size, elements: Link::new(elements), span, ..Default::default() - }) - .into() + }; + mat.finalize_hook(); + Op::Matrix(mat).into() } } diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 64d1674a3..496ccb4ea 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -3,6 +3,7 @@ mod add; mod boundary; mod bus_op; mod call; +mod cast; mod enf; mod exp; mod fold; @@ -20,6 +21,7 @@ pub use add::Add; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; pub use call::Call; +pub use cast::Cast; pub use enf::Enf; pub use exp::Exp; pub use fold::{Fold, FoldOperator}; @@ -30,7 +32,7 @@ pub use mul::Mul; pub use parameter::Parameter; pub use sub::Sub; pub use value::{ - BusAccess, ConstantValue, MirType, MirValue, PeriodicColumnAccess, PublicInputAccess, + BusAccess, ConstantValue, MirValue, PeriodicColumnAccess, PublicInputAccess, PublicInputTableAccess, SpannedMirValue, TraceAccess, TraceAccessBinding, Value, }; pub use vector::Vector; diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index 247123755..c239174d5 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the multiplication of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,48 @@ pub struct Mul { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Mul { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._bin_ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Mul { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._bin_ty.update_ty_unchecked(new_ty); + } +} + +impl Typing for Mul { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Mul { + fn finalize_hook(&mut self) { + let lty = self.lhs.borrow().infer_ty(); + let rty = self.rhs.borrow().infer_ty(); + self._bin_ty = BinType::Mul(lty.unwrap(), rty.unwrap(), None); + let res = self._bin_ty.infer_bin_ty_mul().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Mul { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Mul(Self { lhs, rhs, span, ..Default::default() }).into() + let mut mul = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + mul.finalize_hook(); + Op::Mul(mul).into() } } diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 939215d27..bac9204dc 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,9 +1,9 @@ use std::hash::{Hash, Hasher}; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use super::MirType; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a `Parameter` in a function or evaluator. /// Also used in If and For loops to represent declared parameters. @@ -15,20 +15,40 @@ pub struct Parameter { pub ref_node: BackLink, /// The position of the `Parameter` in the referred node's `Parameter` list pub position: usize, - /// The type of the `Parameter` - pub ty: MirType, pub _node: Singleton, #[span] pub span: SourceSpan, + /// The type of the `Parameter` + pub ty: Option, +} + +impl ScalarTypeMut for Parameter { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Parameter { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for Parameter { + fn ty(&self) -> Option { + self.ty.ty() + } } +impl BuilderHook for Parameter {} + impl Parameter { - pub fn create(position: usize, ty: MirType, span: SourceSpan) -> Link { + pub fn create(position: usize, ty: Type, span: SourceSpan) -> Link { Op::Parameter(Self { parents: Vec::default(), ref_node: BackLink::none(), position, - ty, + ty: Some(ty), _node: Singleton::none(), span, }) diff --git a/mir/src/ir/nodes/ops/sub.rs b/mir/src/ir/nodes/ops/sub.rs index 94e2c00ad..3782666fe 100644 --- a/mir/src/ir/nodes/ops/sub.rs +++ b/mir/src/ir/nodes/ops/sub.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the subtraction of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,46 @@ pub struct Sub { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Sub { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._bin_ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Sub { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._bin_ty.update_ty_unchecked(new_ty); + } +} + +impl Typing for Sub { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Sub { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Sub(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_sub().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Sub { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Sub(Self { lhs, rhs, span, ..Default::default() }).into() + let mut sub = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + sub.finalize_hook(); + Op::Sub(sub).into() } } diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index c102f5777..18aa45247 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -1,9 +1,8 @@ -use air_parser::ast::{ - self, BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId, -}; +use air_parser::ast::{BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId}; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Bus, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a known value, [Value]. /// @@ -15,8 +14,29 @@ pub struct Value { #[span] pub value: SpannedMirValue, pub _node: Singleton, + pub ty: Option, } +impl ScalarTypeMut for Value { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Value { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for Value { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for Value {} + impl Value { pub fn create(value: SpannedMirValue) -> Link { Op::Value(Self { value, ..Default::default() }).into() @@ -27,7 +47,7 @@ impl From for Value { fn from(value: i64) -> Self { Self { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(value as u64)), + value: MirValue::Constant(ConstantValue::Scalar(value as u64)), span: Default::default(), }, ..Default::default() @@ -104,11 +124,28 @@ impl BusAccess { #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub enum ConstantValue { - Felt(u64), + Scalar(u64), Vector(Vec), Matrix(Vec>), } +impl Typing for ConstantValue { + fn ty(&self) -> Option { + match self { + ConstantValue::Scalar(_) => ty!(uint), + ConstantValue::Vector(v) => ty!(uint[v.len()]), + ConstantValue::Matrix(m) => { + let row_count = m.len(); + if row_count == 0 { + return ty!(uint[usize::MAX, usize::MAX]); + } + let col_count = m.iter().map(|r| r.len()).max().unwrap_or(usize::MAX); + ty!(uint[row_count, col_count]) + }, + } + } +} + /// [TraceAccess] is like SymbolAccess, but is used to describe an access to a specific trace /// column or columns. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -124,11 +161,30 @@ pub struct TraceAccess { /// For example, if accessing a trace column with `a'`, where `a` is bound to a single column, /// the row offset would be `1`, as the `'` modifier indicates the "next" row. pub row_offset: usize, + /// The type of the value being accessed, if known. + /// Defaults to None until the access is resolved. + /// This should only be a felt or [felt; n] type. + ty: Option, } impl TraceAccess { /// Creates a new [TraceAccess]. pub const fn new(segment: TraceSegmentId, column: TraceColumnIndex, row_offset: usize) -> Self { - Self { segment, column, row_offset } + Self { segment, column, row_offset, ty: None } + } +} +impl ScalarTypeMut for TraceAccess { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} +impl TypeMut for TraceAccess { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} +impl Typing for TraceAccess { + fn ty(&self) -> Option { + self.ty.ty() } } @@ -140,6 +196,15 @@ pub struct TraceAccessBinding { /// The number of columns which are bound pub size: usize, } +impl Typing for TraceAccessBinding { + fn ty(&self) -> Option { + if self.size == 1 { + ty!(felt) + } else { + ty!(felt[self.size]) + } + } +} /// Represents a typed value in the MIR. #[derive(Debug, Eq, PartialEq, Clone, Hash, Spanned)] @@ -149,24 +214,6 @@ pub struct SpannedMirValue { pub value: MirValue, } -#[derive(Debug, Default, Eq, PartialEq, Clone, Hash)] -pub enum MirType { - #[default] - Felt, - Vector(usize), - Matrix(usize, usize), -} - -impl From for MirType { - fn from(value: ast::Type) -> Self { - match value { - ast::Type::Felt => MirType::Felt, - ast::Type::Vector(n) => MirType::Vector(n), - ast::Type::Matrix(cols, rows) => MirType::Matrix(cols, rows), - } - } -} - /// Represents an access of a PeriodicColumn, similar in nature to [TraceAccess]. #[derive(Debug, Copy, Clone, PartialEq, PartialOrd, Ord, Eq, Hash)] pub struct PeriodicColumnAccess { @@ -186,10 +233,28 @@ pub struct PublicInputAccess { pub name: Identifier, /// The index of the element in the public input to access pub index: usize, + /// The type of the value being accessed, if known. + /// Defaults to None until the access is resolved. + ty: Option, } impl PublicInputAccess { pub const fn new(name: Identifier, index: usize) -> Self { - Self { name, index } + Self { name, index, ty: None } + } +} +impl ScalarTypeMut for PublicInputAccess { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} +impl TypeMut for PublicInputAccess { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} +impl Typing for PublicInputAccess { + fn ty(&self) -> Option { + self.ty.ty() } } @@ -223,8 +288,14 @@ impl PublicInputTableAccess { impl Default for SpannedMirValue { fn default() -> Self { Self { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: Default::default(), } } } + +impl Typing for PublicInputTableAccess { + fn ty(&self) -> Option { + ty!(felt[usize::MAX, self.num_cols]) + } +} diff --git a/mir/src/ir/nodes/ops/vector.rs b/mir/src/ir/nodes/ops/vector.rs index fc53c3d67..8e4f37e0c 100644 --- a/mir/src/ir/nodes/ops/vector.rs +++ b/mir/src/ir/nodes/ops/vector.rs @@ -1,6 +1,7 @@ +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent a vector of MIR ops of a given size #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,18 +14,44 @@ pub struct Vector { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _kind: Option, +} + +impl ScalarTypeMut for Vector { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self._kind.update_scalar_ty_unchecked(new_sty); + } +} + +impl TypeMut for Vector { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self._kind.update_ty_unchecked(new_ty); + } +} + +impl Typing for Vector { + fn ty(&self) -> Option { + self._kind.ty() + } +} + +impl BuilderHook for Vector { + fn finalize_hook(&mut self) { + self._kind = self.elements.borrow().kind(); + } } impl Vector { pub fn create(elements: Vec>, span: SourceSpan) -> Link { let size = elements.len(); - Op::Vector(Self { + let mut vec = Self { size, elements: Link::new(elements), span, ..Default::default() - }) - .into() + }; + vec.finalize_hook(); + Op::Vector(vec).into() } } diff --git a/mir/src/ir/nodes/root.rs b/mir/src/ir/nodes/root.rs index 0c8a4c9d5..53dc61b85 100644 --- a/mir/src/ir/nodes/root.rs +++ b/mir/src/ir/nodes/root.rs @@ -3,11 +3,12 @@ use std::{ ops::{Deref, DerefMut}, }; -use miden_diagnostics::{SourceSpan, Spanned}; +use air_types::Typing; +use miden_diagnostics::Spanned; use crate::ir::{ - BackLink, Evaluator, Function, Link, Node, Op, Owner, Parent, Singleton, get_inner, - get_inner_mut, + BackLink, BuilderHook, Evaluator, Function, Link, Node, Stale, Op, Owner, Parent, Singleton, + get_inner, get_inner_mut, }; /// The root nodes of the MIR Graph @@ -17,12 +18,22 @@ use crate::ir::{ pub enum Root { Function(Function), Evaluator(Evaluator), - None(SourceSpan), + None(Stale), +} + +impl BuilderHook for Root { + fn finalize_hook(&mut self) { + match self { + Root::Function(f) => f.finalize_hook(), + Root::Evaluator(e) => e.finalize_hook(), + Root::None(_) => {}, + } + } } impl Default for Root { fn default() -> Self { - Root::None(SourceSpan::default()) + Root::None(Stale::default()) } } @@ -37,6 +48,16 @@ impl Parent for Root { } } +impl Typing for Root { + fn ty(&self) -> Option { + match self { + Root::Function(f) => f.ty(), + Root::Evaluator(e) => e.ty(), + Root::None(n) => n.ty(), + } + } +} + impl Link { pub fn debug(&self) -> String { match self.borrow().deref() { @@ -70,7 +91,7 @@ impl Link { e._node = Singleton::from(node.clone()); node }, - Root::None(span) => Node::None(*span).into(), + Root::None(none) => Node::None(none.clone()).into(), } } /// Get the current [Root]'s [Owner] variant @@ -90,7 +111,7 @@ impl Link { e._owner = Singleton::from(owner.clone()); owner }, - Root::None(span) => Owner::None(*span).into(), + Root::None(none) => Owner::None(none.clone()).into(), } } /// Try getting the current [Root]'s inner [Function]. diff --git a/mir/src/ir/nodes/roots/evaluator.rs b/mir/src/ir/nodes/roots/evaluator.rs index 35e0fd6e7..67d1363e7 100644 --- a/mir/src/ir/nodes/roots/evaluator.rs +++ b/mir/src/ir/nodes/roots/evaluator.rs @@ -1,6 +1,7 @@ +use air_types::{FunctionType, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{Builder, BuilderHook, Link, Node, Op, Owner, Parent, Root, Singleton}; /// A MIR Root to represent a Evaluator definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -15,18 +16,29 @@ pub struct Evaluator { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub func_ty: FunctionType, } +impl Typing for Evaluator { + fn ty(&self) -> Option { + self.func_ty.result() + } +} + +impl BuilderHook for Evaluator {} + impl Evaluator { pub fn create( parameters: Vec>>, body: Vec>, span: SourceSpan, + func_ty: FunctionType, ) -> Link { Root::Evaluator(Self { parameters, body: Link::new(body), span, + func_ty, ..Default::default() }) .into() diff --git a/mir/src/ir/nodes/roots/function.rs b/mir/src/ir/nodes/roots/function.rs index 12051f466..251075c23 100644 --- a/mir/src/ir/nodes/roots/function.rs +++ b/mir/src/ir/nodes/roots/function.rs @@ -1,6 +1,7 @@ +use air_types::{FunctionType, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{Builder, BuilderHook, Link, Node, Op, Owner, Parent, Root, Singleton}; /// A MIR Root to represent a Function definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -16,20 +17,31 @@ pub struct Function { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub func_ty: FunctionType, } +impl Typing for Function { + fn ty(&self) -> Option { + self.func_ty.result() + } +} + +impl BuilderHook for Function {} + impl Function { pub fn create( parameters: Vec>, return_type: Link, body: Vec>, span: SourceSpan, + func_ty: FunctionType, ) -> Link { Root::Function(Self { parameters, return_type, body: Link::new(body), span, + func_ty, ..Default::default() }) .into() diff --git a/mir/src/ir/nodes/stale.rs b/mir/src/ir/nodes/stale.rs new file mode 100644 index 000000000..d096421b7 --- /dev/null +++ b/mir/src/ir/nodes/stale.rs @@ -0,0 +1,33 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; +use miden_diagnostics::{SourceSpan, Spanned}; + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Spanned)] +pub struct Stale { + #[span] + pub span: SourceSpan, + pub ty: Option, +} + +impl ScalarTypeMut for Stale { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.ty.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Stale { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} + +impl Typing for Stale { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl Default for Stale { + fn default() -> Self { + Stale { span: Default::default(), ty: None } + } +} diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index d1a2c513b..44d2c8a9e 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -1,8 +1,8 @@ use std::ops::Deref; -use miden_diagnostics::{SourceSpan, Spanned}; +use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root, Stale}; /// The nodes that can own [Op] nodes /// The [Owner] enum does not own it's inner struct to avoid reference cycles, @@ -30,7 +30,8 @@ pub enum Owner { Enf(BackLink), For(BackLink), If(BackLink), - None(SourceSpan), + Cast(BackLink), + None(Stale), } impl Parent for Owner { @@ -53,6 +54,7 @@ impl Parent for Owner { Owner::Matrix(m) => m.children(), Owner::Accessor(a) => a.children(), Owner::BusOp(b) => b.children(), + Owner::Cast(c) => c.children(), Owner::None(_) => Link::default(), } } @@ -78,6 +80,7 @@ impl Child for Owner { Owner::Matrix(m) => m.get_parents(), Owner::Accessor(a) => a.get_parents(), Owner::BusOp(b) => b.get_parents(), + Owner::Cast(c) => c.get_parents(), Owner::None(_) => Vec::default(), } } @@ -99,6 +102,7 @@ impl Child for Owner { Owner::Matrix(m) => m.add_parent(parent), Owner::Accessor(a) => a.add_parent(parent), Owner::BusOp(b) => b.add_parent(parent), + Owner::Cast(c) => c.add_parent(parent), Owner::None(_) => (), } } @@ -120,6 +124,7 @@ impl Child for Owner { Owner::Matrix(m) => m.remove_parent(parent), Owner::Accessor(a) => a.remove_parent(parent), Owner::BusOp(b) => b.remove_parent(parent), + Owner::Cast(c) => c.remove_parent(parent), Owner::None(_) => (), } } @@ -145,6 +150,7 @@ impl PartialEq for Owner { (Owner::Matrix(lhs), Owner::Matrix(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::Accessor(lhs), Owner::Accessor(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::BusOp(lhs), Owner::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), + (Owner::Cast(lhs), Owner::Cast(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::None(_), Owner::None(_)) => true, _ => false, } @@ -171,6 +177,7 @@ impl std::hash::Hash for Owner { Owner::Matrix(m) => m.to_link().hash(state), Owner::Accessor(a) => a.to_link().hash(state), Owner::BusOp(b) => b.to_link().hash(state), + Owner::Cast(c) => c.to_link().hash(state), Owner::None(s) => s.hash(state), } } @@ -199,17 +206,18 @@ impl Link { Op::BusOp(_) => Owner::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => unreachable!(), Op::Value(_) => unreachable!(), - Op::None(span) => Owner::None(*span), + Op::Cast(_) => Owner::Cast(BackLink::from(op_inner_val)), + Op::None(none) => Owner::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { to_update = match root_inner_val.clone().borrow().deref() { Root::Function(_) => Owner::Function(BackLink::from(root_inner_val)), Root::Evaluator(_) => Owner::Evaluator(BackLink::from(root_inner_val)), - Root::None(span) => Owner::None(*span), + Root::None(none) => Owner::None(none.clone()), }; } else { // If the [Owner] is stale, we set it to None - to_update = Owner::None(self.span()); + to_update = Owner::None(Stale { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; @@ -243,6 +251,7 @@ impl Link { Owner::Enf(_) => None, Owner::For(_) => None, Owner::If(_) => None, + Owner::Cast(_) => None, Owner::None(_) => None, } } @@ -266,6 +275,7 @@ impl Link { Owner::Enf(back) => back.to_link(), Owner::For(back) => back.to_link(), Owner::If(back) => back.to_link(), + Owner::Cast(back) => back.to_link(), Owner::None(_) => None, } } @@ -301,6 +311,7 @@ impl BackLink { Owner::Enf(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::For(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::If(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), + Owner::Cast(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::None(_) => 0, }) .unwrap_or(0) diff --git a/mir/src/ir/quad_eval.rs b/mir/src/ir/quad_eval.rs index d5b0774e6..9b37e8cad 100644 --- a/mir/src/ir/quad_eval.rs +++ b/mir/src/ir/quad_eval.rs @@ -113,7 +113,7 @@ impl RandomInputs { }, Op::Value(v) => { match &v.value.value { - MirValue::Constant(ConstantValue::Felt(c)) => { + MirValue::Constant(ConstantValue::Scalar(c)) => { let felt = Felt::new(*c); Ok(const_quad_felt(felt)) }, @@ -220,6 +220,7 @@ impl RandomInputs { ); Err(CompileError::Failed) }, + Op::Cast(cast) => self.eval(cast.value.clone()), } } } diff --git a/mir/src/passes/constant_propagation.rs b/mir/src/passes/constant_propagation.rs index 139eeab65..f44ce5fec 100644 --- a/mir/src/passes/constant_propagation.rs +++ b/mir/src/passes/constant_propagation.rs @@ -88,7 +88,7 @@ impl ConstantPropagation<'_> { match (get_inner_const(&lhs), get_inner_const(&rhs)) { (Some(0), _) | (_, Some(0)) => Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: mul_ref.span, }))), (Some(1), _) => Ok(Some(rhs)), @@ -109,12 +109,12 @@ impl ConstantPropagation<'_> { if let Some(0) = get_inner_const(&lhs) { Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: exp_ref.span, }))) } else if let Some(0) = get_inner_const(&rhs) { Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(1)), + value: MirValue::Constant(ConstantValue::Scalar(1)), span: exp_ref.span, }))) } else { @@ -169,6 +169,7 @@ impl Visitor for ConstantPropagation<'_> { | Node::BusOp(_) | Node::Value(_) | Node::Accessor(_) + | Node::Cast(_) | Node::None(_) => Ok(None), Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { unreachable!( @@ -219,7 +220,7 @@ fn get_inner_const(value: &Link) -> Option { Op::Value(Value { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(c)), + value: MirValue::Constant(ConstantValue::Scalar(c)), .. }, .. @@ -275,7 +276,7 @@ fn try_fold_const_binary_op( }; if let Some(folded) = folded { let new_value = Value::create(SpannedMirValue { - value: MirValue::Constant(crate::ir::ConstantValue::Felt(folded)), + value: MirValue::Constant(crate::ir::ConstantValue::Scalar(folded)), span, }); updated_binary_op = Some(new_value); diff --git a/mir/src/passes/inlining.rs b/mir/src/passes/inlining.rs index 817821477..303d0cfc9 100644 --- a/mir/src/passes/inlining.rs +++ b/mir/src/passes/inlining.rs @@ -7,8 +7,8 @@ use super::{duplicate_node_or_replace, visitor::Visitor}; use crate::{ CompileError, ir::{ - Accessor, Graph, Link, Mir, MirType, MirValue, Node, Op, Parameter, Parent, Root, - SpannedMirValue, TraceAccessBinding, Value, Vector, + Accessor, Graph, Link, Mir, MirValue, Node, Op, Parameter, Parent, Root, SpannedMirValue, + TraceAccessBinding, Type, Value, Vector, }, }; @@ -518,8 +518,8 @@ fn check_evaluator_argument_sizes( } else if let Some(parameter) = child.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; trace_segments_arg_vector_len += size; @@ -538,8 +538,8 @@ fn check_evaluator_argument_sizes( } else if let Some(parameter) = indexable.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; trace_segments_arg_vector_len += size; @@ -640,8 +640,8 @@ fn unpack_evaluator_arguments(args: &[Link]) -> Vec> { } else if let Some(parameter) = indexable.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let _size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index 74c72ffd0..dcb7406f0 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -5,6 +5,7 @@ mod unrolling; mod visitor; use std::{collections::HashMap, ops::Deref}; +use air_types::Typing; pub use constant_propagation::ConstantPropagation; pub use inlining::Inlining; use miden_diagnostics::Spanned; @@ -13,8 +14,8 @@ pub use unrolling::Unrolling; pub use visitor::Visitor; use crate::ir::{ - Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, Mul, - Node, Op, Owner, Parameter, Parent, Sub, Value, Vector, + Accessor, Add, Boundary, BusOp, Call, Cast, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, + Mul, Node, Op, Owner, Parameter, Parent, Sub, Value, Vector, }; /// Helper to duplicate a MIR node and its children recursively @@ -183,7 +184,7 @@ pub fn duplicate_node( .to_link() .unwrap_or_else(|| panic!("invalid ref_node for parameter {parameter:?}",)); let new_param = - Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); + Parameter::create(parameter.position, parameter.ty.unwrap(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref); @@ -200,7 +201,13 @@ pub fn duplicate_node( new_param }, Op::Value(value) => Value::create(value.value.clone()), - Op::None(span) => Op::None(*span).into(), + Op::Cast(cast) => { + let value = cast.value.clone(); + let ty = cast.ty(); + let new_expr = duplicate_node(value, current_replace_map); + Cast::create(new_expr, ty, cast.span()) + }, + Op::None(none) => Op::None(none.clone()).into(), } } @@ -428,7 +435,7 @@ pub fn duplicate_node_or_replace( current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); } else { let new_param = - Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); + Parameter::create(parameter.position, parameter.ty.unwrap(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref.clone()); @@ -447,6 +454,13 @@ pub fn duplicate_node_or_replace( let new_node = Value::create(value.value.clone()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, + Op::Cast(cast) => { + let value = cast.value.clone(); + let ty = cast.ty(); + let new_expr = current_replace_map.get(&value.get_ptr()).unwrap().1.clone(); + let new_node = Cast::create(new_expr, ty, cast.span()); + current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); + }, Op::None(_) => {}, } } diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 7aa947ed4..02b3fc909 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -3,19 +3,21 @@ use std::ops::Deref; use air_parser::{ LexicalScope, - ast::{self, AccessType, TraceSegmentId}, + ast::{self, Access, AccessType, TraceSegmentId}, symbols, }; use air_pass::Pass; +use air_types::*; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use crate::{ CompileError, ir::{ - Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, - Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, - MirType, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, - Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, Cast, + ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, + Matrix, Mir, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, + PublicInputTableAccess, Root, SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, + Type, Value, Vector, }, passes::duplicate_node, }; @@ -146,6 +148,7 @@ impl<'a> MirBuilder<'a> { ast_eval: &'a ast::EvaluatorFunction, ) -> Result, CompileError> { let mut all_params_flatten = Vec::new(); + let mut all_params_ty_flatten = Vec::new(); self.root_name = Some(ident); let mut ev = Evaluator::builder().span(ast_eval.span); @@ -158,16 +161,17 @@ impl<'a> MirBuilder<'a> { let span = binding.name.map_or(SourceSpan::UNKNOWN, |n| n.span()); let params = self.translate_params_ev(span, binding.name.as_ref(), &binding.ty, &mut i)?; - + let param_ty = binding.ty.access(AccessType::Index(0)).ok().or(ty!(felt)); for param in params { all_params_flatten_for_trace_segment.push(param.clone()); all_params_flatten.push(param.clone()); + all_params_ty_flatten.push(param_ty); } } ev = ev.parameters(all_params_flatten_for_trace_segment.clone()); } - let ev = ev.build(); + let ev = ev.func_ty(FunctionType::Evaluator(all_params_ty_flatten)).build(); set_all_ref_nodes(all_params_flatten.clone(), ev.as_owner()); @@ -196,7 +200,7 @@ impl<'a> MirBuilder<'a> { for binding in trace_segment.bindings.iter() { let name = binding.name.as_ref(); match &binding.ty { - ast::Type::Vector(size) => { + Type::Vector(_, size) => { let mut params_vec = Vec::new(); let mut span = SourceSpan::UNKNOWN; for _ in 0..*size { @@ -210,7 +214,7 @@ impl<'a> MirBuilder<'a> { let vector_node = Vector::create(params_vec, span); self.bindings.insert(name.unwrap(), vector_node.clone()); }, - ast::Type::Felt => { + Type::Scalar(_) => { let param = all_params_flatten_for_trace_segment[i].clone(); i += 1; self.bindings.insert(name.unwrap(), param.clone()); @@ -232,6 +236,7 @@ impl<'a> MirBuilder<'a> { ast_func: &'a ast::Function, ) -> Result, CompileError> { let mut params = Vec::new(); + let mut params_ty = Vec::new(); self.root_name = Some(ident); let mut func = Function::builder().span(ast_func.span()); @@ -240,13 +245,17 @@ impl<'a> MirBuilder<'a> { let name = Some(param_ident); let param = self.translate_params_fn(param_ident.span(), name, ty, &mut i)?; params.push(param.clone()); + params_ty.push(Some(*ty)); func = func.parameters(param.clone()); } i += 1; - let ret = Parameter::create(i, self.translate_type(&ast_func.return_type), ast_func.span()); + let ret = Parameter::create(i, ast_func.return_type, ast_func.span()); params.push(ret.clone()); - let func = func.return_type(ret).build(); + let func = func + .return_type(ret) + .func_ty(FunctionType::Function(params_ty, Some(ast_func.return_type))) + .build(); set_all_ref_nodes(params.clone(), func.as_owner()); self.mir.constraint_graph_mut().insert_function(*ident, func.clone())?; @@ -279,25 +288,25 @@ impl<'a> MirBuilder<'a> { &mut self, span: SourceSpan, name: Option<&'a ast::Identifier>, - ty: &ast::Type, + ty: &Type, i: &mut usize, ) -> Result>, CompileError> { match ty { - ast::Type::Felt => { - let param = Parameter::create(*i, MirType::Felt, span); + Type::Scalar(_) => { + let param = Parameter::create(*i, ty!(felt).unwrap(), span); *i += 1; Ok(vec![param]) }, - ast::Type::Vector(size) => { + Type::Vector(_, size) => { let mut params = Vec::new(); for _ in 0..*size { - let param = Parameter::create(*i, MirType::Felt, span); + let param = Parameter::create(*i, ty!(felt).unwrap(), span); *i += 1; params.push(param); } Ok(params) }, - ast::Type::Matrix(_rows, _cols) => { + Type::Matrix(..) => { let span = if let Some(name) = name { name.span() } else { @@ -317,21 +326,21 @@ impl<'a> MirBuilder<'a> { &mut self, span: SourceSpan, name: Option<&'a ast::Identifier>, - ty: &ast::Type, + ty: &Type, i: &mut usize, ) -> Result, CompileError> { match ty { - ast::Type::Felt => { - let param = Parameter::create(*i, MirType::Felt, span); + Type::Scalar(_) => { + let param = Parameter::create(*i, *ty, span); *i += 1; Ok(param) }, - ast::Type::Vector(size) => { - let param = Parameter::create(*i, MirType::Vector(*size), span); + Type::Vector(..) => { + let param = Parameter::create(*i, *ty, span); *i += 1; Ok(param) }, - ast::Type::Matrix(_rows, _cols) => { + Type::Matrix(..) => { let span = if let Some(name) = name { name.span() } else { @@ -371,12 +380,8 @@ impl<'a> MirBuilder<'a> { Ok(()) } - fn translate_type(&mut self, ty: &ast::Type) -> MirType { - match ty { - ast::Type::Felt => MirType::Felt, - ast::Type::Vector(size) => MirType::Vector(*size), - ast::Type::Matrix(rows, cols) => MirType::Matrix(*rows, *cols), - } + fn translate_type(&mut self, ty: &ast::Type) -> Type { + *ty } /// Translates a statement and returns the operation. @@ -481,15 +486,16 @@ impl<'a> MirBuilder<'a> { self.bindings.enter(); for (index, binding) in list_comp.bindings.iter().enumerate() { - let binding_node = Parameter::create(index, ast::Type::Felt.into(), binding.span()); + // TODO: extract the type from the bound variable + let binding_node = Parameter::create(index, ty!(felt).unwrap(), binding.span()); params.push(binding_node.clone()); self.bindings.insert(binding, binding_node); } let for_node = For::create( iterator_nodes.into(), - Op::None(list_comp.span()).into(), - Op::None(list_comp.span()).into(), + Op::None(Stale { span: list_comp.span(), ty: None }).into(), + Op::None(Stale { span: list_comp.span(), ty: None }).into(), list_comp.span(), ); set_all_ref_nodes(params, for_node.as_owner().unwrap()); @@ -703,6 +709,7 @@ impl<'a> MirBuilder<'a> { pc.period(), )), }) + .ty(pc.ty()) .build(); Ok(node) } else if let Some(bus) = self.mir.constraint_graph().get_bus_link(&qual_ident) { @@ -711,6 +718,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::BusAccess(BusAccess::new(bus.clone(), access.offset)), }) + .ty(ty!(?)) .build(); Ok(node) } else { @@ -792,7 +800,7 @@ impl<'a> MirBuilder<'a> { })?; }, } - return Ok(Op::None(bin_op.span()).into()); + return Ok(Op::None(Stale { span: bin_op.span(), ty: None }).into()); } } } @@ -824,36 +832,64 @@ impl<'a> MirBuilder<'a> { fn translate_call(&mut self, call: &'a ast::Call) -> Result, CompileError> { // First, resolve the callee, panic if it's not resolved let resolved_callee = call.callee.resolved().unwrap(); - if call.is_builtin() { // If it's a fold operator (Sum / Prod), handle it match call.callee.as_ref().name() { symbols::Sum => { assert_eq!(call.args.len(), 1); + let acc = ast::ConstantExpr::Scalar(0); let iterator_node = self.translate_expr(call.args.first().unwrap())?; - let accumulator_node = - self.translate_const(&ast::ConstantExpr::Scalar(0), call.span())?; + let accumulator_node = self.translate_const(&acc, call.span())?; let node = Fold::builder() .span(call.span()) .iterator(iterator_node) .operator(FoldOperator::Add) .initial_value(accumulator_node) + .ty(acc.ty()) .build(); Ok(node) }, symbols::Prod => { assert_eq!(call.args.len(), 1); + let acc = ast::ConstantExpr::Scalar(1); let iterator_node = self.translate_expr(call.args.first().unwrap())?; - let accumulator_node = - self.translate_const(&ast::ConstantExpr::Scalar(1), call.span())?; + let accumulator_node = self.translate_const(&acc, call.span())?; let node = Fold::builder() .span(call.span()) .iterator(iterator_node) .operator(FoldOperator::Mul) .initial_value(accumulator_node) + .ty(acc.ty()) .build(); Ok(node) }, + symbols::AsBool => { + assert_eq!(call.args.len(), 1); + let x = self.translate_expr(call.args.first().unwrap())?; + // enf x^2 = x + let enforced = + Sub::builder() + .lhs(x.clone()) + .rhs( + Exp::builder() + .lhs(x.clone()) + .rhs(self.translate_const( + &ast::ConstantExpr::Scalar(2), + call.span(), + )?) + .span(call.span()) + .build(), + ) + .span(call.span()) + .build(); + let node = Enf::builder().span(call.span()).expr(enforced).build(); + let _ = self.insert_enforce(node); + let bool_x = duplicate_node(x, &mut Default::default()); + // TODO: cast to a bool + let cast = + Cast::builder().value(bool_x).span(call.span()).ty(ty!(bool)).build(); + Ok(cast) + }, other => unimplemented!("unhandled builtin: {}", other), } } else { @@ -888,6 +924,7 @@ impl<'a> MirBuilder<'a> { } // safe to unwrap because we know it is a Function due to get_function let callee_ref = callee.as_function().unwrap(); + if callee_ref.parameters.len() != arg_nodes.len() { self.diagnostics .diagnostic(Severity::Error) @@ -910,6 +947,27 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } + + let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + let arg_kinds_refs = arg_kinds.iter().collect::>(); + if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("arguments typing mismatch") + .with_primary_label( + call.span(), + format!("called function with arguments {:?}", arg_kinds_refs), + ) + .with_secondary_label( + call.callee.span(), + format!( + "this functions has parameters: {:?}", + callee_ref.func_ty.params() + ), + ) + .emit(); + return Err(CompileError::Failed); + } } else if let Some(callee) = self.mir.constraint_graph().get_evaluator_root(&resolved_callee) { @@ -947,6 +1005,38 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } + let mut arg_kinds = vec![]; + if let Some(first_arg) = arg_nodes.first() { + let Some(v) = first_arg.as_vector() else { + unreachable!( + "expected first argument to be a vector, got {:#?}", + first_arg + ); + }; + for element in v.elements.borrow().iter() { + arg_kinds.push(element.kind().unwrap()); + } + } + // let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + let arg_kinds_refs = arg_kinds.iter().collect::>(); + if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("arguments typing mismatch") + .with_primary_label( + call.span(), + format!("called evaluator with arguments {:?}", arg_kinds_refs), + ) + .with_secondary_label( + call.callee.span(), + format!( + "this evaluator has parameters: {:?}", + callee_ref.func_ty.params() + ), + ) + .emit(); + return Err(CompileError::Failed); + } } else { panic!("Unknown function or evaluator: {:?}", resolved_callee); } @@ -971,7 +1061,8 @@ impl<'a> MirBuilder<'a> { self.bindings.enter(); let mut params = Vec::new(); for (index, binding) in list_comp.bindings.iter().enumerate() { - let binding_node = Parameter::create(index, ast::Type::Felt.into(), binding.span()); + // TODO: extract the type from the bound variable + let binding_node = Parameter::create(index, ty!(felt).unwrap(), binding.span()); params.push(binding_node.clone()); self.bindings.insert(binding, binding_node); } @@ -1008,7 +1099,9 @@ impl<'a> MirBuilder<'a> { scalar_expr: &'a ast::ScalarExpr, ) -> Result, CompileError> { match scalar_expr { - ast::ScalarExpr::Const(c) => self.translate_scalar_const(c.item, c.span()), + ast::ScalarExpr::Const(c) => { + self.translate_scalar_const(c.item, c.span(), scalar_expr.scalar_ty()) + }, ast::ScalarExpr::SymbolAccess(s) => self.translate_symbol_access(s), ast::ScalarExpr::BoundedSymbolAccess(s) => self.translate_bounded_symbol_access(s), ast::ScalarExpr::Binary(b) => self.translate_binary_op(b), @@ -1030,12 +1123,13 @@ impl<'a> MirBuilder<'a> { &mut self, c: u64, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let value = SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(c)), + value: MirValue::Constant(ConstantValue::Scalar(c)), span, }; - let node = Value::builder().value(value).build(); + let node = Value::builder().value(value).ty(ty!(sty)).build(); Ok(node) } @@ -1121,9 +1215,13 @@ impl<'a> MirBuilder<'a> { span: SourceSpan, ) -> Result, CompileError> { match c { - ast::ConstantExpr::Scalar(s) => self.translate_scalar_const(*s, span), - ast::ConstantExpr::Vector(v) => self.translate_vector_const(v.clone(), span), - ast::ConstantExpr::Matrix(m) => self.translate_matrix_const(m.clone(), span), + ast::ConstantExpr::Scalar(s) => self.translate_scalar_const(*s, span, c.scalar_ty()), + ast::ConstantExpr::Vector(v) => { + self.translate_vector_const(v.clone(), span, c.scalar_ty()) + }, + ast::ConstantExpr::Matrix(m) => { + self.translate_matrix_const(m.clone(), span, c.scalar_ty()) + }, } } @@ -1131,10 +1229,11 @@ impl<'a> MirBuilder<'a> { &mut self, v: Vec, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let mut node = Vector::builder().size(v.len()).span(span); for value in v.iter() { - let value_node = self.translate_scalar_const(*value, span)?; + let value_node = self.translate_scalar_const(*value, span, sty)?; node = node.elements(value_node); } Ok(node.build()) @@ -1144,10 +1243,11 @@ impl<'a> MirBuilder<'a> { &mut self, m: Vec>, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let mut node = Matrix::builder().size(m.len()).span(span); for row in m.iter() { - let row_node = self.translate_vector_const(row.clone(), span)?; + let row_node = self.translate_vector_const(row.clone(), span, sty)?; node = node.elements(row_node); } let node = node.build(); @@ -1169,15 +1269,18 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccess(trace_access), }) + .ty(trace_access.ty()) .build()); } if let Some(tab) = self.trace_access_binding(access) { + let typ = tab.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) + .ty(typ) .build()); } @@ -1195,9 +1298,9 @@ impl<'a> MirBuilder<'a> { // // In that case, replacing the default type (Felt) with the one from the access if let Some(mut param) = let_bound_access_expr.as_parameter_mut() - && let Some(access_ty) = &access.ty + && access.ty.is_some() { - param.ty = self.translate_type(access_ty); + param.ty = Some(self.translate_type(access.ty.as_ref().unwrap())); } let accessor: Link = Accessor::create( duplicate_node(let_bound_access_expr, &mut Default::default()), @@ -1215,16 +1318,19 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccess(trace_access), }) + .ty(trace_access.ty()) .build()); } // Otherwise, we check bindings, trace bindings, and public inputs, in that order if let Some(tab) = self.trace_access_binding(access) { + let typ = tab.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) + .ty(typ) .build()); } @@ -1235,14 +1341,17 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::PublicInput(public_input_access), }) + .ty(public_input_access.ty()) .build()); }, (None, Some(public_input_table_access)) => { + let typ = public_input_table_access.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::PublicInputTable(public_input_table_access), }) + .ty(typ) .build()); }, _ => {}, diff --git a/mir/src/passes/unrolling/match_optimizer.rs b/mir/src/passes/unrolling/match_optimizer.rs index d21ca9b65..e00117e92 100644 --- a/mir/src/passes/unrolling/match_optimizer.rs +++ b/mir/src/passes/unrolling/match_optimizer.rs @@ -227,7 +227,7 @@ impl<'a> MatchOptimizer<'a> { // representing `enf x = y` let zero_node = Value::create(SpannedMirValue { span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), }); // The following unwrap is safe as we always have at least one constraint above let new_node_with_sub_zero = Sub::create(cur_node.unwrap(), zero_node, span); diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs index 37b71c20b..49a39de7b 100644 --- a/mir/src/passes/unrolling/unrolling_first_pass.rs +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -1,14 +1,15 @@ use std::{collections::HashMap, ops::Deref}; use air_parser::ast::AccessType; +use air_types::{Type, Typing, ty}; use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; use crate::{ CompileError, ir::{ Accessor, Add, BackLink, Boundary, ConstantValue, Enf, Exp, FoldOperator, Graph, Link, - Matrix, MirType, MirValue, Mul, Node, Op, Owner, Parameter, Parent, RandomInputs, - SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + Matrix, MirValue, Mul, Node, Op, Owner, Parameter, Parent, RandomInputs, SpannedMirValue, + Sub, TraceAccess, TraceAccessBinding, Value, Vector, }, passes::{ Visitor, @@ -57,22 +58,22 @@ fn unroll_trace_access_binding( if trace_access_binding.size == 1 { Value::create(SpannedMirValue { span, - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access_binding.segment, + trace_access_binding.offset, + 0, + )), }) } else { let mut vec = vec![]; for index in 0..trace_access_binding.size { let val = Value::create(SpannedMirValue { span, - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset + index, - row_offset: 0, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access_binding.segment, + trace_access_binding.offset + index, + 0, + )), }); vec.push(val); } @@ -86,7 +87,7 @@ fn unroll_constant_vector(constant_vector: &Vec, span: SourceSpan) -> Link< for val in constant_vector { let val = Value::create(SpannedMirValue { span, - value: MirValue::Constant(ConstantValue::Felt(*val)), + value: MirValue::Constant(ConstantValue::Scalar(*val)), }); vec.push(val); } @@ -101,7 +102,7 @@ fn unroll_constant_matrix(constant_matrix: &Vec>, span: SourceSpan) -> for val in row { let val = Value::create(SpannedMirValue { span, - value: MirValue::Constant(ConstantValue::Felt(*val)), + value: MirValue::Constant(ConstantValue::Scalar(*val)), }); res_row.push(val); } @@ -154,11 +155,11 @@ fn unroll_accessor_default_access_type( if let MirValue::TraceAccess(trace_access) = mir_value { let new_node = Value::create(SpannedMirValue { span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access.segment, + trace_access.column, + trace_access.row_offset + accessor_offset, + )), }); return Some(new_node); } @@ -186,11 +187,11 @@ fn unroll_accessor_index_access_type( MirValue::TraceAccess(trace_access) => { let new_node = Value::create(SpannedMirValue { span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access.segment, + trace_access.column, + trace_access.row_offset + accessor_offset, + )), }); Some(new_node) }, @@ -263,7 +264,7 @@ impl UnrollingFirstPass<'_> { let mir_value = value_ref.value.value.clone(); match &mir_value { MirValue::Constant(c) => match c { - ConstantValue::Felt(_) => {}, + ConstantValue::Scalar(_) => {}, ConstantValue::Vector(v) => { return Ok(Some(unroll_constant_vector(v, value_ref.span()))); }, @@ -522,7 +523,7 @@ impl UnrollingFirstPass<'_> { let mut new_vec = vec![]; for i in 0..iterator_expected_len { let new_node = - Parameter::create(i, MirType::Felt, for_node.as_for().unwrap().deref().span()); + Parameter::create(i, ty!(felt).unwrap(), for_node.as_for().unwrap().deref().span()); new_vec.push(new_node.clone()); let iterators_i = iterators @@ -577,6 +578,14 @@ impl UnrollingFirstPass<'_> { ) -> Result>, CompileError> { Ok(None) // Matrix are already unrolled, we have nothing to do } + + fn visit_cast_bis( + &mut self, + _graph: &mut Graph, + _cast: Link, + ) -> Result>, CompileError> { + Ok(None) + } } impl Visitor for UnrollingFirstPass<'_> { @@ -639,6 +648,7 @@ impl Visitor for UnrollingFirstPass<'_> { to_link_and(p.clone(), graph, |g, el| self.visit_parameter_bis(g, el)) }, Node::Value(v) => to_link_and(v.clone(), graph, |g, el| self.visit_value_bis(g, el)), + Node::Cast(c) => to_link_and(c.clone(), graph, |g, el| self.visit_cast_bis(g, el)), Node::None(_) => Ok(None), Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { unreachable!( @@ -720,9 +730,11 @@ fn compute_iterator_len(iterator: Link) -> usize { AccessType::Matrix(..) => 1, }, Op::Parameter(parameter) => match parameter.ty { - MirType::Felt => 1, - MirType::Vector(l) => l, - MirType::Matrix(l, _) => l, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, l)) => l, + Some(Type::Matrix(_, l, _)) => l, + // NOTE: This should probably be unreachable + None => 1, }, _ => 1, } diff --git a/mir/src/passes/visitor.rs b/mir/src/passes/visitor.rs index 77fabc1e2..c733f857b 100644 --- a/mir/src/passes/visitor.rs +++ b/mir/src/passes/visitor.rs @@ -62,6 +62,7 @@ pub trait Visitor { Node::Accessor(a) => self.visit_accessor(graph, a.clone().into()), Node::BusOp(b) => self.visit_bus_op(graph, b.clone().into()), Node::Parameter(p) => self.visit_parameter(graph, p.clone().into()), + Node::Cast(c) => self.visit_cast(graph, c.clone().into()), Node::Value(v) => self.visit_value(graph, v.clone().into()), Node::None(_) => Ok(()), } @@ -154,6 +155,10 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } + /// Visit a `Cast` node + fn visit_cast(&mut self, _graph: &mut Graph, _cast: Link) -> Result<(), CompileError> { + Ok(()) + } /// Visit a `Value` node fn visit_value(&mut self, _graph: &mut Graph, _value: Link) -> Result<(), CompileError> { Ok(()) diff --git a/mir/src/tests/mod.rs b/mir/src/tests/mod.rs index 60f49f298..d4e26c612 100644 --- a/mir/src/tests/mod.rs +++ b/mir/src/tests/mod.rs @@ -12,6 +12,7 @@ mod pub_inputs; mod selectors; mod source_sections; mod trace; +mod typing; mod variables; use std::sync::Arc; diff --git a/mir/src/tests/trace.rs b/mir/src/tests/trace.rs index 0fb0819bd..710675be7 100644 --- a/mir/src/tests/trace.rs +++ b/mir/src/tests/trace.rs @@ -151,5 +151,5 @@ fn err_ic_trace_cols_group_used_as_scalar() { enf a[0]' = a + clk; }"; - expect_diagnostic(source, "type mismatch"); + expect_diagnostic(source, "invalid binary expression"); } diff --git a/mir/src/tests/typing.rs b/mir/src/tests/typing.rs new file mode 100644 index 000000000..5b8c93b3e --- /dev/null +++ b/mir/src/tests/typing.rs @@ -0,0 +1,33 @@ +use super::compile; + +#[test] +fn test_typing() { + let code = " + def test + + trace_columns { + main: [a, b], + } + + public_inputs { + stack_inputs: [16], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let b2 = as_bool(b); + let c = select(a, b2); + enf c = 42; + } + fn select(x: felt, selector: bool) -> felt { + return x * selector; + } + "; + let Ok(mir) = compile(code) else { + panic!("Failed to compile code: {}", code); + }; + dbg!(&mir.constraint_graph().integrity_constraints_roots); +} diff --git a/parser/Cargo.toml b/parser/Cargo.toml index c50738c40..6b0752713 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -16,6 +16,7 @@ lalrpop = { version = "0.20", default-features = false } [dependencies] air-pass = { package = "air-pass", path = "../pass", version = "0.5" } +air-types.workspace = true either = "1.12" lalrpop-util = "0.20" lazy_static = "1.4" diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index 3939c7860..4a2cc8896 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -144,11 +144,15 @@ impl Constant { pub const fn new(span: SourceSpan, name: Identifier, value: ConstantExpr) -> Self { Self { span, name, value } } - +} +impl Typing for Constant { /// Gets the type of the value associated with this constant - pub fn ty(&self) -> Type { + fn ty(&self) -> Option { self.value.ty() } + fn kind(&self) -> Option { + self.value.kind() + } } impl Eq for Constant {} impl PartialEq for Constant { @@ -168,25 +172,24 @@ pub enum ConstantExpr { Vector(Vec), Matrix(Vec>), } -impl ConstantExpr { +impl Typing for ConstantExpr { /// Gets the type of this expression - pub fn ty(&self) -> Type { + fn ty(&self) -> Option { match self { - Self::Scalar(_) => Type::Felt, - Self::Vector(elems) => Type::Vector(elems.len()), + Self::Scalar(_) => ty!(uint), + Self::Vector(elems) => ty!(uint[elems.len()]), Self::Matrix(rows) => { let num_rows = rows.len(); let num_cols = rows.first().unwrap().len(); - Type::Matrix(num_rows, num_cols) + ty!(uint[num_rows, num_cols]) }, } } - - /// Returns true if this expression is of aggregate type - pub fn is_aggregate(&self) -> bool { - matches!(self, Self::Vector(_) | Self::Matrix(_)) + fn kind(&self) -> Option { + self.ty().kind() } } + impl fmt::Display for ConstantExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -255,14 +258,21 @@ impl Export<'_> { Self::Evaluator(item) => item.name, } } - +} +impl Typing for Export<'_> { /// Returns the type of the value associated with this export /// /// NOTE: Evaluator functions have no return value, so they have no type associated. /// For this reason, this function returns `Option` rather than `Type`. - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::Constant(item) => Some(item.ty()), + Self::Constant(item) => item.ty(), + Self::Evaluator(_) => None, + } + } + fn kind(&self) -> Option { + match self { + Self::Constant(item) => item.kind(), Self::Evaluator(_) => None, } } @@ -296,6 +306,11 @@ impl PartialEq for PeriodicColumn { self.name == other.name && self.values == other.values } } +impl Typing for PeriodicColumn { + fn ty(&self) -> Option { + ty!(felt[self.period()]) + } +} /// Declaration of a public input for an AirScript program. /// @@ -373,16 +388,19 @@ pub struct EvaluatorFunction { pub name: Identifier, pub params: Vec, pub body: Vec, + pub fn_ty: FunctionType, } impl EvaluatorFunction { /// Creates a new function. - pub const fn new( + pub fn new( span: SourceSpan, name: Identifier, params: Vec, body: Vec, ) -> Self { - Self { span, name, params, body } + let param_tys = params.iter().map(|ty| ty.ty()).collect::>(); + let fn_ty = FunctionType::Evaluator(param_tys); + Self { span, name, params, body, fn_ty } } } impl Eq for EvaluatorFunction {} @@ -391,6 +409,14 @@ impl PartialEq for EvaluatorFunction { self.name == other.name && self.params == other.params && self.body == other.body } } +impl Typing for EvaluatorFunction { + fn ty(&self) -> Option { + None + } + fn kind(&self) -> Option { + Some(Kind::Callable(self.fn_ty.clone())) + } +} /// Functions take a group of expressions as parameters and returns a value. /// @@ -405,17 +431,28 @@ pub struct Function { pub params: Vec<(Identifier, Type)>, pub return_type: Type, pub body: Vec, + pub fn_ty: FunctionType, } impl Function { /// Creates a new function. - pub const fn new( + pub fn new( span: SourceSpan, name: Identifier, params: Vec<(Identifier, Type)>, return_type: Type, body: Vec, ) -> Self { - Self { span, name, params, return_type, body } + let p = params.iter().map(|(_, ty)| ty.ty()).collect::>(); + let r = return_type.ty(); + let fn_ty = FunctionType::Function(p, r); + Self { + span, + name, + params, + return_type, + body, + fn_ty, + } } pub fn param_types(&self) -> Vec { diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 49b0f285a..8c145f89d 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -160,7 +160,7 @@ impl QualifiedIdentifier { if self.module.name() == "$builtin" { match self.item { NamespacedIdentifier::Function(id) => { - matches!(id.name(), symbols::Sum | symbols::Prod) + matches!(id.name(), symbols::Sum | symbols::Prod | symbols::AsBool) }, _ => false, } @@ -318,31 +318,27 @@ impl Expr { _ => false, } } - +} +impl Typing for Expr { /// Returns the resolved type of this expression, if known - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::Const(constant) => Some(constant.ty()), + Self::Const(constant) => constant.ty(), Self::Range(range) => range.ty(), - Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) { - Some(Type::Felt) => Some(Type::Vector(vector.len())), - Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)), - Some(_) => None, - None => Some(Type::Vector(0)), - }, - Self::Matrix(matrix) => { - let rows = matrix.len(); - let cols = matrix[0].len(); - Some(Type::Matrix(rows, cols)) - }, + Self::Vector(vector) => vector.ty(), + Self::Matrix(matrix) => matrix.ty(), Self::SymbolAccess(access) => access.ty, - Self::Binary(_) => Some(Type::Felt), + Self::Binary(bin_expr) => bin_expr.ty(), Self::Call(call) => call.ty, Self::ListComprehension(lc) => lc.ty, Self::Let(let_expr) => let_expr.ty(), - Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => Some(Type::Felt), + Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => ty!(felt), } } + + fn kind(&self) -> Option { + self.ty().kind() + } } impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -529,30 +525,30 @@ impl ScalarExpr { _ => false, } } - +} +impl Typing for ScalarExpr { /// Returns the resolved type of this expression, if known. /// /// Returns `Ok(Some)` if the type could be resolved without conflict. /// Returns `Ok(None)` if type information was missing. /// Returns `Err` if the type could not be resolved due to a conflict, /// with a span covering the source of the conflict. - pub fn ty(&self) -> Result, SourceSpan> { + fn infer_ty(&self) -> Result, TypeError> { match self { - Self::Const(_) => Ok(Some(Type::Felt)), + Self::Const(_) => Ok(ty!(uint)), Self::SymbolAccess(sym) => Ok(sym.ty), Self::BoundedSymbolAccess(sym) => Ok(sym.column.ty), - Self::Binary(expr) => match (expr.lhs.ty()?, expr.rhs.ty()?) { - (None, _) | (_, None) => Ok(None), - (Some(lty), Some(rty)) if lty == rty => Ok(Some(lty)), - _ => Err(expr.span()), - }, + Self::Binary(expr) => expr.infer_ty(), Self::Call(expr) => Ok(expr.ty), Self::Let(expr) => Ok(expr.ty()), Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { - Ok(Some(Type::Felt)) + Ok(ty!(felt)) }, } } + fn ty(&self) -> Option { + self.infer_ty().ok()? + } } impl TryFrom for ScalarExpr { type Error = InvalidExprError; @@ -702,15 +698,6 @@ impl RangeExpr { self.try_into() .expect("attempted to convert non-constant range expression to constant") } - - pub fn ty(&self) -> Option { - match (&self.start, &self.end) { - (RangeBound::Const(start), RangeBound::Const(end)) => { - Some(Type::Vector(end.item.abs_diff(start.item))) - }, - _ => None, - } - } } impl From for RangeExpr { fn from(range: Range) -> Self { @@ -738,6 +725,16 @@ impl fmt::Display for RangeExpr { write!(f, "{}..{}", &self.start, &self.end) } } +impl Typing for RangeExpr { + fn ty(&self) -> Option { + match (&self.start, &self.end) { + (RangeBound::Const(start), RangeBound::Const(end)) => { + ty!(uint[end.item.abs_diff(start.item)]) + }, + _ => None, + } + } +} #[derive(Hash, Clone, Spanned, PartialEq, Eq, Debug)] pub enum RangeBound { @@ -776,15 +773,38 @@ pub struct BinaryExpr { pub op: BinaryOp, pub lhs: Box, pub rhs: Box, + pub bin_ty: Option, } impl BinaryExpr { pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self { - Self { + let mut res = Self { span, op, lhs: Box::new(lhs), rhs: Box::new(rhs), + bin_ty: None, + }; + res.update_bin_ty(); + res + } + pub fn update_bin_ty(&mut self) -> Option { + let lhs = self.lhs.as_ref(); + let rhs = self.rhs.as_ref(); + let op = self.op; + if !(lhs.ty().is_some() || rhs.ty().is_some() || (lhs.is_scalar() && rhs.is_scalar())) { + return None; } + let l_ty = lhs.ty(); + let r_ty = rhs.ty(); + let bin_ty = Some(match op { + BinaryOp::Eq => bty!(any:l_ty = any:r_ty), + BinaryOp::Add => bty!(any:l_ty + any:r_ty), + BinaryOp::Sub => bty!(any:l_ty - any:r_ty), + BinaryOp::Mul => bty!(any:l_ty * any:r_ty), + BinaryOp::Exp => bty!(any:l_ty ^ any:r_ty), + }); + self.bin_ty = bin_ty; + bin_ty } /// Returns true if this binary expression could expand to a block, e.g. due to a function call @@ -814,6 +834,34 @@ impl fmt::Display for BinaryExpr { write!(f, "{} {} {}", &self.lhs, &self.op, &self.rhs) } } +impl Typing for BinaryExpr { + fn infer_ty(&self) -> Result, TypeError> { + match self.bin_ty { + Some(ref bty) => bty.infer_ty(), + None => Ok(None), + } + } + fn scalar_ty(&self) -> Option { + self.bin_ty.scalar_ty() + } + fn ty(&self) -> Option { + self.bin_ty.ty() + } +} +impl ScalarTypeMut for BinaryExpr { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + if let Some(bty) = self.bin_ty.as_mut() { + bty.update_scalar_ty_unchecked(new_ty); + } + } +} +impl TypeMut for BinaryExpr { + fn update_ty_unchecked(&mut self, new_ty: Option) { + if let Some(bty) = self.bin_ty.as_mut() { + bty.update_ty_unchecked(new_ty); + } + } +} #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum BinaryOp { @@ -858,6 +906,12 @@ impl fmt::Display for Boundary { } } +pub trait Access { + type Accessed; + /// Return a new [Type] representing the type of the value produced by the given [AccessType] + fn access(&self, access_type: AccessType) -> Result; +} + /// Represents the way an identifier is accessed/referenced in the source. #[derive(Hash, Debug, Clone, Eq, PartialEq, Default)] pub enum AccessType { @@ -973,51 +1027,55 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Index(idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Index(idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, AccessType::Slice(range) => { let slice_range = range.to_slice_range(); let rlen = slice_range.end - slice_range.start; + // TODO: check if this is valid: + // let rlen = slice_range.end.abs_diff(slice_range.start); match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if slice_range.end > len => { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if slice_range.end > len => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Vector(_) => Ok(Self { + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Slice(range), - ty: Some(Type::Vector(rlen)), + ty: ty!(sty[rlen]), ..self.clone() }), - Type::Matrix(rows, _) if slice_range.end > rows => { + Type::Matrix(_, rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Slice(range), - ty: Some(Type::Matrix(rlen, cols)), + ty: ty!(sty[rlen, cols]), ..self.clone() }), } }, AccessType::Matrix(row, col) => match ty { - Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { + Type::Scalar(_) | Type::Vector(..) => Err(InvalidAccessError::IndexIntoScalar), + Type::Matrix(_, rows, cols) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(..) => Ok(Self { + Type::Matrix(sty, ..) => Ok(Self { access_type: AccessType::Matrix(row, col), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), }, @@ -1033,17 +1091,19 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => unreachable!(), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => unreachable!(), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Index(base_range.start + idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Index(base_range.start + idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, @@ -1059,33 +1119,33 @@ impl SymbolAccess { end: RangeBound::Const(Span::new(range.end.span(), end)), }; match ty { - Type::Felt => unreachable!(), - Type::Vector(_) if slice_range.end > blen => { + Type::Scalar(_) => unreachable!(), + Type::Vector(..) if slice_range.end > blen => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Vector(_) => Ok(Self { + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Slice(shifted), - ty: Some(Type::Vector(rlen)), + ty: ty!(sty[rlen]), ..self.clone() }), - Type::Matrix(rows, _) if slice_range.end > rows => { + Type::Matrix(_, rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Slice(shifted), - ty: Some(Type::Matrix(rlen, cols)), + ty: ty!(sty[rlen, cols]), ..self.clone() }), } }, AccessType::Matrix(row, col) => match ty { - Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { + Type::Scalar(_) | Type::Vector(..) => Err(InvalidAccessError::IndexIntoScalar), + Type::Matrix(_, rows, cols) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(..) => Ok(Self { + Type::Matrix(sty, ..) => Ok(Self { access_type: AccessType::Matrix(row, col), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), }, @@ -1101,17 +1161,19 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, @@ -1365,6 +1427,7 @@ impl Call { match callee.name() { symbols::Sum => Self::sum(span, args), symbols::Prod => Self::prod(span, args), + symbols::AsBool => Self::as_bool(span, args), _ => Self { span, callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)), @@ -1383,13 +1446,28 @@ impl Call { /// Constructs a function call for the `sum` reducer/fold #[inline] pub fn sum(span: SourceSpan, args: Vec) -> Self { - Self::new_builtin(span, "sum", args, Type::Felt) + // TODO: adapt to the new type system and use BinType instead of this type + Self::new_builtin(span, "sum", args, ty!(felt).unwrap()) } /// Constructs a function call for the `prod` reducer/fold #[inline] pub fn prod(span: SourceSpan, args: Vec) -> Self { - Self::new_builtin(span, "prod", args, Type::Felt) + // TODO: adapt to the new type system and use BinType instead of this type + Self::new_builtin(span, "prod", args, ty!(felt).unwrap()) + } + + /// Constructs a function call for `as_bool`. + /// An `as_bool(x)` is equivalent to an `enf x^2 = x plus a cast from felt to bool`. + #[inline] + pub fn as_bool(span: SourceSpan, args: Vec) -> Self { + //Self::new_builtin(span, "as_bool", args, ty!(felt).unwrap()) + let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin")); + let name = Identifier::new(span, Symbol::intern("as_bool")); + let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name)); + let callee = ResolvableIdentifier::Resolved(id); + let ty = ty!(bool); + Self { span, callee, args, ty } } fn new_builtin(span: SourceSpan, name: &str, args: Vec, ty: Type) -> Self { diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 795dbd8d2..3a806dcd7 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -200,7 +200,24 @@ impl Let { pub fn new(span: SourceSpan, name: Identifier, value: Expr, body: Vec) -> Self { Self { span, name, value, body } } +} +impl Eq for Let {} +impl PartialEq for Let { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.value == other.value && self.body == other.body + } +} +impl fmt::Debug for Let { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Let") + .field("name", &self.name) + .field("value", &self.value) + .field("body", &self.body) + .finish() + } +} +impl Typing for Let { /// Return the type of the overall `let` expression. /// /// A `let` with an empty body, or with a body that terminates with a non-expression statement @@ -210,7 +227,7 @@ impl Let { /// For `let` statements with a non-empty body that terminates with an expression, the `let` can /// be used in expression position, producing the value of the terminating expression in its /// body, and having the same type as that value. - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { let mut last = self.body.last(); while let Some(stmt) = last.take() { match stmt { @@ -228,18 +245,3 @@ impl Let { None } } -impl Eq for Let {} -impl PartialEq for Let { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.value == other.value && self.body == other.body - } -} -impl fmt::Debug for Let { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Let") - .field("name", &self.name) - .field("value", &self.value) - .field("body", &self.body) - .finish() - } -} diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index f4f420c17..8e81a195a 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -1,5 +1,6 @@ use std::fmt; +use air_types::{Kind, Typing, tty, ty}; use miden_diagnostics::{SourceSpan, Spanned}; use super::*; @@ -65,9 +66,14 @@ impl TraceSegment { for binding in raw_bindings.into_iter() { let (name, size) = binding.item; let ty = match size { - 1 => Type::Felt, - n => Type::Vector(n), - }; + 1 => tty!(name), + n => tty!(name[n]), + } + .unwrap_or_else(|| { + unreachable!( + "Trace segment binding types should always be known, but got None for {name} with size {size}" + ) + }); bindings.push(TraceBinding::new(binding.span(), name, id, offset, size, ty)); offset += size; } @@ -118,6 +124,14 @@ impl TraceSegment { self.size == 0 } } +impl Typing for TraceSegment { + fn ty(&self) -> Option { + match self.size { + 1 => self.bindings.first().map(|b| b.ty())?, + _ => ty!(felt[self.size]), + } + } +} impl fmt::Debug for TraceSegment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("TraceSegment") @@ -250,20 +264,22 @@ impl TraceBinding { ty, } } +} +impl Typing for TraceBinding { /// Returns a [Type] that describes what type of value this binding represents - #[inline] - pub fn ty(&self) -> Type { - self.ty + fn ty(&self) -> Option { + Some(self.ty) } - - #[inline] - pub fn is_scalar(&self) -> bool { - self.ty.is_scalar() + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) } +} +impl Access for TraceBinding { + type Accessed = Self; /// Derive a new [TraceBinding] derived from the current one given an [AccessType] - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match access_type { AccessType::Default => Ok(*self), AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), @@ -277,7 +293,7 @@ impl TraceBinding { Ok(Self { offset, size, - ty: Type::Vector(size), + ty: ty!(felt[size]).unwrap(), ..*self }) } @@ -286,7 +302,12 @@ impl TraceBinding { AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(idx) => { let offset = self.offset + idx; - Ok(Self { offset, size: 1, ty: Type::Felt, ..*self }) + Ok(Self { + offset, + size: 1, + ty: ty!(felt).unwrap(), + ..*self + }) }, AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), } diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index 3155180ec..7a86fc1f5 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -1,106 +1,45 @@ -use super::*; - -/// The types of values which can be represented in an AirScript program -#[derive(Hash, Debug, Copy, Clone, PartialEq, Eq)] -pub enum Type { - /// A field element - Felt, - /// A vector of N integers - Vector(usize), - /// A matrix of N rows and M columns - Matrix(usize, usize), -} -impl Type { - /// Returns true if this type is an aggregate - #[inline] - pub fn is_aggregate(&self) -> bool { - match self { - Self::Felt => false, - Self::Vector(_) | Self::Matrix(..) => true, - } - } +pub use air_types::{bty, fty, kind, sty, tty, ty, tys, *}; - /// Returns true if this type is a scalar - #[inline] - pub fn is_scalar(&self) -> bool { - matches!(self, Self::Felt) - } - - /// Returns true if this type is a valid iterable in a comprehension - #[inline] - pub fn is_iterable(&self) -> bool { - self.is_vector() - } - - /// Returns true if this type is a vector - #[inline] - pub fn is_vector(&self) -> bool { - matches!(self, Self::Vector(_)) - } +use super::*; +impl Access for Type { + type Accessed = Self; /// Return a new [Type] representing the type of the value produced by the given [AccessType] - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match *self { ty if access_type == AccessType::Default => Ok(ty), - Self::Felt => Err(InvalidAccessError::IndexIntoScalar), - Self::Vector(len) => match access_type { + Self::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Self::Vector(sty, len) => match access_type { AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > len { Err(InvalidAccessError::IndexOutOfBounds) } else { - Ok(Self::Vector(slice_range.len())) + Ok(Self::Vector(sty, slice_range.len())) } }, AccessType::Index(idx) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(_) => Ok(Self::Felt), + AccessType::Index(_) => Ok(Self::Scalar(sty)), AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), _ => unreachable!(), }, - Self::Matrix(rows, cols) => match access_type { + Self::Matrix(sty, rows, cols) => match access_type { AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > rows { Err(InvalidAccessError::IndexOutOfBounds) } else { - Ok(Self::Matrix(slice_range.len(), cols)) + Ok(Self::Matrix(sty, slice_range.len(), cols)) } }, AccessType::Index(idx) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(_) => Ok(Self::Vector(cols)), + AccessType::Index(_) => Ok(Self::Vector(sty, cols)), AccessType::Matrix(row, col) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - AccessType::Matrix(..) => Ok(Self::Felt), + AccessType::Matrix(..) => Ok(Self::Scalar(sty)), _ => unreachable!(), }, } } } -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Felt => f.write_str("felt"), - Self::Vector(n) => write!(f, "felt[{n}]"), - Self::Matrix(rows, cols) => write!(f, "felt[{rows}, {cols}]"), - } - } -} - -/// Represents the type signature of a function -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum FunctionType { - /// An evaluator function, which has no results, and has - /// a complex type signature due to the nature of trace bindings - Evaluator(Vec), - /// A standard function with one or more inputs, and a result - Function(Vec, Type), -} -impl FunctionType { - pub fn result(&self) -> Option { - match self { - Self::Evaluator(_) => None, - Self::Function(_, result) => Some(*result), - } - } -} diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index f354cf6f4..5683b30b9 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -153,9 +153,14 @@ pub enum Token { Match, Case, When, - Felt, With, + // SCALAR TYPES + // -------------------------------------------------------------------------------------------- + Felt, + Bool, + UInt, + // PUNCTUATION // -------------------------------------------------------------------------------------------- Quote, @@ -196,6 +201,8 @@ impl Token { "ev" => Self::Ev, "fn" => Self::Fn, "felt" => Self::Felt, + "bool" => Self::Bool, + "uint" => Self::UInt, "buses" => Self::Buses, "multiset" => Self::Multiset, "logup" => Self::Logup, @@ -275,6 +282,8 @@ impl fmt::Display for Token { Self::Ev => write!(f, "ev"), Self::Fn => write!(f, "fn"), Self::Felt => write!(f, "felt"), + Self::Bool => write!(f, "bool"), + Self::UInt => write!(f, "uint"), Self::Buses => write!(f, "buses"), Self::Multiset => write!(f, "multiset"), Self::Logup => write!(f, "logup"), diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index c1351b31c..ddd4cc94f 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -249,9 +249,15 @@ FunctionBinding: (Identifier, Type) = { } FunctionBindingType: Type = { - "felt" => Type::Felt, - "felt" => Type::Vector(size as usize), - "felt" "[" "," "]" => Type::Matrix(row_size as usize, col_size as usize), + => Type::Scalar(sty), + => Type::Vector(sty, size as usize), + "[" "," "]" => Type::Matrix(sty, row_size as usize, col_size as usize), +} + +ScalarType: Option = { + "felt" => Some(ScalarType::Felt), + "uint" => Some(ScalarType::UInt), + "bool" => Some(ScalarType::Bool), } FunctionBody: Vec = { @@ -685,6 +691,8 @@ extern { "when" => Token::When, "with" => Token::With, "felt" => Token::Felt, + "uint" => Token::UInt, + "bool" => Token::Bool, "'" => Token::Quote, "=" => Token::Equal, "+" => Token::Plus, diff --git a/parser/src/parser/tests/constant_propagation.rs b/parser/src/parser/tests/constant_propagation.rs index faaeb9d92..6e37da6a7 100644 --- a/parser/src/parser/tests/constant_propagation.rs +++ b/parser/src/parser/tests/constant_propagation.rs @@ -78,22 +78,22 @@ fn test_constant_propagation() { // enf a.first = 1 expected .boundary_constraints - .push(enforce!(eq!(bounded_access!(a, Boundary::First, Type::Felt), int!(1)))); + .push(enforce!(eq!(bounded_access!(a, Boundary::First, ty!(felt).unwrap()), int!(1)))); // When constant propagation is done, the integrity constraints should look like: // enf test_constraint(b) // enf a + 4 = c + 5 expected .integrity_constraints - .push(enforce!(call!(lib::test_constraint(expr!(access!(b, Type::Vector(2))))))); + .push(enforce!(call!(lib::test_constraint(expr!(access!(b, ty!(felt[2]).unwrap())))))); expected.integrity_constraints.push(enforce!(eq!( - add!(access!(a, Type::Felt), int!(4)), - add!(access!(c, Type::Felt), int!(5)) + add!(access!(a, ty!(felt).unwrap()), int!(4)), + add!(access!(c, ty!(felt).unwrap()), int!(5)) ))); // The test_constraint function should look like: // enf b0 + 2 = b1 + 4 let body = vec![enforce!(eq!( - add!(access!(b0, Type::Felt), int!(2)), - add!(access!(b1, Type::Felt), int!(4)) + add!(access!(b0, ty!(felt).unwrap()), int!(2)), + add!(access!(b1, ty!(felt).unwrap()), int!(4)) ))]; expected.evaluators.insert( function_ident!(lib, test_constraint), diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs index 93e0560f2..5fb542a09 100644 --- a/parser/src/parser/tests/functions.rs +++ b/parser/src/parser/tests/functions.rs @@ -21,8 +21,8 @@ fn fn_def_with_scalars() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_scalars), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), access!(b))))], ), ); @@ -44,8 +44,8 @@ fn fn_def_with_vectors() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_vectors), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Vector(12), + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt[12]).unwrap(), vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!(access!(x), access!(y)))))], ), @@ -85,8 +85,8 @@ fn fn_use_scalars_and_vectors() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_scalars_and_vectors), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!( lc!(((x, expr!(access!(b)))) => add!(access!(a), access!(x))) )))))], @@ -152,8 +152,8 @@ fn fn_call_in_fn() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_vec), - vec![(ident!(a), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -163,8 +163,8 @@ fn fn_call_in_fn() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_scalar_and_vec), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); @@ -234,8 +234,8 @@ fn fn_call_in_ev() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_vec), - vec![(ident!(a), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -245,8 +245,8 @@ fn fn_call_in_ev() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_scalar_and_vec), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); @@ -319,8 +319,8 @@ fn fn_as_lc_iterables() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt).unwrap())], + ty!(felt).unwrap(), vec![let_!(x = expr!(add!(exp!(access!(a), access!(b)), int!(1))) => return_!(expr!(exp!(access!(b), access!(x)))))], ), @@ -390,8 +390,8 @@ fn fn_call_in_binary_ops() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!( lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( access!(x), @@ -466,8 +466,8 @@ fn fn_call_in_vector_def() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Vector(12), + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt[12]).unwrap(), vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( access!(x), access!(y) diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index 80bc49c60..d3a7eed70 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -698,15 +698,16 @@ fn full_air_file() { // enf clk' = clk + 1 // } expected.integrity_constraints.push(enforce!(eq!( - access!(clk, 1, Type::Felt), - add!(access!(clk, Type::Felt), int!(1)) + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), int!(1)) ))); // boundary_constraints { // enf clk.first = 0 // } - expected - .boundary_constraints - .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(clk, Boundary::First, ty!(felt).unwrap()), + int!(0) + ))); ParseTest::new().expect_program_ast_from_file("src/parser/tests/input/system.air", expected); } diff --git a/parser/src/parser/tests/modules.rs b/parser/src/parser/tests/modules.rs index f9a124f76..492147ebd 100644 --- a/parser/src/parser/tests/modules.rs +++ b/parser/src/parser/tests/modules.rs @@ -70,10 +70,10 @@ fn modules_integration_test() { vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce_if!(match_arm!( eq!( - access!(clk, 1, Type::Felt), - add!(access!(clk, Type::Felt), access!(bar, k0, Type::Felt)) + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), access!(bar, k0, ty!(felt).unwrap())) ), - access!(bar, k0, Type::Felt) + access!(bar, k0, ty!(felt).unwrap()) ))], ), ); @@ -87,8 +87,11 @@ fn modules_integration_test() { ident!(foo_constraint), vec![trace_segment!(TraceSegmentId::Main, "%0", [(clk, 1)])], vec![enforce_if!(match_arm!( - eq!(access!(clk, 1, Type::Felt), add!(access!(clk, Type::Felt), int!(1))), - access!(foo, k0, Type::Felt) + eq!( + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), int!(1)) + ), + access!(foo, k0, ty!(felt).unwrap()) ))], ), ); @@ -97,13 +100,14 @@ fn modules_integration_test() { .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected .integrity_constraints - .push(enforce!(call!(foo::foo_constraint(vector!(access!(clk, Type::Felt)))))); + .push(enforce!(call!(foo::foo_constraint(vector!(access!(clk, ty!(felt).unwrap())))))); expected .integrity_constraints - .push(enforce!(call!(bar::bar_constraint(vector!(access!(clk, Type::Felt)))))); - expected - .boundary_constraints - .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); + .push(enforce!(call!(bar::bar_constraint(vector!(access!(clk, ty!(felt).unwrap())))))); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(clk, Boundary::First, ty!(felt).unwrap()), + int!(0) + ))); ParseTest::new() .expect_program_ast_from_file("src/parser/tests/input/import_example.air", expected); diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 2659634ef..6d529a0b6 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,6 +1,10 @@ use std::fmt; -use crate::ast::{AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, Type}; +use air_types::*; + +use crate::ast::{ + Access, AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, TraceSegment, Type, +}; /// This type provides type and contextual information about a binding, /// i.e. not only does it tell us the type of a binding, but what type @@ -20,6 +24,7 @@ pub enum BindingType { /// /// The result type is None if the function is an evaluator Function(FunctionType), + Evaluator(Vec), /// A binding to a bus definition Bus(BusType), /// A function parameter corresponding to trace columns @@ -33,22 +38,43 @@ pub enum BindingType { /// A direct reference to a periodic column PeriodicColumn(usize), } -impl BindingType { + +impl Typing for BindingType { + fn kind(&self) -> Option { + match self { + Self::Alias(aliased) => aliased.kind(), + Self::Local(ty) => ty.kind(), + Self::Constant(ty) => ty.kind(), + Self::Function(func) => func.kind(), + Self::Evaluator(ev) => { + Some(Kind::Callable(FunctionType::Evaluator(ev.iter().map(|tb| tb.ty()).collect()))) + }, + Self::Bus(_) => self.ty().kind(), + Self::TraceColumn(tb) | Self::TraceParam(tb) => tb.kind(), + Self::Vector(elems) => elems.kind(), + Self::PublicInput(ty) => ty.kind(), + Self::PeriodicColumn(_) => Some(kind!(felt)), + } + } /// Get the value type of this binding, if applicable - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::TraceColumn(tb) | Self::TraceParam(tb) => Some(tb.ty()), - Self::Vector(elems) => Some(Type::Vector(elems.len())), + Self::TraceColumn(tb) | Self::TraceParam(tb) => tb.ty(), + Self::Vector(elems) => elems.ty(), Self::Alias(aliased) => aliased.ty(), Self::Local(ty) | Self::Constant(ty) | Self::PublicInput(ty) => Some(*ty), - Self::PeriodicColumn(_) => Some(Type::Felt), - Self::Function(ty) => ty.result(), - Self::Bus(_) => Some(Type::Felt), + Self::PeriodicColumn(_) => ty!(felt), + Self::Function(_) => None, + Self::Evaluator(_) => None, + Self::Bus(_) => None, } } +} +impl Access for BindingType { + type Accessed = Self; /// Produce a new [BindingType] which represents accessing the current binding via `access_type` - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match self { Self::Alias(aliased) => aliased.access(access_type), Self::Local(ty) => ty.access(access_type).map(Self::Local), @@ -81,19 +107,22 @@ impl BindingType { AccessType::Default => Ok(Self::PeriodicColumn(*period)), _ => Err(InvalidAccessError::IndexIntoScalar), }, - Self::Function(_) => Err(InvalidAccessError::InvalidBinding), + Self::Function(_) | Self::Evaluator(_) => Err(InvalidAccessError::InvalidBinding), Self::Bus(bus) => Ok(Self::Bus(*bus)), } } } + impl fmt::Display for BindingType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // TODO: Update to reflect the type signature match self { Self::Alias(aliased) => write!(f, "{aliased}"), Self::Local(_) => f.write_str("local"), Self::Constant(_) => f.write_str("constant"), Self::Vector(_) => f.write_str("vector"), Self::Function(_) => f.write_str("function"), + Self::Evaluator(_) => f.write_str("evaluator"), Self::TraceColumn(_) | Self::TraceParam(_) => f.write_str("trace column(s)"), Self::PublicInput(_) => f.write_str("public input(s)"), Self::PeriodicColumn(_) => f.write_str("periodic column(s)"), diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index aa48175c9..4258ccf15 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -168,7 +168,7 @@ impl VisitMut for SemanticAnalysis<'_> { name: Some(segment.name), offset: 0, size: segment.size, - ty: Type::Vector(segment.size), + ty: ty!(felt[segment.size]).unwrap(), }) ), None @@ -194,7 +194,7 @@ impl VisitMut for SemanticAnalysis<'_> { assert_eq!( self.locals.insert( NamespacedIdentifier::Binding(input.name()), - BindingType::PublicInput(Type::Vector(input.size())) + BindingType::PublicInput(ty!(felt[input.size()]).unwrap()) ), None ); @@ -215,7 +215,8 @@ impl VisitMut for SemanticAnalysis<'_> { } // It should be impossible for there to be a local by this name at this point assert_eq!( - self.locals.insert(namespaced_name, BindingType::Constant(constant.ty())), + self.locals + .insert(namespaced_name, BindingType::Constant(constant.ty().unwrap())), None ); } @@ -229,10 +230,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.declaration_import_conflict(namespaced_name.span(), prev.span())?; } assert_eq!( - self.locals.insert( - namespaced_name, - BindingType::Function(FunctionType::Evaluator(function.params.clone())) - ), + self.locals + .insert(namespaced_name, BindingType::Function(function.fn_ty.clone())), None ); } @@ -243,13 +242,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.declaration_import_conflict(namespaced_name.span(), prev.span())?; } assert_eq!( - self.locals.insert( - namespaced_name, - BindingType::Function(FunctionType::Function( - function.param_types(), - function.return_type - )) - ), + self.locals + .insert(namespaced_name, BindingType::Function(function.fn_ty.clone())), None ); } @@ -569,19 +563,30 @@ impl VisitMut for SemanticAnalysis<'_> { let iterable = &expr.iterables[i]; let iterable_ty = iterable.ty().unwrap(); - if let Some(expected_ty) = result_ty.replace(iterable_ty) - && expected_ty != iterable_ty - { + let lowest_common_supertype = if result_ty.is_some() { + result_ty.lowest_common_supertype(&iterable_ty) + } else { + // If the result type is None, then we use the iterable type as the default + // This means that either: + // - we encountered an error previously, + // - or this is the first iterable we are processing + Some(iterable_ty) + }; + if lowest_common_supertype.is_none() { + // If the lowest common supertype is None, and the result type is Some, + // then the types are incompatible self.has_type_errors = true; // Note: We don't break here but at the end of the module's compilation, as we // want to continue to gather as many errors as possible let _ = self.type_mismatch( - Some(&iterable_ty), + result_ty.as_ref(), iterable.span(), - &expected_ty, + &iterable_ty, expr.iterables[0].span(), expr.span(), ); + } else { + result_ty = lowest_common_supertype; } match self.expr_binding_type(iterable) { Ok(iterable_binding_ty) => { @@ -620,7 +625,7 @@ impl VisitMut for SemanticAnalysis<'_> { // If we were unable to determine a type for any of the bindings, use a large vector as a // placeholder - let expected = BindingType::Local(result_ty.unwrap_or(Type::Vector(u32::MAX as usize))); + let expected = BindingType::Local(result_ty.unwrap_or(ty!(_[u32::MAX as usize]).unwrap())); // Bind everything now, resolving any deferred types using our fallback expected type for (binding, _, binding_ty) in binding_tys.drain(..) { @@ -644,8 +649,8 @@ impl VisitMut for SemanticAnalysis<'_> { // Store the result type of this comprehension result_ty = match result_ty { - Some(Type::Vector(_)) => result_ty, - Some(Type::Matrix(rows, _)) => Some(Type::Vector(rows)), + Some(Type::Vector(..)) => result_ty, + Some(Type::Matrix(sty, rows, _)) => ty!(sty[rows]), _ => None, }; expr.ty = result_ty; @@ -665,7 +670,7 @@ impl VisitMut for SemanticAnalysis<'_> { match callee_binding_ty { Ok(ref binding_ty) => { let derived_from = binding_ty.span(); - if let BindingType::Function(ref fty) = binding_ty.item { + if let Some(Kind::Callable(ref fty)) = binding_ty.item.kind() { // There must be an evaluator by this name let qid = expr.callee.resolved().unwrap(); // Builtin functions are ignored here @@ -719,7 +724,7 @@ impl VisitMut for SemanticAnalysis<'_> { // * Must be trace bindings or aliases of same // * Must match the type signature of the callee if let Ok(ty) = callee_binding_ty - && let BindingType::Function(FunctionType::Evaluator(ref params)) = ty.item + && let BindingType::Evaluator(params) = ty.item { for (arg, param) in expr.args.iter().zip(params.iter()) { self.validate_evaluator_argument(expr.span(), arg, param)?; @@ -735,27 +740,69 @@ impl VisitMut for SemanticAnalysis<'_> { ) -> ControlFlow { self.visit_mut_scalar_expr(expr.lhs.as_mut())?; self.visit_mut_scalar_expr(expr.rhs.as_mut())?; - + let _ = expr.update_bin_ty(); // Validate the operand types - match (expr.lhs.ty(), expr.rhs.ty()) { - (Ok(Some(lty)), Ok(Some(rty))) => { - if lty != rty { - self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we - // want to continue to gather as many errors as possible - let _ = self.type_mismatch( - Some(<y), + match expr.infer_ty() { + Ok(None) => { + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid binary expression") + .with_primary_label(expr.span(), "unable to infer type for binary expression") + .with_secondary_label( expr.lhs.span(), - &rty, + format!("this expression has type: {}", expr.lhs.show_ty()), + ) + .with_secondary_label( expr.rhs.span(), - expr.span(), - ); - } + format!("this expression has type: {}", expr.rhs.show_ty()), + ) + .emit(); + ControlFlow::Continue(()) + }, + Err(err) => { + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid binary expression") + .with_primary_label(expr.span(), format!("{err}")) + .with_secondary_label( + expr.lhs.span(), + format!("this expression has type: {}", expr.lhs.show_ty()), + ) + .with_secondary_label( + expr.rhs.span(), + format!("this expression has type: {}", expr.rhs.show_ty()), + ) + .emit(); ControlFlow::Continue(()) }, - _ => ControlFlow::Continue(()), + Ok(_) => ControlFlow::Continue(()), } } + // match (expr.lhs.ty(), expr.rhs.ty()) { + // (Some(lty), Some(rty)) => { + // if lty != rty { + // self.has_type_errors = true; + // // Note: We don't break here but at the end of the module's compilation, as + // we // want to continue to gather as many errors as possible + // let _ = self.type_mismatch( + // Some(<y), + // expr.lhs.span(), + // &rty, + // expr.rhs.span(), + // expr.span(), + // ); + // } + // ControlFlow::Continue(()) + // }, + // _ => ControlFlow::Continue(()), + // } + // } fn visit_mut_range_bound( &mut self, @@ -799,8 +846,8 @@ impl VisitMut for SemanticAnalysis<'_> { .with_primary_label( expr.span(), format!( - "constant is not a valid range bound: expected scalar, got {}", - const_expr.ty() + "constant is not a valid range bound: expected uint, got {}", + const_expr.show_ty() ), ) .emit(); @@ -942,7 +989,9 @@ impl VisitMut for SemanticAnalysis<'_> { // be captured as a vector of size 1 AccessType::Slice(ref range) => { let range = range.to_slice_range(); - assert_eq!(expr.ty.replace(Type::Vector(range.len())), None) + let sty = expr.ty.scalar_ty(); + let new_ty = ty!(sty[range.len()]).unwrap(); + assert_eq!(expr.ty.replace(new_ty), None) }, // All other access types can be derived from the binding type _ => assert_eq!(expr.ty.replace(binding_ty.ty().unwrap()), None), @@ -958,14 +1007,14 @@ impl VisitMut for SemanticAnalysis<'_> { .with_secondary_label(derived_from, "references this declaration") .emit(); // Continue with a fabricated type - let ty = match &expr.access_type { + let new_ty = match &expr.access_type { AccessType::Slice(range) => { let range = range.to_slice_range(); - Type::Vector(range.len()) + ty!(felt[range.len()]).unwrap() }, - _ => Type::Felt, + _ => ty!(felt).unwrap(), }; - assert_eq!(expr.ty.replace(ty), None); + assert_eq!(expr.ty.replace(new_ty), None); ControlFlow::Continue(()) }, } @@ -1010,6 +1059,7 @@ impl VisitMut for SemanticAnalysis<'_> { // These binding types are module-local declarations BindingType::Constant(_) | BindingType::Function(_) + | BindingType::Evaluator(_) | BindingType::PeriodicColumn(_) => { *expr = ResolvableIdentifier::Resolved(QualifiedIdentifier::new( current_module, @@ -1127,6 +1177,56 @@ impl SemanticAnalysis<'_> { }, } }, + // The known built-in cast functions - each takes a single argument, which + // must be a subtype of the expected type + symbols::AsBool => { + match call.args.as_slice() { + [arg] => { + match self.expr_binding_type(arg) { + Ok(binding_ty) => { + if !binding_ty.ty().map(|t| t.is_scalar()).unwrap_or(false) { + self.has_type_errors = true; + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid call") + .with_primary_label( + call.span(), + "this function expects an argument of scalar type", + ) + .with_secondary_label( + arg.span(), + format!( + "but this argument is a {}", + binding_ty.show_kind() + ), + ) + .emit(); + } + }, + Err(e) => { + eprintln!("error: {e}"); + // We've already raised a diagnostic for this when visiting the + // access expression + assert!(self.has_undefined_variables || self.has_type_errors); + }, + } + }, + _ => { + self.has_type_errors = true; + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid call") + .with_primary_label( + call.span(), + format!( + "the callee expects a single argument, but got {}", + call.args.len() + ), + ) + .emit(); + }, + } + }, other => unimplemented!("unrecognized builtin function: {}", other), } ControlFlow::Continue(()) @@ -1227,9 +1327,9 @@ impl SemanticAnalysis<'_> { // Note: We don't break here but at the end of the module's compilation, // as we want to continue to gather as many errors as possible let _ = self.type_mismatch( - Some(&Type::Vector(param.size)), + ty!(_[param.size]).as_ref(), arg.span(), - &Type::Vector(size), + &ty!(_[size]).unwrap(), param.span(), span, ); @@ -1243,7 +1343,7 @@ impl SemanticAnalysis<'_> { param.id, 0, param.size, - Type::Vector(param.size), + ty!(felt[param.size]).unwrap(), )); // Note: We don't break here but at the end of the module's compilation, as // we want to continue to gather as many errors as possible @@ -1367,9 +1467,9 @@ impl SemanticAnalysis<'_> { } else { let inferred = tb.ty(); return self.type_mismatch( - Some(&inferred), + inferred.as_ref(), access.span(), - &Type::Felt, + &ty!(felt).unwrap(), ty.span(), constraint_span, ); @@ -1386,7 +1486,7 @@ impl SemanticAnalysis<'_> { TraceSegmentId::Main, 0, 1, - Type::Felt, + ty!(felt).unwrap(), )); return self.binding_mismatch( &aty, @@ -1484,7 +1584,7 @@ impl SemanticAnalysis<'_> { self.type_mismatch( Some(ty), access.span(), - &Type::Felt, + &ty!(_).unwrap(), found.span(), constraint_span, )?; @@ -1503,7 +1603,7 @@ impl SemanticAnalysis<'_> { self.type_mismatch( access.ty.as_ref(), access.span(), - &Type::Felt, + &ty!(_).unwrap(), access.name.span(), constraint_span, )?; @@ -1586,6 +1686,7 @@ impl SemanticAnalysis<'_> { None => { // If the call was resolved, it must be to an imported function, // and we will have already validated the reference + dbg!(&id); let (import_id, module_id) = self.imported.get_key_value(&id).unwrap(); let module = self.library.get(module_id).unwrap(); if !module.evaluators.contains_key(&id.id()) { @@ -1776,9 +1877,11 @@ impl SemanticAnalysis<'_> { fn expr_binding_type(&self, expr: &Expr) -> Result { match expr { - Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), + Expr::Const(constant) => { + Ok(BindingType::Local(constant.ty().expect("constant type should be known"))) + }, Expr::Range(range) => { - Ok(BindingType::Local(Type::Vector(range.to_slice_range().len()))) + Ok(BindingType::Local(ty!(uint[range.to_slice_range().len()]).unwrap())) }, Expr::Vector(elems) => { let mut binding_tys = Vec::with_capacity(elems.len()); @@ -1789,14 +1892,12 @@ impl SemanticAnalysis<'_> { Ok(BindingType::Vector(binding_tys)) }, Expr::Matrix(expr) => { - let rows = expr.len(); - let columns = expr[0].len(); - Ok(BindingType::Local(Type::Matrix(rows, columns))) + Ok(BindingType::Local(expr.ty().expect("matrix type should be known"))) }, Expr::SymbolAccess(expr) => self.access_binding_type(expr), Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), - Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)), + Expr::Binary(be) => Ok(BindingType::Local(be.ty().or(ty!(felt)).unwrap())), Expr::ListComprehension(lc) => { match lc.ty { Some(ty) => Ok(BindingType::Local(ty)), @@ -1816,8 +1917,9 @@ impl SemanticAnalysis<'_> { .emit(); Err(InvalidAccessError::InvalidBinding) }, - Expr::BusOperation(_expr) => Ok(BindingType::Local(Type::Felt)), - Expr::Null(_) | Expr::Unconstrained(_) => Ok(BindingType::Local(Type::Felt)), + Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => { + Ok(BindingType::Local(ty!(felt).unwrap())) + }, } } @@ -1882,10 +1984,14 @@ impl SemanticAnalysis<'_> { // it elsewhere. For the time being, functions are not // implemented, so the only place this comes up is with these // list folding builtins - let folder_ty = - FunctionType::Function(vec![Type::Vector(usize::MAX)], Type::Felt); + let folder_ty = FunctionType::Function(vec![ty!(felt[usize::MAX])], ty!(felt)); Ok(Span::new(qid.span(), BindingType::Function(folder_ty))) }, + symbols::AsBool => { + // An `as_bool(x)` is equivalent to an `enf x^2 = x and + // a cast from felt to bool`. + Ok(Span::new(qid.span(), BindingType::Function(fty!(fn(felt) -> bool)))) + }, name => unimplemented!("unsupported builtin: {}", name), } } else { @@ -1896,14 +2002,17 @@ impl SemanticAnalysis<'_> { imported_from .constants .get(qid.as_ref()) - .map(|c| Span::new(c.span(), BindingType::Constant(c.ty()))) + .map(|c| { + Span::new( + c.span(), + BindingType::Constant(c.ty().expect("constant type should be known")), + ) + }) .or_else(|| { - imported_from.evaluators.get(qid.as_ref()).map(|e| { - Span::new( - e.span(), - BindingType::Function(FunctionType::Evaluator(e.params.clone())), - ) - }) + imported_from + .evaluators + .get(qid.as_ref()) + .map(|e| Span::new(e.span(), BindingType::Evaluator(e.params.clone()))) }) .ok_or(InvalidAccessError::UndefinedVariable) } diff --git a/parser/src/symbols.rs b/parser/src/symbols.rs index 0b5022fc0..982a25561 100644 --- a/parser/src/symbols.rs +++ b/parser/src/symbols.rs @@ -17,9 +17,16 @@ pub mod predefined { pub const Sum: Symbol = Symbol::new(2); /// The symbol `prod` pub const Prod: Symbol = Symbol::new(3); - - pub(super) const __SYMBOLS: &[(Symbol, &str)] = - &[(Main, "$main"), (Builtin, "$builtin"), (Sum, "sum"), (Prod, "prod")]; + /// The symbol `as_bool` + pub const AsBool: Symbol = Symbol::new(4); + + pub(super) const __SYMBOLS: &[(Symbol, &str)] = &[ + (Main, "$main"), + (Builtin, "$builtin"), + (Sum, "sum"), + (Prod, "prod"), + (AsBool, "as_bool"), + ]; } pub use self::predefined::*; diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 098025220..cf7e3b900 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -402,6 +402,26 @@ impl VisitMut for ConstantPropagation<'_> { } } }, + symbols::AsBool => { + assert_eq!(call.args.len(), 1); + match &call.args[0] { + // If the assertion is a constant 0 or 1, it's valid + // TODO: if we start allowing casts from a uint to a bool, we should + // fold the assertion to a rebind of type bool if it is 0 or 1, + // and raise a diagnostic if it is not + Expr::Const(Span { item: ConstantExpr::Scalar(0 | 1), .. }) => {}, + // If the assertion is not 0 or 1, emit an error + Expr::Const(Span { item: ConstantExpr::Scalar(_), .. }) => { + self.diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("assertion failed") + .with_primary_label(span, "assertion failed") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + }, + _ => {}, + } + }, invalid => unimplemented!("unknown builtin function: {invalid}"), } ControlFlow::Continue(()) @@ -443,14 +463,8 @@ impl VisitMut for ConstantPropagation<'_> { } if is_constant { - let ty = match vector.first().and_then(|e| e.ty()).unwrap() { - Type::Felt => Type::Vector(vector.len()), - Type::Vector(n) => Type::Matrix(vector.len(), n), - _ => unreachable!(), - }; - - let new_expr = match ty { - Type::Vector(_) => ConstantExpr::Vector( + let new_expr = match vector.ty().expect("vector type must be known") { + Type::Vector(..) => ConstantExpr::Vector( vector .iter() .map(|expr| match expr { diff --git a/types/Cargo.toml b/types/Cargo.toml new file mode 100644 index 000000000..e7b546375 --- /dev/null +++ b/types/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "air-types" +version = "1.0.0" +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dependencies] +miden-diagnostics = { workspace = true } + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/types/src/lib.rs b/types/src/lib.rs new file mode 100644 index 000000000..cef20fee8 --- /dev/null +++ b/types/src/lib.rs @@ -0,0 +1,1025 @@ +mod types; + +use std::{ + cell::{Ref, RefCell, RefMut}, + fmt::Debug, + ops::{Deref, DerefMut}, +}; + +use miden_diagnostics::{SourceSpan, Span}; +pub use types::*; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum TypeError { + IncompatibleScalarTypes { + lhs: Option, + rhs: Option, + span: Option, + }, + IncompatibleShapes { + lhs: Option, + rhs: Option, + span: Option, + }, + IncompatibleType { + lhs: Option, + rhs: Option, + span: Option, + }, + TypeAlreadySet { + lhs: Option, + rhs: Option, + span: Option, + }, + NotASubtype { + lhs: Option, + rhs: Option, + span: Option, + }, + IncompatibleBinOp { + bin_ty: BinType, + span: Option, + }, + NonConstantExponent { + bin_ty: BinType, + span: Option, + }, +} + +impl core::fmt::Display for TypeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeError::IncompatibleScalarTypes { lhs, rhs, .. } => { + write!(f, "incompatible scalar types: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleShapes { lhs, rhs, .. } => { + write!(f, "incompatible shapes: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleType { lhs, rhs, .. } => { + write!(f, "incompatible types: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::TypeAlreadySet { lhs, rhs, .. } => { + write!(f, "type already set: {} vs {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::NotASubtype { lhs, rhs, .. } => { + write!(f, "type {} is not a subtype of {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleBinOp { bin_ty, .. } => { + write!(f, "incompatible types for binary operation: {}", bin_ty.show_fn_ty())?; + Ok(()) + }, + TypeError::NonConstantExponent { bin_ty, .. } => { + write!(f, "expected exponent to be a constant, got: {}", bin_ty.show_fn_ty())?; + Ok(()) + }, + } + } +} + +pub trait Typing { + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) + } + fn ty(&self) -> Option; + fn shape(&self) -> Option { + self.ty().and_then(|t| match t { + Type::Scalar(_) => ty!(_), + Type::Vector(_, len) => ty!(_[len]), + Type::Matrix(_, rows, cols) => ty!(_[rows, cols]), + }) + } + fn scalar_ty(&self) -> Option { + self.ty().scalar_ty() + } + fn ty_with_shape(&self, shape: impl Typing) -> Option { + let sty = self.scalar_ty(); + let shape = shape.shape(); + if sty.is_none() { + return shape; + } + match (self.scalar_ty().unwrap(), shape) { + (sty, None | Some(Type::Scalar(_))) => Some(Type::Scalar(Some(sty))), + (sty, Some(Type::Vector(_, len))) => Some(Type::Vector(Some(sty), len)), + (sty, Some(Type::Matrix(_, rows, cols))) => Some(Type::Matrix(Some(sty), rows, cols)), + } + } + fn is_scalar_felt(&self) -> bool { + matches!(self.scalar_ty(), sty!(felt)) + } + fn is_scalar_bool(&self) -> bool { + matches!(self.scalar_ty(), sty!(bool)) + } + fn is_scalar_int(&self) -> bool { + matches!(self.scalar_ty(), sty!(uint)) + } + fn is_scalar(&self) -> bool { + matches!(self.ty(), Some(Type::Scalar(_))) + } + fn is_vector(&self) -> bool { + matches!(self.ty(), Some(Type::Vector(_, _))) + } + fn is_matrix(&self) -> bool { + matches!(self.ty(), Some(Type::Matrix(_, _, _))) + } + /// Returns true if this type is an aggregate + #[inline] + fn is_aggregate(&self) -> bool { + self.is_vector() || self.is_matrix() + } + + /// Returns true if this type is a valid iterable in a comprehension + #[inline] + fn is_iterable(&self) -> bool { + self.is_vector() + } + /// Returns true if the shape of `self` is a sub-shape of the shape of `other` + /// The shape of `self` is a sub-shape of the shape of `other` if: + /// - other is `?` (None) + /// - both are scalars + /// - both are vectors of the same length + /// - both are vectors with one of the lengths being `u32::MAX` + /// - both are matrices with the same number of rows and columns + /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair + /// (if any) being equal + /// + /// self\\other || ? | _ | _[l] | _[r,c] + /// ============||===|===|======|======== + /// ? || y | n | n | n + /// _ || y | y | n | n + /// _[l] || y | n | y | n + /// _[r,c] || y | n | n | y + fn is_subshape(&self, other: &impl Typing) -> bool { + match (self.ty(), other.ty()) { + (_, None) => true, + (Some(Type::Scalar(_)), Some(Type::Scalar(_))) => true, + (Some(Type::Vector(_, len1)), Some(Type::Vector(_, len2))) => { + len1 == len2 || len1 == u32::MAX as usize || len2 == u32::MAX as usize + }, + (Some(Type::Matrix(_, rows1, cols1)), Some(Type::Matrix(_, rows2, cols2))) => { + (rows1 == rows2 || rows1 == u32::MAX as usize || rows2 == u32::MAX as usize) + && (cols1 == cols2 || cols1 == u32::MAX as usize || cols2 == u32::MAX as usize) + }, + _ => false, + } + } + + /// Returns true if the shape of `self` is compatible with the shape of `other` + /// The shapes are compatible if: + /// - either is `?` (None) + /// - both are scalars + /// - both are vectors of the same length + /// - both are vectors with one of the lengths being `u32::MAX` + /// - both are matrices with the same number of rows and columns + /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair + /// (if any) being equal + /// + /// self\\other || ? | _ | _[l] | _[r,c] + /// ============||===|===|======|======== + /// ? || y | y | y | y + /// _ || y | y | n | n + /// _[l] || y | n | y | n + /// _[r,c] || y | n | n | y + /// + /// This is a more relaxed version of [Typing::is_subshape], + /// allowing for bi-directional compatibility checks. The only + /// difference is that it allows for `self` to be `?` (None). + fn is_shape_compatible(&self, other: &impl Typing) -> bool { + self.ty().is_none() || self.is_subshape(other) + } + /// Returns true if `self` is a subtype of `other` + /// Notation: + /// _ : ScalaType::Scalar(None) + /// Unknown scalar type + /// felt: ScalarType::Felt + /// Felt type + /// bool: ScalarType::Bool + /// Boolean type + /// uint: ScalarType::UInt + /// Integer type + /// + /// Subtyping rules: + /// - _ > felt > bool + /// - _ > felt > uint + /// + /// Which means: + /// - all scalar types are subtypes of `_` + /// - `bool` is a subtype of `felt`: a `bool` is a `felt with a `is_bool` property + /// - `uint` is a subtype of `felt`: a `uint` is a `felt` with the `constant` property + /// + /// self\\other || _ | felt | bool | uint | + /// ============||===|======|======|======| + /// _ || y | n | n | n | + /// felt || y | y | n | n | + /// bool || y | y | y | n | + /// uint || y | y | n | y | + fn is_scalar_subtype(&self, other: &impl Typing) -> bool { + !matches!( + (self.scalar_ty(), other.scalar_ty()), + (sty!(_), sty!(felt) | sty!(bool) | sty!(uint)) + | (sty!(felt), sty!(bool) | sty!(uint)) + | (sty!(bool), sty!(uint)) + | (sty!(uint), sty!(bool)) + ) + } + /// Returns true if `self` is a subtype of `other` + /// Notation: + /// ?: None + /// Unknown type + /// _: Type::Scalar(None) + /// Unknown scalar type + /// felt: Type::Scalar(Some(ScalarType::Felt)) + /// Felt type + /// bool: Type::Scalar(Some(ScalarType::Bool)) + /// Boolean type + /// uint: Type::Scalar(Some(ScalarType::UInt)) + /// Integer type + /// sty[len]: Type::Vector(Some(sty), len) + /// Vector of length `len` with scalar type `sty` + /// sty[rows, cols]: Type::Matrix(Some(sty), rows, cols) + /// Matrix with `rows` and `cols` with scalar type `sty` + /// + /// Subtyping rules: + /// ? > _ > felt > bool + /// ? > _ > felt > uint + /// ? > _[l] > felt[l] > bool[l] + /// ? > _[l] > felt[l] > uint[l] + /// ? > _[r, c] > felt[r, c] > bool[r, c] + /// ? > _[r, c] > felt[r, c] > uint[r, c] + /// Assuming the shape of `self` is a sub-shape of the shape of `other`, + /// this function checks if `self` is a subtype of `other`, + /// with the added case of `?`, which all types are subtypes of. + /// See [Typing::is_scalar_subtype] for a more detailed explanation + /// of the subtyping rules of scalar types. + /// + /// self\\other || ? | _ | felt | bool | uint | + /// ============||===|===|======|======|======| + /// ? || y | n | n | n | n | + /// _ || y |[y | n | n | n]| + /// felt || y |[y | y | n | n]| + /// bool || y |[y | y | y | n]| + /// uint || y |[y | y | n | y]| + /// + /// = self.is_scalar_subtype(other) | other == ? + /// [...] Denotes the result of the [Typing::is_scalar_subtype] method. + fn is_subtype(&self, other: &impl Typing) -> bool { + self.is_subshape(other) && self.is_scalar_subtype(other) + } + fn show_kind(&self) -> Show> { + Show(self.kind()) + } + fn show_fn_ty(&self) -> Show> { + match self.kind() { + Some(Kind::Callable(fn_ty)) => Show(Some(fn_ty)), + _ => Show(None), + } + } + fn show_ty(&self) -> Show> { + Show(self.ty()) + } + fn show_scalar_ty(&self) -> Show> { + Show(self.scalar_ty()) + } + /// Returns the type of the current object, if it is known or can be inferred. + /// If the type is not known, it returns `None`. + /// If the type can be inferred, it returns the inferred type. + /// If the type cannot be inferred, it returns an appropriate error. + fn infer_ty(&self) -> Result, TypeError> { + Ok(self.ty()) + } + fn lowest_common_supertype(&self, other: &impl Typing) -> Option { + match (self.ty(), other.ty()) { + (ty!(?), _) | (_, ty!(?)) => ty!(?), + (ty!(_), Some(Type::Scalar(_))) | (Some(Type::Scalar(_)), ty!(_)) => ty!(_), + (Some(Type::Vector(sty!(_), llen)), Some(Type::Vector(_, rlen))) + | (Some(Type::Vector(_, llen)), Some(Type::Vector(sty!(_), rlen))) => { + ty!(_[llen.max(rlen)]) + }, + (Some(Type::Matrix(sty!(_), lrows, lcols)), Some(Type::Matrix(_, rrows, rcols))) + | (Some(Type::Matrix(_, lrows, lcols)), Some(Type::Matrix(sty!(_), rrows, rcols))) => { + ty!(_[lrows.max(rrows), lcols.max(rcols)]) + }, + (lhs, rhs) if lhs.is_subtype(&rhs) => rhs, + (lhs, rhs) if rhs.is_subtype(&lhs) => lhs, + (ty!(uint), ty!(bool)) | (ty!(bool), ty!(uint)) => ty!(felt), + (Some(Type::Vector(sty!(uint), llen)), Some(Type::Vector(sty!(bool), rlen))) + | (Some(Type::Vector(sty!(bool), llen)), Some(Type::Vector(sty!(uint), rlen))) => { + ty!(felt[core::cmp::max(llen, rlen)]) + }, + ( + Some(Type::Matrix(sty!(uint), lrows, lcols)), + Some(Type::Matrix(sty!(bool), rrows, rcols)), + ) + | ( + Some(Type::Matrix(sty!(bool), lrows, lcols)), + Some(Type::Matrix(sty!(uint), rrows, rcols)), + ) => { + ty!(felt[core::cmp::max(lrows, rrows), core::cmp::max(lcols, rcols)]) + }, + _ => None, + } + } +} + +pub trait ScalarTypeMut: Typing { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option); + fn update_scalar_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { + let ty = self.scalar_ty(); + if ty.is_none() { + // WARN: This should only be true before type inference + // Any None type should raise a diagnostic after type inference + self.update_scalar_ty_unchecked(new_ty); + } else if ty.is_scalar_subtype(&new_ty) { + // Allow widening of types + self.update_scalar_ty_unchecked(new_ty); + } else { + return Err(TypeError::IncompatibleScalarTypes { lhs: ty, rhs: new_ty, span: None }); + } + Ok(()) + } +} + +pub trait TypeMut: Typing + ScalarTypeMut { + fn update_ty_unchecked(&mut self, new_ty: Option); + fn update_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { + let ty = self.ty(); + if ty.is_none() { + // WARN: This should only be true before type inference + // Any None type should raise a diagnostic after type inference + self.update_ty_unchecked(new_ty); + } else if ty.is_subtype(&new_ty) { + // Allow widening of types + self.update_ty_unchecked(new_ty); + } else { + return Err(TypeError::NotASubtype { lhs: ty, rhs: new_ty, span: None }); + } + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Show(T); + +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("!"), + Some(kind) => write!(f, "{kind}"), + } + } +} + +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("?"), + Some(fn_ty) => write!(f, "{fn_ty}"), + } + } +} + +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("?"), + Some(ty) => write!(f, "{ty}"), + } + } +} + +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("_"), + Some(sty) => write!(f, "{sty}"), + } + } +} + +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "[{}]", + self.0.iter().map(|t| t.show_ty().to_string()).collect::>().join(", ") + ) + } +} + +impl Typing for Show { + fn kind(&self) -> Option { + self.0.kind() + } + fn ty(&self) -> Option { + self.0.ty() + } + fn scalar_ty(&self) -> Option { + self.0.scalar_ty() + } + fn show_kind(&self) -> Show> { + self.0.show_kind() + } + fn show_fn_ty(&self) -> Show> { + self.0.show_fn_ty() + } + fn show_ty(&self) -> Show> { + self.0.show_ty() + } + fn show_scalar_ty(&self) -> Show> { + self.0.show_scalar_ty() + } +} + +impl Typing for ScalarType { + fn ty(&self) -> Option { + Some(Type::Scalar(Some(*self))) + } + fn scalar_ty(&self) -> Option { + Some(*self) + } +} + +impl Typing for Type { + fn ty(&self) -> Option { + Some(*self) + } + fn scalar_ty(&self) -> Option { + match self { + Type::Scalar(st) => *st, + Type::Vector(st, _) => *st, + Type::Matrix(st, ..) => *st, + } + } +} + +impl ScalarTypeMut for Type { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + match self { + Type::Scalar(st) => *st = new_ty, + Type::Vector(st, _) => *st = new_ty, + Type::Matrix(st, ..) => *st = new_ty, + } + } +} + +impl Typing for FunctionType { + fn kind(&self) -> Option { + Some(Kind::Callable(self.clone())) + } + fn ty(&self) -> Option { + panic!("FunctionType does not have a concrete type") + } + fn scalar_ty(&self) -> Option { + panic!("FunctionType does not have a concrete scalar type") + } +} + +impl ScalarTypeMut for BinType { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.result_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for BinType { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.result_mut().update_ty_unchecked(new_ty); + } +} + +impl Typing for BinType { + fn ty(&self) -> Option { + self.infer_ty().ok()? + } + fn infer_ty(&self) -> Result, TypeError> { + match self { + BinType::Eq(.., Some(ret)) + | BinType::Add(.., Some(ret)) + | BinType::Sub(.., Some(ret)) + | BinType::Mul(.., Some(ret)) + | BinType::Exp(.., Some(ret)) => Ok(Some(*ret)), + BinType::Eq(.., None) => self.infer_bin_ty_eq(), + BinType::Add(.., None) => self.infer_bin_ty_add(), + BinType::Sub(.., None) => self.infer_bin_ty_sub(), + BinType::Mul(.., None) => self.infer_bin_ty_mul(), + BinType::Exp(.., None) => self.infer_bin_ty_exp(), + } + } +} + +impl ScalarTypeMut for Kind { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + match self { + Kind::Value(ty) => ty.update_scalar_ty_unchecked(new_ty), + Kind::Aggregate(_) => panic!("Cannot mutate scalar type of an aggregate kind"), + Kind::Callable(_) => panic!("Cannot mutate scalar type of a callable kind"), + } + } +} + +impl TypeMut for Kind { + fn update_ty_unchecked(&mut self, new_ty: Option) { + match self { + Kind::Value(ty) => ty.update_ty_unchecked(new_ty), + Kind::Aggregate(_) => panic!("Cannot mutate type of an aggregate kind"), + Kind::Callable(_) => panic!("Cannot mutate type of a callable kind"), + } + } +} + +impl Typing for Kind { + fn kind(&self) -> Option { + Some(self.clone()) + } + fn ty(&self) -> Option { + match self { + Kind::Value(ty) => *ty, + Kind::Aggregate(a) => { + let mut inner_ty = a.first().and_then(|t| t.ty()); + for item in a.iter().skip(1) { + let item_ty = item.ty(); + inner_ty = item_ty.lowest_common_supertype(&inner_ty); + } + match inner_ty { + None => None, + Some(Type::Scalar(st)) => ty!(st[a.len()]), + Some(Type::Vector(st, cols)) => ty!(st[a.len(), cols]), + Some(Type::Matrix(..)) => { + // An aggregate of matrices is not supported + None + }, + } + }, + Kind::Callable(_) => None, + } + } +} + +impl ScalarTypeMut for Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + *self = new_ty; + } +} + +impl ScalarTypeMut for Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + match self { + Some(Type::Scalar(st)) => *st = new_ty, + Some(Type::Vector(st, _)) => *st = new_ty, + Some(Type::Matrix(st, ..)) => *st = new_ty, + None => panic!("Cannot mutate scalar type of None"), + } + } +} + +impl TypeMut for Option { + fn update_ty_unchecked(&mut self, new_ty: Option) { + *self = new_ty; + } +} + +impl ScalarTypeMut for Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + match self { + Some(Kind::Value(ty)) => ty.update_scalar_ty_unchecked(new_ty), + Some(Kind::Aggregate(_)) => panic!("Cannot mutate scalar type of an aggregate kind"), + Some(Kind::Callable(_)) => panic!("Cannot mutate scalar type of a callable kind"), + None => panic!("Cannot mutate scalar type of None"), + } + } +} + +impl TypeMut for Option { + fn update_ty_unchecked(&mut self, new_ty: Option) { + match self { + Some(Kind::Value(ty)) => ty.update_ty_unchecked(new_ty), + Some(Kind::Aggregate(_)) => panic!("Cannot mutate type of an aggregate kind"), + Some(Kind::Callable(_)) => panic!("Cannot mutate type of a callable kind"), + None => panic!("Cannot mutate type of None"), + } + } +} + +impl Typing for Option { + fn kind(&self) -> Option { + self.as_ref().and_then(|t| t.kind()) + } + fn ty(&self) -> Option { + self.as_ref().and_then(|t| t.ty()) + } + fn scalar_ty(&self) -> Option { + self.as_ref().and_then(|t| t.scalar_ty()) + } +} + +impl Typing for Box { + fn kind(&self) -> Option { + T::kind(self) + } + fn ty(&self) -> Option { + T::ty(self) + } +} + +impl ScalarTypeMut for RefCell { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for RefCell { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_ty_unchecked(new_ty); + } +} + +impl Typing for RefCell { + fn kind(&self) -> Option { + self.borrow().kind() + } + fn ty(&self) -> Option { + self.borrow().ty() + } + fn scalar_ty(&self) -> Option { + self.borrow().scalar_ty() + } +} + +impl ScalarTypeMut for RefMut<'_, T> { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.deref_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for RefMut<'_, T> { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.deref_mut().update_ty_unchecked(new_ty); + } +} + +impl Typing for RefMut<'_, T> { + fn kind(&self) -> Option { + self.deref().kind() + } + fn ty(&self) -> Option { + self.deref().ty() + } + fn scalar_ty(&self) -> Option { + self.deref().scalar_ty() + } +} + +impl Typing for Ref<'_, T> { + fn kind(&self) -> Option { + self.deref().kind() + } + fn ty(&self) -> Option { + self.deref().ty() + } + fn scalar_ty(&self) -> Option { + self.deref().scalar_ty() + } +} + +impl ScalarTypeMut for Span { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.item.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Span { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.item.update_ty_unchecked(new_ty); + } +} + +impl Typing for Span { + fn kind(&self) -> Option { + self.item.kind() + } + fn ty(&self) -> Option { + self.item.ty() + } +} + +impl Typing for Vec { + fn kind(&self) -> Option { + let agg = self.iter().map(|t| t.kind().map(Box::new)).collect(); + Some(Kind::Aggregate(agg)) + } + fn ty(&self) -> Option { + self.kind().ty() + } +} + +#[macro_export] +macro_rules! assert_subtype { + ($a:expr; !$b:expr) => { + eprintln!("assert_subtype!({}; !{})", stringify!($a), stringify!($b)); + let res = !$crate::Typing::is_subtype(&$a, &$b); + assert!( + res, + "{}: !{}\nError: {} is a subtype of {}", + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + ); + }; + ($a:expr; $b:expr) => { + eprintln!("assert_subtype!({}; {})", stringify!($a), stringify!($b)); + let res = $crate::Typing::is_subtype(&$a, &$b); + assert!( + res, + "{}: {}\nError: {} is a not subtype of {}", + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + ); + }; +} + +#[cfg(test)] +mod tests { + use pretty_assertions::assert_eq; + + use super::*; + use crate::{sty, ty}; + + #[test] + fn test_typing() { + assert_eq!(ty!(?).ty(), None); + assert_eq!(ty!(?).scalar_ty(), sty!(_)); + assert_eq!(ty!(_).ty(), Some(Type::Scalar(sty!(_)))); + assert_eq!(ty!(_).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt).ty(), Some(Type::Scalar(sty!(felt)))); + assert_eq!(ty!(felt).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool).ty(), Some(Type::Scalar(sty!(bool)))); + assert_eq!(ty!(bool).scalar_ty(), sty!(bool)); + assert_eq!(ty!(uint).ty(), Some(Type::Scalar(sty!(uint)))); + assert_eq!(ty!(uint).scalar_ty(), sty!(uint)); + assert_eq!(ty!(_[5]).ty(), Some(Type::Vector(sty!(_), 5))); + assert_eq!(ty!(_[5]).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt[5]).ty(), Some(Type::Vector(sty!(felt), 5))); + assert_eq!(ty!(felt[5]).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool[5]).ty(), Some(Type::Vector(sty!(bool), 5))); + assert_eq!(ty!(bool[5]).scalar_ty(), sty!(bool)); + assert_eq!(ty!(uint[5]).ty(), Some(Type::Vector(sty!(uint), 5))); + assert_eq!(ty!(uint[5]).scalar_ty(), sty!(uint)); + assert_eq!(ty!(_[3, 4]).ty(), Some(Type::Matrix(sty!(_), 3, 4))); + assert_eq!(ty!(_[3, 4]).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt[3, 4]).ty(), Some(Type::Matrix(sty!(felt), 3, 4))); + assert_eq!(ty!(felt[3, 4]).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool[3, 4]).ty(), Some(Type::Matrix(sty!(bool), 3, 4))); + assert_eq!(ty!(bool[3, 4]).scalar_ty(), sty!(bool)); + assert_eq!(ty!(uint[3, 4]).ty(), Some(Type::Matrix(sty!(uint), 3, 4))); + assert_eq!(ty!(uint[3, 4]).scalar_ty(), sty!(uint)); + } + + #[test] + fn test_typing_subtype() { + assert_subtype!(ty!(?); ty!(?)); + assert_subtype!(ty!(?); !ty!(_)); + assert_subtype!(ty!(?); !ty!(felt)); + assert_subtype!(ty!(?); !ty!(bool)); + assert_subtype!(ty!(?); !ty!(uint)); + assert_subtype!(ty!(?); !ty!(_[5])); + assert_subtype!(ty!(?); !ty!(felt[5])); + assert_subtype!(ty!(?); !ty!(bool[5])); + assert_subtype!(ty!(?); !ty!(uint[5])); + assert_subtype!(ty!(?); !ty!(_[3, 4])); + assert_subtype!(ty!(?); !ty!(felt[3, 4])); + assert_subtype!(ty!(?); !ty!(bool[3, 4])); + assert_subtype!(ty!(?); !ty!(uint[3, 4])); + + assert_subtype!(ty!(_); ty!(?)); + assert_subtype!(ty!(_); ty!(_)); + assert_subtype!(ty!(_); !ty!(felt)); + assert_subtype!(ty!(_); !ty!(bool)); + assert_subtype!(ty!(_); !ty!(uint)); + assert_subtype!(ty!(_); !ty!(_[5])); + assert_subtype!(ty!(_); !ty!(felt[5])); + assert_subtype!(ty!(_); !ty!(bool[5])); + assert_subtype!(ty!(_); !ty!(uint[5])); + assert_subtype!(ty!(_); !ty!(_[3, 4])); + assert_subtype!(ty!(_); !ty!(felt[3, 4])); + assert_subtype!(ty!(_); !ty!(bool[3, 4])); + assert_subtype!(ty!(_); !ty!(uint[3, 4])); + + assert_subtype!(ty!(felt); ty!(?)); + assert_subtype!(ty!(felt); ty!(_)); + assert_subtype!(ty!(felt); ty!(felt)); + assert_subtype!(ty!(felt); !ty!(bool)); + assert_subtype!(ty!(felt); !ty!(uint)); + assert_subtype!(ty!(felt); !ty!(_[5])); + assert_subtype!(ty!(felt); !ty!(felt[5])); + assert_subtype!(ty!(felt); !ty!(bool[5])); + assert_subtype!(ty!(felt); !ty!(uint[5])); + assert_subtype!(ty!(felt); !ty!(_[3, 4])); + assert_subtype!(ty!(felt); !ty!(felt[3, 4])); + assert_subtype!(ty!(felt); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt); !ty!(uint[3, 4])); + + assert_subtype!(ty!(bool); ty!(?)); + assert_subtype!(ty!(bool); ty!(_)); + assert_subtype!(ty!(bool); ty!(felt)); + assert_subtype!(ty!(bool); ty!(bool)); + assert_subtype!(ty!(bool); !ty!(uint)); + assert_subtype!(ty!(bool); !ty!(_[5])); + assert_subtype!(ty!(bool); !ty!(felt[5])); + assert_subtype!(ty!(bool); !ty!(bool[5])); + assert_subtype!(ty!(bool); !ty!(uint[5])); + assert_subtype!(ty!(bool); !ty!(_[3, 4])); + assert_subtype!(ty!(bool); !ty!(felt[3, 4])); + assert_subtype!(ty!(bool); !ty!(bool[3, 4])); + assert_subtype!(ty!(bool); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint); ty!(?)); + assert_subtype!(ty!(uint); ty!(_)); + assert_subtype!(ty!(uint); ty!(felt)); + assert_subtype!(ty!(uint); !ty!(bool)); + assert_subtype!(ty!(uint); ty!(uint)); + assert_subtype!(ty!(uint); !ty!(_[5])); + assert_subtype!(ty!(uint); !ty!(felt[5])); + assert_subtype!(ty!(uint); !ty!(bool[5])); + assert_subtype!(ty!(uint); !ty!(uint[5])); + assert_subtype!(ty!(uint); !ty!(_[3, 4])); + assert_subtype!(ty!(uint); !ty!(felt[3, 4])); + assert_subtype!(ty!(uint); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint); !ty!(uint[3, 4])); + + assert_subtype!(ty!(_[5]); ty!(?)); + assert_subtype!(ty!(_[5]); !ty!(_)); + assert_subtype!(ty!(_[5]); !ty!(felt)); + assert_subtype!(ty!(_[5]); !ty!(bool)); + assert_subtype!(ty!(_[5]); !ty!(uint)); + assert_subtype!(ty!(_[5]); ty!(_[5])); + assert_subtype!(ty!(_[5]); !ty!(felt[5])); + assert_subtype!(ty!(_[5]); !ty!(bool[5])); + assert_subtype!(ty!(_[5]); !ty!(uint[5])); + assert_subtype!(ty!(_[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(felt[5]); ty!(?)); + assert_subtype!(ty!(felt[5]); !ty!(_)); + assert_subtype!(ty!(felt[5]); !ty!(felt)); + assert_subtype!(ty!(felt[5]); !ty!(bool)); + assert_subtype!(ty!(felt[5]); !ty!(uint)); + assert_subtype!(ty!(felt[5]); ty!(_[5])); + assert_subtype!(ty!(felt[5]); ty!(felt[5])); + assert_subtype!(ty!(felt[5]); !ty!(bool[5])); + assert_subtype!(ty!(felt[5]); !ty!(uint[5])); + assert_subtype!(ty!(felt[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(bool[5]); ty!(?)); + assert_subtype!(ty!(bool[5]); !ty!(_)); + assert_subtype!(ty!(bool[5]); !ty!(felt)); + assert_subtype!(ty!(bool[5]); !ty!(bool)); + assert_subtype!(ty!(bool[5]); !ty!(uint)); + assert_subtype!(ty!(bool[5]); ty!(_[5])); + assert_subtype!(ty!(bool[5]); ty!(felt[5])); + assert_subtype!(ty!(bool[5]); ty!(bool[5])); + assert_subtype!(ty!(bool[5]); !ty!(uint[5])); + assert_subtype!(ty!(bool[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint[5]); ty!(?)); + assert_subtype!(ty!(uint[5]); !ty!(_)); + assert_subtype!(ty!(uint[5]); !ty!(felt)); + assert_subtype!(ty!(uint[5]); !ty!(bool)); + assert_subtype!(ty!(uint[5]); !ty!(uint)); + assert_subtype!(ty!(uint[5]); ty!(_[5])); + assert_subtype!(ty!(uint[5]); ty!(felt[5])); + assert_subtype!(ty!(uint[5]); !ty!(bool[5])); + assert_subtype!(ty!(uint[5]); ty!(uint[5])); + assert_subtype!(ty!(uint[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(_[3, 4]); ty!(?)); + assert_subtype!(ty!(_[3, 4]); !ty!(_)); + assert_subtype!(ty!(_[3, 4]); !ty!(felt)); + assert_subtype!(ty!(_[3, 4]); !ty!(bool)); + assert_subtype!(ty!(_[3, 4]); !ty!(uint)); + assert_subtype!(ty!(_[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(uint[5])); + assert_subtype!(ty!(_[3, 4]); ty!(_[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(felt[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(felt[3, 4]); ty!(?)); + assert_subtype!(ty!(felt[3, 4]); !ty!(_)); + assert_subtype!(ty!(felt[3, 4]); !ty!(felt)); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool)); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint)); + assert_subtype!(ty!(felt[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint[5])); + assert_subtype!(ty!(felt[3, 4]); ty!(_[3, 4])); + assert_subtype!(ty!(felt[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(bool[3, 4]); ty!(?)); + assert_subtype!(ty!(bool[3, 4]); !ty!(_)); + assert_subtype!(ty!(bool[3, 4]); !ty!(felt)); + assert_subtype!(ty!(bool[3, 4]); !ty!(bool)); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint)); + assert_subtype!(ty!(bool[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint[5])); + assert_subtype!(ty!(bool[3, 4]); ty!(_[3, 4])); + assert_subtype!(ty!(bool[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(bool[3, 4]); ty!(bool[3, 4])); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint[3, 4]); ty!(?)); + assert_subtype!(ty!(uint[3, 4]); !ty!(_)); + assert_subtype!(ty!(uint[3, 4]); !ty!(felt)); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool)); + assert_subtype!(ty!(uint[3, 4]); !ty!(uint)); + assert_subtype!(ty!(uint[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(uint[5])); + assert_subtype!(ty!(uint[3, 4]); ty!(_[3, 4])); + assert_subtype!(ty!(uint[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint[3, 4]); ty!(uint[3, 4])); + } + + macro_rules! assert_ty_eq { + ($a:expr, $b:expr) => {{ + eprintln!("{}: {} == {}", $a.show_kind(), $a.show_ty(), $b.show_ty()); + assert_eq!( + $a.ty(), + $b, + "Expected {} to be equal to {}, but it was not", + $a.ty().show_ty(), + $b.show_ty(), + ); + }}; + } + #[track_caller] + fn assert_tys_eq_with_rev(a: Vec, b: Option) { + assert_ty_eq!(a, b); + assert_ty_eq!(a.iter().rev().cloned().collect::>(), b); + } + #[test] + fn test_vec_typing() { + assert_ty_eq!(vec![ty!(felt), ty!(felt), ty!(felt)], ty!(felt[3])); + assert_tys_eq_with_rev(tys!([uint, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, uint]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, uint]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, uint]), ty!(?)); + assert_tys_eq_with_rev(tys!([felt[5], felt[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([uint[5], felt[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([bool[5], uint[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([_[5], uint[5]]), ty!(_[2, 5])); + assert_tys_eq_with_rev(tys!([bool[3], uint[8]]), ty!(felt[2, 8])); + assert_tys_eq_with_rev(tys!([_[3], uint[8]]), ty!(_[2, 8])); + assert_tys_eq_with_rev(tys!([?, uint[5]]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5, 2], _]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5, 2]]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, uint]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, uint]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, uint]), ty!(?)); + + assert_tys_eq_with_rev( + vec![tys!([uint, felt]), tys!([uint, felt]), tys!([uint, felt])], + ty!(felt[3, 2]), + ); + assert_tys_eq_with_rev( + vec![tys!([bool, uint]), tys!([bool, uint]), tys!([bool, uint])], + ty!(felt[3, 2]), + ); + assert_tys_eq_with_rev( + vec![tys!([_, uint]), tys!([_, uint]), tys!([_, uint])], + ty!(_[3, 2]), + ); + assert_tys_eq_with_rev(vec![tys!([?, uint]), tys!([?, uint]), tys!([?, uint])], ty!(?)); + assert_tys_eq_with_rev(tys!([felt[5], uint[5], bool[5]]), ty!(felt[3, 5])); + } +} diff --git a/types/src/types.rs b/types/src/types.rs new file mode 100644 index 000000000..d513f5ac3 --- /dev/null +++ b/types/src/types.rs @@ -0,0 +1,900 @@ +use crate::{TypeError, Typing}; + +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum ScalarType { + Felt, + Bool, + UInt, +} + +impl core::fmt::Display for ScalarType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Felt => f.write_str("felt"), + Self::Bool => f.write_str("bool"), + Self::UInt => f.write_str("uint"), + } + } +} + +#[macro_export] +macro_rules! sty { + // for pattern matching + // equivalent to a `_` in a match or let expression + (any) => { + _ + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any: $name:ident) => { + $name + }; + (_) => { + None + }; + (felt) => { + Some($crate::ScalarType::Felt) + }; + (bool) => { + Some($crate::ScalarType::Bool) + }; + (uint) => { + Some($crate::ScalarType::UInt) + }; + ($sty:ident) => { + $sty + }; +} + +/// The types of values which can be represented in an AirScript program +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum Type { + // annotation: sty + // where sty is the scalar type + Scalar(Option), + // annotation: `sty[len]` + // where len is the number of elements in the vector, + // and sty is the scalar type + Vector(Option, usize), + // annotation: `sty[rows, cols]` + // where rows and cols are the dimensions of the matrix, + // and sty is the scalar type + Matrix(Option, usize, usize), +} + +impl core::fmt::Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Scalar(None) => f.write_str("_"), + Self::Vector(None, len) => write!(f, "_[{len}]"), + Self::Matrix(None, rows, cols) => write!(f, "_[{rows}, {cols}]"), + Self::Scalar(Some(sty)) => f.write_str(&sty.to_string()), + Self::Vector(Some(sty), len) => write!(f, "{sty}[{len}]"), + Self::Matrix(Some(sty), rows, cols) => write!(f, "{sty}[{rows}, {cols}]"), + } + } +} + +#[macro_export] +macro_rules! ty { + // for pattern matching + // equivalent to a `_` in a match or let expression + (any) => { + _ + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any: $name:ident) => { + $name + }; + (?) => { + None::<$crate::Type> + }; + (_) => { + Some($crate::Type::Scalar(None)) + }; + ($sty:ident) => { + Some($crate::Type::Scalar($crate::sty!($sty))) + }; + (_[$len:expr]) => { + Some($crate::Type::Vector($crate::sty!(_), $len)) + }; + ($sty:ident[$len:expr]) => { + Some($crate::Type::Vector($crate::sty!($sty), $len)) + }; + (_[$rows:expr, $cols:expr]) => { + Some($crate::Type::Matrix($crate::sty!(_), $rows, $cols)) + }; + ($sty:ident[$rows:expr, $cols:expr]) => { + Some($crate::Type::Matrix($crate::sty!($sty), $rows, $cols)) + }; +} + +pub struct Push(pub Vec); +impl Push { + pub fn push(mut self, ty: T) -> Self { + self.0.push(ty); + self + } +} + +#[macro_export] +macro_rules! tys { + ([$($args:tt)+]) => { + tys!(RES: Push(vec![]); $($args)+).0 + }; + (RES: $res:expr; ) => { + $res + }; + (RES: $res:expr; ? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!(?)); $($($rest)+)?) + }; + (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!(_$([$($spec)+])?)); $($($rest)+)?) + }; + (RES: $res:expr; $name:ident$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!($name$([$($spec)+])?)); $($($rest)+)?) + }; +} + +#[macro_export] +macro_rules! kinds { + ([$($args:tt)+]) => { + kinds!(RES: Push(vec![]); $($args)+).0 + }; + (RES: $res:expr; ) => { + $res + }; + (RES: $res:expr; ?) => { + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!(?))));) + }; + (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!(_$([$($spec)+])?)))); $($($rest)+)?) + }; + (RES: $res:expr; $name:ident$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!($name$([$($spec)+])?)))); $($($rest)+)?) + }; +} + +#[macro_export] +macro_rules! tty { + ([$($n1:ident$([$l1:expr])?),*]) => { + Vec::>::from([ + $($crate::tty!($n1$([$l1])?)),* + ]) + }; + ($name:ident[$len:expr]) => { + match $len { + 1 => $crate::ty!(felt), + _ => $crate::ty!(felt[$len]), + } + }; + ($name:ident) => { + $crate::ty!(felt) + }; +} + +/// Represents the type signature of a function +#[derive(Hash, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum FunctionType { + /// An evaluator function, which has no results, and has + /// a complex type signature due to the nature of trace bindings + Evaluator(Vec>), + /// A standard function with one or more inputs, and a result + Function(Vec>, Option), +} + +impl Default for FunctionType { + fn default() -> Self { + Self::Evaluator(vec![]) + } +} + +impl FunctionType { + pub fn params(&self) -> &[Option] { + match self { + Self::Evaluator(params) => params, + Self::Function(params, _) => params, + } + } + + pub fn result(&self) -> Option { + match self { + Self::Evaluator(_) => None, + Self::Function(_, ret) => *ret, + } + } + + pub fn check_args_kinds(&self, args: &[&Kind]) -> bool { + match self { + Self::Function(params, _) => { + if params.len() != args.len() { + return false; + } + for (arg_ty, param_kind) in args.iter().zip(params.iter()) { + if !arg_ty.is_subtype(param_kind) { + return false; + } + } + true + }, + Self::Evaluator(params) => { + // Only check that the number of columns match + // since evaluator arguments get matched as a vec of felt + let mut params_len = 0; + for param in params.iter() { + match param { + Some(Type::Scalar(_)) => params_len += 1, + Some(Type::Vector(_, len)) => params_len += *len, + Some(Type::Matrix(_, _, _)) => { + return false; + }, + None => return false, + } + } + let mut args_len = 0; + for arg in args.iter() { + match arg.ty() { + Some(Type::Scalar(_)) => args_len += 1, + Some(Type::Vector(_, len)) => args_len += len, + Some(Type::Matrix(_, _, _)) => { + return false; + }, + None => return false, + } + } + if params_len != args_len { + return false; + } + true + }, + } + } +} + +impl core::fmt::Display for FunctionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Evaluator(args) => { + f.write_str("ev(")?; + write!( + f, + "[{}]", + args.iter().map(|ty| ty.show_ty().to_string()).collect::>().join(", ") + )?; + f.write_str(")") + }, + Self::Function(args, ret) => { + f.write_str("fn(")?; + f.write_str( + &args.iter().map(|ty| ty.show_ty().to_string()).collect::>().join(", "), + )?; + f.write_str(") -> ")?; + if let Some(ret_type) = ret { + write!(f, "{ret_type}") + } else { + f.write_str("?") + } + }, + } + } +} + +#[macro_export] +macro_rules! fty { + (ev ([])) => { + $crate::FunctionType::Evaluator(vec![]) + }; + (ev ([$($tty:tt)+])) => { + $crate::FunctionType::Evaluator($crate::tty!([$($tty)+])) + }; + (fn ($($arg:tt)*) -> $($ret:tt)+) => { + $crate::FunctionType::Function(tys!([$($arg)*]), $crate::ty!($($ret)+)) + }; +} + +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub enum BinType { + Eq(Option, Option, Option), + Add(Option, Option, Option), + Sub(Option, Option, Option), + Mul(Option, Option, Option), + Exp(Option, Option, Option), +} + +impl Default for BinType { + fn default() -> Self { + Self::Eq(None, None, None) + } +} + +impl BinType { + pub fn lhs(&self) -> Option { + match self { + Self::Eq(lhs, ..) + | Self::Add(lhs, ..) + | Self::Sub(lhs, ..) + | Self::Mul(lhs, ..) + | Self::Exp(lhs, ..) => *lhs, + } + } + + pub fn lhs_mut(&mut self) -> &mut Option { + match self { + Self::Eq(lhs, ..) + | Self::Add(lhs, ..) + | Self::Sub(lhs, ..) + | Self::Mul(lhs, ..) + | Self::Exp(lhs, ..) => lhs, + } + } + + pub fn rhs(&self) -> Option { + match self { + Self::Eq(_, rhs, _) + | Self::Add(_, rhs, _) + | Self::Sub(_, rhs, _) + | Self::Mul(_, rhs, _) + | Self::Exp(_, rhs, _) => *rhs, + } + } + + pub fn rhs_mut(&mut self) -> &mut Option { + match self { + Self::Eq(_, rhs, _) + | Self::Add(_, rhs, _) + | Self::Sub(_, rhs, _) + | Self::Mul(_, rhs, _) + | Self::Exp(_, rhs, _) => rhs, + } + } + + pub fn result(&self) -> Option { + match self { + Self::Eq(_, _, ret) + | Self::Add(_, _, ret) + | Self::Sub(_, _, ret) + | Self::Mul(_, _, ret) + | Self::Exp(_, _, ret) => *ret, + } + } + + pub fn result_mut(&mut self) -> &mut Option { + match self { + Self::Eq(_, _, ret) + | Self::Add(_, _, ret) + | Self::Sub(_, _, ret) + | Self::Mul(_, _, ret) + | Self::Exp(_, _, ret) => ret, + } + } + + pub fn as_fn(&self) -> FunctionType { + match self { + Self::Eq(lhs, rhs, ret) + | Self::Add(lhs, rhs, ret) + | Self::Sub(lhs, rhs, ret) + | Self::Mul(lhs, rhs, ret) + | Self::Exp(lhs, rhs, ret) => FunctionType::Function(vec![*lhs, *rhs], *ret), + } + } + /// Returns a new [BinType] with all types casted to their [Type::Scalar] equivalent: + /// - `?` -> `_` + /// - `sty` -> `sty` + /// - `sty[len]` -> `sty` + /// - `sty[rows, cols]` -> `sty` + /// + /// This corresponds to the shape `_`. + pub fn without_shape(&self) -> Self { + match self { + Self::Eq(lhs, rhs, ret) => Self::Eq( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Add(lhs, rhs, ret) => Self::Add( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Sub(lhs, rhs, ret) => Self::Sub( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Mul(lhs, rhs, ret) => Self::Mul( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Exp(lhs, rhs, ret) => Self::Exp( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + } + } +} + +impl core::fmt::Display for BinType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Eq(lhs, rhs, None) => write!(f, "{} = {}", lhs.show_ty(), rhs.show_ty()), + Self::Add(lhs, rhs, None) => write!(f, "{} + {}", lhs.show_ty(), rhs.show_ty()), + Self::Sub(lhs, rhs, None) => write!(f, "{} - {}", lhs.show_ty(), rhs.show_ty()), + Self::Mul(lhs, rhs, None) => write!(f, "{} * {}", lhs.show_ty(), rhs.show_ty()), + Self::Exp(lhs, rhs, None) => write!(f, "{} ^ {}", lhs.show_ty(), rhs.show_ty()), + Self::Eq(lhs, rhs, ret) => { + write!(f, "{} = {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Add(lhs, rhs, ret) => { + write!(f, "{} + {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Sub(lhs, rhs, ret) => { + write!(f, "{} - {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Mul(lhs, rhs, ret) => { + write!(f, "{} * {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Exp(lhs, rhs, ret) => { + write!(f, "{} ^ {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + } + } +} + +#[macro_export] +macro_rules! bty { + ($($bty:tt)+ -> $($ret:tt)+) => {{ + let b = $crate::bty!($($bty)+); + b.result_mut().replace($crate::ty!($($ret)+)); + b + }}; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; +} + +impl BinType { + /// Returns the type of the result of an equality based on the types + /// of the left-hand side and right-hand side operands. + /// If the types are not compatible, it returns a [TypeError::IncompatibleBinOp]. + /// + /// Assuming shapes are compatible, the following table shows the result type + /// based on the scalar types of the operands: + /// ? == ? || felt | bool | uint | _ | ? + /// =========||======|======|======|======|===== + /// felt || bool | bool | bool | bool | ? + /// bool || bool | bool | bool | bool | ? + /// uint || bool | bool | bool | bool | ? + /// _ || bool | bool | bool | bool | ? + /// ? || ? | ? | ? | ? | ? + /// + /// So, the result type of an equality is: + /// - an error if lhs or rhs don't have a compatible shape, + /// - symmetric over the operands, + /// - any == ? -> ?, + /// - always `bool` otherwise + pub fn infer_bin_ty_eq(&self) -> Result, TypeError> { + if let Some(ret) = self.result() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if lhs.is_none() || rhs.is_none() { + return Ok(ty!(?)); + } + if self.lhs().is_shape_compatible(&self.rhs()) { + Ok(ty!(bool)) + } else { + Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }) + } + } + + /// Returns the type of the result of an addition based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? + ? || felt | bool | uint | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | felt | felt | felt | felt + /// uint || felt | felt | uint | _ | ? + /// _ || felt | felt | _ | _ | ? + /// ? || felt | felt | ? | ? | ? + /// + /// So, the result type of an addition is: + /// - an error if lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt + any -> felt + /// - bool + any -> felt + /// - ? + any -> ? + /// - uint + uint -> uint + /// - everything else is an unknown scalar type `_` + pub fn infer_bin_ty_add(&self) -> Result, TypeError> { + if let Some(ret) = self.result() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleShapes { lhs, rhs, span: None }); + } + match self { + bty!(felt + any) | bty!(any + felt) => Ok(ty!(felt)), + bty!(bool + any) | bty!(any + bool) => Ok(ty!(felt)), + bty!(? + any) | bty!(any + ?) => Ok(ty!(?)), + bty!(uint + uint) => Ok(ty!(uint)), + _ => Ok(ty!(_)), + } + } + + /// Returns the type of the result of a substraction based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? - ? || felt | bool | uint | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | felt | felt | felt | felt + /// uint || felt | felt | uint | _ | ? + /// _ || felt | felt | _ | _ | ? + /// ? || felt | felt | ? | ? | ? + /// + /// So, the result type of a substraction is: + /// - an error if either lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt - any -> felt + /// - bool - any -> felt + /// - uint - uint -> uint + /// - ? - any -> ? + /// - everything else is an unknown scalar type `_` + /// + /// This is the same as [BinType::infer_bin_ty_add], so it reuses that method. + /// + /// NOTE: if we refine the types as described in #432, this method will need to be + /// updated to handle the substraction of `bool` and `uint` types correctly. + /// This will no longer be symmetric over the operands! + /// Because: + /// - 0 - bool = - bool -> felt + /// - bool - 0 -> bool + /// - 0 - uint = - uint -> uint (or error depending on the design) + /// - uint - 0 -> uint + /// - 1 - bool -> bool + /// - bool - 1 -> felt + pub fn infer_bin_ty_sub(&self) -> Result, TypeError> { + self.infer_bin_ty_add() + } + + /// Returns the type of the result of a multiplication based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? * ? || felt | bool | uint | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | bool | uint | _ | ? + /// uint || felt | uint | uint | _ | ? + /// _ || felt | _ | _ | _ | ? + /// ? || felt | ? | ? | ? | ? + /// + /// So, the result type of a multiplication is: + /// - an error if either lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt * any -> felt + /// - ? * any -> ? + /// - _ * any -> _ + /// - uint * uint -> uint + /// - bool * x -> x + /// - everything else is an unknown scalar type `_` + pub fn infer_bin_ty_mul(&self) -> Result, TypeError> { + if let Some(ret) = self.result() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleShapes { lhs, rhs, span: None }); + } + match self { + bty!(felt * any) | bty!(any * felt) => Ok(ty!(felt)), + bty!(? * any) | bty!(any * ?) => Ok(ty!(?)), + bty!(_ * any) | bty!(any * _) => Ok(ty!(_)), + bty!(uint * uint) => Ok(ty!(uint)), + bty!(bool * any:x) | bty!(any:x * bool) => Ok(*x), + _ => Ok(ty!(_)), + } + } + + /// Returns the type of the result of an exponentiation based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs is not a scalar type, or rhs is not `uint`, + /// it returns a [TypeError::IncompatibleBinOp]. + /// + /// based on the scalar types of the operands: + /// ? ^ ? || felt | bool | uint | _ | ? + /// =========||======|======|======|======|===== + /// felt || err | err | felt | err | err + /// bool || err | err | bool | err | err + /// uint || err | err | uint | err | err + /// _ || err | err | _ | err | err + /// ? || err | err | ? | err | err + /// + /// So, the result type of an exponentiation is: + /// - an error if either lhs or rhs isn't scalar types, + /// - an error if the rhs is not an uint + /// - the lhs type otherwise + /// + /// Because: + /// - it is an error if rhs is not an uint + /// - a bool to any power is still a bool: + /// - 0^n = 0 + /// - 1^n = 1 + /// - a felt to any power is still a felt + /// - an uint to any power is still an uint + /// - a _ to any power is still a _ + /// - a ? to any power is still a ? + pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { + if let Some(ret) = self.result() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }); + } + match self { + bty!(any ^ uint) => Ok(lhs), + bty!(any ^ felt) | bty!(any ^ bool) | bty!(any ^ _) | bty!(any ^ ?) => { + Err(TypeError::NonConstantExponent { bin_ty: *self, span: None }) + }, + _ => unreachable!("Undefined case for infer_bin_ty_exp: {self}"), + } + } +} + +#[derive(Hash, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] +pub enum Kind { + Value(Option), + Aggregate(Vec>>), + Callable(FunctionType), +} + +impl Default for Kind { + fn default() -> Self { + Self::Value(None) + } +} + +impl core::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Value(ty) => write!(f, "{}", ty.show_ty()), + Self::Aggregate(tys) => { + write!( + f, + "[{}]", + tys.iter() + .map(|ty| ty + .as_ref() + .map_or("?".to_string(), |k| k.show_kind().to_string())) + .collect::>() + .join(", ") + ) + }, + Self::Callable(fty) => write!(f, "{}", fty.show_fn_ty()), + } + } +} + +#[macro_export] +macro_rules! kind { + (ev $($spec:tt)+) => { + $crate::Kind::Callable($crate::fty!(ev $($spec)+)) + }; + (fn ($($args:tt)*) -> $($ret:tt)+) => { + $crate::Kind::Callable($crate::fty!(fn ($($args)*) -> $($ret)+)) + }; + ([$($spec:tt)+]) => { + $crate::Kind::Aggregate(kinds!([$($spec)+])) + }; + ($($spec:tt)+) => { + $crate::Kind::Value($crate::ty!($($spec)+)) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_macro_scalar_type() { + assert_eq!(sty!(_), None::); + assert_eq!(sty!(felt), Some(ScalarType::Felt)); + assert_eq!(sty!(bool), Some(ScalarType::Bool)); + assert_eq!(sty!(uint), Some(ScalarType::UInt)); + } + + #[test] + fn test_macro_type() { + assert_eq!(ty!(?), None::); + assert_eq!(ty!(_), Some(Type::Scalar(None))); + assert_eq!(ty!(felt), Some(Type::Scalar(Some(ScalarType::Felt)))); + assert_eq!(ty!(bool), Some(Type::Scalar(Some(ScalarType::Bool)))); + assert_eq!(ty!(uint), Some(Type::Scalar(Some(ScalarType::UInt)))); + assert_eq!(ty!(_[5]), Some(Type::Vector(None, 5))); + assert_eq!(ty!(uint[5]), Some(Type::Vector(Some(ScalarType::UInt), 5))); + assert_eq!(ty!(_[3, 4]), Some(Type::Matrix(None, 3, 4))); + assert_eq!(ty!(felt[3, 4]), Some(Type::Matrix(Some(ScalarType::Felt), 3, 4))); + } + + #[test] + fn test_macro_trace_segment_type() { + assert_eq!(tty!(a), ty!(felt)); + assert_eq!(tty!(a[5]), ty!(felt[5])); + assert_eq!(tty!([]), Vec::>::new()); + assert_eq!(tty!([a]), vec![ty!(felt)]); + assert_eq!(tty!([a[5]]), vec![ty!(felt[5])]); + assert_eq!(tty!([a[1], b[3]]), vec![ty!(felt), ty!(felt[3])]); + } + + #[test] + fn test_macro_function_type() { + assert_eq!(fty!(ev([])), FunctionType::Evaluator(vec![])); + assert_eq!(fty!(ev([a])), FunctionType::Evaluator(vec![ty!(felt)])); + assert_eq!(fty!(ev([a[5]])), FunctionType::Evaluator(vec![ty!(felt[5])])); + assert_eq!(fty!(ev([a, b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); + assert_eq!(fty!(ev([a[1], b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); + assert_eq!(fty!(ev([a[1], b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); + + assert_eq!(fty!(fn(uint) -> felt), FunctionType::Function(vec![ty!(uint)], ty!(felt))); + assert_eq!( + fty!(fn(uint[5]) -> felt[3, 4]), + FunctionType::Function(vec![ty!(uint[5])], ty!(felt[3, 4]),) + ); + assert_eq!( + fty!(fn(uint[5], felt) -> felt[3, 4]), + FunctionType::Function(vec![ty!(uint[5]), ty!(felt)], ty!(felt[3, 4]),) + ); + assert_eq!( + fty!(fn(uint[5], felt, bool[3, 4]) -> felt[3, 4]), + FunctionType::Function( + vec![ty!(uint[5]), ty!(felt), ty!(bool[3, 4]),], + ty!(felt[3, 4]), + ) + ); + } + + #[test] + fn test_macro_bin_type() { + assert_eq!(bty!(uint + felt), BinType::Add(ty!(uint), ty!(felt), ty!(?))); + assert_eq!(bty!(_ - felt), BinType::Sub(ty!(_), ty!(felt), ty!(?))); + assert_eq!(bty!(? = felt), BinType::Eq(ty!(?), ty!(felt), ty!(?))); + assert_eq!(bty!(uint + ?), BinType::Add(ty!(uint), ty!(?), ty!(?))); + assert_eq!(bty!(uint - felt), BinType::Sub(ty!(uint), ty!(felt), ty!(?))); + assert_eq!(bty!(uint[2] * felt[2]), BinType::Mul(ty!(uint[2]), ty!(felt[2]), ty!(?))); + assert_eq!(bty!(uint[2, 3] ^ _), BinType::Exp(ty!(uint[2, 3]), ty!(_), ty!(?))); + assert_eq!(bty!(bool[5] = _[5]), BinType::Eq(ty!(bool[5]), ty!(_[5]), ty!(?))); + } + + #[test] + fn test_macro_kind() { + assert_eq!(kind!(ev([])), Kind::Callable(fty!(ev([])))); + assert_eq!(kind!(ev([a])), Kind::Callable(fty!(ev([a])))); + assert_eq!(kind!(fn(uint) -> felt), Kind::Callable(fty!(fn(uint) -> felt))); + assert_eq!(kind!(uint), Kind::Value(ty!(uint))); + assert_eq!(kind!(_), Kind::Value(ty!(_))); + assert_eq!(kind!(bool[3, 4]), Kind::Value(ty!(bool[3, 4]))); + } + + #[test] + fn test_fn_ty_check_param_kinds() { + // Scalar types + assert!(fty!(fn(uint, felt) -> felt).check_args_kinds(&[&kind!(uint), &kind!(felt)]),); + assert!(fty!(fn(felt, felt) -> felt).check_args_kinds(&[&kind!(felt), &kind!(bool)]),); + // Vector types + assert!(fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(_[3]), &kind!(felt[2])])); + assert!( + fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(bool[3]), &kind!(uint[2])]) + ); + // Aggregate types + assert!(fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, bool, bool]), + &kind!([uint, uint]), + ])); + + // Negative cases + + // Scalar types + assert!(!fty!(fn(uint, bool) -> felt).check_args_kinds(&[&kind!(bool), &kind!(felt)]),); + assert!(!fty!(fn(felt, bool) -> felt).check_args_kinds(&[&kind!(felt), &kind!(uint[2])]),); + // Vector types + assert!(!fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(_), &kind!(felt[2])])); + assert!( + !fty!(fn(_[3], felt[2]) -> felt) + .check_args_kinds(&[&kind!(bool[3, 5]), &kind!(uint[2])]) + ); + // Aggregate types + assert!(!fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, felt, uint]), + &kind!([uint, uint]), + ])); + assert!(!fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, bool]), + &kind!([uint, uint]), + ])); + } +}