diff --git a/pkg/remote/trans/default_server_handler.go b/pkg/remote/trans/default_server_handler.go index 6643c035ec..2502e8902f 100644 --- a/pkg/remote/trans/default_server_handler.go +++ b/pkg/remote/trans/default_server_handler.go @@ -171,7 +171,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) recvMsg.SetPayloadCodec(t.opt.PayloadCodec) ctx, err = t.transPipe.Read(ctx, conn, recvMsg) if err != nil { - t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true) + t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true, true) // t.OnError(ctx, err, conn) will be executed at outer function when transServer close the conn return err } @@ -187,7 +187,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) var methodInfo serviceinfo.MethodInfo if methodInfo, err = GetMethodInfo(ri, svcInfo); err != nil { // it won't be err, because the method has been checked in decode, err check here just do defensive inspection - t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true) + t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, true, true) // for proxy case, need read actual remoteAddr, error print must exec after writeErrorReplyIfNeeded, // t.OnError(ctx, err, conn) will be executed at outer function when transServer close the conn return err @@ -203,7 +203,7 @@ func (t *svrTransHandler) OnRead(ctx context.Context, conn net.Conn) (err error) // error cannot be wrapped to print here, so it must exec before NewTransError t.OnError(ctx, err, conn) err = remote.NewTransError(remote.InternalError, err) - if closeConn := t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, false); closeConn { + if closeConn := t.writeErrorReplyIfNeeded(ctx, recvMsg, conn, err, ri, false, false); closeConn { return err } // connection don't need to be closed when the error is return by the server handler @@ -272,7 +272,7 @@ func (t *svrTransHandler) SetPipeline(p *remote.TransPipeline) { } func (t *svrTransHandler) writeErrorReplyIfNeeded( - ctx context.Context, recvMsg remote.Message, conn net.Conn, err error, ri rpcinfo.RPCInfo, doOnMessage bool, + ctx context.Context, recvMsg remote.Message, conn net.Conn, err error, ri rpcinfo.RPCInfo, doOnMessage bool, connReset bool, ) (shouldCloseConn bool) { if cn, ok := conn.(remote.IsActive); ok && !cn.IsActive() { // conn is closed, no need reply @@ -297,6 +297,13 @@ func (t *svrTransHandler) writeErrorReplyIfNeeded( // if error happen before normal OnMessage, exec it to transfer header trans info into rpcinfo t.transPipe.OnMessage(ctx, recvMsg, errMsg) } + if connReset { + // if connection needs to be closed, set ConnResetTag to response header + // to ensure the client won't reuse the connection. + if ei := rpcinfo.AsTaggable(ri.To()); ei != nil { + ei.SetTag(rpcinfo.ConnResetTag, "1") + } + } ctx, err = t.transPipe.Write(ctx, conn, errMsg) if err != nil { klog.CtxErrorf(ctx, "KITEX: write error reply failed, remote=%s, error=%s", conn.RemoteAddr(), err.Error()) diff --git a/pkg/transmeta/ttheader.go b/pkg/transmeta/ttheader.go index 3b18537a58..b37519c1de 100644 --- a/pkg/transmeta/ttheader.go +++ b/pkg/transmeta/ttheader.go @@ -122,6 +122,11 @@ func (ch *clientTTHeaderHandler) ReadMeta(ctx context.Context, msg remote.Messag if setter, ok := ri.Invocation().(rpcinfo.InvocationSetter); ok && bizErr != nil { setter.SetBizStatusErr(bizErr) } + if val, ok := strInfo[transmeta.HeaderConnectionReadyToReset]; ok { + if ei := rpcinfo.AsTaggable(ri.To()); ei != nil { + ei.SetTag(rpcinfo.ConnResetTag, val) + } + } return ctx, nil } @@ -190,6 +195,9 @@ func (sh *serverTTHeaderHandler) WriteMeta(ctx context.Context, msg remote.Messa strInfo[bizExtra], _ = utils.Map2JSONStr(bizErr.BizExtra()) } } + if val, ok := ri.To().Tag(rpcinfo.ConnResetTag); ok { + strInfo[transmeta.HeaderConnectionReadyToReset] = val + } return ctx, nil }