Skip to content
Merged
1 change: 1 addition & 0 deletions configs/neopp/neopp_dense_fp8.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"timestep_shift": 3.0,
"cfg_interval": [-1, 2],
"enable_cfg": true,
"use_magi_compile": true,
"dit_quantized": true,
"dit_quant_scheme": "fp8-sgl"
}
1 change: 1 addition & 0 deletions configs/qwen_image/qwen_image_i2i_2511.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"resize_mode": "adaptive",
"attn_type": "flash_attn3",
"enable_cfg": true,
"use_magi_compile": false,
"sample_guide_scale": 4.0,
"CONDITION_IMAGE_SIZE": 147456,
"USE_IMAGE_ID_IN_PROMPT": true
Expand Down
21 changes: 21 additions & 0 deletions lightx2v/common/magi_custom_op_mode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
"""Global switch: use magi subgraph-boundary custom ops only when MAGI compile is active."""

_use_magi_custom_ops = False


def set_magi_custom_op_mode(enabled: bool) -> None:
global _use_magi_custom_ops
_use_magi_custom_ops = bool(enabled)


def use_magi_custom_ops() -> bool:
return _use_magi_custom_ops


def configure_dynamo_for_magi_compile() -> None:
"""Apply Dynamo settings required by magi_compiler (call when use_magi_compile=True)."""
import torch._dynamo as _dynamo

_dynamo.config.capture_scalar_outputs = False
_dynamo.config.specialize_int = False
_dynamo.config.automatic_dynamic_shapes = True
19 changes: 12 additions & 7 deletions lightx2v/common/ops/norm/layer_norm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@

from .triton_ops import norm_infer

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None

from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops


class LNWeightTemplate(metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -262,10 +269,8 @@ def __init__(
)

def apply(self, input_tensor):
output_tensor = norm_infer(
input_tensor,
(self._get_actual_weight()),
self._get_actual_bias(),
self.eps,
)
return output_tensor
w = self._get_actual_weight()
b = self._get_actual_bias()
if use_magi_custom_ops() and magi_register_custom_op is not None:
return torch.ops.lightx2v.triton_layer_norm(input_tensor, w, b, self.eps)
return norm_infer(input_tensor, w, b, self.eps)
12 changes: 11 additions & 1 deletion lightx2v/common/ops/norm/rms_norm_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,13 @@
except ImportError:
sgl_kernel = None

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None

from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops


class RMSWeightTemplate(metaclass=ABCMeta):
def __init__(
Expand Down Expand Up @@ -414,7 +421,10 @@ def __init__(
)

def apply(self, input_tensor):
return rms_norm_kernel(input_tensor, (self._get_actual_weight()), self.eps)
w = self._get_actual_weight()
if use_magi_custom_ops() and magi_register_custom_op is not None:
return torch.ops.lightx2v.rms_norm(input_tensor, w, self.eps)
return rms_norm_kernel(input_tensor, w, self.eps)


class RMSWeightFusedQKNorm3DRope:
Expand Down
24 changes: 24 additions & 0 deletions lightx2v/common/ops/norm/triton_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@
import triton.language as tl # type: ignore
from torch import Tensor

try:
from magi_compiler import magi_register_custom_op
except ImportError:
magi_register_custom_op = None


@triton.autotune(
configs=[
Expand Down Expand Up @@ -861,6 +866,25 @@ def norm_infer(
return out


if magi_register_custom_op is not None:

@magi_register_custom_op(
"lightx2v::triton_layer_norm",
infer_output_meta_fn=["x"],
is_subgraph_boundary=True,
)
def _triton_layer_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor:
return norm_infer(x, weight, bias, eps, is_rms_norm=False)

@magi_register_custom_op(
"lightx2v::rms_norm",
infer_output_meta_fn=["x"],
is_subgraph_boundary=True,
)
def _rms_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
return rms_norm_kernel(x, weight, eps)


def rms_norm_fn(
x,
weight,
Expand Down
100 changes: 72 additions & 28 deletions lightx2v/models/networks/neopp/infer/transformer_infer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
import torch.nn.functional as F
from loguru import logger

# from flashinfer.activation import silu_and_mul as flashinfer_silu_and_mul
try:
Expand All @@ -13,6 +14,7 @@
magi_compile = None
magi_register_custom_op = None

from lightx2v.common.magi_custom_op_mode import configure_dynamo_for_magi_compile
from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer
from lightx2v.models.networks.neopp.infer.kv_cache_manager import KVCacheManager
from lightx2v.utils.profiler import *
Expand Down Expand Up @@ -83,6 +85,15 @@ def __init__(self, config):
self.seq_p_group = None
self.kv_cache = KVCacheManager()

# MagiCompiler enable/disable switch
self.use_magi_compile = config.get("use_magi_compile", False)
if self.use_magi_compile and magi_compile is None:
logger.warning("use_magi_compile=True but magi_compiler is not available, using eager mode")
self.use_magi_compile = False
if self.use_magi_compile:
configure_dynamo_for_magi_compile()
logger.info("Using Magi Compile (per-layer decoder, split at kv_update)")

@torch.no_grad()
def infer(self, weights, pre_infer_out, inputs):
pass_key = "cond" if self.scheduler.infer_condition else "uncond"
Expand All @@ -96,7 +107,7 @@ def infer(self, weights, pre_infer_out, inputs):
hidden_states = self._fm_head(weights.fm_head, hidden_states)
return hidden_states.unsqueeze(0)

def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_values):
def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
seq_len_q = hidden_states.shape[0]
kvcache_len = past_key_values.shape[2]
seq_len_k = kvcache_len + seq_len_q
Expand All @@ -106,40 +117,70 @@ def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_v
self.kv_cache.clear()
self.kv_cache.prepare(past_key_values, seq_len_q)

self._cu_seqlens_q = torch.tensor([0, seq_len_q], dtype=torch.int32)
self._cu_seqlens_k = torch.tensor([0, seq_len_k], dtype=torch.int32)
self._max_seqlen_q = seq_len_q
self._max_seqlen_k = seq_len_k
self._kvcache_len = kvcache_len

kv_buf = self.kv_cache._kv_buf
cos_t, sin_t, cos_h, sin_h, cos_w, sin_w = cos_sin
for layer_idx, block_weight in enumerate(blocks):
hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin)
if self.use_magi_compile:
hidden_states = self._decoder_layer_magi(
block_weight,
layer_idx,
hidden_states,
cos_t,
sin_t,
cos_h,
sin_h,
cos_w,
sin_w,
kv_buf,
)
else:
hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf)
return hidden_states

if magi_compile is not None:

@magi_compile(
dynamic_arg_dims={"hidden_states": 0, "past_key_values": 2},
config_patch=lambda c: c.model_copy(
def _magi_config_patch(c):
return c.model_copy(
update={
"enable_inductor_max_autotune": True,
"disable_cache": True, # Avoid pickle errors with custom op registrations
"disable_cache": True,
}
),
)
def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values)
else:
)

def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values):
return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values)
@magi_compile(
dynamic_arg_dims={
"hidden_states": 0,
"kv_buf": 2,
"cos_t": 1,
"sin_t": 1,
"cos_h": 1,
"sin_h": 1,
"cos_w": 1,
"sin_w": 1,
},
config_patch=_magi_config_patch,
)
def _decoder_layer_magi(
self,
block_weight,
layer_idx,
hidden_states,
cos_t,
sin_t,
cos_h,
sin_h,
cos_w,
sin_w,
kv_buf,
):
cos_sin = (cos_t, sin_t, cos_h, sin_h, cos_w, sin_w)
return self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf)

# @ProfilingContext4DebugL1("Decoder Layer")
def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin):
def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None):
residual = hidden_states
hidden_states = block_weight.input_layernorm_mot_gen.apply(hidden_states)

hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin)
hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin, kv_buf)
hidden_states = residual + hidden_states

residual = hidden_states
Expand All @@ -150,7 +191,7 @@ def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin):
return hidden_states

# @ProfilingContext4DebugL1("Self Attn")
def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin):
def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin, kv_buf=None):
query_states = attn_w.q_proj_mot_gen.apply(hidden_states)
query_states = query_states.view(-1, self.num_heads, self.head_dim) # [seq, num_heads, head_dim]

Expand All @@ -167,9 +208,12 @@ def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin):
value_states = attn_w.v_proj_mot_gen.apply(hidden_states)
value_states = value_states.view(-1, self.num_kv_heads, self.head_dim) # [seq, num_kv_heads, head_dim]

if kv_buf is None:
kv_buf = self.kv_cache._kv_buf

# Custom op: forces MagiCompiler to split the FX graph at this op,
# isolating the slice-scatter from the surrounding compiled regions.
key_states, value_states = torch.ops.neopp.kv_update(self.kv_cache._kv_buf, layer_idx, key_states, value_states)
key_states, value_states = torch.ops.neopp.kv_update(kv_buf, layer_idx, key_states, value_states)

attn_output = self._compute_attn(attn_w, query_states, key_states, value_states)

Expand Down Expand Up @@ -254,10 +298,10 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states):
q=query_states,
k=key_states,
v=value_states,
cu_seqlens_q=self._cu_seqlens_q,
cu_seqlens_kv=self._cu_seqlens_k,
max_seqlen_q=self._max_seqlen_q,
max_seqlen_kv=self._max_seqlen_k,
cu_seqlens_q=torch.tensor([0, seq_len_q], dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], dtype=torch.int32),
Comment on lines +301 to +302
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Creating cu_seqlens tensors inside the _compute_attn method (which is called for every layer) introduces unnecessary overhead. Additionally, these tensors are created on the CPU by default, which may cause device synchronization or host-to-device copies if the attention implementation expects them on the GPU.

Consider creating these tensors once outside the layer loop or explicitly specifying the device.

Suggested change
cu_seqlens_q=torch.tensor([0, seq_len_q], dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], dtype=torch.int32),
cu_seqlens_q=torch.tensor([0, seq_len_q], device=query_states.device, dtype=torch.int32),
cu_seqlens_kv=torch.tensor([0, seq_len_k], device=query_states.device, dtype=torch.int32),

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cu_seqlens_q和cu_seqlens_kv放在cpu是有意设计。

max_seqlen_q=seq_len_q,
max_seqlen_kv=seq_len_k,
)
return attn_output

Expand Down
Loading
Loading