Skip to content

Commit 6a718bd

Browse files
committed
feat(volo-thrift): close connection when encounter error
1 parent ea932cf commit 6a718bd

File tree

3 files changed

+26
-17
lines changed

3 files changed

+26
-17
lines changed

volo-thrift/src/context.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext {
292292
}
293293

294294
pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'static {
295-
fn encode_conn_reset(&self) -> Option<bool>;
295+
fn encode_conn_reset(&self) -> bool;
296296
fn set_conn_reset_by_ttheader(&mut self, reset: bool);
297297
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier);
298298
fn seq_id(&self) -> i32;
@@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'stati
307307

308308
impl ThriftContext for ClientContext {
309309
#[inline]
310-
fn encode_conn_reset(&self) -> Option<bool> {
311-
None
310+
fn encode_conn_reset(&self) -> bool {
311+
false
312312
}
313313

314314
#[inline]
@@ -342,12 +342,14 @@ impl ThriftContext for ClientContext {
342342

343343
impl ThriftContext for ServerContext {
344344
#[inline]
345-
fn encode_conn_reset(&self) -> Option<bool> {
346-
Some(self.transport.is_conn_reset())
345+
fn encode_conn_reset(&self) -> bool {
346+
self.transport.is_conn_reset()
347347
}
348348

349349
#[inline]
350-
fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {}
350+
fn set_conn_reset_by_ttheader(&mut self, reset: bool) {
351+
self.transport.set_conn_reset(reset)
352+
}
351353

352354
#[inline]
353355
fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) {

volo-thrift/src/transport/multiplex/server.rs

+10-5
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable};
1212

1313
use crate::{
1414
codec::{Decoder, Encoder},
15-
context::ServerContext,
15+
context::{ServerContext, ThriftContext as _},
1616
protocol::TMessageType,
1717
server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage,
1818
EntryMessage, ServerError, ThriftMessage,
@@ -40,7 +40,8 @@ pub async fn serve<Svc, Req, Resp, E, D>(
4040

4141
// mpsc channel used to send responses to the loop
4242
let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE);
43-
let (error_send_tx, mut error_send_rx) = mpsc::channel(1);
43+
let (error_send_tx, mut error_send_rx) =
44+
mpsc::channel::<(ServerContext, ThriftMessage<DummyMessage>)>(1);
4445

4546
tokio::spawn({
4647
let peer_addr = peer_addr.clone();
@@ -70,6 +71,9 @@ pub async fn serve<Svc, Req, Resp, E, D>(
7071
return;
7172
}
7273
stat_tracer.iter().for_each(|f| f(&cx));
74+
if cx.encode_conn_reset() {
75+
return;
76+
}
7377
}
7478
None => {
7579
// log it
@@ -85,6 +89,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
8589
error_msg = error_send_rx.recv() => {
8690
match error_msg {
8791
Some((mut cx, msg)) => {
92+
cx.set_conn_reset_by_ttheader(true);
8893
if let Err(e) = encoder
8994
.encode::<DummyMessage, ServerContext>(&mut cx, msg)
9095
.await
@@ -185,11 +190,11 @@ pub async fn serve<Svc, Req, Resp, E, D>(
185190
metainfo::METAINFO
186191
.scope(RefCell::new(mi), async move {
187192
cx.stats.record_process_start_at();
188-
let resp = svc.call(&mut cx, req).await;
193+
let resp = svc.call(&mut cx, req).await.map_err(Into::into);
189194
cx.stats.record_process_end_at();
190195

191196
if exit_mark.load(Ordering::Relaxed) {
192-
cx.transport.set_conn_reset(true);
197+
cx.set_conn_reset_by_ttheader(true);
193198
}
194199
let req_msg_type =
195200
cx.req_msg_type.expect("`req_msg_type` should be set.");
@@ -201,7 +206,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
201206
let msg = ThriftMessage::mk_server_resp(
202207
&cx,
203208
resp.map_err(|e| {
204-
server_error_to_application_exception(e.into())
209+
server_error_to_application_exception(e)
205210
}),
206211
);
207212
let mi = metainfo::METAINFO.with(|m| m.take());

volo-thrift/src/transport/pingpong/server.rs

+8-6
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable};
1212

1313
use crate::{
1414
codec::{Decoder, Encoder},
15-
context::{ServerContext, SERVER_CONTEXT_CACHE},
15+
context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE},
1616
protocol::TMessageType,
1717
server_error_to_application_exception, thrift_exception_to_application_exception,
1818
tracing::SpanProvider,
@@ -81,11 +81,11 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
8181
match msg {
8282
Ok(Some(ThriftMessage { data: Ok(req), .. })) => {
8383
cx.stats.record_process_start_at();
84-
let resp = service.call(&mut cx, req).await;
84+
let resp = service.call(&mut cx, req).await.map_err(Into::into);
8585
cx.stats.record_process_end_at();
8686

8787
if exit_mark.load(Ordering::Relaxed) {
88-
cx.transport.set_conn_reset(true);
88+
cx.set_conn_reset_by_ttheader(true);
8989
}
9090

9191
let req_msg_type =
@@ -98,9 +98,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
9898
});
9999
let msg = ThriftMessage::mk_server_resp(
100100
&cx,
101-
resp.map_err(|e| {
102-
server_error_to_application_exception(e.into())
103-
}),
101+
resp.map_err(|e| server_error_to_application_exception(e)),
104102
);
105103
if let Err(e) = async {
106104
let result = encoder.encode(&mut cx, msg).await;
@@ -119,6 +117,9 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
119117
return Err(());
120118
}
121119
}
120+
if cx.transport.is_conn_reset() {
121+
return Err(());
122+
}
122123
}
123124
Ok(Some(ThriftMessage { data: Err(_), .. })) => {
124125
volo_unreachable!();
@@ -138,6 +139,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
138139
e, cx, peer_addr
139140
);
140141
cx.msg_type = Some(TMessageType::Exception);
142+
cx.set_conn_reset_by_ttheader(true);
141143
if !matches!(e, ThriftException::Transport(_)) {
142144
let msg = ThriftMessage::mk_server_resp(
143145
&cx,

0 commit comments

Comments
 (0)