4
4
import Accelerate
5
5
import CoreML
6
6
7
+ /// How to space timesteps for inference
8
+ public enum TimeStepSpacing {
9
+ case linspace
10
+ case leading
11
+ case karras
12
+ }
13
+
7
14
/// A scheduler used to compute a de-noised image
8
15
///
9
16
/// This implementation matches:
@@ -32,6 +39,8 @@ public final class DPMSolverMultistepScheduler: Scheduler {
32
39
public let solverOrder = 2
33
40
private( set) var lowerOrderStepped = 0
34
41
42
+ private var usingKarrasSigmas = false
43
+
35
44
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
36
45
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
37
46
public let useLowerOrderFinal = true
@@ -47,13 +56,15 @@ public final class DPMSolverMultistepScheduler: Scheduler {
47
56
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd
48
57
/// - betaStart: The starting value of beta for inference
49
58
/// - betaEnd: The end value for beta for inference
59
+ /// - timeStepSpacing: How to space time steps
50
60
/// - Returns: A scheduler ready for its first step
51
61
public init (
52
62
stepCount: Int = 50 ,
53
63
trainStepCount: Int = 1000 ,
54
64
betaSchedule: BetaSchedule = . scaledLinear,
55
65
betaStart: Float = 0.00085 ,
56
- betaEnd: Float = 0.012
66
+ betaEnd: Float = 0.012 ,
67
+ timeStepSpacing: TimeStepSpacing = . linspace
57
68
) {
58
69
self . trainStepCount = trainStepCount
59
70
self . inferenceStepCount = stepCount
@@ -72,20 +83,60 @@ public final class DPMSolverMultistepScheduler: Scheduler {
72
83
}
73
84
self . alphasCumProd = alphasCumProd
74
85
75
- // Currently we only support VP-type noise shedule
76
- self . alpha_t = vForce. sqrt ( self . alphasCumProd)
77
- self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
78
- self . lambda_t = zip ( self . alpha_t, self . sigma_t) . map { α, σ in log ( α) - log( σ) }
86
+ switch timeStepSpacing {
87
+ case . linspace:
88
+ self . timeSteps = linspace ( 0 , Float ( self . trainStepCount- 1 ) , stepCount+ 1 ) . dropFirst ( ) . reversed ( ) . map { Int ( round ( $0) ) }
89
+ self . alpha_t = vForce. sqrt ( self . alphasCumProd)
90
+ self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
91
+ case . leading:
92
+ let lastTimeStep = trainStepCount - 1
93
+ let stepRatio = lastTimeStep / ( stepCount + 1 )
94
+ // Creates integer timesteps by multiplying by ratio
95
+ self . timeSteps = ( 0 ... stepCount) . map { 1 + $0 * stepRatio } . dropFirst ( ) . reversed ( )
96
+ self . alpha_t = vForce. sqrt ( self . alphasCumProd)
97
+ self . sigma_t = vForce. sqrt ( vDSP. subtract ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) )
98
+ case . karras:
99
+ // sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
100
+ let scaled = vDSP. multiply (
101
+ subtraction: ( [ Float] ( repeating: 1 , count: self . alphasCumProd. count) , self . alphasCumProd) ,
102
+ subtraction: ( vDSP. divide ( 1 , self . alphasCumProd) , [ Float] ( repeating: 0 , count: self . alphasCumProd. count) )
103
+ )
104
+ let sigmas = vForce. sqrt ( scaled)
105
+ let logSigmas = sigmas. map { log ( $0) }
106
+
107
+ let sigmaMin = sigmas. first!
108
+ let sigmaMax = sigmas. last!
109
+ let rho : Float = 7
110
+ let ramp = linspace ( 0 , 1 , stepCount)
111
+ let minInvRho = pow ( sigmaMin, ( 1 / rho) )
112
+ let maxInvRho = pow ( sigmaMax, ( 1 / rho) )
79
113
80
- self . timeSteps = linspace ( 0 , Float ( self . trainStepCount- 1 ) , stepCount+ 1 ) . dropFirst ( ) . reversed ( ) . map { Int ( round ( $0) ) }
114
+ var karrasSigmas = ramp. map { pow ( maxInvRho + $0 * ( minInvRho - maxInvRho) , rho) }
115
+ let karrasTimeSteps = karrasSigmas. map { sigmaToTimestep ( sigma: $0, logSigmas: logSigmas) }
116
+ self . timeSteps = karrasTimeSteps
117
+
118
+ karrasSigmas. append ( karrasSigmas. last!)
119
+
120
+ self . alpha_t = vDSP. divide ( 1 , vForce. sqrt ( vDSP. add ( 1 , vDSP. square ( karrasSigmas) ) ) )
121
+ self . sigma_t = vDSP. multiply ( karrasSigmas, self . alpha_t)
122
+ usingKarrasSigmas = true
123
+ }
124
+
125
+ self . lambda_t = zip ( self . alpha_t, self . sigma_t) . map { α, σ in log ( α) - log( σ) }
126
+ }
127
+
128
+ func timestepToIndex( _ timestep: Int ) -> Int {
129
+ guard usingKarrasSigmas else { return timestep }
130
+ return self . timeSteps. firstIndex ( of: timestep) ?? 0
81
131
}
82
132
83
133
/// Convert the model output to the corresponding type the algorithm needs.
84
134
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
85
135
func convertModelOutput( modelOutput: MLShapedArray < Float32 > , timestep: Int , sample: MLShapedArray < Float32 > ) -> MLShapedArray < Float32 > {
86
136
assert ( modelOutput. scalarCount == sample. scalarCount)
87
137
let scalarCount = modelOutput. scalarCount
88
- let ( alpha_t, sigma_t) = ( self . alpha_t [ timestep] , self . sigma_t [ timestep] )
138
+ let sigmaIndex = timestepToIndex ( timestep)
139
+ let ( alpha_t, sigma_t) = ( self . alpha_t [ sigmaIndex] , self . sigma_t [ sigmaIndex] )
89
140
90
141
return MLShapedArray ( unsafeUninitializedShape: modelOutput. shape) { scalars, _ in
91
142
assert ( scalars. count == scalarCount)
@@ -108,9 +159,11 @@ public final class DPMSolverMultistepScheduler: Scheduler {
108
159
prevTimestep: Int ,
109
160
sample: MLShapedArray < Float32 >
110
161
) -> MLShapedArray < Float32 > {
111
- let ( p_lambda_t, lambda_s) = ( Double ( lambda_t [ prevTimestep] ) , Double ( lambda_t [ timestep] ) )
112
- let p_alpha_t = Double ( alpha_t [ prevTimestep] )
113
- let ( p_sigma_t, sigma_s) = ( Double ( sigma_t [ prevTimestep] ) , Double ( sigma_t [ timestep] ) )
162
+ let prevIndex = timestepToIndex ( prevTimestep)
163
+ let currIndex = timestepToIndex ( timestep)
164
+ let ( p_lambda_t, lambda_s) = ( Double ( lambda_t [ prevIndex] ) , Double ( lambda_t [ currIndex] ) )
165
+ let p_alpha_t = Double ( alpha_t [ prevIndex] )
166
+ let ( p_sigma_t, sigma_s) = ( Double ( sigma_t [ prevIndex] ) , Double ( sigma_t [ currIndex] ) )
114
167
let h = p_lambda_t - lambda_s
115
168
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
116
169
let x_t = weightedSum (
@@ -130,9 +183,13 @@ public final class DPMSolverMultistepScheduler: Scheduler {
130
183
) -> MLShapedArray < Float32 > {
131
184
let ( s0, s1) = ( timesteps [ back: 1 ] , timesteps [ back: 2 ] )
132
185
let ( m0, m1) = ( modelOutputs [ back: 1 ] , modelOutputs [ back: 2 ] )
133
- let ( p_lambda_t, lambda_s0, lambda_s1) = ( Double ( lambda_t [ t] ) , Double ( lambda_t [ s0] ) , Double ( lambda_t [ s1] ) )
134
- let p_alpha_t = Double ( alpha_t [ t] )
135
- let ( p_sigma_t, sigma_s0) = ( Double ( sigma_t [ t] ) , Double ( sigma_t [ s0] ) )
186
+ let ( p_lambda_t, lambda_s0, lambda_s1) = (
187
+ Double ( lambda_t [ timestepToIndex ( t) ] ) ,
188
+ Double ( lambda_t [ timestepToIndex ( s0) ] ) ,
189
+ Double ( lambda_t [ timestepToIndex ( s1) ] )
190
+ )
191
+ let p_alpha_t = Double ( alpha_t [ timestepToIndex ( t) ] )
192
+ let ( p_sigma_t, sigma_s0) = ( Double ( sigma_t [ timestepToIndex ( t) ] ) , Double ( sigma_t [ timestepToIndex ( s0) ] ) )
136
193
let ( h, h_0) = ( p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
137
194
let r0 = h_0 / h
138
195
let D0 = m0
@@ -186,3 +243,31 @@ public final class DPMSolverMultistepScheduler: Scheduler {
186
243
return prevSample
187
244
}
188
245
}
246
+
247
+ func sigmaToTimestep( sigma: Float , logSigmas: [ Float ] ) -> Int {
248
+ let logSigma = log ( sigma)
249
+ let dists = logSigmas. map { logSigma - $0 }
250
+
251
+ // last index that is not negative, clipped to last index - 1
252
+ var lowIndex = dists. reduce ( - 1 ) { partialResult, dist in
253
+ return dist >= 0 && partialResult < dists. endIndex- 2 ? partialResult + 1 : partialResult
254
+ }
255
+ lowIndex = max ( lowIndex, 0 )
256
+ let highIndex = lowIndex + 1
257
+
258
+ let low = logSigmas [ lowIndex]
259
+ let high = logSigmas [ highIndex]
260
+
261
+ // Interpolate sigmas
262
+ let w = ( ( low - logSigma) / ( low - high) ) . clipped ( to: 0 ... 1 )
263
+
264
+ // transform interpolated value to time range
265
+ let t = ( 1 - w) * Float( lowIndex) + w * Float( highIndex)
266
+ return Int ( round ( t) )
267
+ }
268
+
269
+ extension FloatingPoint {
270
+ func clipped( to range: ClosedRange < Self > ) -> Self {
271
+ return min ( max ( self , range. lowerBound) , range. upperBound)
272
+ }
273
+ }
0 commit comments