From dd7084545dfd93a07599fc10676b6c8ec1e3d458 Mon Sep 17 00:00:00 2001 From: Akosh Farkash Date: Fri, 17 Jan 2025 21:03:22 +0000 Subject: [PATCH] fix: defunctionalize pass on the caller runtime to apply (#7100) Co-authored-by: Tom French --- .../noirc_evaluator/src/ssa/ir/function.rs | 2 +- .../src/ssa/opt/defunctionalize.rs | 203 +++++++++++++++--- .../noirc_evaluator/src/ssa/opt/inlining.rs | 17 +- .../src/ssa/parser/into_ssa.rs | 38 ++-- .../src/monomorphization/ast.rs | 4 +- 5 files changed, 214 insertions(+), 50 deletions(-) diff --git a/compiler/noirc_evaluator/src/ssa/ir/function.rs b/compiler/noirc_evaluator/src/ssa/ir/function.rs index b59b0c18a1..b21a84d16d 100644 --- a/compiler/noirc_evaluator/src/ssa/ir/function.rs +++ b/compiler/noirc_evaluator/src/ssa/ir/function.rs @@ -12,7 +12,7 @@ use super::map::Id; use super::types::Type; use super::value::ValueId; -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)] pub(crate) enum RuntimeType { // A noir function, to be compiled in ACIR and executed by ACVM Acir(InlineType), diff --git a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs index 186f10c53e..a6e04332c0 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs @@ -8,12 +8,13 @@ use std::collections::{BTreeMap, BTreeSet, HashSet}; use acvm::FieldElement; use iter_extended::vecmap; +use noirc_frontend::monomorphization::ast::InlineType; use crate::ssa::{ function_builder::FunctionBuilder, ir::{ basic_block::BasicBlockId, - function::{Function, FunctionId, Signature}, + function::{Function, FunctionId, RuntimeType, Signature}, instruction::{BinaryOp, Instruction}, types::{NumericType, Type}, value::{Value, ValueId}, @@ -43,12 +44,15 @@ struct ApplyFunction { dispatches_to_multiple_functions: bool, } +type Variants = BTreeMap<(Signature, RuntimeType), Vec>; +type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>; + /// Performs defunctionalization on all functions /// This is done by changing all functions as value to be a number (FieldElement) /// And creating apply functions that dispatch to the correct target by runtime comparisons with constants #[derive(Debug, Clone)] struct DefunctionalizationContext { - apply_functions: HashMap, + apply_functions: ApplyFunctions, } impl Ssa { @@ -104,7 +108,7 @@ impl DefunctionalizationContext { }; // Find the correct apply function - let apply_function = self.get_apply_function(&signature); + let apply_function = self.get_apply_function(signature, func.runtime()); // Replace the instruction with a call to apply let apply_function_value_id = func.dfg.import_function(apply_function.id); @@ -152,19 +156,21 @@ impl DefunctionalizationContext { } /// Returns the apply function for the given signature - fn get_apply_function(&self, signature: &Signature) -> ApplyFunction { - *self.apply_functions.get(signature).expect("Could not find apply function") + fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction { + *self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function") } } /// Collects all functions used as values that can be called by their signatures -fn find_variants(ssa: &Ssa) -> BTreeMap> { - let mut dynamic_dispatches: BTreeSet = BTreeSet::new(); +fn find_variants(ssa: &Ssa) -> Variants { + let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new(); let mut functions_as_values: BTreeSet = BTreeSet::new(); for function in ssa.functions.values() { functions_as_values.extend(find_functions_as_values(function)); - dynamic_dispatches.extend(find_dynamic_dispatches(function)); + dynamic_dispatches.extend( + find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())), + ); } let mut signature_to_functions_as_value: BTreeMap> = BTreeMap::new(); @@ -174,16 +180,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap> { signature_to_functions_as_value.entry(signature).or_default().push(function_id); } - let mut variants = BTreeMap::new(); + let mut variants: Variants = BTreeMap::new(); - for dispatch_signature in dynamic_dispatches { - let mut target_fns = vec![]; - for (target_signature, functions) in &signature_to_functions_as_value { - if &dispatch_signature == target_signature { - target_fns.extend(functions); - } - } - variants.insert(dispatch_signature, target_fns); + for (dispatch_signature, caller_runtime) in dynamic_dispatches { + let target_fns = + signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default(); + variants.insert((dispatch_signature, caller_runtime), target_fns); } variants @@ -247,10 +249,10 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet { fn create_apply_functions( ssa: &mut Ssa, - variants_map: BTreeMap>, -) -> HashMap { + variants_map: BTreeMap<(Signature, RuntimeType), Vec>, +) -> ApplyFunctions { let mut apply_functions = HashMap::default(); - for (signature, variants) in variants_map.into_iter() { + for ((signature, runtime), variants) in variants_map.into_iter() { assert!( !variants.is_empty(), "ICE: at least one variant should exist for a dynamic call {signature:?}" @@ -258,11 +260,12 @@ fn create_apply_functions( let dispatches_to_multiple_functions = variants.len() > 1; let id = if dispatches_to_multiple_functions { - create_apply_function(ssa, signature.clone(), variants) + create_apply_function(ssa, signature.clone(), runtime, variants) } else { variants[0] }; - apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions }); + apply_functions + .insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions }); } apply_functions } @@ -275,6 +278,7 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement { fn create_apply_function( ssa: &mut Ssa, signature: Signature, + caller_runtime: RuntimeType, function_ids: Vec, ) -> FunctionId { assert!(!function_ids.is_empty()); @@ -282,6 +286,13 @@ fn create_apply_function( ssa.add_fn(|id| { let mut function_builder = FunctionBuilder::new("apply".to_string(), id); function_builder.set_globals(globals); + + // We want to push for apply functions to be inlined more aggressively. + let runtime = match caller_runtime { + RuntimeType::Acir(_) => RuntimeType::Acir(InlineType::InlineAlways), + RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways), + }; + function_builder.set_runtime(runtime); let target_id = function_builder.add_parameter(Type::field()); let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ)); @@ -339,22 +350,156 @@ fn create_apply_function( }) } -/// Crates a return block, if no previous return exists, it will create a final return -/// Else, it will create a bypass return block that points to the previous return block +/// If no previous return target exists, it will create a final return, +/// otherwise returns the existing return block to jump to. fn build_return_block( builder: &mut FunctionBuilder, previous_block: BasicBlockId, passed_types: &[Type], target: Option, ) -> BasicBlockId { + if let Some(return_block) = target { + return return_block; + } let return_block = builder.insert_block(); builder.switch_to_block(return_block); - let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone())); - match target { - None => builder.terminate_with_return(params), - Some(target) => builder.terminate_with_jmp(target, params), - } + builder.terminate_with_return(params); builder.switch_to_block(previous_block); return_block } + +#[cfg(test)] +mod tests { + use crate::ssa::opt::assert_normalized_ssa_equals; + + use super::Ssa; + + #[test] + fn apply_inherits_caller_runtime() { + // Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig` + let src = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let expected = " + brillig(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(Field 2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f1(Field 3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: Field, v1: u32): + v3 = call f4(v0, v1) -> u32 + return v3 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + brillig(inline_always) fn apply f4 { + b0(v0: Field, v1: u32): + v4 = eq v0, Field 2 + jmpif v4 then: b2, else: b1 + b1(): + constrain v0 == Field 3 + v7 = call f3(v1) -> u32 + jmp b3(v7) + b2(): + v9 = call f2(v1) -> u32 + jmp b3(v9) + b3(v2: u32): + return v2 + } + "; + assert_normalized_ssa_equals(ssa, expected); + } + + #[test] + fn apply_created_per_caller_runtime() { + let src = " + acir(inline) fn main f0 { + b0(v0: u32): + v3 = call f1(f2, v0) -> u32 + v5 = add v0, u32 1 + v6 = eq v3, v5 + constrain v3 == v5 + v9 = call f4(f3, v0) -> u32 + v10 = add v0, u32 1 + v11 = eq v9, v10 + constrain v9 == v10 + return + } + brillig(inline) fn wrapper f1 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + acir(inline) fn wrapper_acir f4 { + b0(v0: function, v1: u32): + v2 = call v0(v1) -> u32 + return v2 + } + brillig(inline) fn increment f2 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + acir(inline) fn increment_acir f3 { + b0(v0: u32): + v2 = add v0, u32 1 + return v2 + } + "; + + let ssa = Ssa::from_str(src).unwrap(); + let ssa = ssa.defunctionalize(); + + let applies = ssa.functions.values().filter(|f| f.name() == "apply").collect::>(); + assert_eq!(applies.len(), 2); + assert!(applies.iter().any(|f| f.runtime().is_acir())); + assert!(applies.iter().any(|f| f.runtime().is_brillig())); + } +} diff --git a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs index 88cf70e13c..7554ad64a9 100644 --- a/compiler/noirc_evaluator/src/ssa/opt/inlining.rs +++ b/compiler/noirc_evaluator/src/ssa/opt/inlining.rs @@ -66,22 +66,27 @@ impl Ssa { self.functions = btree_map(inline_sources, |entry_point| { let should_inline_call = |_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool { - let function = &ssa.functions[&called_func_id]; + let callee = &ssa.functions[&called_func_id]; + let caller_runtime = ssa.functions[entry_point].runtime(); - match function.runtime() { + match callee.runtime() { RuntimeType::Acir(inline_type) => { // If the called function is acir, we inline if it's not an entry point // If we have not already finished the flattening pass, functions marked // to not have predicates should be preserved. let preserve_function = - !inline_no_predicates_functions && function.is_no_predicates(); + !inline_no_predicates_functions && callee.is_no_predicates(); !inline_type.is_entry_point() && !preserve_function } RuntimeType::Brillig(_) => { - // If the called function is brillig, we inline only if it's into brillig and the function is not recursive - ssa.functions[entry_point].runtime().is_brillig() - && !inline_sources.contains(&called_func_id) + if caller_runtime.is_acir() { + // We never inline a brillig function into an ACIR function. + return false; + } + + // Avoid inlining recursive functions. + !inline_sources.contains(&called_func_id) } } }; diff --git a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs index 98e7586cab..e2eea234dc 100644 --- a/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs +++ b/compiler/noirc_evaluator/src/ssa/parser/into_ssa.rs @@ -26,7 +26,7 @@ impl ParsedSsa { struct Translator { builder: FunctionBuilder, - /// Maps function names to their IDs + /// Maps internal function names (e.g. "f1") to their IDs functions: HashMap, /// Maps block names to their IDs @@ -137,14 +137,14 @@ impl Translator { match block.terminator { ParsedTerminator::Jmp { destination, arguments } => { - let block_id = self.lookup_block(destination)?; + let block_id = self.lookup_block(&destination)?; let arguments = self.translate_values(arguments)?; self.builder.terminate_with_jmp(block_id, arguments); } ParsedTerminator::Jmpif { condition, then_block, else_block } => { let condition = self.translate_value(condition)?; - let then_destination = self.lookup_block(then_block)?; - let else_destination = self.lookup_block(else_block)?; + let then_destination = self.lookup_block(&then_block)?; + let else_destination = self.lookup_block(&else_block)?; self.builder.terminate_with_jmpif(condition, then_destination, else_destination); } ParsedTerminator::Return(values) => { @@ -189,8 +189,13 @@ impl Translator { let function_id = if let Some(id) = self.builder.import_intrinsic(&function.name) { id } else { - let function_id = self.lookup_function(function)?; - self.builder.import_function(function_id) + let maybe_func = + self.lookup_function(&function).map(|f| self.builder.import_function(f)); + + maybe_func.or_else(|e| { + // e.g. `v2 = call v0(v1) -> u32`, a lambda passed as a parameter + self.lookup_variable(&function).map_err(|_| e) + })? }; let arguments = self.translate_values(arguments)?; @@ -295,7 +300,14 @@ impl Translator { ParsedValue::NumericConstant { constant, typ } => { Ok(self.builder.numeric_constant(constant, typ.unwrap_numeric())) } - ParsedValue::Variable(identifier) => self.lookup_variable(identifier), + ParsedValue::Variable(identifier) => self.lookup_variable(&identifier).or_else(|e| { + self.lookup_function(&identifier) + .map(|f| { + // e.g. `v3 = call f1(f2, v0) -> u32` + self.builder.import_function(f) + }) + .map_err(|_| e) + }), } } @@ -316,27 +328,27 @@ impl Translator { Ok(()) } - fn lookup_variable(&mut self, identifier: Identifier) -> Result { + fn lookup_variable(&mut self, identifier: &Identifier) -> Result { if let Some(value_id) = self.variables[&self.current_function_id()].get(&identifier.name) { Ok(*value_id) } else { - Err(SsaError::UnknownVariable(identifier)) + Err(SsaError::UnknownVariable(identifier.clone())) } } - fn lookup_block(&mut self, identifier: Identifier) -> Result { + fn lookup_block(&mut self, identifier: &Identifier) -> Result { if let Some(block_id) = self.blocks[&self.current_function_id()].get(&identifier.name) { Ok(*block_id) } else { - Err(SsaError::UnknownBlock(identifier)) + Err(SsaError::UnknownBlock(identifier.clone())) } } - fn lookup_function(&mut self, identifier: Identifier) -> Result { + fn lookup_function(&mut self, identifier: &Identifier) -> Result { if let Some(function_id) = self.functions.get(&identifier.name) { Ok(*function_id) } else { - Err(SsaError::UnknownFunction(identifier)) + Err(SsaError::UnknownFunction(identifier.clone())) } } diff --git a/compiler/noirc_frontend/src/monomorphization/ast.rs b/compiler/noirc_frontend/src/monomorphization/ast.rs index 65bddcb680..621eb30e4f 100644 --- a/compiler/noirc_frontend/src/monomorphization/ast.rs +++ b/compiler/noirc_frontend/src/monomorphization/ast.rs @@ -228,7 +228,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>; /// Represents how an Acir function should be inlined. /// This type is only relevant for ACIR functions as we do not inline any Brillig functions -#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)] +#[derive( + Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord, +)] pub enum InlineType { /// The most basic entry point can expect all its functions to be inlined. /// All function calls are expected to be inlined into a single ACIR.