Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions extension/llm/runner/llm_runner_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -266,10 +266,15 @@ std::unique_ptr<TextLLMRunner> create_text_llm_runner(
// Create IOManager
std::unique_ptr<IOManager> io_manager = std::make_unique<IOManager>(*module);

// Read vocab_size for Sampler
int32_t vocab_size = static_cast<int32_t>(metadata.at(kVocabSize));
float init_temp = temperature == -1.0f ? 0.0f : temperature;
auto sampler = std::make_unique<Sampler>(vocab_size, init_temp);

// Create text_decoder_runner
ET_LOG(Info, "Using method: %s", method_name.c_str());
auto text_decoder_runner = std::make_unique<TextDecoderRunner>(
module.get(), io_manager.get(), method_name);
module.get(), io_manager.get(), method_name, std::move(sampler));
Comment on lines +269 to +277

// Create text_prefiller
auto text_prefiller = std::make_unique<TextPrefiller>(
Expand Down Expand Up @@ -334,9 +339,13 @@ std::unique_ptr<MultimodalRunner> create_multimodal_runner(
// Create IOManager
std::unique_ptr<IOManager> io_manager = std::make_unique<IOManager>(*module);

// Read vocab_size for Sampler
int32_t vocab_size = static_cast<int32_t>(metadata.at(kVocabSize));
auto sampler = std::make_unique<Sampler>(vocab_size, 0.0f); // Default temp

// Create text_decoder_runner
auto text_decoder_runner =
std::make_unique<MultimodalDecoderRunner>(module.get(), io_manager.get());
auto text_decoder_runner = std::make_unique<MultimodalDecoderRunner>(
module.get(), io_manager.get(), "forward", std::move(sampler));

// Create multimodal_prefiller
auto multimodal_prefiller = std::make_unique<MultimodalPrefiller>(
Expand Down
12 changes: 10 additions & 2 deletions extension/llm/runner/multimodal_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,16 @@ namespace executorch::extension::llm {
class ET_EXPERIMENTAL MultimodalDecoderRunner
: public executorch::extension::llm::TextDecoderRunner {
public:
explicit MultimodalDecoderRunner(Module* module, IOManager* io_manager)
: TextDecoderRunner(module, io_manager) {}
explicit MultimodalDecoderRunner(
Module* module,
IOManager* io_manager,
std::string method_name = "forward",
std::unique_ptr<Sampler> sampler = nullptr)
: TextDecoderRunner(
module,
io_manager,
std::move(method_name),
std::move(sampler)) {}

/**
* Step the LLM Decoder with the given tokens and start position.
Expand Down
58 changes: 58 additions & 0 deletions extension/llm/runner/test/test_text_decoder_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -155,6 +155,64 @@
EXPECT_LT(token, 4);
}

// Test logits_to_token() with an injected Sampler (greedy, temp=0)
TEST_F(TextDecoderRunnerTest, LogitsToTokenWithInjectedSampler) {
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
auto logits = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});

auto sampler = std::make_unique<executorch::extension::llm::Sampler>(4, 0.0f);
auto runner = std::make_unique<TextDecoderRunner>(
nullptr, nullptr, "forward", std::move(sampler));

int32_t token = runner->logits_to_token(logits, 0.0f);
EXPECT_EQ(token, 2);

auto logits2 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
token = runner->logits_to_token(logits2, 0.0f);
EXPECT_EQ(token, 2);
}

// Test that set_temperature works on an injected Sampler
TEST_F(TextDecoderRunnerTest, LogitsToTokenInjectedSamplerTemperatureSwitch) {
auto sampler = std::make_unique<executorch::extension::llm::Sampler>(4, 0.0f);
auto runner = std::make_unique<TextDecoderRunner>(
nullptr, nullptr, "forward", std::move(sampler));

TensorFactory<executorch::aten::ScalarType::Float> tf_float;

// temp=0 → argmax
auto logits1 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
EXPECT_EQ(runner->logits_to_token(logits1, 0.0f), 2);

// temp=1.0 → stochastic, result must be in valid range
auto logits2 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
int32_t token = runner->logits_to_token(logits2, 1.0f);
EXPECT_GE(token, 0);
EXPECT_LT(token, 4);

// temp=0 again → back to argmax
auto logits3 = tf_float.make({1, 4}, {0.1f, 0.2f, 0.8f, 0.4f});
EXPECT_EQ(runner->logits_to_token(logits3, 0.0f), 2);
}

// Test logits_to_token() with an injected Sampler on a 3D tensor
TEST_F(TextDecoderRunnerTest, LogitsToTokenWithInjectedSampler3D) {
TensorFactory<executorch::aten::ScalarType::Float> tf_float;
auto logits = tf_float.make(
{1, 2, 4},
{
0.1f, 0.2f, 0.3f, 0.4f, // first position
0.5f, 0.6f, 0.9f, 0.8f // last position (used for sampling)
});

auto sampler = std::make_unique<executorch::extension::llm::Sampler>(4, 0.0f);
auto runner = std::make_unique<TextDecoderRunner>(
nullptr, nullptr, "forward", std::move(sampler));

int32_t token = runner->logits_to_token(logits, 0.0f);
EXPECT_EQ(token, 2);
}

// Test step() method with all available PTE models
TEST_F(TextDecoderRunnerTest, StepWithAllModels) {
// List of all environment variables for PTE models
Expand Down
6 changes: 4 additions & 2 deletions extension/llm/runner/text_decoder_runner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ namespace llm {
TextDecoderRunner::TextDecoderRunner(
Module* module,
IOManager* io_manager,
std::string method_name)
std::string method_name,
std::unique_ptr<Sampler> sampler)
: module_(module),
io_manager_(io_manager),
method_name_(std::move(method_name)) {}
method_name_(std::move(method_name)),
sampler_(std::move(sampler)) {}

// This function is functional, meaning it shouldn't modify any state of the
// input. It should be safe to call multiple times with the same inputs. The
Expand Down
16 changes: 11 additions & 5 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ class ET_EXPERIMENTAL TextDecoderRunner {
explicit TextDecoderRunner(
Module* module,
IOManager* io_manager,
std::string method_name = "forward");
std::string method_name = "forward",
std::unique_ptr<Sampler> sampler = nullptr);

virtual ~TextDecoderRunner() = default;

Expand Down Expand Up @@ -71,14 +72,18 @@ class ET_EXPERIMENTAL TextDecoderRunner {

/**
* Sample the next token from the logits tensor.
* @param logits_tensor The logits tensor.
* @param temperature The temperature parameter used to control randomness in
* sampling.
* @return The next token.
* If a Sampler was passed in the constructor, it is reused (its temperature
* will be updated to match the argument). Otherwise a temporary Sampler is
* created per call.
*/
inline int32_t logits_to_token(
const executorch::aten::Tensor& logits_tensor,
const float temperature = 0.0f) {
if (sampler_) {
sampler_->set_temperature(temperature);
return ::executorch::extension::llm::sample_from_logits(
logits_tensor, *sampler_);
}
Comment on lines 79 to +86
return ::executorch::extension::llm::logits_to_token(
logits_tensor, temperature);
}
Expand All @@ -94,6 +99,7 @@ class ET_EXPERIMENTAL TextDecoderRunner {
Module* module_;
IOManager* io_manager_;
std::string method_name_;
std::unique_ptr<Sampler> sampler_;
};

} // namespace llm
Expand Down
7 changes: 6 additions & 1 deletion extension/llm/runner/text_llm_runner.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -235,13 +235,18 @@
// Set ignore_eos based on config
text_token_generator_->set_ignore_eos(config.ignore_eos);

// Use the configuration's temperature
float resolved_temp =
temperature_ == -1.0f ? config.temperature : temperature_;

// Generate max_new_tokens - 1 because prefill already generated 1 token.
auto generate_result = text_token_generator_->generate(
prompt_tokens,
pos_,
max_new_tokens - 1,
temperature_ == -1.0f ? config.temperature : temperature_,
resolved_temp,
wrapped_callback);

if (!generate_result.ok()) {
return generate_result.error();
}
Expand Down
4 changes: 4 additions & 0 deletions extension/llm/sampler/sampler.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
Expand Down Expand Up @@ -44,6 +44,10 @@

Sampler(int32_t vocab_size, float temperature);

void set_temperature(float temperature) {
inv_temperature_ = static_cast<bool>(temperature) ? 1.0f / temperature : 0.0f;
}

// Enable top-k filtering. k <= 0 or k >= vocab_size disables top-k.
// When top-k is enabled, top-p is ignored — the two modes are mutually
// exclusive in this implementation.
Expand Down
35 changes: 22 additions & 13 deletions extension/llm/sampler/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@ namespace extension {
namespace llm {

/**
* Sample the next token from the logits tensor.
* Sample the next token from the logits tensor using a pre-configured Sampler.
* @param logits_tensor The logits tensor.
* @param temperature The temperature parameter used to control randomness in
* sampling.
* @param sampler The sampler to use for token selection.
* @return The next token.
*/
inline int32_t logits_to_token(
inline int32_t sample_from_logits(
const executorch::aten::Tensor& logits_tensor,
const float temperature = 0.0f) {
Sampler& sampler) {
int32_t result = 0;

// Create a minimal context for error handling in ET_SWITCH
struct {
[[noreturn]] void fail(torch::executor::Error /* error */) {
ET_CHECK_MSG(false, "Unsupported dtype in logits_to_token");
ET_CHECK_MSG(false, "Unsupported dtype in sample_from_logits");
}
} ctx;

Expand All @@ -41,25 +39,36 @@ inline int32_t logits_to_token(
UInt16,
logits_tensor.scalar_type(),
ctx,
"logits_to_token",
"sample_from_logits",
CTYPE,
[&]() {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
Comment on lines 24 to 46
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
logits += (num_tokens - 1) * vocab_size;
}
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
Sampler sampler(vocab_size, temperature);
result = sampler.sample(logits);
});
Comment on lines 24 to 52
return result;
}

/**
* Sample the next token from the logits tensor.
* @param logits_tensor The logits tensor.
* @param temperature The temperature parameter used to control randomness in
* sampling.
* @return The next token.
*/
inline int32_t logits_to_token(
const executorch::aten::Tensor& logits_tensor,
const float temperature = 0.0f) {
ssize_t vocab_size = logits_tensor.size(logits_tensor.dim() - 1);
// @lint-ignore CLANGTIDY facebook-hte-Deprecated
Sampler sampler(vocab_size, temperature);
return sample_from_logits(logits_tensor, sampler);
}

} // namespace llm
} // namespace extension
} // namespace executorch
Loading