Skip to content

Commit 1f1ab0d

Browse files
committed
feat(rust): add possibility to change socket buffer size
1 parent 6779d86 commit 1f1ab0d

File tree

8 files changed

+126
-23
lines changed

8 files changed

+126
-23
lines changed

implementations/rust/ockam/ockam_command/src/environment/static/env_info.txt

+2
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,12 @@ UDP
5454
TCP
5555
- OCKAM_PRIVILEGED: if variable is set, all TCP Inlets/Outlets will use eBPF (overrides `--privileged` argument for `ockam tcp-inlet create` and `ockam tcp-outlet create`). WARNING: This flag value should be equal on both ends of a portal (inlet and outlet)
5656
- OCKAM_TCP_PORTAL_PAYLOAD_LENGTH: size of the buffer into which TCP Portal reads the TCP stream. Default value: `128 * 1024`
57+
- OCKAM_TCP_PORTAL_SOCKET_LENGTH: // FIXME
5758
- OCKAM_TCP_PORTAL_SKIP_HANDSHAKE: skip Portal handshake for lower latency, but also lower throughput. WARNING: This flag value should be equal on both ends of a portal (inlet and outlet)
5859
- OCKAM_TCP_PORTAL_ENABLE_NAGLE: enable Nagle's algorithm for Portal TCP streams for potentially higher throughput, but higher latency
5960
- OCKAM_TCP_PORTAL_ENABLE_MPTCP: // FIXME
6061
- OCKAM_ENABLE_MPTCP: // FIXME
62+
- OCKAM_TCP_SOCKET_LENGTH: // FIXME
6163

6264
Devs Usage
6365
- OCKAM: a `string` that defines the path to the ockam binary to use.

implementations/rust/ockam/ockam_transport_core/src/error.rs

+4
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ pub enum TransportError {
8282
InvalidOckamPortalPacket(String),
8383
/// Connection timeout
8484
ConnectionTimeout,
85+
/// Invalid socket options
86+
SockOpt(String),
8587
}
8688

8789
impl ockam_core::compat::error::Error for TransportError {}
@@ -133,6 +135,7 @@ impl core::fmt::Display for TransportError {
133135
),
134136
Self::InvalidOckamPortalPacket(e) => write!(f, "invalid OckamPortalPacket: {}", e),
135137
Self::ConnectionTimeout => write!(f, "connection timed out"),
138+
Self::SockOpt(e) => write!(f, "socket options failed: {}", e),
136139
}
137140
}
138141
}
@@ -175,6 +178,7 @@ impl From<TransportError> for Error {
175178
IdentifierChanged => Kind::Conflict,
176179
InvalidOckamPortalPacket(_) => Kind::Invalid,
177180
ConnectionTimeout => Kind::Io,
181+
SockOpt(_) => Kind::Io,
178182
};
179183

180184
Error::new(Origin::Transport, kind, err)

implementations/rust/ockam/ockam_transport_tcp/src/options.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::workers::Addresses;
22
use ockam_core::compat::sync::Arc;
33
use ockam_core::compat::time::Duration;
4+
use ockam_core::env::get_env;
45
use ockam_core::flow_control::{FlowControlId, FlowControlOutgoingAccessControl, FlowControls};
56
use ockam_core::{Address, OutgoingAccessControl};
67

@@ -11,17 +12,20 @@ pub struct TcpConnectionOptions {
1112
pub(super) consumer: Vec<FlowControlId>,
1213
pub(crate) flow_control_id: FlowControlId,
1314
pub(crate) enable_mptcp: bool,
15+
pub(crate) buffer_size: Option<usize>,
1416
}
1517

1618
impl TcpConnectionOptions {
1719
#[allow(clippy::new_without_default)]
1820
/// Mark this Tcp Receiver as a Producer with a random [`FlowControlId`]
1921
pub fn new() -> Self {
22+
let buffer_size = get_env("OCKAM_TCP_SOCKET_LENGTH").ok().flatten();
2023
Self {
2124
timeout: None,
2225
consumer: vec![],
2326
flow_control_id: FlowControls::generate_flow_control_id(),
2427
enable_mptcp: false,
28+
buffer_size,
2529
}
2630
}
2731

@@ -65,6 +69,12 @@ impl TcpConnectionOptions {
6569
self.enable_mptcp = true;
6670
self
6771
}
72+
73+
/// Set socket buffer size
74+
pub fn set_buffer_size(mut self, buffer_size: Option<usize>) -> Self {
75+
self.buffer_size = buffer_size;
76+
self
77+
}
6878
}
6979

7080
impl TcpConnectionOptions {
@@ -98,6 +108,7 @@ impl TcpConnectionOptions {
98108
pub struct TcpListenerOptions {
99109
pub(crate) flow_control_id: FlowControlId,
100110
pub(crate) enable_mptcp: bool,
111+
pub(crate) buffer_size: Option<usize>,
101112
}
102113

103114
impl TcpListenerOptions {
@@ -106,9 +117,11 @@ impl TcpListenerOptions {
106117
/// with Spawner's [`FlowControlId`]
107118
#[allow(clippy::new_without_default)]
108119
pub fn new() -> Self {
120+
let buffer_size = get_env("OCKAM_TCP_SOCKET_LENGTH").ok().flatten();
109121
Self {
110122
flow_control_id: FlowControls::generate_flow_control_id(),
111123
enable_mptcp: false,
124+
buffer_size,
112125
}
113126
}
114127

@@ -128,6 +141,12 @@ impl TcpListenerOptions {
128141
self.enable_mptcp = true;
129142
self
130143
}
144+
145+
/// Set socket buffer size
146+
pub fn set_buffer_size(mut self, buffer_size: Option<usize>) -> Self {
147+
self.buffer_size = buffer_size;
148+
self
149+
}
131150
}
132151

133152
impl TcpListenerOptions {

implementations/rust/ockam/ockam_transport_tcp/src/portal/inlet_listener.rs

+8
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
11
use crate::portal::addresses::{Addresses, PortalType};
22
use crate::portal::tls_certificate::TlsCertificateProvider;
33
use crate::portal::{InletSharedState, ReadHalfMaybeTls, WriteHalfMaybeTls};
4+
use crate::transport::set_socket_buffer_size;
45
use crate::{portal::TcpPortalWorker, TcpInlet, TcpInletOptions, TcpRegistry};
56
use log::warn;
67
use ockam_core::compat::net::SocketAddr;
78
use ockam_core::compat::sync::{Arc, RwLock as SyncRwLock};
9+
use ockam_core::env::get_env;
810
use ockam_core::errcode::{Kind, Origin};
911
use ockam_core::{async_trait, compat::boxed::Box, Result};
1012
use ockam_core::{Address, Processor, Route};
1113
use ockam_node::Context;
1214
use ockam_transport_core::{HostnamePort, TransportError};
1315
use rustls::pki_types::CertificateDer;
1416
use std::io::BufReader;
17+
use std::os::fd::AsRawFd;
1518
use std::time::Duration;
1619
use tokio::net::TcpListener;
1720
use tokio::time::Instant;
@@ -64,6 +67,11 @@ impl TcpInletListenProcessor {
6467
return Err(TransportError::from(err))?;
6568
}
6669
};
70+
71+
if let Ok(Some(buffer_size)) = get_env::<usize>("OCKAM_TCP_PORTAL_SOCKET_LENGTH") {
72+
set_socket_buffer_size(inner.as_raw_fd(), buffer_size)?;
73+
}
74+
6775
let socket_addr = inner.local_addr().map_err(TransportError::from)?;
6876
let inlet_shared_state =
6977
InletSharedState::create(ctx, outlet_listener_route, options.is_paused)?;

implementations/rust/ockam/ockam_transport_tcp/src/portal/portal_worker.rs

+12-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use crate::portal::portal_worker::WriteHalfMaybeTls::{WriteHalfNoTls, WriteHalfW
55
use crate::transport::{connect_tcp, connect_tls};
66
use crate::{portal::TcpPortalRecvProcessor, PortalInternalMessage, PortalMessage, TcpRegistry};
77
use ockam_core::compat::{boxed::Box, sync::Arc};
8+
use ockam_core::env::get_env;
89
use ockam_core::{
910
async_trait, AllowAll, AllowOnwardAddress, AllowSourceAddress, Decodable, DenyAll,
1011
IncomingAccessControl, LocalInfoIdentifier, Mailbox, Mailboxes, OutgoingAccessControl,
@@ -469,10 +470,18 @@ impl TcpPortalWorker {
469470
}
470471

471472
async fn connect(&mut self) -> Result<()> {
473+
let buffer_size = get_env::<usize>("OCKAM_TCP_PORTAL_SOCKET_LENGTH")
474+
.ok()
475+
.flatten();
472476
if self.is_tls {
473477
debug!(portal_type = %self.portal_type, sender_internal = %self.addresses.sender_internal, "connect to {} via TLS", &self.hostname_port);
474-
let (rx, tx) =
475-
connect_tls(&self.hostname_port, self.enable_mptcp, self.enable_nagle).await?;
478+
let (rx, tx) = connect_tls(
479+
&self.hostname_port,
480+
self.enable_mptcp,
481+
self.enable_nagle,
482+
buffer_size,
483+
)
484+
.await?;
476485
self.write_half = Some(WriteHalfWithTls(tx));
477486
self.read_half = Some(ReadHalfWithTls(rx));
478487
} else {
@@ -482,6 +491,7 @@ impl TcpPortalWorker {
482491
self.enable_mptcp,
483492
self.enable_nagle,
484493
None,
494+
buffer_size,
485495
)
486496
.await?
487497
.into_split();

implementations/rust/ockam/ockam_transport_tcp/src/transport/common.rs

+79-19
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use ockam_core::{Error, Result};
55
use ockam_transport_core::{HostnamePort, TransportError};
66
use socket2::{SockRef, TcpKeepalive};
77
use std::net::SocketAddr;
8+
use std::os::fd::{AsRawFd, RawFd};
89
use std::sync::Arc;
910
use std::time::Duration;
1011
use tokio::io::{ReadHalf, WriteHalf};
@@ -14,47 +15,104 @@ use tokio_rustls::rustls::{ClientConfig, RootCertStore};
1415
use tokio_rustls::{TlsConnector, TlsStream};
1516
use tracing::{debug, instrument};
1617

17-
pub(crate) async fn bind_tcp_listener(at: SocketAddr, enable_mptcp: bool) -> Result<TcpListener> {
18-
if !enable_mptcp {
19-
return Ok(TcpListener::bind(&at).await.map_err(TransportError::from)?);
18+
pub fn set_socket_buffer_size(socket: RawFd, buffer_size: usize) -> Result<()> {
19+
let buffer_size = buffer_size as nix::libc::c_int;
20+
21+
let res = unsafe {
22+
#[allow(trivial_casts)]
23+
nix::libc::setsockopt(
24+
socket.as_raw_fd(),
25+
nix::libc::SOL_SOCKET,
26+
nix::libc::SO_RCVBUF,
27+
(&buffer_size as *const nix::libc::c_int) as *const nix::libc::c_void,
28+
size_of::<nix::libc::c_int>() as nix::libc::socklen_t,
29+
)
30+
};
31+
32+
assert!(res >= 0,);
33+
34+
if res < 0 {
35+
return Err(TransportError::SockOpt("TCP portal receive buffer error".to_string()).into());
2036
}
2137

22-
let socket = bind_mptcp(at).await.map_err(TransportError::from)?;
38+
let res = unsafe {
39+
#[allow(trivial_casts)]
40+
nix::libc::setsockopt(
41+
socket.as_raw_fd(),
42+
nix::libc::SOL_SOCKET,
43+
nix::libc::SO_SNDBUF,
44+
(&buffer_size as *const nix::libc::c_int) as *const nix::libc::c_void,
45+
size_of::<nix::libc::c_int>() as nix::libc::socklen_t,
46+
)
47+
};
48+
49+
if res < 0 {
50+
return Err(TransportError::SockOpt("TCP portal send buffer error".to_string()).into());
51+
}
2352

24-
Ok(socket)
53+
Ok(())
2554
}
2655

27-
async fn create_tcp_stream(to: &HostnamePort, enable_mptcp: bool) -> Result<TcpStream> {
28-
if !enable_mptcp {
29-
return Ok(TcpStream::connect(to.to_string())
30-
.await
31-
.map_err(TransportError::from)?);
56+
pub(crate) async fn bind_tcp_listener(
57+
at: SocketAddr,
58+
enable_mptcp: bool,
59+
buffer_size: Option<usize>,
60+
) -> Result<TcpListener> {
61+
let listener = if !enable_mptcp {
62+
TcpListener::bind(&at).await.map_err(TransportError::from)?
63+
} else {
64+
bind_mptcp(at).await.map_err(TransportError::from)?
65+
};
66+
67+
if let Some(buffer_size) = buffer_size {
68+
set_socket_buffer_size(listener.as_raw_fd(), buffer_size)?;
3269
}
3370

34-
// TODO: Add timeout
35-
let socket = connect_mptcp(to.to_string())
36-
.await
37-
.map_err(TransportError::from)?;
71+
Ok(listener)
72+
}
73+
74+
async fn create_tcp_stream(
75+
to: &HostnamePort,
76+
enable_mptcp: bool,
77+
buffer_size: Option<usize>,
78+
) -> Result<TcpStream> {
79+
let stream = if !enable_mptcp {
80+
TcpStream::connect(to.to_string())
81+
.await
82+
.map_err(TransportError::from)?
83+
} else {
84+
// TODO: Add timeout
85+
connect_mptcp(to.to_string())
86+
.await
87+
.map_err(TransportError::from)?
88+
};
89+
90+
if let Some(buffer_size) = buffer_size {
91+
set_socket_buffer_size(stream.as_raw_fd(), buffer_size)?;
92+
}
3893

39-
Ok(socket)
94+
Ok(stream)
4095
}
4196

4297
async fn create_tcp_stream_timeout(
4398
to: &HostnamePort,
4499
enable_mptcp: bool,
45100
timeout: Option<Duration>,
101+
buffer_size: Option<usize>,
46102
) -> Result<TcpStream> {
47103
match timeout {
48104
Some(timeout) => {
49-
match tokio::time::timeout(timeout, create_tcp_stream(to, enable_mptcp)).await {
105+
match tokio::time::timeout(timeout, create_tcp_stream(to, enable_mptcp, buffer_size))
106+
.await
107+
{
50108
Ok(result) => result,
51109
Err(_) => {
52110
debug!(addr = %to, timeout = %timeout.as_secs(), "Timeout");
53111
Err(TransportError::ConnectionTimeout)?
54112
}
55113
}
56114
}
57-
None => create_tcp_stream(to, enable_mptcp).await,
115+
None => create_tcp_stream(to, enable_mptcp, buffer_size).await,
58116
}
59117
}
60118

@@ -64,10 +122,11 @@ pub(crate) async fn connect_tcp(
64122
enable_mptcp: bool,
65123
enable_nagle: bool,
66124
timeout: Option<Duration>,
125+
buffer_size: Option<usize>,
67126
) -> Result<TcpStream> {
68127
debug!(addr = %to, "Connecting");
69128

70-
let result = create_tcp_stream_timeout(to, enable_mptcp, timeout).await;
129+
let result = create_tcp_stream_timeout(to, enable_mptcp, timeout, buffer_size).await;
71130

72131
let connection = match result {
73132
Ok(c) => {
@@ -109,14 +168,15 @@ pub(crate) async fn connect_tls(
109168
to: &HostnamePort,
110169
enable_mptcp: bool,
111170
enable_nagle: bool,
171+
buffer_size: Option<usize>,
112172
) -> Result<(
113173
ReadHalf<TlsStream<TcpStream>>,
114174
WriteHalf<TlsStream<TcpStream>>,
115175
)> {
116176
debug!(to = %to, "Trying to connect using TLS");
117177

118178
// create a tcp stream
119-
let connection = connect_tcp(to, enable_mptcp, enable_nagle, None).await?;
179+
let connection = connect_tcp(to, enable_mptcp, enable_nagle, None, buffer_size).await?;
120180

121181
// create a TLS connector
122182
let tls_connector = create_tls_connector().await?;

implementations/rust/ockam/ockam_transport_tcp/src/transport/connection.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ impl TcpTransport {
112112
debug!("Connecting to {}", peer.clone());
113113

114114
let (read_half, write_half) =
115-
connect_tcp(&peer, options.enable_mptcp, false, options.timeout)
115+
connect_tcp(&peer, options.enable_mptcp, false, options.timeout, None)
116116
.await?
117117
.into_split();
118118
let socket = read_half

implementations/rust/ockam/ockam_transport_tcp/src/workers/listener.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl TcpListenProcessor {
2929
options: TcpListenerOptions,
3030
) -> Result<(SocketAddr, Address)> {
3131
debug!("Binding TcpListener to {}", addr);
32-
let inner = bind_tcp_listener(addr, options.enable_mptcp).await?;
32+
let inner = bind_tcp_listener(addr, options.enable_mptcp, None).await?;
3333
let saddr = inner.local_addr().map_err(TransportError::from)?;
3434

3535
let address = Address::random_tagged("TcpListenProcessor");

0 commit comments

Comments
 (0)