-
Couldn't load subscription status.
- Fork 578
feat(tdigest): Implement TDIGEST.CDF command #3163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: unstable
Are you sure you want to change the base?
Changes from all commits
3235803
f3bdd77
9e26d86
7366edb
d3a44cf
262574f
b64389d
4731484
f361fc4
6753870
cb86cb3
fcd7d5a
e0451ba
1d425a0
31423cd
67923f4
4d1988b
ec52575
d23ad76
cbff56e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -358,7 +358,56 @@ class CommandTDigestMerge : public Commander { | |
| std::vector<std::string> source_keys_; | ||
| TDigestMergeOptions options_; | ||
| }; | ||
| class CommandTDigestCDF : public Commander { | ||
| Status Parse(const std::vector<std::string> &args) override { | ||
| if (args.size() == 2) return {Status::RedisParseErr, errWrongNumOfArguments}; | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could check the vector size at beginning. |
||
| key_name_ = args[1]; | ||
| values_.reserve(args.size() - 2); | ||
| for (size_t i = 2; i < args.size(); i++) { | ||
| auto value = ParseFloat(args[i]); | ||
| if (!value) { | ||
| return {Status::RedisParseErr, errValueIsNotFloat}; | ||
| } | ||
| values_.push_back(*value); | ||
| } | ||
| return Status::OK(); | ||
| } | ||
| Status Execute(engine::Context &ctx, Server *srv, Connection *conn, std::string *output) override { | ||
| TDigest tdigest(srv->storage, conn->GetNamespace()); | ||
| std::vector<std::string> cdf_result; | ||
| TDigestCDFResult result; | ||
| TDigestMetadata metadata; | ||
| auto meta_status = tdigest.GetMetaData(ctx, key_name_, &metadata); | ||
| std::vector<std::string> nan_results(values_.size(), "nan"); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @SharonIV0x86 , We could move this |
||
| if (!meta_status.ok()) { | ||
| if (meta_status.IsNotFound()) { | ||
| return {Status::RedisExecErr, errKeyNotFound}; | ||
| } | ||
| *output = redis::MultiBulkString(RESP::v2, nan_results); | ||
| return Status::OK(); | ||
| } | ||
| if (metadata.total_observations == 0) { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have tested with Redis Docker, it should be the ["nan"] vector with the same size as the input. |
||
| *output = redis::MultiBulkString(RESP::v2, nan_results); | ||
| return Status::OK(); | ||
| } | ||
| auto s = tdigest.CDF(ctx, key_name_, values_, &result); | ||
| if (!s.ok()) { | ||
| *output = redis::MultiBulkString(RESP::v2, nan_results); | ||
| return {Status::RedisExecErr, s.ToString()}; | ||
| } | ||
| if (result.cdf_values) { | ||
| for (const auto &val : *result.cdf_values) { | ||
| cdf_result.push_back(util::Float2String(val)); | ||
| } | ||
| } | ||
| *output = redis::MultiBulkString(RESP::v2, cdf_result); | ||
| return Status::OK(); | ||
| } | ||
|
|
||
| private: | ||
| std::string key_name_; | ||
| std::vector<double> values_; | ||
| }; | ||
| std::vector<CommandKeyRange> GetMergeKeyRange(const std::vector<std::string> &args) { | ||
| auto numkeys = ParseInt<int>(args[2], 10).ValueOr(0); | ||
| return {{1, 1, 1}, {3, 2 + numkeys, 1}}; | ||
|
|
@@ -371,5 +420,6 @@ REDIS_REGISTER_COMMANDS(TDigest, MakeCmdAttr<CommandTDigestCreate>("tdigest.crea | |
| MakeCmdAttr<CommandTDigestMin>("tdigest.min", 2, "read-only", 1, 1, 1), | ||
| MakeCmdAttr<CommandTDigestQuantile>("tdigest.quantile", -3, "read-only", 1, 1, 1), | ||
| MakeCmdAttr<CommandTDigestReset>("tdigest.reset", 2, "write", 1, 1, 1), | ||
| MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, "write", GetMergeKeyRange)); | ||
| MakeCmdAttr<CommandTDigestMerge>("tdigest.merge", -4, "write", GetMergeKeyRange), | ||
| MakeCmdAttr<CommandTDigestCDF>("tdigest.cdf", -3, "read-only", 1, 1, 1)); | ||
| } // namespace redis | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -414,6 +414,69 @@ rocksdb::Status TDigest::Merge(engine::Context& ctx, const Slice& dest_digest, | |||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| return storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| rocksdb::Status TDigest::CDF(engine::Context& ctx, const Slice& digest_name, const std::vector<double>& inputs, | ||||||||||||||||||||||||||||||||||||||||||||||
| TDigestCDFResult* result) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto ns_key = AppendNamespacePrefix(digest_name); | ||||||||||||||||||||||||||||||||||||||||||||||
| TDigestMetadata metadata; | ||||||||||||||||||||||||||||||||||||||||||||||
| { | ||||||||||||||||||||||||||||||||||||||||||||||
| LockGuard guard(storage_->GetLockManager(), ns_key); | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = getMetaDataByNsKey(ctx, ns_key, &metadata); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (metadata.unmerged_nodes > 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto batch = storage_->GetWriteBatchBase(); | ||||||||||||||||||||||||||||||||||||||||||||||
| WriteBatchLogData log_data(kRedisTDigest); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = batch->PutLogData(log_data.Encode()); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = mergeCurrentBuffer(ctx, ns_key, batch, &metadata); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| if (metadata.total_observations == 0) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return rocksdb::Status::OK(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| std::string metadata_bytes; | ||||||||||||||||||||||||||||||||||||||||||||||
| metadata.Encode(&metadata_bytes); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = batch->Put(metadata_cf_handle_, ns_key, metadata_bytes); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = storage_->Write(ctx, storage_->DefaultWriteOptions(), batch->GetWriteBatch()); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| ctx.RefreshLatestSnapshot(); | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<Centroid> centroids; | ||||||||||||||||||||||||||||||||||||||||||||||
| if (auto status = dumpCentroids(ctx, ns_key, metadata, ¢roids); !status.ok()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| return status; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| auto dump_centroids = DummyCentroids(metadata, centroids); | ||||||||||||||||||||||||||||||||||||||||||||||
| double total_weight = dump_centroids.TotalWeight(); | ||||||||||||||||||||||||||||||||||||||||||||||
| std::vector<double> results; | ||||||||||||||||||||||||||||||||||||||||||||||
| for (double val : inputs) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto iter_begin = dump_centroids.Begin(); | ||||||||||||||||||||||||||||||||||||||||||||||
| auto iter_end = dump_centroids.End(); | ||||||||||||||||||||||||||||||||||||||||||||||
| double eq_count = 0; | ||||||||||||||||||||||||||||||||||||||||||||||
| double smaller_count = 0; | ||||||||||||||||||||||||||||||||||||||||||||||
| for (; iter_begin->Valid(); iter_begin->Next()) { | ||||||||||||||||||||||||||||||||||||||||||||||
| auto current_centroid = iter_begin->GetCentroid(); | ||||||||||||||||||||||||||||||||||||||||||||||
| if (val > current_centroid->mean) { | ||||||||||||||||||||||||||||||||||||||||||||||
| smaller_count++; | ||||||||||||||||||||||||||||||||||||||||||||||
| } else if (val == current_centroid->mean) { | ||||||||||||||||||||||||||||||||||||||||||||||
| eq_count++; | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||||||||||||||||||
| double cdf_val = (smaller_count / total_weight) + ((eq_count / 2) / total_weight); | ||||||||||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+464
to
+474
|
||||||||||||||||||||||||||||||||||||||||||||||
| double eq_count = 0; | |
| double smaller_count = 0; | |
| for (; iter_begin->Valid(); iter_begin->Next()) { | |
| auto current_centroid = iter_begin->GetCentroid(); | |
| if (val > current_centroid->mean) { | |
| smaller_count++; | |
| } else if (val == current_centroid->mean) { | |
| eq_count++; | |
| } | |
| } | |
| double cdf_val = (smaller_count / total_weight) + ((eq_count / 2) / total_weight); | |
| double eq_weight = 0; | |
| double smaller_weight = 0; | |
| for (; iter_begin->Valid(); iter_begin->Next()) { | |
| auto current_centroid = iter_begin->GetCentroid(); | |
| if (val > current_centroid->mean) { | |
| smaller_weight += current_centroid->weight; | |
| } else if (val == current_centroid->mean) { | |
| eq_weight += current_centroid->weight; | |
| } | |
| } | |
| double cdf_val = (smaller_weight / total_weight) + ((eq_weight / 2) / total_weight); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @SharonIV0x86 ,
It seems that we mistake the count with weight here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I will take a look into it.
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -39,6 +39,7 @@ const ( | |||||
| errMsgKeyNotExist = "key does not exist" | ||||||
| errNumkeysMustBePositive = "numkeys need to be a positive integer" | ||||||
| errCompressionParameterMustBePositive = "compression parameter needs to be a positive integer" | ||||||
| errValueIsNotFloat = "value is not a valid float" | ||||||
| ) | ||||||
|
|
||||||
| type tdigestInfo struct { | ||||||
|
|
@@ -518,4 +519,55 @@ func tdigestTests(t *testing.T, configs util.KvrocksServerConfigs) { | |||||
| validation(newDestKey1) | ||||||
| validation(newDestKey2) | ||||||
| }) | ||||||
| t.Run("tdigest.cdf with different arguments", func(t *testing.T) { | ||||||
| keyPrefix := "tdigest_cdf_" | ||||||
|
|
||||||
| require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.CDF").Err(), errMsgWrongNumberArg) | ||||||
| require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.CDF", keyPrefix+"key1").Err(), errMsgWrongNumberArg) | ||||||
|
|
||||||
| // non-existent key | ||||||
| require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.CDF", keyPrefix+"nonexistent", "1.0").Err(), errMsgKeyNotExist) | ||||||
|
|
||||||
| // invalid float value | ||||||
| require.ErrorContains(t, rdb.Do(ctx, "TDIGEST.CDF", keyPrefix+"key2", "invalid").Err(), errValueIsNotFloat) | ||||||
|
|
||||||
| // create a tdigest and add some data | ||||||
| tdigestKey := keyPrefix + "source" | ||||||
| require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", tdigestKey).Err()) | ||||||
| require.NoError(t, rdb.Do(ctx, "TDIGEST.ADD", tdigestKey, "1.0", "2.0", "3.0", "4.0", "5.0").Err()) | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hi @SharonIV0x86 , We'd better add some tests with duplicated values to create different weights for some centroids. |
||||||
|
|
||||||
| // single-value CDF query | ||||||
| rsp := rdb.Do(ctx, "TDIGEST.CDF", tdigestKey, "3.0") | ||||||
| require.NoError(t, rsp.Err()) | ||||||
| vals, err := rsp.Slice() | ||||||
| require.NoError(t, err) | ||||||
| require.Len(t, vals, 1) | ||||||
| require.NotEqual(t, "nan", vals[0]) | ||||||
|
|
||||||
| // multi-value CDF query | ||||||
| rsp = rdb.Do(ctx, "TDIGEST.CDF", tdigestKey, "0.0", "2.5", "5.0", "10.0") | ||||||
| require.NoError(t, rsp.Err()) | ||||||
| vals, err = rsp.Slice() | ||||||
| require.NoError(t, err) | ||||||
| require.Len(t, vals, 4) | ||||||
|
|
||||||
| // empty tdigest should return "nan" | ||||||
| emptyKey := keyPrefix + "empty" | ||||||
| require.NoError(t, rdb.Do(ctx, "TDIGEST.CREATE", emptyKey).Err()) | ||||||
| rsp = rdb.Do(ctx, "TDIGEST.CDF", emptyKey, "1.0") | ||||||
| require.NoError(t, rsp.Err()) | ||||||
| vals, err = rsp.Slice() | ||||||
| require.NoError(t, err) | ||||||
| require.Len(t, vals, 1) | ||||||
| require.Equal(t, "nan", vals[0]) | ||||||
|
|
||||||
| // testing with a empry digest with multi-valued CDF | ||||||
|
||||||
| // testing with a empry digest with multi-valued CDF | |
| // testing with an empty digest with multi-valued CDF |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The command registration specifies minimum 4 arguments (-4), but this validation only checks for exactly 2 arguments. It should validate that there are at least 3 arguments (command + key + at least one value).