Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,6 @@ runs/
exps/
wandb
wandb/

# generated calib data
datasets/
3 changes: 3 additions & 0 deletions deepcompressor/app/diffusion/dataset/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def __len__(self) -> int:
return len(self.filepaths)

def __getitem__(self, idx) -> dict[str, tp.Any]:

# TODO verfify ZImage data loading.

data = np.load(self.filepaths[idx], allow_pickle=True).item()
if isinstance(data["input_args"][0], str):
name = data["input_args"][0]
Expand Down
13 changes: 13 additions & 0 deletions deepcompressor/app/diffusion/dataset/calib.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
FluxSingleTransformerBlock,
FluxTransformerBlock,
)
from diffusers.models.transformers.transformer_z_image import ZImageTransformerBlock
from omniconfig import configclass

from deepcompressor.data.cache import (
Expand Down Expand Up @@ -172,6 +173,18 @@ def _init_cache(self, name: str, module: nn.Module) -> IOTensorsCache:
),
outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
)
elif isinstance(module, ZImageTransformerBlock):
return IOTensorsCache(
inputs=TensorsCache(
OrderedDict(
x=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
attn_mask=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
freqs_cis=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
# TODO verify
)
),
outputs=TensorCache(channels_dim=-1, reshape=LinearReshapeFn()),
)
elif isinstance(module, Attention):
return IOTensorsCache(
inputs=TensorsCache(
Expand Down
12 changes: 11 additions & 1 deletion deepcompressor/app/diffusion/dataset/collect/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
FluxTransformer2DModel,
PixArtTransformer2DModel,
SanaTransformer2DModel,
ZImageTransformer2DModel,
)
from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel

Expand Down Expand Up @@ -58,10 +59,19 @@ def __call__(
new_args.append(input_kwargs.pop("hidden_states"))
elif isinstance(module, FluxTransformer2DModel):
new_args.append(input_kwargs.pop("hidden_states"))
elif isinstance(module, ZImageTransformer2DModel):
new_args.append(input_kwargs.pop("x"))
new_args.append(input_kwargs.pop("t"))
new_args.append(input_kwargs.pop("cap_feats"))
else:
raise ValueError(f"Unknown model: {module}")
cache = tree_map(lambda x: x.cpu(), {"input_args": new_args, "input_kwargs": input_kwargs, "outputs": output})
split_cache = tree_split(cache)

if isinstance(module, ZImageTransformer2DModel):
# assume that batch size is 1.
split_cache = [cache]
else:
split_cache = tree_split(cache)

if isinstance(module, PixArtTransformer2DModel) and self.zero_redundancy:
for cache in split_cache:
Expand Down
10 changes: 10 additions & 0 deletions deepcompressor/app/diffusion/nn/patch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch.nn as nn
from diffusers.models.attention_processor import Attention
from diffusers.models.transformers.transformer_flux import FluxSingleTransformerBlock
from diffusers.models.transformers.transformer_z_image import ZImageTransformer2DModel, ZImageTransformerBlock

from deepcompressor.nn.patch.conv import ConcatConv2d, ShiftedConv2d
from deepcompressor.nn.patch.ff import convert_z_image_ff
from deepcompressor.nn.patch.linear import ConcatLinear, ShiftedLinear
from deepcompressor.utils import patch, tools

Expand Down Expand Up @@ -116,3 +118,11 @@ def replace_attn_processor(model: nn.Module) -> None:
logger.info(f"+ Replacing {name} processor with DiffusionAttentionProcessor.")
module.set_processor(DiffusionAttentionProcessor(module.processor))
tools.logging.Formatter.indent_dec()


def replace_zimage_feedforward(z_image_model: ZImageTransformer2DModel) -> None:
"""Replace custom FeedForward module in `ZImageTransformerBlock`s with standard FeedForward in diffusers lib."""
for _, module in z_image_model.named_modules():
if isinstance(module, ZImageTransformerBlock):
orig_ff = module.feed_forward
module.feed_forward = convert_z_image_ff(orig_ff)
Loading