diff --git a/crates/cairo-lang-lowering/src/optimizations/reboxing.rs b/crates/cairo-lang-lowering/src/optimizations/reboxing.rs index 1a59a03d151..397a2bfbe4a 100644 --- a/crates/cairo-lang-lowering/src/optimizations/reboxing.rs +++ b/crates/cairo-lang-lowering/src/optimizations/reboxing.rs @@ -14,7 +14,9 @@ use cairo_lang_utils::ordered_hash_set::OrderedHashSet; use itertools::Itertools; use salsa::Database; +use super::var_renamer::VarRenamer; use crate::borrow_check::analysis::StatementLocation; +use crate::utils::RebuilderEx; use crate::{ BlockEnd, Lowered, Statement, StatementStructDestructure, VarUsage, Variable, VariableArena, VariableId, @@ -170,8 +172,32 @@ pub fn apply_reboxing_candidates<'db>( trace!("Applying {} reboxing optimization(s).", candidates.len()); + let mut renamer = VarRenamer::default(); + let mut stmts_to_remove = Vec::new(); + for candidate in candidates { - apply_reboxing_candidate(db, lowered, candidate); + match &candidate.source { + ReboxingValue::Revoked => {} + ReboxingValue::Unboxed(id) => { + renamer.renamed_vars.insert(candidate.reboxed_var, *id); + stmts_to_remove.push(candidate.into_box_location); + } + ReboxingValue::MemberOfUnboxed { source, member } => { + replace_into_box_call(db, lowered, candidate, source, *member); + } + } + } + + // We sort the candidates such that removal of statements can be done in reverse order. + // Statements are expected to be in order due to analysis being forward, but we do not require + // the ordered assumption. + stmts_to_remove.sort_by_key(|(block_id, stmt_idx)| (block_id.0, *stmt_idx)); + for (block_id, stmt_idx) in stmts_to_remove.into_iter().rev() { + lowered.blocks[block_id].statements.remove(stmt_idx); + } + + for block in lowered.blocks.iter_mut() { + *block = renamer.rebuild_block(block); } } @@ -190,11 +216,13 @@ pub fn apply_reboxing<'db>(db: &'db dyn Database, lowered: &mut Lowered<'db>) { } } -/// Applies a single reboxing optimization for the given candidate. -fn apply_reboxing_candidate<'db>( +/// Replaces the call to `into_box` with a call to `struct_boxed_deconstruct`. +fn replace_into_box_call<'db>( db: &'db dyn Database, lowered: &mut Lowered<'db>, candidate: &ReboxCandidate, + source: &Rc, + member: usize, ) { trace!( "Applying optimization: candidate={:?}, reboxed={}", @@ -202,13 +230,6 @@ fn apply_reboxing_candidate<'db>( candidate.reboxed_var.index() ); - // TODO(eytan-starkware): Handle snapshot of box (e.g., @Box). - // Only support MemberOfUnboxed where source is Unboxed for now. - let ReboxingValue::MemberOfUnboxed { source, member } = &candidate.source else { - // If source is not member of unboxed, we are reboxing original value which is not supported - // yet. - return; - }; let ReboxingValue::Unboxed(source_var) = **source else { // When source of the value is not `Unboxes`, it is a nested MemberOfUnboxed, which is not // supported yet. @@ -220,7 +241,7 @@ fn apply_reboxing_candidate<'db>( db, &mut lowered.variables, source_var, - *member, + member, candidate.reboxed_var, &lowered.blocks[into_box_block].statements[into_box_stmt_idx], ) { diff --git a/crates/cairo-lang-lowering/src/optimizations/test_data/reboxing b/crates/cairo-lang-lowering/src/optimizations/test_data/reboxing index a8e712cfd18..71246bc5ea4 100644 --- a/crates/cairo-lang-lowering/src/optimizations/test_data/reboxing +++ b/crates/cairo-lang-lowering/src/optimizations/test_data/reboxing @@ -229,8 +229,6 @@ test_reboxing_analysis //! > function_name rebox_whole -//! > TODO(eytan-starkware): Add support for whole var reboxing - //! > module_code struct Simple { value: felt252, @@ -262,9 +260,8 @@ Parameters: v0: core::box::Box:: blk0 (root): Statements: (v1: test::Simple) <- core::box::unbox::(v0) - (v2: core::box::Box::) <- core::box::into_box::(v1) End: - Return(v2) + Return(v0) //! > ========================================================================== @@ -751,3 +748,104 @@ Statements: (v7: (core::box::Box::<@core::felt252>, @test::NonDrop)) <- struct_construct(v4, v6) End: Return(v7) + +//! > ========================================================================== + +//! > Test mixed reboxing types in different blocks + +//! > test_runner_name +test_reboxing_analysis + +//! > function_name +mixed_reboxing + +//! > TODO(eytan-starkware): When removing demand for Copy, check no double +//! > into_box remain + +//! > module_code +#[derive(Drop, Copy)] +struct Point { + x: felt252, + y: felt252, +} + +//! > function_code +fn mixed_reboxing(p: Box, flag: bool) -> Box { + if flag { + let unboxed = p.unbox(); + BoxTrait::new(unboxed.x) + } else { + let unboxed = p.unbox(); + let boxed = BoxTrait::new(unboxed); + let unboxed2 = boxed.unbox(); + BoxTrait::new(unboxed2.y) + } +} + +//! > candidates +v5, v9, v14 + +//! > before +Parameters: v0: core::box::Box::, v1: core::bool +blk0 (root): +Statements: +End: + Match(match_enum(v1) { + bool::False(v2) => blk1, + bool::True(v3) => blk2, + }) + +blk1: +Statements: + (v4: test::Point) <- core::box::unbox::(v0) + (v5: core::box::Box::) <- core::box::into_box::(v4) + (v6: test::Point) <- core::box::unbox::(v5) + (v7: core::felt252, v8: core::felt252) <- struct_destructure(v6) + (v9: core::box::Box::) <- core::box::into_box::(v8) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: test::Point) <- core::box::unbox::(v0) + (v12: core::felt252, v13: core::felt252) <- struct_destructure(v11) + (v14: core::box::Box::) <- core::box::into_box::(v12) +End: + Goto(blk3, {v14 -> v10}) + +blk3: +Statements: +End: + Return(v10) + +//! > after +Parameters: v0: core::box::Box::, v1: core::bool +blk0 (root): +Statements: +End: + Match(match_enum(v1) { + bool::False(v2) => blk1, + bool::True(v3) => blk2, + }) + +blk1: +Statements: + (v4: test::Point) <- core::box::unbox::(v0) + (v6: test::Point) <- core::box::unbox::(v0) + (v7: core::felt252, v8: core::felt252) <- struct_destructure(v6) + (v15: core::box::Box::, v9: core::box::Box::) <- struct_destructure(v0) +End: + Goto(blk3, {v9 -> v10}) + +blk2: +Statements: + (v11: test::Point) <- core::box::unbox::(v0) + (v12: core::felt252, v13: core::felt252) <- struct_destructure(v11) + (v14: core::box::Box::, v16: core::box::Box::) <- struct_destructure(v0) +End: + Goto(blk3, {v14 -> v10}) + +blk3: +Statements: +End: + Return(v10)