Skip to content

Commit b02a4e0

Browse files
authored
fix: tls server recv function shouldn't contain await (#96)
1 parent a8a7d06 commit b02a4e0

File tree

3 files changed

+44
-19
lines changed

3 files changed

+44
-19
lines changed

bin/relayer/src/agent/quic.rs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ pub struct AgentQuicListener<VALIDATE, HANDSHAKE: ClusterRequest> {
3333

3434
impl<VALIDATE, HANDSHAKE: ClusterRequest> AgentQuicListener<VALIDATE, HANDSHAKE> {
3535
pub async fn new(addr: SocketAddr, priv_key: PrivatePkcs8KeyDer<'static>, cert: CertificateDer<'static>, validate: VALIDATE) -> anyhow::Result<Self> {
36+
log::info!("[AgentQuic] starting with addr {addr}");
3637
let endpoint = make_server_endpoint(addr, priv_key, cert)?;
3738
let (internal_tx, internal_rx) = channel(10);
3839

@@ -51,7 +52,15 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
5152
loop {
5253
select! {
5354
incoming = self.endpoint.accept() => {
54-
tokio::spawn(run_connection(self.validate.clone(), incoming.ok_or(anyhow!("quinn crash"))?, self.internal_tx.clone()));
55+
let validate = self.validate.clone();
56+
let internal_tx = self.internal_tx.clone();
57+
let incoming = incoming.ok_or(anyhow!("quinn crash"))?;
58+
let remote = incoming.remote_address();
59+
tokio::spawn(async move {
60+
if let Err(e) = run_connection(validate, incoming, internal_tx).await {
61+
log::error!("[AgentQuic] connection {remote} error {e:?}");
62+
}
63+
});
5564
},
5665
event = self.internal_rx.recv() => break Ok(event.expect("should work")),
5766
}

bin/relayer/src/agent/tcp.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ pub struct AgentTcpListener<VALIDATE, HANDSHAKE: ClusterRequest> {
3131

3232
impl<VALIDATE, HANDSHAKE: ClusterRequest> AgentTcpListener<VALIDATE, HANDSHAKE> {
3333
pub async fn new(addr: SocketAddr, validate: VALIDATE) -> anyhow::Result<Self> {
34+
log::info!("[AgentTcp] starting with addr {addr}");
3435
let (internal_tx, internal_rx) = channel(10);
3536

3637
Ok(Self {
@@ -49,7 +50,13 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
4950
select! {
5051
incoming = self.listener.accept() => {
5152
let (stream, remote) = incoming?;
52-
tokio::spawn(run_connection(self.validate.clone(), stream, remote, self.internal_tx.clone()));
53+
let validate = self.validate.clone();
54+
let internal_tx = self.internal_tx.clone();
55+
tokio::spawn(async move {
56+
if let Err(e) = run_connection(validate, stream, remote, internal_tx).await {
57+
log::error!("[AgentTcp] connection {remote} error {e:?}");
58+
}
59+
});
5360
},
5461
event = self.internal_rx.recv() => break Ok(event.expect("should receive event from internal channel")),
5562
}

bin/relayer/src/agent/tls.rs

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ use tokio::{
1313
select,
1414
sync::mpsc::{channel, Receiver, Sender},
1515
};
16-
use tokio_rustls::server::TlsStream;
1716
use tokio_rustls::TlsAcceptor;
1817
use tokio_yamux::{Session, StreamHandle};
1918

@@ -24,7 +23,7 @@ use super::{AgentListener, AgentListenerEvent};
2423

2524
pub type TunnelTlsStream = StreamHandle;
2625
pub struct AgentTlsListener<VALIDATE, HANDSHAKE: ClusterRequest> {
27-
tls_acceptor: TlsAcceptor,
26+
tls_acceptor: Arc<TlsAcceptor>,
2827
validate: Arc<VALIDATE>,
2928
listener: TcpListener,
3029
internal_tx: Sender<AgentListenerEvent<HANDSHAKE::Context, TunnelTlsStream>>,
@@ -34,12 +33,13 @@ pub struct AgentTlsListener<VALIDATE, HANDSHAKE: ClusterRequest> {
3433

3534
impl<VALIDATE, HANDSHAKE: ClusterRequest> AgentTlsListener<VALIDATE, HANDSHAKE> {
3635
pub async fn new(addr: SocketAddr, validate: VALIDATE, key: PrivatePkcs8KeyDer<'static>, cert: CertificateDer<'static>) -> anyhow::Result<Self> {
36+
log::info!("[AgentTls] starting with addr {addr}");
3737
let (internal_tx, internal_rx) = channel(10);
3838
let config = rustls::ServerConfig::builder().with_no_client_auth().with_single_cert(vec![cert], PrivateKeyDer::Pkcs8(key))?;
3939
let tls_acceptor = TlsAcceptor::from(Arc::new(config));
4040

4141
Ok(Self {
42-
tls_acceptor,
42+
tls_acceptor: Arc::new(tls_acceptor),
4343
listener: TcpListener::bind(addr).await?,
4444
internal_tx,
4545
internal_rx,
@@ -57,8 +57,14 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
5757
event = self.internal_rx.recv() => break Ok(event.expect("should receive event from internal channel")),
5858
};
5959

60-
let tls_stream = self.tls_acceptor.accept(stream).await?;
61-
tokio::spawn(run_connection(self.validate.clone(), tls_stream, remote, self.internal_tx.clone()));
60+
let tls_acceptor = self.tls_acceptor.clone();
61+
let validate = self.validate.clone();
62+
let internal_tx = self.internal_tx.clone();
63+
tokio::spawn(async move {
64+
if let Err(e) = run_connection(validate, tls_acceptor, stream, remote, internal_tx).await {
65+
log::error!("[AgentTls] connection {remote} error {e:?}");
66+
}
67+
});
6268
}
6369
}
6470

@@ -67,25 +73,28 @@ impl<VALIDATE: ClusterValidator<REQ>, REQ: DeserializeOwned + Send + Sync + 'sta
6773

6874
async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
6975
validate: Arc<VALIDATE>,
70-
mut in_stream: TlsStream<TcpStream>,
76+
tls_acceptor: Arc<TlsAcceptor>,
77+
stream: TcpStream,
7178
remote: SocketAddr,
7279
internal_tx: Sender<AgentListenerEvent<REQ::Context, TunnelTlsStream>>,
7380
) -> anyhow::Result<()> {
7481
let started = Instant::now();
75-
log::info!("[AgentTcp] new connection from {}", remote);
82+
log::info!("[AgentTls] new connection from {remote}, handshaking tls");
83+
let mut in_stream = tls_acceptor.accept(stream).await?;
84+
log::info!("[AgentTls] new connection from {remote}, handshake tls success");
7685

7786
let mut buf = [0u8; 4096];
7887
let buf_len = in_stream.read(&mut buf).await?;
7988

80-
log::info!("[AgentTcp] new connection got handhsake data {buf_len} bytes");
89+
log::info!("[AgentTls] new connection from {remote} got handhsake data {buf_len} bytes");
8190

8291
let req = validate.validate_connect_req(&buf[0..buf_len])?;
8392
let domain = validate.generate_domain(&req)?;
8493
let agent_id = AgentId::try_from_domain(&domain)?;
8594
let session_id = AgentSessionId::rand();
8695
let agent_ctx = req.context();
8796

88-
log::info!("[AgentTcp] new connection validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}");
97+
log::info!("[AgentTls] new connection from {remote} validated with domain {domain} agent_id: {agent_id}, session uuid: {session_id}");
8998

9099
let res_buf = validate.sign_response_res(&req, None);
91100
in_stream.write_all(&res_buf).await?;
@@ -96,7 +105,7 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
96105
.await
97106
.expect("should send to main loop");
98107

99-
log::info!("[AgentTcp] new connection {agent_id} {session_id} started loop");
108+
log::info!("[AgentTls] new connection {agent_id} {session_id} started loop");
100109
let mut session = Session::new_client(in_stream, Default::default());
101110
histogram!(METRICS_AGENT_HISTOGRAM).record(started.elapsed().as_millis() as f32 / 1000.0);
102111

@@ -105,17 +114,17 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
105114
control = control_rx.recv() => match control {
106115
Some(control) => match control {
107116
AgentSessionControl::CreateStream(tx) => {
108-
log::info!("[AgentTcp] agent {agent_id} {session_id} create stream request");
117+
log::info!("[AgentTls] agent {agent_id} {session_id} create stream request");
109118
match session.open_stream() {
110119
Ok(stream) => {
111-
log::info!("[AgentTcp] agent {agent_id} {session_id} created stream");
120+
log::info!("[AgentTls] agent {agent_id} {session_id} created stream");
112121
if let Err(_e) = tx.send(Ok(stream)) {
113-
log::error!("[AgentTcp] agent {agent_id} {session_id} send created stream error");
122+
log::error!("[AgentTls] agent {agent_id} {session_id} send created stream error");
114123
}
115124
},
116125
Err(err) => {
117126
if let Err(_e) = tx.send(Err(err.into())) {
118-
log::error!("[AgentTcp] agent {agent_id} {session_id} send create stream's error, may be internal channel failed");
127+
log::error!("[AgentTls] agent {agent_id} {session_id} send create stream's error, may be internal channel failed");
119128
}
120129
},
121130
}
@@ -134,18 +143,18 @@ async fn run_connection<VALIDATE: ClusterValidator<REQ>, REQ: ClusterRequest>(
134143
});
135144
},
136145
Some(Err(err)) => {
137-
log::error!("[AgentTcp] agent {agent_id} {session_id} Tcp connection error {err:?}");
146+
log::error!("[AgentTls] agent {agent_id} {session_id} Tcp connection error {err:?}");
138147
break;
139148
},
140149
None => {
141-
log::error!("[AgentTcp] agent {agent_id} {session_id} Tcp connection broken with None");
150+
log::error!("[AgentTls] agent {agent_id} {session_id} Tcp connection broken with None");
142151
break;
143152
}
144153
}
145154
}
146155
}
147156

148-
log::info!("[AgentTcp] agent {agent_id} {session_id} stopped loop");
157+
log::info!("[AgentTls] agent {agent_id} {session_id} stopped loop");
149158

150159
internal_tx.send(AgentListenerEvent::Disconnected(agent_id, session_id)).await.expect("should send to main loop");
151160

0 commit comments

Comments
 (0)