Skip to content

Commit ac645d4

Browse files
authored
Merge branch 'main' into export-D108082431
2 parents e3636c6 + e257a71 commit ac645d4

6 files changed

Lines changed: 242 additions & 18 deletions

File tree

extension/llm/runner/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,7 +731,7 @@ std::unordered_map<std::string, int64_t> get_llm_metadata(
731731
|-----------|------|---------|-------------|
732732
| `max_new_tokens` | `int32_t` | `-1` | Maximum new tokens to generate (-1 = use available context) |
733733
| `seq_len` | `int32_t` | `1024` | Total sequence length including prompt |
734-
| `temperature` | `float` | `0.8f` | Sampling temperature (0.0 = deterministic, 1.0+ = creative) |
734+
| `temperature` | `float` | `0.8f` | Sampling temperature in [0.0, 1.0] (0.0 = deterministic) |
735735
| `echo` | `bool` | `true` | Whether to echo the input prompt |
736736
| `num_bos` | `int8_t` | `1` | Number of beginning-of-sequence tokens |
737737
| `num_eos` | `int8_t` | `1` | Number of end-of-sequence tokens |
@@ -824,7 +824,7 @@ GenerationConfig config;
824824
config.temperature = 0.1f; // Very deterministic
825825
runner->generate(factual_prompt, config, callback);
826826

827-
config.temperature = 1.2f; // Very creative
827+
config.temperature = 1.0f; // Highest supported temperature
828828
runner->generate(creative_prompt, config, callback);
829829
```
830830

extension/llm/runner/test/test_text_llm_runner.cpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,26 @@ TEST_F(RunnerTest, TextTokenGeneratorProcessorChainMasksMultipleTokens) {
709709
EXPECT_EQ(generated_tokens, expected);
710710
}
711711

712+
TEST_F(RunnerTest, TextTokenGeneratorRejectsTemperatureOutOfRange) {
713+
auto tokenizer = createMockTokenizer();
714+
auto text_decoder_runner = createMockTextDecoderRunner();
715+
Stats stats;
716+
auto generator = createTextTokenGenerator(
717+
tokenizer.get(), text_decoder_runner.get(), &stats);
718+
719+
std::vector<uint64_t> tokens = {1, 2, 3};
720+
EXPECT_CALL(*text_decoder_runner, step(_, _)).Times(0);
721+
722+
EXPECT_EQ(
723+
generator->generate(tokens, 3, 3, -0.1f, [](const std::string&) {})
724+
.error(),
725+
Error::InvalidArgument);
726+
EXPECT_EQ(
727+
generator->generate(tokens, 3, 3, 1.1f, [](const std::string&) {})
728+
.error(),
729+
Error::InvalidArgument);
730+
}
731+
712732
// Without any processors, greedy argmax picks token 3 (zero-overhead path).
713733
TEST_F(RunnerTest, TextTokenGeneratorWithoutProcessorPicksArgmax) {
714734
auto tokenizer = createMockTokenizer();

extension/llm/runner/test/test_text_prefiller.cpp

Lines changed: 148 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,12 @@ class TextPrefillerTest : public Test {
8080
::executorch::runtime::Result<uint64_t>,
8181
prefill_chunk,
8282
(std::vector<uint64_t>&, int64_t&),
83-
());
83+
(override));
84+
MOCK_METHOD(
85+
::executorch::runtime::Result<uint64_t>,
86+
prefill_chunk,
87+
(std::vector<uint64_t>&, int64_t&, float),
88+
(override));
8489
};
8590

8691
// Create a mock TextPrefiller
@@ -112,27 +117,145 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
112117
int64_t start_pos = 0;
113118

114119
// Expect prefill_chunk to be called exactly once with the entire prompt
115-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
120+
constexpr float temperature = 0.7f;
121+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(temperature)))
116122
.Times(1)
117-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
123+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float temp) {
118124
// Verify the tokens passed to prefill_chunk
119125
EXPECT_EQ(tokens.size(), prompt_tokens.size());
120126
for (size_t i = 0; i < tokens.size(); i++) {
121127
EXPECT_EQ(tokens[i], prompt_tokens[i]);
122128
}
123129
// Verify the position
124130
EXPECT_EQ(pos, start_pos);
131+
EXPECT_EQ(temp, temperature);
125132
return Result<uint64_t>(42);
126133
});
127134

128135
// Call prefill
129-
auto result = prefiller->prefill(prompt_tokens, start_pos);
136+
auto result = prefiller->prefill(prompt_tokens, start_pos, temperature);
130137

131138
// Verify the result
132139
EXPECT_EQ(result.error(), Error::Ok);
133140
EXPECT_EQ(result.get(), 42);
134141
}
135142

143+
TEST_F(TextPrefillerTest, TwoArgumentPrefillUsesGreedyTemperature) {
144+
auto prefiller = createMockTextPrefiller(10);
145+
146+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
147+
int64_t start_pos = 0;
148+
149+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(0.0f)))
150+
.Times(1)
151+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
152+
return Result<uint64_t>(42);
153+
});
154+
155+
auto result = prefiller->prefill(prompt_tokens, start_pos);
156+
157+
EXPECT_EQ(result.error(), Error::Ok);
158+
EXPECT_EQ(result.get(), 42);
159+
}
160+
161+
TEST_F(TextPrefillerTest, PrefillAcceptsTemperatureBounds) {
162+
auto prefiller = createMockTextPrefiller(10);
163+
164+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
165+
int64_t start_pos = 0;
166+
167+
{
168+
InSequence seq;
169+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(0.0f)))
170+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
171+
return Result<uint64_t>(41);
172+
});
173+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(1.0f)))
174+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
175+
return Result<uint64_t>(42);
176+
});
177+
}
178+
179+
auto greedy = prefiller->prefill(prompt_tokens, start_pos, 0.0f);
180+
auto max_temp = prefiller->prefill(prompt_tokens, start_pos, 1.0f);
181+
182+
EXPECT_EQ(greedy.error(), Error::Ok);
183+
EXPECT_EQ(greedy.get(), 41);
184+
EXPECT_EQ(max_temp.error(), Error::Ok);
185+
EXPECT_EQ(max_temp.get(), 42);
186+
}
187+
188+
TEST_F(TextPrefillerTest, PrefillRejectsTemperatureOutOfRange) {
189+
auto prefiller = createMockTextPrefiller(10);
190+
191+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
192+
int64_t start_pos = 0;
193+
194+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _)).Times(0);
195+
196+
EXPECT_EQ(
197+
prefiller->prefill(prompt_tokens, start_pos, -0.1f).error(),
198+
Error::InvalidArgument);
199+
EXPECT_EQ(
200+
prefiller->prefill(prompt_tokens, start_pos, 1.1f).error(),
201+
Error::InvalidArgument);
202+
}
203+
204+
TEST_F(TextPrefillerTest, TwoArgumentPrefillChunkOverrideStillDispatches) {
205+
class LegacyPrefiller final : public TextPrefiller {
206+
public:
207+
explicit LegacyPrefiller(TextDecoderRunner* text_decoder_runner)
208+
: TextPrefiller(text_decoder_runner, true, true, 10) {}
209+
210+
Result<uint64_t> prefill_chunk(std::vector<uint64_t>&, int64_t&) override {
211+
called = true;
212+
return Result<uint64_t>(42);
213+
}
214+
215+
bool called = false;
216+
};
217+
218+
LegacyPrefiller prefiller(&text_decoder_runner_);
219+
TextPrefiller* base = &prefiller;
220+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
221+
int64_t start_pos = 0;
222+
223+
auto result = base->prefill_chunk(prompt_tokens, start_pos);
224+
225+
EXPECT_EQ(result.error(), Error::Ok);
226+
EXPECT_EQ(result.get(), 42);
227+
EXPECT_TRUE(prefiller.called);
228+
}
229+
230+
TEST_F(TextPrefillerTest, ChunkedPrefillSamplesOnlyLastChunkWithTemperature) {
231+
auto prefiller = createMockTextPrefiller(3);
232+
233+
std::vector<uint64_t> prompt_tokens = {1, 2, 3, 4, 5, 6, 7, 8};
234+
int64_t start_pos = 0;
235+
constexpr float temperature = 0.9f;
236+
237+
{
238+
InSequence seq;
239+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(0.0f)))
240+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
241+
return Result<uint64_t>(10);
242+
});
243+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(0.0f)))
244+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
245+
return Result<uint64_t>(11);
246+
});
247+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, FloatEq(temperature)))
248+
.WillOnce([](std::vector<uint64_t>&, int64_t&, float) {
249+
return Result<uint64_t>(12);
250+
});
251+
}
252+
253+
auto result = prefiller->prefill(prompt_tokens, start_pos, temperature);
254+
255+
EXPECT_EQ(result.error(), Error::Ok);
256+
EXPECT_EQ(result.get(), 12);
257+
}
258+
136259
// Test that prefill() calls prefill_chunk() multiple times when prompt tokens >
137260
// max_seq_len
138261
TEST_F(
@@ -217,14 +340,14 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
217340
InSequence seq;
218341

219342
// First chunk: tokens [1, 2, 3] - succeeds
220-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
221-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
343+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _))
344+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float) {
222345
return Result<uint64_t>(10);
223346
});
224347

225348
// Second chunk: tokens [4, 5] - fails
226-
EXPECT_CALL(*prefiller, prefill_chunk(_, _))
227-
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos) {
349+
EXPECT_CALL(*prefiller, prefill_chunk(_, _, _))
350+
.WillOnce([&](std::vector<uint64_t>& tokens, int64_t& pos, float) {
228351
return Result<uint64_t>(Error::InvalidArgument);
229352
});
230353
}
@@ -236,6 +359,23 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
236359
EXPECT_EQ(result.error(), Error::InvalidArgument);
237360
}
238361

362+
TEST_F(TextPrefillerTest, PrefillChunkRejectsTemperatureOutOfRange) {
363+
auto prefiller = createTextPrefiller(10, true, true);
364+
365+
std::vector<uint64_t> prompt_tokens = {1, 2, 3};
366+
int64_t start_pos = 0;
367+
368+
EXPECT_CALL(text_decoder_runner_, step(_, _)).Times(0);
369+
370+
EXPECT_EQ(
371+
prefiller->prefill_chunk(prompt_tokens, start_pos, -0.1f).error(),
372+
Error::InvalidArgument);
373+
EXPECT_EQ(
374+
prefiller->prefill_chunk(prompt_tokens, start_pos, 1.1f).error(),
375+
Error::InvalidArgument);
376+
EXPECT_EQ(start_pos, 0);
377+
}
378+
239379
// Test that prefill_chunk() works correctly with parallel prefill enabled
240380
TEST_F(TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
241381
// Create a TextPrefiller with parallel prefill enabled

extension/llm/runner/text_prefiller.cpp

Lines changed: 37 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,19 @@ TextPrefiller::TextPrefiller(
2929
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
3030
std::vector<uint64_t>& prompt_tokens,
3131
int64_t& start_pos) {
32+
return prefill(prompt_tokens, start_pos, 0.0f);
33+
}
34+
35+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
36+
std::vector<uint64_t>& prompt_tokens,
37+
int64_t& start_pos,
38+
float temperature) {
3239
ET_CHECK_MSG(!prompt_tokens.empty(), "Prompt cannot be null");
40+
ET_CHECK_OR_RETURN_ERROR(
41+
temperature >= 0.0f && temperature <= 1.0f,
42+
InvalidArgument,
43+
"Temperature must be in [0, 1], got %f",
44+
static_cast<double>(temperature));
3345
if (!text_decoder_runner_->is_method_loaded()) {
3446
ET_CHECK_OK_OR_RETURN_ERROR(text_decoder_runner_->load());
3547
}
@@ -54,8 +66,14 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
5466
num_tokens_to_prefill_with,
5567
prompt_tokens_to_process.begin());
5668

57-
// Process this chunk
58-
auto chunk_result = prefill_chunk(prompt_tokens_to_process, start_pos);
69+
// Only the final chunk samples the first generated token.
70+
const bool is_last_chunk =
71+
num_tokens_to_process + num_tokens_to_prefill_with >=
72+
num_prompt_tokens;
73+
auto chunk_result = prefill_chunk(
74+
prompt_tokens_to_process,
75+
start_pos,
76+
is_last_chunk ? temperature : 0.0f);
5977
ET_CHECK_OK_OR_RETURN_ERROR(chunk_result.error());
6078
cur_token = chunk_result.get();
6179

@@ -65,13 +83,25 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill(
6583
return cur_token;
6684
} else {
6785
// If prompt tokens don't exceed max_seq_len_, process them directly
68-
return prefill_chunk(prompt_tokens, start_pos);
86+
return prefill_chunk(prompt_tokens, start_pos, temperature);
6987
}
7088
}
7189

7290
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
7391
std::vector<uint64_t>& prompt_tokens,
7492
int64_t& start_pos) {
93+
return prefill_chunk(prompt_tokens, start_pos, 0.0f);
94+
}
95+
96+
::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
97+
std::vector<uint64_t>& prompt_tokens,
98+
int64_t& start_pos,
99+
float temperature) {
100+
ET_CHECK_OR_RETURN_ERROR(
101+
temperature >= 0.0f && temperature <= 1.0f,
102+
InvalidArgument,
103+
"Temperature must be in [0, 1], got %f",
104+
static_cast<double>(temperature));
75105
// enable_parallel_prefill_ maybe set even when not using kv cache
76106
// When kv cache is not used, start pos is ignored
77107
int32_t num_prompt_tokens = prompt_tokens.size();
@@ -92,7 +122,8 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
92122
Info, "Prefill token result numel(): %zu", outputs_res.get().numel());
93123

94124
start_pos += num_prompt_tokens;
95-
cur_token = text_decoder_runner_->logits_to_token(outputs_res.get());
125+
cur_token =
126+
text_decoder_runner_->logits_to_token(outputs_res.get(), temperature);
96127
} else { // sequential prefill
97128
int64_t pos = 0; // position in the sequence
98129
// NOLINTNEXTLINE(facebook-hte-ParameterUncheckedArrayBounds)
@@ -128,7 +159,8 @@ ::executorch::runtime::Result<uint64_t> TextPrefiller::prefill_chunk(
128159
start_pos++;
129160
}
130161

131-
cur_token = text_decoder_runner_->logits_to_token(logits_tensor);
162+
cur_token =
163+
text_decoder_runner_->logits_to_token(logits_tensor, temperature);
132164
}
133165
return cur_token;
134166
}

extension/llm/runner/text_prefiller.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,43 @@ class ET_EXPERIMENTAL TextPrefiller {
3232
* tokenizer.
3333
* @param start_pos The starting position in KV cache of the input in the LLM
3434
* Module.
35+
* Equivalent to `prefill(prompt_tokens, start_pos, 0.0f)`.
3536
* @return The next token of the LLM Module after prefill.
3637
*/
3738
virtual ::executorch::runtime::Result<uint64_t> prefill(
3839
std::vector<uint64_t>& prompt_tokens,
3940
int64_t& start_pos);
4041

42+
/**
43+
* Like `prefill(prompt_tokens, start_pos)`, but samples the first generated
44+
* token with `temperature` in [0.0, 1.0].
45+
*/
46+
virtual ::executorch::runtime::Result<uint64_t> prefill(
47+
std::vector<uint64_t>& prompt_tokens,
48+
int64_t& start_pos,
49+
float temperature);
50+
4151
/**
4252
* Helper method to prefill a chunk of tokens.
4353
* @param prompt_tokens The chunk of text prompt tokens to process.
4454
* @param start_pos The starting position in KV cache of the input in the LLM
4555
* Module.
56+
* Equivalent to `prefill_chunk(prompt_tokens, start_pos, 0.0f)`.
4657
* @return The next token of the LLM Module after prefilling this chunk.
4758
*/
4859
virtual ::executorch::runtime::Result<uint64_t> prefill_chunk(
4960
std::vector<uint64_t>& prompt_tokens,
5061
int64_t& start_pos);
5162

63+
/**
64+
* Like `prefill_chunk(prompt_tokens, start_pos)`, but samples the produced
65+
* token with `temperature` in [0.0, 1.0].
66+
*/
67+
virtual ::executorch::runtime::Result<uint64_t> prefill_chunk(
68+
std::vector<uint64_t>& prompt_tokens,
69+
int64_t& start_pos,
70+
float temperature);
71+
5272
/**
5373
* Load the necessary resources for the TextPrefiller.
5474
* This method should be called before using the prefill methods.

0 commit comments

Comments
 (0)