diff --git a/consumer/option.go b/consumer/option.go index 24acf7c3..2e08163d 100644 --- a/consumer/option.go +++ b/consumer/option.go @@ -382,6 +382,15 @@ func WithLimiter(limiter Limiter) Option { } } +// WithRemotingTimeout set remote client timeout options +func WithRemotingTimeout(connectionTimeout, readTimeout, writeTimeout time.Duration) Option { + return func(opts *consumerOptions) { + opts.ClientOptions.RemotingClientConfig.ConnectionTimeout = connectionTimeout + opts.ClientOptions.RemotingClientConfig.ReadTimeout = readTimeout + opts.ClientOptions.RemotingClientConfig.WriteTimeout = writeTimeout + } +} + func WithTls(useTls bool) Option { return func(opts *consumerOptions) { opts.ClientOptions.RemotingClientConfig.UseTls = useTls diff --git a/consumer/option_test.go b/consumer/option_test.go index ab99b632..4db5b93b 100644 --- a/consumer/option_test.go +++ b/consumer/option_test.go @@ -3,6 +3,7 @@ package consumer import ( "reflect" "testing" + "time" ) func getFieldString(obj interface{}, field string) string { @@ -12,6 +13,20 @@ func getFieldString(obj interface{}, field string) string { }).String() } +func TestWithRemotingTimeout(t *testing.T) { + opt := defaultPushConsumerOptions() + WithRemotingTimeout(3*time.Second, 4*time.Second, 5*time.Second)(&opt) + if timeout := opt.RemotingClientConfig.ConnectionTimeout; timeout != 3*time.Second { + t.Errorf("consumer option WithRemotingTimeout connectionTimeout. want:%s, got=%s", 3*time.Second, timeout) + } + if timeout := opt.RemotingClientConfig.ReadTimeout; timeout != 4*time.Second { + t.Errorf("consumer option WithRemotingTimeout readTimeout. want:%s, got=%s", 4*time.Second, timeout) + } + if timeout := opt.RemotingClientConfig.WriteTimeout; timeout != 5*time.Second { + t.Errorf("consumer option WithRemotingTimeout writeTimeout. want:%s, got=%s", 5*time.Second, timeout) + } +} + func TestWithUnitName(t *testing.T) { opt := defaultPushConsumerOptions() unitName := "unsh" diff --git a/internal/remote/remote_client.go b/internal/remote/remote_client.go index 45dfbbf5..36fbea7a 100644 --- a/internal/remote/remote_client.go +++ b/internal/remote/remote_client.go @@ -102,7 +102,7 @@ func (c *remotingClient) InvokeSync(ctx context.Context, addr string, request *R c.responseTable.Store(resp.Opaque, resp) defer c.responseTable.Delete(request.Opaque) - err = c.sendRequest(conn, request) + err = c.sendRequest(ctx, conn, request) if err != nil { return nil, err } @@ -120,7 +120,7 @@ func (c *remotingClient) InvokeAsync(ctx context.Context, addr string, request * resp := NewResponseFuture(ctx, request.Opaque, callback) c.responseTable.Store(resp.Opaque, resp) - err = c.sendRequest(conn, request) + err = c.sendRequest(ctx, conn, request) if err != nil { c.responseTable.Delete(request.Opaque) return err @@ -146,11 +146,11 @@ func (c *remotingClient) InvokeOneWay(ctx context.Context, addr string, request if err != nil { return err } - return c.sendRequest(conn, request) + return c.sendRequest(ctx, conn, request) } func (c *remotingClient) connect(ctx context.Context, addr string) (*tcpConnWrapper, error) { - //it needs additional locker. + // it needs additional locker. c.connectionLocker.Lock() defer c.connectionLocker.Unlock() conn, ok := c.connectionTable.Load(addr) @@ -246,7 +246,7 @@ func (c *remotingClient) processCMD(cmd *RemotingCommand, r *tcpConnWrapper) { if res != nil { res.Opaque = cmd.Opaque res.Flag |= 1 << 0 - err := c.sendRequest(r, res) + err := c.sendRequest(context.Background(), r, res) if err != nil { rlog.Warning("send response to broker error", map[string]interface{}{ rlog.LogKeyUnderlayError: err, @@ -297,23 +297,27 @@ func (c *remotingClient) createScanner(r io.Reader) *bufio.Scanner { return scanner } -func (c *remotingClient) sendRequest(conn *tcpConnWrapper, request *RemotingCommand) error { +func (c *remotingClient) sendRequest(ctx context.Context, conn *tcpConnWrapper, request *RemotingCommand) error { var err error if c.interceptor != nil { - err = c.interceptor(context.Background(), request, nil, func(ctx context.Context, req, reply interface{}) error { - return c.doRequest(conn, request) + err = c.interceptor(ctx, request, nil, func(ctx context.Context, req, reply interface{}) error { + return c.doRequest(ctx, conn, request) }) } else { - err = c.doRequest(conn, request) + err = c.doRequest(ctx, conn, request) } return err } -func (c *remotingClient) doRequest(conn *tcpConnWrapper, request *RemotingCommand) error { +func (c *remotingClient) doRequest(ctx context.Context, conn *tcpConnWrapper, request *RemotingCommand) error { conn.Lock() defer conn.Unlock() - err := conn.Conn.SetWriteDeadline(time.Now().Add(c.config.WriteTimeout)) + deadline, ok := ctx.Deadline() + if !ok { + deadline = time.Now().Add(c.config.WriteTimeout) + } + err := conn.Conn.SetWriteDeadline(deadline) if err != nil { rlog.Error("conn error, close connection", map[string]interface{}{ rlog.LogKeyUnderlayError: err, diff --git a/producer/option.go b/producer/option.go index 6e43cc25..72af3c6a 100644 --- a/producer/option.go +++ b/producer/option.go @@ -179,6 +179,15 @@ func WithCompressLevel(level int) Option { } } +// WithRemotingTimeout set remote client timeout options +func WithRemotingTimeout(connectionTimeout, readTimeout, writeTimeout time.Duration) Option { + return func(opts *producerOptions) { + opts.ClientOptions.RemotingClientConfig.ConnectionTimeout = connectionTimeout + opts.ClientOptions.RemotingClientConfig.ReadTimeout = readTimeout + opts.ClientOptions.RemotingClientConfig.WriteTimeout = writeTimeout + } +} + func WithTls(useTls bool) Option { return func(opts *producerOptions) { opts.ClientOptions.RemotingClientConfig.UseTls = useTls diff --git a/producer/option_test.go b/producer/option_test.go index 723da031..9b6ee133 100644 --- a/producer/option_test.go +++ b/producer/option_test.go @@ -3,6 +3,7 @@ package producer import ( "reflect" "testing" + "time" ) func getFieldString(obj interface{}, field string) string { @@ -12,6 +13,20 @@ func getFieldString(obj interface{}, field string) string { }).String() } +func TestWithRemotingTimeout(t *testing.T) { + opt := defaultProducerOptions() + WithRemotingTimeout(3*time.Second, 4*time.Second, 5*time.Second)(&opt) + if timeout := opt.RemotingClientConfig.ConnectionTimeout; timeout != 3*time.Second { + t.Errorf("consumer option WithRemotingTimeout connectionTimeout. want:%s, got=%s", 3*time.Second, timeout) + } + if timeout := opt.RemotingClientConfig.ReadTimeout; timeout != 4*time.Second { + t.Errorf("consumer option WithRemotingTimeout readTimeout. want:%s, got=%s", 4*time.Second, timeout) + } + if timeout := opt.RemotingClientConfig.WriteTimeout; timeout != 5*time.Second { + t.Errorf("consumer option WithRemotingTimeout writeTimeout. want:%s, got=%s", 5*time.Second, timeout) + } +} + func TestWithUnitName(t *testing.T) { opt := defaultProducerOptions() unitName := "unsh" diff --git a/producer/producer.go b/producer/producer.go index f8238843..70e8d013 100644 --- a/producer/producer.go +++ b/producer/producer.go @@ -26,14 +26,15 @@ import ( "sync/atomic" "time" + "github.com/google/uuid" + "github.com/pkg/errors" + errors2 "github.com/apache/rocketmq-client-go/v2/errors" "github.com/apache/rocketmq-client-go/v2/internal" "github.com/apache/rocketmq-client-go/v2/internal/remote" "github.com/apache/rocketmq-client-go/v2/internal/utils" "github.com/apache/rocketmq-client-go/v2/primitive" "github.com/apache/rocketmq-client-go/v2/rlog" - "github.com/google/uuid" - "github.com/pkg/errors" ) type defaultProducer struct { @@ -355,7 +356,7 @@ func (p *defaultProducer) sendSync(ctx context.Context, msg *primitive.Message, producerCtx.MQ = *mq } - res, _err := p.client.InvokeSync(ctx, addr, p.buildSendRequest(mq, msg), 3*time.Second) + res, _err := p.client.InvokeSync(ctx, addr, p.buildSendRequest(mq, msg), p.options.SendMsgTimeout) if _err != nil { err = _err continue @@ -400,7 +401,7 @@ func (p *defaultProducer) sendAsync(ctx context.Context, msg *primitive.Message, return errors.Errorf("topic=%s route info not found", mq.Topic) } - ctx, cancel := context.WithTimeout(ctx, 3*time.Second) + ctx, cancel := context.WithTimeout(ctx, p.options.SendMsgTimeout) err := p.client.InvokeAsync(ctx, addr, p.buildSendRequest(mq, msg), func(command *remote.RemotingCommand, err error) { cancel() if err != nil { @@ -465,7 +466,7 @@ func (p *defaultProducer) sendOneWay(ctx context.Context, msg *primitive.Message return fmt.Errorf("topic=%s route info not found", mq.Topic) } - _err := p.client.InvokeOneWay(ctx, addr, p.buildSendRequest(mq, msg), 3*time.Second) + _err := p.client.InvokeOneWay(ctx, addr, p.buildSendRequest(mq, msg), p.options.SendMsgTimeout) if _err != nil { err = _err continue