From bd4f90b054038e52af56f21be828efebd4c0d005 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 01:01:31 +0000 Subject: [PATCH] Fix defunctionalization to inherit runtime of caller --- .../noirc_evaluator/src/ssa/ir/function.rs | 2 +- .../src/ssa/opt/defunctionalize.rs | 48 ++++++++++--------- .../src/monomorphization/ast.rs | 4 +- 3 files changed, 30 insertions(+), 24 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b59b0c18a10..b21a84d16dc 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -12,7 +12,7 @@ use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir(InlineType), diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index da4cbfa89b5..135bbf84d09 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -13,7 +13,7 @@ use crate::ssa::{ function_builder::FunctionBuilder, ir::{ basic_block::BasicBlockId, - function::{Function, FunctionId, Signature}, + function::{Function, FunctionId, RuntimeType, Signature}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, value::{Value, ValueId}, @@ -43,12 +43,15 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants #[derive(Debug, Clone)] struct DefunctionalizationContext { - apply_functions: HashMap, + apply_functions: ApplyFunctions, } impl Ssa { @@ -104,7 +107,7 @@ impl DefunctionalizationContext { }; // Find the correct apply function - let apply_function = self.get_apply_function(&signature); + let apply_function = self.get_apply_function(signature, func.runtime()); // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); @@ -152,19 +155,21 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &Signature) -> ApplyFunction { - *self.apply_functions.get(signature).expect("Could not find apply function") + fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction { + *self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function") } } /// Collects all functions used as values that can be called by their signatures -fn find_variants(ssa: &Ssa) -> BTreeMap> { - let mut dynamic_dispatches: BTreeSet = BTreeSet::new(); +fn find_variants(ssa: &Ssa) -> Variants { + let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); for function in ssa.functions.values() { functions_as_values.extend(find_functions_as_values(function)); - dynamic_dispatches.extend(find_dynamic_dispatches(function)); + dynamic_dispatches.extend( + find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())), + ); } let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); @@ -174,16 +179,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap> { signature_to_functions_as_value.entry(signature).or_default().push(function_id); } - let mut variants = BTreeMap::new(); + let mut variants: Variants = BTreeMap::new(); - for dispatch_signature in dynamic_dispatches { - let mut target_fns = vec![]; - for (target_signature, functions) in &signature_to_functions_as_value { - if &dispatch_signature == target_signature { - target_fns.extend(functions); - } - } - variants.insert(dispatch_signature, target_fns); + for (dispatch_signature, caller_runtime) in dynamic_dispatches { + let target_fns = + signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); + variants.insert((dispatch_signature, caller_runtime), target_fns); } variants @@ -247,10 +248,10 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { fn create_apply_functions( ssa: &mut Ssa, - variants_map: BTreeMap>, -) -> HashMap { + variants_map: BTreeMap<(Signature, RuntimeType), Vec>, +) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for (signature, variants) in variants_map.into_iter() { + for ((signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" @@ -258,11 +259,12 @@ fn create_apply_functions( let dispatches_to_multiple_functions = variants.len() > 1; let id = if dispatches_to_multiple_functions { - create_apply_function(ssa, signature.clone(), variants) + create_apply_function(ssa, signature.clone(), runtime, variants) } else { variants[0] }; - apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions }); + apply_functions + .insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -275,6 +277,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, + runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -282,6 +285,7 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); diff --git a/compiler/noirc_frontend/src/monomorphization/ast.rs b/compiler/noirc_frontend/src/monomorphization/ast.rs index d219e8f7c2d..05df3887848 100644 --- a/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -227,7 +227,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; /// Represents how an Acir function should be inlined. /// This type is only relevant for ACIR functions as we do not inline any Brillig functions -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive( + Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord, +)] pub enum InlineType { /// The most basic entry point can expect all its functions to be inlined. /// All function calls are expected to be inlined into a single ACIR.