-
Notifications
You must be signed in to change notification settings - Fork 205
additional changes related to async eval #266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
// Copyright © 2024 Apple Inc. | ||
|
||
import AsyncAlgorithms | ||
import MLX | ||
import MLXLLM | ||
import MLXLMCommon | ||
|
@@ -169,9 +170,8 @@ class LLMEvaluator { | |
let modelConfiguration = LLMRegistry.qwen2_5_1_5b | ||
|
||
/// parameters controlling the output | ||
let generateParameters = GenerateParameters(temperature: 0.6) | ||
let maxTokens = 240 | ||
let updateInterval = 0.25 | ||
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6) | ||
let updateInterval = Duration.seconds(0.25) | ||
|
||
/// A task responsible for handling the generation process. | ||
var generationTask: Task<Void, Error>? | ||
|
@@ -254,36 +254,23 @@ class LLMEvaluator { | |
let stream = try MLXLMCommon.generate( | ||
input: lmInput, parameters: generateParameters, context: context) | ||
|
||
var tokenCount = 0 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This diff is kind of messy -- here is the complete version: // generate and output in batches
for await batch in stream._throttle(
for: updateInterval, reducing: Generation.collect)
{
let output = batch.compactMap { $0.chunk }.joined(separator: "")
if !output.isEmpty {
Task { @MainActor [output] in
self.output += output
}
}
if let completion = batch.compactMap({ $0.info }).first {
Task { @MainActor in
self.stat = "\(completion.tokensPerSecond) tokens/s"
}
}
} Per discussion in this PR: #256 (comment) I think this is the right place for it -- we were doing the Date based throttling and I just replaced it with a smaller swift-async-algorithm version. I like it here (rather than in the library) because it is very much a concern for the presentation layer, not the generation. We want developers to look at this for inspiration and so it should have this feature (as it did before I changed it). What do you think of this approach? I simplified and cleaned it up a bit -- I think it turned out nice. |
||
var lastEmissionTime: Date = Date() | ||
var chunks = "" | ||
|
||
for await result in stream { | ||
switch result { | ||
case .chunk(let string): | ||
tokenCount += 1 | ||
if tokenCount >= maxTokens { await generationTask?.cancel() } | ||
let now = Date() | ||
if now.timeIntervalSince(lastEmissionTime) >= updateInterval { | ||
lastEmissionTime = now | ||
let text = chunks | ||
chunks = "" | ||
Task { @MainActor in | ||
self.output += text | ||
} | ||
} else { | ||
chunks += string | ||
// generate and output in batches | ||
for await batch in stream._throttle( | ||
for: updateInterval, reducing: Generation.collect) | ||
{ | ||
let output = batch.compactMap { $0.chunk }.joined(separator: "") | ||
if !output.isEmpty { | ||
Task { @MainActor [output] in | ||
self.output += output | ||
} | ||
case .info(let info): | ||
} | ||
|
||
if let completion = batch.compactMap({ $0.info }).first { | ||
Task { @MainActor in | ||
self.stat = "\(info.tokensPerSecond) tokens/s" | ||
self.stat = "\(completion.tokensPerSecond) tokens/s" | ||
} | ||
} | ||
} | ||
|
||
Task { @MainActor in | ||
self.output += chunks | ||
} | ||
} | ||
|
||
} catch { | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -57,6 +57,9 @@ public struct GenerateParameters: Sendable { | |
/// Step size for processing the prompt | ||
public var prefillStepSize = 512 | ||
|
||
/// Maximum tokens to generate | ||
public var maxTokens: Int? | ||
|
||
/// sampling temperature | ||
public var temperature: Float = 0.6 | ||
|
||
|
@@ -70,9 +73,11 @@ public struct GenerateParameters: Sendable { | |
public var repetitionContextSize: Int = 20 | ||
|
||
public init( | ||
maxTokens: Int? = nil, | ||
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, | ||
repetitionContextSize: Int = 20 | ||
) { | ||
self.maxTokens = maxTokens | ||
self.temperature = temperature | ||
self.topP = topP | ||
self.repetitionPenalty = repetitionPenalty | ||
|
@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler { | |
let temp: MLXArray | ||
let topP: MLXArray | ||
|
||
init(temperature: Float, topP: Float) { | ||
public init(temperature: Float, topP: Float) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not exactly the same issue, but I noticed that some of these other methods were not public and should be. |
||
self.temp = MLXArray(temperature) | ||
self.topP = MLXArray(topP) | ||
} | ||
|
@@ -149,7 +154,7 @@ public struct TopPSampler: LogitSampler { | |
public struct CategoricalSampler: LogitSampler { | ||
let temp: MLXArray | ||
|
||
init(temperature: Float) { | ||
public init(temperature: Float) { | ||
self.temp = MLXArray(temperature) | ||
} | ||
|
||
|
@@ -178,7 +183,7 @@ public struct RepetitionContext: LogitProcessor { | |
/// number of tokens to consider for repetition penalty | ||
let repetitionContextSize: Int | ||
|
||
init(repetitionPenalty: Float, repetitionContextSize: Int) { | ||
public init(repetitionPenalty: Float, repetitionContextSize: Int) { | ||
precondition(repetitionContextSize > 0) | ||
self.repetitionPenalty = repetitionPenalty | ||
self.repetitionContextSize = repetitionContextSize | ||
|
@@ -250,6 +255,9 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
var processor: LogitProcessor? | ||
let sampler: LogitSampler | ||
|
||
var tokenCount = 0 | ||
let maxTokens: Int? | ||
|
||
/// Initialize a `TokenIterator` with the given tokens. Note: this has been | ||
/// replaced with ``init(input:model:cache:parameters:)``. | ||
/// | ||
|
@@ -269,6 +277,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
|
||
self.processor = parameters.processor() | ||
self.sampler = parameters.sampler() | ||
self.maxTokens = parameters.maxTokens | ||
|
||
try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize) | ||
} | ||
|
@@ -295,6 +304,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
|
||
self.processor = parameters.processor() | ||
self.sampler = parameters.sampler() | ||
self.maxTokens = parameters.maxTokens | ||
|
||
try prepare(input: input, windowSize: parameters.prefillStepSize) | ||
} | ||
|
@@ -308,16 +318,19 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
/// - processor: the logit processor | ||
/// - sampler: the logit sampler | ||
/// - prefillStepSize: optional prefill step size | ||
/// - maxTokens: maximum number of tokens to generate | ||
public init( | ||
input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil, | ||
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512 | ||
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512, | ||
maxTokens: Int? = nil | ||
) throws { | ||
self.model = model | ||
self.y = input.text | ||
self.cache = cache ?? model.newCache(parameters: nil) | ||
|
||
self.processor = processor | ||
self.sampler = sampler | ||
self.maxTokens = maxTokens | ||
|
||
try prepare(input: input, windowSize: prefillStepSize) | ||
} | ||
|
@@ -365,6 +378,10 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
} | ||
|
||
mutating public func next() -> Int? { | ||
if let maxTokens, tokenCount >= maxTokens { | ||
return nil | ||
} | ||
|
||
// save current value -- this will be returned | ||
let previousY = y | ||
|
||
|
@@ -373,6 +390,8 @@ public struct TokenIterator: Sequence, IteratorProtocol { | |
y = .init(tokens: token) | ||
asyncEval(token) | ||
|
||
tokenCount += 1 | ||
|
||
return previousY.tokens.item(Int.self) | ||
} | ||
} | ||
|
@@ -413,24 +432,32 @@ public struct GenerateResult: Sendable { | |
/// output text | ||
public let output: String | ||
|
||
/// The number of tokens included in the input prompt. | ||
public var promptTokenCount: Int { inputText.tokens.size } | ||
|
||
/// The number of tokens generated by the language model. | ||
public var generationTokenCount: Int { tokens.count } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conform the properties to match GenerateCompletionInfo |
||
|
||
/// time to process the prompt / generate the first token | ||
public let promptTime: TimeInterval | ||
|
||
/// time to generate the remaining tokens | ||
public let generateTime: TimeInterval | ||
|
||
/// The number of tokens processed per second during the prompt phase. | ||
public var promptTokensPerSecond: Double { | ||
Double(inputText.tokens.size) / promptTime | ||
} | ||
|
||
/// The number of tokens generated per second during the generation phase. | ||
public var tokensPerSecond: Double { | ||
Double(tokens.count) / generateTime | ||
} | ||
|
||
public func summary() -> String { | ||
""" | ||
Prompt: \(inputText.tokens.size) tokens, \(promptTokensPerSecond.formatted()) tokens/s | ||
Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s | ||
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s | ||
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s | ||
""" | ||
} | ||
} | ||
|
@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable { | |
public let promptTime: TimeInterval | ||
|
||
/// The time interval (in seconds) taken to generate the output tokens. | ||
public let generationTime: TimeInterval | ||
public let generateTime: TimeInterval | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Conform this to match GenerateResult |
||
|
||
/// The number of tokens processed per second during the prompt phase. | ||
public var promptTokensPerSecond: Double { | ||
|
@@ -804,7 +831,7 @@ public struct GenerateCompletionInfo: Sendable { | |
|
||
/// The number of tokens generated per second during the generation phase. | ||
public var tokensPerSecond: Double { | ||
Double(generationTokenCount) / generationTime | ||
Double(generationTokenCount) / generateTime | ||
} | ||
|
||
public init( | ||
|
@@ -816,7 +843,14 @@ public struct GenerateCompletionInfo: Sendable { | |
self.promptTokenCount = promptTokenCount | ||
self.generationTokenCount = generationTokenCount | ||
self.promptTime = promptTime | ||
self.generationTime = generationTime | ||
self.generateTime = generationTime | ||
} | ||
|
||
public func summary() -> String { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. They both have a summary method now |
||
""" | ||
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s | ||
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s | ||
""" | ||
} | ||
} | ||
|
||
|
@@ -825,9 +859,31 @@ public struct GenerateCompletionInfo: Sendable { | |
/// This enum distinguishes between the following: | ||
/// - `.chunk`: A decoded string from one or more tokens generated by the language model. | ||
/// - `.info`: Metadata and performance statistics about the generation process. | ||
public enum Generation { | ||
/// A generated token represented as an integer. | ||
public enum Generation: Sendable { | ||
/// A generated token represented as a String | ||
case chunk(String) | ||
/// Completion information summarizing token counts and performance metrics. | ||
case info(GenerateCompletionInfo) | ||
|
||
/// Generated text or nil | ||
public var chunk: String? { | ||
switch self { | ||
case .chunk(let string): string | ||
case .info: nil | ||
} | ||
} | ||
|
||
/// Completion info or nil | ||
public var info: GenerateCompletionInfo? { | ||
switch self { | ||
case .chunk: nil | ||
case .info(let info): info | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some utility properties that let us map over these. |
||
|
||
/// Reducer that can be used with `throttle()` to gather elements into a batch | ||
@Sendable | ||
public static func collect(_ batch: [Generation]?, _ element: Generation) -> [Generation] { | ||
(batch ?? []) + [element] | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxTokens added here -- see LLMTool for an explanation why