diff --git a/Cargo.lock b/Cargo.lock index a655110..b82b344 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -362,8 +362,10 @@ dependencies = [ "rustls", "serde", "serde_json", + "sgx-ocalls", "thiserror", "tiny-keccak", + "tls-enclave", "tracing", "tracing-subscriber", "webpki-roots", @@ -1146,6 +1148,13 @@ dependencies = [ "bindgen", ] +[[package]] +name = "sgx-ocalls" +version = "0.1.0" +dependencies = [ + "tracing", +] + [[package]] name = "sgx_build_helper" version = "2.0.0" @@ -1432,6 +1441,17 @@ dependencies = [ "zerovec", ] +[[package]] +name = "tls-enclave" +version = "0.1.0" +dependencies = [ + "rustls", + "serde_json", + "thiserror", + "tracing", + "webpki-roots", +] + [[package]] name = "tokio" version = "1.44.1" diff --git a/Cargo.toml b/Cargo.toml index 158deb1..f41d29a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [workspace] resolver = "2" -members = ["bin/zktls-pairs", "crates/enclave", "crates/untrusted-host"] +members = ["bin/zktls-pairs", "crates/enclave", "crates/sgx-ocalls", "crates/tls-enclave", "crates/untrusted-host"] [workspace.package] version = "0.1.0" @@ -26,6 +26,8 @@ thiserror = "2.0.12" clap = "3.2" untrusted-host = { path = "crates/untrusted-host" } +tls-enclave = { path = "crates/tls-enclave" } +sgx-ocalls = { path = "crates/sgx-ocalls" } [patch.crates-io] ring = { git = "https://github.com/automata-network/ring-sgx", rev = "e9b37b8f5a7c3331b21a6650f1ce6653d70d0923" } diff --git a/crates/enclave/Cargo.toml b/crates/enclave/Cargo.toml index 3bb5b7f..5379ca2 100644 --- a/crates/enclave/Cargo.toml +++ b/crates/enclave/Cargo.toml @@ -19,6 +19,7 @@ hex.workspace = true rustls.workspace = true webpki-roots.workspace = true thiserror.workspace = true +tls-enclave.workspace = true tracing.workspace = true tracing-subscriber = { workspace = true, features = ["env-filter"]} @@ -26,3 +27,5 @@ tracing-subscriber = { workspace = true, features = ["env-filter"]} serde = { workspace = true, features = ["derive"]} tiny-keccak = { workspace = true, features = ["sha3", "keccak"]} clap = { workspace = true, features = ["derive"] } + +sgx-ocalls.workspace = true diff --git a/crates/enclave/src/lib.rs b/crates/enclave/src/lib.rs index c736608..4605bee 100644 --- a/crates/enclave/src/lib.rs +++ b/crates/enclave/src/lib.rs @@ -2,7 +2,6 @@ extern crate core; mod error; mod parser; -mod tcp_stream_oc; mod tls; use std::{ffi::CString, fmt::Debug, string::String}; @@ -11,41 +10,18 @@ use automata_sgx_sdk::types::SgxStatus; use clap::Parser; use ethabi::{Token, Uint}; use serde_json::json; +use sgx_ocalls::bindings::*; use tiny_keccak::{Hasher, Keccak}; +use tls_enclave::tls_request; -use crate::{parser::get_filtered_items, tcp_stream_oc::UntrustedTcpStreamPtr, tls::tls_request}; - -extern "C" { - fn ocall_get_tcp_stream(server_address: *const u8, stream_ptr: *mut UntrustedTcpStreamPtr); - fn ocall_tcp_write(stream_ptr: UntrustedTcpStreamPtr, data: *const u8, data_len: usize); - fn ocall_tcp_read( - stream_ptr: UntrustedTcpStreamPtr, - buffer: *mut u8, - max_len: usize, - read_len: *mut usize, - ); - - fn ocall_write_to_file( - data_bytes: *const u8, - data_len: usize, - filename_bytes: *const u8, - filename_len: usize, - ); - - fn ocall_read_from_file( - filename_bytes: *const u8, - pairs_list_buffer: *mut u8, - pairs_list_buffer_len: usize, - pairs_list_actual_len: *mut usize, - ); -} +use crate::parser::get_filtered_items; pub(crate) const BINANCE_API_HOST: &str = "data-api.binance.vision"; pub(crate) const HARDCODED_DECIMALS: u32 = 8; #[derive(Parser)] #[clap(author = "Diffuse", version = "v0", about)] -struct ZkTlsPairs { +struct ZkTlsPairsCli { /// Path to the file with pairs #[clap(long, default_value = "pairs/list.txt")] pairs_file_path: String, @@ -53,7 +29,7 @@ struct ZkTlsPairs { #[no_mangle] pub unsafe extern "C" fn trusted_execution() -> SgxStatus { - let cli = ZkTlsPairs::parse(); + let cli = ZkTlsPairsCli::parse(); let env_filter = tracing_subscriber::EnvFilter::try_from_default_env() .unwrap_or_else(|_| tracing_subscriber::EnvFilter::new("info")); @@ -88,7 +64,9 @@ pub unsafe extern "C" fn trusted_execution() -> SgxStatus { let currency_pairs_bytes = json!(currency_pairs).to_string(); - let response_str = match tls_request(BINANCE_API_HOST, currency_pairs_bytes) { + let zk_tls_pairs = tls::ZkTlsPairs::new(BINANCE_API_HOST.to_string(), currency_pairs_bytes); + + let response_str = match tls_request(BINANCE_API_HOST, zk_tls_pairs) { Ok(response) => response, Err(e) => { tracing::error!("Error encountered in TLS request: {e}"); diff --git a/crates/enclave/src/tls.rs b/crates/enclave/src/tls.rs index d107861..5dbb994 100644 --- a/crates/enclave/src/tls.rs +++ b/crates/enclave/src/tls.rs @@ -1,121 +1,69 @@ -use std::{ - ffi::CString, - io::{Read, Write}, - ptr, - sync::Arc, -}; - -use automata_sgx_sdk::types::SgxStatus; -use rustls::{pki_types::ServerName, ClientConnection, RootCertStore, StreamOwned}; +use std::{ffi::CString, fmt::Debug, ptr}; -use crate::{ - error::SgxResult, ocall_get_tcp_stream, tcp_stream_oc::TcpStreamOc, UntrustedTcpStreamPtr, +use sgx_ocalls::{ + bindings::{ocall_get_tcp_stream, UntrustedTcpStreamPtr}, + tcp_stream::TcpStreamOc, +}; +use tls_enclave::{ + error::TlsResult, + traits::{RequestProvider, TcpProvider}, }; -pub(crate) fn tls_request, T: AsRef>( - server_host_name: S, - symbols: T, -) -> SgxResult { - let (mut connection, mut tcp_stream) = open_connection(server_host_name.as_ref())?; - - tracing::info!("Handshaking with {}", server_host_name.as_ref()); - handshake(&mut connection, &mut tcp_stream)?; - tracing::info!("Is handshake done : {:?}", !connection.is_handshaking()); - tracing::info!("TLS version: {:?}", connection.protocol_version()); - - let symbols_request_bytes = generate_request(symbols, server_host_name); - tracing::debug!( - "Sending request: {:?}", - String::from_utf8_lossy(&symbols_request_bytes) - ); - - let resp = request_symbols(connection, tcp_stream, &symbols_request_bytes); - tracing::info!("Is response fine : {}", resp.is_ok()); - resp +#[derive(Debug)] +pub(crate) struct ZkTlsPairs { + pub(crate) server_address: String, + pub(crate) requested_symbols: String, + pub(crate) stream_ptr: TcpStreamOc, } -/// We are using the default rustls configuration, including patched version of ['_ring_'] crate. -fn open_connection>( - server_host_name: S, -) -> SgxResult<(ClientConnection, TcpStreamOc)> { - let server_name = ServerName::try_from(server_host_name.as_ref().to_string())?; - let address_cstr = CString::new(server_host_name.as_ref().to_string() + ":443")?; - let mut stream_ptr: UntrustedTcpStreamPtr = ptr::null_mut(); - - // TODO: remove this call, create TcpStream directly - unsafe { - ocall_get_tcp_stream( - address_cstr.as_ptr() as *const u8, - &mut stream_ptr as *mut UntrustedTcpStreamPtr, - ); - } - if stream_ptr.is_null() { - tracing::error!("Failed to get tcp stream"); - return Err(SgxStatus::Unexpected.into()); - } - let tcp_stream_oc = TcpStreamOc::new(stream_ptr); - - let root_store = RootCertStore { - roots: webpki_roots::TLS_SERVER_ROOTS.into(), - }; - let config = rustls::ClientConfig::builder() - .with_root_certificates(root_store) - .with_no_client_auth(); - - let connection = ClientConnection::new(Arc::new(config), server_name)?; - - Ok((connection, tcp_stream_oc)) -} +impl ZkTlsPairs { + pub fn new(server_address: String, requested_symbols: String) -> Self { + let address_cstr = + CString::new(format!("{server_address}:443")).expect("Failed to create CString"); + let mut stream_ptr: UntrustedTcpStreamPtr = ptr::null_mut(); + + unsafe { + ocall_get_tcp_stream( + address_cstr.as_ptr() as *const u8, + &mut stream_ptr as *mut UntrustedTcpStreamPtr, + ); + } -fn handshake(connection: &mut ClientConnection, tcp_stream: &mut TcpStreamOc) -> SgxResult<()> { - while connection.is_handshaking() { - tracing::debug!("Handshake in progress..."); - if connection.wants_write() { - tracing::debug!("Handshake in progress... write"); - let mut buf = Vec::new(); - connection.write_tls(&mut buf)?; - tcp_stream.write_all(&buf)?; + if stream_ptr.is_null() { + panic!("Failed to create TCP stream"); } - if connection.wants_read() { - tracing::debug!("Handshake in progress... read"); - let mut buf = vec![0u8; 4096]; - let bytes_read = tcp_stream.read(&mut buf)?; - if bytes_read == 0 { - break; - } - buf.truncate(bytes_read); - connection.read_tls(&mut &buf[..])?; + + ZkTlsPairs { + server_address, + requested_symbols, + stream_ptr: TcpStreamOc::new(stream_ptr), } - connection.process_new_packets()?; } - - Ok(()) } -fn request_symbols( - connection: ClientConnection, - tcp_stream: TcpStreamOc, - request: &[u8], -) -> SgxResult { - let mut tls = StreamOwned::new(connection, tcp_stream); - - tls.write_all(request)?; - tls.flush()?; - - let mut resp_vec = Vec::new(); - tls.read_to_end(&mut resp_vec)?; - - Ok(String::from_utf8(resp_vec)?) +impl> RequestProvider for ZkTlsPairs { + fn get_request(&self, server_address: S) -> Vec { + format!( + "GET /api/v3/ticker/24hr?symbols={} HTTP/1.1\r\n\ + Host: {}\r\n\ + Accept: application/json\r\n\ + Connection: close\r\n\r\n", + self.requested_symbols, + server_address.as_ref() + ) + .into_bytes() + } } -fn generate_request, T: AsRef>(symbols: S, server_host_name: T) -> Vec { - format!( - "GET /api/v3/ticker/24hr?symbols={} HTTP/1.1\r\n\ - Host: {}\r\n\ - Accept: application/json\r\n\ - Connection: close\r\n\r\n", - symbols.as_ref(), - server_host_name.as_ref() - ) - .into_bytes() +impl> TcpProvider for ZkTlsPairs { + type Stream = TcpStreamOc; + + fn get(&mut self, server_address: S) -> TlsResult { + assert_eq!( + self.server_address, + server_address.as_ref(), + "Server address mismatch" + ); + Ok(std::mem::take(&mut self.stream_ptr)) + } } diff --git a/crates/sgx-ocalls/Cargo.toml b/crates/sgx-ocalls/Cargo.toml new file mode 100644 index 0000000..681a57f --- /dev/null +++ b/crates/sgx-ocalls/Cargo.toml @@ -0,0 +1,9 @@ +[package] +name = "sgx-ocalls" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +tracing.workspace = true diff --git a/crates/sgx-ocalls/src/bindings.rs b/crates/sgx-ocalls/src/bindings.rs new file mode 100644 index 0000000..63381b1 --- /dev/null +++ b/crates/sgx-ocalls/src/bindings.rs @@ -0,0 +1,30 @@ +use core::ffi::c_void; + +pub type UntrustedTcpStreamPtr = *mut c_void; + +extern "C" { + pub fn ocall_get_tcp_stream(server_address: *const u8, stream_ptr: *mut UntrustedTcpStreamPtr); + + pub fn ocall_tcp_write(stream_ptr: UntrustedTcpStreamPtr, data: *const u8, data_len: usize); + + pub fn ocall_tcp_read( + stream_ptr: UntrustedTcpStreamPtr, + buffer: *mut u8, + max_len: usize, + read_len: *mut usize, + ); + + pub fn ocall_write_to_file( + data_bytes: *const u8, + data_len: usize, + filename_bytes: *const u8, + filename_len: usize, + ); + + pub fn ocall_read_from_file( + filename_bytes: *const u8, + pairs_list_buffer: *mut u8, + pairs_list_buffer_len: usize, + pairs_list_actual_len: *mut usize, + ); +} diff --git a/crates/sgx-ocalls/src/lib.rs b/crates/sgx-ocalls/src/lib.rs new file mode 100644 index 0000000..05ff876 --- /dev/null +++ b/crates/sgx-ocalls/src/lib.rs @@ -0,0 +1,2 @@ +pub mod bindings; +pub mod tcp_stream; diff --git a/crates/enclave/src/tcp_stream_oc.rs b/crates/sgx-ocalls/src/tcp_stream.rs similarity index 74% rename from crates/enclave/src/tcp_stream_oc.rs rename to crates/sgx-ocalls/src/tcp_stream.rs index bd3b6c0..b1b1dbc 100644 --- a/crates/enclave/src/tcp_stream_oc.rs +++ b/crates/sgx-ocalls/src/tcp_stream.rs @@ -3,10 +3,9 @@ use std::{ io::{Read, Write}, }; -use crate::{ocall_tcp_read, ocall_tcp_write}; - -pub(crate) type UntrustedTcpStreamPtr = *mut core::ffi::c_void; +use crate::bindings::{ocall_tcp_read, ocall_tcp_write, UntrustedTcpStreamPtr}; +/// A safe wrapper for an untrusted TCP stream obtained via OCALLs. pub struct TcpStreamOc { stream_ptr: UntrustedTcpStreamPtr, } @@ -17,6 +16,14 @@ impl TcpStreamOc { } } +impl Default for TcpStreamOc { + fn default() -> Self { + TcpStreamOc { + stream_ptr: core::ptr::null_mut(), + } + } +} + impl Read for TcpStreamOc { fn read(&mut self, buf: &mut [u8]) -> std::io::Result { let mut read_len: usize = 0; @@ -31,6 +38,7 @@ impl Read for TcpStreamOc { Ok(read_len) } } + impl Write for TcpStreamOc { fn write(&mut self, buf: &[u8]) -> std::io::Result { unsafe { @@ -45,7 +53,7 @@ impl Write for TcpStreamOc { impl Debug for TcpStreamOc { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "TcpStreamOc") + write!(f, "TcpStreamOc {{ stream_ptr: {:?} }}", self.stream_ptr) } } diff --git a/crates/tls-enclave/Cargo.toml b/crates/tls-enclave/Cargo.toml new file mode 100644 index 0000000..07b7d15 --- /dev/null +++ b/crates/tls-enclave/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "tls-enclave" +version.workspace = true +edition.workspace = true +authors.workspace = true +homepage.workspace = true + +[dependencies] +rustls.workspace = true +tracing.workspace = true +webpki-roots.workspace = true +thiserror.workspace = true +serde_json.workspace = true diff --git a/crates/tls-enclave/src/error.rs b/crates/tls-enclave/src/error.rs new file mode 100644 index 0000000..1110149 --- /dev/null +++ b/crates/tls-enclave/src/error.rs @@ -0,0 +1,24 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum TlsError { + #[error(transparent)] + Io(#[from] std::io::Error), + + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), + + #[error(transparent)] + Tls(#[from] rustls::Error), + + #[error(transparent)] + Ffi(#[from] std::ffi::NulError), + + #[error(transparent)] + DnsName(#[from] rustls::pki_types::InvalidDnsNameError), + + #[error(transparent)] + SerdeJson(#[from] serde_json::Error), +} + +pub type TlsResult = Result; diff --git a/crates/tls-enclave/src/lib.rs b/crates/tls-enclave/src/lib.rs new file mode 100644 index 0000000..a15cd9e --- /dev/null +++ b/crates/tls-enclave/src/lib.rs @@ -0,0 +1,106 @@ +pub mod error; +pub mod traits; + +use std::{ + io::{Read, Write}, + sync::Arc, +}; + +use rustls::{pki_types::ServerName, ClientConnection, RootCertStore, StreamOwned}; + +use crate::{ + error::TlsResult, + traits::{RequestProvider, TcpProvider}, +}; + +pub fn tls_request + Clone, P>( + server_host_name: S, + mut provider: P, +) -> TlsResult +where + P: RequestProvider + TcpProvider, +{ + tracing::info!("Handshaking with {}", server_host_name.as_ref()); + let (mut connection, mut tcp_stream) = + open_connection(server_host_name.clone(), &mut provider)?; + + handshake(&mut connection, &mut tcp_stream)?; + tracing::info!("Is handshake done : {:?}", !connection.is_handshaking()); + tracing::info!("TLS version: {:?}", connection.protocol_version()); + + let symbols_request_bytes = provider.get_request(server_host_name); + + tracing::debug!( + "Sending request: {:?}", + String::from_utf8_lossy(&symbols_request_bytes) + ); + + let resp = request_symbols::(connection, tcp_stream, &symbols_request_bytes); + tracing::info!("Is response fine : {}", resp.is_ok()); + resp +} + +/// We are using the default rustls configuration, including patched version of ['_ring_'] crate. +fn open_connection, P>( + server_host_name: S, + tcp_provider: &mut P, +) -> TlsResult<(ClientConnection, P::Stream)> +where + P: TcpProvider, +{ + let server_name = ServerName::try_from(server_host_name.as_ref().to_string())?; + let tcp_stream = tcp_provider.get(server_host_name)?; + let root_store = RootCertStore { + roots: webpki_roots::TLS_SERVER_ROOTS.into(), + }; + let config = rustls::ClientConfig::builder() + .with_root_certificates(root_store) + .with_no_client_auth(); + let connection = ClientConnection::new(Arc::new(config), server_name)?; + + Ok((connection, tcp_stream)) +} + +fn handshake( + connection: &mut ClientConnection, + tcp_stream: &mut (impl Read + Write), +) -> TlsResult<()> { + while connection.is_handshaking() { + if connection.wants_write() { + let mut buf = Vec::new(); + connection.write_tls(&mut buf)?; + tcp_stream.write_all(&buf)?; + } + if connection.wants_read() { + let mut buf = vec![0u8; 4096]; + let bytes_read = tcp_stream.read(&mut buf)?; + if bytes_read == 0 { + break; + } + buf.truncate(bytes_read); + connection.read_tls(&mut &buf[..])?; + } + connection.process_new_packets()?; + } + Ok(()) +} + +fn request_symbols, T: TcpProvider>( + connection: ClientConnection, + tcp_stream: T::Stream, + request: &[u8], +) -> TlsResult { + let mut tls = StreamOwned::new(connection, tcp_stream); + + tls.write_all(request)?; + tls.flush()?; + + let mut resp_vec = Vec::new(); + match tls.read_to_end(&mut resp_vec) { + Ok(_) => {} + Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof && !resp_vec.is_empty() => {} + Err(e) => return Err(e.into()), + } + + Ok(String::from_utf8(resp_vec)?) +} diff --git a/crates/tls-enclave/src/traits.rs b/crates/tls-enclave/src/traits.rs new file mode 100644 index 0000000..3894386 --- /dev/null +++ b/crates/tls-enclave/src/traits.rs @@ -0,0 +1,20 @@ +use std::{ + fmt::Debug, + io::{Read, Write}, +}; + +use crate::error::TlsResult; + +pub trait TcpProvider>: Debug { + type Stream: Read + Write; + fn get(&mut self, server_address: S) -> TlsResult; +} + +pub trait FileProvider>: Debug { + fn write_to_file(&self, data: &[u8], filename: S); + fn read_from_file(&self, filename: S) -> Vec; +} + +pub trait RequestProvider>: Debug { + fn get_request(&self, server_address: S) -> Vec; +}