diff --git a/Applications/LLMEval/ContentView.swift b/Applications/LLMEval/ContentView.swift index 74e1a9e8..5bf9bae1 100644 --- a/Applications/LLMEval/ContentView.swift +++ b/Applications/LLMEval/ContentView.swift @@ -1,5 +1,6 @@ // Copyright © 2024 Apple Inc. +import AsyncAlgorithms import MLX import MLXLLM import MLXLMCommon @@ -169,9 +170,8 @@ class LLMEvaluator { let modelConfiguration = LLMRegistry.qwen2_5_1_5b /// parameters controlling the output - let generateParameters = GenerateParameters(temperature: 0.6) - let maxTokens = 240 - let updateInterval = 0.25 + let generateParameters = GenerateParameters(maxTokens: 240, temperature: 0.6) + let updateInterval = Duration.seconds(0.25) /// A task responsible for handling the generation process. var generationTask: Task? @@ -254,36 +254,23 @@ class LLMEvaluator { let stream = try MLXLMCommon.generate( input: lmInput, parameters: generateParameters, context: context) - var tokenCount = 0 - var lastEmissionTime: Date = Date() - var chunks = "" - - for await result in stream { - switch result { - case .chunk(let string): - tokenCount += 1 - if tokenCount >= maxTokens { await generationTask?.cancel() } - let now = Date() - if now.timeIntervalSince(lastEmissionTime) >= updateInterval { - lastEmissionTime = now - let text = chunks - chunks = "" - Task { @MainActor in - self.output += text - } - } else { - chunks += string + // generate and output in batches + for await batch in stream._throttle( + for: updateInterval, reducing: Generation.collect) + { + let output = batch.compactMap { $0.chunk }.joined(separator: "") + if !output.isEmpty { + Task { @MainActor [output] in + self.output += output } - case .info(let info): + } + + if let completion = batch.compactMap({ $0.info }).first { Task { @MainActor in - self.stat = "\(info.tokensPerSecond) tokens/s" + self.stat = "\(completion.tokensPerSecond) tokens/s" } } } - - Task { @MainActor in - self.output += chunks - } } } catch { diff --git a/Applications/VLMEval/ContentView.swift b/Applications/VLMEval/ContentView.swift index 903e2dd5..5b9dbfc9 100644 --- a/Applications/VLMEval/ContentView.swift +++ b/Applications/VLMEval/ContentView.swift @@ -1,6 +1,7 @@ // Copyright 2024 Apple Inc. import AVKit +import AsyncAlgorithms import CoreImage import MLX import MLXLMCommon @@ -338,9 +339,9 @@ class VLMEvaluator { let modelConfiguration = VLMRegistry.smolvlm /// parameters controlling the output – use values appropriate for the model selected above - let generateParameters = MLXLMCommon.GenerateParameters(temperature: 0.7, topP: 0.9) - let maxTokens = 800 - let updateInterval = 0.25 + let generateParameters = MLXLMCommon.GenerateParameters( + maxTokens: 800, temperature: 0.7, topP: 0.9) + let updateInterval = Duration.seconds(0.25) /// A task responsible for handling the generation process. var generationTask: Task? @@ -444,36 +445,23 @@ class VLMEvaluator { let stream = try MLXLMCommon.generate( input: lmInput, parameters: generateParameters, context: context) - var tokenCount = 0 - var lastEmissionTime: Date = Date() - var chunks = "" - - for await result in stream { - switch result { - case .chunk(let string): - tokenCount += 1 - if tokenCount >= maxTokens { await generationTask?.cancel() } - let now = Date() - if now.timeIntervalSince(lastEmissionTime) >= updateInterval { - lastEmissionTime = now - let text = chunks - chunks = "" - Task { @MainActor in - self.output += text - } - } else { - chunks += string + // generate and output in batches + for await batch in stream._throttle( + for: updateInterval, reducing: Generation.collect) + { + let output = batch.compactMap { $0.chunk }.joined(separator: "") + if !output.isEmpty { + Task { @MainActor [output] in + self.output += output } - case .info(let info): + } + + if let completion = batch.compactMap({ $0.info }).first { Task { @MainActor in - self.stat = "\(info.tokensPerSecond) tokens/s" + self.stat = "\(completion.tokensPerSecond) tokens/s" } } } - - Task { @MainActor in - self.output += chunks - } } } catch { output = "Failed: \(error)" diff --git a/Libraries/MLXLMCommon/Evaluate.swift b/Libraries/MLXLMCommon/Evaluate.swift index ff005fea..26a27255 100644 --- a/Libraries/MLXLMCommon/Evaluate.swift +++ b/Libraries/MLXLMCommon/Evaluate.swift @@ -57,6 +57,9 @@ public struct GenerateParameters: Sendable { /// Step size for processing the prompt public var prefillStepSize = 512 + /// Maximum tokens to generate + public var maxTokens: Int? + /// sampling temperature public var temperature: Float = 0.6 @@ -70,9 +73,11 @@ public struct GenerateParameters: Sendable { public var repetitionContextSize: Int = 20 public init( + maxTokens: Int? = nil, temperature: Float = 0.6, topP: Float = 1.0, repetitionPenalty: Float? = nil, repetitionContextSize: Int = 20 ) { + self.maxTokens = maxTokens self.temperature = temperature self.topP = topP self.repetitionPenalty = repetitionPenalty @@ -111,7 +116,7 @@ public struct TopPSampler: LogitSampler { let temp: MLXArray let topP: MLXArray - init(temperature: Float, topP: Float) { + public init(temperature: Float, topP: Float) { self.temp = MLXArray(temperature) self.topP = MLXArray(topP) } @@ -149,7 +154,7 @@ public struct TopPSampler: LogitSampler { public struct CategoricalSampler: LogitSampler { let temp: MLXArray - init(temperature: Float) { + public init(temperature: Float) { self.temp = MLXArray(temperature) } @@ -178,7 +183,7 @@ public struct RepetitionContext: LogitProcessor { /// number of tokens to consider for repetition penalty let repetitionContextSize: Int - init(repetitionPenalty: Float, repetitionContextSize: Int) { + public init(repetitionPenalty: Float, repetitionContextSize: Int) { precondition(repetitionContextSize > 0) self.repetitionPenalty = repetitionPenalty self.repetitionContextSize = repetitionContextSize @@ -250,6 +255,9 @@ public struct TokenIterator: Sequence, IteratorProtocol { var processor: LogitProcessor? let sampler: LogitSampler + var tokenCount = 0 + let maxTokens: Int? + /// Initialize a `TokenIterator` with the given tokens. Note: this has been /// replaced with ``init(input:model:cache:parameters:)``. /// @@ -269,6 +277,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { self.processor = parameters.processor() self.sampler = parameters.sampler() + self.maxTokens = parameters.maxTokens try prepare(input: .init(text: y), windowSize: parameters.prefillStepSize) } @@ -295,6 +304,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { self.processor = parameters.processor() self.sampler = parameters.sampler() + self.maxTokens = parameters.maxTokens try prepare(input: input, windowSize: parameters.prefillStepSize) } @@ -308,9 +318,11 @@ public struct TokenIterator: Sequence, IteratorProtocol { /// - processor: the logit processor /// - sampler: the logit sampler /// - prefillStepSize: optional prefill step size + /// - maxTokens: maximum number of tokens to generate public init( input: LMInput, model: any LanguageModel, cache: [KVCache]? = nil, - processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512 + processor: LogitProcessor?, sampler: LogitSampler, prefillStepSize: Int = 512, + maxTokens: Int? = nil ) throws { self.model = model self.y = input.text @@ -318,6 +330,7 @@ public struct TokenIterator: Sequence, IteratorProtocol { self.processor = processor self.sampler = sampler + self.maxTokens = maxTokens try prepare(input: input, windowSize: prefillStepSize) } @@ -365,6 +378,10 @@ public struct TokenIterator: Sequence, IteratorProtocol { } mutating public func next() -> Int? { + if let maxTokens, tokenCount >= maxTokens { + return nil + } + // save current value -- this will be returned let previousY = y @@ -373,6 +390,8 @@ public struct TokenIterator: Sequence, IteratorProtocol { y = .init(tokens: token) asyncEval(token) + tokenCount += 1 + return previousY.tokens.item(Int.self) } } @@ -413,24 +432,32 @@ public struct GenerateResult: Sendable { /// output text public let output: String + /// The number of tokens included in the input prompt. + public var promptTokenCount: Int { inputText.tokens.size } + + /// The number of tokens generated by the language model. + public var generationTokenCount: Int { tokens.count } + /// time to process the prompt / generate the first token public let promptTime: TimeInterval /// time to generate the remaining tokens public let generateTime: TimeInterval + /// The number of tokens processed per second during the prompt phase. public var promptTokensPerSecond: Double { Double(inputText.tokens.size) / promptTime } + /// The number of tokens generated per second during the generation phase. public var tokensPerSecond: Double { Double(tokens.count) / generateTime } public func summary() -> String { """ - Prompt: \(inputText.tokens.size) tokens, \(promptTokensPerSecond.formatted()) tokens/s - Generation: \(tokens.count) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s + Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s + Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s """ } } @@ -795,7 +822,7 @@ public struct GenerateCompletionInfo: Sendable { public let promptTime: TimeInterval /// The time interval (in seconds) taken to generate the output tokens. - public let generationTime: TimeInterval + public let generateTime: TimeInterval /// The number of tokens processed per second during the prompt phase. public var promptTokensPerSecond: Double { @@ -804,7 +831,7 @@ public struct GenerateCompletionInfo: Sendable { /// The number of tokens generated per second during the generation phase. public var tokensPerSecond: Double { - Double(generationTokenCount) / generationTime + Double(generationTokenCount) / generateTime } public init( @@ -816,7 +843,14 @@ public struct GenerateCompletionInfo: Sendable { self.promptTokenCount = promptTokenCount self.generationTokenCount = generationTokenCount self.promptTime = promptTime - self.generationTime = generationTime + self.generateTime = generationTime + } + + public func summary() -> String { + """ + Prompt: \(promptTokenCount) tokens, \(promptTokensPerSecond.formatted()) tokens/s + Generation: \(generationTokenCount) tokens, \(tokensPerSecond.formatted()) tokens/s, \(generateTime.formatted())s + """ } } @@ -825,9 +859,31 @@ public struct GenerateCompletionInfo: Sendable { /// This enum distinguishes between the following: /// - `.chunk`: A decoded string from one or more tokens generated by the language model. /// - `.info`: Metadata and performance statistics about the generation process. -public enum Generation { - /// A generated token represented as an integer. +public enum Generation: Sendable { + /// A generated token represented as a String case chunk(String) /// Completion information summarizing token counts and performance metrics. case info(GenerateCompletionInfo) + + /// Generated text or nil + public var chunk: String? { + switch self { + case .chunk(let string): string + case .info: nil + } + } + + /// Completion info or nil + public var info: GenerateCompletionInfo? { + switch self { + case .chunk: nil + case .info(let info): info + } + } + + /// Reducer that can be used with `throttle()` to gather elements into a batch + @Sendable + public static func collect(_ batch: [Generation]?, _ element: Generation) -> [Generation] { + (batch ?? []) + [element] + } } diff --git a/Tools/llm-tool/LLMTool.swift b/Tools/llm-tool/LLMTool.swift index 1e50f2fd..fb87d3f6 100644 --- a/Tools/llm-tool/LLMTool.swift +++ b/Tools/llm-tool/LLMTool.swift @@ -87,6 +87,7 @@ struct GenerateArguments: ParsableArguments, Sendable { var generateParameters: GenerateParameters { GenerateParameters( + maxTokens: maxTokens, temperature: temperature, topP: topP, repetitionPenalty: repetitionPenalty, repetitionContextSize: repetitionContextSize) } @@ -111,27 +112,18 @@ struct GenerateArguments: ParsableArguments, Sendable { func generate( input: LMInput, context: ModelContext - ) throws -> GenerateResult { - var detokenizer = NaiveStreamingDetokenizer(tokenizer: context.tokenizer) - - return try MLXLMCommon.generate( - input: input, parameters: generateParameters, context: context - ) { tokens in - if let last = tokens.last { - detokenizer.append(token: last) - } - - if let new = detokenizer.next() { - print(new, terminator: "") - fflush(stdout) - } - - if tokens.count >= maxTokens { - return .stop - } else { - return .more + ) async throws -> GenerateCompletionInfo { + for await item in try MLXLMCommon.generate( + input: input, parameters: generateParameters, context: context) + { + switch item { + case .chunk(let string): + print(string, terminator: "") + case .info(let info): + return info } } + fatalError("exited loop without seeing .info") } } @@ -317,7 +309,7 @@ struct EvaluateCommand: AsyncParsableCommand { let result = try await modelContainer.perform { [generate] context in let input = try await context.processor.prepare(input: userInput) - return try generate.generate(input: input, context: context) + return try await generate.generate(input: input, context: context) } if !generate.quiet { diff --git a/Tools/llm-tool/LoraCommands.swift b/Tools/llm-tool/LoraCommands.swift index 60c16216..7a9a1b8c 100644 --- a/Tools/llm-tool/LoraCommands.swift +++ b/Tools/llm-tool/LoraCommands.swift @@ -307,7 +307,7 @@ struct LoRAEvalCommand: AsyncParsableCommand { // generate and print the result let result = try await modelContainer.perform { [generate] context in let input = try await context.processor.prepare(input: .init(prompt: prompt)) - return try generate.generate(input: input, context: context) + return try await generate.generate(input: input, context: context) } if !generate.quiet { diff --git a/mlx-swift-examples.xcodeproj/project.pbxproj b/mlx-swift-examples.xcodeproj/project.pbxproj index 896e276b..daec81d7 100644 --- a/mlx-swift-examples.xcodeproj/project.pbxproj +++ b/mlx-swift-examples.xcodeproj/project.pbxproj @@ -36,6 +36,8 @@ C32A18482D00E1540092A5B6 /* MLX in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18472D00E1540092A5B6 /* MLX */; }; C32A184A2D00E1540092A5B6 /* MLXNN in Frameworks */ = {isa = PBXBuildFile; productRef = C32A18492D00E1540092A5B6 /* MLXNN */; }; C32A184C2D00E1540092A5B6 /* MLXOptimizers in Frameworks */ = {isa = PBXBuildFile; productRef = C32A184B2D00E1540092A5B6 /* MLXOptimizers */; }; + C32B4C6D2DA7136000EF663D /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C32B4C6C2DA7136000EF663D /* AsyncAlgorithms */; }; + C32B4C6F2DA71ADC00EF663D /* AsyncAlgorithms in Frameworks */ = {isa = PBXBuildFile; productRef = C32B4C6E2DA71ADC00EF663D /* AsyncAlgorithms */; }; C34E48F52B696F0B00FCB841 /* LLMTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E48F42B696F0B00FCB841 /* LLMTool.swift */; }; C34E49242B6A026F00FCB841 /* MNISTTool.swift in Sources */ = {isa = PBXBuildFile; fileRef = C34E49232B6A026F00FCB841 /* MNISTTool.swift */; }; C34E49292B6A028100FCB841 /* ArgumentParser in Frameworks */ = {isa = PBXBuildFile; productRef = C34E49282B6A028100FCB841 /* ArgumentParser */; }; @@ -254,6 +256,7 @@ buildActionMask = 2147483647; files = ( 0A8284222D13863900BEF338 /* MLXVLM in Frameworks */, + C32B4C6D2DA7136000EF663D /* AsyncAlgorithms in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -335,6 +338,7 @@ files = ( C32A18052CFFD19F0092A5B6 /* MLXLLM in Frameworks */, 81695B412BA373D300F260D8 /* MarkdownUI in Frameworks */, + C32B4C6F2DA71ADC00EF663D /* AsyncAlgorithms in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -611,6 +615,7 @@ name = VLMEval; packageProductDependencies = ( 0AC74ED92D136223003C90A7 /* MLXVLM */, + C32B4C6C2DA7136000EF663D /* AsyncAlgorithms */, ); productName = VLMEval; productReference = 0AC74EBB2D136221003C90A7 /* VLMEval.app */; @@ -804,6 +809,7 @@ packageProductDependencies = ( 81695B402BA373D300F260D8 /* MarkdownUI */, C32A18042CFFD19F0092A5B6 /* MLXLLM */, + C32B4C6E2DA71ADC00EF663D /* AsyncAlgorithms */, ); productName = LLMEval; productReference = C3A8B3DC2B92A29E0002EFB8 /* LLMEval.app */; @@ -864,6 +870,7 @@ C36BEFF02BC32A8C002D4AFE /* XCRemoteSwiftPackageReference "Progress" */, C397D8F22CD2F60B00B87EE2 /* XCLocalSwiftPackageReference "Source/.." */, C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */, + C32B4C6B2DA7132C00EF663D /* XCRemoteSwiftPackageReference "swift-async-algorithms" */, ); productRefGroup = C39273752B606A0A00368D5D /* Products */; projectDirPath = ""; @@ -2714,6 +2721,14 @@ minimumVersion = 0.21.2; }; }; + C32B4C6B2DA7132C00EF663D /* XCRemoteSwiftPackageReference "swift-async-algorithms" */ = { + isa = XCRemoteSwiftPackageReference; + repositoryURL = "https://github.com/apple/swift-async-algorithms.git"; + requirement = { + kind = upToNextMajorVersion; + minimumVersion = 1.0.3; + }; + }; C34E491A2B69C43600FCB841 /* XCRemoteSwiftPackageReference "GzipSwift" */ = { isa = XCRemoteSwiftPackageReference; repositoryURL = "https://github.com/1024jp/GzipSwift"; @@ -2811,6 +2826,16 @@ package = C32A18442D00E13E0092A5B6 /* XCRemoteSwiftPackageReference "mlx-swift" */; productName = MLXOptimizers; }; + C32B4C6C2DA7136000EF663D /* AsyncAlgorithms */ = { + isa = XCSwiftPackageProductDependency; + package = C32B4C6B2DA7132C00EF663D /* XCRemoteSwiftPackageReference "swift-async-algorithms" */; + productName = AsyncAlgorithms; + }; + C32B4C6E2DA71ADC00EF663D /* AsyncAlgorithms */ = { + isa = XCSwiftPackageProductDependency; + package = C32B4C6B2DA7132C00EF663D /* XCRemoteSwiftPackageReference "swift-async-algorithms" */; + productName = AsyncAlgorithms; + }; C34E49282B6A028100FCB841 /* ArgumentParser */ = { isa = XCSwiftPackageProductDependency; package = C392736E2B60699100368D5D /* XCRemoteSwiftPackageReference "swift-argument-parser" */; diff --git a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved index 3ed84667..b2c346d5 100644 --- a/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved +++ b/mlx-swift-examples.xcodeproj/project.xcworkspace/xcshareddata/swiftpm/Package.resolved @@ -1,5 +1,5 @@ { - "originHash" : "347ce608ed233db4ed416d22692a515e7f4fd2fd3eed7904f75bb8b35eb5366c", + "originHash" : "369f2014f0f4b1785f2b2642d3b4a3cbd3220a38b18d03ac9d74965949a0f0ba", "pins" : [ { "identity" : "gzipswift", @@ -55,6 +55,15 @@ "version" : "1.4.0" } }, + { + "identity" : "swift-async-algorithms", + "kind" : "remoteSourceControl", + "location" : "https://github.com/apple/swift-async-algorithms.git", + "state" : { + "revision" : "4c3ea81f81f0a25d0470188459c6d4bf20cf2f97", + "version" : "1.0.3" + } + }, { "identity" : "swift-cmark", "kind" : "remoteSourceControl",