Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions crates/cairo-lang-sierra-generator/src/block_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
mod test;

use cairo_lang_diagnostics::Maybe;
use cairo_lang_filesystem::flag::flag_future_sierra;
use cairo_lang_lowering::BlockId;
use cairo_lang_lowering::db::LoweringGroup;
use cairo_lang_lowering::ids::LocationId;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use itertools::{chain, enumerate, zip_eq};
Expand All @@ -24,9 +26,9 @@ use crate::utils::{
drop_libfunc_id, dup_libfunc_id, enable_ap_tracking_libfunc_id,
enum_from_bounded_int_libfunc_id, enum_init_libfunc_id, get_concrete_libfunc_id,
get_libfunc_signature, into_box_libfunc_id, jump_libfunc_id, jump_statement,
match_enum_libfunc_id, rename_libfunc_id, return_statement, simple_basic_statement,
snapshot_take_libfunc_id, struct_construct_libfunc_id, struct_deconstruct_libfunc_id,
unbox_libfunc_id,
local_into_box_libfunc_id, match_enum_libfunc_id, rename_libfunc_id, return_statement,
simple_basic_statement, snapshot_take_libfunc_id, struct_construct_libfunc_id,
struct_deconstruct_libfunc_id, unbox_libfunc_id,
};

/// Generates Sierra code for the body of the given [lowering::Block].
Expand Down Expand Up @@ -564,11 +566,20 @@ fn generate_statement_into_box<'db>(
statement_location: &StatementLocation,
) -> Maybe<()> {
let input = maybe_add_dup_statement(context, statement_location, 0, &statement.input)?;
let ty = context.get_variable_sierra_type(statement.input.var_id)?;
let db = context.get_db();
let semantic_ty = context.get_lowered_variable(statement.input.var_id).ty;
// When size < 3, into_box is cheaper.
let use_local_into_box = flag_future_sierra(db)
&& context.is_non_ap_based(statement.input.var_id)
&& db.type_size(semantic_ty) >= 3;
let libfunc_id = if use_local_into_box {
local_into_box_libfunc_id(db, ty)
} else {
into_box_libfunc_id(db, ty)
};
let stmt = simple_basic_statement(
into_box_libfunc_id(
context.get_db(),
context.get_variable_sierra_type(statement.input.var_id)?,
),
libfunc_id,
&[input],
&[context.get_sierra_variable(statement.output)],
);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ fn block_generator_test(
function_id,
&lifetime,
crate::ap_tracking::ApTrackingConfiguration::default(),
Default::default(),
);

let mut expected_sierra_code = String::default();
Expand Down
14 changes: 14 additions & 0 deletions crates/cairo-lang-sierra-generator/src/expr_generator_context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use cairo_lang_sierra::extensions::uninitialized::UninitializedType;
use cairo_lang_sierra::program::{ConcreteTypeLongId, GenericArg};
use cairo_lang_utils::Intern;
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use lowering::ids::ConcreteFunctionWithBodyId;
use lowering::{BlockId, Lowered};
use salsa::Database;
Expand All @@ -26,6 +27,7 @@ pub struct ExprGeneratorContext<'db, 'a> {
var_id_allocator: IdAllocator,
label_id_allocator: IdAllocator,
variables: OrderedHashMap<SierraGenVar, cairo_lang_sierra::ids::VarId>,
non_ap_based_variables: UnorderedHashSet<VariableId>,
/// Allocated Sierra variables and their locations.
variable_locations: Vec<(cairo_lang_sierra::ids::VarId, LocationId<'db>)>,
block_labels: OrderedHashMap<BlockId, pre_sierra::LabelId<'db>>,
Expand All @@ -48,6 +50,7 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
function_id: ConcreteFunctionWithBodyId<'db>,
lifetime: &'a VariableLifetimeResult,
ap_tracking_configuration: ApTrackingConfiguration,
non_ap_based_variables: UnorderedHashSet<VariableId>,
) -> Self {
ExprGeneratorContext {
db,
Expand All @@ -61,6 +64,7 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
block_labels: OrderedHashMap::default(),
ap_tracking_enabled: true,
ap_tracking_configuration,
non_ap_based_variables,
statements: vec![],
curr_cairo_location: None,
}
Expand Down Expand Up @@ -189,6 +193,16 @@ impl<'db, 'a> ExprGeneratorContext<'db, 'a> {
&& self.ap_tracking_configuration.disable_ap_tracking.contains(block_id)
}

/// Returns true if the variable is non-AP-based.
pub fn is_non_ap_based(&self, var_id: VariableId) -> bool {
self.non_ap_based_variables.contains(&var_id)
}

/// Returns the lowered variable for the given variable id.
pub fn get_lowered_variable(&self, var_id: VariableId) -> &lowering::Variable<'db> {
&self.lowered.variables[var_id]
}

/// Adds a statement for the expression.
pub fn push_statement(&mut self, statement: pre_sierra::Statement<'db>) {
self.statements.push(pre_sierra::StatementWithLocation {
Expand Down
10 changes: 8 additions & 2 deletions crates/cairo-lang-sierra-generator/src/function_generator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ fn get_function_ap_change_and_code<'db>(
analyze_ap_change_result: AnalyzeApChangesResult,
) -> Maybe<pre_sierra::Function<'db>> {
let root_block = lowered_function.blocks.root_block()?;
let AnalyzeApChangesResult { known_ap_change, local_variables, ap_tracking_configuration } =
analyze_ap_change_result;
let AnalyzeApChangesResult {
known_ap_change,
local_variables,
ap_tracking_configuration,
non_ap_based_variables,
} = analyze_ap_change_result;

// Get lifetime information.
let lifetime = find_variable_lifetime(lowered_function, &local_variables)?;
Expand All @@ -77,6 +81,7 @@ fn get_function_ap_change_and_code<'db>(
function_id,
&lifetime,
ap_tracking_configuration,
non_ap_based_variables,
);

// If the function starts with `revoke_ap_tracking` then we can avoid
Expand Down Expand Up @@ -159,6 +164,7 @@ pub fn priv_get_dummy_function<'db>(
function_id,
&lifetime,
ap_tracking_configuration,
Default::default(),
);

// Generate a label for the function's body.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cairo_lang_test_utils::test_file_test!(
function_generator,
"src/function_generator_test_data",
{
boxing: "boxing",
inline: "inline",
struct_: "struct",
match_: "match",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//! > Test local_into_box for large struct parameter (size >= 3).

//! > test_runner_name
test_function_generator

//! > function_code
fn foo(a: MyStruct) -> Box<MyStruct> {
BoxTrait::new(a)
}

//! > function_name
foo

//! > module_code
struct MyStruct {
x: felt252,
y: felt252,
z: felt252,
}

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > sierra_gen_diagnostics

//! > sierra_code
label_test::foo::0:
local_into_box<test::MyStruct>([0]) -> ([1])
store_temp<Box<test::MyStruct>>([1]) -> ([1])
return([1])

//! > ==========================================================================

//! > Test into_box for small struct parameter (size < 3).

//! > test_runner_name
test_function_generator

//! > function_code
fn foo(a: SmallStruct) -> Box<SmallStruct> {
BoxTrait::new(a)
}

//! > function_name
foo

//! > module_code
struct SmallStruct {
x: felt252,
y: felt252,
}

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > sierra_gen_diagnostics

//! > sierra_code
label_test::foo::0:
into_box<test::SmallStruct>([0]) -> ([1])
return([1])

//! > ==========================================================================

//! > Test chained boxing: inner uses local_into_box, outer uses into_box (Box has size 1).

//! > test_runner_name
test_function_generator

//! > function_code
fn foo(a: MyStruct) -> Box<Box<MyStruct>> {
BoxTrait::new(BoxTrait::new(a))
}

//! > function_name
foo

//! > module_code
struct MyStruct {
x: felt252,
y: felt252,
z: felt252,
}

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > sierra_gen_diagnostics

//! > sierra_code
label_test::foo::0:
local_into_box<test::MyStruct>([0]) -> ([1])
store_temp<Box<test::MyStruct>>([1]) -> ([1])
into_box<Box<test::MyStruct>>([1]) -> ([2])
return([2])

//! > ==========================================================================

//! > Test local_into_box for variable that becomes local due to revoke.

//! > test_runner_name
test_function_generator

//! > function_code
fn foo() -> Box<MyStruct> {
let x = create_struct();
revoke_ap();
BoxTrait::new(x)
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
struct MyStruct {
x: felt252,
y: felt252,
z: felt252,
}

#[inline(never)]
fn create_struct() -> MyStruct {
MyStruct { x: 1, y: 2, z: 3 }
}

fn revoke_ap() {
revoke_ap()
}

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > sierra_gen_diagnostics

//! > sierra_code
label_test::foo::0:
alloc_local<test::MyStruct>() -> ([1])
finalize_locals() -> ()
disable_ap_tracking() -> ()
function_call<user@test::create_struct>() -> ([0])
store_local<test::MyStruct>([1], [0]) -> ([0])
function_call<user@test::revoke_ap>() -> ()
local_into_box<test::MyStruct>([0]) -> ([2])
store_temp<Box<test::MyStruct>>([2]) -> ([2])
return([2])

//! > ==========================================================================

//! > Test local_into_box for snapshot of parameter (alias is non-AP-based).

//! > test_runner_name
test_function_generator

//! > function_code
fn foo(a: MyStruct) -> Box<@MyStruct> {
BoxTrait::new(@a)
}

//! > function_name
foo

//! > module_code
#[derive(Drop)]
struct MyStruct {
x: felt252,
y: felt252,
z: felt252,
}

//! > semantic_diagnostics

//! > lowering_diagnostics

//! > sierra_gen_diagnostics

//! > sierra_code
label_test::foo::0:
snapshot_take<test::MyStruct>([0]) -> ([1], [2])
drop<test::MyStruct>([1]) -> ()
local_into_box<test::MyStruct>([2]) -> ([3])
store_temp<Box<test::MyStruct>>([3]) -> ([3])
return([3])
13 changes: 12 additions & 1 deletion crates/cairo-lang-sierra-generator/src/local_variables.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ use crate::utils::{

/// Information returned by [analyze_ap_changes].
pub struct AnalyzeApChangesResult {
/// True if the function has a known_ap_change
/// True if the function has a known_ap_change.
pub known_ap_change: bool,
/// The variables that should be stored in locals as they are revoked during the function.
pub local_variables: OrderedHashSet<VariableId>,
/// Information about where ap tracking should be enabled and disabled.
pub ap_tracking_configuration: ApTrackingConfiguration,
/// Variables that are known to be non-AP-based (expanded to include all aliases).
pub non_ap_based_variables: UnorderedHashSet<VariableId>,
}

/// Does ap change related analysis for a given function.
Expand Down Expand Up @@ -98,6 +100,14 @@ pub fn analyze_ap_changes<'db>(
}
}

// Expand non_ap_based to include all aliases.
let non_ap_based_variables: UnorderedHashSet<_> = lowered_function
.variables
.iter()
.map(|(id, _)| id)
.filter(|v| ctx.non_ap_based.contains(ctx.peel_aliases(v)))
.collect();

Ok(AnalyzeApChangesResult {
known_ap_change: root_info.known_ap_change,
local_variables: locals,
Expand All @@ -106,6 +116,7 @@ pub fn analyze_ap_changes<'db>(
root_info.known_ap_change,
need_ap_alignment,
),
non_ap_based_variables,
})
}

Expand Down
8 changes: 8 additions & 0 deletions crates/cairo-lang-sierra-generator/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,14 @@ pub fn unbox_libfunc_id(
get_libfunc_id_with_generic_arg(db, "unbox", ty)
}

/// Returns the [ConcreteLibfuncId] associated with `local_into_box`.
pub fn local_into_box_libfunc_id(
db: &dyn Database,
ty: cairo_lang_sierra::ids::ConcreteTypeId,
) -> cairo_lang_sierra::ids::ConcreteLibfuncId {
get_libfunc_id_with_generic_arg(db, "local_into_box", ty)
}

pub fn enum_init_libfunc_id(
db: &dyn Database,
ty: cairo_lang_sierra::ids::ConcreteTypeId,
Expand Down
Loading
Loading