@@ -80,7 +80,12 @@ class TextPrefillerTest : public Test {
8080 ::executorch::runtime::Result<uint64_t >,
8181 prefill_chunk,
8282 (std::vector<uint64_t >&, int64_t &),
83- ());
83+ (override ));
84+ MOCK_METHOD (
85+ ::executorch::runtime::Result<uint64_t >,
86+ prefill_chunk,
87+ (std::vector<uint64_t >&, int64_t &, float ),
88+ (override ));
8489 };
8590
8691 // Create a mock TextPrefiller
@@ -112,27 +117,145 @@ TEST_F(TextPrefillerTest, PrefillCallsPrefillChunkOnceWhenPromptFits) {
112117 int64_t start_pos = 0 ;
113118
114119 // Expect prefill_chunk to be called exactly once with the entire prompt
115- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
120+ constexpr float temperature = 0 .7f ;
121+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (temperature)))
116122 .Times (1 )
117- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
123+ .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos, float temp ) {
118124 // Verify the tokens passed to prefill_chunk
119125 EXPECT_EQ (tokens.size (), prompt_tokens.size ());
120126 for (size_t i = 0 ; i < tokens.size (); i++) {
121127 EXPECT_EQ (tokens[i], prompt_tokens[i]);
122128 }
123129 // Verify the position
124130 EXPECT_EQ (pos, start_pos);
131+ EXPECT_EQ (temp, temperature);
125132 return Result<uint64_t >(42 );
126133 });
127134
128135 // Call prefill
129- auto result = prefiller->prefill (prompt_tokens, start_pos);
136+ auto result = prefiller->prefill (prompt_tokens, start_pos, temperature );
130137
131138 // Verify the result
132139 EXPECT_EQ (result.error (), Error::Ok);
133140 EXPECT_EQ (result.get (), 42 );
134141}
135142
143+ TEST_F (TextPrefillerTest, TwoArgumentPrefillUsesGreedyTemperature) {
144+ auto prefiller = createMockTextPrefiller (10 );
145+
146+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
147+ int64_t start_pos = 0 ;
148+
149+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (0 .0f )))
150+ .Times (1 )
151+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
152+ return Result<uint64_t >(42 );
153+ });
154+
155+ auto result = prefiller->prefill (prompt_tokens, start_pos);
156+
157+ EXPECT_EQ (result.error (), Error::Ok);
158+ EXPECT_EQ (result.get (), 42 );
159+ }
160+
161+ TEST_F (TextPrefillerTest, PrefillAcceptsTemperatureBounds) {
162+ auto prefiller = createMockTextPrefiller (10 );
163+
164+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
165+ int64_t start_pos = 0 ;
166+
167+ {
168+ InSequence seq;
169+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (0 .0f )))
170+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
171+ return Result<uint64_t >(41 );
172+ });
173+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (1 .0f )))
174+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
175+ return Result<uint64_t >(42 );
176+ });
177+ }
178+
179+ auto greedy = prefiller->prefill (prompt_tokens, start_pos, 0 .0f );
180+ auto max_temp = prefiller->prefill (prompt_tokens, start_pos, 1 .0f );
181+
182+ EXPECT_EQ (greedy.error (), Error::Ok);
183+ EXPECT_EQ (greedy.get (), 41 );
184+ EXPECT_EQ (max_temp.error (), Error::Ok);
185+ EXPECT_EQ (max_temp.get (), 42 );
186+ }
187+
188+ TEST_F (TextPrefillerTest, PrefillRejectsTemperatureOutOfRange) {
189+ auto prefiller = createMockTextPrefiller (10 );
190+
191+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
192+ int64_t start_pos = 0 ;
193+
194+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, _)).Times (0 );
195+
196+ EXPECT_EQ (
197+ prefiller->prefill (prompt_tokens, start_pos, -0 .1f ).error (),
198+ Error::InvalidArgument);
199+ EXPECT_EQ (
200+ prefiller->prefill (prompt_tokens, start_pos, 1 .1f ).error (),
201+ Error::InvalidArgument);
202+ }
203+
204+ TEST_F (TextPrefillerTest, TwoArgumentPrefillChunkOverrideStillDispatches) {
205+ class LegacyPrefiller final : public TextPrefiller {
206+ public:
207+ explicit LegacyPrefiller (TextDecoderRunner* text_decoder_runner)
208+ : TextPrefiller(text_decoder_runner, true , true , 10 ) {}
209+
210+ Result<uint64_t > prefill_chunk (std::vector<uint64_t >&, int64_t &) override {
211+ called = true ;
212+ return Result<uint64_t >(42 );
213+ }
214+
215+ bool called = false ;
216+ };
217+
218+ LegacyPrefiller prefiller (&text_decoder_runner_);
219+ TextPrefiller* base = &prefiller;
220+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
221+ int64_t start_pos = 0 ;
222+
223+ auto result = base->prefill_chunk (prompt_tokens, start_pos);
224+
225+ EXPECT_EQ (result.error (), Error::Ok);
226+ EXPECT_EQ (result.get (), 42 );
227+ EXPECT_TRUE (prefiller.called );
228+ }
229+
230+ TEST_F (TextPrefillerTest, ChunkedPrefillSamplesOnlyLastChunkWithTemperature) {
231+ auto prefiller = createMockTextPrefiller (3 );
232+
233+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 , 4 , 5 , 6 , 7 , 8 };
234+ int64_t start_pos = 0 ;
235+ constexpr float temperature = 0 .9f ;
236+
237+ {
238+ InSequence seq;
239+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (0 .0f )))
240+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
241+ return Result<uint64_t >(10 );
242+ });
243+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (0 .0f )))
244+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
245+ return Result<uint64_t >(11 );
246+ });
247+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, FloatEq (temperature)))
248+ .WillOnce ([](std::vector<uint64_t >&, int64_t &, float ) {
249+ return Result<uint64_t >(12 );
250+ });
251+ }
252+
253+ auto result = prefiller->prefill (prompt_tokens, start_pos, temperature);
254+
255+ EXPECT_EQ (result.error (), Error::Ok);
256+ EXPECT_EQ (result.get (), 12 );
257+ }
258+
136259// Test that prefill() calls prefill_chunk() multiple times when prompt tokens >
137260// max_seq_len
138261TEST_F (
@@ -217,14 +340,14 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
217340 InSequence seq;
218341
219342 // First chunk: tokens [1, 2, 3] - succeeds
220- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
221- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
343+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, _ ))
344+ .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos, float ) {
222345 return Result<uint64_t >(10 );
223346 });
224347
225348 // Second chunk: tokens [4, 5] - fails
226- EXPECT_CALL (*prefiller, prefill_chunk (_, _))
227- .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos) {
349+ EXPECT_CALL (*prefiller, prefill_chunk (_, _, _ ))
350+ .WillOnce ([&](std::vector<uint64_t >& tokens, int64_t & pos, float ) {
228351 return Result<uint64_t >(Error::InvalidArgument);
229352 });
230353 }
@@ -236,6 +359,23 @@ TEST_F(TextPrefillerTest, PrefillHandlesPrefillChunkErrorsCorrectly) {
236359 EXPECT_EQ (result.error (), Error::InvalidArgument);
237360}
238361
362+ TEST_F (TextPrefillerTest, PrefillChunkRejectsTemperatureOutOfRange) {
363+ auto prefiller = createTextPrefiller (10 , true , true );
364+
365+ std::vector<uint64_t > prompt_tokens = {1 , 2 , 3 };
366+ int64_t start_pos = 0 ;
367+
368+ EXPECT_CALL (text_decoder_runner_, step (_, _)).Times (0 );
369+
370+ EXPECT_EQ (
371+ prefiller->prefill_chunk (prompt_tokens, start_pos, -0 .1f ).error (),
372+ Error::InvalidArgument);
373+ EXPECT_EQ (
374+ prefiller->prefill_chunk (prompt_tokens, start_pos, 1 .1f ).error (),
375+ Error::InvalidArgument);
376+ EXPECT_EQ (start_pos, 0 );
377+ }
378+
239379// Test that prefill_chunk() works correctly with parallel prefill enabled
240380TEST_F (TextPrefillerTest, PrefillChunkWorksWithParallelPrefill) {
241381 // Create a TextPrefiller with parallel prefill enabled
0 commit comments