diff --git a/configs/neopp/neopp_dense_fp8.json b/configs/neopp/neopp_dense_fp8.json index 055b27345..3ca76cdeb 100644 --- a/configs/neopp/neopp_dense_fp8.json +++ b/configs/neopp/neopp_dense_fp8.json @@ -7,6 +7,7 @@ "timestep_shift": 3.0, "cfg_interval": [-1, 2], "enable_cfg": true, + "use_magi_compile": true, "dit_quantized": true, "dit_quant_scheme": "fp8-sgl" } diff --git a/configs/qwen_image/qwen_image_i2i_2511.json b/configs/qwen_image/qwen_image_i2i_2511.json index 9093a458a..6a183599c 100755 --- a/configs/qwen_image/qwen_image_i2i_2511.json +++ b/configs/qwen_image/qwen_image_i2i_2511.json @@ -5,6 +5,7 @@ "resize_mode": "adaptive", "attn_type": "flash_attn3", "enable_cfg": true, + "use_magi_compile": true, "sample_guide_scale": 4.0, "CONDITION_IMAGE_SIZE": 147456, "USE_IMAGE_ID_IN_PROMPT": true diff --git a/lightx2v/common/magi_custom_op_mode.py b/lightx2v/common/magi_custom_op_mode.py new file mode 100644 index 000000000..260b72534 --- /dev/null +++ b/lightx2v/common/magi_custom_op_mode.py @@ -0,0 +1,12 @@ +"""Global switch: use magi subgraph-boundary custom ops only when MAGI compile is active.""" + +_use_magi_custom_ops = False + + +def set_magi_custom_op_mode(enabled: bool) -> None: + global _use_magi_custom_ops + _use_magi_custom_ops = bool(enabled) + + +def use_magi_custom_ops() -> bool: + return _use_magi_custom_ops diff --git a/lightx2v/common/ops/norm/layer_norm_weight.py b/lightx2v/common/ops/norm/layer_norm_weight.py index add98de4d..732d9f273 100755 --- a/lightx2v/common/ops/norm/layer_norm_weight.py +++ b/lightx2v/common/ops/norm/layer_norm_weight.py @@ -10,6 +10,13 @@ from .triton_ops import norm_infer +try: + from magi_compiler import magi_register_custom_op +except ImportError: + magi_register_custom_op = None + +from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops + class LNWeightTemplate(metaclass=ABCMeta): def __init__( @@ -262,10 +269,8 @@ def __init__( ) def apply(self, input_tensor): - output_tensor = norm_infer( - input_tensor, - (self._get_actual_weight()), - self._get_actual_bias(), - self.eps, - ) - return output_tensor + w = self._get_actual_weight() + b = self._get_actual_bias() + if use_magi_custom_ops() and magi_register_custom_op is not None: + return torch.ops.lightx2v.triton_layer_norm(input_tensor, w, b, self.eps) + return norm_infer(input_tensor, w, b, self.eps) diff --git a/lightx2v/common/ops/norm/rms_norm_weight.py b/lightx2v/common/ops/norm/rms_norm_weight.py index c7ebe3bc9..340b747aa 100755 --- a/lightx2v/common/ops/norm/rms_norm_weight.py +++ b/lightx2v/common/ops/norm/rms_norm_weight.py @@ -16,6 +16,13 @@ except ImportError: sgl_kernel = None +try: + from magi_compiler import magi_register_custom_op +except ImportError: + magi_register_custom_op = None + +from lightx2v.common.magi_custom_op_mode import use_magi_custom_ops + class RMSWeightTemplate(metaclass=ABCMeta): def __init__( @@ -414,7 +421,10 @@ def __init__( ) def apply(self, input_tensor): - return rms_norm_kernel(input_tensor, (self._get_actual_weight()), self.eps) + w = self._get_actual_weight() + if use_magi_custom_ops() and magi_register_custom_op is not None: + return torch.ops.lightx2v.rms_norm(input_tensor, w, self.eps) + return rms_norm_kernel(input_tensor, w, self.eps) class RMSWeightFusedQKNorm3DRope: diff --git a/lightx2v/common/ops/norm/triton_ops.py b/lightx2v/common/ops/norm/triton_ops.py index a43ee4c08..b73ea505d 100644 --- a/lightx2v/common/ops/norm/triton_ops.py +++ b/lightx2v/common/ops/norm/triton_ops.py @@ -8,6 +8,11 @@ import triton.language as tl # type: ignore from torch import Tensor +try: + from magi_compiler import magi_register_custom_op +except ImportError: + magi_register_custom_op = None + @triton.autotune( configs=[ @@ -861,6 +866,25 @@ def norm_infer( return out +if magi_register_custom_op is not None: + + @magi_register_custom_op( + "lightx2v::triton_layer_norm", + infer_output_meta_fn=["x"], + is_subgraph_boundary=True, + ) + def _triton_layer_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor, eps: float) -> torch.Tensor: + return norm_infer(x, weight, bias, eps, is_rms_norm=False) + + @magi_register_custom_op( + "lightx2v::rms_norm", + infer_output_meta_fn=["x"], + is_subgraph_boundary=True, + ) + def _rms_norm_custom_op(x: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor: + return rms_norm_kernel(x, weight, eps) + + def rms_norm_fn( x, weight, diff --git a/lightx2v/models/networks/neopp/infer/transformer_infer.py b/lightx2v/models/networks/neopp/infer/transformer_infer.py index 4ba1b2246..8c34f0120 100755 --- a/lightx2v/models/networks/neopp/infer/transformer_infer.py +++ b/lightx2v/models/networks/neopp/infer/transformer_infer.py @@ -1,5 +1,6 @@ import torch import torch.nn.functional as F +from loguru import logger # from flashinfer.activation import silu_and_mul as flashinfer_silu_and_mul try: @@ -9,9 +10,11 @@ try: from magi_compiler import magi_compile, magi_register_custom_op + from magi_compiler.config import CudaGraphMode except ImportError: magi_compile = None magi_register_custom_op = None + CudaGraphMode = None from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from lightx2v.models.networks.neopp.infer.kv_cache_manager import KVCacheManager @@ -83,6 +86,14 @@ def __init__(self, config): self.seq_p_group = None self.kv_cache = KVCacheManager() + # MagiCompiler enable/disable switch + self.use_magi_compile = config.get("use_magi_compile", False) + if self.use_magi_compile and magi_compile is None: + logger.warning("use_magi_compile=True but magi_compiler is not available, using eager mode") + self.use_magi_compile = False + if self.use_magi_compile: + logger.info("Using Magi Compile (per-layer decoder, split at kv_update)") + @torch.no_grad() def infer(self, weights, pre_infer_out, inputs): pass_key = "cond" if self.scheduler.infer_condition else "uncond" @@ -96,7 +107,7 @@ def infer(self, weights, pre_infer_out, inputs): hidden_states = self._fm_head(weights.fm_head, hidden_states) return hidden_states.unsqueeze(0) - def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_values): + def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values): seq_len_q = hidden_states.shape[0] kvcache_len = past_key_values.shape[2] seq_len_k = kvcache_len + seq_len_q @@ -106,40 +117,43 @@ def _infer_without_offload_impl(self, blocks, hidden_states, cos_sin, past_key_v self.kv_cache.clear() self.kv_cache.prepare(past_key_values, seq_len_q) - self._cu_seqlens_q = torch.tensor([0, seq_len_q], dtype=torch.int32) - self._cu_seqlens_k = torch.tensor([0, seq_len_k], dtype=torch.int32) - self._max_seqlen_q = seq_len_q - self._max_seqlen_k = seq_len_k - self._kvcache_len = kvcache_len - + kv_buf = self.kv_cache._kv_buf for layer_idx, block_weight in enumerate(blocks): - hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin) + if self.use_magi_compile: + hidden_states = self._decoder_layer_magi(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) + else: + hidden_states = self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) return hidden_states if magi_compile is not None: + import torch._dynamo as _dynamo - @magi_compile( - dynamic_arg_dims={"hidden_states": 0, "past_key_values": 2}, - config_patch=lambda c: c.model_copy( + _dynamo.config.capture_scalar_outputs = False + _dynamo.config.specialize_int = False + _dynamo.config.automatic_dynamic_shapes = True + + def _magi_config_patch(c): + return c.model_copy( update={ "enable_inductor_max_autotune": True, - "disable_cache": True, # Avoid pickle errors with custom op registrations + "disable_cache": True, + "cudagraph_mode": CudaGraphMode.NONE, } - ), - ) - def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values): - return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values) - else: + ) - def infer_without_offload(self, blocks, hidden_states, cos_sin, past_key_values): - return self._infer_without_offload_impl(blocks, hidden_states, cos_sin, past_key_values) + @magi_compile( + dynamic_arg_dims={"hidden_states": 0, "kv_buf": 2}, + config_patch=_magi_config_patch, + ) + def _decoder_layer_magi(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf): + return self._decoder_layer(block_weight, layer_idx, hidden_states, cos_sin, kv_buf) # @ProfilingContext4DebugL1("Decoder Layer") - def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin): + def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin, kv_buf=None): residual = hidden_states hidden_states = block_weight.input_layernorm_mot_gen.apply(hidden_states) - hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin) + hidden_states = self._self_attn(block_weight.self_attn, layer_idx, hidden_states, cos_sin, kv_buf) hidden_states = residual + hidden_states residual = hidden_states @@ -150,7 +164,7 @@ def _decoder_layer(self, block_weight, layer_idx, hidden_states, cos_sin): return hidden_states # @ProfilingContext4DebugL1("Self Attn") - def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin): + def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin, kv_buf=None): query_states = attn_w.q_proj_mot_gen.apply(hidden_states) query_states = query_states.view(-1, self.num_heads, self.head_dim) # [seq, num_heads, head_dim] @@ -167,9 +181,12 @@ def _self_attn(self, attn_w, layer_idx, hidden_states, cos_sin): value_states = attn_w.v_proj_mot_gen.apply(hidden_states) value_states = value_states.view(-1, self.num_kv_heads, self.head_dim) # [seq, num_kv_heads, head_dim] + if kv_buf is None: + kv_buf = self.kv_cache._kv_buf + # Custom op: forces MagiCompiler to split the FX graph at this op, # isolating the slice-scatter from the surrounding compiled regions. - key_states, value_states = torch.ops.neopp.kv_update(self.kv_cache._kv_buf, layer_idx, key_states, value_states) + key_states, value_states = torch.ops.neopp.kv_update(kv_buf, layer_idx, key_states, value_states) attn_output = self._compute_attn(attn_w, query_states, key_states, value_states) @@ -254,10 +271,10 @@ def _compute_attn(self, attn_w, query_states, key_states, value_states): q=query_states, k=key_states, v=value_states, - cu_seqlens_q=self._cu_seqlens_q, - cu_seqlens_kv=self._cu_seqlens_k, - max_seqlen_q=self._max_seqlen_q, - max_seqlen_kv=self._max_seqlen_k, + cu_seqlens_q=torch.tensor([0, seq_len_q], dtype=torch.int32), + cu_seqlens_kv=torch.tensor([0, seq_len_k], dtype=torch.int32), + max_seqlen_q=seq_len_q, + max_seqlen_kv=seq_len_k, ) return attn_output diff --git a/lightx2v/models/networks/qwen_image/infer/transformer_infer.py b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py index 67a6e5bf4..c5821d7dd 100755 --- a/lightx2v/models/networks/qwen_image/infer/transformer_infer.py +++ b/lightx2v/models/networks/qwen_image/infer/transformer_infer.py @@ -1,6 +1,10 @@ +import os + import torch import torch.nn.functional as F +from loguru import logger +from lightx2v.common.magi_custom_op_mode import set_magi_custom_op_mode from lightx2v.common.transformer_infer.transformer_infer import BaseTransformerInfer from .triton_ops import ( @@ -9,12 +13,12 @@ ) from .utils import apply_qwen_rope_with_flashinfer, apply_qwen_rope_with_torch, apply_qwen_rope_with_torch_naive - -def calculate_q_k_len(q, k_lens): - q_lens = torch.tensor([q.size(0)], dtype=torch.int32) - cu_seqlens_q = torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(0, dtype=torch.int32) - cu_seqlens_k = torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(0, dtype=torch.int32) - return cu_seqlens_q, cu_seqlens_k +try: + from magi_compiler import magi_compile + from magi_compiler.config import CudaGraphMode +except ImportError: + magi_compile = None + CudaGraphMode = None class QwenImageTransformerInfer(BaseTransformerInfer): @@ -22,7 +26,16 @@ def __init__(self, config): self.config = config self.infer_conditional = True self.clean_cuda_cache = self.config.get("clean_cuda_cache", False) + + self.use_magi_compile = config.get("use_magi_compile", False) + if self.use_magi_compile and magi_compile is None: + logger.warning("use_magi_compile=True but magi_compiler is not available, using eager mode") + self.use_magi_compile = False + set_magi_custom_op_mode(self.use_magi_compile) + if self.use_magi_compile: + logger.info("Using Magi Compile (split: pre-attn / cross-attn eager / post-attn)") self.infer_func = self.infer_calculating + self.attn_type = config.get("attn_type", "flash_attn3") self.zero_cond_t = config.get("zero_cond_t", False) if self.config["seq_parallel"]: @@ -95,35 +108,6 @@ def _modulate(self, x, mod_params, index=None): gate_result = gate.unsqueeze(0) return self.modulate_func(x, scale_result, shift_result).squeeze(0), gate_result.squeeze(0) - def infer_modulate( - self, - mod_phase, - hidden_states, - encoder_hidden_states, - temb_img_silu, - temb_txt_silu, - modulate_index=None, - ): - # Get modulation parameters for both streams - img_mod_params = mod_phase.img_mod.apply(temb_img_silu) - - txt_mod_params = mod_phase.txt_mod.apply(temb_txt_silu) - - # Split modulation parameters for norm1 and norm2 - img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) - - txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) - - # Process image stream - norm1 + modulation - img_normed = mod_phase.img_norm1.apply(hidden_states) - img_modulated, img_gate1 = self._modulate(img_normed, img_mod1, modulate_index) - - # Process text stream - norm1 + modulation - txt_normed = mod_phase.txt_norm1.apply(encoder_hidden_states) - txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1) - - return img_modulated, txt_modulated, img_gate1, txt_gate1, img_mod2, txt_mod2 - def infer_img_qkv( self, img_attn_phase, @@ -156,7 +140,6 @@ def infer_img_qkv( return img_query, img_key, img_value, img_gate1, img_mod2 def infer_txt_qkv(self, txt_attn_phase, encoder_hidden_states, temb_txt_silu, txt_freqs): - # Get sequence length from text hidden states seq_txt = encoder_hidden_states.shape[0] txt_mod_params = txt_attn_phase.txt_mod.apply(temb_txt_silu) @@ -198,8 +181,6 @@ def infer_cross_attn( hidden_states, encoder_hidden_states, ): - # Concatenate for joint attention - # Order: [text, image] joint_query = torch.cat([txt_query, img_query], dim=0) joint_key = torch.cat([txt_key, img_key], dim=0) joint_value = torch.cat([txt_value, img_value], dim=0) @@ -232,9 +213,8 @@ def infer_cross_attn( max_seqlen_kv=img_qkv_len, ) - # Split attention outputs back - txt_attn_output = joint_hidden_states[:seq_txt, :] # Text part - img_attn_output = joint_hidden_states[seq_txt:, :] # Image part + txt_attn_output = joint_hidden_states[:seq_txt, :] + img_attn_output = joint_hidden_states[seq_txt:, :] # Apply output projections img_attn_output = cross_attn_phase.to_out.apply(img_attn_output) @@ -278,6 +258,66 @@ def infer_ffn( return encoder_hidden_states, hidden_states + def infer_block_pre_attn( + self, + img_attn_phase, + txt_attn_phase, + hidden_states, + encoder_hidden_states, + temb_img_silu, + temb_txt_silu, + img_freqs, + txt_freqs, + modulate_index=None, + ): + """Norm1 + modulate + QKV (+ RoPE); hidden_states unchanged for cross-attn residual.""" + img_query, img_key, img_value, img_gate1, img_mod2 = self.infer_img_qkv( + img_attn_phase=img_attn_phase, + hidden_states=hidden_states, + temb_img_silu=temb_img_silu, + img_freqs=img_freqs, + modulate_index=modulate_index, + ) + + txt_query, txt_key, txt_value, _, txt_gate1, txt_mod2 = self.infer_txt_qkv( + txt_attn_phase=txt_attn_phase, + encoder_hidden_states=encoder_hidden_states, + temb_txt_silu=temb_txt_silu, + txt_freqs=txt_freqs, + ) + + return ( + img_query, + img_key, + img_value, + txt_query, + txt_key, + txt_value, + img_gate1, + txt_gate1, + img_mod2, + txt_mod2, + ) + + def infer_block_post_attn( + self, + ffn_phase, + hidden_states, + encoder_hidden_states, + img_mod2, + txt_mod2, + modulate_index=None, + ): + """Norm2 + modulate + FFN after cross-attention.""" + return self.infer_ffn( + ffn_phase=ffn_phase, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + img_mod2=img_mod2, + txt_mod2=txt_mod2, + modulate_index=modulate_index, + ) + def infer_block( self, block, @@ -288,6 +328,52 @@ def infer_block( image_rotary_emb, modulate_index=None, ): + if self.use_magi_compile: + ( + img_query, + img_key, + img_value, + txt_query, + txt_key, + txt_value, + img_gate1, + txt_gate1, + img_mod2, + txt_mod2, + ) = self.infer_block_pre_attn_magi( + img_attn_phase=block.compute_phases[0], + txt_attn_phase=block.compute_phases[1], + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_img_silu=temb_img_silu, + temb_txt_silu=temb_txt_silu, + img_freqs=image_rotary_emb[0], + txt_freqs=image_rotary_emb[1], + modulate_index=modulate_index, + ) + hidden_states, encoder_hidden_states = self.infer_cross_attn( + cross_attn_phase=block.compute_phases[2], + seq_txt=encoder_hidden_states.shape[0], + img_query=img_query, + img_key=img_key, + img_value=img_value, + txt_query=txt_query, + txt_key=txt_key, + txt_value=txt_value, + img_gate1=img_gate1, + txt_gate1=txt_gate1, + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + return self.infer_block_post_attn_magi( + ffn_phase=block.compute_phases[3], + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + img_mod2=img_mod2, + txt_mod2=txt_mod2, + modulate_index=modulate_index, + ) + img_query, img_key, img_value, img_gate1, img_mod2 = self.infer_img_qkv( img_attn_phase=block.compute_phases[0], hidden_states=hidden_states, @@ -339,6 +425,19 @@ def infer_calculating( image_rotary_emb, modulate_index, ): + trace_path = os.environ.get("QWEN_MAGI_PROFILE_TRACE") + if trace_path: + return self._infer_calculating_profiled( + blocks, + hidden_states, + encoder_hidden_states, + temb_img_silu, + temb_txt_silu, + image_rotary_emb, + modulate_index, + trace_path, + ) + for idx in range(len(blocks)): encoder_hidden_states, hidden_states = self.infer_block( block=blocks[idx], @@ -351,6 +450,119 @@ def infer_calculating( ) return hidden_states + def _infer_calculating_profiled( + self, + blocks, + hidden_states, + encoder_hidden_states, + temb_img_silu, + temb_txt_silu, + image_rotary_emb, + modulate_index, + trace_path, + ): + import torch.profiler as torch_profiler + from torch.profiler import ProfilerActivity, schedule + + my_schedule = schedule(wait=1, warmup=1, active=1, repeat=1) + with torch_profiler.profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + schedule=my_schedule, + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for step in range(3): + for idx in range(len(blocks)): + encoder_hidden_states, hidden_states = self.infer_block( + block=blocks[idx], + hidden_states=hidden_states, + encoder_hidden_states=encoder_hidden_states, + temb_img_silu=temb_img_silu, + temb_txt_silu=temb_txt_silu, + image_rotary_emb=image_rotary_emb, + modulate_index=modulate_index, + ) + prof.step() + prof.export_chrome_trace(trace_path) + return hidden_states + + if magi_compile is not None: + import torch._dynamo as _dynamo + + _dynamo.config.capture_scalar_outputs = False + _dynamo.config.specialize_int = False + _dynamo.config.automatic_dynamic_shapes = True + + def _magi_config_patch(c): + return c.model_copy( + update={ + "enable_inductor_max_autotune": False, + "disable_cache": True, + "cudagraph_mode": CudaGraphMode.NONE, + } + ) + + @magi_compile( + dynamic_arg_dims={ + "hidden_states": 0, + "encoder_hidden_states": 0, + "img_freqs": 0, + "txt_freqs": 0, + "modulate_index": 1, + }, + config_patch=_magi_config_patch, + ) + def infer_block_pre_attn_magi( + self, + img_attn_phase, + txt_attn_phase, + hidden_states, + encoder_hidden_states, + temb_img_silu, + temb_txt_silu, + img_freqs, + txt_freqs, + modulate_index=None, + ): + return self.infer_block_pre_attn( + img_attn_phase, + txt_attn_phase, + hidden_states, + encoder_hidden_states, + temb_img_silu, + temb_txt_silu, + img_freqs, + txt_freqs, + modulate_index, + ) + + @magi_compile( + dynamic_arg_dims={ + "hidden_states": 0, + "encoder_hidden_states": 0, + "modulate_index": 1, + }, + config_patch=_magi_config_patch, + ) + def infer_block_post_attn_magi( + self, + ffn_phase, + hidden_states, + encoder_hidden_states, + img_mod2, + txt_mod2, + modulate_index=None, + ): + return self.infer_block_post_attn( + ffn_phase, + hidden_states, + encoder_hidden_states, + img_mod2, + txt_mod2, + modulate_index, + ) + def infer(self, block_weights, pre_infer_out): hidden_states = pre_infer_out.hidden_states encoder_hidden_states = pre_infer_out.encoder_hidden_states diff --git a/lightx2v/models/networks/qwen_image/infer/utils.py b/lightx2v/models/networks/qwen_image/infer/utils.py index ce40e0277..82a1d2fcd 100755 --- a/lightx2v/models/networks/qwen_image/infer/utils.py +++ b/lightx2v/models/networks/qwen_image/infer/utils.py @@ -7,16 +7,27 @@ except ImportError: apply_rope_with_cos_sin_cache_inplace = None +try: + from magi_compiler import magi_register_custom_op +except ImportError: + magi_register_custom_op = None -def apply_qwen_rope_with_flashinfer( - xq: torch.Tensor, - xk: torch.Tensor, - cos_sin_cache: torch.Tensor, -): +# Re-export for transformer_infer init. +from lightx2v.common.magi_custom_op_mode import ( + set_magi_custom_op_mode, # noqa: F401 + use_magi_custom_ops, +) + + +def _qwen_rope_meta(xq, xk, cos_sin_cache): + return torch.empty_like(xq), torch.empty_like(xk) + + +def _apply_qwen_rope_with_flashinfer_impl(xq, xk, cos_sin_cache): L, H, D = xq.shape - query = xq.reshape(L, H * D).contiguous() - key = xk.reshape(L, H * D).contiguous() + query = xq.reshape(L, H * D).contiguous().clone() + key = xk.reshape(L, H * D).contiguous().clone() positions = torch.arange(L, device="cpu", dtype=torch.long).to(xq.device, non_blocking=True) @@ -34,17 +45,55 @@ def apply_qwen_rope_with_flashinfer( return xq_out, xk_out +def _apply_qwen_rope_with_torch_impl(xq, xk, cos_sin_cache): + xq_rotated = torch.view_as_complex(xq.float().unflatten(-1, (-1, 2))) + xk_rotated = torch.view_as_complex(xk.float().unflatten(-1, (-1, 2))) + freqs_cis = cos_sin_cache.unsqueeze(1) + xq_out = torch.view_as_real(xq_rotated * freqs_cis).flatten(-2) + xk_out = torch.view_as_real(xk_rotated * freqs_cis).flatten(-2) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +if magi_register_custom_op is not None and apply_rope_with_cos_sin_cache_inplace is not None: + + @magi_register_custom_op( + "lightx2v::qwen_rope_flashinfer", + infer_output_meta_fn=_qwen_rope_meta, + is_subgraph_boundary=True, + ) + def _qwen_rope_flashinfer_custom_op(xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return _apply_qwen_rope_with_flashinfer_impl(xq, xk, cos_sin_cache) + + +if magi_register_custom_op is not None: + + @magi_register_custom_op( + "lightx2v::qwen_rope_torch", + infer_output_meta_fn=_qwen_rope_meta, + is_subgraph_boundary=True, + ) + def _qwen_rope_torch_custom_op(xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return _apply_qwen_rope_with_torch_impl(xq, xk, cos_sin_cache) + + +def apply_qwen_rope_with_flashinfer( + xq: torch.Tensor, + xk: torch.Tensor, + cos_sin_cache: torch.Tensor, +): + if use_magi_custom_ops() and magi_register_custom_op is not None and apply_rope_with_cos_sin_cache_inplace is not None: + return torch.ops.lightx2v.qwen_rope_flashinfer(xq, xk, cos_sin_cache) + return _apply_qwen_rope_with_flashinfer_impl(xq, xk, cos_sin_cache) + + def apply_qwen_rope_with_torch( xq: torch.Tensor, xk: torch.Tensor, cos_sin_cache: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: - xq_rotated = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)).squeeze(0) - xk_rotated = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)).squeeze(0) - freqs_cis = cos_sin_cache.unsqueeze(1) - xq_out = torch.view_as_real(xq_rotated * freqs_cis).flatten(-2) - xk_out = torch.view_as_real(xk_rotated * freqs_cis).flatten(-2) - return xq_out.type_as(xq), xk_out.type_as(xk) + if use_magi_custom_ops() and magi_register_custom_op is not None: + return torch.ops.lightx2v.qwen_rope_torch(xq, xk, cos_sin_cache) + return _apply_qwen_rope_with_torch_impl(xq, xk, cos_sin_cache) def apply_qwen_rope_with_torch_naive(