Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: defunctionalize pass on the caller runtime to apply #7100

Merged
merged 9 commits into from
Jan 17, 2025
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::map::Id;
use super::types::Type;
use super::value::ValueId;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub(crate) enum RuntimeType {
// A noir function, to be compiled in ACIR and executed by ACVM
Acir(InlineType),
Expand Down
203 changes: 174 additions & 29 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@ use std::collections::{BTreeMap, BTreeSet, HashSet};

use acvm::FieldElement;
use iter_extended::vecmap;
use noirc_frontend::monomorphization::ast::InlineType;

use crate::ssa::{
function_builder::FunctionBuilder,
ir::{
basic_block::BasicBlockId,
function::{Function, FunctionId, Signature},
function::{Function, FunctionId, RuntimeType, Signature},
instruction::{BinaryOp, Instruction},
types::{NumericType, Type},
value::{Value, ValueId},
Expand Down Expand Up @@ -43,12 +44,15 @@ struct ApplyFunction {
dispatches_to_multiple_functions: bool,
}

type Variants = BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>;
type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>;

/// Performs defunctionalization on all functions
/// This is done by changing all functions as value to be a number (FieldElement)
/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants
#[derive(Debug, Clone)]
struct DefunctionalizationContext {
apply_functions: HashMap<Signature, ApplyFunction>,
apply_functions: ApplyFunctions,
}

impl Ssa {
Expand Down Expand Up @@ -104,7 +108,7 @@ impl DefunctionalizationContext {
};

// Find the correct apply function
let apply_function = self.get_apply_function(&signature);
let apply_function = self.get_apply_function(signature, func.runtime());

// Replace the instruction with a call to apply
let apply_function_value_id = func.dfg.import_function(apply_function.id);
Expand Down Expand Up @@ -152,19 +156,21 @@ impl DefunctionalizationContext {
}

/// Returns the apply function for the given signature
fn get_apply_function(&self, signature: &Signature) -> ApplyFunction {
*self.apply_functions.get(signature).expect("Could not find apply function")
fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction {
*self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function")
}
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: BTreeSet<Signature> = BTreeSet::new();
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

for function in ssa.functions.values() {
functions_as_values.extend(find_functions_as_values(function));
dynamic_dispatches.extend(find_dynamic_dispatches(function));
dynamic_dispatches.extend(
find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())),
);
}

let mut signature_to_functions_as_value: BTreeMap<Signature, Vec<FunctionId>> = BTreeMap::new();
Expand All @@ -174,16 +180,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
signature_to_functions_as_value.entry(signature).or_default().push(function_id);
}

let mut variants = BTreeMap::new();
let mut variants: Variants = BTreeMap::new();

for dispatch_signature in dynamic_dispatches {
let mut target_fns = vec![];
for (target_signature, functions) in &signature_to_functions_as_value {
if &dispatch_signature == target_signature {
target_fns.extend(functions);
}
}
variants.insert(dispatch_signature, target_fns);
for (dispatch_signature, caller_runtime) in dynamic_dispatches {
let target_fns =
signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default();
variants.insert((dispatch_signature, caller_runtime), target_fns);
}

variants
Expand Down Expand Up @@ -247,22 +249,23 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet<Signature> {

fn create_apply_functions(
ssa: &mut Ssa,
variants_map: BTreeMap<Signature, Vec<FunctionId>>,
) -> HashMap<Signature, ApplyFunction> {
variants_map: BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>,
) -> ApplyFunctions {
let mut apply_functions = HashMap::default();
for (signature, variants) in variants_map.into_iter() {
for ((signature, runtime), variants) in variants_map.into_iter() {
assert!(
!variants.is_empty(),
"ICE: at least one variant should exist for a dynamic call {signature:?}"
);
let dispatches_to_multiple_functions = variants.len() > 1;

let id = if dispatches_to_multiple_functions {
create_apply_function(ssa, signature.clone(), variants)
create_apply_function(ssa, signature.clone(), runtime, variants)
} else {
variants[0]
};
apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions });
apply_functions
.insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions });
}
apply_functions
}
Expand All @@ -275,13 +278,21 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement {
fn create_apply_function(
ssa: &mut Ssa,
signature: Signature,
caller_runtime: RuntimeType,
function_ids: Vec<FunctionId>,
) -> FunctionId {
assert!(!function_ids.is_empty());
let globals = ssa.functions[&function_ids[0]].dfg.globals.clone();
ssa.add_fn(|id| {
let mut function_builder = FunctionBuilder::new("apply".to_string(), id);
function_builder.set_globals(globals);

// We want to push for apply functions to be inlined more aggressively.
let runtime = match caller_runtime {
michaeljklein marked this conversation as resolved.
Show resolved Hide resolved
RuntimeType::Acir(_) => RuntimeType::Acir(InlineType::InlineAlways),
RuntimeType::Brillig(_) => RuntimeType::Brillig(InlineType::InlineAlways),
};
function_builder.set_runtime(runtime);
let target_id = function_builder.add_parameter(Type::field());
let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ));

Expand Down Expand Up @@ -339,22 +350,156 @@ fn create_apply_function(
})
}

/// Crates a return block, if no previous return exists, it will create a final return
/// Else, it will create a bypass return block that points to the previous return block
/// If no previous return target exists, it will create a final return,
/// otherwise returns the existing return block to jump to.
fn build_return_block(
builder: &mut FunctionBuilder,
previous_block: BasicBlockId,
passed_types: &[Type],
target: Option<BasicBlockId>,
) -> BasicBlockId {
if let Some(return_block) = target {
return return_block;
}
let return_block = builder.insert_block();
builder.switch_to_block(return_block);

let params = vecmap(passed_types, |typ| builder.add_block_parameter(return_block, typ.clone()));
match target {
None => builder.terminate_with_return(params),
Some(target) => builder.terminate_with_jmp(target, params),
}
builder.terminate_with_return(params);
builder.switch_to_block(previous_block);
return_block
}

#[cfg(test)]
mod tests {
use crate::ssa::opt::assert_normalized_ssa_equals;

use super::Ssa;

#[test]
fn apply_inherits_caller_runtime() {
// Extracted from `execution_success/brillig_fns_as_values` with `--force-brillig`
let src = "
brillig(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(f2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

let expected = "
brillig(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(Field 2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f1(Field 3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: Field, v1: u32):
v3 = call f4(v0, v1) -> u32
return v3
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
brillig(inline_always) fn apply f4 {
b0(v0: Field, v1: u32):
v4 = eq v0, Field 2
jmpif v4 then: b2, else: b1
b1():
constrain v0 == Field 3
v7 = call f3(v1) -> u32
jmp b3(v7)
b2():
v9 = call f2(v1) -> u32
jmp b3(v9)
b3(v2: u32):
return v2
}
";
assert_normalized_ssa_equals(ssa, expected);
}

#[test]
fn apply_created_per_caller_runtime() {
let src = "
acir(inline) fn main f0 {
b0(v0: u32):
v3 = call f1(f2, v0) -> u32
v5 = add v0, u32 1
v6 = eq v3, v5
constrain v3 == v5
v9 = call f4(f3, v0) -> u32
v10 = add v0, u32 1
v11 = eq v9, v10
constrain v9 == v10
return
}
brillig(inline) fn wrapper f1 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
acir(inline) fn wrapper_acir f4 {
b0(v0: function, v1: u32):
v2 = call v0(v1) -> u32
return v2
}
brillig(inline) fn increment f2 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
acir(inline) fn increment_acir f3 {
b0(v0: u32):
v2 = add v0, u32 1
return v2
}
";

let ssa = Ssa::from_str(src).unwrap();
let ssa = ssa.defunctionalize();

let applies = ssa.functions.values().filter(|f| f.name() == "apply").collect::<Vec<_>>();
assert_eq!(applies.len(), 2);
assert!(applies.iter().any(|f| f.runtime().is_acir()));
assert!(applies.iter().any(|f| f.runtime().is_brillig()));
}
}
17 changes: 11 additions & 6 deletions compiler/noirc_evaluator/src/ssa/opt/inlining.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,22 +66,27 @@ impl Ssa {
self.functions = btree_map(inline_sources, |entry_point| {
let should_inline_call =
|_context: &PerFunctionContext, ssa: &Ssa, called_func_id: FunctionId| -> bool {
let function = &ssa.functions[&called_func_id];
let callee = &ssa.functions[&called_func_id];
let caller_runtime = ssa.functions[entry_point].runtime();

match function.runtime() {
match callee.runtime() {
RuntimeType::Acir(inline_type) => {
// If the called function is acir, we inline if it's not an entry point

// If we have not already finished the flattening pass, functions marked
// to not have predicates should be preserved.
let preserve_function =
!inline_no_predicates_functions && function.is_no_predicates();
!inline_no_predicates_functions && callee.is_no_predicates();
!inline_type.is_entry_point() && !preserve_function
}
RuntimeType::Brillig(_) => {
// If the called function is brillig, we inline only if it's into brillig and the function is not recursive
ssa.functions[entry_point].runtime().is_brillig()
&& !inline_sources.contains(&called_func_id)
if caller_runtime.is_acir() {
// We never inline a brillig function into an ACIR function.
return false;
}

// Avoid inlining recursive functions.
!inline_sources.contains(&called_func_id)
}
}
};
Expand Down
Loading
Loading