Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 16 additions & 5 deletions willow/input_encoding/codec_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,29 @@ PYBIND11_MODULE(codec_bindings, m) {

m.def(
"CreateExplicitCodec",
[](const std::string& serialized_input_spec,
size_t max_flattened_domain_size)
[](const std::string& serialized_input_spec)
-> absl::StatusOr<std::unique_ptr<Codec>> {
InputSpec input_spec;
if (!input_spec.ParseFromString(serialized_input_spec)) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
return CodecFactory::CreateExplicitCodec(input_spec);
},
py::arg("serialized_input_spec"));

m.def(
"ValidateExplicitCodecInputSpec",
[](const std::string& serialized_input_spec,
size_t max_flattened_domain_size) -> absl::Status {
InputSpec input_spec;
if (!input_spec.ParseFromString(serialized_input_spec)) {
return absl::InvalidArgumentError("Failed to parse InputSpec");
}
if (max_flattened_domain_size == 0) {
return CodecFactory::CreateExplicitCodec(input_spec);
return CodecFactory::ValidateExplicitCodecInputSpec(input_spec);
} else {
return CodecFactory::CreateExplicitCodec(input_spec,
max_flattened_domain_size);
return CodecFactory::ValidateExplicitCodecInputSpec(
input_spec, max_flattened_domain_size);
}
},
py::arg("serialized_input_spec"),
Expand Down
11 changes: 7 additions & 4 deletions willow/input_encoding/codec_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,8 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery(
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<Codec>> CodecFactory::CreateExplicitCodec(
InputSpec input_spec, size_t max_flattened_domain_size) {
// Check that the combined size of the string domains is less than the
// maximum allowed size.
absl::Status CodecFactory::ValidateExplicitCodecInputSpec(
const InputSpec& input_spec, size_t max_flattened_domain_size) {
size_t flattened_domain_size = 1;
for (const auto& spec : input_spec.group_by_vector_specs()) {
flattened_domain_size *= spec.domain_spec().string_values().values_size();
Expand All @@ -385,6 +383,11 @@ absl::StatusOr<std::unique_ptr<Codec>> CodecFactory::CreateExplicitCodec(
"Global output domain size exceeds maximum threshold.");
}
}
return absl::OkStatus();
}

absl::StatusOr<std::unique_ptr<Codec>> CodecFactory::CreateExplicitCodec(
InputSpec input_spec) {
// Check that specs include at least one metric vector.
if (input_spec.metric_vector_specs().empty()) {
return absl::InvalidArgumentError(
Expand Down
7 changes: 6 additions & 1 deletion willow/input_encoding/codec_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,12 @@ class CodecFactory {
public:
// Creates an instance of ExplicitCodec.
static absl::StatusOr<std::unique_ptr<Codec>> CreateExplicitCodec(
InputSpec input_spec,
InputSpec input_spec);

// Check that the combined size of the string domains is less than the
// maximum allowed size.
static absl::Status ValidateExplicitCodecInputSpec(
const InputSpec& input_spec,
size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize);
};

Expand Down
6 changes: 3 additions & 3 deletions willow/input_encoding/explicit_codec_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) {
->add_values(std::to_string(i));
}

EXPECT_THAT(CodecFactory::CreateExplicitCodec(input_spec),
EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Global output domain size exceeds")));
}
Expand All @@ -292,10 +292,10 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) {
group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values(
"b");
// Domain size is 2.
EXPECT_THAT(CodecFactory::CreateExplicitCodec(input_spec, 1),
EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 1),
StatusIs(absl::StatusCode::kInvalidArgument,
HasSubstr("Global output domain size exceeds")));
SECAGG_EXPECT_OK(CodecFactory::CreateExplicitCodec(input_spec, 2));
SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2));
}

TEST(CodecFactoryTest, EncodeSimpleGroupBy) {
Expand Down