Skip to content
Merged
10 changes: 9 additions & 1 deletion mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cc = "1.0"
default = ["std", "aesni", "time", "padlock"]
std = ["mbedtls-sys-auto/std", "serde/std", "yasna"]
debug = ["mbedtls-sys-auto/debug"]
no_std_deps = ["core_io", "spin"]
no_std_deps = ["core_io", "spin", "serde/alloc"]
force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"]
mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"]
rdrand = []
Expand All @@ -73,6 +73,14 @@ pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"]
name = "client"
required-features = ["std"]

[[example]]
name = "client_dtls"
required-features = ["std"]

[[example]]
name = "client_psk"
required-features = ["std"]

[[example]]
name = "server"
required-features = ["std"]
Expand Down
57 changes: 57 additions & 0 deletions mbedtls/examples/client_dtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::UdpSocket;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::x509::Certificate;
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;
use support::keys;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let cert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?);
let mut config = Config::new(Endpoint::Client, Transport::Datagram, Preset::Default);
config.set_rng(rng);
config.set_ca_list(cert, None);
let mut ctx = Context::new(Arc::new(config));
ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new()));

let sock = UdpSocket::bind("localhost:12345").unwrap();
let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap();
ctx.establish(sock, None).unwrap();

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
52 changes: 52 additions & 0 deletions mbedtls/examples/client_psk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::TcpStream;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client").unwrap();
let mut ctx = Context::new(Arc::new(config));

let conn = TcpStream::connect(addr).unwrap();
ctx.establish(conn, None)?;

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
16 changes: 15 additions & 1 deletion mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ impl Config {
self.dbg_callback = Some(Arc::new(cb));
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
}

/// Sets the PSK and the PSK-Identity
///
/// Only a single entry is supported at the moment. If another one was set before, it will be
/// overridden.
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
unsafe {
// This allocates and copies the buffers and does not store any pointer to them
let psk_identity = psk_identity.as_bytes();
ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len())
.into_result()
.map(|_| ())?;
}
Ok(())
}
}

// TODO
Expand All @@ -466,7 +481,6 @@ impl Config {
// ssl_conf_dtls_badmac_limit
// ssl_conf_handshake_timeout
// ssl_conf_session_cache
// ssl_conf_psk
// ssl_conf_psk_cb
// ssl_conf_sig_hashes
// ssl_conf_alpn_protocols
Expand Down
129 changes: 127 additions & 2 deletions mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,121 @@ impl<IO: Read + Write> IoCallback for IO {
}
}

#[cfg(feature = "std")]
pub struct ConnectedUdpSocket {
socket: std::net::UdpSocket,
}

#[cfg(feature = "std")]
impl ConnectedUdpSocket {
pub fn connect<A: std::net::ToSocketAddrs>(socket: std::net::UdpSocket, addr: A) -> std::result::Result<Self, (std::io::Error, std::net::UdpSocket)> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

We already use core::result::Result as StdResult in this file, and since here we have std in this block let's keep with StdResult.

Similarly, we can import std::io::Error as IoError and use that.

(note: https://github.com/rust-lang/rust/blob/5807fbbfde3ad04820f6fa0269711c81538057ec/src/libstd/lib.rs#L332-L333)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just giving evidence of the fact that std exports core.

match socket.connect(addr) {
Ok(_) => Ok(ConnectedUdpSocket {
socket,
}),
Err(e) => Err((e, socket)),
}
}
}

#[cfg(feature = "std")]
impl IoCallback for ConnectedUdpSocket {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}

unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}

fn data_ptr(&mut self) -> *mut c_void {
self as *mut ConnectedUdpSocket as *mut c_void
}
}

pub trait TimerCallback: Send + Sync {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized;

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized;

fn data_ptr(&mut self) -> *mut c_void;
}

#[cfg(feature = "std")]
pub struct Timer {
timer_start: std::time::Instant,
timer_int_ms: u32,
timer_fin_ms: u32,
}

#[cfg(feature = "std")]
impl Timer {
pub fn new() -> Self {
Timer {
timer_start: std::time::Instant::now(),
timer_int_ms: 0,
timer_fin_ms: 0,
}
}
}

#[cfg(feature = "std")]
impl TimerCallback for Timer {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
slf.timer_start = std::time::Instant::now();
slf.timer_int_ms = int_ms;
slf.timer_fin_ms = fin_ms;
}

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
if slf.timer_int_ms == 0 || slf.timer_fin_ms == 0 {
return 0;
}
let passed = std::time::Instant::now() - slf.timer_start;
if passed.as_millis() >= slf.timer_fin_ms.into() {
2
} else if passed.as_millis() >= slf.timer_int_ms.into() {
1
} else {
0
}
}

fn data_ptr(&mut self) -> *mut mbedtls_sys::types::raw_types::c_void {
self as *mut _ as *mut _
}
}

define!(
#[c_ty(ssl_context)]
Expand All @@ -89,11 +204,13 @@ pub struct Context<T> {
// Base structure used in SNI callback where we cannot determine the io type.
inner: HandshakeContext,

// config is used read-only for mutliple contexts and is immutable once configured.
// config is used read-only for multiple contexts and is immutable once configured.
config: Arc<Config>,

// Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated.
io: Option<Box<T>>,

timer_callback: Option<Box<dyn TimerCallback>>,
}

impl<'a, T> Into<*const ssl_context> for &'a Context<T> {
Expand Down Expand Up @@ -128,6 +245,7 @@ impl<T> Context<T> {
},
config: config.clone(),
io: None,
timer_callback: None,
}
}

Expand Down Expand Up @@ -157,7 +275,7 @@ impl<T: IoCallback> Context<T> {
);

self.io = Some(io);
self.inner.reset_handshake();
self.inner.reset_handshake();
}

match self.handshake() {
Expand Down Expand Up @@ -298,6 +416,13 @@ impl<T> Context<T> {
}
}
}

pub fn set_timer_callback<F: TimerCallback + 'static>(&mut self, mut cb: Box<F>) {
unsafe {
ssl_set_timer_cb(self.into(), cb.data_ptr(), Some(F::set_timer), Some(F::get_timer));
}
self.timer_callback = Some(cb);
}
}

impl<T> Drop for Context<T> {
Expand Down