Skip to content

additional changes related to async eval #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged

Conversation

davidkoski
Copy link
Collaborator

  • add async eval to llm-tool
  • add maxTokens to GenerateParameters
    • otherwise you can't cap the number of tokens and obtain the info
  • switch throttle to use swift-async-algorithms
    • in the example (front end) code -- I think this is the right level
  • make GenerateCompletionInfo and GenerateResult have identical naming

FYI @ronaldmannak and @Alessan-git -- this builds on your recent work

- add async eval to llm-tool
- add maxTokens to GenerateParameters
	- otherwise you can't cap the number of tokens and obtain the info
- switch throttle to use swift-async-algorithms
	- in the example (front end) code -- I think this is the right level
-  make GenerateCompletionInfo and GenerateResult have identical naming
@davidkoski davidkoski requested a review from awni April 9, 2025 22:22
let generateParameters = GenerateParameters(temperature: 0.6)
let maxTokens = 240
let updateInterval = 0.25
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxTokens added here -- see LLMTool for an explanation why

@@ -254,36 +254,23 @@ class LLMEvaluator {
let stream = try MLXLMCommon.generate(
input: lmInput, parameters: generateParameters, context: context)

var tokenCount = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is kind of messy -- here is the complete version:

                // generate and output in batches
                for await batch in stream._throttle(
                    for: updateInterval, reducing: Generation.collect)
                {
                    let output = batch.compactMap { $0.chunk }.joined(separator: "")
                    if !output.isEmpty {
                        Task { @MainActor [output] in
                            self.output += output
                        }
                    }

                    if let completion = batch.compactMap({ $0.info }).first {
                        Task { @MainActor in
                            self.stat = "\(completion.tokensPerSecond) tokens/s"
                        }
                    }
                }

Per discussion in this PR: #256 (comment) I think this is the right place for it -- we were doing the Date based throttling and I just replaced it with a smaller swift-async-algorithm version. I like it here (rather than in the library) because it is very much a concern for the presentation layer, not the generation. We want developers to look at this for inspiration and so it should have this feature (as it did before I changed it).

What do you think of this approach? I simplified and cleaned it up a bit -- I think it turned out nice.

@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler {
let temp: MLXArray
let topP: MLXArray

init(temperature: Float, topP: Float) {
public init(temperature: Float, topP: Float) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same issue, but I noticed that some of these other methods were not public and should be.

public var promptTokenCount: Int { inputText.tokens.size }

/// The number of tokens generated by the language model.
public var generationTokenCount: Int { tokens.count }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conform the properties to match GenerateCompletionInfo

@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable {
public let promptTime: TimeInterval

/// The time interval (in seconds) taken to generate the output tokens.
public let generationTime: TimeInterval
public let generateTime: TimeInterval
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conform this to match GenerateResult

self.generateTime = generationTime
}

public func summary() -> String {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They both have a summary method now

case .chunk: nil
case .info(let info): info
}
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some utility properties that let us map over these.

}
}
fatalError("exited loop without seeing .info")
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so now why maxTokens. It is a problem in the apps too, which you can see if you set the maxTokens count low -- if it breaks early then it doesn't get the info and won't report the stats. Not as obvious in the apps but very obvious here.

I tried a variety of things, but didn't have a good way to break the loop AND signal to the AsyncStream we were done AND wait for the info:

            let sequence = try MLXLMCommon.generate(input: input, parameters: generateParameters, context: context)

            loop:
            for await item in sequence {
                switch item {
                case .chunk(let string):
                    print(string, terminator: "")
                    tokenCount += 1
                    if tokenCount >= maxTokens {
                       // what now?
                       break loop
                    }
                case .info(let info):
                    return info
                }
            }

            // now what?  continue to iterate the sequence?  but it doesn't know we are stopping.  put it in a `Task`?  circular reference

Just adding the maxTokens seems like the right way to go and it simplifies all the call sites.

Copy link
Contributor

@Alessan-git Alessan-git Apr 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I’m totally on board, putting maxTokens into GenerateParameters is something I’ve been considering for quite a while now. I just wanted to give it a bit more thought. On one hand, I think the name could be better since these are sampling parameters; maybe SamplingParameters would work. On the other hand, this structure is all about setting up the LogitSampler and LogitProcessor, and I’ve been thinking about how to enable users to create their own samplers.

But they can just initialise the TokenIterator with them and/or create a custom TokenIterator and maybe there's no need to complicate things more? In fact, I usually skip the use of GenerateParameters.

// Current TokenIterator init
public init(
        input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil,
        processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512
    ) throws {

But yeah, totally agree that the max token count to sample definitely belongs to this sampling parameters.

@davidkoski davidkoski merged commit 1029eef into main Apr 14, 2025
1 check passed
@davidkoski davidkoski deleted the async-eval-cont branch April 14, 2025 16:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants