Skip to content

Commit f08b725

Browse files
committed
theoretically add diffusion spacing
1 parent f66c75a commit f08b725

File tree

4 files changed

+110
-16
lines changed

4 files changed

+110
-16
lines changed

diffusion/diffusion.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
from collections import namedtuple
2-
from types import SimpleNamespace
32
from abc import ABC, abstractmethod
43

54
import torch as th
@@ -81,13 +80,21 @@ def get_eps_and_var(model_output, *, C):
8180

8281

8382
class GaussianDiffusion(ABC):
84-
def __init__(self, betas):
83+
def __init__(self, betas, timestep_map=None):
84+
# TODO: Get rid of this "check"?
85+
# It's never once caught a bug ...
86+
def check(x):
87+
assert x.shape == (self.n_timesteps,)
88+
8589
self.n_timesteps = betas.shape[0]
90+
self.betas = betas
91+
self.timestep_map = timestep_map if timestep_map else range(self.n_timesteps)
92+
assert len(self.timestep_map) == self.n_timesteps
93+
8694
alphas = 1 - betas
8795
alphas_cumprod = th.cumprod(alphas, dim=0)
88-
89-
def check(x):
90-
assert x.shape == (self.n_timesteps,)
96+
self.alphas_cumprod = alphas_cumprod
97+
check(self.alphas_cumprod)
9198

9299
# TODO(verify): By prepending 1, the 1st beta is 0
93100
# This represents the initial image, which as a mean but no variance (since it's ground truth)
@@ -231,7 +238,6 @@ def p_sample_loop_progressive(self, *, model, noise, shape, threshold, device):
231238
assert xor(
232239
noise, shape
233240
), f"Either noise or shape must be specified, but not both or neither"
234-
indices = list(range(self.n_timesteps))[::-1]
235241

236242
img = N = None
237243
if noise:
@@ -240,8 +246,8 @@ def p_sample_loop_progressive(self, *, model, noise, shape, threshold, device):
240246
img = th.randn(shape, device=device)
241247
N = img.shape[0]
242248

243-
for i in indices:
244-
t = th.tensor([i] * N, device=device)
249+
for _t in self.timestep_map[::-1]:
250+
t = th.tensor([_t] * N, device=device)
245251
with th.no_grad():
246252
img = self.p_sample(model=model, x_t=img, t=t, threshold=threshold)
247253
yield img

diffusion/spaced.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
from .diffusion import LearnedVarianceGaussianDiffusion
2+
3+
### Start OpenAI Code
4+
def space_timesteps(num_timesteps, section_counts):
5+
"""
6+
Create a list of timesteps to use from an original diffusion process,
7+
given the number of timesteps we want to take from equally-sized portions
8+
of the original process.
9+
For example, if there's 300 timesteps and the section counts are [10,15,20]
10+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
11+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
12+
If the stride is a string starting with "ddim", then the fixed striding
13+
from the DDIM paper is used, and only one section is allowed.
14+
:param num_timesteps: the number of diffusion steps in the original
15+
process to divide up.
16+
:param section_counts: either a list of numbers, or a string containing
17+
comma-separated numbers, indicating the step count
18+
per section. As a special case, use "ddimN" where N
19+
is a number of steps to use the striding from the
20+
DDIM paper.
21+
:return: a set of diffusion steps from the original process to use.
22+
"""
23+
if isinstance(section_counts, str):
24+
if section_counts.startswith("ddim"):
25+
desired_count = int(section_counts[len("ddim") :])
26+
for i in range(1, num_timesteps):
27+
if len(range(0, num_timesteps, i)) == desired_count:
28+
return set(range(0, num_timesteps, i))
29+
raise ValueError(
30+
f"cannot create exactly {num_timesteps} steps with an integer stride"
31+
)
32+
section_counts = [int(x) for x in section_counts.split(",")]
33+
size_per = num_timesteps // len(section_counts)
34+
extra = num_timesteps % len(section_counts)
35+
start_idx = 0
36+
all_steps = []
37+
for i, section_count in enumerate(section_counts):
38+
size = size_per + (1 if i < extra else 0)
39+
if size < section_count:
40+
raise ValueError(
41+
f"cannot divide section of {size} steps into {section_count}"
42+
)
43+
if section_count <= 1:
44+
frac_stride = 1
45+
else:
46+
frac_stride = (size - 1) / (section_count - 1)
47+
cur_idx = 0.0
48+
taken_steps = []
49+
for _ in range(section_count):
50+
taken_steps.append(start_idx + round(cur_idx))
51+
cur_idx += frac_stride
52+
all_steps += taken_steps
53+
start_idx += size
54+
return set(all_steps)
55+
56+
57+
### End OpenAI Code
58+
59+
60+
def create_map_and_betas(betas, use_timesteps):
61+
use_timesteps = set(use_timesteps)
62+
63+
# Doesn't matter what diffusion we use since the constructor
64+
# is defined in the base class
65+
base_diffusion = LearnedVarianceGaussianDiffusion(betas)
66+
as_t_m_1 = 1
67+
68+
map_generation_step_to_timestep = []
69+
70+
new_betas = []
71+
for i, as_t in enumerate(base_diffusion.alphas_cumprod):
72+
if i in use_timesteps:
73+
new_betas.append(1 - (as_t / as_t_m_1))
74+
as_t_m_1 = as_t
75+
map_generation_step_to_timestep.append(i)
76+
77+
assert len(new_betas) == len(map_generation_step_to_timestep)
78+
return map_generation_step_to_timestep, new_betas

iddpm.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ class DiffusionParams(hp.Hparams):
4949
schedule: str = hp.required("diffusion schedule")
5050
learn_sigma: bool = hp.required("whether to learn sigma")
5151

52-
def initialize_object(self):
52+
def initialize_object(self, diffusion_kwargs):
5353
assert self.schedule == "cosine", "Only cosine schedule is supported"
54-
betas = cosine_betas(self.steps)
54+
if not diffusion_kwargs:
55+
diffusion_kwargs = {"betas": cosine_betas(self.steps)}
5556
return (
56-
LearnedVarianceGaussianDiffusion(betas)
57+
LearnedVarianceGaussianDiffusion(**diffusion_kwargs)
5758
if self.learn_sigma
58-
else FixedSmallVarianceGaussianDiffusion(betas)
59+
else FixedSmallVarianceGaussianDiffusion(**diffusion_kwargs)
5960
)
6061

6162

@@ -64,10 +65,10 @@ class IDDPMConfig(hp.Hparams):
6465
unet: UNetParams = hp.required("the UNet model")
6566
diffusion: DiffusionParams = hp.required("Gaussian diffusion parameters")
6667

67-
def initialize_object(self):
68+
def initialize_object(self, diffusion_kwargs=None):
6869
unet, diffusion = (
6970
self.unet.initialize_object(),
70-
self.diffusion.initialize_object(),
71+
self.diffusion.initialize_object(diffusion_kwargs),
7172
)
7273
return IDDPM(unet, diffusion)
7374

sample.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44
import torchvision
55
import torch as th
66
import typer
7+
from diffusion.spaced import create_map_and_betas, space_timesteps
78

8-
from iddpm import IDDPMConfig, IDDPM
9+
from iddpm import IDDPMConfig
910

1011

1112
def img_to_bytes(img):
@@ -18,11 +19,19 @@ def run(
1819
out_dir: Path = typer.Option(...),
1920
checkpoint: Path = typer.Option(...),
2021
samples: int = typer.Option(...),
22+
spacing: str = typer.Option(default=None),
2123
):
2224
assert checkpoint.is_file(), f"Checkpoint file not found: {checkpoint}"
2325

2426
config = IDDPMConfig.create(config, None, cli_args=False)
25-
iddpm = config.initialize_object()
27+
28+
spacing = [1] if spacing is None else [int(x) for x in spacing.split(",")]
29+
spacing = space_timesteps(iddpm.diffusion.n_timesteps, spacing)
30+
timestep_map, betas = create_map_and_betas(iddpm.diffusion.betas, spacing)
31+
32+
iddpm = config.initialize_object(
33+
diffusion=dict(timestep_map=timestep_map, betas=betas)
34+
)
2635

2736
out_dir.mkdir(parents=True)
2837

0 commit comments

Comments
 (0)