Skip to content
Merged
10 changes: 9 additions & 1 deletion mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cc = "1.0"
default = ["std", "aesni", "time", "padlock"]
std = ["mbedtls-sys-auto/std", "serde/std", "yasna"]
debug = ["mbedtls-sys-auto/debug"]
no_std_deps = ["core_io", "spin"]
no_std_deps = ["core_io", "spin", "serde/alloc"]
force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"]
mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"]
rdrand = []
Expand All @@ -73,6 +73,14 @@ pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"]
name = "client"
required-features = ["std"]

[[example]]
name = "client_dtls"
required-features = ["std"]

[[example]]
name = "client_psk"
required-features = ["std"]

[[example]]
name = "server"
required-features = ["std"]
Expand Down
57 changes: 57 additions & 0 deletions mbedtls/examples/client_dtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::UdpSocket;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::x509::Certificate;
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;
use support::keys;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let cert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?);
let mut config = Config::new(Endpoint::Client, Transport::Datagram, Preset::Default);
config.set_rng(rng);
config.set_ca_list(cert, None);
let mut ctx = Context::new(Arc::new(config));
ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new()));

let sock = UdpSocket::bind("localhost:12345").unwrap();
let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap();
ctx.establish(sock, None).unwrap();

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
52 changes: 52 additions & 0 deletions mbedtls/examples/client_psk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::TcpStream;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client").unwrap();
let mut ctx = Context::new(Arc::new(config));

let conn = TcpStream::connect(addr).unwrap();
ctx.establish(conn, None)?;

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
16 changes: 15 additions & 1 deletion mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ impl Config {
self.dbg_callback = Some(Arc::new(cb));
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
}

/// Sets the PSK and the PSK-Identity
///
/// Only a single entry is supported at the moment. If another one was set before, it will be
/// overridden.
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
unsafe {
// This allocates and copies the buffers and does not store any pointer to them
let psk_identity = psk_identity.as_bytes();
ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len())
.into_result()
.map(|_| ())?;
}
Ok(())
}
}

// TODO
Expand All @@ -466,7 +481,6 @@ impl Config {
// ssl_conf_dtls_badmac_limit
// ssl_conf_handshake_timeout
// ssl_conf_session_cache
// ssl_conf_psk
// ssl_conf_psk_cb
// ssl_conf_sig_hashes
// ssl_conf_alpn_protocols
Expand Down
131 changes: 128 additions & 3 deletions mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use core::result::Result as StdResult;

#[cfg(feature = "std")]
use {
std::io::{Read, Write, Result as IoResult},
std::io::{Read, Write, Result as IoResult, Error as IoError},
std::sync::Arc,
};

Expand Down Expand Up @@ -67,6 +67,121 @@ impl<IO: Read + Write> IoCallback for IO {
}
}

#[cfg(feature = "std")]
pub struct ConnectedUdpSocket {
socket: std::net::UdpSocket,
}

#[cfg(feature = "std")]
impl ConnectedUdpSocket {
pub fn connect<A: std::net::ToSocketAddrs>(socket: std::net::UdpSocket, addr: A) -> StdResult<Self, (IoError, std::net::UdpSocket)> {
match socket.connect(addr) {
Ok(_) => Ok(ConnectedUdpSocket {
socket,
}),
Err(e) => Err((e, socket)),
}
}
}

#[cfg(feature = "std")]
impl IoCallback for ConnectedUdpSocket {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}

unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}

fn data_ptr(&mut self) -> *mut c_void {
self as *mut ConnectedUdpSocket as *mut c_void
}
}

pub trait TimerCallback: Send + Sync {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized;

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized;

fn data_ptr(&mut self) -> *mut c_void;
}

#[cfg(feature = "std")]
pub struct Timer {
timer_start: std::time::Instant,
timer_int_ms: u32,
timer_fin_ms: u32,
}

#[cfg(feature = "std")]
impl Timer {
pub fn new() -> Self {
Timer {
timer_start: std::time::Instant::now(),
timer_int_ms: 0,
timer_fin_ms: 0,
}
}
}

#[cfg(feature = "std")]
impl TimerCallback for Timer {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
slf.timer_start = std::time::Instant::now();
slf.timer_int_ms = int_ms;
slf.timer_fin_ms = fin_ms;
}

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
if slf.timer_int_ms == 0 || slf.timer_fin_ms == 0 {
return 0;
}
let passed = std::time::Instant::now() - slf.timer_start;
if passed.as_millis() >= slf.timer_fin_ms.into() {
2
} else if passed.as_millis() >= slf.timer_int_ms.into() {
1
} else {
0
}
}

fn data_ptr(&mut self) -> *mut mbedtls_sys::types::raw_types::c_void {
self as *mut _ as *mut _
}
}

define!(
#[c_ty(ssl_context)]
Expand All @@ -89,11 +204,13 @@ pub struct Context<T> {
// Base structure used in SNI callback where we cannot determine the io type.
inner: HandshakeContext,

// config is used read-only for mutliple contexts and is immutable once configured.
// config is used read-only for multiple contexts and is immutable once configured.
config: Arc<Config>,

// Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated.
io: Option<Box<T>>,

timer_callback: Option<Box<dyn TimerCallback>>,
}

impl<'a, T> Into<*const ssl_context> for &'a Context<T> {
Expand Down Expand Up @@ -128,6 +245,7 @@ impl<T> Context<T> {
},
config: config.clone(),
io: None,
timer_callback: None,
}
}

Expand Down Expand Up @@ -157,7 +275,7 @@ impl<T: IoCallback> Context<T> {
);

self.io = Some(io);
self.inner.reset_handshake();
self.inner.reset_handshake();
}

match self.handshake() {
Expand Down Expand Up @@ -298,6 +416,13 @@ impl<T> Context<T> {
}
}
}

pub fn set_timer_callback<F: TimerCallback + 'static>(&mut self, mut cb: Box<F>) {
unsafe {
ssl_set_timer_cb(self.into(), cb.data_ptr(), Some(F::set_timer), Some(F::get_timer));
}
self.timer_callback = Some(cb);
}
}

impl<T> Drop for Context<T> {
Expand Down