diff --git a/Cargo.toml b/Cargo.toml index 6c81127a..cbb69d56 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,9 +6,21 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[profile.samply] +inherits = "release" +debug = true + +[[bin]] +name = "descendc" +path = "src/main.rs" + [dependencies.peg] version = "0.8.0" +[dependencies.bumpalo] +version="3.17.0" +features=["collections", "boxed"] + [dependencies.annotate-snippets] version = "0.9.0" features = ["color"] @@ -16,5 +28,24 @@ features = ["color"] [dependencies.descend_derive] path = "./descend_derive" +[dependencies.clap] +version = "4.3" +features = ["derive"] + +[dependencies.which] +version = "7.0.2" + +[dependencies.log] +version = "0.4.27" + +[dependencies.env_logger] +version = "0.11.7" + +[dependencies.predicates] +version= "3.1.3" + +[dependencies.assert_cmd] +version = "2.0.16" + [workspace] -members = ["descend_derive"] \ No newline at end of file +members = ["descend_derive"] diff --git a/README.md b/README.md index 6a1773a1..2281bb7d 100644 --- a/README.md +++ b/README.md @@ -110,7 +110,7 @@ codegen cuda-examples/ --------------------- -* Contains handwritte or generated CUDA programs +* Contains handwritten or generated CUDA programs * Contains `descend.cuh`; the header file which is required in order to compile Descend programs, that were translated to CUDA, with `nvcc` (contains for example the implementation of `exec`) diff --git a/descend_derive/src/lib.rs b/descend_derive/src/lib.rs index 21fb372e..c4589d30 100644 --- a/descend_derive/src/lib.rs +++ b/descend_derive/src/lib.rs @@ -1,7 +1,6 @@ -use std::collections::HashSet; - use proc_macro::TokenStream; use quote::quote; +use std::collections::HashSet; use syn::{ parse::{Parse, ParseStream}, parse_macro_input, @@ -9,6 +8,154 @@ use syn::{ Fields, FieldsNamed, Ident, ItemStruct, Token, Type, TypeReference, }; +struct Args { + vars: HashSet, +} + +impl Parse for Args { + fn parse(input: ParseStream) -> syn::parse::Result { + let vars = Punctuated::::parse_terminated(input)?; + Ok(Args { + vars: vars.into_iter().collect(), + }) + } +} + +const IGNORE_ATTR_NAME: &str = "span_derive_ignore"; + +#[proc_macro_attribute] +pub fn span_derive(attr: TokenStream, input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + let args = parse_macro_input!(attr as Args); + if args.vars.is_empty() { + panic!("Need at least one attribute in span_derive!"); + } + + let original_name = &input.ident; + let original_generics = input.generics.clone(); + let (impl_generics, ty_generics, where_clause) = original_generics.split_for_impl(); + + let original_fields = match &input.fields { + syn::Fields::Named(fields) => fields, + _ => unreachable!(), + }; + + let mut new_fields = original_fields.clone(); + for field in new_fields.named.iter_mut() { + field + .attrs + .retain(|attr| !attr.path.is_ident(IGNORE_ATTR_NAME)); + } + + let mut original = input.clone(); + original.fields = Fields::Named(new_fields); + + let mut helper = input.clone(); + + let mut helper_generics = original_generics.clone(); + let mut new_params = Punctuated::new(); + new_params.push(syn::parse_quote!('helper)); + for param in original_generics.params.iter() { + new_params.push(param.clone()); + } + helper_generics.params = new_params; + helper.generics = helper_generics.clone(); + + let mut helper_fields = FieldsNamed { + brace_token: original_fields.brace_token.clone(), + named: Punctuated::new(), + }; + for mut field in original_fields.named.clone().into_iter() { + if field + .attrs + .iter() + .any(|attr| attr.path.is_ident(IGNORE_ATTR_NAME)) + { + continue; + } + let mut typ: TypeReference = syn::parse_str("&'helper i64").unwrap(); + typ.elem = Box::new(field.ty); + field.ty = Type::Reference(typ); + helper_fields.named.push(field); + } + helper.fields = Fields::Named(helper_fields); + + let helper_name = Ident::new( + &format!("__{}SpanHelper", original_name.to_string()), + input.ident.span(), + ); + helper.ident = helper_name.clone(); + helper.attrs.clear(); + + let into_fields = helper + .fields + .iter() + .map(|field| field.ident.as_ref().unwrap().clone()) + .collect::>(); + + let mut derive_args: Vec<_> = args.vars.iter().collect(); + derive_args.sort_by(|a, b| a.to_string().cmp(&b.to_string())); + + let (helper_impl_generics, helper_ty_generics, helper_where_clause) = + helper_generics.split_for_impl(); + + let mut output = quote! { + #original + + #[derive(#(#derive_args),*)] + #helper + + impl #helper_impl_generics From<&'helper #original_name #ty_generics> + for #helper_name #helper_ty_generics #helper_where_clause { + fn from(orig: &'helper #original_name #ty_generics) -> Self { + #helper_name { + #(#into_fields: &orig.#into_fields),* + } + } + } + }; + + for derive_arg in args.vars.iter() { + match &derive_arg.to_string() as &str { + "PartialEq" => { + output.extend(quote! { + impl #impl_generics ::core::cmp::PartialEq for #original_name #ty_generics #where_clause { + fn eq(&self, other: &Self) -> bool { + let helper = #helper_name::from(self); + let helper_other = #helper_name::from(other); + helper == helper_other + } + } + }); + } + "Eq" => { + output.extend(quote! { + impl #impl_generics ::core::cmp::Eq for #original_name #ty_generics #where_clause {} + }); + } + "Hash" => { + output.extend(quote! { + impl #impl_generics ::core::hash::Hash for #original_name #ty_generics #where_clause { + fn hash(&self, state: &mut H) { + let helper = #helper_name::from(self); + helper.hash(state); + } + } + }); + } + other => panic!("span_derive not implemented for {}", other), + } + } + + eprintln!( + "--- Generated code ---\n{}\n------------------------", + output + ); + + TokenStream::from(output) +} + +/*** OLD WAY // Copy paste from syn example struct Args { vars: HashSet, @@ -148,3 +295,4 @@ pub fn span_derive(attr: TokenStream, input: TokenStream) -> TokenStream { TokenStream::from(output) } +*/ diff --git a/profile.json.gz b/profile.json.gz new file mode 100644 index 00000000..86077875 Binary files /dev/null and b/profile.json.gz differ diff --git a/profile_memory.sh b/profile_memory.sh new file mode 100755 index 00000000..40c44ae7 --- /dev/null +++ b/profile_memory.sh @@ -0,0 +1,6 @@ +#!/bin/bash +INPUT="examples/infer/matmul.desc" + +cargo build --profile samply + +samply record -- ./target/samply/descendc emit "$INPUT" \ No newline at end of file diff --git a/src/arena_ast/internal.rs b/src/arena_ast/internal.rs new file mode 100644 index 00000000..e0aaa2e4 --- /dev/null +++ b/src/arena_ast/internal.rs @@ -0,0 +1,219 @@ +// Constructs in this module are part of the AST but not part of the user facing syntax. +// These are also used in typechecking and ty_check::ctxs + +// TODO specific access modifiers + +use super::{Ident, Ownership, PlaceExpr, Ty}; +use crate::arena_ast::{ExecExpr, Mutability, Nat, PlaceExprKind, View}; +use bumpalo::collections::Vec as BumpVec; +use std::collections::HashSet; + +// TODO: Removed the Default trait here, see what kind of consequences has this later +// Otherwise implement the trait +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct Frame<'a> { + pub bindings: BumpVec<'a, FrameEntry<'a>>, +} + +impl<'a> Frame<'a> { + pub fn new_in(bump: &'a bumpalo::Bump) -> Self { + Self { + bindings: BumpVec::new_in(bump), + } + } + + pub fn append_idents_typed(&mut self, idents_typed: I) + where + I: IntoIterator>, + { + for ident in idents_typed { + self.bindings.push(FrameEntry::Var(ident)); + } + } +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum FrameEntry<'a> { + Var(IdentTyped<'a>), + ExecMapping(ExecMapping<'a>), + PrvMapping(PrvMapping<'a>), +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct IdentTyped<'a> { + pub ident: Ident<'a>, + pub ty: Ty<'a>, + pub mutbl: Mutability, + pub exec: ExecExpr<'a>, +} + +impl<'a> IdentTyped<'a> { + pub fn new_in( + arena: &'a bumpalo::Bump, + ident: &'a str, + ty: Ty<'a>, + mutbl: Mutability, + exec: ExecExpr<'a>, + ) -> Self { + IdentTyped { + ident: Ident::new(arena, ident), + ty, + mutbl, + exec, + } + } +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct ExecMapping<'a> { + pub ident: Ident<'a>, + pub exec_expr: ExecExpr<'a>, +} + +impl<'a> ExecMapping<'a> { + pub fn new(ident: Ident<'a>, exec_expr: ExecExpr<'a>) -> Self { + ExecMapping { ident, exec_expr } + } +} + +// TODO: Problems with HashSet and String in the Arena implementation --> Find a work +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct PrvMapping<'a> { + pub prv: String, + pub loans: HashSet>, +} + +impl<'a> PrvMapping<'a> { + pub fn new(name: &str) -> Self { + PrvMapping { + prv: name.to_string(), + loans: HashSet::new(), + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct Loan<'a> { + pub place_expr: PlaceExpr<'a>, + pub own: Ownership, +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum PathElem<'a> { + Proj(usize), + FieldProj(&'a Ident<'a>), +} +pub type Path<'a> = BumpVec<'a, PathElem<'a>>; + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct Place<'a> { + pub ident: Ident<'a>, + pub path: Path<'a>, +} +impl<'a> Place<'a> { + pub fn new(ident: Ident<'a>, path: Path<'a>) -> Self { + Place { ident, path } + } + + pub fn to_place_expr(&self, arena: &'a bumpalo::Bump) -> PlaceExpr { + self.path.iter().fold( + PlaceExpr::new(PlaceExprKind::Ident(self.ident.clone())), + |pl_expr, path_entry| match path_entry { + PathElem::Proj(n) => PlaceExpr::new(PlaceExprKind::Proj(arena.alloc(pl_expr), *n)), + PathElem::FieldProj(field) => { + PlaceExpr::new(PlaceExprKind::FieldProj(arena.alloc(pl_expr), field)) + } + }, + ) + } + + /** + pub fn prefix_of(&self, other: &Self) -> bool { + if self.path.len() > other.path.len() { + return false; + } + self.ident == other.ident && &self.path == &other.path[..self.path.len()] + }*/ + + pub fn prefix_of(&self, other: &Self) -> bool { + if self.ident != other.ident || self.path.len() > other.path.len() { + return false; + } + + other.path.iter().zip(&self.path).all(|(a, b)| a == b) + } +} + +pub enum PlaceCtx<'a> { + Proj(&'a PlaceCtx<'a>, usize), + FieldProj(&'a PlaceCtx<'a>, Ident<'a>), + Deref(&'a PlaceCtx<'a>), + Select(&'a PlaceCtx<'a>, &'a ExecExpr<'a>), + View(&'a PlaceCtx<'a>, &'a View<'a>), + Idx(&'a PlaceCtx<'a>, &'a Nat<'a>), + Hole, +} + +impl<'a> PlaceCtx<'a> { + pub fn insert_pl_expr( + &'a self, + arena: &'a bumpalo::Bump, + pl_expr: PlaceExpr<'a>, + ) -> PlaceExpr<'a> { + match self { + Self::Hole => pl_expr, + Self::Proj(pl_ctx, n) => PlaceExpr::new(PlaceExprKind::Proj( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + *n, + )), + Self::FieldProj(pl_ctx, field) => PlaceExpr::new(PlaceExprKind::FieldProj( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + field, + )), + Self::Deref(pl_ctx) => PlaceExpr::new(PlaceExprKind::Deref( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + )), + Self::Select(pl_ctx, exec) => PlaceExpr::new(PlaceExprKind::Select( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + exec.clone(), + )), + Self::View(pl_ctx, view) => PlaceExpr::new(PlaceExprKind::View( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + view.clone(), + )), + Self::Idx(pl_ctx, idx) => PlaceExpr::new(PlaceExprKind::Idx( + arena.alloc(pl_ctx.insert_pl_expr(arena, pl_expr)), + idx.clone(), + )), + } + } + + pub fn without_innermost_deref(&'a self, arena: &'a bumpalo::Bump) -> &'a PlaceCtx<'a> { + match self { + PlaceCtx::Hole => self, + PlaceCtx::Proj(pl_ctx, idx) => { + arena.alloc(PlaceCtx::Proj(pl_ctx.without_innermost_deref(arena), *idx)) + } + PlaceCtx::FieldProj(pl_ctx, ident) => arena.alloc(PlaceCtx::FieldProj( + pl_ctx.without_innermost_deref(arena), + ident.clone(), + )), + PlaceCtx::Deref(pl_ctx) => match **pl_ctx { + PlaceCtx::Hole => arena.alloc(PlaceCtx::Hole), + _ => arena.alloc(PlaceCtx::Deref(pl_ctx.without_innermost_deref(arena))), + }, + PlaceCtx::Select(pl_ctx, exec) => arena.alloc(PlaceCtx::Select( + pl_ctx.without_innermost_deref(arena), + exec.clone(), + )), + PlaceCtx::View(pl_ctx, view) => arena.alloc(PlaceCtx::View( + pl_ctx.without_innermost_deref(arena), + view.clone(), + )), + PlaceCtx::Idx(pl_ctx, idx) => arena.alloc(PlaceCtx::Idx( + pl_ctx.without_innermost_deref(arena), + idx.clone(), + )), + } + } +} diff --git a/src/arena_ast/mod.rs b/src/arena_ast/mod.rs new file mode 100644 index 00000000..8bdfba3a --- /dev/null +++ b/src/arena_ast/mod.rs @@ -0,0 +1,2229 @@ +use std::fmt; + +use crate::arena_ast::internal::PathElem; +use bumpalo::{collections::Vec as BumpVec, Bump}; +use descend_derive::span_derive; + +use crate::ast::Span; +use crate::parser::SourceCode; +pub mod internal; + +pub mod printer; +pub mod utils; +pub mod visit; +pub mod visit_mut; + +use std::cell::OnceCell; + +#[derive(Debug)] +pub struct CompilUnit<'a> { + pub items: BumpVec<'a, Item<'a>>, + pub source: &'a SourceCode<'a>, +} + +impl<'a> CompilUnit<'a> { + pub fn new(items: BumpVec<'a, Item<'a>>, source: &'a SourceCode<'a>) -> Self { + CompilUnit { items, source } + } +} + +#[derive(Debug)] +pub enum Item<'a> { + FunDef(&'a FunDef<'a>), + FunDecl(&'a FunDecl<'a>), + StructDecl(&'a StructDecl<'a>), +} + +#[derive(Debug, PartialEq)] +pub struct FunDecl<'a> { + pub ident: Ident<'a>, + pub generic_params: BumpVec<'a, IdentKinded<'a>>, + pub generic_exec: Option>, + pub param_decls: BumpVec<'a, ParamDecl<'a>>, + pub ret_dty: &'a DataTy<'a>, + pub exec: ExecExpr<'a>, + pub prv_rels: BumpVec<'a, PrvRel<'a>>, +} + +impl<'a> FunDecl<'a> { + pub fn fn_ty(&self, bump: &'a Bump) -> FnTy<'a> { + let mut param_sigs = BumpVec::new_in(bump); + for p_decl in &self.param_decls { + let exec_expr = p_decl.exec_expr.as_ref().unwrap_or(&self.exec).clone(); + let ty = p_decl.ty.as_ref().unwrap().clone(); // This may need arena allocation too + param_sigs.push(ParamSig::new(exec_expr, ty)); + } + + let mut generics = BumpVec::new_in(bump); + generics.extend(self.generic_params.iter().cloned()); + + FnTy::new( + bump, + generics, + self.generic_exec.clone(), + param_sigs, + self.exec.clone(), + bump.alloc(Ty { + ty: TyKind::Data(self.ret_dty), + span: None, + }), + [], + ) + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> FunDecl<'a> { + let mut generic_params = BumpVec::new_in(arena); + generic_params.extend(self.generic_params.iter().cloned()); + + let generic_exec = self.generic_exec.clone(); + + let mut param_decls = BumpVec::new_in(arena); + param_decls.extend(self.param_decls.iter().cloned()); + + let mut prv_rels = BumpVec::new_in(arena); + prv_rels.extend(self.prv_rels.iter().cloned()); + + FunDecl { + ident: self.ident.clone(), + generic_params, + generic_exec, + param_decls, + ret_dty: self.ret_dty, // copy pointer; visitor will re-point if needed + exec: self.exec.clone(), + prv_rels, + } + } +} + +#[derive(Debug, Clone, Eq, Hash, PartialEq)] +pub struct StructDecl<'a> { + pub ident: Ident<'a>, + pub generic_params: BumpVec<'a, IdentKinded<'a>>, + pub fields: BumpVec<'a, (Ident<'a>, DataTy<'a>)>, +} + +// TODO refactor to make use of FunDecl +#[derive(Debug, Clone, PartialEq)] +pub struct FunDef<'a> { + pub ident: Ident<'a>, + pub generic_params: BumpVec<'a, IdentKinded<'a>>, + pub generic_exec: Option>, + pub param_decls: BumpVec<'a, ParamDecl<'a>>, + pub ret_dty: &'a DataTy<'a>, + pub exec: ExecExpr<'a>, + pub prv_rels: BumpVec<'a, PrvRel<'a>>, + pub body: &'a Block<'a>, +} + +impl<'a> FunDef<'a> { + pub fn fn_ty(&self, bump: &'a Bump) -> FnTy<'a> { + let mut param_sigs = BumpVec::new_in(bump); + for p_decl in &self.param_decls { + let exec_expr = p_decl.exec_expr.as_ref().unwrap_or(&self.exec).clone(); + let ty = p_decl.ty.expect("Missing parameter type"); + let ty_ref = bump.alloc(ty.clone()); + param_sigs.push(ParamSig::new(exec_expr, ty_ref)); + } + + let mut generics = BumpVec::new_in(bump); + generics.extend(self.generic_params.iter().cloned()); + + let ret_ty = bump.alloc(Ty { + ty: TyKind::Data(self.ret_dty), + span: None, + }); + + FnTy::new( + bump, + generics, + self.generic_exec.clone(), + param_sigs, + self.exec.clone(), + ret_ty, + [], + ) + } + + pub fn clone_in(&self, arena: &'a Bump) -> FunDef<'a> { + let mut generic_params = BumpVec::new_in(arena); + generic_params.extend(self.generic_params.iter().cloned()); + + let mut param_decls = BumpVec::new_in(arena); + param_decls.extend(self.param_decls.iter().cloned()); + + let mut prv_rels = BumpVec::new_in(arena); + prv_rels.extend(self.prv_rels.iter().cloned()); + + FunDef { + ident: self.ident.clone(), + generic_params, + generic_exec: self.generic_exec.clone(), + param_decls, + ret_dty: arena.alloc(self.ret_dty.clone_in(arena)), + exec: self.exec.clone_in(arena), + prv_rels, + body: arena.alloc(self.body.clone_in(arena)), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub struct IdentExec<'a> { + pub ident: Ident<'a>, + pub ty: &'a ExecTy<'a>, +} + +impl<'a> IdentExec<'a> { + pub fn new_in(bump: &'a bumpalo::Bump, ident: Ident<'a>, exec_ty: ExecTy<'a>) -> Self { + IdentExec { + ident, + ty: bump.alloc(exec_ty), + } + } + + pub fn clone_in(&self, arena: &'a Bump) -> Self { + let cloned = self.ty.clone_in(arena); // ExecTy<'a> + IdentExec { + ident: self.ident.clone(), // shallow clone of Ident<'a> is fine + ty: arena.alloc(cloned), // store &'a ExecTy<'a> + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct ParamDecl<'a> { + pub ident: Ident<'a>, + pub ty: Option<&'a Ty<'a>>, + pub mutbl: Mutability, + pub exec_expr: Option>, +} + +#[span_derive(PartialEq)] +#[derive(Debug, Clone)] +pub struct Expr<'a> { + pub expr: ExprKind<'a>, + // FIXME misusing span_derive_ignore to ignore type on equality checks + #[span_derive_ignore] + pub ty: Option<&'a Ty<'a>>, + #[span_derive_ignore] + pub span: Option, +} + +impl<'a> Expr<'a> { + pub fn new(expr: ExprKind<'a>) -> Self { + Expr { + expr, + ty: None, + span: None, + } + } + + pub fn with_span(expr: ExprKind<'a>, span: Span) -> Self { + Expr { + expr, + ty: None, + span: Some(span), + } + } + + pub fn with_type(expr: ExprKind<'a>, ty: &'a Ty<'a>) -> Self { + Expr { + expr, + ty: Some(ty), + span: None, + } + } + + // pub fn subst_idents(&mut self, subst_map: &HashMap<&str, &Expr>) { + // fn pl_expr_contains_name_in<'a, I>(pl_expr: &PlaceExpr, mut idents: I) -> bool + // where + // I: Iterator, + // { + // match &pl_expr.pl_expr { + // PlaceExprKind::Ident(ident) => idents.any(|name| ident.name.as_ref() == *name), + // PlaceExprKind::Proj(tuple, _) => pl_expr_contains_name_in(tuple, idents), + // PlaceExprKind::Deref(deref) => pl_expr_contains_name_in(deref, idents), + // PlaceExprKind::Select(pl_expr, _) => pl_expr_contains_name_in(pl_expr, idents), + // PlaceExprKind::SplitAt(_, pl_expr) => pl_expr_contains_name_in(pl_expr, idents), + // PlaceExprKind::View(pl_expr, _) => pl_expr_contains_name_in(pl_expr, idents), + // PlaceExprKind::Idx(pl_expr, _) => pl_expr_contains_name_in(pl_expr, idents), + // } + // } + // + // struct SubstIdents<'a> { + // subst_map: &'a HashMap<&'a str, &'a Expr>, + // } + // impl VisitMut for SubstIdents<'_> { + // fn visit_pl_expr(&mut self, pl_expr: &mut PlaceExpr) { + // if pl_expr_contains_name_in(pl_expr, self.subst_map.keys()) { + // match &pl_expr.pl_expr { + // PlaceExprKind::Ident(ident) => { + // let subst_expr = + // self.subst_map.get::(ident.name.as_ref()).unwrap(); + // if let ExprKind::PlaceExpr(pl_e) = &subst_expr.expr { + // *pl_expr = pl_e.as_ref().clone(); + // } else { + // // TODO can this happen? + // panic!("How did this happen?") + // } + // } + // _ => visit_mut::walk_pl_expr(self, pl_expr), + // } + // } + // } + // + // fn visit_expr(&mut self, expr: &mut Expr) { + // match &expr.expr { + // ExprKind::PlaceExpr(pl_expr) => { + // if pl_expr_contains_name_in(pl_expr, self.subst_map.keys()) { + // match &pl_expr.pl_expr { + // PlaceExprKind::Ident(ident) => { + // if let Some(&subst_expr) = + // self.subst_map.get::(ident.name.as_ref()) + // { + // *expr = subst_expr.clone(); + // } + // } + // PlaceExprKind::Proj(tuple, i) => { + // let mut tuple_expr = Expr::new(ExprKind::PlaceExpr(Box::new( + // tuple.as_ref().clone(), + // ))); + // self.visit_expr(&mut tuple_expr); + // *expr = Expr::new(ExprKind::Proj(Box::new(tuple_expr), *i)); + // } + // PlaceExprKind::Deref(deref_expr) => { + // let mut ref_expr = Expr::new(ExprKind::PlaceExpr(Box::new( + // deref_expr.as_ref().clone(), + // ))); + // self.visit_expr(&mut ref_expr); + // *expr = Expr::new(ExprKind::Deref(Box::new(ref_expr))); + // } + // PlaceExprKind::Select(_, _) + // | PlaceExprKind::SplitAt(_, _) + // | PlaceExprKind::Idx(_, _) + // | PlaceExprKind::View(_, _) => { + // unimplemented!() + // } + // } + // } + // } + // _ => visit_mut::walk_expr(self, expr), + // } + // } + // } + // let mut subst_idents = SubstIdents { subst_map }; + // subst_idents.visit_expr(self); + // } +} + +#[derive(PartialEq, Debug, Clone)] +pub struct Sched<'a> { + pub dim: DimCompo, + pub inner_exec_ident: Option>, + pub sched_exec: &'a ExecExpr<'a>, + pub body: &'a Block<'a>, +} + +impl<'a> Sched<'a> { + pub fn new_in( + arena: &'a bumpalo::Bump, + dim: DimCompo, + inner_exec_ident: Option>, + sched_exec: ExecExpr<'a>, + body: Block<'a>, + ) -> Self { + Sched { + dim, + inner_exec_ident, + sched_exec: arena.alloc(sched_exec), + body: arena.alloc(body), + } + } +} + +#[derive(PartialEq, Debug)] +pub struct Split<'a> { + pub dim_compo: DimCompo, + pub pos: Nat<'a>, + pub split_exec: &'a ExecExpr<'a>, + pub branch_idents: BumpVec<'a, Ident<'a>>, + pub branch_bodies: BumpVec<'a, Expr<'a>>, +} + +impl<'a> Split<'a> { + pub fn new( + bump: &'a bumpalo::Bump, + dim_compo: DimCompo, + pos: Nat<'a>, + split_exec: ExecExpr<'a>, + branch_idents: impl IntoIterator>, + branch_bodies: impl IntoIterator>, + ) -> Self { + let split_exec = bump.alloc(split_exec); + + let mut idents = BumpVec::new_in(bump); + idents.extend(branch_idents); + + let mut bodies = BumpVec::new_in(bump); + bodies.extend(branch_bodies); + + Split { + dim_compo, + pos, + split_exec, + branch_idents: idents, + branch_bodies: bodies, + } + } +} + +#[derive(PartialEq, Debug, Clone)] +pub struct Block<'a> { + pub prvs: BumpVec<'a, String>, + pub body: &'a Expr<'a>, +} + +impl<'a> Block<'a> { + pub fn new(bump: &'a bumpalo::Bump, body: Expr<'a>) -> Self { + Block { + prvs: BumpVec::new_in(bump), + body: bump.alloc(body), + } + } + + pub fn with_prvs( + bump: &'a bumpalo::Bump, + prvs: impl IntoIterator, + body: Expr<'a>, + ) -> Self { + let mut prvs_vec = BumpVec::new_in(bump); + prvs_vec.extend(prvs); + Block { + prvs: prvs_vec, + body: bump.alloc(body), + } + } +} + +#[derive(PartialEq, Debug)] +pub struct AppKernel<'a> { + pub grid_dim: Dim<'a>, + pub block_dim: Dim<'a>, + pub shared_mem_dtys: BumpVec<'a, DataTy<'a>>, + pub shared_mem_prvs: BumpVec<'a, String>, + pub fun_ident: &'a Ident<'a>, + pub gen_args: BumpVec<'a, ArgKinded<'a>>, + pub args: BumpVec<'a, Expr<'a>>, +} + +#[derive(PartialEq, Debug, Clone)] +pub enum ExprKind<'a> { + Hole, + Lit(Lit), + // An l-value equivalent: *p, p.n, x + PlaceExpr(&'a PlaceExpr<'a>), + // e.g., [1, 2 + 3, 4] + Array(BumpVec<'a, Expr<'a>>), + Tuple(BumpVec<'a, Expr<'a>>), + // Borrow Expressions + Ref(Option, Ownership, &'a PlaceExpr<'a>), + Block(&'a Block<'a>), + // Variable declaration + // let mut x: ty; + LetUninit(Option<&'a ExecExpr<'a>>, Ident<'a>, &'a Ty<'a>), + // let w x: ty = e1 + Let(Pattern<'a>, Option<&'a Ty<'a>>, &'a Expr<'a>), + // Assignment to existing place [expression] + Assign(&'a PlaceExpr<'a>, &'a Expr<'a>), + // e1[i] = e2 + IdxAssign(&'a PlaceExpr<'a>, Nat<'a>, &'a Expr<'a>), + // e1 ; e2 + Seq(BumpVec<'a, Expr<'a>>), + // Anonymous function which can capture its surrounding context + // | x_n: d_1, ..., x_n: d_n | [exec]-> d_r { e } + // TODO body expression should always be block?! No but treated like one. + //Lambda(Vec, Ident<'a>Exec, Box, Box), + // Function application + // e_f(e_1, ..., e_n) + App( + &'a Ident<'a>, + BumpVec<'a, ArgKinded<'a>>, + BumpVec<'a, Expr<'a>>, + ), + DepApp(Ident<'a>, BumpVec<'a, ArgKinded<'a>>), + //AppKernel(&'a AppKernel<'a>), + AppKernel(&'a AppKernel<'a>), + // TODO branches must be blocks + IfElse(&'a Expr<'a>, &'a Expr<'a>, &'a Expr<'a>), + // TODO branch must be block + If(&'a Expr<'a>, &'a Expr<'a>), + // For-each loop. + // for x in e_1 { e_2 } + // TODO body must be block + For(Ident<'a>, &'a Expr<'a>, &'a Expr<'a>), + // for n in range(..) { e } + // TODO body must be block + ForNat(Ident<'a>, &'a NatRange<'a>, &'a Expr<'a>), + // while( e_1 ) { e_2 } + // TODO body must be block + While(&'a Expr<'a>, &'a Expr<'a>), + BinOp(BinOp, &'a Expr<'a>, &'a Expr<'a>), + UnOp(UnOp, &'a Expr<'a>), + Cast(&'a Expr<'a>, &'a DataTy<'a>), + // TODO branches must be blocks or treated like blocks + Split(&'a Split<'a>), + Sched(&'a Sched<'a>), + Sync(Option>), + Unsafe(&'a Expr<'a>), + Range(&'a Expr<'a>, &'a Expr<'a>), +} + +#[derive(Clone, Debug)] +#[span_derive(PartialEq, Eq, Hash)] +pub struct Ident<'a> { + // Identifier names never change. Instead a new identifier is created. Therefore it is not + // necessary to keep the capacity that is stored in a String for efficient appending. + pub name: &'a str, + #[span_derive_ignore] + pub span: Option, + pub is_implicit: bool, +} +impl<'a> Ident<'a> { + pub fn new(bump: &'a bumpalo::Bump, name: &'a str) -> Self { + Self { + name: bump.alloc_str(name), + span: None, + is_implicit: false, + } + } + + pub fn new_impli(bump: &'a bumpalo::Bump, name: &'a str) -> Self { + Self { + name: bump.alloc_str(name), + span: None, + is_implicit: true, + } + } + + pub fn with_span(bump: &'a bumpalo::Bump, name: &'a str, span: Span) -> Self { + Self { + name: bump.alloc_str(name), + span: Some(span), + is_implicit: false, + } + } +} + +#[derive(Debug, Clone, PartialEq)] +pub enum Pattern<'a> { + Ident(Mutability, Ident<'a>), + Tuple(BumpVec<'a, Pattern<'a>>), + Wildcard, +} + +#[derive(Debug, Copy, Clone, PartialEq)] +pub enum Lit { + Unit, + Bool(bool), + I32(i32), + U8(u8), + U32(u32), + U64(u64), + F32(f32), + F64(f64), +} + +// impl PartialEq for Lit{ +// fn eq(&self, other:&Self) -> bool { +// let b = match (self, other) { +// (Self::Unit, Self::Unit) => true, +// (Self::Bool(x), Self::Bool(y)) => if x == y {true} else {false}, +// (Self::Int(x), Self::Int(y)) => if x == y {true} else {false}, +// (Self::Float(x), Self::Float(y)) => if x == y {true} else {false}, +// _ => false +// }; +// b +// } +// } + +impl fmt::Display for Lit { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Unit => write!(f, "()"), + Self::Bool(b) => write!(f, "{}", b), + Self::I32(i) => write!(f, "{}", i), + Self::U8(uc) => write!(f, "{}", uc), + Self::U32(u) => write!(f, "{}", u), + Self::U64(ul) => write!(f, "{}", ul), + Self::F32(fl) => write!(f, "{}f", fl), + Self::F64(d) => write!(f, "{}", d), + } + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum Mutability { + Const, + Mut, +} + +impl fmt::Display for Mutability { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let str = match self { + Self::Const => "const", + Self::Mut => "mut", + }; + write!(f, "{}", str) + } +} + +#[derive(PartialEq, Eq, PartialOrd, Ord, Hash, Debug, Copy, Clone)] +pub enum Ownership { + Shrd, + Uniq, +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum UnOp { + Not, + Neg, +} + +impl fmt::Display for UnOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let str = match self { + Self::Not => "!", + Self::Neg => "-", + }; + write!(f, "{}", str) + } +} + +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] +pub enum BinOp { + Add, + Sub, + Mul, + Div, + Mod, + And, + Or, + Eq, + Lt, + Le, + Gt, + Ge, + Neq, + Shl, + Shr, + BitOr, + BitAnd, +} + +impl fmt::Display for BinOp { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let str = match self { + Self::Add => "+", + Self::Sub => "-", + Self::Mul => "*", + Self::Div => "/", + Self::Mod => "%", + Self::And => "&&", + Self::Or => "||", + Self::Eq => "=", + Self::Lt => "<", + Self::Le => "<=", + Self::Gt => ">", + Self::Ge => ">=", + Self::Neq => "!=", + Self::Shl => "<<", + Self::Shr => ">>", + Self::BitOr => "|", + Self::BitAnd => "&", + }; + write!(f, "{}", str) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum Kind { + Nat, + Memory, + DataTy, + Provenance, +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +pub enum ArgKinded<'a> { + Ident(Ident<'a>), + Nat(Nat<'a>), + Memory(Memory<'a>), + DataTy(DataTy<'a>), + Provenance(Provenance<'a>), +} + +impl<'a> ArgKinded<'a> { + pub fn kind(&self) -> Kind { + match self { + ArgKinded::Ident(_) => { + panic!("Unexpected: unkinded identifier should have been removed after parsing") + } + ArgKinded::DataTy(_) => Kind::DataTy, + ArgKinded::Provenance(_) => Kind::Provenance, + ArgKinded::Memory(_) => Kind::Memory, + ArgKinded::Nat(_) => Kind::Nat, + } + } + + pub fn equal(&'a self, nat_ctx: &'a NatCtx<'a>, other: &'a Self) -> NatEvalResult<'a, bool> { + match (self, other) { + (ArgKinded::Ident(i), ArgKinded::Ident(o)) => Ok(i == o), + (ArgKinded::Nat(n), ArgKinded::Nat(no)) => Ok(n.eval(nat_ctx)? == no.eval(nat_ctx)?), + (ArgKinded::Provenance(r), ArgKinded::Provenance(ro)) => Ok(r == ro), + (ArgKinded::DataTy(dty), ArgKinded::DataTy(dtyo)) => dty.equal(nat_ctx, dtyo), + (ArgKinded::Memory(mem), ArgKinded::Memory(memo)) => Ok(mem == memo), + _ => Ok(false), + } + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> Self { + match self { + ArgKinded::Ident(i) => ArgKinded::Ident(i.clone()), + ArgKinded::Nat(n) => ArgKinded::Nat(n.clone_in(arena)), + ArgKinded::Memory(m) => ArgKinded::Memory(m.clone_in(arena)), + ArgKinded::DataTy(d) => ArgKinded::DataTy(d.clone_in(arena)), + ArgKinded::Provenance(p) => ArgKinded::Provenance(p.clone_in(arena)), + } + } +} + +#[span_derive(PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] +pub struct PlaceExpr<'a> { + pub pl_expr: PlaceExprKind<'a>, + // FIXME misusing span_derive_ignore to ignore type on equality checks + #[span_derive_ignore] + pub ty: OnceCell<&'a Ty<'a>>, + #[span_derive_ignore] + pub span: Option, +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct View<'a> { + pub name: Ident<'a>, + pub gen_args: BumpVec<'a, ArgKinded<'a>>, + pub args: BumpVec<'a, View<'a>>, +} + +impl<'a> View<'a> { + pub fn equal( + &'a self, + nat_ctx: &'a NatCtx<'a>, + other: &'a View<'a>, + ) -> NatEvalResult<'a, bool> { + if self.name.name != other.name.name { + return Ok(false); + } + + if self.gen_args.len() != other.gen_args.len() { + return Ok(false); + } + + for (ga, go) in self.gen_args.iter().zip(other.gen_args.iter()) { + if !ga.equal(nat_ctx, go)? { + return Ok(false); + } + } + + if self.args.len() != other.args.len() { + return Ok(false); + } + + for (v, vo) in self.args.iter().zip(other.args.iter()) { + if !v.equal(nat_ctx, vo)? { + return Ok(false); + } + } + + Ok(true) + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> Self { + // Clone generic arguments into the arena + let mut gen_args = bumpalo::collections::Vec::new_in(arena); + gen_args.extend(self.gen_args.iter().map(|ga| ga.clone_in(arena))); + + // Clone nested views into the arena + let mut args = bumpalo::collections::Vec::new_in(arena); + args.extend(self.args.iter().map(|v| v.clone_in(arena))); + + View { + name: self.name.clone(), + gen_args, + args, + } + } +} + +// TODO create generic View struct to enable easier extensibility by introducing only +// new predeclared types +// #[derive(PartialEq, Eq, Hash, Debug, Clone)] +// pub enum View { +// ToView, +// Group(Nat), +// SplitAt(Nat), +// Transpose, +// Rev, +// Map(Box), +// } + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum PlaceExprKind<'a> { + View(&'a PlaceExpr<'a>, &'a View<'a>), + // similar to a projection, but it projects an element for each provided execution resource + // (similar to indexing) + // p[[x]] + Select(&'a PlaceExpr<'a>, &'a ExecExpr<'a>), + // p.0 | p.1 + Proj(&'a PlaceExpr<'a>, usize), + FieldProj(&'a PlaceExpr<'a>, &'a Ident<'a>), + // *p + Deref(&'a PlaceExpr<'a>), + // Index into array, e.g., arr[i] + Idx(&'a PlaceExpr<'a>, &'a Nat<'a>), + // x + Ident(Ident<'a>), +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum PlExprPathElem<'a> { + View(View<'a>), + Select(&'a ExecExpr<'a>), + Proj(usize), + FieldProj(Ident<'a>), + Deref, + Idx(&'a Nat<'a>), + RangeSelec(&'a Nat<'a>, &'a Nat<'a>), +} + +impl<'a> PlaceExpr<'a> { + pub fn new(pl_expr: PlaceExprKind<'a>) -> Self { + PlaceExpr { + pl_expr, + ty: OnceCell::new(), + span: None, + } + } + + pub fn with_span(pl_expr: PlaceExprKind<'a>, span: Span) -> Self { + PlaceExpr { + pl_expr, + ty: OnceCell::new(), + span: Some(span), + } + } + + pub fn set_ty(&self, arena: &'a bumpalo::Bump, ty: Ty<'a>) { + let ty_ref = arena.alloc(ty); + // Ignore the error if already set, or assert if you prefer: + let _ = self.ty.set(ty_ref); + // or: self.ty.get_or_init(|| arena.alloc(ty)); + } + + /// Read-only access to the inferred type (if already set). + pub fn get_ty(&self) -> Option<&'a Ty<'a>> { + self.ty.get().copied() + } + + pub fn is_place(&self) -> bool { + match &self.pl_expr { + PlaceExprKind::Ident(_) => true, + PlaceExprKind::Proj(ple, _) | PlaceExprKind::FieldProj(ple, _) => ple.is_place(), + PlaceExprKind::Select(_, _) + | PlaceExprKind::Deref(_) + | PlaceExprKind::Idx(_, _) + | PlaceExprKind::View(_, _) => false, + } + } + + // TODO refactor. Places are only needed during typechecking and codegen + pub fn to_place(&'a self, arena: &'a bumpalo::Bump) -> Option> { + if self.is_place() { + Some(self.to_pl_ctx_and_most_specif_pl(arena).1) + } else { + None + } + } + + // TODO refactor see to_place + pub fn to_pl_ctx_and_most_specif_pl( + &'a self, + arena: &'a bumpalo::Bump, + ) -> (internal::PlaceCtx<'a>, internal::Place<'a>) { + match &self.pl_expr { + PlaceExprKind::Select(inner_ple, exec_idents) => { + let (pl_ctx, pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + ( + internal::PlaceCtx::Select(arena.alloc(pl_ctx), exec_idents.clone()), + pl, + ) + } + PlaceExprKind::Deref(inner_ple) => { + let (pl_ctx, pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + (internal::PlaceCtx::Deref(arena.alloc(pl_ctx)), pl) + } + PlaceExprKind::View(inner_ple, view) => { + let (pl_ctx, pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + ( + internal::PlaceCtx::View(arena.alloc(pl_ctx), view.clone()), + pl, + ) + } + PlaceExprKind::Proj(inner_ple, n) => { + let (pl_ctx, mut pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + match pl_ctx { + internal::PlaceCtx::Hole => { + pl.path.push(PathElem::Proj(*n)); + (pl_ctx, internal::Place::new(pl.ident, pl.path)) + } + _ => (internal::PlaceCtx::Proj(arena.alloc(pl_ctx), *n), pl), + } + } + PlaceExprKind::FieldProj(inner_ple, field_name) => { + let (pl_ctx, mut pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + match pl_ctx { + internal::PlaceCtx::Hole => { + pl.path.push(PathElem::FieldProj(field_name.clone())); + (pl_ctx, internal::Place::new(pl.ident, pl.path)) + } + _ => ( + internal::PlaceCtx::FieldProj( + arena.alloc(pl_ctx), + field_name.clone().clone(), + ), + pl, + ), + } + } + PlaceExprKind::Idx(inner_ple, idx) => { + let (pl_ctx, pl) = inner_ple.to_pl_ctx_and_most_specif_pl(arena); + ( + internal::PlaceCtx::Idx(arena.alloc(pl_ctx), idx.clone()), + pl, + ) + } + PlaceExprKind::Ident(ident) => ( + internal::PlaceCtx::Hole, + internal::Place::new(ident.clone(), BumpVec::new_in(arena)), + ), + } + } + + pub fn equiv(&'a self, arena: &'a bumpalo::Bump, place: &'a internal::Place) -> bool { + if let (internal::PlaceCtx::Hole, pl) = self.to_pl_ctx_and_most_specif_pl(arena) { + &pl == place + } else { + false + } + } + + pub fn as_ident_and_path( + &'a self, + arena: &'a bumpalo::Bump, + ) -> (Ident<'a>, BumpVec<'a, PlExprPathElem<'a>>) { + fn as_ident_and_path_rec<'a>( + pl_expr: &'a PlaceExpr<'a>, + mut path: BumpVec<'a, PlExprPathElem<'a>>, + ) -> (Ident<'a>, BumpVec<'a, PlExprPathElem<'a>>) { + match &pl_expr.pl_expr { + PlaceExprKind::Ident(i) => { + path.reverse(); + (i.clone(), path) + } + PlaceExprKind::Select(inner_ple, exec_idents) => { + path.push(PlExprPathElem::Select(exec_idents.clone())); + as_ident_and_path_rec(inner_ple, path) + } + PlaceExprKind::Deref(inner_ple) => { + path.push(PlExprPathElem::Deref); + as_ident_and_path_rec(inner_ple, path) + } + PlaceExprKind::View(inner_ple, view) => { + path.push(PlExprPathElem::View(view.clone().clone())); // formerly as_ref().clone() ? Can that just work with double cloning? + as_ident_and_path_rec(inner_ple, path) + } + PlaceExprKind::Proj(inner_ple, n) => { + path.push(PlExprPathElem::Proj(*n)); + as_ident_and_path_rec(inner_ple, path) + } + PlaceExprKind::FieldProj(inner_ple, ident) => { + path.push(PlExprPathElem::FieldProj(ident.clone().clone())); // formerly as_ref().clone() ? Can that just work with double cloning? + as_ident_and_path_rec(inner_ple, path) + } + PlaceExprKind::Idx(inner_ple, idx) => { + path.push(PlExprPathElem::Idx(idx.clone())); + as_ident_and_path_rec(inner_ple, path) + } + } + } + as_ident_and_path_rec(self, BumpVec::new_in(arena)) // BumpVec Stuff into it + } +} + +#[span_derive(PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] +pub struct ExecExpr<'a> { + pub exec: &'a ExecExprKind<'a>, + #[span_derive_ignore] + pub ty: Option<&'a ExecTy<'a>>, + #[span_derive_ignore] + pub span: Option, +} +impl<'a> ExecExpr<'a> { + pub fn new(arena: &'a bumpalo::Bump, exec: ExecExprKind<'a>) -> Self { + Self { + exec: arena.alloc(exec), + ty: None, + span: None, + } + } + + // TODO how does this relate to is_prefix_of. Refactor. + pub fn is_sub_exec_of(&self, exec: &ExecExpr) -> bool { + if self.exec.path.len() > exec.exec.path.len() { + return self.exec.path[..exec.exec.path.len()] == exec.exec.path[..]; + } + false + } + + pub fn remove_last_distrib(&self, arena: &'a bumpalo::Bump) -> ExecExpr { + let last_distrib_pos = self + .exec + .path + .iter() + .rposition(|e| matches!(e, ExecPathElem::ForAll(_))); + + // What did i do here? + let removed_distrib_path = if let Some(ldp) = last_distrib_pos { + let mut vec = BumpVec::new_in(arena); + // self.exec.path[..ldp].to_vec() --> changed this to BumpVec + vec.extend_from_slice(&self.exec.path[..ldp]); + vec + } else { + //vec![] --> changed this to BumpVec + BumpVec::new_in(arena) + }; + + ExecExpr::new( + arena, + ExecExprKind::with_path(self.exec.base.clone(), removed_distrib_path), + ) + } + + pub fn equal(&self, nat_ctx: &'a NatCtx<'a>, other: &Self) -> NatEvalResult { + match (&self.exec.base, &other.exec.base) { + (BaseExec::Ident(i), BaseExec::Ident(o)) => { + if i != o { + return Ok(false); + } + } + (BaseExec::CpuThread, BaseExec::CpuThread) => (), + (BaseExec::GpuGrid(gdim, bdim), BaseExec::GpuGrid(gdimo, bdimo)) => { + if !(gdim.equal(nat_ctx, gdimo)? && bdim.equal(nat_ctx, bdimo)?) { + return Ok(false); + } + } + _ => return Ok(false), + } + if self.exec.path.len() != other.exec.path.len() { + return Ok(false); + } + for path_elems in self.exec.path.iter().zip(&other.exec.path) { + match path_elems { + (ExecPathElem::ToWarps, ExecPathElem::ToWarps) => (), + (ExecPathElem::ForAll(d), ExecPathElem::ForAll(o)) => { + if d != o { + return Ok(false); + } + } + (ExecPathElem::ToThreads(d), ExecPathElem::ToThreads(o)) => { + if d != o { + return Ok(false); + } + } + (ExecPathElem::TakeRange(r), ExecPathElem::TakeRange(ro)) => { + if !(r.split_dim == ro.split_dim + && r.left_or_right == ro.left_or_right + && r.pos.eval(nat_ctx)? == ro.pos.eval(nat_ctx)?) + { + return Ok(false); + } + } + _ => return Ok(false), + } + } + Ok(true) + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> Self { + ExecExpr { + exec: arena.alloc(self.exec.clone_in(arena)), + ty: self.ty.as_ref().map(|t| { + let t_mut: &mut ExecTy<'a> = arena.alloc(t.clone_in(arena)); + &*t_mut // Reborrow as immutable reference, somehow feels illegal + }), + span: self.span, + } + } +} + +#[test] +fn equal_exec_exprs() { + let arena = Bump::new(); + + let exec1 = ExecExpr::new( + &arena, + ExecExprKind::with_path( + BaseExec::Ident(Ident::new(&arena, "grid")), + bumpalo::collections::Vec::from_iter_in([ExecPathElem::ForAll(DimCompo::X)], &arena), + ), + ); + + let exec2 = ExecExpr::new( + &arena, + ExecExprKind::with_path( + BaseExec::Ident(Ident::new(&arena, "grid")), + bumpalo::collections::Vec::from_iter_in([ExecPathElem::ForAll(DimCompo::X)], &arena), + ), + ); + + assert_eq!(exec1, exec2, "Unequal execs that should be equal"); +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum LeftOrRight { + Left, + Right, +} + +impl fmt::Display for LeftOrRight { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + LeftOrRight::Left => write!(f, "left"), + LeftOrRight::Right => write!(f, "right"), + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct TakeRange<'a> { + pub split_dim: DimCompo, + pub pos: Nat<'a>, + pub left_or_right: LeftOrRight, +} + +impl<'a> TakeRange<'a> { + pub fn new(split_dim: DimCompo, pos: Nat<'a>, proj: LeftOrRight) -> Self { + TakeRange { + split_dim, + pos, + left_or_right: proj, + } + } + + pub fn clone_in(&self, arena: &'a Bump) -> &'a TakeRange<'a> { + arena.alloc(TakeRange::new( + self.split_dim, + self.pos.clone_in(arena), + self.left_or_right, + )) + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct ExecExprKind<'a> { + pub base: BaseExec<'a>, + pub path: BumpVec<'a, ExecPathElem<'a>>, +} + +impl<'a> ExecExprKind<'a> { + pub fn new(arena: &'a bumpalo::Bump, base: BaseExec<'a>) -> Self { + ExecExprKind { + base, + path: BumpVec::new_in(arena), + } + } + + pub fn clone_in(&self, arena: &'a Bump) -> Self { + let mut new_path = BumpVec::new_in(arena); + for e in self.path.iter() { + new_path.push(e.clone_in(arena)); + } + ExecExprKind { + base: self.base.clone_in(arena), + path: new_path, + } + } + + pub fn with_path(base: BaseExec<'a>, path: BumpVec<'a, ExecPathElem<'a>>) -> Self { + ExecExprKind { base, path } + } + + /** + pub fn with_path(base: BaseExec, path: impl IntoIterator>, arena: &'a Bump) -> Self { + let mut bump_vec = BumpVec::new_in(arena); + bump_vec.extend(path); + Self { base, path: bump_vec } + }*/ + + pub fn split_proj( + mut self, + arena: &'a bumpalo::Bump, + dim_compo: DimCompo, + pos: Nat<'a>, + proj: LeftOrRight, + ) -> Self { + self.path.push(ExecPathElem::TakeRange( + arena.alloc(TakeRange::new(dim_compo, pos, proj)), + )); + self + } + + pub fn forall(mut self, dim_compo: DimCompo) -> Self { + self.path.push(ExecPathElem::ForAll(dim_compo)); + self + } + + pub fn active_distrib_dim(&self) -> Option { + for e in self.path.iter().rev() { + if let ExecPathElem::ForAll(dim) = e { + return Some(*dim); + } + } + None + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum BaseExec<'a> { + Ident(Ident<'a>), + CpuThread, + GpuGrid(&'a Dim<'a>, &'a Dim<'a>), +} + +impl<'a> BaseExec<'a> { + pub fn clone_in(&self, arena: &'a Bump) -> Self { + match self { + BaseExec::Ident(ie) => BaseExec::Ident(ie.clone()), + BaseExec::CpuThread => BaseExec::CpuThread, + BaseExec::GpuGrid(gdim, bdim) => { + let ng = arena.alloc(gdim.clone_in(arena)); + let nb = arena.alloc(bdim.clone_in(arena)); + BaseExec::GpuGrid(ng, nb) + } // add other variants here if you have them + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum ExecPathElem<'a> { + TakeRange(&'a TakeRange<'a>), + ForAll(DimCompo), + ToWarps, + ToThreads(DimCompo), +} + +impl<'a> ExecPathElem<'a> { + pub fn clone_in(&self, arena: &'a Bump) -> Self { + match self { + ExecPathElem::TakeRange(tr) => ExecPathElem::TakeRange(tr.clone_in(arena)), + ExecPathElem::ForAll(d) => ExecPathElem::ForAll(*d), + ExecPathElem::ToWarps => ExecPathElem::ToWarps, + ExecPathElem::ToThreads(d) => ExecPathElem::ToThreads(*d), + } + } +} + +// ExecTy +// fn size(DimCompo) -> usize +// fn take_range(DimCompo, Nat) -> ExecTy +// fn elem_type(DimCompo) -> ExecTy +#[span_derive(PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] +pub struct ExecTy<'a> { + pub ty: ExecTyKind<'a>, + #[span_derive_ignore] + pub span: Option, +} + +impl<'a> ExecTy<'a> { + pub fn new(exec: ExecTyKind<'a>) -> Self { + ExecTy { + ty: exec, + span: None, + } + } + + pub fn clone_in(&self, _arena: &'a bumpalo::Bump) -> ExecTy<'a> { + ExecTy { + ty: self.ty.clone(), + span: self.span, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum ExecTyKind<'a> { + CpuThread, + GpuThread, + GpuWarp, + GpuBlock(Dim<'a>), + GpuGrid(Dim<'a>, Dim<'a>), + GpuToThreads(Dim<'a>, &'a ExecTy<'a>), + GpuThreadGrp(Dim<'a>), + GpuWarpGrp(Nat<'a>), + GpuBlockGrp(Dim<'a>, Dim<'a>), + Any, +} + +#[span_derive(PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] +pub struct Ty<'a> { + pub ty: TyKind<'a>, + #[span_derive_ignore] + pub span: Option, +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct ParamSig<'a> { + pub exec_expr: ExecExpr<'a>, + pub ty: &'a Ty<'a>, +} + +impl<'a> ParamSig<'a> { + pub fn new(exec_expr: ExecExpr<'a>, ty: &'a Ty<'a>) -> Self { + ParamSig { exec_expr, ty } + } + + pub fn clone_in(&self, arena: &'a Bump) -> Self { + ParamSig { + exec_expr: self.exec_expr.clone_in(arena), + ty: { + let t = self.ty.clone_in(arena); + arena.alloc(t) + }, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug)] +pub struct FnTy<'a> { + pub generics: BumpVec<'a, IdentKinded<'a>>, + pub generic_exec: Option>, + pub param_sigs: BumpVec<'a, ParamSig<'a>>, + pub exec: ExecExpr<'a>, + pub ret_ty: &'a Ty<'a>, + pub nat_constrs: BumpVec<'a, NatConstr<'a>>, +} + +impl<'a> FnTy<'a> { + pub fn new( + arena: &'a Bump, + generics: impl IntoIterator>, + generic_exec: Option>, + param_sigs: impl IntoIterator>, + exec: ExecExpr<'a>, + ret_ty: &'a Ty<'a>, + nat_constrs: impl IntoIterator>, + ) -> Self { + let mut generics_vec = BumpVec::new_in(arena); + generics_vec.extend(generics); + + let mut param_vec = BumpVec::new_in(arena); + param_vec.extend(param_sigs); + + let mut nat_vec = BumpVec::new_in(arena); + nat_vec.extend(nat_constrs); + + FnTy { + generics: generics_vec, + generic_exec, + param_sigs: param_vec, + exec, + ret_ty: arena.alloc(ret_ty), + nat_constrs: nat_vec, + } + } + + pub fn clone_in(&self, arena: &'a Bump) -> FnTy<'a> { + // generics + let mut generics = BumpVec::new_in(arena); + generics.extend(self.generics.iter().cloned()); + + // optional generic exec + let generic_exec = self.generic_exec.as_ref().map(|ge| ge.clone_in(arena)); + + // param sigs + let mut param_sigs = BumpVec::new_in(arena); + for ps in self.param_sigs.iter() { + param_sigs.push(ps.clone_in(arena)); + } + + // exec expression + let exec = self.exec.clone_in(arena); + + // return type (allocate cloned Ty in the arena and store the &'a Ty) + let ret_ty_ref: &'a Ty<'a> = { + let ret = self.ret_ty.clone_in(arena); + arena.alloc(ret) + }; + + // nat constraints + let mut nat_constrs = BumpVec::new_in(arena); + for c in self.nat_constrs.iter() { + nat_constrs.push(c.clone_in(arena)); + } + + FnTy { + generics, + generic_exec, + param_sigs, + exec, + ret_ty: ret_ty_ref, + nat_constrs, + } + } +} + +impl<'a> Clone for FnTy<'a> { + fn clone(&self) -> Self { + let mut generics = BumpVec::new_in(self.generics.bump()); + generics.extend(self.generics.iter().cloned()); + + let mut param_sigs = BumpVec::new_in(self.param_sigs.bump()); + param_sigs.extend(self.param_sigs.iter().cloned()); + + let mut nat_constrs = BumpVec::new_in(self.nat_constrs.bump()); + nat_constrs.extend(self.nat_constrs.iter().cloned()); + + FnTy { + generics, + generic_exec: self.generic_exec.clone(), + param_sigs, + exec: self.exec.clone(), + ret_ty: self.ret_ty, + nat_constrs, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum NatConstr<'a> { + True, + Eq(&'a Nat<'a>, &'a Nat<'a>), + Lt(&'a Nat<'a>, &'a Nat<'a>), + And(&'a NatConstr<'a>, &'a NatConstr<'a>), + Or(&'a NatConstr<'a>, &'a NatConstr<'a>), +} + +impl<'a> NatConstr<'a> { + pub fn clone_in(&self, arena: &'a Bump) -> NatConstr<'a> { + match self { + NatConstr::True => NatConstr::True, + + NatConstr::Eq(l, r) => { + let l2 = arena.alloc((*l).clone_in(arena)); + let r2 = arena.alloc((*r).clone_in(arena)); + NatConstr::Eq(l2, r2) + } + + NatConstr::Lt(l, r) => { + let l2 = arena.alloc((*l).clone_in(arena)); + let r2 = arena.alloc((*r).clone_in(arena)); + NatConstr::Lt(l2, r2) + } + + NatConstr::And(a, b) => { + let a2 = arena.alloc((*a).clone_in(arena)); + let b2 = arena.alloc((*b).clone_in(arena)); + NatConstr::And(a2, b2) + } + + NatConstr::Or(a, b) => { + let a2 = arena.alloc((*a).clone_in(arena)); + let b2 = arena.alloc((*b).clone_in(arena)); + NatConstr::Or(a2, b2) + } + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum TyKind<'a> { + Data(&'a DataTy<'a>), + // (ty..) -[x:exec]-> ty + FnTy(&'a FnTy<'a>), +} + +// TODO remove +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum Constraint { + Copyable, +} + +impl<'a> Ty<'a> { + pub fn new(ty: TyKind<'a>) -> Self { + Ty { ty, span: None } + } + + pub fn with_span(ty: TyKind<'a>, span: Span) -> Ty<'a> { + Ty { + ty, + span: Some(span), + } + } + + pub fn dty(&self) -> &'a DataTy<'a> { + match &self.ty { + TyKind::Data(dty) => dty, + _ => panic!("Expected data type but found {:?}", self), + } + } + + pub fn copyable(&self) -> bool { + match &self.ty { + TyKind::Data(dty) => dty.copyable(), + TyKind::FnTy(_) => true, + } + } + + pub fn is_fully_alive(&self) -> bool { + match &self.ty { + TyKind::Data(dty) => dty.is_fully_alive(), + TyKind::FnTy(_) => true, + } + } + + pub fn contains_ref_to_prv(&self, prv_val_name: &str) -> bool { + match &self.ty { + TyKind::Data(dty) => dty.contains_ref_to_prv(prv_val_name), + TyKind::FnTy(fn_ty) => { + fn_ty + .param_sigs + .iter() + .any(|param_sig| param_sig.ty.contains_ref_to_prv(prv_val_name)) + || fn_ty.ret_ty.contains_ref_to_prv(prv_val_name) + } + } + } + + pub fn clone_in(&self, _arena: &'a bumpalo::Bump) -> Ty<'a> { + Ty { + ty: self.ty.clone(), + span: self.span, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct Dim1d<'a>(pub Nat<'a>); +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct Dim2d<'a>(pub Nat<'a>, pub Nat<'a>); +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct Dim3d<'a>(pub Nat<'a>, pub Nat<'a>, pub Nat<'a>); +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Dim<'a> { + XYZ(&'a Dim3d<'a>), + XY(&'a Dim2d<'a>), + XZ(&'a Dim2d<'a>), + YZ(&'a Dim2d<'a>), + X(&'a Dim1d<'a>), + Y(&'a Dim1d<'a>), + Z(&'a Dim1d<'a>), +} + +impl<'a> Dim<'a> { + pub fn new_3d(arena: &'a Bump, n1: Nat<'a>, n2: Nat<'a>, n3: Nat<'a>) -> Self { + Dim::XYZ(arena.alloc(Dim3d(n1, n2, n3))) + } + + pub fn new_2d(arena: &'a Bump, constr: F, n1: Nat<'a>, n2: Nat<'a>) -> Self + where + F: Fn(&'a Dim2d<'a>) -> Self + 'a, + { + constr(arena.alloc(Dim2d(n1, n2))) + } + + pub fn new_1d(arena: &'a Bump, constr: F, n: Nat<'a>) -> Self + where + F: Fn(&'a Dim1d<'a>) -> Self + 'a, + { + constr(arena.alloc(Dim1d(n))) + } + + pub fn equal(&self, nat_ctx: &'a NatCtx<'a>, other: &Self) -> NatEvalResult { + match (self, other) { + (Dim::XYZ(d), Dim::XYZ(o)) => Ok(d.0.eval(nat_ctx)? == o.0.eval(nat_ctx)? + && d.1.eval(nat_ctx)? == o.1.eval(nat_ctx)? + && d.2.eval(nat_ctx)? == o.2.eval(nat_ctx)?), + (Dim::XY(d), Dim::XY(o)) | (Dim::XZ(d), Dim::XZ(o)) | (Dim::YZ(d), Dim::YZ(o)) => { + Ok(d.0.eval(nat_ctx)? == o.0.eval(nat_ctx)? + && d.1.eval(nat_ctx)? == o.1.eval(nat_ctx)?) + } + (Dim::X(d), Dim::X(o)) | (Dim::Y(d), Dim::Y(o)) | (Dim::Z(d), Dim::Z(o)) => { + Ok(d.0.eval(nat_ctx)? == o.0.eval(nat_ctx)?) + } + _ => Ok(false), + } + } + + pub fn clone_in(&self, arena: &'a Bump) -> Self { + match *self { + Dim::XYZ(d) => Dim::new_3d(arena, d.0.clone(), d.1.clone(), d.2.clone()), + Dim::XY(d) => Dim::new_2d(arena, Dim::XY, d.0.clone(), d.1.clone()), + Dim::XZ(d) => Dim::new_2d(arena, Dim::XZ, d.0.clone(), d.1.clone()), + Dim::YZ(d) => Dim::new_2d(arena, Dim::YZ, d.0.clone(), d.1.clone()), + Dim::X(d) => Dim::new_1d(arena, Dim::X, d.0.clone()), + Dim::Y(d) => Dim::new_1d(arena, Dim::Y, d.0.clone()), + Dim::Z(d) => Dim::new_1d(arena, Dim::Z, d.0.clone()), + } + } +} + +#[derive(PartialEq, Eq, PartialOrd, Hash, Debug, Copy, Clone)] +pub enum DimCompo { + X, + Y, + Z, +} + +#[span_derive(PartialEq, Eq, Hash)] +#[derive(Debug, Clone)] +pub struct DataTy<'a> { + pub dty: DataTyKind<'a>, + // TODO remove with introduction of traits + pub constraints: BumpVec<'a, Constraint>, + #[span_derive_ignore] + pub span: Option, +} + +impl<'a> DataTy<'a> { + pub fn new(bump: &'a bumpalo::Bump, dty: DataTyKind<'a>) -> Self { + DataTy { + dty, + constraints: BumpVec::new_in(bump), + span: None, + } + } + + pub fn with_constr( + bump: &'a bumpalo::Bump, + dty: DataTyKind<'a>, + constraints: impl IntoIterator, + ) -> Self { + let mut v = BumpVec::new_in(bump); + v.extend(constraints); + DataTy { + dty, + constraints: v, + span: None, + } + } + + pub fn with_span(bump: &'a bumpalo::Bump, dty: DataTyKind<'a>, span: Span) -> Self { + DataTy { + dty, + constraints: BumpVec::new_in(bump), + span: Some(span), + } + } + + pub fn non_copyable(&'a self) -> bool { + use DataTyKind::*; + + match &self.dty { + Scalar(_) => false, + Atomic(_) => false, + Ident(_) => true, + Ref(reff) => reff.own == Ownership::Uniq, + At(_, _) => true, + ArrayShape(_, _) => true, + Tuple(elem_tys) => elem_tys.iter().any(|ty| ty.non_copyable()), + Array(_, _) => false, + RawPtr(_) => true, + Range => true, + Dead(_) => panic!( + "This case is not expected to mean anything.\ + The type is dead. There is nothing we can do with it." + ), + } + } + + pub fn copyable(&'a self) -> bool { + !self.non_copyable() + } + + pub fn is_fully_alive(&'a self) -> bool { + use DataTyKind::*; + match &self.dty { + Scalar(_) + | RawPtr(_) + | Atomic(_) + | Ident(_) + | Ref(_) + | At(_, _) + | Array(_, _) + | ArrayShape(_, _) => true, + Tuple(elem_tys) => elem_tys + .iter() + .fold(true, |acc, dty| acc & dty.is_fully_alive()), + Struct(struct_decl) => struct_decl + .fields + .iter() + .fold(true, |acc, (_, dty)| acc & dty.is_fully_alive()), + Dead(_) => false, + } + } + + pub fn occurs_in(&'a self, dty: &DataTy) -> bool { + if self == dty { + return true; + } + match &dty.dty { + DataTyKind::Scalar(_) | DataTyKind::Ident(_) => false, + DataTyKind::Dead(_) => panic!("unexpected"), + DataTyKind::Atomic(aty) => &self.dty == &DataTyKind::Atomic(aty.clone()), + DataTyKind::Ref(reff) => self.occurs_in(&reff.dty), + DataTyKind::RawPtr(elem_dty) => self.occurs_in(elem_dty), + DataTyKind::Tuple(elem_dtys) => { + let mut found = false; + for elem_dty in elem_dtys { + found = self.occurs_in(elem_dty); + } + found + } + DataTyKind::Struct(struct_decl) => { + let mut found = false; + for (_, dty) in &struct_decl.fields { + found = self.occurs_in(dty) + } + found + } + DataTyKind::Array(elem_dty, _) => self.occurs_in(elem_dty), + DataTyKind::ArrayShape(elem_dty, _) => self.occurs_in(elem_dty), + DataTyKind::At(elem_dty, _) => self.occurs_in(elem_dty), + } + } + + pub fn contains_ref_to_prv(&'a self, prv_val_name: &str) -> bool { + use DataTyKind::*; + match &self.dty { + Scalar(_) | Atomic(_) | Ident(_) | Dead(_) => false, + Ref(reff) => { + let found_reference = if let Provenance::Value(prv_val_n) = &reff.rgn { + prv_val_name == *prv_val_n + } else { + false + }; + found_reference || reff.dty.contains_ref_to_prv(prv_val_name) + } + RawPtr(dty) => dty.contains_ref_to_prv(prv_val_name), + At(dty, _) => dty.contains_ref_to_prv(prv_val_name), + Array(dty, _) => dty.contains_ref_to_prv(prv_val_name), + ArrayShape(dty, _) => dty.contains_ref_to_prv(prv_val_name), + Tuple(elem_tys) => elem_tys + .iter() + .any(|ty| ty.contains_ref_to_prv(prv_val_name)), + Struct(struct_decl) => struct_decl + .fields + .iter() + .any(|(_, dty)| dty.contains_ref_to_prv(prv_val_name)), + } + } + + pub fn equal(&'a self, nat_ctx: &'a NatCtx<'a>, other: &'a Self) -> NatEvalResult<'a, bool> { + match (&self.dty, &other.dty) { + (DataTyKind::Ident(i), DataTyKind::Ident(o)) => Ok(i == o), + (DataTyKind::Tuple(dtys), DataTyKind::Tuple(dtyos)) => { + for (d, o) in dtys.iter().zip(dtyos) { + if !d.equal(nat_ctx, o)? { + return Ok(false); + } + } + Ok(true) + } + (DataTyKind::Ref(ref_dty), DataTyKind::Ref(ref_dtyo)) => Ok(ref_dty.own + == ref_dtyo.own + && ref_dty.rgn == ref_dtyo.rgn + && ref_dty.mem == ref_dtyo.mem + && ref_dty.dty.equal(nat_ctx, &ref_dtyo.dty)?), + (DataTyKind::Array(dty, n), DataTyKind::Array(dtyo, no)) + | (DataTyKind::ArrayShape(dty, n), DataTyKind::ArrayShape(dtyo, no)) => { + Ok(dty.equal(nat_ctx, dtyo)? && n.eval(nat_ctx)? == no.eval(nat_ctx)?) + } + (DataTyKind::At(dty, mem), DataTyKind::At(dtyo, memo)) => { + Ok(dty.equal(nat_ctx, dtyo)? && mem == memo) + } + (DataTyKind::Struct(struct_decl), DataTyKind::Struct(struct_declo)) => { + Ok(struct_decl.ident == struct_declo.ident) + } + (DataTyKind::Atomic(aty), DataTyKind::Atomic(atyo)) => Ok(aty == atyo), + (DataTyKind::Scalar(sty), DataTyKind::Scalar(styo)) => Ok(sty == styo), + _ => Ok(false), + } + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> DataTy<'a> { + use DataTyKind::*; + + let dty: DataTyKind<'a> = match &self.dty { + // leaf cases (Copy/Clone-by-value) + Scalar(s) => Scalar(*s), + Atomic(a) => Atomic(*a), + Ident(id) => Ident(id.clone()), + + // pointer-carrying cases: clone payloads, allocate where needed + RawPtr(elem) => { + let cloned = elem.clone_in(arena); + RawPtr(arena.alloc(cloned)) + } + + Array(elem, n) => { + let cloned_e = elem.clone_in(arena); + let cloned_n = n.clone_in(arena); + Array(arena.alloc(cloned_e), cloned_n) + } + + ArrayShape(elem, n) => { + let cloned_e = elem.clone_in(arena); + let cloned_n = n.clone_in(arena); + ArrayShape(arena.alloc(cloned_e), cloned_n) + } + + At(elem, mem) => { + let cloned_e = elem.clone_in(arena); + let mem_val = (*mem).clone(); + At(arena.alloc(cloned_e), mem_val) + } + + Ref(r) => { + let inner = r.dty.clone_in(arena); + let reff = RefDty::new(arena, r.rgn.clone(), r.own, r.mem.clone(), inner); + Ref(arena.alloc(reff)) + } + + Tuple(elems) => { + let mut out = bumpalo::collections::Vec::new_in(arena); + out.extend(elems.iter().map(|e| e.clone_in(arena))); + Tuple(out) + } + + Struct(sd) => Struct(*sd), + + Dead(inner) => { + let cloned = inner.clone_in(arena); + Dead(arena.alloc(cloned)) + } + }; + + // clone constraints into this arena + let mut constraints = bumpalo::collections::Vec::new_in(arena); + constraints.extend(self.constraints.iter().cloned()); + + DataTy { + dty, + constraints, + span: self.span, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct RefDty<'a> { + pub rgn: Provenance<'a>, + pub own: Ownership, + pub mem: Memory<'a>, + pub dty: &'a DataTy<'a>, +} + +impl<'a> RefDty<'a> { + pub fn new( + arena: &'a Bump, + rgn: Provenance<'a>, + own: Ownership, + mem: Memory<'a>, + dty: DataTy<'a>, + ) -> Self { + RefDty { + rgn, + own, + mem, + dty: arena.alloc(dty), + } + } + + pub fn clone_in(&self, arena: &'a bumpalo::Bump) -> Self { + let cloned_dty: &'a DataTy<'a> = arena.alloc(self.dty.clone_in(arena)); + + RefDty { + rgn: self.rgn.clone_in(arena), + own: self.own, // Copy + mem: self.mem.clone_in(arena), + dty: cloned_dty, + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum DataTyKind<'a> { + Ident(Ident<'a>), + Scalar(ScalarTy), + Atomic(AtomicTy), + Array(&'a DataTy<'a>, Nat<'a>), + ArrayShape(&'a DataTy<'a>, Nat<'a>), + Tuple(BumpVec<'a, DataTy<'a>>), + Struct(&'a StructDecl<'a>), + At(&'a DataTy<'a>, Memory<'a>), + Ref(&'a RefDty<'a>), + RawPtr(&'a DataTy<'a>), + //Range, + // TODO remove. This is an attribute of a typing context entry, not the type. + // Only for type checking purposes. + Dead(&'a DataTy<'a>), +} + +#[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)] +pub enum ScalarTy { + Unit, + U8, + U32, + U64, + I32, + I64, + F32, + F64, + Bool, + Gpu, +} + +#[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)] +pub enum AtomicTy { + AtomicU32, + AtomicI32, +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Provenance<'a> { + Value(&'a str), + Ident(Ident<'a>), +} + +impl<'a> Provenance<'a> { + pub fn clone_in(&self, arena: &'a Bump) -> Self { + match self { + Provenance::Value(s) => Provenance::Value(arena.alloc_str(s)), + Provenance::Ident(id) => Provenance::Ident(id.clone()), + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Memory<'a> { + CpuMem, + GpuGlobal, + GpuShared, + GpuLocal, + Ident(Ident<'a>), +} + +impl<'a> Memory<'a> { + pub fn clone_in(&self, _arena: &'a Bump) -> Self { + match self { + Memory::CpuMem => Memory::CpuMem, + Memory::GpuGlobal => Memory::GpuGlobal, + Memory::GpuShared => Memory::GpuShared, + Memory::GpuLocal => Memory::GpuLocal, + Memory::Ident(id) => Memory::Ident(id.clone()), + } + } +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub struct PrvRel<'a> { + pub longer: Ident<'a>, + pub shorter: Ident<'a>, +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub struct IdentKinded<'a> { + pub ident: Ident<'a>, + pub kind: Kind, +} + +impl<'a> IdentKinded<'a> { + pub fn new(ident: &Ident<'a>, kind: Kind) -> Self { + IdentKinded { + ident: ident.clone(), + kind, + } + } +} + +#[derive(PartialEq, Eq, Debug, Clone)] +pub enum NatRange<'a> { + Simple { lower: Nat<'a>, upper: Nat<'a> }, + Halved { upper: Nat<'a> }, + Doubled { upper: Nat<'a> }, +} + +impl<'a> NatRange<'a> { + pub fn lift(&self, arena: &'a Bump, nat_ctx: &'a NatCtx<'a>) -> NatEvalResult { + let range_iter = match self { + NatRange::Simple { lower, upper } => { + let lower = lower.eval(nat_ctx)?; + let upper = upper.eval(nat_ctx)?; + NatRangeIter::new( + lower, + arena.alloc(|x| x + 1), + arena.alloc(move |c| c >= upper), + ) + } + NatRange::Halved { upper } => { + let upper = upper.eval(nat_ctx)?; + NatRangeIter::new(upper, arena.alloc(|x| x / 2), arena.alloc(|c| c == 0)) + } + NatRange::Doubled { upper } => { + let upper = upper.eval(nat_ctx)?; + NatRangeIter::new(1, arena.alloc(|x| x * 2), arena.alloc(move |c| c >= upper)) + } + }; + Ok(range_iter) + } +} + +pub struct NatRangeIter<'a> { + current: usize, + // go from current to next value + step_fun: &'a dyn Fn(usize) -> usize, + // determine whether the current value is still within range + end_cond: &'a dyn Fn(usize) -> bool, +} + +impl<'a> NatRangeIter<'a> { + fn new( + start: usize, + step_fun: &'a dyn Fn(usize) -> usize, + end_cond: &'a dyn Fn(usize) -> bool, + ) -> Self { + NatRangeIter { + current: start, + step_fun, + end_cond, + } + } +} + +impl<'a> Iterator for NatRangeIter<'a> { + type Item = usize; + + fn next(&mut self) -> Option { + self.current = (self.step_fun)(self.current); + + if !(self.end_cond)(self.current) { + Some(self.current) + } else { + None + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone)] +pub enum Nat<'a> { + Ident(Ident<'a>), + Lit(usize), + ThreadIdx(DimCompo), + BlockIdx(DimCompo), + BlockDim(DimCompo), + WarpGrpIdx, + WarpIdx, + LaneIdx, + // Dummy that is always 0, i.e. equivalent to Lit(0) + GridIdx, + BinOp(BinOpNat, &'a Nat<'a>, &'a Nat<'a>), + // Use Box<[Nat]> to safe 8 bytes compared to Vec + App(Ident<'a>, BumpVec<'a, Nat<'a>>), +} + +pub struct NatCtx<'a> { + frames: BumpVec<'a, BumpVec<'a, (&'a str, usize)>>, +} + +impl<'a> NatCtx<'a> { + pub fn new(arena: &'a Bump) -> Self { + let mut frames = BumpVec::new_in(arena); + frames.push(BumpVec::new_in(arena)); + NatCtx { frames } + } + + pub fn with_frame(arena: &'a Bump, frame: BumpVec<'a, (&'a str, usize)>) -> Self { + let mut frames = BumpVec::new_in(arena); + frames.push(frame); + NatCtx { frames } + } + + pub fn append(&mut self, nat_name: &str, val: usize, arena: &'a Bump) -> &mut Self { + let interned: &'a str = arena.alloc_str(nat_name); + let frame = self.frames.last_mut().unwrap(); + frame.push((interned, val)); + self + } + + pub fn find(&self, name: &str) -> Option { + self.frames.iter().flatten().rev().find_map( + |(i, n)| { + if *i == name { + Some(*n) + } else { + None + } + }, + ) + } + + pub fn push_empty_frame(&mut self, arena: &'a Bump) -> &mut Self { + self.frames.push(BumpVec::new_in(arena)); + self + } + + fn push_frame(&mut self, frame: BumpVec<'a, (&'a str, usize)>) -> &mut Self { + self.frames.push(frame); + self + } + + pub fn pop_frame(&mut self) -> &mut Self { + self.frames.pop().expect("There must always be a scope."); + self + } +} + +#[derive(Debug)] +pub struct NatEvalError<'a> { + pub unevaluable: Nat<'a>, +} + +pub type NatEvalResult<'a, T> = Result>; + +impl<'a> Nat<'a> { + pub fn eval(&self, nat_ctx: &'a NatCtx<'a>) -> NatEvalResult { + match self { + Nat::GridIdx + | Nat::BlockIdx(_) + | Nat::BlockDim(_) + | Nat::ThreadIdx(_) + | Nat::WarpGrpIdx + | Nat::WarpIdx + | Nat::LaneIdx => Err(NatEvalError { + unevaluable: self.clone(), + }), + Nat::Ident(i) => { + if let Some(n) = nat_ctx.find(&i.name) { + Ok(n) + } else { + Err(NatEvalError { + unevaluable: self.clone(), + }) + } + } + Nat::Lit(n) => Ok(*n), + Nat::BinOp(op, l, r) => match op { + BinOpNat::Add => Ok(l.eval(nat_ctx)? + r.eval(nat_ctx)?), + BinOpNat::Sub => Ok(l.eval(nat_ctx)? - r.eval(nat_ctx)?), + BinOpNat::Mul => Ok(l.eval(nat_ctx)? * r.eval(nat_ctx)?), + BinOpNat::Div => Ok(l.eval(nat_ctx)? / r.eval(nat_ctx)?), + BinOpNat::Mod => Ok(l.eval(nat_ctx)? % r.eval(nat_ctx)?), + }, + Nat::App(_, _) => unimplemented!(), + } + } + + pub fn new_binop_ref(arena: &'a Bump, op: BinOpNat, lhs: Nat<'a>, rhs: Nat<'a>) -> Self { + let l_ref = arena.alloc(lhs); + let r_ref = arena.alloc(rhs); + Nat::BinOp(op, l_ref, r_ref) + } + + pub fn clone_in(&self, arena: &'a Bump) -> Nat<'a> { + match self { + Nat::Ident(id) => Nat::Ident(id.clone()), + Nat::Lit(n) => Nat::Lit(*n), + + Nat::ThreadIdx(c) => Nat::ThreadIdx(*c), + Nat::BlockIdx(c) => Nat::BlockIdx(*c), + Nat::BlockDim(c) => Nat::BlockDim(*c), + Nat::WarpGrpIdx => Nat::WarpGrpIdx, + Nat::WarpIdx => Nat::WarpIdx, + Nat::LaneIdx => Nat::LaneIdx, + Nat::GridIdx => Nat::GridIdx, + + Nat::BinOp(op, l, r) => { + let lc: &'a Nat<'a> = arena.alloc(l.clone_in(arena)); + let rc: &'a Nat<'a> = arena.alloc(r.clone_in(arena)); + Nat::BinOp(*op, lc, rc) + } + + Nat::App(id, args) => { + let mut new_args: BumpVec<'a, Nat<'a>> = BumpVec::new_in(arena); + new_args.extend(args.iter().map(|a| a.clone_in(arena))); + Nat::App(id.clone(), new_args) + } + } + } +} + +#[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] +pub enum BinOpNat { + Add, + Sub, + Mul, + Div, + Mod, +} + +// When changing the AST, the types can quickly grow and lead to stack overflows in the different +// compiler stages. +// +// Taken from the rustc implementation and adjusted for this AST: +// Some nodes are used a lot. Make sure they don't unintentionally get bigger. +#[cfg(all(target_arch = "x86_64", target_pointer_width = "64"))] +mod size_asserts { + use super::*; + // Type size assertion. The first argument is a type and the second argument is its expected size. + macro_rules! static_assert_size { + ($ty:ty, $size:expr) => { + const _: [(); $size] = [(); ::std::mem::size_of::<$ty>()]; + }; + } + static_assert_size!(Dim, 16); + static_assert_size!(DataTy, 128); + static_assert_size!(DataTyKind, 80); + static_assert_size!(ExecExpr, 32); + static_assert_size!(ExecExprKind, 64); + static_assert_size!(ExecPathElem, 16); + static_assert_size!(ExecTy, 80); + static_assert_size!(ExecTyKind, 64); + static_assert_size!(Expr, 112); + static_assert_size!(ExprKind, 88); + static_assert_size!(FunDef, 216); + static_assert_size!(Ident, 32); // maybe too large? + static_assert_size!(IdentExec, 40); + static_assert_size!(Lit, 16); + static_assert_size!(Memory, 32); + static_assert_size!(Nat, 64); + static_assert_size!(ParamDecl, 80); + static_assert_size!(Pattern, 40); + static_assert_size!(PlaceExpr, 56); + static_assert_size!(PlaceExprKind, 32); + static_assert_size!(ScalarTy, 1); + static_assert_size!(Ty, 32); + static_assert_size!(TyKind, 16); +} diff --git a/src/arena_ast/printer.rs b/src/arena_ast/printer.rs new file mode 100644 index 00000000..bbb2d20b --- /dev/null +++ b/src/arena_ast/printer.rs @@ -0,0 +1,380 @@ +use crate::arena_ast::{ + AtomicTy, BaseExec, BinOpNat, DataTy, DataTyKind, Dim, DimCompo, ExecExpr, ExecPathElem, + ExecTy, ExecTyKind, FnTy, Ident, IdentExec, IdentKinded, Kind, LeftOrRight, Memory, Nat, + Ownership, ParamSig, Provenance, ScalarTy, TakeRange, Ty, TyKind, +}; +use std::fmt::Write; + +pub struct PrintState { + string: String, +} + +macro_rules! print_list { + ($print_state: ident, $print_fun: path, $list: expr) => { + for elem in $list { + $print_fun($print_state, elem); + $print_state.string.push(','); + } + }; +} + +macro_rules! print_static_list { + ($print_state: ident, $print_fun: path, $($item: expr),*) => { + $( + $print_fun($print_state, $item); + $print_state.string.push(','); + )* + }; +} + +impl PrintState { + pub fn new() -> Self { + PrintState { + string: String::new(), + } + } + + pub fn get(&self) -> String { + self.string.clone() + } + + fn print_ident(&mut self, ident: &Ident) { + self.string.push_str(&ident.name); + } + + fn print_ty(&mut self, ty: &Ty) { + match &ty.ty { + TyKind::FnTy(fn_ty) => { + self.print_fn_ty(fn_ty); + } + TyKind::Data(dty) => self.print_dty(dty), + } + } + + fn print_fn_ty(&mut self, fn_ty: &FnTy) { + self.string.push('<'); + print_list!(self, Self::print_ident_kinded, &fn_ty.generics); + self.string.push_str(">("); + print_list!(self, Self::print_param_sig, &fn_ty.param_sigs); + self.string.push_str(") -["); + if let Some(ident_exec) = &fn_ty.generic_exec { + self.print_ident_exec(ident_exec); + } else { + self.print_exec_expr(&fn_ty.exec); + } + self.string.push_str("]-> "); + self.print_ty(&fn_ty.ret_ty); + } + + fn print_param_sig(&mut self, param_sig: &ParamSig) { + self.print_exec_expr(¶m_sig.exec_expr); + self.print_ty(¶m_sig.ty); + } + + fn print_ident_exec(&mut self, ident_exec: &IdentExec) { + self.print_ident(&ident_exec.ident); + self.print_exec_ty(&ident_exec.ty); + } + + fn print_ident_kinded(&mut self, ident_kinded: &IdentKinded) { + self.print_ident(&ident_kinded.ident); + self.print_kind(ident_kinded.kind); + } + + fn print_kind(&mut self, kind: Kind) { + let kind_str = match kind { + Kind::DataTy => "dty", + Kind::Provenance => "prv", + Kind::Memory => "mem", + Kind::Nat => "nat", + }; + self.string.push_str(kind_str); + } + + pub fn print_exec_ty(&mut self, exec_ty: &ExecTy) { + match &exec_ty.ty { + ExecTyKind::CpuThread => self.string.push_str("cpu.thread"), + ExecTyKind::GpuGrid(gdim, bdim) => { + self.string.push_str("gpu.grid<"); + print_static_list!(self, Self::print_dim, gdim, bdim); + self.string.push('>'); + } + ExecTyKind::GpuToThreads(dim, exec_ty) => { + self.string.push_str("gpu.global_threads<"); + self.print_dim(dim); + self.string.push_str(", "); + self.print_exec_ty(exec_ty); + self.string.push('>'); + } + ExecTyKind::GpuBlock(bdim) => { + self.string.push_str("gpu.block<"); + self.print_dim(bdim); + self.string.push('>'); + } + ExecTyKind::GpuThread => self.string.push_str("gpu.thread"), + ExecTyKind::GpuBlockGrp(gdim, bdim) => { + self.string.push_str("gpu.block_grp<"); + print_static_list!(self, Self::print_dim, gdim, bdim); + self.string.push('>'); + } + ExecTyKind::GpuThreadGrp(dim) => { + self.string.push_str("gpu.thread_grp<"); + self.print_dim(dim); + self.string.push('>'); + } + ExecTyKind::Any => self.string.push_str("view"), + ExecTyKind::GpuWarpGrp(n) => { + self.string.push_str("gpu.warp_grp<"); + self.print_nat(n); + self.string.push('>'); + } + ExecTyKind::GpuWarp => {} + } + } + + pub fn print_exec_expr(&mut self, exec_expr: &ExecExpr) { + match &exec_expr.exec.base { + BaseExec::Ident(ident) => self.print_ident(ident), + BaseExec::CpuThread => self.string.push_str("cpu.thread"), + BaseExec::GpuGrid(gdim, bdim) => { + self.string.push_str("gpu.grid<"); + self.print_dim(gdim); + self.string.push_str(", "); + self.print_dim(bdim); + self.string.push('>'); + } + } + for pe in &exec_expr.exec.path { + self.string.push('.'); + match pe { + ExecPathElem::TakeRange(split_proj) => self.print_take_range(split_proj), + ExecPathElem::ForAll(dim_compo) => { + self.string.push_str("forall("); + self.print_dim_compo(dim_compo); + self.string.push(')'); + } + ExecPathElem::ToThreads(dim_compo) => { + self.string.push_str("to_threads("); + self.print_dim_compo(dim_compo); + self.string.push(')'); + } + ExecPathElem::ToWarps => self.string.push_str("to_warps"), + } + } + } + + fn print_take_range(&mut self, take_range: &TakeRange) { + self.string.push('['); + self.print_dim_compo(&take_range.split_dim); + self.string.push_str("; "); + match &take_range.left_or_right { + LeftOrRight::Left => { + self.string.push_str(".."); + self.print_nat(&take_range.pos); + } + LeftOrRight::Right => { + self.print_nat(&take_range.pos); + self.string.push_str(".."); + } + } + self.string.push(']'); + } + + fn print_dim(&mut self, dim: &Dim) { + match dim { + Dim::XYZ(dim3d) => { + self.string.push_str("XYZ<"); + print_static_list!(self, Self::print_nat, &dim3d.0, &dim3d.1, &dim3d.2); + } + Dim::XY(dim2d) => { + self.string.push_str("XY<"); + print_static_list!(self, Self::print_nat, &dim2d.0, &dim2d.1); + } + Dim::XZ(dim2d) => { + self.string.push_str("XZ<"); + print_static_list!(self, Self::print_nat, &dim2d.0, &dim2d.1); + } + Dim::YZ(dim2d) => { + self.string.push_str("YZ<"); + print_static_list!(self, Self::print_nat, &dim2d.0, &dim2d.1); + } + Dim::X(dim1d) => { + self.string.push_str("X<"); + self.print_nat(&dim1d.0) + } + Dim::Y(dim1d) => { + self.string.push_str("Y<"); + self.print_nat(&dim1d.0) + } + Dim::Z(dim1d) => { + self.string.push_str("Z<"); + self.print_nat(&dim1d.0) + } + } + self.string.push('>'); + } + + fn print_dim_compo(&mut self, dim_compo: &DimCompo) { + match dim_compo { + DimCompo::X => self.string.push('X'), + DimCompo::Y => self.string.push('Y'), + DimCompo::Z => self.string.push('Z'), + } + } + + pub fn print_aty(&mut self, aty: &AtomicTy) { + match &aty { + AtomicTy::AtomicU32 => self.string.push_str("AtomicU32"), + AtomicTy::AtomicI32 => self.string.push_str("AtomicI32"), + } + } + + pub fn print_dty(&mut self, dty: &DataTy) { + match &dty.dty { + DataTyKind::Ident(ident) => self.print_ident(ident), + DataTyKind::Scalar(sty) => self.print_sty(sty), + DataTyKind::Atomic(aty) => self.print_aty(aty), + DataTyKind::Array(dty, n) => { + self.string.push('['); + self.print_dty(dty); + self.string.push(';'); + self.print_nat(n); + self.string.push(']'); + } + DataTyKind::ArrayShape(dty, n) => { + self.string.push_str("[["); + self.print_dty(dty); + self.string.push(';'); + self.print_nat(n); + self.string.push_str("]]"); + } + DataTyKind::Tuple(dtys) => { + self.string.push('('); + print_list!(self, Self::print_dty, dtys); + self.string.push(')'); + } + DataTyKind::Struct(struct_decl) => { + self.string.push_str("struct "); + self.print_ident(&struct_decl.ident); + self.string.push_str(" { "); + print_list!(self, Self::print_field, &struct_decl.fields); + self.string.push_str(" }"); + } + DataTyKind::At(dty, mem) => { + self.print_dty(dty); + self.string.push('@'); + self.print_mem(mem); + } + DataTyKind::Ref(ref_dty) => { + self.string.push('&'); + self.print_prv(&ref_dty.rgn); + self.string.push(' '); + self.print_own(ref_dty.own); + self.string.push(' '); + self.print_mem(&ref_dty.mem); + self.string.push(' '); + self.print_dty(&ref_dty.dty); + } + DataTyKind::RawPtr(_) => { + unimplemented!() + } + DataTyKind::Dead(dty) => self.print_dty(dty), + } + } + + fn print_field(&mut self, field: &(Ident, DataTy)) { + self.print_ident(&field.0); + self.string.push_str(" : "); + self.print_dty(&field.1) + } + + fn print_sty(&mut self, sty: &ScalarTy) { + match sty { + ScalarTy::Unit => self.string.push_str("()"), + ScalarTy::U32 => self.string.push_str("u32"), + ScalarTy::U64 => self.string.push_str("u64"), + ScalarTy::I32 => self.string.push_str("i32"), + ScalarTy::I64 => self.string.push_str("i64"), + ScalarTy::F32 => self.string.push_str("f32"), + ScalarTy::F64 => self.string.push_str("f64"), + ScalarTy::Bool => self.string.push_str("bool"), + ScalarTy::Gpu => self.string.push_str("Gpu"), + ScalarTy::U8 => self.string.push_str("u8"), + } + } + + fn print_mem(&mut self, mem: &Memory) { + match mem { + Memory::CpuMem => self.string.push_str("cpu.mem"), + Memory::GpuGlobal => self.string.push_str("gpu.global"), + Memory::GpuShared => self.string.push_str("gpu.shared"), + Memory::GpuLocal => self.string.push_str("gpu.local"), + Memory::Ident(x) => self.print_ident(x), + } + } + + fn print_prv(&mut self, prv: &Provenance) { + match prv { + Provenance::Value(name) => self.string.push_str(&name), + Provenance::Ident(ident) => self.print_ident(ident), + } + } + + fn print_own(&mut self, own: Ownership) { + match own { + Ownership::Shrd => self.string.push_str("shrd"), + Ownership::Uniq => self.string.push_str("uniq"), + } + } + + fn print_nat(&mut self, n: &Nat) { + match n { + Nat::Ident(ident) => self.print_ident(ident), + Nat::GridIdx => {} // print nothing + Nat::BlockIdx(d) => { + self.string.push_str("blockIdx."); + self.print_dim_compo(d); + } + Nat::BlockDim(d) => { + self.string.push_str("blockDim."); + self.print_dim_compo(d); + } + Nat::ThreadIdx(d) => { + self.string.push_str("threadIdx."); + self.print_dim_compo(d); + } + Nat::Lit(n) => write!(&mut self.string, "{}", n).unwrap(), + Nat::BinOp(op, lhs, rhs) => { + self.string.push('('); + self.print_nat(lhs); + self.print_bin_op_nat(op); + self.print_nat(rhs); + self.string.push(')'); + } + Nat::App(func, args) => { + self.string.push_str("{}("); + self.print_ident(func); + if let Some((last, leading)) = args.split_last() { + for arg in leading { + self.print_nat(arg); + } + self.print_nat(last); + self.string.push(')'); + } + } + Nat::WarpGrpIdx => self.string.push_str("warpGrpIdx"), + Nat::WarpIdx => self.string.push_str("warpIdx"), + Nat::LaneIdx => self.string.push_str("laneIdx"), + } + } + + fn print_bin_op_nat(&mut self, bin_op_nat: &BinOpNat) { + match bin_op_nat { + BinOpNat::Add => self.string.push('+'), + BinOpNat::Sub => self.string.push('-'), + BinOpNat::Mul => self.string.push('*'), + BinOpNat::Div => self.string.push('/'), + BinOpNat::Mod => self.string.push('%'), + } + } +} diff --git a/src/arena_ast/utils.rs b/src/arena_ast/utils.rs new file mode 100644 index 00000000..e458f738 --- /dev/null +++ b/src/arena_ast/utils.rs @@ -0,0 +1,469 @@ +use crate::arena_ast::visit::walk_list; +use crate::arena_ast::visit::Visit; +use crate::arena_ast::visit_mut::walk_list as walk_list_mut; +use crate::arena_ast::visit_mut::VisitMut; +use crate::arena_ast::{ + visit, visit_mut, ArgKinded, BaseExec, DataTy, DataTyKind, Dim, ExecExpr, ExecExprKind, ExecTy, + Expr, ExprKind, FnTy, FunDef, Ident, IdentExec, IdentKinded, Kind, Memory, Nat, ParamSig, + Provenance, Ty, TyKind, +}; +use std::collections::{HashMap, HashSet}; +use std::sync::atomic::{AtomicI32, Ordering}; + +static mut COUNTER: AtomicI32 = AtomicI32::new(0); + +/** +pub(crate) fn fresh_ident<'a, F, R>(arena: &'a bumpalo::Bump, name: &str, ident_constr: F) -> R +where + F: Fn(Ident) -> R, +{ + ident_constr(Ident::new_impli(&arena, &fresh_name(name))) +} +*/ + +pub(crate) fn fresh_ident<'a, F, R>(arena: &'a bumpalo::Bump, name: &'a str, ident_constr: F) -> R +where + F: FnOnce(Ident<'a>) -> R, +{ + let id = Ident::new(arena, name); + ident_constr(id) +} + +pub(crate) fn fresh_name(name: &str) -> String { + let prefix = format!("${}", name); + let i; + unsafe { + i = COUNTER.fetch_add(1, Ordering::SeqCst); + } + format!("{}_{}", prefix, i) +} + +pub fn implicit_idents<'a>(f: &FunDef<'a>) -> Option>> { + struct ImplicitIdents<'b>(HashSet>); + impl<'b> Visit<'b> for ImplicitIdents<'b> { + fn visit_ident(&mut self, ident: &Ident<'b>) { + if ident.is_implicit { + self.0.insert(ident.clone()); + } + } + } + + let mut impl_idents = ImplicitIdents(HashSet::new()); + impl_idents.visit_fun_def(f); + if impl_idents.0.is_empty() { + None + } else { + Some(impl_idents.0) + } +} + +pub trait VisitableMut<'a> { + fn visit_mut>(&mut self, visitor: &mut V, arena: &'a bumpalo::Bump); +} + +macro_rules! visitable_mut { + ($t:ident, $f:ident) => { + impl<'a> VisitableMut<'a> for $t<'a> { + fn visit_mut>(&mut self, visitor: &mut V, arena: &'a bumpalo::Bump) { + visitor.$f(arena, self); + } + } + }; +} + +visitable_mut!(Ty, visit_ty); +visitable_mut!(Expr, visit_expr); +visitable_mut!(ExecExpr, visit_exec_expr); +visitable_mut!(IdentExec, visit_ident_exec); +visitable_mut!(ParamSig, visit_param_sig); +visitable_mut!(FnTy, visit_fn_ty); + +/* + * gen_idents: a list of generic identifiers to be substituted (this list can be longer than + * gen_args. In that case, only the first gen_args.len() many identifiers are substituted. + * gen_args: the kinded expressions that are substituting the generic identifiers + * t: the term to substitute in + */ +pub fn subst_idents_kinded<'a, I, J, T: VisitableMut<'a>>( + arena: &'a bumpalo::Bump, + gen_idents: I, + gen_args: J, + t: &mut T, +) where + I: IntoIterator>, + J: IntoIterator>, +{ + let subst_map: HashMap<&'a str, &'a ArgKinded<'a>> = gen_idents + .into_iter() + .map(|p| p.ident.name.as_ref()) + .zip(gen_args) + .collect(); + + let mut v = SubstIdentsKinded::new(&subst_map); + t.visit_mut(&mut v, arena); +} + +pub fn subst_ident_exec<'a, T: VisitableMut<'a>>( + arena: &'a bumpalo::Bump, + ident: &'a Ident<'a>, + exec: &'a ExecExpr<'a>, + t: &mut T, +) { + let mut subst_ident_exec = SubstIdentExec::new(ident, exec); + t.visit_mut(&mut subst_ident_exec, arena); +} + +/* + * substitute kinded arguments for free identifiers + * + * When substituting within a function definition or function type, the generic parameters are + * bound. In order to substitute generic identifiers with their arguments, the relevant generic + * identifiers must be removed from the list, first. + */ +struct SubstIdentsKinded<'a, 'm> { + pub subst_map: &'m HashMap<&'a str, &'a ArgKinded<'a>>, + pub bound_idents: HashSet>, +} + +impl<'a, 'm> SubstIdentsKinded<'a, 'm> { + fn new(subst_map: &'m HashMap<&'a str, &'a ArgKinded<'a>>) -> Self { + Self { + subst_map, + bound_idents: HashSet::new(), + } + } + + fn with_bound_idents(&self, bound_idents: HashSet>) -> Self { + Self { + subst_map: self.subst_map, + bound_idents, + } + } +} + +impl<'a, 'm> VisitMut<'a> for SubstIdentsKinded<'a, 'm> { + fn visit_nat(&mut self, arena: &'a bumpalo::Bump, nat: &mut Nat<'a>) { + match nat { + Nat::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Nat); + if !self.bound_idents.contains(&ident_kinded) { + if let Some(ArgKinded::Nat(nat_arg)) = + self.subst_map.get::(ident.name.as_ref()) + { + *nat = nat_arg.clone() + } + } + } + _ => visit_mut::walk_nat(self, arena, nat), + } + } + + fn visit_mem(&mut self, arena: &'a bumpalo::Bump, mem: &mut Memory<'a>) { + match mem { + Memory::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Memory); + if !self.bound_idents.contains(&ident_kinded) { + if let Some(ArgKinded::Memory(mem_arg)) = + self.subst_map.get::(ident.name.as_ref()) + { + *mem = mem_arg.clone() + } + } + } + _ => visit_mut::walk_mem(self, arena, mem), + } + } + + fn visit_prv(&mut self, arena: &'a bumpalo::Bump, prv: &mut Provenance<'a>) { + match prv { + Provenance::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Provenance); + if !self.bound_idents.contains(&ident_kinded) { + if let Some(ArgKinded::Provenance(prv_arg)) = + self.subst_map.get::(ident.name.as_ref()) + { + *prv = prv_arg.clone() + } + } + } + _ => visit_mut::walk_prv(self, arena, prv), + } + } + + fn visit_dty(&mut self, arena: &'a bumpalo::Bump, dty: &mut DataTy<'a>) { + match &mut dty.dty { + DataTyKind::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::DataTy); + if !self.bound_idents.contains(&ident_kinded) { + if let Some(ArgKinded::DataTy(dty_arg)) = + self.subst_map.get::(ident.name.as_ref()) + { + *dty = dty_arg.clone() + } + } + } + _ => visit_mut::walk_dty(self, arena, dty), + } + } + + // add generic paramters to list of bound identifiers + fn visit_fn_ty(&mut self, arena: &'a bumpalo::Bump, fn_ty: &mut FnTy<'a>) { + let fun_bound_idents = fn_ty.generics.clone(); + let mut all_bound_idents = self.bound_idents.clone(); + all_bound_idents.extend(fun_bound_idents); + let mut visitor_subst_generic_ident = + SubstIdentsKinded::with_bound_idents(self, all_bound_idents); + walk_list_mut!( + &mut visitor_subst_generic_ident, + visit_param_sig, + &mut fn_ty.param_sigs.as_mut_slice(), + arena + ); + if let Some(ident_exec) = &mut fn_ty.generic_exec { + let mut owned = (*ident_exec.ty).clone(); + self.visit_exec_ty(&mut owned); + ident_exec.ty = arena.alloc(owned); + } + + visitor_subst_generic_ident.visit_exec_expr(arena, &mut fn_ty.exec); + let mut ret_owned = (*fn_ty.ret_ty).clone(); + self.visit_ty(arena, &mut ret_owned); + fn_ty.ret_ty = arena.alloc(ret_owned); + } + + // only required to introduce a new scope of bound identifiers + fn visit_expr(&mut self, arena: &'a bumpalo::Bump, expr: &mut Expr<'a>) { + match &mut expr.expr { + ExprKind::ForNat(ident, collec, body) => { + let mut range_owned = (**collec).clone(); + self.visit_nat_range(arena, &mut range_owned); + *collec = arena.alloc(range_owned); + let mut scoped_bound_idents = self.bound_idents.clone(); + scoped_bound_idents.extend(std::iter::once(IdentKinded::new(ident, Kind::Nat))); + let mut subst_inner_kinded_idents = + SubstIdentsKinded::with_bound_idents(self, scoped_bound_idents); + let mut body_owned = (**body).clone(); + subst_inner_kinded_idents.visit_expr(arena, &mut body_owned); + *body = arena.alloc(body_owned); + } + _ => visit_mut::walk_expr(self, arena, expr), + } + } + + // add generic paramters to list of bound identifiers + fn visit_fun_def(&mut self, arena: &'a bumpalo::Bump, fun_def: &mut FunDef<'a>) { + let fun_bound_idents = fun_def.generic_params.clone(); + let mut all_bound_idents = self.bound_idents.clone(); + all_bound_idents.extend(fun_bound_idents); + let mut subst_fun_free_kind_idents = + SubstIdentsKinded::with_bound_idents(self, all_bound_idents); + walk_list_mut!( + &mut subst_fun_free_kind_idents, + visit_param_decl, + &mut fun_def.param_decls.as_mut_slice(), + arena + ); + let mut ret_owned = (*fun_def.ret_dty).clone(); + self.visit_dty(arena, &mut ret_owned); + fun_def.ret_dty = arena.alloc(ret_owned); + for ident_exec in &mut fun_def.generic_exec { + subst_fun_free_kind_idents.visit_ident_exec(arena, ident_exec); + } + subst_fun_free_kind_idents.visit_exec_expr(arena, &mut fun_def.exec); + walk_list_mut!( + subst_fun_free_kind_idents, + visit_prv_rel, + &mut fun_def.prv_rels.as_mut_slice(), + arena + ); + let mut body_owned = (*fun_def.body).clone(); + self.visit_block(arena, &mut body_owned); + fun_def.body = arena.alloc(body_owned); + } +} + +/* + * Substitue a generic exec identifier with specific exec. + * This substitution ignores whehter an execution identifier is bound by a function type. + */ +struct SubstIdentExec<'a> { + pub ident: &'a Ident<'a>, + pub exec: &'a ExecExpr<'a>, +} + +impl<'a> SubstIdentExec<'a> { + fn new(ident: &'a Ident<'a>, exec: &'a ExecExpr<'a>) -> Self { + SubstIdentExec { ident, exec } + } +} + +impl<'a> VisitMut<'a> for SubstIdentExec<'a> { + fn visit_exec_expr(&mut self, arena: &'a bumpalo::Bump, exec_expr: &mut ExecExpr<'a>) { + insert_for_ident(arena, self.exec, self.ident, exec_expr) + } +} + +fn insert_for_ident<'a>( + bump: &'a bumpalo::Bump, + exec: &ExecExpr<'a>, + ident: &Ident<'a>, + in_exec: &mut ExecExpr<'a>, +) { + if let BaseExec::Ident(i) = &in_exec.exec.base { + if i == ident { + // Build merged path in this arena + let mut merged = bumpalo::collections::Vec::new_in(bump); + merged.extend(exec.exec.path.iter().cloned()); + merged.extend(in_exec.exec.path.iter().cloned()); + + // New exec node allocated in arena + let new_kind = bump.alloc(ExecExprKind { + base: exec.exec.base.clone(), + path: merged, + }); + + // Keep or drop the cached type + // let new_ty = in_exec.ty; // keep it (may be stale) + let new_ty = None; // safer: force re-tycheck later + + *in_exec = ExecExpr { + exec: new_kind, + ty: new_ty, + span: in_exec.span, + }; + } + } +} + +pub trait Visitable<'a> { + fn visit>(&self, visitor: &mut V); +} +macro_rules! visitable { + ($t:ident, $f:ident) => { + impl<'a> Visitable<'a> for $t<'a> { + fn visit>(&self, visitor: &mut V) { + visitor.$f(self); + } + } + }; +} +visitable!(Ty, visit_ty); +visitable!(FnTy, visit_fn_ty); +visitable!(ParamSig, visit_param_sig); +visitable!(DataTy, visit_dty); +visitable!(Memory, visit_mem); +visitable!(Provenance, visit_prv); +visitable!(ExecExpr, visit_exec_expr); +visitable!(ExecTy, visit_exec_ty); +visitable!(Dim, visit_dim); +visitable!(Expr, visit_expr); +visitable!(Nat, visit_nat); + +pub fn free_kinded_idents<'a, T: Visitable<'a>>(t: &T) -> HashSet> { + let mut free_kinded_idents = FreeKindedIdents::new(); + t.visit(&mut free_kinded_idents); + free_kinded_idents.set +} + +pub struct FreeKindedIdents<'a> { + pub set: HashSet>, + pub bound_idents: HashSet>, +} + +impl<'a> FreeKindedIdents<'a> { + fn new() -> Self { + FreeKindedIdents { + set: HashSet::new(), + bound_idents: HashSet::new(), + } + } + + fn with_bound_idents(idents: HashSet>) -> Self { + FreeKindedIdents { + set: HashSet::new(), + bound_idents: idents, + } + } +} + +impl<'a> Visit<'a> for FreeKindedIdents<'a> { + fn visit_nat(&mut self, nat: &Nat<'a>) { + match nat { + Nat::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Nat); + if !self.bound_idents.contains(&ident_kinded) { + self.set.extend(std::iter::once(ident_kinded)) + } + } + _ => visit::walk_nat(self, nat), + } + } + + fn visit_mem(&mut self, mem: &Memory<'a>) { + match mem { + Memory::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Memory); + if !self.bound_idents.contains(&ident_kinded) { + self.set.extend(std::iter::once(ident_kinded)) + } + } + _ => visit::walk_mem(self, mem), + } + } + + fn visit_prv(&mut self, prv: &Provenance<'a>) { + match prv { + Provenance::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::Provenance); + if !self.bound_idents.contains(&ident_kinded) { + self.set.extend(std::iter::once(ident_kinded)) + } + } + _ => visit::walk_prv(self, prv), + } + } + + fn visit_dty(&mut self, dty: &DataTy<'a>) { + match &dty.dty { + DataTyKind::Ident(ident) => { + let ident_kinded = IdentKinded::new(ident, Kind::DataTy); + if !self.bound_idents.contains(&ident_kinded) { + self.set.extend(std::iter::once(ident_kinded)) + } + } + _ => visit::walk_dty(self, dty), + } + } + + fn visit_ty(&mut self, ty: &Ty<'a>) { + match &ty.ty { + TyKind::FnTy(fn_ty) => { + if !fn_ty.generics.is_empty() { + panic!( + "Generic function types can not appear,\ + only their instatiated counter parts." + ) + } + + walk_list!(self, visit_param_sig, &fn_ty.param_sigs); + self.visit_ty(fn_ty.ret_ty) + } + _ => visit::walk_ty(self, ty), + } + } + + fn visit_expr(&mut self, expr: &Expr<'a>) { + match &expr.expr { + ExprKind::ForNat(ident, collec, body) => { + self.visit_nat_range(collec); + let mut scoped_bound_idents = self.bound_idents.clone(); + scoped_bound_idents.extend(std::iter::once(IdentKinded::new(ident, Kind::Nat))); + let mut inner_free_idents = + FreeKindedIdents::with_bound_idents(scoped_bound_idents); + inner_free_idents.visit_expr(body); + self.set.extend(inner_free_idents.set) + } + _ => visit::walk_expr(self, expr), + } + } +} diff --git a/src/arena_ast/visit.rs b/src/arena_ast/visit.rs new file mode 100644 index 00000000..52c96f73 --- /dev/null +++ b/src/arena_ast/visit.rs @@ -0,0 +1,580 @@ +use crate::arena_ast::*; + +#[rustfmt::skip] +pub trait Visit<'a>: Sized { + fn visit_binary_op_nat(&mut self, _op: &BinOpNat) {} + fn visit_nat(&mut self, n: &Nat<'a>) { walk_nat(self, n) } + fn visit_nat_range(&mut self, nr: &NatRange<'a>) { walk_nat_range(self, nr) } + fn visit_ident_kinded(&mut self, id_kind: &IdentKinded<'a>) { walk_ident_kinded(self, id_kind) } + fn visit_ident_exec(&mut self, ident_exec: &IdentExec<'a>) { walk_ident_exec(self, ident_exec) } + fn visit_prv_rel(&mut self, prv_rel: &PrvRel<'a>) { walk_prv_rel(self, prv_rel) } + fn visit_exec_ty(&mut self, _exec: &ExecTy<'a>) {} + fn visit_mem(&mut self, mem: &Memory<'a>) { walk_mem(self, mem) } + fn visit_prv(&mut self, prv: &Provenance<'a>) { walk_prv(self, prv) } + fn visit_scalar_ty(&mut self, _sty: &ScalarTy) {} + fn visit_atomic_ty(&mut self, _aty: &AtomicTy) {} + fn visit_dim_compo(&mut self, _dim_compo: &DimCompo) { } + fn visit_dim(&mut self, dim: &Dim<'a>) { walk_dim(self, dim) } + fn visit_dim3d(&mut self, dim3d: &Dim3d<'a>) { walk_dim3d(self, dim3d) } + fn visit_dim2d(&mut self, dim2d: &Dim2d<'a>) { walk_dim2d(self, dim2d) } + fn visit_dim1d(&mut self, dim1d: &Dim1d<'a>) { walk_dim1d(self, dim1d) } + fn visit_ref(&mut self, reff: &RefDty<'a>) { walk_ref(self, reff) } + fn visit_dty(&mut self, dty: &DataTy<'a>) { walk_dty(self, dty) } + fn visit_fn_ty(&mut self, fn_ty: &FnTy<'a>) { walk_fn_ty(self, fn_ty) } + fn visit_nat_constr(&mut self, nat_constr: &NatConstr<'a>) { walk_nat_constr(self, nat_constr) } + fn visit_ty(&mut self, ty: &Ty<'a>) { walk_ty(self, ty) } + fn visit_view(&mut self, view: &View<'a>) { walk_view(self, view) } + fn visit_pl_expr(&mut self, pl_expr: &PlaceExpr<'a>) { walk_pl_expr(self, pl_expr) } + fn visit_arg_kinded(&mut self, arg_kinded: &ArgKinded<'a>) { walk_arg_kinded(self, arg_kinded) } + fn visit_kind(&mut self, _kind: &Kind) {} + fn visit_binary_op(&mut self, _op: &BinOp) {} + fn visit_unary_op(&mut self, _op: &UnOp) {} + fn visit_own(&mut self, _own: &Ownership) {} + fn visit_mutability(&mut self, _mutbl: &Mutability) {} + fn visit_lit(&mut self, _lit: &Lit) {} + fn visit_ident(&mut self, _ident: &Ident<'a>) {} + fn visit_pattern(&mut self, pattern: &Pattern<'a>) { walk_pattern(self, pattern) } + fn visit_split(&mut self, split: &Split<'a>) { walk_split(self, split) } + fn visit_sched(&mut self, par_for: &Sched<'a>) { walk_sched(self, par_for) } + fn visit_expr(&mut self, expr: &Expr<'a>) { walk_expr(self, expr) } + fn visit_app_kernel(&mut self, app_kernel: &AppKernel<'a>) { walk_app_kernel(self, app_kernel) } + fn visit_block(&mut self, block: &Block<'a>) { walk_block(self, block) } + fn visit_split_proj(&mut self, exec_split: &TakeRange<'a>) { walk_split_proj(self, exec_split) } + fn visit_exec_expr(&mut self, exec_expr: &ExecExpr<'a>) { walk_exec_expr(self, exec_expr) } + fn visit_exec(&mut self, exec: &ExecExprKind<'a>) { walk_exec(self, exec) } + fn visit_param_decl(&mut self, param_decl: &ParamDecl<'a>) { walk_param_decl(self, param_decl) } + fn visit_fun_def(&mut self, fun_def: &FunDef<'a>) { walk_fun_def(self, fun_def) } + fn visit_fun_decl(&mut self, fun_decl: &FunDecl<'a>) { walk_fun_decl(self, fun_decl) } + fn visit_param_sig(&mut self, param_sig: &ParamSig<'a>) { walk_param_sig(self, param_sig) } + fn visit_field(&mut self, field: &(Ident<'a>, DataTy<'a>)) { walk_field(self, field) } +} + +macro_rules! walk_list { + ($visitor: expr, $method: ident, $list: expr) => { + for elem in $list.iter() { + $visitor.$method(elem) + } + }; +} +pub(crate) use walk_list; + +pub fn walk_nat<'a, V: Visit<'a>>(visitor: &mut V, n: &Nat<'a>) { + match n { + Nat::Ident(ident) => visitor.visit_ident(ident), + Nat::BinOp(op, l, r) => { + visitor.visit_binary_op_nat(op); + visitor.visit_nat(l); + visitor.visit_nat(r) + } + Nat::GridIdx + | Nat::BlockIdx(_) + | Nat::BlockDim(_) + | Nat::ThreadIdx(_) + | Nat::WarpGrpIdx + | Nat::WarpIdx + | Nat::LaneIdx + | Nat::Lit(_) => {} + Nat::App(func, args) => { + visitor.visit_ident(func); + walk_list!(visitor, visit_nat, args) + } + } +} + +pub fn walk_nat_range<'a, V: Visit<'a>>(visitor: &mut V, nr: &NatRange<'a>) { + match nr { + NatRange::Simple { lower, upper } => { + visitor.visit_nat(lower); + visitor.visit_nat(upper); + } + NatRange::Halved { upper } | NatRange::Doubled { upper } => visitor.visit_nat(upper), + } +} + +pub fn walk_ident_kinded<'a, V: Visit<'a>>(visitor: &mut V, id_kind: &IdentKinded<'a>) { + let IdentKinded { ident, kind } = id_kind; + visitor.visit_ident(ident); + visitor.visit_kind(kind) +} + +pub fn walk_ident_exec<'a, V: Visit<'a>>(visitor: &mut V, id_exec: &IdentExec<'a>) { + let IdentExec { ident, ty } = id_exec; + visitor.visit_ident(ident); + visitor.visit_exec_ty(ty) +} + +pub fn walk_prv_rel<'a, V: Visit<'a>>(visitor: &mut V, prv_rel: &PrvRel<'a>) { + let PrvRel { longer, shorter } = prv_rel; + visitor.visit_ident(longer); + visitor.visit_ident(shorter) +} + +pub fn walk_mem<'a, V: Visit<'a>>(visitor: &mut V, mem: &Memory<'a>) { + if let Memory::Ident(ident) = mem { + visitor.visit_ident(ident) + } +} + +pub fn walk_prv<'a, V: Visit<'a>>(visitor: &mut V, prv: &Provenance<'a>) { + match prv { + Provenance::Ident(ident) => visitor.visit_ident(ident), + Provenance::Value(_) => {} + } +} + +pub fn walk_dim3d<'a, V: Visit<'a>>(visitor: &mut V, dim3d: &Dim3d<'a>) { + let Dim3d(n1, n2, n3) = dim3d; + visitor.visit_nat(n1); + visitor.visit_nat(n2); + visitor.visit_nat(n3); +} + +pub fn walk_dim2d<'a, V: Visit<'a>>(visitor: &mut V, dim2d: &Dim2d<'a>) { + let Dim2d(n1, n2) = dim2d; + visitor.visit_nat(n1); + visitor.visit_nat(n2); +} + +pub fn walk_dim1d<'a, V: Visit<'a>>(visitor: &mut V, dim1d: &Dim1d<'a>) { + let Dim1d(n) = dim1d; + visitor.visit_nat(n); +} + +pub fn walk_dim<'a, V: Visit<'a>>(visitor: &mut V, dim: &Dim<'a>) { + match dim { + Dim::XYZ(dim3d) => { + visitor.visit_dim3d(dim3d); + } + Dim::XY(dim2d) | Dim::XZ(dim2d) | Dim::YZ(dim2d) => { + visitor.visit_dim2d(dim2d); + } + Dim::X(dim1d) | Dim::Y(dim1d) | Dim::Z(dim1d) => visitor.visit_dim1d(dim1d), + } +} + +pub fn walk_ref<'a, V: Visit<'a>>(visitor: &mut V, reff: &RefDty<'a>) { + let RefDty { rgn, own, mem, dty } = reff; + visitor.visit_prv(rgn); + visitor.visit_own(own); + visitor.visit_mem(mem); + visitor.visit_dty(dty); +} + +pub fn walk_dty<'a, V: Visit<'a>>(visitor: &mut V, dty: &DataTy<'a>) { + match &dty.dty { + DataTyKind::Ident(ident) => visitor.visit_ident(ident), + DataTyKind::Scalar(sty) => visitor.visit_scalar_ty(sty), + DataTyKind::Atomic(aty) => visitor.visit_atomic_ty(aty), + DataTyKind::Tuple(elem_dtys) => walk_list!(visitor, visit_dty, elem_dtys), + DataTyKind::Struct(struct_decl) => { + visitor.visit_ident(&struct_decl.ident); + walk_list!(visitor, visit_field, &struct_decl.fields) + } + DataTyKind::Array(dty, n) => { + visitor.visit_dty(dty); + visitor.visit_nat(n) + } + DataTyKind::ArrayShape(dty, n) => { + visitor.visit_dty(dty); + visitor.visit_nat(n); + } + DataTyKind::At(dty, mem) => { + visitor.visit_dty(dty); + visitor.visit_mem(mem) + } + DataTyKind::Ref(reff) => { + visitor.visit_ref(reff); + } + DataTyKind::RawPtr(dty) => visitor.visit_dty(dty), + DataTyKind::Dead(dty) => visitor.visit_dty(dty), + } +} + +pub fn walk_fn_ty<'a, V: Visit<'a>>(visitor: &mut V, fn_ty: &FnTy<'a>) { + let FnTy { + generics, + generic_exec, + param_sigs, + exec, + ret_ty, + nat_constrs, + } = fn_ty; + walk_list!(visitor, visit_ident_kinded, generics); + for exec_decl in generic_exec { + visitor.visit_ident_exec(exec_decl) + } + walk_list!(visitor, visit_param_sig, param_sigs); + visitor.visit_exec_expr(exec); + visitor.visit_ty(ret_ty); + walk_list!(visitor, visit_nat_constr, nat_constrs); +} + +pub fn walk_nat_constr<'a, V: Visit<'a>>(visitor: &mut V, nat_constr: &NatConstr<'a>) { + match nat_constr { + NatConstr::True => {} + NatConstr::Eq(l, r) => { + visitor.visit_nat(l); + visitor.visit_nat(r); + } + NatConstr::Lt(l, r) => { + visitor.visit_nat(l); + visitor.visit_nat(r); + } + NatConstr::And(l, r) => { + visitor.visit_nat_constr(l); + visitor.visit_nat_constr(r); + } + NatConstr::Or(l, r) => { + visitor.visit_nat_constr(l); + visitor.visit_nat_constr(r); + } + } +} + +pub fn walk_ty<'a, V: Visit<'a>>(visitor: &mut V, ty: &Ty<'a>) { + match &ty.ty { + TyKind::Data(dty) => visitor.visit_dty(dty), + TyKind::FnTy(fn_ty) => { + visitor.visit_fn_ty(fn_ty); + } + } +} + +pub fn walk_view<'a, V: Visit<'a>>(visitor: &mut V, view: &View<'a>) { + visitor.visit_ident(&view.name); + walk_list!(visitor, visit_arg_kinded, &view.gen_args); + for v in &view.args { + visitor.visit_view(v) + } +} + +pub fn walk_pl_expr<'a, V: Visit<'a>>(visitor: &mut V, pl_expr: &PlaceExpr<'a>) { + match &pl_expr.pl_expr { + PlaceExprKind::Ident(ident) => visitor.visit_ident(ident), + PlaceExprKind::Deref(pl_expr) => visitor.visit_pl_expr(pl_expr), + PlaceExprKind::Select(p, distrib_exec) => { + visitor.visit_pl_expr(p); + visitor.visit_exec_expr(distrib_exec); + } + PlaceExprKind::Proj(pl_expr, _) => { + visitor.visit_pl_expr(pl_expr); + } + PlaceExprKind::FieldProj(pl_expr, field_name) => { + visitor.visit_pl_expr(pl_expr); + visitor.visit_ident(field_name); + } + PlaceExprKind::View(pl_expr, view) => { + visitor.visit_pl_expr(pl_expr); + visitor.visit_view(view); + } + PlaceExprKind::Idx(pl_expr, n) => { + visitor.visit_pl_expr(pl_expr); + visitor.visit_nat(n) + } + } +} + +pub fn walk_arg_kinded<'a, V: Visit<'a>>(visitor: &mut V, arg_kinded: &ArgKinded<'a>) { + match arg_kinded { + ArgKinded::Ident(ident) => visitor.visit_ident(ident), + ArgKinded::Nat(n) => visitor.visit_nat(n), + ArgKinded::Memory(mem) => visitor.visit_mem(mem), + ArgKinded::DataTy(dty) => visitor.visit_dty(dty), + ArgKinded::Provenance(prv) => visitor.visit_prv(prv), + } +} + +pub fn walk_pattern<'a, V: Visit<'a>>(visitor: &mut V, pattern: &Pattern<'a>) { + match pattern { + Pattern::Ident(mutab, ident) => { + visitor.visit_mutability(mutab); + visitor.visit_ident(ident); + } + Pattern::Tuple(patterns) => { + walk_list!(visitor, visit_pattern, patterns) + } + Pattern::Wildcard => {} + } +} + +pub fn walk_split<'a, V: Visit<'a>>(visitor: &mut V, indep: &Split<'a>) { + let Split { + dim_compo, + pos, + split_exec, + branch_idents, + branch_bodies, + } = indep; + visitor.visit_dim_compo(dim_compo); + visitor.visit_nat(pos); + visitor.visit_exec_expr(split_exec); + walk_list!(visitor, visit_ident, branch_idents); + walk_list!(visitor, visit_expr, branch_bodies); +} + +pub fn walk_sched<'a, V: Visit<'a>>(visitor: &mut V, sched: &Sched<'a>) { + let Sched { + dim, + inner_exec_ident, + sched_exec, + body, + } = sched; + visitor.visit_dim_compo(dim); + for ident in inner_exec_ident { + visitor.visit_ident(ident) + } + visitor.visit_exec_expr(sched_exec); + visitor.visit_block(body); +} + +pub fn walk_expr<'a, V: Visit<'a>>(visitor: &mut V, expr: &Expr<'a>) { + // For now, only visit ExprKind + match &expr.expr { + ExprKind::Lit(l) => visitor.visit_lit(l), + ExprKind::PlaceExpr(pl_expr) => visitor.visit_pl_expr(pl_expr), + + ExprKind::Ref(_, own, pl_expr) => { + visitor.visit_own(own); + visitor.visit_pl_expr(pl_expr); + } + ExprKind::Block(block) => visitor.visit_block(block), + ExprKind::LetUninit(maybe_exec_expr, ident, ty) => { + for e in maybe_exec_expr { + visitor.visit_exec_expr(e); + } + visitor.visit_ident(ident); + visitor.visit_ty(ty); + } + ExprKind::Let(pattern, ty, e) => { + visitor.visit_pattern(pattern); + for ty in ty.as_ref() { + visitor.visit_ty(ty); + } + visitor.visit_expr(e); + } + ExprKind::Assign(pl_expr, expr) => { + visitor.visit_pl_expr(pl_expr); + visitor.visit_expr(expr) + } + ExprKind::IdxAssign(pl_expr, idx, expr) => { + visitor.visit_pl_expr(pl_expr); + visitor.visit_nat(idx); + visitor.visit_expr(expr); + } + ExprKind::Seq(es) => { + for e in es { + visitor.visit_expr(e) + } + } + // ExprKind::Lambda(params, exec_decl, dty, expr) => { + // walk_list!(visitor, visit_param_decl, params); + // visitor.visit_ident_exec(exec_decl); + // visitor.visit_dty(dty); + // visitor.visit_expr(expr) + // } + ExprKind::App(f, gen_args, args) => { + visitor.visit_ident(f); + walk_list!(visitor, visit_arg_kinded, gen_args); + walk_list!(visitor, visit_expr, args); + } + ExprKind::DepApp(f, gen_args) => { + visitor.visit_ident(f); + walk_list!(visitor, visit_arg_kinded, gen_args); + } + ExprKind::AppKernel(app_kernel) => { + visitor.visit_app_kernel(app_kernel); + } + ExprKind::IfElse(cond, tt, ff) => { + visitor.visit_expr(cond); + visitor.visit_expr(tt); + visitor.visit_expr(ff) + } + ExprKind::If(cond, tt) => { + visitor.visit_expr(cond); + visitor.visit_expr(tt) + } + ExprKind::Array(elems) => { + walk_list!(visitor, visit_expr, elems); + } + ExprKind::Tuple(elems) => { + walk_list!(visitor, visit_expr, elems); + } + ExprKind::For(ident, coll, body) => { + visitor.visit_ident(ident); + visitor.visit_expr(coll); + visitor.visit_expr(body); + } + ExprKind::Split(par_branch) => { + visitor.visit_split(par_branch); + } + ExprKind::Sched(sched) => { + visitor.visit_sched(sched); + } + ExprKind::ForNat(ident, range, body) => { + visitor.visit_ident(ident); + visitor.visit_nat_range(range); + visitor.visit_expr(body) + } + ExprKind::While(cond, body) => { + visitor.visit_expr(cond); + visitor.visit_expr(body); + } + ExprKind::BinOp(op, l, r) => { + visitor.visit_binary_op(op); + visitor.visit_expr(l); + visitor.visit_expr(r) + } + ExprKind::UnOp(op, expr) => { + visitor.visit_unary_op(op); + visitor.visit_expr(expr) + } + ExprKind::Sync(exec) => { + for e in exec { + visitor.visit_exec_expr(e) + } + } + ExprKind::Unsafe(expr) => visitor.visit_expr(expr), + ExprKind::Cast(expr, dty) => { + visitor.visit_expr(expr); + visitor.visit_dty(dty) + } + ExprKind::Range(_, _) | ExprKind::Hole => (), + } +} + +pub fn walk_app_kernel<'a, V: Visit<'a>>(visitor: &mut V, app_kernel: &AppKernel<'a>) { + let AppKernel { + grid_dim, + block_dim, + shared_mem_dtys, + shared_mem_prvs: _, + fun_ident, + gen_args, + args, + } = app_kernel; + visitor.visit_dim(grid_dim); + visitor.visit_dim(block_dim); + for dty in shared_mem_dtys { + visitor.visit_dty(dty); + } + visitor.visit_ident(fun_ident); + for garg in gen_args { + visitor.visit_arg_kinded(garg); + } + for arg in args { + visitor.visit_expr(arg); + } +} + +pub fn walk_block<'a, V: Visit<'a>>(visitor: &mut V, block: &Block<'a>) { + let Block { body, .. } = block; + visitor.visit_expr(body); +} + +pub fn walk_split_proj<'a, V: Visit<'a>>(visitor: &mut V, split_proj: &TakeRange<'a>) { + let TakeRange { + split_dim, + pos, + left_or_right: _, + } = split_proj; + visitor.visit_dim_compo(split_dim); + visitor.visit_nat(pos); +} + +pub fn walk_exec_expr<'a, V: Visit<'a>>(visitor: &mut V, exec_expr: &ExecExpr<'a>) { + visitor.visit_exec(&exec_expr.exec); + for t in &exec_expr.ty { + visitor.visit_exec_ty(t); + } +} + +pub fn walk_exec<'a, V: Visit<'a>>(visitor: &mut V, exec: &ExecExprKind<'a>) { + let ExecExprKind { base, path } = exec; + match base { + BaseExec::CpuThread => (), + BaseExec::Ident(ident) => visitor.visit_ident(ident), + BaseExec::GpuGrid(gdim, bdim) => { + visitor.visit_dim(gdim); + visitor.visit_dim(bdim); + } + }; + for e in path { + match e { + ExecPathElem::TakeRange(split_proj) => visitor.visit_split_proj(split_proj), + ExecPathElem::ForAll(dim_compo) => visitor.visit_dim_compo(dim_compo), + ExecPathElem::ToThreads(dim_compo) => visitor.visit_dim_compo(dim_compo), + ExecPathElem::ToWarps => {} + } + } +} + +pub fn walk_param_decl<'a, V: Visit<'a>>(visitor: &mut V, param_decl: &ParamDecl<'a>) { + let ParamDecl { + ident, + ty, + mutbl, + exec_expr, + } = param_decl; + visitor.visit_ident(ident); + if let Some(tty) = ty { + visitor.visit_ty(tty); + } + visitor.visit_mutability(mutbl); + for ex in exec_expr { + visitor.visit_exec_expr(ex); + } +} + +pub fn walk_fun_def<'a, V: Visit<'a>>(visitor: &mut V, fun_def: &FunDef<'a>) { + let FunDef { + ident: _, + generic_params, + generic_exec, + param_decls: params, + ret_dty, + exec, + prv_rels, + body, + } = fun_def; + walk_list!(visitor, visit_ident_kinded, generic_params); + for exec_decl in generic_exec { + visitor.visit_ident_exec(exec_decl); + } + walk_list!(visitor, visit_param_decl, params); + visitor.visit_dty(ret_dty); + visitor.visit_exec_expr(exec); + walk_list!(visitor, visit_prv_rel, prv_rels); + visitor.visit_block(body) +} + +pub fn walk_fun_decl<'a, V: Visit<'a>>(visitor: &mut V, fun_decl: &FunDecl<'a>) { + let FunDecl { + ident: _, + generic_params, + generic_exec, + param_decls: params, + ret_dty, + exec, + prv_rels, + } = fun_decl; + walk_list!(visitor, visit_ident_kinded, generic_params); + for exec_decl in generic_exec { + visitor.visit_ident_exec(exec_decl); + } + walk_list!(visitor, visit_param_decl, params); + visitor.visit_dty(ret_dty); + visitor.visit_exec_expr(exec); + walk_list!(visitor, visit_prv_rel, prv_rels); +} + +pub fn walk_param_sig<'a, V: Visit<'a>>(visitor: &mut V, param_sig: &ParamSig<'a>) { + let ParamSig { exec_expr, ty } = param_sig; + visitor.visit_exec_expr(exec_expr); + visitor.visit_ty(ty); +} + +pub fn walk_field<'a, V: Visit<'a>>(visitor: &mut V, field: &(Ident<'a>, DataTy<'a>)) { + let (ident, dty) = field; + visitor.visit_ident(ident); + visitor.visit_dty(dty); +} diff --git a/src/arena_ast/visit_mut.rs b/src/arena_ast/visit_mut.rs new file mode 100644 index 00000000..d04ee72b --- /dev/null +++ b/src/arena_ast/visit_mut.rs @@ -0,0 +1,1056 @@ +use crate::arena_ast::*; + +#[rustfmt::skip] +pub trait VisitMut<'a>: Sized { + fn visit_binary_op_nat(&mut self, _op: &mut BinOpNat) {} + fn visit_nat(&mut self, arena: &'a Bump, n: &mut Nat<'a>) { walk_nat(self, arena, n) } + fn visit_nat_ref(&mut self, arena: &'a Bump, slot: &mut &'a Nat<'a>) { + walk_nat_ref(self, arena, slot) + } + fn visit_nat_range(&mut self, arena: &'a Bump, nr: &mut NatRange<'a>) { walk_nat_range(self, arena, nr) } + fn visit_ident_kinded(&mut self, arena: &'a Bump, id_kind: &mut IdentKinded<'a>) { walk_ident_kinded(self, arena, id_kind) } + fn visit_ident_exec(&mut self, arena: &'a Bump, id_exec: &mut IdentExec<'a>) { walk_ident_exec(self, arena, id_exec) } + fn visit_prv_rel(&mut self, arena: &'a Bump, prv_rel: &mut PrvRel<'a>) { walk_prv_rel(self, arena, prv_rel) } + fn visit_exec_ty(&mut self, _exec: &mut ExecTy<'a>) {} + fn visit_exec_ty_ref(&mut self, _arena: &'a bumpalo::Bump, _slot: &mut &'a ExecTy<'a>) {} + fn visit_mem(&mut self, arena: &'a Bump, mem: &mut Memory<'a>) { walk_mem(self, arena, mem) } + fn visit_prv(&mut self, arena: &'a Bump, prv: &mut Provenance<'a>) { walk_prv(self, arena, prv) } + fn visit_scalar_ty(&mut self, _sty: &mut ScalarTy) {} + fn visit_atomic_ty(&mut self, _aty: &mut AtomicTy) {} + fn visit_dim_compo(&mut self, _dim_compo: &mut DimCompo) {} + fn visit_dim(&mut self, arena: &'a Bump, dim: &mut Dim<'a>) { walk_dim(self, arena, dim) } + fn visit_dim3d(&mut self, arena: &'a Bump, dim3d: &mut Dim3d<'a>) { walk_dim3d(self, arena, dim3d) } + fn visit_dim2d(&mut self, arena: &'a Bump, dim2d: &mut Dim2d<'a>) { walk_dim2d(self, arena, dim2d) } + fn visit_dim1d(&mut self, arena: &'a Bump, dim1d: &mut Dim1d<'a>) { walk_dim1d(self, arena, dim1d) } + fn visit_ref(&mut self, arena: &'a Bump, reff: &mut RefDty<'a>) { walk_ref(self, arena, reff) } + fn visit_dty(&mut self, arena: &'a Bump, dty: &mut DataTy<'a>) { walk_dty(self, arena, dty) } + fn visit_fn_ty(&mut self, arena: &'a Bump, fn_ty: &mut FnTy<'a>) { walk_fn_ty(self, arena, fn_ty) } + fn visit_nat_constr(&mut self, arena: &'a Bump, nat_constr: &mut NatConstr<'a>) { walk_nat_constr(self, arena, nat_constr) } + fn visit_ty(&mut self, arena: &'a Bump, ty: &mut Ty<'a>) { walk_ty(self, arena, ty) } + fn visit_view(&mut self, arena: &'a Bump, view: &mut View<'a>) { walk_view(self, arena, view) } + fn visit_pl_expr(&mut self, arena: &'a Bump, pl_expr: &mut PlaceExpr<'a>) { walk_pl_expr(self, arena, pl_expr) } + fn visit_arg_kinded(&mut self, arena: &'a Bump, arg_kinded: &mut ArgKinded<'a>) { walk_arg_kinded(self, arena, arg_kinded) } + fn visit_kind(&mut self, _kind: &mut Kind) {} + fn visit_binary_op(&mut self, _op: &mut BinOp) {} + fn visit_unary_op(&mut self, _op: &mut UnOp) {} + fn visit_own(&mut self, _own: &mut Ownership) {} + fn visit_mutability(&mut self, _mutbl: &mut Mutability) {} + fn visit_lit(&mut self, _lit: &mut Lit) {} + fn visit_ident(&mut self, _arena: &'a Bump, _ident: &mut Ident<'a>) {} + fn visit_pattern(&mut self, arena: &'a Bump, pattern: &mut Pattern<'a>) { walk_pattern(self, arena, pattern) } + fn visit_split(&mut self, arena: &'a Bump, split: &mut Split<'a>) { walk_split(self, arena, split) } + fn visit_sched(&mut self, arena: &'a Bump, sched: &mut Sched<'a>) { walk_sched(self, arena, sched) } + fn visit_expr(&mut self, arena: &'a Bump, expr: &mut Expr<'a>) { walk_expr(self, arena, expr) } + fn visit_app_kernel(&mut self, arena: &'a Bump, app_kernel: &mut AppKernel<'a>) { walk_app_kernel(self, arena, app_kernel) } + fn visit_block(&mut self, arena: &'a Bump, block: &mut Block<'a>) { walk_block(self, arena, block) } + fn visit_split_proj(&mut self,arena: &'a Bump, exec_split: &mut TakeRange<'a>) { walk_split_proj(self, arena, exec_split) } + fn visit_exec_expr(&mut self, arena: &'a Bump, exec_expr: &mut ExecExpr<'a>) { walk_exec_expr(self, arena, exec_expr) } + fn visit_exec(&mut self, arena: &'a Bump, exec: &mut ExecExprKind<'a>) { walk_exec(self, arena, exec) } + fn visit_exec_path_elem(&mut self, arena: &'a Bump, exec_path_elem: &mut ExecPathElem<'a>) { walk_exec_path_elem(self, arena, exec_path_elem) } + fn visit_param_decl(&mut self, arena: &'a Bump, param_decl: &mut ParamDecl<'a>) { walk_param_decl(self, arena, param_decl) } + fn visit_fun_def(&mut self, arena: &'a Bump, fun_def: &mut FunDef<'a>) { walk_fun_def(self, arena, fun_def) } + fn visit_fun_decl(&mut self, arena: &'a Bump, fun_decl: &mut FunDecl<'a>) { walk_fun_decl(self, arena, fun_decl) } + fn visit_param_sig(&mut self, arena: &'a Bump, param_sig: &mut ParamSig<'a>) { walk_param_sig(self, arena, param_sig) } + fn visit_field(&mut self, arena: &'a Bump, field: &mut (Ident<'a>, DataTy<'a>)) { walk_field(self, arena, field) } +} + +// Taken from the Rust compiler +macro_rules! walk_list { + ($visitor:expr, $method:ident, $list:expr, $arena:expr) => { + for elem in $list.iter_mut() { + $visitor.$method($arena, elem) + } + }; +} +pub(crate) use walk_list; + +pub fn walk_nat<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, n: &mut Nat<'a>) { + match n { + Nat::Ident(ident) => visitor.visit_ident(arena, ident), + Nat::BinOp(op, ref mut l, ref mut r) => { + visitor.visit_binary_op_nat(op); + visitor.visit_nat_ref(arena, l); + visitor.visit_nat_ref(arena, r); + } + Nat::GridIdx + | Nat::BlockIdx(_) + | Nat::BlockDim(_) + | Nat::ThreadIdx(_) + | Nat::WarpGrpIdx + | Nat::WarpIdx + | Nat::LaneIdx + | Nat::Lit(_) => {} + Nat::App(func, args) => { + visitor.visit_ident(arena, func); + walk_list!(visitor, visit_nat, args.as_mut_slice(), arena) + } + } +} + +pub fn walk_nat_ref<'a, V: VisitMut<'a>>( + v: &mut V, + arena: &'a bumpalo::Bump, + slot: &mut &'a Nat<'a>, +) { + match &*(*slot) { + Nat::BinOp(op, l, r) => { + let mut l_slot: &'a Nat<'a> = *l; + let mut r_slot: &'a Nat<'a> = *r; + + v.visit_nat_ref(arena, &mut l_slot); + v.visit_nat_ref(arena, &mut r_slot); + + let new = arena.alloc(Nat::BinOp(*op, l_slot, r_slot)); + *slot = new; + } + Nat::App(func, args) => { + let mut rebuilt = bumpalo::collections::Vec::new_in(arena); + rebuilt.reserve(args.len()); + for a in args.iter() { + let mut owned = a.clone(); + v.visit_nat(arena, &mut owned); + rebuilt.push(owned); + } + *slot = arena.alloc(Nat::App(func.clone(), rebuilt)); + } + _ => {} + } +} + +pub fn walk_nat_range<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + nr: &mut NatRange<'a>, +) { + match nr { + NatRange::Simple { lower, upper } => { + visitor.visit_nat(arena, lower); + visitor.visit_nat(arena, upper); + } + NatRange::Halved { upper } | NatRange::Doubled { upper } => visitor.visit_nat(arena, upper), + } +} + +pub fn walk_ident_kinded<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + id_kind: &mut IdentKinded<'a>, +) { + let IdentKinded { ident, kind } = id_kind; + visitor.visit_ident(arena, ident); + visitor.visit_kind(kind) +} + +pub fn walk_ident_exec<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + id_exec: &mut IdentExec<'a>, +) { + let IdentExec { ident, ty } = id_exec; + visitor.visit_ident(arena, ident); + visitor.visit_exec_ty_ref(arena, ty) +} + +pub fn walk_prv_rel<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + prv_rel: &mut PrvRel<'a>, +) { + let PrvRel { longer, shorter } = prv_rel; + visitor.visit_ident(arena, longer); + visitor.visit_ident(arena, shorter) +} + +pub fn walk_mem<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, mem: &mut Memory<'a>) { + if let Memory::Ident(ident) = mem { + visitor.visit_ident(arena, ident) + } +} + +pub fn walk_prv<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, prv: &mut Provenance<'a>) { + match prv { + Provenance::Ident(ident) => visitor.visit_ident(arena, ident), + Provenance::Value(_) => {} + } +} + +pub fn walk_dim3d<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, dim3d: &mut Dim3d<'a>) { + let Dim3d(n1, n2, n3) = dim3d; + visitor.visit_nat(arena, n1); + visitor.visit_nat(arena, n2); + visitor.visit_nat(arena, n3); +} + +pub fn walk_dim2d<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, dim2d: &mut Dim2d<'a>) { + let Dim2d(n1, n2) = dim2d; + visitor.visit_nat(arena, n1); + visitor.visit_nat(arena, n2); +} + +pub fn walk_dim1d<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, dim1d: &mut Dim1d<'a>) { + let Dim1d(n) = dim1d; + visitor.visit_nat(arena, n); +} + +/** +pub fn walk_dim<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, dim: &mut Dim<'a>) { + match dim { + Dim::XYZ(dim3d) => { + visitor.visit_dim3d(arena, dim3d); + } + Dim::XY(dim2d) | Dim::XZ(dim2d) | Dim::YZ(dim2d) => { + visitor.visit_dim2d(arena, dim2d); + } + Dim::X(dim1d) | Dim::Y(dim1d) | Dim::Z(dim1d) => visitor.visit_dim1d(arena, dim1d), + } +}*/ + +pub fn walk_dim<'a, V: VisitMut<'a>>(v: &mut V, arena: &'a bumpalo::Bump, dim: &mut Dim<'a>) { + let new_dim = match dim { + Dim::XYZ(d3_ref) => { + let mut d3 = (*(*d3_ref)).clone(); // owned, mutable + v.visit_dim3d(arena, &mut d3); + Dim::new_3d(arena, d3.0.clone(), d3.1.clone(), d3.2.clone()) + } + Dim::XY(d2_ref) => { + let mut d2 = (*(*d2_ref)).clone(); + v.visit_dim2d(arena, &mut d2); + Dim::new_2d(arena, Dim::XY, d2.0.clone(), d2.1.clone()) + } + Dim::XZ(d2_ref) => { + let mut d2 = (*(*d2_ref)).clone(); + v.visit_dim2d(arena, &mut d2); + Dim::new_2d(arena, Dim::XZ, d2.0.clone(), d2.1.clone()) + } + Dim::YZ(d2_ref) => { + let mut d2 = (*(*d2_ref)).clone(); + v.visit_dim2d(arena, &mut d2); + Dim::new_2d(arena, Dim::YZ, d2.0.clone(), d2.1.clone()) + } + Dim::X(d1_ref) => { + let mut d1 = (*(*d1_ref)).clone(); + v.visit_dim1d(arena, &mut d1); + Dim::new_1d(arena, Dim::X, d1.0.clone()) + } + Dim::Y(d1_ref) => { + let mut d1 = (*(*d1_ref)).clone(); + v.visit_dim1d(arena, &mut d1); + Dim::new_1d(arena, Dim::Y, d1.0.clone()) + } + Dim::Z(d1_ref) => { + let mut d1 = (*(*d1_ref)).clone(); + v.visit_dim1d(arena, &mut d1); + Dim::new_1d(arena, Dim::Z, d1.0.clone()) + } + }; + *dim = new_dim; +} + +pub fn walk_ref<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, reff: &mut RefDty<'a>) { + let RefDty { rgn, own, mem, dty } = reff; + visitor.visit_prv(arena, rgn); + visitor.visit_own(own); + visitor.visit_mem(arena, mem); + + let mut owned = (*reff.dty).clone(); + visitor.visit_dty(arena, &mut owned); + reff.dty = arena.alloc(owned); +} + +pub fn walk_dty<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, dty: &mut DataTy<'a>) { + match &mut dty.dty { + DataTyKind::Ident(ident) => visitor.visit_ident(arena, ident), + DataTyKind::Scalar(sty) => visitor.visit_scalar_ty(sty), + DataTyKind::Atomic(aty) => visitor.visit_atomic_ty(aty), + DataTyKind::Tuple(elem_dtys) => walk_list!(visitor, visit_dty, elem_dtys, arena), + DataTyKind::Struct(struct_decl_ref) => { + let mut owned = (**struct_decl_ref).clone(); + visitor.visit_ident(arena, &mut owned.ident); + walk_list!(visitor, visit_field, &mut owned.fields, arena); + *struct_decl_ref = arena.alloc(owned); + } + DataTyKind::Array(dty_ref, n_ref) => { + let mut elem = (**dty_ref).clone(); + visitor.visit_dty(arena, &mut elem); + *dty_ref = arena.alloc(elem); + + let mut n = (*n_ref).clone(); + visitor.visit_nat(arena, &mut n); + *n_ref = n; + } + DataTyKind::ArrayShape(dty_ref, n_ref) => { + let mut elem = (**dty_ref).clone(); + visitor.visit_dty(arena, &mut elem); + *dty_ref = arena.alloc(elem); + + let mut n = (*n_ref).clone(); + visitor.visit_nat(arena, &mut n); + *n_ref = n; + } + DataTyKind::At(dty_ref, mem) => { + let mut elem = (**dty_ref).clone(); + visitor.visit_dty(arena, &mut elem); + *dty_ref = arena.alloc(elem); + + visitor.visit_mem(arena, mem); + } + DataTyKind::Ref(reff_ref) => { + let mut r = (**reff_ref).clone(); + visitor.visit_ref(arena, &mut r); + *reff_ref = arena.alloc(r); + } + DataTyKind::RawPtr(datayt_ref) => { + let mut elem = (**datayt_ref).clone(); + visitor.visit_dty(arena, &mut elem); + *datayt_ref = arena.alloc(elem); + } + + DataTyKind::Dead(datayt_ref) => { + let mut elem = (**datayt_ref).clone(); + visitor.visit_dty(arena, &mut elem); + *datayt_ref = arena.alloc(elem); + } + } +} + +pub fn walk_fn_ty<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, fn_ty: &mut FnTy<'a>) { + let FnTy { + generics, + generic_exec, + param_sigs, + exec, + ret_ty, + nat_constrs, + } = fn_ty; + walk_list!(visitor, visit_ident_kinded, generics, arena); + for exec_decl in generic_exec { + visitor.visit_ident_exec(arena, exec_decl) + } + walk_list!(visitor, visit_param_sig, param_sigs, arena); + visitor.visit_exec_expr(arena, exec); + let mut ret_ty_owned = (**ret_ty).clone(); + visitor.visit_ty(arena, &mut ret_ty_owned); + *ret_ty = arena.alloc(ret_ty_owned); + walk_list!(visitor, visit_nat_constr, nat_constrs, arena); +} + +pub fn walk_nat_constr<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + nat_constr: &mut NatConstr<'a>, +) { + match nat_constr { + NatConstr::True => {} + NatConstr::Eq(l, r) => { + let mut owned_l = (**l).clone(); + let mut owned_r = (**r).clone(); + visitor.visit_nat(arena, &mut owned_l); + visitor.visit_nat(arena, &mut owned_r); + *l = arena.alloc(owned_l); + *r = arena.alloc(owned_r); + } + NatConstr::Lt(l, r) => { + let mut owned_l = (**l).clone(); + let mut owned_r = (**r).clone(); + visitor.visit_nat(arena, &mut owned_l); + visitor.visit_nat(arena, &mut owned_r); + *l = arena.alloc(owned_l); + *r = arena.alloc(owned_r); + } + NatConstr::And(l, r) => { + let mut owned_l = (**l).clone(); + let mut owned_r = (**r).clone(); + visitor.visit_nat_constr(arena, &mut owned_l); + visitor.visit_nat_constr(arena, &mut owned_r); + *l = arena.alloc(owned_l); + *r = arena.alloc(owned_r); + } + NatConstr::Or(l, r) => { + let mut owned_l = (**l).clone(); + let mut owned_r = (**r).clone(); + visitor.visit_nat_constr(arena, &mut owned_l); + visitor.visit_nat_constr(arena, &mut owned_r); + *l = arena.alloc(owned_l); + *r = arena.alloc(owned_r); + } + } +} + +pub fn walk_ty<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a bumpalo::Bump, ty: &mut Ty<'a>) { + match &mut ty.ty { + TyKind::Data(dty) => { + let mut dty_owned = (**dty).clone(); + visitor.visit_dty(arena, &mut dty_owned); + *dty = arena.alloc(dty_owned); + } + TyKind::FnTy(fn_slot) => { + let src: &FnTy<'a> = *fn_slot; + + let mut owned = FnTy { + generics: src.generics.clone(), + generic_exec: src.generic_exec.clone(), + param_sigs: src.param_sigs.clone(), + exec: src.exec.clone(), + ret_ty: src.ret_ty, + nat_constrs: src.nat_constrs.clone(), + }; + + walk_list!(visitor, visit_ident_kinded, &mut owned.generics, arena); + + if let Some(ref mut ie) = owned.generic_exec { + visitor.visit_ident_exec(arena, ie); + } + + walk_list!(visitor, visit_param_sig, &mut owned.param_sigs, arena); + visitor.visit_exec_expr(arena, &mut owned.exec); + + let mut ret_owned = (*owned.ret_ty).clone(); + visitor.visit_ty(arena, &mut ret_owned); + owned.ret_ty = arena.alloc(ret_owned); + + walk_list!(visitor, visit_nat_constr, &mut owned.nat_constrs, arena); + + *fn_slot = arena.alloc(owned); + } + } +} + +pub fn walk_view<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, view: &mut View<'a>) { + visitor.visit_ident(arena, &mut view.name); + walk_list!(visitor, visit_arg_kinded, &mut view.gen_args, arena); + for v in &mut view.args { + visitor.visit_view(arena, v) + } +} + +/** +pub fn walk_pl_expr<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + pl_expr: &mut PlaceExpr<'a>, +) { + match &mut pl_expr.pl_expr { + PlaceExprKind::Ident(ident) => visitor.visit_ident(arena, ident), + PlaceExprKind::Deref(pl_expr) => visitor.visit_pl_expr(arena, pl_expr), + PlaceExprKind::Select(p, distrib_exec) => { + visitor.visit_pl_expr(arena, p); + visitor.visit_exec_expr(arena, distrib_exec); + } + PlaceExprKind::Proj(pl_expr, _) => { + visitor.visit_pl_expr(arena, pl_expr); + } + PlaceExprKind::FieldProj(pl_expr, field_name) => { + visitor.visit_pl_expr(arena, pl_expr); + visitor.visit_ident(arena, field_name); + } + PlaceExprKind::View(pl_expr, view) => { + visitor.visit_pl_expr(arena, pl_expr); + visitor.visit_view(arena, view); + } + PlaceExprKind::Idx(pl_expr, n) => { + visitor.visit_pl_expr(arena, pl_expr); + visitor.visit_nat(arena, n) + } + } +} +*/ + +pub fn walk_pl_expr<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a bumpalo::Bump, + pl_expr: &mut PlaceExpr<'a>, +) { + match &mut pl_expr.pl_expr { + PlaceExprKind::Ident(ident) => { + visitor.visit_ident(arena, ident); + } + + PlaceExprKind::Deref(inner_ref) => { + let mut owned = (**inner_ref).clone(); + visitor.visit_pl_expr(arena, &mut owned); + *inner_ref = arena.alloc(owned); + } + + PlaceExprKind::Select(p_ref, exec_ref) => { + let mut p_owned = (**p_ref).clone(); + visitor.visit_pl_expr(arena, &mut p_owned); + *p_ref = arena.alloc(p_owned); + + let mut exec_owned = (**exec_ref).clone(); + visitor.visit_exec_expr(arena, &mut exec_owned); + *exec_ref = arena.alloc(exec_owned); + } + + PlaceExprKind::Proj(p_ref, _k) => { + let mut p_owned = (**p_ref).clone(); + visitor.visit_pl_expr(arena, &mut p_owned); + *p_ref = arena.alloc(p_owned); + } + + PlaceExprKind::FieldProj(p_ref, field_ref) => { + let mut p_owned = (**p_ref).clone(); + visitor.visit_pl_expr(arena, &mut p_owned); + *p_ref = arena.alloc(p_owned); + + let mut field_owned = (**field_ref).clone(); + visitor.visit_ident(arena, &mut field_owned); + *field_ref = arena.alloc(field_owned); + } + + PlaceExprKind::View(p_ref, view_ref) => { + let mut p_owned = (**p_ref).clone(); + visitor.visit_pl_expr(arena, &mut p_owned); + *p_ref = arena.alloc(p_owned); + + let mut view_owned = (**view_ref).clone(); + visitor.visit_view(arena, &mut view_owned); + *view_ref = arena.alloc(view_owned); + } + + PlaceExprKind::Idx(p_ref, n_ref) => { + let mut p_owned = (**p_ref).clone(); + visitor.visit_pl_expr(arena, &mut p_owned); + *p_ref = arena.alloc(p_owned); + + let mut n_owned = (**n_ref).clone(); + visitor.visit_nat(arena, &mut n_owned); + *n_ref = arena.alloc(n_owned); + } + } +} + +pub fn walk_arg_kinded<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + arg_kinded: &mut ArgKinded<'a>, +) { + match arg_kinded { + ArgKinded::Ident(ident) => visitor.visit_ident(arena, ident), + ArgKinded::Nat(n) => visitor.visit_nat(arena, n), + ArgKinded::Memory(mem) => visitor.visit_mem(arena, mem), + ArgKinded::DataTy(dty) => visitor.visit_dty(arena, dty), + ArgKinded::Provenance(prv) => visitor.visit_prv(arena, prv), + } +} + +pub fn walk_pattern<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + pattern: &mut Pattern<'a>, +) { + match pattern { + Pattern::Ident(mutab, ident) => { + visitor.visit_mutability(mutab); + visitor.visit_ident(arena, ident); + } + Pattern::Tuple(patterns) => { + walk_list!(visitor, visit_pattern, patterns, arena) + } + Pattern::Wildcard => {} + } +} + +pub fn walk_split<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, indep: &mut Split<'a>) { + let Split { + dim_compo, + pos, + split_exec, + branch_idents, + branch_bodies, + } = indep; + visitor.visit_dim_compo(dim_compo); + visitor.visit_nat(arena, pos); + let mut split_exec_owned = (**split_exec).clone(); + visitor.visit_exec_expr(arena, &mut split_exec_owned); + *split_exec = arena.alloc(split_exec_owned); + walk_list!(visitor, visit_ident, branch_idents, arena); + walk_list!(visitor, visit_expr, branch_bodies, arena); +} + +pub fn walk_sched<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, sched: &mut Sched<'a>) { + let Sched { + dim, + inner_exec_ident, + sched_exec, + body, + } = sched; + visitor.visit_dim_compo(dim); + for ident in inner_exec_ident { + visitor.visit_ident(arena, ident) + } + let mut sched_exec_owned = (**sched_exec).clone(); + visitor.visit_exec_expr(arena, &mut sched_exec_owned); + *sched_exec = arena.alloc(sched_exec_owned); + + let mut body_owned = (**body).clone(); + visitor.visit_block(arena, &mut body_owned); + *body = arena.alloc(body_owned); +} + +pub fn walk_expr<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, expr: &mut Expr<'a>) { + // For now, only visit ExprKind + match &mut expr.expr { + ExprKind::Lit(l) => visitor.visit_lit(l), + ExprKind::PlaceExpr(pl_ref) => { + let mut owned = (**pl_ref).clone(); + visitor.visit_pl_expr(arena, &mut owned); + *pl_ref = arena.alloc(owned); + } + ExprKind::Ref(_, own, pl_ref) => { + visitor.visit_own(own); + let mut owned = (**pl_ref).clone(); + visitor.visit_pl_expr(arena, &mut owned); + *pl_ref = arena.alloc(owned); + } + ExprKind::Block(block_ref) => { + let mut owned = (**block_ref).clone(); + visitor.visit_block(arena, &mut owned); + *block_ref = arena.alloc(owned); + } + ExprKind::LetUninit(maybe_exec_expr, ident_ref, ty_ref) => { + if let Some(slot) = maybe_exec_expr.as_mut() { + let mut owned = (**slot).clone(); + visitor.visit_exec_expr(arena, &mut owned); + *slot = arena.alloc(owned); + } + + visitor.visit_ident(arena, ident_ref); + + let mut ty_owned = (**ty_ref).clone(); + visitor.visit_ty(arena, &mut ty_owned); + *ty_ref = arena.alloc(ty_owned); + } + ExprKind::Let(pattern, ty_opt, e) => { + visitor.visit_pattern(arena, pattern); + if let Some(slot) = ty_opt.as_mut() { + let mut ty_owned = (**slot).clone(); + visitor.visit_ty(arena, &mut ty_owned); + *slot = arena.alloc(ty_owned); + } + let mut e_owned = (**e).clone(); + visitor.visit_expr(arena, &mut e_owned); + *e = arena.alloc(e_owned); + } + ExprKind::Assign(pl_expr, expr) => { + let mut pl_expr_owned = (**pl_expr).clone(); + let mut expr_owned = (**expr).clone(); + visitor.visit_pl_expr(arena, &mut pl_expr_owned); + visitor.visit_expr(arena, &mut expr_owned); + *pl_expr = arena.alloc(pl_expr_owned); + *expr = arena.alloc(expr_owned); + } + ExprKind::IdxAssign(pl_expr, idx, expr) => { + let mut pl_expr_owned = (**pl_expr).clone(); + let mut expr_owned = (**expr).clone(); + visitor.visit_pl_expr(arena, &mut pl_expr_owned); + visitor.visit_nat(arena, idx); + visitor.visit_expr(arena, &mut expr_owned); + *pl_expr = arena.alloc(pl_expr_owned); + *expr = arena.alloc(expr_owned); + } + ExprKind::Seq(es) => { + for e in es { + visitor.visit_expr(arena, e) + } + } + // ExprKind::Lambda(params, exec_decl, dty, expr) => { + // walk_list!(visitor, visit_param_decl, params); + // visitor.visit_ident_exec(exec_decl); + // visitor.visit_dty(dty); + // visitor.visit_expr(expr) + // } + ExprKind::App(f, gen_args, args) => { + let mut f_owned = (**f).clone(); + visitor.visit_ident(arena, &mut f_owned); + *f = arena.alloc(f_owned); + walk_list!(visitor, visit_arg_kinded, gen_args, arena); + walk_list!(visitor, visit_expr, args, arena); + } + ExprKind::DepApp(f, gen_args) => { + visitor.visit_ident(arena, f); + walk_list!(visitor, visit_arg_kinded, gen_args, arena); + } + ExprKind::AppKernel(app_kernel_ref) => { + let src: &AppKernel<'a> = *app_kernel_ref; + + let mut tmp = AppKernel { + grid_dim: src.grid_dim.clone(), + block_dim: src.block_dim.clone(), + shared_mem_dtys: src.shared_mem_dtys.clone(), + shared_mem_prvs: src.shared_mem_prvs.clone(), + fun_ident: { + let mut id = (*src.fun_ident).clone(); + visitor.visit_ident(arena, &mut id); + arena.alloc(id) + }, + gen_args: { + let mut v = src.gen_args.clone(); + for a in v.iter_mut() { + visitor.visit_arg_kinded(arena, a); + } + v + }, + args: { + let mut v = src.args.clone(); + for e in v.iter_mut() { + visitor.visit_expr(arena, e); + } + v + }, + }; + + visitor.visit_dim(arena, &mut tmp.grid_dim); + visitor.visit_dim(arena, &mut tmp.block_dim); + for d in tmp.shared_mem_dtys.iter_mut() { + visitor.visit_dty(arena, d); + } + + *app_kernel_ref = arena.alloc(tmp); + } + ExprKind::IfElse(cond_ref, tt_ref, ff_ref) => { + let mut cond = (**cond_ref).clone(); + let mut tt = (**tt_ref).clone(); + let mut ff = (**ff_ref).clone(); + visitor.visit_expr(arena, &mut cond); + visitor.visit_expr(arena, &mut tt); + visitor.visit_expr(arena, &mut ff); + *cond_ref = arena.alloc(cond); + *tt_ref = arena.alloc(tt); + *ff_ref = arena.alloc(ff) + } + ExprKind::If(cond, tt) => { + let mut cond_owned = (**cond).clone(); + let mut tt_owned = (**tt).clone(); + visitor.visit_expr(arena, &mut cond_owned); + visitor.visit_expr(arena, &mut tt_owned); + *cond = arena.alloc(cond_owned); + *tt = arena.alloc(tt_owned); + } + ExprKind::Array(elems) => { + walk_list!(visitor, visit_expr, elems, arena); + } + ExprKind::Tuple(elems) => { + walk_list!(visitor, visit_expr, elems, arena); + } + ExprKind::For(ident, coll_ref, body_ref) => { + visitor.visit_ident(arena, ident); + let mut coll = (**coll_ref).clone(); + let mut body = (**body_ref).clone(); + visitor.visit_expr(arena, &mut coll); + visitor.visit_expr(arena, &mut body); + *coll_ref = arena.alloc(coll); + *body_ref = arena.alloc(body); + } + ExprKind::Split(split_ref) => { + let src: &Split<'a> = *split_ref; + + let mut dim = src.dim_compo; + visitor.visit_dim_compo(&mut dim); + + let mut pos = src.pos.clone(); + visitor.visit_nat(arena, &mut pos); + + let mut exec_owned = (*src.split_exec).clone(); + visitor.visit_exec_expr(arena, &mut exec_owned); + let exec_ref: &'a ExecExpr<'a> = arena.alloc(exec_owned); + + let mut branch_idents = src.branch_idents.clone(); + for id in branch_idents.iter_mut() { + visitor.visit_ident(arena, id); + } + + let mut branch_bodies = src.branch_bodies.clone(); + for body in branch_bodies.iter_mut() { + visitor.visit_expr(arena, body); + } + + let new_split = Split { + dim_compo: dim, + pos, + split_exec: exec_ref, + branch_idents, + branch_bodies, + }; + *split_ref = arena.alloc(new_split); + } + ExprKind::Sched(sched) => { + let mut sched_owned = (**sched).clone(); + visitor.visit_sched(arena, &mut sched_owned); + *sched = arena.alloc(sched_owned); + } + ExprKind::ForNat(ident, range, body) => { + visitor.visit_ident(arena, ident); + let mut range_owned = (**range).clone(); + let mut body_owned = (**body).clone(); + visitor.visit_nat_range(arena, &mut range_owned); + visitor.visit_expr(arena, &mut body_owned); + *range = arena.alloc(range_owned); + *body = arena.alloc(body_owned); + } + ExprKind::While(cond, body) => { + let mut cond_owned = (**cond).clone(); + let mut body_owned = (**body).clone(); + visitor.visit_expr(arena, &mut cond_owned); + visitor.visit_expr(arena, &mut body_owned); + *cond = arena.alloc(cond_owned); + *body = arena.alloc(body_owned); + } + ExprKind::BinOp(op, l, r) => { + visitor.visit_binary_op(op); + let mut l_owned = (**l).clone(); + let mut r_owned = (**r).clone(); + visitor.visit_expr(arena, &mut l_owned); + visitor.visit_expr(arena, &mut r_owned); + *l = arena.alloc(l_owned); + *r = arena.alloc(r_owned); + } + ExprKind::UnOp(op, expr) => { + visitor.visit_unary_op(op); + let mut expr_owned = (**expr).clone(); + visitor.visit_expr(arena, &mut expr_owned); + *expr = arena.alloc(expr_owned); + } + ExprKind::Sync(exec) => { + for e in exec { + visitor.visit_exec_expr(arena, e) + } + } + ExprKind::Unsafe(expr) => { + let mut expr_owned = (**expr).clone(); + visitor.visit_expr(arena, &mut expr_owned); + *expr = arena.alloc(expr_owned); + } + ExprKind::Cast(expr, dty) => { + let mut expr_owned = (**expr).clone(); + let mut dty_owned = (**dty).clone(); + visitor.visit_expr(arena, &mut expr_owned); + visitor.visit_dty(arena, &mut dty_owned); + *expr = arena.alloc(expr_owned); + *dty = arena.alloc(dty_owned); + } + ExprKind::Range(_, _) | ExprKind::Hole => (), + } +} + +pub fn walk_app_kernel<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + app_kernel: &mut AppKernel<'a>, +) { + let AppKernel { + grid_dim, + block_dim, + shared_mem_dtys, + shared_mem_prvs: _, + fun_ident, + gen_args, + args, + } = app_kernel; + visitor.visit_dim(arena, grid_dim); + visitor.visit_dim(arena, block_dim); + for dty in shared_mem_dtys { + visitor.visit_dty(arena, dty); + } + let mut fun_owned = (**fun_ident).clone(); + visitor.visit_ident(arena, &mut fun_owned); + *fun_ident = arena.alloc(fun_owned); + for garg in gen_args { + visitor.visit_arg_kinded(arena, garg); + } + for arg in args { + visitor.visit_expr(arena, arg); + } +} + +pub fn walk_block<'a, V: VisitMut<'a>>(visitor: &mut V, arena: &'a Bump, block: &mut Block<'a>) { + let Block { body, .. } = block; + let mut body_owned = (**body).clone(); + visitor.visit_expr(arena, &mut body_owned); + *body = arena.alloc(body_owned); +} + +pub fn walk_split_proj<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + split_proj: &mut TakeRange<'a>, +) { + let TakeRange { + split_dim, + pos, + left_or_right: _, + } = split_proj; + visitor.visit_dim_compo(split_dim); + visitor.visit_nat(arena, pos); +} + +/** +pub fn walk_exec_expr<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + exec_expr: &mut ExecExpr<'a>, +) { + visitor.visit_exec(arena, &mut exec_expr.exec); + for t in &mut exec_expr.ty { + visitor.visit_exec_ty(t); + } +} +*/ + +pub fn walk_exec_expr<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a bumpalo::Bump, + exec_expr: &mut ExecExpr<'a>, +) { + let mut exec_owned = (*exec_expr.exec).clone(); + visitor.visit_exec(arena, &mut exec_owned); + exec_expr.exec = arena.alloc(exec_owned); + + if let Some(ty_ref) = &mut exec_expr.ty { + let mut ty_owned = (**ty_ref).clone(); + visitor.visit_exec_ty(&mut ty_owned); + *ty_ref = arena.alloc(ty_owned); + } +} + +pub fn walk_exec<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + exec: &mut ExecExprKind<'a>, +) { + let ExecExprKind { base, path } = exec; + match base { + BaseExec::CpuThread => (), + BaseExec::Ident(ident) => visitor.visit_ident(arena, ident), + BaseExec::GpuGrid(gdim, bdim) => { + let mut gdim_owned = (**gdim).clone(); + let mut bdim_owned = (**bdim).clone(); + visitor.visit_dim(arena, &mut gdim_owned); + visitor.visit_dim(arena, &mut bdim_owned); + *gdim = arena.alloc(gdim_owned); + *bdim = arena.alloc(bdim_owned); + } + }; + for e in path { + visitor.visit_exec_path_elem(arena, e) + } +} + +pub fn walk_exec_path_elem<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + exec_path_elem: &mut ExecPathElem<'a>, +) { + match exec_path_elem { + ExecPathElem::TakeRange(split_proj) => { + let mut split_proj_owned = (**split_proj).clone(); + visitor.visit_split_proj(arena, &mut split_proj_owned); + *split_proj = arena.alloc(split_proj_owned); + } + ExecPathElem::ForAll(dim_compo) => visitor.visit_dim_compo(dim_compo), + ExecPathElem::ToWarps => {} + ExecPathElem::ToThreads(dim_compo) => visitor.visit_dim_compo(dim_compo), + } +} + +pub fn walk_param_decl<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + param_decl: &mut ParamDecl<'a>, +) { + let ParamDecl { + ident, + ty, + mutbl, + exec_expr, + } = param_decl; + visitor.visit_ident(arena, ident); + if let Some(tty) = ty { + let mut tty_owned = (**tty).clone(); + visitor.visit_ty(arena, &mut tty_owned); + *tty = arena.alloc(tty_owned); + } + visitor.visit_mutability(mutbl); + for ex in exec_expr { + visitor.visit_exec_expr(arena, ex); + } +} + +pub fn walk_fun_def<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + fun_def: &mut FunDef<'a>, +) { + let FunDef { + ident: _, + generic_params, + generic_exec, + param_decls: params, + ret_dty, + exec, + prv_rels, + body, + } = fun_def; + walk_list!(visitor, visit_ident_kinded, generic_params, arena); + for exec_decl in generic_exec { + visitor.visit_ident_exec(arena, exec_decl); + } + walk_list!(visitor, visit_param_decl, params, arena); + let mut ret_dty_owned = (**ret_dty).clone(); + visitor.visit_dty(arena, &mut ret_dty_owned); + *ret_dty = arena.alloc(ret_dty_owned); + + visitor.visit_exec_expr(arena, exec); + walk_list!(visitor, visit_prv_rel, prv_rels, arena); + + let mut body_owned = (**body).clone(); + visitor.visit_block(arena, &mut body_owned); + *body = arena.alloc(body_owned); +} + +pub fn walk_fun_decl<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + fun_decl: &mut FunDecl<'a>, +) { + let FunDecl { + ident: _, + generic_params, + generic_exec, + param_decls: params, + ret_dty, + exec, + prv_rels, + } = fun_decl; + walk_list!(visitor, visit_ident_kinded, generic_params, arena); + for exec_decl in generic_exec { + visitor.visit_ident_exec(arena, exec_decl); + } + walk_list!(visitor, visit_param_decl, params, arena); + let mut ret_dty_owned = (**ret_dty).clone(); + visitor.visit_dty(arena, &mut ret_dty_owned); + *ret_dty = arena.alloc(ret_dty_owned); + visitor.visit_exec_expr(arena, exec); + walk_list!(visitor, visit_prv_rel, prv_rels, arena); +} + +pub fn walk_param_sig<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + param_sig: &mut ParamSig<'a>, +) { + let ParamSig { exec_expr, ty } = param_sig; + visitor.visit_exec_expr(arena, exec_expr); + let mut ty_owned = (**ty).clone(); + visitor.visit_ty(arena, &mut ty_owned); + *ty = arena.alloc(ty_owned); +} + +pub fn walk_field<'a, V: VisitMut<'a>>( + visitor: &mut V, + arena: &'a Bump, + field: &mut (Ident<'a>, DataTy<'a>), +) { + let (ident, dty) = field; + visitor.visit_ident(arena, ident); + visitor.visit_dty(arena, dty); +} diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 0d095fab..280a2b0a 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -4,7 +4,9 @@ use crate::ast::internal::PathElem; use descend_derive::span_derive; pub use span::*; +use crate::arena_ast; use crate::parser::SourceCode; +use bumpalo::{collections::CollectIn, collections::Vec as BumpVec, Bump}; pub mod internal; @@ -13,6 +15,7 @@ mod span; pub mod utils; pub mod visit; pub mod visit_mut; +use std::cell::OnceCell; #[derive(Clone, Debug)] pub struct CompilUnit<'a> { @@ -33,6 +36,25 @@ pub enum Item { StructDecl(Box), } +impl Item { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Item<'a> { + match self { + Item::FunDef(f) => { + let arena_f = f.into_arena(arena); + arena_ast::Item::FunDef(arena.alloc(arena_f)) + } + Item::FunDecl(f) => { + let arena_fd = f.into_arena(arena); + arena_ast::Item::FunDecl(arena.alloc(arena_fd)) + } + Item::StructDecl(s) => { + let arena_s = s.into_arena(arena); + arena_ast::Item::StructDecl(arena.alloc(arena_s)) + } + } + } +} + #[derive(Debug, Clone, PartialEq)] pub struct FunDecl { pub ident: Ident, @@ -65,6 +87,36 @@ impl FunDecl { nat_constrs: vec![], } } + + pub fn into_arena<'a>(self: Box, arena: &'a Bump) -> arena_ast::FunDecl<'a> { + let generic_params = self + .generic_params + .into_iter() + .map(|g| g.into_arena(arena)) + .collect_in(arena); + + let param_decls = self + .param_decls + .into_iter() + .map(|p| p.into_arena(arena)) + .collect_in(arena); + + let prv_rels = self + .prv_rels + .into_iter() + .map(|r| r.into_arena(arena)) + .collect_in(arena); + + arena_ast::FunDecl { + ident: self.ident.into_arena(arena), + generic_params, + generic_exec: self.generic_exec.map(|g| g.into_arena(arena)), + param_decls, + ret_dty: arena.alloc(self.ret_dty.into_arena(arena)), + exec: self.exec.into_arena(arena), + prv_rels, + } + } } #[derive(Debug, Clone, Eq, Hash, PartialEq)] @@ -74,6 +126,28 @@ pub struct StructDecl { pub fields: Vec<(Ident, DataTy)>, } +impl StructDecl { + pub fn into_arena<'a>(self: Box, arena: &'a Bump) -> arena_ast::StructDecl<'a> { + let generic_params = self + .generic_params + .into_iter() + .map(|gp| gp.into_arena(arena)) + .collect_in(arena); + + let fields = self + .fields + .into_iter() + .map(|(i, dty)| (i.into_arena(arena), dty.into_arena(arena))) + .collect_in(arena); + + arena_ast::StructDecl { + ident: self.ident.into_arena(arena), + generic_params, + fields, + } + } +} + // TODO refactor to make use of FunDecl #[derive(Debug, Clone, PartialEq)] pub struct FunDef { @@ -108,6 +182,43 @@ impl FunDef { nat_constrs: vec![], } } + + pub fn into_arena<'a>(self: Box, arena: &'a Bump) -> arena_ast::FunDef<'a> { + let generic_params = self + .generic_params + .into_iter() + .map(|g| g.into_arena(arena)) + .collect_in(arena); + + let generic_exec = self.generic_exec.map(|g| g.into_arena(arena)); + + let param_decls = self + .param_decls + .into_iter() + .map(|p| p.into_arena(arena)) + .collect_in(arena); + + let ret_dty = arena.alloc(self.ret_dty.into_arena(arena)); + + let prv_rels = self + .prv_rels + .into_iter() + .map(|r| r.into_arena(arena)) + .collect_in(arena); + + let body = arena.alloc(self.body.into_arena(arena)); + + arena_ast::FunDef { + ident: self.ident.into_arena(arena), + generic_params, + generic_exec, + param_decls, + ret_dty, + exec: self.exec.into_arena(arena), + prv_rels, + body, + } + } } #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -123,6 +234,13 @@ impl IdentExec { ty: Box::new(exec_ty), } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::IdentExec<'a> { + arena_ast::IdentExec { + ident: self.ident.into_arena(arena), + ty: arena.alloc(self.ty.into_arena(arena)), + } + } } #[derive(Debug, Clone, PartialEq, Eq)] @@ -133,6 +251,22 @@ pub struct ParamDecl { pub exec_expr: Option, } +impl ParamDecl { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ParamDecl<'a> { + let ty = self.ty.map(|t| &*arena.alloc(t.into_arena(arena))); + let exec_expr = self + .exec_expr + .map(|e| arena.alloc(e.into_arena(arena)).clone()); + + arena_ast::ParamDecl { + ident: self.ident.into_arena(arena), + ty, + mutbl: self.mutbl.into_arena(), + exec_expr, + } + } +} + #[span_derive(PartialEq)] #[derive(Debug, Clone)] pub struct Expr { @@ -169,6 +303,14 @@ impl Expr { } } + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Expr<'a> { + arena_ast::Expr { + expr: self.expr.into_arena(arena), + ty: self.ty.map(|t| &*arena.alloc(t.into_arena(arena))), + span: self.span, + } + } + // pub fn subst_idents(&mut self, subst_map: &HashMap<&str, &Expr>) { // fn pl_expr_contains_name_in<'a, I>(pl_expr: &PlaceExpr, mut idents: I) -> bool // where @@ -273,6 +415,15 @@ impl Sched { body: Box::new(body), } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Sched<'a> { + arena_ast::Sched { + dim: self.dim.into_arena(), + inner_exec_ident: self.inner_exec_ident.map(|id| id.into_arena(arena)), + sched_exec: arena.alloc(self.sched_exec.into_arena(arena)), + body: arena.alloc(self.body.into_arena(arena)), + } + } } #[derive(PartialEq, Debug, Clone)] @@ -300,6 +451,30 @@ impl Split { branch_bodies, } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Split<'a> { + let mut branch_idents = bumpalo::collections::Vec::new_in(arena); + branch_idents.extend( + self.branch_idents + .into_iter() + .map(|ident| ident.into_arena(arena)), + ); + + let mut branch_bodies = bumpalo::collections::Vec::new_in(arena); + branch_bodies.extend( + self.branch_bodies + .into_iter() + .map(|expr| expr.into_arena(arena)), + ); + + arena_ast::Split { + dim_compo: self.dim_compo.into_arena(), + pos: self.pos.into_arena(arena), + split_exec: arena.alloc(self.split_exec.into_arena(arena)), + branch_idents, + branch_bodies, + } + } } #[derive(PartialEq, Debug, Clone)] @@ -316,6 +491,15 @@ impl Block { } } + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Block<'a> { + let prvs: bumpalo::collections::Vec<'a, String> = self.prvs.into_iter().collect_in(arena); + + arena_ast::Block { + prvs, + body: arena.alloc(self.body.into_arena(arena)), + } + } + pub fn with_prvs(prvs: Vec, body: Expr) -> Self { Block { prvs, @@ -335,6 +519,36 @@ pub struct AppKernel { pub args: Vec, } +impl AppKernel { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::AppKernel<'a> { + arena_ast::AppKernel { + grid_dim: self.grid_dim.into_arena(arena), + block_dim: self.block_dim.into_arena(arena), + shared_mem_dtys: self + .shared_mem_dtys + .iter() + .map(|dty| dty.clone().into_arena(arena)) + .collect_in(arena), + shared_mem_prvs: self + .shared_mem_prvs + .iter() + .map(|s| arena.alloc_str(s).to_string()) + .collect_in(arena), + fun_ident: arena.alloc(self.fun_ident.clone().into_arena(arena)), + gen_args: self + .gen_args + .iter() + .map(|arg| arg.into_arena(arena)) + .collect_in(arena), + args: self + .args + .iter() + .map(|arg| arg.clone().into_arena(arena)) + .collect_in(arena), + } + } +} + #[derive(PartialEq, Debug, Clone)] pub enum ExprKind { Hole, @@ -393,6 +607,125 @@ pub enum ExprKind { Range(Box, Box), } +impl ExprKind { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ExprKind<'a> { + use ExprKind::*; + match self { + Hole => arena_ast::ExprKind::Hole, + Lit(l) => arena_ast::ExprKind::Lit(l.into_arena(arena)), // assuming `Lit` is Copy or doesn't need arena + PlaceExpr(p) => arena_ast::ExprKind::PlaceExpr(arena.alloc(p.into_arena(arena))), + Array(exprs) => { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + for e in exprs { + bump_vec.push(e.into_arena(arena)); + } + arena_ast::ExprKind::Array(bump_vec) + } + Tuple(exprs) => { + let mut bump_vec = bumpalo::collections::Vec::new_in(arena); + for e in exprs { + bump_vec.push(e.into_arena(arena)); + } + arena_ast::ExprKind::Tuple(bump_vec) + } + Ref(ann, own, pl) => { + arena_ast::ExprKind::Ref(ann, own.into_arena(), arena.alloc(pl.into_arena(arena))) + } + Block(b) => { + let b_ref = arena.alloc(b.into_arena(arena)); + arena_ast::ExprKind::Block(b_ref) + } + LetUninit(exec, ident, ty) => arena_ast::ExprKind::LetUninit( + exec.map(|e| &*arena.alloc(e.into_arena(arena))), + ident.into_arena(arena), + arena.alloc(ty.into_arena(arena)), + ), + Let(pat, ty, expr) => arena_ast::ExprKind::Let( + pat.into_arena(arena), + ty.map(|t| &*arena.alloc(t.into_arena(arena))), + arena.alloc(expr.into_arena(arena)), + ), + Assign(pl, val) => arena_ast::ExprKind::Assign( + arena.alloc(pl.into_arena(arena)), + arena.alloc(val.into_arena(arena)), + ), + IdxAssign(pl, nat, val) => arena_ast::ExprKind::IdxAssign( + arena.alloc(pl.into_arena(arena)), + nat.into_arena(arena), + arena.alloc(val.into_arena(arena)), + ), + Seq(exprs) => arena_ast::ExprKind::Seq( + exprs + .into_iter() + .map(|e| e.into_arena(arena)) + .collect_in(arena), + ), + App(ident, args, exprs) => arena_ast::ExprKind::App( + arena.alloc(ident.into_arena(arena)), + args.into_iter() + .map(|a| a.into_arena(arena)) + .collect_in(arena), + exprs + .into_iter() + .map(|e| e.into_arena(arena)) + .collect_in(arena), + ), + DepApp(ident, args) => arena_ast::ExprKind::DepApp( + ident.into_arena(arena), + args.into_iter() + .map(|a| a.into_arena(arena)) + .collect_in(arena), + ), + ExprKind::AppKernel(kern) => { + arena_ast::ExprKind::AppKernel(arena.alloc(kern.into_arena(arena))) + } + IfElse(cond, then_, else_) => arena_ast::ExprKind::IfElse( + arena.alloc(cond.into_arena(arena)), + arena.alloc(then_.into_arena(arena)), + arena.alloc(else_.into_arena(arena)), + ), + If(cond, body) => arena_ast::ExprKind::If( + arena.alloc(cond.into_arena(arena)), + arena.alloc(body.into_arena(arena)), + ), + For(ident, iter, body) => arena_ast::ExprKind::For( + ident.into_arena(arena), + arena.alloc(iter.into_arena(arena)), + arena.alloc(body.into_arena(arena)), + ), + ForNat(ident, range, body) => arena_ast::ExprKind::ForNat( + ident.into_arena(arena), + arena.alloc(range.into_arena(arena)), + arena.alloc(body.into_arena(arena)), + ), + While(cond, body) => arena_ast::ExprKind::While( + arena.alloc(cond.into_arena(arena)), + arena.alloc(body.into_arena(arena)), + ), + BinOp(op, lhs, rhs) => arena_ast::ExprKind::BinOp( + op.into_arena(), + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + UnOp(op, expr) => { + arena_ast::ExprKind::UnOp(op.into_arena(), arena.alloc(expr.into_arena(arena))) + } + Cast(expr, dty) => arena_ast::ExprKind::Cast( + arena.alloc(expr.into_arena(arena)), + arena.alloc(dty.into_arena(arena)), + ), + Split(split) => arena_ast::ExprKind::Split(arena.alloc(split.into_arena(arena))), + Sched(sched) => arena_ast::ExprKind::Sched(arena.alloc(sched.into_arena(arena))), + Sync(exec_opt) => arena_ast::ExprKind::Sync(exec_opt.map(|e| e.into_arena(arena))), + Unsafe(expr) => arena_ast::ExprKind::Unsafe(arena.alloc(expr.into_arena(arena))), + Range(start, end) => arena_ast::ExprKind::Range( + arena.alloc(start.into_arena(arena)), + arena.alloc(end.into_arena(arena)), + ), + } + } +} + #[span_derive(PartialEq, Eq, Hash)] #[derive(Clone, Debug)] pub struct Ident { @@ -428,6 +761,14 @@ impl Ident { is_implicit: false, } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Ident<'a> { + arena_ast::Ident { + name: arena.alloc_str(&self.name), + span: self.span, + is_implicit: self.is_implicit, + } + } } #[derive(Debug, Clone, PartialEq)] @@ -437,6 +778,23 @@ pub enum Pattern { Wildcard, } +impl Pattern { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::Pattern<'a> { + use Pattern::*; + match self { + Ident(mutability, ident) => { + arena_ast::Pattern::Ident(mutability.into_arena(), ident.clone().into_arena(arena)) + } + Tuple(pats) => { + let arena_vec: bumpalo::collections::Vec<'a, _> = + pats.iter().map(|p| p.into_arena(arena)).collect_in(arena); + arena_ast::Pattern::Tuple(arena_vec) + } + Wildcard => arena_ast::Pattern::Wildcard, + } + } +} + #[derive(Debug, Copy, Clone, PartialEq)] pub enum Lit { Unit, @@ -449,6 +807,24 @@ pub enum Lit { F64(f64), } +impl Lit { + pub fn into_arena<'a>(&self, _arena: &'a Bump) -> arena_ast::Lit { + use arena_ast::Lit as ALit; + use Lit::*; + + match self { + Unit => ALit::Unit, + Bool(b) => ALit::Bool(*b), + I32(i) => ALit::I32(*i), + U8(u) => ALit::U8(*u), + U32(u) => ALit::U32(*u), + U64(u) => ALit::U64(*u), + F32(f) => ALit::F32(*f), + F64(f) => ALit::F64(*f), + } + } +} + // impl PartialEq for Lit{ // fn eq(&self, other:&Self) -> bool { // let b = match (self, other) { @@ -483,6 +859,15 @@ pub enum Mutability { Mut, } +impl Mutability { + pub fn into_arena(&self) -> arena_ast::Mutability { + match self { + Mutability::Mut => arena_ast::Mutability::Mut, + Mutability::Const => arena_ast::Mutability::Const, + } + } +} + impl fmt::Display for Mutability { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let str = match self { @@ -499,12 +884,30 @@ pub enum Ownership { Uniq, } +impl Ownership { + pub fn into_arena(&self) -> arena_ast::Ownership { + match self { + Ownership::Shrd => arena_ast::Ownership::Shrd, + Ownership::Uniq => arena_ast::Ownership::Uniq, + } + } +} + #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub enum UnOp { Not, Neg, } +impl UnOp { + pub fn into_arena(&self) -> arena_ast::UnOp { + match self { + UnOp::Not => arena_ast::UnOp::Not, + UnOp::Neg => arena_ast::UnOp::Neg, + } + } +} + impl fmt::Display for UnOp { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let str = match self { @@ -536,6 +939,30 @@ pub enum BinOp { BitAnd, } +impl BinOp { + pub fn into_arena(&self) -> arena_ast::BinOp { + match self { + BinOp::Add => arena_ast::BinOp::Add, + BinOp::Sub => arena_ast::BinOp::Sub, + BinOp::Mul => arena_ast::BinOp::Mul, + BinOp::Div => arena_ast::BinOp::Div, + BinOp::Mod => arena_ast::BinOp::Mod, + BinOp::And => arena_ast::BinOp::And, + BinOp::Or => arena_ast::BinOp::Or, + BinOp::Eq => arena_ast::BinOp::Eq, + BinOp::Lt => arena_ast::BinOp::Lt, + BinOp::Le => arena_ast::BinOp::Le, + BinOp::Gt => arena_ast::BinOp::Gt, + BinOp::Ge => arena_ast::BinOp::Ge, + BinOp::Neq => arena_ast::BinOp::Neq, + BinOp::Shl => arena_ast::BinOp::Shl, + BinOp::Shr => arena_ast::BinOp::Shr, + BinOp::BitOr => arena_ast::BinOp::BitOr, + BinOp::BitAnd => arena_ast::BinOp::BitAnd, + } + } +} + impl fmt::Display for BinOp { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let str = match self { @@ -569,6 +996,17 @@ pub enum Kind { Provenance, } +impl Kind { + pub fn into_arena(&self, _arena: &Bump) -> arena_ast::Kind { + match self { + Kind::Nat => arena_ast::Kind::Nat, + Kind::Memory => arena_ast::Kind::Memory, + Kind::DataTy => arena_ast::Kind::DataTy, + Kind::Provenance => arena_ast::Kind::Provenance, + } + } +} + #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum ArgKinded { Ident(Ident), @@ -601,6 +1039,18 @@ impl ArgKinded { _ => Ok(false), } } + + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::ArgKinded<'a> { + use ArgKinded::*; + + match self { + Ident(ident) => arena_ast::ArgKinded::Ident(ident.clone().into_arena(arena)), + Nat(nat) => arena_ast::ArgKinded::Nat(nat.into_arena(arena)), + Memory(mem) => arena_ast::ArgKinded::Memory(mem.into_arena(arena)), + DataTy(dty) => arena_ast::ArgKinded::DataTy(dty.clone().into_arena(arena)), + Provenance(prv) => arena_ast::ArgKinded::Provenance(prv.into_arena(arena)), + } + } } #[span_derive(PartialEq, Eq, Hash)] @@ -614,6 +1064,24 @@ pub struct PlaceExpr { pub span: Option, } +impl PlaceExpr { + pub fn into_arena<'a>(self, arena: &'a bumpalo::Bump) -> arena_ast::PlaceExpr<'a> { + let out = arena_ast::PlaceExpr { + pl_expr: self.pl_expr.into_arena(arena), + ty: OnceCell::new(), + span: self.span, + }; + + if let Some(t) = self.ty { + let t_arena = t.into_arena(arena); + let t_ref: &'a arena_ast::Ty<'a> = arena.alloc(t_arena); + let _ = out.ty.set(t_ref); + } + + out + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct View { pub name: Ident, @@ -644,6 +1112,26 @@ impl View { } Ok(true) } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::View<'a> { + let gen_args = self + .gen_args + .into_iter() + .map(|g| g.into_arena(arena)) + .collect_in(arena); + + let args = self + .args + .into_iter() + .map(|v| v.into_arena(arena)) + .collect_in(arena); + + arena_ast::View { + name: self.name.into_arena(arena), + gen_args, + args, + } + } } // TODO create generic View struct to enable easier extensibility by introducing only @@ -676,6 +1164,34 @@ pub enum PlaceExprKind { Ident(Ident), } +impl PlaceExprKind { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::PlaceExprKind<'a> { + use PlaceExprKind::*; + + match self { + View(pl, view) => arena_ast::PlaceExprKind::View( + arena.alloc(pl.into_arena(arena)), + arena.alloc(view.into_arena(arena)), + ), + Select(pl, exec) => arena_ast::PlaceExprKind::Select( + arena.alloc(pl.into_arena(arena)), + arena.alloc(exec.into_arena(arena)), + ), + Proj(pl, idx) => arena_ast::PlaceExprKind::Proj(arena.alloc(pl.into_arena(arena)), idx), + FieldProj(pl, ident) => arena_ast::PlaceExprKind::FieldProj( + arena.alloc(pl.into_arena(arena)), + arena.alloc(ident.into_arena(arena)), + ), + Deref(pl) => arena_ast::PlaceExprKind::Deref(arena.alloc(pl.into_arena(arena))), + Idx(pl, nat) => arena_ast::PlaceExprKind::Idx( + arena.alloc(pl.into_arena(arena)), + arena.alloc(nat.into_arena(arena)), + ), + Ident(ident) => arena_ast::PlaceExprKind::Ident(ident.into_arena(arena)), + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub enum PlExprPathElem { View(View), @@ -844,6 +1360,17 @@ impl ExecExpr { } } + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ExecExpr<'a> { + arena_ast::ExecExpr { + exec: arena.alloc(self.exec.into_arena(arena)), + ty: self.ty.map(|t| { + let tmp: &mut arena_ast::ExecTy<'a> = arena.alloc(t.into_arena(arena)); + &*tmp + }), + span: self.span, + } + } + // TODO how does this relate to is_prefix_of. Refactor. pub fn is_sub_exec_of(&self, exec: &ExecExpr) -> bool { if self.exec.path.len() > exec.exec.path.len() { @@ -936,6 +1463,15 @@ pub enum LeftOrRight { Right, } +impl LeftOrRight { + pub fn into_arena(self) -> crate::arena_ast::LeftOrRight { + match self { + LeftOrRight::Left => crate::arena_ast::LeftOrRight::Left, + LeftOrRight::Right => crate::arena_ast::LeftOrRight::Right, + } + } +} + impl fmt::Display for LeftOrRight { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -960,6 +1496,14 @@ impl TakeRange { left_or_right: proj, } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::TakeRange<'a> { + arena_ast::TakeRange { + split_dim: self.split_dim.into_arena(), + pos: self.pos.into_arena(arena), + left_or_right: self.left_or_right.into_arena(), + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -998,6 +1542,18 @@ impl ExecExprKind { } None } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ExecExprKind<'a> { + let path = bumpalo::collections::Vec::from_iter_in( + self.path.into_iter().map(|elem| elem.into_arena(arena)), + arena, + ); + + arena_ast::ExecExprKind { + base: self.base.into_arena(arena), + path, + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -1007,6 +1563,20 @@ pub enum BaseExec { GpuGrid(Dim, Dim), } +impl BaseExec { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::BaseExec<'a> { + use BaseExec::*; + match self { + Ident(ident) => arena_ast::BaseExec::Ident(ident.into_arena(arena)), + CpuThread => arena_ast::BaseExec::CpuThread, + GpuGrid(d1, d2) => arena_ast::BaseExec::GpuGrid( + arena.alloc(d1.into_arena(arena)), + arena.alloc(d2.into_arena(arena)), + ), + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub enum ExecPathElem { TakeRange(Box), @@ -1015,6 +1585,22 @@ pub enum ExecPathElem { ToThreads(DimCompo), } +impl ExecPathElem { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ExecPathElem<'a> { + use ExecPathElem::*; + + match self { + TakeRange(take_range) => { + let arena_take_range = arena.alloc(take_range.clone().into_arena(arena)); + arena_ast::ExecPathElem::TakeRange(arena_take_range) + } + ForAll(dim_compo) => arena_ast::ExecPathElem::ForAll(dim_compo.into_arena()), + ToWarps => arena_ast::ExecPathElem::ToWarps, + ToThreads(dim_compo) => arena_ast::ExecPathElem::ToThreads(dim_compo.into_arena()), + } + } +} + // ExecTy // fn size(DimCompo) -> usize // fn take_range(DimCompo, Nat) -> ExecTy @@ -1034,6 +1620,13 @@ impl ExecTy { span: None, } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ExecTy<'a> { + arena_ast::ExecTy { + ty: self.ty.into_arena(arena), + span: self.span, + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -1050,6 +1643,32 @@ pub enum ExecTyKind { Any, } +impl ExecTyKind { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::ExecTyKind<'a> { + use ExecTyKind::*; + + match self { + CpuThread => arena_ast::ExecTyKind::CpuThread, + GpuThread => arena_ast::ExecTyKind::GpuThread, + GpuWarp => arena_ast::ExecTyKind::GpuWarp, + GpuBlock(dim) => arena_ast::ExecTyKind::GpuBlock(dim.into_arena(arena)), + GpuGrid(d1, d2) => { + arena_ast::ExecTyKind::GpuGrid(d1.into_arena(arena), d2.into_arena(arena)) + } + GpuToThreads(dim, inner) => arena_ast::ExecTyKind::GpuToThreads( + dim.into_arena(arena), + arena.alloc(inner.clone().into_arena(arena)), + ), + GpuThreadGrp(dim) => arena_ast::ExecTyKind::GpuThreadGrp(dim.into_arena(arena)), + GpuWarpGrp(nat) => arena_ast::ExecTyKind::GpuWarpGrp(nat.into_arena(arena)), + GpuBlockGrp(d1, d2) => { + arena_ast::ExecTyKind::GpuBlockGrp(d1.into_arena(arena), d2.into_arena(arena)) + } + Any => arena_ast::ExecTyKind::Any, + } + } +} + #[span_derive(PartialEq, Eq, Hash)] #[derive(Debug, Clone)] pub struct Ty { @@ -1058,6 +1677,15 @@ pub struct Ty { pub span: Option, } +impl Ty { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::Ty<'a> { + arena_ast::Ty { + ty: self.ty.into_arena(arena), + span: self.span, + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct ParamSig { pub exec_expr: ExecExpr, @@ -1068,6 +1696,14 @@ impl ParamSig { pub fn new(exec_expr: ExecExpr, ty: Ty) -> Self { ParamSig { exec_expr, ty } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::ParamSig<'a> { + let ty = self.ty.into_arena(arena); + arena_ast::ParamSig { + exec_expr: self.exec_expr.into_arena(arena), + ty: arena.alloc(ty), + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -1098,6 +1734,37 @@ impl FnTy { nat_constrs, } } + + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::FnTy<'a> { + let generics = self + .generics + .into_iter() + .map(|g| g.into_arena(arena)) + .collect_in(arena); + + let generic_exec = self.generic_exec.map(|g| g.into_arena(arena)); + + let param_sigs = self + .param_sigs + .into_iter() + .map(|p| p.into_arena(arena)) + .collect_in(arena); + + let nat_constrs = self + .nat_constrs + .into_iter() + .map(|c| c.into_arena(arena)) + .collect_in(arena); + + arena_ast::FnTy { + generics, + generic_exec, + param_sigs, + exec: self.exec.into_arena(arena), + ret_ty: arena.alloc(self.ret_ty.into_arena(arena)), + nat_constrs, + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -1109,6 +1776,32 @@ pub enum NatConstr { Or(Box, Box), } +impl NatConstr { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::NatConstr<'a> { + use NatConstr::*; + + match self { + True => arena_ast::NatConstr::True, + Eq(lhs, rhs) => arena_ast::NatConstr::Eq( + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + Lt(lhs, rhs) => arena_ast::NatConstr::Lt( + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + And(lhs, rhs) => arena_ast::NatConstr::And( + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + Or(lhs, rhs) => arena_ast::NatConstr::Or( + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub enum TyKind { Data(Box), @@ -1116,12 +1809,35 @@ pub enum TyKind { FnTy(Box), } +impl TyKind { + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::TyKind<'a> { + match self { + TyKind::Data(dty) => { + let boxed = bumpalo::boxed::Box::new_in(dty.into_arena(arena), arena); + arena_ast::TyKind::Data(arena.alloc(boxed)) + } + TyKind::FnTy(fn_ty) => { + let boxed = bumpalo::boxed::Box::new_in(fn_ty.into_arena(arena), arena); + arena_ast::TyKind::FnTy(arena.alloc(boxed)) + } + } + } +} + // TODO remove #[derive(PartialEq, Eq, Hash, Debug, Clone, Copy)] pub enum Constraint { Copyable, } +impl Constraint { + pub fn into_arena(&self) -> arena_ast::Constraint { + match self { + Constraint::Copyable => arena_ast::Constraint::Copyable, + } + } +} + impl Ty { pub fn new(ty: TyKind) -> Self { Ty { ty, span: None } @@ -1213,6 +1929,33 @@ impl Dim { _ => Ok(false), } } + + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::Dim<'a> { + use Dim::*; + + match self { + XYZ(boxed) => arena_ast::Dim::XYZ(arena.alloc(arena_ast::Dim3d( + boxed.0.into_arena(arena), + boxed.1.into_arena(arena), + boxed.2.into_arena(arena), + ))), + XY(boxed) => arena_ast::Dim::XY(arena.alloc(arena_ast::Dim2d( + boxed.0.into_arena(arena), + boxed.1.into_arena(arena), + ))), + XZ(boxed) => arena_ast::Dim::XZ(arena.alloc(arena_ast::Dim2d( + boxed.0.into_arena(arena), + boxed.1.into_arena(arena), + ))), + YZ(boxed) => arena_ast::Dim::YZ(arena.alloc(arena_ast::Dim2d( + boxed.0.into_arena(arena), + boxed.1.into_arena(arena), + ))), + X(boxed) => arena_ast::Dim::X(arena.alloc(arena_ast::Dim1d(boxed.0.into_arena(arena)))), + Y(boxed) => arena_ast::Dim::Y(arena.alloc(arena_ast::Dim1d(boxed.0.into_arena(arena)))), + Z(boxed) => arena_ast::Dim::Z(arena.alloc(arena_ast::Dim1d(boxed.0.into_arena(arena)))), + } + } } #[derive(PartialEq, Eq, PartialOrd, Hash, Debug, Copy, Clone)] @@ -1222,6 +1965,16 @@ pub enum DimCompo { Z, } +impl DimCompo { + pub fn into_arena(self) -> crate::arena_ast::DimCompo { + match self { + DimCompo::X => crate::arena_ast::DimCompo::X, + DimCompo::Y => crate::arena_ast::DimCompo::Y, + DimCompo::Z => crate::arena_ast::DimCompo::Z, + } + } +} + #[span_derive(PartialEq, Eq, Hash)] #[derive(Debug, Clone)] pub struct DataTy { @@ -1241,6 +1994,18 @@ impl DataTy { } } + pub fn into_arena<'a>(self, arena: &'a Bump) -> arena_ast::DataTy<'a> { + arena_ast::DataTy { + dty: self.dty.into_arena(arena), + constraints: self + .constraints + .into_iter() + .map(|c| c.into_arena()) + .collect_in(arena), + span: self.span, + } + } + pub fn with_constr(dty: DataTyKind, constraints: Vec) -> Self { DataTy { dty, @@ -1410,6 +2175,15 @@ impl RefDty { dty: Box::new(dty), } } + + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::RefDty<'a> { + arena_ast::RefDty { + rgn: self.rgn.into_arena(arena), + own: self.own.into_arena(), + mem: self.mem.into_arena(arena), + dty: arena.alloc(self.dty.clone().into_arena(arena)), + } + } } #[derive(PartialEq, Eq, Hash, Debug, Clone)] @@ -1431,6 +2205,56 @@ pub enum DataTyKind { Dead(Box), } +impl DataTyKind { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::DataTyKind<'a> { + use DataTyKind::*; + + match self { + Ident(ident) => arena_ast::DataTyKind::Ident(ident.clone().into_arena(arena)), + Scalar(sty) => arena_ast::DataTyKind::Scalar(sty.into_arena()), + Atomic(aty) => arena_ast::DataTyKind::Atomic(aty.into_arena()), + + Array(dty, nat) => arena_ast::DataTyKind::Array( + arena.alloc(dty.clone().into_arena(arena)), + nat.into_arena(arena), + ), + + ArrayShape(dty, nat) => arena_ast::DataTyKind::ArrayShape( + arena.alloc(dty.clone().into_arena(arena)), + nat.into_arena(arena), + ), + + Tuple(elem_dtys) => { + let arena_vec = bumpalo::collections::Vec::from_iter_in( + elem_dtys.iter().map(|dty| dty.clone().into_arena(arena)), + arena, + ); + arena_ast::DataTyKind::Tuple(arena_vec) + } + + Struct(decl) => { + arena_ast::DataTyKind::Struct(arena.alloc(decl.clone().into_arena(arena))) + } + + At(dty, mem) => arena_ast::DataTyKind::At( + arena.alloc(dty.clone().into_arena(arena)), + mem.into_arena(arena), + ), + + Ref(refdty) => { + let refdty_in = arena.alloc(refdty.into_arena(arena)); + arena_ast::DataTyKind::Ref(refdty_in) + } + + RawPtr(dty) => { + arena_ast::DataTyKind::RawPtr(arena.alloc(dty.clone().into_arena(arena))) + } + + Dead(dty) => arena_ast::DataTyKind::Dead(arena.alloc(dty.clone().into_arena(arena))), + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)] pub enum ScalarTy { Unit, @@ -1445,18 +2269,55 @@ pub enum ScalarTy { Gpu, } +impl ScalarTy { + pub fn into_arena(self) -> arena_ast::ScalarTy { + match self { + ScalarTy::Unit => arena_ast::ScalarTy::Unit, + ScalarTy::U8 => arena_ast::ScalarTy::U8, + ScalarTy::U32 => arena_ast::ScalarTy::U32, + ScalarTy::U64 => arena_ast::ScalarTy::U64, + ScalarTy::I32 => arena_ast::ScalarTy::I32, + ScalarTy::I64 => arena_ast::ScalarTy::I64, + ScalarTy::F32 => arena_ast::ScalarTy::F32, + ScalarTy::F64 => arena_ast::ScalarTy::F64, + ScalarTy::Bool => arena_ast::ScalarTy::Bool, + ScalarTy::Gpu => arena_ast::ScalarTy::Gpu, + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Copy, Clone)] pub enum AtomicTy { AtomicU32, AtomicI32, } +impl AtomicTy { + pub fn into_arena(self) -> arena_ast::AtomicTy { + match self { + AtomicTy::AtomicU32 => arena_ast::AtomicTy::AtomicU32, + AtomicTy::AtomicI32 => arena_ast::AtomicTy::AtomicI32, + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub enum Provenance { Value(String), Ident(Ident), } +impl Provenance { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::Provenance<'a> { + match self { + Provenance::Value(s) => arena_ast::Provenance::Value(arena.alloc_str(s)), + Provenance::Ident(ident) => { + arena_ast::Provenance::Ident(ident.clone().into_arena(arena)) + } + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub enum Memory { CpuMem, @@ -1466,12 +2327,33 @@ pub enum Memory { Ident(Ident), } +impl Memory { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::Memory<'a> { + match self { + Memory::CpuMem => arena_ast::Memory::CpuMem, + Memory::GpuGlobal => arena_ast::Memory::GpuGlobal, + Memory::GpuShared => arena_ast::Memory::GpuShared, + Memory::GpuLocal => arena_ast::Memory::GpuLocal, + Memory::Ident(ident) => arena_ast::Memory::Ident(ident.clone().into_arena(arena)), + } + } +} + #[derive(PartialEq, Eq, Debug, Clone)] pub struct PrvRel { pub longer: Ident, pub shorter: Ident, } +impl PrvRel { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::PrvRel<'a> { + arena_ast::PrvRel { + longer: self.longer.clone().into_arena(arena), + shorter: self.shorter.clone().into_arena(arena), + } + } +} + #[derive(PartialEq, Eq, Hash, Debug, Clone)] pub struct IdentKinded { pub ident: Ident, @@ -1485,6 +2367,13 @@ impl IdentKinded { kind, } } + + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::IdentKinded<'a> { + arena_ast::IdentKinded { + ident: self.ident.clone().into_arena(arena), + kind: self.kind.into_arena(arena), + } + } } #[derive(PartialEq, Eq, Debug, Clone)] @@ -1513,6 +2402,23 @@ impl NatRange { }; Ok(range_iter) } + + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::NatRange<'a> { + use NatRange::*; + + match self { + Simple { lower, upper } => arena_ast::NatRange::Simple { + lower: lower.into_arena(arena), + upper: upper.into_arena(arena), + }, + Halved { upper } => arena_ast::NatRange::Halved { + upper: upper.into_arena(arena), + }, + Doubled { upper } => arena_ast::NatRange::Doubled { + upper: upper.into_arena(arena), + }, + } + } } pub struct NatRangeIter { @@ -1620,12 +2526,59 @@ impl NatCtx { #[derive(Debug)] pub struct NatEvalError { - unevaluable: Nat, + pub unevaluable: Nat, +} + +impl NatEvalError { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::NatEvalError<'a> { + arena_ast::NatEvalError { + unevaluable: self.unevaluable.clone().into_arena(arena), + } + } } pub type NatEvalResult = Result; +fn convert_result<'a>( + result: NatEvalResult, + arena: &'a Bump, +) -> arena_ast::NatEvalResult<'a, usize> { + result.map_err(|e| e.into_arena(arena)) +} + impl Nat { + pub fn into_arena<'a>(&self, arena: &'a Bump) -> arena_ast::Nat<'a> { + use Nat::*; + + match self { + Ident(ident) => arena_ast::Nat::Ident(ident.clone().into_arena(arena)), + + Lit(n) => arena_ast::Nat::Lit(*n), + + ThreadIdx(dc) => arena_ast::Nat::ThreadIdx((*dc).into_arena()), + BlockIdx(dc) => arena_ast::Nat::BlockIdx((*dc).into_arena()), + BlockDim(dc) => arena_ast::Nat::BlockDim((*dc).into_arena()), + + WarpGrpIdx => arena_ast::Nat::WarpGrpIdx, + WarpIdx => arena_ast::Nat::WarpIdx, + LaneIdx => arena_ast::Nat::LaneIdx, + GridIdx => arena_ast::Nat::GridIdx, + + BinOp(op, lhs, rhs) => arena_ast::Nat::BinOp( + op.clone().into_arena(), + arena.alloc(lhs.into_arena(arena)), + arena.alloc(rhs.into_arena(arena)), + ), + + App(ident, args) => { + let arena_args: BumpVec<'a, arena_ast::Nat<'a>> = + BumpVec::from_iter_in(args.iter().map(|n| n.into_arena(arena)), arena); + + arena_ast::Nat::App(ident.clone().into_arena(arena), arena_args) + } + } + } + pub fn eval(&self, nat_ctx: &NatCtx) -> NatEvalResult { match self { Nat::GridIdx @@ -1668,6 +2621,18 @@ pub enum BinOpNat { Mod, } +impl BinOpNat { + pub fn into_arena(self) -> crate::arena_ast::BinOpNat { + match self { + BinOpNat::Add => crate::arena_ast::BinOpNat::Add, + BinOpNat::Sub => crate::arena_ast::BinOpNat::Sub, + BinOpNat::Mul => crate::arena_ast::BinOpNat::Mul, + BinOpNat::Div => crate::arena_ast::BinOpNat::Div, + BinOpNat::Mod => crate::arena_ast::BinOpNat::Mod, + } + } +} + // When changing the AST, the types can quickly grow and lead to stack overflows in the different // compiler stages. // diff --git a/src/codegen/cu_ast.rs b/src/codegen/cu_ast.rs index 0e4320a0..2e0a4ada 100644 --- a/src/codegen/cu_ast.rs +++ b/src/codegen/cu_ast.rs @@ -1,27 +1,27 @@ -use crate::ast::Nat; +use crate::arena_ast::Nat; pub(super) enum Item<'a> { Include(String), - FunDecl(&'a FnSig), - FnDef(Box), + FunDecl(&'a FnSig<'a>), + FnDef(Box>), MultiLineComment(String), } #[derive(Clone)] -pub(super) struct FnSig { +pub(super) struct FnSig<'a> { pub(super) name: String, - pub(super) templ_params: Vec, - pub(super) params: Vec, - pub(super) ret_ty: Ty, + pub(super) templ_params: Vec>, + pub(super) params: Vec>, + pub(super) ret_ty: Ty<'a>, pub(super) exec_kind: ExecKind, } -impl FnSig { +impl<'a> FnSig<'a> { pub(super) fn new( name: String, - templ_params: Vec, - params: Vec, - ret_ty: Ty, + templ_params: Vec>, + params: Vec>, + ret_ty: Ty<'a>, exec_kind: ExecKind, ) -> Self { FnSig { @@ -42,135 +42,135 @@ pub(super) enum ExecKind { } #[derive(Clone)] -pub(super) struct FnDef { - pub(super) fn_sig: FnSig, - pub(super) body: Stmt, +pub(super) struct FnDef<'a> { + pub(super) fn_sig: FnSig<'a>, + pub(super) body: Stmt<'a>, } -impl FnDef { - pub(super) fn new(fn_sig: FnSig, body: Stmt) -> Self { +impl<'a> FnDef<'a> { + pub(super) fn new(fn_sig: FnSig<'a>, body: Stmt<'a>) -> Self { FnDef { fn_sig, body } } } #[derive(Clone, Debug)] -pub(super) struct ParamDecl { +pub(super) struct ParamDecl<'a> { pub(super) name: String, - pub(super) ty: Ty, + pub(super) ty: Ty<'a>, } #[derive(Clone, Debug)] -pub(super) enum Stmt { +pub(super) enum Stmt<'a> { Skip, VarDecl { name: String, - ty: Ty, + ty: Ty<'a>, addr_space: Option, - expr: Option, + expr: Option>, is_extern: bool, }, - Block(Box), - Seq(Vec), - Expr(Expr), + Block(Box>), + Seq(Vec>), + Expr(Expr<'a>), If { - cond: Expr, - body: Box, + cond: Expr<'a>, + body: Box>, }, IfElse { - cond: Expr, - true_body: Box, - false_body: Box, + cond: Expr<'a>, + true_body: Box>, + false_body: Box>, }, While { - cond: Expr, - stmt: Box, + cond: Expr<'a>, + stmt: Box>, }, ForLoop { - init: Box, - cond: Expr, - iter: Expr, - stmt: Box, + init: Box>, + cond: Expr<'a>, + iter: Expr<'a>, + stmt: Box>, }, - Return(Option), - ExecKernel(Box), + Return(Option>), + ExecKernel(Box>), } #[derive(Clone, Debug)] -pub(super) struct ExecKernel { +pub(super) struct ExecKernel<'a> { pub fun_name: String, - pub template_args: Vec, - pub grid_dim: Box, - pub block_dim: Box, - pub shared_mem_bytes: Box, - pub args: Vec, + pub template_args: Vec>, + pub grid_dim: Box>, + pub block_dim: Box>, + pub shared_mem_bytes: Box>, + pub args: Vec>, } #[derive(Clone, Debug)] -pub(super) enum Expr { +pub(super) enum Expr<'a> { Empty, Ident(String), Lit(Lit), Assign { - lhs: Box, - rhs: Box, + lhs: Box>, + rhs: Box>, }, Lambda { - captures: Vec, - params: Vec, - body: Box, - ret_ty: Ty, + captures: Vec>, + params: Vec>, + body: Box>, + ret_ty: Ty<'a>, is_dev_fun: bool, }, - FnCall(FnCall), + FnCall(FnCall<'a>), UnOp { op: UnOp, - arg: Box, + arg: Box>, }, BinOp { op: BinOp, - lhs: Box, - rhs: Box, + lhs: Box>, + rhs: Box>, }, Cast { - expr: Box, - ty: Ty, + expr: Box>, + ty: Ty<'a>, }, ArraySubscript { - array: Box, - index: Nat, + array: Box>, + index: Nat<'a>, }, Proj { - tuple: Box, + tuple: Box>, n: usize, }, FieldProj { - struct_expr: Box, + struct_expr: Box>, field_name: String, }, InitializerList { - elems: Vec, + elems: Vec>, }, AtomicRef { - expr: Box, - base_ty: Ty, + expr: Box>, + base_ty: Ty<'a>, }, - Ref(Box), - Deref(Box), - Tuple(Vec), + Ref(Box>), + Deref(Box>), + Tuple(Vec>), // The current plan for Nats is to simply print them with C syntax. // Instead generate a C/Cuda expression? - Nat(Nat), + Nat(Nat<'a>), } #[derive(Clone, Debug)] -pub(super) struct FnCall { - pub fun: Box, - pub template_args: Vec, - pub args: Vec, +pub(super) struct FnCall<'a> { + pub fun: Box>, + pub template_args: Vec>, + pub args: Vec>, } -impl FnCall { - pub fn new(fun: Expr, template_args: Vec, args: Vec) -> Self { +impl<'a> FnCall<'a> { + pub fn new(fun: Expr<'a>, template_args: Vec>, args: Vec>) -> Self { FnCall { fun: Box::new(fun), template_args, @@ -218,15 +218,15 @@ pub(super) enum BinOp { } #[derive(Clone)] -pub(super) enum TemplParam { - Value { param_name: String, ty: Ty }, +pub(super) enum TemplParam<'a> { + Value { param_name: String, ty: Ty<'a> }, TyName { name: String }, } #[derive(Clone, Debug)] -pub(super) enum TemplateArg { - Expr(Expr), - Ty(Ty), +pub(super) enum TemplateArg<'a> { + Expr(Expr<'a>), + Ty(Ty<'a>), } #[derive(Clone, Debug)] @@ -237,21 +237,21 @@ pub(super) enum GpuAddrSpace { } #[derive(Clone, Debug)] -pub(super) enum Ty { +pub(super) enum Ty<'a> { Scalar(ScalarTy), - Tuple(Vec), - Array(Box, Nat), - CArray(Box, Option), - Buffer(Box, BufferKind), + Tuple(Vec>), + Array(Box>, Nat<'a>), + CArray(Box>, Option>), + Buffer(Box>, BufferKind), // for now assume every pointer to be __restrict__ qualified // http://www.open-std.org/JTC1/SC22/WG14/www/docs/n1256.pdf#page=122&zoom=auto,-205,535 - Ptr(Box), + Ptr(Box>), // The pointer itself is mutable, but the underlying data is not. - PtrConst(Box), + PtrConst(Box>), // const in a parameter declaration changes the parameter type in a definition but not // "necessarily" the function signature ... https://abseil.io/tips/109 // Top-level const - Const(Box), + Const(Box>), // Template parameter identifer Ident(String), } diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 71931425..4b4bf4d7 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -1,9 +1,9 @@ mod cu_ast; mod printer; -use crate::ast as desc; -use crate::ast::visit::Visit; -use crate::ast::visit_mut::VisitMut; +use crate::arena_ast as desc; +use crate::arena_ast::visit::Visit; +use crate::arena_ast::visit_mut::VisitMut; use crate::ty_check; use cu_ast as cu; use std::collections::HashMap; @@ -14,7 +14,7 @@ pub(crate) static WARP_IDENT: &str = "$warp"; // Precondition. all function definitions are successfully typechecked and // therefore every subexpression stores a type -pub fn gen(comp_unit: &desc::CompilUnit, idx_checks: bool) -> String { +pub fn gen<'a>(comp_unit: &'a desc::CompilUnit<'a>, idx_checks: bool) -> String { let mut initial_fns_to_generate = collect_initial_fns_to_generate(comp_unit); let mut codegen_ctx = CodegenCtx::new( // CpuThread is only a dummy and will be set according to the generated function. @@ -74,7 +74,9 @@ pub fn gen(comp_unit: &desc::CompilUnit, idx_checks: bool) -> String { printer::print(&cu_program) } -fn collect_initial_fns_to_generate(comp_unit: &desc::CompilUnit) -> Vec { +fn collect_initial_fns_to_generate<'a>( + comp_unit: &'a desc::CompilUnit<'a>, +) -> Vec> { comp_unit .items .iter() @@ -142,11 +144,11 @@ fn collect_initial_fns_to_generate(comp_unit: &desc::CompilUnit) -> Vec cu::Stmt, +fn mv_shrd_mem_params_into_decls<'a>( + mut f: cu::FnDef<'a>, + unnamed_shrd_mem_decls: &dyn Fn(&[String]) -> cu::Stmt<'a>, num_shared_mem_decls: usize, -) -> cu::FnDef { +) -> cu::FnDef<'a> { if let cu::Stmt::Block(stmt) = f.body { let shrd_mem_params = f .fn_sig @@ -179,16 +181,16 @@ fn collect_fn_decls<'a>(items: &'a [cu::Item<'a>]) -> Vec> { } struct CodegenCtx<'a> { - view_ctx: ViewCtx, - inst_fn_ctx: HashMap, - exec_mapping: ExecMapping, - exec: desc::ExecExpr, - comp_unit: &'a [desc::Item], + view_ctx: ViewCtx<'a>, + inst_fn_ctx: HashMap>, + exec_mapping: ExecMapping<'a>, + exec: desc::ExecExpr<'a>, + comp_unit: &'a [desc::Item<'a>], kernel_infos: Vec, } impl<'a> CodegenCtx<'a> { - fn new(exec: desc::ExecExpr, comp_unit: &'a [desc::Item]) -> Self { + fn new(exec: desc::ExecExpr<'a>, comp_unit: &'a [desc::Item<'a>]) -> Self { CodegenCtx { view_ctx: ViewCtx::new(), inst_fn_ctx: HashMap::new(), @@ -216,8 +218,8 @@ struct KernelInfo { num_shrd_mem_decls: usize, } -type ViewCtx = ScopeCtx; -type ExecMapping = ScopeCtx; +type ViewCtx<'a> = ScopeCtx>; +type ExecMapping<'a> = ScopeCtx>; #[derive(Default, Clone, Debug)] struct ScopeCtx { @@ -269,7 +271,10 @@ impl ScopeCtx { } } -fn gen_fun_def(gl_fun: &desc::FunDef, codegen_ctx: &mut CodegenCtx) -> cu::FnDef { +fn gen_fun_def<'a>( + gl_fun: &'a desc::FunDef<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::FnDef<'a> { let desc::FunDef { ident: name, generic_params: ty_idents, @@ -312,7 +317,11 @@ fn gen_fun_def(gl_fun: &desc::FunDef, codegen_ctx: &mut CodegenCtx) -> cu::FnDef } // Generate CUDA code for Descend syntax that allows sequencing. -fn gen_stmt(expr: &desc::Expr, return_value: bool, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_stmt<'a>( + expr: &'a desc::Expr<'a>, + return_value: bool, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { use desc::ExprKind::*; match &expr.expr { Let(pattern, _, e) => { @@ -498,7 +507,11 @@ fn gen_stmt(expr: &desc::Expr, return_value: bool, codegen_ctx: &mut CodegenCtx) } } -fn gen_let(pattern: &desc::Pattern, e: &desc::Expr, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_let<'a>( + pattern: &'a desc::Pattern<'a>, + e: &'a desc::Expr<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { match pattern { desc::Pattern::Tuple(tuple_elems) => { let tuple_ident = desc::Ident::new(&desc::utils::fresh_name("tuple")); @@ -550,12 +563,12 @@ fn gen_let(pattern: &desc::Pattern, e: &desc::Expr, codegen_ctx: &mut CodegenCtx } } -fn gen_decl_init( - ident: &desc::Ident, +fn gen_decl_init<'a>( + ident: &'a desc::Ident<'a>, mutbl: desc::Mutability, - e: &desc::Expr, - codegen_ctx: &mut CodegenCtx, -) -> cu::Stmt { + e: &'a desc::Expr<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { //let gened_ty = gen_ty(&e.ty.as_ref().unwrap().ty, mutbl); let (init_expr, cu_ty) = if let desc::ExprKind::Ref(_, _, pl_expr) = &e.expr { match &pl_expr.ty.as_ref().unwrap().dty().dty { @@ -618,12 +631,12 @@ fn gen_decl_init( } } -fn gen_if_else( - cond: cu_ast::Expr, - e_tt: &desc::Expr, - e_ff: &desc::Expr, - codegen_ctx: &mut CodegenCtx, -) -> cu::Stmt { +fn gen_if_else<'a>( + cond: cu_ast::Expr<'a>, + e_tt: &'a desc::Expr<'a>, + e_ff: &'a desc::Expr<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { cu::Stmt::IfElse { cond: cond, true_body: Box::new(gen_stmt(e_tt, false, codegen_ctx)), @@ -631,21 +644,25 @@ fn gen_if_else( } } -fn gen_if(cond: cu_ast::Expr, e_tt: &desc::Expr, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_if<'a>( + cond: cu_ast::Expr<'a>, + e_tt: &'a desc::Expr<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { cu::Stmt::If { cond: cond, body: Box::new(gen_stmt(e_tt, false, codegen_ctx)), } } -fn gen_for_each( - ident: &desc::Ident, - coll_expr: &desc::Expr, - body: &desc::Block, - codegen_ctx: &mut CodegenCtx, -) -> cu::Stmt { +fn gen_for_each<'a>( + ident: &'a desc::Ident<'a>, + coll_expr: &'a desc::Expr<'a>, + body: &'a desc::Block<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { todo!(); - let i_name = crate::ast::utils::fresh_name("i__"); + let i_name = crate::arena_ast::utils::fresh_name("i__"); let i_decl = cu::Stmt::VarDecl { name: i_name.clone(), ty: cu::Ty::Scalar(cu::ScalarTy::SizeT), @@ -691,12 +708,12 @@ fn gen_for_each( // for_loop } -fn gen_for_range( - ident: &desc::Ident, - range: &desc::Expr, - body: &desc::Expr, - codegen_ctx: &mut CodegenCtx, -) -> cu::Stmt { +fn gen_for_range<'a>( + ident: &'a desc::Ident<'a>, + range: &'a desc::Expr<'a>, + body: &'a desc::Expr<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { if let desc::ExprKind::Range(l, u) = &range.expr { let lower = gen_expr(l, codegen_ctx); let upper = gen_expr(u, codegen_ctx); @@ -733,7 +750,10 @@ fn gen_for_range( } } -fn gen_app_kernel(app_kernel: &desc::AppKernel, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_app_kernel( + app_kernel: &'a desc::AppKernel<'a>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Stmt<'a> { let tmp_global_fn_call = gen_global_fn_call( &app_kernel.fun_ident, &app_kernel.gen_args, @@ -759,14 +779,16 @@ fn gen_app_kernel(app_kernel: &desc::AppKernel, codegen_ctx: &mut CodegenCtx) -> })) } -fn convert_to_fn_name(f_expr: &cu::Expr) -> String { +fn convert_to_fn_name<'a>(f_expr: &'a cu::Expr<'a>) -> String { match f_expr { cu::Expr::Ident(f_name) => f_name.clone(), _ => panic!("The expression does not refer to a function by its identifier."), } } -fn unnamed_shared_mem_decls(dtys: Vec) -> Box cu::Stmt> { +fn unnamed_shared_mem_decls<'a>( + dtys: Vec>, +) -> Box cu::Stmt<'a>> { // Multiple shared memory arrays and Alignments: // Memory accesses require that the address be aligned to a multiple of the access size. // The access size of a memory instruction is the total number of bytes accessed in memory. @@ -841,7 +863,7 @@ fn unnamed_shared_mem_decls(dtys: Vec) -> Box }) } -fn size_of_dty(dty: &desc::DataTy) -> usize { +fn size_of_dty<'a>(dty: &'a desc::DataTy<'a>) -> usize { match &dty.dty { desc::DataTyKind::Scalar(desc::ScalarTy::Bool) => 1, desc::DataTyKind::Scalar(desc::ScalarTy::U32) @@ -855,7 +877,7 @@ fn size_of_dty(dty: &desc::DataTy) -> usize { } } -fn get_elem_ty_and_amount(dty: &desc::DataTy) -> (desc::Ty, desc::Nat) { +fn get_elem_ty_and_amount<'a>(dty: &'a desc::DataTy<'a>) -> (desc::Ty<'a>, desc::Nat<'a>) { let nat_1 = desc::Nat::Lit(1); match &dty.dty { desc::DataTyKind::Scalar(desc::ScalarTy::Bool) @@ -882,7 +904,7 @@ fn get_elem_ty_and_amount(dty: &desc::DataTy) -> (desc::Ty, desc::Nat) { } } -fn count_bytes(dtys: &[desc::DataTy]) -> desc::Nat { +fn count_bytes(dtys: &'a [desc::DataTy<'a>]) -> desc::Nat<'a> { let mut bytes = desc::Nat::Lit(0); for dty in dtys { let (elem_ty, amount) = get_elem_ty_and_amount(dty); @@ -896,7 +918,7 @@ fn count_bytes(dtys: &[desc::DataTy]) -> desc::Nat { bytes } -fn gen_indep(indep: &desc::Split, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_indep<'a>(indep: &'a desc::Split<'a>, codegen_ctx: &'a mut CodegenCtx<'a>) -> cu::Stmt<'a> { let outer_exec = codegen_ctx.exec.clone(); let expanded_outer_exec = expand_exec_expr(codegen_ctx, &outer_exec); codegen_ctx.push_scope(); @@ -947,7 +969,7 @@ fn gen_indep(indep: &desc::Split, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { } } -fn gen_sync_stmt(exec: &desc::ExecExpr) -> cu::Stmt { +fn gen_sync_stmt<'a>(exec: &'a desc::ExecExpr<'a>) -> cu::Stmt<'a> { let sync = cu::Stmt::Expr(cu::Expr::FnCall(cu::FnCall::new( cu::Expr::Ident("__syncthreads".to_string()), vec![], @@ -972,7 +994,7 @@ fn gen_sync_stmt(exec: &desc::ExecExpr) -> cu::Stmt { // } } -fn gen_sched(sched: &desc::Sched, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { +fn gen_sched<'a>(sched: &'a desc::Sched<'a>, codegen_ctx: &'a mut CodegenCtx<'a>) -> cu::Stmt<'a> { codegen_ctx.push_scope(); let expanded_sched_exec_expr = expand_exec_expr(codegen_ctx, sched.sched_exec.as_ref()); let inner_exec = desc::ExecExpr::new(expanded_sched_exec_expr.exec.clone().forall(sched.dim)); @@ -1055,7 +1077,7 @@ fn gen_sched(sched: &desc::Sched, codegen_ctx: &mut CodegenCtx) -> cu::Stmt { // } // } -fn gen_expr(expr: &desc::Expr, codegen_ctx: &mut CodegenCtx) -> cu::Expr { +fn gen_expr<'a>(expr: &'a desc::Expr<'a>, codegen_ctx: &'a mut CodegenCtx<'a>) -> cu::Expr<'a> { use desc::ExprKind::*; match &expr.expr { Hole => cu::Expr::Empty, @@ -1235,10 +1257,10 @@ fn gen_expr(expr: &desc::Expr, codegen_ctx: &mut CodegenCtx) -> cu::Expr { } fn gen_lambda_call( - fun: &desc::Expr, - args: &[desc::Expr], - codegen_ctx: &mut CodegenCtx, -) -> cu::Expr { + fun: &'a desc::Expr<'a>, + args: &'a [desc::Expr<'a>], + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Expr<'a> { unimplemented!( "The only case for which this would have to be generated is, when a lambda is called right\ where it is created. There is no way to bind a lambda with let.\ @@ -1246,12 +1268,12 @@ fn gen_lambda_call( ) } -fn gen_global_fn_call( - fun_ident: &desc::Ident, - gen_args: &[desc::ArgKinded], - args: &[desc::Expr], - codegen_ctx: &mut CodegenCtx, -) -> cu::FnCall { +fn gen_global_fn_call<'a>( + fun_ident: &'a desc::Ident<'a>, + gen_args: &'a [desc::ArgKinded<'a>], + args: &'a [desc::Expr<'a>], + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::FnCall<'a> { // Make sure that we do not accidentally add views conflicting to fun, // because during type checking the order is: check fun first then do the arguments. codegen_ctx.push_scope(); @@ -1279,7 +1301,10 @@ fn gen_global_fn_call( } // TODO generate different arguments for views or inline -fn gen_fn_call_args(args: &[desc::Expr], codegen_ctx: &mut CodegenCtx) -> Vec { +fn gen_fn_call_args<'a>( + args: &'a [desc::Expr<'a>], + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> Vec> { args.iter() .map(|arg| gen_expr(arg, codegen_ctx)) // GenState::Gened(cu_expr) => cu_expr, @@ -1290,7 +1315,7 @@ fn gen_fn_call_args(args: &[desc::Expr], codegen_ctx: &mut CodegenCtx) -> Vec desc::PlaceExpr { +fn basis_ref<'a>(view_expr: &'a desc::PlaceExpr<'a>) -> desc::PlaceExpr<'a> { let mut bref = view_expr.clone(); let mut current = view_expr.clone(); while !matches!(¤t.pl_expr, desc::PlaceExprKind::Ident(_)) { @@ -1312,7 +1337,7 @@ fn basis_ref(view_expr: &desc::PlaceExpr) -> desc::PlaceExpr { bref } -fn view_exprs_in_args(args: &[desc::Expr]) -> Vec<&desc::Expr> { +fn view_exprs_in_args<'a>(args: &'a [desc::Expr<'a>]) -> Vec<&'a desc::Expr<'a>> { let (views, _): (Vec<_>, Vec<_>) = args .iter() .partition(|a| is_view_dty(a.ty.as_ref().unwrap())); @@ -1335,7 +1360,7 @@ fn view_exprs_in_args(args: &[desc::Expr]) -> Vec<&desc::Expr> { fn separate_view_params_with_args_from_rest<'a>( param_decls: &'a [desc::ParamDecl], args: &'a [desc::Expr], -) -> Vec<(&'a desc::ParamDecl, &'a desc::Expr)> { +) -> Vec<(&'a desc::ParamDecl<'a>, &'a desc::Expr<'a>)> { let (view_params_with_args, _): (Vec<_>, Vec<_>) = param_decls .iter() .zip(args.iter()) @@ -1356,7 +1381,7 @@ fn separate_view_params_with_args_from_rest<'a>( // } // } -fn stringify_exec(exec: &desc::ExecExpr) -> String { +fn stringify_exec<'a>(exec: &'a desc::ExecExpr<'a>) -> String { let mut str = String::with_capacity(10); for e in &exec.exec.path { match e { @@ -1442,19 +1467,19 @@ fn stringify_exec(exec: &desc::ExecExpr) -> String { // str // } -fn create_named_fn_call( +fn create_named_fn_call<'a>( name: String, - gen_args: Vec, - args: Vec, -) -> cu::FnCall { + gen_args: Vec>, + args: Vec>, +) -> cu::FnCall<'a> { create_fn_call(cu::Expr::Ident(name), gen_args, args) } -fn create_fn_call( - fun: cu::Expr, +fn create_fn_call<'a>( + fun: cu::Expr<'a>, gen_args: Vec, params: Vec, -) -> cu::FnCall { +) -> cu::FnCall<'a> { cu::FnCall { fun: Box::new(fun), template_args: gen_args, @@ -1462,12 +1487,12 @@ fn create_fn_call( } } -fn gen_bin_op_expr( +fn gen_bin_op_expr<'a>( op: &desc::BinOp, - lhs: &desc::Expr, - rhs: &desc::Expr, + lhs: &'a desc::Expr<'a>, + rhs: &'a desc::Expr<'a>, codegen_ctx: &mut CodegenCtx, -) -> cu::Expr { +) -> cu::Expr<'a> { let op = match op { desc::BinOp::Add => cu::BinOp::Add, desc::BinOp::Sub => cu::BinOp::Sub, @@ -1494,7 +1519,7 @@ fn gen_bin_op_expr( } } -fn extract_fn_ident(ident: &desc::Expr) -> desc::Ident { +fn extract_fn_ident<'a>(ident: &'a desc::Expr<'a>) -> desc::Ident<'a> { if let desc::ExprKind::PlaceExpr(pl_expr) = &ident.expr { if let desc::PlaceExprKind::Ident(ident) = &pl_expr.pl_expr { ident.clone() @@ -1506,7 +1531,7 @@ fn extract_fn_ident(ident: &desc::Expr) -> desc::Ident { } } -fn contains_shape_expr(pl_expr: &desc::PlaceExpr, shape_ctx: &ViewCtx) -> bool { +fn contains_shape_expr<'a>(pl_expr: &'a desc::PlaceExpr<'a>, shape_ctx: &ViewCtx) -> bool { let (_, pl) = pl_expr.to_pl_ctx_and_most_specif_pl(); shape_ctx.contains_key(&pl.ident.name) } @@ -1530,7 +1555,7 @@ fn contains_shape_expr(pl_expr: &desc::PlaceExpr, shape_ctx: &ViewCtx) -> bool { // ) // } -fn gen_lit(l: desc::Lit) -> cu::Expr { +fn gen_lit<'a>(l: desc::Lit) -> cu::Expr<'a> { match l { desc::Lit::Bool(b) => cu::Expr::Lit(cu::Lit::Bool(b)), desc::Lit::I32(i) => cu::Expr::Lit(cu::Lit::I32(i)), @@ -1543,15 +1568,15 @@ fn gen_lit(l: desc::Lit) -> cu::Expr { } } -enum IdxOrProj { - Idx(desc::Nat), +enum IdxOrProj<'a> { + Idx(desc::Nat<'a>), Proj(usize), } -fn flattened_elem_counts_per_dim( - dty: &desc::DataTy, - mut elem_counts: Vec, -) -> Vec { +fn flattened_elem_counts_per_dim<'a>( + dty: &'a desc::DataTy<'a>, + mut elem_counts: Vec>, +) -> Vec> { match &dty.dty { desc::DataTyKind::Array(d, n) | desc::DataTyKind::ArrayShape(d, n) => { for elem_count in &mut elem_counts { @@ -1568,16 +1593,16 @@ fn flattened_elem_counts_per_dim( } } -fn gen_pl_expr( - pl_expr: &desc::PlaceExpr, - path: &mut Vec, - codegen_ctx: &mut CodegenCtx, -) -> cu::Expr { - fn gen_flat_indexing( - expr: cu::Expr, - path: &[desc::Nat], - operand_dty: &desc::DataTy, - ) -> cu::Expr { +fn gen_pl_expr<'a>( + pl_expr: &'a desc::PlaceExpr<'a>, + path: &'a mut Vec>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Expr<'a> { + fn gen_flat_indexing<'a>( + expr: cu::Expr<'a>, + path: &'a [desc::Nat<'a>], + operand_dty: &'a desc::DataTy<'a>, + ) -> cu::Expr<'a> { let elem_counts = flattened_elem_counts_per_dim(operand_dty, vec![]); let mut elem_counts_iter = elem_counts.iter(); // skip outermost dimension @@ -1668,7 +1693,10 @@ fn gen_pl_expr( } } -fn inline_view_expr(pl_expr: &desc::PlaceExpr, codegen_ctx: &CodegenCtx) -> desc::PlaceExpr { +fn inline_view_expr<'a>( + pl_expr: &'a desc::PlaceExpr<'a>, + codegen_ctx: &'a CodegenCtx<'a>, +) -> desc::PlaceExpr<'a> { let (_, most_spec_pl) = pl_expr.to_pl_ctx_and_most_specif_pl(); if codegen_ctx.view_ctx.contains_key(&most_spec_pl.ident.name) { insert_into_pl_expr( @@ -1680,11 +1708,14 @@ fn inline_view_expr(pl_expr: &desc::PlaceExpr, codegen_ctx: &CodegenCtx) -> desc } } -fn insert_into_pl_expr(mut pl_expr: desc::PlaceExpr, insert: &desc::PlaceExpr) -> desc::PlaceExpr { +fn insert_into_pl_expr<'a>( + mut pl_expr: desc::PlaceExpr<'a>, + insert: &'a desc::PlaceExpr<'a>, +) -> desc::PlaceExpr<'a> { struct InsertIntoPlExpr<'a> { - insert: &'a desc::PlaceExpr, + insert: &'a desc::PlaceExpr<'a>, } - impl VisitMut for InsertIntoPlExpr<'_> { + impl<'a> VisitMut<'a> for InsertIntoPlExpr<'_> { fn visit_pl_expr(&mut self, pl_expr: &mut desc::PlaceExpr) { match &mut pl_expr.pl_expr { desc::PlaceExprKind::Deref(ple) => { @@ -1779,7 +1810,10 @@ fn transform_path_with_view(view: &desc::View, path: &mut Vec) -> boo true } -fn transform_path_with_group(grp_size: &desc::Nat, path: &mut Vec) -> bool { +fn transform_path_with_group<'a>( + grp_size: &'a desc::Nat<'a>, + path: &'a mut Vec>, +) -> bool { let i = path.pop(); let j = path.pop(); match (i, j) { @@ -1807,7 +1841,7 @@ fn transform_path_with_group(grp_size: &desc::Nat, path: &mut Vec) -> } } -fn transform_path_with_rev(len: &desc::Nat, path: &mut Vec) -> bool { +fn transform_path_with_rev<'a>(len: &'a desc::Nat<'a>, path: &'a mut Vec>) -> bool { let i = path.pop(); match i { Some(i) => { @@ -1826,7 +1860,7 @@ fn transform_path_with_rev(len: &desc::Nat, path: &mut Vec) -> bool { } } -fn transform_path_with_transpose(path: &mut Vec) -> bool { +fn transform_path_with_transpose<'a>(path: &'a mut Vec>) -> bool { let i = path.pop(); let j = path.pop(); match (i, j) { @@ -1839,7 +1873,10 @@ fn transform_path_with_transpose(path: &mut Vec) -> bool { } } -fn transform_path_with_join(row_size: &desc::Nat, path: &mut Vec) -> bool { +fn transform_path_with_join<'a>( + row_size: &'a desc::Nat<'a>, + path: &'a mut Vec>, +) -> bool { let i = path.pop(); match i { Some(idx) => { @@ -1859,7 +1896,10 @@ fn transform_path_with_join(row_size: &desc::Nat, path: &mut Vec) -> } } -fn transform_path_with_select_range(lower_bound: &desc::Nat, path: &mut Vec) -> bool { +fn transform_path_with_select_range<'a>( + lower_bound: &'a desc::Nat<'a>, + path: &'a mut Vec>, +) -> bool { let idx = path.pop(); match idx { Some(i) => { @@ -1875,9 +1915,9 @@ fn transform_path_with_select_range(lower_bound: &desc::Nat, path: &mut Vec, +fn transform_path_with_take<'a>( + split_pos: &'a desc::Nat<'a>, + path: &'a mut Vec>, take_side: ty_check::pre_decl::TakeSide, ) -> bool { let idx = path.pop(); @@ -1900,7 +1940,7 @@ fn transform_path_with_take( } } -fn transform_path_with_map(f: &desc::View, path: &mut Vec) -> bool { +fn transform_path_with_map<'a>(f: &'a desc::View<'a>, path: &'a mut Vec>) -> bool { let i = path.pop(); match i { Some(idx) => { @@ -1912,11 +1952,11 @@ fn transform_path_with_map(f: &desc::View, path: &mut Vec) -> bool { } } -fn gen_indep_branch_cond( +fn gen_indep_branch_cond<'a>( dim_compo: desc::DimCompo, - pos: &desc::Nat, - exec: &desc::ExecExprKind, -) -> cu::Expr { + pos: &'a desc::Nat<'a>, + exec: &'a desc::ExecExprKind<'a>, +) -> cu::Expr<'a> { cu::Expr::BinOp { op: cu::BinOp::Lt, lhs: Box::new(cu::Expr::Nat(parall_idx( @@ -1929,7 +1969,7 @@ fn gen_indep_branch_cond( } } -fn gen_templ_params(ty_idents: &[desc::IdentKinded]) -> Vec { +fn gen_templ_params<'a>(ty_idents: &'a [desc::IdentKinded<'a>]) -> Vec> { ty_idents .iter() .filter_map(|ty_ident| { @@ -1942,7 +1982,7 @@ fn gen_templ_params(ty_idents: &[desc::IdentKinded]) -> Vec { .collect() } -fn gen_templ_param(ty_ident: &desc::IdentKinded) -> cu::TemplParam { +fn gen_templ_param<'a>(ty_ident: &'a desc::IdentKinded<'a>) -> cu::TemplParam<'a> { let name = ty_ident.ident.name.clone(); match ty_ident.kind { desc::Kind::Nat => cu::TemplParam::Value { @@ -1963,11 +2003,11 @@ fn gen_templ_param(ty_ident: &desc::IdentKinded) -> cu::TemplParam { } } -fn gen_param_decls(param_decls: &[desc::ParamDecl]) -> Vec { +fn gen_param_decls<'a>(param_decls: &'a [desc::ParamDecl<'a>]) -> Vec> { param_decls.iter().map(gen_param_decl).collect() } -fn gen_param_decl(param_decl: &desc::ParamDecl) -> cu::ParamDecl { +fn gen_param_decl<'a>(param_decl: &'a desc::ParamDecl<'a>) -> cu::ParamDecl<'a> { let desc::ParamDecl { ident, ty, @@ -1980,11 +2020,11 @@ fn gen_param_decl(param_decl: &desc::ParamDecl) -> cu::ParamDecl { } } -fn gen_args_kinded(templ_args: &[desc::ArgKinded]) -> Vec { +fn gen_args_kinded<'a>(templ_args: &'a [desc::ArgKinded<'a>]) -> Vec> { templ_args.iter().filter_map(gen_arg_kinded).collect() } -fn gen_nat_as_u64(templ_args: &[desc::ArgKinded]) -> cu::Expr { +fn gen_nat_as_u64<'a>(templ_args: &'a [desc::ArgKinded<'a>]) -> cu::Expr<'a> { let generated_arg_expr = gen_arg_kinded(&templ_args[0]); if let Some(e) = generated_arg_expr { if let cu::TemplateArg::Expr(expr) = e { @@ -1997,15 +2037,18 @@ fn gen_nat_as_u64(templ_args: &[desc::ArgKinded]) -> cu::Expr { } } -fn gen_to_atomic_array(args: &Vec, codegen_ctx: &mut CodegenCtx) -> cu::Expr { +fn gen_to_atomic_array<'a>( + args: &'a Vec>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Expr<'a> { gen_fn_call_args(args, codegen_ctx)[0].clone() } -fn gen_shfl_up( - args: &Vec, - kinded_args: &Vec, - codegen_ctx: &mut CodegenCtx, -) -> cu::Expr { +fn gen_shfl_up<'a>( + args: &'a Vec>, + kinded_args: &'a Vec>, + codegen_ctx: &'a mut CodegenCtx<'a>, +) -> cu::Expr<'a> { cu::Expr::FnCall(create_fn_call( cu::Expr::Ident(format!("{}.shfl_up", WARP_IDENT)), gen_args_kinded(kinded_args), @@ -2013,7 +2056,7 @@ fn gen_shfl_up( )) } -fn gen_arg_kinded(templ_arg: &desc::ArgKinded) -> Option { +fn gen_arg_kinded<'a>(templ_arg: &'a desc::ArgKinded<'a>) -> Option> { match templ_arg { desc::ArgKinded::Nat(n) => Some(cu::TemplateArg::Expr(cu::Expr::Nat(n.clone()))), desc::ArgKinded::DataTy(dty) => Some(cu::TemplateArg::Ty(gen_ty( @@ -2030,7 +2073,7 @@ fn gen_arg_kinded(templ_arg: &desc::ArgKinded) -> Option { // in cu::Ty::Const. However, the formalism uses this, because it shows the generated code // as opposed to a Cuda-AST and there, the order of the const is different // when it comes to pointers (C things). -fn gen_ty(ty: &desc::TyKind, mutbl: desc::Mutability) -> cu::Ty { +fn gen_ty<'a>(ty: &'a desc::TyKind<'a>, mutbl: desc::Mutability) -> cu::Ty<'a> { use desc::DataTyKind as d; use desc::TyKind::*; @@ -2139,7 +2182,7 @@ fn gen_ty(ty: &desc::TyKind, mutbl: desc::Mutability) -> cu::Ty { } } -fn base_dty(dty: &desc::DataTy) -> desc::DataTy { +fn base_dty<'a>(dty: &desc::DataTy) -> desc::DataTy<'a> { if let desc::DataTyKind::Array(elem_dty, _) = &dty.dty { base_dty(elem_dty) } else { @@ -2161,7 +2204,10 @@ fn is_dev_fun(exec_ty: &desc::ExecTy) -> bool { } } -fn expand_exec_expr(codegen_ctx: &CodegenCtx, exec_expr: &desc::ExecExpr) -> desc::ExecExpr { +fn expand_exec_expr<'a>( + codegen_ctx: &CodegenCtx, + exec_expr: &desc::ExecExpr, +) -> desc::ExecExpr<'a> { match &exec_expr.exec.base { desc::BaseExec::CpuThread | desc::BaseExec::GpuGrid(_, _) => exec_expr.clone(), desc::BaseExec::Ident(ident) => { @@ -2176,7 +2222,9 @@ fn expand_exec_expr(codegen_ctx: &CodegenCtx, exec_expr: &desc::ExecExpr) -> des } } -fn to_parall_indices(exec: &desc::ExecExpr) -> (desc::Nat, desc::Nat, desc::Nat) { +fn to_parall_indices<'a>( + exec: &'a desc::ExecExpr<'a>, +) -> (desc::Nat<'a>, desc::Nat<'a>, desc::Nat<'a>) { let mut indices = match &exec.exec.base { desc::BaseExec::GpuGrid(_, _) => { (desc::Nat::GridIdx, desc::Nat::GridIdx, desc::Nat::GridIdx) @@ -2272,12 +2320,12 @@ fn to_parall_indices(exec: &desc::ExecExpr) -> (desc::Nat, desc::Nat, desc::Nat) indices } -fn contained_par_idx(n: &desc::Nat) -> Option { - struct ContainedParIdx { - par_idx: Option, +fn contained_par_idx<'a>(n: &'a desc::Nat<'a>) -> Option> { + struct ContainedParIdx<'a> { + par_idx: Option>, } - impl Visit for ContainedParIdx { - fn visit_nat(&mut self, n: &desc::Nat) { + impl<'a> Visit<'a> for ContainedParIdx<'a> { + fn visit_nat(&mut self, n: &desc::Nat<'a>) { match n { desc::Nat::GridIdx => self.par_idx = Some(n.clone()), desc::Nat::BlockIdx(_) => self.par_idx = Some(n.clone()), @@ -2292,16 +2340,20 @@ fn contained_par_idx(n: &desc::Nat) -> Option { contained.par_idx } -fn set_distrib_idx(idx: &mut desc::Nat, parall_idx: desc::Nat, shift: &mut desc::Nat) { +fn set_distrib_idx<'a>( + idx: &'a mut desc::Nat<'a>, + parall_idx: desc::Nat<'a>, + shift: &'a mut desc::Nat<'a>, +) { *idx = shift_idx_by(parall_idx, shift.clone()); *shift = desc::Nat::Lit(0); } -fn shift_idx_by(idx: desc::Nat, shift: desc::Nat) -> desc::Nat { +fn shift_idx_by<'a>(idx: desc::Nat<'a>, shift: desc::Nat<'a>) -> desc::Nat<'a> { desc::Nat::BinOp(desc::BinOpNat::Sub, Box::new(idx), Box::new(shift)) } -fn parall_idx(dim: desc::DimCompo, exec: &desc::ExecExpr) -> desc::Nat { +fn parall_idx<'a>(dim: desc::DimCompo, exec: &'a desc::ExecExpr<'a>) -> desc::Nat<'a> { match dim { desc::DimCompo::X => to_parall_indices(exec).0, desc::DimCompo::Y => to_parall_indices(exec).1, @@ -2309,7 +2361,7 @@ fn parall_idx(dim: desc::DimCompo, exec: &desc::ExecExpr) -> desc::Nat { } } -fn gen_dim3(dim: &desc::Dim) -> cu::Expr { +fn gen_dim3<'a>(dim: &'a desc::Dim<'a>) -> cu::Expr<'a> { let one = desc::Nat::Lit(1); let (nx, ny, nz) = match dim { desc::Dim::X(n) => (n.0.clone(), one.clone(), one), @@ -2328,7 +2380,7 @@ fn gen_dim3(dim: &desc::Dim) -> cu::Expr { }) } -fn is_view_dty(ty: &desc::Ty) -> bool { +fn is_view_dty<'a>(ty: &'a desc::Ty<'a>) -> bool { match &ty.ty { desc::TyKind::Data(dty) => match &dty.dty { desc::DataTyKind::Ref(reff) => { diff --git a/src/codegen/printer.rs b/src/codegen/printer.rs index 8dca7f2a..5d1560bd 100644 --- a/src/codegen/printer.rs +++ b/src/codegen/printer.rs @@ -2,7 +2,7 @@ use super::cu_ast::{ BinOp, BufferKind, ExecKind, Expr, FnDef, FnSig, GpuAddrSpace, Item, Lit, ParamDecl, ScalarTy, Stmt, TemplParam, TemplateArg, Ty, UnOp, }; -use crate::ast::{BinOpNat, DimCompo, Ident, Nat}; +use crate::arena_ast::{BinOpNat, DimCompo, Ident, Nat}; use std::env; use std::fmt::Formatter; @@ -60,7 +60,7 @@ impl<'a> std::fmt::Display for Item<'a> { } } -impl std::fmt::Display for FnSig { +impl<'a> std::fmt::Display for FnSig<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let FnSig { name, @@ -81,7 +81,7 @@ impl std::fmt::Display for FnSig { } } -impl std::fmt::Display for FnDef { +impl<'a> std::fmt::Display for FnDef<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { let FnDef { fn_sig, body } = self; write!(f, "{}", fn_sig)?; @@ -89,7 +89,7 @@ impl std::fmt::Display for FnDef { } } -impl std::fmt::Display for Ident { +impl<'a> std::fmt::Display for Ident<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self.name) } @@ -105,7 +105,7 @@ impl std::fmt::Display for ExecKind { } } -impl std::fmt::Display for Stmt { +impl<'a> std::fmt::Display for Stmt<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use Stmt::*; match self { @@ -201,7 +201,7 @@ impl std::fmt::Display for Stmt { } } -impl std::fmt::Display for Expr { +impl<'a> std::fmt::Display for Expr<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use Expr::*; match self { @@ -293,13 +293,13 @@ impl std::fmt::Display for Lit { } } -impl std::fmt::Display for ParamDecl { +impl<'a> std::fmt::Display for ParamDecl<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{} {}", self.ty, self.name) } } -impl std::fmt::Display for TemplateArg { +impl<'a> std::fmt::Display for TemplateArg<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { TemplateArg::Expr(expr) => write!(f, "{}", expr), @@ -308,7 +308,7 @@ impl std::fmt::Display for TemplateArg { } } -impl std::fmt::Display for TemplParam { +impl<'a> std::fmt::Display for TemplParam<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { TemplParam::Value { param_name, ty } => write!(f, "{} {}", ty, param_name), @@ -361,7 +361,7 @@ impl std::fmt::Display for GpuAddrSpace { } } -impl std::fmt::Display for Ty { +impl<'a> std::fmt::Display for Ty<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { use Ty::*; match self { @@ -415,7 +415,7 @@ impl std::fmt::Display for ScalarTy { } } -impl std::fmt::Display for Nat { +impl<'a> std::fmt::Display for Nat<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { match self { Self::Ident(ident) => write!(f, "{}", ident), diff --git a/src/error.rs b/src/error.rs index 64fe3ff0..a46b5ecf 100644 --- a/src/error.rs +++ b/src/error.rs @@ -94,3 +94,91 @@ impl std::fmt::Debug for ErrorReported { write!(f, "Aborting due to previous error.") } } + +impl std::fmt::Display for ErrorReported { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Aborting due to a previous error.") + } +} + +#[derive(Debug)] +pub struct NVCCError { + message: String, +} + +impl NVCCError { + pub fn new>(message: S) -> Self { + NVCCError { + message: message.into(), + } + } + + pub fn emit(&self) -> ErrorReported { + println!("{}", self.to_string()); + ErrorReported + } + + fn to_string(&self) -> String { + let label = format!("{}", self.message); + let snippet = Snippet { + title: Some(Annotation { + id: None, + label: Some(&label), + annotation_type: AnnotationType::Error, + }), + footer: vec![], + slices: vec![], + opt: default_format(), + }; + DisplayList::from(snippet).to_string() + } +} + +impl std::fmt::Display for NVCCError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NVCC Error: {}", self.message) + } +} + +impl std::error::Error for NVCCError {} + +#[derive(Debug)] +pub struct ExecutableError { + message: String, +} + +impl ExecutableError { + pub fn new>(message: S) -> Self { + ExecutableError { + message: message.into(), + } + } + + pub fn emit(&self) -> ErrorReported { + println!("{}", self.to_string()); + ErrorReported + } + + fn to_string(&self) -> String { + let label = format!("{}", self.message); + let snippet = Snippet { + title: Some(Annotation { + id: None, + label: Some(&label), + annotation_type: AnnotationType::Error, + }), + footer: vec![], + slices: vec![], + opt: default_format(), + }; + DisplayList::from(snippet).to_string() + } +} + +impl std::fmt::Display for ExecutableError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Executable Error: {}", self.message) + } +} + +impl std::error::Error for ExecutableError {} diff --git a/src/lib.rs b/src/lib.rs index 87a0c26d..62c11265 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,8 @@ extern crate core; use crate::error::ErrorReported; - +use bumpalo::Bump; +mod arena_ast; mod ast; mod codegen; pub mod error; @@ -10,7 +11,8 @@ pub mod ty_check; pub fn compile(file_path: &str) -> Result { let source = parser::SourceCode::from_file(file_path)?; - let mut compil_unit = parser::parse(&source)?; - ty_check::ty_check(&mut compil_unit)?; + let arena = Bump::new(); + let mut compil_unit = parser::parse(&arena, &source); + ty_check::ty_check(&mut compil_unit, &arena)?; Ok(codegen::gen(&compil_unit, false)) } diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 00000000..5eef6abf --- /dev/null +++ b/src/main.rs @@ -0,0 +1,281 @@ +use clap::{Args, Parser, Subcommand}; +use descend::error::NVCCError; +use descend::{compile, error::ErrorReported, error::ExecutableError, error::FileIOError}; +use env_logger::Env; +use log::LevelFilter; +use log::{debug, error, info}; +use std::env; +use std::fs; +use std::fs::write; +use std::path::PathBuf; +use std::process; +use std::process::{exit, Command}; +use std::time::{SystemTime, UNIX_EPOCH}; +use which::which; + +#[derive(Parser, Debug)] +#[command(name = "descendc", version = "1.0", about = "Descend Compiler")] +struct Cli { + /// Enable debug mode. + #[arg(short, long)] + debug: bool, + + /// Suppress warning if nvcc (CUDA Toolkit) is not installed. + #[arg(long)] + suppress_cuda_warning: bool, + + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand, Debug)] +enum Commands { + /// Emit the generated CUDA code without compiling. + Emit { + #[clap(flatten)] + common: CommonArgs, + }, + /// Compile the generated CUDA code into a binary. + Build { + #[clap(flatten)] + common: CommonArgs, + #[clap(flatten)] + build_run: BuildRunArgs, + }, + /// Compile the generated CUDA code and run the resulting binary. + Run { + #[clap(flatten)] + common: CommonArgs, + #[clap(flatten)] + build_run: BuildRunArgs, + }, +} + +/// Arguments common to all subcommands. +#[derive(Args, Debug)] +struct CommonArgs { + /// Input file (.desc) + input: String, + + /// Optionally write the CUDA code to a file (if not provided, uses default naming) + #[arg(short, long)] + output: Option, +} + +/// Arguments only applicable to Build and Run. +#[derive(Args, Debug)] +struct BuildRunArgs { + /// Specify CUDA architecture (e.g., sm_75, sm_80) + #[arg(long, default_value = "sm_75")] + arch: String, + + /// Optimization level for nvcc (0-3) + #[arg(long, default_value = "3")] + optimize: u8, + + /// Additional flags to pass directly to nvcc + #[arg(long, default_value = "")] + nvcc_flags: String, + + /// Only save the generated CUDA file if explicitly requested. + #[arg(long)] + save_cuda: bool, +} + +/// RAII wrapper for a temporary file. +struct TempFile { + path: PathBuf, +} + +impl TempFile { + fn new(path: PathBuf) -> Self { + TempFile { path } + } + + /// Returns the file path as a string. + fn path_string(&self) -> String { + self.path.to_string_lossy().into_owned() + } +} + +impl Drop for TempFile { + fn drop(&mut self) { + // Attempt to delete the file; report error if it fails. + if let Err(e) = fs::remove_file(&self.path) { + error!( + "Warning: failed to remove temporary file {:?}: {}", + self.path, e + ); + } else { + debug!("Temporary file {:?} deleted.", self.path); + } + } +} + +/// Checks if a command exists using the which crate +fn command_exists(cmd: &str) -> bool { + which(cmd).is_ok() +} + +fn generate_cuda(input: &str) -> Result { + compile(input) +} + +fn write_cuda_file(cuda_code: &str, filename: &str) -> Result<(), ErrorReported> { + write(filename, cuda_code).map_err(|e| FileIOError::new(filename, e).emit()) +} + +fn build_cuda( + cuda_file: &str, + executable: &str, + optimize: u8, + arch: &str, + nvcc_flags: &str, +) -> Result<(), ErrorReported> { + let mut nvcc_cmd = Command::new("nvcc"); + nvcc_cmd + .arg(cuda_file) + .arg("-o") + .arg(executable) + .arg(format!("-O{}", optimize)) + .arg("-I") + .arg("cuda-examples/") + .args(nvcc_flags.split_whitespace()); + if arch != "none" { + nvcc_cmd.arg(format!("-arch={}", arch)); + } + debug!("Running NVCC command: {:?}", nvcc_cmd); + let output = nvcc_cmd + .output() + .map_err(|_e| NVCCError::new("Failed to run nvcc command").emit())?; + if !output.status.success() { + return Err(NVCCError::new(format!( + "nvcc compilation failed:\n{}", + String::from_utf8_lossy(&output.stderr) + )) + .emit()); + } + Ok(()) +} + +fn run_executable(executable: &str) -> Result<(), ErrorReported> { + let output = Command::new(format!("./{}", executable)) + .output() + .map_err(|_e| ExecutableError::new("Failed to run the executable").emit())?; + + info!( + "Program output:\n{}", + String::from_utf8_lossy(&output.stdout) + ); + error!( + "Program errors:\n{}", + String::from_utf8_lossy(&output.stderr) + ); + Ok(()) +} + +fn handle_emit(common: CommonArgs) -> Result<(), ErrorReported> { + let cuda_code = generate_cuda(&common.input)?; + if let Some(file) = common.output { + write_cuda_file(&cuda_code, &file)?; + info!("CUDA code written to {}", file); + } else { + info!("Generated CUDA Code:\n{}", cuda_code); + } + Ok(()) +} + +fn handle_build_run( + common: CommonArgs, + build_run: BuildRunArgs, + run_after: bool, + suppress_cuda_warning: bool, +) -> Result<(), ErrorReported> { + if !command_exists("nvcc") { + if suppress_cuda_warning { + info!("Warning: 'nvcc' not found, but warnings are suppressed. Compilation will likely fail."); + } else { + return Err( + NVCCError::new("Error: 'nvcc' is not installed. Please install the CUDA Toolkit to compile the code.") + .emit() + ); + } + } + let cuda_code = generate_cuda(&common.input)?; + + // Determine the file name based on the --save-cuda flag. If save_cuda is false, generate a temporary file path. + let (cuda_file, _temp_guard): (String, Option) = if build_run.save_cuda { + ( + common + .output + .unwrap_or_else(|| common.input.replace(".desc", ".cu")), + None, + ) + } else { + let temp_dir = env::temp_dir(); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap_or_default() + .as_millis(); + let temp_filename = format!("descendc-{}-{}.cu", process::id(), timestamp); + let temp_path = temp_dir.join(temp_filename); + let file_str = temp_path.to_string_lossy().into_owned(); + (file_str, Some(TempFile::new(temp_path))) + }; + + write_cuda_file(&cuda_code, &cuda_file)?; + debug!("CUDA code written to {}", cuda_file); + + let executable = cuda_file.replace(".cu", ""); + build_cuda( + &cuda_file, + &executable, + build_run.optimize, + &build_run.arch, + &build_run.nvcc_flags, + )?; + debug!("Compilation successful: {}", executable); + + if run_after { + run_executable(&executable)?; + } + + Ok(()) +} + +fn main() { + let cli = Cli::parse(); + + let default_log_level = if cli.debug { + LevelFilter::Debug + } else { + LevelFilter::Info + }; + + env_logger::Builder::from_env(Env::default().default_filter_or(default_log_level.to_string())) + .init(); + + if cli.debug { + info!("Debug mode enabled."); + } + + if !command_exists("clang-format") { + error!("Error: 'clang-format' is not installed. Please install clang-format to proceed."); + exit(1); + } + + let result = match cli.command { + Commands::Emit { common } => handle_emit(common), + Commands::Build { common, build_run } => { + handle_build_run(common, build_run, false, cli.suppress_cuda_warning) + } + Commands::Run { common, build_run } => { + handle_build_run(common, build_run, true, cli.suppress_cuda_warning) + } + }; + + if let Err(e) = result { + error!("{:#}", e); + exit(1); + } +} diff --git a/src/parser/mod.rs b/src/parser/mod.rs index b2297a3c..aa87fb2c 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -2,45 +2,83 @@ pub mod source; mod utils; +use crate::arena_ast::visit_mut as arena_visit_mut; +use crate::arena_ast::visit_mut::VisitMut as ArenaVisitMut; +use crate::arena_ast::AppKernel as ArenaAppKernel; +use crate::arena_ast::ArgKinded as ArenaArgKinded; +use crate::arena_ast::BaseExec as ArenaBaseExec; +use crate::arena_ast::Block as ArenaBlock; +use crate::arena_ast::CompilUnit as ArenaCompilUnit; +use crate::arena_ast::DataTy as ArenaDataTy; +use crate::arena_ast::DataTyKind as ArenaDataTyKind; +use crate::arena_ast::ExecExpr as ArenaExecExpr; +use crate::arena_ast::ExecExprKind as ArenaExecExprKind; +use crate::arena_ast::ExecTyKind as ArenaExecTyKind; +use crate::arena_ast::Expr as ArenaExpr; +use crate::arena_ast::ExprKind as ArenaExprKind; +use crate::arena_ast::FunDef as ArenaFunDef; +use crate::arena_ast::Ident as ArenaIdent; +use crate::arena_ast::IdentKinded as ArenaIdentKinded; +use crate::arena_ast::Item as ArenaItem; +use crate::arena_ast::Kind as ArenaKind; +use crate::arena_ast::LeftOrRight as ArenaLeftOrRight; +use crate::arena_ast::Memory as ArenaMemory; +use crate::arena_ast::Nat as ArenaNat; +use crate::arena_ast::Provenance as ArenaProvenance; +use crate::arena_ast::Sched as ArenaSched; +use crate::arena_ast::Split as ArenaSplit; +use crate::arena_ast::StructDecl as ArenaStructDecl; +use crate::arena_ast::View as ArenaView; use crate::ast::*; + use core::iter; use error::ParseError; use std::collections::HashMap; use crate::error::ErrorReported; +use bumpalo::{collections::Vec as BumpVec, Bump}; pub use source::*; -use crate::ast::visit_mut::VisitMut; - -pub fn parse<'a>(source: &'a SourceCode<'a>) -> Result { +pub fn parse<'a>( + arena: &'a bumpalo::Bump, + source: &'a SourceCode<'a>, +) -> Result, ErrorReported> { let parser = Parser::new(source); - let mut items = parser.parse().map_err(|err| err.emit())?; - // TODO refactor to not require unnecessary copying out of items - let struct_copies = items - .iter() - .filter_map(|i| { - if let Item::StructDecl(struct_dty) = i { - Some(struct_dty.as_ref()) - } else { - None + let heap_items = parser.parse().map_err(|err| err.emit())?; + + // 1) heap -> arena, collect &'a StructDecl while we build the vec + let mut arena_items: bumpalo::collections::Vec<'a, ArenaItem<'a>> = + bumpalo::collections::Vec::new_in(arena); + let mut struct_refs: Vec<&'a ArenaStructDecl<'a>> = Vec::new(); + + for heap_item in heap_items { + match heap_item.into_arena(arena) { + ArenaItem::StructDecl(sd_ref) => { + struct_refs.push(sd_ref); + arena_items.push(ArenaItem::StructDecl(sd_ref)); + } + other => { + arena_items.push(other); } - }) - .cloned() - .collect::>(); - for fun_def in &mut items.iter_mut().filter_map(|i| { - if let Item::FunDef(fun_def) = i { - Some(fun_def) - } else { - None } - }) { - replace_arg_kinded_idents(fun_def); - replace_exec_idents_with_specific_execs(fun_def); } - for i in &mut items.iter_mut() { - replace_struct_idents_with_specific_struct_dtys(&struct_copies, i); + + // 2) mutate fun defs (safe; we’re not holding borrows into arena_items) + for item in arena_items.iter_mut() { + if let ArenaItem::FunDef(fun_def) = item { + let mut fun_def_owned = (**fun_def).clone(); + replace_arg_kinded_idents(&mut fun_def_owned, arena); + replace_exec_idents_with_specific_execs(arena, &mut fun_def_owned); + *fun_def = arena.alloc(fun_def_owned); + } + } + + // 3) resolve struct idents using the collected &'a decls + for item in arena_items.iter_mut() { + replace_struct_idents_with_specific_struct_dtys(arena, &struct_refs, item); } - Ok(CompilUnit::new(items, source)) + + Ok(ArenaCompilUnit::new(arena_items, source)) } #[derive(Debug)] @@ -58,24 +96,35 @@ impl<'a> Parser<'a> { } } -fn replace_arg_kinded_idents(fun_def: &mut FunDef) { +fn replace_arg_kinded_idents<'a>(fun_def: &mut ArenaFunDef<'a>, arena: &'a Bump) { struct ReplaceArgKindedIdents { - ident_names_to_kinds: HashMap, Kind>, + ident_names_to_kinds: HashMap, ArenaKind>, } - impl ReplaceArgKindedIdents { - fn subst_in_gen_args(&self, gen_args: &mut [ArgKinded]) { + impl<'a> ReplaceArgKindedIdents { + fn subst_in_gen_args(&self, arena: &'a Bump, gen_args: &mut [ArenaArgKinded<'a>]) { for gen_arg in gen_args { - if let ArgKinded::Ident(ident) = gen_arg { + if let ArenaArgKinded::Ident(ident) = gen_arg { let to_be_kinded = ident.clone(); - match self.ident_names_to_kinds.get(&ident.name).unwrap() { - Kind::Provenance => { - *gen_arg = ArgKinded::Provenance(Provenance::Ident(to_be_kinded)) - } - Kind::Memory => *gen_arg = ArgKinded::Memory(Memory::Ident(to_be_kinded)), - Kind::Nat => *gen_arg = ArgKinded::Nat(Nat::Ident(to_be_kinded)), - Kind::DataTy => { + match self + .ident_names_to_kinds + .get::(ident.name.as_ref()) + .unwrap() + { + ArenaKind::Provenance => { *gen_arg = - ArgKinded::DataTy(DataTy::new(DataTyKind::Ident(to_be_kinded))) + ArenaArgKinded::Provenance(ArenaProvenance::Ident(to_be_kinded)) + } + ArenaKind::Memory => { + *gen_arg = ArenaArgKinded::Memory(ArenaMemory::Ident(to_be_kinded)) + } + ArenaKind::Nat => { + *gen_arg = ArenaArgKinded::Nat(ArenaNat::Ident(to_be_kinded)) + } + ArenaKind::DataTy => { + *gen_arg = ArenaArgKinded::DataTy(ArenaDataTy::new( + arena, + ArenaDataTyKind::Ident(to_be_kinded), + )) } } } @@ -83,160 +132,228 @@ fn replace_arg_kinded_idents(fun_def: &mut FunDef) { } } - impl VisitMut for ReplaceArgKindedIdents { - fn visit_expr(&mut self, expr: &mut Expr) { + impl<'a> ArenaVisitMut<'a> for ReplaceArgKindedIdents { + fn visit_expr(&mut self, arena: &'a Bump, expr: &mut ArenaExpr<'a>) { match &mut expr.expr { - ExprKind::Block(block) => { - self.ident_names_to_kinds.extend( - block - .prvs - .iter() - .map(|prv| (prv.clone().into_boxed_str(), Kind::Provenance)), - ); - self.visit_expr(&mut block.body) + ArenaExprKind::Block(block_ref) => { + let src = *block_ref; + + let added_keys: Vec> = src + .prvs + .iter() + .map(|p| p.clone().into_boxed_str()) + .collect(); + for k in &added_keys { + self.ident_names_to_kinds + .insert(k.clone(), ArenaKind::Provenance); + } + + let mut body_owned = (*src.body).clone(); + self.visit_expr(arena, &mut body_owned); + let body_ref = arena.alloc(body_owned); + + let mut new_prvs = BumpVec::new_in(arena); + new_prvs.extend(src.prvs.iter().cloned()); + + let new_block = ArenaBlock { + prvs: new_prvs, + body: body_ref, + }; + + *block_ref = arena.alloc(new_block); + + for k in added_keys { + self.ident_names_to_kinds.remove(k.as_ref()); + } } - ExprKind::DepApp(fun_ident, gen_args) => { - self.visit_ident(fun_ident); - self.subst_in_gen_args(gen_args); + ArenaExprKind::DepApp(fun_ident, gen_args) => { + self.visit_ident(arena, fun_ident); + self.subst_in_gen_args(arena, gen_args); } - ExprKind::AppKernel(app_kernel) => { - let AppKernel { - fun_ident, + ArenaExprKind::AppKernel(app_kernel_ref) => { + let src = *app_kernel_ref; + let mut grid_dim = src.grid_dim.clone(); + let mut block_dim = src.block_dim.clone(); + self.visit_dim(arena, &mut grid_dim); + self.visit_dim(arena, &mut block_dim); + let mut fun_ident_owned = (*src.fun_ident).clone(); + self.visit_ident(arena, &mut fun_ident_owned); + let fun_ident_ref = arena.alloc(fun_ident_owned); + + let mut gen_args = BumpVec::new_in(arena); + gen_args.extend(src.gen_args.iter().cloned()); + self.subst_in_gen_args(arena, gen_args.as_mut_slice()); + + let mut args = BumpVec::new_in(arena); + for mut e in src.args.iter().cloned() { + self.visit_expr(arena, &mut e); + args.push(e); + } + + let mut shared_mem_dtys = BumpVec::new_in(arena); + for mut d in src.shared_mem_dtys.iter().cloned() { + self.visit_dty(arena, &mut d); + shared_mem_dtys.push(d); + } + + let mut shared_mem_prvs = BumpVec::new_in(arena); + shared_mem_prvs.extend(src.shared_mem_prvs.iter().cloned()); + + let new_node = ArenaAppKernel { + grid_dim, + block_dim, + shared_mem_dtys, + shared_mem_prvs, + fun_ident: fun_ident_ref, gen_args, args, - .. - } = app_kernel.as_mut(); - self.visit_ident(fun_ident); - self.subst_in_gen_args(gen_args); - visit_mut::walk_list!(self, visit_expr, args) + }; + *app_kernel_ref = arena.alloc(new_node); } - ExprKind::App(fun_ident, gen_args, args) => { - self.visit_ident(fun_ident); - self.subst_in_gen_args(gen_args); - visit_mut::walk_list!(self, visit_expr, args) + + ArenaExprKind::App(fun_ident, gen_args, args) => { + let mut fun_ident_owned = (**fun_ident).clone(); + self.visit_ident(arena, &mut fun_ident_owned); + *fun_ident = arena.alloc(fun_ident_owned); + self.subst_in_gen_args(arena, gen_args); + arena_visit_mut::walk_list!(self, visit_expr, args, arena) } - ExprKind::ForNat(ident, _, body) => { - self.ident_names_to_kinds - .extend(iter::once((ident.name.clone(), Kind::Nat))); - self.visit_expr(body) + ArenaExprKind::ForNat(ident, _, body) => { + self.ident_names_to_kinds.extend(iter::once(( + ident.name.to_owned().into_boxed_str(), + ArenaKind::Nat, + ))); + let mut body_owned = (**body).clone(); + self.visit_expr(arena, &mut body_owned); + *body = arena.alloc(body_owned); } - _ => visit_mut::walk_expr(self, expr), + _ => arena_visit_mut::walk_expr(self, arena, expr), } } - fn visit_view(&mut self, view: &mut View) { - self.subst_in_gen_args(&mut view.gen_args); + fn visit_view(&mut self, arena: &'a Bump, view: &mut ArenaView<'a>) { + self.subst_in_gen_args(arena, &mut view.gen_args); for v in &mut view.args { - self.visit_view(v) + self.visit_view(arena, v) } } - fn visit_fun_def(&mut self, fun_def: &mut FunDef) { + fn visit_fun_def(&mut self, arena: &'a Bump, fun_def: &mut ArenaFunDef<'a>) { self.ident_names_to_kinds = fun_def .generic_params .iter() - .map(|IdentKinded { ident, kind }| (ident.name.clone(), *kind)) + .map(|ArenaIdentKinded { ident, kind }| { + (ident.name.to_owned().into_boxed_str(), *kind) + }) .collect(); - visit_mut::walk_fun_def(self, fun_def) + arena_visit_mut::walk_fun_def(self, arena, fun_def) } } let mut replace = ReplaceArgKindedIdents { ident_names_to_kinds: HashMap::new(), }; - replace.visit_fun_def(fun_def); + replace.visit_fun_def(arena, fun_def); } -fn replace_exec_idents_with_specific_execs(fun_def: &mut FunDef) { - struct ReplaceExecIdents { - ident_names_to_exec_expr: Vec<(Box, ExecExpr)>, +fn replace_exec_idents_with_specific_execs<'a>(arena: &'a Bump, fun_def: &mut ArenaFunDef<'a>) { + struct ReplaceExecIdents<'a> { + ident_names_to_exec_expr: Vec<(Box, ArenaExecExpr<'a>)>, } - impl VisitMut for ReplaceExecIdents { - fn visit_split(&mut self, indep: &mut Split) { + + impl<'a> ArenaVisitMut<'a> for ReplaceExecIdents<'a> { + fn visit_split(&mut self, arena: &'a Bump, indep: &mut ArenaSplit<'a>) { // manually expand to keep scopes for different branches of split - expand_exec_expr(&self.ident_names_to_exec_expr, &mut indep.split_exec); + expand_exec_expr(arena, &self.ident_names_to_exec_expr, &mut indep.split_exec); for (i, (ident, branch)) in indep .branch_idents .iter() .zip(&mut indep.branch_bodies) .enumerate() { - let branch_exec_expr = ExecExpr::new(indep.split_exec.exec.clone().split_proj( - indep.dim_compo, - indep.pos.clone(), - if i == 0 { - LeftOrRight::Left - } else if i == 1 { - LeftOrRight::Right - } else { - panic!("Unexpected projection.") - }, - )); + let branch_exec_expr = ArenaExecExpr::new( + arena, + indep.split_exec.exec.clone().split_proj( + arena, + indep.dim_compo, + indep.pos.clone(), + if i == 0 { + ArenaLeftOrRight::Left + } else if i == 1 { + ArenaLeftOrRight::Right + } else { + panic!("Unexpected projection.") + }, + ), + ); self.ident_names_to_exec_expr - .push((ident.name.clone(), branch_exec_expr)); - self.visit_expr(branch); + .push((ident.name.to_owned().into_boxed_str(), branch_exec_expr)); + self.visit_expr(arena, branch); self.ident_names_to_exec_expr.pop(); } } - fn visit_sched(&mut self, sched: &mut Sched) { + fn visit_sched(&mut self, arena: &'a Bump, sched: &mut ArenaSched<'a>) { // manually expand to map inner_exec_ident to expanded exec - expand_exec_expr(&self.ident_names_to_exec_expr, &mut sched.sched_exec); - let body_exec = ExecExpr::new(sched.sched_exec.exec.clone().forall(sched.dim)); + expand_exec_expr(arena, &self.ident_names_to_exec_expr, &mut sched.sched_exec); + let body_exec = + ArenaExecExpr::new(arena, sched.sched_exec.exec.clone().forall(sched.dim)); if let Some(ident) = &sched.inner_exec_ident { self.ident_names_to_exec_expr - .push((ident.name.clone(), body_exec)); + .push((ident.name.to_owned().into_boxed_str(), body_exec)); } - visit_mut::walk_sched(self, sched); + arena_visit_mut::walk_sched(self, arena, sched); self.ident_names_to_exec_expr.pop(); // self.visit_block(&mut sched.body); } - fn visit_exec_expr(&mut self, exec_expr: &mut ExecExpr) { - expand_exec_expr(&self.ident_names_to_exec_expr, exec_expr) + fn visit_exec_expr(&mut self, arena: &'a Bump, exec_expr: &mut ArenaExecExpr<'a>) { + expand_exec_expr(arena, &self.ident_names_to_exec_expr, exec_expr); } - // fn visit_expr(&mut self, expr: &mut Expr) { - // match &mut expr.expr { - // ExprKind::Sync(exec) => { - // for exec in exec { - // expand_exec_expr(&self.ident_names_to_exec_expr, exec); - // } - // } - // _ => visit_mut::walk_expr(self, expr), - // } - // } - - fn visit_fun_def(&mut self, fun_def: &mut FunDef) { + fn visit_fun_def(&mut self, arena: &'a Bump, fun_def: &mut ArenaFunDef<'a>) { if let Some(ident_exec) = fun_def.generic_exec.as_ref() { match &ident_exec.ty.ty { - ExecTyKind::CpuThread => { + ArenaExecTyKind::CpuThread => { self.ident_names_to_exec_expr.push(( - ident_exec.ident.name.clone(), - ExecExpr::new(ExecExprKind::new(BaseExec::CpuThread)), + ident_exec.ident.name.to_owned().into_boxed_str(), + ArenaExecExpr::new( + arena, + ArenaExecExprKind::new(arena, ArenaBaseExec::CpuThread), + ), )); fun_def.generic_exec = None; } - ExecTyKind::GpuGrid(gdim, bdim) => { + ArenaExecTyKind::GpuGrid(gdim, bdim) => { self.ident_names_to_exec_expr.push(( - ident_exec.ident.name.clone(), - ExecExpr::new(ExecExprKind::new(BaseExec::GpuGrid( - gdim.clone(), - bdim.clone(), - ))), + ident_exec.ident.name.to_owned().into_boxed_str(), + ArenaExecExpr::new( + arena, + ArenaExecExprKind::new( + arena, + ArenaBaseExec::GpuGrid( + arena.alloc(gdim.clone()), + arena.alloc(bdim.clone()), + ), + ), + ), )); fun_def.generic_exec = None; } _ => {} } } - visit_mut::walk_fun_def(self, fun_def) + arena_visit_mut::walk_fun_def(self, arena, fun_def) } } - fn expand_exec_expr(exec_mapping: &[(Box, ExecExpr)], exec_expr: &mut ExecExpr) { + /** + fn expand_exec_expr<'a>( + exec_mapping: &'a [(Box, ArenaExecExpr)], + exec_expr: &mut ArenaExecExpr<'a>, + ) { match &exec_expr.exec.base { - BaseExec::CpuThread | BaseExec::GpuGrid(_, _) => {} - BaseExec::Ident(ident) => { + ArenaBaseExec::CpuThread | ArenaBaseExec::GpuGrid(_, _) => {} + ArenaBaseExec::Ident(ident) => { if let Some(exec) = get_exec_expr(exec_mapping, ident) { let new_base = exec.exec.base.clone(); let mut new_exec_path = exec.exec.path.clone(); @@ -246,11 +363,36 @@ fn replace_exec_idents_with_specific_execs(fun_def: &mut FunDef) { }; } } + }*/ + + fn expand_exec_expr<'a>( + arena: &'a bumpalo::Bump, + exec_mapping: &'a [(Box, ArenaExecExpr<'a>)], + exec_expr: &ArenaExecExpr<'a>, + ) -> ArenaExecExpr<'a> { + match &exec_expr.exec.base { + ArenaBaseExec::CpuThread | ArenaBaseExec::GpuGrid(_, _) => exec_expr.clone_in(arena), + ArenaBaseExec::Ident(ident) => { + if let Some(mapped_exec) = get_exec_expr(exec_mapping, ident) { + let new_base = mapped_exec.exec.base.clone(); + let mut new_path = BumpVec::new_in(arena); + new_path.extend(mapped_exec.exec.path.iter().cloned()); + new_path.extend(exec_expr.exec.path.iter().cloned()); + + ArenaExecExpr::new(arena, ArenaExecExprKind::with_path(new_base, new_path)) + } else { + exec_expr.clone_in(arena) + } + } + } } - fn get_exec_expr(exec_mapping: &[(Box, ExecExpr)], ident: &Ident) -> Option { + fn get_exec_expr<'a>( + exec_mapping: &'a [(Box, ArenaExecExpr)], + ident: &'a ArenaIdent<'a>, + ) -> Option> { for (i, exec) in exec_mapping.iter().rev() { - if i == &ident.name { + if i.as_ref() == ident.name { return Some(exec.clone()); } } @@ -260,33 +402,53 @@ fn replace_exec_idents_with_specific_execs(fun_def: &mut FunDef) { let mut replace_exec_idents = ReplaceExecIdents { ident_names_to_exec_expr: vec![], }; - replace_exec_idents.visit_fun_def(fun_def); + replace_exec_idents.visit_fun_def(arena, fun_def); } -fn replace_struct_idents_with_specific_struct_dtys(struct_dtys: &[StructDecl], item: &mut Item) { - struct ReplaceStructIdents<'a> { - struct_dtys: &'a [StructDecl], +fn replace_struct_idents_with_specific_struct_dtys<'a, 's>( + arena: &'a bumpalo::Bump, + struct_dtys: &'s [&'a ArenaStructDecl<'a>], // outer borrow 's, elements 'a + item: &mut ArenaItem<'a>, +) { + struct ReplaceStructIdents<'s, 'a> { + struct_dtys: &'s [&'a ArenaStructDecl<'a>], } - impl<'a> VisitMut for ReplaceStructIdents<'a> { - fn visit_dty(&mut self, dty: &mut DataTy) { - if let DataTyKind::Ident(ident) = &mut dty.dty { - if let Some(struct_decl) = self.struct_dtys.iter().find(|s| &s.ident == ident) { - dty.dty = DataTyKind::Struct(Box::new(struct_decl.clone())) + + impl<'s, 'a> ArenaVisitMut<'a> for ReplaceStructIdents<'s, 'a> { + fn visit_dty(&mut self, arena: &'a bumpalo::Bump, dty: &mut ArenaDataTy<'a>) { + if let ArenaDataTyKind::Ident(ident) = &mut dty.dty { + if let Some(sd) = self + .struct_dtys + .iter() + .copied() // &&T -> &T + .find(|sd| &sd.ident == ident) + { + dty.dty = ArenaDataTyKind::Struct(sd); + return; } - } else { - visit_mut::walk_dty(self, dty) } + arena_visit_mut::walk_dty(self, arena, dty) } } - let mut replace_struct_idents = ReplaceStructIdents { struct_dtys }; + let mut v = ReplaceStructIdents { struct_dtys }; match item { - Item::FunDef(fun_def) => replace_struct_idents.visit_fun_def(fun_def), - Item::FunDecl(fun_decl) => replace_struct_idents.visit_fun_decl(fun_decl), + ArenaItem::FunDef(fd) => { + let mut fd_owned = (**fd).clone(); + v.visit_fun_def(arena, &mut fd_owned); + *fd = arena.alloc(fd_owned); + } + ArenaItem::FunDecl(fd) => { + let mut fd_owned = (**fd).clone_in(arena); + v.visit_fun_decl(arena, &mut fd_owned); + *fd = arena.alloc(fd_owned); + } _ => {} } } +// + pub mod error { use crate::error::ErrorReported; use crate::parser::{Parser, SourceCode}; @@ -1063,7 +1225,6 @@ peg::parser! { mod tests { use super::*; - #[test] fn nat_literal() { assert_eq!(descend::nat("0"), Ok(Nat::Lit(0)), "cannot parse 0"); @@ -2471,12 +2632,20 @@ mod tests { #[test] fn empty_annotate_snippet() { let source = SourceCode::new("fn\n".to_string()); - assert!(parse(&source).is_err(), "Expected a parsing error and specifically not a panic!"); + let bump: Bump = Bump::new(); + assert!( + parse(&bump, &source).is_err(), + "Expected a parsing error and specifically not a panic!" + ); } #[test] fn empty_annotate_snippet2() { let source = SourceCode::new("fn ".to_string()); - assert!(parse(&source).is_err(), "Expected a parsing error and specifically not a panic!"); + let bump: Bump = Bump::new(); + assert!( + parse(&bump, &source).is_err(), + "Expected a parsing error and specifically not a panic!" + ); } } diff --git a/src/ty_check/borrow_check.rs b/src/ty_check/borrow_check.rs index 91c4823a..4a82c0cb 100644 --- a/src/ty_check/borrow_check.rs +++ b/src/ty_check/borrow_check.rs @@ -1,33 +1,33 @@ use super::ctxs::TyCtx; -use crate::ast::internal::{Loan, PlaceCtx, PrvMapping}; -use crate::ast::*; -use crate::parser::descend::nat; +use crate::arena_ast::internal::{Loan, PlaceCtx, PrvMapping}; +use crate::arena_ast::*; use crate::ty_check::ctxs::{AccessCtx, GlobalCtx, KindCtx}; use crate::ty_check::error::BorrowingError; use crate::ty_check::exec::normalize; use crate::ty_check::{exec, pre_decl, ExprTyCtx}; +use bumpalo::{collections::Vec as BumpVec, Bump}; use std::collections::HashSet; -type OwnResult = Result; +type OwnResult<'a, T> = Result>; -pub(super) struct BorrowCheckCtx<'gl, 'src, 'ctxt> { +pub(super) struct BorrowCheckCtx<'a> { // TODO refactor: move into ctx module and remove public - pub gl_ctx: &'ctxt GlobalCtx<'gl, 'src>, - pub nat_ctx: &'ctxt NatCtx, - pub kind_ctx: &'ctxt KindCtx, - pub ident_exec: Option<&'ctxt IdentExec>, - pub ty_ctx: &'ctxt TyCtx, - pub access_ctx: &'ctxt AccessCtx, - pub exec: ExecExpr, - pub reborrows: Vec, + pub gl_ctx: &'a GlobalCtx<'a>, + pub nat_ctx: &'a NatCtx<'a>, + pub kind_ctx: &'a KindCtx<'a>, + pub ident_exec: Option<&'a IdentExec<'a>>, + pub ty_ctx: &'a TyCtx<'a>, + pub access_ctx: &'a AccessCtx<'a>, + pub exec: ExecExpr<'a>, + pub reborrows: Vec>, pub own: Ownership, pub unsafe_flag: bool, } -impl<'gl, 'src, 'ctxt> BorrowCheckCtx<'gl, 'src, 'ctxt> { +impl<'a> BorrowCheckCtx<'a> { pub(super) fn new( - expr_ty_ctx: &'ctxt ExprTyCtx<'gl, 'src, 'ctxt>, - reborrows: Vec, + expr_ty_ctx: &'a ExprTyCtx<'a>, + reborrows: Vec>, own: Ownership, ) -> Self { BorrowCheckCtx { @@ -46,7 +46,7 @@ impl<'gl, 'src, 'ctxt> BorrowCheckCtx<'gl, 'src, 'ctxt> { fn extend_reborrows(&self, iter: I) -> Self where - I: Iterator, + I: Iterator>, { let mut extended_reborrows = self.reborrows.clone(); extended_reborrows.extend(iter); @@ -69,51 +69,70 @@ impl<'gl, 'src, 'ctxt> BorrowCheckCtx<'gl, 'src, 'ctxt> { // Ownership Safety // //p is ω-safe under δ and γ, with reborrow exclusion list π , and may point to any of the loans in ωp -pub(super) fn access_safety_check(ctx: &BorrowCheckCtx, p: &PlaceExpr) -> OwnResult> { +pub(super) fn access_safety_check<'a>( + ctx: &'a BorrowCheckCtx<'a>, + p: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { if !ctx.unsafe_flag { - narrowing_check(ctx, p, &ctx.exec)?; - access_conflict_check(ctx, p)?; + narrowing_check(ctx, p, &ctx.exec, arena)?; + access_conflict_check(ctx, p, arena)?; } - borrow_check(ctx, p) + borrow_check(ctx, p, arena) } -pub(super) fn borrow_check(ctx: &BorrowCheckCtx, p: &PlaceExpr) -> OwnResult> { - let (pl_ctx, most_spec_pl) = p.to_pl_ctx_and_most_specif_pl(); +pub(super) fn borrow_check<'a>( + ctx: &'a BorrowCheckCtx<'a>, + p: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + let (pl_ctx, most_spec_pl) = p.to_pl_ctx_and_most_specif_pl(arena); + + let pl_ctx: &'a PlaceCtx<'a> = arena.alloc(pl_ctx); + let most_spec_pl = arena.alloc(most_spec_pl); + if p.is_place() { - ownership_safe_place(ctx, p) + ownership_safe_place(ctx, p, arena) } else { - let pl_ctx_no_deref = pl_ctx.without_innermost_deref(); + let pl_ctx_no_deref = pl_ctx.without_innermost_deref(arena); + // Γ(π) = &r ωπ τπ match &ctx.ty_ctx.place_dty(&most_spec_pl)?.dty { DataTyKind::Ref(reff) => match &reff.rgn { Provenance::Value(prv_val_name) => ownership_safe_deref( ctx, - &pl_ctx_no_deref, - &most_spec_pl, - prv_val_name.as_str(), + pl_ctx_no_deref, + most_spec_pl, + *prv_val_name, reff.own, + arena, ), Provenance::Ident(_) => { - ownership_safe_deref_abs(ctx, &pl_ctx_no_deref, &most_spec_pl, reff.own) + ownership_safe_deref_abs(ctx, pl_ctx_no_deref, most_spec_pl, reff.own, arena) } }, - DataTyKind::RawPtr(_) => ownership_safe_deref_raw(ctx, &pl_ctx_no_deref, &most_spec_pl), + DataTyKind::RawPtr(_) => { + ownership_safe_deref_raw(ctx, pl_ctx_no_deref, most_spec_pl, arena) + } // TODO improve error message - t => ownership_safe_place(ctx, p), //panic!("Is the type dead? `{:?}`\n {:?}", t, p), + t => ownership_safe_place(ctx, p, arena), //panic!("Is the type dead? `{:?}`\n {:?}", t, p), } } } // TODO remove? -fn ownership_safe_deref_raw( - ctx: &BorrowCheckCtx, - pl_ctx_no_deref: &PlaceCtx, - most_spec_pl: &internal::Place, -) -> OwnResult> { +/** +fn ownership_safe_deref_raw<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + most_spec_pl: &'a internal::Place<'a>, + arena: &'a Bump, +) -> OwnResult<'a, HashSet>> { // TODO is this correct? - let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr(PlaceExpr::new( - PlaceExprKind::Deref(Box::new(most_spec_pl.to_place_expr())), - )); + let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr( + arena, + PlaceExpr::new(PlaceExprKind::Deref(&most_spec_pl.to_place_expr(arena))), + ); let mut passed_through_prvs = HashSet::new(); passed_through_prvs.insert(Loan { place_expr: currently_checked_pl_expr, @@ -121,30 +140,62 @@ fn ownership_safe_deref_raw( }); Ok(passed_through_prvs) } +*/ + +fn ownership_safe_deref_raw<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + most_spec_pl: &'a internal::Place<'a>, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + // 1) Build the inner PlaceExpr and bump-allocate it + let inner_pe: PlaceExpr<'a> = most_spec_pl.to_place_expr(arena); + let inner_ref: &'a PlaceExpr<'a> = arena.alloc(inner_pe); + + // 2) Create the Deref node around that borrowed reference + let deref_pe = PlaceExpr::new(PlaceExprKind::Deref(inner_ref)); + + // 3) Insert into the context to get back an &'a PlaceExpr<'a> + let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr(arena, deref_pe); + + // 4) Build the result + let mut passed_through_prvs = BumpVec::new_in(arena); + passed_through_prvs.push(Loan { + place_expr: currently_checked_pl_expr, + own: ctx.own, + }); + + Ok(passed_through_prvs) +} -fn ownership_safe_place(ctx: &BorrowCheckCtx, p: &PlaceExpr) -> OwnResult> { - ownership_safe_under_existing_borrows(ctx, p)?; - let mut loan_set = HashSet::new(); - loan_set.insert(Loan { +fn ownership_safe_place<'a>( + ctx: &'a BorrowCheckCtx<'a>, + p: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + ownership_safe_under_existing_borrows(ctx, p, arena)?; + let mut loan_set = BumpVec::new_in(arena); + loan_set.push(Loan { place_expr: p.clone(), own: ctx.own, }); Ok(loan_set) } -fn ownership_safe_deref( - ctx: &BorrowCheckCtx, - pl_ctx_no_deref: &PlaceCtx, - most_spec_pl: &internal::Place, +fn ownership_safe_deref<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + most_spec_pl: &'a internal::Place<'a>, prv_val_name: &str, ref_own: Ownership, -) -> OwnResult> { + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { // Γ(r) = { ω′pi } let loans_in_prv = ctx.ty_ctx.loans_in_prv(prv_val_name)?; // ω ≲ ωπ new_own_weaker_equal(ctx.own, ref_own)?; // List - let pl_ctxs_and_places_in_loans = pl_ctxs_and_places_in_loans(loans_in_prv); + let pl_ctxs_and_places_in_loans = pl_ctxs_and_places_in_loans(loans_in_prv, arena); // List<πe>, List<πi>, π let ext_reborrow_ctx = ctx.extend_reborrows( pl_ctxs_and_places_in_loans @@ -152,96 +203,184 @@ fn ownership_safe_deref( .chain(std::iter::once(most_spec_pl.clone())), ); // ext_reborrow_ctx.exec = from_exec.clone(); + let ext_ctx: &'a BorrowCheckCtx<'a> = arena.alloc(ext_reborrow_ctx); // ∀i ∈ {1...n}.Δ;Γ ⊢ω List<πe>,List<πi>,π p□[pi] ⇒ {ω pi′} - let mut potential_prvs_after_subst = subst_pl_with_potential_prvs_ownership_safe( - &ext_reborrow_ctx, - pl_ctx_no_deref, - loans_in_prv, - )?; - - let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr(PlaceExpr::new( - PlaceExprKind::Deref(Box::new(most_spec_pl.to_place_expr())), - )); - ownership_safe_under_existing_borrows(&ext_reborrow_ctx, ¤tly_checked_pl_expr)?; - potential_prvs_after_subst.insert(Loan { - place_expr: currently_checked_pl_expr, + let mut potential_prvs_after_subst = + subst_pl_with_potential_prvs_ownership_safe(ext_ctx, pl_ctx_no_deref, loans_in_prv, arena)?; + + let inner_pe: PlaceExpr<'a> = most_spec_pl.to_place_expr(arena); + let inner_ref: &'a PlaceExpr<'a> = arena.alloc(inner_pe); + + // 2) Wrap in a Deref and bump‐allocate that too + let wrapper = PlaceExpr::new(PlaceExprKind::Deref(inner_ref)); + let stripped_ref: &'a PlaceExpr<'a> = arena.alloc(wrapper); + + ownership_safe_under_existing_borrows(ext_ctx, stripped_ref, arena)?; + potential_prvs_after_subst.push(Loan { + place_expr: stripped_ref.clone(), own: ctx.own, }); Ok(potential_prvs_after_subst) } -fn subst_pl_with_potential_prvs_ownership_safe( - ctx: &BorrowCheckCtx, - pl_ctx_no_deref: &PlaceCtx, - loans_in_prv: &HashSet, -) -> OwnResult> { - let mut loans: HashSet = HashSet::new(); +/** +fn subst_pl_with_potential_prvs_ownership_safe<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + loans_in_prv: &HashSet>, + arena: &'a Bump, +) -> OwnResult<'a, HashSet>> { + let mut loans: HashSet> = HashSet::new(); + for pl_expr in loans_in_prv.iter().map(|loan| &loan.place_expr) { - let insert_dereferenced_pl_expr = pl_ctx_no_deref.insert_pl_expr(pl_expr.clone()); + let insert_dereferenced_pl_expr = pl_ctx_no_deref.insert_pl_expr(arena, pl_expr.clone()); let loans_for_possible_prv_pl_expr = - access_safety_check(ctx, &insert_dereferenced_pl_expr)?; + access_safety_check(ctx, &insert_dereferenced_pl_expr, arena)?; loans.extend(loans_for_possible_prv_pl_expr); } + Ok(loans) } +*/ -fn ownership_safe_deref_abs( - ctx: &BorrowCheckCtx, - pl_ctx_no_deref: &PlaceCtx, - most_spec_pl: &internal::Place, +fn subst_pl_with_potential_prvs_ownership_safe<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + loans_in_prv: &HashSet>, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + // 1) collect into a bump-vector + let mut collected = BumpVec::new_in(arena); + for loan in loans_in_prv { + let base_pe: PlaceExpr<'a> = loan.place_expr.clone(); + let inserted_pe = pl_ctx_no_deref.insert_pl_expr(arena, base_pe); + let pe_ref: &'a PlaceExpr<'a> = arena.alloc(inserted_pe); + + let new_loans = access_safety_check(ctx, pe_ref, arena)?; + for l in new_loans { + collected.push(l); + } + } + + // 2) deduplicate in place + let mut seen = HashSet::new(); + collected.retain(|loan| seen.insert(loan.clone())); + + Ok(collected) +} + +/** +fn ownership_safe_deref_abs<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + most_spec_pl: &'a internal::Place<'a>, ref_own: Ownership, -) -> OwnResult> { - let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr(PlaceExpr::new( - PlaceExprKind::Deref(Box::new(most_spec_pl.to_place_expr())), - )); + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + let currently_checked_pl_expr = pl_ctx_no_deref.insert_pl_expr( + arena, + PlaceExpr::new(PlaceExprKind::Deref(&most_spec_pl.to_place_expr(arena))), + ); // FIXME the type check should not have any effect, however guaranteeing that every place // expression even those which are formed recursively seem cleaner // pl_expr::ty_check(&PlExprTyCtx::from(ctx), &mut currently_checked_pl_expr)?; new_own_weaker_equal(ctx.own, ref_own)?; - ownership_safe_under_existing_borrows(ctx, ¤tly_checked_pl_expr)?; - let mut passed_through_prvs = HashSet::new(); - passed_through_prvs.insert(Loan { + ownership_safe_under_existing_borrows(ctx, ¤tly_checked_pl_expr, arena)?; + let mut passed_through_prvs = BumpVec::new_in(arena); + passed_through_prvs.push(Loan { place_expr: currently_checked_pl_expr, own: ctx.own, }); Ok(passed_through_prvs) } + */ + +fn ownership_safe_deref_abs<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_ctx_no_deref: &'a PlaceCtx<'a>, + most_spec_pl: &'a internal::Place<'a>, + ref_own: Ownership, + arena: &'a Bump, +) -> OwnResult<'a, BumpVec<'a, Loan<'a>>> { + // 1) Build and bump-allocate the inner PlaceExpr for the raw place + let base_pe: PlaceExpr<'a> = most_spec_pl.to_place_expr(arena); + let base_ref: &'a PlaceExpr<'a> = arena.alloc(base_pe); + + // 2) Insert into the context (still an owned PlaceExpr), then bump-allocate + let inserted_pe: PlaceExpr<'a> = pl_ctx_no_deref.insert_pl_expr(arena, base_ref.clone()); + let inserted_ref: &'a PlaceExpr<'a> = arena.alloc(inserted_pe); + + // 3) Wrap that in a Deref, bump-allocate again + let deref_pe = PlaceExpr::new(PlaceExprKind::Deref(inserted_ref)); + let deref_ref: &'a PlaceExpr<'a> = arena.alloc(deref_pe); + + // 4) Check ownership ordering + new_own_weaker_equal(ctx.own, ref_own)?; + + // 5) Run the “under existing borrows” check + ownership_safe_under_existing_borrows(ctx, deref_ref, arena)?; -fn narrowing_check( - ctx: &BorrowCheckCtx, - p: &PlaceExpr, - active_ctx_exec: &ExecExpr, -) -> OwnResult<()> { + // 6) Return a single‐element BumpVec of the resulting loan + let mut passed_through_prvs = BumpVec::new_in(arena); + passed_through_prvs.push(Loan { + place_expr: deref_ref.clone(), + own: ctx.own, + }); + + Ok(passed_through_prvs) +} + +fn narrowing_check<'a>( + ctx: &'a BorrowCheckCtx<'a>, + p: &'a PlaceExpr<'a>, + active_ctx_exec: &'a ExecExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, ()> { if ctx.own == Ownership::Shrd { return Ok(()); } match &p.pl_expr { PlaceExprKind::Ident(ident) => { - narrowable(&ctx.ty_ctx.ident_ty(ident)?.exec, active_ctx_exec) + narrowable(&ctx.ty_ctx.ident_ty(ident)?.exec, active_ctx_exec, arena) } PlaceExprKind::Select(pl_expr, select_exec) => { - narrowable(select_exec, active_ctx_exec)?; - let mut outer_exec = active_ctx_exec.remove_last_distrib(); - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut outer_exec)?; - narrowing_check(ctx, pl_expr, &outer_exec) + narrowable(select_exec, active_ctx_exec, arena)?; + + let mut outer_exec: ExecExpr<'a> = active_ctx_exec.remove_last_distrib(arena); + + exec::ty_check( + ctx.nat_ctx, + ctx.ty_ctx, + ctx.ident_exec, + &mut outer_exec, + arena, + )?; + + let outer_exec_ref: &'a ExecExpr<'a> = arena.alloc(outer_exec); + + narrowing_check(ctx, pl_expr, outer_exec_ref, arena) } PlaceExprKind::View(pl_expr, _) | PlaceExprKind::Deref(pl_expr) | PlaceExprKind::Proj(pl_expr, _) | PlaceExprKind::FieldProj(pl_expr, _) - | PlaceExprKind::Idx(pl_expr, _) => narrowing_check(ctx, pl_expr, active_ctx_exec), + | PlaceExprKind::Idx(pl_expr, _) => narrowing_check(ctx, pl_expr, active_ctx_exec, arena), } } -fn narrowable(from: &ExecExpr, to: &ExecExpr) -> OwnResult<()> { - let normal_from = normalize(from.clone()); - let normal_to = normalize(to.clone()); - exec_is_prefix_of(&normal_from, &normal_to)?; - no_forall_in_diff(&normal_from, &normal_to) +fn narrowable<'a>( + from: &'a ExecExpr<'a>, + to: &'a ExecExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, ()> { + let normal_from = arena.alloc(normalize(from.clone(), arena)); + let normal_to = arena.alloc(normalize(to.clone(), arena)); + exec_is_prefix_of(normal_from, normal_to)?; + no_forall_in_diff(normal_from, normal_to) } -fn exec_is_prefix_of(prefix: &ExecExpr, of: &ExecExpr) -> OwnResult<()> { +fn exec_is_prefix_of<'a>(prefix: &'a ExecExpr<'a>, of: &'a ExecExpr<'a>) -> OwnResult<'a, ()> { if prefix.exec.base != of.exec.base { return Err(BorrowingError::WrongDevice( of.exec.base.clone(), @@ -259,9 +398,13 @@ fn exec_is_prefix_of(prefix: &ExecExpr, of: &ExecExpr) -> OwnResult<()> { Ok(()) } -fn access_conflict_check(ctx: &BorrowCheckCtx, p: &PlaceExpr) -> OwnResult<()> { +fn access_conflict_check<'a>( + ctx: &'a BorrowCheckCtx<'a>, + p: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, ()> { for loan in ctx.access_ctx.hash_set() { - if possible_conflict_with_previous_access(ctx.nat_ctx, ctx.own, p, loan)? { + if possible_conflict_with_previous_access(ctx.nat_ctx, ctx.own, p, loan, arena)? { return Err(BorrowingError::Conflict { checked: p.clone(), existing: loan.place_expr.clone(), @@ -271,22 +414,29 @@ fn access_conflict_check(ctx: &BorrowCheckCtx, p: &PlaceExpr) -> OwnResult<()> { Ok(()) } -fn possible_conflict_with_previous_access( - nat_ctx: &NatCtx, +fn possible_conflict_with_previous_access<'a>( + nat_ctx: &'a NatCtx<'a>, own: Ownership, - p: &PlaceExpr, - previous: &Loan, -) -> NatEvalResult { + p: &'a PlaceExpr<'a>, + previous: &'a Loan<'a>, + arena: &'a Bump, +) -> NatEvalResult<'a, bool> { if own == Ownership::Shrd && previous.own == Ownership::Shrd { return Ok(false); } - let (p_ident, p_path) = p.as_ident_and_path(); - let (l_ident, l_path) = previous.place_expr.as_ident_and_path(); + + let (p_ident, p_path_local) = p.as_ident_and_path(arena); + let (l_ident, l_path_local) = previous.place_expr.as_ident_and_path(arena); + + let p_path: &'a [PlExprPathElem<'a>] = arena.alloc_slice_clone(&p_path_local); + let l_path: &'a [PlExprPathElem<'a>] = arena.alloc_slice_clone(&l_path_local); + if p_ident != l_ident { return Ok(false); } - for path_elems in p_path.iter().zip(&l_path) { - match path_elems { + + for (pe_p, pe_l) in p_path.iter().zip(l_path.iter()) { + match (pe_p, pe_l) { (PlExprPathElem::Deref, PlExprPathElem::Deref) => {} (PlExprPathElem::Proj(kp), PlExprPathElem::Proj(kl)) => { if kp != kl { @@ -335,20 +485,20 @@ fn possible_conflict_with_previous_access( Ok(false) } -fn range_intersects( - nat_ctx: &NatCtx, - lower_left: &Nat, - upper_left: &Nat, - lower_right: &Nat, - upper_right: &Nat, -) -> NatEvalResult { +fn range_intersects<'a>( + nat_ctx: &'a NatCtx<'a>, + lower_left: &'a Nat<'a>, + upper_left: &'a Nat<'a>, + lower_right: &'a Nat<'a>, + upper_right: &'a Nat<'a>, +) -> NatEvalResult<'a, bool> { Ok((lower_left.eval(nat_ctx)? < lower_right.eval(nat_ctx)? && upper_left.eval(nat_ctx)? <= lower_right.eval(nat_ctx)?) || (lower_left.eval(nat_ctx)? >= upper_right.eval(nat_ctx)? && upper_left.eval(nat_ctx)? > upper_right.eval(nat_ctx)?)) } -fn no_forall_in_diff(from: &ExecExpr, under: &ExecExpr) -> OwnResult<()> { +fn no_forall_in_diff<'a>(from: &'a ExecExpr<'a>, under: &'a ExecExpr<'a>) -> OwnResult<'a, ()> { if from.exec.path.len() > under.exec.path.len() { return Err(BorrowingError::CannotNarrow); } @@ -360,16 +510,33 @@ fn no_forall_in_diff(from: &ExecExpr, under: &ExecExpr) -> OwnResult<()> { Ok(()) } -fn pl_ctxs_and_places_in_loans( - loans: &HashSet, -) -> impl Iterator + '_ { +/** +fn pl_ctxs_and_places_in_loans<'a>( + loans: &HashSet>, + arena: &'a Bump, +) -> impl Iterator, internal::Place<'a>)> + 'a { + // was '_ before, what does that mean loans .iter() .map(|loan| &loan.place_expr) - .map(|pl_expr| pl_expr.to_pl_ctx_and_most_specif_pl()) + .map(|pl_expr| pl_expr.to_pl_ctx_and_most_specif_pl(arena)) } +*/ -fn new_own_weaker_equal(checked_own: Ownership, ref_own: Ownership) -> OwnResult<()> { +fn pl_ctxs_and_places_in_loans<'a, I>( + loans: I, + arena: &'a Bump, +) -> impl Iterator, internal::Place<'a>)> + 'a +where + I: IntoIterator> + 'a, + ::IntoIter: 'a, +{ + loans + .into_iter() + .map(move |loan| loan.place_expr.to_pl_ctx_and_most_specif_pl(arena)) +} + +fn new_own_weaker_equal<'a>(checked_own: Ownership, ref_own: Ownership) -> OwnResult<'a, ()> { if ref_own < checked_own { Err(BorrowingError::ConflictingOwnership) } else { @@ -377,19 +544,21 @@ fn new_own_weaker_equal(checked_own: Ownership, ref_own: Ownership) -> OwnResult } } -fn ownership_safe_under_existing_borrows( - ctx: &BorrowCheckCtx, - pl_expr: &PlaceExpr, -) -> OwnResult<()> { +fn ownership_safe_under_existing_borrows<'a>( + ctx: &'a BorrowCheckCtx<'a>, + pl_expr: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> OwnResult<'a, ()> { if !ctx.unsafe_flag { for prv_mapping in ctx.ty_ctx.prv_mappings() { let PrvMapping { prv, loans } = prv_mapping; - let no_uniq_overlap = no_uniq_loan_overlap(ctx.own, pl_expr, loans).is_none(); + let no_uniq_overlap = no_uniq_loan_overlap(ctx.own, pl_expr, loans, arena).is_none(); if !no_uniq_overlap { return at_least_one_borrowing_place_and_all_in_reborrow( ctx.ty_ctx, prv, &ctx.reborrows, + arena, ); } } @@ -397,26 +566,55 @@ fn ownership_safe_under_existing_borrows( Ok(()) } +/** // returns None if there is no unique loan overlap or Some with the existing overlapping loan -fn no_uniq_loan_overlap( +fn no_uniq_loan_overlap<'a>( own: Ownership, - pl_expr: &PlaceExpr, - loans: &HashSet, -) -> Option { + pl_expr: &'a PlaceExpr<'a>, + loans: &HashSet>, + arena: &'a Bump, +) -> Option> { for l in loans { - if (own == Ownership::Uniq || l.own == Ownership::Uniq) && overlap(&l.place_expr, pl_expr) { + if (own == Ownership::Uniq || l.own == Ownership::Uniq) + && overlap(&l.place_expr, pl_expr, arena) + { return Some(l.clone()); } } None } +*/ + +/// Returns `Some(clashing_loan)` if there is any unique‐ownership overlap, +/// or `None` if none of the loans conflict. +fn no_uniq_loan_overlap<'a, I>( + own: Ownership, + pl_expr: &'a PlaceExpr<'a>, + loans: I, + arena: &'a Bump, +) -> Option> +where + I: IntoIterator>, +{ + loans.into_iter().find_map(|l| { + // must be at least one Unique side, and the places must overlap + if (own == Ownership::Uniq || l.own == Ownership::Uniq) + && overlap(&l.place_expr, pl_expr, arena) + { + Some(l.clone()) + } else { + None + } + }) +} -fn at_least_one_borrowing_place_and_all_in_reborrow( - ty_ctx: &TyCtx, +fn at_least_one_borrowing_place_and_all_in_reborrow<'a>( + ty_ctx: &'a TyCtx<'a>, prv_name: &str, - reborrows: &[internal::Place], -) -> OwnResult<()> { - let all_places = ty_ctx.all_places(); + reborrows: &[internal::Place<'a>], + arena: &'a Bump, +) -> OwnResult<'a, ()> { + let all_places = ty_ctx.all_places(arena); // check that a borrow with given provenance exists. // It could not exist for example in case it is used for a parameter // during function application. The second part of this function would succeed in this case @@ -439,19 +637,19 @@ fn at_least_one_borrowing_place_and_all_in_reborrow( Ok(()) } -fn conflicting_path(pathl: &[PlExprPathElem], pathr: &[PlExprPathElem]) -> bool { +fn conflicting_path<'a>(pathl: &[PlExprPathElem<'a>], pathr: &[PlExprPathElem<'a>]) -> bool { for lr in pathl.iter().zip(pathr) { match lr { (PlExprPathElem::Idx(_), _) => return true, (v @ PlExprPathElem::View(iv), path_elem) - if v != path_elem && iv.name.name.as_ref() != pre_decl::SELECT_RANGE => + if v != path_elem && iv.name.name != pre_decl::SELECT_RANGE => { return true } (PlExprPathElem::View(ivl), PlExprPathElem::View(ivr)) if ivl != ivr - && ivl.name.name.as_ref() == pre_decl::SELECT_RANGE - && ivr.name.name.as_ref() == pre_decl::SELECT_RANGE => + && ivl.name.name == pre_decl::SELECT_RANGE + && ivr.name.name == pre_decl::SELECT_RANGE => { match ( &ivr.gen_args[0], @@ -460,10 +658,10 @@ fn conflicting_path(pathl: &[PlExprPathElem], pathr: &[PlExprPathElem]) -> bool &ivr.gen_args[1], ) { ( - ArgKinded::Nat(lower_left), - ArgKinded::Nat(upper_left), - ArgKinded::Nat(lower_right), - ArgKinded::Nat(upper_right), + ArgKinded::Nat(_lower_left), + ArgKinded::Nat(_upper_left), + ArgKinded::Nat(_lower_right), + ArgKinded::Nat(_upper_right), ) => { // intersecting ranges // TAKE CARE: the comparisons are partial and return false in case the @@ -483,9 +681,9 @@ fn conflicting_path(pathl: &[PlExprPathElem], pathr: &[PlExprPathElem]) -> bool true } -fn overlap(pll: &PlaceExpr, plr: &PlaceExpr) -> bool { - let (pl_ident, pl_path) = pll.as_ident_and_path(); - let (pr_ident, pr_path) = plr.as_ident_and_path(); +fn overlap<'a>(pll: &'a PlaceExpr<'a>, plr: &'a PlaceExpr<'a>, arena: &'a Bump) -> bool { + let (pl_ident, pl_path) = pll.as_ident_and_path(arena); + let (pr_ident, pr_path) = plr.as_ident_and_path(arena); if pl_ident == pr_ident { conflicting_path(&pl_path, &pr_path) || conflicting_path(&pr_path, &pl_path) } else { diff --git a/src/ty_check/ctxs.rs b/src/ty_check/ctxs.rs index 3283105c..4f31e23a 100644 --- a/src/ty_check/ctxs.rs +++ b/src/ty_check/ctxs.rs @@ -1,26 +1,28 @@ -use crate::ast::internal::{ +use crate::arena_ast::internal::{ ExecMapping, Frame, FrameEntry, IdentTyped, Loan, PathElem, PrvMapping, }; -use crate::ast::*; +use crate::arena_ast::*; use crate::ty_check::error::CtxError; +use bumpalo::collections::CollectIn; +use bumpalo::{collections::Vec as BumpVec, Bump}; use std::collections::HashSet; // TODO introduce proper struct -pub(super) type TypedPlace = (internal::Place, DataTy); +pub(super) type TypedPlace<'a> = (internal::Place<'a>, DataTy<'a>); #[derive(PartialEq, Eq, Debug, Clone)] -pub(super) struct TyCtx { - frames: Vec, +pub(super) struct TyCtx<'a> { + frames: BumpVec<'a, Frame<'a>>, } -impl TyCtx { - pub fn new() -> Self { - TyCtx { - frames: vec![Frame::new()], - } +impl<'a> TyCtx<'a> { + pub fn new(arena: &'a Bump) -> Self { + let mut frames = BumpVec::new_in(arena); + frames.push(Frame::new_in(arena)); + TyCtx { frames } } - pub fn get_exec_expr_for_exec_ident(&self, ident: &Ident) -> CtxResult<&ExecExpr> { + pub fn get_exec_expr_for_exec_ident(&self, ident: &Ident<'a>) -> CtxResult<'a, &ExecExpr<'a>> { let exec_expr = self.flat_bindings().rev().find_map(|entry| match entry { FrameEntry::ExecMapping(em) if &em.ident == ident => Some(&em.exec_expr), _ => None, @@ -30,59 +32,55 @@ impl TyCtx { None => Err(CtxError::IdentNotFound(ident.clone())), } } - // - // pub fn last_frame(&self) -> &Frame { - // self.frames.last().unwrap() - // } - pub fn last_frame_mut(&mut self) -> &mut Frame { + pub fn last_frame_mut(&mut self) -> &mut Frame<'a> { self.frames.last_mut().unwrap() } - pub fn flat_bindings_mut(&mut self) -> impl DoubleEndedIterator { - self.frames.iter_mut().flat_map(|frm| &mut frm.bindings) + fn flat_bindings(&self) -> impl DoubleEndedIterator> { + self.frames.iter().flat_map(|f| &f.bindings) } - - pub fn flat_bindings(&self) -> impl DoubleEndedIterator { - self.frames.iter().flat_map(|frm| &frm.bindings) + fn flat_bindings_mut(&mut self) -> impl DoubleEndedIterator> { + self.frames.iter_mut().flat_map(|f| &mut f.bindings) } - pub fn push_empty_frame(&mut self) -> &mut Self { - self.frames.push(Frame::new()); + pub fn push_empty_frame(&mut self, arena: &'a Bump) -> &mut Self { + self.frames.push(Frame::new_in(arena)); self } - pub fn push_frame(&mut self, frame: Frame) -> &mut Self { + pub fn push_frame(&mut self, frame: Frame<'a>) -> &mut Self { self.frames.push(frame); self } - pub fn pop_frame(&mut self) -> Frame { - self.frames.pop().expect("There must always be a scope.") + pub fn pop_frame(&mut self) -> Frame<'a> { + assert!(self.frames.len() > 1, "Cannot pop the last frame"); + self.frames.pop().unwrap() } - pub fn append_ident_typed(&mut self, id_typed: IdentTyped) -> &mut Self { + pub fn append_ident_typed(&mut self, id_typed: IdentTyped<'a>) -> &mut Self { self.last_frame_mut() .bindings .push(FrameEntry::Var(id_typed)); self } - pub fn append_exec_mapping(&mut self, ident: Ident, exec: ExecExpr) -> &mut Self { + pub fn append_exec_mapping(&mut self, ident: Ident<'a>, exec: ExecExpr<'a>) -> &mut Self { self.last_frame_mut() .bindings .push(FrameEntry::ExecMapping(ExecMapping::new(ident, exec))); self } - pub fn append_prv_mapping(&mut self, prv_mapping: PrvMapping) -> &mut Self { + pub fn append_prv_mapping(&mut self, prv_mapping: PrvMapping<'a>) -> &mut Self { self.last_frame_mut() .bindings .push(FrameEntry::PrvMapping(prv_mapping)); self } - fn idents_typed(&self) -> impl DoubleEndedIterator { + fn idents_typed(&self) -> impl DoubleEndedIterator> { self.flat_bindings().filter_map(|fe| { if let FrameEntry::Var(ident_typed) = fe { Some(ident_typed) @@ -92,7 +90,7 @@ impl TyCtx { }) } - fn idents_typed_mut(&mut self) -> impl DoubleEndedIterator { + fn idents_typed_mut(&mut self) -> impl DoubleEndedIterator> { self.flat_bindings_mut().filter_map(|fe| { if let FrameEntry::Var(ident_typed) = fe { Some(ident_typed) @@ -102,7 +100,7 @@ impl TyCtx { }) } - pub(crate) fn prv_mappings(&self) -> impl DoubleEndedIterator { + pub(crate) fn prv_mappings(&self) -> impl DoubleEndedIterator> { self.flat_bindings().filter_map(|fe| { if let FrameEntry::PrvMapping(prv_mapping) = fe { Some(prv_mapping) @@ -112,7 +110,7 @@ impl TyCtx { }) } - fn prv_mappings_mut(&mut self) -> impl DoubleEndedIterator { + fn prv_mappings_mut(&mut self) -> impl DoubleEndedIterator> { self.flat_bindings_mut().filter_map(|fe| { if let FrameEntry::PrvMapping(prv_mapping) = fe { Some(prv_mapping) @@ -125,7 +123,7 @@ impl TyCtx { pub fn update_loan_set( &mut self, prv_val_name: &str, - loan_set: HashSet, + loan_set: HashSet>, ) -> CtxResult<&mut Self> { let mut found = false; for prv_mapping in self.prv_mappings_mut().rev() { @@ -142,16 +140,20 @@ impl TyCtx { } } - pub fn extend_loans_for_prv(&mut self, base: &str, extension: I) -> CtxResult<&mut TyCtx> + pub fn extend_loans_for_prv( + &mut self, + base: &str, + extension: I, + ) -> CtxResult<'a, &mut TyCtx<'a>> where - I: IntoIterator, + I: IntoIterator>, { let base_loans = self.loans_for_prv_mut(base)?; base_loans.extend(extension); Ok(self) } - pub fn loans_in_prv(&self, prv_val_name: &str) -> CtxResult<&HashSet> { + pub fn loans_in_prv(&self, prv_val_name: &str) -> CtxResult<'a, &HashSet>> { match self .prv_mappings() .rev() @@ -162,7 +164,22 @@ impl TyCtx { } } - pub fn loans_for_prv_mut(&mut self, prv_val_name: &str) -> CtxResult<&mut HashSet> { + /// Return an arena-owned snapshot of the loans for `prv_val_name`. + pub fn loans_in_prv_snapshot( + &self, + prv_val_name: &str, + arena: &'a bumpalo::Bump, + ) -> CtxResult<'a, bumpalo::collections::Vec<'a, Loan<'a>>> { + let set = self.loans_in_prv(prv_val_name)?; + let mut out = bumpalo::collections::Vec::new_in(arena); + out.extend(set.iter().cloned()); + Ok(out) + } + + pub fn loans_for_prv_mut( + &mut self, + prv_val_name: &str, + ) -> CtxResult<'a, &mut HashSet>> { match self .prv_mappings_mut() .rev() @@ -187,26 +204,34 @@ impl TyCtx { } // ∀π:τ ∈ Γ - pub fn all_places(&self) -> Vec { + pub fn all_places(&self, arena: &'a Bump) -> BumpVec<'a, TypedPlace<'a>> { self.idents_typed() .filter_map(|IdentTyped { ident, ty, .. }| { if let TyKind::Data(dty) = &ty.ty { - Some(TyCtx::explode_places(ident, dty)) + Some(TyCtx::explode_places(ident, dty, arena)) } else { None } }) .flatten() - .collect() + .collect_in(arena) } - fn explode_places(ident: &Ident, dty: &DataTy) -> Vec { - fn proj(mut pl: internal::Place, idx: PathElem) -> internal::Place { + fn explode_places( + ident: &Ident<'a>, + dty: &DataTy<'a>, + arena: &'a Bump, + ) -> BumpVec<'a, TypedPlace<'a>> { + fn proj<'a>(mut pl: internal::Place<'a>, idx: PathElem<'a>) -> internal::Place<'a> { pl.path.push(idx); pl } - fn explode(pl: internal::Place, dty: DataTy) -> Vec { + fn explode<'a>( + pl: internal::Place<'a>, + dty: DataTy<'a>, + arena: &'a Bump, + ) -> BumpVec<'a, TypedPlace<'a>> { use DataTyKind as d; match &dty.dty { @@ -218,22 +243,29 @@ impl TyCtx { | d::Ref(_) | d::RawPtr(_) | d::Ident(_) - | d::Dead(_) => vec![(pl, dty.clone())], + | d::Dead(_) => BumpVec::from_iter_in([(pl.clone(), dty.clone())], arena), //vec![(pl, dty.clone())], d::Tuple(tys) => { - let mut place_frame = vec![(pl.clone(), dty.clone())]; + let mut place_frame = BumpVec::from_iter_in([(pl.clone(), dty.clone())], arena); for (index, proj_ty) in tys.iter().enumerate() { - let mut exploded_index = - explode(proj(pl.clone(), PathElem::Proj(index)), proj_ty.clone()); + let mut exploded_index = explode( + proj(pl.clone(), PathElem::Proj(index)), + proj_ty.clone(), + arena, + ); place_frame.append(&mut exploded_index); } place_frame } d::Struct(sdecl) => { - let mut place_frame = vec![(pl.clone(), dty.clone())]; + let mut place_frame = BumpVec::from_iter_in([(pl.clone(), dty.clone())], arena); for field in sdecl.fields.iter() { let mut exploded_field = explode( - proj(pl.clone(), PathElem::FieldProj(Box::new(field.0.clone()))), + proj( + pl.clone(), + PathElem::FieldProj(arena.alloc(field.0.clone())), + ), field.1.clone(), + arena, ); place_frame.append(&mut exploded_field); } @@ -242,14 +274,18 @@ impl TyCtx { } } - explode(internal::Place::new(ident.clone(), vec![]), dty.clone()) + explode( + internal::Place::new(ident.clone(), BumpVec::new_in(arena)), + dty.clone(), + arena, + ) } - pub fn ty_of_ident(&self, ident: &Ident) -> CtxResult<&Ty> { + pub fn ty_of_ident(&self, ident: &Ident<'a>) -> CtxResult<'a, &Ty<'a>> { Ok(&self.ident_ty(ident)?.ty) } - pub fn ident_ty(&self, ident: &Ident) -> CtxResult<&IdentTyped> { + pub fn ident_ty(&self, ident: &Ident<'a>) -> CtxResult<'a, &IdentTyped<'a>> { match self .idents_typed() .rev() @@ -260,12 +296,12 @@ impl TyCtx { } } - pub fn contains(&self, ident: &Ident) -> bool { + pub fn contains(&self, ident: &Ident<'a>) -> bool { self.idents_typed().any(|i| i.ident.name == ident.name) } - pub fn place_dty(&self, place: &internal::Place) -> CtxResult { - fn proj_ty(dty: DataTy, path: &[PathElem]) -> CtxResult { + pub fn place_dty(&self, place: &internal::Place<'a>) -> CtxResult<'a, DataTy<'a>> { + fn proj_ty<'a>(dty: DataTy<'a>, path: &[PathElem<'a>]) -> CtxResult<'a, DataTy<'a>> { let mut res_dty = dty; for pe in path { match (&res_dty.dty, pe) { @@ -277,7 +313,7 @@ impl TyCtx { } (DataTyKind::Struct(struct_decl), PathElem::FieldProj(ident)) => { res_dty = if let Some(field) = - struct_decl.fields.iter().find(|f| &f.0 == ident.as_ref()) + struct_decl.fields.iter().find(|f| &f.0 == *ident) { field.1.clone() } else { @@ -297,18 +333,24 @@ impl TyCtx { } let ident_ty = self.ty_of_ident(&place.ident)?; if let TyKind::Data(dty) = &ident_ty.ty { - proj_ty(dty.as_ref().clone(), &place.path) + proj_ty((**dty).clone(), &place.path) } else { panic!("This place is not of a data type.") } } - pub fn set_place_dty(&mut self, pl: &internal::Place, pl_ty: DataTy) -> &mut Self { - fn set_dty_for_path_in_dty( - orig_dty: DataTy, - path: &[PathElem], - part_dty: DataTy, - ) -> DataTy { + pub fn set_place_dty( + &mut self, + pl: &internal::Place<'a>, + pl_ty: DataTy<'a>, + arena: &'a Bump, + ) -> &mut Self { + fn set_dty_for_path_in_dty<'a>( + arena: &'a Bump, + orig_dty: DataTy<'a>, + path: &[PathElem<'a>], + part_dty: DataTy<'a>, + ) -> DataTy<'a> { if path.is_empty() { return part_dty; } @@ -317,17 +359,15 @@ impl TyCtx { match (orig_dty.dty, pe) { (DataTyKind::Tuple(mut elem_tys), PathElem::Proj(n)) => { elem_tys[*n] = - set_dty_for_path_in_dty(elem_tys[*n].clone(), &path[1..], part_dty); - DataTy::new(DataTyKind::Tuple(elem_tys)) + set_dty_for_path_in_dty(arena, elem_tys[*n].clone(), &path[1..], part_dty); + DataTy::new(arena, DataTyKind::Tuple(elem_tys)) } - (DataTyKind::Struct(mut struct_decl), PathElem::FieldProj(ident)) => { - if let Some(field) = struct_decl - .fields - .iter_mut() - .find(|f| &f.0 == ident.as_ref()) - { - field.1 = set_dty_for_path_in_dty(field.1.clone(), &path[1..], part_dty); - DataTy::new(DataTyKind::Struct(struct_decl)) + (DataTyKind::Struct(struct_decl), PathElem::FieldProj(ident)) => { + let struct_decl = arena.alloc(struct_decl.clone()); + if let Some(field) = struct_decl.fields.iter_mut().find(|f| &f.0 == *ident) { + field.1 = + set_dty_for_path_in_dty(arena, field.1.clone(), &path[1..], part_dty); + DataTy::new(arena, DataTyKind::Struct(struct_decl)) } else { panic!("Struct field with name `{}` does not exist.", ident.name) } @@ -336,23 +376,28 @@ impl TyCtx { } } - let mut ident_typed = self + let ident_typed = self .idents_typed_mut() .rev() .find(|ident_typed| ident_typed.ident == pl.ident) .unwrap(); if let TyKind::Data(dty) = &ident_typed.ty.ty { - let updated_dty = set_dty_for_path_in_dty(*dty.clone(), pl.path.as_slice(), pl_ty); - ident_typed.ty = Ty::new(TyKind::Data(Box::new(updated_dty))); + let updated_dty = + set_dty_for_path_in_dty(arena, (**dty).clone(), pl.path.as_slice(), pl_ty); + ident_typed.ty = Ty::new(TyKind::Data(arena.alloc(updated_dty))); self } else { panic!("Trying to set data type for identifier without data type.") } } - pub fn kill_place(&mut self, pl: &internal::Place) -> &mut Self { + pub fn kill_place(&mut self, pl: &internal::Place<'a>, arena: &'a Bump) -> &mut Self { if let Ok(pl_dty) = self.place_dty(pl) { - self.set_place_dty(pl, DataTy::new(DataTyKind::Dead(Box::new(pl_dty)))) + self.set_place_dty( + pl, + DataTy::new(arena, DataTyKind::Dead(arena.alloc(pl_dty))), + arena, + ) } else { panic!("Trying to kill the type of a place that doesn't exist.") } @@ -382,7 +427,7 @@ impl TyCtx { } // Γ ▷- p = Γ′ - pub(super) fn without_reborrow_loans(&mut self, pl_expr: &PlaceExpr) -> &mut Self { + pub(super) fn without_reborrow_loans(&mut self, pl_expr: &PlaceExpr<'a>) -> &mut Self { for frame_entry in self.flat_bindings_mut() { if let FrameEntry::PrvMapping(PrvMapping { prv: _, loans }) = frame_entry { // FIXME not prefix_of but *x within p? @@ -407,26 +452,31 @@ impl TyCtx { } } -pub(super) struct AccessCtx { - ctx: HashSet, +pub(super) struct AccessCtx<'a> { + ctx: BumpVec<'a, Loan<'a>>, } -impl AccessCtx { - pub fn new() -> Self { +impl<'a> AccessCtx<'a> { + pub fn new(arena: &'a Bump) -> Self { AccessCtx { - ctx: HashSet::new(), + ctx: BumpVec::new_in(arena), } } - pub fn insert(&mut self, loans: HashSet) { + pub fn insert(&mut self, loans: BumpVec>) { self.ctx.extend(loans.into_iter()) } - pub fn hash_set(&self) -> &HashSet { + pub fn hash_set(&self) -> &BumpVec> { &self.ctx } - pub fn clear_sync_for(&mut self, ty_ctx: &TyCtx, exec: &ExecExpr) { + pub fn clear_sync_for( + &mut self, + ty_ctx: &'a TyCtx<'a>, + exec: &'a ExecExpr<'a>, + arena: &'a Bump, + ) { self.ctx = self .ctx .iter() @@ -436,26 +486,53 @@ impl AccessCtx { place_expr, }) }) - .collect(); + .collect_in(arena); + } + + // a tiny helper that drills down a PlaceExpr to its `Ident` + // and returns it by value (i.e. clones the Box inside Ident) + // maybe move this one out ? + fn root_ident_of_expr(pe: &PlaceExpr<'a>) -> Ident<'a> { + match &pe.pl_expr { + PlaceExprKind::Ident(id) => id.clone(), + PlaceExprKind::Select(inner, _) + | PlaceExprKind::View(inner, _) + | PlaceExprKind::Proj(inner, _) + | PlaceExprKind::FieldProj(inner, _) + | PlaceExprKind::Idx(inner, _) + | PlaceExprKind::Deref(inner) => { + // recursive descent + Self::root_ident_of_expr(inner) + } + } } - pub fn garbage_collect(&mut self, ty_ctx: &TyCtx) { - // TODO make more efficient - // drain is unstable for HashSet, use Vec anyway? - let mut cleaned_up_set = HashSet::new(); - for l in &self.ctx { - let ident = &l.place_expr.as_ident_and_path().0; - if ty_ctx.contains(ident) { - cleaned_up_set.insert(l.clone()); + pub fn garbage_collect(&mut self, ty_ctx: &TyCtx<'a>, arena: &'a Bump) { + // 1) take ownership of the old loans + let old_loans = std::mem::replace(&mut self.ctx, BumpVec::new_in(arena)); + + // 2) build a fresh vec of only the “alive” loans + let mut new_loans = BumpVec::new_in(arena); + for loan in old_loans.into_iter() { + // extract root ident by *value* (no long‐lived borrow) + let ident = Self::root_ident_of_expr(&loan.place_expr); + if ty_ctx.contains(&ident) { + new_loans.push(loan); } } - self.ctx = cleaned_up_set; + + // 3) store it back + self.ctx = new_loans; } } -fn trim_after_select_of(ty_ctx: &TyCtx, exec: &ExecExpr, pl_expr: PlaceExpr) -> Option { +fn trim_after_select_of<'a>( + ty_ctx: &'a TyCtx<'a>, + exec: &'a ExecExpr<'a>, + pl_expr: PlaceExpr<'a>, +) -> Option> { match pl_expr.pl_expr { - PlaceExprKind::Select(p, sel_exec) if sel_exec.as_ref() == exec => { + PlaceExprKind::Select(p, sel_exec) if sel_exec == exec => { Some(PlaceExpr::new(PlaceExprKind::Select(p, sel_exec))) } PlaceExprKind::Select(ipl, _) @@ -463,7 +540,7 @@ fn trim_after_select_of(ty_ctx: &TyCtx, exec: &ExecExpr, pl_expr: PlaceExpr) -> | PlaceExprKind::Proj(ipl, _) | PlaceExprKind::FieldProj(ipl, _) | PlaceExprKind::Idx(ipl, _) - | PlaceExprKind::Deref(ipl) => trim_after_select_of(ty_ctx, exec, *ipl), + | PlaceExprKind::Deref(ipl) => trim_after_select_of(ty_ctx, exec, ipl.clone()), PlaceExprKind::Ident(ident) => { let ident_exec = &ty_ctx .ident_ty(&ident) @@ -482,32 +559,38 @@ fn trim_after_select_of(ty_ctx: &TyCtx, exec: &ExecExpr, pl_expr: PlaceExpr) -> } #[derive(PartialEq, Eq, Debug, Clone)] -enum KindingCtxEntry { - Ident(IdentKinded), - PrvRel(PrvRel), +enum KindingCtxEntry<'a> { + Ident(IdentKinded<'a>), + PrvRel(PrvRel<'a>), } -pub(super) type CtxResult = Result; +pub(super) type CtxResult<'a, T> = Result>; #[derive(PartialEq, Eq, Debug, Clone)] -pub(super) struct KindCtx { - ctx: Vec>, +pub(super) struct KindCtx<'a> { + ctx: BumpVec<'a, BumpVec<'a, KindingCtxEntry<'a>>>, } -impl KindCtx { - pub fn new() -> Self { - KindCtx { ctx: vec![vec![]] } +impl<'a> KindCtx<'a> { + pub fn new(arena: &'a Bump) -> Self { + let mut scopes = BumpVec::new_in(arena); + scopes.push(BumpVec::new_in(arena)); + KindCtx { ctx: scopes } } - pub fn gl_fun_kind_ctx(idents: Vec, prv_rels: Vec) -> CtxResult { - let mut kind_ctx: Self = KindCtx::new(); + pub fn gl_fun_kind_ctx( + idents: BumpVec<'a, IdentKinded<'a>>, + prv_rels: BumpVec<'a, PrvRel<'a>>, + arena: &'a Bump, + ) -> CtxResult<'a, Self> { + let mut kind_ctx: Self = KindCtx::new(arena); kind_ctx.append_idents(idents); kind_ctx.append_prv_rels(prv_rels)?; Ok(kind_ctx) } - pub fn push_empty_scope(&mut self) -> &mut Self { - self.ctx.push(vec![]); + pub fn push_empty_scope(&mut self, arena: &'a Bump) -> &mut Self { + self.ctx.push(BumpVec::new_in(arena)); self } @@ -515,7 +598,10 @@ impl KindCtx { self.ctx.pop(); } - pub fn append_idents>(&mut self, idents: I) -> &mut Self { + pub fn append_idents>>( + &mut self, + idents: I, + ) -> &mut Self { let entries = idents.into_iter().map(KindingCtxEntry::Ident); for e in entries { self.ctx.last_mut().unwrap().push(e); @@ -523,10 +609,10 @@ impl KindCtx { self } - pub fn append_prv_rels + Clone>( + pub fn append_prv_rels> + Clone>( &mut self, prv_rels: I, - ) -> CtxResult<&mut Self> { + ) -> CtxResult<'a, &mut Self> { self.well_kinded_prv_rels(prv_rels.clone())?; for prv_rel in prv_rels { self.ctx @@ -537,10 +623,10 @@ impl KindCtx { Ok(self) } - pub fn well_kinded_prv_rels>( + pub fn well_kinded_prv_rels>>( &self, prv_rels: I, - ) -> CtxResult<()> { + ) -> CtxResult<'a, ()> { let mut prv_idents = self.get_idents(Kind::Provenance); for prv_rel in prv_rels.into_iter() { if !prv_idents.any(|prv_ident| &prv_rel.longer == prv_ident) { @@ -553,7 +639,7 @@ impl KindCtx { Ok(()) } - pub fn get_idents(&self, kind: Kind) -> impl Iterator { + pub fn get_idents(&'a self, kind: Kind) -> impl Iterator> + 'a { self.ctx.iter().flatten().filter_map(move |entry| { if let KindingCtxEntry::Ident(IdentKinded { ident, kind: k }) = entry { if k == &kind { @@ -567,11 +653,11 @@ impl KindCtx { }) } - pub fn ident_of_kind_exists(&self, ident: &Ident, kind: Kind) -> bool { + pub fn ident_of_kind_exists(&self, ident: &'a Ident<'a>, kind: Kind) -> bool { self.get_idents(kind).any(|id| ident == id) } - pub fn outlives(&self, l: &Ident, s: &Ident) -> CtxResult<()> { + pub fn outlives(&self, l: &'a Ident<'a>, s: &'a Ident<'a>) -> CtxResult<'a, ()> { if self.ctx.iter().flatten().any(|entry| match entry { KindingCtxEntry::PrvRel(PrvRel { longer, shorter }) => longer == l && shorter == s, _ => false, @@ -584,38 +670,55 @@ impl KindCtx { } #[derive(Debug, Clone)] -pub(super) enum GlobalDecl { - FnDecl(Box, Box), - StructDecl(Box), +pub(super) enum GlobalDecl<'a> { + FnDecl(&'a str, &'a FnTy<'a>), + StructDecl(&'a StructDecl<'a>), } #[derive(Debug)] -pub(super) struct GlobalCtx<'src, 'compil> { - compil_unit: &'compil mut CompilUnit<'src>, - checked_funs: Vec<(Box, Box<[usize]>)>, - decls: Vec, - //items: HashMap, GlobalItem>, +pub(super) struct GlobalCtx<'a> { + pub compil_unit: &'a mut CompilUnit<'a>, + checked_funs: BumpVec<'a, (&'a str, &'a [usize])>, + decls: BumpVec<'a, GlobalDecl<'a>>, } -impl<'src, 'compil> GlobalCtx<'src, 'compil> { - pub fn new(compil_unit: &'compil mut CompilUnit<'src>, mut decls: Vec) -> Self { - let mut compil_unit_decls = compil_unit - .items - .iter() - .map(|item| match item { +impl<'a> GlobalCtx<'a> { + pub fn new( + compil_unit: &'a mut CompilUnit<'a>, + mut decls: BumpVec<'a, GlobalDecl<'a>>, + arena: &'a Bump, + ) -> Self { + // 1) grab a raw pointer + length; this does NOT borrow. + let items_ptr = compil_unit.items.as_ptr(); + let len = compil_unit.items.len(); + + // 2) iterate by pointer offets + for i in 0..len { + // SAFETY: `i < len` so ptr.add(i) is in-bounds, and we never touch compil_unit.items mutably. + let item: &Item<'a> = unsafe { &*items_ptr.add(i) }; + match item { Item::FunDef(fun_def) => { - GlobalDecl::FnDecl(fun_def.ident.name.clone(), Box::new(fun_def.fn_ty())) + let name: &str = &fun_def.ident.name; + let ty_ref: &FnTy<'a> = arena.alloc(fun_def.fn_ty(arena)); + decls.push(GlobalDecl::FnDecl(name, ty_ref)); } Item::FunDecl(fun_decl) => { - GlobalDecl::FnDecl(fun_decl.ident.name.clone(), Box::new(fun_decl.fn_ty())) + let name: &str = &fun_decl.ident.name; + let ty_ref: &FnTy<'a> = arena.alloc(fun_decl.fn_ty(arena)); + decls.push(GlobalDecl::FnDecl(name, ty_ref)); } - Item::StructDecl(struct_decl) => GlobalDecl::StructDecl(struct_decl.clone()), - }) - .collect(); - decls.append(&mut compil_unit_decls); + Item::StructDecl(struct_decl) => { + // We can safely store the reference here, + // because `struct_decl` lives inside `compil_unit` for 'a. + decls.push(GlobalDecl::StructDecl(struct_decl)); + } + } + } + + // 3) now that we never held any &borrows of items, we can store the &mut GlobalCtx { compil_unit, - checked_funs: vec![], + checked_funs: BumpVec::new_in(arena), decls, } } @@ -623,26 +726,32 @@ impl<'src, 'compil> GlobalCtx<'src, 'compil> { pub fn has_been_checked(&self, name: &str, nat_args: &[usize]) -> bool { self.checked_funs .iter() - .any(|(fun_name, nargs)| fun_name.as_ref() == name && nargs.as_ref() == nat_args) + .any(|(fun_name, nargs)| *fun_name == name && *nargs == nat_args) } - pub fn push_fun_checked_under_nats(&mut self, fun_def: Box, nat_vals: Box<[usize]>) { - let fun_name = fun_def.ident.name.clone(); - self.compil_unit.items.push(Item::FunDef(fun_def)); - self.checked_funs.push((fun_name, nat_vals)) + pub fn push_fun_checked_under_nats( + &mut self, + arena: &'a bumpalo::Bump, + fun_def_owned: FunDef<'a>, + nat_vals: &'a [usize], + ) { + let fun_name = fun_def_owned.ident.name.clone(); + let fd_ref: &'a FunDef<'a> = arena.alloc(fun_def_owned); + self.compil_unit.items.push(Item::FunDef(fd_ref)); + self.checked_funs.push((fun_name, nat_vals)); } - pub fn pop_fun_def(&mut self, name: &str) -> Option> { + pub fn pop_fun_def(&mut self, name: &'a str) -> Option> { let index = self.compil_unit.items.iter().position(|item| { if let Item::FunDef(fun_def) = item { - fun_def.ident.name.as_ref() == name + fun_def.ident.name == name } else { false } }); if let Some(i) = index { if let Item::FunDef(fun_def) = self.compil_unit.items.remove(i) { - Some(fun_def) + Some((*fun_def).clone()) } else { None } @@ -651,7 +760,7 @@ impl<'src, 'compil> GlobalCtx<'src, 'compil> { } } - pub fn fn_ty_by_ident(&self, ident: &Ident) -> CtxResult<&FnTy> { + pub fn fn_ty_by_ident(&self, ident: &'a Ident<'a>) -> CtxResult<'a, &'a FnTy<'a>> { if let Some(fn_ty) = self.decls.iter().find_map(|decl| match decl { GlobalDecl::FnDecl(name, fn_ty) if name == &ident.name => Some(fn_ty), GlobalDecl::FnDecl(_, _) | GlobalDecl::StructDecl(_) => None, @@ -665,17 +774,23 @@ impl<'src, 'compil> GlobalCtx<'src, 'compil> { #[test] fn test_kill_place_ident() { - let mut ty_ctx = TyCtx::new(); - let x = IdentTyped::new( - Ident::new("x"), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), - Mutability::Const, - ExecExpr::new(ExecExprKind::new(BaseExec::Ident(Ident::new("exec")))), + let arena = Bump::new(); + + let mut ty_ctx = TyCtx::new(&arena); + let x_ident = Ident::new(&arena, "x"); + let exec_ident = Ident::new(&arena, "exec"); + let exec_expr = ExecExpr::new( + &arena, + ExecExprKind::new(&arena, BaseExec::Ident(exec_ident.clone())), ); - let place = internal::Place::new(x.ident.clone(), vec![]); - ty_ctx.append_ident_typed(x).kill_place(&place); + let scalar_ty = arena.alloc(DataTy::new(&arena, DataTyKind::Scalar(ScalarTy::I32))); + let ty_kind = TyKind::Data(scalar_ty); + let x_typed = IdentTyped::new_in(&arena, "x", Ty::new(ty_kind), Mutability::Const, exec_expr); + + ty_ctx.append_ident_typed(x_typed); + let place = internal::Place::new(x_ident.clone(), BumpVec::new_in(&arena)); + ty_ctx.kill_place(&place, &arena); + assert!(matches!( ty_ctx.idents_typed().next().unwrap().ty.dty(), DataTy { diff --git a/src/ty_check/error.rs b/src/ty_check/error.rs index 69174329..66d2428e 100644 --- a/src/ty_check/error.rs +++ b/src/ty_check/error.rs @@ -1,7 +1,7 @@ use super::Ty; -use crate::ast::internal::Place; -use crate::ast::printer::PrintState; -use crate::ast::{BaseExec, DataTy, Expr, Ident, NatEvalError, Ownership, PlaceExpr, TyKind}; +use crate::arena_ast::internal::Place; +use crate::arena_ast::printer::PrintState; +use crate::arena_ast::{BaseExec, DataTy, Expr, Ident, NatEvalError, Ownership, PlaceExpr, TyKind}; use crate::error; use crate::error::{default_format, ErrorReported}; use crate::parser::SourceCode; @@ -10,16 +10,16 @@ use annotate_snippets::snippet::{Annotation, AnnotationType, Slice, Snippet}; #[must_use] #[derive(Debug)] -pub enum TyError { - MultiError(Vec), - MutabilityNotAllowed(Ty), - CtxError(CtxError), - SubTyError(SubTyError), +pub enum TyError<'a> { + MultiError(Vec>), + MutabilityNotAllowed(Ty<'a>), + CtxError(CtxError<'a>), + SubTyError(SubTyError<'a>), // Standard data type mismatch, expected type followed by actual type - MismatchedDataTypes(DataTy, DataTy, Expr), + MismatchedDataTypes(DataTy<'a>, DataTy<'a>, Expr<'a>), // "Trying to violate existing borrow of {:?}.", // p1 under own1 is in conflict because of BorrowingError - ConflictingBorrow(Box, Ownership, BorrowingError), + ConflictingBorrow(Box>, Ownership, BorrowingError<'a>), PrvValueAlreadyInUse(String), // No loan the reference points to has a type that fits the reference element type ReferenceToIncompatibleType, @@ -29,15 +29,15 @@ pub enum TyError { // out from under the reference which is forbidden. ReferenceToDeadTy, // Assignment to a constant place expression. - AssignToConst(PlaceExpr), //, Box), + AssignToConst(PlaceExpr<'a>), //, Box), // Assigning to a view is forbidden AssignToView, // Trying to split a non-view array. SplittingNonViewArray, // Expected a different type - ExpectedTupleType(TyKind, PlaceExpr), + ExpectedTupleType(TyKind<'a>, PlaceExpr<'a>), // Trying to borrow uniquely but place is not mutable - ConstBorrow(PlaceExpr), + ConstBorrow(PlaceExpr<'a>), // The borrowed view type is at least paritally dead BorrowingDeadView, IllegalExec, @@ -53,23 +53,23 @@ pub enum TyError { UnexpectedType, // The thread hierarchy dimension referred to does not exist IllegalDimension, - UnifyError(UnifyError), + UnifyError(UnifyError<'a>), MissingMain, - NatEvalError(NatEvalError), - CannotInferGenericArg(Ident), + NatEvalError(NatEvalError<'a>), + CannotInferGenericArg(Ident<'a>), UnsafeRequired, // TODO remove as soon as possible String(String), } -impl<'a> FromIterator for TyError { - fn from_iter>(iter: T) -> Self { +impl<'a> FromIterator> for TyError<'a> { + fn from_iter>>(iter: T) -> Self { TyError::MultiError(iter.into_iter().collect()) } } -impl TyError { - pub fn emit(&self, source: &SourceCode) -> ErrorReported { +impl<'a> TyError<'a> { + pub fn emit(&self, source: &'a SourceCode<'a>) -> ErrorReported { match &self { TyError::MultiError(errs) => { for err in errs { @@ -229,31 +229,31 @@ impl TyError { } } -impl From for TyError { - fn from(err: CtxError) -> Self { +impl<'a> From> for TyError<'a> { + fn from(err: CtxError<'a>) -> Self { TyError::CtxError(err) } } -impl From for TyError { - fn from(err: SubTyError) -> Self { +impl<'a> From> for TyError<'a> { + fn from(err: SubTyError<'a>) -> Self { TyError::SubTyError(err) } } -impl From for TyError { - fn from(err: UnifyError) -> Self { +impl<'a> From> for TyError<'a> { + fn from(err: UnifyError<'a>) -> Self { TyError::UnifyError(err) } } -impl From for TyError { - fn from(err: NatEvalError) -> Self { +impl<'a> From> for TyError<'a> { + fn from(err: NatEvalError<'a>) -> Self { TyError::NatEvalError(err) } } #[must_use] #[derive(Debug)] -pub enum SubTyError { - CtxError(CtxError), +pub enum SubTyError<'a> { + CtxError(CtxError<'a>), // format!("{} lives longer than {}.", shorter, longer) NotOutliving(String, String), // format!("No loans bound to provenance.") @@ -268,79 +268,79 @@ pub enum SubTyError { #[must_use] #[derive(Debug)] -pub enum UnifyError { +pub enum UnifyError<'a> { // Cannot unify the two terms CannotUnify, // A type variable has to be equal to a term that is referring to the same type variable InfiniteType, - SubTyError(SubTyError), + SubTyError(SubTyError<'a>), } -impl From for UnifyError { - fn from(err: SubTyError) -> Self { +impl<'a> From> for UnifyError<'a> { + fn from(err: SubTyError<'a>) -> Self { UnifyError::SubTyError(err) } } #[must_use] #[derive(Debug)] -pub enum CtxError { +pub enum CtxError<'a> { //format!("Identifier: {} not found in context.", ident)), - IdentNotFound(Ident), + IdentNotFound(Ident<'a>), //"Cannot find identifier {} in kinding context", - KindedIdentNotFound(Ident), + KindedIdentNotFound(Ident<'a>), // "Typing Context is missing the provenance value {}", PrvValueNotFound(String), // format!("{} is not declared", prv_rel.longer)); - PrvIdentNotFound(Ident), - // format!("{} is not defined as outliving {}.", l, s) - OutlRelNotDefined(Ident, Ident), + PrvIdentNotFound(Ident<'a>), + // format!("{} is not de<'a>ined as outliving {}.", l, s) + OutlRelNotDefined(Ident<'a>, Ident<'a>), // TODO move to TyError IllegalProjection, } -impl From for SubTyError { - fn from(err: CtxError) -> Self { +impl<'a> From> for SubTyError<'a> { + fn from(err: CtxError<'a>) -> Self { SubTyError::CtxError(err) } } #[must_use] #[derive(Debug)] -pub enum BorrowingError { +pub enum BorrowingError<'a> { Conflict { - checked: PlaceExpr, - existing: PlaceExpr, + checked: PlaceExpr<'a>, + existing: PlaceExpr<'a>, }, - CtxError(CtxError), + CtxError(CtxError<'a>), // "Trying to use place expression with {} capability while it refers to a \ // loan with {} capability.", // checked_own, ref_own ConflictingOwnership, ConflictingAccess, // The borrowing place is not in the reborrow list - BorrowNotInReborrowList(Place), + BorrowNotInReborrowList(Place<'a>), TemporaryConflictingBorrow(String), - WrongDevice(BaseExec, BaseExec), + WrongDevice(BaseExec<'a>, BaseExec<'a>), MultipleDistribs, CannotNarrow, DivergingExec, - TyError(Box), - NatEvalError(NatEvalError), + TyError(Box>), + NatEvalError(NatEvalError<'a>), } -impl From for BorrowingError { - fn from(err: TyError) -> Self { +impl<'a> From> for BorrowingError<'a> { + fn from(err: TyError<'a>) -> Self { BorrowingError::TyError(Box::new(err)) } } -impl From for BorrowingError { - fn from(err: CtxError) -> Self { +impl<'a> From> for BorrowingError<'a> { + fn from(err: CtxError<'a>) -> Self { BorrowingError::CtxError(err) } } -impl From for BorrowingError { - fn from(err: NatEvalError) -> Self { +impl<'a> From> for BorrowingError<'a> { + fn from(err: NatEvalError<'a>) -> Self { BorrowingError::NatEvalError(err) } } diff --git a/src/ty_check/exec.rs b/src/ty_check/exec.rs index cb031d39..a352bfe6 100644 --- a/src/ty_check/exec.rs +++ b/src/ty_check/exec.rs @@ -2,94 +2,107 @@ use super::{ BaseExec, BinOpNat, Dim, Dim1d, Dim2d, DimCompo, ExecExpr, ExecPathElem, ExecTy, ExecTyKind, IdentExec, Nat, TyCtx, TyError, TyResult, }; -use crate::ast::{LeftOrRight, NatCtx}; +use crate::arena_ast::{ExecExprKind, LeftOrRight, NatCtx}; +use bumpalo::{collections::Vec as BumpVec, Bump}; -pub(super) fn ty_check( - nat_ctx: &NatCtx, - ty_ctx: &TyCtx, - ident_exec: Option<&IdentExec>, - exec_expr: &mut ExecExpr, -) -> TyResult<()> { - let mut exec_ty = match &exec_expr.exec.base { +pub(super) fn ty_check<'a>( + nat_ctx: &'a NatCtx<'a>, + ty_ctx: &'a TyCtx<'a>, + ident_exec: Option<&'a IdentExec<'a>>, + exec_expr: &mut ExecExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, ()> { + // 1) compute the base kind + let exec_kind = match &exec_expr.exec.base { BaseExec::Ident(ident) => { if let Some(ie) = ident_exec { if ident == &ie.ident { ie.ty.ty.clone() } else { - let inline_exec = ty_ctx.get_exec_expr_for_exec_ident(ident)?; - inline_exec.ty.as_ref().unwrap().ty.clone() + let inline = ty_ctx.get_exec_expr_for_exec_ident(ident)?; + inline.ty.as_ref().unwrap().ty.clone() } } else { return Err(TyError::IllegalExec); } } BaseExec::CpuThread => ExecTyKind::CpuThread, - BaseExec::GpuGrid(gdim, bdim) => ExecTyKind::GpuGrid(gdim.clone(), bdim.clone()), + BaseExec::GpuGrid(gdim, bdim) => ExecTyKind::GpuGrid((**gdim).clone(), (**bdim).clone()), }; - for e in &exec_expr.exec.path { - match e { - ExecPathElem::ForAll(d) => { - exec_ty = ty_check_exec_forall(*d, &exec_ty)?; - } - ExecPathElem::TakeRange(exec_split) => { - exec_ty = ty_check_exec_take_range( - exec_split.split_dim, - &exec_split.pos, - exec_split.left_or_right, - &exec_ty, - )?; - } - ExecPathElem::ToThreads(d) => { - exec_ty = ty_check_exec_to_threads(*d, &exec_ty)?; + // 2) Bump‐allocate that base kind so we have &'a ExecTyKind + let mut kind_ref: &'a ExecTyKind = &arena.alloc(ExecTy::new(exec_kind.clone())).ty; + + // 3) For each step, work entirely with arena‐owned refs + for step in &exec_expr.exec.path { + // call the helper on an arena‐live reference + let next_kind: ExecTyKind = match step { + ExecPathElem::ForAll(d) => ty_check_exec_forall(*d, kind_ref, arena)?, + ExecPathElem::TakeRange(sp) => { + ty_check_exec_take_range(sp.split_dim, &sp.pos, sp.left_or_right, kind_ref, arena)? } - ExecPathElem::ToWarps => exec_ty = ty_check_exec_to_warps(nat_ctx, &exec_ty)?, - } + ExecPathElem::ToThreads(d) => ty_check_exec_to_threads(*d, kind_ref, arena)?, + ExecPathElem::ToWarps => ty_check_exec_to_warps(nat_ctx, kind_ref, arena)?, + }; + // Now bump‐allocate the returned kind, update our ref + let boxed = arena.alloc(ExecTy::new(next_kind)); + kind_ref = &boxed.ty; } - exec_expr.ty = Some(Box::new(ExecTy::new(exec_ty))); + + // 4) Finally write out the fully‐elaborated type + exec_expr.ty = Some(arena.alloc(ExecTy::new(kind_ref.clone()))); + Ok(()) } -fn ty_check_exec_to_threads(d: DimCompo, exec_ty: &ExecTyKind) -> TyResult { +fn ty_check_exec_to_threads<'a>( + d: DimCompo, + exec_ty: &'a ExecTyKind<'a>, + arena: &'a Bump, +) -> TyResult<'a, ExecTyKind<'a>> { if let ExecTyKind::GpuGrid(gdim, bdim) = exec_ty { - let (rest_gdim, rem_gdim) = remove_dim(gdim, d)?; - let (rest_bdim, rem_bdim) = remove_dim(bdim, d)?; - let global_dim = match (rem_gdim, rem_bdim) { - (Dim::X(g), Dim::X(b)) => Dim::X(Box::new(Dim1d(Nat::BinOp( - BinOpNat::Mul, - Box::new(g.0), - Box::new(b.0), - )))), - (Dim::Y(g), Dim::Y(b)) => Dim::Y(Box::new(Dim1d(Nat::BinOp( - BinOpNat::Mul, - Box::new(g.0), - Box::new(b.0), - )))), - (Dim::Z(g), Dim::Z(b)) => Dim::Z(Box::new(Dim1d(Nat::BinOp( - BinOpNat::Mul, - Box::new(g.0), - Box::new(b.0), - )))), + let (rest_gdim, rem_gdim) = remove_dim(gdim, d, arena)?; + let (rest_bdim, rem_bdim) = remove_dim(bdim, d, arena)?; + + let global_dim: Dim<'a> = match (rem_gdim, rem_bdim) { + (Dim::X(g), Dim::X(b)) => { + let combined = Nat::new_binop_ref(arena, BinOpNat::Mul, g.0.clone(), b.0.clone()); + // use a zero-capture closure so that its inferred lifetime is `'a` + Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), combined) + } + (Dim::Y(g), Dim::Y(b)) => { + let combined = Nat::new_binop_ref(arena, BinOpNat::Mul, g.0.clone(), b.0.clone()); + Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), combined) + } + (Dim::Z(g), Dim::Z(b)) => { + let combined = Nat::new_binop_ref(arena, BinOpNat::Mul, g.0.clone(), b.0.clone()); + Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), combined) + } _ => { return Err(TyError::String(format!( - "Provided dimension {} does not exist", - d + "Cannot thread-map dimension {:?} on {:?}/{:?}", + d, gdim, bdim ))) } }; - match (rest_gdim, rest_bdim) { - (Some(rest_gdim), Some(rest_bdim)) => Ok(ExecTyKind::GpuToThreads( - global_dim, - Box::new(ExecTy::new(ExecTyKind::GpuBlockGrp(rest_gdim, rest_bdim))), - )), - _ => unimplemented!(), + + if let (Some(rg), Some(rb)) = (rest_gdim, rest_bdim) { + let inner = ExecTyKind::GpuBlockGrp(rg, rb); + let inner_ty: &'a ExecTy<'a> = arena.alloc(ExecTy::new(inner)); + Ok(ExecTyKind::GpuToThreads(global_dim, inner_ty)) + } else { + Err(TyError::UnexpectedType) } } else { Err(TyError::UnexpectedType) } } -fn ty_check_exec_to_warps(nat_ctx: &NatCtx, exec_ty: &ExecTyKind) -> TyResult { +/** +fn ty_check_exec_to_warps<'a>( + nat_ctx: &NatCtx, + exec_ty: &'a ExecTyKind<'a>, +) -> TyResult<'a, ExecTyKind<'a>> { match exec_ty { ExecTyKind::GpuBlock(dim) => match dim.clone() { Dim::X(d) => { @@ -117,32 +130,72 @@ fn ty_check_exec_to_warps(nat_ctx: &NatCtx, exec_ty: &ExecTyKind) -> TyResult( + nat_ctx: &'a NatCtx<'a>, + exec_ty: &'a ExecTyKind<'a>, + arena: &'a Bump, +) -> TyResult<'a, ExecTyKind<'a>> { + // Only valid if we're looking at a single‐dimension block + if let ExecTyKind::GpuBlock(block_dim) = exec_ty { + match block_dim { + Dim::X(d1) | Dim::Y(d1) | Dim::Z(d1) => { + // d1: &Dim1d<'a>, so d1.0: Nat<'a> + let len = d1.0.eval(nat_ctx)?; + if len % 32 != 0 { + return Err(TyError::String(format!( + "Block size must be divisible by 32 to form warps, got {} in {:?}", + len, exec_ty + ))); + } + // compute len / 32 in the arena + let warp_nat: Nat<'a> = + Nat::new_binop_ref(arena, BinOpNat::Div, d1.0.clone(), Nat::Lit(32)); + Ok(ExecTyKind::GpuWarpGrp(warp_nat)) + } + _ => Err(TyError::String(format!( + "GpuBlock must be 1-D to form warps, got {:?}", + exec_ty + ))), + } + } else { + Err(TyError::String(format!( + "Cannot form warps from non-GpuBlock type {:?}", + exec_ty + ))) + } +} -fn ty_check_exec_forall(d: DimCompo, exec_ty: &ExecTyKind) -> TyResult { +fn ty_check_exec_forall<'a>( + d: DimCompo, + exec_ty: &'a ExecTyKind<'a>, + arena: &'a Bump, +) -> TyResult<'a, ExecTyKind<'a>> { let res_ty = match exec_ty { ExecTyKind::GpuGrid(gdim, bdim) => { - let inner_dim = remove_dim(gdim, d)?.0; + let inner_dim = remove_dim(gdim, d, arena)?.0; match inner_dim { Some(dim) => ExecTyKind::GpuGrid(dim, bdim.clone()), None => ExecTyKind::GpuBlock(bdim.clone()), } } ExecTyKind::GpuBlockGrp(gdim, bdim) => { - let inner_dim = remove_dim(gdim, d)?.0; + let inner_dim = remove_dim(gdim, d, arena)?.0; match inner_dim { Some(dim) => ExecTyKind::GpuBlockGrp(dim, bdim.clone()), None => ExecTyKind::GpuBlock(bdim.clone()), } } ExecTyKind::GpuBlock(bdim) => { - let inner_dim = remove_dim(bdim, d)?.0; + let inner_dim = remove_dim(bdim, d, arena)?.0; match inner_dim { Some(dim) => ExecTyKind::GpuBlock(dim), None => ExecTyKind::GpuThread, } } ExecTyKind::GpuThreadGrp(tdim) => { - let inner_dim = remove_dim(tdim, d)?.0; + let inner_dim = remove_dim(tdim, d, arena)?.0; match inner_dim { Some(dim) => ExecTyKind::GpuThreadGrp(dim), None => ExecTyKind::GpuThread, @@ -154,8 +207,10 @@ fn ty_check_exec_forall(d: DimCompo, exec_ty: &ExecTyKind) -> TyResult = ExecTy::new(forall_inner); + let new_ref: &'a ExecTy<'a> = arena.alloc(new_ty); + ExecTyKind::GpuToThreads(dim.clone(), new_ref) } } ex @ ExecTyKind::CpuThread | ex @ ExecTyKind::GpuThread | ex @ ExecTyKind::Any => { @@ -165,7 +220,11 @@ fn ty_check_exec_forall(d: DimCompo, exec_ty: &ExecTyKind) -> TyResult TyResult<(Option, Dim)> { +/** +pub fn remove_dim<'a>( + dim: &'a Dim<'a>, + dim_compo: DimCompo, +) -> TyResult<'a, (Option>, Dim<'a>)> { match (dim, dim_compo) { (Dim::XYZ(dim3d), DimCompo::X) => Ok(( Some(Dim::YZ(Box::new(Dim2d( @@ -218,73 +277,183 @@ pub fn remove_dim(dim: &Dim, dim_compo: DimCompo) -> TyResult<(Option, Dim) _ => Err(TyError::IllegalDimension), } } +*/ + +pub fn remove_dim<'a>( + dim: &'a Dim<'a>, + dim_compo: DimCompo, + arena: &'a Bump, +) -> TyResult<'a, (Option>, Dim<'a>)> { + use DimCompo::*; + let result = match (dim, dim_compo) { + // 3D → leftover 2D + removed 1D + (Dim::XYZ(d3), X) => { + let rest = Dim::new_2d( + arena, + |d2: &'a Dim2d<'a>| Dim::YZ(d2), + d3.1.clone(), + d3.2.clone(), + ); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), d3.0.clone()); + (Some(rest), rem) + } + (Dim::XYZ(d3), Y) => { + let rest = Dim::new_2d( + arena, + |d2: &'a Dim2d<'a>| Dim::XZ(d2), + d3.0.clone(), + d3.2.clone(), + ); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), d3.1.clone()); + (Some(rest), rem) + } + (Dim::XYZ(d3), Z) => { + let rest = Dim::new_2d( + arena, + |d2: &'a Dim2d<'a>| Dim::XY(d2), + d3.0.clone(), + d3.1.clone(), + ); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), d3.2.clone()); + (Some(rest), rem) + } + + // 2D → leftover 1D + removed 1D + (Dim::XY(d2), X) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), d2.1.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), d2.0.clone()); + (Some(rest), rem) + } + (Dim::XY(d2), Y) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), d2.0.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), d2.1.clone()); + (Some(rest), rem) + } -fn ty_check_exec_take_range( + (Dim::XZ(d2), X) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), d2.1.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), d2.0.clone()); + (Some(rest), rem) + } + (Dim::XZ(d2), Z) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::X(d1), d2.0.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), d2.1.clone()); + (Some(rest), rem) + } + + (Dim::YZ(d2), Y) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), d2.1.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), d2.0.clone()); + (Some(rest), rem) + } + (Dim::YZ(d2), Z) => { + let rest = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Y(d1), d2.0.clone()); + let rem = Dim::new_1d(arena, |d1: &'a Dim1d<'a>| Dim::Z(d1), d2.1.clone()); + (Some(rest), rem) + } + + // 1D → nothing + same 1D + (Dim::X(d1), X) => ( + None, + Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::X(d), d1.0.clone()), + ), + (Dim::Y(d1), Y) => ( + None, + Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Y(d), d1.0.clone()), + ), + (Dim::Z(d1), Z) => ( + None, + Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Z(d), d1.0.clone()), + ), + + // anything else is illegal + _ => return Err(TyError::IllegalDimension), + }; + + Ok(result) +} + +fn ty_check_exec_take_range<'a>( d: DimCompo, - n: &Nat, + n: &'a Nat<'a>, proj: LeftOrRight, - exec_ty: &ExecTyKind, -) -> TyResult { - // TODO check well-formedness of Nats - let (lexec_ty, rexec_ty) = match exec_ty { + exec_ty: &'a ExecTyKind<'a>, + arena: &'a Bump, +) -> TyResult<'a, ExecTyKind<'a>> { + // split into two ExecTyKind variants, left & right + let (left_ty, right_ty) = match exec_ty { + // Grid or BlockGrp: we split the grid‐dim, keep block‐dim the same ExecTyKind::GpuGrid(gdim, bdim) | ExecTyKind::GpuBlockGrp(gdim, bdim) => { - let (ldim, rdim) = split_dim(d, n.clone(), gdim.clone())?; - ( - ExecTyKind::GpuBlockGrp(ldim, bdim.clone()), - ExecTyKind::GpuBlockGrp(rdim, bdim.clone()), - ) + let (ldim, rdim) = split_dim(d, n, gdim.clone(), arena)?; + let left = ExecTyKind::GpuBlockGrp(ldim, bdim.clone()); + let right = ExecTyKind::GpuBlockGrp(rdim, bdim.clone()); + (left, right) } + + // Block or ThreadGrp: split the single dim ExecTyKind::GpuBlock(dim) | ExecTyKind::GpuThreadGrp(dim) => { - let (ldim, rdim) = split_dim(d, n.clone(), dim.clone())?; - ( - ExecTyKind::GpuThreadGrp(ldim), - ExecTyKind::GpuThreadGrp(rdim), - ) + let (ldim, rdim) = split_dim(d, n, dim.clone(), arena)?; + let left = ExecTyKind::GpuThreadGrp(ldim); + let right = ExecTyKind::GpuThreadGrp(rdim); + (left, right) } - ExecTyKind::GpuToThreads(dim, inner) => { + + // ToThreads: either it splits on the ToThreads dim, or we descend into the inner ExecTy + ExecTyKind::GpuToThreads(dim, inner_ref) => { if dim_compo_matches_dim(d, dim) { - let (ldim, rdim) = split_dim(d, n.clone(), dim.clone())?; + // slice the ToThreads dimension itself + let (ldim, rdim) = split_dim(d, n, dim.clone(), arena)?; ( - ExecTyKind::GpuToThreads(ldim, inner.clone()), - ExecTyKind::GpuToThreads(rdim, inner.clone()), + ExecTyKind::GpuToThreads(ldim, inner_ref.clone()), + ExecTyKind::GpuToThreads(rdim, inner_ref.clone()), ) - } else if let ExecTyKind::GpuBlockGrp(gdim, bdim) = &inner.ty { - let (ldim, rdim) = split_dim(d, n.clone(), gdim.clone())?; + } else if let ExecTyKind::GpuBlockGrp(gdim2, bdim2) = &inner_ref.ty { + // otherwise split inside the inner block‐group + let (ldim, rdim) = split_dim(d, n, gdim2.clone(), arena)?; + + // bump‐allocate the two new inner ExecTy values + let left_inner = ExecTyKind::GpuBlockGrp(ldim, bdim2.clone()); + let right_inner = ExecTyKind::GpuBlockGrp(rdim, bdim2.clone()); + let left_ref: &'a ExecTy<'a> = arena.alloc(ExecTy::new(left_inner)); + let right_ref: &'a ExecTy<'a> = arena.alloc(ExecTy::new(right_inner)); + ( - ExecTyKind::GpuToThreads( - dim.clone(), - Box::new(ExecTy::new(ExecTyKind::GpuBlockGrp(ldim, bdim.clone()))), - ), - ExecTyKind::GpuToThreads( - dim.clone(), - Box::new(ExecTy::new(ExecTyKind::GpuBlockGrp(rdim, bdim.clone()))), - ), + ExecTyKind::GpuToThreads(dim.clone(), left_ref), + ExecTyKind::GpuToThreads(dim.clone(), right_ref), ) } else { panic!("GpuToThreads is not well-formed.") } } - ex => { + + other => { return Err(TyError::String(format!( "Trying to split non-splittable execution resource: {:?}", - ex + other ))) } }; + + // pick the left or right projection Ok(if proj == LeftOrRight::Left { - lexec_ty + left_ty } else { - rexec_ty + right_ty }) } -fn dim_compo_matches_dim(d: DimCompo, dim: &Dim) -> bool { +fn dim_compo_matches_dim<'a>(d: DimCompo, dim: &'a Dim<'a>) -> bool { (matches!(dim, Dim::X(_)) && d == DimCompo::X) | (matches!(dim, Dim::Y(_)) && d == DimCompo::Y) | (matches!(dim, Dim::Z(_)) && d == DimCompo::Z) } -fn split_dim(split_dim: DimCompo, pos: Nat, dim: Dim) -> TyResult<(Dim, Dim)> { +/** +fn split_dim<'a>( + split_dim: DimCompo, + pos: Nat<'a>, + dim: Dim<'a>, +) -> TyResult<'a, (Dim<'a>, Dim<'a>)> { Ok(match dim { Dim::XYZ(d) => match split_dim { DimCompo::X => ( @@ -410,8 +579,176 @@ fn split_dim(split_dim: DimCompo, pos: Nat, dim: Dim) -> TyResult<(Dim, Dim)> { } }) } +*/ + +pub fn split_dim<'a>( + split_dim: DimCompo, + pos: &'a Nat<'a>, + dim: Dim<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Dim<'a>, Dim<'a>)> { + use DimCompo::*; + + let result = match dim { + // 3D case: peel off one axis, leave a 3D on both sides + Dim::XYZ(d3) => match split_dim { + X => { + let left = Dim::new_3d(arena, d3.0.clone(), d3.1.clone(), d3.2.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d3.0.clone(), pos.clone()); + let right = Dim::new_3d(arena, right_nat, d3.1.clone(), d3.2.clone()); + (left, right) + } + Y => { + let left = Dim::new_3d(arena, d3.0.clone(), d3.1.clone(), d3.2.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d3.1.clone(), pos.clone()); + let right = Dim::new_3d(arena, d3.0.clone(), right_nat, d3.2.clone()); + (left, right) + } + Z => { + let left = Dim::new_3d(arena, d3.0.clone(), d3.1.clone(), d3.2.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d3.2.clone(), pos.clone()); + let right = Dim::new_3d(arena, d3.0.clone(), d3.1.clone(), right_nat); + (left, right) + } + }, + + // 2D cases: peel off one axis, leave a 2D on both sides + Dim::XY(d2) => match split_dim { + X => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XY(d), + pos.clone(), + d2.1.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.0.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XY(d), + right_nat, + d2.1.clone(), + ); + (left, right) + } + Y => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XY(d), + d2.0.clone(), + pos.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.1.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XY(d), + d2.0.clone(), + right_nat, + ); + (left, right) + } + Z => return Err(TyError::IllegalDimension), + }, + + Dim::XZ(d2) => match split_dim { + X => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XZ(d), + pos.clone(), + d2.1.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.0.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XZ(d), + right_nat, + d2.1.clone(), + ); + (left, right) + } + Z => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XZ(d), + d2.0.clone(), + pos.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.1.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::XZ(d), + d2.0.clone(), + right_nat, + ); + (left, right) + } + Y => return Err(TyError::IllegalDimension), + }, + + Dim::YZ(d2) => match split_dim { + Y => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::YZ(d), + pos.clone(), + d2.1.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.0.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::YZ(d), + right_nat, + d2.1.clone(), + ); + (left, right) + } + Z => { + let left = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::YZ(d), + d2.0.clone(), + pos.clone(), + ); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d2.1.clone(), pos.clone()); + let right = Dim::new_2d( + arena, + |d: &'a Dim2d<'a>| Dim::YZ(d), + d2.0.clone(), + right_nat, + ); + (left, right) + } + X => return Err(TyError::IllegalDimension), + }, -pub(super) fn normalize(mut exec: ExecExpr) -> ExecExpr { + // 1D cases: peel off the only axis + Dim::X(d1) if split_dim == X => { + let left = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::X(d), pos.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d1.0.clone(), pos.clone()); + let right = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::X(d), right_nat); + (left, right) + } + Dim::Y(d1) if split_dim == Y => { + let left = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Y(d), pos.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d1.0.clone(), pos.clone()); + let right = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Y(d), right_nat); + (left, right) + } + Dim::Z(d1) if split_dim == Z => { + let left = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Z(d), pos.clone()); + let right_nat = Nat::new_binop_ref(arena, BinOpNat::Sub, d1.0.clone(), pos.clone()); + let right = Dim::new_1d(arena, |d: &'a Dim1d<'a>| Dim::Z(d), right_nat); + (left, right) + } + + _ => return Err(TyError::IllegalDimension), + }; + + Ok(result) +} + +/** +pub(super) fn normalize<'a>(mut exec: ExecExpr<'a>) -> ExecExpr<'a> { assert!(exec.ty.is_some()); let mut exec_path = exec.exec.path; if !exec_path.is_empty() { @@ -421,10 +758,42 @@ pub(super) fn normalize(mut exec: ExecExpr) -> ExecExpr { exec.exec.path = exec_path; exec } +*/ + +pub(super) fn normalize<'a>(old: ExecExpr<'a>, arena: &'a bumpalo::Bump) -> ExecExpr<'a> { + assert!(old.ty.is_some()); + + // 1) Copy the old path into a fresh bump-Vec + let mut new_path: BumpVec<'a, ExecPathElem<'a>> = BumpVec::new_in(arena); + new_path.extend(old.exec.path.iter().cloned()); + + // 2) Sort if needed + if !new_path.is_empty() { + // We can still use the slice-based sorter: + let slice = new_path.as_mut_slice(); + let boundaries = level_boundaries(slice); + sort_within_boundaries(slice, &boundaries); + } + + // 3) Build a new ExecExprKind<'a> in the arena + let new_exec_kind = ExecExprKind { + base: old.exec.base.clone(), + path: new_path, + }; + let new_exec_ref: &'a ExecExprKind<'a> = arena.alloc(new_exec_kind); + + // 4) Return a fresh ExecExpr pointing at it, but re‐use the old ty/span + ExecExpr { + exec: new_exec_ref, + ty: old.ty, + span: old.span, + } +} +// Allocating a handful of usizes on the heap here is negligible // FIXME: not correct if first take_range on dimension of lower level followed by forall on different dimension in upper level // for fix: see formalism -fn level_boundaries(exec_path: &[ExecPathElem]) -> Vec { +fn level_boundaries<'a>(exec_path: &'a [ExecPathElem<'a>]) -> Vec { let mut forall_dims_encountered = Vec::with_capacity(3); let mut boundaries = Vec::with_capacity(3); for (i, elem) in exec_path.iter().enumerate() { @@ -454,7 +823,8 @@ fn level_boundaries(exec_path: &[ExecPathElem]) -> Vec { boundaries } -fn sort_within_boundaries(exec_path: &mut Vec, boundaries: &[usize]) { +/** +fn sort_within_boundaries<'a>(exec_path: &'a mut Vec>, boundaries: &[usize]) { let mut lower_bound = 0; for b in boundaries { for i in lower_bound..*b { @@ -467,8 +837,23 @@ fn sort_within_boundaries(exec_path: &mut Vec, boundaries: &[usize lower_bound = *b; } } +*/ + +fn sort_within_boundaries<'a>(exec_path: &mut [ExecPathElem<'a>], boundaries: &[usize]) { + let mut lower = 0; + for &b in boundaries { + for i in lower..b { + for j in lower..(b - 1) { + if swappable_exec_path_elems(&exec_path[j], &exec_path[j + 1]) { + exec_path.swap(j, j + 1); + } + } + } + lower = b; + } +} -fn swappable_exec_path_elems(lhs: &ExecPathElem, rhs: &ExecPathElem) -> bool { +fn swappable_exec_path_elems<'a>(lhs: &'a ExecPathElem<'a>, rhs: &'a ExecPathElem<'a>) -> bool { match (lhs, rhs) { (ExecPathElem::ForAll(dl), ExecPathElem::ForAll(dr)) => dl > dr, (ExecPathElem::ForAll(_), ExecPathElem::TakeRange(_)) => true, diff --git a/src/ty_check/infer_kinded_args.rs b/src/ty_check/infer_kinded_args.rs index 94bb916c..56008e1b 100644 --- a/src/ty_check/infer_kinded_args.rs +++ b/src/ty_check/infer_kinded_args.rs @@ -1,8 +1,9 @@ use super::{TyError, TyResult}; -use crate::ast::{ +use crate::arena_ast::{ ArgKinded, BaseExec, DataTy, DataTyKind, Dim, ExecExpr, ExecTy, ExecTyKind, FnTy, Ident, Memory, Nat, ParamSig, Provenance, Ty, TyKind, }; +use bumpalo::collections::Vec as BumpVec; use std::collections::HashMap; // mono_ty is function type, @@ -10,30 +11,41 @@ use std::collections::HashMap; // introduced by the polymorphic function, therefore finding an identifier on the poly type // means that it was introduced by the polymorphic function (even though the identifier may be an // instantiation of a bound identifier -pub fn infer_kinded_args(poly_fn_ty: &FnTy, mono_fn_ty: &FnTy) -> TyResult> { +pub fn infer_kinded_args<'a>( + poly_fn_ty: &'a FnTy<'a>, + mono_fn_ty: &'a FnTy<'a>, + arena: &'a bumpalo::Bump, +) -> TyResult<'a, BumpVec<'a, ArgKinded<'a>>> { if poly_fn_ty.param_sigs.len() != mono_fn_ty.param_sigs.len() { - panic!("Unexpected difference in amount of paramters.") + panic!("Unexpected difference in amount of parameters."); } - let mut res_map = HashMap::new(); - for (subst_ty, mono_ty) in poly_fn_ty.param_sigs.iter().zip(&mono_fn_ty.param_sigs) { - infer_kargs_param_sig(&mut res_map, subst_ty, mono_ty) + + let mut res_map: HashMap<&'a Ident<'a>, ArgKinded<'a>> = HashMap::new(); + + for (subst_ps, mono_ps) in poly_fn_ty.param_sigs.iter().zip(&mono_fn_ty.param_sigs) { + infer_kargs_param_sig(&mut res_map, subst_ps, mono_ps); } + infer_kargs_exec_expr(&mut res_map, &poly_fn_ty.exec, &mono_fn_ty.exec); - infer_kargs_tys(&mut res_map, &poly_fn_ty.ret_ty, &mono_fn_ty.ret_ty); - let mut res_vec = Vec::new(); - for gen_arg in &poly_fn_ty.generics { - // FIXME unwrap leads to panic when the value for ident could not be inferred - // as does happen when the identifier is not used in the argument type or part of - // an expression in the case of nats - if let Some(res_karg) = res_map.get(&gen_arg.ident) { - if gen_arg.kind != res_karg.kind() { - panic!("Unexpected: Kinds of identifier and argument do not match.") + infer_kargs_tys(&mut res_map, poly_fn_ty.ret_ty, mono_fn_ty.ret_ty); + + let mut res_vec = BumpVec::new_in(arena); + res_vec.reserve(poly_fn_ty.generics.len()); + + for g in &poly_fn_ty.generics { + match res_map.remove(&g.ident) { + Some(arg) => { + if g.kind != arg.kind() { + panic!("Unexpected: Kinds of identifier and argument do not match."); + } + res_vec.push(arg); + } + None => { + return Err(TyError::CannotInferGenericArg(g.ident.clone())); } - res_vec.push(res_karg.clone()); - } else { - return Err(TyError::CannotInferGenericArg(gen_arg.ident.clone())); } } + Ok(res_vec) } @@ -48,7 +60,7 @@ macro_rules! infer_from_lists { macro_rules! insert_checked { ($map: expr, $constr: path, $id_ref: expr, $mono_ref: expr) => {{ let arg_kinded = $constr($mono_ref.clone()); - if let Some(old) = $map.insert($id_ref.clone(), arg_kinded.clone()) { + if let Some(old) = $map.insert($id_ref, arg_kinded.clone()) { if old != arg_kinded { println!("old: {:?}", old); println!("new: {:?}", arg_kinded); @@ -63,6 +75,7 @@ macro_rules! panic_no_inst { panic!("Unexpected: mono type is not an instantiation of poly type") }; } + macro_rules! panic_if_neq { ($lhs: expr, $rhs: expr) => { if $lhs != $rhs { @@ -71,7 +84,11 @@ macro_rules! panic_if_neq { }; } -fn infer_kargs_tys(map: &mut HashMap, poly_ty: &Ty, mono_ty: &Ty) { +fn infer_kargs_tys<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_ty: &'a Ty<'a>, + mono_ty: &'a Ty<'a>, +) { match (&poly_ty.ty, &mono_ty.ty) { (TyKind::Data(dty1), TyKind::Data(dty2)) => infer_kargs_dtys(map, dty1, dty2), (TyKind::FnTy(fn_ty1), TyKind::FnTy(fn_ty2)) => { @@ -94,19 +111,19 @@ fn infer_kargs_tys(map: &mut HashMap, poly_ty: &Ty, mono_ty: & } } -fn infer_kargs_param_sig( - map: &mut HashMap, - poly_param_sig: &ParamSig, - mono_param_sig: &ParamSig, +fn infer_kargs_param_sig<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_param_sig: &'a ParamSig<'a>, + mono_param_sig: &'a ParamSig<'a>, ) { infer_kargs_exec_expr(map, &poly_param_sig.exec_expr, &mono_param_sig.exec_expr); infer_kargs_tys(map, &poly_param_sig.ty, &mono_param_sig.ty); } -fn infer_kargs_exec_expr( - map: &mut HashMap, - poly_exec_expr: &ExecExpr, - mono_exec_expr: &ExecExpr, +fn infer_kargs_exec_expr<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_exec_expr: &'a ExecExpr<'a>, + mono_exec_expr: &'a ExecExpr<'a>, ) { match (&poly_exec_expr.exec.base, &mono_exec_expr.exec.base) { (BaseExec::Ident(i1), BaseExec::Ident(i2)) if i1 == i2 => (), @@ -119,10 +136,10 @@ fn infer_kargs_exec_expr( } } -fn infer_kargs_exec_level( - map: &mut HashMap, - poly_exec_level: &ExecTy, - mono_exec_level: &ExecTy, +fn infer_kargs_exec_level<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_exec_level: &'a ExecTy<'a>, + mono_exec_level: &'a ExecTy<'a>, ) { match (&poly_exec_level.ty, &mono_exec_level.ty) { (ExecTyKind::GpuGrid(gdim1, bdim1), ExecTyKind::GpuGrid(gdim2, bdim2)) @@ -141,7 +158,11 @@ fn infer_kargs_exec_level( } } -fn infer_kargs_dims(map: &mut HashMap, poly_dim: &Dim, mono_dim: &Dim) { +fn infer_kargs_dims<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_dim: &'a Dim<'a>, + mono_dim: &'a Dim<'a>, +) { match (poly_dim, mono_dim) { (Dim::XYZ(d3d1), Dim::XYZ(d3d2)) => { infer_kargs_nats(map, &d3d1.0, &d3d2.0); @@ -163,15 +184,19 @@ fn infer_kargs_dims(map: &mut HashMap, poly_dim: &Dim, mono_di } } -fn infer_kargs_field( - map: &mut HashMap, - poly_field: &(Ident, DataTy), - mono_field: &(Ident, DataTy), +fn infer_kargs_field<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_field: &'a (Ident<'a>, DataTy<'a>), + mono_field: &'a (Ident<'a>, DataTy<'a>), ) { infer_kargs_dtys(map, &poly_field.1, &mono_field.1) } -fn infer_kargs_dtys(map: &mut HashMap, poly_dty: &DataTy, mono_dty: &DataTy) { +fn infer_kargs_dtys<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_dty: &'a DataTy<'a>, + mono_dty: &'a DataTy<'a>, +) { match (&poly_dty.dty, &mono_dty.dty) { (DataTyKind::Ident(id), _) => insert_checked!(map, ArgKinded::DataTy, id, mono_dty), (DataTyKind::Scalar(sty1), DataTyKind::Scalar(sty2)) => { @@ -215,12 +240,14 @@ fn infer_kargs_dtys(map: &mut HashMap, poly_dty: &DataTy, mono } } -fn infer_kargs_nats(map: &mut HashMap, poly_nat: &Nat, mono_nat: &Nat) { +fn infer_kargs_nats<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_nat: &'a Nat<'a>, + mono_nat: &'a Nat<'a>, +) { match (poly_nat, mono_nat) { (Nat::Ident(id), _) => { - if let Some(ArgKinded::Nat(old)) = - map.insert(id.clone(), ArgKinded::Nat(mono_nat.clone())) - { + if let Some(ArgKinded::Nat(old)) = map.insert(id, ArgKinded::Nat(mono_nat.clone())) { if &old != mono_nat { panic!( "not able to check equality of Nats `{}` and `{}`", @@ -243,17 +270,21 @@ fn infer_kargs_nats(map: &mut HashMap, poly_nat: &Nat, mono_na } } -fn infer_kargs_mems(map: &mut HashMap, poly_mem: &Memory, mono_mem: &Memory) { +fn infer_kargs_mems<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_mem: &'a Memory<'a>, + mono_mem: &'a Memory<'a>, +) { match (poly_mem, mono_mem) { (Memory::Ident(id), _) => insert_checked!(map, ArgKinded::Memory, id, mono_mem), _ => panic_if_neq!(poly_mem, mono_mem), } } -fn infer_kargs_prvs( - map: &mut HashMap, - poly_prv: &Provenance, - mono_prv: &Provenance, +fn infer_kargs_prvs<'m, 'a>( + map: &'m mut HashMap<&'a Ident<'a>, ArgKinded<'a>>, + poly_prv: &'a Provenance<'a>, + mono_prv: &'a Provenance<'a>, ) { match (poly_prv, mono_prv) { (Provenance::Ident(id), _) => insert_checked!(map, ArgKinded::Provenance, id, mono_prv), diff --git a/src/ty_check/mod.rs b/src/ty_check/mod.rs index 09511b08..b3b68b63 100644 --- a/src/ty_check/mod.rs +++ b/src/ty_check/mod.rs @@ -9,26 +9,29 @@ mod subty; mod unify; use self::pl_expr::PlExprTyCtx; -use crate::ast::internal::{Frame, IdentTyped, Loan, Place, PrvMapping}; -use crate::ast::utils; -use crate::ast::*; +use crate::arena_ast::internal::{Frame, IdentTyped, Loan, Place, PrvMapping}; +use crate::arena_ast::{utils}; +use crate::arena_ast::*; use crate::error::ErrorReported; +use bumpalo::collections::CollectIn; +use bumpalo::collections::Vec as BumpVec; +use bumpalo::Bump; use ctxs::{AccessCtx, GlobalCtx, KindCtx, TyCtx}; use error::*; use std::collections::HashSet; -type TyResult = Result; +type TyResult<'a, T> = Result>; macro_rules! matches_dty { ($ty: expr, $dty_pat: pat_param) => { - if let crate::ast::TyKind::Data(d) = &$ty.ty { - matches!(d.as_ref(), $dty_pat) + if let crate::arena_ast::TyKind::Data(d) = &$ty.ty { + matches!(d, $dty_pat) } else { false } }; } -use crate::ast::printer::PrintState; +use crate::arena_ast::printer::PrintState; use crate::ty_check::borrow_check::BorrowCheckCtx; use crate::ty_check::ctxs::GlobalDecl; pub(crate) use matches_dty; @@ -36,21 +39,25 @@ pub(crate) use matches_dty; // ∀ε ∈ Σ. Σ ⊢ ε // -------------- // ⊢ Σ -pub fn ty_check(compil_unit: &mut CompilUnit) -> Result<(), ErrorReported> { +/** +pub fn ty_check<'a>(compil_unit: &'a mut CompilUnit<'a>, arena: &'a Bump) -> Result<(), ErrorReported> { let mut gl_ctx = GlobalCtx::new( compil_unit, - pre_decl::fun_decls() + pre_decl::fun_decls(arena) .into_iter() - .map(|(fname, fty)| GlobalDecl::FnDecl(Box::from(fname), Box::new(fty))) - .collect(), + .map(|(fname, fty)| GlobalDecl::FnDecl(fname, arena.alloc(fty))) + .collect_in(arena), + arena ); - let mut nat_ctx = NatCtx::new(); - if let Some(mut main_fun) = gl_ctx.pop_fun_def("main") { - if let Err(err) = ty_check_global_fun_def(&mut gl_ctx, &mut nat_ctx, &mut main_fun) { + let mut nat_ctx = NatCtx::new(arena); + if let Some(main_fun) = &mut gl_ctx.pop_fun_def("main") { + if let Err(err) = ty_check_global_fun_def(&mut gl_ctx, &mut nat_ctx, main_fun, arena) + { + drop(gl_ctx); err.emit(compil_unit.source); Err(ErrorReported) } else { - gl_ctx.push_fun_checked_under_nats(main_fun, Box::from(vec![])); + gl_ctx.push_fun_checked_under_nats(arena, (*main_fun).clone(), &[]); Ok(()) } } else { @@ -58,36 +65,93 @@ pub fn ty_check(compil_unit: &mut CompilUnit) -> Result<(), ErrorReported> { Err(ErrorReported) } } +*/ + + +pub fn ty_check<'a>( + compil_unit: &'a mut CompilUnit<'a>, + arena: &'a Bump, +) -> Result<(), ErrorReported> { + // Run the pass while holding &mut CompilUnit via gl_ctx. + // Return a TyError if something goes wrong, but DON'T emit yet. + let result: Result<(), TyError<'a>> = { + let predecls = pre_decl::fun_decls(arena) + .into_iter() + .map(|(fname, fty)| GlobalDecl::FnDecl(fname, arena.alloc(fty))) + .collect_in(arena); + + let mut gl_ctx = GlobalCtx::new(compil_unit, predecls, arena); + let mut nat_ctx = NatCtx::new(arena); + + if let Some(mut main_fun) = gl_ctx.pop_fun_def("main") { + match ty_check_global_fun_def(&mut gl_ctx, &mut nat_ctx, &mut main_fun, arena) { + Ok(()) => { + gl_ctx.push_fun_checked_under_nats(arena, main_fun, &[]); + Ok(()) + } + Err(err) => Err(err), + } + } else { + Err(TyError::MissingMain) + } + }; // <- gl_ctx (& its &mut CompilUnit borrow) is dropped here + + // Now it's safe to immutably borrow compil_unit.source to emit. + match result { + Ok(()) => Ok(()), + Err(err) => { + err.emit(compil_unit.source); + Err(ErrorReported) + } + } +} + -struct ExprTyCtx<'src, 'compil, 'ctxt> { - gl_ctx: &'ctxt mut GlobalCtx<'src, 'compil>, - nat_ctx: &'ctxt mut NatCtx, - ident_exec: Option<&'ctxt IdentExec>, - kind_ctx: &'ctxt mut KindCtx, - exec: ExecExpr, - ty_ctx: &'ctxt mut TyCtx, - access_ctx: &'ctxt mut AccessCtx, +struct ExprTyCtx<'a> { + gl_ctx: &'a mut GlobalCtx<'a>, + nat_ctx: &'a mut NatCtx<'a>, + ident_exec: Option<&'a IdentExec<'a>>, + kind_ctx: &'a mut KindCtx<'a>, + exec: ExecExpr<'a>, + ty_ctx: &'a mut TyCtx<'a>, + access_ctx: &'a mut AccessCtx<'a>, unsafe_flag: bool, } // Σ ⊢ fn f (x1: τ1, ..., xn: τn) → τr where List[ρ1:ρ2] { e } -fn ty_check_global_fun_def( +fn ty_check_global_fun_def<'a>( gl_ctx: &mut GlobalCtx, nat_ctx: &mut NatCtx, - gf: &mut FunDef, -) -> TyResult<()> { + gf: &mut FunDef<'a>, + arena: &'a Bump, +) -> TyResult<'a, ()> { // TODO check that every prv_rel only uses provenance variables bound in generic_params - let mut kind_ctx = KindCtx::gl_fun_kind_ctx(gf.generic_params.clone(), gf.prv_rels.clone())?; - let mut ty_ctx = TyCtx::new(); + let mut kind_ctx = + KindCtx::gl_fun_kind_ctx(gf.generic_params.clone(), gf.prv_rels.clone(), arena)?; + let mut ty_ctx = TyCtx::new(arena); // Build frame typing for this function // TODO give Frame its own type and move this into frame and/or ParamDecl if let Some(ident_exec) = &gf.generic_exec { - let mut exec_ident = - ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); - exec::ty_check(nat_ctx, &ty_ctx, gf.generic_exec.as_ref(), &mut exec_ident)?; + let mut exec_ident = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + exec::ty_check( + nat_ctx, + &ty_ctx, + gf.generic_exec.as_ref(), + &mut exec_ident, + arena, + )?; ty_ctx.append_exec_mapping(ident_exec.ident.clone(), exec_ident); } - exec::ty_check(nat_ctx, &ty_ctx, gf.generic_exec.as_ref(), &mut gf.exec)?; + exec::ty_check( + nat_ctx, + &ty_ctx, + gf.generic_exec.as_ref(), + &mut gf.exec, + arena, + )?; let param_idents_ty = gf .param_decls @@ -100,7 +164,7 @@ fn ty_check_global_fun_def( exec_expr, }| { let mut exec = exec_expr.as_ref().unwrap_or(&gf.exec).clone(); - exec::ty_check(nat_ctx, &ty_ctx, gf.generic_exec.as_ref(), &mut exec)?; + exec::ty_check(nat_ctx, &ty_ctx, gf.generic_exec.as_ref(), &mut exec, arena)?; Ok(IdentTyped { ident: ident.clone(), ty: ty.as_ref().unwrap().clone(), @@ -117,7 +181,7 @@ fn ty_check_global_fun_def( ty_ctx.append_prv_mapping(PrvMapping::new(prv)); } - let mut access_ctx = AccessCtx::new(); + let mut access_ctx = AccessCtx::new(arena); let mut ctx = ExprTyCtx { gl_ctx: &mut *gl_ctx, nat_ctx: &mut *nat_ctx, @@ -129,7 +193,7 @@ fn ty_check_global_fun_def( unsafe_flag: false, }; - ty_check_expr(&mut ctx, &mut gf.body.body)?; + ty_check_expr(&mut ctx, &mut gf.body.body, arena)?; // t <= t_f // unify::constrain( // gf.body_expr.ty.as_ref().unwrap(), @@ -137,12 +201,13 @@ fn ty_check_global_fun_def( // )?; //coalesce::coalesce_ty(&mut self.term_constr.constr_map, &mut body_ctx, ) - let mut empty_ty_ctx = TyCtx::new(); + let mut empty_ty_ctx = TyCtx::new(arena); subty::check( &kind_ctx, &mut empty_ty_ctx, gf.body.body.ty.as_ref().unwrap().dty(), &gf.ret_dty, + arena, )?; #[cfg(debug_assertions)] @@ -162,100 +227,110 @@ fn ty_check_global_fun_def( // type τ is well-formed under well-formed GlFunCtxt, kinding ctx, output context Γ'. // Σ; Δ; Γ ⊢ e :^exec τ ⇒ Γ′, side conditions: ⊢ Σ;Δ;Γ and Σ;Δ;Γ′ ⊢ τ // This never returns a dead type, because typing an expression with a dead type is not possible. -fn ty_check_expr(ctx: &mut ExprTyCtx, expr: &mut Expr) -> TyResult<()> { +fn ty_check_expr<'a>( + ctx: &'a mut ExprTyCtx, + expr: &'a mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, ()> { let ty = match &mut expr.expr { ExprKind::PlaceExpr(pl_expr) => { if pl_expr.is_place() { - ty_check_place(ctx, pl_expr)? + ty_check_place(ctx, pl_expr, arena)? } else { - ty_check_non_place(ctx, pl_expr)? + ty_check_non_place(ctx, pl_expr, arena)? } } - ExprKind::Block(block) => ty_check_block(ctx, block)?, - ExprKind::Let(pattern, ty, e) => ty_check_let(ctx, pattern, ty, e)?, + ExprKind::Block(block) => ty_check_block(ctx, block, arena)?, + ExprKind::Let(pattern, ty, e) => ty_check_let(ctx, pattern, ty, e, arena)?, ExprKind::LetUninit(annot_exec, ident, ty) => { - ty_check_let_uninit(ctx, annot_exec, ident, ty)? + ty_check_let_uninit(ctx, annot_exec, ident, ty, arena)? } - ExprKind::Seq(es) => ty_check_seq(ctx, es)?, - ExprKind::Lit(l) => ty_check_literal(l), - ExprKind::Array(elems) => ty_check_array(ctx, elems)?, - ExprKind::Tuple(elems) => ty_check_tuple(ctx, elems)?, + ExprKind::Seq(es) => ty_check_seq(ctx, es, arena)?, + ExprKind::Lit(l) => ty_check_literal(l, arena), + ExprKind::Array(elems) => ty_check_array(ctx, elems, arena)?, + ExprKind::Tuple(elems) => ty_check_tuple(ctx, elems, arena)?, // ExprKind::Proj(e, i) => ty_check_proj(ctx, e, *i)?, - ExprKind::App(fn_ident, gen_args, args) => ty_check_app(ctx, fn_ident, gen_args, args)?, - ExprKind::DepApp(fn_ident, gen_args) => Ty::new(TyKind::FnTy(Box::new(ty_check_dep_app( - ctx, fn_ident, gen_args, - )?))), - ExprKind::AppKernel(app_kernel) => ty_check_app_kernel(ctx, app_kernel)?, - ExprKind::Ref(prv, own, pl_expr) => ty_check_borrow(ctx, prv, *own, pl_expr)?, + ExprKind::App(fn_ident, gen_args, args) => ty_check_app(ctx, fn_ident, gen_args, args, arena)?, + ExprKind::DepApp(fn_ident, gen_args) => Ty::new(TyKind::FnTy( + arena.alloc(ty_check_dep_app(ctx, fn_ident, gen_args, arena)?), + )), + ExprKind::AppKernel(app_kernel) => ty_check_app_kernel(ctx, app_kernel, arena)?, + ExprKind::Ref(prv, own, pl_expr) => ty_check_borrow(ctx, prv, *own, pl_expr, arena)?, ExprKind::Assign(pl_expr, e) => { if pl_expr.is_place() { - ty_check_assign_place(ctx, pl_expr, e)? + ty_check_assign_place(ctx, pl_expr, e, arena)? } else { - ty_check_assign_non_place(ctx, pl_expr, e)? + ty_check_assign_non_place(ctx, pl_expr, e, arena)? } } - ExprKind::IdxAssign(pl_expr, idx, e) => ty_check_idx_assign(ctx, pl_expr, idx, e)?, - ExprKind::Split(split) => ty_check_split(ctx, split)?, - ExprKind::Sched(sched) => ty_check_sched(ctx, sched)?, - ExprKind::ForNat(var, range, body) => ty_check_for_nat(ctx, var, range, body)?, - ExprKind::For(ident, collec, body) => ty_check_for(ctx, ident, collec, body)?, + ExprKind::IdxAssign(pl_expr, idx, e) => ty_check_idx_assign(ctx, pl_expr, idx, e, arena)?, + ExprKind::Split(split) => ty_check_split(ctx, split, arena)?, + ExprKind::Sched(sched) => ty_check_sched(ctx, sched, arena)?, + ExprKind::ForNat(var, range, body) => ty_check_for_nat(ctx, var, range, body, arena)?, + ExprKind::For(ident, collec, body) => ty_check_for(ctx, ident, collec, body, arena)?, ExprKind::IfElse(cond, case_true, case_false) => { - ty_check_if_else(ctx, cond, case_true, case_false)? + ty_check_if_else(ctx, cond, case_true, case_false, arena)? } - ExprKind::If(cond, case_true) => ty_check_if(ctx, cond, case_true)?, - ExprKind::While(cond, body) => ty_check_while(ctx, cond, body)?, + ExprKind::If(cond, case_true) => ty_check_if(ctx, cond, case_true, arena)?, + ExprKind::While(cond, body) => ty_check_while(ctx, cond, body, arena)?, // ExprKind::Lambda(params, lambda_exec_ident, ret_ty, body) => { // ty_check_lambda(ctx, params, lambda_exec_ident, ret_ty, body)? // } - ExprKind::BinOp(bin_op, lhs, rhs) => ty_check_binary_op(ctx, bin_op, lhs, rhs)?, - ExprKind::UnOp(un_op, e) => ty_check_unary_op(ctx, un_op, e)?, - ExprKind::Sync(exec) => ty_check_sync(ctx, exec)?, + ExprKind::BinOp(bin_op, lhs, rhs) => ty_check_binary_op(ctx, bin_op, lhs, rhs, arena)?, + ExprKind::UnOp(un_op, e) => ty_check_unary_op(ctx, un_op, e, arena)?, + ExprKind::Sync(exec) => ty_check_sync(ctx, exec, arena)?, ExprKind::Unsafe(e) => { ctx.unsafe_flag = true; - ty_check_expr(ctx, e)?; + ty_check_expr(ctx, e, arena)?; ctx.unsafe_flag = false; e.ty.as_ref().unwrap().as_ref().clone() } - ExprKind::Cast(expr, dty) => ty_check_cast(ctx, expr, dty)?, + ExprKind::Cast(expr, dty) => ty_check_cast(ctx, expr, dty, arena)?, ExprKind::Range(_, _) => unimplemented!(), //ty_check_range(ctx, l, u)?, - ExprKind::Hole => ty_check_hole(ctx)?, + ExprKind::Hole => ty_check_hole(ctx, arena)?, }; // TODO reintroduce!!!! //if let Err(err) = self.ty_well_formed(kind_ctx, &res_ty_ctx, exec, &ty) { // panic!("{:?}", err); //} - expr.ty = Some(Box::new(ty)); + expr.ty = Some(arena.alloc(ty)); Ok(()) } -fn ty_check_hole(ctx: &ExprTyCtx) -> TyResult { +fn ty_check_hole<'a>(ctx: &ExprTyCtx, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { if ctx.unsafe_flag { - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Ident(Ident::new_impli(&utils::fresh_name("hole"))), + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, + DataTyKind::Ident(Ident::new_impli(arena, &utils::fresh_name("hole"))), ))))) } else { Err(TyError::UnsafeRequired) } } -fn ty_check_sync(ctx: &mut ExprTyCtx, exec: &mut Option) -> TyResult { +fn ty_check_sync<'a>( + ctx: &mut ExprTyCtx, + exec: &'a mut Option>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { let synced = match exec { Some(exec) => { - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, exec)?; + exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, exec, arena)?; exec } None => &ctx.exec, }; syncable_under_exec(synced, &ctx.exec)?; - ctx.access_ctx.clear_sync_for(ctx.ty_ctx, synced); - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + ctx.access_ctx.clear_sync_for(ctx.ty_ctx, synced, arena); + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } // assumes fully typed ExecExpr as input -fn syncable_under_exec(synced: &ExecExpr, under: &ExecExpr) -> TyResult<()> { +fn syncable_under_exec<'a>(synced: &'a ExecExpr<'a>, under: &'a ExecExpr<'a>) -> TyResult<'a, ()> { if !syncable_exec_ty(synced.ty.as_ref().unwrap()) { return Err(TyError::String( "trying to synchronize non-synchronizable execution resource".to_string(), @@ -277,7 +352,7 @@ fn syncable_under_exec(synced: &ExecExpr, under: &ExecExpr) -> TyResult<()> { } } -fn syncable_exec_ty(exec_ty: &ExecTy) -> bool { +fn syncable_exec_ty<'a>(exec_ty: &'a ExecTy<'a>) -> bool { match &exec_ty.ty { ExecTyKind::GpuBlock(_) | ExecTyKind::GpuWarp => true, ExecTyKind::CpuThread @@ -291,7 +366,7 @@ fn syncable_exec_ty(exec_ty: &ExecTy) -> bool { } } -fn infer_and_append_prv(ty_ctx: &mut TyCtx, prv_name: &Option) -> String { +fn infer_and_append_prv<'a>(ty_ctx: &'a mut TyCtx<'a>, prv_name: &Option) -> String { if let Some(prv) = prv_name.as_ref() { prv.clone() } else { @@ -301,22 +376,23 @@ fn infer_and_append_prv(ty_ctx: &mut TyCtx, prv_name: &Option) -> String } } -fn ty_check_for_nat( +fn ty_check_for_nat<'a>( ctx: &mut ExprTyCtx, - ident: &Ident, - range: &NatRange, + ident: &Ident<'a>, + range: &NatRange<'a>, // TODO make this a block - body: &mut Expr, -) -> TyResult { + body: &mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { let compare_ty_ctx = ctx.ty_ctx.clone(); - let lifted_range = range.lift(ctx.nat_ctx)?; + let lifted_range = range.lift(arena, ctx.nat_ctx)?; for i in lifted_range { - ctx.ty_ctx.push_empty_frame(); - ctx.nat_ctx.push_empty_frame(); - ctx.nat_ctx.append(&ident.name, i); + ctx.ty_ctx.push_empty_frame(arena); + ctx.nat_ctx.push_empty_frame(arena); + ctx.nat_ctx.append(&ident.name, i, arena); - ty_check_expr(ctx, body)?; + ty_check_expr(ctx, body, arena)?; ctx.nat_ctx.pop_frame(); ctx.ty_ctx.pop_frame(); @@ -330,18 +406,20 @@ fn ty_check_for_nat( return Err(TyError::UnexpectedType); } } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } -fn ty_check_for( +fn ty_check_for<'a>( ctx: &mut ExprTyCtx, - ident: &Ident, - collec: &mut Expr, - body: &mut Expr, -) -> TyResult { - ty_check_expr(ctx, collec)?; + ident: &Ident<'a>, + collec: &mut Expr<'a>, + body: &mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, collec, arena)?; let collec_dty = if let TyKind::Data(collec_dty) = &collec.ty.as_ref().unwrap().ty { collec_dty.as_ref() } else { @@ -355,13 +433,15 @@ fn ty_check_for( // TODO DataTyKind::Array(elem_dty, n) => unimplemented!(), DataTyKind::Ref(reff) => match &reff.dty.as_ref().dty { - DataTyKind::Array(elem_dty, _) => DataTyKind::Ref(Box::new(RefDty::new( + DataTyKind::Array(elem_dty, _) => DataTyKind::Ref(arena.alloc(RefDty::new( + arena, reff.rgn.clone(), reff.own, reff.mem.clone(), elem_dty.as_ref().clone(), ))), - DataTyKind::ArrayShape(elem_dty, _) => DataTyKind::Ref(Box::new(RefDty::new( + DataTyKind::ArrayShape(elem_dty, _) => DataTyKind::Ref(arena.alloc(RefDty::new( + arena, reff.rgn.clone(), reff.own, reff.mem.clone(), @@ -383,41 +463,47 @@ fn ty_check_for( } }; let compare_ty_ctx = ctx.ty_ctx.clone(); - let mut frame = Frame::new(); - frame.append_idents_typed(vec![IdentTyped::new( + let mut frame = Frame::new_in(arena); + frame.append_idents_typed(vec![IdentTyped::new_in(arena ident.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(ident_dty)))), + Ty::new(TyKind::Data(arena.alloc(DataTy::new(arena, ident_dty)))), Mutability::Const, ctx.exec.clone(), )]); ctx.ty_ctx.push_frame(frame); - ty_check_expr(ctx, body)?; + ty_check_expr(ctx, body, arena)?; ctx.ty_ctx.pop_frame(); if ctx.ty_ctx != &compare_ty_ctx { return Err(TyError::String( "Using a data type in loop that can only be used once.".to_string(), )); } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } -fn ty_check_while(ctx: &mut ExprTyCtx, cond: &mut Expr, body: &mut Expr) -> TyResult { - ty_check_expr(&mut *ctx, cond)?; - ctx.ty_ctx.push_empty_frame(); - ty_check_expr(ctx, body)?; +fn ty_check_while<'a>( + ctx: &mut ExprTyCtx, + cond: &'a mut Expr<'a>, + body: &'a mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(&mut *ctx, cond, arena)?; + ctx.ty_ctx.push_empty_frame(arena); + ty_check_expr(ctx, body, arena)?; ctx.ty_ctx.pop_frame(); let compare_ty_ctx = ctx.ty_ctx.clone(); // Is it better/more correct to push and pop scope around this as well? - ty_check_expr(ctx, cond)?; + ty_check_expr(ctx, cond, arena)?; if ctx.ty_ctx != &compare_ty_ctx { return Err(TyError::String( "Context should have stayed the same".to_string(), )); } - ctx.ty_ctx.push_empty_frame(); - ty_check_expr(ctx, body)?; + ctx.ty_ctx.push_empty_frame(arena); + ty_check_expr(ctx, body, arena)?; ctx.ty_ctx.pop_frame(); if ctx.ty_ctx != &compare_ty_ctx { return Err(TyError::String( @@ -452,19 +538,21 @@ fn ty_check_while(ctx: &mut ExprTyCtx, cond: &mut Expr, body: &mut Expr) -> TyRe body_ty ))); } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } -fn ty_check_if_else( +fn ty_check_if_else<'a>( ctx: &mut ExprTyCtx, - cond: &mut Expr, - case_true: &mut Expr, - case_false: &mut Expr, -) -> TyResult { + cond: &'a mut Expr<'a>, + case_true: &'a mut Expr<'a>, + case_false: &'a mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { // TODO deal with provenances in cases - ty_check_expr(ctx, cond)?; + ty_check_expr(ctx, cond, arena)?; // TODO acccess_ctx clone let mut ty_ctx_clone = ctx.ty_ctx.clone(); let mut ctx_clone = ExprTyCtx { @@ -477,9 +565,9 @@ fn ty_check_if_else( access_ctx: &mut *ctx.access_ctx, unsafe_flag: ctx.unsafe_flag, }; - let _case_true_ty_ctx = ty_check_expr(&mut ctx_clone, case_true)?; - ctx.ty_ctx.push_empty_frame(); - ty_check_expr(ctx, case_false)?; + let _case_true_ty_ctx = ty_check_expr(&mut ctx_clone, case_true, arena)?; + ctx.ty_ctx.push_empty_frame(arena); + ty_check_expr(ctx, case_false, arena)?; ctx.ty_ctx.pop_frame(); let cond_ty = cond.ty.as_ref().unwrap(); @@ -523,16 +611,22 @@ fn ty_check_if_else( ))); } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } -fn ty_check_if(ctx: &mut ExprTyCtx, cond: &mut Expr, case_true: &mut Expr) -> TyResult { +fn ty_check_if<'a>( + ctx: &mut ExprTyCtx, + cond: &'a mut Expr<'a>, + case_true: &'a mut Expr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { // TODO deal with provenances in cases - ty_check_expr(ctx, cond)?; - ctx.ty_ctx.push_empty_frame(); - ty_check_expr(ctx, case_true)?; + ty_check_expr(ctx, cond, arena)?; + ctx.ty_ctx.push_empty_frame(arena); + ty_check_expr(ctx, case_true, arena)?; ctx.ty_ctx.pop_frame(); let cond_ty = cond.ty.as_ref().unwrap(); @@ -563,12 +657,17 @@ fn ty_check_if(ctx: &mut ExprTyCtx, cond: &mut Expr, case_true: &mut Expr) -> Ty ))); } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, DataTyKind::Scalar(ScalarTy::Unit), ))))) } -fn ty_check_split(ctx: &mut ExprTyCtx, indep: &mut Split) -> TyResult { +fn ty_check_split<'a>( + ctx: &'a mut ExprTyCtx<'a>, + indep: &'a mut Split<'a>, + arena: &'a Bump, +) -> TyResult<'a, Ty<'a>> { // exec::ty_check( // ctx.kind_ctx, // ctx.ty_ctx, @@ -580,9 +679,10 @@ fn ty_check_split(ctx: &mut ExprTyCtx, indep: &mut Split) -> TyResult { ctx.ty_ctx, ctx.ident_exec, &mut indep.split_exec, + arena, )?; - legal_exec_under_current(ctx, &indep.split_exec)?; - let expanded_exec_expr = expand_exec_expr(ctx, &indep.split_exec)?; + legal_exec_under_current(ctx, &indep.split_exec, arena)?; + let expanded_exec_expr = expand_exec_expr(ctx, &indep.split_exec, arena)?; if indep.branch_idents.len() != indep.branch_bodies.len() { panic!( "Amount of branch identifiers and amount of branches do not match:\ @@ -599,18 +699,28 @@ fn ty_check_split(ctx: &mut ExprTyCtx, indep: &mut Split) -> TyResult { } for i in 0..indep.branch_bodies.len() { - let mut branch_exec = ExecExpr::new(expanded_exec_expr.exec.clone().split_proj( - indep.dim_compo, - indep.pos.clone(), - if i == 0 { - LeftOrRight::Left - } else if i == 1 { - LeftOrRight::Right - } else { - panic!("Unexepected projection.") - }, - )); - exec::ty_check(&ctx.nat_ctx, &ctx.ty_ctx, ctx.ident_exec, &mut branch_exec)?; + let mut branch_exec = ExecExpr::new( + arena, + expanded_exec_expr.exec.clone().split_proj( + arena, + indep.dim_compo, + indep.pos.clone(), + if i == 0 { + LeftOrRight::Left + } else if i == 1 { + LeftOrRight::Right + } else { + panic!("Unexepected projection.") + }, + ), + ); + exec::ty_check( + &ctx.nat_ctx, + &ctx.ty_ctx, + ctx.ident_exec, + &mut branch_exec, + arena, + )?; let mut branch_expr_ty_ctx = ExprTyCtx { gl_ctx: ctx.gl_ctx, nat_ctx: &mut *ctx.nat_ctx, @@ -623,11 +733,11 @@ fn ty_check_split(ctx: &mut ExprTyCtx, indep: &mut Split) -> TyResult { }; branch_expr_ty_ctx .ty_ctx - .push_empty_frame() + .push_empty_frame(arena) .append_exec_mapping(indep.branch_idents[i].clone(), branch_exec.clone()); - ty_check_expr(&mut branch_expr_ty_ctx, &mut indep.branch_bodies[i])?; + ty_check_expr(&mut branch_expr_ty_ctx, &mut indep.branch_bodies[i], arena)?; if indep.branch_bodies[i].ty.as_ref().unwrap().ty - != TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar(ScalarTy::Unit)))) + != TyKind::Data(arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit)))) { return Err(TyError::String( "A par_branch branch must not return a value.".to_string(), @@ -635,22 +745,22 @@ fn ty_check_split(ctx: &mut ExprTyCtx, indep: &mut Split) -> TyResult { } branch_expr_ty_ctx.ty_ctx.pop_frame(); } - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } -fn ty_check_sched(ctx: &mut ExprTyCtx, sched: &mut Sched) -> TyResult { +fn ty_check_sched<'a>(ctx: &'a mut ExprTyCtx<'a>, sched: &'a mut Sched<'a>, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { exec::ty_check( ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, - &mut sched.sched_exec, + &mut sched.sched_exec, arena )?; - legal_exec_under_current(ctx, &sched.sched_exec)?; - let expanded_exec_expr = expand_exec_expr(ctx, &sched.sched_exec)?; - let mut body_exec = ExecExpr::new(expanded_exec_expr.exec.forall(sched.dim)); - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut body_exec)?; + legal_exec_under_current(ctx, &sched.sched_exec, arena)?; + let expanded_exec_expr = expand_exec_expr(ctx, &sched.sched_exec, arena)?; + let mut body_exec = ExecExpr::new(arena expanded_exec_expr.exec.forall(sched.dim)); + exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut body_exec, arena)?; let mut schedule_body_ctx = ExprTyCtx { gl_ctx: ctx.gl_ctx, nat_ctx: &mut *ctx.nat_ctx, @@ -661,7 +771,7 @@ fn ty_check_sched(ctx: &mut ExprTyCtx, sched: &mut Sched) -> TyResult { access_ctx: &mut *ctx.access_ctx, unsafe_flag: ctx.unsafe_flag, }; - schedule_body_ctx.ty_ctx.push_empty_frame(); + schedule_body_ctx.ty_ctx.push_empty_frame(arena); if let Some(ident) = &sched.inner_exec_ident { schedule_body_ctx .ty_ctx @@ -672,48 +782,49 @@ fn ty_check_sched(ctx: &mut ExprTyCtx, sched: &mut Sched) -> TyResult { .ty_ctx .append_prv_mapping(PrvMapping::new(prv)); } - ty_check_expr(&mut schedule_body_ctx, &mut sched.body.body)?; + ty_check_expr(&mut schedule_body_ctx, &mut sched.body.body, arena)?; schedule_body_ctx.ty_ctx.pop_frame(); - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } -fn ty_check_block(ctx: &mut ExprTyCtx, block: &mut Block) -> TyResult { - ctx.ty_ctx.push_empty_frame(); +fn ty_check_block<'a>(ctx: &'a mut ExprTyCtx<'a>, block: &'a mut Block<'a>, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { + ctx.ty_ctx.push_empty_frame(arena); for prv in &block.prvs { ctx.ty_ctx.append_prv_mapping(PrvMapping::new(prv)); } - ty_check_expr(ctx, &mut block.body)?; + ty_check_expr(ctx, &mut block.body, arena)?; ctx.ty_ctx.pop_frame(); - ctx.access_ctx.garbage_collect(ctx.ty_ctx); + ctx.access_ctx.garbage_collect(ctx.ty_ctx, arena); Ok(block.body.ty.as_ref().unwrap().as_ref().clone()) } -fn collect_valid_loans(ty_ctx: &TyCtx, mut loans: HashSet) -> HashSet { +fn collect_valid_loans<'a>(ty_ctx: &TyCtx, mut loans: HashSet>, arena: &'a Bump) -> HashSet> { // FIXME this implementations assumes unique names which is not the case loans.retain(|l| { - let root_ident = &l.place_expr.to_pl_ctx_and_most_specif_pl().1.ident; + let root_ident = &l.place_expr.to_pl_ctx_and_most_specif_pl(arena).1.ident; ty_ctx.contains(root_ident) }); loans } -fn check_mutable(ty_ctx: &TyCtx, pl: &Place) -> TyResult<()> { +fn check_mutable<'a>(ty_ctx: &'a TyCtx<'a>, pl: &'a Place<'a>, arena: &'a Bump) -> TyResult<'a, ()> { let ident_ty = ty_ctx.ident_ty(&pl.ident)?; if ident_ty.mutbl != Mutability::Mut { - return Err(TyError::AssignToConst(pl.to_place_expr())); + return Err(TyError::AssignToConst(pl.to_place_expr(arena))); } Ok(()) } -fn ty_check_assign_place( - ctx: &mut ExprTyCtx, - pl_expr: &mut PlaceExpr, - e: &mut Expr, -) -> TyResult { - ty_check_expr(ctx, e)?; - let pl = pl_expr.to_place().unwrap(); +fn ty_check_assign_place<'a>( + ctx: &'a mut ExprTyCtx<'a>, + pl_expr: &'a mut PlaceExpr<'a>, + e: &'a mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, e, arena)?; + let pl = pl_expr.to_place(arena).unwrap(); let mut place_ty = ctx.ty_ctx.place_dty(&pl)?; // FIXME this should be checked for ArrayViews as well // fn contains_view_dty(ty: &TyKind) -> bool { @@ -725,13 +836,13 @@ fn ty_check_assign_place( // e // ))); // } - check_mutable(ctx.ty_ctx, &pl)?; + check_mutable(ctx.ty_ctx, &pl, arena)?; // If the place is not dead, check that it is safe to use, otherwise it is safe to use anyway. if !matches!(&place_ty.dty, DataTyKind::Dead(_),) { - borrow_check::borrow_check(&BorrowCheckCtx::new(ctx, vec![], Ownership::Uniq), pl_expr) + borrow_check::borrow_check(&BorrowCheckCtx::new(ctx, vec![], Ownership::Uniq), pl_expr, arena) .map_err(|err| { - TyError::ConflictingBorrow(Box::new(pl_expr.clone()), Ownership::Uniq, err) + TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), Ownership::Uniq, err) })?; } @@ -740,7 +851,7 @@ fn ty_check_assign_place( } else { return Err(TyError::UnexpectedType); }; - let err = unify::sub_unify(ctx.kind_ctx, ctx.ty_ctx, e_dty, &mut place_ty); + let err = unify::sub_unify(ctx.kind_ctx, ctx.ty_ctx, e_dty, &mut place_ty, arena); if let Err(err) = err { return Err(match err { UnifyError::CannotUnify => { @@ -750,30 +861,32 @@ fn ty_check_assign_place( }); } ctx.ty_ctx - .set_place_dty(&pl, e_dty.clone()) + .set_place_dty(&pl, e_dty.clone(), arena) .without_reborrow_loans(pl_expr); // TODO remove: not required for correctness // removing this leads to problems in Codegen, because the pl_expr is not annotated with a // type which is required by gen_pl_expr - pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr)?; - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr, arena)?; + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } -fn ty_check_assign_non_place( - ctx: &mut ExprTyCtx, - deref_expr: &mut PlaceExpr, - e: &mut Expr, -) -> TyResult { - ty_check_expr(ctx, e)?; - pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), deref_expr)?; +fn ty_check_assign_non_place<'a>( + ctx: &'a mut ExprTyCtx<'a>, + deref_expr: &'a mut PlaceExpr<'a>, + e: &'a mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, e, arena)?; + pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), deref_expr, arena)?; let potential_accesses = borrow_check::access_safety_check( &BorrowCheckCtx::new(ctx, vec![], Ownership::Uniq), deref_expr, + arena ) .map_err(|err| { - TyError::ConflictingBorrow(Box::new(deref_expr.clone()), Ownership::Uniq, err) + TyError::ConflictingBorrow(arena.alloc(deref_expr.clone()), Ownership::Uniq, err) })?; ctx.access_ctx.insert(potential_accesses); let deref_ty = deref_expr.ty.as_mut().unwrap(); @@ -782,6 +895,7 @@ fn ty_check_assign_non_place( ctx.ty_ctx, e.ty.as_mut().unwrap().as_mut(), deref_ty, + arena )?; if !deref_ty.is_fully_alive() { return Err(TyError::String( @@ -790,9 +904,9 @@ fn ty_check_assign_non_place( } // FIXME needs subtyping check on p, e types if let TyKind::Data(_) = &deref_ty.ty { - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } else { Err(TyError::String( "Trying to dereference view type which is not allowed.".to_string(), @@ -800,14 +914,15 @@ fn ty_check_assign_non_place( } } -fn ty_check_idx_assign( - ctx: &mut ExprTyCtx, - pl_expr: &mut PlaceExpr, - idx: &Nat, - e: &mut Expr, -) -> TyResult { - ty_check_expr(ctx, e)?; - pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr)?; +fn ty_check_idx_assign<'a>( + ctx: &'a mut ExprTyCtx<'a>, + pl_expr: &'a mut PlaceExpr<'a>, + idx: &'a Nat<'a>, + e: &'a mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, e, arena)?; + pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr, arena)?; let pl_expr_dty = if let TyKind::Data(dty) = &pl_expr.ty.as_ref().unwrap().ty { dty } else { @@ -868,27 +983,31 @@ fn ty_check_idx_assign( let potential_accesses = borrow_check::access_safety_check( &BorrowCheckCtx::new(ctx, vec![], Ownership::Uniq), pl_expr, + arena ) - .map_err(|err| TyError::ConflictingBorrow(Box::new(pl_expr.clone()), Ownership::Shrd, err))?; + .map_err(|err| { + TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), Ownership::Shrd, err) + })?; ctx.access_ctx.insert(potential_accesses); - subty::check(ctx.kind_ctx, ctx.ty_ctx, e.ty.as_ref().unwrap().dty(), dty)?; - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + subty::check(ctx.kind_ctx, ctx.ty_ctx, e.ty.as_ref().unwrap().dty(), dty, arena)?; + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } // FIXME currently assumes that binary operators exist only for f32 and i32 and that both // arguments have to be of the same type -fn ty_check_binary_op( - ctx: &mut ExprTyCtx, +fn ty_check_binary_op<'a>( + ctx: &mut ExprTyCtx<'a>, bin_op: &BinOp, - lhs: &mut Expr, - rhs: &mut Expr, -) -> TyResult { + lhs: &mut Expr<'a>, + rhs: &mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { // FIXME certain operations should only be allowed for certain data types // true > false is currently valid - ty_check_expr(ctx, lhs)?; - ty_check_expr(ctx, rhs)?; + ty_check_expr(ctx, lhs, arena)?; + ty_check_expr(ctx, rhs, arena)?; let lhs_ty = lhs.ty.as_ref().unwrap(); let rhs_ty = rhs.ty.as_ref().unwrap(); let ret_dty = match bin_op { @@ -908,9 +1027,9 @@ fn ty_check_binary_op( | BinOp::Ge | BinOp::And | BinOp::Or - | BinOp::Neq => Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Bool, - ))))), + | BinOp::Neq => Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Bool))), + )), }; match bin_op { // Shift operators only allow integer values (lhs_ty and rhs_ty can differ!) @@ -981,8 +1100,13 @@ fn ty_check_binary_op( } } -fn ty_check_unary_op(ctx: &mut ExprTyCtx, un_op: &UnOp, e: &mut Expr) -> TyResult { - ty_check_expr(ctx, e)?; +fn ty_check_unary_op<'a>( + ctx: &mut ExprTyCtx, + un_op: &UnOp, + e: &'a mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, e, arena)?; let e_ty = e.ty.as_ref().unwrap(); let e_dty = if let TyKind::Data(dty) = &e_ty.ty { dty.as_ref() @@ -1003,8 +1127,13 @@ fn ty_check_unary_op(ctx: &mut ExprTyCtx, un_op: &UnOp, e: &mut Expr) -> TyResul } } -fn ty_check_cast(ctx: &mut ExprTyCtx, e: &mut Expr, dty: &DataTy) -> TyResult { - ty_check_expr(ctx, e)?; +fn ty_check_cast<'a>( + ctx: &'a mut ExprTyCtx<'a>, + e: &mut Expr<'a>, + dty: &'a DataTy<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, e, arena)?; let e_ty = e.ty.as_ref().unwrap(); match &e_ty.dty().dty { DataTyKind::Scalar(ScalarTy::F32) @@ -1019,7 +1148,7 @@ fn ty_check_cast(ctx: &mut ExprTyCtx, e: &mut Expr, dty: &DataTy) -> TyResult Ok(Ty::new(TyKind::Data(Box::new(dty.clone())))), + | DataTyKind::Scalar(ScalarTy::F64) => Ok(Ty::new(TyKind::Data(arena.alloc(dty.clone())))), _ => Err(TyError::String(format!( "Exected a number type (i.e. i32 or f32) to cast to from {:?}, but found {:?}", e_ty, dty @@ -1030,7 +1159,7 @@ fn ty_check_cast(ctx: &mut ExprTyCtx, e: &mut Expr, dty: &DataTy) -> TyResult Ok(Ty::new(TyKind::Data(Box::new(dty.clone())))), + | DataTyKind::Scalar(ScalarTy::U64) => Ok(Ty::new(TyKind::Data(arena.alloc(dty.clone())))), _ => Err(TyError::String(format!( "Exected an integer type (i.e. i32 or u32) to cast to from a bool, but found {:?}", dty @@ -1043,28 +1172,28 @@ fn ty_check_cast(ctx: &mut ExprTyCtx, e: &mut Expr, dty: &DataTy) -> TyResult, - args: &mut [Expr], -) -> TyResult { +fn ty_check_app<'a>( + ctx: &'a mut ExprTyCtx<'a>, + fn_ident: &'a mut Ident<'a>, + gen_args: &'a mut Vec>, + args: &'a mut [Expr<'a>], + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { // TODO check well-kinded: FrameTyping, Prv, Ty - let partially_applied_dep_fn_ty = ty_check_dep_app(ctx, fn_ident, gen_args)?; + let partially_applied_dep_fn_ty = ty_check_dep_app(ctx, fn_ident, gen_args, arena)?; for arg in args.iter_mut() { - ty_check_expr(ctx, arg)?; + ty_check_expr(ctx, arg,arena)?; } let param_sigs_for_args = args .iter() .map(|arg| ParamSig::new(ctx.exec.clone(), arg.ty.as_ref().unwrap().as_ref().clone())) .collect(); - let ret_dty_placeholder = Ty::new(TyKind::Data(Box::new(DataTy::new(utils::fresh_ident( - "ret_ty", - DataTyKind::Ident, - ))))); - let mut mono_fn_ty = unify::inst_fn_ty_scheme(&partially_applied_dep_fn_ty); + let ret_dty_placeholder = Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, utils::fresh_ident(arena, "ret_ty", DataTyKind::Ident))), + )); + let mut mono_fn_ty = unify::inst_fn_ty_scheme(&partially_applied_dep_fn_ty, arena); unify::unify( - &mut FnTy::new( + &mut FnTy::new(arena, vec![], None, param_sigs_for_args, @@ -1073,9 +1202,10 @@ fn ty_check_app( vec![], ), &mut mono_fn_ty, + arena )?; let mut inferred_gen_args = - infer_kinded_args::infer_kinded_args(&partially_applied_dep_fn_ty, &mono_fn_ty)?; + infer_kinded_args::infer_kinded_args(&partially_applied_dep_fn_ty, &mono_fn_ty, arena)?; gen_args.append(&mut inferred_gen_args); if let Some(mut fn_def) = ctx.gl_ctx.pop_fun_def(&fn_ident.name) { @@ -1090,10 +1220,10 @@ fn ty_check_app( } if !ctx.gl_ctx.has_been_checked(&fn_ident.name, &nat_vals) { let mut called_nat_ctx = - NatCtx::with_frame(nat_names.into_iter().zip(nat_vals.clone()).collect()); - ty_check_global_fun_def(ctx.gl_ctx, &mut called_nat_ctx, &mut fn_def)?; + NatCtx::with_frame(arena, nat_names.into_iter().zip(nat_vals.clone()).collect(), ); + ty_check_global_fun_def(ctx.gl_ctx, &mut called_nat_ctx, &mut fn_def, arena)?; ctx.gl_ctx - .push_fun_checked_under_nats(fn_def, Box::from(nat_vals)) + .push_fun_checked_under_nats(arena,fn_def, arena.alloc(nat_vals)) } } @@ -1122,14 +1252,15 @@ fn ty_check_app( // } // } -fn ty_check_dep_app( +fn ty_check_dep_app<'a>( ctx: &mut ExprTyCtx, - fn_ident: &Ident, - gen_args: &[ArgKinded], -) -> TyResult { + fn_ident: &'a Ident<'a>, + gen_args: &'a [ArgKinded<'a>], + arena: &'a Bump +) -> TyResult<'a, FnTy<'a>> { //ty_check_expr(ctx, ef)?; let fn_ty = ctx.gl_ctx.fn_ty_by_ident(fn_ident)?; - apply_gen_args_to_fn_ty_checked(ctx.kind_ctx, &ctx.exec, &fn_ty, gen_args) + apply_gen_args_to_fn_ty_checked(ctx.kind_ctx, &ctx.exec, &fn_ty, gen_args, arena) // } else { // Err(TyError::String(format!( // "The provided function expression\n {:?}\n does not have a function type.", @@ -1138,23 +1269,25 @@ fn ty_check_dep_app( // } } -fn apply_gen_args_to_fn_ty_checked( - kind_ctx: &KindCtx, - exec: &ExecExpr, - fn_ty: &FnTy, - gen_args: &[ArgKinded], -) -> TyResult { +fn apply_gen_args_to_fn_ty_checked<'a>( + kind_ctx: &'a KindCtx<'a>, + exec: &'a ExecExpr<'a>, + fn_ty: &'a FnTy<'a>, + gen_args: &'a [ArgKinded<'a>], + arena: &'a Bump +) -> TyResult<'a, FnTy<'a>> { let mut subst_fn_ty = fn_ty.clone(); - apply_gen_args_checked(kind_ctx, &mut subst_fn_ty, gen_args)?; - apply_exec_checked(&mut subst_fn_ty, exec)?; + apply_gen_args_checked(kind_ctx, &mut subst_fn_ty, gen_args, arena)?; + apply_exec_checked(&mut subst_fn_ty, exec, arena)?; Ok(subst_fn_ty) } -fn apply_gen_args_checked( - kind_ctx: &KindCtx, - fn_ty: &mut FnTy, - gen_args: &[ArgKinded], -) -> TyResult<()> { +fn apply_gen_args_checked<'a>( + kind_ctx: &'a KindCtx<'a>, + fn_ty: &'a mut FnTy<'a>, + gen_args: &'a [ArgKinded<'a>], + arena: &'a Bump +) -> TyResult<'a, ()> { if fn_ty.generics.len() < gen_args.len() { return Err(TyError::String(format!( "Wrong amount of generic arguments. Expected {}, found {}", @@ -1166,11 +1299,15 @@ fn apply_gen_args_checked( check_arg_has_correct_kind(kind_ctx, &gen_param.kind, gen_arg)?; } let substituted_gen_idents = fn_ty.generics.drain(..gen_args.len()).collect::>(); - utils::subst_idents_kinded(&substituted_gen_idents, gen_args, fn_ty); + utils::subst_idents_kinded(arena, &substituted_gen_idents, gen_args, fn_ty); Ok(()) } -fn check_arg_has_correct_kind(kind_ctx: &KindCtx, expected: &Kind, kv: &ArgKinded) -> TyResult<()> { +fn check_arg_has_correct_kind<'a>( + kind_ctx: &'a KindCtx<'a>, + expected: &'a Kind, + kv: &'a ArgKinded<'a>, +) -> TyResult<'a, ()> { if expected == &kv.kind() { Ok(()) } else { @@ -1183,7 +1320,7 @@ fn check_arg_has_correct_kind(kind_ctx: &KindCtx, expected: &Kind, kv: &ArgKinde // FIXME the correct way to do this is to unify execs and to unify an identifier with an exec_expr // only if the types match (i.e., the exec expr type check must happen within unify) -fn apply_exec_checked(fn_ty: &mut FnTy, exec: &ExecExpr) -> TyResult<()> { +fn apply_exec_checked<'a>(fn_ty: &'a mut FnTy<'a>, exec: &'a ExecExpr<'a>, arena: &'a Bump) -> TyResult<'a, ()> { // TODO reintroduce // exec::ty_check( // ctx.kind_ctx, @@ -1196,17 +1333,22 @@ fn apply_exec_checked(fn_ty: &mut FnTy, exec: &ExecExpr) -> TyResult<()> { unify::unify( &mut exec.ty.as_ref().unwrap().as_ref().clone(), &mut ge.ty.clone(), + arena )?; let gen_exec_ident = ge.ident.clone(); fn_ty.generic_exec = None; - utils::subst_ident_exec(&gen_exec_ident, exec, fn_ty); + utils::subst_ident_exec(arena, &gen_exec_ident, exec, fn_ty); } // if no generic exec was substituted, execs must still be unifable - unify::unify(&mut fn_ty.exec, &mut exec.clone())?; + unify::unify(&mut fn_ty.exec, &mut exec.clone(), arena)?; Ok(()) } -fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyResult { +fn ty_check_app_kernel<'a>( + ctx: &mut ExprTyCtx, + app_kernel: &'a mut AppKernel<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { // current exec = cpu.thread if !matches!(ctx.exec.ty.as_ref().unwrap().ty, ExecTyKind::CpuThread) { return Err(TyError::String( @@ -1215,13 +1357,13 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes } // type check argument list for arg in app_kernel.args.iter_mut() { - ty_check_expr(ctx, arg)?; + ty_check_expr(ctx, arg, arena)?; } - let mut kernel_exec = ExecExpr::new(ExecExprKind::new(BaseExec::GpuGrid( + let mut kernel_exec = ExecExpr::new(arena, ExecExprKind::new(arena,BaseExec::GpuGrid( app_kernel.grid_dim.clone(), app_kernel.block_dim.clone(), ))); - exec::ty_check(ctx.nat_ctx, &TyCtx::new(), None, &mut kernel_exec)?; + exec::ty_check(ctx.nat_ctx, &TyCtx::new(arena), None, &mut kernel_exec, arena)?; let mut kernel_ctx = ExprTyCtx { gl_ctx: &mut *ctx.gl_ctx, @@ -1230,7 +1372,7 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes kind_ctx: ctx.kind_ctx, exec: kernel_exec, ty_ctx: &mut *ctx.ty_ctx, - access_ctx: &mut AccessCtx::new(), + access_ctx: &mut AccessCtx::new(arena), unsafe_flag: ctx.unsafe_flag, }; exec::ty_check( @@ -1238,6 +1380,7 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes kernel_ctx.ty_ctx, None, &mut kernel_ctx.exec, + arena )?; // add explicit provenances to typing context (see ty_check_block) for prv in &app_kernel.shared_mem_prvs { @@ -1248,10 +1391,11 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes .shared_mem_dtys .iter() .map(|dty| { - IdentTyped::new( - Ident::new_impli(&utils::fresh_name("shared_mem")), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::At( - Box::new(dty.clone()), + IdentTyped::new_in( + arena, + Ident::new_impli(arena, &utils::fresh_name("shared_mem")), + Ty::new(TyKind::Data(arena.alloc(DataTy::new(arena, DataTyKind::At( + arena.alloc(dty.clone()), Memory::GpuShared, ))))), Mutability::Mut, @@ -1276,12 +1420,12 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes Expr::new(ExprKind::Ref( prv, Ownership::Uniq, - Box::new(PlaceExpr::new(PlaceExprKind::Ident(idt.ident.clone()))), + arena.alloc(PlaceExpr::new(PlaceExprKind::Ident(idt.ident.clone()))), )) }) .collect::>(); for shrd_mem_arg in refs_to_shrd.iter_mut() { - ty_check_expr(&mut kernel_ctx, shrd_mem_arg)?; + ty_check_expr(&mut kernel_ctx, shrd_mem_arg, arena)?; } // create extended argument list with references to shared memory let extended_arg_sigs = app_kernel @@ -1294,7 +1438,7 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes ) }) .chain(refs_to_shrd.into_iter().map(|a| { - let block_exec = exec_distrib_over_blocks(&kernel_ctx.exec); + let block_exec = exec_distrib_over_blocks(&kernel_ctx.exec, arena); ParamSig::new(block_exec, *a.ty.unwrap()) })) .collect::>(); @@ -1303,25 +1447,29 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes &mut kernel_ctx, &mut app_kernel.fun_ident, &mut app_kernel.gen_args, + arena )?; // build expected type to unify with - let unit_ty = Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Unit, - ))))); - let mut mono_fn_ty = unify::inst_fn_ty_scheme(&partially_applied_dep_fn_ty); + let unit_ty = Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + )); + let mut mono_fn_ty = unify::inst_fn_ty_scheme(&partially_applied_dep_fn_ty, arena); unify::unify( &mut FnTy::new( + arena, vec![], None, extended_arg_sigs, kernel_ctx.exec, unit_ty.clone(), vec![], + ), &mut mono_fn_ty, + arena )?; let mut inferred_k_args = - infer_kinded_args::infer_kinded_args(&partially_applied_dep_fn_ty, &mono_fn_ty)?; + infer_kinded_args::infer_kinded_args(&partially_applied_dep_fn_ty, &mono_fn_ty, arena)?; app_kernel.gen_args.append(&mut inferred_k_args); if let Some(mut fn_def) = ctx.gl_ctx.pop_fun_def(&app_kernel.fun_ident.name) { @@ -1339,17 +1487,17 @@ fn ty_check_app_kernel(ctx: &mut ExprTyCtx, app_kernel: &mut AppKernel) -> TyRes .has_been_checked(&app_kernel.fun_ident.name, &nat_vals) { let mut called_nat_ctx = - NatCtx::with_frame(nat_names.into_iter().zip(nat_vals.clone()).collect()); - ty_check_global_fun_def(ctx.gl_ctx, &mut called_nat_ctx, &mut fn_def)?; + NatCtx::with_frame(arena, nat_names.into_iter().zip(nat_vals.clone()).collect()); + ty_check_global_fun_def(ctx.gl_ctx, &mut called_nat_ctx, &mut fn_def, arena)?; ctx.gl_ctx - .push_fun_checked_under_nats(fn_def, Box::from(nat_vals)) + .push_fun_checked_under_nats(arena, fn_def, Box::from(nat_vals)) } } Ok(unit_ty) } -fn exec_distrib_over_blocks(exec_expr: &ExecExpr) -> ExecExpr { - let base_clone = ExecExprKind::new(exec_expr.exec.base.clone()); +fn exec_distrib_over_blocks<'a>(exec_expr: &'a ExecExpr<'a>, arena: &'a Bump) -> ExecExpr<'a> { + let base_clone = ExecExprKind::new(arena, exec_expr.exec.base.clone()); let distrib_over_blocks = if let BaseExec::GpuGrid(gdim, _) = &exec_expr.exec.base { match gdim { Dim::XYZ(_) => base_clone @@ -1366,12 +1514,12 @@ fn exec_distrib_over_blocks(exec_expr: &ExecExpr) -> ExecExpr { } else { panic!("Expected GPU grid.") }; - ExecExpr::new(distrib_over_blocks) + ExecExpr::new(arena, distrib_over_blocks) } -fn ty_check_tuple(ctx: &mut ExprTyCtx, elems: &mut [Expr]) -> TyResult { +fn ty_check_tuple<'a>(ctx: &mut ExprTyCtx, elems: &'a mut [Expr<'a>], arena: &'a Bump) -> TyResult<'a, Ty<'a>> { for elem in elems.iter_mut() { - ty_check_expr(ctx, elem)?; + ty_check_expr(ctx, elem, arena)?; } let elem_tys: TyResult> = elems .iter() @@ -1382,29 +1530,29 @@ fn ty_check_tuple(ctx: &mut ExprTyCtx, elems: &mut [Expr]) -> TyResult { )), }) .collect(); - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Tuple(elem_tys?), - ))))) + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Tuple(elem_tys?))), + ))) } -fn ty_check_proj(ctx: &mut ExprTyCtx, e: &mut Expr, i: usize) -> TyResult { +fn ty_check_proj<'a>(ctx: &mut ExprTyCtx, e: &'a mut Expr<'a>, i: usize, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { if let ExprKind::PlaceExpr(_) = e.expr { panic!("Place expression should have been typechecked by a different rule.") } - ty_check_expr(ctx, e)?; + ty_check_expr(ctx, e, arena)?; let e_dty = if let TyKind::Data(dty) = &e.ty.as_ref().unwrap().ty { dty.as_ref() } else { return Err(TyError::UnexpectedType); }; let elem_ty = proj_elem_dty(e_dty, i); - Ok(Ty::new(TyKind::Data(Box::new(elem_ty?)))) + Ok(Ty::new(TyKind::Data(arena.alloc(elem_ty?)))) } -fn ty_check_array(ctx: &mut ExprTyCtx, elems: &mut Vec) -> TyResult { +fn ty_check_array<'a>(ctx: &'a mut ExprTyCtx<'a>, elems: &'a mut Vec>, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { assert!(!elems.is_empty()); for elem in elems.iter_mut() { - ty_check_expr(ctx, elem)?; + ty_check_expr(ctx, elem, arena)?; } let ty = elems.first().unwrap().ty.as_ref(); if !matches!(&ty.unwrap().ty, TyKind::Data(_)) { @@ -1417,16 +1565,16 @@ fn ty_check_array(ctx: &mut ExprTyCtx, elems: &mut Vec) -> TyResult { "Not all provided elements have the same type.".to_string(), )) } else { - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( + Ok(Ty::new(TyKind::Data(arena.alloc(DataTy::new(arena, DataTyKind::Array( - Box::new(ty.as_ref().unwrap().dty().clone()), + arena.alloc(ty.as_ref().unwrap().dty().clone()), Nat::Lit(elems.len()), ), ))))) } } -fn ty_check_literal(l: &mut Lit) -> Ty { +fn ty_check_literal<'a>(l: &mut Lit, arena: &'a Bump) -> Ty<'a> { let scalar_data = match l { Lit::Unit => ScalarTy::Unit, Lit::Bool(_) => ScalarTy::Bool, @@ -1437,16 +1585,17 @@ fn ty_check_literal(l: &mut Lit) -> Ty { Lit::F32(_) => ScalarTy::F32, Lit::F64(_) => ScalarTy::F64, }; - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - scalar_data, - ))))) + Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(scalar_data))), + )) } -fn infer_pattern_ident_tys( +fn infer_pattern_ident_tys<'a>( ctx: &mut ExprTyCtx, - pattern: &Pattern, - pattern_ty: &Ty, -) -> TyResult<()> { + pattern: &'a Pattern<'a>, + pattern_ty: &'a Ty<'a>, + arena: &'a Bump +) -> TyResult<'a, ()> { let pattern_dty = if let TyKind::Data(dty) = &pattern_ty.ty { dty.as_ref() } else { @@ -1454,9 +1603,10 @@ fn infer_pattern_ident_tys( }; match (pattern, &pattern_dty.dty) { (Pattern::Ident(mutbl, ident), _) => { - let ident_with_annotated_ty = IdentTyped::new( + let ident_with_annotated_ty = IdentTyped::new_in( + arena, ident.clone(), - Ty::new(TyKind::Data(Box::new(pattern_dty.clone()))), + Ty::new(TyKind::Data(arena.alloc(pattern_dty.clone()))), *mutbl, ctx.exec.clone(), ); @@ -1466,7 +1616,7 @@ fn infer_pattern_ident_tys( (Pattern::Wildcard, _) => Ok(()), (Pattern::Tuple(patterns), DataTyKind::Tuple(elem_tys)) => { for (p, tty) in patterns.iter().zip(elem_tys) { - infer_pattern_ident_tys(ctx, p, &Ty::new(TyKind::Data(Box::new(tty.clone()))))?; + infer_pattern_ident_tys(ctx, p, &Ty::new(TyKind::Data(arena.alloc(tty.clone()))), arena)?; } Ok(()) } @@ -1474,42 +1624,45 @@ fn infer_pattern_ident_tys( } } -fn infer_tys_and_append_idents( - ctx: &mut ExprTyCtx, - pattern: &Pattern, - pattern_ty: &mut Option>, - assign_ty: &mut Ty, -) -> TyResult<()> { +fn infer_tys_and_append_idents<'a>( + ctx: &'a mut ExprTyCtx<'a>, + pattern: &'a Pattern<'a>, + pattern_ty: &'a mut Option>>, + assign_ty: &'a mut Ty<'a>, + arena: &'a Bump +) -> TyResult<'a, ()> { let pattern_ty = if let Some(pty) = pattern_ty { - unify::sub_unify(ctx.kind_ctx, ctx.ty_ctx, assign_ty, pty)?; + unify::sub_unify(ctx.kind_ctx, ctx.ty_ctx, assign_ty, pty, arena)?; pty.as_ref().clone() } else { assign_ty.clone() }; - infer_pattern_ident_tys(ctx, pattern, &pattern_ty) + infer_pattern_ident_tys(ctx, pattern, &pattern_ty, arena) } -fn ty_check_let( - ctx: &mut ExprTyCtx, - pattern: &Pattern, - pattern_ty: &mut Option>, - expr: &mut Expr, -) -> TyResult { - ty_check_expr(ctx, expr)?; +fn ty_check_let<'a>( + ctx: &'a mut ExprTyCtx<'a>, + pattern: &'a Pattern<'a>, + pattern_ty: &'a mut Option>>, + expr: &'a mut Expr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + ty_check_expr(ctx, expr, arena)?; let e_ty = expr.ty.as_mut().unwrap(); - infer_tys_and_append_idents(ctx, pattern, pattern_ty, e_ty)?; - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + infer_tys_and_append_idents(ctx, pattern, pattern_ty, e_ty, arena)?; + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } // TODO respect exec? -fn ty_check_let_uninit( - ctx: &mut ExprTyCtx, - annot_exec: &Option>, - ident: &Ident, - ty: &Ty, -) -> TyResult { +fn ty_check_let_uninit<'a>( + ctx: &'a mut ExprTyCtx<'a>, + annot_exec: &'a Option>>, + ident: &'a Ident<'a>, + ty: &'a Ty<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { // TODO is the type well-formed? if let TyKind::Data(dty) = &ty.ty { let mut exec_expr = if let Some(ex) = annot_exec { @@ -1517,34 +1670,39 @@ fn ty_check_let_uninit( } else { ctx.exec.clone() }; - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut exec_expr)?; - let ident_with_ty = IdentTyped::new( + exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut exec_expr, arena)?; + let ident_with_ty = IdentTyped::new_in( + arena, ident.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Dead( - dty.clone(), - ))))), + Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Dead(dty.clone()))), + )), Mutability::Mut, exec_expr, ); ctx.ty_ctx.append_ident_typed(ident_with_ty); - Ok(Ty::new(TyKind::Data(Box::new(DataTy::new( - DataTyKind::Scalar(ScalarTy::Unit), - ))))) + Ok(Ty::new(TyKind::Data( + arena.alloc(DataTy::new(arena, DataTyKind::Scalar(ScalarTy::Unit))), + ))) } else { Err(TyError::MutabilityNotAllowed(ty.clone())) } } -fn ty_check_seq(ctx: &mut ExprTyCtx, es: &mut [Expr]) -> TyResult { +fn ty_check_seq<'a>(ctx: &mut ExprTyCtx, es: &'a mut [Expr<'a>], arena: &'a Bump) -> TyResult<'a, Ty<'a>> { for e in &mut *es { - ty_check_expr(ctx, e)?; + ty_check_expr(ctx, e, arena)?; ctx.ty_ctx.garbage_collect_loans(); } Ok(es.last().unwrap().ty.as_ref().unwrap().as_ref().clone()) } -fn ty_check_non_place(ctx: &mut ExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult { - pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Shrd), pl_expr)?; +fn ty_check_non_place<'a>( + ctx: &mut ExprTyCtx, + pl_expr: &'a mut PlaceExpr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { + pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Shrd), pl_expr, arena)?; if !pl_expr.ty.as_ref().unwrap().is_fully_alive() { return Err(TyError::String(format!( "Part of Place {:?} was moved before.", @@ -1553,16 +1711,20 @@ fn ty_check_non_place(ctx: &mut ExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult< } unify::unify( pl_expr.ty.as_mut().unwrap().as_mut(), - &mut Ty::new(TyKind::Data(Box::new(DataTy::with_constr( - utils::fresh_ident("pl_deref", DataTyKind::Ident), + &mut Ty::new(TyKind::Data(arena.alloc(DataTy::with_constr( + arena, utils::fresh_ident(arena, "pl_deref", DataTyKind::Ident), vec![Constraint::Copyable], )))), + arena )?; let potential_accesses = borrow_check::access_safety_check( &BorrowCheckCtx::new(ctx, vec![], Ownership::Shrd), pl_expr, + arena ) - .map_err(|err| TyError::ConflictingBorrow(Box::new(pl_expr.clone()), Ownership::Shrd, err))?; + .map_err(|err| { + TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), Ownership::Shrd, err) + })?; ctx.access_ctx.insert(potential_accesses); if pl_expr.ty.as_ref().unwrap().copyable() { Ok(pl_expr.ty.as_ref().unwrap().as_ref().clone()) @@ -1571,9 +1733,9 @@ fn ty_check_non_place(ctx: &mut ExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult< } } -fn ty_check_place(ctx: &mut ExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult { - pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr)?; - let place = pl_expr.clone().to_place().unwrap(); +fn ty_check_place<'a>(ctx: &mut ExprTyCtx, pl_expr: &'a mut PlaceExpr<'a>, arena: &'a Bump) -> TyResult<'a, Ty<'a>> { + pl_expr::ty_check(&PlExprTyCtx::new(ctx, Ownership::Uniq), pl_expr, assert_mem_uninitialized_valid();)?; + let place = pl_expr.clone().to_place(arena).unwrap(); let pl_ty = ctx.ty_ctx.place_dty(&place)?; if !pl_ty.is_fully_alive() { return Err(TyError::String(format!( @@ -1586,31 +1748,34 @@ fn ty_check_place(ctx: &mut ExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult borrow_check::access_safety_check( &BorrowCheckCtx::new(ctx, vec![], Ownership::Shrd), pl_expr, + arena ) .map_err(|err| { - TyError::ConflictingBorrow(Box::new(pl_expr.clone()), Ownership::Shrd, err) + TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), Ownership::Shrd, err) })?; } else { borrow_check::access_safety_check( &BorrowCheckCtx::new(ctx, vec![], Ownership::Uniq), pl_expr, + arena ) .map_err(|err| { - TyError::ConflictingBorrow(Box::new(pl_expr.clone()), Ownership::Uniq, err) + TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), Ownership::Uniq, err) })?; - ctx.ty_ctx.kill_place(&place); + ctx.ty_ctx.kill_place(&place, arena); }; - Ok(Ty::new(TyKind::Data(Box::new(pl_ty)))) + Ok(Ty::new(TyKind::Data(arena.alloc(pl_ty)))) } -fn ty_check_borrow( +fn ty_check_borrow<'a>( ctx: &mut ExprTyCtx, prv_val_name: &Option, own: Ownership, - pl_expr: &mut PlaceExpr, -) -> TyResult { + pl_expr: &'a mut PlaceExpr<'a>, + arena: &'a Bump +) -> TyResult<'a, Ty<'a>> { // If borrowing a place uniquely, is it mutable? - if let Some(place) = pl_expr.to_place() { + if let Some(place) = pl_expr.to_place(arena) { if own == Ownership::Uniq && ctx.ty_ctx.ident_ty(&place.ident)?.mutbl == Mutability::Const { return Err(TyError::ConstBorrow(pl_expr.clone())); } @@ -1619,9 +1784,9 @@ fn ty_check_borrow( if !ctx.ty_ctx.loans_in_prv(&prv_val_name)?.is_empty() { return Err(TyError::PrvValueAlreadyInUse(prv_val_name)); } - let mems = pl_expr::ty_check_and_passed_mems(&PlExprTyCtx::new(ctx, own), pl_expr)?; - let loans = borrow_check::access_safety_check(&BorrowCheckCtx::new(ctx, vec![], own), pl_expr) - .map_err(|err| TyError::ConflictingBorrow(Box::new(pl_expr.clone()), own, err))?; + let mems = pl_expr::ty_check_and_passed_mems(&PlExprTyCtx::new(ctx, own), pl_expr, arena)?; + let loans = borrow_check::access_safety_check(&BorrowCheckCtx::new(ctx, vec![], own), pl_expr, arena) + .map_err(|err| TyError::ConflictingBorrow(arena.alloc(pl_expr.clone()), own, err))?; mems.iter() .try_for_each(|mem| accessible_memory(ctx.exec.ty.as_ref().unwrap().as_ref(), mem))?; let pl_expr_ty = pl_expr.ty.as_ref().unwrap(); @@ -1655,17 +1820,18 @@ fn ty_check_borrow( "Trying to take reference of unaddressable gpu.local memory.".to_string(), )); } - let res_dty = DataTy::new(DataTyKind::Ref(Box::new(RefDty::new( + let res_dty = DataTy::new(arena, DataTyKind::Ref(arena.alloc(RefDty::new( + arena, Provenance::Value(prv_val_name.clone()), own, rmem, reffed_ty, )))); ctx.ty_ctx.extend_loans_for_prv(&prv_val_name, loans)?; - Ok(Ty::new(TyKind::Data(Box::new(res_dty)))) + Ok(Ty::new(TyKind::Data(arena.alloc(res_dty)))) } -fn allowed_mem_for_exec(exec_ty: &ExecTyKind) -> Vec { +fn allowed_mem_for_exec<'a>(exec_ty: &'a ExecTyKind<'a>) -> Vec> { match exec_ty { ExecTyKind::CpuThread => vec![Memory::CpuMem], ExecTyKind::GpuThread @@ -1682,7 +1848,7 @@ fn allowed_mem_for_exec(exec_ty: &ExecTyKind) -> Vec { } } -pub fn accessible_memory(exec_ty: &ExecTy, mem: &Memory) -> TyResult<()> { +pub fn accessible_memory<'a>(exec_ty: &'a ExecTy<'a>, mem: &'a Memory<'a>) -> TyResult<'a, ()> { if allowed_mem_for_exec(&exec_ty.ty).contains(mem) { Ok(()) } else { @@ -1694,7 +1860,13 @@ pub fn accessible_memory(exec_ty: &ExecTy, mem: &Memory) -> TyResult<()> { } // TODO respect memory -fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) -> TyResult<()> { +fn ty_well_formed<'a>( + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &'a TyCtx<'a>, + exec_ty: &'a ExecTy<'a>, + ty: &'a Ty<'a>, + arena: &'a Bump +) -> TyResult<'a, ()> { match &ty.ty { TyKind::Data(dty) => match &dty.dty { // TODO variables of Dead types can be reassigned. So why do we not have to check @@ -1756,14 +1928,14 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) } } } - ty_well_formed(kind_ctx, ty_ctx, exec_ty, &elem_ty)?; + ty_well_formed(kind_ctx, ty_ctx, exec_ty, &elem_ty, arena)?; } Provenance::Ident(ident) => { let elem_ty = Ty::new(TyKind::Data(reff.dty.clone())); if !kind_ctx.ident_of_kind_exists(ident, Kind::Provenance) { Err(CtxError::KindedIdentNotFound(ident.clone()))? } - ty_well_formed(kind_ctx, ty_ctx, exec_ty, &elem_ty)?; + ty_well_formed(kind_ctx, ty_ctx, exec_ty, &elem_ty, arena)?; } }; } @@ -1773,13 +1945,14 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) kind_ctx, ty_ctx, exec_ty, - &Ty::new(TyKind::Data(Box::new(elem_dty.clone()))), + &Ty::new(TyKind::Data(arena.alloc(elem_dty.clone()))), + arena )?; } } DataTyKind::Struct(struct_decl) => { for (_, dty) in &struct_decl.fields { - ty_well_formed(kind_ctx, ty_ctx, exec_ty, &Ty::new(TyKind::Data(Box::new(dty.clone()))))?; + ty_well_formed(kind_ctx, ty_ctx, exec_ty, &Ty::new(TyKind::Data(arena.alloc(dty.clone()))), arena)?; } } DataTyKind::Array(elem_dty, n) => { @@ -1788,6 +1961,7 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) ty_ctx, exec_ty, &Ty::new(TyKind::Data(elem_dty.clone())), + arena )?; // TODO well-formed nat } @@ -1797,6 +1971,7 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) ty_ctx, exec_ty, &Ty::new(TyKind::Data(elem_dty.clone())), + arena )? // TODO well-formed nat } @@ -1811,6 +1986,7 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) ty_ctx, exec_ty, &Ty::new(TyKind::Data(elem_dty.clone())), + arena )?; } DataTyKind::At(elem_dty, _) => { @@ -1819,6 +1995,7 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) ty_ctx, exec_ty, &Ty::new(TyKind::Data(elem_dty.clone())), + arena )?; } }, @@ -1826,27 +2003,31 @@ fn ty_well_formed(kind_ctx: &KindCtx, ty_ctx: &TyCtx, exec_ty: &ExecTy, ty: &Ty) TyKind::FnTy(fn_ty) => { let mut extended_kind_ctx = kind_ctx.clone(); extended_kind_ctx.append_idents(fn_ty.generics.clone()); - ty_well_formed(&extended_kind_ctx, ty_ctx, exec_ty, &fn_ty.ret_ty)?; + ty_well_formed(&extended_kind_ctx, ty_ctx, exec_ty, &fn_ty.ret_ty, arena)?; for param_sig in &fn_ty.param_sigs { // TODO which checks are necessary for the execution resource in // param_sig.exec_expr? - ty_well_formed(&extended_kind_ctx, ty_ctx, exec_ty, ¶m_sig.ty)?; + ty_well_formed(&extended_kind_ctx, ty_ctx, exec_ty, ¶m_sig.ty, arena)?; } } } Ok(()) } -pub fn callable_in(callee_exec_ty: &ExecTy, caller_exec_ty: &ExecTy) -> bool { +pub fn callable_in<'a>(callee_exec_ty: &'a ExecTy<'a>, caller_exec_ty: &'a ExecTy<'a>, arena: &'a Bump) -> bool { if &callee_exec_ty.ty == &ExecTyKind::Any { true } else { - let res = unify::unify(&mut callee_exec_ty.clone(), &mut caller_exec_ty.clone()); + let res = unify::unify(&mut callee_exec_ty.clone(), &mut caller_exec_ty.clone(), arena); res.is_ok() } } -fn expand_exec_expr(ctx: &ExprTyCtx, exec_expr: &ExecExpr) -> TyResult { +fn expand_exec_expr<'a>( + ctx: &'a ExprTyCtx<'a>, + exec_expr: &'a ExecExpr<'a>, + arena: &'a Bump +) -> TyResult<'a, ExecExpr<'a>> { match &exec_expr.exec.base { BaseExec::CpuThread | BaseExec::GpuGrid(_, _) => Ok(exec_expr.clone()), BaseExec::Ident(ident) => { @@ -1855,20 +2036,21 @@ fn expand_exec_expr(ctx: &ExprTyCtx, exec_expr: &ExecExpr) -> TyResult let mut new_exec_path = inner_exec_expr.exec.path.clone(); new_exec_path.append(&mut exec_expr.exec.path.clone()); let mut expanded_exec_expr: ExecExpr = - ExecExpr::new(ExecExprKind::with_path(new_base, new_exec_path)); + ExecExpr::new(arena, ExecExprKind::with_path(new_base, new_exec_path)); exec::ty_check( ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut expanded_exec_expr, + arena )?; Ok(expanded_exec_expr) } } } -fn legal_exec_under_current(ctx: &ExprTyCtx, exec: &ExecExpr) -> TyResult<()> { - let expanded_exec_expr = expand_exec_expr(ctx, exec)?; +fn legal_exec_under_current<'a>(ctx: &'a ExprTyCtx<'a>, exec: &'a ExecExpr<'a>, arena: &'a Bump) -> TyResult<'a, ()> { + let expanded_exec_expr = expand_exec_expr(ctx, exec, arena)?; if ctx.exec != expanded_exec_expr { let current_exec_ty = &ctx.exec.ty.as_ref().unwrap().ty; let expanded_exec_ty = expanded_exec_expr.ty.unwrap().ty; @@ -1889,7 +2071,7 @@ fn legal_exec_under_current(ctx: &ExprTyCtx, exec: &ExecExpr) -> TyResult<()> { } // TODO move into utility module (also used in codegen) -pub fn proj_elem_dty(dty: &DataTy, i: usize) -> TyResult { +pub fn proj_elem_dty<'a>(dty: &'a DataTy<'a>, i: usize) -> TyResult<'a, DataTy<'a>> { match &dty.dty { DataTyKind::Tuple(dtys) => match dtys.get(i) { Some(dty) => Ok(dty.clone()), diff --git a/src/ty_check/pl_expr.rs b/src/ty_check/pl_expr.rs index 3739792e..6e508acb 100644 --- a/src/ty_check/pl_expr.rs +++ b/src/ty_check/pl_expr.rs @@ -1,7 +1,7 @@ use super::borrow_check::BorrowCheckCtx; use super::error::TyError; use super::TyResult; -use crate::ast::{ +use crate::arena_ast::{ utils, DataTy, DataTyKind, ExecExpr, ExecTyKind, FnTy, Ident, IdentExec, Memory, Nat, NatCtx, Ownership, ParamSig, PlaceExpr, PlaceExprKind, Provenance, Ty, TyKind, View, }; @@ -10,20 +10,22 @@ use crate::ty_check::ctxs::{AccessCtx, GlobalCtx, KindCtx, TyCtx}; use crate::ty_check::unify; use crate::ty_check::unify::ConstrainMap; use crate::ty_check::{exec, ExprTyCtx}; +use bumpalo::collections::Vec as BumpVec; +use bumpalo::Bump; -pub(super) struct PlExprTyCtx<'gl, 'src, 'ctxt> { - gl_ctx: &'ctxt GlobalCtx<'gl, 'src>, - nat_ctx: &'ctxt NatCtx, - kind_ctx: &'ctxt KindCtx, - ident_exec: Option<&'ctxt IdentExec>, - exec: ExecExpr, - ty_ctx: &'ctxt TyCtx, - exec_borrow_ctx: &'ctxt AccessCtx, +pub(super) struct PlExprTyCtx<'a> { + gl_ctx: &'a GlobalCtx<'a>, + nat_ctx: &'a NatCtx<'a>, + kind_ctx: &'a KindCtx<'a>, + ident_exec: Option<&'a IdentExec<'a>>, + exec: ExecExpr<'a>, + ty_ctx: &'a TyCtx<'a>, + exec_borrow_ctx: &'a AccessCtx<'a>, own: Ownership, } -impl<'gl, 'src, 'ctxt> PlExprTyCtx<'gl, 'src, 'ctxt> { - pub(super) fn new(expr_ty_ctx: &'ctxt ExprTyCtx<'gl, 'src, 'ctxt>, own: Ownership) -> Self { +impl<'a> PlExprTyCtx<'a> { + pub(super) fn new(expr_ty_ctx: &'a ExprTyCtx<'a>, own: Ownership) -> Self { PlExprTyCtx { gl_ctx: &*expr_ty_ctx.gl_ctx, nat_ctx: &*expr_ty_ctx.nat_ctx, @@ -37,10 +39,8 @@ impl<'gl, 'src, 'ctxt> PlExprTyCtx<'gl, 'src, 'ctxt> { } } -impl<'gl, 'src, 'ctxt> From<&'ctxt BorrowCheckCtx<'gl, 'src, 'ctxt>> - for PlExprTyCtx<'gl, 'src, 'ctxt> -{ - fn from(ctx: &'ctxt BorrowCheckCtx<'gl, 'src, 'ctxt>) -> Self { +impl<'a> From<&'a BorrowCheckCtx<'a>> for PlExprTyCtx<'a> { + fn from(ctx: &'a BorrowCheckCtx<'a>) -> Self { PlExprTyCtx { gl_ctx: ctx.gl_ctx, nat_ctx: ctx.nat_ctx, @@ -56,138 +56,204 @@ impl<'gl, 'src, 'ctxt> From<&'ctxt BorrowCheckCtx<'gl, 'src, 'ctxt>> // Δ; Γ ⊢ω p:τ // p in an ω context has type τ under Δ and Γ -pub(super) fn ty_check(ctx: &PlExprTyCtx, pl_expr: &mut PlaceExpr) -> TyResult<()> { - let _mem = ty_check_and_passed_mems(ctx, pl_expr)?; +pub(super) fn ty_check<'a>( + ctx: &'a PlExprTyCtx<'a>, + pl_expr: &'a mut PlaceExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, ()> { + let _mem = ty_check_and_passed_mems(ctx, pl_expr, arena)?; Ok(()) } -pub(super) fn ty_check_and_passed_mems( - ctx: &PlExprTyCtx, - pl_expr: &mut PlaceExpr, -) -> TyResult> { - let (mem, _) = ty_check_and_passed_mems_prvs(ctx, pl_expr)?; +pub(super) fn ty_check_and_passed_mems<'a>( + ctx: &'a PlExprTyCtx<'a>, + pl_expr: &'a mut PlaceExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, Vec>> { + let (mem, _) = ty_check_and_passed_mems_prvs(ctx, pl_expr, arena)?; Ok(mem) } // Δ; Γ ⊢ω p:τ,{ρ} // p in an ω context has type τ under Δ and Γ, passing through provenances in Vec<ρ> -fn ty_check_and_passed_mems_prvs( - ctx: &PlExprTyCtx, - pl_expr: &mut PlaceExpr, -) -> TyResult<(Vec, Vec)> { - let (ty, mem, prvs) = match &mut pl_expr.pl_expr { +fn ty_check_and_passed_mems_prvs<'a>( + ctx: &'a PlExprTyCtx<'a>, + pl_expr: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Vec>, Vec>)> { + let (ty, mem, prvs) = match &pl_expr.pl_expr { // TC-Var PlaceExprKind::Ident(ident) => ty_check_ident(ctx, ident)?, // TC-Proj - PlaceExprKind::Proj(tuple_expr, n) => ty_check_proj(ctx, tuple_expr, *n)?, + PlaceExprKind::Proj(tuple_expr, n) => ty_check_proj(ctx, *tuple_expr, *n, arena)?, // TC-Field PlaceExprKind::FieldProj(struct_expr, ident) => { - ty_check_field_proj(ctx, struct_expr, ident)? + ty_check_field_proj(ctx, *struct_expr, ident, arena)? } // TC-Deref - PlaceExprKind::Deref(borr_expr) => ty_check_deref(ctx, borr_expr)?, + PlaceExprKind::Deref(borr_expr) => ty_check_deref(ctx, *borr_expr, arena)?, // TC-Select - PlaceExprKind::Select(p, select_exec) => ty_check_select(ctx, p, select_exec)?, - PlaceExprKind::View(pl_expr, view) => ty_check_view_pl_expr(ctx, pl_expr, view)?, - PlaceExprKind::Idx(pl_expr, idx) => ty_check_index(ctx, pl_expr, idx)?, + PlaceExprKind::Select(pl_expr, select_exec) => { + ty_check_select(ctx, *pl_expr, *select_exec, arena)? + } + PlaceExprKind::View(pl_expr, view) => ty_check_view_pl_expr(ctx, *pl_expr, *view, arena)?, + PlaceExprKind::Idx(pl_expr, idx) => ty_check_index(ctx, *pl_expr, *idx, arena)?, }; - pl_expr.ty = Some(Box::new(ty)); + + let _ = pl_expr.ty.set(arena.alloc(ty)); Ok((mem, prvs)) } -fn ty_check_view_pl_expr( - ctx: &PlExprTyCtx, - pl_expr: &mut PlaceExpr, - view: &mut View, -) -> TyResult<(Ty, Vec, Vec)> { - let (mems, prvs) = ty_check_and_passed_mems_prvs(ctx, pl_expr)?; - let view_fn_ty = ty_check_view(ctx, view)?; - let in_dty = pl_expr.ty.as_ref().unwrap().dty().clone(); - let (res_dty, constr_map) = ty_check_app_view_fn_ty(ctx, &in_dty, view_fn_ty)?; - // substitute implicit identifiers in view, that were inferred from the input data type - unify::substitute(&constr_map, view); - Ok((Ty::new(TyKind::Data(Box::new(res_dty))), mems, prvs)) +fn ty_check_view_pl_expr<'a>( + ctx: &'a PlExprTyCtx<'a>, + pl_expr: &'a PlaceExpr<'a>, + view: &'a View<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let (mems, prvs) = ty_check_and_passed_mems_prvs(ctx, pl_expr, arena)?; + let mut v_tmp = (*view).clone_in(arena); + let view_fn_ty = ty_check_view(ctx, &mut v_tmp, arena)?; + let in_dty_ref: &'a DataTy<'a> = { + let tmp = pl_expr.ty.get().unwrap().dty().clone_in(arena); + arena.alloc(tmp) + }; + let (res_dty, constr_map) = ty_check_app_view_fn_ty(ctx, in_dty_ref, view_fn_ty, arena)?; + unify::substitute(&constr_map, &mut v_tmp, arena); + Ok((Ty::new(TyKind::Data(arena.alloc(res_dty))), mems, prvs)) } -fn ty_check_app_view_fn_ty( - ctx: &PlExprTyCtx, - in_dty: &DataTy, - mut view_fn_ty: FnTy, -) -> TyResult<(DataTy, ConstrainMap)> { - let mut arg_dty_fn_ty = FnTy::new( - vec![], +fn ty_check_app_view_fn_ty<'a>( + ctx: &'a PlExprTyCtx<'a>, + in_dty: &'a DataTy<'a>, + view_fn_ty: FnTy<'a>, + arena: &'a Bump, +) -> TyResult<'a, (DataTy<'a>, ConstrainMap<'a>)> { + let arg_dty_fn_ty = FnTy::new( + arena, + std::iter::empty(), None, - vec![ParamSig::new( + [ParamSig::new( ctx.exec.clone(), - Ty::new(TyKind::Data(Box::new(in_dty.clone()))), + arena.alloc(Ty::new(TyKind::Data(arena.alloc(in_dty.clone_in(arena))))), )], ctx.exec.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ident( - Ident::new_impli("ret_dty"), + arena.alloc(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, + DataTyKind::Ident(Ident::new_impli(arena, "ret_dty")), ))))), - vec![], + std::iter::empty(), ); - let (constr_map, _) = unify::constrain(&mut arg_dty_fn_ty, &mut view_fn_ty)?; - let res_dty = arg_dty_fn_ty.ret_ty.dty().clone(); + + let arg_ptr = arena.alloc(arg_dty_fn_ty); + let view_ptr = arena.alloc(view_fn_ty); + + let (constr_map, _prv) = unify::constrain(arg_ptr, view_ptr, arena)?; + + let mut res_dty = DataTy::new(arena, DataTyKind::Ident(Ident::new_impli(arena, "ret_dty"))); + unify::substitute(&constr_map, &mut res_dty, arena); + Ok((res_dty, constr_map)) } -fn ty_check_view(ctx: &PlExprTyCtx, view: &mut View) -> TyResult { - let arg_tys = view - .args - .iter_mut() - .map(|v| Ok(Ty::new(TyKind::FnTy(Box::new(ty_check_view(ctx, v)?))))) - .collect::>>()?; - let view_fn_ty = ctx.gl_ctx.fn_ty_by_ident(&view.name)?; - let partially_applied_view_fn_ty = super::apply_gen_args_to_fn_ty_checked( +fn ty_check_view<'a, 'm>( + ctx: &'a PlExprTyCtx<'a>, + view: &'m mut View<'a>, + arena: &'a Bump, +) -> TyResult<'a, FnTy<'a>> { + let mut arg_tys = BumpVec::new_in(arena); + for v in view.args.iter_mut() { + let inner = ty_check_view(ctx, v, arena)?; + arg_tys.push(Ty::new(TyKind::FnTy(arena.alloc(inner)))); + } + + let name_ref: &'a Ident<'a> = arena.alloc(view.name.clone()); + let view_fn_ty = ctx.gl_ctx.fn_ty_by_ident(name_ref)?; + + let gen_args_ref = { + let mut tmp = BumpVec::new_in(arena); + for ga in view.gen_args.iter() { + tmp.push(ga.clone_in(arena)); + } + arena.alloc(tmp) + }; + + let partially_applied_view_fn_ty = arena.alloc(super::apply_gen_args_to_fn_ty_checked( ctx.kind_ctx, &ctx.exec, view_fn_ty, - &view.gen_args, + gen_args_ref, + arena, + )?); + + let actual_view_fn_ty = arena.alloc(create_view_ty_with_input_view_and_free_ret( + &ctx.exec, arg_tys, arena, + )); + + let mono_fn_ty = arena.alloc(unify::inst_fn_ty_scheme( + partially_applied_view_fn_ty, + arena, + )); + + let (constr_map, _) = unify::constrain(actual_view_fn_ty, mono_fn_ty, arena)?; + + unify::substitute(&constr_map, view, arena); + + let inferred_k_args = super::infer_kinded_args::infer_kinded_args( + partially_applied_view_fn_ty, + mono_fn_ty, + arena, )?; - let mut actual_view_fn_ty = create_view_ty_with_input_view_and_free_ret(&ctx.exec, arg_tys); - let mut mono_fn_ty = unify::inst_fn_ty_scheme(&partially_applied_view_fn_ty); - let (constr_map, _) = unify::constrain(&mut actual_view_fn_ty, &mut mono_fn_ty)?; - // substitute implicit identifiers in view, that were inferred without the input data type - unify::substitute(&constr_map, view); - let mut inferred_k_args = - super::infer_kinded_args::infer_kinded_args(&partially_applied_view_fn_ty, &mono_fn_ty)?; - view.gen_args.append(&mut inferred_k_args); + view.gen_args.extend(inferred_k_args.into_iter()); + let res_view_ty = FnTy::new( - vec![], + arena, + std::iter::empty(), actual_view_fn_ty.generic_exec.clone(), - vec![actual_view_fn_ty.param_sigs.pop().unwrap()], + std::iter::once(actual_view_fn_ty.param_sigs.pop().expect("param exists")), actual_view_fn_ty.exec.clone(), - actual_view_fn_ty.ret_ty.as_ref().clone(), - vec![], + actual_view_fn_ty.ret_ty, // &'a Ty<'a> + std::iter::empty(), ); + Ok(res_view_ty) } -fn create_view_ty_with_input_view_and_free_ret(exec: &ExecExpr, mut arg_tys: Vec) -> FnTy { - arg_tys.push(Ty::new(TyKind::Data(Box::new(DataTy::new( - utils::fresh_ident("in_view_dty", DataTyKind::Ident), +fn create_view_ty_with_input_view_and_free_ret<'a>( + exec: &ExecExpr<'a>, + mut arg_tys: BumpVec<'a, Ty<'a>>, + arena: &'a Bump, +) -> FnTy<'a> { + arg_tys.push(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, + utils::fresh_ident(arena, "in_view_dty", DataTyKind::Ident), + ))))); + + let mut param_sigs = BumpVec::new_in(arena); + for ty in arg_tys.into_iter() { + let ty_ref: &'a Ty<'a> = arena.alloc(ty); + param_sigs.push(ParamSig::new(exec.clone(), ty_ref)); + } + + let ret_ty_ref: &'a Ty<'a> = arena.alloc(Ty::new(TyKind::Data(arena.alloc(DataTy::new( + arena, + utils::fresh_ident(arena, "view_out_dty", DataTyKind::Ident), ))))); + FnTy::new( - vec![], + arena, + std::iter::empty(), None, - arg_tys - .into_iter() - .map(|ty| ParamSig::new(exec.clone(), ty)) - .collect(), + param_sigs, exec.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(utils::fresh_ident( - "view_out_dty", - DataTyKind::Ident, - ))))), - vec![], + ret_ty_ref, + std::iter::empty(), ) } -fn ty_check_ident( - ctx: &PlExprTyCtx, - ident: &Ident, -) -> TyResult<(Ty, Vec, Vec)> { +fn ty_check_ident<'a>( + ctx: &'a PlExprTyCtx<'a>, + ident: &'a Ident<'a>, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { // if let Ok(tty) = ctx.ty_ctx.ty_of_ident(ident) { let tty = ctx.ty_ctx.ty_of_ident(ident)?; if !&tty.is_fully_alive() { @@ -214,7 +280,7 @@ fn ty_check_ident( // } } -fn default_mem_by_exec(exec_ty: &ExecTyKind) -> Option { +fn default_mem_by_exec<'a>(exec_ty: &'a ExecTyKind<'a>) -> Option> { match exec_ty { ExecTyKind::CpuThread => Some(Memory::CpuMem), ExecTyKind::GpuThread => Some(Memory::GpuLocal), @@ -230,13 +296,14 @@ fn default_mem_by_exec(exec_ty: &ExecTyKind) -> Option { } // TODO refactor by fusing with ty_check_field_proj -fn ty_check_proj( - ctx: &PlExprTyCtx, - tuple_expr: &mut PlaceExpr, +fn ty_check_proj<'a>( + ctx: &'a PlExprTyCtx<'a>, + tuple_expr: &'a PlaceExpr<'a>, n: usize, -) -> TyResult<(Ty, Vec, Vec)> { - let (mem, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, tuple_expr)?; - let tuple_dty = match &tuple_expr.ty.as_ref().unwrap().ty { + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let (mem, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, tuple_expr, arena)?; + let tuple_dty = match &tuple_expr.ty.get().unwrap().ty { TyKind::Data(dty) => dty, ty_kind => { return Err(TyError::ExpectedTupleType( @@ -249,7 +316,7 @@ fn ty_check_proj( DataTyKind::Tuple(elem_dtys) => { if let Some(dty) = elem_dtys.get(n) { Ok(( - Ty::new(TyKind::Data(Box::new(dty.clone()))), + Ty::new(TyKind::Data(arena.alloc(dty.clone()))), mem, passed_prvs, )) @@ -260,19 +327,20 @@ fn ty_check_proj( } } dty_kind => Err(TyError::ExpectedTupleType( - TyKind::Data(Box::new(DataTy::new(dty_kind.clone()))), + TyKind::Data(arena.alloc(DataTy::new(arena, dty_kind.clone()))), tuple_expr.clone(), )), } } -fn ty_check_field_proj( - ctx: &PlExprTyCtx, - struct_expr: &mut PlaceExpr, - ident: &Ident, -) -> TyResult<(Ty, Vec, Vec)> { - let (mem, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, struct_expr)?; - let struct_dty = match &struct_expr.ty.as_ref().unwrap().ty { +fn ty_check_field_proj<'a>( + ctx: &'a PlExprTyCtx<'a>, + struct_expr: &'a PlaceExpr<'a>, + ident: &'a Ident<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let (mem, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, struct_expr, arena)?; + let struct_dty = match &struct_expr.ty.get().unwrap().ty { TyKind::Data(dty) => dty, ty_kind => { return Err(TyError::ExpectedTupleType( @@ -286,7 +354,7 @@ fn ty_check_field_proj( DataTyKind::Struct(struct_decl) => { if let Some(field) = struct_decl.fields.iter().find(|f| &f.0 == ident) { Ok(( - Ty::new(TyKind::Data(Box::new(field.1.clone()))), + Ty::new(TyKind::Data(arena.alloc(field.1.clone()))), mem, passed_prvs, )) @@ -297,18 +365,19 @@ fn ty_check_field_proj( } } dty_kind => Err(TyError::ExpectedTupleType( - TyKind::Data(Box::new(DataTy::new(dty_kind.clone()))), + TyKind::Data(arena.alloc(DataTy::new(arena, dty_kind.clone()))), struct_expr.clone(), )), } } -fn ty_check_deref( - ctx: &PlExprTyCtx, - borr_expr: &mut PlaceExpr, -) -> TyResult<(Ty, Vec, Vec)> { - let (mut inner_mem, mut passed_prvs) = ty_check_and_passed_mems_prvs(ctx, borr_expr)?; - let borr_dty = if let TyKind::Data(dty) = &borr_expr.ty.as_ref().unwrap().ty { +fn ty_check_deref<'a>( + ctx: &'a PlExprTyCtx<'a>, + borr_expr: &'a PlaceExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let (mut inner_mem, mut passed_prvs) = ty_check_and_passed_mems_prvs(ctx, borr_expr, arena)?; + let borr_dty = if let TyKind::Data(dty) = &borr_expr.ty.get().unwrap().ty { dty } else { return Err(TyError::String( @@ -325,7 +394,7 @@ fn ty_check_deref( passed_prvs.push(reff.rgn.clone()); inner_mem.push(reff.mem.clone()); Ok(( - Ty::new(TyKind::Data(Box::new(reff.dty.as_ref().clone()))), + Ty::new(TyKind::Data(arena.alloc(reff.dty.clone()))), inner_mem, passed_prvs, )) @@ -333,7 +402,7 @@ fn ty_check_deref( DataTyKind::RawPtr(dty) => { // TODO is anything of this correct? Ok(( - Ty::new(TyKind::Data(Box::new(dty.as_ref().clone()))), + Ty::new(TyKind::Data(arena.alloc(dty.clone()))), inner_mem, passed_prvs, )) @@ -344,12 +413,20 @@ fn ty_check_deref( } } -fn ty_check_select( - ctx: &PlExprTyCtx, - p: &mut PlaceExpr, - select_exec: &mut ExecExpr, -) -> TyResult<(Ty, Vec, Vec)> { - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, select_exec)?; +fn ty_check_select<'a>( + ctx: &'a PlExprTyCtx<'a>, + p: &'a PlaceExpr<'a>, + select_exec: &'a ExecExpr<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let mut select_exec_cloned = select_exec.clone_in(arena); + exec::ty_check( + ctx.nat_ctx, + ctx.ty_ctx, + ctx.ident_exec, + &mut select_exec_cloned, + arena, + )?; // FIXME this check is required for uniq accesses, but not for shared accesses because there // the duplication of accesses is fine. Move this check into ownership/borrow checking? // if &ctx.exec != select_exec { @@ -357,9 +434,15 @@ fn ty_check_select( // "Trying select memory for illegal combination of excution resources.".to_string(), // )); // } - let mut outer_exec = select_exec.remove_last_distrib(); - exec::ty_check(ctx.nat_ctx, ctx.ty_ctx, ctx.ident_exec, &mut outer_exec)?; - let outer_ctx = PlExprTyCtx { + let mut outer_exec = select_exec.remove_last_distrib(arena); + exec::ty_check( + ctx.nat_ctx, + ctx.ty_ctx, + ctx.ident_exec, + &mut outer_exec, + arena, + )?; + let outer_ctx = arena.alloc(PlExprTyCtx { gl_ctx: ctx.gl_ctx, nat_ctx: ctx.nat_ctx, kind_ctx: ctx.kind_ctx, @@ -368,43 +451,45 @@ fn ty_check_select( ty_ctx: ctx.ty_ctx, exec_borrow_ctx: ctx.exec_borrow_ctx, own: ctx.own, - }; - let (mems, prvs) = ty_check_and_passed_mems_prvs(&outer_ctx, p)?; - let mut p_dty = p.ty.as_ref().unwrap().dty().clone(); + }); + let (mems, prvs) = ty_check_and_passed_mems_prvs(outer_ctx, p, arena)?; + let mut p_dty = p.ty.get().unwrap().dty().clone_in(arena); match p_dty.dty { - DataTyKind::Array(elem_dty, n) | DataTyKind::ArrayShape(elem_dty, n) => { + DataTyKind::Array(elem_dty, _n) | DataTyKind::ArrayShape(elem_dty, _n) => { // TODO check sizes // if n != distrib_exec.active_distrib_size() { // return Err(TyError::String("There must be as many elements in the view // as there exist execution resources that select from it.".to_string())); // } - p_dty = *elem_dty; + p_dty = (*elem_dty).clone_in(arena); } _ => { return Err(TyError::String("Expected an array or view.".to_string())); } } - Ok((Ty::new(TyKind::Data(Box::new(p_dty))), mems, prvs)) + Ok((Ty::new(TyKind::Data(arena.alloc(p_dty))), mems, prvs)) } -fn ty_check_index( - ctx: &PlExprTyCtx, - pl_expr: &mut PlaceExpr, - idx: &mut Nat, -) -> TyResult<(Ty, Vec, Vec)> { - let (mems, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, pl_expr)?; - let pl_expr_dty = if let TyKind::Data(dty) = &pl_expr.ty.as_ref().unwrap().ty { +fn ty_check_index<'a>( + ctx: &'a PlExprTyCtx<'a>, + pl_expr: &'a PlaceExpr<'a>, + idx: &'a Nat<'a>, + arena: &'a Bump, +) -> TyResult<'a, (Ty<'a>, Vec>, Vec>)> { + let (mems, passed_prvs) = ty_check_and_passed_mems_prvs(ctx, pl_expr, arena)?; + + let pl_expr_dty = if let TyKind::Data(dty) = &pl_expr.ty.get().unwrap().ty { dty } else { return Err(TyError::String( "Trying to index into non array type.".to_string(), )); }; - let (elem_dty, n) = match pl_expr_dty.dty.clone() { + let (elem_dty_ref, n_ref): (&'a DataTy<'a>, &'a Nat<'a>) = match &pl_expr_dty.dty { DataTyKind::Array(elem_dty, n) | DataTyKind::ArrayShape(elem_dty, n) => (*elem_dty, n), DataTyKind::At(arr_dty, _) => { if let DataTyKind::Array(elem_ty, n) = &arr_dty.dty { - (elem_ty.as_ref().clone(), n.clone()) + (*elem_ty, n) } else { return Err(TyError::String( "Trying to index into non array type.".to_string(), @@ -418,11 +503,17 @@ fn ty_check_index( } }; - if n.eval(ctx.nat_ctx)? <= idx.eval(ctx.nat_ctx)? { + if n_ref.eval(ctx.nat_ctx)? <= idx.eval(ctx.nat_ctx)? { return Err(TyError::String( "Trying to access array out-of-bounds.".to_string(), )); } - Ok((Ty::new(TyKind::Data(Box::new(elem_dty))), mems, passed_prvs)) + let elem_dty_owned = elem_dty_ref.clone_in(arena); + + Ok(( + Ty::new(TyKind::Data(arena.alloc(elem_dty_owned))), + mems, + passed_prvs, + )) } diff --git a/src/ty_check/pre_decl.rs b/src/ty_check/pre_decl.rs index 82f17996..d5ef6787 100644 --- a/src/ty_check/pre_decl.rs +++ b/src/ty_check/pre_decl.rs @@ -1,9 +1,11 @@ -use crate::ast::{ +use crate::arena_ast::{ AtomicTy, BaseExec, BinOpNat, DataTy, DataTyKind, DimCompo, ExecExpr, ExecExprKind, ExecTy, ExecTyKind, FnTy, Ident, IdentExec, IdentKinded, Kind, Memory, Nat, NatConstr, Ownership, ParamSig, Provenance, RefDty, ScalarTy, Ty, TyKind, }; +use bumpalo::{collections::Vec as BumpVec, Bump}; + pub static GPU_DEVICE: &str = "gpu_device"; pub static GPU_ALLOC: &str = "gpu_alloc_copy"; pub static COPY_TO_HOST: &str = "copy_to_host"; @@ -41,49 +43,95 @@ pub static TAKE_RIGHT: &str = "take_right"; pub static SELECT_RANGE: &str = "select_range"; pub static MAP: &str = "map"; -pub fn fun_decls() -> Vec<(&'static str, FnTy)> { - let decls = [ - // Built-in functions - (GPU_DEVICE, gpu_device_ty()), - (GPU_ALLOC, gpu_alloc_copy_ty()), - (COPY_TO_HOST, copy_to_host_ty()), - (COPY_TO_GPU, copy_to_gpu_ty()), - (CREATE_ARRAY, create_array_ty()), - (TO_RAW_PTR, to_raw_ptr_ty()), - (OFFSET_RAW_PTR, offset_raw_ptr_ty()), - (SHFL_SYNC, shfl_sync_ty()), - (SHFL_UP, shfl_up_ty()), - (BALLOT_SYNC, ballot_sync_ty()), - (THREAD_ID_X, thread_id_x_ty()), - (GET_WARP_ID, get_warp_id_ty()), - (GET_LANE_ID, get_lane_id_ty()), - (NAT_AS_U64, nat_as_u64_ty()), - // Built-in atomic functions - (ATOMIC_STORE, atomic_store_ty()), - (ATOMIC_LOAD, atomic_load_ty()), - (ATOMIC_FETCH_OR, atomic_fetch_or_ty()), - (ATOMIC_FETCH_ADD, atomic_fetch_add_ty()), - (ATOMIC_MIN, atomic_min_ty()), - (TO_ATOMIC_ARRAY, to_atomic_array_ty()), - (TO_ATOMIC, to_atomic_ty()), - // View constructors - (TO_VIEW, to_view_ty()), - (REVERSE, reverse_ty()), - (MAP, map_ty()), - (GROUP, group_ty()), - (JOIN, join_ty()), - (TRANSPOSE, transpose_ty()), - // (TAKE_LEFT, take_ty(TakeSide::Left)), - // (TAKE_RIGHT, take_ty(TakeSide::Right)), - (SELECT_RANGE, select_range_ty()), - ]; - - decls.to_vec() +pub fn fun_decls<'a>(arena: &'a Bump) -> BumpVec<'a, (&'static str, FnTy<'a>)> { + let mut decls = BumpVec::new_in(arena); + + decls.push((GPU_DEVICE, gpu_device_ty(arena))); + decls.push((GPU_ALLOC, gpu_alloc_copy_ty(arena))); + decls.push((COPY_TO_HOST, copy_to_host_ty(arena))); + decls.push((COPY_TO_GPU, copy_to_gpu_ty(arena))); + decls.push((CREATE_ARRAY, create_array_ty(arena))); + decls.push((TO_RAW_PTR, to_raw_ptr_ty(arena))); + decls.push((OFFSET_RAW_PTR, offset_raw_ptr_ty(arena))); + decls.push((SHFL_SYNC, shfl_sync_ty(arena))); + decls.push((SHFL_UP, shfl_up_ty(arena))); + decls.push((BALLOT_SYNC, ballot_sync_ty(arena))); + decls.push((THREAD_ID_X, thread_id_x_ty(arena))); + decls.push((GET_WARP_ID, get_warp_id_ty(arena))); + decls.push((GET_LANE_ID, get_lane_id_ty(arena))); + decls.push((NAT_AS_U64, nat_as_u64_ty(arena))); + decls.push((ATOMIC_STORE, atomic_store_ty(arena))); + decls.push((ATOMIC_LOAD, atomic_load_ty(arena))); + decls.push((ATOMIC_FETCH_OR, atomic_fetch_or_ty(arena))); + decls.push((ATOMIC_FETCH_ADD, atomic_fetch_add_ty(arena))); + decls.push((ATOMIC_MIN, atomic_min_ty(arena))); + decls.push((TO_ATOMIC_ARRAY, to_atomic_array_ty(arena))); + decls.push((TO_ATOMIC, to_atomic_ty(arena))); + decls.push((TO_VIEW, to_view_ty(arena))); + decls.push((REVERSE, reverse_ty(arena))); + decls.push((MAP, map_ty(arena))); + decls.push((GROUP, group_ty(arena))); + decls.push((JOIN, join_ty(arena))); + decls.push((TRANSPOSE, transpose_ty(arena))); + decls.push((SELECT_RANGE, select_range_ty(arena))); + + decls +} + +// DataTy helpers +fn d_ident<'a>(id: Ident<'a>, arena: &'a Bump) -> DataTy<'a> { + DataTy::new(arena, DataTyKind::Ident(id)) +} +fn d_scalar<'a>(s: ScalarTy, arena: &'a Bump) -> DataTy<'a> { + DataTy::new(arena, DataTyKind::Scalar(s)) +} +fn d_atomic<'a>(a: AtomicTy, arena: &'a Bump) -> DataTy<'a> { + DataTy::new(arena, DataTyKind::Atomic(a)) +} + +fn d_array<'a>(arena: &'a Bump, elem: DataTy<'a>, n: Nat<'a>) -> DataTy<'a> { + let elem_ref = arena.alloc(elem); + DataTy::new(arena, DataTyKind::Array(elem_ref, n)) +} +fn d_array_shape<'a>(arena: &'a Bump, elem: DataTy<'a>, n: Nat<'a>) -> DataTy<'a> { + let elem_ref = arena.alloc(elem); + DataTy::new(arena, DataTyKind::ArrayShape(elem_ref, n)) +} +fn d_ref<'a>( + arena: &'a Bump, + prv: Provenance<'a>, + own: Ownership, + mem: Memory<'a>, + ty: DataTy<'a>, +) -> DataTy<'a> { + let reff = arena.alloc(RefDty::new(arena, prv, own, mem, ty)); + DataTy::new(arena, DataTyKind::Ref(reff)) +} + +// Ty helpers +fn ty_data_ref<'a>(arena: &'a Bump, d: DataTy<'a>) -> &'a Ty<'a> { + let data_ty = arena.alloc(d); + arena.alloc(Ty { + ty: TyKind::Data(data_ty), + span: None, + }) +} + +// Exec helpers +fn exec_ident<'a>(arena: &'a Bump, id: Ident<'a>) -> ExecExpr<'a> { + let kind = arena.alloc(ExecExprKind::new(arena, BaseExec::Ident(id))); + ExecExpr { + exec: kind, + ty: None, + span: None, + } } -fn create_array_ty() -> FnTy { - let n = Ident::new("n"); - let d = Ident::new("d"); +fn create_array_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // (d) -[Any]-> [d; n] + let n = Ident::new(arena, "n"); + let d = Ident::new(arena, "d"); + let n_nat = IdentKinded { ident: n.clone(), kind: Kind::Nat, @@ -92,23 +140,26 @@ fn create_array_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // param: d + let param_ty = ty_data_ref(arena, d_ident(d.clone(), arena)); + + // return: [d; n] + let ret_dt = d_array(arena, d_ident(d, arena), Nat::Ident(n)); + let ret_ty = ty_data_ref(arena, ret_dt); + FnTy::new( - vec![n_nat, d_dty], + arena, + [n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ident( - d.clone(), - ))))), - )], + [ParamSig::new(exec_expr.clone(), param_ty)], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Array( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(n), - ))))), - vec![], + ret_ty, + [], ) } @@ -116,10 +167,11 @@ fn create_array_ty() -> FnTy { // ( // &r gpu.thread uniq m t // ) -[gpu.thread]-> RawPtr -fn to_raw_ptr_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); - let d = Ident::new("d"); +fn to_raw_ptr_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // (&r uniq m d) -[gpu.thread]-> RawPtr + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let d = Ident::new(arena, "d"); let r_prv = IdentKinded { ident: r.clone(), @@ -133,27 +185,37 @@ fn to_raw_ptr_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // param: &r uniq m d + let param_dt = d_ref( + arena, + Provenance::Ident(r.clone()), + Ownership::Uniq, + Memory::Ident(m.clone()), + d_ident(d.clone(), arena), + ); + let param_ty = ty_data_ref(arena, param_dt); + + // return: RawPtr + let ret_inner = d_ident(d, arena); + let ret_dt = DataTy::new(arena, DataTyKind::RawPtr(arena.alloc(ret_inner))); + let ret_ty = ty_data_ref(arena, ret_dt); + FnTy::new( - vec![r_prv, m_mem, d_dty], + arena, + [r_prv, m_mem, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Uniq, - Memory::Ident(m), - DataTy::new(DataTyKind::Ident(d.clone())), - )), - ))))), - )], + [ParamSig::new(exec_expr.clone(), param_ty)], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::RawPtr( - Box::new(DataTy::new(DataTyKind::Ident(d))), - ))))), - vec![], + ret_ty, + [], ) } @@ -161,233 +223,275 @@ fn to_raw_ptr_ty() -> FnTy { // ( // RawPtr, i32 // ) -[gpu.thread]-> RawPtr -fn offset_raw_ptr_ty() -> FnTy { - let d = Ident::new("d"); +// (RawPtr, i32) -[gpu.thread]-> RawPtr +fn offset_raw_ptr_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let d = Ident::new(arena, "d"); let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // param 1: RawPtr + let p1_dt = DataTy::new( + arena, + DataTyKind::RawPtr(arena.alloc(d_ident(d.clone(), arena))), + ); + let p1_ty = ty_data_ref(arena, p1_dt); + + // param 2: i32 + let p2_ty = ty_data_ref(arena, d_scalar(ScalarTy::I32, arena)); + + // return: RawPtr + let ret_dt = DataTy::new(arena, DataTyKind::RawPtr(arena.alloc(d_ident(d, arena)))); + let ret_ty = ty_data_ref(arena, ret_dt); + FnTy::new( - vec![d_dty], + arena, + [d_dty], Some(ident_exec), - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::RawPtr( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), p1_ty), + ParamSig::new(exec_expr.clone(), p2_ty), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::RawPtr( - Box::new(DataTy::new(DataTyKind::Ident(d))), - ))))), - vec![], + ret_ty, + [], ) } // ballot_sync: // <>(bool) -[w: gpu.warp]-> u32 -fn ballot_sync_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("w"), ExecTy::new(ExecTyKind::GpuWarp)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); - let param_exec = ExecExpr::new(exec_expr.exec.clone().forall(DimCompo::X)); +// <>(bool) -[w: gpu.warp]-> u32 +fn ballot_sync_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "w"), + ExecTy::new(ExecTyKind::GpuWarp), + ); + + // body exec: w + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // param exec: w.forall(x) + let lane_kind = + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())).forall(DimCompo::X); + let param_exec = ExecExpr { + exec: arena.alloc(lane_kind), + ty: None, + span: None, + }; + + // param: bool + let p_ty = ty_data_ref(arena, d_scalar(ScalarTy::Bool, arena)); + // return: u32 + let r_ty = ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)); FnTy::new( - vec![], + arena, + [], Some(ident_exec), - vec![ParamSig::new( - param_exec, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Bool, - ))))), - )], + [ParamSig::new(param_exec, p_ty)], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + r_ty, + [], ) } // FIXME warp should have the type given in the comment below // shfl_sync: -// (u32, u32) -[w.forall]-> u32 -fn shfl_sync_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("w"), ExecTy::new(ExecTyKind::GpuWarp)); - let exec_expr_lane = ExecExpr::new( - ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone())).forall(DimCompo::X), +// (u32, u32) -[w.forall]-> u32 +fn shfl_sync_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generic exec: w : gpu.warp + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "w"), + ExecTy::new(ExecTyKind::GpuWarp), ); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // param exec = w.forall(X) + let lane_kind = + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())).forall(DimCompo::X); + let param_exec = ExecExpr { + exec: arena.alloc(lane_kind), + ty: None, + span: None, + }; + + // body exec = w + let body_exec = exec_ident(arena, ident_exec.ident.clone()); + + // types + let u32_ty = ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)); FnTy::new( - vec![], - Some(ident_exec), - vec![ - ParamSig::new( - exec_expr_lane.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - ), - ParamSig::new( - exec_expr_lane.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - ), + arena, + [], // generics (kinded Idents) + Some(ident_exec), // generic exec + [ + ParamSig::new(param_exec.clone(), u32_ty), + ParamSig::new(param_exec, u32_ty), ], - exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + body_exec, + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), + [], // nat constraints ) } // shfl_up: -// <>(u32, i32) -[gpu.warp]-> u32 -fn shfl_up_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuWarp)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); +// <>(u32, i32) -[gpu.warp]-> u32 +fn shfl_up_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generic exec: ex : gpu.warp + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuWarp), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // types + let u32_ty = ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)); + let i32_ty = ty_data_ref(arena, d_scalar(ScalarTy::I32, arena)); FnTy::new( - vec![], + arena, + [], // no kinded generics Some(ident_exec), - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), u32_ty), + ParamSig::new(exec_expr.clone(), i32_ty), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), + [], // no nat constraints ) } // nat_as_u64: -// () -[view]-> u64 -fn nat_as_u64_ty() -> FnTy { - let n = Ident::new("n"); +// () -[Any]-> u64 +fn nat_as_u64_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generic nat parameter + let n = Ident::new(arena, "n"); let n_nat = IdentKinded { ident: n, kind: Kind::Nat, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // execution level: Any, carried via an exec identifier "ex" + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); FnTy::new( - vec![n_nat], - Some(ident_exec), - vec![], - exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U64, - ))))), - vec![], + arena, + [n_nat], // generics + Some(ident_exec), // generic exec + [], // params + exec_expr, // function exec + ty_data_ref(arena, d_scalar(ScalarTy::U64, arena)), // return type + [], // nat constraints ) } -// get_warp_id: <>() -[w: gpu.Warp]-> u32 -fn get_warp_id_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("w"), ExecTy::new(ExecTyKind::GpuWarp)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); +// get_warp_id: +// <>() -[w: gpu.warp]-> u32 +fn get_warp_id_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generic exec: w : gpu.warp + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "w"), + ExecTy::new(ExecTyKind::GpuWarp), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); FnTy::new( - vec![], - Some(ident_exec), - vec![], - exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + arena, + [], // no kinded generics + Some(ident_exec), // generic exec + [], // params + exec_expr, // function exec + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), // return type + [], // nat constraints ) } -// get_lane_id: <>() -[w: gpu.Thread]-> u32 -fn get_lane_id_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("t"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); +// get_lane_id: <>() -[t: gpu.thread]-> u32 +fn get_lane_id_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "t"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); FnTy::new( - vec![], - Some(ident_exec), - vec![], - exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + arena, + [], // no kinded generics + Some(ident_exec), // generic exec + [], // no params + exec_expr, // function exec + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), // return type + [], // no nat constraints ) } -// thread_id_x: -// <>() -[gpu.thread]-> u32 -fn thread_id_x_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); +// thread_id_x: <>() -[gpu.thread]-> u32 +fn thread_id_x_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); FnTy::new( - vec![], - Some(ident_exec), - vec![], - exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + arena, + [], // no kinded generics + Some(ident_exec), // generic exec + [], // no params + exec_expr, // function exec + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), // return type + [], // no nat constraints ) } // gpu: // <>(i32) -[cpu.thread]-> Gpu -fn gpu_device_ty() -> FnTy { - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::CpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); +fn gpu_device_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::CpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); FnTy::new( - vec![], + arena, + [], // no kinded generics Some(ident_exec), - vec![ParamSig::new( + [ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), + ty_data_ref(arena, d_scalar(ScalarTy::I32, arena)), )], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Gpu, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::Gpu, arena)), + [], // no nat constraints ) } // to_atomic_array: -// (ex: &r uniq m [u32; n]) -[x: Any]-> &r uniq m [AtomicU32; n] -fn to_atomic_array_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); - let n = Ident::new("n"); +// (ex: &r uniq m [u32; n]) -[ex: Any]-> &r uniq m [AtomicU32; n] +fn to_atomic_array_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let n = Ident::new(arena, "n"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -400,47 +504,49 @@ fn to_atomic_array_ty() -> FnTy { ident: n.clone(), kind: Kind::Nat, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r uniq m [u32; n] + let param_dty = d_ref( + arena, + Provenance::Ident(r.clone()), + Ownership::Uniq, + Memory::Ident(m.clone()), + d_array(arena, d_scalar(ScalarTy::U32, arena), Nat::Ident(n.clone())), + ); + + // &r uniq m [AtomicU32; n] + let ret_dty = d_ref( + arena, + Provenance::Ident(r), + Ownership::Uniq, + Memory::Ident(m), + d_array(arena, d_atomic(AtomicTy::AtomicU32, arena), Nat::Ident(n)), + ); FnTy::new( - vec![r_prv, m_mem, n_nat], + arena, + [r_prv, m_mem, n_nat], Some(ident_exec), - vec![ParamSig::new( + [ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r.clone()), - Ownership::Uniq, - Memory::Ident(m.clone()), - DataTy::new(DataTyKind::Array( - Box::new(DataTy::new(DataTyKind::Scalar(ScalarTy::U32))), - Nat::Ident(n.clone()), - )), - )), - ))))), + ty_data_ref(arena, param_dty), )], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Uniq, - Memory::Ident(m), - DataTy::new(DataTyKind::Array( - Box::new(DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32))), - Nat::Ident(n), - )), - )), - ))))), - vec![], + ty_data_ref(arena, ret_dty), + [], // no nat constraints ) } // to_atomic: -// (&r uniq x m u32) -[x: Any]-> &r uniq x m AtomicU32 -fn to_atomic_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +// (&r uniq m u32) -[ex: Any]-> &r uniq m AtomicU32 +fn to_atomic_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -449,41 +555,49 @@ fn to_atomic_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r uniq m u32 + let param_dty = d_ref( + arena, + Provenance::Ident(r.clone()), + Ownership::Uniq, + Memory::Ident(m.clone()), + d_scalar(ScalarTy::U32, arena), + ); + + // &r uniq m AtomicU32 + let ret_dty = d_ref( + arena, + Provenance::Ident(r), + Ownership::Uniq, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicU32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ParamSig::new( + [ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r.clone()), - Ownership::Uniq, - Memory::Ident(m.clone()), - DataTy::new(DataTyKind::Scalar(ScalarTy::U32)), - )), - ))))), + ty_data_ref(arena, param_dty), )], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Uniq, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32)), - )), - ))))), - vec![], + ty_data_ref(arena, ret_dty), + [], ) } // atomic_store: -// (&r shrd m AtomicU32, u32) -[gpu.thread]-> () -fn atomic_store_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +// (&r shrd m AtomicU32, u32) -[gpu.thread]-> () +fn atomic_store_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -492,44 +606,46 @@ fn atomic_store_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r shrd m AtomicU32 + let ptr_arg = d_ref( + arena, + Provenance::Ident(r), + Ownership::Shrd, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicU32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, ptr_arg)), ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Shrd, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32)), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), ), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Unit, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::Unit, arena)), + [], ) } // atomic_fetch_or: // (&r shrd m AtomicU32, u32) -[gpu.thread]-> u32 -fn atomic_fetch_or_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +fn atomic_fetch_or_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -538,44 +654,46 @@ fn atomic_fetch_or_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r shrd m AtomicU32 + let ptr_arg = d_ref( + arena, + Provenance::Ident(r), + Ownership::Shrd, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicU32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, ptr_arg)), ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Shrd, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32)), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), ), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), + [], ) } // atomic_min: // (&r shrd m AtomicI32, i32) -[gpu.thread]-> i32 -fn atomic_min_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +fn atomic_min_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -584,44 +702,46 @@ fn atomic_min_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("t"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "t"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r shrd m AtomicI32 + let ptr_arg = d_ref( + arena, + Provenance::Ident(r), + Ownership::Shrd, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicI32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Shrd, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicI32)), - )), - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, ptr_arg)), ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), + ty_data_ref(arena, d_scalar(ScalarTy::I32, arena)), ), ], - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::I32, - ))))), - vec![], + exec_expr, + ty_data_ref(arena, d_scalar(ScalarTy::I32, arena)), + [], ) } // atomic_fetch_add: // (&r shrd m AtomicU32, u32) -[gpu.thread]-> u32 -fn atomic_fetch_add_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +fn atomic_fetch_add_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -630,44 +750,46 @@ fn atomic_fetch_add_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r shrd m AtomicU32 + let ptr_arg = d_ref( + arena, + Provenance::Ident(r), + Ownership::Shrd, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicU32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, ptr_arg)), ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Shrd, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32)), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), ), ], exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), + [], ) } // atomic_load: // (&r shrd m AtomicU32) -[gpu.thread]-> u32 -fn atomic_load_ty() -> FnTy { - let r = Ident::new("r"); - let m = Ident::new("m"); +fn atomic_load_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r = Ident::new(arena, "r"); + let m = Ident::new(arena, "m"); + let r_prv = IdentKinded { ident: r.clone(), kind: Kind::Provenance, @@ -676,39 +798,46 @@ fn atomic_load_ty() -> FnTy { ident: m.clone(), kind: Kind::Memory, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::GpuThread)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = IdentExec::new_in( + arena, + Ident::new(arena, "ex"), + ExecTy::new(ExecTyKind::GpuThread), + ); + let exec_expr = exec_ident(arena, ident_exec.ident.clone()); + + // &r shrd m AtomicU32 + let ptr_arg = d_ref( + arena, + Provenance::Ident(r), + Ownership::Shrd, + Memory::Ident(m), + d_atomic(AtomicTy::AtomicU32, arena), + ); FnTy::new( - vec![r_prv, m_mem], + arena, + [r_prv, m_mem], Some(ident_exec), - vec![ParamSig::new( + [ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r), - Ownership::Shrd, - Memory::Ident(m), - DataTy::new(DataTyKind::Atomic(AtomicTy::AtomicU32)), - )), - ))))), + ty_data_ref(arena, ptr_arg), )], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::U32, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::U32, arena)), + [], ) } -// gpu_alloc: +// gpu_alloc_copy: // ( -// &r1 uniq cpu.mem Gpu, &r2 shrd cpu.mem t -// ) -[cpu.thread]-> t @ gpu.global -fn gpu_alloc_copy_ty() -> FnTy { - let r1 = Ident::new("r1"); - let r2 = Ident::new("r2"); - let d = Ident::new("d"); +// &r1 uniq cpu.mem Gpu, &r2 shrd cpu.mem d +// ) -[cpu.thread]-> d @ gpu.global +fn gpu_alloc_copy_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r1 = Ident::new(arena, "r1"); + let r2 = Ident::new(arena, "r2"); + let d = Ident::new(arena, "d"); + let r1_prv = IdentKinded { ident: r1.clone(), kind: Kind::Provenance, @@ -721,51 +850,55 @@ fn gpu_alloc_copy_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::CpuThread)); + + let exec_expr = ExecExpr::new(arena, ExecExprKind::new(arena, BaseExec::CpuThread)); + + // &r1 uniq cpu.mem Gpu + let arg0 = d_ref( + arena, + Provenance::Ident(r1), + Ownership::Uniq, + Memory::CpuMem, + d_scalar(ScalarTy::Gpu, arena), + ); + + // &r2 shrd cpu.mem d + let arg1 = d_ref( + arena, + Provenance::Ident(r2), + Ownership::Shrd, + Memory::CpuMem, + d_ident(d.clone(), arena), + ); + + // d @ gpu.global + let ret_inner = d_ident(d, arena); + let ret_dty_ref = arena.alloc(ret_inner); + let ret_at = DataTy::new(arena, DataTyKind::At(ret_dty_ref, Memory::GpuGlobal)); FnTy::new( - vec![r1_prv, r2_prv, d_dty], + arena, + [r1_prv, r2_prv, d_dty], None, - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r1), - Ownership::Uniq, - Memory::CpuMem, - DataTy::new(DataTyKind::Scalar(ScalarTy::Gpu)), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r2), - Ownership::Shrd, - Memory::CpuMem, - DataTy::new(DataTyKind::Ident(d.clone())), - )), - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg0)), + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg1)), ], - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::At( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Memory::GpuGlobal, - ))))), - vec![], + exec_expr, + ty_data_ref(arena, ret_at), + [], ) } // copy_to_host: -// (&r1 shrd gpu.global d, &r2 uniq cpu.mem d) -// -[cpu.thread]-> () -fn copy_to_host_ty() -> FnTy { - let r1 = Ident::new("r1"); - let r2 = Ident::new("r2"); - let d = Ident::new("d"); +// ( +// &r1 shrd gpu.global d, &r2 uniq cpu.mem d +// ) -[cpu.thread]-> () +fn copy_to_host_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r1 = Ident::new(arena, "r1"); + let r2 = Ident::new(arena, "r2"); + let d = Ident::new(arena, "d"); + let r1_prv = IdentKinded { ident: r1.clone(), kind: Kind::Provenance, @@ -778,50 +911,49 @@ fn copy_to_host_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::CpuThread)); + + let exec_expr = ExecExpr::new(arena, ExecExprKind::new(arena, BaseExec::CpuThread)); + + // &r1 shrd gpu.global d + let arg0 = d_ref( + arena, + Provenance::Ident(r1), + Ownership::Shrd, + Memory::GpuGlobal, + d_ident(d.clone(), arena), + ); + + // &r2 uniq cpu.mem d + let arg1 = d_ref( + arena, + Provenance::Ident(r2), + Ownership::Uniq, + Memory::CpuMem, + d_ident(d, arena), + ); FnTy::new( - vec![r1_prv, r2_prv, d_dty], + arena, + [r1_prv, r2_prv, d_dty], None, - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r1), - Ownership::Shrd, - Memory::GpuGlobal, - DataTy::new(DataTyKind::Ident(d.clone())), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r2), - Ownership::Uniq, - Memory::CpuMem, - DataTy::new(DataTyKind::Ident(d)), - )), - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg0)), + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg1)), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Unit, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::Unit, arena)), + [], ) } // copy_to_gpu: -// (& r1 uniq gpu.global d, -// & r2 shrd cpu.mem d) -[cpu.thread]-> () -fn copy_to_gpu_ty() -> FnTy { - let r1 = Ident::new("r1"); - let r2 = Ident::new("r2"); - let d = Ident::new("d"); +// (&r1 uniq gpu.global d, &r2 shrd cpu.mem d) +// -[cpu.thread]-> () +fn copy_to_gpu_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let r1 = Ident::new(arena, "r1"); + let r2 = Ident::new(arena, "r2"); + let d = Ident::new(arena, "d"); + let r1_prv = IdentKinded { ident: r1.clone(), kind: Kind::Provenance, @@ -834,48 +966,47 @@ fn copy_to_gpu_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::CpuThread)); + + let exec_expr = ExecExpr::new(arena, ExecExprKind::new(arena, BaseExec::CpuThread)); + + // &r1 uniq gpu.global d + let arg0 = d_ref( + arena, + Provenance::Ident(r1), + Ownership::Uniq, + Memory::GpuGlobal, + d_ident(d.clone(), arena), + ); + + // &r2 shrd cpu.mem d + let arg1 = d_ref( + arena, + Provenance::Ident(r2), + Ownership::Shrd, + Memory::CpuMem, + d_ident(d, arena), + ); FnTy::new( - vec![r1_prv, r2_prv, d_dty], + arena, + [r1_prv, r2_prv, d_dty], None, - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r1), - Ownership::Uniq, - Memory::GpuGlobal, - DataTy::new(DataTyKind::Ident(d.clone())), - )), - ))))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ref( - Box::new(RefDty::new( - Provenance::Ident(r2), - Ownership::Shrd, - Memory::CpuMem, - DataTy::new(DataTyKind::Ident(d)), - )), - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg0)), + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arg1)), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Scalar( - ScalarTy::Unit, - ))))), - vec![], + ty_data_ref(arena, d_scalar(ScalarTy::Unit, arena)), + [], ) } // to_view: -// ([d; n]) -[view]-> [[d; n]] -fn to_view_ty() -> FnTy { - let n = Ident::new("n"); - let d = Ident::new("d"); +// ([d; n]) -[Any]-> [[d; n]] +fn to_view_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let n = Ident::new(arena, "n"); + let d = Ident::new(arena, "d"); + let n_nat = IdentKinded { ident: n.clone(), kind: Kind::Nat, @@ -884,68 +1015,79 @@ fn to_view_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [d; n] + let param_dty = d_array(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + + // return: [[d; n]] + let ret_dty = d_array_shape(arena, d_ident(d, arena), Nat::Ident(n)); FnTy::new( - vec![n_nat, d_dty], + arena, + [n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( + [ParamSig::new( exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Array( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))))), + ty_data_ref(arena, param_dty), )], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(n), - ))))), - vec![], + ty_data_ref(arena, ret_dty), + [], ) } -// rev/rev_mut: -// (&r W m [[d; n]]) -> &r W m [[d; n]] -fn reverse_ty() -> FnTy { - let n = Ident::new("n"); - let d = Ident::new("d"); +// rev / rev_mut +// ([[d; n]]) -[Any]-> [[d; n]] +fn reverse_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let n = Ident::new(arena, "n"); + let d = Ident::new(arena, "d"); + let n_nat = IdentKinded { ident: n.clone(), kind: Kind::Nat, }; - let d_ty = IdentKinded { + let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [[d; n]] + let param = d_array_shape(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + // return: [[d; n]] + let ret = d_array_shape(arena, d_ident(d, arena), Nat::Ident(n)); FnTy::new( - vec![n_nat, d_ty], + arena, + [n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))))), - )], + [ParamSig::new(exec_expr.clone(), ty_data_ref(arena, param))], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(n), - ))))), - vec![], + ty_data_ref(arena, ret), + [], ) } -//map_mut:(|d| -[ex]-> d2, [[d;n]]) -[ex: Any]-> [[d2; n]] -fn map_ty() -> FnTy { - let d = Ident::new("d"); - let d2 = Ident::new("d2"); - let n = Ident::new("n"); +// map_mut: +// (|d| -[ex]-> d2, [[d; n]]) -[ex: Any]-> [[d2; n]] +fn map_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + let d = Ident::new(arena, "d"); + let d2 = Ident::new(arena, "d2"); + let n = Ident::new(arena, "n"); + let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, @@ -958,54 +1100,69 @@ fn map_ty() -> FnTy { ident: n.clone(), kind: Kind::Nat, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // Build the inner function type: (d) -[ex]-> d2 + let inner_param_ty = ty_data_ref(arena, d_ident(d.clone(), arena)); + let inner_ret_ty = ty_data_ref(arena, d_ident(d2.clone(), arena)); + + let inner_fn = FnTy::new( + arena, + [], // no generics + None, // no generic exec + [ParamSig::new(exec_expr.clone(), inner_param_ty)], + exec_expr.clone(), + inner_ret_ty, + [], // no nat constraints + ); + let inner_fn_ref: &'a FnTy<'a> = arena.alloc(inner_fn); + let inner_fn_ty_ref: &'a Ty<'a> = arena.alloc(Ty { + ty: TyKind::FnTy(inner_fn_ref), + span: None, + }); + + // Second param: [[d; n]] + let arr_param = d_array_shape(arena, d_ident(d, arena), Nat::Ident(n.clone())); + + // Return type: [[d2; n]] + let ret = d_array_shape(arena, d_ident(d2, arena), Nat::Ident(n)); FnTy::new( - vec![d_dty, d2_dty, n_nat], + arena, + [d_dty, d2_dty, n_nat], Some(ident_exec), - vec![ - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::FnTy(Box::new(FnTy::new( - vec![], - None, - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ident( - d.clone(), - ))))), - )], - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::Ident( - d2.clone(), - ))))), - vec![], - )))), - ), - ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(n.clone()), - ))))), - ), + [ + ParamSig::new(exec_expr.clone(), inner_fn_ty_ref), + ParamSig::new(exec_expr.clone(), ty_data_ref(arena, arr_param)), ], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d2))), - Nat::Ident(n), - ))))), - vec![], + ty_data_ref(arena, ret), + [], ) } +// Small Nat helper (arena-allocates both operands) +#[inline] +fn n_binop<'a>(arena: &'a Bump, op: BinOpNat, l: Nat<'a>, r: Nat<'a>) -> Nat<'a> { + let l_ref = arena.alloc(l); + let r_ref = arena.alloc(r); + Nat::BinOp(op, l_ref, r_ref) +} + // group/group_mut: -// ([[d; n]]) -> [[ [[d; size]]; n/size ]] -fn group_ty() -> FnTy { - let s = Ident::new("s"); - let n = Ident::new("n"); - let d = Ident::new("d"); +// ([[d; n]]) -[Any]-> [[ [[d; s]]; n/s ]] +fn group_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generics + let s = Ident::new(arena, "s"); + let n = Ident::new(arena, "n"); + let d = Ident::new(arena, "d"); + let s_nat = IdentKinded { ident: s.clone(), kind: Kind::Nat, @@ -1014,43 +1171,48 @@ fn group_ty() -> FnTy { ident: n.clone(), kind: Kind::Nat, }; - let d_ty = IdentKinded { + let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // exec + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [[d; n]] + let param = d_array_shape(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + + // inner element: [[d; s]] + let inner = d_array_shape(arena, d_ident(d, arena), Nat::Ident(s.clone())); + + // outer size: n / s (arena-allocated operands) + let n_div_s = n_binop( + arena, + BinOpNat::Div, + Nat::Ident(n.clone()), + Nat::Ident(s.clone()), + ); + + // return: [[ [[d; s]]; n/s ]] + let ret = d_array_shape(arena, inner, n_div_s); + + // constraint: (n % s) == 0 (all pieces arena-allocated) + let n_mod_s = n_binop(arena, BinOpNat::Mod, Nat::Ident(n), Nat::Ident(s)); + let constr = NatConstr::Eq(arena.alloc(n_mod_s), arena.alloc(Nat::Lit(0))); FnTy::new( - vec![s_nat, n_nat, d_ty], + arena, + [s_nat, n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))))), - )], + [ParamSig::new(exec_expr.clone(), ty_data_ref(arena, param))], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(s.clone()), - ))), - Nat::BinOp( - BinOpNat::Div, - Box::new(Nat::Ident(n.clone())), - Box::new(Nat::Ident(s.clone())), - ), - ))))), - vec![NatConstr::Eq( - Box::new(Nat::BinOp( - BinOpNat::Mod, - Box::new(Nat::Ident(n)), - Box::new(Nat::Ident(s)), - )), - Box::new(Nat::Lit(0)), - )], + ty_data_ref(arena, ret), + [constr], ) } @@ -1112,11 +1274,13 @@ pub enum TakeSide { // } // select: ([[ d; n ]]) -[a: any]-> [[ d; u-l ]] -fn select_range_ty() -> FnTy { - let l = Ident::new("l"); - let u = Ident::new("u"); - let n = Ident::new("n"); - let d = Ident::new("d"); +fn select_range_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generics + let l = Ident::new(arena, "l"); + let u = Ident::new(arena, "u"); + let n = Ident::new(arena, "n"); + let d = Ident::new(arena, "d"); + let l_nat = IdentKinded { ident: l.clone(), kind: Kind::Nat, @@ -1129,48 +1293,52 @@ fn select_range_ty() -> FnTy { ident: n.clone(), kind: Kind::Nat, }; - let d_ty = IdentKinded { + let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // exec + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [[d; n]] + let param = d_array_shape(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + + // return: [[d; u - l]] + let u_minus_l = n_binop( + arena, + BinOpNat::Sub, + Nat::Ident(u.clone()), + Nat::Ident(l.clone()), + ); + let ret = d_array_shape(arena, d_ident(d, arena), u_minus_l); + + // constraints: l < u AND (u < n OR u == n) + let c_lt_lu = NatConstr::Lt( + arena.alloc(Nat::Ident(l.clone())), + arena.alloc(Nat::Ident(u.clone())), + ); + let c_lt_un = NatConstr::Lt( + arena.alloc(Nat::Ident(u.clone())), + arena.alloc(Nat::Ident(n.clone())), + ); + let c_eq_un = NatConstr::Eq(arena.alloc(Nat::Ident(u)), arena.alloc(Nat::Ident(n))); + let c_or = NatConstr::Or(arena.alloc(c_lt_un), arena.alloc(c_eq_un)); + let constr = NatConstr::And(arena.alloc(c_lt_lu), arena.alloc(c_or)); FnTy::new( - vec![l_nat, u_nat, n_nat, d_ty], + arena, + [l_nat, u_nat, n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))))), - )], + [ParamSig::new(exec_expr.clone(), ty_data_ref(arena, param))], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::BinOp( - BinOpNat::Sub, - Box::new(Nat::Ident(u.clone())), - Box::new(Nat::Ident(l.clone())), - ), - ))))), - vec![NatConstr::And( - Box::new(NatConstr::Lt( - Box::new(Nat::Ident(l)), - Box::new(Nat::Ident(u.clone())), - )), - Box::new(NatConstr::Or( - Box::new(NatConstr::Lt( - Box::new(Nat::Ident(u.clone())), - Box::new(Nat::Ident(n.clone())), - )), - Box::new(NatConstr::Eq( - Box::new(Nat::Ident(u)), - Box::new(Nat::Ident(n)), - )), - )), - )], + ty_data_ref(arena, ret), + [constr], ) } @@ -1193,10 +1361,12 @@ fn select_range_ty() -> FnTy { // join/join_mut: // (&r W m [[ [[d; n]]; o]]) -> [[d; n*o]] -fn join_ty() -> FnTy { - let n = Ident::new("n"); - let o = Ident::new("o"); - let d = Ident::new("d"); +fn join_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generics + let n = Ident::new(arena, "n"); + let o = Ident::new(arena, "o"); + let d = Ident::new(arena, "d"); + let n_nat = IdentKinded { ident: n.clone(), kind: Kind::Nat, @@ -1209,77 +1379,78 @@ fn join_ty() -> FnTy { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // exec + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [[ [[d; n]]; o ]] + let inner = d_array_shape(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + let param = d_array_shape(arena, inner, Nat::Ident(o.clone())); + + // return: [[d; n * o]] + let n_mul_o = n_binop(arena, BinOpNat::Mul, Nat::Ident(n), Nat::Ident(o)); + let ret = d_array_shape(arena, d_ident(d, arena), n_mul_o); FnTy::new( - vec![o_nat, n_nat, d_dty], + arena, + [o_nat, n_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))), - Nat::Ident(o.clone()), - ))))), - )], + [ParamSig::new(exec_expr.clone(), ty_data_ref(arena, param))], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::BinOp( - BinOpNat::Mul, - Box::new(Nat::Ident(n)), - Box::new(Nat::Ident(o)), - ), - ))))), - vec![], + ty_data_ref(arena, ret), + [], ) } // transpose: // (&r W m [[ [[d; n]]; o]]) -> &r W m [[ [[d; o]]; n]] -fn transpose_ty() -> FnTy { - let n = Ident::new("n"); - let o = Ident::new("o"); - let d = Ident::new("d"); - let o_nat = IdentKinded { - ident: o.clone(), - kind: Kind::Nat, - }; +fn transpose_ty<'a>(arena: &'a Bump) -> FnTy<'a> { + // generics + let n = Ident::new(arena, "n"); + let o = Ident::new(arena, "o"); + let d = Ident::new(arena, "d"); + let n_nat = IdentKinded { ident: n.clone(), kind: Kind::Nat, }; - let d_ty = IdentKinded { + let o_nat = IdentKinded { + ident: o.clone(), + kind: Kind::Nat, + }; + let d_dty = IdentKinded { ident: d.clone(), kind: Kind::DataTy, }; - let ident_exec = IdentExec::new(Ident::new("ex"), ExecTy::new(ExecTyKind::Any)); - let exec_expr = ExecExpr::new(ExecExprKind::new(BaseExec::Ident(ident_exec.ident.clone()))); + + // exec + let ident_exec = + IdentExec::new_in(arena, Ident::new(arena, "ex"), ExecTy::new(ExecTyKind::Any)); + let exec_expr = ExecExpr::new( + arena, + ExecExprKind::new(arena, BaseExec::Ident(ident_exec.ident.clone())), + ); + + // param: [[ [[d; n]]; o ]] + let inner_param = d_array_shape(arena, d_ident(d.clone(), arena), Nat::Ident(n.clone())); + let param = d_array_shape(arena, inner_param, Nat::Ident(o.clone())); + + // return: [[ [[d; o]]; n ]] + let inner_ret = d_array_shape(arena, d_ident(d, arena), Nat::Ident(o)); + let ret = d_array_shape(arena, inner_ret, Nat::Ident(n)); FnTy::new( - vec![n_nat, o_nat, d_ty], + arena, + [n_nat, o_nat, d_dty], Some(ident_exec), - vec![ParamSig::new( - exec_expr.clone(), - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d.clone()))), - Nat::Ident(n.clone()), - ))), - Nat::Ident(o.clone()), - ))))), - )], + [ParamSig::new(exec_expr.clone(), ty_data_ref(arena, param))], exec_expr, - Ty::new(TyKind::Data(Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::ArrayShape( - Box::new(DataTy::new(DataTyKind::Ident(d))), - Nat::Ident(o), - ))), - Nat::Ident(n), - ))))), - vec![], + ty_data_ref(arena, ret), + [], ) } diff --git a/src/ty_check/subty.rs b/src/ty_check/subty.rs index 343095fa..e6a90f7a 100644 --- a/src/ty_check/subty.rs +++ b/src/ty_check/subty.rs @@ -1,25 +1,26 @@ use super::ctxs::{KindCtx, TyCtx}; -use crate::ast::internal::Loan; +use crate::arena_ast::internal::Loan; // // Subtyping and Provenance Subtyping from Oxide // use super::error::{CtxError, SubTyError}; -use crate::ast::*; -use std::collections::HashSet; +use crate::arena_ast::*; +use bumpalo::Bump; -type SubTyResult = Result; +type SubTyResult<'a, T> = Result>; -// FIXME respect memory alaways, somehow provenances can be different is this correct? +// FIXME respect memory always, somehow provenances can be different is this correct? // τ1 is subtype of τ2 under Δ and Γ, producing Γ′ // Δ; Γ ⊢ τ1 ≲ τ2 ⇒ Γ′ -pub(super) fn check( - kind_ctx: &KindCtx, - ty_ctx: &mut TyCtx, - sub_dty: &DataTy, - super_dty: &DataTy, -) -> SubTyResult<()> { +pub(super) fn check<'a>( + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &mut TyCtx<'a>, + sub_dty: &'a DataTy<'a>, + super_dty: &'a DataTy<'a>, + arena: &'a Bump, +) -> SubTyResult<'a, ()> { use super::Ownership::*; use DataTyKind::*; @@ -27,17 +28,17 @@ pub(super) fn check( // Δ; Γ ⊢ τ ≲ τ ⇒ Γ (sub, sup) if sub == sup => Ok(()), // Δ; Γ ⊢ [τ 1 ; n] ≲ [τ2 ; n] ⇒ Γ′ - (Array(sub_elem_ty, sub_size), Array(sup_elem_ty, sup_size)) - | (ArrayShape(sub_elem_ty, sub_size), ArrayShape(sup_elem_ty, sup_size)) => { - check(kind_ctx, ty_ctx, sub_elem_ty, sup_elem_ty) + (Array(sub_elem_ty, _sub_size), Array(sup_elem_ty, _sup_size)) + | (ArrayShape(sub_elem_ty, _sub_size), ArrayShape(sup_elem_ty, _sup_size)) => { + check(kind_ctx, ty_ctx, sub_elem_ty, sup_elem_ty, arena) } // Δ; Γ ⊢ &B ρ1 shrd τ1 ≲ &B ρ2 shrd τ2 ⇒ Γ′′ (Ref(lref), Ref(rref)) if lref.own == Shrd && rref.own == Shrd => { - outlives(kind_ctx, ty_ctx, &lref.rgn, &rref.rgn)?; + outlives(kind_ctx, ty_ctx, &lref.rgn, &rref.rgn, arena)?; if lref.mem != rref.mem { return Err(SubTyError::MemoryKindsNoMatch); } - check(kind_ctx, ty_ctx, lref.dty.as_ref(), rref.dty.as_ref()) + check(kind_ctx, ty_ctx, lref.dty, rref.dty, arena) } // Δ; Γ ⊢ &B ρ1 uniq τ1 ≲ &B ρ2 uniq τ2 ⇒ Γ'' (Ref(lref), Ref(rref)) => { @@ -47,18 +48,18 @@ pub(super) fn check( if lref.mem != rref.mem { return Err(SubTyError::MemoryKindsNoMatch); } - outlives(kind_ctx, ty_ctx, &lref.rgn, &rref.rgn)?; - check(kind_ctx, ty_ctx, lref.dty.as_ref(), rref.dty.as_ref()) + outlives(kind_ctx, ty_ctx, &lref.rgn, &rref.rgn, arena)?; + check(kind_ctx, ty_ctx, lref.dty, rref.dty, arena) } // Δ; Γ ⊢ (τ1, ..., τn) ≲ (τ1′, ..., τn′) ⇒ Γn (Tuple(sub_elems), Tuple(sup_elems)) => { for (sub, sup) in sub_elems.iter().zip(sup_elems) { - check(kind_ctx, ty_ctx, sub, sup)?; + check(kind_ctx, ty_ctx, sub, sup, arena)?; } Ok(()) } // Δ; Γ ⊢ \delta1 ≲ †\delta2 ⇒ Γ - (_, Dead(sup)) => check(kind_ctx, ty_ctx, sub_dty, sup), + (_, Dead(sup)) => check(kind_ctx, ty_ctx, sub_dty, sup, arena), //TODO add case for Transitiviy? // Δ; Γ ⊢ τ1 ≲ τ3 ⇒ Γ'' (sub, sup) => panic!( @@ -70,12 +71,13 @@ pub(super) fn check( // ρ1 outlives ρ2 under Δ and Γ, producing Γ′ // Δ; Γ ⊢ ρ1 :> ρ2 ⇒ Γ′ -fn outlives( - kind_ctx: &KindCtx, - ty_ctx: &mut TyCtx, - longer_prv: &Provenance, - shorter_prv: &Provenance, -) -> SubTyResult<()> { +pub(super) fn outlives<'a>( + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &mut TyCtx<'a>, + longer_prv: &'a Provenance<'a>, + shorter_prv: &'a Provenance<'a>, + arena: &'a Bump, +) -> SubTyResult<'a, ()> { use Provenance::*; match (longer_prv, shorter_prv) { @@ -97,9 +99,9 @@ fn outlives( Ok(()) } // OL-LocalProvenances - (Value(longer), Value(shorter)) => outl_check_val_prvs(ty_ctx, longer, shorter), + (Value(longer), Value(shorter)) => outl_check_val_prvs(ty_ctx, longer, shorter, arena), // OL-LocalProvAbsProv - (Value(longer_val), Ident(_)) => outl_check_val_ident_prv(ty_ctx, longer_val), + (Value(longer_val), Ident(_)) => outl_check_val_ident_prv(ty_ctx, longer_val, arena), // OL-AbsProvLocalProv (Ident(longer_ident), Value(shorter_val)) => { outl_check_ident_val_prv(kind_ctx, ty_ctx, longer_ident, shorter_val) @@ -109,13 +111,18 @@ fn outlives( // OL-LocalProvenances // Δ; Γ ⊢ r1 :> r2 ⇒ Γ[r2 ↦→ { Γ(r1) ∪ Γ(r2) }] -fn outl_check_val_prvs(ty_ctx: &mut TyCtx, longer: &str, shorter: &str) -> SubTyResult<()> { +fn outl_check_val_prvs<'tcx, 'a>( + ty_ctx: &'tcx mut TyCtx<'a>, + longer: &str, + shorter: &str, + arena: &'a Bump, +) -> SubTyResult<'a, ()> { // CHECK: // NOT CLEAR WHY a. IS NECESSARY // a. for every variable of reference type with r1 in ty_ctx: there must not exist a loan // dereferencing the variable for any provenance in ty_ctx. - if exists_deref_loan_with_prv(ty_ctx, longer) { + if exists_deref_loan_with_prv(ty_ctx, longer, arena) { // TODO better error msg return Err(SubTyError::Dummy); } @@ -134,7 +141,11 @@ fn outl_check_val_prvs(ty_ctx: &mut TyCtx, longer: &str, shorter: &str) -> SubTy Ok(()) } -fn longer_occurs_before_shorter(ty_ctx: &TyCtx, longer: &str, shorter: &str) -> bool { +fn longer_occurs_before_shorter<'tcx, 'a>( + ty_ctx: &'tcx TyCtx<'a>, + longer: &str, + shorter: &str, +) -> bool { for prv in ty_ctx .prv_mappings() .map(|prv_mappings| prv_mappings.prv.clone()) @@ -148,54 +159,65 @@ fn longer_occurs_before_shorter(ty_ctx: &TyCtx, longer: &str, shorter: &str) -> panic!("Neither provenance found in typing context") } -fn exists_deref_loan_with_prv(ty_ctx: &TyCtx, prv: &str) -> bool { +fn exists_deref_loan_with_prv<'a>(ty_ctx: &TyCtx<'a>, prv: &str, arena: &'a bumpalo::Bump) -> bool { ty_ctx - .all_places() + .all_places(arena) .into_iter() .filter(|(_, dty)| match &dty.dty { DataTyKind::Ref(reff) => match &reff.rgn { - Provenance::Value(prv_name) => prv_name == prv, + Provenance::Value(prv_name) => *prv_name == prv, _ => false, }, _ => false, }) - .any(|(place, _)| { - ty_ctx.prv_mappings().into_iter().any(|prv_mapping| { - for loan in prv_mapping.loans.iter() { - if let PlaceExprKind::Deref(pl_expr) = &loan.place_expr.pl_expr { - return pl_expr.equiv(&place); - } - } - false + .any(|(place_owned, _)| { + let place_ref = arena.alloc(place_owned); + ty_ctx.prv_mappings().into_iter().any(|pm| { + pm.loans.iter().any(|loan| match &loan.place_expr.pl_expr { + PlaceExprKind::Deref(pl_expr) => pl_expr.equiv(arena, place_ref), + _ => false, + }) }) }) } -fn outl_check_val_ident_prv(ty_ctx: &TyCtx, longer_val: &str) -> SubTyResult<()> { +fn outl_check_val_ident_prv<'a>( + ty_ctx: &TyCtx<'a>, + longer_val: &str, + arena: &'a Bump, +) -> SubTyResult<'a, ()> { // TODO how could the set ever be empty? - let loan_set = ty_ctx.loans_in_prv(longer_val)?; - if loan_set.is_empty() { + let loan_snapshot = arena.alloc(ty_ctx.loans_in_prv_snapshot(longer_val, arena)?); + if loan_snapshot.is_empty() { return Err(SubTyError::PrvNotUsedInBorrow(longer_val.to_string())); } - borrowed_pl_expr_no_ref_to_existing_pl(ty_ctx, loan_set); + borrowed_pl_expr_no_ref_to_existing_pl(ty_ctx, loan_snapshot.as_slice(), arena); panic!("Not yet implemented.") } // FIXME Makes no sense! -fn borrowed_pl_expr_no_ref_to_existing_pl(ty_ctx: &TyCtx, loan_set: &HashSet) -> bool { - ty_ctx - .all_places() - .iter() - .any(|(pl, _)| loan_set.iter().any(|loan| loan.place_expr.equiv(pl))) +fn borrowed_pl_expr_no_ref_to_existing_pl<'a>( + ty_ctx: &TyCtx<'a>, + loans: &'a [Loan<'a>], + arena: &'a bumpalo::Bump, +) -> bool { + let places = ty_ctx.all_places(arena); + + places.into_iter().any(|(pl_owned, _)| { + let pl_ref: &'a internal::Place<'a> = arena.alloc(pl_owned); + loans + .iter() + .any(|loan| loan.place_expr.equiv(arena, pl_ref)) + }) } -fn outl_check_ident_val_prv( - kind_ctx: &KindCtx, - ty_ctx: &TyCtx, - longer_ident: &Ident, +fn outl_check_ident_val_prv<'tcx, 'a>( + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &'tcx TyCtx<'a>, + longer_ident: &'a Ident<'a>, shorter_val: &str, -) -> SubTyResult<()> { +) -> SubTyResult<'a, ()> { if !kind_ctx.ident_of_kind_exists(longer_ident, Kind::Provenance) { return Err(SubTyError::CtxError(CtxError::PrvIdentNotFound( longer_ident.clone(), @@ -211,16 +233,16 @@ fn outl_check_ident_val_prv( // Δ; Γ ⊢ List[ρ1 :> ρ2] ⇒ Γ′ pub(super) fn multiple_outlives<'a, I>( - kind_ctx: &KindCtx, - ty_ctx: &mut TyCtx, + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &'a mut TyCtx<'a>, prv_rels: I, -) -> SubTyResult<()> + arena: &'a Bump, +) -> SubTyResult<'a, ()> where - I: IntoIterator, + I: IntoIterator, &'a Provenance<'a>)>, { - for prv_rel in prv_rels { - let (longer, shorter) = prv_rel; - outlives(kind_ctx, ty_ctx, longer, shorter)?; + for (p1, p2) in prv_rels { + outlives(kind_ctx, ty_ctx, p1, p2, arena)?; } Ok(()) } diff --git a/src/ty_check/unify.rs b/src/ty_check/unify.rs index 5b469759..ed574f91 100644 --- a/src/ty_check/unify.rs +++ b/src/ty_check/unify.rs @@ -1,45 +1,59 @@ -use crate::ast::utils; -use crate::ast::utils::Visitable; -use crate::ast::visit_mut::VisitMut; -use crate::ast::*; +use crate::arena_ast::utils; +use crate::arena_ast::utils::Visitable; +use crate::arena_ast::visit_mut::VisitMut; +use crate::arena_ast::*; use crate::ty_check::ctxs::{KindCtx, TyCtx}; use crate::ty_check::error::UnifyError; use crate::ty_check::subty; +use bumpalo::{collections::Vec as BumpVec, Bump}; use std::collections::HashMap; -type UnifyResult = Result; +#[inline] +fn arena_slice<'a, T>(v: &bumpalo::collections::Vec<'a, T>) -> &'a [T] { + unsafe { std::slice::from_raw_parts(v.as_ptr(), v.len()) } +} + +type UnifyResult<'a, T> = Result>; -pub(super) fn unify(t1: &mut C, t2: &mut C) -> UnifyResult<()> { - let (_, _) = constrain(t1, t2)?; +pub(super) fn unify<'a, C: Constrainable<'a>>( + t1: &'a mut C, + t2: &'a mut C, + arena: &'a Bump, +) -> UnifyResult<'a, ()> { + let (_, _) = constrain(t1, t2, arena)?; Ok(()) } -pub(super) fn sub_unify( - kind_ctx: &KindCtx, - ty_ctx: &mut TyCtx, - sub: &mut C, - sup: &mut C, -) -> UnifyResult<()> { - let (_, prv_rels) = constrain(sub, sup)?; +pub(super) fn sub_unify<'a, C: Constrainable<'a>>( + kind_ctx: &'a KindCtx<'a>, + ty_ctx: &'a mut TyCtx<'a>, + sub: &'a mut C, + sup: &'a mut C, + arena: &'a Bump, +) -> UnifyResult<'a, ()> { + let (_map, prv) = constrain(sub, sup, arena)?; subty::multiple_outlives( kind_ctx, ty_ctx, - prv_rels.iter().map(|PrvConstr(p1, p2)| (p1, p2)), + prv.iter().map(|PrvConstr(p1, p2)| (*p1, *p2)), + arena, )?; Ok(()) } -pub(super) fn constrain( - t1: &mut S, - t2: &mut S, -) -> UnifyResult<(ConstrainMap, Vec)> { +pub(super) fn constrain<'a, 'm, S: Constrainable<'a>>( + t1: &'m mut S, + t2: &'m mut S, + arena: &'a Bump, +) -> UnifyResult<'a, (ConstrainMap<'a>, BumpVec<'a, PrvConstr<'a>>)> { let mut constr_map = ConstrainMap::new(); - let mut prv_rels = Vec::new(); - t1.constrain(t2, &mut constr_map, &mut prv_rels)?; + let mut prv_rels = BumpVec::new_in(arena); + t1.constrain(t2, &mut constr_map, &mut prv_rels, arena)?; Ok((constr_map, prv_rels)) } -pub(super) fn inst_fn_ty_scheme(fn_ty: &FnTy) -> FnTy { +/** +pub(super) fn inst_fn_ty_scheme<'a>(fn_ty: &'a FnTy<'a>, arena: &'a Bump) -> FnTy<'a> { assert!( fn_ty.generic_exec.is_none(), "exec must be substituted before instantiation to make sure that it has the correct type" @@ -48,39 +62,86 @@ pub(super) fn inst_fn_ty_scheme(fn_ty: &FnTy) -> FnTy { .generics .iter() .map(|i| match i.kind { - Kind::DataTy => ArgKinded::DataTy(DataTy::new(utils::fresh_ident( - &i.ident.name, - DataTyKind::Ident, - ))), - Kind::Nat => ArgKinded::Nat(utils::fresh_ident(&i.ident.name, Nat::Ident)), - Kind::Memory => ArgKinded::Memory(utils::fresh_ident(&i.ident.name, Memory::Ident)), + Kind::DataTy => ArgKinded::DataTy(DataTy::new( + arena, + utils::fresh_ident(arena, &i.ident.name, DataTyKind::Ident), + )), + Kind::Nat => ArgKinded::Nat(utils::fresh_ident(arena, &i.ident.name, Nat::Ident)), + Kind::Memory => { + ArgKinded::Memory(utils::fresh_ident(arena, &i.ident.name, Memory::Ident)) + } Kind::Provenance => { - ArgKinded::Provenance(utils::fresh_ident(&i.ident.name, Provenance::Ident)) + ArgKinded::Provenance(utils::fresh_ident(arena, &i.ident.name, Provenance::Ident)) } }) .collect(); + let mut inst_fn_ty = fn_ty.clone(); let generics = inst_fn_ty.generics.drain(..).collect::>(); - utils::subst_idents_kinded(generics.iter(), mono_idents.iter(), &mut inst_fn_ty); + utils::subst_idents_kinded(arena, generics.iter(), mono_idents.iter(), &mut inst_fn_ty); inst_fn_ty } +*/ -#[derive(Debug, PartialEq, Eq, Clone)] -pub(super) struct PrvConstr(pub Provenance, pub Provenance); +pub(super) fn inst_fn_ty_scheme<'a>(fn_ty: &'a FnTy<'a>, arena: &'a bumpalo::Bump) -> FnTy<'a> { + assert!( + fn_ty.generic_exec.is_none(), + "exec must be substituted before instantiation to make sure that it has the correct type" + ); + + // 1) Build arena-resident args (so the *elements* live for 'a) + let mut mono_args = bumpalo::collections::Vec::new_in(arena); + for g in fn_ty.generics.iter() { + let arg = match g.kind { + Kind::DataTy => { + let dk = utils::fresh_ident(arena, &g.ident.name, |id| DataTyKind::Ident(id)); + ArgKinded::DataTy(DataTy::new(arena, dk)) + } + Kind::Nat => ArgKinded::Nat(utils::fresh_ident(arena, &g.ident.name, |id| { + Nat::Ident(id) + })), + Kind::Memory => ArgKinded::Memory(utils::fresh_ident(arena, &g.ident.name, |id| { + Memory::Ident(id) + })), + Kind::Provenance => { + ArgKinded::Provenance(utils::fresh_ident(arena, &g.ident.name, |id| { + Provenance::Ident(id) + })) + } + }; + mono_args.push(arg); + } + + // 2) Convert to a slice with lifetime 'a (elements are arena-owned) + let args_a: &'a [ArgKinded<'a>] = arena_slice(&mono_args); + + // 3) Make a working copy (whatever “clone” mechanism you have) + let mut inst = fn_ty.clone(); // or your `clone_in(arena)`/rebuild + + // 4) Substitute: domain = original generics (already 'a), codomain = args_a + utils::subst_idents_kinded(arena, fn_ty.generics.iter(), args_a.iter(), &mut inst); + + // 5) Monomorphic result has no generics/exec param + inst.generics.clear(); + inst.generic_exec = None; + inst +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub(super) struct PrvConstr<'a>(pub &'a Provenance<'a>, pub &'a Provenance<'a>); #[derive(Debug)] -pub(super) struct ConstrainMap { - // TODO swap Box for something more abstract, like Symbol or Identifier - pub dty_unifier: HashMap, DataTy>, - pub nat_unifier: HashMap, Nat>, - pub mem_unifier: HashMap, Memory>, - pub prv_unifier: HashMap, Provenance>, - pub exec_unifier: HashMap, ExecExpr>, +pub(super) struct ConstrainMap<'a> { + pub dty_unifier: HashMap<&'a str, DataTy<'a>>, + pub nat_unifier: HashMap<&'a str, Nat<'a>>, + pub mem_unifier: HashMap<&'a str, Memory<'a>>, + pub prv_unifier: HashMap<&'a str, Provenance<'a>>, + pub exec_unifier: HashMap<&'a str, ExecExpr<'a>>, } -impl ConstrainMap { +impl<'a> ConstrainMap<'a> { fn new() -> Self { - ConstrainMap { + Self { dty_unifier: HashMap::new(), nat_unifier: HashMap::new(), mem_unifier: HashMap::new(), @@ -90,8 +151,13 @@ impl ConstrainMap { } } -impl DataTy { - fn bind_to(&self, ident: &Ident, constr_map: &mut ConstrainMap) -> UnifyResult<()> { +impl<'a> DataTy<'a> { + fn bind_to( + &self, + ident: &'a Ident<'a>, + constr_map: &mut ConstrainMap<'a>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { if let DataTyKind::Ident(ty_id) = &self.dty { if ty_id == ident { return Ok(()); @@ -113,45 +179,51 @@ impl DataTy { ); } } + + let term_ref: &'a DataTy<'a> = arena.alloc(self.clone_in(arena)); constr_map .dty_unifier .values_mut() - .for_each(|dty| SubstIdent::new(ident, self).visit_dty(dty)); + .for_each(|dty| SubstIdent::new(ident, term_ref).visit_dty(arena, dty)); Ok(()) } } -pub(super) trait Substitutable { - fn substitute(&mut self, subst: &ConstrainMap); +pub(super) trait Substitutable<'a> { + fn substitute<'s>(&mut self, subst: &'s ConstrainMap<'a>, arena: &'a Bump); } -pub(super) trait Constrainable: Visitable + Substitutable { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()>; - fn occurs_check(ident_kinded: &IdentKinded, s: &S) -> bool { +pub(super) trait Constrainable<'a>: Visitable<'a> + Substitutable<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()>; + + fn occurs_check>(ident_kinded: &IdentKinded<'a>, s: &S) -> bool { utils::free_kinded_idents(s).contains(ident_kinded) } } -impl Constrainable for FnTy { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for FnTy<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { assert!(self.generics.is_empty()); assert!(other.generics.is_empty()); assert!(self.generic_exec.is_none()); assert!(other.generic_exec.is_none()); - self.exec.constrain(&mut other.exec, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + self.exec + .constrain(&mut other.exec, constr_map, prv_rels, arena)?; + substitute(&*constr_map, self, arena); + substitute(&*constr_map, other, arena); if self.param_sigs.len() != other.param_sigs.len() { return Err(UnifyError::CannotUnify); @@ -164,189 +236,240 @@ impl Constrainable for FnTy { while let (Some((next_lhs, _)), Some((next_rhs, _))) = (remain_lhs.split_first_mut(), remain_rhs.split_first_mut()) { - next_lhs.constrain(next_rhs, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + next_lhs.constrain(next_rhs, constr_map, prv_rels, arena)?; + substitute(&*constr_map, self, arena); + substitute(&*constr_map, other, arena); i += 1; remain_lhs = &mut self.param_sigs[i..]; remain_rhs = &mut other.param_sigs[i..]; } - self.ret_ty - .constrain(&mut other.ret_ty, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + + // self.ret_ty.constrain(&mut other.ret_ty, constr_map, prv_rels, arena)?; + + let mut lhs_ret = self.ret_ty.clone_in(arena); + let mut rhs_ret = other.ret_ty.clone_in(arena); + + lhs_ret.constrain(&mut rhs_ret, constr_map, prv_rels, arena)?; + + self.ret_ty = arena.alloc(lhs_ret); + other.ret_ty = arena.alloc(rhs_ret); + + substitute(&*constr_map, self, arena); + substitute(&*constr_map, other, arena); Ok(()) } } -impl Substitutable for FnTy { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for FnTy<'a> { + fn substitute<'s>(&mut self, subst: &'s ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_fn_ty(self); + apply_subst.visit_fn_ty(arena, self); } } -impl Constrainable for ParamSig { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for ParamSig<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { self.exec_expr - .constrain(&mut other.exec_expr, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); - self.ty.constrain(&mut other.ty, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + .constrain(&mut other.exec_expr, constr_map, prv_rels, arena)?; + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); + + { + let mut lhs = self.ty.clone_in(arena); + let mut rhs = other.ty.clone_in(arena); + + lhs.constrain(&mut rhs, constr_map, prv_rels, arena)?; + + self.ty = arena.alloc(lhs); + other.ty = arena.alloc(rhs); + } + + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); Ok(()) } } -impl Substitutable for ParamSig { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for ParamSig<'a> { + fn substitute<'s>(&mut self, subst: &'s ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_param_sig(self); + apply_subst.visit_param_sig(arena, self); } } // TODO unification for exec expressions necessary for Nats? Can this be moved into a separate // equality check? -impl Constrainable for ExecExpr { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { - match (&mut self.exec.base, &mut other.exec.base) { - (BaseExec::Ident(i1), BaseExec::Ident(i2)) => { - assert!( - !i1.is_implicit, - "Implicit identifier for exec expression should not exist" - ); +impl<'a> Constrainable<'a> for ExecExpr<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { + use BaseExec as BE; + use ExecPathElem as EPE; + + match (&self.exec.base, &other.exec.base) { + (BE::Ident(i1), BE::Ident(i2)) => { assert!( - !i2.is_implicit, + !i1.is_implicit && !i2.is_implicit, "Implicit identifier for exec expression should not exist" ); - if i1 == i2 { - return Ok(()); - } else { + if i1 != i2 { return Err(UnifyError::CannotUnify); } } - (BaseExec::CpuThread, BaseExec::CpuThread) => {} - (BaseExec::GpuGrid(gdim1, bdim1), BaseExec::GpuGrid(gdim2, bdim2)) => { - gdim1.constrain(gdim2, constr_map, prv_rels)?; - bdim1.constrain(bdim2, constr_map, prv_rels)?; + (BE::CpuThread, BE::CpuThread) => { /* ok */ } + + (BE::GpuGrid(gd1, bd1), BE::GpuGrid(gd2, bd2)) => { + let mut lg = gd1.clone_in(arena); + let mut rg = gd2.clone_in(arena); + lg.constrain(&mut rg, constr_map, prv_rels, arena)?; + + let mut lb = bd1.clone_in(arena); + let mut rb = bd2.clone_in(arena); + lb.constrain(&mut rb, constr_map, prv_rels, arena)?; } + _ => return Err(UnifyError::CannotUnify), } - let mut i = 0; - let mut remain_lhs = &mut self.exec.path[i..]; - let mut remain_rhs = &mut other.exec.path[i..]; - while let (Some((next_lhs, tail_lhs)), Some((next_rhs, tail_rhs))) = - (remain_lhs.split_first_mut(), remain_rhs.split_first_mut()) - { - tail_lhs.iter_mut().for_each(|ep| { - let mut apply_subst = ApplySubst::new(constr_map); - apply_subst.visit_exec_path_elem(ep); - }); - tail_rhs.iter_mut().for_each(|ep| { - let mut apply_subst = ApplySubst::new(constr_map); - apply_subst.visit_exec_path_elem(ep); - }); - - match (next_lhs, next_rhs) { - (ExecPathElem::ForAll(dl), ExecPathElem::ForAll(dr)) - | (ExecPathElem::ToThreads(dl), ExecPathElem::ToThreads(dr)) => { + let l_len = self.exec.path.len(); + let r_len = other.exec.path.len(); + if l_len != r_len { + return Err(UnifyError::CannotUnify); + } + + for i in 0..l_len { + let mut le = self.exec.path[i].clone(); + let mut re = other.exec.path[i].clone(); + + { + let mut ap = ApplySubst::new(constr_map); + ap.visit_exec_path_elem(arena, &mut le); + ap.visit_exec_path_elem(arena, &mut re); + } + + match (&mut le, &mut re) { + (EPE::ForAll(dl), EPE::ForAll(dr)) | (EPE::ToThreads(dl), EPE::ToThreads(dr)) => { if dl != dr { return Err(UnifyError::CannotUnify); } } - (ExecPathElem::TakeRange(rl), ExecPathElem::TakeRange(rr)) => { - if rl.split_dim != rr.split_dim { - return Err(UnifyError::CannotUnify); - } - if rl.left_or_right != rr.left_or_right { + + (EPE::TakeRange(rl), EPE::TakeRange(rr)) => { + if rl.split_dim != rr.split_dim || rl.left_or_right != rr.left_or_right { return Err(UnifyError::CannotUnify); } - rl.pos.constrain(&mut rr.pos, constr_map, prv_rels)? + let mut lp = rl.pos.clone_in(arena); + let mut rp = rr.pos.clone_in(arena); + lp.constrain(&mut rp, constr_map, prv_rels, arena)?; } - (ExecPathElem::ToWarps, ExecPathElem::ToWarps) => {} + + (EPE::ToWarps, EPE::ToWarps) => { /* ok */ } + _ => return Err(UnifyError::CannotUnify), } - - i += 1; - remain_lhs = &mut self.exec.path[i..]; - remain_rhs = &mut other.exec.path[i..]; } + + // Optional: normalize by applying the final substitution to the whole exprs. + // (This rebuilds nodes in the arena and updates the &-fields atomically.) + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); + Ok(()) } } -impl Substitutable for ExecExpr { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for ExecExpr<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_exec_expr(self); + apply_subst.visit_exec_expr(arena, self); } } -impl Constrainable for Ty { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for Ty<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { match (&mut self.ty, &mut other.ty) { (TyKind::FnTy(fn_ty1), TyKind::FnTy(fn_ty2)) => { - fn_ty1.constrain(fn_ty2, constr_map, prv_rels) + { + let mut f1 = (*fn_ty1).clone_in(arena); + let mut f2 = (*fn_ty2).clone_in(arena); + f1.constrain(&mut f2, constr_map, prv_rels, arena)?; + } + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); + Ok(()) + } + (TyKind::Data(dty1), TyKind::Data(dty2)) => { + { + let mut d1 = (*dty1).clone_in(arena); + let mut d2 = (*dty2).clone_in(arena); + d1.constrain(&mut d2, constr_map, prv_rels, arena)?; + } + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); + Ok(()) } - (TyKind::Data(dty1), TyKind::Data(dty2)) => dty1.constrain(dty2, constr_map, prv_rels), _ => Err(UnifyError::CannotUnify), } } } -impl Substitutable for Ty { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for Ty<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_ty(self); + apply_subst.visit_ty(arena, self); } } -impl Constrainable for DataTy { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for DataTy<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { match (&mut self.dty, &mut other.dty) { (DataTyKind::Ident(i1), DataTyKind::Ident(i2)) => { if i1.is_implicit { - other.bind_to(i1, constr_map)? + let i1_ref: &'a Ident<'a> = arena.alloc(i1.clone()); + other.bind_to(i1_ref, constr_map, arena)? } else if i2.is_implicit { - self.bind_to(i2, constr_map)? + let i2_ref: &'a Ident<'a> = arena.alloc(i2.clone()); + self.bind_to(i2_ref, constr_map, arena)? } else if i1 == i2 { return Ok(()); } else { return Err(UnifyError::CannotUnify); } - substitute(constr_map, self); - substitute(constr_map, other); + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); } (DataTyKind::Ident(i), _) if i.is_implicit => { - other.bind_to(i, constr_map)?; - substitute(constr_map, other); + let i_ref: &'a Ident<'a> = arena.alloc(i.clone()); + other.bind_to(i_ref, constr_map, arena)?; + substitute(constr_map, other, arena); } (_, DataTyKind::Ident(i)) if i.is_implicit => { - self.bind_to(i, constr_map)?; - substitute(constr_map, self); + let i_ref: &'a Ident<'a> = arena.alloc(i.clone()); + self.bind_to(i_ref, constr_map, arena)?; + substitute(constr_map, self, arena); } (DataTyKind::Scalar(sty1), DataTyKind::Scalar(sty2)) => { if sty1 != sty2 { @@ -361,28 +484,42 @@ impl Constrainable for DataTy { own: own1, mem: mem1, dty: dty1, - } = ref1.as_mut(); + } = (**ref1).clone_in(arena); + let RefDty { rgn: rgn2, own: own2, mem: mem2, dty: dty2, - } = ref2.as_mut(); + } = (**ref2).clone_in(arena); if own1 != own2 { return Err(UnifyError::CannotUnify); } - rgn1.constrain(rgn2, constr_map, prv_rels)?; - substitute(constr_map, &mut **dty1); - substitute(constr_map, &mut **dty2); - mem1.constrain(mem2, constr_map, prv_rels)?; - substitute(constr_map, mem1); - substitute(constr_map, mem2); - substitute(constr_map, &mut **dty1); - substitute(constr_map, &mut **dty2); - dty1.constrain(dty2, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + + let mut rgn1 = rgn1; + let mut rgn2 = rgn2; + rgn1.constrain(&mut rgn2, constr_map, prv_rels, arena)?; + + let mut dty1_mut = (*dty1).clone(); + let mut dty2_mut = (*dty2).clone(); + + substitute(constr_map, &mut dty1_mut, arena); + substitute(constr_map, &mut dty2_mut, arena); + + let mut mem1 = mem1; + let mut mem2 = mem2; + mem1.constrain(&mut mem2, constr_map, prv_rels, arena)?; + + substitute(constr_map, &mut mem1, arena); + substitute(constr_map, &mut mem2, arena); + substitute(constr_map, &mut dty1_mut, arena); + substitute(constr_map, &mut dty1_mut, arena); + + dty1_mut.constrain(&mut dty2_mut, constr_map, prv_rels, arena)?; + + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); } (DataTyKind::Tuple(elem_dtys1), DataTyKind::Tuple(elem_dtys2)) => { // TODO figure out why the las three/two lines of the while loop enable borrowing @@ -393,10 +530,10 @@ impl Constrainable for DataTy { while let (Some((next_lhs, _)), Some((next_rhs, _))) = (remain_lhs.split_first_mut(), remain_rhs.split_first_mut()) { - next_lhs.constrain(next_rhs, constr_map, prv_rels)?; + next_lhs.constrain(next_rhs, constr_map, prv_rels, arena)?; for (dty1, dty2) in elem_dtys1.iter_mut().zip(elem_dtys2.iter_mut()) { - substitute(constr_map, dty1); - substitute(constr_map, dty2); + substitute(constr_map, dty1, arena); + substitute(constr_map, dty2, arena); } i += 1; @@ -405,6 +542,27 @@ impl Constrainable for DataTy { } } (DataTyKind::Struct(struct_decl1), DataTyKind::Struct(struct_decl2)) => { + if struct_decl1.fields.len() != struct_decl2.fields.len() { + return Err(UnifyError::CannotUnify); + } + + for ((lname, lty_ref), (rname, rty_ref)) in + struct_decl1.fields.iter().zip(struct_decl2.fields.iter()) + { + if lname != rname { + return Err(UnifyError::CannotUnify); + } + + let mut lty = lty_ref.clone_in(arena); + let mut rty = rty_ref.clone_in(arena); + + lty.constrain(&mut rty, constr_map, prv_rels, arena)?; + } + + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); + + /* let mut i = 0; let mut remain_lhs = &mut struct_decl1.fields[i..]; let mut remain_rhs = &mut struct_decl2.fields[i..]; @@ -416,37 +574,47 @@ impl Constrainable for DataTy { } next_lhs .1 - .constrain(&mut next_rhs.1, constr_map, prv_rels)?; + .constrain(&mut next_rhs.1, constr_map, prv_rels, arena)?; for ((_, dty1), (_, dty2)) in struct_decl1 .fields .iter_mut() .zip(struct_decl2.fields.iter_mut()) { - substitute(constr_map, dty1); - substitute(constr_map, dty2); + substitute(constr_map, dty1, arena); + substitute(constr_map, dty2, arena); } i += 1; remain_lhs = &mut struct_decl1.fields[i..]; remain_rhs = &mut struct_decl2.fields[i..]; } + */ } (DataTyKind::Array(dty1, n1), DataTyKind::Array(dty2, n2)) | (DataTyKind::ArrayShape(dty1, n1), DataTyKind::ArrayShape(dty2, n2)) => { - dty1.constrain(dty2, constr_map, prv_rels)?; - substitute(constr_map, &mut **dty1); - substitute(constr_map, &mut **dty2); - n1.constrain(n2, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + let mut dty1_owned = (**dty1).clone(); + let mut dty2_owned = (**dty2).clone(); + dty1_owned.constrain(&mut dty2_owned, constr_map, prv_rels, arena)?; + substitute(constr_map, &mut dty1_owned, arena); + substitute(constr_map, &mut dty2_owned, arena); + *dty1 = arena.alloc(dty1_owned); + *dty2 = arena.alloc(dty2_owned); + + n1.constrain(n2, constr_map, prv_rels, arena)?; + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); } (DataTyKind::At(dty1, mem1), DataTyKind::At(dty2, mem2)) => { - dty1.constrain(dty2, constr_map, prv_rels)?; - substitute(constr_map, &mut **dty1); - substitute(constr_map, &mut **dty2); - mem1.constrain(mem2, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + let mut dty1_owned = (**dty1).clone(); + let mut dty2_owned = (**dty2).clone(); + dty1_owned.constrain(&mut dty2_owned, constr_map, prv_rels, arena)?; + substitute(constr_map, &mut dty1_owned, arena); + substitute(constr_map, &mut dty2_owned, arena); + *dty1 = arena.alloc(dty1_owned); + *dty2 = arena.alloc(dty2_owned); + mem1.constrain(mem2, constr_map, prv_rels, arena)?; + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); } (DataTyKind::Atomic(sty1), DataTyKind::Atomic(sty2)) => { if sty1 != sty2 { @@ -462,9 +630,11 @@ impl Constrainable for DataTy { panic!() } (dty1, DataTyKind::Dead(dty2)) if !matches!(dty1, DataTyKind::Dead(_)) => { - self.constrain(dty2, constr_map, prv_rels)?; - substitute(constr_map, self); - substitute(constr_map, other); + let mut dty2_owned = (**dty2).clone(); + self.constrain(&mut dty2_owned, constr_map, prv_rels, arena)?; + *dty2 = arena.alloc(dty2_owned); + substitute(constr_map, self, arena); + substitute(constr_map, other, arena); } _ => return Err(UnifyError::CannotUnify), } @@ -472,99 +642,166 @@ impl Constrainable for DataTy { } } -impl Substitutable for DataTy { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for DataTy<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_dty(self); + apply_subst.visit_dty(arena, self); } } -impl Constrainable for ExecTy { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for ExecTy<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { match (&mut self.ty, &mut other.ty) { (ExecTyKind::CpuThread, ExecTyKind::CpuThread) | (ExecTyKind::GpuThread, ExecTyKind::GpuThread) | (ExecTyKind::GpuWarp, ExecTyKind::GpuWarp) | (_, ExecTyKind::Any) => Ok(()), - (ExecTyKind::GpuWarpGrp(nl), ExecTyKind::GpuWarpGrp(nr)) => { - nl.constrain(nr, constr_map, prv_rels) - } - (ExecTyKind::GpuGrid(lgdim, lbdim), ExecTyKind::GpuGrid(rgdim, rbdim)) - | (ExecTyKind::GpuBlockGrp(lgdim, lbdim), ExecTyKind::GpuBlockGrp(rgdim, rbdim)) => { - lgdim.constrain(rgdim, constr_map, prv_rels)?; - lbdim.constrain(rbdim, constr_map, prv_rels) - } - ( - ExecTyKind::GpuToThreads(ldim_compo, l_inner), - ExecTyKind::GpuToThreads(rdim_compo, r_inner), - ) => { - if ldim_compo != rdim_compo { + + (ExecTyKind::GpuWarpGrp(ln), ExecTyKind::GpuWarpGrp(rn)) => { + let mut l = (*ln).clone_in(arena); + let mut r = (*rn).clone_in(arena); + l.constrain(&mut r, constr_map, prv_rels, arena) + } + + (ExecTyKind::GpuGrid(lg, lb), ExecTyKind::GpuGrid(rg, rb)) + | (ExecTyKind::GpuBlockGrp(lg, lb), ExecTyKind::GpuBlockGrp(rg, rb)) => { + let mut lgc = (*lg).clone_in(arena); + let mut rgc = (*rg).clone_in(arena); + lgc.constrain(&mut rgc, constr_map, prv_rels, arena)?; + + let mut lbc = (*lb).clone_in(arena); + let mut rbc = (*rb).clone_in(arena); + lbc.constrain(&mut rbc, constr_map, prv_rels, arena) + } + + (ExecTyKind::GpuToThreads(ldc, l_inner), ExecTyKind::GpuToThreads(rdc, r_inner)) => { + if ldc != rdc { return Err(UnifyError::CannotUnify); } - l_inner.constrain(r_inner, constr_map, prv_rels) + let mut li = (*l_inner).clone_in(arena); + let mut ri = (*r_inner).clone_in(arena); + li.constrain(&mut ri, constr_map, prv_rels, arena) } - (ExecTyKind::GpuBlock(ldim), ExecTyKind::GpuBlock(rdim)) - | (ExecTyKind::GpuThreadGrp(ldim), ExecTyKind::GpuThreadGrp(rdim)) => { - ldim.constrain(rdim, constr_map, prv_rels) + + (ExecTyKind::GpuBlock(ld), ExecTyKind::GpuBlock(rd)) + | (ExecTyKind::GpuThreadGrp(ld), ExecTyKind::GpuThreadGrp(rd)) => { + let mut lc = (*ld).clone_in(arena); + let mut rc = (*rd).clone_in(arena); + lc.constrain(&mut rc, constr_map, prv_rels, arena) } + _ => Err(UnifyError::CannotUnify), } } } -impl Substitutable for ExecTy { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for ExecTy<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, _arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); apply_subst.visit_exec_ty(self); } } -impl Constrainable for Dim { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +/** +impl<'a> Constrainable<'a> for Dim<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { match (self, other) { (Dim::XYZ(ldim), Dim::XYZ(rdim)) => { - ldim.0.constrain(&mut rdim.0, constr_map, prv_rels)?; - ldim.1.constrain(&mut rdim.1, constr_map, prv_rels)?; - ldim.2.constrain(&mut rdim.2, constr_map, prv_rels) + ldim.0.constrain(&mut rdim.0, constr_map, prv_rels, arena)?; + ldim.1.constrain(&mut rdim.1, constr_map, prv_rels, arena)?; + ldim.2.constrain(&mut rdim.2, constr_map, prv_rels, arena) } (Dim::XY(ldim), Dim::XY(rdim)) | (Dim::XZ(ldim), Dim::XZ(rdim)) | (Dim::YZ(ldim), Dim::YZ(rdim)) => { - ldim.0.constrain(&mut rdim.0, constr_map, prv_rels)?; - ldim.1.constrain(&mut rdim.1, constr_map, prv_rels) + ldim.0.constrain(&mut rdim.0, constr_map, prv_rels, arena)?; + ldim.1.constrain(&mut rdim.1, constr_map, prv_rels, arena) } (Dim::X(ld), Dim::X(rd)) | (Dim::Y(ld), Dim::Y(rd)) | (Dim::Z(ld), Dim::Z(rd)) => { - ld.0.constrain(&mut rd.0, constr_map, prv_rels) + ld.0.constrain(&mut rd.0, constr_map, prv_rels, arena) + } + _ => Err(UnifyError::CannotUnify), + } + } +} +*/ + +impl<'a> Constrainable<'a> for Dim<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { + use Dim::*; + + match (self, other) { + (XYZ(ld), XYZ(rd)) => { + let mut lx = ld.0.clone_in(arena); + let mut rx = rd.0.clone_in(arena); + lx.constrain(&mut rx, constr_map, prv_rels, arena)?; + + let mut ly = ld.1.clone_in(arena); + let mut ry = rd.1.clone_in(arena); + ly.constrain(&mut ry, constr_map, prv_rels, arena)?; + + let mut lz = ld.2.clone_in(arena); + let mut rz = rd.2.clone_in(arena); + lz.constrain(&mut rz, constr_map, prv_rels, arena)?; + + Ok(()) + } + + (XY(ld), XY(rd)) | (XZ(ld), XZ(rd)) | (YZ(ld), YZ(rd)) => { + let mut l0 = ld.0.clone_in(arena); + let mut r0 = rd.0.clone_in(arena); + l0.constrain(&mut r0, constr_map, prv_rels, arena)?; + + let mut l1 = ld.1.clone_in(arena); + let mut r1 = rd.1.clone_in(arena); + l1.constrain(&mut r1, constr_map, prv_rels, arena)?; + + Ok(()) } + + (X(ld), X(rd)) | (Y(ld), Y(rd)) | (Z(ld), Z(rd)) => { + let mut ln = ld.0.clone_in(arena); + let mut rn = rd.0.clone_in(arena); + ln.constrain(&mut rn, constr_map, prv_rels, arena) + } + _ => Err(UnifyError::CannotUnify), } } } -impl Substitutable for Dim { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for Dim<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_dim(self); + apply_subst.visit_dim(arena, self); } } -impl Nat { +impl<'a> Nat<'a> { fn bind_to( &self, - ident: &Ident, - constr_map: &mut ConstrainMap, - _: &mut Vec, - ) -> UnifyResult<()> { + ident: &'a Ident<'a>, + constr_map: &mut ConstrainMap<'a>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { // No occurs check. // Nats can be equal to an expression in which the nat appears again. E.g., a = a * 1 if let Some(old) = constr_map @@ -578,14 +815,15 @@ impl Nat { ) } } + let term_ref: &'a Nat<'a> = arena.alloc(self.clone_in(arena)); constr_map .nat_unifier .values_mut() - .for_each(|n| SubstIdent::new(ident, self).visit_nat(n)); + .for_each(|n| SubstIdent::new(ident, term_ref).visit_nat(arena, n)); Ok(()) } - fn unify(n1: &Nat, n2: &Nat, _constr_map: &mut ConstrainMap) -> UnifyResult<()> { + fn unify<'m>(n1: &'m Nat<'a>, n2: &'m Nat<'a>) -> UnifyResult<'a, ()> { if n1 == n2 { Ok(()) } else { @@ -594,46 +832,68 @@ impl Nat { } } -impl Constrainable for Nat { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { - match (&mut *self, &mut *other) { +impl<'a> Constrainable<'a> for Nat<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { + match (&*self, &*other) { (Nat::Ident(n1i), Nat::Ident(n2i)) if n1i.is_implicit || n2i.is_implicit => { match (n1i.is_implicit, n2i.is_implicit) { - (true, _) => other.bind_to(n1i, constr_map, prv_rels), - (false, _) => self.bind_to(n2i, constr_map, prv_rels), + (true, _) => other.bind_to(arena.alloc(n1i.clone()), constr_map, arena), + (false, _) => self.bind_to(arena.alloc(n2i.clone()), constr_map, arena), } } - (Nat::Ident(n1i), _) if n1i.is_implicit => other.bind_to(n1i, constr_map, prv_rels), - (_, Nat::Ident(n2i)) if n2i.is_implicit => self.bind_to(n2i, constr_map, prv_rels), + (Nat::Ident(n1i), _) if n1i.is_implicit => { + other.bind_to(arena.alloc(n1i.clone()), constr_map, arena) + } + (_, Nat::Ident(n2i)) if n2i.is_implicit => { + self.bind_to(arena.alloc(n2i.clone()), constr_map, arena) + } (Nat::BinOp(op1, n1l, n1r), Nat::BinOp(op2, n2l, n2r)) if op1 == op2 => { - n1l.constrain(n2l, constr_map, prv_rels)?; - n1r.constrain(n2r, constr_map, prv_rels) + let mut l_left = (*n1l).clone_in(arena); + let mut r_left = (*n2l).clone_in(arena); + l_left.constrain(&mut r_left, constr_map, prv_rels, arena)?; + + let mut l_right = (*n1r).clone_in(arena); + let mut r_right = (*n2r).clone_in(arena); + l_right.constrain(&mut r_right, constr_map, prv_rels, arena)?; + + Ok(()) } (Nat::App(f1, ns1), Nat::App(f2, ns2)) if f1 == f2 => { - for (n1, n2) in ns1.iter_mut().zip(ns2.iter_mut()) { - n1.constrain(n2, constr_map, prv_rels)?; + if ns1.len() != ns2.len() { + return Err(UnifyError::CannotUnify); + } + for (n1, n2) in ns1.iter().zip(ns2.iter()) { + let mut l = n1.clone_in(arena); + let mut r = n2.clone_in(arena); + l.constrain(&mut r, constr_map, prv_rels, arena)?; } Ok(()) } - _ => Self::unify(self, other, constr_map), + _ => Self::unify(self, other), } } } -impl Substitutable for Nat { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for Nat<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_nat(self); + apply_subst.visit_nat(arena, self); } } -impl Memory { - fn bind_to(&self, ident: &Ident, constr_map: &mut ConstrainMap) -> UnifyResult<()> { +impl<'a> Memory<'a> { + fn bind_to( + &self, + ident: &'a Ident<'a>, + constr_map: &mut ConstrainMap<'a>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { if Self::occurs_check(&IdentKinded::new(ident, Kind::Memory), self) { return Err(UnifyError::InfiniteType); } @@ -655,44 +915,51 @@ impl Memory { ) } } + let term_ref: &'a Memory<'a> = arena.alloc(self.clone_in(arena)); constr_map .mem_unifier .values_mut() - .for_each(|m| SubstIdent::new(ident, self).visit_mem(m)); + .for_each(|m| SubstIdent::new(ident, term_ref).visit_mem(arena, m)); Ok(()) } } -impl Constrainable for Memory { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - _prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for Memory<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + _prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { match (&*self, &*other) { (Memory::Ident(i1), Memory::Ident(i2)) if i1 == i2 => Ok(()), (Memory::Ident(i1), Memory::Ident(i2)) => match (i1.is_implicit, i2.is_implicit) { - (true, _) => other.bind_to(i1, constr_map), - (false, _) => self.bind_to(i2, constr_map), + (true, _) => other.bind_to(arena.alloc(i1.clone()), constr_map, arena), + (false, _) => self.bind_to(arena.alloc(i2.clone()), constr_map, arena), }, - (Memory::Ident(i), o) => o.bind_to(i, constr_map), - (s, Memory::Ident(i)) => s.bind_to(i, constr_map), + (Memory::Ident(i), o) => o.bind_to(arena.alloc(i.clone()), constr_map, arena), + (s, Memory::Ident(i)) => s.bind_to(arena.alloc(i.clone()), constr_map, arena), (mem1, mem2) if mem1 == mem2 => Ok(()), _ => Err(UnifyError::CannotUnify), } } } -impl Substitutable for Memory { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for Memory<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_mem(self); + apply_subst.visit_mem(arena, self); } } -impl Provenance { - fn bind_to(&self, ident: &Ident, constr_map: &mut ConstrainMap) -> UnifyResult<()> { +impl<'a> Provenance<'a> { + fn bind_to( + &self, + ident: &'a Ident<'a>, + constr_map: &mut ConstrainMap<'a>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { // TODO not necessary, since no recursion possible if Self::occurs_check(&IdentKinded::new(ident, Kind::Provenance), self) { return Err(UnifyError::InfiniteType); @@ -715,154 +982,166 @@ impl Provenance { ) } } + let term_ref: &'a Provenance<'a> = arena.alloc(self.clone_in(arena)); constr_map .prv_unifier .values_mut() - .for_each(|m| SubstIdent::new(ident, self).visit_prv(m)); + .for_each(|m| SubstIdent::new(ident, term_ref).visit_prv(arena, m)); Ok(()) } } -impl Constrainable for Provenance { - fn constrain( - &mut self, - other: &mut Self, - constr_map: &mut ConstrainMap, - prv_rels: &mut Vec, - ) -> UnifyResult<()> { +impl<'a> Constrainable<'a> for Provenance<'a> { + fn constrain<'m>( + &'m mut self, + other: &'m mut Self, + constr_map: &'m mut ConstrainMap<'a>, + prv_rels: &'m mut BumpVec<'a, PrvConstr<'a>>, + arena: &'a Bump, + ) -> UnifyResult<'a, ()> { // TODO restructure cases for less? match (&*self, &*other) { (Provenance::Ident(i1), Provenance::Ident(i2)) if i1 == i2 => Ok(()), (Provenance::Ident(i), r) | (r, Provenance::Ident(i)) if i.is_implicit => { - r.bind_to(i, constr_map) + let i_ref: &'a Ident<'a> = arena.alloc(i.clone()); + r.bind_to(i_ref, constr_map, arena) } (Provenance::Ident(_), _) | (_, Provenance::Ident(_)) => { - prv_rels.push(PrvConstr(self.clone(), other.clone())); + let l_ref: &'a Provenance<'a> = arena.alloc(self.clone_in(arena)); + let r_ref: &'a Provenance<'a> = arena.alloc(other.clone_in(arena)); + prv_rels.push(PrvConstr(l_ref, r_ref)); Ok(()) } (Provenance::Value(_), Provenance::Value(_)) => { - prv_rels.push(PrvConstr(self.clone(), other.clone())); + let l_ref: &'a Provenance<'a> = arena.alloc(self.clone_in(arena)); + let r_ref: &'a Provenance<'a> = arena.alloc(other.clone_in(arena)); + prv_rels.push(PrvConstr(l_ref, r_ref)); Ok(()) } } } } -impl Substitutable for Provenance { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for Provenance<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_prv(self); + apply_subst.visit_prv(arena, self); } } -impl Substitutable for View { - fn substitute(&mut self, subst: &ConstrainMap) { +impl<'a> Substitutable<'a> for View<'a> { + fn substitute(&mut self, subst: &ConstrainMap<'a>, arena: &'a Bump) { let mut apply_subst = ApplySubst::new(subst); - apply_subst.visit_view(self); + apply_subst.visit_view(arena, self); } } -pub(super) fn substitute(subst: &ConstrainMap, s: &mut S) { - s.substitute(subst) +pub(super) fn substitute<'a, 's, S: Substitutable<'a>>( + subst: &'s ConstrainMap<'a>, + s: &mut S, + arena: &'a Bump, +) { + s.substitute(subst, arena) } -pub(super) struct ApplySubst<'a> { - subst: &'a ConstrainMap, +pub(super) struct ApplySubst<'s, 'a> { + subst: &'s ConstrainMap<'a>, } -impl<'a> ApplySubst<'a> { - pub(super) fn new(subst: &'a ConstrainMap) -> Self { +impl<'s, 'a> ApplySubst<'s, 'a> { + pub(super) fn new(subst: &'s ConstrainMap<'a>) -> Self { ApplySubst { subst } } } -impl<'a> VisitMut for ApplySubst<'a> { - fn visit_nat(&mut self, nat: &mut Nat) { +impl<'s, 'a> VisitMut<'a> for ApplySubst<'s, 'a> { + fn visit_nat(&mut self, arena: &'a Bump, nat: &mut Nat<'a>) { match nat { Nat::Ident(ident) if self.subst.nat_unifier.contains_key(&ident.name) => { *nat = self.subst.nat_unifier.get(&ident.name).unwrap().clone(); } - _ => visit_mut::walk_nat(self, nat), + _ => visit_mut::walk_nat(self, arena, nat), } } - fn visit_mem(&mut self, mem: &mut Memory) { + fn visit_mem(&mut self, arena: &'a Bump, mem: &mut Memory<'a>) { match mem { Memory::Ident(ident) if self.subst.mem_unifier.contains_key(&ident.name) => { *mem = self.subst.mem_unifier.get(&ident.name).unwrap().clone(); } - _ => visit_mut::walk_mem(self, mem), + _ => visit_mut::walk_mem(self, arena, mem), } } - fn visit_prv(&mut self, prv: &mut Provenance) { + fn visit_prv(&mut self, arena: &'a Bump, prv: &mut Provenance<'a>) { match prv { Provenance::Ident(ident) if self.subst.prv_unifier.contains_key(&ident.name) => { *prv = self.subst.prv_unifier.get(&ident.name).unwrap().clone() } - _ => visit_mut::walk_prv(self, prv), + _ => visit_mut::walk_prv(self, arena, prv), } } - fn visit_dty(&mut self, dty: &mut DataTy) { + fn visit_dty(&mut self, arena: &'a Bump, dty: &mut DataTy<'a>) { match &mut dty.dty { DataTyKind::Ident(ident) if self.subst.dty_unifier.contains_key(&ident.name) => { *dty = self.subst.dty_unifier.get(&ident.name).unwrap().clone() } - _ => visit_mut::walk_dty(self, dty), + _ => visit_mut::walk_dty(self, arena, dty), } } } -struct SubstIdent<'a, S: Constrainable> { - ident: &'a Ident, +struct SubstIdent<'a, S: Constrainable<'a>> { + ident: &'a Ident<'a>, term: &'a S, } -impl<'a, S: Constrainable> SubstIdent<'a, S> { +impl<'a, S: Constrainable<'a>> SubstIdent<'a, S> { fn new(ident: &'a Ident, term: &'a S) -> Self { SubstIdent { ident, term } } } -impl<'a> VisitMut for SubstIdent<'a, Nat> { - fn visit_nat(&mut self, nat: &mut Nat) { +impl<'a> VisitMut<'a> for SubstIdent<'a, Nat<'a>> { + fn visit_nat(&mut self, arena: &'a Bump, nat: &mut Nat<'a>) { match nat { Nat::Ident(ident) if ident.name == self.ident.name => *nat = self.term.clone(), - _ => visit_mut::walk_nat(self, nat), + _ => visit_mut::walk_nat(self, arena, nat), } } } -impl<'a> VisitMut for SubstIdent<'a, Memory> { - fn visit_mem(&mut self, mem: &mut Memory) { +impl<'a> VisitMut<'a> for SubstIdent<'a, Memory<'a>> { + fn visit_mem(&mut self, arena: &'a Bump, mem: &mut Memory<'a>) { match mem { Memory::Ident(ident) if ident.name == self.ident.name => *mem = self.term.clone(), - _ => visit_mut::walk_mem(self, mem), + _ => visit_mut::walk_mem(self, arena, mem), } } } -impl<'a> VisitMut for SubstIdent<'a, Provenance> { - fn visit_prv(&mut self, prv: &mut Provenance) { +impl<'a> VisitMut<'a> for SubstIdent<'a, Provenance<'a>> { + fn visit_prv(&mut self, arena: &'a Bump, prv: &mut Provenance<'a>) { match prv { Provenance::Ident(ident) if ident.name == self.ident.name => *prv = self.term.clone(), - _ => visit_mut::walk_prv(self, prv), + _ => visit_mut::walk_prv(self, arena, prv), } } } -impl<'a> VisitMut for SubstIdent<'a, DataTy> { - fn visit_dty(&mut self, dty: &mut DataTy) { +impl<'a> VisitMut<'a> for SubstIdent<'a, DataTy<'a>> { + fn visit_dty(&mut self, arena: &'a Bump, dty: &mut DataTy<'a>) { match &mut dty.dty { DataTyKind::Ident(ident) if ident.name == self.ident.name => *dty = self.term.clone(), - _ => visit_mut::walk_dty(self, dty), + _ => visit_mut::walk_dty(self, arena, dty), } } } -impl<'a> VisitMut for SubstIdent<'a, ExecExpr> { - fn visit_exec_expr(&mut self, exec: &mut ExecExpr) { +/** +impl<'a> VisitMut<'a> for SubstIdent<'a, ExecExpr<'a>> { + fn visit_exec_expr(&mut self, arena: &'a Bump, exec: &mut ExecExpr<'a>) { if let BaseExec::Ident(i) = &exec.exec.base { if i.name == self.ident.name { let mut subst_exec = self.term.clone(); @@ -872,83 +1151,159 @@ impl<'a> VisitMut for SubstIdent<'a, ExecExpr> { } } } +*/ + +impl<'a> VisitMut<'a> for SubstIdent<'a, ExecExpr<'a>> { + fn visit_exec_expr(&mut self, arena: &'a Bump, exec: &mut ExecExpr<'a>) { + use crate::arena_ast::{BaseExec, ExecExpr, ExecExprKind}; + + if let BaseExec::Ident(i) = &exec.exec.base { + if i.name == self.ident.name { + let mut merged = bumpalo::collections::Vec::new_in(arena); + + for e in self.term.exec.path.iter() { + merged.push(e.clone_in(arena)); + } + for e in exec.exec.path.iter() { + merged.push(e.clone_in(arena)); + } + + let new_kind = arena.alloc(ExecExprKind { + base: self.term.exec.base.clone_in(arena), + path: merged, + }); + + *exec = ExecExpr { + exec: new_kind, + ty: exec.ty, + span: exec.span, + }; + + return; + } + } + } +} #[cfg(test)] mod tests { use super::*; + use bumpalo::Bump; + + fn shrd_ref_ty<'a>(arena: &'a Bump) -> DataTy<'a> { + //Dim::X(Box::new(Dim1d(Nat::Lit(32)))); + let elem = DataTy::new(arena, DataTyKind::Scalar(ScalarTy::I32)); + let arr = DataTy::new( + arena, + DataTyKind::Array(arena.alloc(elem), Nat::Ident(Ident::new(arena, "n"))), + ); - fn shrd_ref_ty() -> DataTy { - Dim::X(Box::new(Dim1d(Nat::Lit(32)))); - DataTy::new(DataTyKind::Ref(Box::new(RefDty::new( - Provenance::Value("r".to_string()), + let ref_dty = RefDty::new( + arena, + Provenance::Value("r"), Ownership::Shrd, Memory::GpuGlobal, - DataTy::new(DataTyKind::Array( - Box::new(DataTy::new(DataTyKind::Scalar(ScalarTy::I32))), - Nat::Ident(Ident::new("n")), - )), - )))) + arr, + ); + + DataTy::new(arena, DataTyKind::Ref(arena.alloc(ref_dty))) } #[test] - fn scalar() -> UnifyResult<()> { - let mut i32 = DataTy::new(DataTyKind::Scalar(ScalarTy::I32)); - let mut t = DataTy::new(DataTyKind::Ident(Ident::new_impli("t"))); - let (subst, _) = constrain(&mut i32, &mut t)?; - substitute(&subst, &mut i32); - substitute(&subst, &mut t); - assert_eq!(i32, t); + fn scalar<'a>() -> UnifyResult<'a, ()> { + let arena = Bump::new(); + + let mut i32_ty = DataTy::new(&arena, DataTyKind::Scalar(ScalarTy::I32)); + let mut t = DataTy::new(&arena, DataTyKind::Ident(Ident::new_impli(&arena, "t"))); + + let lhs = arena.alloc(i32_ty.clone_in(&arena)); + let rhs = arena.alloc(t.clone_in(&arena)); + let (subst, _prv) = constrain(lhs, rhs, &arena).unwrap(); + + substitute(&subst, &mut i32_ty, &arena); + substitute(&subst, &mut t, &arena); + + assert_eq!(i32_ty, t); Ok(()) } - #[test] - fn shrd_reft() -> UnifyResult<()> { - let mut t = DataTy::new(DataTyKind::Ident(Ident::new_impli("t"))); - let mut shrd_ref = shrd_ref_ty(); - let (subst, _) = constrain(&mut shrd_ref, &mut t)?; - substitute(&subst, &mut shrd_ref); - substitute(&subst, &mut t); + fn shrd_reft<'a>() -> UnifyResult<'a, ()> { + let arena = Bump::new(); + + let mut t = DataTy::new(&arena, DataTyKind::Ident(Ident::new_impli(&arena, "t"))); + let mut shrd_ref = shrd_ref_ty(&arena); + + let lhs = arena.alloc(shrd_ref.clone_in(&arena)); + let rhs = arena.alloc(t.clone_in(&arena)); + let (subst, _prv) = constrain(lhs, rhs, &arena).unwrap(); + + substitute(&subst, &mut shrd_ref, &arena); + substitute(&subst, &mut t, &arena); + assert_eq!(shrd_ref, t); Ok(()) } #[test] - fn shrd_ref_inner_var() -> UnifyResult<()> { - let mut shrd_ref_t = DataTy::new(DataTyKind::Ref(Box::new(RefDty::new( - Provenance::Value("r".to_string()), + fn shrd_ref_inner_var<'a>() -> UnifyResult<'a, ()> { + use bumpalo::Bump; + + let arena = Bump::new(); + + let inner_t = DataTy::new(&arena, DataTyKind::Ident(Ident::new_impli(&arena, "t"))); + let ref_t = RefDty::new( + &arena, + Provenance::Value("r"), Ownership::Shrd, Memory::GpuGlobal, - DataTy::new(DataTyKind::Ident(Ident::new_impli("t"))), - )))); - let mut shrd_ref = shrd_ref_ty(); - let (subst, _) = constrain(&mut shrd_ref, &mut shrd_ref_t)?; + inner_t, + ); + let mut shrd_ref_t = DataTy::new(&arena, DataTyKind::Ref(arena.alloc(ref_t))); + + let mut shrd_ref = shrd_ref_ty(&arena); + + let lhs = arena.alloc(shrd_ref.clone_in(&arena)); + let rhs = arena.alloc(shrd_ref_t.clone_in(&arena)); + let (subst, _prv) = constrain(lhs, rhs, &arena).unwrap(); println!("{:?}", subst); - substitute(&subst, &mut shrd_ref); - substitute(&subst, &mut shrd_ref_t); + + substitute(&subst, &mut shrd_ref, &arena); + substitute(&subst, &mut shrd_ref_t, &arena); + assert_eq!(shrd_ref, shrd_ref_t); Ok(()) } #[test] - fn prv_val_ident() -> UnifyResult<()> { - let mut shrd_ref_t = DataTy::new(DataTyKind::Ref(Box::new(RefDty::new( - Provenance::Ident(Ident::new("a")), + fn prv_val_ident<'a>() -> UnifyResult<'a, ()> { + use bumpalo::Bump; + + let arena = Bump::new(); + + let inner_t = DataTy::new(&arena, DataTyKind::Ident(Ident::new_impli(&arena, "t"))); + let ref_t = RefDty::new( + &arena, + Provenance::Ident(Ident::new(&arena, "a")), Ownership::Shrd, Memory::GpuGlobal, - DataTy::new(DataTyKind::Ident(Ident::new_impli("t"))), - )))); - let mut shrd_ref = shrd_ref_ty(); - let (subst, prv_rels) = constrain(&mut shrd_ref, &mut shrd_ref_t)?; - println!("{:?}", subst); - substitute(&subst, &mut shrd_ref); - substitute(&subst, &mut shrd_ref_t); - assert_eq!( - prv_rels[0], - PrvConstr( - Provenance::Value("r".to_string()), - Provenance::Ident(Ident::new("a")) - ) + inner_t, + ); + let mut shrd_ref_t = DataTy::new(&arena, DataTyKind::Ref(arena.alloc(ref_t))); + + let mut shrd_ref = shrd_ref_ty(&arena); + + let lhs = arena.alloc(shrd_ref.clone_in(&arena)); + let rhs = arena.alloc(shrd_ref_t.clone_in(&arena)); + let (subst, prv_rels) = constrain(lhs, rhs, &arena).unwrap(); + + substitute(&subst, &mut shrd_ref, &arena); + substitute(&subst, &mut shrd_ref_t, &arena); + + let expected = PrvConstr( + arena.alloc(Provenance::Value("r")), + arena.alloc(Provenance::Ident(Ident::new(&arena, "a"))), ); + assert_eq!(prv_rels[0], expected); Ok(()) } } diff --git a/tests/cli_test.rs b/tests/cli_test.rs new file mode 100644 index 00000000..ee702360 --- /dev/null +++ b/tests/cli_test.rs @@ -0,0 +1,13 @@ +use assert_cmd::Command; +use predicates::prelude::*; + +#[test] +fn test_emit_cuda_on_transpose_desc() { + let mut cmd = Command::cargo_bin("descendc").expect("Failed to find descendc binary"); + + cmd.arg("emit").arg("examples/infer/transpose.desc"); + + cmd.assert() + .success() + .stderr(predicate::str::contains("Generated CUDA Code")); +}