Skip to content

Allow TextDecoderRunner to reuse a persistent Sampler#19351

Open
kirklandsign wants to merge 1 commit intomainfrom
export-D103748189
Open

Allow TextDecoderRunner to reuse a persistent Sampler#19351
kirklandsign wants to merge 1 commit intomainfrom
export-D103748189

Conversation

@kirklandsign
Copy link
Copy Markdown
Contributor

Summary:
TextDecoderRunner now optionally takes a std::unique_ptr in its
constructor, owning the Sampler and reusing it across token generation steps.
This avoids creating a temporary Sampler on every logits_to_token() call,
preserving RNG state across tokens and laying groundwork for future grammar-
constrained decoding.

Changes:

  • sampler.h: add set_temperature() so the owned Sampler's temperature can be
    updated per-call without reconstructing it
  • util.h: extract sample_from_logits(Tensor, Sampler&) for dtype dispatch with
    an existing Sampler; logits_to_token delegates to it
  • text_decoder_runner.h/cpp: ctor takes optional unique_ptr; when set,
    logits_to_token updates its temperature and reuses it
  • multimodal_decoder_runner.h: forward Sampler to base class ctor
  • llm_runner_helper.cpp: factory creates Sampler and passes to TextDecoderRunner
  • text_llm_runner.cpp: temperature resolved to local variable per-call
  • test_text_decoder_runner.cpp: add tests for ctor-injected Sampler

Differential Revision: D103748189

Summary:
TextDecoderRunner now optionally takes a std::unique_ptr<Sampler> in its
constructor, owning the Sampler and reusing it across token generation steps.
This avoids creating a temporary Sampler on every logits_to_token() call,
preserving RNG state across tokens and laying groundwork for future grammar-
constrained decoding.

Changes:
- sampler.h: add set_temperature() so the owned Sampler's temperature can be
  updated per-call without reconstructing it
- util.h: extract sample_from_logits(Tensor, Sampler&) for dtype dispatch with
  an existing Sampler; logits_to_token delegates to it
- text_decoder_runner.h/cpp: ctor takes optional unique_ptr<Sampler>; when set,
  logits_to_token updates its temperature and reuses it
- multimodal_decoder_runner.h: forward Sampler to base class ctor
- llm_runner_helper.cpp: factory creates Sampler and passes to TextDecoderRunner
- text_llm_runner.cpp: temperature resolved to local variable per-call
- test_text_decoder_runner.cpp: add tests for ctor-injected Sampler

Differential Revision: D103748189
Copilot AI review requested due to automatic review settings May 6, 2026 23:03
@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented May 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/19351

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Cancelled Job, 11 Pending, 11 Unrelated Failures

As of commit 0a1b6b7 with merge base 851cffb (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

BROKEN TRUNK - The following jobs failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 6, 2026
@meta-codesync
Copy link
Copy Markdown
Contributor

meta-codesync Bot commented May 6, 2026

@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D103748189.

@kirklandsign
Copy link
Copy Markdown
Contributor Author

@claude review

@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

This PR needs a release notes: label

If your change should be included in the release notes (i.e. would users of this library care about this change?), please use a label starting with release notes:. This helps us keep track and include your important work in the next release notes.

To add a label, you can comment to pytorchbot, for example
@pytorchbot label "release notes: none"

For more information, see
https://github.com/pytorch/pytorch/wiki/PyTorch-AutoLabel-Bot#why-categorize-for-release-notes-and-how-does-it-work.

@claude
Copy link
Copy Markdown

claude Bot commented May 6, 2026

Claude finished @kirklandsign's task in 1m 37s —— View job


Code Review: Allow TextDecoderRunner to reuse a persistent Sampler

  • Read changed files and understand the design
  • Review sampler.h changes
  • Review util.h changes
  • Review text_decoder_runner.h/cpp changes
  • Review multimodal_decoder_runner.h changes
  • Review llm_runner_helper.cpp changes
  • Review text_llm_runner.cpp changes
  • Review test file
  • Post review feedback

Overall: Clean, well-structured change. The design is sound — constructor-injected Sampler ownership with a fallback to per-call temporary creation preserves backward compatibility while enabling RNG state persistence. A few minor observations below.


sampler.h — set_temperature()

The new set_temperature() at line 47 uses static_cast<bool>(temperature) to match the constructor pattern in sampler.cpp:179. This is fine and consistent, though a reader might initially expect temperature != 0.0f. The static_cast<bool> cast is idiomatic for this pattern in the codebase (matches both constructors), so no issue.

One nit: the 0.0f literal in set_temperature vs 0 (no f suffix) in the constructor initializer lists (sampler.cpp:179,185). Functionally identical but worth noting for consistency.

util.h — sample_from_logits() extraction

Good refactor. The extracted function cleanly separates dtype dispatch from Sampler construction. The retained logits_to_token() wrapper is minimal and correctly delegates.

text_decoder_runner.h — logits_to_token() dispatch

if (sampler_) {
  sampler_->set_temperature(temperature);
  return ::executorch::extension::llm::sample_from_logits(logits_tensor, *sampler_);
}
return ::executorch::extension::llm::logits_to_token(logits_tensor, temperature);

This is clean. The only subtlety: when sampler_ is set, the vocab_size used for sampling is fixed at construction time (in llm_runner_helper.cpp). The fallback path (logits_to_token) reads vocab_size from the tensor shape on every call. If the model's output vocab_size ever differed from metadata.at(kVocabSize), the persistent Sampler could misbehave — but this is an extremely unlikely edge case and not a real concern for production models.

llm_runner_helper.cpp — Sampler factory

int32_t vocab_size = static_cast<int32_t>(metadata.at(kVocabSize));
float init_temp = temperature == -1.0f ? 0.0f : temperature;
auto sampler = std::make_unique<Sampler>(vocab_size, init_temp);

The -1.0f sentinel handling here mirrors what text_llm_runner.cpp does for the generate loop. The initial temperature doesn't matter much since logits_to_token() calls set_temperature() before every sample, but setting a reasonable initial value is still good practice.

For the multimodal path (line 344), the sampler is created with hardcoded 0.0f temperature. This is fine since it'll be overwritten per-call.

text_llm_runner.cpp

Extracting resolved_temp to a local variable is a minor readability improvement. No concerns.

multimodal_decoder_runner.h

The constructor signature expansion correctly forwards all parameters to the base class. The default sampler = nullptr preserves backward compatibility.

Tests

Good coverage:

  • Basic injected sampler with argmax (greedy)
  • Temperature switching on a persistent sampler
  • 3D tensor handling with injected sampler

One observation: the tests pass nullptr for both module and io_manager when constructing TextDecoderRunner with an injected sampler — this works because logits_to_token doesn't touch those members. This is fine for unit testing the sampling path but worth noting that it implicitly tests that the sampling logic is decoupled from the model execution logic.


Summary

This is a well-designed incremental change that:

  1. Preserves full backward compatibility (nullptr default for sampler)
  2. Enables RNG state persistence across tokens (key for reproducible generation)
  3. Lays groundwork for grammar-constrained decoding (persistent sampler state)
  4. Has good test coverage for the new code paths

No blocking issues. LGTM.


Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR updates the LLM decoding path to optionally reuse a persistent Sampler inside TextDecoderRunner, avoiding per-token sampler reconstruction, preserving RNG state across tokens, and enabling future extensions like grammar-constrained decoding.

Changes:

  • Added Sampler::set_temperature() and refactored sampling utilities to support reusing an existing Sampler (sample_from_logits()).
  • Extended TextDecoderRunner (and MultimodalDecoderRunner) constructors to optionally accept and own a std::unique_ptr<Sampler> and reuse it during token selection.
  • Updated runner factories to construct and inject a Sampler, and added unit tests covering injected sampler behavior.

Reviewed changes

Copilot reviewed 8 out of 8 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
extension/llm/sampler/util.h Extracts sample_from_logits() and routes logits_to_token() through it.
extension/llm/sampler/sampler.h Adds set_temperature() to update an existing sampler per call.
extension/llm/runner/text_llm_runner.cpp Resolves temperature into a local variable before generation.
extension/llm/runner/text_decoder_runner.h Adds optional owned Sampler and reuses it in logits_to_token().
extension/llm/runner/text_decoder_runner.cpp Wires new constructor parameter into member initialization.
extension/llm/runner/test/test_text_decoder_runner.cpp Adds tests for constructor-injected sampler reuse and temperature switching.
extension/llm/runner/multimodal_decoder_runner.h Forwards optional sampler/method name to TextDecoderRunner base ctor.
extension/llm/runner/llm_runner_helper.cpp Creates a Sampler in factories and injects it into decoder runners.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 24 to 52
@@ -41,25 +39,36 @@ inline int32_t logits_to_token(
UInt16,
logits_tensor.scalar_type(),
ctx,
"logits_to_token",
"sample_from_logits",
CTYPE,
[&]() {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
logits += (num_tokens - 1) * vocab_size;
}
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
Sampler sampler(vocab_size, temperature);
result = sampler.sample(logits);
});
Comment on lines 24 to 46
@@ -41,25 +39,36 @@ inline int32_t logits_to_token(
UInt16,
logits_tensor.scalar_type(),
ctx,
"logits_to_token",
"sample_from_logits",
CTYPE,
[&]() {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
Comment on lines +269 to +277
// Read vocab_size for Sampler
int32_t vocab_size = static_cast<int32_t>(metadata.at(kVocabSize));
float init_temp = temperature == -1.0f ? 0.0f : temperature;
auto sampler = std::make_unique<Sampler>(vocab_size, init_temp);

// Create text_decoder_runner
ET_LOG(Info, "Using method: %s", method_name.c_str());
auto text_decoder_runner = std::make_unique<TextDecoderRunner>(
module.get(), io_manager.get(), method_name);
module.get(), io_manager.get(), method_name, std::move(sampler));
Comment on lines 79 to +86
inline int32_t logits_to_token(
const executorch::aten::Tensor& logits_tensor,
const float temperature = 0.0f) {
if (sampler_) {
sampler_->set_temperature(temperature);
return ::executorch::extension::llm::sample_from_logits(
logits_tensor, *sampler_);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. fb-exported meta-exported

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants