Skip to content

Commit 9be13c0

Browse files
added specialization for recursive functions
1 parent d6823c2 commit 9be13c0

File tree

6 files changed

+832
-73
lines changed

6 files changed

+832
-73
lines changed

crates/cairo-lang-lowering/src/lower/generated_test.rs

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@ use cairo_lang_defs::ids::TopLevelLanguageElementId;
55
use cairo_lang_filesystem::location_marks::get_location_marks;
66
use cairo_lang_semantic::test_utils::setup_test_function;
77
use cairo_lang_test_utils::parse_test_file::TestRunnerResult;
8-
use cairo_lang_utils::Intern;
98
use cairo_lang_utils::ordered_hash_map::OrderedHashMap;
9+
use cairo_lang_utils::{Intern, try_extract_matches};
1010

11-
use crate::LoweringStage;
1211
use crate::db::LoweringGroup;
1312
use crate::fmt::LoweredFormatter;
14-
use crate::ids::{ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, GeneratedFunction};
13+
use crate::ids::{
14+
ConcreteFunctionWithBodyId, ConcreteFunctionWithBodyLongId, FunctionLongId, GeneratedFunction,
15+
};
1516
use crate::test_utils::LoweringDatabaseForTesting;
17+
use crate::{LoweringStage, Statement};
1618

1719
cairo_lang_test_utils::test_file_test!(
1820
generated,
@@ -64,6 +66,34 @@ fn test_generated_function(
6466
lowering.debug(&LoweredFormatter::new(db, &lowering.variables))
6567
)
6668
.unwrap();
69+
// Collect the calls to the generated functions in the final lowering of the main function.
70+
let calls = lowering
71+
.blocks
72+
.iter()
73+
.flat_map(|(_, block)| {
74+
block
75+
.statements
76+
.iter()
77+
.filter_map(|statement| try_extract_matches!(statement, Statement::Call))
78+
})
79+
.flat_map(|call| match call.function.long(db) {
80+
FunctionLongId::Semantic(_) => None,
81+
FunctionLongId::Generated(generated) => Some((
82+
generated.key,
83+
ConcreteFunctionWithBodyLongId::Generated(*generated).intern(db),
84+
)),
85+
FunctionLongId::Specialized(specialized) => try_extract_matches!(
86+
specialized.base.function_id(db).unwrap().long(db),
87+
FunctionLongId::Generated
88+
)
89+
.map(|generated| {
90+
(
91+
generated.key,
92+
ConcreteFunctionWithBodyLongId::Specialized(specialized.clone()).intern(db),
93+
)
94+
}),
95+
})
96+
.collect::<OrderedHashMap<_, _>>();
6797

6898
for (key, lowering) in multi_lowering.generated_lowerings.iter() {
6999
let generated_id = ConcreteFunctionWithBodyLongId::Generated(GeneratedFunction {
@@ -103,6 +133,17 @@ fn test_generated_function(
103133
lowering.debug(&LoweredFormatter::new(db, &lowering.variables))
104134
)
105135
.unwrap();
136+
137+
if let Some(call) = calls.get(key) {
138+
let lowering = db.lowered_body(*call, LoweringStage::Final).unwrap();
139+
writeln!(
140+
&mut writer,
141+
"Final lowering of specialized call {:?}:\n{:?}",
142+
call.full_path(db),
143+
lowering.debug(&LoweredFormatter::new(db, &lowering.variables))
144+
)
145+
.unwrap();
146+
}
106147
}
107148
}
108149

crates/cairo-lang-lowering/src/lower/test_data/for

Lines changed: 122 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -54,41 +54,39 @@ Final lowering:
5454
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin
5555
blk0 (root):
5656
Statements:
57-
(v2: core::felt252) <- 1
58-
(v3: core::felt252) <- 2
59-
(v4: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
60-
(v5: core::felt252) <- 10
57+
(v2: core::array::Array::<core::felt252>) <- core::array::array_new::<core::felt252>()
58+
(v3: core::felt252) <- 10
59+
(v4: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v2, v3)
60+
(v5: core::felt252) <- 11
6161
(v6: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v4, v5)
62-
(v7: core::felt252) <- 11
62+
(v7: core::felt252) <- 12
6363
(v8: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v6, v7)
64-
(v9: core::felt252) <- 12
64+
(v9: core::felt252) <- 13
6565
(v10: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v8, v9)
66-
(v11: core::felt252) <- 13
67-
(v12: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v10, v11)
68-
(v13: core::array::Array::<core::felt252>, v14: @core::array::Array::<core::felt252>) <- snapshot(v12)
69-
(v15: core::array::Span::<core::felt252>) <- struct_construct(v14)
70-
(v16: core::array::SpanIter::<core::felt252>) <- struct_construct(v15)
71-
(v17: core::RangeCheck, v18: core::gas::GasBuiltin, v19: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[118-164](v0, v1, v16, v2, v3)
72-
End:
73-
Match(match_enum(v19) {
74-
PanicResult::Ok(v20) => blk1,
75-
PanicResult::Err(v21) => blk2,
66+
(v11: core::array::Array::<core::felt252>, v12: @core::array::Array::<core::felt252>) <- snapshot(v10)
67+
(v13: core::array::Span::<core::felt252>) <- struct_construct(v12)
68+
(v14: core::array::SpanIter::<core::felt252>) <- struct_construct(v13)
69+
(v15: core::RangeCheck, v16: core::gas::GasBuiltin, v17: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[118-164]{NotSpecialized, 1, 2, }(v0, v1, v14)
70+
End:
71+
Match(match_enum(v17) {
72+
PanicResult::Ok(v18) => blk1,
73+
PanicResult::Err(v19) => blk2,
7674
})
7775

7876
blk1:
7977
Statements:
80-
(v22: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v18)
81-
(v23: core::array::SpanIter::<core::felt252>, v24: core::felt252, v25: ()) <- struct_destructure(v20)
82-
(v26: (core::felt252,)) <- struct_construct(v24)
83-
(v27: core::panics::PanicResult::<(core::felt252,)>) <- PanicResult::Ok(v26)
78+
(v20: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v16)
79+
(v21: core::array::SpanIter::<core::felt252>, v22: core::felt252, v23: ()) <- struct_destructure(v18)
80+
(v24: (core::felt252,)) <- struct_construct(v22)
81+
(v25: core::panics::PanicResult::<(core::felt252,)>) <- PanicResult::Ok(v24)
8482
End:
85-
Return(v17, v22, v27)
83+
Return(v15, v20, v25)
8684

8785
blk2:
8886
Statements:
89-
(v28: core::panics::PanicResult::<(core::felt252,)>) <- PanicResult::Err(v21)
87+
(v26: core::panics::PanicResult::<(core::felt252,)>) <- PanicResult::Err(v19)
9088
End:
91-
Return(v17, v18, v28)
89+
Return(v15, v16, v26)
9290

9391

9492
Generated loop lowering for source location:
@@ -170,6 +168,56 @@ Statements:
170168
End:
171169
Return(v7, v8, v28)
172170

171+
172+
Final lowering of specialized call "test::foo[118-164]{NotSpecialized, 1, 2, }":
173+
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin, v2: core::array::SpanIter::<core::felt252>
174+
blk0 (root):
175+
Statements:
176+
End:
177+
Match(match core::gas::withdraw_gas(v0, v1) {
178+
Option::Some(v3, v4) => blk1,
179+
Option::None(v5, v6) => blk4,
180+
})
181+
182+
blk1:
183+
Statements:
184+
(v7: core::array::Span::<core::felt252>) <- struct_destructure(v2)
185+
(v8: @core::array::Array::<core::felt252>) <- struct_destructure(v7)
186+
End:
187+
Match(match core::array::array_snapshot_pop_front::<core::felt252>(v8) {
188+
Option::Some(v9, v10) => blk2,
189+
Option::None(v11) => blk3,
190+
})
191+
192+
blk2:
193+
Statements:
194+
(v12: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
195+
(v13: core::felt252) <- 3
196+
(v14: core::array::Span::<core::felt252>) <- struct_construct(v9)
197+
(v15: core::array::SpanIter::<core::felt252>) <- struct_construct(v14)
198+
(v16: core::RangeCheck, v17: core::gas::GasBuiltin, v18: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- test::foo[118-164]{NotSpecialized, NotSpecialized, 2, }(v3, v12, v15, v13)
199+
End:
200+
Return(v16, v17, v18)
201+
202+
blk3:
203+
Statements:
204+
(v19: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v4)
205+
(v20: core::felt252) <- 1
206+
(v21: core::array::Span::<core::felt252>) <- struct_construct(v11)
207+
(v22: core::array::SpanIter::<core::felt252>) <- struct_construct(v21)
208+
(v23: ()) <- struct_construct()
209+
(v24: (core::array::SpanIter::<core::felt252>, core::felt252, ())) <- struct_construct(v22, v20, v23)
210+
(v25: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- PanicResult::Ok(v24)
211+
End:
212+
Return(v3, v19, v25)
213+
214+
blk4:
215+
Statements:
216+
(v26: (core::panics::Panic, core::array::Array::<core::felt252>)) <- core::panic_with_const_felt252::<375233589013918064796019>()
217+
(v27: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::felt252, ())>) <- PanicResult::Err(v26)
218+
End:
219+
Return(v5, v6, v27)
220+
173221
//! > ==========================================================================
174222

175223
//! > Test calling function with generics Self.
@@ -424,3 +472,54 @@ Statements:
424472
(v29: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::array::Array::<core::felt252>, ())>) <- PanicResult::Err(v28)
425473
End:
426474
Return(v6, v7, v29)
475+
476+
477+
Final lowering of specialized call "test::foo[78-125]":
478+
Parameters: v0: core::RangeCheck, v1: core::gas::GasBuiltin, v2: core::array::SpanIter::<core::felt252>, v3: core::array::Array::<core::felt252>
479+
blk0 (root):
480+
Statements:
481+
End:
482+
Match(match core::gas::withdraw_gas(v0, v1) {
483+
Option::Some(v4, v5) => blk1,
484+
Option::None(v6, v7) => blk4,
485+
})
486+
487+
blk1:
488+
Statements:
489+
(v8: core::array::Span::<core::felt252>) <- struct_destructure(v2)
490+
(v9: @core::array::Array::<core::felt252>) <- struct_destructure(v8)
491+
End:
492+
Match(match core::array::array_snapshot_pop_front::<core::felt252>(v9) {
493+
Option::Some(v10, v11) => blk2,
494+
Option::None(v12) => blk3,
495+
})
496+
497+
blk2:
498+
Statements:
499+
(v13: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v5)
500+
(v14: @core::felt252) <- core::box::unbox::<@core::felt252>(v11)
501+
(v15: core::felt252) <- desnap(v14)
502+
(v16: core::array::Array::<core::felt252>) <- core::array::array_append::<core::felt252>(v3, v15)
503+
(v17: core::array::Span::<core::felt252>) <- struct_construct(v10)
504+
(v18: core::array::SpanIter::<core::felt252>) <- struct_construct(v17)
505+
(v19: core::RangeCheck, v20: core::gas::GasBuiltin, v21: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::array::Array::<core::felt252>, ())>) <- test::foo[78-125](v4, v13, v18, v16)
506+
End:
507+
Return(v19, v20, v21)
508+
509+
blk3:
510+
Statements:
511+
(v22: core::gas::GasBuiltin) <- core::gas::redeposit_gas(v5)
512+
(v23: core::array::Span::<core::felt252>) <- struct_construct(v12)
513+
(v24: core::array::SpanIter::<core::felt252>) <- struct_construct(v23)
514+
(v25: ()) <- struct_construct()
515+
(v26: (core::array::SpanIter::<core::felt252>, core::array::Array::<core::felt252>, ())) <- struct_construct(v24, v3, v25)
516+
(v27: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::array::Array::<core::felt252>, ())>) <- PanicResult::Ok(v26)
517+
End:
518+
Return(v4, v22, v27)
519+
520+
blk4:
521+
Statements:
522+
(v28: (core::panics::Panic, core::array::Array::<core::felt252>)) <- core::panic_with_const_felt252::<375233589013918064796019>()
523+
(v29: core::panics::PanicResult::<(core::array::SpanIter::<core::felt252>, core::array::Array::<core::felt252>, ())>) <- PanicResult::Err(v28)
524+
End:
525+
Return(v6, v7, v29)

0 commit comments

Comments
 (0)