Skip to content

Commit 55386e0

Browse files
guopengfpre-commit-ci[bot]alkamidKumoLiu
authored
Adding Tailored ControlNet Implementations into Generative Model Application (#7875)
Fixes #7874. ### Description Integrating a tailored ControlNet model into the generative model application to enable the training using high-dimensional 3D images (up to 512 x 512 x 768). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Pengfei Guo <[email protected]> Signed-off-by: alkamid <[email protected]> Signed-off-by: Pengfei Guo <[email protected]> Signed-off-by: YunLiu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Adam Klimont <[email protected]> Co-authored-by: YunLiu <[email protected]>
1 parent 2f62b81 commit 55386e0

File tree

8 files changed

+380
-2
lines changed

8 files changed

+380
-2
lines changed

.pre-commit-config.yaml

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ ci:
99

1010
repos:
1111
- repo: https://github.com/pre-commit/pre-commit-hooks
12-
rev: v4.4.0
12+
rev: v4.6.0
1313
hooks:
1414
- id: end-of-file-fixer
1515
- id: trailing-whitespace

monai/apps/generation/__init__.py

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
from typing import TYPE_CHECKING, Sequence, cast
15+
16+
import torch
17+
18+
from monai.utils import optional_import
19+
20+
ControlNet, has_controlnet = optional_import("generative.networks.nets.controlnet", name="ControlNet")
21+
get_timestep_embedding, has_get_timestep_embedding = optional_import(
22+
"generative.networks.nets.diffusion_model_unet", name="get_timestep_embedding"
23+
)
24+
25+
if TYPE_CHECKING:
26+
from generative.networks.nets.controlnet import ControlNet as ControlNetType
27+
else:
28+
ControlNetType = cast(type, ControlNet)
29+
30+
31+
class ControlNetMaisi(ControlNetType):
32+
"""
33+
Control network for diffusion models based on Zhang and Agrawala "Adding Conditional Control to Text-to-Image
34+
Diffusion Models" (https://arxiv.org/abs/2302.05543)
35+
36+
Args:
37+
spatial_dims: number of spatial dimensions.
38+
in_channels: number of input channels.
39+
num_res_blocks: number of residual blocks (see ResnetBlock) per level.
40+
num_channels: tuple of block output channels.
41+
attention_levels: list of levels to add attention.
42+
norm_num_groups: number of groups for the normalization.
43+
norm_eps: epsilon for the normalization.
44+
resblock_updown: if True use residual blocks for up/downsampling.
45+
num_head_channels: number of channels in each attention head.
46+
with_conditioning: if True add spatial transformers to perform conditioning.
47+
transformer_num_layers: number of layers of Transformer blocks to use.
48+
cross_attention_dim: number of context dimensions to use.
49+
num_class_embeds: if specified (as an int), then this model will be class-conditional with `num_class_embeds`
50+
classes.
51+
upcast_attention: if True, upcast attention operations to full precision.
52+
use_flash_attention: if True, use flash attention for a memory efficient attention mechanism.
53+
conditioning_embedding_in_channels: number of input channels for the conditioning embedding.
54+
conditioning_embedding_num_channels: number of channels for the blocks in the conditioning embedding.
55+
use_checkpointing: if True, use activation checkpointing to save memory.
56+
"""
57+
58+
def __init__(
59+
self,
60+
spatial_dims: int,
61+
in_channels: int,
62+
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
63+
num_channels: Sequence[int] = (32, 64, 64, 64),
64+
attention_levels: Sequence[bool] = (False, False, True, True),
65+
norm_num_groups: int = 32,
66+
norm_eps: float = 1e-6,
67+
resblock_updown: bool = False,
68+
num_head_channels: int | Sequence[int] = 8,
69+
with_conditioning: bool = False,
70+
transformer_num_layers: int = 1,
71+
cross_attention_dim: int | None = None,
72+
num_class_embeds: int | None = None,
73+
upcast_attention: bool = False,
74+
use_flash_attention: bool = False,
75+
conditioning_embedding_in_channels: int = 1,
76+
conditioning_embedding_num_channels: Sequence[int] | None = (16, 32, 96, 256),
77+
use_checkpointing: bool = True,
78+
) -> None:
79+
super().__init__(
80+
spatial_dims,
81+
in_channels,
82+
num_res_blocks,
83+
num_channels,
84+
attention_levels,
85+
norm_num_groups,
86+
norm_eps,
87+
resblock_updown,
88+
num_head_channels,
89+
with_conditioning,
90+
transformer_num_layers,
91+
cross_attention_dim,
92+
num_class_embeds,
93+
upcast_attention,
94+
use_flash_attention,
95+
conditioning_embedding_in_channels,
96+
conditioning_embedding_num_channels,
97+
)
98+
self.use_checkpointing = use_checkpointing
99+
100+
def forward(
101+
self,
102+
x: torch.Tensor,
103+
timesteps: torch.Tensor,
104+
controlnet_cond: torch.Tensor,
105+
conditioning_scale: float = 1.0,
106+
context: torch.Tensor | None = None,
107+
class_labels: torch.Tensor | None = None,
108+
) -> tuple[Sequence[torch.Tensor], torch.Tensor]:
109+
emb = self._prepare_time_and_class_embedding(x, timesteps, class_labels)
110+
h = self._apply_initial_convolution(x)
111+
if self.use_checkpointing:
112+
controlnet_cond = torch.utils.checkpoint.checkpoint(
113+
self.controlnet_cond_embedding, controlnet_cond, use_reentrant=False
114+
)
115+
else:
116+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
117+
h += controlnet_cond
118+
down_block_res_samples, h = self._apply_down_blocks(emb, context, h)
119+
h = self._apply_mid_block(emb, context, h)
120+
down_block_res_samples, mid_block_res_sample = self._apply_controlnet_blocks(h, down_block_res_samples)
121+
# scaling
122+
down_block_res_samples = [h * conditioning_scale for h in down_block_res_samples]
123+
mid_block_res_sample *= conditioning_scale
124+
125+
return down_block_res_samples, mid_block_res_sample
126+
127+
def _prepare_time_and_class_embedding(self, x, timesteps, class_labels):
128+
# 1. time
129+
t_emb = get_timestep_embedding(timesteps, self.block_out_channels[0])
130+
131+
# timesteps does not contain any weights and will always return f32 tensors
132+
# but time_embedding might actually be running in fp16. so we need to cast here.
133+
# there might be better ways to encapsulate this.
134+
t_emb = t_emb.to(dtype=x.dtype)
135+
emb = self.time_embed(t_emb)
136+
137+
# 2. class
138+
if self.num_class_embeds is not None:
139+
if class_labels is None:
140+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
141+
class_emb = self.class_embedding(class_labels)
142+
class_emb = class_emb.to(dtype=x.dtype)
143+
emb = emb + class_emb
144+
145+
return emb
146+
147+
def _apply_initial_convolution(self, x):
148+
# 3. initial convolution
149+
h = self.conv_in(x)
150+
return h
151+
152+
def _apply_down_blocks(self, emb, context, h):
153+
# 4. down
154+
if context is not None and self.with_conditioning is False:
155+
raise ValueError("model should have with_conditioning = True if context is provided")
156+
down_block_res_samples: list[torch.Tensor] = [h]
157+
for downsample_block in self.down_blocks:
158+
h, res_samples = downsample_block(hidden_states=h, temb=emb, context=context)
159+
for residual in res_samples:
160+
down_block_res_samples.append(residual)
161+
162+
return down_block_res_samples, h
163+
164+
def _apply_mid_block(self, emb, context, h):
165+
# 5. mid
166+
h = self.middle_block(hidden_states=h, temb=emb, context=context)
167+
return h
168+
169+
def _apply_controlnet_blocks(self, h, down_block_res_samples):
170+
# 6. Control net blocks
171+
controlnet_down_block_res_samples = []
172+
for down_block_res_sample, controlnet_block in zip(down_block_res_samples, self.controlnet_down_blocks):
173+
down_block_res_sample = controlnet_block(down_block_res_sample)
174+
controlnet_down_block_res_samples.append(down_block_res_sample)
175+
176+
mid_block_res_sample = self.controlnet_mid_block(h)
177+
178+
return controlnet_down_block_res_samples, mid_block_res_sample

monai/data/torchscript_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def load_net_with_metadata(
116116
Returns:
117117
Triple containing loaded object, metadata dict, and extra files dict containing other file data if present
118118
"""
119-
extra_files = {f: "" for f in more_extra_files}
119+
extra_files = dict.fromkeys(more_extra_files, "")
120120
extra_files[METADATA_FILENAME] = ""
121121

122122
jit_obj = torch.jit.load(filename_prefix_or_stream, map_location, extra_files)

requirements-dev.txt

+1
Original file line numberDiff line numberDiff line change
@@ -58,3 +58,4 @@ lpips==0.1.4
5858
nvidia-ml-py
5959
huggingface_hub
6060
pyamg>=5.0.0
61+
git+https://github.com/KumoLiu/GenerativeModels.git@cuda#egg=monai-generative

0 commit comments

Comments
 (0)