Skip to content

Commit

Permalink
fix: defunctionalize pass on the caller runtime to apply (#7100)
Browse files Browse the repository at this point in the history
Co-authored-by: Tom French <[email protected]>
  • Loading branch information
aakoshh and TomAFrench authored Jan 17, 2025
1 parent 7705a62 commit dd70845
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 50 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::map::Id;
use super::types::Type;
use super::value::ValueId;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub(crate) enum RuntimeType {
// A noir function, to be compiled in ACIR and executed by ACVM
Acir(InlineType),
Expand Down
203 changes: 174 additions & 29 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use std::collections::{BTreeMap, BTreeSet, HashSet};

use acvm::FieldElement;
use iter_extended::vecmap;
use noirc_frontend::monomorphization::ast::InlineType;

use crate::ssa::{
function_builder::FunctionBuilder,
ir::{
basic_block::BasicBlockId,
function::{Function, FunctionId, Signature},
function::{Function, FunctionId, RuntimeType, Signature},
instruction::{BinaryOp, Instruction},
types::{NumericType, Type},
value::{Value, ValueId},
Expand Down Expand Up @@ -43,12 +44,15 @@ struct ApplyFunction {
dispatches_to_multiple_functions: bool,
}

type Variants = BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>;
type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>;

/// Performs defunctionalization on all functions
/// This is done by changing all functions as value to be a number (FieldElement)
/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants
#[derive(Debug, Clone)]
struct DefunctionalizationContext {
apply_functions: HashMap<Signature, ApplyFunction>,
apply_functions: ApplyFunctions,
}

impl Ssa {
Expand Down Expand Up @@ -104,7 +108,7 @@ impl DefunctionalizationContext {
};

// Find the correct apply function
let apply_function = self.get_apply_function(&signature);
let apply_function = self.get_apply_function(signature, func.runtime());

// Replace the instruction with a call to apply
let apply_function_value_id = func.dfg.import_function(apply_function.id);
Expand Down Expand Up @@ -152,19 +156,21 @@ impl DefunctionalizationContext {
}

/// Returns the apply function for the given signature
fn get_apply_function(&self, signature: &Signature) -> ApplyFunction {
*self.apply_functions.get(signature).expect("Could not find apply function")
fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction {
*self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function")
}
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: BTreeSet<Signature> = BTreeSet::new();
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

for function in ssa.functions.values() {
functions_as_values.extend(find_functions_as_values(function));
dynamic_dispatches.extend(find_dynamic_dispatches(function));
dynamic_dispatches.extend(
find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())),
);
}

let mut signature_to_functions_as_value: BTreeMap<Signature, Vec<FunctionId>> = BTreeMap::new();
Expand All @@ -174,16 +180,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
signature_to_functions_as_value.entry(signature).or_default().push(function_id);
}

let mut variants = BTreeMap::new();
let mut variants: Variants = BTreeMap::new();

for dispatch_signature in dynamic_dispatches {
let mut target_fns = vec![];
for (target_signature, functions) in &signature_to_functions_as_value {
if &dispatch_signature == target_signature {
target_fns.extend(functions);
}
}
variants.insert(dispatch_signature, target_fns);
for (dispatch_signature, caller_runtime) in dynamic_dispatches {
let target_fns =
signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default();
variants.insert((dispatch_signature, caller_runtime), target_fns);
}

variants
Expand Down Expand Up @@ -247,22 +249,23 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet<Signature> {

fn create_apply_functions(
ssa: &mut Ssa,
variants_map: BTreeMap<Signature, Vec<FunctionId>>,
) -> HashMap<Signature, ApplyFunction> {
variants_map: BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>,
) -> ApplyFunctions {
let mut apply_functions = HashMap::default();
for (signature, variants) in variants_map.into_iter() {
for ((signature, runtime), variants) in variants_map.into_iter() {
assert!(
!variants.is_empty(),
"ICE: at least one variant should exist for a dynamic call {signature:?}"
);
let dispatches_to_multiple_functions = variants.len() > 1;

let id = if dispatches_to_multiple_functions {
create_apply_function(ssa, signature.clone(), variants)
create_apply_function(ssa, signature.clone(), runtime, variants)
} else {
variants[0]
};
apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions });
apply_functions
.insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions });
}
apply_functions
}
Expand All @@ -275,13 +278,21 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement {
fn create_apply_function(
ssa: &mut Ssa,
signature: Signature,
caller_runtime: RuntimeType,
function_ids: Vec<FunctionId>,
) -> FunctionId {
assert!(!function_ids.is_empty());
let globals = ssa.functions[&function_ids[0]].dfg.globals.clone();
ssa.add_fn(|id| {
let mut function_builder = FunctionBuilder::new("apply".to_string(), id);
function_builder.set_globals(globals);

// We want to push for apply functions to be inlined more aggressively.
let runtime = match caller_runtime {
RuntimeType::Acir(_) => RuntimeType::Acir(InlineType::InlineAlways),
RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways),
};
function_builder.set_runtime(runtime);
let target_id = function_builder.add_parameter(Type::field());
let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ));

Expand Down Expand Up @@ -339,22 +350,156 @@ fn create_apply_function(
})
}

/// Crates a return block, if no previous return exists, it will create a final return
/// Else, it will create a bypass return block that points to the previous return block
/// If no previous return target exists, it will create a final return,
/// otherwise returns the existing return block to jump to.
fn build_return_block(
builder: &mut FunctionBuilder,
previous_block: BasicBlockId,
passed_types: &[Type],
target: Option<BasicBlockId>,
) -> BasicBlockId {
if let Some(return_block) = target {
return return_block;
}
let return_block = builder.insert_block();
builder.switch_to_block(return_block);

let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone()));
match target {
None => builder.terminate_with_return(params),
Some(target) => builder.terminate_with_jmp(target, params),
}
builder.terminate_with_return(params);
builder.switch_to_block(previous_block);
return_block
}

#[cfg(test)]
mod tests {
use crate::ssa::opt::assert_normalized_ssa_equals;

use super::Ssa;

#[test]
fn apply_inherits_caller_runtime() {
// Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig`
let src = "
brillig(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(f2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

let expected = "
brillig(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(Field 2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(Field 3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: Field, v1: u32):
v3 = call f4(v0, v1) -> u32
return v3
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline_always) fn apply f4 {
b0(v0: Field, v1: u32):
v4 = eq v0, Field 2
jmpif v4 then: b2, else: b1
b1():
constrain v0 == Field 3
v7 = call f3(v1) -> u32
jmp b3(v7)
b2():
v9 = call f2(v1) -> u32
jmp b3(v9)
b3(v2: u32):
return v2
}
";
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn apply_created_per_caller_runtime() {
let src = "
acir(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(f2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f4(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
acir(inline) fn wrapper_acir f4 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
acir(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

let applies = ssa.functions.values().filter(|f| f.name() == "apply").collect::<Vec<_>>();
assert_eq!(applies.len(), 2);
assert!(applies.iter().any(|f| f.runtime().is_acir()));
assert!(applies.iter().any(|f| f.runtime().is_brillig()));
}
}
17 changes: 11 additions & 6 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,27 @@ impl Ssa {
self.functions = btree_map(inline_sources, |entry_point| {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];
let callee = &ssa.functions[&called_func_id];
let caller_runtime = ssa.functions[entry_point].runtime();

match function.runtime() {
match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_no_predicates_functions && callee.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[entry_point].runtime().is_brillig()
&& !inline_sources.contains(&called_func_id)
if caller_runtime.is_acir() {
// We never inline a brillig function into an ACIR function.
return false;
}

// Avoid inlining recursive functions.
!inline_sources.contains(&called_func_id)
}
}
};
Expand Down
Loading

0 comments on commit dd70845

Please sign in to comment.