From 4ee260850257735da52ba64ff990ccdb3c3b229b Mon Sep 17 00:00:00 2001
From: Pure White <wudi.daniel@bytedance.com>
Date: Wed, 7 Aug 2024 14:33:39 +0800
Subject: [PATCH] feat(volo-thrift): close connection when encounter error

---
 volo-thrift/src/codec/default/ttheader.rs     |  9 ++++-----
 volo-thrift/src/context.rs                    | 14 ++++++++------
 volo-thrift/src/transport/multiplex/server.rs | 17 ++++++++++-------
 volo-thrift/src/transport/pingpong/server.rs  | 14 ++++++++------
 4 files changed, 30 insertions(+), 24 deletions(-)

diff --git a/volo-thrift/src/codec/default/ttheader.rs b/volo-thrift/src/codec/default/ttheader.rs
index 5b882362..1f6e7161 100644
--- a/volo-thrift/src/codec/default/ttheader.rs
+++ b/volo-thrift/src/codec/default/ttheader.rs
@@ -324,7 +324,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
             }
             Role::Server => {
                 metainfo.get_all_backward_transients().is_some()
-                    || cx.encode_conn_reset().unwrap_or(false)
+                    || cx.encode_conn_reset()
                     || cx.stats().biz_error().is_some()
             }
         };
@@ -375,7 +375,7 @@ pub(crate) fn encode<Cx: ThriftContext>(
                             string_kv_len += 1;
                         }
                     }
-                    if cx.encode_conn_reset().unwrap_or(false) {
+                    if cx.encode_conn_reset() {
                         dst.put_u16(5);
                         dst.put_slice("crrst".as_bytes());
                         dst.put_u16(1);
@@ -582,8 +582,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
                 metainfo.get_all_persistents().is_some() || metainfo.get_all_transients().is_some()
             }
             Role::Server => {
-                metainfo.get_all_backward_transients().is_some()
-                    || thrift_cx.encode_conn_reset().unwrap_or(false)
+                metainfo.get_all_backward_transients().is_some() || thrift_cx.encode_conn_reset()
             }
         };
 
@@ -624,7 +623,7 @@ pub(crate) fn encode_size<Cx: ThriftContext>(cx: &mut Cx) -> Result<usize, Thrif
                             len += value.as_bytes().len();
                         }
                     }
-                    if thrift_cx.encode_conn_reset().unwrap_or(false) {
+                    if thrift_cx.encode_conn_reset() {
                         len += 2;
                         len += "crrst".as_bytes().len();
                         len += 2;
diff --git a/volo-thrift/src/context.rs b/volo-thrift/src/context.rs
index efc260e3..2786ef53 100644
--- a/volo-thrift/src/context.rs
+++ b/volo-thrift/src/context.rs
@@ -292,7 +292,7 @@ impl std::ops::DerefMut for ServerContext {
 }
 
 pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'static {
-    fn encode_conn_reset(&self) -> Option<bool>;
+    fn encode_conn_reset(&self) -> bool;
     fn set_conn_reset_by_ttheader(&mut self, reset: bool);
     fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier);
     fn seq_id(&self) -> i32;
@@ -307,8 +307,8 @@ pub trait ThriftContext: volo::context::Context<Config = Config> + Send + 'stati
 
 impl ThriftContext for ClientContext {
     #[inline]
-    fn encode_conn_reset(&self) -> Option<bool> {
-        None
+    fn encode_conn_reset(&self) -> bool {
+        false
     }
 
     #[inline]
@@ -342,12 +342,14 @@ impl ThriftContext for ClientContext {
 
 impl ThriftContext for ServerContext {
     #[inline]
-    fn encode_conn_reset(&self) -> Option<bool> {
-        Some(self.transport.is_conn_reset())
+    fn encode_conn_reset(&self) -> bool {
+        self.transport.is_conn_reset()
     }
 
     #[inline]
-    fn set_conn_reset_by_ttheader(&mut self, _reset: bool) {}
+    fn set_conn_reset_by_ttheader(&mut self, reset: bool) {
+        self.transport.set_conn_reset(reset)
+    }
 
     #[inline]
     fn handle_decoded_msg_ident(&mut self, ident: &TMessageIdentifier) {
diff --git a/volo-thrift/src/transport/multiplex/server.rs b/volo-thrift/src/transport/multiplex/server.rs
index 71b6f601..bd6b96e0 100644
--- a/volo-thrift/src/transport/multiplex/server.rs
+++ b/volo-thrift/src/transport/multiplex/server.rs
@@ -12,7 +12,7 @@ use volo::{context::Context, net::Address, volo_unreachable};
 
 use crate::{
     codec::{Decoder, Encoder},
-    context::ServerContext,
+    context::{ServerContext, ThriftContext as _},
     protocol::TMessageType,
     server_error_to_application_exception, thrift_exception_to_application_exception, DummyMessage,
     EntryMessage, ServerError, ThriftMessage,
@@ -40,7 +40,8 @@ pub async fn serve<Svc, Req, Resp, E, D>(
 
     // mpsc channel used to send responses to the loop
     let (send_tx, mut send_rx) = mpsc::channel(CHANNEL_SIZE);
-    let (error_send_tx, mut error_send_rx) = mpsc::channel(1);
+    let (error_send_tx, mut error_send_rx) =
+        mpsc::channel::<(ServerContext, ThriftMessage<DummyMessage>)>(1);
 
     tokio::spawn({
         let peer_addr = peer_addr.clone();
@@ -70,6 +71,9 @@ pub async fn serve<Svc, Req, Resp, E, D>(
                                             return;
                                         }
                                         stat_tracer.iter().for_each(|f| f(&cx));
+                                        if cx.encode_conn_reset() {
+                                            return;
+                                        }
                                     }
                                     None => {
                                         // log it
@@ -85,6 +89,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
                             error_msg = error_send_rx.recv() => {
                                 match error_msg {
                                     Some((mut cx, msg)) => {
+                                        cx.set_conn_reset_by_ttheader(true);
                                         if let Err(e) = encoder
                                             .encode::<DummyMessage, ServerContext>(&mut cx, msg)
                                             .await
@@ -185,11 +190,11 @@ pub async fn serve<Svc, Req, Resp, E, D>(
                             metainfo::METAINFO
                                 .scope(RefCell::new(mi), async move {
                                     cx.stats.record_process_start_at();
-                                    let resp = svc.call(&mut cx, req).await;
+                                    let resp = svc.call(&mut cx, req).await.map_err(Into::into);
                                     cx.stats.record_process_end_at();
 
                                     if exit_mark.load(Ordering::Relaxed) {
-                                        cx.transport.set_conn_reset(true);
+                                        cx.set_conn_reset_by_ttheader(true);
                                     }
                                     let req_msg_type =
                                         cx.req_msg_type.expect("`req_msg_type` should be set.");
@@ -200,9 +205,7 @@ pub async fn serve<Svc, Req, Resp, E, D>(
                                         });
                                         let msg = ThriftMessage::mk_server_resp(
                                             &cx,
-                                            resp.map_err(|e| {
-                                                server_error_to_application_exception(e.into())
-                                            }),
+                                            resp.map_err(server_error_to_application_exception),
                                         );
                                         let mi = metainfo::METAINFO.with(|m| m.take());
                                         let _ = send_tx.send((mi, cx, msg)).await;
diff --git a/volo-thrift/src/transport/pingpong/server.rs b/volo-thrift/src/transport/pingpong/server.rs
index 886bfac7..d7797449 100644
--- a/volo-thrift/src/transport/pingpong/server.rs
+++ b/volo-thrift/src/transport/pingpong/server.rs
@@ -12,7 +12,7 @@ use volo::{net::Address, volo_unreachable};
 
 use crate::{
     codec::{Decoder, Encoder},
-    context::{ServerContext, SERVER_CONTEXT_CACHE},
+    context::{ServerContext, ThriftContext, SERVER_CONTEXT_CACHE},
     protocol::TMessageType,
     server_error_to_application_exception, thrift_exception_to_application_exception,
     tracing::SpanProvider,
@@ -81,11 +81,11 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
                     match msg {
                         Ok(Some(ThriftMessage { data: Ok(req), .. })) => {
                             cx.stats.record_process_start_at();
-                            let resp = service.call(&mut cx, req).await;
+                            let resp = service.call(&mut cx, req).await.map_err(Into::into);
                             cx.stats.record_process_end_at();
 
                             if exit_mark.load(Ordering::Relaxed) {
-                                cx.transport.set_conn_reset(true);
+                                cx.set_conn_reset_by_ttheader(true);
                             }
 
                             let req_msg_type =
@@ -98,9 +98,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
                                 });
                                 let msg = ThriftMessage::mk_server_resp(
                                     &cx,
-                                    resp.map_err(|e| {
-                                        server_error_to_application_exception(e.into())
-                                    }),
+                                    resp.map_err(server_error_to_application_exception),
                                 );
                                 if let Err(e) = async {
                                     let result = encoder.encode(&mut cx, msg).await;
@@ -119,6 +117,9 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
                                     return Err(());
                                 }
                             }
+                            if cx.transport.is_conn_reset() {
+                                return Err(());
+                            }
                         }
                         Ok(Some(ThriftMessage { data: Err(_), .. })) => {
                             volo_unreachable!();
@@ -138,6 +139,7 @@ pub async fn serve<Svc, Req, Resp, E, D, SP>(
                                 e, cx, peer_addr
                             );
                             cx.msg_type = Some(TMessageType::Exception);
+                            cx.set_conn_reset_by_ttheader(true);
                             if !matches!(e, ThriftException::Transport(_)) {
                                 let msg = ThriftMessage::mk_server_resp(
                                     &cx,