Skip to content

Commit

Permalink
Fix defunctionalization to inherit runtime of caller
Browse files Browse the repository at this point in the history
  • Loading branch information
aakoshh committed Jan 17, 2025
1 parent 7f2cdef commit bd4f90b
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 24 deletions.
2 changes: 1 addition & 1 deletion compiler/noirc_evaluator/src/ssa/ir/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use super::map::Id;
use super::types::Type;
use super::value::ValueId;

#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub(crate) enum RuntimeType {
// A noir function, to be compiled in ACIR and executed by ACVM
Acir(InlineType),
Expand Down
48 changes: 26 additions & 22 deletions compiler/noirc_evaluator/src/ssa/opt/defunctionalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use crate::ssa::{
function_builder::FunctionBuilder,
ir::{
basic_block::BasicBlockId,
function::{Function, FunctionId, Signature},
function::{Function, FunctionId, RuntimeType, Signature},
instruction::{BinaryOp, Instruction},
types::{NumericType, Type},
value::{Value, ValueId},
Expand Down Expand Up @@ -43,12 +43,15 @@ struct ApplyFunction {
dispatches_to_multiple_functions: bool,
}

type Variants = BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>;
type ApplyFunctions = HashMap<(Signature, RuntimeType), ApplyFunction>;

/// Performs defunctionalization on all functions
/// This is done by changing all functions as value to be a number (FieldElement)
/// And creating apply functions that dispatch to the correct target by runtime comparisons with constants
#[derive(Debug, Clone)]
struct DefunctionalizationContext {
apply_functions: HashMap<Signature, ApplyFunction>,
apply_functions: ApplyFunctions,
}

impl Ssa {
Expand Down Expand Up @@ -104,7 +107,7 @@ impl DefunctionalizationContext {
};

// Find the correct apply function
let apply_function = self.get_apply_function(&signature);
let apply_function = self.get_apply_function(signature, func.runtime());

// Replace the instruction with a call to apply
let apply_function_value_id = func.dfg.import_function(apply_function.id);
Expand Down Expand Up @@ -152,19 +155,21 @@ impl DefunctionalizationContext {
}

/// Returns the apply function for the given signature
fn get_apply_function(&self, signature: &Signature) -> ApplyFunction {
*self.apply_functions.get(signature).expect("Could not find apply function")
fn get_apply_function(&self, signature: Signature, runtime: RuntimeType) -> ApplyFunction {
*self.apply_functions.get(&(signature, runtime)).expect("Could not find apply function")
}
}

/// Collects all functions used as values that can be called by their signatures
fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
let mut dynamic_dispatches: BTreeSet<Signature> = BTreeSet::new();
fn find_variants(ssa: &Ssa) -> Variants {
let mut dynamic_dispatches: BTreeSet<(Signature, RuntimeType)> = BTreeSet::new();
let mut functions_as_values: BTreeSet<FunctionId> = BTreeSet::new();

for function in ssa.functions.values() {
functions_as_values.extend(find_functions_as_values(function));
dynamic_dispatches.extend(find_dynamic_dispatches(function));
dynamic_dispatches.extend(
find_dynamic_dispatches(function).into_iter().map(|sig| (sig, function.runtime())),
);
}

let mut signature_to_functions_as_value: BTreeMap<Signature, Vec<FunctionId>> = BTreeMap::new();
Expand All @@ -174,16 +179,12 @@ fn find_variants(ssa: &Ssa) -> BTreeMap<Signature, Vec<FunctionId>> {
signature_to_functions_as_value.entry(signature).or_default().push(function_id);
}

let mut variants = BTreeMap::new();
let mut variants: Variants = BTreeMap::new();

for dispatch_signature in dynamic_dispatches {
let mut target_fns = vec![];
for (target_signature, functions) in &signature_to_functions_as_value {
if &dispatch_signature == target_signature {
target_fns.extend(functions);
}
}
variants.insert(dispatch_signature, target_fns);
for (dispatch_signature, caller_runtime) in dynamic_dispatches {
let target_fns =
signature_to_functions_as_value.get(&dispatch_signature).cloned().unwrap_or_default();
variants.insert((dispatch_signature, caller_runtime), target_fns);
}

variants
Expand Down Expand Up @@ -247,22 +248,23 @@ fn find_dynamic_dispatches(func: &Function) -> BTreeSet<Signature> {

fn create_apply_functions(
ssa: &mut Ssa,
variants_map: BTreeMap<Signature, Vec<FunctionId>>,
) -> HashMap<Signature, ApplyFunction> {
variants_map: BTreeMap<(Signature, RuntimeType), Vec<FunctionId>>,
) -> ApplyFunctions {
let mut apply_functions = HashMap::default();
for (signature, variants) in variants_map.into_iter() {
for ((signature, runtime), variants) in variants_map.into_iter() {
assert!(
!variants.is_empty(),
"ICE: at least one variant should exist for a dynamic call {signature:?}"
);
let dispatches_to_multiple_functions = variants.len() > 1;

let id = if dispatches_to_multiple_functions {
create_apply_function(ssa, signature.clone(), variants)
create_apply_function(ssa, signature.clone(), runtime, variants)
} else {
variants[0]
};
apply_functions.insert(signature, ApplyFunction { id, dispatches_to_multiple_functions });
apply_functions
.insert((signature, runtime), ApplyFunction { id, dispatches_to_multiple_functions });
}
apply_functions
}
Expand All @@ -275,13 +277,15 @@ fn function_id_to_field(function_id: FunctionId) -> FieldElement {
fn create_apply_function(
ssa: &mut Ssa,
signature: Signature,
runtime: RuntimeType,
function_ids: Vec<FunctionId>,
) -> FunctionId {
assert!(!function_ids.is_empty());
let globals = ssa.functions[&function_ids[0]].dfg.globals.clone();
ssa.add_fn(|id| {
let mut function_builder = FunctionBuilder::new("apply".to_string(), id);
function_builder.set_globals(globals);
function_builder.set_runtime(runtime);
let target_id = function_builder.add_parameter(Type::field());
let params_ids = vecmap(signature.params, |typ| function_builder.add_parameter(typ));

Expand Down
4 changes: 3 additions & 1 deletion compiler/noirc_frontend/src/monomorphization/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,9 @@ pub type Parameters = Vec<(LocalId, /*mutable:*/ bool, /*name:*/ String, Type)>;

/// Represents how an Acir function should be inlined.
/// This type is only relevant for ACIR functions as we do not inline any Brillig functions
#[derive(Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize)]
#[derive(
Default, Clone, Copy, PartialEq, Eq, Debug, Hash, Serialize, Deserialize, PartialOrd, Ord,
)]
pub enum InlineType {
/// The most basic entry point can expect all its functions to be inlined.
/// All function calls are expected to be inlined into a single ACIR.
Expand Down

0 comments on commit bd4f90b

Please sign in to comment.