diff --git a/src/commands/cmd_tdigest.cc b/src/commands/cmd_tdigest.cc index e7815256d11..4cdadc71580 100644 --- a/src/commands/cmd_tdigest.cc +++ b/src/commands/cmd_tdigest.cc @@ -176,6 +176,45 @@ class CommandTDigestAdd : public Commander { std::vector values_; }; +class CommandTDigestRevRank : public Commander { + public: + Status Parse(const std::vector &args) override { + key_name_ = args[1]; + inputs_.reserve(args.size() - 2); + for (size_t i = 2; i < args.size(); i++) { + auto value = ParseFloat(args[i]); + if (!value) { + return {Status::RedisParseErr, errValueIsNotFloat}; + } + inputs_.push_back(*value); + } + return Status::OK(); + } + Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { + TDigest tdigest(srv->storage, conn->GetNamespace()); + std::vector result; + result.reserve(inputs_.size()); + if (const auto s = tdigest.RevRank(ctx, key_name_, inputs_, result); !s.ok()) { + if (s.IsNotFound()) { + return {Status::RedisExecErr, errKeyNotFound}; + } + return {Status::RedisExecErr, s.ToString()}; + } + + std::vector rev_ranks; + rev_ranks.reserve(result.size()); + for (const auto v : result) { + rev_ranks.push_back(redis::Integer(v)); + } + *output = redis::Array(rev_ranks); + return Status::OK(); + } + + private: + std::string key_name_; + std::vector inputs_; +}; + class CommandTDigestMinMax : public Commander { public: explicit CommandTDigestMinMax(bool is_min) : is_min_(is_min) {} @@ -369,6 +408,7 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr("tdigest.crea MakeCmdAttr("tdigest.add", -3, "write", 1, 1, 1), MakeCmdAttr("tdigest.max", 2, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.min", 2, "read-only", 1, 1, 1), + MakeCmdAttr("tdigest.revrank", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.quantile", -3, "read-only", 1, 1, 1), MakeCmdAttr("tdigest.reset", 2, "write", 1, 1, 1), MakeCmdAttr("tdigest.merge", -4, "write", GetMergeKeyRange)); diff --git a/src/types/redis_tdigest.cc b/src/types/redis_tdigest.cc index f506ad92e2a..2d45a58d853 100644 --- a/src/types/redis_tdigest.cc +++ b/src/types/redis_tdigest.cc @@ -70,6 +70,8 @@ class DummyCentroids { return iter_ != centroids_.cend(); } + bool IsBegin() { return iter_ == centroids_.cbegin(); } + // The Prev function can only be called for item is not cend, // because we must guarantee the iterator to be inside the valid range before iteration. bool Prev() { @@ -186,8 +188,37 @@ rocksdb::Status TDigest::Add(engine::Context& ctx, const Slice& digest_name, con return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); } -rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice& digest_name, const std::vector& qs, - TDigestQuantitleResult* result) { +rocksdb::Status TDigest::mergeNodes(engine::Context& ctx, const std::string& ns_key, TDigestMetadata* metadata) { + if (metadata->unmerged_nodes == 0) { + return rocksdb::Status::OK(); + } + + auto batch = storage_->GetWriteBatchBase(); + WriteBatchLogData log_data(kRedisTDigest); + if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) { + return status; + } + + if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, metadata); !status.ok()) { + return status; + } + + std::string metadata_bytes; + metadata->Encode(&metadata_bytes); + if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); !status.ok()) { + return status; + } + + if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); !status.ok()) { + return status; + } + + ctx.RefreshLatestSnapshot(); + return rocksdb::Status::OK(); +} + +rocksdb::Status TDigest::RevRank(engine::Context& ctx, const Slice& digest_name, const std::vector& inputs, + std::vector& result) { auto ns_key = AppendNamespacePrefix(digest_name); TDigestMetadata metadata; { @@ -198,31 +229,45 @@ rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice& digest_name } if (metadata.total_observations == 0) { + result.resize(inputs.size(), -2); return rocksdb::Status::OK(); } - if (metadata.unmerged_nodes > 0) { - auto batch = storage_->GetWriteBatchBase(); - WriteBatchLogData log_data(kRedisTDigest); - if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) { - return status; - } + if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + } - if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata); !status.ok()) { - return status; - } + std::vector centroids; + if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { + return status; + } - std::string metadata_bytes; - metadata.Encode(&metadata_bytes); - if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); !status.ok()) { - return status; - } + auto dump_centroids = DummyCentroids(metadata, centroids); + auto status = TDigestRevRank(dump_centroids, inputs, result); + if (!status) { + return rocksdb::Status::InvalidArgument(status.Msg()); + } + return rocksdb::Status::OK(); +} - if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); !status.ok()) { - return status; - } +rocksdb::Status TDigest::Quantile(engine::Context& ctx, const Slice& digest_name, const std::vector& qs, + TDigestQuantitleResult* result) { + auto ns_key = AppendNamespacePrefix(digest_name); + TDigestMetadata metadata; + { + LockGuard guard(storage_->GetLockManager(), ns_key); - ctx.RefreshLatestSnapshot(); + if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) { + return status; + } + + if (metadata.total_observations == 0) { + return rocksdb::Status::OK(); + } + + if (auto status = mergeNodes(ctx, ns_key, &metadata); !status.ok()) { + return status; } } diff --git a/src/types/redis_tdigest.h b/src/types/redis_tdigest.h index 2026cf94d05..02ecc24c016 100644 --- a/src/types/redis_tdigest.h +++ b/src/types/redis_tdigest.h @@ -77,7 +77,8 @@ class TDigest : public SubKeyScanner { rocksdb::Status Merge(engine::Context& ctx, const Slice& dest_digest, const std::vector& source_digests, const TDigestMergeOptions& options); - + rocksdb::Status RevRank(engine::Context& ctx, const Slice& digest_name, const std::vector& inputs, + std::vector& result); rocksdb::Status GetMetaData(engine::Context& context, const Slice& digest_name, TDigestMetadata* metadata); private: @@ -117,6 +118,8 @@ class TDigest : public SubKeyScanner { std::string internalSegmentGuardPrefixKey(const TDigestMetadata& metadata, const std::string& ns_key, SegmentType seg) const; + rocksdb::Status mergeNodes(engine::Context& ctx, const std::string& ns_key, TDigestMetadata* metadata); + rocksdb::Status mergeCurrentBuffer(engine::Context& ctx, const std::string& ns_key, ObserverOrUniquePtr& batch, TDigestMetadata* metadata, const std::vector* additional_buffer = nullptr, diff --git a/src/types/tdigest.h b/src/types/tdigest.h index 1f416b48ecf..33f697d0324 100644 --- a/src/types/tdigest.h +++ b/src/types/tdigest.h @@ -22,6 +22,8 @@ #include +#include +#include #include #include "common/status.h" @@ -150,3 +152,76 @@ inline StatusOr TDigestQuantile(TD&& td, double q) { diff /= (lc.weight / 2 + rc.weight / 2); return Lerp(lc.mean, rc.mean, diff); } + +inline void AssignRankForEqualInputs(const std::vector& indices, double cumulative_weight, + std::vector& result) { + for (auto index : indices) { + result[index] = static_cast(cumulative_weight); + } +} + +template +inline Status TDigestRevRank(TD&& td, const std::vector& inputs, std::vector& result) { + std::map> value_to_indices; + for (size_t i = 0; i < inputs.size(); ++i) { + value_to_indices[inputs[i]].push_back(i); + } + + double cumulative_weight = 0; + result.resize(inputs.size()); + auto it = value_to_indices.rbegin(); + + // handle inputs larger than maximum + while (it != value_to_indices.rend() && it->first > td.Max()) { + AssignRankForEqualInputs(it->second, -1, result); + ++it; + } + + auto iter = td.End(); + while (iter->Valid() && it != value_to_indices.rend()) { + auto centroid = GET_OR_RET(iter->GetCentroid()); + auto input_value = it->first; + if (centroid.mean == input_value) { + auto current_mean = centroid.mean; + auto current_mean_cumulative_weight = cumulative_weight + centroid.weight / 2; + cumulative_weight += centroid.weight; + + // handle all the prev centroids which has the same mean + while (!iter->IsBegin() && iter->Prev()) { + auto next_centroid = GET_OR_RET(iter->GetCentroid()); + if (current_mean != next_centroid.mean) { + // move back to the last equal centroid, because we will process it in the next loop + iter->Next(); + break; + } + current_mean_cumulative_weight += next_centroid.weight / 2; + cumulative_weight += next_centroid.weight; + } + + // handle the prev inputs which has the same value + AssignRankForEqualInputs(it->second, current_mean_cumulative_weight, result); + ++it; + if (iter->IsBegin()) { + break; + } + iter->Prev(); + } else if (centroid.mean > input_value) { + cumulative_weight += centroid.weight; + if (iter->IsBegin()) { + break; + } + iter->Prev(); + } else { + AssignRankForEqualInputs(it->second, cumulative_weight, result); + ++it; + } + } + + // handle inputs less than minimum + while (it != value_to_indices.rend()) { + AssignRankForEqualInputs(it->second, td.TotalWeight(), result); + ++it; + } + + return Status::OK(); +} diff --git a/tests/cppunit/types/tdigest_test.cc b/tests/cppunit/types/tdigest_test.cc index 91d1f311f4c..09b27fd70cd 100644 --- a/tests/cppunit/types/tdigest_test.cc +++ b/tests/cppunit/types/tdigest_test.cc @@ -298,3 +298,79 @@ TEST_F(RedisTDigestTest, Quantile_returns_nan_on_empty_tdigest) { ASSERT_TRUE(status.ok()) << status.ToString(); ASSERT_FALSE(result.quantiles) << "should not have quantiles with empty tdigest"; } + +TEST_F(RedisTDigestTest, RevRank_on_the_set_contains_different_elements) { + std::string test_digest_name = "test_digest_revrank" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + std::vector input{10, 20, 30, 40, 50, 60}; + status = tdigest_->Add(*ctx_, test_digest_name, input); + ASSERT_TRUE(status.ok()) << status.ToString(); + + std::vector result; + result.reserve(input.size()); + const std::vector value = {0, 10, 20, 30, 40, 50, 60, 70}; + status = tdigest_->RevRank(*ctx_, test_digest_name, value, result); + const auto expect_result = std::vector{6, 5, 4, 3, 2, 1, 0, -1}; + + for (size_t i = 0; i < result.size(); i++) { + auto got = result[i]; + EXPECT_EQ(got, expect_result[i]); + } + ASSERT_TRUE(status.ok()) << status.ToString(); +} + +TEST_F(RedisTDigestTest, RevRank_on_the_set_contains_several_identical_elements) { + std::string test_digest_name = "test_digest_revrank" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + std::vector input{10, 10, 10, 20, 20}; + status = tdigest_->Add(*ctx_, test_digest_name, input); + ASSERT_TRUE(status.ok()) << status.ToString(); + + std::vector result; + result.reserve(input.size()); + const std::vector value = {10, 20}; + status = tdigest_->RevRank(*ctx_, test_digest_name, value, result); + const auto expect_result = std::vector{3, 1}; + for (size_t i = 0; i < result.size(); i++) { + auto got = result[i]; + EXPECT_EQ(got, expect_result[i]); + } + ASSERT_TRUE(status.ok()) << status.ToString(); + + status = tdigest_->Add(*ctx_, test_digest_name, std::vector{10}); + ASSERT_TRUE(status.ok()) << status.ToString(); + + result.clear(); + status = tdigest_->RevRank(*ctx_, test_digest_name, value, result); + const auto expect_result_new = std::vector{4, 1}; + for (size_t i = 0; i < result.size(); i++) { + auto got = result[i]; + EXPECT_EQ(got, expect_result_new[i]); + } + ASSERT_TRUE(status.ok()) << status.ToString(); +} + +TEST_F(RedisTDigestTest, RevRank_on_empty_tdigest) { + std::string test_digest_name = "test_digest_revrank" + std::to_string(util::GetTimeStampMS()); + bool exists = false; + auto status = tdigest_->Create(*ctx_, test_digest_name, {100}, &exists); + ASSERT_FALSE(exists); + ASSERT_TRUE(status.ok()); + + std::vector result; + result.reserve(2); + const std::vector value = {10, 20}; + status = tdigest_->RevRank(*ctx_, test_digest_name, value, result); + const auto expect_result = std::vector{-2, -2}; + for (size_t i = 0; i < result.size(); i++) { + auto got = result[i]; + EXPECT_EQ(got, expect_result[i]); + } + ASSERT_TRUE(status.ok()) << status.ToString(); +} \ No newline at end of file diff --git a/tests/gocase/unit/type/tdigest/tdigest_test.go b/tests/gocase/unit/type/tdigest/tdigest_test.go index 3c2565ac066..91c453bb33f 100644 --- a/tests/gocase/unit/type/tdigest/tdigest_test.go +++ b/tests/gocase/unit/type/tdigest/tdigest_test.go @@ -518,4 +518,73 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { validation(newDestKey1) validation(newDestKey2) }) + + t.Run("tdigest.revrank with different arguments", func(t *testing.T) { + keyPrefix := "tdigest_revrank_" + + // Test invalid arguments + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK").Err(), errMsgWrongNumberArg) + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK", keyPrefix+"nonexistent").Err(), errMsgWrongNumberArg) + + // Test Non-existent key + require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.REVRANK", keyPrefix+"nonexistent", "10").Err(), errMsgKeyNotExist) + + // Test with empty tdigest + key := keyPrefix + "test1" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key, "compression", "100").Err()) + rsp := rdb.Do(ctx, "TDIGEST.REVRANK", key, "10") + require.NoError(t, rsp.Err()) + vals, err := rsp.Slice() + require.NoError(t, err) + require.Len(t, vals, 1) + expected := []int64{-2} + for i, v := range vals { + rank, ok := v.(int64) + require.True(t, ok, "expected int64 but got %T at index %d", v, i) + require.EqualValues(t, rank, expected[i]) + } + + // Test with set_contains several identical elements + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10", "10", "10", "20", "20").Err()) + rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20") + require.NoError(t, rsp.Err()) + vals, err = rsp.Slice() + require.NoError(t, err) + require.Len(t, vals, 2) + expected = []int64{3, 1} + for i, v := range vals { + rank, ok := v.(int64) + require.True(t, ok, "expected int64 but got %T at index %d", v, i) + require.EqualValues(t, rank, expected[i]) + } + + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key, "10").Err()) + rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key, "10", "20") + require.NoError(t, rsp.Err()) + vals, err = rsp.Slice() + require.NoError(t, err) + require.Len(t, vals, 2) + expected = []int64{4, 1} + for i, v := range vals { + rank, ok := v.(int64) + require.True(t, ok, "expected int64 but got %T at index %d", v, i) + require.EqualValues(t, rank, expected[i]) + } + + // Test with set_contains different elements + key2 := keyPrefix + "test2" + require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", key2, "compression", "100").Err()) + require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", key2, "10", "20", "30", "40", "50", "60").Err()) + rsp = rdb.Do(ctx, "TDIGEST.REVRANK", key2, "0", "10", "20", "30", "40", "50", "60", "70") + require.NoError(t, rsp.Err()) + vals, err = rsp.Slice() + require.NoError(t, err) + require.Len(t, vals, 8) + expected = []int64{6, 5, 4, 3, 2, 1, 0, -1} + for i, v := range vals { + rank, ok := v.(int64) + require.True(t, ok, "expected int64 but got %T at index %d", v, i) + require.EqualValues(t, rank, expected[i]) + } + }) }