diff --git a/Cargo.toml b/Cargo.toml index 252e2eeda50..97999dc4b49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ exclude = [ "esp-metadata", "esp-println", "esp-riscv-rt", + "esp-rustls-provider", "esp-wifi", "esp-storage", "examples", diff --git a/esp-rustls-provider/Cargo.toml b/esp-rustls-provider/Cargo.toml new file mode 100644 index 00000000000..6ec038c3b8a --- /dev/null +++ b/esp-rustls-provider/Cargo.toml @@ -0,0 +1,40 @@ +[package] +name = "esp-rustls-provider" +version = "0.1.0" +edition = "2021" + +[dependencies] +chacha20poly1305 = { version = "0.10", default-features = false, features = [ + "alloc", +] } +der = "0.7" +ecdsa = "0.16.8" +hmac = "0.12" +p256 = { version = "0.13.2", default-features = false, features = [ + "alloc", + "ecdsa", + "pkcs8", +] } +pkcs8 = "0.10.2" +pki-types = { package = "rustls-pki-types", version = "1" } +rand_core = { version = "0.6", default-features = false } + +rustls = { version = "0.23.21", default-features = false, features = ["tls12", "custom-provider"] } +rsa = { version = "0.9", features = ["sha2"], default-features = false } +sha2 = { version = "0.10", default-features = false } +signature = "2" +webpki = { package = "rustls-webpki", version = "0.102", features = [ + "alloc", +], default-features = false } +x25519-dalek = "2" + +# should this be a feature? - not really needed by this crate but re-exported for convenience +webpki-roots = "0.26.1" + +esp-hal = { path = "../esp-hal", version = "0.23.1", default-features = false } +embedded-io = { version = "0.6.1", default-features = false } +log = "0.4.20" + +[features] +log = ["rustls/logging"] +defmt = [] diff --git a/esp-rustls-provider/README.md b/esp-rustls-provider/README.md new file mode 100644 index 00000000000..2c001aa9159 --- /dev/null +++ b/esp-rustls-provider/README.md @@ -0,0 +1,11 @@ +# esp-rustls-provider + +## NO support for targets w/o atomics + +While most dependencies can be used with `portable-atomic` that's unfortunately not true for Rustls itself. It needs `alloc::sync::Arc` in a lot of places. + +This means that ESP32-S2, ESP32-C2 and ESP32-C3 are NOT supported. + +## Status + +This crate is currently experimental/preview. It's not available on crates.io and might be limited in functionality. diff --git a/esp-rustls-provider/src/adapter/client.rs b/esp-rustls-provider/src/adapter/client.rs new file mode 100644 index 00000000000..c65e907a5ca --- /dev/null +++ b/esp-rustls-provider/src/adapter/client.rs @@ -0,0 +1,233 @@ +//! Client connection wrappers + +use super::ConnectionError; + +/// Wrapper for [embedded_io] to be used as a client connection +pub struct ClientConnection<'s, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + socket: S, + conn: rustls::client::UnbufferedClientConnection, + incoming_tls: &'s mut [u8], + outgoing_tls: &'s mut [u8], + incoming_used: usize, + outgoing_used: usize, + + plaintext_in: &'s mut [u8], + plaintext_in_used: usize, + plaintext_out: &'s mut [u8], + plaintext_out_used: usize, +} + +impl<'s, S> ClientConnection<'s, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + pub fn new( + config: alloc::sync::Arc, + server: rustls::pki_types::ServerName<'static>, + socket: S, + incoming_tls: &'s mut [u8], + outgoing_tls: &'s mut [u8], + plaintext_in: &'s mut [u8], + plaintext_out: &'s mut [u8], + ) -> Result)> { + match rustls::client::UnbufferedClientConnection::new(config, server) { + Ok(conn) => Ok(Self { + socket, + conn, + incoming_tls, + outgoing_tls, + incoming_used: 0, + outgoing_used: 0, + + plaintext_in, + plaintext_in_used: 0, + plaintext_out, + plaintext_out_used: 0, + }), + Err(err) => Err((socket, ConnectionError::Rustls(err))), + } + } + + pub fn free(self) -> S { + self.socket + } + + fn work(&mut self) -> Result<(), ConnectionError> { + use rustls::unbuffered::{AppDataRecord, ConnectionState}; + + let mut done = false; + loop { + if done { + debug!("Done work for now"); + break; + } + + debug!( + "Incoming used {}, outgoing used {}, plaintext_in used {}, plaintext_out used {}", + self.incoming_used, + self.outgoing_used, + self.plaintext_in_used, + self.plaintext_out_used + ); + + let rustls::unbuffered::UnbufferedStatus { mut discard, state } = self + .conn + .process_tls_records(&mut self.incoming_tls[..self.incoming_used]); + + debug!("State {:?}", state); + + match state.map_err(ConnectionError::Rustls)? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res.map_err(ConnectionError::Rustls)?; + discard += new_discard; + + self.plaintext_in[self.plaintext_in_used..][..payload.len()] + .copy_from_slice(payload); + self.plaintext_in_used += payload.len(); + + done = true; + } + } + + ConnectionState::EncodeTlsData(mut state) => { + let written = state + .encode(&mut self.outgoing_tls[self.outgoing_used..]) + .map_err(ConnectionError::RustlsEncodeError)?; + self.outgoing_used += written; + } + + ConnectionState::TransmitTlsData(mut state) => { + if let Some(_may_encrypt_early_data) = state.may_encrypt_early_data() { + panic!("Early data unsupported"); + } + + if let Some(_may_encrypt) = state.may_encrypt_app_data() { + debug!("time to sent app data") + } + + debug!("Send tls"); + self.socket + .write_all(&self.outgoing_tls[..self.outgoing_used])?; + self.socket.flush()?; + self.outgoing_used = 0; + state.done(); + } + + ConnectionState::BlockedHandshake { .. } => { + debug!("Receive tls"); + + let read = self + .socket + .read(&mut self.incoming_tls[self.incoming_used..])?; + debug!("Received {read}B of data"); + self.incoming_used += read; + } + + ConnectionState::WriteTraffic(mut may_encrypt) => { + if self.plaintext_out_used != 0 { + let written = may_encrypt + .encrypt( + &self.plaintext_out[..self.plaintext_out_used], + &mut self.outgoing_tls[self.outgoing_used..], + ) + .expect("encrypted request does not fit in `outgoing_tls`"); + self.outgoing_used += written; + self.plaintext_out_used = 0; + + debug!("Send tls"); + self.socket + .write_all(&self.outgoing_tls[..self.outgoing_used])?; + self.socket.flush()?; + self.outgoing_used = 0; + } + done = true; + } + + ConnectionState::Closed => { + // connection has been cleanly closed - these might be still plaintext data to + // be consumed + done = true; + } + + // other states are not expected here + _ => unreachable!(), + } + + if discard != 0 { + assert!(discard <= self.incoming_used); + + self.incoming_tls + .copy_within(discard..self.incoming_used, 0); + self.incoming_used -= discard; + + debug!("Discarded {discard}B from `incoming_tls`"); + } + } + + Ok(()) + } +} + +impl embedded_io::ErrorType for ClientConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + type Error = ConnectionError; +} + +impl embedded_io::Read for ClientConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + fn read(&mut self, buf: &mut [u8]) -> Result { + let tls_read_res = self + .socket + .read(&mut self.incoming_tls[self.incoming_used..]); + + let tls_read = if let Err(err) = tls_read_res { + if self.plaintext_in_used == 0 { + Err(err) + } else { + Ok(0) + } + } else { + tls_read_res + }?; + self.incoming_used += tls_read; + + self.work()?; + + let l = usize::min(buf.len(), self.plaintext_in_used); + buf[0..l].copy_from_slice(&self.plaintext_in[0..l]); + + self.plaintext_in.copy_within(l..self.plaintext_in_used, 0); + self.plaintext_in_used -= l; + + Ok(l) + } +} + +impl embedded_io::Write for ClientConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + fn write(&mut self, buf: &[u8]) -> Result { + self.plaintext_out[self.plaintext_out_used..][..buf.len()].copy_from_slice(buf); + self.plaintext_out_used += buf.len(); + self.work()?; + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<(), Self::Error> { + self.socket.flush()?; + self.work()?; + Ok(()) + } +} diff --git a/esp-rustls-provider/src/adapter/mod.rs b/esp-rustls-provider/src/adapter/mod.rs new file mode 100644 index 00000000000..eae06b77897 --- /dev/null +++ b/esp-rustls-provider/src/adapter/mod.rs @@ -0,0 +1,36 @@ +//! Useful wrappers + +pub mod client; +pub mod server; + +/// Errors returned by the adapters +#[derive(Debug)] +pub enum ConnectionError { + /// Error from embedded-io + Io(E), + /// Error from Rustls + Rustls(rustls::Error), + /// Error from Rustls' `encode`function + RustlsEncodeError(rustls::unbuffered::EncodeError), +} + +impl embedded_io::Error for ConnectionError +where + E: embedded_io::Error, +{ + fn kind(&self) -> embedded_io::ErrorKind { + match self { + ConnectionError::Io(err) => err.kind(), + _ => embedded_io::ErrorKind::Other, + } + } +} + +impl From for ConnectionError +where + E: embedded_io::Error, +{ + fn from(value: E) -> Self { + Self::Io(value) + } +} diff --git a/esp-rustls-provider/src/adapter/server.rs b/esp-rustls-provider/src/adapter/server.rs new file mode 100644 index 00000000000..192c548ad7a --- /dev/null +++ b/esp-rustls-provider/src/adapter/server.rs @@ -0,0 +1,240 @@ +use super::ConnectionError; + +/// Wrapper for [embedded_io] to be used as a server connection +pub struct ServerConnection<'s, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + socket: S, + conn: rustls::server::UnbufferedServerConnection, + incoming_tls: &'s mut [u8], + outgoing_tls: &'s mut [u8], + incoming_used: usize, + outgoing_used: usize, + + plaintext_in: &'s mut [u8], + plaintext_in_used: usize, + plaintext_out: &'s mut [u8], + plaintext_out_used: usize, +} + +impl<'s, S> ServerConnection<'s, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + pub fn new( + config: alloc::sync::Arc, + socket: S, + incoming_tls: &'s mut [u8], + outgoing_tls: &'s mut [u8], + plaintext_in: &'s mut [u8], + plaintext_out: &'s mut [u8], + ) -> Result)> { + match rustls::server::UnbufferedServerConnection::new(config) { + Ok(conn) => { + let mut this = Self { + socket, + conn, + incoming_tls, + outgoing_tls, + incoming_used: 0, + outgoing_used: 0, + + plaintext_in, + plaintext_in_used: 0, + plaintext_out, + plaintext_out_used: 0, + }; + + match this.work() { + Ok(_) => Ok(this), + Err(err) => Err((this.socket, err)), + } + } + Err(err) => Err((socket, ConnectionError::Rustls(err))), + } + } + + pub fn free(self) -> S { + self.socket + } + + fn work(&mut self) -> Result<(), ConnectionError> { + use rustls::unbuffered::{AppDataRecord, ConnectionState}; + + let mut done = false; + loop { + if done { + debug!("Done work for now"); + break; + } + + debug!( + "Incoming used {}, outgoing used {}, plaintext_in used {}, plaintext_out used {}", + self.incoming_used, + self.outgoing_used, + self.plaintext_in_used, + self.plaintext_out_used + ); + + let rustls::unbuffered::UnbufferedStatus { mut discard, state } = self + .conn + .process_tls_records(&mut self.incoming_tls[..self.incoming_used]); + + debug!("State {:?}", state); + + match state.map_err(ConnectionError::Rustls)? { + ConnectionState::ReadTraffic(mut state) => { + while let Some(res) = state.next_record() { + let AppDataRecord { + discard: new_discard, + payload, + } = res.map_err(ConnectionError::Rustls)?; + discard += new_discard; + + self.plaintext_in[self.plaintext_in_used..][..payload.len()] + .copy_from_slice(payload); + self.plaintext_in_used += payload.len(); + + done = true; + } + } + + ConnectionState::ReadEarlyData(_state) => { + panic!("Unsupported early-data"); + } + + ConnectionState::EncodeTlsData(mut state) => { + let written = state + .encode(&mut self.outgoing_tls[self.outgoing_used..]) + .map_err(ConnectionError::RustlsEncodeError)?; + self.outgoing_used += written; + } + + ConnectionState::TransmitTlsData(state) => { + debug!("Send tls"); + self.socket + .write_all(&self.outgoing_tls[..self.outgoing_used])?; + self.socket.flush()?; + self.outgoing_used = 0; + state.done(); + } + + ConnectionState::BlockedHandshake { .. } => { + debug!("Receive tls"); + + let read = self + .socket + .read(&mut self.incoming_tls[self.incoming_used..])?; + debug!("Received {read}B of data"); + self.incoming_used += read; + + if read == 0 { + return Ok(()); + } + } + + ConnectionState::WriteTraffic(mut may_encrypt) => { + if self.plaintext_out_used != 0 { + let written = may_encrypt + .encrypt( + &self.plaintext_out[..self.plaintext_out_used], + &mut self.outgoing_tls[self.outgoing_used..], + ) + .expect("encrypted request does not fit in `outgoing_tls`"); + self.outgoing_used += written; + self.plaintext_out_used = 0; + + debug!("Send tls"); + self.socket + .write_all(&self.outgoing_tls[..self.outgoing_used])?; + self.socket.flush()?; + self.outgoing_used = 0; + } + done = true; + } + + ConnectionState::Closed => { + // connection has been cleanly closed - these might be still plaintext data to + // be consumed + done = true; + } + + // other states are not expected here + _ => unreachable!(), + } + + if discard != 0 { + assert!(discard <= self.incoming_used); + + self.incoming_tls + .copy_within(discard..self.incoming_used, 0); + self.incoming_used -= discard; + + debug!("Discarded {discard}B from `incoming_tls`"); + } + } + + Ok(()) + } +} + +impl embedded_io::ErrorType for ServerConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + type Error = ConnectionError; +} + +impl embedded_io::Read for ServerConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write + embedded_io::ReadReady, +{ + fn read(&mut self, buf: &mut [u8]) -> Result { + let tls_read_res = if let Ok(true) = self.socket.read_ready() { + self.socket + .read(&mut self.incoming_tls[self.incoming_used..]) + } else { + Ok(0) + }; + + let tls_read = if let Err(err) = tls_read_res { + if self.plaintext_in_used == 0 { + Err(err) + } else { + Ok(0) + } + } else { + tls_read_res + }?; + self.incoming_used += tls_read; + + self.work()?; + + let l = usize::min(buf.len(), self.plaintext_in_used); + buf[0..l].copy_from_slice(&self.plaintext_in[0..l]); + + self.plaintext_in.copy_within(l..self.plaintext_in_used, 0); + self.plaintext_in_used -= l; + + Ok(l) + } +} + +impl embedded_io::Write for ServerConnection<'_, S> +where + S: embedded_io::Read + embedded_io::Write, +{ + fn write(&mut self, buf: &[u8]) -> Result { + self.plaintext_out[self.plaintext_out_used..][..buf.len()].copy_from_slice(buf); + self.plaintext_out_used += buf.len(); + self.work()?; + Ok(buf.len()) + } + + fn flush(&mut self) -> Result<(), Self::Error> { + self.socket.flush()?; + self.work()?; + Ok(()) + } +} diff --git a/esp-rustls-provider/src/aead.rs b/esp-rustls-provider/src/aead.rs new file mode 100644 index 00000000000..8e4263a4b41 --- /dev/null +++ b/esp-rustls-provider/src/aead.rs @@ -0,0 +1,249 @@ +use alloc::boxed::Box; + +use chacha20poly1305::{aead::Buffer, AeadInPlace, KeyInit, KeySizeUser}; +use rustls::{ + crypto::cipher::{ + make_tls12_aad, + make_tls13_aad, + AeadKey, + BorrowedPayload, + InboundOpaqueMessage, + InboundPlainMessage, + Iv, + KeyBlockShape, + MessageDecrypter, + MessageEncrypter, + Nonce, + OutboundOpaqueMessage, + OutboundPlainMessage, + PrefixedPayload, + Tls12AeadAlgorithm, + Tls13AeadAlgorithm, + UnsupportedOperationError, + NONCE_LEN, + }, + ConnectionTrafficSecrets, + ContentType, + ProtocolVersion, +}; + +pub struct Chacha20Poly1305; + +impl Tls13AeadAlgorithm for Chacha20Poly1305 { + fn encrypter(&self, key: AeadKey, iv: Iv) -> Box { + Box::new(Tls13Cipher( + chacha20poly1305::ChaCha20Poly1305::new_from_slice(key.as_ref()).unwrap(), + iv, + )) + } + + fn decrypter(&self, key: AeadKey, iv: Iv) -> Box { + Box::new(Tls13Cipher( + chacha20poly1305::ChaCha20Poly1305::new_from_slice(key.as_ref()).unwrap(), + iv, + )) + } + + fn key_len(&self) -> usize { + chacha20poly1305::ChaCha20Poly1305::key_size() + } + + fn extract_keys( + &self, + key: AeadKey, + iv: Iv, + ) -> Result { + Ok(ConnectionTrafficSecrets::Chacha20Poly1305 { key, iv }) + } +} + +impl Tls12AeadAlgorithm for Chacha20Poly1305 { + fn encrypter(&self, key: AeadKey, iv: &[u8], _: &[u8]) -> Box { + Box::new(Tls12Cipher( + chacha20poly1305::ChaCha20Poly1305::new_from_slice(key.as_ref()).unwrap(), + Iv::copy(iv), + )) + } + + fn decrypter(&self, key: AeadKey, iv: &[u8]) -> Box { + Box::new(Tls12Cipher( + chacha20poly1305::ChaCha20Poly1305::new_from_slice(key.as_ref()).unwrap(), + Iv::copy(iv), + )) + } + + fn key_block_shape(&self) -> KeyBlockShape { + KeyBlockShape { + enc_key_len: 32, + fixed_iv_len: 12, + explicit_nonce_len: 0, + } + } + + fn extract_keys( + &self, + key: AeadKey, + iv: &[u8], + _explicit: &[u8], + ) -> Result { + // This should always be true because KeyBlockShape and the Iv nonce len are in + // agreement. + debug_assert_eq!(NONCE_LEN, iv.len()); + Ok(ConnectionTrafficSecrets::Chacha20Poly1305 { + key, + iv: Iv::new(iv[..].try_into().unwrap()), + }) + } +} + +struct Tls13Cipher(chacha20poly1305::ChaCha20Poly1305, Iv); + +impl MessageEncrypter for Tls13Cipher { + fn encrypt( + &mut self, + m: OutboundPlainMessage, + seq: u64, + ) -> Result { + let total_len = self.encrypted_payload_len(m.payload.len()); + let mut payload = PrefixedPayload::with_capacity(total_len); + + payload.extend_from_chunks(&m.payload); + payload.extend_from_slice(&m.typ.to_array()); + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_tls13_aad(total_len); + + self.0 + .encrypt_in_place(&nonce, &aad, &mut EncryptBufferAdapter(&mut payload)) + .map_err(|_| rustls::Error::EncryptError) + .map(|_| { + OutboundOpaqueMessage::new( + ContentType::ApplicationData, + ProtocolVersion::TLSv1_2, + payload, + ) + }) + } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + 1 + CHACHAPOLY1305_OVERHEAD + } +} + +impl MessageDecrypter for Tls13Cipher { + fn decrypt<'a>( + &mut self, + mut m: InboundOpaqueMessage<'a>, + seq: u64, + ) -> Result, rustls::Error> { + let payload = &mut m.payload; + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_tls13_aad(payload.len()); + + self.0 + .decrypt_in_place(&nonce, &aad, &mut DecryptBufferAdapter(payload)) + .map_err(|_| rustls::Error::DecryptError)?; + + m.into_tls13_unpadded_message() + } +} + +struct Tls12Cipher(chacha20poly1305::ChaCha20Poly1305, Iv); + +impl MessageEncrypter for Tls12Cipher { + fn encrypt( + &mut self, + m: OutboundPlainMessage, + seq: u64, + ) -> Result { + let total_len = self.encrypted_payload_len(m.payload.len()); + let mut payload = PrefixedPayload::with_capacity(total_len); + + payload.extend_from_chunks(&m.payload); + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_tls12_aad(seq, m.typ, m.version, m.payload.len()); + + self.0 + .encrypt_in_place(&nonce, &aad, &mut EncryptBufferAdapter(&mut payload)) + .map_err(|_| rustls::Error::EncryptError) + .map(|_| OutboundOpaqueMessage::new(m.typ, m.version, payload)) + } + + fn encrypted_payload_len(&self, payload_len: usize) -> usize { + payload_len + CHACHAPOLY1305_OVERHEAD + } +} + +impl MessageDecrypter for Tls12Cipher { + fn decrypt<'a>( + &mut self, + mut m: InboundOpaqueMessage<'a>, + seq: u64, + ) -> Result, rustls::Error> { + let payload = &m.payload; + let nonce = chacha20poly1305::Nonce::from(Nonce::new(&self.1, seq).0); + let aad = make_tls12_aad( + seq, + m.typ, + m.version, + payload.len() - CHACHAPOLY1305_OVERHEAD, + ); + + let payload = &mut m.payload; + self.0 + .decrypt_in_place(&nonce, &aad, &mut DecryptBufferAdapter(payload)) + .map_err(|_| rustls::Error::DecryptError)?; + + Ok(m.into_plain_message()) + } +} + +const CHACHAPOLY1305_OVERHEAD: usize = 16; + +struct EncryptBufferAdapter<'a>(&'a mut PrefixedPayload); + +impl AsRef<[u8]> for EncryptBufferAdapter<'_> { + fn as_ref(&self) -> &[u8] { + self.0.as_ref() + } +} + +impl AsMut<[u8]> for EncryptBufferAdapter<'_> { + fn as_mut(&mut self) -> &mut [u8] { + self.0.as_mut() + } +} + +impl Buffer for EncryptBufferAdapter<'_> { + fn extend_from_slice(&mut self, other: &[u8]) -> chacha20poly1305::aead::Result<()> { + self.0.extend_from_slice(other); + Ok(()) + } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len) + } +} + +struct DecryptBufferAdapter<'a, 'p>(&'a mut BorrowedPayload<'p>); + +impl AsRef<[u8]> for DecryptBufferAdapter<'_, '_> { + fn as_ref(&self) -> &[u8] { + self.0 + } +} + +impl AsMut<[u8]> for DecryptBufferAdapter<'_, '_> { + fn as_mut(&mut self) -> &mut [u8] { + self.0 + } +} + +impl Buffer for DecryptBufferAdapter<'_, '_> { + fn extend_from_slice(&mut self, _: &[u8]) -> chacha20poly1305::aead::Result<()> { + unreachable!("not used by `AeadInPlace::decrypt_in_place`") + } + + fn truncate(&mut self, len: usize) { + self.0.truncate(len) + } +} diff --git a/esp-rustls-provider/src/fmt.rs b/esp-rustls-provider/src/fmt.rs new file mode 100644 index 00000000000..aa009027405 --- /dev/null +++ b/esp-rustls-provider/src/fmt.rs @@ -0,0 +1,226 @@ +#![macro_use] +#![allow(unused_macros)] + +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("You may not enable both `defmt` and `log` features."); + +macro_rules! assert { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert!($($x)*); + } + }; +} + +macro_rules! assert_eq { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert_eq!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert_eq!($($x)*); + } + }; +} + +macro_rules! assert_ne { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::assert_ne!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::assert_ne!($($x)*); + } + }; +} + +macro_rules! debug_assert { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert!($($x)*); + } + }; +} + +macro_rules! debug_assert_eq { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert_eq!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert_eq!($($x)*); + } + }; +} + +macro_rules! debug_assert_ne { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::debug_assert_ne!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::debug_assert_ne!($($x)*); + } + }; +} + +macro_rules! todo { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::todo!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::todo!($($x)*); + } + }; +} + +macro_rules! unreachable { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::unreachable!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::unreachable!($($x)*); + } + }; +} + +macro_rules! panic { + ($($x:tt)*) => { + { + #[cfg(not(feature = "defmt"))] + ::core::panic!($($x)*); + #[cfg(feature = "defmt")] + ::defmt::panic!($($x)*); + } + }; +} + +macro_rules! trace { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::trace!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::trace!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! debug { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::debug!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::debug!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! info { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::info!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::info!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! warn { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::warn!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::warn!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +macro_rules! error { + ($s:literal $(, $x:expr)* $(,)?) => { + { + #[cfg(feature = "log")] + ::log::error!($s $(, $x)*); + #[cfg(feature = "defmt")] + ::defmt::error!($s $(, $x)*); + #[cfg(not(any(feature = "log", feature="defmt")))] + let _ = ($( & $x ),*); + } + }; +} + +#[cfg(feature = "defmt")] +macro_rules! unwrap { + ($($x:tt)*) => { + ::defmt::unwrap!($($x)*) + }; +} + +#[cfg(not(feature = "defmt"))] +macro_rules! unwrap { + ($arg:expr) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {:?}", ::core::stringify!($arg), e); + } + } + }; + ($arg:expr, $($msg:expr),+ $(,)? ) => { + match $crate::fmt::Try::into_result($arg) { + ::core::result::Result::Ok(t) => t, + ::core::result::Result::Err(e) => { + ::core::panic!("unwrap of `{}` failed: {}: {:?}", ::core::stringify!($arg), ::core::format_args!($($msg,)*), e); + } + } + } +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +pub struct NoneError; + +pub trait Try { + type Ok; + type Error; + #[allow(unused)] + fn into_result(self) -> Result; +} + +impl Try for Option { + type Ok = T; + type Error = NoneError; + + #[inline] + fn into_result(self) -> Result { + self.ok_or(NoneError) + } +} + +impl Try for Result { + type Ok = T; + type Error = E; + + #[inline] + fn into_result(self) -> Self { + self + } +} diff --git a/esp-rustls-provider/src/hash.rs b/esp-rustls-provider/src/hash.rs new file mode 100644 index 00000000000..4f28b0f2e41 --- /dev/null +++ b/esp-rustls-provider/src/hash.rs @@ -0,0 +1,44 @@ +use alloc::boxed::Box; + +use rustls::crypto::hash; +use sha2::Digest; + +pub struct Sha256; + +impl hash::Hash for Sha256 { + fn start(&self) -> Box { + Box::new(Sha256Context(sha2::Sha256::new())) + } + + fn hash(&self, data: &[u8]) -> hash::Output { + hash::Output::new(&sha2::Sha256::digest(data)[..]) + } + + fn algorithm(&self) -> hash::HashAlgorithm { + hash::HashAlgorithm::SHA256 + } + + fn output_len(&self) -> usize { + 32 + } +} + +struct Sha256Context(sha2::Sha256); + +impl hash::Context for Sha256Context { + fn fork_finish(&self) -> hash::Output { + hash::Output::new(&self.0.clone().finalize()[..]) + } + + fn fork(&self) -> Box { + Box::new(Sha256Context(self.0.clone())) + } + + fn finish(self: Box) -> hash::Output { + hash::Output::new(&self.0.finalize()[..]) + } + + fn update(&mut self, data: &[u8]) { + self.0.update(data); + } +} diff --git a/esp-rustls-provider/src/hmac.rs b/esp-rustls-provider/src/hmac.rs new file mode 100644 index 00000000000..763dd71e092 --- /dev/null +++ b/esp-rustls-provider/src/hmac.rs @@ -0,0 +1,35 @@ +use alloc::boxed::Box; + +use hmac::{Hmac, Mac}; +use rustls::crypto; +use sha2::{Digest, Sha256}; + +pub struct Sha256Hmac; + +impl crypto::hmac::Hmac for Sha256Hmac { + fn with_key(&self, key: &[u8]) -> Box { + Box::new(Sha256HmacKey(Hmac::::new_from_slice(key).unwrap())) + } + + fn hash_output_len(&self) -> usize { + Sha256::output_size() + } +} + +struct Sha256HmacKey(Hmac); + +impl crypto::hmac::Key for Sha256HmacKey { + fn sign_concat(&self, first: &[u8], middle: &[&[u8]], last: &[u8]) -> crypto::hmac::Tag { + let mut ctx = self.0.clone(); + ctx.update(first); + for m in middle { + ctx.update(m); + } + ctx.update(last); + crypto::hmac::Tag::new(&ctx.finalize().into_bytes()[..]) + } + + fn tag_len(&self) -> usize { + Sha256::output_size() + } +} diff --git a/esp-rustls-provider/src/kx.rs b/esp-rustls-provider/src/kx.rs new file mode 100644 index 00000000000..edb912805d7 --- /dev/null +++ b/esp-rustls-provider/src/kx.rs @@ -0,0 +1,58 @@ +use alloc::boxed::Box; + +use crypto::SupportedKxGroup; +use rustls::{crypto, ffdhe_groups::FfdheGroup}; + +pub struct KeyExchange { + priv_key: x25519_dalek::EphemeralSecret, + pub_key: x25519_dalek::PublicKey, +} + +impl crypto::ActiveKeyExchange for KeyExchange { + fn complete( + self: Box, + peer: &[u8], + ) -> Result { + let peer_array: [u8; 32] = peer + .try_into() + .map_err(|_| rustls::Error::from(rustls::PeerMisbehaved::InvalidKeyShare))?; + let their_pub = x25519_dalek::PublicKey::from(peer_array); + let shared_secret = self.priv_key.diffie_hellman(&their_pub); + Ok(crypto::SharedSecret::from(&shared_secret.as_bytes()[..])) + } + + fn pub_key(&self) -> &[u8] { + self.pub_key.as_bytes() + } + + fn ffdhe_group(&self) -> Option> { + None + } + + fn group(&self) -> rustls::NamedGroup { + X25519.name() + } +} + +pub const ALL_KX_GROUPS: &[&dyn SupportedKxGroup] = &[&X25519 as &dyn SupportedKxGroup]; + +#[derive(Debug)] +pub struct X25519; + +impl crypto::SupportedKxGroup for X25519 { + fn start(&self) -> Result, rustls::Error> { + let priv_key = x25519_dalek::EphemeralSecret::random_from_rng(crate::ProbablyTrng); + Ok(Box::new(KeyExchange { + pub_key: (&priv_key).into(), + priv_key, + })) + } + + fn ffdhe_group(&self) -> Option> { + None + } + + fn name(&self) -> rustls::NamedGroup { + rustls::NamedGroup::X25519 + } +} diff --git a/esp-rustls-provider/src/lib.rs b/esp-rustls-provider/src/lib.rs new file mode 100644 index 00000000000..874b7ebc273 --- /dev/null +++ b/esp-rustls-provider/src/lib.rs @@ -0,0 +1,140 @@ +//! A Rustls [CryptoProvider] and useful utilities + +#![no_std] + +extern crate alloc; + +// MUST be the first module +mod fmt; + +use alloc::sync::Arc; + +use rand_core::RngCore; +use rustls::{crypto::CryptoProvider, pki_types::PrivateKeyDer}; + +mod aead; +mod hash; +mod hmac; +mod kx; +mod sign; +mod verify; + +pub mod adapter; + +// some re-exports to make the user's life easier +pub use rustls; +pub use webpki; +pub use webpki_roots; + +/// Creates a ready-to-use Rustls [CryptoProvider] +pub fn provider() -> CryptoProvider { + CryptoProvider { + cipher_suites: ALL_CIPHER_SUITES.to_vec(), + kx_groups: kx::ALL_KX_GROUPS.to_vec(), + signature_verification_algorithms: verify::ALGORITHMS, + secure_random: &Provider, + key_provider: &Provider, + } +} + +/// Assume the RNG is actually producing true random numbers - which is the case +/// when radio peripherals are enabled +struct ProbablyTrng; + +impl rand_core::RngCore for ProbablyTrng { + fn next_u32(&mut self) -> u32 { + esp_hal::rng::Rng::new(unsafe { esp_hal::peripherals::RNG::steal() }).next_u32() + } + + fn next_u64(&mut self) -> u64 { + esp_hal::rng::Rng::new(unsafe { esp_hal::peripherals::RNG::steal() }).next_u64() + } + + fn fill_bytes(&mut self, dest: &mut [u8]) { + esp_hal::rng::Rng::new(unsafe { esp_hal::peripherals::RNG::steal() }).fill_bytes(dest); + } + + fn try_fill_bytes(&mut self, dest: &mut [u8]) -> Result<(), rand_core::Error> { + esp_hal::rng::Rng::new(unsafe { esp_hal::peripherals::RNG::steal() }).try_fill_bytes(dest) + } +} + +impl rand_core::CryptoRng for ProbablyTrng {} + +/// A time provider to be used by Rustls +#[derive(Debug)] +pub struct EspTimeProvider { + offset: u32, +} + +impl EspTimeProvider { + pub fn new(offset: u32) -> Self { + Self { offset } + } +} + +impl rustls::time_provider::TimeProvider for EspTimeProvider { + fn current_time(&self) -> Option { + let now = esp_hal::time::now().duration_since_epoch().to_secs(); + Some(rustls::pki_types::UnixTime::since_unix_epoch( + core::time::Duration::from_secs(self.offset as u64 + now), + )) + } +} + +#[derive(Debug)] +struct Provider; + +impl rustls::crypto::SecureRandom for Provider { + fn fill(&self, bytes: &mut [u8]) -> Result<(), rustls::crypto::GetRandomFailed> { + let mut rng = ProbablyTrng; + rng.fill_bytes(bytes); + Ok(()) + } +} + +impl rustls::crypto::KeyProvider for Provider { + fn load_private_key( + &self, + key_der: PrivateKeyDer<'static>, + ) -> Result, rustls::Error> { + Ok(Arc::new( + sign::EcdsaSigningKeyP256::try_from(key_der) + .map_err(|err| rustls::Error::General(alloc::format!("{}", err)))?, + )) + } +} + +/// All supported cipher suites +pub static ALL_CIPHER_SUITES: &[rustls::SupportedCipherSuite] = &[ + TLS13_CHACHA20_POLY1305_SHA256, + TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, +]; + +static TLS13_CHACHA20_POLY1305_SHA256: rustls::SupportedCipherSuite = + rustls::SupportedCipherSuite::Tls13(&rustls::Tls13CipherSuite { + common: rustls::crypto::CipherSuiteCommon { + suite: rustls::CipherSuite::TLS13_CHACHA20_POLY1305_SHA256, + hash_provider: &hash::Sha256, + confidentiality_limit: u64::MAX, + }, + hkdf_provider: &rustls::crypto::tls13::HkdfUsingHmac(&hmac::Sha256Hmac), + aead_alg: &aead::Chacha20Poly1305, + quic: None, + }); + +static TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256: rustls::SupportedCipherSuite = + rustls::SupportedCipherSuite::Tls12(&rustls::Tls12CipherSuite { + common: rustls::crypto::CipherSuiteCommon { + suite: rustls::CipherSuite::TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256, + hash_provider: &hash::Sha256, + confidentiality_limit: u64::MAX, + }, + kx: rustls::crypto::KeyExchangeAlgorithm::ECDHE, + sign: &[ + rustls::SignatureScheme::RSA_PSS_SHA256, + rustls::SignatureScheme::RSA_PKCS1_SHA256, + ], + prf_provider: &rustls::crypto::tls12::PrfUsingHmac(&hmac::Sha256Hmac), + aead_alg: &aead::Chacha20Poly1305, + }); diff --git a/esp-rustls-provider/src/sign.rs b/esp-rustls-provider/src/sign.rs new file mode 100644 index 00000000000..a2b14543962 --- /dev/null +++ b/esp-rustls-provider/src/sign.rs @@ -0,0 +1,59 @@ +use alloc::{boxed::Box, sync::Arc, vec::Vec}; + +use pkcs8::DecodePrivateKey; +use rustls::{ + pki_types::PrivateKeyDer, + sign::{Signer, SigningKey}, + SignatureAlgorithm, + SignatureScheme, +}; +use signature::{RandomizedSigner, SignatureEncoding}; + +#[derive(Clone, Debug)] +pub struct EcdsaSigningKeyP256 { + key: Arc, + scheme: SignatureScheme, +} + +impl TryFrom> for EcdsaSigningKeyP256 { + type Error = pkcs8::Error; + + fn try_from(value: PrivateKeyDer<'_>) -> Result { + match value { + PrivateKeyDer::Pkcs8(der) => { + p256::ecdsa::SigningKey::from_pkcs8_der(der.secret_pkcs8_der()).map(|kp| Self { + key: Arc::new(kp), + scheme: SignatureScheme::ECDSA_NISTP256_SHA256, + }) + } + _ => panic!("unsupported private key format"), + } + } +} + +impl SigningKey for EcdsaSigningKeyP256 { + fn choose_scheme(&self, offered: &[SignatureScheme]) -> Option> { + if offered.contains(&self.scheme) { + Some(Box::new(self.clone())) + } else { + None + } + } + + fn algorithm(&self) -> SignatureAlgorithm { + SignatureAlgorithm::ECDSA + } +} + +impl Signer for EcdsaSigningKeyP256 { + fn sign(&self, message: &[u8]) -> Result, rustls::Error> { + self.key + .try_sign_with_rng(&mut crate::ProbablyTrng, message) + .map_err(|_| rustls::Error::General("signing failed".into())) + .map(|sig: p256::ecdsa::DerSignature| sig.to_vec()) + } + + fn scheme(&self) -> SignatureScheme { + self.scheme + } +} diff --git a/esp-rustls-provider/src/verify.rs b/esp-rustls-provider/src/verify.rs new file mode 100644 index 00000000000..b25823c6805 --- /dev/null +++ b/esp-rustls-provider/src/verify.rs @@ -0,0 +1,88 @@ +use der::Reader; +use rsa::{pkcs1v15, pss, signature::Verifier, BigUint, RsaPublicKey}; +use rustls::{ + crypto::WebPkiSupportedAlgorithms, + pki_types::{AlgorithmIdentifier, InvalidSignature, SignatureVerificationAlgorithm}, + SignatureScheme, +}; +use webpki::alg_id; + +pub static ALGORITHMS: WebPkiSupportedAlgorithms = WebPkiSupportedAlgorithms { + all: &[RSA_PSS_SHA256, RSA_PKCS1_SHA256], + mapping: &[ + (SignatureScheme::RSA_PSS_SHA256, &[RSA_PSS_SHA256]), + (SignatureScheme::RSA_PKCS1_SHA256, &[RSA_PKCS1_SHA256]), + ], +}; + +static RSA_PSS_SHA256: &dyn SignatureVerificationAlgorithm = &RsaPssSha256Verify; +static RSA_PKCS1_SHA256: &dyn SignatureVerificationAlgorithm = &RsaPkcs1Sha256Verify; + +#[derive(Debug)] +struct RsaPssSha256Verify; + +impl SignatureVerificationAlgorithm for RsaPssSha256Verify { + fn public_key_alg_id(&self) -> AlgorithmIdentifier { + alg_id::RSA_ENCRYPTION + } + + fn signature_alg_id(&self) -> AlgorithmIdentifier { + alg_id::RSA_PSS_SHA256 + } + + fn verify_signature( + &self, + public_key: &[u8], + message: &[u8], + signature: &[u8], + ) -> Result<(), InvalidSignature> { + let public_key = decode_spki_spk(public_key)?; + + let signature = pss::Signature::try_from(signature).map_err(|_| InvalidSignature)?; + + pss::VerifyingKey::::new(public_key) + .verify(message, &signature) + .map_err(|_| InvalidSignature) + } +} + +#[derive(Debug)] +struct RsaPkcs1Sha256Verify; + +impl SignatureVerificationAlgorithm for RsaPkcs1Sha256Verify { + fn public_key_alg_id(&self) -> AlgorithmIdentifier { + alg_id::RSA_ENCRYPTION + } + + fn signature_alg_id(&self) -> AlgorithmIdentifier { + alg_id::RSA_PKCS1_SHA256 + } + + fn verify_signature( + &self, + public_key: &[u8], + message: &[u8], + signature: &[u8], + ) -> Result<(), InvalidSignature> { + let public_key = decode_spki_spk(public_key)?; + + let signature = pkcs1v15::Signature::try_from(signature).map_err(|_| InvalidSignature)?; + + pkcs1v15::VerifyingKey::::new(public_key) + .verify(message, &signature) + .map_err(|_| InvalidSignature) + } +} + +fn decode_spki_spk(spki_spk: &[u8]) -> Result { + // public_key: unfortunately this is not a whole SPKI, but just the key + // material. decode the two integers manually. + let mut reader = der::SliceReader::new(spki_spk).map_err(|_| InvalidSignature)?; + let ne: [der::asn1::UintRef; 2] = reader.decode().map_err(|_| InvalidSignature)?; + + RsaPublicKey::new( + BigUint::from_bytes_be(ne[0].as_bytes()), + BigUint::from_bytes_be(ne[1].as_bytes()), + ) + .map_err(|_| InvalidSignature) +} diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 07e1266e0a9..c68d1258cc6 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -27,6 +27,7 @@ esp-alloc = { path = "../esp-alloc" } esp-backtrace = { path = "../esp-backtrace", features = ["exception-handler", "panic-handler", "println"] } esp-hal = { path = "../esp-hal", features = ["log"] } esp-hal-embassy = { path = "../esp-hal-embassy", optional = true } +esp-rustls-provider = { path = "../esp-rustls-provider", optional = true } esp-ieee802154 = { path = "../esp-ieee802154", optional = true } esp-println = { path = "../esp-println", features = ["log"] } esp-storage = { path = "../esp-storage", optional = true } @@ -38,6 +39,7 @@ ieee80211 = { version = "0.4.0", default-features = false } ieee802154 = "0.6.1" log = "0.4.22" nb = "1.1.0" +ntp-nostd = { version = "0.0.1", optional = true } portable-atomic = { version = "1.9.0", default-features = false } sha2 = { version = "0.10.8", default-features = false } smoltcp = { version = "0.12.0", default-features = false, features = [ "medium-ethernet", "socket-raw"] } diff --git a/examples/src/bin/wifi_rustls_client.rs b/examples/src/bin/wifi_rustls_client.rs new file mode 100644 index 00000000000..57a856bd8ea --- /dev/null +++ b/examples/src/bin/wifi_rustls_client.rs @@ -0,0 +1,300 @@ +//! Example using Rustls +//! +//! In general, if you can get away with < https://crates.io/crates/embedded-tls > you should prefer that. +//! +//! While Rustls can do more it's more of a heavy-weight dependency. This needs `alloc` +//! +//! +//! Set SSID and PASSWORD env variable before running this example. +//! +//! This gets an ip address via DHCP then performs an HTTPS request for "google.com" + +//% FEATURES: esp-wifi esp-wifi/wifi esp-wifi/utils esp-hal/unstable esp-rustls-provider ntp-nostd +//% CHIPS: esp32 esp32s3 esp32c6 + +#![no_std] +#![no_main] + +extern crate alloc; + +use blocking_network_stack::Stack; +use embedded_io::*; +use esp_backtrace as _; +use esp_hal::{clock::CpuClock, main, rng::Rng, time::Duration, timer::timg::TimerGroup}; +use esp_println::{print, println}; +use esp_rustls_provider::{ + adapter::client::ClientConnection, + rustls, + webpki_roots, + EspTimeProvider, +}; +use esp_wifi::{ + config::PowerSaveMode, + wifi::{ + utils::create_network_interface, + AccessPointInfo, + ClientConfiguration, + Configuration, + WifiDevice, + WifiError, + WifiStaDevice, + }, +}; +use smoltcp::{ + iface::{SocketSet, SocketStorage}, + wire::{DhcpOption, IpAddress, Ipv4Address}, +}; + +const SSID: &str = env!("SSID"); +const PASSWORD: &str = env!("PASSWORD"); + +const KB: usize = 1024; +const INCOMING_TLS_BUFSIZE: usize = 16 * KB; +const OUTGOING_TLS_INITIAL_BUFSIZE: usize = KB; + +#[main] +fn main() -> ! { + esp_println::logger::init_logger_from_env(); + let config = esp_hal::Config::default().with_cpu_clock(CpuClock::max()); + let peripherals = esp_hal::init(config); + + static mut HEAP: core::mem::MaybeUninit<[u8; 72 * 1024]> = core::mem::MaybeUninit::uninit(); + + #[link_section = ".dram2_uninit"] + static mut HEAP2: core::mem::MaybeUninit<[u8; 64 * 1024]> = core::mem::MaybeUninit::uninit(); + + #[allow(static_mut_refs)] + unsafe { + esp_alloc::HEAP.add_region(esp_alloc::HeapRegion::new( + HEAP.as_mut_ptr() as *mut u8, + core::mem::size_of_val(&*core::ptr::addr_of!(HEAP)), + esp_alloc::MemoryCapability::Internal.into(), + )); + + // COEX needs more RAM - add some more + esp_alloc::HEAP.add_region(esp_alloc::HeapRegion::new( + HEAP2.as_mut_ptr() as *mut u8, + core::mem::size_of_val(&*core::ptr::addr_of!(HEAP2)), + esp_alloc::MemoryCapability::Internal.into(), + )); + } + + let timg0 = TimerGroup::new(peripherals.TIMG0); + + let mut rng = Rng::new(peripherals.RNG); + + let init = esp_wifi::init(timg0.timer0, rng.clone(), peripherals.RADIO_CLK).unwrap(); + + let mut wifi = peripherals.WIFI; + let (iface, device, mut controller) = + create_network_interface(&init, &mut wifi, WifiStaDevice).unwrap(); + controller.set_power_saving(PowerSaveMode::None).unwrap(); + + let mut socket_set_entries: [SocketStorage; 3] = Default::default(); + let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); + let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); + // we can set a hostname here (or add other DHCP options) + dhcp_socket.set_outgoing_options(&[DhcpOption { + kind: 12, + data: b"esp-wifi", + }]); + socket_set.add(dhcp_socket); + + let now = || esp_hal::time::now().duration_since_epoch().to_millis(); + let stack = Stack::new(iface, device, socket_set, now, rng.random()); + + let client_config = Configuration::Client(ClientConfiguration { + ssid: SSID.try_into().unwrap(), + password: PASSWORD.try_into().unwrap(), + ..Default::default() + }); + let res = controller.set_configuration(&client_config); + println!("wifi_set_configuration returned {:?}", res); + + controller.start().unwrap(); + println!("is wifi started: {:?}", controller.is_started()); + + println!("Start Wifi Scan"); + let res: Result<(heapless::Vec, usize), WifiError> = controller.scan_n(); + if let Ok((res, _count)) = res { + for ap in res { + println!("{:?}", ap); + } + } + + println!("{:?}", controller.capabilities()); + println!("wifi_connect {:?}", controller.connect()); + + // wait to get connected + println!("Wait to get connected"); + loop { + let res = controller.is_connected(); + match res { + Ok(connected) => { + if connected { + break; + } + } + Err(err) => { + println!("{:?}", err); + loop {} + } + } + } + println!("{:?}", controller.is_connected()); + + // wait for getting an ip address + println!("Wait to get an ip address"); + loop { + stack.work(); + + if stack.is_iface_up() { + println!("got ip {:?}", stack.get_ip_info()); + break; + } + } + + let mut rx_meta1 = [smoltcp::socket::udp::PacketMetadata::EMPTY; 10]; + let mut rx_buffer1 = [0u8; 1536]; + let mut tx_meta1 = [smoltcp::socket::udp::PacketMetadata::EMPTY; 10]; + let mut tx_buffer1 = [0u8; 1536]; + let mut udp_socket = stack.get_udp_socket( + &mut rx_meta1, + &mut rx_buffer1, + &mut tx_meta1, + &mut tx_buffer1, + ); + udp_socket.bind(50123).unwrap(); + + let mut rx_buffer = [0u8; 1536]; + let mut tx_buffer = [0u8; 1536]; + let mut socket = stack.get_socket(&mut rx_buffer, &mut tx_buffer); + + let mut incoming_tls = [0; INCOMING_TLS_BUFSIZE]; + let mut outgoing_tls = [0; OUTGOING_TLS_INITIAL_BUFSIZE]; + let mut plaintext_in = [0u8; INCOMING_TLS_BUFSIZE]; + let mut plaintext_out = [0u8; KB]; + + let unix_ts_now = get_current_unix_ts(&mut udp_socket); + let time_provider = EspTimeProvider::new(unix_ts_now); + println!("unix_ts_now = {}", unix_ts_now); + + let root_store = + rustls::RootCertStore::from_iter(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + + let mut config = rustls::ClientConfig::builder_with_details( + esp_rustls_provider::provider().into(), + alloc::sync::Arc::new(time_provider), + ) + .with_safe_default_protocol_versions() + .unwrap() + .with_root_certificates(root_store) + .with_no_client_auth(); + config.enable_early_data = true; + + let config = alloc::sync::Arc::new(config); + + println!("Start busy loop on main"); + + loop { + println!("Making HTTP request"); + socket.work(); + + socket + .open(IpAddress::Ipv4(Ipv4Address::new(142, 251, 36, 164)), 443) + .unwrap(); + + let server_name = "google.com".try_into().unwrap(); + let tls = ClientConnection::new( + config.clone(), + server_name, + socket, + &mut incoming_tls, + &mut outgoing_tls, + &mut plaintext_in, + &mut plaintext_out, + ); + + match tls { + Ok(mut tls) => { + tls.write(b"GET / HTTP/1.0\r\nHost: google.com\r\n\r\n") + .unwrap(); + tls.flush().unwrap(); + + let deadline = esp_hal::time::now() + Duration::secs(20); + loop { + let mut buffer = [0u8; 512]; + + let res = tls.read(&mut buffer); + if let Ok(len) = res { + let to_print = unsafe { core::str::from_utf8_unchecked(&buffer[..len]) }; + print!("{}", to_print); + } else { + println!("{:?}", res); + break; + } + + if esp_hal::time::now() > deadline { + println!("Timeout"); + break; + } + } + println!(); + + socket = tls.free(); + } + Err((s, _)) => { + socket = s; + } + } + + socket.disconnect(); + + let deadline = esp_hal::time::now() + Duration::secs(5); + while esp_hal::time::now() < deadline { + socket.work(); + } + + println!(); + } +} + +fn get_current_unix_ts<'s, 'n, 'd>( + udp_socket: &mut blocking_network_stack::UdpSocket<'s, 'n, WifiDevice<'d, WifiStaDevice>>, +) -> u32 { + let req_data = ntp_nostd::get_client_request(); + let mut rcvd_data = [0_u8; 1536]; + + udp_socket + // using ip from https://tf.nist.gov/tf-cgi/servers.cgi (time-b-wwv.nist.gov) + .send(Ipv4Address::new(132, 163, 97, 2).into(), 123, &req_data) + .unwrap(); + let mut count = 0; + + loop { + count += 1; + let rcvd = udp_socket.receive(&mut rcvd_data); + if rcvd.is_ok() { + break; + } + + let deadline = esp_hal::time::now() + Duration::secs(1); + while esp_hal::time::now() < deadline {} + + if count > 10 { + udp_socket + // retry with another server + // using ip from https://tf.nist.gov/tf-cgi/servers.cgi (time-b-g.nist.gov) + .send(Ipv4Address::new(129, 6, 15, 29).into(), 123, &req_data) + .unwrap(); + println!("Trying another NTP server..."); + count = 0; + } + } + let response = ntp_nostd::NtpServerResponse::from(rcvd_data.as_ref()); + if response.headers.tx_time_seconds == 0 { + panic!("No timestamp received"); + } + + response.headers.get_unix_timestamp() +} diff --git a/examples/src/bin/wifi_rustls_server.rs b/examples/src/bin/wifi_rustls_server.rs new file mode 100644 index 00000000000..cdde39b507b --- /dev/null +++ b/examples/src/bin/wifi_rustls_server.rs @@ -0,0 +1,378 @@ +//! Example using Rustls as a server +//! +//! This needs `alloc` +//! +//! Set SSID and PASSWORD env variable before running this example. +//! +//! This gets an ip address via DHCP then runs an HTTPS server on port 4443 + +//% FEATURES: esp-wifi esp-wifi/wifi esp-wifi/utils esp-hal/unstable esp-rustls-provider ntp-nostd +//% CHIPS: esp32 esp32s3 esp32c6 + +#![no_std] +#![no_main] + +extern crate alloc; + +use alloc::sync::Arc; + +use blocking_network_stack::Stack; +use embedded_io::*; +use esp_backtrace as _; +use esp_hal::{clock::CpuClock, main, rng::Rng, time::Duration, timer::timg::TimerGroup}; +use esp_println::{print, println}; +use esp_rustls_provider::{adapter::server::ServerConnection, rustls, EspTimeProvider}; +use esp_wifi::{ + config::PowerSaveMode, + wifi::{ + utils::create_network_interface, + AccessPointInfo, + ClientConfiguration, + Configuration, + WifiDevice, + WifiError, + WifiStaDevice, + }, +}; +use smoltcp::{ + iface::{SocketSet, SocketStorage}, + wire::{DhcpOption, Ipv4Address}, +}; + +const SSID: &str = env!("SSID"); +const PASSWORD: &str = env!("PASSWORD"); + +const KB: usize = 1024; +const INCOMING_TLS_BUFSIZE: usize = 16 * KB; +const OUTGOING_TLS_INITIAL_BUFSIZE: usize = KB; + +#[main] +fn main() -> ! { + esp_println::logger::init_logger_from_env(); + let config = esp_hal::Config::default().with_cpu_clock(CpuClock::max()); + let peripherals = esp_hal::init(config); + + static mut HEAP: core::mem::MaybeUninit<[u8; 72 * 1024]> = core::mem::MaybeUninit::uninit(); + + #[link_section = ".dram2_uninit"] + static mut HEAP2: core::mem::MaybeUninit<[u8; 64 * 1024]> = core::mem::MaybeUninit::uninit(); + + #[allow(static_mut_refs)] + unsafe { + esp_alloc::HEAP.add_region(esp_alloc::HeapRegion::new( + HEAP.as_mut_ptr() as *mut u8, + core::mem::size_of_val(&*core::ptr::addr_of!(HEAP)), + esp_alloc::MemoryCapability::Internal.into(), + )); + + // COEX needs more RAM - add some more + esp_alloc::HEAP.add_region(esp_alloc::HeapRegion::new( + HEAP2.as_mut_ptr() as *mut u8, + core::mem::size_of_val(&*core::ptr::addr_of!(HEAP2)), + esp_alloc::MemoryCapability::Internal.into(), + )); + } + + let timg0 = TimerGroup::new(peripherals.TIMG0); + + let mut rng = Rng::new(peripherals.RNG); + + let init = esp_wifi::init(timg0.timer0, rng.clone(), peripherals.RADIO_CLK).unwrap(); + + let mut wifi = peripherals.WIFI; + let (iface, device, mut controller) = + create_network_interface(&init, &mut wifi, WifiStaDevice).unwrap(); + controller.set_power_saving(PowerSaveMode::None).unwrap(); + + let mut socket_set_entries: [SocketStorage; 3] = Default::default(); + let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); + let mut dhcp_socket = smoltcp::socket::dhcpv4::Socket::new(); + // we can set a hostname here (or add other DHCP options) + dhcp_socket.set_outgoing_options(&[DhcpOption { + kind: 12, + data: b"esp-wifi", + }]); + socket_set.add(dhcp_socket); + + let now = || esp_hal::time::now().duration_since_epoch().to_millis(); + let stack = Stack::new(iface, device, socket_set, now, rng.random()); + + let client_config = Configuration::Client(ClientConfiguration { + ssid: SSID.try_into().unwrap(), + password: PASSWORD.try_into().unwrap(), + ..Default::default() + }); + let res = controller.set_configuration(&client_config); + println!("wifi_set_configuration returned {:?}", res); + + controller.start().unwrap(); + println!("is wifi started: {:?}", controller.is_started()); + + println!("Start Wifi Scan"); + let res: Result<(heapless::Vec, usize), WifiError> = controller.scan_n(); + if let Ok((res, _count)) = res { + for ap in res { + println!("{:?}", ap); + } + } + + println!("{:?}", controller.capabilities()); + println!("wifi_connect {:?}", controller.connect()); + + // wait to get connected + println!("Wait to get connected"); + loop { + let res = controller.is_connected(); + match res { + Ok(connected) => { + if connected { + break; + } + } + Err(err) => { + println!("{:?}", err); + loop {} + } + } + } + println!("{:?}", controller.is_connected()); + + // wait for getting an ip address + println!("Wait to get an ip address"); + loop { + stack.work(); + + if stack.is_iface_up() { + println!("got ip {:?}", stack.get_ip_info()); + break; + } + } + + let mut rx_meta1 = [smoltcp::socket::udp::PacketMetadata::EMPTY; 10]; + let mut rx_buffer1 = [0u8; 1536]; + let mut tx_meta1 = [smoltcp::socket::udp::PacketMetadata::EMPTY; 10]; + let mut tx_buffer1 = [0u8; 1536]; + let mut udp_socket = stack.get_udp_socket( + &mut rx_meta1, + &mut rx_buffer1, + &mut tx_meta1, + &mut tx_buffer1, + ); + udp_socket.bind(50123).unwrap(); + + let mut rx_buffer = [0u8; 1536]; + let mut tx_buffer = [0u8; 1536]; + let mut socket = stack.get_socket(&mut rx_buffer, &mut tx_buffer); + + let mut incoming_tls = [0; INCOMING_TLS_BUFSIZE]; + let mut outgoing_tls = [0; OUTGOING_TLS_INITIAL_BUFSIZE]; + let mut plaintext_in = [0u8; INCOMING_TLS_BUFSIZE]; + let mut plaintext_out = [0u8; KB]; + + let unix_ts_now = get_current_unix_ts(&mut udp_socket); + let time_provider = EspTimeProvider::new(unix_ts_now); + println!("unix_ts_now = {}", unix_ts_now); + + let pki = TestPki::new(); + + let server_config = rustls::ServerConfig::builder_with_details( + esp_rustls_provider::provider().into(), + Arc::new(time_provider), + ) + .with_safe_default_protocol_versions() + .unwrap() + .with_no_client_auth() + .with_single_cert(alloc::vec![pki.server_cert_der], pki.server_key_der) + .unwrap(); + + let server_config = alloc::sync::Arc::new(server_config); + + println!("Open https://{}:4443/", stack.get_ip_info().unwrap().ip); + + socket.listen(4443).unwrap(); + + loop { + socket.work(); + + if !socket.is_open() { + socket.listen(4443).unwrap(); + } + + if socket.is_connected() { + println!("Connected"); + + let tls = ServerConnection::new( + server_config.clone(), + socket, + &mut incoming_tls, + &mut outgoing_tls, + &mut plaintext_in, + &mut plaintext_out, + ); + + match tls { + Ok(mut tls) => { + let mut time_out = false; + let mut err = false; + let mut got_data = false; + let deadline = esp_hal::time::now() + Duration::secs(4); + let mut buffer = [0u8; 1024]; + let mut pos = 0; + loop { + let r = tls.read(&mut buffer[pos..]); + if let Ok(len) = r { + if len > 0 { + got_data = true; + } + + let to_print = + unsafe { core::str::from_utf8_unchecked(&buffer[..(pos + len)]) }; + + if to_print.contains("\r\n\r\n") { + print!("{}", to_print); + println!(); + break; + } + + pos += len; + + if len == 0 && got_data { + break; + } + } else { + println!("{:?}", r); + err = true; + break; + } + + if esp_hal::time::now() > deadline { + println!("Timeout"); + time_out = true; + break; + } + } + + if !time_out && !err { + tls.write_all( + b"HTTP/1.0 200 OK\r\n\r\n\ + \ + \ +

Hello Rust! Hello Rustls!

\ + + \ + \r\n\ + " + ).ok(); + + tls.flush().ok(); + } + + socket = tls.free(); + } + Err((s, _)) => { + socket = s; + } + } + + let deadline = esp_hal::time::now() + Duration::secs(1); + while esp_hal::time::now() < deadline { + socket.work(); + } + + socket.disconnect(); + socket.close(); + + println!("Done\n"); + println!(); + } + } +} + +fn get_current_unix_ts<'s, 'n, 'd>( + udp_socket: &mut blocking_network_stack::UdpSocket<'s, 'n, WifiDevice<'d, WifiStaDevice>>, +) -> u32 { + let req_data = ntp_nostd::get_client_request(); + let mut rcvd_data = [0_u8; 1536]; + + udp_socket + // using ip from https://tf.nist.gov/tf-cgi/servers.cgi (time-b-wwv.nist.gov) + .send(Ipv4Address::new(132, 163, 97, 2).into(), 123, &req_data) + .unwrap(); + let mut count = 0; + + loop { + count += 1; + let rcvd = udp_socket.receive(&mut rcvd_data); + if rcvd.is_ok() { + break; + } + + let deadline = esp_hal::time::now() + Duration::secs(1); + while esp_hal::time::now() < deadline {} + + if count > 10 { + udp_socket + // retry with another server + // using ip from https://tf.nist.gov/tf-cgi/servers.cgi (time-b-g.nist.gov) + .send(Ipv4Address::new(129, 6, 15, 29).into(), 123, &req_data) + .unwrap(); + println!("Trying another NTP server..."); + count = 0; + } + } + let response = ntp_nostd::NtpServerResponse::from(rcvd_data.as_ref()); + if response.headers.tx_time_seconds == 0 { + panic!("No timestamp received"); + } + + response.headers.get_unix_timestamp() +} + +struct TestPki { + server_cert_der: rustls::pki_types::CertificateDer<'static>, + server_key_der: rustls::pki_types::PrivateKeyDer<'static>, +} + +impl TestPki { + fn new() -> Self { + static CERT: &[u8] = &[ + 48u8, 130, 1, 138, 48, 130, 1, 47, 160, 3, 2, 1, 2, 2, 20, 80, 89, 193, 6, 153, 91, 16, + 212, 128, 110, 89, 53, 108, 183, 139, 150, 172, 160, 238, 160, 48, 10, 6, 8, 42, 134, + 72, 206, 61, 4, 3, 2, 48, 55, 49, 19, 48, 17, 6, 3, 85, 4, 3, 12, 10, 69, 120, 97, 109, + 112, 108, 101, 32, 67, 65, 49, 32, 48, 30, 6, 3, 85, 4, 10, 12, 23, 80, 114, 111, 118, + 105, 100, 101, 114, 32, 83, 101, 114, 118, 101, 114, 32, 69, 120, 97, 109, 112, 108, + 101, 48, 32, 23, 13, 55, 53, 48, 49, 48, 49, 48, 48, 48, 48, 48, 48, 90, 24, 15, 52, + 48, 57, 54, 48, 49, 48, 49, 48, 48, 48, 48, 48, 48, 90, 48, 33, 49, 31, 48, 29, 6, 3, + 85, 4, 3, 12, 22, 114, 99, 103, 101, 110, 32, 115, 101, 108, 102, 32, 115, 105, 103, + 110, 101, 100, 32, 99, 101, 114, 116, 48, 89, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, + 6, 8, 42, 134, 72, 206, 61, 3, 1, 7, 3, 66, 0, 4, 242, 150, 10, 198, 8, 151, 136, 40, + 123, 104, 14, 246, 178, 151, 176, 193, 229, 222, 187, 216, 154, 223, 176, 221, 103, 87, + 75, 171, 64, 29, 30, 70, 47, 34, 93, 94, 18, 94, 19, 252, 33, 12, 249, 145, 104, 6, 75, + 46, 111, 147, 253, 99, 206, 64, 83, 126, 244, 34, 54, 112, 83, 87, 5, 18, 163, 45, 48, + 43, 48, 20, 6, 3, 85, 29, 17, 4, 13, 48, 11, 130, 9, 108, 111, 99, 97, 108, 104, 111, + 115, 116, 48, 19, 6, 3, 85, 29, 37, 4, 12, 48, 10, 6, 8, 43, 6, 1, 5, 5, 7, 3, 1, 48, + 10, 6, 8, 42, 134, 72, 206, 61, 4, 3, 2, 3, 73, 0, 48, 70, 2, 33, 0, 216, 8, 139, 33, + 131, 48, 10, 188, 169, 173, 34, 67, 18, 208, 37, 151, 185, 90, 170, 248, 53, 82, 136, + 220, 192, 89, 23, 88, 152, 107, 64, 171, 2, 33, 0, 142, 29, 9, 154, 225, 201, 34, 174, + 248, 51, 216, 50, 205, 46, 103, 243, 155, 190, 125, 115, 95, 8, 45, 40, 72, 146, 113, + 74, 226, 15, 56, 215, + ]; + + static KEY: &[u8] = &[ + 48u8, 129, 135, 2, 1, 0, 48, 19, 6, 7, 42, 134, 72, 206, 61, 2, 1, 6, 8, 42, 134, 72, + 206, 61, 3, 1, 7, 4, 109, 48, 107, 2, 1, 1, 4, 32, 238, 131, 128, 95, 255, 63, 243, 16, + 105, 169, 88, 28, 10, 118, 255, 188, 38, 9, 49, 102, 178, 232, 146, 94, 165, 117, 24, + 179, 118, 197, 4, 77, 161, 68, 3, 66, 0, 4, 242, 150, 10, 198, 8, 151, 136, 40, 123, + 104, 14, 246, 178, 151, 176, 193, 229, 222, 187, 216, 154, 223, 176, 221, 103, 87, 75, + 171, 64, 29, 30, 70, 47, 34, 93, 94, 18, 94, 19, 252, 33, 12, 249, 145, 104, 6, 75, 46, + 111, 147, 253, 99, 206, 64, 83, 126, 244, 34, 54, 112, 83, 87, 5, 18, + ]; + + Self { + server_cert_der: rustls::pki_types::CertificateDer::from(CERT), + server_key_der: rustls::pki_types::PrivateKeyDer::Pkcs8( + rustls::pki_types::PrivatePkcs8KeyDer::from(KEY), + ), + } + } +} diff --git a/xtask/src/lib.rs b/xtask/src/lib.rs index 0667dfece2e..4016b120b25 100644 --- a/xtask/src/lib.rs +++ b/xtask/src/lib.rs @@ -45,6 +45,7 @@ pub enum Package { EspMetadata, EspPrintln, EspRiscvRt, + EspRustlsProvider, EspStorage, EspWifi, Examples, diff --git a/xtask/src/main.rs b/xtask/src/main.rs index 1ecac248404..74e155db39d 100644 --- a/xtask/src/main.rs +++ b/xtask/src/main.rs @@ -791,6 +791,22 @@ fn lint_packages(workspace: &Path, args: LintPackagesArgs) -> Result<()> { } } + Package::EspRustlsProvider => { + if [Chip::Esp32, Chip::Esp32s3, Chip::Esp32c6].contains(chip) { + let features = format!("--features=esp-hal/{chip},esp-hal/unstable"); + lint_package( + chip, + &path, + &[ + "-Zbuild-std=core,alloc", + &features, + &format!("--target={}", chip.target()), + ], + args.fix, + )?; + } + } + Package::EspStorage => { lint_package( chip,