Skip to content

additional changes related to async eval #266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 14, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 15 additions & 28 deletions Applications/LLMEval/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc.

import AsyncAlgorithms
import MLX
import MLXLLM
import MLXLMCommon
@@ -169,9 +170,8 @@ class LLMEvaluator {
let modelConfiguration = LLMRegistry.qwen2_5_1_5b

/// parameters controlling the output
let generateParameters = GenerateParameters(temperature: 0.6)
let maxTokens = 240
let updateInterval = 0.25
let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maxTokens added here -- see LLMTool for an explanation why

let updateInterval = Duration.seconds(0.25)

/// A task responsible for handling the generation process.
var generationTask: Task<Void, Error>?
@@ -254,36 +254,23 @@ class LLMEvaluator {
let stream = try MLXLMCommon.generate(
input: lmInput, parameters: generateParameters, context: context)

var tokenCount = 0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This diff is kind of messy -- here is the complete version:

                // generate and output in batches
                for await batch in stream._throttle(
                    for: updateInterval, reducing: Generation.collect)
                {
                    let output = batch.compactMap { $0.chunk }.joined(separator: "")
                    if !output.isEmpty {
                        Task { @MainActor [output] in
                            self.output += output
                        }
                    }

                    if let completion = batch.compactMap({ $0.info }).first {
                        Task { @MainActor in
                            self.stat = "\(completion.tokensPerSecond) tokens/s"
                        }
                    }
                }

Per discussion in this PR: #256 (comment) I think this is the right place for it -- we were doing the Date based throttling and I just replaced it with a smaller swift-async-algorithm version. I like it here (rather than in the library) because it is very much a concern for the presentation layer, not the generation. We want developers to look at this for inspiration and so it should have this feature (as it did before I changed it).

What do you think of this approach? I simplified and cleaned it up a bit -- I think it turned out nice.

var lastEmissionTime: Date = Date()
var chunks = ""

for await result in stream {
switch result {
case .chunk(let string):
tokenCount += 1
if tokenCount >= maxTokens { await generationTask?.cancel() }
let now = Date()
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
lastEmissionTime = now
let text = chunks
chunks = ""
Task { @MainActor in
self.output += text
}
} else {
chunks += string
// generate and output in batches
for await batch in stream._throttle(
for: updateInterval, reducing: Generation.collect)
{
let output = batch.compactMap { $0.chunk }.joined(separator: "")
if !output.isEmpty {
Task { @MainActor [output] in
self.output += output
}
case .info(let info):
}

if let completion = batch.compactMap({ $0.info }).first {
Task { @MainActor in
self.stat = "\(info.tokensPerSecond) tokens/s"
self.stat = "\(completion.tokensPerSecond) tokens/s"
}
}
}

Task { @MainActor in
self.output += chunks
}
}

} catch {
44 changes: 16 additions & 28 deletions Applications/VLMEval/ContentView.swift
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright 2024 Apple Inc.

import AVKit
import AsyncAlgorithms
import CoreImage
import MLX
import MLXLMCommon
@@ -338,9 +339,9 @@ class VLMEvaluator {
let modelConfiguration = VLMRegistry.smolvlm

/// parameters controlling the output – use values appropriate for the model selected above
let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.7, topP: 0.9)
let maxTokens = 800
let updateInterval = 0.25
let generateParameters = MLXLMCommon.GenerateParameters(
maxTokens: 800, temperature: 0.7, topP: 0.9)
let updateInterval = Duration.seconds(0.25)

/// A task responsible for handling the generation process.
var generationTask: Task<Void, Error>?
@@ -444,36 +445,23 @@ class VLMEvaluator {
let stream = try MLXLMCommon.generate(
input: lmInput, parameters: generateParameters, context: context)

var tokenCount = 0
var lastEmissionTime: Date = Date()
var chunks = ""

for await result in stream {
switch result {
case .chunk(let string):
tokenCount += 1
if tokenCount >= maxTokens { await generationTask?.cancel() }
let now = Date()
if now.timeIntervalSince(lastEmissionTime) >= updateInterval {
lastEmissionTime = now
let text = chunks
chunks = ""
Task { @MainActor in
self.output += text
}
} else {
chunks += string
// generate and output in batches
for await batch in stream._throttle(
for: updateInterval, reducing: Generation.collect)
{
let output = batch.compactMap { $0.chunk }.joined(separator: "")
if !output.isEmpty {
Task { @MainActor [output] in
self.output += output
}
case .info(let info):
}

if let completion = batch.compactMap({ $0.info }).first {
Task { @MainActor in
self.stat = "\(info.tokensPerSecond) tokens/s"
self.stat = "\(completion.tokensPerSecond) tokens/s"
}
}
}

Task { @MainActor in
self.output += chunks
}
}
} catch {
output = "Failed: \(error)"
78 changes: 67 additions & 11 deletions Libraries/MLXLMCommon/Evaluate.swift
Original file line number Diff line number Diff line change
@@ -57,6 +57,9 @@ public struct GenerateParameters: Sendable {
/// Step size for processing the prompt
public var prefillStepSize = 512

/// Maximum tokens to generate
public var maxTokens: Int?

/// sampling temperature
public var temperature: Float = 0.6

@@ -70,9 +73,11 @@ public struct GenerateParameters: Sendable {
public var repetitionContextSize: Int = 20

public init(
maxTokens: Int? = nil,
temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil,
repetitionContextSize: Int = 20
) {
self.maxTokens = maxTokens
self.temperature = temperature
self.topP = topP
self.repetitionPenalty = repetitionPenalty
@@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler {
let temp: MLXArray
let topP: MLXArray

init(temperature: Float, topP: Float) {
public init(temperature: Float, topP: Float) {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not exactly the same issue, but I noticed that some of these other methods were not public and should be.

self.temp = MLXArray(temperature)
self.topP = MLXArray(topP)
}
@@ -149,7 +154,7 @@ public struct TopPSampler: LogitSampler {
public struct CategoricalSampler: LogitSampler {
let temp: MLXArray

init(temperature: Float) {
public init(temperature: Float) {
self.temp = MLXArray(temperature)
}

@@ -178,7 +183,7 @@ public struct RepetitionContext: LogitProcessor {
/// number of tokens to consider for repetition penalty
let repetitionContextSize: Int

init(repetitionPenalty: Float, repetitionContextSize: Int) {
public init(repetitionPenalty: Float, repetitionContextSize: Int) {
precondition(repetitionContextSize > 0)
self.repetitionPenalty = repetitionPenalty
self.repetitionContextSize = repetitionContextSize
@@ -250,6 +255,9 @@ public struct TokenIterator: Sequence, IteratorProtocol {
var processor: LogitProcessor?
let sampler: LogitSampler

var tokenCount = 0
let maxTokens: Int?

/// Initialize a `TokenIterator` with the given tokens. Note: this has been
/// replaced with ``init(input:model:cache:parameters:)``.
///
@@ -269,6 +277,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {

self.processor = parameters.processor()
self.sampler = parameters.sampler()
self.maxTokens = parameters.maxTokens

try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize)
}
@@ -295,6 +304,7 @@ public struct TokenIterator: Sequence, IteratorProtocol {

self.processor = parameters.processor()
self.sampler = parameters.sampler()
self.maxTokens = parameters.maxTokens

try prepare(input: input, windowSize: parameters.prefillStepSize)
}
@@ -308,16 +318,19 @@ public struct TokenIterator: Sequence, IteratorProtocol {
/// - processor: the logit processor
/// - sampler: the logit sampler
/// - prefillStepSize: optional prefill step size
/// - maxTokens: maximum number of tokens to generate
public init(
input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil,
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512
processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512,
maxTokens: Int? = nil
) throws {
self.model = model
self.y = input.text
self.cache = cache ?? model.newCache(parameters: nil)

self.processor = processor
self.sampler = sampler
self.maxTokens = maxTokens

try prepare(input: input, windowSize: prefillStepSize)
}
@@ -365,6 +378,10 @@ public struct TokenIterator: Sequence, IteratorProtocol {
}

mutating public func next() -> Int? {
if let maxTokens, tokenCount >= maxTokens {
return nil
}

// save current value -- this will be returned
let previousY = y

@@ -373,6 +390,8 @@ public struct TokenIterator: Sequence, IteratorProtocol {
y = .init(tokens: token)
asyncEval(token)

tokenCount += 1

return previousY.tokens.item(Int.self)
}
}
@@ -413,24 +432,32 @@ public struct GenerateResult: Sendable {
/// output text
public let output: String

/// The number of tokens included in the input prompt.
public var promptTokenCount: Int { inputText.tokens.size }

/// The number of tokens generated by the language model.
public var generationTokenCount: Int { tokens.count }
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conform the properties to match GenerateCompletionInfo


/// time to process the prompt / generate the first token
public let promptTime: TimeInterval

/// time to generate the remaining tokens
public let generateTime: TimeInterval

/// The number of tokens processed per second during the prompt phase.
public var promptTokensPerSecond: Double {
Double(inputText.tokens.size) / promptTime
}

/// The number of tokens generated per second during the generation phase.
public var tokensPerSecond: Double {
Double(tokens.count) / generateTime
}

public func summary() -> String {
"""
Prompt: \(inputText.tokens.size) tokens, \(promptTokensPerSecond.formatted()) tokens/s
Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
"""
}
}
@@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable {
public let promptTime: TimeInterval

/// The time interval (in seconds) taken to generate the output tokens.
public let generationTime: TimeInterval
public let generateTime: TimeInterval
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Conform this to match GenerateResult


/// The number of tokens processed per second during the prompt phase.
public var promptTokensPerSecond: Double {
@@ -804,7 +831,7 @@ public struct GenerateCompletionInfo: Sendable {

/// The number of tokens generated per second during the generation phase.
public var tokensPerSecond: Double {
Double(generationTokenCount) / generationTime
Double(generationTokenCount) / generateTime
}

public init(
@@ -816,7 +843,14 @@ public struct GenerateCompletionInfo: Sendable {
self.promptTokenCount = promptTokenCount
self.generationTokenCount = generationTokenCount
self.promptTime = promptTime
self.generationTime = generationTime
self.generateTime = generationTime
}

public func summary() -> String {
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They both have a summary method now

"""
Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s
Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s
"""
}
}

@@ -825,9 +859,31 @@ public struct GenerateCompletionInfo: Sendable {
/// This enum distinguishes between the following:
/// - `.chunk`: A decoded string from one or more tokens generated by the language model.
/// - `.info`: Metadata and performance statistics about the generation process.
public enum Generation {
/// A generated token represented as an integer.
public enum Generation: Sendable {
/// A generated token represented as a String
case chunk(String)
/// Completion information summarizing token counts and performance metrics.
case info(GenerateCompletionInfo)

/// Generated text or nil
public var chunk: String? {
switch self {
case .chunk(let string): string
case .info: nil
}
}

/// Completion info or nil
public var info: GenerateCompletionInfo? {
switch self {
case .chunk: nil
case .info(let info): info
}
}
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some utility properties that let us map over these.


/// Reducer that can be used with `throttle()` to gather elements into a batch
@Sendable
public static func collect(_ batch: [Generation]?, _ element: Generation) -> [Generation] {
(batch ?? []) + [element]
}
}
Loading