Allow TextDecoderRunner to reuse a persistent Sampler#19351
Allow TextDecoderRunner to reuse a persistent Sampler#19351kirklandsign wants to merge 1 commit intomainfrom
Conversation
Summary: TextDecoderRunner now optionally takes a std::unique_ptr<Sampler> in its constructor, owning the Sampler and reusing it across token generation steps. This avoids creating a temporary Sampler on every logits_to_token() call, preserving RNG state across tokens and laying groundwork for future grammar- constrained decoding. Changes: - sampler.h: add set_temperature() so the owned Sampler's temperature can be updated per-call without reconstructing it - util.h: extract sample_from_logits(Tensor, Sampler&) for dtype dispatch with an existing Sampler; logits_to_token delegates to it - text_decoder_runner.h/cpp: ctor takes optional unique_ptr<Sampler>; when set, logits_to_token updates its temperature and reuses it - multimodal_decoder_runner.h: forward Sampler to base class ctor - llm_runner_helper.cpp: factory creates Sampler and passes to TextDecoderRunner - text_llm_runner.cpp: temperature resolved to local variable per-call - test_text_decoder_runner.cpp: add tests for ctor-injected Sampler Differential Revision: D103748189
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19351
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 1 Cancelled Job, 11 Pending, 11 Unrelated FailuresAs of commit 0a1b6b7 with merge base 851cffb ( NEW FAILURES - The following jobs have failed:
CANCELLED JOB - The following job was cancelled. Please retry:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D103748189. |
|
@claude review |
This PR needs a
|
|
Claude finished @kirklandsign's task in 1m 37s —— View job Code Review: Allow TextDecoderRunner to reuse a persistent Sampler
Overall: Clean, well-structured change. The design is sound — constructor-injected sampler.h —
|
There was a problem hiding this comment.
Pull request overview
This PR updates the LLM decoding path to optionally reuse a persistent Sampler inside TextDecoderRunner, avoiding per-token sampler reconstruction, preserving RNG state across tokens, and enabling future extensions like grammar-constrained decoding.
Changes:
- Added
Sampler::set_temperature()and refactored sampling utilities to support reusing an existingSampler(sample_from_logits()). - Extended
TextDecoderRunner(andMultimodalDecoderRunner) constructors to optionally accept and own astd::unique_ptr<Sampler>and reuse it during token selection. - Updated runner factories to construct and inject a
Sampler, and added unit tests covering injected sampler behavior.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/llm/sampler/util.h | Extracts sample_from_logits() and routes logits_to_token() through it. |
| extension/llm/sampler/sampler.h | Adds set_temperature() to update an existing sampler per call. |
| extension/llm/runner/text_llm_runner.cpp | Resolves temperature into a local variable before generation. |
| extension/llm/runner/text_decoder_runner.h | Adds optional owned Sampler and reuses it in logits_to_token(). |
| extension/llm/runner/text_decoder_runner.cpp | Wires new constructor parameter into member initialization. |
| extension/llm/runner/test/test_text_decoder_runner.cpp | Adds tests for constructor-injected sampler reuse and temperature switching. |
| extension/llm/runner/multimodal_decoder_runner.h | Forwards optional sampler/method name to TextDecoderRunner base ctor. |
| extension/llm/runner/llm_runner_helper.cpp | Creates a Sampler in factories and injects it into decoder runners. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| @@ -41,25 +39,36 @@ inline int32_t logits_to_token( | |||
| UInt16, | |||
| logits_tensor.scalar_type(), | |||
| ctx, | |||
| "logits_to_token", | |||
| "sample_from_logits", | |||
| CTYPE, | |||
| [&]() { | |||
| // If the logit_tensor rank is 3, the shape is [batch, seq_length, | |||
| // vocab_size], get the last logits, sample and return. Else the model | |||
| // outputs the last logit, directly sample and return. | |||
| auto* logits = logits_tensor.mutable_data_ptr<CTYPE>(); | |||
| ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | |||
| if (logits_tensor.dim() == 3) { | |||
| auto num_tokens = logits_tensor.size(1); | |||
| logits += (num_tokens - 1) * vocab_size; | |||
| } | |||
| // @lint-ignore CLANGTIDY facebook-hte-Deprecated | |||
| Sampler sampler(vocab_size, temperature); | |||
| result = sampler.sample(logits); | |||
| }); | |||
| @@ -41,25 +39,36 @@ inline int32_t logits_to_token( | |||
| UInt16, | |||
| logits_tensor.scalar_type(), | |||
| ctx, | |||
| "logits_to_token", | |||
| "sample_from_logits", | |||
| CTYPE, | |||
| [&]() { | |||
| // If the logit_tensor rank is 3, the shape is [batch, seq_length, | |||
| // vocab_size], get the last logits, sample and return. Else the model | |||
| // outputs the last logit, directly sample and return. | |||
| auto* logits = logits_tensor.mutable_data_ptr<CTYPE>(); | |||
| ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1); | |||
| // Read vocab_size for Sampler | ||
| int32_t vocab_size = static_cast<int32_t>(metadata.at(kVocabSize)); | ||
| float init_temp = temperature == -1.0f ? 0.0f : temperature; | ||
| auto sampler = std::make_unique<Sampler>(vocab_size, init_temp); | ||
|
|
||
| // Create text_decoder_runner | ||
| ET_LOG(Info, "Using method: %s", method_name.c_str()); | ||
| auto text_decoder_runner = std::make_unique<TextDecoderRunner>( | ||
| module.get(), io_manager.get(), method_name); | ||
| module.get(), io_manager.get(), method_name, std::move(sampler)); |
| inline int32_t logits_to_token( | ||
| const executorch::aten::Tensor& logits_tensor, | ||
| const float temperature = 0.0f) { | ||
| if (sampler_) { | ||
| sampler_->set_temperature(temperature); | ||
| return ::executorch::extension::llm::sample_from_logits( | ||
| logits_tensor, *sampler_); | ||
| } |
Summary:
TextDecoderRunner now optionally takes a std::unique_ptr in its
constructor, owning the Sampler and reusing it across token generation steps.
This avoids creating a temporary Sampler on every logits_to_token() call,
preserving RNG state across tokens and laying groundwork for future grammar-
constrained decoding.
Changes:
updated per-call without reconstructing it
an existing Sampler; logits_to_token delegates to it
logits_to_token updates its temperature and reuses it
Differential Revision: D103748189