Skip to content

Commit 1029eef

Browse files
authored
additional changes related to async eval (#266)
- add async eval to llm-tool - add maxTokens to GenerateParameters - otherwise you can't cap the number of tokens and obtain the info - switch throttle to use swift-async-algorithms - in the example (front end) code -- I think this is the right level - make GenerateCompletionInfo and GenerateResult have identical naming
1 parent 0c65930 commit 1029eef

File tree

7 files changed

+146
-89
lines changed

7 files changed

+146
-89
lines changed

Diff for: Applications/LLMEval/ContentView.swift

+15-28
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
// Copyright © 2024 Apple Inc.
22

3+
import AsyncAlgorithms
34
import MLX
45
import MLXLLM
56
import MLXLMCommon
@@ -169,9 +170,8 @@ class LLMEvaluator {
169170
let modelConfiguration = LLMRegistry.qwen2_5_1_5b
170171

171172
/// parameters controlling the output
172-
let generateParameters = GenerateParameters(temperature: 0.6)
173-
let maxTokens = 240
174-
let updateInterval = 0.25
173+
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6)
174+
let updateInterval = Duration.seconds(0.25)
175175

176176
/// A task responsible for handling the generation process.
177177
var generationTask: Task<Void, Error>?
@@ -254,36 +254,23 @@ class LLMEvaluator {
254254
let stream = try MLXLMCommon.generate(
255255
input: lmInput, parameters: generateParameters, context: context)
256256

257-
var tokenCount = 0
258-
var lastEmissionTime: Date = Date()
259-
var chunks = ""
260-
261-
for await result in stream {
262-
switch result {
263-
case .chunk(let string):
264-
tokenCount += 1
265-
if tokenCount >= maxTokens { await generationTask?.cancel() }
266-
let now = Date()
267-
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
268-
lastEmissionTime = now
269-
let text = chunks
270-
chunks = ""
271-
Task { @MainActor in
272-
self.output += text
273-
}
274-
} else {
275-
chunks += string
257+
// generate and output in batches
258+
for await batch in stream._throttle(
259+
for: updateInterval, reducing: Generation.collect)
260+
{
261+
let output = batch.compactMap { $0.chunk }.joined(separator: "")
262+
if !output.isEmpty {
263+
Task { @MainActor [output] in
264+
self.output += output
276265
}
277-
case .info(let info):
266+
}
267+
268+
if let completion = batch.compactMap({ $0.info }).first {
278269
Task { @MainActor in
279-
self.stat = "\(info.tokensPerSecond) tokens/s"
270+
self.stat = "\(completion.tokensPerSecond) tokens/s"
280271
}
281272
}
282273
}
283-
284-
Task { @MainActor in
285-
self.output += chunks
286-
}
287274
}
288275

289276
} catch {

Diff for: Applications/VLMEval/ContentView.swift

+16-28
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
// Copyright 2024 Apple Inc.
22

33
import AVKit
4+
import AsyncAlgorithms
45
import CoreImage
56
import MLX
67
import MLXLMCommon
@@ -338,9 +339,9 @@ class VLMEvaluator {
338339
let modelConfiguration = VLMRegistry.smolvlm
339340

340341
/// parameters controlling the output – use values appropriate for the model selected above
341-
let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.7, topP: 0.9)
342-
let maxTokens = 800
343-
let updateInterval = 0.25
342+
let generateParameters = MLXLMCommon.GenerateParameters(
343+
maxTokens: 800, temperature: 0.7, topP: 0.9)
344+
let updateInterval = Duration.seconds(0.25)
344345

345346
/// A task responsible for handling the generation process.
346347
var generationTask: Task<Void, Error>?
@@ -444,36 +445,23 @@ class VLMEvaluator {
444445
let stream = try MLXLMCommon.generate(
445446
input: lmInput, parameters: generateParameters, context: context)
446447

447-
var tokenCount = 0
448-
var lastEmissionTime: Date = Date()
449-
var chunks = ""
450-
451-
for await result in stream {
452-
switch result {
453-
case .chunk(let string):
454-
tokenCount += 1
455-
if tokenCount >= maxTokens { await generationTask?.cancel() }
456-
let now = Date()
457-
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
458-
lastEmissionTime = now
459-
let text = chunks
460-
chunks = ""
461-
Task { @MainActor in
462-
self.output += text
463-
}
464-
} else {
465-
chunks += string
448+
// generate and output in batches
449+
for await batch in stream._throttle(
450+
for: updateInterval, reducing: Generation.collect)
451+
{
452+
let output = batch.compactMap { $0.chunk }.joined(separator: "")
453+
if !output.isEmpty {
454+
Task { @MainActor [output] in
455+
self.output += output
466456
}
467-
case .info(let info):
457+
}
458+
459+
if let completion = batch.compactMap({ $0.info }).first {
468460
Task { @MainActor in
469-
self.stat = "\(info.tokensPerSecond) tokens/s"
461+
self.stat = "\(completion.tokensPerSecond) tokens/s"
470462
}
471463
}
472464
}
473-
474-
Task { @MainActor in
475-
self.output += chunks
476-
}
477465
}
478466
} catch {
479467
output = "Failed: \(error)"

Diff for: Libraries/MLXLMCommon/Evaluate.swift

+67-11
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ public struct GenerateParameters: Sendable {
5757
/// Step size for processing the prompt
5858
public var prefillStepSize = 512
5959

60+
/// Maximum tokens to generate
61+
public var maxTokens: Int?
62+
6063
/// sampling temperature
6164
public var temperature: Float = 0.6
6265

@@ -70,9 +73,11 @@ public struct GenerateParameters: Sendable {
7073
public var repetitionContextSize: Int = 20
7174

7275
public init(
76+
maxTokens: Int? = nil,
7377
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
7478
repetitionContextSize: Int = 20
7579
) {
80+
self.maxTokens = maxTokens
7681
self.temperature = temperature
7782
self.topP = topP
7883
self.repetitionPenalty = repetitionPenalty
@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler {
111116
let temp: MLXArray
112117
let topP: MLXArray
113118

114-
init(temperature: Float, topP: Float) {
119+
public init(temperature: Float, topP: Float) {
115120
self.temp = MLXArray(temperature)
116121
self.topP = MLXArray(topP)
117122
}
@@ -149,7 +154,7 @@ public struct TopPSampler: LogitSampler {
149154
public struct CategoricalSampler: LogitSampler {
150155
let temp: MLXArray
151156

152-
init(temperature: Float) {
157+
public init(temperature: Float) {
153158
self.temp = MLXArray(temperature)
154159
}
155160

@@ -178,7 +183,7 @@ public struct RepetitionContext: LogitProcessor {
178183
/// number of tokens to consider for repetition penalty
179184
let repetitionContextSize: Int
180185

181-
init(repetitionPenalty: Float, repetitionContextSize: Int) {
186+
public init(repetitionPenalty: Float, repetitionContextSize: Int) {
182187
precondition(repetitionContextSize > 0)
183188
self.repetitionPenalty = repetitionPenalty
184189
self.repetitionContextSize = repetitionContextSize
@@ -250,6 +255,9 @@ public struct TokenIterator: Sequence, IteratorProtocol {
250255
var processor: LogitProcessor?
251256
let sampler: LogitSampler
252257

258+
var tokenCount = 0
259+
let maxTokens: Int?
260+
253261
/// Initialize a `TokenIterator` with the given tokens. Note: this has been
254262
/// replaced with ``init(input:model:cache:parameters:)``.
255263
///
@@ -269,6 +277,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
269277

270278
self.processor = parameters.processor()
271279
self.sampler = parameters.sampler()
280+
self.maxTokens = parameters.maxTokens
272281

273282
try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize)
274283
}
@@ -295,6 +304,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {
295304

296305
self.processor = parameters.processor()
297306
self.sampler = parameters.sampler()
307+
self.maxTokens = parameters.maxTokens
298308

299309
try prepare(input: input, windowSize: parameters.prefillStepSize)
300310
}
@@ -308,16 +318,19 @@ public struct TokenIterator: Sequence, IteratorProtocol {
308318
/// - processor: the logit processor
309319
/// - sampler: the logit sampler
310320
/// - prefillStepSize: optional prefill step size
321+
/// - maxTokens: maximum number of tokens to generate
311322
public init(
312323
input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil,
313-
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512
324+
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512,
325+
maxTokens: Int? = nil
314326
) throws {
315327
self.model = model
316328
self.y = input.text
317329
self.cache = cache ?? model.newCache(parameters: nil)
318330

319331
self.processor = processor
320332
self.sampler = sampler
333+
self.maxTokens = maxTokens
321334

322335
try prepare(input: input, windowSize: prefillStepSize)
323336
}
@@ -365,6 +378,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
365378
}
366379

367380
mutating public func next() -> Int? {
381+
if let maxTokens, tokenCount >= maxTokens {
382+
return nil
383+
}
384+
368385
// save current value -- this will be returned
369386
let previousY = y
370387

@@ -373,6 +390,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
373390
y = .init(tokens: token)
374391
asyncEval(token)
375392

393+
tokenCount += 1
394+
376395
return previousY.tokens.item(Int.self)
377396
}
378397
}
@@ -413,24 +432,32 @@ public struct GenerateResult: Sendable {
413432
/// output text
414433
public let output: String
415434

435+
/// The number of tokens included in the input prompt.
436+
public var promptTokenCount: Int { inputText.tokens.size }
437+
438+
/// The number of tokens generated by the language model.
439+
public var generationTokenCount: Int { tokens.count }
440+
416441
/// time to process the prompt / generate the first token
417442
public let promptTime: TimeInterval
418443

419444
/// time to generate the remaining tokens
420445
public let generateTime: TimeInterval
421446

447+
/// The number of tokens processed per second during the prompt phase.
422448
public var promptTokensPerSecond: Double {
423449
Double(inputText.tokens.size) / promptTime
424450
}
425451

452+
/// The number of tokens generated per second during the generation phase.
426453
public var tokensPerSecond: Double {
427454
Double(tokens.count) / generateTime
428455
}
429456

430457
public func summary() -> String {
431458
"""
432-
Prompt: \(inputText.tokens.size) tokens, \(promptTokensPerSecond.formatted()) tokens/s
433-
Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
459+
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s
460+
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
434461
"""
435462
}
436463
}
@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable {
795822
public let promptTime: TimeInterval
796823

797824
/// The time interval (in seconds) taken to generate the output tokens.
798-
public let generationTime: TimeInterval
825+
public let generateTime: TimeInterval
799826

800827
/// The number of tokens processed per second during the prompt phase.
801828
public var promptTokensPerSecond: Double {
@@ -804,7 +831,7 @@ public struct GenerateCompletionInfo: Sendable {
804831

805832
/// The number of tokens generated per second during the generation phase.
806833
public var tokensPerSecond: Double {
807-
Double(generationTokenCount) / generationTime
834+
Double(generationTokenCount) / generateTime
808835
}
809836

810837
public init(
@@ -816,7 +843,14 @@ public struct GenerateCompletionInfo: Sendable {
816843
self.promptTokenCount = promptTokenCount
817844
self.generationTokenCount = generationTokenCount
818845
self.promptTime = promptTime
819-
self.generationTime = generationTime
846+
self.generateTime = generationTime
847+
}
848+
849+
public func summary() -> String {
850+
"""
851+
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s
852+
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
853+
"""
820854
}
821855
}
822856

@@ -825,9 +859,31 @@ public struct GenerateCompletionInfo: Sendable {
825859
/// This enum distinguishes between the following:
826860
/// - `.chunk`: A decoded string from one or more tokens generated by the language model.
827861
/// - `.info`: Metadata and performance statistics about the generation process.
828-
public enum Generation {
829-
/// A generated token represented as an integer.
862+
public enum Generation: Sendable {
863+
/// A generated token represented as a String
830864
case chunk(String)
831865
/// Completion information summarizing token counts and performance metrics.
832866
case info(GenerateCompletionInfo)
867+
868+
/// Generated text or nil
869+
public var chunk: String? {
870+
switch self {
871+
case .chunk(let string): string
872+
case .info: nil
873+
}
874+
}
875+
876+
/// Completion info or nil
877+
public var info: GenerateCompletionInfo? {
878+
switch self {
879+
case .chunk: nil
880+
case .info(let info): info
881+
}
882+
}
883+
884+
/// Reducer that can be used with `throttle()` to gather elements into a batch
885+
@Sendable
886+
public static func collect(_ batch: [Generation]?, _ element: Generation) -> [Generation] {
887+
(batch ?? []) + [element]
888+
}
833889
}

0 commit comments

Comments
 (0)