From 8ca24957508dd9dc8aac410c03a496dd4b1a52b8 Mon Sep 17 00:00:00 2001 From: Pierre Tholoniat Date: Thu, 5 Mar 2026 15:22:08 -0800 Subject: [PATCH] Add InputSpec artifact URI and max_number_of_clients to Willow tasks. PiperOrigin-RevId: 879271513 --- willow/input_encoding/codec_bindings.cc | 21 +++++++++++++++----- willow/input_encoding/codec_factory.cc | 11 ++++++---- willow/input_encoding/codec_factory.h | 7 ++++++- willow/input_encoding/explicit_codec_test.cc | 6 +++--- 4 files changed, 32 insertions(+), 13 deletions(-) diff --git a/willow/input_encoding/codec_bindings.cc b/willow/input_encoding/codec_bindings.cc index 4c860f3..d076527 100644 --- a/willow/input_encoding/codec_bindings.cc +++ b/willow/input_encoding/codec_bindings.cc @@ -48,18 +48,29 @@ PYBIND11_MODULE(codec_bindings, m) { m.def( "CreateExplicitCodec", - [](const std::string& serialized_input_spec, - size_t max_flattened_domain_size) + [](const std::string& serialized_input_spec) -> absl::StatusOr> { InputSpec input_spec; if (!input_spec.ParseFromString(serialized_input_spec)) { return absl::InvalidArgumentError("Failed to parse InputSpec"); } + return CodecFactory::CreateExplicitCodec(input_spec); + }, + py::arg("serialized_input_spec")); + + m.def( + "ValidateExplicitCodecInputSpec", + [](const std::string& serialized_input_spec, + size_t max_flattened_domain_size) -> absl::Status { + InputSpec input_spec; + if (!input_spec.ParseFromString(serialized_input_spec)) { + return absl::InvalidArgumentError("Failed to parse InputSpec"); + } if (max_flattened_domain_size == 0) { - return CodecFactory::CreateExplicitCodec(input_spec); + return CodecFactory::ValidateExplicitCodecInputSpec(input_spec); } else { - return CodecFactory::CreateExplicitCodec(input_spec, - max_flattened_domain_size); + return CodecFactory::ValidateExplicitCodecInputSpec( + input_spec, max_flattened_domain_size); } }, py::arg("serialized_input_spec"), diff --git a/willow/input_encoding/codec_factory.cc b/willow/input_encoding/codec_factory.cc index 8ded392..b6e8ec3 100644 --- a/willow/input_encoding/codec_factory.cc +++ b/willow/input_encoding/codec_factory.cc @@ -373,10 +373,8 @@ absl::Status ExplicitCodecImpl::ValidateExampleQuery( return absl::OkStatus(); } -absl::StatusOr> CodecFactory::CreateExplicitCodec( - InputSpec input_spec, size_t max_flattened_domain_size) { - // Check that the combined size of the string domains is less than the - // maximum allowed size. +absl::Status CodecFactory::ValidateExplicitCodecInputSpec( + const InputSpec& input_spec, size_t max_flattened_domain_size) { size_t flattened_domain_size = 1; for (const auto& spec : input_spec.group_by_vector_specs()) { flattened_domain_size *= spec.domain_spec().string_values().values_size(); @@ -385,6 +383,11 @@ absl::StatusOr> CodecFactory::CreateExplicitCodec( "Global output domain size exceeds maximum threshold."); } } + return absl::OkStatus(); +} + +absl::StatusOr> CodecFactory::CreateExplicitCodec( + InputSpec input_spec) { // Check that specs include at least one metric vector. if (input_spec.metric_vector_specs().empty()) { return absl::InvalidArgumentError( diff --git a/willow/input_encoding/codec_factory.h b/willow/input_encoding/codec_factory.h index dc88727..5ef4bb6 100644 --- a/willow/input_encoding/codec_factory.h +++ b/willow/input_encoding/codec_factory.h @@ -34,7 +34,12 @@ class CodecFactory { public: // Creates an instance of ExplicitCodec. static absl::StatusOr> CreateExplicitCodec( - InputSpec input_spec, + InputSpec input_spec); + + // Check that the combined size of the string domains is less than the + // maximum allowed size. + static absl::Status ValidateExplicitCodecInputSpec( + const InputSpec& input_spec, size_t max_flattened_domain_size = kMaxGlobalOutputDomainSize); }; diff --git a/willow/input_encoding/explicit_codec_test.cc b/willow/input_encoding/explicit_codec_test.cc index 80e810a..eb46406 100644 --- a/willow/input_encoding/explicit_codec_test.cc +++ b/willow/input_encoding/explicit_codec_test.cc @@ -270,7 +270,7 @@ TEST(CodecFactoryTest, ValidateInputAndSpecGlobalDomainSizeExceeded) { ->add_values(std::to_string(i)); } - EXPECT_THAT(CodecFactory::CreateExplicitCodec(input_spec), + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Global output domain size exceeds"))); } @@ -292,10 +292,10 @@ TEST(CodecFactoryTest, ValidateInputAndSpecCustomGlobalDomainSize) { group_by_spec->mutable_domain_spec()->mutable_string_values()->add_values( "b"); // Domain size is 2. - EXPECT_THAT(CodecFactory::CreateExplicitCodec(input_spec, 1), + EXPECT_THAT(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 1), StatusIs(absl::StatusCode::kInvalidArgument, HasSubstr("Global output domain size exceeds"))); - SECAGG_EXPECT_OK(CodecFactory::CreateExplicitCodec(input_spec, 2)); + SECAGG_EXPECT_OK(CodecFactory::ValidateExplicitCodecInputSpec(input_spec, 2)); } TEST(CodecFactoryTest, EncodeSimpleGroupBy) {