forked from ml-explore/mlx-swift-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathMNISTTool.swift
102 lines (76 loc) · 2.79 KB
/
MNISTTool.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Copyright © 2024 Apple Inc.
import ArgumentParser
import Foundation
import MLX
import MLXNN
import MLXOptimizers
import MLXRandom
import MNIST
@main
struct MNISTTool: AsyncParsableCommand {
static var configuration = CommandConfiguration(
abstract: "Command line tool for training mnist models",
subcommands: [Train.self],
defaultSubcommand: Train.self)
}
extension MLX.DeviceType: ExpressibleByArgument {
public init?(argument: String) {
self.init(rawValue: argument)
}
}
struct Train: AsyncParsableCommand {
@Option(name: .long, help: "Directory with the training data")
var data: String
@Option(name: .long, help: "The PRNG seed")
var seed: UInt64 = 0
@Option var batchSize = 256
@Option var epochs = 20
@Option var learningRate: Float = 1e-1
@Option var device = DeviceType.gpu
@Flag var compile = false
func run() async throws {
Device.setDefault(device: Device(device))
MLXRandom.seed(seed)
var generator: RandomNumberGenerator = SplitMix64(seed: seed)
// load the data
let url = URL(filePath: data)
try FileManager.default.createDirectory(at: url, withIntermediateDirectories: true)
try await download(into: url)
let data = try load(from: url)
let trainImages = data[.init(.training, .images)]!
let trainLabels = data[.init(.training, .labels)]!
let testImages = data[.init(.test, .images)]!
let testLabels = data[.init(.test, .labels)]!
// create the model
let model = LeNet()
eval(model.parameters())
let lg = valueAndGrad(model: model, loss)
let optimizer = SGD(learningRate: learningRate)
func step(_ x: MLXArray, _ y: MLXArray) -> MLXArray {
let (loss, grads) = lg(model, x, y)
optimizer.update(model: model, gradients: grads)
return loss
}
let resolvedStep =
compile
? MLX.compile(inputs: [model, optimizer], outputs: [model, optimizer], step) : step
for e in 0 ..< epochs {
let start = Date.timeIntervalSinceReferenceDate
for (x, y) in iterateBatches(
batchSize: batchSize, x: trainImages, y: trainLabels, using: &generator)
{
_ = resolvedStep(x, y)
// eval the parameters so the next iteration is independent
eval(model, optimizer)
}
let accuracy = eval(model: model, x: testImages, y: testLabels)
let end = Date.timeIntervalSinceReferenceDate
print(
"""
Epoch \(e): test accuracy \(accuracy.item(Float.self).formatted())
Time: \((end - start).formatted())
"""
)
}
}
}