From 0a1b6b727dba7fc82d31f2c8ff6468d568c69a34 Mon Sep 17 00:00:00 2001 From: Hansong Zhang Date: Wed, 6 May 2026 16:03:45 -0700 Subject: [PATCH] Allow TextDecoderRunner to reuse a persistent Sampler Summary: TextDecoderRunner now optionally takes a std::unique_ptr in its constructor, owning the Sampler and reusing it across token generation steps. This avoids creating a temporary Sampler on every logits_to_token() call, preserving RNG state across tokens and laying groundwork for future grammar- constrained decoding. Changes: - sampler.h: add set_temperature() so the owned Sampler's temperature can be updated per-call without reconstructing it - util.h: extract sample_from_logits(Tensor, Sampler&) for dtype dispatch with an existing Sampler; logits_to_token delegates to it - text_decoder_runner.h/cpp: ctor takes optional unique_ptr; when set, logits_to_token updates its temperature and reuses it - multimodal_decoder_runner.h: forward Sampler to base class ctor - llm_runner_helper.cpp: factory creates Sampler and passes to TextDecoderRunner - text_llm_runner.cpp: temperature resolved to local variable per-call - test_text_decoder_runner.cpp: add tests for ctor-injected Sampler Differential Revision: D103748189 --- extension/llm/runner/llm_runner_helper.cpp | 15 ++++- .../llm/runner/multimodal_decoder_runner.h | 12 +++- .../runner/test/test_text_decoder_runner.cpp | 58 +++++++++++++++++++ extension/llm/runner/text_decoder_runner.cpp | 6 +- extension/llm/runner/text_decoder_runner.h | 16 +++-- extension/llm/runner/text_llm_runner.cpp | 7 ++- extension/llm/sampler/sampler.h | 4 ++ extension/llm/sampler/util.h | 35 ++++++----- 8 files changed, 127 insertions(+), 26 deletions(-) diff --git a/extension/llm/runner/llm_runner_helper.cpp b/extension/llm/runner/llm_runner_helper.cpp index 4d34fd716e3..0744c09e641 100644 --- a/extension/llm/runner/llm_runner_helper.cpp +++ b/extension/llm/runner/llm_runner_helper.cpp @@ -266,10 +266,15 @@ std::unique_ptr create_text_llm_runner( // Create IOManager std::unique_ptr io_manager = std::make_unique(*module); + // Read vocab_size for Sampler + int32_t vocab_size = static_cast(metadata.at(kVocabSize)); + float init_temp = temperature == -1.0f ? 0.0f : temperature; + auto sampler = std::make_unique(vocab_size, init_temp); + // Create text_decoder_runner ET_LOG(Info, "Using method: %s", method_name.c_str()); auto text_decoder_runner = std::make_unique( - module.get(), io_manager.get(), method_name); + module.get(), io_manager.get(), method_name, std::move(sampler)); // Create text_prefiller auto text_prefiller = std::make_unique( @@ -334,9 +339,13 @@ std::unique_ptr create_multimodal_runner( // Create IOManager std::unique_ptr io_manager = std::make_unique(*module); + // Read vocab_size for Sampler + int32_t vocab_size = static_cast(metadata.at(kVocabSize)); + auto sampler = std::make_unique(vocab_size, 0.0f); // Default temp + // Create text_decoder_runner - auto text_decoder_runner = - std::make_unique(module.get(), io_manager.get()); + auto text_decoder_runner = std::make_unique( + module.get(), io_manager.get(), "forward", std::move(sampler)); // Create multimodal_prefiller auto multimodal_prefiller = std::make_unique( diff --git a/extension/llm/runner/multimodal_decoder_runner.h b/extension/llm/runner/multimodal_decoder_runner.h index 5773e5ca909..b9be142c30d 100644 --- a/extension/llm/runner/multimodal_decoder_runner.h +++ b/extension/llm/runner/multimodal_decoder_runner.h @@ -15,8 +15,16 @@ namespace executorch::extension::llm { class ET_EXPERIMENTAL MultimodalDecoderRunner : public executorch::extension::llm::TextDecoderRunner { public: - explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager) - : TextDecoderRunner(module, io_manager) {} + explicit MultimodalDecoderRunner( + Module* module, + IOManager* io_manager, + std::string method_name = "forward", + std::unique_ptr sampler = nullptr) + : TextDecoderRunner( + module, + io_manager, + std::move(method_name), + std::move(sampler)) {} /** * Step the LLM Decoder with the given tokens and start position. diff --git a/extension/llm/runner/test/test_text_decoder_runner.cpp b/extension/llm/runner/test/test_text_decoder_runner.cpp index 917467e31fd..af66bdc52e1 100644 --- a/extension/llm/runner/test/test_text_decoder_runner.cpp +++ b/extension/llm/runner/test/test_text_decoder_runner.cpp @@ -155,6 +155,64 @@ TEST_F(TextDecoderRunnerTest, LogitsToTokenWithTemperature) { EXPECT_LT(token, 4); } +// Test logits_to_token() with an injected Sampler (greedy, temp=0) +TEST_F(TextDecoderRunnerTest, LogitsToTokenWithInjectedSampler) { + TensorFactory tf_float; + auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + + auto sampler = std::make_unique(4, 0.0f); + auto runner = std::make_unique( + nullptr, nullptr, "forward", std::move(sampler)); + + int32_t token = runner->logits_to_token(logits, 0.0f); + EXPECT_EQ(token, 2); + + auto logits2 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + token = runner->logits_to_token(logits2, 0.0f); + EXPECT_EQ(token, 2); +} + +// Test that set_temperature works on an injected Sampler +TEST_F(TextDecoderRunnerTest, LogitsToTokenInjectedSamplerTemperatureSwitch) { + auto sampler = std::make_unique(4, 0.0f); + auto runner = std::make_unique( + nullptr, nullptr, "forward", std::move(sampler)); + + TensorFactory tf_float; + + // temp=0 → argmax + auto logits1 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + EXPECT_EQ(runner->logits_to_token(logits1, 0.0f), 2); + + // temp=1.0 → stochastic, result must be in valid range + auto logits2 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + int32_t token = runner->logits_to_token(logits2, 1.0f); + EXPECT_GE(token, 0); + EXPECT_LT(token, 4); + + // temp=0 again → back to argmax + auto logits3 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f}); + EXPECT_EQ(runner->logits_to_token(logits3, 0.0f), 2); +} + +// Test logits_to_token() with an injected Sampler on a 3D tensor +TEST_F(TextDecoderRunnerTest, LogitsToTokenWithInjectedSampler3D) { + TensorFactory tf_float; + auto logits = tf_float.make( + {1, 2, 4}, + { + 0.1f, 0.2f, 0.3f, 0.4f, // first position + 0.5f, 0.6f, 0.9f, 0.8f // last position (used for sampling) + }); + + auto sampler = std::make_unique(4, 0.0f); + auto runner = std::make_unique( + nullptr, nullptr, "forward", std::move(sampler)); + + int32_t token = runner->logits_to_token(logits, 0.0f); + EXPECT_EQ(token, 2); +} + // Test step() method with all available PTE models TEST_F(TextDecoderRunnerTest, StepWithAllModels) { // List of all environment variables for PTE models diff --git a/extension/llm/runner/text_decoder_runner.cpp b/extension/llm/runner/text_decoder_runner.cpp index 3eb4e346e05..8d234a4f306 100644 --- a/extension/llm/runner/text_decoder_runner.cpp +++ b/extension/llm/runner/text_decoder_runner.cpp @@ -25,10 +25,12 @@ namespace llm { TextDecoderRunner::TextDecoderRunner( Module* module, IOManager* io_manager, - std::string method_name) + std::string method_name, + std::unique_ptr sampler) : module_(module), io_manager_(io_manager), - method_name_(std::move(method_name)) {} + method_name_(std::move(method_name)), + sampler_(std::move(sampler)) {} // This function is functional, meaning it shouldn't modify any state of the // input. It should be safe to call multiple times with the same inputs. The diff --git a/extension/llm/runner/text_decoder_runner.h b/extension/llm/runner/text_decoder_runner.h index 6762b73b7ce..e92fcdaa8f6 100644 --- a/extension/llm/runner/text_decoder_runner.h +++ b/extension/llm/runner/text_decoder_runner.h @@ -23,7 +23,8 @@ class ET_EXPERIMENTAL TextDecoderRunner { explicit TextDecoderRunner( Module* module, IOManager* io_manager, - std::string method_name = "forward"); + std::string method_name = "forward", + std::unique_ptr sampler = nullptr); virtual ~TextDecoderRunner() = default; @@ -71,14 +72,18 @@ class ET_EXPERIMENTAL TextDecoderRunner { /** * Sample the next token from the logits tensor. - * @param logits_tensor The logits tensor. - * @param temperature The temperature parameter used to control randomness in - * sampling. - * @return The next token. + * If a Sampler was passed in the constructor, it is reused (its temperature + * will be updated to match the argument). Otherwise a temporary Sampler is + * created per call. */ inline int32_t logits_to_token( const executorch::aten::Tensor& logits_tensor, const float temperature = 0.0f) { + if (sampler_) { + sampler_->set_temperature(temperature); + return ::executorch::extension::llm::sample_from_logits( + logits_tensor, *sampler_); + } return ::executorch::extension::llm::logits_to_token( logits_tensor, temperature); } @@ -94,6 +99,7 @@ class ET_EXPERIMENTAL TextDecoderRunner { Module* module_; IOManager* io_manager_; std::string method_name_; + std::unique_ptr sampler_; }; } // namespace llm diff --git a/extension/llm/runner/text_llm_runner.cpp b/extension/llm/runner/text_llm_runner.cpp index 160b254460a..2e56e802cbc 100644 --- a/extension/llm/runner/text_llm_runner.cpp +++ b/extension/llm/runner/text_llm_runner.cpp @@ -235,13 +235,18 @@ Error TextLLMRunner::generate( // Set ignore_eos based on config text_token_generator_->set_ignore_eos(config.ignore_eos); + // Use the configuration's temperature + float resolved_temp = + temperature_ == -1.0f ? config.temperature : temperature_; + // Generate max_new_tokens - 1 because prefill already generated 1 token. auto generate_result = text_token_generator_->generate( prompt_tokens, pos_, max_new_tokens - 1, - temperature_ == -1.0f ? config.temperature : temperature_, + resolved_temp, wrapped_callback); + if (!generate_result.ok()) { return generate_result.error(); } diff --git a/extension/llm/sampler/sampler.h b/extension/llm/sampler/sampler.h index 4a480edc1ef..ce5ecc528f3 100644 --- a/extension/llm/sampler/sampler.h +++ b/extension/llm/sampler/sampler.h @@ -44,6 +44,10 @@ class ET_EXPERIMENTAL Sampler { Sampler(int32_t vocab_size, float temperature); + void set_temperature(float temperature) { + inv_temperature_ = static_cast(temperature) ? 1.0f / temperature : 0.0f; + } + // Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k. // When top-k is enabled, top-p is ignored — the two modes are mutually // exclusive in this implementation. diff --git a/extension/llm/sampler/util.h b/extension/llm/sampler/util.h index 6a3a06355ca..37579d7433e 100644 --- a/extension/llm/sampler/util.h +++ b/extension/llm/sampler/util.h @@ -16,21 +16,19 @@ namespace extension { namespace llm { /** - * Sample the next token from the logits tensor. + * Sample the next token from the logits tensor using a pre-configured Sampler. * @param logits_tensor The logits tensor. - * @param temperature The temperature parameter used to control randomness in - * sampling. + * @param sampler The sampler to use for token selection. * @return The next token. */ -inline int32_t logits_to_token( +inline int32_t sample_from_logits( const executorch::aten::Tensor& logits_tensor, - const float temperature = 0.0f) { + Sampler& sampler) { int32_t result = 0; - // Create a minimal context for error handling in ET_SWITCH struct { [[noreturn]] void fail(torch::executor::Error /* error */) { - ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token"); + ET_CHECK_MSG(false, "Unsupported dtype in sample_from_logits"); } } ctx; @@ -41,25 +39,36 @@ inline int32_t logits_to_token( UInt16, logits_tensor.scalar_type(), ctx, - "logits_to_token", + "sample_from_logits", CTYPE, [&]() { - // If the logit_tensor rank is 3, the shape is [batch, seq_length, - // vocab_size], get the last logits, sample and return. Else the model - // outputs the last logit, directly sample and return. auto* logits = logits_tensor.mutable_data_ptr(); ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); if (logits_tensor.dim() == 3) { auto num_tokens = logits_tensor.size(1); logits += (num_tokens - 1) * vocab_size; } - // @lint-ignore CLANGTIDY facebook-hte-Deprecated - Sampler sampler(vocab_size, temperature); result = sampler.sample(logits); }); return result; } +/** + * Sample the next token from the logits tensor. + * @param logits_tensor The logits tensor. + * @param temperature The temperature parameter used to control randomness in + * sampling. + * @return The next token. + */ +inline int32_t logits_to_token( + const executorch::aten::Tensor& logits_tensor, + const float temperature = 0.0f) { + ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); + // @lint-ignore CLANGTIDY facebook-hte-Deprecated + Sampler sampler(vocab_size, temperature); + return sample_from_logits(logits_tensor, sampler); +} + } // namespace llm } // namespace extension } // namespace executorch