@@ -57,6 +57,9 @@ public struct GenerateParameters: Sendable {
57
57
/// Step size for processing the prompt
58
58
public var prefillStepSize = 512
59
59
60
+ /// Maximum tokens to generate
61
+ public var maxTokens : Int ?
62
+
60
63
/// sampling temperature
61
64
public var temperature : Float = 0.6
62
65
@@ -70,9 +73,11 @@ public struct GenerateParameters: Sendable {
70
73
public var repetitionContextSize : Int = 20
71
74
72
75
public init (
76
+ maxTokens: Int ? = nil ,
73
77
temperature: Float = 0.6 , topP: Float = 1.0 , repetitionPenalty: Float ? = nil ,
74
78
repetitionContextSize: Int = 20
75
79
) {
80
+ self . maxTokens = maxTokens
76
81
self . temperature = temperature
77
82
self . topP = topP
78
83
self . repetitionPenalty = repetitionPenalty
@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler {
111
116
let temp : MLXArray
112
117
let topP : MLXArray
113
118
114
- init ( temperature: Float , topP: Float ) {
119
+ public init ( temperature: Float , topP: Float ) {
115
120
self . temp = MLXArray ( temperature)
116
121
self . topP = MLXArray ( topP)
117
122
}
@@ -149,7 +154,7 @@ public struct TopPSampler: LogitSampler {
149
154
public struct CategoricalSampler : LogitSampler {
150
155
let temp : MLXArray
151
156
152
- init ( temperature: Float ) {
157
+ public init ( temperature: Float ) {
153
158
self . temp = MLXArray ( temperature)
154
159
}
155
160
@@ -178,7 +183,7 @@ public struct RepetitionContext: LogitProcessor {
178
183
/// number of tokens to consider for repetition penalty
179
184
let repetitionContextSize : Int
180
185
181
- init ( repetitionPenalty: Float , repetitionContextSize: Int ) {
186
+ public init ( repetitionPenalty: Float , repetitionContextSize: Int ) {
182
187
precondition ( repetitionContextSize > 0 )
183
188
self . repetitionPenalty = repetitionPenalty
184
189
self . repetitionContextSize = repetitionContextSize
@@ -250,6 +255,9 @@ public struct TokenIterator: Sequence, IteratorProtocol {
250
255
var processor : LogitProcessor ?
251
256
let sampler : LogitSampler
252
257
258
+ var tokenCount = 0
259
+ let maxTokens : Int ?
260
+
253
261
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
254
262
/// replaced with ``init(input:model:cache:parameters:)``.
255
263
///
@@ -269,6 +277,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
269
277
270
278
self . processor = parameters. processor ( )
271
279
self . sampler = parameters. sampler ( )
280
+ self . maxTokens = parameters. maxTokens
272
281
273
282
try prepare ( input: . init( text: y) , windowSize: parameters. prefillStepSize)
274
283
}
@@ -295,6 +304,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
295
304
296
305
self . processor = parameters. processor ( )
297
306
self . sampler = parameters. sampler ( )
307
+ self . maxTokens = parameters. maxTokens
298
308
299
309
try prepare ( input: input, windowSize: parameters. prefillStepSize)
300
310
}
@@ -308,16 +318,19 @@ public struct TokenIterator: Sequence, IteratorProtocol {
308
318
/// - processor: the logit processor
309
319
/// - sampler: the logit sampler
310
320
/// - prefillStepSize: optional prefill step size
321
+ /// - maxTokens: maximum number of tokens to generate
311
322
public init (
312
323
input: LMInput , model: any LanguageModel , cache: [ KVCache ] ? = nil ,
313
- processor: LogitProcessor ? , sampler: LogitSampler , prefillStepSize: Int = 512
324
+ processor: LogitProcessor ? , sampler: LogitSampler , prefillStepSize: Int = 512 ,
325
+ maxTokens: Int ? = nil
314
326
) throws {
315
327
self . model = model
316
328
self . y = input. text
317
329
self . cache = cache ?? model. newCache ( parameters: nil )
318
330
319
331
self . processor = processor
320
332
self . sampler = sampler
333
+ self . maxTokens = maxTokens
321
334
322
335
try prepare ( input: input, windowSize: prefillStepSize)
323
336
}
@@ -365,6 +378,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
365
378
}
366
379
367
380
mutating public func next( ) -> Int ? {
381
+ if let maxTokens, tokenCount >= maxTokens {
382
+ return nil
383
+ }
384
+
368
385
// save current value -- this will be returned
369
386
let previousY = y
370
387
@@ -373,6 +390,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
373
390
y = . init( tokens: token)
374
391
asyncEval ( token)
375
392
393
+ tokenCount += 1
394
+
376
395
return previousY. tokens. item ( Int . self)
377
396
}
378
397
}
@@ -413,24 +432,32 @@ public struct GenerateResult: Sendable {
413
432
/// output text
414
433
public let output : String
415
434
435
+ /// The number of tokens included in the input prompt.
436
+ public var promptTokenCount : Int { inputText. tokens. size }
437
+
438
+ /// The number of tokens generated by the language model.
439
+ public var generationTokenCount : Int { tokens. count }
440
+
416
441
/// time to process the prompt / generate the first token
417
442
public let promptTime : TimeInterval
418
443
419
444
/// time to generate the remaining tokens
420
445
public let generateTime : TimeInterval
421
446
447
+ /// The number of tokens processed per second during the prompt phase.
422
448
public var promptTokensPerSecond : Double {
423
449
Double ( inputText. tokens. size) / promptTime
424
450
}
425
451
452
+ /// The number of tokens generated per second during the generation phase.
426
453
public var tokensPerSecond : Double {
427
454
Double ( tokens. count) / generateTime
428
455
}
429
456
430
457
public func summary( ) -> String {
431
458
"""
432
- Prompt: \( inputText . tokens . size ) tokens, \( promptTokensPerSecond. formatted ( ) ) tokens/s
433
- Generation: \( tokens . count ) tokens, \( tokensPerSecond. formatted ( ) ) tokens/s, \( generateTime. formatted ( ) ) s
459
+ Prompt: \( promptTokenCount ) tokens, \( promptTokensPerSecond. formatted ( ) ) tokens/s
460
+ Generation: \( generationTokenCount ) tokens, \( tokensPerSecond. formatted ( ) ) tokens/s, \( generateTime. formatted ( ) ) s
434
461
"""
435
462
}
436
463
}
@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable {
795
822
public let promptTime : TimeInterval
796
823
797
824
/// The time interval (in seconds) taken to generate the output tokens.
798
- public let generationTime : TimeInterval
825
+ public let generateTime : TimeInterval
799
826
800
827
/// The number of tokens processed per second during the prompt phase.
801
828
public var promptTokensPerSecond : Double {
@@ -804,7 +831,7 @@ public struct GenerateCompletionInfo: Sendable {
804
831
805
832
/// The number of tokens generated per second during the generation phase.
806
833
public var tokensPerSecond : Double {
807
- Double ( generationTokenCount) / generationTime
834
+ Double ( generationTokenCount) / generateTime
808
835
}
809
836
810
837
public init (
@@ -816,7 +843,14 @@ public struct GenerateCompletionInfo: Sendable {
816
843
self . promptTokenCount = promptTokenCount
817
844
self . generationTokenCount = generationTokenCount
818
845
self . promptTime = promptTime
819
- self . generationTime = generationTime
846
+ self . generateTime = generationTime
847
+ }
848
+
849
+ public func summary( ) -> String {
850
+ """
851
+ Prompt: \( promptTokenCount) tokens, \( promptTokensPerSecond. formatted ( ) ) tokens/s
852
+ Generation: \( generationTokenCount) tokens, \( tokensPerSecond. formatted ( ) ) tokens/s, \( generateTime. formatted ( ) ) s
853
+ """
820
854
}
821
855
}
822
856
@@ -825,9 +859,31 @@ public struct GenerateCompletionInfo: Sendable {
825
859
/// This enum distinguishes between the following:
826
860
/// - `.chunk`: A decoded string from one or more tokens generated by the language model.
827
861
/// - `.info`: Metadata and performance statistics about the generation process.
828
- public enum Generation {
829
- /// A generated token represented as an integer.
862
+ public enum Generation : Sendable {
863
+ /// A generated token represented as a String
830
864
case chunk( String )
831
865
/// Completion information summarizing token counts and performance metrics.
832
866
case info( GenerateCompletionInfo )
867
+
868
+ /// Generated text or nil
869
+ public var chunk : String ? {
870
+ switch self {
871
+ case . chunk( let string) : string
872
+ case . info: nil
873
+ }
874
+ }
875
+
876
+ /// Completion info or nil
877
+ public var info : GenerateCompletionInfo ? {
878
+ switch self {
879
+ case . chunk: nil
880
+ case . info( let info) : info
881
+ }
882
+ }
883
+
884
+ /// Reducer that can be used with `throttle()` to gather elements into a batch
885
+ @Sendable
886
+ public static func collect( _ batch: [ Generation ] ? , _ element: Generation ) -> [ Generation ] {
887
+ ( batch ?? [ ] ) + [ element]
888
+ }
833
889
}
0 commit comments