Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions configs/neopp/neopp_dense_fp8.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"timestep_shift": 3.0,
"cfg_interval": [-1, 2],
"enable_cfg": true,
"use_magi_compile": true,
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl"
}
1 change: 1 addition & 0 deletions configs/qwen_image/qwen_image_i2i_2511.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"resize_mode": "adaptive",
"attn_type": "flash_attn3",
"enable_cfg": true,
"use_magi_compile": true,
"sample_guide_scale": 4.0,
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true
Expand Down
12 changes: 12 additions & 0 deletions lightx2v/common/magi_custom_op_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
"""Global switch: use magi subgraph-boundary custom ops only when MAGI compile is active."""

_use_magi_custom_ops = False


def set_magi_custom_op_mode(enabled: bool) -> None:
global _use_magi_custom_ops
_use_magi_custom_ops = bool(enabled)


def use_magi_custom_ops() -> bool:
return _use_magi_custom_ops
19 changes: 12 additions & 7 deletions lightx2v/common/ops/norm/layer_norm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

from .triton_ops import norm_infer

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None

from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops


class LNWeightTemplate(metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -262,10 +269,8 @@ def __init__(
)

def apply(self, input_tensor):
output_tensor = norm_infer(
input_tensor,
(self._get_actual_weight()),
self._get_actual_bias(),
self.eps,
)
return output_tensor
w = self._get_actual_weight()
b = self._get_actual_bias()
if use_magi_custom_ops() and magi_register_custom_op is not None:
return torch.ops.lightx2v.triton_layer_norm(input_tensor, w, b, self.eps)
return norm_infer(input_tensor, w, b, self.eps)
12 changes: 11 additions & 1 deletion lightx2v/common/ops/norm/rms_norm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
except ImportError:
sgl_kernel = None

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None

from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops


class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -414,7 +421,10 @@ def __init__(
)

def apply(self, input_tensor):
return rms_norm_kernel(input_tensor, (self._get_actual_weight()), self.eps)
w = self._get_actual_weight()
if use_magi_custom_ops() and magi_register_custom_op is not None:
return torch.ops.lightx2v.rms_norm(input_tensor, w, self.eps)
return rms_norm_kernel(input_tensor, w, self.eps)


class RMSWeightFusedQKNorm3DRope:
Expand Down
24 changes: 24 additions & 0 deletions lightx2v/common/ops/norm/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import triton.language as tl # type: ignore
from torch import Tensor

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None


@triton.autotune(
configs=[
Expand Down Expand Up @@ -861,6 +866,25 @@ def norm_infer(
return out


if magi_register_custom_op is not None:

@magi_register_custom_op(
"lightx2v::triton_layer_norm",
infer_output_meta_fn=["x"],
is_subgraph_boundary=True,
)
def _triton_layer_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor:
return norm_infer(x, weight, bias, eps, is_rms_norm=False)

@magi_register_custom_op(
"lightx2v::rms_norm",
infer_output_meta_fn=["x"],
is_subgraph_boundary=True,
)
def _rms_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return rms_norm_kernel(x, weight, eps)


def rms_norm_fn(
x,
weight,
Expand Down
71 changes: 44 additions & 27 deletions lightx2v/models/networks/neopp/infer/transformer_infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from loguru import logger

# from flashinfer.activation import silu_and_mul as flashinfer_silu_and_mul
try:
Expand All @@ -9,9 +10,11 @@

try:
from magi_compiler import magi_compile, magi_register_custom_op
from magi_compiler.config import CudaGraphMode
except ImportError:
magi_compile = None
magi_register_custom_op = None
CudaGraphMode = None

from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.models.networks.neopp.infer.kv_cache_manager import KVCacheManager
Expand Down Expand Up @@ -83,6 +86,14 @@ def __init__(self, config):
self.seq_p_group = None
self.kv_cache = KVCacheManager()

# MagiCompiler enable/disable switch
self.use_magi_compile = config.get("use_magi_compile", False)
if self.use_magi_compile and magi_compile is None:
logger.warning("use_magi_compile=True but magi_compiler is not available, using eager mode")
self.use_magi_compile = False
if self.use_magi_compile:
logger.info("Using Magi Compile (per-layer decoder, split at kv_update)")

@torch.no_grad()
def infer(self, weights, pre_infer_out, inputs):
pass_key = "cond" if self.scheduler.infer_condition else "uncond"
Expand All @@ -96,7 +107,7 @@ def infer(self, weights, pre_infer_out, inputs):
hidden_states = self._fm_head(weights.fm_head, hidden_states)
return hidden_states.unsqueeze(0)

def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_values):
def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
seq_len_q = hidden_states.shape[0]
kvcache_len = past_key_values.shape[2]
seq_len_k = kvcache_len + seq_len_q
Expand All @@ -106,40 +117,43 @@ def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_v
self.kv_cache.clear()
self.kv_cache.prepare(past_key_values, seq_len_q)

self._cu_seqlens_q = torch.tensor([0, seq_len_q], dtype=torch.int32)
self._cu_seqlens_k = torch.tensor([0, seq_len_k], dtype=torch.int32)
self._max_seqlen_q = seq_len_q
self._max_seqlen_k = seq_len_k
self._kvcache_len = kvcache_len

kv_buf = self.kv_cache._kv_buf
for layer_idx, block_weight in enumerate(blocks):
hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin)
if self.use_magi_compile:
hidden_states = self._decoder_layer_magi(block_weight, layer_idx, hidden_states, cos_sin, kv_buf)
else:
hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf)
return hidden_states

if magi_compile is not None:
import torch._dynamo as _dynamo

@magi_compile(
dynamic_arg_dims={"hidden_states": 0, "past_key_values": 2},
config_patch=lambda c: c.model_copy(
_dynamo.config.capture_scalar_outputs = False
_dynamo.config.specialize_int = False
_dynamo.config.automatic_dynamic_shapes = True

def _magi_config_patch(c):
return c.model_copy(
update={
"enable_inductor_max_autotune": True,
"disable_cache": True, # Avoid pickle errors with custom op registrations
"disable_cache": True,
"cudagraph_mode": CudaGraphMode.NONE,
}
),
)
def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values)
else:
)

def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values)
@magi_compile(
dynamic_arg_dims={"hidden_states": 0, "kv_buf": 2},
config_patch=_magi_config_patch,
)
def _decoder_layer_magi(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf):
return self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf)

# @ProfilingContext4DebugL1("Decoder Layer")
def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin):
def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None):
residual = hidden_states
hidden_states = block_weight.input_layernorm_mot_gen.apply(hidden_states)

hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin)
hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin, kv_buf)
hidden_states = residual + hidden_states

residual = hidden_states
Expand All @@ -150,7 +164,7 @@ def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin):
return hidden_states

# @ProfilingContext4DebugL1("Self Attn")
def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin):
def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin, kv_buf=None):
query_states = attn_w.q_proj_mot_gen.apply(hidden_states)
query_states = query_states.view(-1, self.num_heads, self.head_dim) # [seq, num_heads, head_dim]

Expand All @@ -167,9 +181,12 @@ def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin):
value_states = attn_w.v_proj_mot_gen.apply(hidden_states)
value_states = value_states.view(-1, self.num_kv_heads, self.head_dim) # [seq, num_kv_heads, head_dim]

if kv_buf is None:
kv_buf = self.kv_cache._kv_buf

# Custom op: forces MagiCompiler to split the FX graph at this op,
# isolating the slice-scatter from the surrounding compiled regions.
key_states, value_states = torch.ops.neopp.kv_update(self.kv_cache._kv_buf, layer_idx, key_states, value_states)
key_states, value_states = torch.ops.neopp.kv_update(kv_buf, layer_idx, key_states, value_states)

attn_output = self._compute_attn(attn_w, query_states, key_states, value_states)

Expand Down Expand Up @@ -254,10 +271,10 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states):
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=self._cu_seqlens_q,
cu_seqlens_kv=self._cu_seqlens_k,
max_seqlen_q=self._max_seqlen_q,
max_seqlen_kv=self._max_seqlen_k,
cu_seqlens_q=torch.tensor([0, seq_len_q], dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], dtype=torch.int32),
Comment on lines +274 to +275
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating cu_seqlens tensors inside the _compute_attn method (which is called for every layer) introduces unnecessary overhead. Additionally, these tensors are created on the CPU by default, which may cause device synchronization or host-to-device copies if the attention implementation expects them on the GPU.

Consider creating these tensors once outside the layer loop or explicitly specifying the device.

Suggested change
cu_seqlens_q=torch.tensor([0, seq_len_q], dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], dtype=torch.int32),
cu_seqlens_q=torch.tensor([0, seq_len_q], device=query_states.device, dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], device=query_states.device, dtype=torch.int32),

max_seqlen_q=seq_len_q,
max_seqlen_kv=seq_len_k,
)
return attn_output

Expand Down
Loading
Loading