Skip to content

Support for Chroma - Flux1 Schnell distilled with CFG #7355

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 46 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
0a4bc66
Upload files for Chroma Implementation
silveroxides Mar 22, 2025
79f4601
Remove trailing whitespace
silveroxides Mar 22, 2025
2710f77
trim more trailing whitespace..oops
silveroxides Mar 22, 2025
ca73500
remove unused imports
silveroxides Mar 22, 2025
fda35f3
Add supported_inference_dtypes
silveroxides Mar 22, 2025
9fa34e7
Set min_length to 0 and remove attention_mask=True
silveroxides Mar 23, 2025
4fa663d
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Mar 23, 2025
fe6e1fa
Set min_length to 1
silveroxides Mar 25, 2025
9f70cfb
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Mar 25, 2025
f04b502
get_mdulations added from blepping and minor changes
silveroxides Mar 25, 2025
159df22
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Mar 25, 2025
b7da8e2
Add lora conversion if statement in lora.py
silveroxides Mar 27, 2025
bf339e8
Update supported_models.py
silveroxides Mar 27, 2025
2378c63
update model_base.py
silveroxides Mar 27, 2025
73f10ea
add uptream commits
silveroxides Mar 27, 2025
a737f2f
Merge branch 'master' into chroma-support
silveroxides Mar 27, 2025
cb3e388
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 1, 2025
7012015
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 3, 2025
4c33a57
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 4, 2025
1ca0353
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 5, 2025
de3f3ea
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 6, 2025
cb6ece9
set modelType.FLOW, will cause beta scheduler to work properly
silveroxides Apr 6, 2025
c4f6874
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 7, 2025
e1af413
Adjust memory usage factor and remove unnecessary code
silveroxides Apr 8, 2025
3d375a1
fix mistake
silveroxides Apr 8, 2025
caa52b9
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 8, 2025
1c76711
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 10, 2025
fc978a7
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 11, 2025
2c24819
reduce code duplication
silveroxides Apr 11, 2025
cd3d2d5
remove unused imports
silveroxides Apr 11, 2025
d781e17
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 12, 2025
7b39ea6
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 13, 2025
69a6147
merge chroma-support with upstream
silveroxides Apr 16, 2025
c791e95
refactor for upstream sync
silveroxides Apr 16, 2025
acaeca4
Merge pull request #3 from silveroxides/intermediary
silveroxides Apr 16, 2025
7e98854
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 17, 2025
a9f208d
sync chroma-support with upstream via syncbranch patch
silveroxides Apr 18, 2025
5f7be81
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 18, 2025
ea9c851
Merge branch 'master' into chroma-support
silveroxides Apr 20, 2025
6c3e841
Update sd.py
silveroxides Apr 20, 2025
9038fe3
Merge branch 'chroma-support' into syncbranch-04-23
silveroxides Apr 23, 2025
38858a6
Merge pull request #4 from silveroxides/syncbranch-04-23
silveroxides Apr 23, 2025
48f6b97
Add Chroma as option for the OptimalStepsScheduler node
silveroxides Apr 23, 2025
79acf8f
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 26, 2025
6931228
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 29, 2025
e7d4935
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides Apr 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 183 additions & 0 deletions comfy/ldm/chroma/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
import torch
from torch import Tensor, nn

from .math import attention
from comfy.ldm.flux.layers import (
MLPEmbedder,
RMSNorm,
QKNorm,
SelfAttention,
ModulationOut,
)



class ChromaModulationOut(ModulationOut):
@classmethod
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut:
return cls(
shift=tensor[:, offset : offset + 1, :],
scale=tensor[:, offset + 1 : offset + 2, :],
gate=tensor[:, offset + 2 : offset + 3, :],
)




class Approximator(nn.Module):
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None):
super().__init__()
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device)
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)])
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device)

@property
def device(self):
# Get the device of the module (assumes all parameters are on the same device)
return next(self.parameters()).device

def forward(self, x: Tensor) -> Tensor:
x = self.in_proj(x)

for layer, norms in zip(self.layers, self.norms):
x = x + layer(norms(x))

x = self.out_proj(x)

return x


class DoubleStreamBlock(nn.Module):
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None):
super().__init__()

mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)

self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.img_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)

self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations)

self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.txt_mlp = nn.Sequential(
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device),
nn.GELU(approximate="tanh"),
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device),
)
self.flipped_img_txt = flipped_img_txt

def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None):
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec

# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v)

# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v)

# run actual attention
attn = attention(torch.cat((txt_q, img_q), dim=2),
torch.cat((txt_k, img_k), dim=2),
torch.cat((txt_v, img_v), dim=2),
pe=pe, mask=attn_mask)

txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :]

# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift)

# calculate the txt bloks
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift)

if txt.dtype == torch.float16:
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504)

return img, txt


class SingleStreamBlock(nn.Module):
"""
A DiT block with parallel linear layers as described in
https://arxiv.org/abs/2302.05442 and adapted modulation interface.
"""

def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: float = None,
dtype=None,
device=None,
operations=None
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5

self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device)
# proj and mlp_out
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device)

self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations)

self.hidden_size = hidden_size
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)

self.mlp_act = nn.GELU(approximate="tanh")

def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor:
mod = vec
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1)

q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k = self.norm(q, k, v)

# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask)
# compute activation in mlp stream, cat again and run second linear layer
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2))
x += mod.gate * output
if x.dtype == torch.float16:
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504)
return x


class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device)

def forward(self, x: Tensor, vec: Tensor) -> Tensor:
shift, scale = vec
shift = shift.squeeze(1)
scale = scale.squeeze(1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x
44 changes: 44 additions & 0 deletions comfy/ldm/chroma/math.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import torch
from einops import rearrange
from torch import Tensor

from comfy.ldm.modules.attention import optimized_attention
import comfy.model_management


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor:
q_shape = q.shape
k_shape = k.shape

q = q.float().reshape(*q.shape[:-1], -1, 1, 2)
k = k.float().reshape(*k.shape[:-1], -1, 1, 2)
Comment on lines +13 to +14
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason to cast these + the one in apply_rope below to float? This part could be reused from comfy/ldm/flux/math.py if not, since it seems to be the same otherwise.

q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)

heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask)
return x


def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
assert dim % 2 == 0
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled():
device = torch.device("cpu")
else:
device = pos.device

scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device)
omega = 1.0 / (theta**scale)
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega)
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1)
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2)
return out.to(dtype=torch.float32, device=pos.device)


def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor):
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2)
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2)
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1]
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1]
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk)

Loading