Skip to content

Commit 7644b4d

Browse files
committed
refactor: gateway compiler handle declare tx
1 parent d09a690 commit 7644b4d

File tree

8 files changed

+87
-65
lines changed

8 files changed

+87
-65
lines changed

Cargo.lock

+1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/gateway/src/compilation.rs

+55-45
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
55
use cairo_lang_starknet_classes::casm_contract_class::{
66
CasmContractClass, CasmContractEntryPoints,
77
};
8+
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
89
use starknet_api::core::CompiledClassHash;
910
use starknet_api::rpc_transaction::RPCDeclareTransaction;
1011
use starknet_sierra_compile::compile::compile_sierra_to_casm;
@@ -19,6 +20,7 @@ use crate::utils::is_subsequence;
1920
#[path = "compilation_test.rs"]
2021
mod compilation_test;
2122

23+
// TODO(Arni): Pass the compiler with dependancy injection.
2224
#[derive(Clone)]
2325
pub struct GatewayCompiler {
2426
#[allow(dead_code)]
@@ -29,64 +31,56 @@ impl GatewayCompiler {
2931
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
3032
/// class wrapped in a [`ClassInfo`].
3133
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
32-
pub fn compile_contract_class(
34+
pub fn process_declare_tx(
3335
&self,
3436
declare_tx: &RPCDeclareTransaction,
3537
) -> GatewayResult<ClassInfo> {
3638
let RPCDeclareTransaction::V3(tx) = declare_tx;
37-
let starknet_api_contract_class = &tx.contract_class;
38-
let cairo_lang_contract_class =
39-
into_contract_class_for_compilation(starknet_api_contract_class);
39+
let rpc_contract_class = &tx.contract_class;
40+
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);
4041

41-
// Compile Sierra to Casm.
42+
let casm_contract_class = self.compile(cairo_lang_contract_class)?;
43+
44+
validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
45+
validate_casm_class(&casm_contract_class)?;
46+
47+
Ok(ClassInfo::new(
48+
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
49+
rpc_contract_class.sierra_program.len(),
50+
rpc_contract_class.abi.len(),
51+
)?)
52+
}
53+
54+
// TODO(Arni): Pass the compilation args from the config.
55+
fn compile(
56+
&self,
57+
cairo_lang_contract_class: CairoLangContractClass,
58+
) -> Result<CasmContractClass, GatewayError> {
4259
let catch_unwind_result =
4360
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
44-
let casm_contract_class = match catch_unwind_result {
45-
Ok(compilation_result) => compilation_result?,
46-
Err(_) => {
47-
// TODO(Arni): Log the panic.
48-
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
49-
}
50-
};
51-
self.validate_casm_class(&casm_contract_class)?;
61+
let casm_contract_class =
62+
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;
5263

53-
let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
54-
if hash_result != tx.compiled_class_hash {
55-
return Err(GatewayError::CompiledClassHashMismatch {
56-
supplied: tx.compiled_class_hash,
57-
hash_result,
58-
});
59-
}
60-
61-
// Convert Casm contract class to Starknet contract class directly.
62-
let blockifier_contract_class =
63-
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
64-
let class_info = ClassInfo::new(
65-
&blockifier_contract_class,
66-
starknet_api_contract_class.sierra_program.len(),
67-
starknet_api_contract_class.abi.len(),
68-
)?;
69-
Ok(class_info)
64+
Ok(casm_contract_class)
7065
}
66+
}
7167

72-
// TODO(Arni): Add test.
73-
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
74-
let CasmContractEntryPoints { external, l1_handler, constructor } =
75-
&contract_class.entry_points_by_type;
76-
let entry_points_iterator =
77-
external.iter().chain(l1_handler.iter()).chain(constructor.iter());
68+
// TODO(Arni): Add test.
69+
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
70+
let CasmContractEntryPoints { external, l1_handler, constructor } =
71+
&contract_class.entry_points_by_type;
72+
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());
7873

79-
for entry_point in entry_points_iterator {
80-
let builtins = &entry_point.builtins;
81-
if !is_subsequence(builtins, supported_builtins()) {
82-
return Err(GatewayError::UnsupportedBuiltins {
83-
builtins: builtins.clone(),
84-
supported_builtins: supported_builtins().to_vec(),
85-
});
86-
}
74+
for entry_point in entry_points_iterator {
75+
let builtins = &entry_point.builtins;
76+
if !is_subsequence(builtins, supported_builtins()) {
77+
return Err(GatewayError::UnsupportedBuiltins {
78+
builtins: builtins.clone(),
79+
supported_builtins: supported_builtins().to_vec(),
80+
});
8781
}
88-
Ok(())
8982
}
83+
Ok(())
9084
}
9185

9286
// TODO(Arni): Add to a config.
@@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
10195
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
10296
})
10397
}
98+
99+
/// Validates that the compiled class hash of the compiled contract class matches the supplied
100+
/// compiled class hash.
101+
fn validate_compiled_class_hash(
102+
casm_contract_class: &CasmContractClass,
103+
supplied_compiled_class_hash: &CompiledClassHash,
104+
) -> Result<(), GatewayError> {
105+
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
106+
if compiled_class_hash != *supplied_compiled_class_hash {
107+
return Err(GatewayError::CompiledClassHashMismatch {
108+
supplied: *supplied_compiled_class_hash,
109+
hash_result: compiled_class_hash,
110+
});
111+
}
112+
Ok(())
113+
}

crates/gateway/src/compilation_test.rs

+10-10
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,23 @@ fn gateway_compiler() -> GatewayCompiler {
1515
GatewayCompiler { config: Default::default() }
1616
}
1717

18+
// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection.
1819
#[rstest]
19-
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
20+
fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) {
2021
let mut tx = assert_matches!(
2122
declare_tx(),
2223
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
2324
);
24-
let expected_hash_result = tx.compiled_class_hash;
25-
let supplied_hash = CompiledClassHash::default();
26-
27-
tx.compiled_class_hash = supplied_hash;
25+
let expected_hash = tx.compiled_class_hash;
26+
let wrong_supplied_hash = CompiledClassHash::default();
27+
tx.compiled_class_hash = wrong_supplied_hash;
2828
let declare_tx = RPCDeclareTransaction::V3(tx);
2929

30-
let result = gateway_compiler.compile_contract_class(&declare_tx);
30+
let result = gateway_compiler.process_declare_tx(&declare_tx);
3131
assert_matches!(
3232
result.unwrap_err(),
3333
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
34-
if supplied == supplied_hash && hash_result == expected_hash_result
34+
if supplied == wrong_supplied_hash && hash_result == expected_hash
3535
);
3636
}
3737

@@ -45,7 +45,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
4545
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
4646
let declare_tx = RPCDeclareTransaction::V3(tx);
4747

48-
let result = gateway_compiler.compile_contract_class(&declare_tx);
48+
let result = gateway_compiler.process_declare_tx(&declare_tx);
4949
assert_matches!(
5050
result.unwrap_err(),
5151
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
@@ -55,15 +55,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
5555
}
5656

5757
#[rstest]
58-
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
58+
fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) {
5959
let declare_tx = assert_matches!(
6060
declare_tx(),
6161
RPCTransaction::Declare(declare_tx) => declare_tx
6262
);
6363
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
6464
let contract_class = &declare_tx_v3.contract_class;
6565

66-
let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
66+
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
6767
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
6868
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
6969
assert_eq!(class_info.abi_length(), contract_class.abi.len());

crates/gateway/src/gateway.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ fn process_tx(
128128
// Compile Sierra to Casm.
129129
let optional_class_info = match &tx {
130130
RPCTransaction::Declare(declare_tx) => {
131-
Some(gateway_compiler.compile_contract_class(declare_tx)?)
131+
Some(gateway_compiler.process_declare_tx(declare_tx)?)
132132
}
133133
_ => None,
134134
};

crates/gateway/src/stateful_transaction_validator_test.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn test_stateful_tx_validator(
9797
let optional_class_info = match &external_tx {
9898
RPCTransaction::Declare(declare_tx) => Some(
9999
GatewayCompiler { config: GatewayCompilerConfig {} }
100-
.compile_contract_class(declare_tx)
100+
.process_declare_tx(declare_tx)
101101
.unwrap(),
102102
),
103103
_ => None,

crates/mempool_test_utils/Cargo.toml

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ license.workspace = true
1010
[dependencies]
1111
assert_matches.workspace = true
1212
blockifier = { workspace = true, features = ["testing"] }
13+
rstest.workspace = true
1314
starknet-types-core.workspace = true
1415
starknet_api.workspace = true
1516
serde_json.workspace = true

crates/mempool_test_utils/src/starknet_api_test_utils.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
9090
)
9191
}
9292

93-
pub fn declare_tx() -> RPCTransaction {
93+
/// Get the contract class used for testing.
94+
pub fn contract_class() -> ContractClass {
9495
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
9596
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
96-
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
97-
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
97+
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
98+
}
99+
100+
/// Get the compiled class hash corresponding to the contract class used for testing.
101+
pub fn compiled_class_hash() -> CompiledClassHash {
102+
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
103+
}
104+
105+
pub fn declare_tx() -> RPCTransaction {
106+
let contract_class = contract_class();
107+
let compiled_class_hash = compiled_class_hash();
98108

99109
let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
100110
let account_address = account_contract.get_instance_address(0);

crates/starknet_sierra_compile/src/utils.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,25 @@ use cairo_lang_starknet_classes::contract_class::{
66
};
77
use cairo_lang_utils::bigint::BigUintAsHex;
88
use starknet_api::rpc_transaction::{
9-
ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType,
9+
ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType,
1010
};
1111
use starknet_api::state::EntryPoint as StarknetApiEntryPoint;
1212
use starknet_types_core::felt::Felt;
1313

1414
/// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi`
1515
/// field is None as it is not relevant for the compilation.
1616
pub fn into_contract_class_for_compilation(
17-
starknet_api_contract_class: &StarknetApiContractClass,
17+
rpc_contract_class: &RpcContractClass,
1818
) -> CairoLangContractClass {
1919
let sierra_program =
20-
starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
20+
rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
2121
let entry_points_by_type =
22-
into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type);
22+
into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type);
2323

2424
CairoLangContractClass {
2525
sierra_program,
2626
sierra_program_debug_info: None,
27-
contract_class_version: starknet_api_contract_class.contract_class_version.clone(),
27+
contract_class_version: rpc_contract_class.contract_class_version.clone(),
2828
entry_points_by_type,
2929
abi: None,
3030
}

0 commit comments

Comments
 (0)