Skip to content

Commit a7bf219

Browse files
Improve BPETokenizer.readMerges performance (#169)
From 240ms to 19ms on M1 MacBook Pro Co-authored-by: Alejandro Isaza <[email protected]>
1 parent fbef6c3 commit a7bf219

File tree

1 file changed

+35
-12
lines changed

1 file changed

+35
-12
lines changed

swift/StableDiffusion/tokenizer/BPETokenizer+Reading.swift

+35-12
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,42 @@ extension BPETokenizer {
1717

1818
/// Read merges.txt file at URL into a dictionary mapping bigrams to the line number/rank/priority
1919
static func readMerges(url: URL) throws -> [TokenPair: Int] {
20-
let content = try String(contentsOf: url)
21-
let lines = content.split(separator: "\n")
22-
23-
let merges: [(TokenPair, Int)] = try lines.enumerated().compactMap { (index, line) in
24-
if line.hasPrefix("#") {
25-
return nil
26-
}
27-
let pair = line.split(separator: " ")
28-
if pair.count != 2 {
29-
throw FileReadError.invalidMergeFileLine(index+1)
20+
let data = try Data(contentsOf: url)
21+
var merges = [TokenPair: Int]()
22+
var index = 0
23+
var line = [UInt8]()
24+
for byte in data {
25+
if byte == UInt8(ascii: "\n") {
26+
if let pair = try parseMergesLine(line, index: index) {
27+
merges[pair] = index
28+
}
29+
line.removeAll(keepingCapacity: true)
30+
index += 1
31+
} else {
32+
line.append(byte)
3033
}
31-
return (TokenPair(String(pair[0]), String(pair[1])),index)
3234
}
33-
return [TokenPair : Int](uniqueKeysWithValues: merges)
35+
36+
return merges
37+
}
38+
39+
static func parseMergesLine(_ line: [UInt8], index: Int) throws -> TokenPair? {
40+
if line.isEmpty || line.first == UInt8(ascii: "#") {
41+
return nil
42+
}
43+
let pair = line.split(separator: UInt8(ascii: " "))
44+
if pair.count != 2 {
45+
throw FileReadError.invalidMergeFileLine(index + 1)
46+
}
47+
return TokenPair( String(bytes: pair[0]), String(bytes: pair[1]))
48+
}
49+
}
50+
51+
extension String {
52+
init(bytes: some Collection<UInt8>) {
53+
self.init(unsafeUninitializedCapacity: bytes.count) { pointer in
54+
_ = pointer.initialize(fromContentsOf: bytes)
55+
return bytes.count
56+
}
3457
}
3558
}

0 commit comments

Comments
 (0)