diff --git a/src/commands/cmd_server.cc b/src/commands/cmd_server.cc index d1359087c0f..661df2cf599 100644 --- a/src/commands/cmd_server.cc +++ b/src/commands/cmd_server.cc @@ -392,7 +392,7 @@ class CommandClient : public Commander { public: Status Parse(const std::vector &args) override { subcommand_ = util::ToLower(args[1]); - // subcommand: getname id kill list info setname + // subcommand: getname id kill list info setname reply if ((subcommand_ == "id" || subcommand_ == "getname" || subcommand_ == "list" || subcommand_ == "info") && args.size() == 2) { return Status::OK(); @@ -412,6 +412,23 @@ class CommandClient : public Commander { return Status::OK(); } + if (subcommand_ == "reply") { + if (args.size() != 3) { + return {Status::RedisParseErr, errInvalidSyntax}; + } + auto mode_str = util::ToLower(args[2]); + if (mode_str == "on") { + reply_mode_ = redis::Connection::ReplyMode::ON; + } else if (mode_str == "off") { + reply_mode_ = redis::Connection::ReplyMode::OFF; + } else if (mode_str == "skip") { + reply_mode_ = redis::Connection::ReplyMode::SKIP; + } else { + return {Status::RedisParseErr, errInvalidSyntax}; + } + return Status::OK(); + } + if ((subcommand_ == "kill")) { if (args.size() == 2) { return {Status::RedisParseErr, errInvalidSyntax}; @@ -464,7 +481,7 @@ class CommandClient : public Commander { } return Status::OK(); } - return {Status::RedisInvalidCmd, "Syntax error, try CLIENT LIST|INFO|KILL ip:port|GETNAME|SETNAME"}; + return {Status::RedisInvalidCmd, "Syntax error, try CLIENT LIST|INFO|KILL ip:port|GETNAME|SETNAME|REPLY"}; } Status Execute([[maybe_unused]] engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { @@ -497,15 +514,22 @@ class CommandClient : public Commander { *output = redis::RESP_OK; } return Status::OK(); + } else if (subcommand_ == "reply") { + conn->SetReplyMode(reply_mode_); + if (reply_mode_ != redis::Connection::ReplyMode::SKIP) { + *output = redis::RESP_OK; + } + return Status::OK(); } - return {Status::RedisInvalidCmd, "Syntax error, try CLIENT LIST|INFO|KILL ip:port|GETNAME|SETNAME"}; + return {Status::RedisInvalidCmd, "Syntax error, try CLIENT LIST|INFO|KILL ip:port|GETNAME|SETNAME|REPLY"}; } private: std::string addr_; std::string conn_name_; std::string subcommand_; + redis::Connection::ReplyMode reply_mode_ = redis::Connection::ReplyMode::ON; bool skipme_ = false; int64_t kill_type_ = 0; uint64_t id_ = 0; diff --git a/src/server/redis_connection.cc b/src/server/redis_connection.cc index 39781087333..443f29a07fb 100644 --- a/src/server/redis_connection.cc +++ b/src/server/redis_connection.cc @@ -132,6 +132,13 @@ void Connection::OnEvent(bufferevent *bev, int16_t events) { } void Connection::Reply(const std::string &msg) { + if (reply_mode_ == ReplyMode::SKIP) { + reply_mode_ = ReplyMode::ON; + return; + } + if (reply_mode_ == ReplyMode::OFF) { + return; + } owner_->srv->stats.IncrOutboundBytes(msg.size()); redis::Reply(bufferevent_get_output(bev_), msg); } diff --git a/src/server/redis_connection.h b/src/server/redis_connection.h index e8b44d94144..d31fbce1c32 100644 --- a/src/server/redis_connection.h +++ b/src/server/redis_connection.h @@ -50,6 +50,12 @@ class Connection : public EvbufCallbackBase { kAsking = 1 << 10, }; + enum class ReplyMode { + ON, // Always reply to every command (default) + OFF, // Never reply to any command + SKIP // Skip reply for the next command, then automatically switch back to ON + }; + explicit Connection(bufferevent *bev, Worker *owner); ~Connection(); @@ -181,6 +187,10 @@ class Connection : public EvbufCallbackBase { std::set watched_keys; std::atomic watched_keys_modified = false; + // Reply mode getter/setter + void SetReplyMode(ReplyMode mode) { reply_mode_ = mode; } + ReplyMode GetReplyMode() const { return reply_mode_; } + private: uint64_t id_ = 0; std::atomic flags_ = 0; @@ -215,6 +225,8 @@ class Connection : public EvbufCallbackBase { bool importing_ = false; RESP protocol_version_ = RESP::v2; + + ReplyMode reply_mode_ = ReplyMode::ON; }; } // namespace redis diff --git a/tests/gocase/unit/introspection/introspection_test.go b/tests/gocase/unit/introspection/introspection_test.go index af15117dfbf..1b17a98caae 100644 --- a/tests/gocase/unit/introspection/introspection_test.go +++ b/tests/gocase/unit/introspection/introspection_test.go @@ -265,6 +265,40 @@ func TestIntrospection(t *testing.T) { require.NoError(t, rdb.Do(ctx, "SET", "key", "value").Err()) require.EqualValues(t, 1, rdb.Do(ctx, "MOVE", "key", "0").Val()) }) + + // Test CLIENT REPLY subcommand behaviors + t.Run("CLIENT REPLY mode switching", func(t *testing.T) { + c := srv.NewTCPClient() + defer func() { require.NoError(t, c.Close()) }() + + // Should reply by default + require.NoError(t, c.WriteArgs("ECHO", "default")) + c.MustReadBulkString(t, "default") + + // Set to OFF, following commands should not reply + require.NoError(t, c.WriteArgs("CLIENT", "REPLY", "OFF")) + require.NoError(t, c.WriteArgs("ECHO", "off")) + // No reply expected here, do not read + + // Set back to ON, commands should reply again + require.NoError(t, c.WriteArgs("CLIENT", "REPLY", "ON")) + c.MustRead(t, "+OK") + require.NoError(t, c.WriteArgs("ECHO", "on")) + c.MustReadBulkString(t, "on") + + // Set to SKIP, next command should not reply, then reply resumes + require.NoError(t, c.WriteArgs("CLIENT", "REPLY", "SKIP")) + // No reply expected here, do not read + + require.NoError(t, c.WriteArgs("ECHO", "skip1")) + // No reply expected here, do not read + + require.NoError(t, c.WriteArgs("ECHO", "skip2")) + c.MustReadBulkString(t, "skip2") + + require.NoError(t, c.WriteArgs("ECHO", "skip3")) + c.MustReadBulkString(t, "skip3") + }) } func TestMultiServerIntrospection(t *testing.T) { diff --git a/tests/gocase/util/tcp_client.go b/tests/gocase/util/tcp_client.go index 46ed3ac9af5..3cd9de2b1a9 100644 --- a/tests/gocase/util/tcp_client.go +++ b/tests/gocase/util/tcp_client.go @@ -87,6 +87,16 @@ func (c *TCPClient) MustReadStrings(t testing.TB, s []string) { } } +func (c *TCPClient) MustReadBulkString(t testing.TB, s string) { + r, err := c.ReadLine() + require.NoError(t, err) + require.Equal(t, "$"+strconv.Itoa(len(s)), r) + + r, err = c.ReadLine() + require.NoError(t, err) + require.Equal(t, s, r) +} + func (c *TCPClient) MustReadStringsWithKey(t testing.TB, key string, s []string) { r, err := c.ReadLine() require.NoError(t, err)