Skip to content

Commit 94814cf

Browse files
authored
Sigma and timestep spacing configuration for DPMSolverMultistepScheduler (#265)
* `leading` timestep spacing for DPMSolverMultistepScheduler * Karras sigmas and timesteps * Make sigmas/timesteps externally configurable
1 parent d0d05ce commit 94814cf

4 files changed

+102
-15
lines changed

swift/StableDiffusion/pipeline/DPMSolverMultistepScheduler.swift

Lines changed: 98 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,13 @@
44
import Accelerate
55
import CoreML
66

7+
/// How to space timesteps for inference
8+
public enum TimeStepSpacing {
9+
case linspace
10+
case leading
11+
case karras
12+
}
13+
714
/// A scheduler used to compute a de-noised image
815
///
916
/// This implementation matches:
@@ -32,6 +39,8 @@ public final class DPMSolverMultistepScheduler: Scheduler {
3239
public let solverOrder = 2
3340
private(set) var lowerOrderStepped = 0
3441

42+
private var usingKarrasSigmas = false
43+
3544
/// Whether to use lower-order solvers in the final steps. Only valid for less than 15 inference steps.
3645
/// We empirically find this trick can stabilize the sampling of DPM-Solver, especially with 10 or fewer steps.
3746
public let useLowerOrderFinal = true
@@ -47,13 +56,15 @@ public final class DPMSolverMultistepScheduler: Scheduler {
4756
/// - betaSchedule: Method to schedule betas from betaStart to betaEnd
4857
/// - betaStart: The starting value of beta for inference
4958
/// - betaEnd: The end value for beta for inference
59+
/// - timeStepSpacing: How to space time steps
5060
/// - Returns: A scheduler ready for its first step
5161
public init(
5262
stepCount: Int = 50,
5363
trainStepCount: Int = 1000,
5464
betaSchedule: BetaSchedule = .scaledLinear,
5565
betaStart: Float = 0.00085,
56-
betaEnd: Float = 0.012
66+
betaEnd: Float = 0.012,
67+
timeStepSpacing: TimeStepSpacing = .linspace
5768
) {
5869
self.trainStepCount = trainStepCount
5970
self.inferenceStepCount = stepCount
@@ -72,20 +83,60 @@ public final class DPMSolverMultistepScheduler: Scheduler {
7283
}
7384
self.alphasCumProd = alphasCumProd
7485

75-
// Currently we only support VP-type noise shedule
76-
self.alpha_t = vForce.sqrt(self.alphasCumProd)
77-
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
78-
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) }
86+
switch timeStepSpacing {
87+
case .linspace:
88+
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount+1).dropFirst().reversed().map { Int(round($0)) }
89+
self.alpha_t = vForce.sqrt(self.alphasCumProd)
90+
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
91+
case .leading:
92+
let lastTimeStep = trainStepCount - 1
93+
let stepRatio = lastTimeStep / (stepCount + 1)
94+
// Creates integer timesteps by multiplying by ratio
95+
self.timeSteps = (0...stepCount).map { 1 + $0 * stepRatio }.dropFirst().reversed()
96+
self.alpha_t = vForce.sqrt(self.alphasCumProd)
97+
self.sigma_t = vForce.sqrt(vDSP.subtract([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd))
98+
case .karras:
99+
// sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
100+
let scaled = vDSP.multiply(
101+
subtraction: ([Float](repeating: 1, count: self.alphasCumProd.count), self.alphasCumProd),
102+
subtraction: (vDSP.divide(1, self.alphasCumProd), [Float](repeating: 0, count: self.alphasCumProd.count))
103+
)
104+
let sigmas = vForce.sqrt(scaled)
105+
let logSigmas = sigmas.map { log($0) }
106+
107+
let sigmaMin = sigmas.first!
108+
let sigmaMax = sigmas.last!
109+
let rho: Float = 7
110+
let ramp = linspace(0, 1, stepCount)
111+
let minInvRho = pow(sigmaMin, (1 / rho))
112+
let maxInvRho = pow(sigmaMax, (1 / rho))
79113

80-
self.timeSteps = linspace(0, Float(self.trainStepCount-1), stepCount+1).dropFirst().reversed().map { Int(round($0)) }
114+
var karrasSigmas = ramp.map { pow(maxInvRho + $0 * (minInvRho - maxInvRho), rho) }
115+
let karrasTimeSteps = karrasSigmas.map { sigmaToTimestep(sigma: $0, logSigmas: logSigmas) }
116+
self.timeSteps = karrasTimeSteps
117+
118+
karrasSigmas.append(karrasSigmas.last!)
119+
120+
self.alpha_t = vDSP.divide(1, vForce.sqrt(vDSP.add(1, vDSP.square(karrasSigmas))))
121+
self.sigma_t = vDSP.multiply(karrasSigmas, self.alpha_t)
122+
usingKarrasSigmas = true
123+
}
124+
125+
self.lambda_t = zip(self.alpha_t, self.sigma_t).map { α, σ in log(α) - log(σ) }
126+
}
127+
128+
func timestepToIndex(_ timestep: Int) -> Int {
129+
guard usingKarrasSigmas else { return timestep }
130+
return self.timeSteps.firstIndex(of: timestep) ?? 0
81131
}
82132

83133
/// Convert the model output to the corresponding type the algorithm needs.
84134
/// This implementation is for second-order DPM-Solver++ assuming epsilon prediction.
85135
func convertModelOutput(modelOutput: MLShapedArray<Float32>, timestep: Int, sample: MLShapedArray<Float32>) -> MLShapedArray<Float32> {
86136
assert(modelOutput.scalarCount == sample.scalarCount)
87137
let scalarCount = modelOutput.scalarCount
88-
let (alpha_t, sigma_t) = (self.alpha_t[timestep], self.sigma_t[timestep])
138+
let sigmaIndex = timestepToIndex(timestep)
139+
let (alpha_t, sigma_t) = (self.alpha_t[sigmaIndex], self.sigma_t[sigmaIndex])
89140

90141
return MLShapedArray(unsafeUninitializedShape: modelOutput.shape) { scalars, _ in
91142
assert(scalars.count == scalarCount)
@@ -108,9 +159,11 @@ public final class DPMSolverMultistepScheduler: Scheduler {
108159
prevTimestep: Int,
109160
sample: MLShapedArray<Float32>
110161
) -> MLShapedArray<Float32> {
111-
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevTimestep]), Double(lambda_t[timestep]))
112-
let p_alpha_t = Double(alpha_t[prevTimestep])
113-
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevTimestep]), Double(sigma_t[timestep]))
162+
let prevIndex = timestepToIndex(prevTimestep)
163+
let currIndex = timestepToIndex(timestep)
164+
let (p_lambda_t, lambda_s) = (Double(lambda_t[prevIndex]), Double(lambda_t[currIndex]))
165+
let p_alpha_t = Double(alpha_t[prevIndex])
166+
let (p_sigma_t, sigma_s) = (Double(sigma_t[prevIndex]), Double(sigma_t[currIndex]))
114167
let h = p_lambda_t - lambda_s
115168
// x_t = (sigma_t / sigma_s) * sample - (alpha_t * (torch.exp(-h) - 1.0)) * model_output
116169
let x_t = weightedSum(
@@ -130,9 +183,13 @@ public final class DPMSolverMultistepScheduler: Scheduler {
130183
) -> MLShapedArray<Float32> {
131184
let (s0, s1) = (timesteps[back: 1], timesteps[back: 2])
132185
let (m0, m1) = (modelOutputs[back: 1], modelOutputs[back: 2])
133-
let (p_lambda_t, lambda_s0, lambda_s1) = (Double(lambda_t[t]), Double(lambda_t[s0]), Double(lambda_t[s1]))
134-
let p_alpha_t = Double(alpha_t[t])
135-
let (p_sigma_t, sigma_s0) = (Double(sigma_t[t]), Double(sigma_t[s0]))
186+
let (p_lambda_t, lambda_s0, lambda_s1) = (
187+
Double(lambda_t[timestepToIndex(t)]),
188+
Double(lambda_t[timestepToIndex(s0)]),
189+
Double(lambda_t[timestepToIndex(s1)])
190+
)
191+
let p_alpha_t = Double(alpha_t[timestepToIndex(t)])
192+
let (p_sigma_t, sigma_s0) = (Double(sigma_t[timestepToIndex(t)]), Double(sigma_t[timestepToIndex(s0)]))
136193
let (h, h_0) = (p_lambda_t - lambda_s0, lambda_s0 - lambda_s1)
137194
let r0 = h_0 / h
138195
let D0 = m0
@@ -186,3 +243,31 @@ public final class DPMSolverMultistepScheduler: Scheduler {
186243
return prevSample
187244
}
188245
}
246+
247+
func sigmaToTimestep(sigma: Float, logSigmas: [Float]) -> Int {
248+
let logSigma = log(sigma)
249+
let dists = logSigmas.map { logSigma - $0 }
250+
251+
// last index that is not negative, clipped to last index - 1
252+
var lowIndex = dists.reduce(-1) { partialResult, dist in
253+
return dist >= 0 && partialResult < dists.endIndex-2 ? partialResult + 1 : partialResult
254+
}
255+
lowIndex = max(lowIndex, 0)
256+
let highIndex = lowIndex + 1
257+
258+
let low = logSigmas[lowIndex]
259+
let high = logSigmas[highIndex]
260+
261+
// Interpolate sigmas
262+
let w = ((low - logSigma) / (low - high)).clipped(to: 0...1)
263+
264+
// transform interpolated value to time range
265+
let t = (1 - w) * Float(lowIndex) + w * Float(highIndex)
266+
return Int(round(t))
267+
}
268+
269+
extension FloatingPoint {
270+
func clipped(to range: ClosedRange<Self>) -> Self {
271+
return min(max(self, range.lowerBound), range.upperBound)
272+
}
273+
}

swift/StableDiffusion/pipeline/StableDiffusionPipeline.Configuration.swift

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ public struct PipelineConfiguration: Hashable {
4444
public var useDenoisedIntermediates: Bool = false
4545
/// The type of Scheduler to use.
4646
public var schedulerType: StableDiffusionScheduler = .pndmScheduler
47+
/// The spacing to use for scheduler sigmas and time steps. Only supported when using `.dpmppScheduler`.
48+
public var schedulerTimestepSpacing: TimeStepSpacing = .linspace
4749
/// The type of RNG to use
4850
public var rngType: StableDiffusionRNG = .numpyRNG
4951
/// Scale factor to use on the latent after encoding

swift/StableDiffusion/pipeline/StableDiffusionPipeline.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ public struct StableDiffusionPipeline: StableDiffusionPipelineProtocol {
228228
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
229229
switch config.schedulerType {
230230
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
231-
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
231+
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount, timeStepSpacing: config.schedulerTimestepSpacing)
232232
}
233233
}
234234

swift/StableDiffusion/pipeline/StableDiffusionXLPipeline.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ public struct StableDiffusionXLPipeline: StableDiffusionPipelineProtocol {
175175
let scheduler: [Scheduler] = (0..<config.imageCount).map { _ in
176176
switch config.schedulerType {
177177
case .pndmScheduler: return PNDMScheduler(stepCount: config.stepCount)
178-
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount)
178+
case .dpmSolverMultistepScheduler: return DPMSolverMultistepScheduler(stepCount: config.stepCount, timeStepSpacing: config.schedulerTimestepSpacing)
179179
}
180180
}
181181

0 commit comments

Comments
 (0)