-
Notifications
You must be signed in to change notification settings - Fork 8.6k
Support for Chroma - Flux1 Schnell distilled with CFG #7355
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
comfyanonymous
merged 46 commits into
comfyanonymous:master
from
silveroxides:chroma-support
May 1, 2025
+667
−4
Merged
Changes from all commits
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
0a4bc66
Upload files for Chroma Implementation
silveroxides 79f4601
Remove trailing whitespace
silveroxides 2710f77
trim more trailing whitespace..oops
silveroxides ca73500
remove unused imports
silveroxides fda35f3
Add supported_inference_dtypes
silveroxides 9fa34e7
Set min_length to 0 and remove attention_mask=True
silveroxides 4fa663d
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides fe6e1fa
Set min_length to 1
silveroxides 9f70cfb
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides f04b502
get_mdulations added from blepping and minor changes
silveroxides 159df22
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides b7da8e2
Add lora conversion if statement in lora.py
silveroxides bf339e8
Update supported_models.py
silveroxides 2378c63
update model_base.py
silveroxides 73f10ea
add uptream commits
silveroxides a737f2f
Merge branch 'master' into chroma-support
silveroxides cb3e388
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 7012015
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 4c33a57
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 1ca0353
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides de3f3ea
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides cb6ece9
set modelType.FLOW, will cause beta scheduler to work properly
silveroxides c4f6874
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides e1af413
Adjust memory usage factor and remove unnecessary code
silveroxides 3d375a1
fix mistake
silveroxides caa52b9
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 1c76711
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides fc978a7
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 2c24819
reduce code duplication
silveroxides cd3d2d5
remove unused imports
silveroxides d781e17
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 7b39ea6
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 69a6147
merge chroma-support with upstream
silveroxides c791e95
refactor for upstream sync
silveroxides acaeca4
Merge pull request #3 from silveroxides/intermediary
silveroxides 7e98854
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides a9f208d
sync chroma-support with upstream via syncbranch patch
silveroxides 5f7be81
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides ea9c851
Merge branch 'master' into chroma-support
silveroxides 6c3e841
Update sd.py
silveroxides 9038fe3
Merge branch 'chroma-support' into syncbranch-04-23
silveroxides 38858a6
Merge pull request #4 from silveroxides/syncbranch-04-23
silveroxides 48f6b97
Add Chroma as option for the OptimalStepsScheduler node
silveroxides 79acf8f
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides 6931228
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides e7d4935
Merge branch 'comfyanonymous:master' into chroma-support
silveroxides File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,183 @@ | ||
import torch | ||
from torch import Tensor, nn | ||
|
||
from .math import attention | ||
from comfy.ldm.flux.layers import ( | ||
MLPEmbedder, | ||
RMSNorm, | ||
QKNorm, | ||
SelfAttention, | ||
ModulationOut, | ||
) | ||
|
||
|
||
|
||
class ChromaModulationOut(ModulationOut): | ||
@classmethod | ||
def from_offset(cls, tensor: torch.Tensor, offset: int = 0) -> ModulationOut: | ||
return cls( | ||
shift=tensor[:, offset : offset + 1, :], | ||
scale=tensor[:, offset + 1 : offset + 2, :], | ||
gate=tensor[:, offset + 2 : offset + 3, :], | ||
) | ||
|
||
|
||
|
||
|
||
class Approximator(nn.Module): | ||
def __init__(self, in_dim: int, out_dim: int, hidden_dim: int, n_layers = 5, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
self.in_proj = operations.Linear(in_dim, hidden_dim, bias=True, dtype=dtype, device=device) | ||
self.layers = nn.ModuleList([MLPEmbedder(hidden_dim, hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) | ||
self.norms = nn.ModuleList([RMSNorm(hidden_dim, dtype=dtype, device=device, operations=operations) for x in range( n_layers)]) | ||
self.out_proj = operations.Linear(hidden_dim, out_dim, dtype=dtype, device=device) | ||
|
||
@property | ||
def device(self): | ||
# Get the device of the module (assumes all parameters are on the same device) | ||
return next(self.parameters()).device | ||
|
||
def forward(self, x: Tensor) -> Tensor: | ||
x = self.in_proj(x) | ||
|
||
for layer, norms in zip(self.layers, self.norms): | ||
x = x + layer(norms(x)) | ||
|
||
x = self.out_proj(x) | ||
|
||
return x | ||
|
||
|
||
class DoubleStreamBlock(nn.Module): | ||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False, flipped_img_txt=False, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
|
||
mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||
self.num_heads = num_heads | ||
self.hidden_size = hidden_size | ||
self.img_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
self.img_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) | ||
|
||
self.img_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
self.img_mlp = nn.Sequential( | ||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), | ||
nn.GELU(approximate="tanh"), | ||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), | ||
) | ||
|
||
self.txt_norm1 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
self.txt_attn = SelfAttention(dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias, dtype=dtype, device=device, operations=operations) | ||
|
||
self.txt_norm2 = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
self.txt_mlp = nn.Sequential( | ||
operations.Linear(hidden_size, mlp_hidden_dim, bias=True, dtype=dtype, device=device), | ||
nn.GELU(approximate="tanh"), | ||
operations.Linear(mlp_hidden_dim, hidden_size, bias=True, dtype=dtype, device=device), | ||
) | ||
self.flipped_img_txt = flipped_img_txt | ||
|
||
def forward(self, img: Tensor, txt: Tensor, pe: Tensor, vec: Tensor, attn_mask=None): | ||
(img_mod1, img_mod2), (txt_mod1, txt_mod2) = vec | ||
|
||
# prepare image for attention | ||
img_modulated = self.img_norm1(img) | ||
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift | ||
img_qkv = self.img_attn.qkv(img_modulated) | ||
img_q, img_k, img_v = img_qkv.view(img_qkv.shape[0], img_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | ||
img_q, img_k = self.img_attn.norm(img_q, img_k, img_v) | ||
|
||
# prepare txt for attention | ||
txt_modulated = self.txt_norm1(txt) | ||
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift | ||
txt_qkv = self.txt_attn.qkv(txt_modulated) | ||
txt_q, txt_k, txt_v = txt_qkv.view(txt_qkv.shape[0], txt_qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | ||
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k, txt_v) | ||
|
||
# run actual attention | ||
attn = attention(torch.cat((txt_q, img_q), dim=2), | ||
torch.cat((txt_k, img_k), dim=2), | ||
torch.cat((txt_v, img_v), dim=2), | ||
pe=pe, mask=attn_mask) | ||
|
||
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1] :] | ||
|
||
# calculate the img bloks | ||
img = img + img_mod1.gate * self.img_attn.proj(img_attn) | ||
img = img + img_mod2.gate * self.img_mlp((1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift) | ||
|
||
# calculate the txt bloks | ||
txt += txt_mod1.gate * self.txt_attn.proj(txt_attn) | ||
txt += txt_mod2.gate * self.txt_mlp((1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift) | ||
|
||
if txt.dtype == torch.float16: | ||
txt = torch.nan_to_num(txt, nan=0.0, posinf=65504, neginf=-65504) | ||
|
||
return img, txt | ||
|
||
|
||
class SingleStreamBlock(nn.Module): | ||
""" | ||
A DiT block with parallel linear layers as described in | ||
https://arxiv.org/abs/2302.05442 and adapted modulation interface. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
hidden_size: int, | ||
num_heads: int, | ||
mlp_ratio: float = 4.0, | ||
qk_scale: float = None, | ||
dtype=None, | ||
device=None, | ||
operations=None | ||
): | ||
super().__init__() | ||
self.hidden_dim = hidden_size | ||
self.num_heads = num_heads | ||
head_dim = hidden_size // num_heads | ||
self.scale = qk_scale or head_dim**-0.5 | ||
|
||
self.mlp_hidden_dim = int(hidden_size * mlp_ratio) | ||
# qkv and mlp_in | ||
self.linear1 = operations.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim, dtype=dtype, device=device) | ||
# proj and mlp_out | ||
self.linear2 = operations.Linear(hidden_size + self.mlp_hidden_dim, hidden_size, dtype=dtype, device=device) | ||
|
||
self.norm = QKNorm(head_dim, dtype=dtype, device=device, operations=operations) | ||
|
||
self.hidden_size = hidden_size | ||
self.pre_norm = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
|
||
self.mlp_act = nn.GELU(approximate="tanh") | ||
|
||
def forward(self, x: Tensor, pe: Tensor, vec: Tensor, attn_mask=None) -> Tensor: | ||
mod = vec | ||
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift | ||
qkv, mlp = torch.split(self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1) | ||
|
||
q, k, v = qkv.view(qkv.shape[0], qkv.shape[1], 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) | ||
q, k = self.norm(q, k, v) | ||
|
||
# compute attention | ||
attn = attention(q, k, v, pe=pe, mask=attn_mask) | ||
# compute activation in mlp stream, cat again and run second linear layer | ||
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) | ||
x += mod.gate * output | ||
if x.dtype == torch.float16: | ||
x = torch.nan_to_num(x, nan=0.0, posinf=65504, neginf=-65504) | ||
return x | ||
|
||
|
||
class LastLayer(nn.Module): | ||
def __init__(self, hidden_size: int, patch_size: int, out_channels: int, dtype=None, device=None, operations=None): | ||
super().__init__() | ||
self.norm_final = operations.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device) | ||
self.linear = operations.Linear(hidden_size, out_channels, bias=True, dtype=dtype, device=device) | ||
|
||
def forward(self, x: Tensor, vec: Tensor) -> Tensor: | ||
shift, scale = vec | ||
shift = shift.squeeze(1) | ||
scale = scale.squeeze(1) | ||
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] | ||
x = self.linear(x) | ||
return x |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import torch | ||
from einops import rearrange | ||
from torch import Tensor | ||
|
||
from comfy.ldm.modules.attention import optimized_attention | ||
import comfy.model_management | ||
|
||
|
||
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None) -> Tensor: | ||
q_shape = q.shape | ||
k_shape = k.shape | ||
|
||
q = q.float().reshape(*q.shape[:-1], -1, 1, 2) | ||
k = k.float().reshape(*k.shape[:-1], -1, 1, 2) | ||
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v) | ||
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v) | ||
|
||
heads = q.shape[1] | ||
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask) | ||
return x | ||
|
||
|
||
def rope(pos: Tensor, dim: int, theta: int) -> Tensor: | ||
assert dim % 2 == 0 | ||
if comfy.model_management.is_device_mps(pos.device) or comfy.model_management.is_intel_xpu() or comfy.model_management.is_directml_enabled(): | ||
device = torch.device("cpu") | ||
else: | ||
device = pos.device | ||
|
||
scale = torch.linspace(0, (dim - 2) / dim, steps=dim//2, dtype=torch.float64, device=device) | ||
omega = 1.0 / (theta**scale) | ||
out = torch.einsum("...n,d->...nd", pos.to(dtype=torch.float32, device=device), omega) | ||
out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) | ||
out = rearrange(out, "b n d (i j) -> b n d i j", i=2, j=2) | ||
return out.to(dtype=torch.float32, device=pos.device) | ||
|
||
|
||
def apply_rope(xq: Tensor, xk: Tensor, freqs_cis: Tensor): | ||
xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) | ||
xk_ = xk.float().reshape(*xk.shape[:-1], -1, 1, 2) | ||
xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] | ||
xk_out = freqs_cis[..., 0] * xk_[..., 0] + freqs_cis[..., 1] * xk_[..., 1] | ||
return xq_out.reshape(*xq.shape).type_as(xq), xk_out.reshape(*xk.shape).type_as(xk) | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any reason to cast these + the one in
apply_rope
below to float? This part could be reused fromcomfy/ldm/flux/math.py
if not, since it seems to be the same otherwise.