Skip to content

Commit 4ab9474

Browse files
committed
refactor: gateway compiler handle declare tx
1 parent c9f77e7 commit 4ab9474

File tree

8 files changed

+88
-60
lines changed

8 files changed

+88
-60
lines changed

Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/gateway/src/compilation.rs

+55-45
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
55
use cairo_lang_starknet_classes::casm_contract_class::{
66
CasmContractClass, CasmContractEntryPoints,
77
};
8+
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
89
use starknet_api::core::CompiledClassHash;
910
use starknet_api::rpc_transaction::RPCDeclareTransaction;
1011
use starknet_sierra_compile::compile::compile_sierra_to_casm;
@@ -29,64 +30,57 @@ impl GatewayCompiler {
2930
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
3031
/// class wrapped in a [`ClassInfo`].
3132
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
32-
pub fn compile_contract_class(
33+
pub fn process_declare_tx(
3334
&self,
3435
declare_tx: &RPCDeclareTransaction,
3536
) -> GatewayResult<ClassInfo> {
3637
let RPCDeclareTransaction::V3(tx) = declare_tx;
37-
let starknet_api_contract_class = &tx.contract_class;
38-
let cairo_lang_contract_class =
39-
into_contract_class_for_compilation(starknet_api_contract_class);
38+
let rpc_contract_class = &tx.contract_class;
39+
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);
4040

41-
// Compile Sierra to Casm.
41+
let casm_contract_class = self.compile(cairo_lang_contract_class)?;
42+
43+
validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
44+
validate_casm_class(&casm_contract_class)?;
45+
46+
Ok(ClassInfo::new(
47+
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
48+
rpc_contract_class.sierra_program.len(),
49+
rpc_contract_class.abi.len(),
50+
)?)
51+
}
52+
53+
// TODO(Arni): Pass the compilation args from the config.
54+
// TODO(Arni): Pass the compiler with dependancy injection.
55+
fn compile(
56+
&self,
57+
cairo_lang_contract_class: CairoLangContractClass,
58+
) -> Result<CasmContractClass, GatewayError> {
4259
let catch_unwind_result =
4360
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
44-
let casm_contract_class = match catch_unwind_result {
45-
Ok(compilation_result) => compilation_result?,
46-
Err(_) => {
47-
// TODO(Arni): Log the panic.
48-
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
49-
}
50-
};
51-
self.validate_casm_class(&casm_contract_class)?;
61+
let casm_contract_class =
62+
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;
5263

53-
let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
54-
if hash_result != tx.compiled_class_hash {
55-
return Err(GatewayError::CompiledClassHashMismatch {
56-
supplied: tx.compiled_class_hash,
57-
hash_result,
58-
});
59-
}
60-
61-
// Convert Casm contract class to Starknet contract class directly.
62-
let blockifier_contract_class =
63-
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
64-
let class_info = ClassInfo::new(
65-
&blockifier_contract_class,
66-
starknet_api_contract_class.sierra_program.len(),
67-
starknet_api_contract_class.abi.len(),
68-
)?;
69-
Ok(class_info)
64+
Ok(casm_contract_class)
7065
}
66+
}
7167

72-
// TODO(Arni): Add test.
73-
fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> {
74-
let CasmContractEntryPoints { external, l1_handler, constructor } =
75-
&contract_class.entry_points_by_type;
76-
let entry_points_iterator =
77-
external.iter().chain(l1_handler.iter()).chain(constructor.iter());
68+
// TODO(Arni): Add test.
69+
fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> {
70+
let CasmContractEntryPoints { external, l1_handler, constructor } =
71+
&contract_class.entry_points_by_type;
72+
let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter());
7873

79-
for entry_point in entry_points_iterator {
80-
let builtins = &entry_point.builtins;
81-
if !is_subsequence(builtins, supported_builtins()) {
82-
return Err(GatewayError::UnsupportedBuiltins {
83-
builtins: builtins.clone(),
84-
supported_builtins: supported_builtins().to_vec(),
85-
});
86-
}
74+
for entry_point in entry_points_iterator {
75+
let builtins = &entry_point.builtins;
76+
if !is_subsequence(builtins, supported_builtins()) {
77+
return Err(GatewayError::UnsupportedBuiltins {
78+
builtins: builtins.clone(),
79+
supported_builtins: supported_builtins().to_vec(),
80+
});
8781
}
88-
Ok(())
8982
}
83+
Ok(())
9084
}
9185

9286
// TODO(Arni): Add to a config.
@@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
10195
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
10296
})
10397
}
98+
99+
/// Validates that the compiled class hash of the compiled contract class matches the supplied
100+
/// compiled class hash.
101+
fn validate_compiled_class_hash(
102+
casm_contract_class: &CasmContractClass,
103+
supplied_compiled_class_hash: &CompiledClassHash,
104+
) -> Result<(), GatewayError> {
105+
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
106+
if compiled_class_hash != *supplied_compiled_class_hash {
107+
return Err(GatewayError::CompiledClassHashMismatch {
108+
supplied: *supplied_compiled_class_hash,
109+
hash_result: compiled_class_hash,
110+
});
111+
}
112+
Ok(())
113+
}

crates/gateway/src/compilation_test.rs

+6-5
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@ fn gateway_compiler() -> GatewayCompiler {
1515
GatewayCompiler { config: Default::default() }
1616
}
1717

18+
// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection.
1819
#[rstest]
19-
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
20+
fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) {
2021
let mut tx = assert_matches!(
2122
declare_tx(),
2223
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
@@ -27,7 +28,7 @@ fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: G
2728
tx.compiled_class_hash = supplied_hash;
2829
let declare_tx = RPCDeclareTransaction::V3(tx);
2930

30-
let result = gateway_compiler.compile_contract_class(&declare_tx);
31+
let result = gateway_compiler.process_declare_tx(&declare_tx);
3132
assert_matches!(
3233
result.unwrap_err(),
3334
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
@@ -45,7 +46,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
4546
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
4647
let declare_tx = RPCDeclareTransaction::V3(tx);
4748

48-
let result = gateway_compiler.compile_contract_class(&declare_tx);
49+
let result = gateway_compiler.process_declare_tx(&declare_tx);
4950
assert_matches!(
5051
result.unwrap_err(),
5152
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
@@ -55,15 +56,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
5556
}
5657

5758
#[rstest]
58-
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
59+
fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) {
5960
let declare_tx = assert_matches!(
6061
declare_tx(),
6162
RPCTransaction::Declare(declare_tx) => declare_tx
6263
);
6364
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
6465
let contract_class = &declare_tx_v3.contract_class;
6566

66-
let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
67+
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
6768
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
6869
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
6970
assert_eq!(class_info.abi_length(), contract_class.abi.len());

crates/gateway/src/gateway.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ fn process_tx(
128128
// Compile Sierra to Casm.
129129
let optional_class_info = match &tx {
130130
RPCTransaction::Declare(declare_tx) => {
131-
Some(gateway_compiler.compile_contract_class(declare_tx)?)
131+
Some(gateway_compiler.process_declare_tx(declare_tx)?)
132132
}
133133
_ => None,
134134
};

crates/gateway/src/stateful_transaction_validator_test.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn test_stateful_tx_validator(
9797
let optional_class_info = match &external_tx {
9898
RPCTransaction::Declare(declare_tx) => Some(
9999
GatewayCompiler { config: GatewayCompilerConfig {} }
100-
.compile_contract_class(declare_tx)
100+
.process_declare_tx(declare_tx)
101101
.unwrap(),
102102
),
103103
_ => None,

crates/mempool_test_utils/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ license.workspace = true
1010
[dependencies]
1111
assert_matches.workspace = true
1212
blockifier = { workspace = true, features = ["testing"] }
13+
cairo-lang-starknet-classes.workspace = true
14+
rstest.workspace = true
1315
starknet-types-core.workspace = true
1416
starknet_api.workspace = true
1517
serde_json.workspace = true

crates/mempool_test_utils/src/starknet_api_test_utils.rs

+16-3
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::rc::Rc;
88
use assert_matches::assert_matches;
99
use blockifier::test_utils::contracts::FeatureContract;
1010
use blockifier::test_utils::{create_trivial_calldata, CairoVersion, NonceManager};
11+
use rstest::fixture;
1112
use serde_json::to_string_pretty;
1213
use starknet_api::core::{
1314
calculate_contract_address, ClassHash, CompiledClassHash, ContractAddress, Nonce,
@@ -90,11 +91,23 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
9091
)
9192
}
9293

93-
pub fn declare_tx() -> RPCTransaction {
94+
/// Get the contract class used for testing.
95+
#[fixture]
96+
pub fn contract_class() -> ContractClass {
9497
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
9598
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
96-
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
97-
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
99+
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
100+
}
101+
102+
/// Get the compiled class hash corresponding to the contract class used for testing.
103+
#[fixture]
104+
pub fn compiled_class_hash() -> CompiledClassHash {
105+
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
106+
}
107+
108+
pub fn declare_tx() -> RPCTransaction {
109+
let contract_class = contract_class();
110+
let compiled_class_hash = compiled_class_hash();
98111

99112
let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
100113
let account_address = account_contract.get_instance_address(0);

crates/starknet_sierra_compile/src/utils.rs

+5-5
Original file line numberDiff line numberDiff line change
@@ -6,25 +6,25 @@ use cairo_lang_starknet_classes::contract_class::{
66
};
77
use cairo_lang_utils::bigint::BigUintAsHex;
88
use starknet_api::rpc_transaction::{
9-
ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType,
9+
ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType,
1010
};
1111
use starknet_api::state::EntryPoint as StarknetApiEntryPoint;
1212
use starknet_types_core::felt::Felt;
1313

1414
/// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi`
1515
/// field is None as it is not relevant for the compilation.
1616
pub fn into_contract_class_for_compilation(
17-
starknet_api_contract_class: &StarknetApiContractClass,
17+
rpc_contract_class: &RpcContractClass,
1818
) -> CairoLangContractClass {
1919
let sierra_program =
20-
starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
20+
rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect();
2121
let entry_points_by_type =
22-
into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type);
22+
into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type);
2323

2424
CairoLangContractClass {
2525
sierra_program,
2626
sierra_program_debug_info: None,
27-
contract_class_version: starknet_api_contract_class.contract_class_version.clone(),
27+
contract_class_version: rpc_contract_class.contract_class_version.clone(),
2828
entry_points_by_type,
2929
abi: None,
3030
}

0 commit comments

Comments
 (0)