|
1 | 1 | use std::sync::Arc;
|
2 | 2 |
|
| 3 | +use assert_matches::assert_matches; |
3 | 4 | use axum::body::{Bytes, HttpBody};
|
4 | 5 | use axum::extract::State;
|
5 | 6 | use axum::http::StatusCode;
|
6 | 7 | use axum::response::{IntoResponse, Response};
|
7 | 8 | use blockifier::context::ChainInfo;
|
| 9 | +use blockifier::execution::contract_class::ContractClass; |
8 | 10 | use blockifier::test_utils::CairoVersion;
|
| 11 | +use cairo_lang_starknet_classes::allowed_libfuncs::AllowedLibfuncsError; |
9 | 12 | use rstest::{fixture, rstest};
|
10 |
| -use starknet_api::rpc_transaction::RPCTransaction; |
| 13 | +use starknet_api::core::CompiledClassHash; |
| 14 | +use starknet_api::rpc_transaction::{RPCDeclareTransaction, RPCTransaction}; |
11 | 15 | use starknet_api::transaction::TransactionHash;
|
12 | 16 | use starknet_mempool::communication::create_mempool_server;
|
13 | 17 | use starknet_mempool::mempool::Mempool;
|
14 | 18 | use starknet_mempool_types::communication::{MempoolClientImpl, MempoolRequestAndResponseSender};
|
| 19 | +use starknet_sierra_compile::compile::CompilationUtilError; |
15 | 20 | use tokio::sync::mpsc::channel;
|
16 | 21 | use tokio::task;
|
17 | 22 |
|
18 | 23 | use crate::config::{StatefulTransactionValidatorConfig, StatelessTransactionValidatorConfig};
|
| 24 | +use crate::errors::GatewayError; |
19 | 25 | use crate::gateway::{add_tx, compile_contract_class, AppState, SharedMempoolClient};
|
20 | 26 | use crate::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
|
21 | 27 | use crate::state_reader_test_utils::{
|
@@ -103,6 +109,60 @@ async fn test_add_tx(
|
103 | 109 | assert_eq!(tx_hash, serde_json::from_slice(response_bytes).unwrap());
|
104 | 110 | }
|
105 | 111 |
|
| 112 | +#[test] |
| 113 | +fn test_compile_contract_class_compiled_class_hash_missmatch() { |
| 114 | + let mut tx = assert_matches!( |
| 115 | + declare_tx(), |
| 116 | + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx |
| 117 | + ); |
| 118 | + let expected_hash_result = tx.compiled_class_hash; |
| 119 | + let supplied_hash = CompiledClassHash::default(); |
| 120 | + |
| 121 | + tx.compiled_class_hash = supplied_hash; |
| 122 | + let declare_tx = RPCDeclareTransaction::V3(tx); |
| 123 | + |
| 124 | + let result = compile_contract_class(&declare_tx); |
| 125 | + assert_matches!( |
| 126 | + result.unwrap_err(), |
| 127 | + GatewayError::CompiledClassHashMismatch { supplied, hash_result } |
| 128 | + if supplied == supplied_hash && hash_result == expected_hash_result |
| 129 | + ); |
| 130 | +} |
| 131 | + |
| 132 | +#[test] |
| 133 | +fn test_compile_contract_class_bad_sierra() { |
| 134 | + let mut tx = assert_matches!( |
| 135 | + declare_tx(), |
| 136 | + RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx |
| 137 | + ); |
| 138 | + // Truncate the sierra program to trigger an error. |
| 139 | + tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); |
| 140 | + let declare_tx = RPCDeclareTransaction::V3(tx); |
| 141 | + |
| 142 | + let result = compile_contract_class(&declare_tx); |
| 143 | + assert_matches!( |
| 144 | + result.unwrap_err(), |
| 145 | + GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( |
| 146 | + AllowedLibfuncsError::SierraProgramError |
| 147 | + )) |
| 148 | + ) |
| 149 | +} |
| 150 | + |
| 151 | +#[test] |
| 152 | +fn test_compile_contract_class() { |
| 153 | + let declare_tx = assert_matches!( |
| 154 | + declare_tx(), |
| 155 | + RPCTransaction::Declare(declare_tx) => declare_tx |
| 156 | + ); |
| 157 | + let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; |
| 158 | + let contract_class = &declare_tx_v3.contract_class; |
| 159 | + |
| 160 | + let class_info = compile_contract_class(&declare_tx).unwrap(); |
| 161 | + assert_matches!(class_info.contract_class(), ContractClass::V1(_)); |
| 162 | + assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); |
| 163 | + assert_eq!(class_info.abi_length(), contract_class.abi.len()); |
| 164 | +} |
| 165 | + |
106 | 166 | async fn to_bytes(res: Response) -> Bytes {
|
107 | 167 | res.into_body().collect().await.unwrap().to_bytes()
|
108 | 168 | }
|
|
0 commit comments