-
Notifications
You must be signed in to change notification settings - Fork 199
additional changes related to async eval #266
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- add async eval to llm-tool - add maxTokens to GenerateParameters - otherwise you can't cap the number of tokens and obtain the info - switch throttle to use swift-async-algorithms - in the example (front end) code -- I think this is the right level - make GenerateCompletionInfo and GenerateResult have identical naming
let generateParameters = GenerateParameters(temperature: 0.6) | ||
let maxTokens = 240 | ||
let updateInterval = 0.25 | ||
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maxTokens added here -- see LLMTool for an explanation why
@@ -254,36 +254,23 @@ class LLMEvaluator { | |||
let stream = try MLXLMCommon.generate( | |||
input: lmInput, parameters: generateParameters, context: context) | |||
|
|||
var tokenCount = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This diff is kind of messy -- here is the complete version:
// generate and output in batches
for await batch in stream._throttle(
for: updateInterval, reducing: Generation.collect)
{
let output = batch.compactMap { $0.chunk }.joined(separator: "")
if !output.isEmpty {
Task { @MainActor [output] in
self.output += output
}
}
if let completion = batch.compactMap({ $0.info }).first {
Task { @MainActor in
self.stat = "\(completion.tokensPerSecond) tokens/s"
}
}
}
Per discussion in this PR: #256 (comment) I think this is the right place for it -- we were doing the Date based throttling and I just replaced it with a smaller swift-async-algorithm version. I like it here (rather than in the library) because it is very much a concern for the presentation layer, not the generation. We want developers to look at this for inspiration and so it should have this feature (as it did before I changed it).
What do you think of this approach? I simplified and cleaned it up a bit -- I think it turned out nice.
@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler { | |||
let temp: MLXArray | |||
let topP: MLXArray | |||
|
|||
init(temperature: Float, topP: Float) { | |||
public init(temperature: Float, topP: Float) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not exactly the same issue, but I noticed that some of these other methods were not public and should be.
public var promptTokenCount: Int { inputText.tokens.size } | ||
|
||
/// The number of tokens generated by the language model. | ||
public var generationTokenCount: Int { tokens.count } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conform the properties to match GenerateCompletionInfo
@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable { | |||
public let promptTime: TimeInterval | |||
|
|||
/// The time interval (in seconds) taken to generate the output tokens. | |||
public let generationTime: TimeInterval | |||
public let generateTime: TimeInterval |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conform this to match GenerateResult
self.generateTime = generationTime | ||
} | ||
|
||
public func summary() -> String { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They both have a summary method now
case .chunk: nil | ||
case .info(let info): info | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some utility properties that let us map over these.
} | ||
} | ||
fatalError("exited loop without seeing .info") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, so now why maxTokens. It is a problem in the apps too, which you can see if you set the maxTokens count low -- if it breaks early then it doesn't get the info
and won't report the stats. Not as obvious in the apps but very obvious here.
I tried a variety of things, but didn't have a good way to break the loop AND signal to the AsyncStream we were done AND wait for the info
:
let sequence = try MLXLMCommon.generate(input: input, parameters: generateParameters, context: context)
loop:
for await item in sequence {
switch item {
case .chunk(let string):
print(string, terminator: "")
tokenCount += 1
if tokenCount >= maxTokens {
// what now?
break loop
}
case .info(let info):
return info
}
}
// now what? continue to iterate the sequence? but it doesn't know we are stopping. put it in a `Task`? circular reference
Just adding the maxTokens
seems like the right way to go and it simplifies all the call sites.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I’m totally on board, putting maxTokens into GenerateParameters is something I’ve been considering for quite a while now. I just wanted to give it a bit more thought. On one hand, I think the name could be better since these are sampling parameters; maybe SamplingParameters would work. On the other hand, this structure is all about setting up the LogitSampler and LogitProcessor, and I’ve been thinking about how to enable users to create their own samplers.
But they can just initialise the TokenIterator with them and/or create a custom TokenIterator and maybe there's no need to complicate things more? In fact, I usually skip the use of GenerateParameters.
// Current TokenIterator init
public init(
input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil,
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512
) throws {
But yeah, totally agree that the max token count to sample definitely belongs to this sampling parameters.
FYI @ronaldmannak and @Alessan-git -- this builds on your recent work