Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
64 commits
Select commit Hold shift + click to select a range
00184d1
Add environment variable to opt out of #10302 (forced disablement of …
sfinktah Oct 24, 2025
995c073
Replace TORCH_BACKENDS_CUDNN_ENABLED with TORCH_AMD_CUDNN_ENABLED (mo…
sfinktah Oct 24, 2025
dd5af0c
convert Tripo API nodes to V3 schema (#10469)
bigcat88 Oct 24, 2025
426cde3
Remove useless function (#10472)
comfyanonymous Oct 24, 2025
e86b79a
convert Gemini API nodes to V3 schema (#10476)
bigcat88 Oct 25, 2025
098a352
Add warning for torch-directml usage (#10482)
comfyanonymous Oct 26, 2025
f6bbc1a
Fix mistake. (#10484)
comfyanonymous Oct 26, 2025
9d529e5
fix(api-nodes): random issues on Windows by capturing general OSError…
bigcat88 Oct 26, 2025
c170fd2
Bump portable deps workflow to torch cu130 python 3.13.9 (#10493)
comfyanonymous Oct 27, 2025
601ee17
Add a bat to run comfyui portable without api nodes. (#10504)
comfyanonymous Oct 28, 2025
c305dee
Update template to 0.2.3 (#10503)
comfyui-wiki Oct 28, 2025
55bad30
feat(api-nodes): add LTXV API nodes (#10496)
bigcat88 Oct 28, 2025
6abc30a
Update template to 0.2.4 (#10505)
comfyui-wiki Oct 28, 2025
614b8d3
frontend bump to 1.28.8 (#10506)
Kosinkadink Oct 28, 2025
f2bb323
ComfyUI version v0.3.67
comfyanonymous Oct 28, 2025
b61a40c
Bump stable portable to cu130 python 3.13.9 (#10508)
comfyanonymous Oct 28, 2025
8cf2ba4
Remove comfy api key from queue api. (#10502)
comfyanonymous Oct 28, 2025
3bea4ef
Tell users to update nvidia drivers if problem with portable. (#10510)
comfyanonymous Oct 28, 2025
22e40d2
Tell users to update their nvidia drivers if portable doesn't start. …
comfyanonymous Oct 28, 2025
8817f8f
Mixed Precision Quantization System (#10498)
contentis Oct 28, 2025
d202c2b
execution: Allow a subgraph nodes to execute multiple times (#10499)
rattus128 Oct 28, 2025
210f7a1
convert nodes_recraft.py to V3 schema (#10507)
bigcat88 Oct 28, 2025
3fa7a5c
Speed up offloading using pinned memory. (#10526)
comfyanonymous Oct 29, 2025
e525673
Fix issue. (#10527)
comfyanonymous Oct 29, 2025
a4eb32a
inserted missing is_amd() check
sfinktah Oct 29, 2025
6c14f3a
use new API client in Luma and Minimax nodes (#10528)
bigcat88 Oct 29, 2025
1a58087
Reduce memory usage for fp8 scaled op. (#10531)
comfyanonymous Oct 29, 2025
ec4fc2a
Fix case of weights not being unpinned. (#10533)
comfyanonymous Oct 29, 2025
ab7ab5b
Fix Race condition in --async-offload that can cause corruption (#10501)
rattus128 Oct 29, 2025
25de7b1
Try to fix slow load issue on low ram hardware with pinned mem. (#10536)
comfyanonymous Oct 29, 2025
906c089
Fix small performance regression with fp8 fast and scaled fp8. (#10537)
comfyanonymous Oct 29, 2025
998bf60
Add units/info for the numbers displayed on 'load completely' and 'lo…
Kosinkadink Oct 29, 2025
163b629
use new API client in Pixverse and Ideogram nodes (#10543)
bigcat88 Oct 30, 2025
dfac946
fix img2img operation in Dall2 node (#10552)
bigcat88 Oct 30, 2025
513b0c4
Add RAM Pressure cache mode (#10454)
rattus128 Oct 30, 2025
614cf98
Add a ScaleROPE node. Currently only works on WAN models. (#10559)
comfyanonymous Oct 31, 2025
27d1bd8
Fix rope scaling. (#10560)
comfyanonymous Oct 31, 2025
7f374e4
ScaleROPE now works on Lumina models. (#10578)
comfyanonymous Oct 31, 2025
c58c13b
Fix torch compile regression on fp8 ops. (#10580)
comfyanonymous Nov 1, 2025
5f109fe
added 12s-20s as available output durations for the LTXV API nodes (#…
bigcat88 Nov 1, 2025
20182a3
convert StabilityAI to use new API client (#10582)
bigcat88 Nov 1, 2025
44869ff
Fix issue with pinned memory. (#10597)
comfyanonymous Nov 1, 2025
135fa49
Small speed improvements to --async-offload (#10593)
rattus128 Nov 1, 2025
97ff9fa
Clarify help text for --fast argument (#10609)
comfyanonymous Nov 2, 2025
6d6a18b
fix(api-nodes-cloud): stop using sub-folder and absolute path for out…
bigcat88 Nov 3, 2025
88df172
fix(caching): treat bytes as hashable (#10567)
EverNebula Nov 3, 2025
1f3f7a2
convert nodes_hypernetwork.py to V3 schema (#10583)
bigcat88 Nov 3, 2025
e617cdd
convert nodes_openai.py to V3 schema (#10604)
bigcat88 Nov 3, 2025
4e2110c
feat(Pika-API-nodes): use new API client (#10608)
bigcat88 Nov 3, 2025
e974e55
chore: update embedded docs to v0.3.1 (#10614)
comfyui-wiki Nov 3, 2025
958a171
People should update their pytorch versions. (#10618)
comfyanonymous Nov 3, 2025
0652cb8
Speed up torch.compile (#10620)
comfyanonymous Nov 3, 2025
e199c8c
Fixes (#10621)
comfyanonymous Nov 3, 2025
6b88478
Bring back fp8 torch compile performance to what it should be. (#10622)
comfyanonymous Nov 4, 2025
0f4ef3a
This seems to slow things down slightly on Linux. (#10624)
comfyanonymous Nov 4, 2025
af4b7b5
More fp8 torch.compile regressions fixed. (#10625)
comfyanonymous Nov 4, 2025
9c71a66
chore: update workflow templates to v0.2.11 (#10634)
comfyui-wiki Nov 4, 2025
a389ee0
caching: Handle None outputs tuple case (#10637)
rattus128 Nov 4, 2025
7f3e4d4
Limit amount of pinned memory on windows to prevent issues. (#10638)
comfyanonymous Nov 4, 2025
265adad
ComfyUI version v0.3.68
comfyanonymous Nov 5, 2025
4cd8818
Use single apply_rope function across models (#10547)
contentis Nov 5, 2025
c4a6b38
Lower ltxv mem usage to what it was before previous pr. (#10643)
comfyanonymous Nov 5, 2025
58db886
Add env TORCH_AMD_CUDNN_ENABLED
sfinktah Nov 5, 2025
175c22d
Merge remote-tracking branch 'origin/sfink-amd-cudnn-envvar' into sfi…
sfinktah Nov 5, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
..\python_embeded\python.exe -s ..\ComfyUI\main.py --windows-standalone-build --disable-api-nodes
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause
1 change: 1 addition & 0 deletions .ci/windows_nvidia_base_files/run_nvidia_gpu.bat
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
.\python_embeded\python.exe -s ComfyUI\main.py --windows-standalone-build --fast fp16_accumulation
echo If you see this and ComfyUI did not start try updating your Nvidia Drivers to the latest.
pause
4 changes: 2 additions & 2 deletions .github/workflows/release-stable-all.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,9 @@ jobs:
uses: ./.github/workflows/stable-release.yml
with:
git_tag: ${{ inputs.git_tag }}
cache_tag: "cu129"
cache_tag: "cu130"
python_minor: "13"
python_patch: "6"
python_patch: "9"
rel_name: "nvidia"
rel_extra_name: ""
test_release: true
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/windows_release_dependencies.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ on:
description: 'cuda version'
required: true
type: string
default: "129"
default: "130"

python_minor:
description: 'python minor version'
Expand All @@ -29,7 +29,7 @@ on:
description: 'python patch version'
required: true
type: string
default: "6"
default: "9"
# push:
# branches:
# - master
Expand Down
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,8 @@ Simply download, extract with [7-Zip](https://7-zip.org) and run. Make sure you

If you have trouble extracting it, right click the file -> properties -> unblock

Update your Nvidia drivers if it doesn't start.

#### Alternative Downloads:

[Experimental portable for AMD GPUs](https://github.com/comfyanonymous/ComfyUI/releases/latest/download/ComfyUI_windows_portable_amd.7z)
Expand Down
4 changes: 3 additions & 1 deletion comfy/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class LatentPreviewMethod(enum.Enum):
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
cache_group.add_argument("--cache-ram", nargs='?', const=4.0, type=float, default=0, help="Use RAM pressure caching with the specified headroom threshold. If available RAM drops below the threhold the cache remove large items to free RAM. Default 4GB")

attn_group = parser.add_mutually_exclusive_group()
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
Expand Down Expand Up @@ -144,8 +145,9 @@ class PerformanceFeature(enum.Enum):
Fp8MatrixMultiplication = "fp8_matrix_mult"
CublasOps = "cublas_ops"
AutoTune = "autotune"
PinnedMem = "pinned_memory"

parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))

parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
parser.add_argument("--disable-mmap", action="store_true", help="Don't use mmap when loading safetensors.")
Expand Down
17 changes: 10 additions & 7 deletions comfy/controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,11 +310,13 @@ def __init__(self, in_features: int, out_features: int, bias: bool = True,
self.bias = None

def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
x = torch.nn.functional.linear(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias)
else:
return torch.nn.functional.linear(input, weight, bias)
x = torch.nn.functional.linear(input, weight, bias)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x

class Conv2d(torch.nn.Module, comfy.ops.CastWeightBiasOp):
def __init__(
Expand Down Expand Up @@ -350,12 +352,13 @@ def __init__(


def forward(self, input):
weight, bias = comfy.ops.cast_bias_weight(self, input)
weight, bias, offload_stream = comfy.ops.cast_bias_weight(self, input, offloadable=True)
if self.up is not None:
return torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
x = torch.nn.functional.conv2d(input, weight + (torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1))).reshape(self.weight.shape).type(input.dtype), bias, self.stride, self.padding, self.dilation, self.groups)
else:
return torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)

x = torch.nn.functional.conv2d(input, weight, bias, self.stride, self.padding, self.dilation, self.groups)
comfy.ops.uncast_bias_weight(self, weight, bias, offload_stream)
return x

class ControlLora(ControlNet):
def __init__(self, control_weights, global_average_pooling=False, model_options={}): #TODO? model_options
Expand Down
4 changes: 2 additions & 2 deletions comfy/ldm/flux/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,8 +195,8 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
txt_attn, img_attn = attn[:, : txt.shape[1]], attn[:, txt.shape[1]:]

# calculate the img bloks
img = img + apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img = img + apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)
img += apply_mod(self.img_attn.proj(img_attn), img_mod1.gate, None, modulation_dims_img)
img += apply_mod(self.img_mlp(apply_mod(self.img_norm2(img), (1 + img_mod2.scale), img_mod2.shift, modulation_dims_img)), img_mod2.gate, None, modulation_dims_img)

# calculate the txt bloks
txt += apply_mod(self.txt_attn.proj(txt_attn), txt_mod1.gate, None, modulation_dims_txt)
Expand Down
10 changes: 1 addition & 9 deletions comfy/ldm/flux/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,7 @@


def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor, mask=None, transformer_options={}) -> Tensor:
q_shape = q.shape
k_shape = k.shape

if pe is not None:
q = q.to(dtype=pe.dtype).reshape(*q.shape[:-1], -1, 1, 2)
k = k.to(dtype=pe.dtype).reshape(*k.shape[:-1], -1, 1, 2)
q = (pe[..., 0] * q[..., 0] + pe[..., 1] * q[..., 1]).reshape(*q_shape).type_as(v)
k = (pe[..., 0] * k[..., 0] + pe[..., 1] * k[..., 1]).reshape(*k_shape).type_as(v)

q, k = apply_rope(q, k, pe)
heads = q.shape[1]
x = optimized_attention(q, k, v, heads, skip_reshape=True, mask=mask, transformer_options=transformer_options)
return x
Expand Down
86 changes: 35 additions & 51 deletions comfy/ldm/lightricks/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@
import comfy.patcher_extension
import comfy.ldm.modules.attention
import comfy.ldm.common_dit
from einops import rearrange
import math
from typing import Dict, Optional, Tuple

from .symmetric_patchifier import SymmetricPatchifier, latent_to_pixel_coords

from comfy.ldm.flux.math import apply_rope1

def get_timestep_embedding(
timesteps: torch.Tensor,
Expand Down Expand Up @@ -238,20 +237,6 @@ def forward(self, x):
return self.net(x)


def apply_rotary_emb(input_tensor, freqs_cis): #TODO: remove duplicate funcs and pick the best/fastest one
cos_freqs = freqs_cis[0]
sin_freqs = freqs_cis[1]

t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2)
t1, t2 = t_dup.unbind(dim=-1)
t_dup = torch.stack((-t2, t1), dim=-1)
input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)")

out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs

return out


class CrossAttention(nn.Module):
def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., attn_precision=None, dtype=None, device=None, operations=None):
super().__init__()
Expand Down Expand Up @@ -281,8 +266,8 @@ def forward(self, x, context=None, mask=None, pe=None, transformer_options={}):
k = self.k_norm(k)

if pe is not None:
q = apply_rotary_emb(q, pe)
k = apply_rotary_emb(k, pe)
q = apply_rope1(q.unsqueeze(1), pe).squeeze(1)
k = apply_rope1(k.unsqueeze(1), pe).squeeze(1)

if mask is None:
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
Expand All @@ -306,12 +291,17 @@ def __init__(self, dim, n_heads, d_head, context_dim=None, attn_precision=None,
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)

x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, transformer_options=transformer_options) * gate_msa
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input

x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)

y = comfy.ldm.common_dit.rms_norm(x) * (1 + scale_mlp) + shift_mlp
x += self.ff(y) * gate_mlp
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
x.addcmul_(self.ff(y), gate_mlp)

return x

Expand All @@ -327,41 +317,35 @@ def get_fractional_positions(indices_grid, max_pos):


def precompute_freqs_cis(indices_grid, dim, out_dtype, theta=10000.0, max_pos=[20, 2048, 2048]):
dtype = torch.float32 #self.dtype
dtype = torch.float32
device = indices_grid.device

# Get fractional positions and compute frequency indices
fractional_positions = get_fractional_positions(indices_grid, max_pos)
indices = theta ** torch.linspace(0, 1, dim // 6, device=device, dtype=dtype) * math.pi / 2

start = 1
end = theta
device = fractional_positions.device

indices = theta ** (
torch.linspace(
math.log(start, theta),
math.log(end, theta),
dim // 6,
device=device,
dtype=dtype,
)
)
indices = indices.to(dtype=dtype)
# Compute frequencies and apply cos/sin
freqs = (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)).transpose(-1, -2).flatten(2)
cos_vals = freqs.cos().repeat_interleave(2, dim=-1)
sin_vals = freqs.sin().repeat_interleave(2, dim=-1)

# Pad if dim is not divisible by 6
if dim % 6 != 0:
padding_size = dim % 6
cos_vals = torch.cat([torch.ones_like(cos_vals[:, :, :padding_size]), cos_vals], dim=-1)
sin_vals = torch.cat([torch.zeros_like(sin_vals[:, :, :padding_size]), sin_vals], dim=-1)

indices = indices * math.pi / 2
# Reshape and extract one value per pair (since repeat_interleave duplicates each value)
cos_vals = cos_vals.reshape(*cos_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]
sin_vals = sin_vals.reshape(*sin_vals.shape[:2], -1, 2)[..., 0].to(out_dtype) # [B, N, dim//2]

freqs = (
(indices * (fractional_positions.unsqueeze(-1) * 2 - 1))
.transpose(-1, -2)
.flatten(2)
)
# Build rotation matrix [[cos, -sin], [sin, cos]] and add heads dimension
freqs_cis = torch.stack([
torch.stack([cos_vals, -sin_vals], dim=-1),
torch.stack([sin_vals, cos_vals], dim=-1)
], dim=-2).unsqueeze(1) # [B, 1, N, dim//2, 2, 2]

cos_freq = freqs.cos().repeat_interleave(2, dim=-1)
sin_freq = freqs.sin().repeat_interleave(2, dim=-1)
if dim % 6 != 0:
cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6])
sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6])
cos_freq = torch.cat([cos_padding, cos_freq], dim=-1)
sin_freq = torch.cat([sin_padding, sin_freq], dim=-1)
return cos_freq.to(out_dtype), sin_freq.to(out_dtype)
return freqs_cis


class LTXVModel(torch.nn.Module):
Expand Down Expand Up @@ -501,7 +485,7 @@ def block_wrap(args):
shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1]
x = self.norm_out(x)
# Modulation
x = x * (1 + scale) + shift
x = torch.addcmul(x, x, scale).add_(shift)
x = self.proj_out(x)

x = self.patchifier.unpatchify(
Expand Down
20 changes: 16 additions & 4 deletions comfy/ldm/lumina/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,7 +522,7 @@ def patchify_and_embed(
max_cap_len = max(l_effective_cap_len)
max_img_len = max(l_effective_img_len)

position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.int32, device=device)
position_ids = torch.zeros(bsz, max_seq_len, 3, dtype=torch.float32, device=device)

for i in range(bsz):
cap_len = l_effective_cap_len[i]
Expand All @@ -531,10 +531,22 @@ def patchify_and_embed(
H_tokens, W_tokens = H // pH, W // pW
assert H_tokens * W_tokens == img_len

position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
rope_options = transformer_options.get("rope_options", None)
h_scale = 1.0
w_scale = 1.0
h_start = 0
w_start = 0
if rope_options is not None:
h_scale = rope_options.get("scale_y", 1.0)
w_scale = rope_options.get("scale_x", 1.0)

h_start = rope_options.get("shift_y", 0.0)
w_start = rope_options.get("shift_x", 0.0)

position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.float32, device=device)
position_ids[i, cap_len:cap_len+img_len, 0] = cap_len
row_ids = torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
row_ids = (torch.arange(H_tokens, dtype=torch.float32, device=device) * h_scale + h_start).view(-1, 1).repeat(1, W_tokens).flatten()
col_ids = (torch.arange(W_tokens, dtype=torch.float32, device=device) * w_scale + w_start).view(1, -1).repeat(H_tokens, 1).flatten()
position_ids[i, cap_len:cap_len+img_len, 1] = row_ids
position_ids[i, cap_len:cap_len+img_len, 2] = col_ids

Expand Down
36 changes: 19 additions & 17 deletions comfy/ldm/qwen_image/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from comfy.ldm.flux.layers import EmbedND
import comfy.ldm.common_dit
import comfy.patcher_extension
from comfy.ldm.flux.math import apply_rope1

class GELU(nn.Module):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True, dtype=None, device=None, operations=None):
Expand Down Expand Up @@ -134,33 +135,34 @@ def forward(
image_rotary_emb: Optional[torch.Tensor] = None,
transformer_options={},
) -> Tuple[torch.Tensor, torch.Tensor]:
batch_size = hidden_states.shape[0]
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]

img_query = self.to_q(hidden_states).unflatten(-1, (self.heads, -1))
img_key = self.to_k(hidden_states).unflatten(-1, (self.heads, -1))
img_value = self.to_v(hidden_states).unflatten(-1, (self.heads, -1))
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_value = self.to_v(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2)

txt_query = self.add_q_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_key = self.add_k_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_value = self.add_v_proj(encoder_hidden_states).unflatten(-1, (self.heads, -1))
txt_query = self.add_q_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_key = self.add_k_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2).contiguous()
txt_value = self.add_v_proj(encoder_hidden_states).view(batch_size, seq_txt, self.heads, -1).transpose(1, 2)

img_query = self.norm_q(img_query)
img_key = self.norm_k(img_key)
txt_query = self.norm_added_q(txt_query)
txt_key = self.norm_added_k(txt_key)

joint_query = torch.cat([txt_query, img_query], dim=1)
joint_key = torch.cat([txt_key, img_key], dim=1)
joint_value = torch.cat([txt_value, img_value], dim=1)
joint_query = torch.cat([txt_query, img_query], dim=2)
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)

joint_query = apply_rotary_emb(joint_query, image_rotary_emb)
joint_key = apply_rotary_emb(joint_key, image_rotary_emb)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)

joint_query = joint_query.flatten(start_dim=2)
joint_key = joint_key.flatten(start_dim=2)
joint_value = joint_value.flatten(start_dim=2)

joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, attention_mask, transformer_options=transformer_options)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attention_mask, transformer_options=transformer_options,
skip_reshape=True)

txt_attn_output = joint_hidden_states[:, :seq_txt, :]
img_attn_output = joint_hidden_states[:, seq_txt:, :]
Expand Down Expand Up @@ -413,7 +415,7 @@ def _forward(
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).squeeze(1).unsqueeze(2).to(x.dtype)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids

hidden_states = self.img_in(hidden_states)
Expand Down
Loading