Skip to content

Commit 989780d

Browse files
committed
refactor: gateway compiler handle declare tx
1 parent ab3a562 commit 989780d

File tree

10 files changed

+8095
-67
lines changed

10 files changed

+8095
-67
lines changed

Cargo.lock

+2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/gateway/src/compilation.rs

+40-30
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl
55
use cairo_lang_starknet_classes::casm_contract_class::{
66
CasmContractClass, CasmContractEntryPoints,
77
};
8+
use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass;
89
use starknet_api::core::CompiledClassHash;
910
use starknet_api::rpc_transaction::RPCDeclareTransaction;
1011
use starknet_sierra_compile::compile::compile_sierra_to_casm;
@@ -29,44 +30,37 @@ impl GatewayCompiler {
2930
/// Formats the contract class for compilation, compiles it, and returns the compiled contract
3031
/// class wrapped in a [`ClassInfo`].
3132
/// Assumes the contract class is of a Sierra program which is compiled to Casm.
32-
pub fn compile_contract_class(
33+
pub fn process_declare_tx(
3334
&self,
3435
declare_tx: &RPCDeclareTransaction,
3536
) -> GatewayResult<ClassInfo> {
3637
let RPCDeclareTransaction::V3(tx) = declare_tx;
37-
let starknet_api_contract_class = &tx.contract_class;
38-
let cairo_lang_contract_class =
39-
into_contract_class_for_compilation(starknet_api_contract_class);
38+
let rpc_contract_class = &tx.contract_class;
39+
let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class);
4040

41-
// Compile Sierra to Casm.
42-
let catch_unwind_result =
43-
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
44-
let casm_contract_class = match catch_unwind_result {
45-
Ok(compilation_result) => compilation_result?,
46-
Err(_) => {
47-
// TODO(Arni): Log the panic.
48-
return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic));
49-
}
50-
};
41+
let casm_contract_class = self.compile(cairo_lang_contract_class)?;
42+
43+
validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?;
5144
self.validate_casm_class(&casm_contract_class)?;
5245

53-
let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash());
54-
if hash_result != tx.compiled_class_hash {
55-
return Err(GatewayError::CompiledClassHashMismatch {
56-
supplied: tx.compiled_class_hash,
57-
hash_result,
58-
});
59-
}
46+
Ok(ClassInfo::new(
47+
&ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?),
48+
rpc_contract_class.sierra_program.len(),
49+
rpc_contract_class.abi.len(),
50+
)?)
51+
}
6052

61-
// Convert Casm contract class to Starknet contract class directly.
62-
let blockifier_contract_class =
63-
ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?);
64-
let class_info = ClassInfo::new(
65-
&blockifier_contract_class,
66-
starknet_api_contract_class.sierra_program.len(),
67-
starknet_api_contract_class.abi.len(),
68-
)?;
69-
Ok(class_info)
53+
/// TODO(Arni): Pass the compilation args from the config.
54+
fn compile(
55+
&self,
56+
cairo_lang_contract_class: CairoLangContractClass,
57+
) -> Result<CasmContractClass, GatewayError> {
58+
let catch_unwind_result =
59+
panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class));
60+
let casm_contract_class =
61+
catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??;
62+
63+
Ok(casm_contract_class)
7064
}
7165

7266
// TODO(Arni): Add test.
@@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec<String> {
10195
SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::<Vec<String>>()
10296
})
10397
}
98+
99+
/// Validates that the compiled class hash of the compiled contract class matches the supplied
100+
/// compiled class hash.
101+
fn validate_compiled_class_hash(
102+
casm_contract_class: &CasmContractClass,
103+
supplied_compiled_class_hash: &CompiledClassHash,
104+
) -> Result<(), GatewayError> {
105+
let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash());
106+
if compiled_class_hash != *supplied_compiled_class_hash {
107+
return Err(GatewayError::CompiledClassHashMismatch {
108+
supplied: *supplied_compiled_class_hash,
109+
hash_result: compiled_class_hash,
110+
});
111+
}
112+
Ok(())
113+
}

crates/gateway/src/compilation_test.rs

+27-26
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,19 @@
11
use assert_matches::assert_matches;
22
use blockifier::execution::contract_class::ContractClass;
33
use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError;
4-
use mempool_test_utils::starknet_api_test_utils::declare_tx;
4+
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
5+
use mempool_test_utils::starknet_api_test_utils::{
6+
casm_contract_class, compiled_class_hash, contract_class, declare_tx,
7+
};
58
use rstest::{fixture, rstest};
69
use starknet_api::core::CompiledClassHash;
7-
use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction};
10+
use starknet_api::rpc_transaction::{
11+
ContractClass as RpcContractClass, RPCDeclareTransaction, RPCTransaction,
12+
};
813
use starknet_sierra_compile::errors::CompilationUtilError;
14+
use starknet_sierra_compile::utils::into_contract_class_for_compilation;
915

10-
use crate::compilation::GatewayCompiler;
16+
use crate::compilation::{validate_compiled_class_hash, GatewayCompiler};
1117
use crate::errors::GatewayError;
1218

1319
#[fixture]
@@ -16,36 +22,31 @@ fn gateway_compiler() -> GatewayCompiler {
1622
}
1723

1824
#[rstest]
19-
fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) {
20-
let mut tx = assert_matches!(
21-
declare_tx(),
22-
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
23-
);
24-
let expected_hash_result = tx.compiled_class_hash;
25-
let supplied_hash = CompiledClassHash::default();
26-
27-
tx.compiled_class_hash = supplied_hash;
28-
let declare_tx = RPCDeclareTransaction::V3(tx);
25+
fn test_compile_contract_class_compiled_class_hash_mismatch(
26+
casm_contract_class: CasmContractClass,
27+
compiled_class_hash: CompiledClassHash,
28+
) {
29+
let wrong_supplied_hash = CompiledClassHash::default();
30+
let expected_hash = compiled_class_hash;
2931

30-
let result = gateway_compiler.compile_contract_class(&declare_tx);
32+
let result = validate_compiled_class_hash(&casm_contract_class, &wrong_supplied_hash);
3133
assert_matches!(
3234
result.unwrap_err(),
3335
GatewayError::CompiledClassHashMismatch { supplied, hash_result }
34-
if supplied == supplied_hash && hash_result == expected_hash_result
36+
if supplied == wrong_supplied_hash && hash_result == expected_hash
3537
);
3638
}
3739

3840
#[rstest]
39-
fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
40-
let mut tx = assert_matches!(
41-
declare_tx(),
42-
RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx
43-
);
44-
// Truncate the sierra program to trigger an error.
45-
tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec();
46-
let declare_tx = RPCDeclareTransaction::V3(tx);
41+
fn test_compile_contract_class_bad_sierra(
42+
gateway_compiler: GatewayCompiler,
43+
mut contract_class: RpcContractClass,
44+
) {
45+
// Create a corrupted contract class.
46+
contract_class.sierra_program = contract_class.sierra_program[..100].to_vec();
4747

48-
let result = gateway_compiler.compile_contract_class(&declare_tx);
48+
let cairo_lang_contract_class = into_contract_class_for_compilation(&contract_class);
49+
let result = gateway_compiler.compile(cairo_lang_contract_class);
4950
assert_matches!(
5051
result.unwrap_err(),
5152
GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError(
@@ -55,15 +56,15 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) {
5556
}
5657

5758
#[rstest]
58-
fn test_compile_contract_class(gateway_compiler: GatewayCompiler) {
59+
fn test_process_declare_tx(gateway_compiler: GatewayCompiler) {
5960
let declare_tx = assert_matches!(
6061
declare_tx(),
6162
RPCTransaction::Declare(declare_tx) => declare_tx
6263
);
6364
let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx;
6465
let contract_class = &declare_tx_v3.contract_class;
6566

66-
let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap();
67+
let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap();
6768
assert_matches!(class_info.contract_class(), ContractClass::V1(_));
6869
assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len());
6970
assert_eq!(class_info.abi_length(), contract_class.abi.len());

crates/gateway/src/gateway.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ fn process_tx(
128128
// Compile Sierra to Casm.
129129
let optional_class_info = match &tx {
130130
RPCTransaction::Declare(declare_tx) => {
131-
Some(gateway_compiler.compile_contract_class(declare_tx)?)
131+
Some(gateway_compiler.process_declare_tx(declare_tx)?)
132132
}
133133
_ => None,
134134
};

crates/gateway/src/stateful_transaction_validator_test.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ fn test_stateful_tx_validator(
9797
let optional_class_info = match &external_tx {
9898
RPCTransaction::Declare(declare_tx) => Some(
9999
GatewayCompiler { config: GatewayCompilerConfig {} }
100-
.compile_contract_class(declare_tx)
100+
.process_declare_tx(declare_tx)
101101
.unwrap(),
102102
),
103103
_ => None,

crates/mempool_test_utils/Cargo.toml

+2
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ license.workspace = true
1010
[dependencies]
1111
assert_matches.workspace = true
1212
blockifier = { workspace = true, features = ["testing"] }
13+
cairo-lang-starknet-classes.workspace = true
14+
rstest.workspace = true
1315
starknet-types-core.workspace = true
1416
starknet_api.workspace = true
1517
serde_json.workspace = true

crates/mempool_test_utils/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod starknet_api_test_utils;
55

66
pub const TEST_FILES_FOLDER: &str = "crates/mempool_test_utils/test_files";
77
pub const CONTRACT_CLASS_FILE: &str = "contract_class.json";
8+
pub const CASM_CONTRACT_CLASS_FILE: &str = "casm_contract_class.json";
89
pub const COMPILED_CLASS_HASH_OF_CONTRACT_CLASS: &str =
910
"0x01e4f1248860f32c336f93f2595099aaa4959be515e40b75472709ef5243ae17";
1011
pub const FAULTY_ACCOUNT_CLASS_FILE: &str = "faulty_account.sierra.json";

crates/mempool_test_utils/src/starknet_api_test_utils.rs

+27-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use std::rc::Rc;
88
use assert_matches::assert_matches;
99
use blockifier::test_utils::contracts::FeatureContract;
1010
use blockifier::test_utils::{create_trivial_calldata, CairoVersion, NonceManager};
11+
use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass;
12+
use rstest::fixture;
1113
use serde_json::to_string_pretty;
1214
use starknet_api::core::{
1315
calculate_contract_address, ClassHash, CompiledClassHash, ContractAddress, Nonce,
@@ -26,7 +28,8 @@ use starknet_types_core::felt::Felt;
2628

2729
use crate::{
2830
declare_tx_args, deploy_account_tx_args, get_absolute_path, invoke_tx_args,
29-
COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE, TEST_FILES_FOLDER,
31+
CASM_CONTRACT_CLASS_FILE, COMPILED_CLASS_HASH_OF_CONTRACT_CLASS, CONTRACT_CLASS_FILE,
32+
TEST_FILES_FOLDER,
3033
};
3134

3235
pub const VALID_L1_GAS_MAX_AMOUNT: u64 = 203484;
@@ -90,11 +93,31 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping {
9093
)
9194
}
9295

93-
pub fn declare_tx() -> RPCTransaction {
96+
/// Get the contract class used for testing.
97+
#[fixture]
98+
pub fn contract_class() -> ContractClass {
9499
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
95100
let json_file_path = Path::new(CONTRACT_CLASS_FILE);
96-
let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap();
97-
let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS));
101+
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
102+
}
103+
104+
/// Get the casm contract class corresponding to the contract class used for testing.
105+
#[fixture]
106+
pub fn casm_contract_class() -> CasmContractClass {
107+
env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir.");
108+
let json_file_path = Path::new(CASM_CONTRACT_CLASS_FILE);
109+
serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap()
110+
}
111+
112+
/// Get the compiled class hash corresponding to the contract class used for testing.
113+
#[fixture]
114+
pub fn compiled_class_hash() -> CompiledClassHash {
115+
CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS))
116+
}
117+
118+
pub fn declare_tx() -> RPCTransaction {
119+
let contract_class = contract_class();
120+
let compiled_class_hash = compiled_class_hash();
98121

99122
let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1);
100123
let account_address = account_contract.get_instance_address(0);

0 commit comments

Comments
 (0)