diff --git a/vllm_fl/__init__.py b/vllm_fl/__init__.py index 6280da89..4774eead 100644 --- a/vllm_fl/__init__.py +++ b/vllm_fl/__init__.py @@ -41,7 +41,6 @@ def register(): def register_model(): """Register the FL model.""" from vllm import ModelRegistry - import vllm.model_executor.models.qwen3_next as qwen3_next_module # Register Qwen3.5 MoE config try: diff --git a/vllm_fl/compilation/graph.py b/vllm_fl/compilation/graph.py index 909ec2d7..96316543 100644 --- a/vllm_fl/compilation/graph.py +++ b/vllm_fl/compilation/graph.py @@ -40,7 +40,8 @@ class Graph: elif current_platform.device_type == "npu": graph = torch.npu.NPUGraph else: - raise NotImplementedError("not support graph") + pass + # raise NotImplementedError("not support graph") @dataclasses.dataclass class GraphEntry: diff --git a/vllm_fl/dispatch/backends/vendor/txda/__init__.py b/vllm_fl/dispatch/backends/vendor/txda/__init__.py new file mode 100644 index 00000000..0c1cd8c8 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/txda/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Txda (Tsingmicro) backend for vllm-plugin-FL dispatch. +""" + +from .txda import TxdaBackend + +__all__ = ["TxadBackend"] diff --git a/vllm_fl/dispatch/backends/vendor/txda/register_ops.py b/vllm_fl/dispatch/backends/vendor/txda/register_ops.py new file mode 100755 index 00000000..e322c181 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/txda/register_ops.py @@ -0,0 +1,78 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +METAX backend operator registrations. + +This module registers all VENDOR (METAX) implementations. +""" + +from __future__ import annotations + +import functools + +from vllm_fl.dispatch.types import OpImpl, BackendImplKind, BackendPriority + + +def _bind_is_available(fn, is_available_fn): + """Wrap a function and bind _is_available attribute for OpImpl.is_available() check.""" + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper._is_available = is_available_fn + return wrapper + + +def register_builtins(registry) -> None: + """ + Register all METAX (VENDOR) operator implementations. + + Args: + registry: Registry to register into + """ + from .txda import TxdaBackend + + backend = TxdaBackend() + is_avail = backend.is_available + + impls = [ + # # Activation + # OpImpl( + # op_name="silu_and_mul", + # impl_id="vendor.txda", + # kind=BackendImplKind.VENDOR, + # fn=_bind_is_available(backend.silu_and_mul, is_avail), + # vendor="txda", + # priority=BackendPriority.VENDOR, + # ), + # # Normalization + # OpImpl( + # op_name="rms_norm", + # impl_id="vendor.txda", + # kind=BackendImplKind.VENDOR, + # fn=_bind_is_available(backend.rms_norm, is_avail), + # vendor="txda", + # priority=BackendPriority.VENDOR, + # ), + # # Rotary Embedding + # OpImpl( + # op_name="rotary_embedding", + # impl_id="vendor.txda", + # kind=BackendImplKind.VENDOR, + # fn=_bind_is_available(backend.rotary_embedding, is_avail), + # vendor="txda", + # priority=BackendPriority.VENDOR, + # ), + # Attention Backend + OpImpl( + op_name="attention_backend", + impl_id="vendor.txda", + kind=BackendImplKind.VENDOR, + fn=_bind_is_available(backend.attention_backend, is_avail), + vendor="txda", + priority=BackendPriority.VENDOR, + ), + ] + + registry.register_many(impls) diff --git a/vllm_fl/dispatch/backends/vendor/txda/txda.py b/vllm_fl/dispatch/backends/vendor/txda/txda.py new file mode 100644 index 00000000..f2b59f81 --- /dev/null +++ b/vllm_fl/dispatch/backends/vendor/txda/txda.py @@ -0,0 +1,150 @@ +# Copyright (c) 2026 BAAI. All rights reserved. + +""" +Txda backend implementation. + +This backend provides operator implementations for Tsingmiocro Txda NPUs. +""" + +from __future__ import annotations + +from typing import Optional, Union + +import torch + +# from vllm_fl.dispatch.backends.flaggems import FlagGemsBackend +from vllm_fl.dispatch.backends.base import Backend + + +class TxdaBackend(Backend): + """ + Txda backend for operator implementations. + + This backend uses Txda CANN libraries to provide high-performance + operator implementations for Tsingmiocro Txda NPUs. + """ + + _available: Optional[bool] = None + + @property + def name(self) -> str: + return "txda" + + @property + def vendor(self) -> Optional[str]: + return "txda" + + def is_available(self) -> bool: + """Check if Txda hardware and libraries are available.""" + if TxdaBackend._available is None: + try: + # Check for torch_npu (Txda PyTorch extension) + import torch_txda + + # Check if NPU device is available + if torch.txda.is_available() and torch.txda.device_count() > 0: + TxdaBackend._available = True + else: + TxdaBackend._available = False + except (ImportError, AttributeError): + TxdaBackend._available = False + return TxdaBackend._available + + # ==================== Operator Implementations ==================== + + # def silu_and_mul(self, obj, x: torch.Tensor) -> torch.Tensor: + # """ + # SiLU activation followed by element-wise multiplication. + + # Args: + # obj: The calling obj (for interface consistency) + # x: Input tensor of shape [..., 2*d] + + # Returns: + # Output tensor of shape [..., d] + # """ + # from .impl.activation import silu_and_mul_Txda + + # return silu_and_mul_Txda(obj, x) + + # def rms_norm( + # self, + # obj, + # x: torch.Tensor, + # residual: Optional[torch.Tensor] = None, + # ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: + # """ + # RMS normalization. + + # Args: + # obj: The calling obj (e.g., RMSNorm layer) + # x: Input tensor + # residual: Optional residual tensor + + # Returns: + # Normalized tensor, or tuple of (normalized, residual) if residual is provided + # """ + # from .impl.normalization import rms_norm_Txda + + # return rms_norm_Txda(obj, x, residual) + + # def rotary_embedding( + # self, + # obj, + # query: torch.Tensor, + # key: torch.Tensor, + # cos: torch.Tensor, + # sin: torch.Tensor, + # position_ids: torch.Tensor, + # rotary_interleaved: bool = False, + # inplace: bool = True, + # ) -> tuple[torch.Tensor, torch.Tensor]: + # """ + # Apply rotary position embedding. + + # Args: + # obj: The calling obj (for interface consistency) + # query: Query tensor + # key: Key tensor + # cos: Cosine cache + # sin: Sine cache + # position_ids: Position indices + # rotary_interleaved: Whether to use interleaved rotary + # inplace: Whether to modify tensors in-place + + # Returns: + # Tuple of (embedded_query, embedded_key) + # """ + # from .impl.rotary import rotary_embedding_Txda + + # return rotary_embedding_Txda( + # obj, + # query, + # key, + # cos, + # sin, + # position_ids, + # rotary_interleaved=rotary_interleaved, + # inplace=inplace, + # ) + + def attention_backend(self, use_mla: bool = False) -> str: + """ + Get the attention backend class path for Txda NPU. + + This method returns the native Txda attention backend that uses + torch_npu operators (npu_fused_infer_attention_score, etc.) + instead of flag_gems operators. + + Uses vllm_fl's native Txda implementation which directly calls + torch_npu operators without depending on vllm-Txda package. + + Args: + use_mla: Whether to use Multi-head Latent Attention (MLA) + + Returns: + Fully qualified class path string + """ + if use_mla: + return "vllm_fl.dispatch.backends.flaggems.impl.mla.MLAFLBackend" + return "vllm_fl.dispatch.backends.flaggems.impl.attention.AttentionFLBackend" diff --git a/vllm_fl/distributed/device_communicators/flagcx.py b/vllm_fl/distributed/device_communicators/flagcx.py index d0db73d7..228a1e7d 100644 --- a/vllm_fl/distributed/device_communicators/flagcx.py +++ b/vllm_fl/distributed/device_communicators/flagcx.py @@ -72,7 +72,8 @@ def __init__( ### TODO(lms): simplify it if library_path is None: flagcx_path = os.getenv('FLAGCX_PATH') - library_path=os.path.join(flagcx_path, "build/lib/libflagcx.so") + #library_path=os.path.join(flagcx_path, "libflagcx.so") # rcy fix + library_path= "/usr/local/kuiper/lib/libflagcx.so" self.flagcx = FLAGCXLibrary(library_path) else: self.flagcx = FLAGCXLibrary(library_path) @@ -113,7 +114,8 @@ def __init__( # nccl communicator and stream will use this device # `torch.cuda.device` is a context manager that changes the # current cuda device to the specified one - with torch.cuda.device(device): + # + with torch.txda.device(device): self.comm = self.flagcx.flagcxCommInitRank( self.world_size, ctypes.byref(self.unique_id), self.rank) @@ -144,6 +146,12 @@ def all_reduce(self, if stream is None: stream = current_stream() flagcx_stream = self.flagcx.adaptor_stream_copy(stream) + change_type = False + if in_tensor.dtype == torch.bfloat16: + in_tensor = in_tensor.to(torch.float32) + out_tensor = out_tensor.to(torch.float32) + change_type = True + self.flagcx.flagcxAllReduce(buffer_type(in_tensor.data_ptr()), buffer_type(out_tensor.data_ptr()), in_tensor.numel(), @@ -151,6 +159,10 @@ def all_reduce(self, flagcxRedOpTypeEnum.from_torch(op), self.comm, flagcx_stream) self.flagcx.adaptor_stream_free(flagcx_stream) + if change_type: + in_tensor = in_tensor.to(torch.bfloat16) + out_tensor = out_tensor.to(torch.bfloat16) + return out_tensor def all_gather(self, diff --git a/vllm_fl/models/fla_ops.py b/vllm_fl/models/fla_ops.py index ca3c18dc..c1cabea2 100644 --- a/vllm_fl/models/fla_ops.py +++ b/vllm_fl/models/fla_ops.py @@ -7,50 +7,50 @@ ) logger = init_logger(__name__) -def fi_chunk_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: torch.LongTensor | None = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = True, -): - from flashinfer.gdn_prefill import ( - chunk_gated_delta_rule as chunk_gated_delta_rule_fi, - ) +# def fi_chunk_gated_delta_rule( +# q: torch.Tensor, +# k: torch.Tensor, +# v: torch.Tensor, +# g: torch.Tensor, +# beta: torch.Tensor, +# initial_state: torch.Tensor, +# output_final_state: bool, +# cu_seqlens: torch.LongTensor | None = None, +# head_first: bool = False, +# use_qk_l2norm_in_kernel: bool = True, +# ): +# from flashinfer.gdn_prefill import ( +# chunk_gated_delta_rule as chunk_gated_delta_rule_fi, +# ) - if use_qk_l2norm_in_kernel: - q = l2norm_fwd(q) - k = l2norm_fwd(k) +# if use_qk_l2norm_in_kernel: +# q = l2norm_fwd(q) +# k = l2norm_fwd(k) - # use flashinfer implementation - q = q.squeeze(0).contiguous() - k = k.squeeze(0).contiguous() - v = v.squeeze(0).contiguous() +# # use flashinfer implementation +# q = q.squeeze(0).contiguous() +# k = k.squeeze(0).contiguous() +# v = v.squeeze(0).contiguous() - g = g.squeeze(0).contiguous() - beta = beta.squeeze(0).contiguous() - fi_state = initial_state.to(torch.float32) - fi_g = g.to(torch.float32) - fi_beta = beta.to(torch.float32) - output, final_state = chunk_gated_delta_rule_fi( - q=q, - k=k, - v=v, - g=torch.exp(fi_g), - beta=fi_beta, - initial_state=fi_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - ) - # Unsqueeze back to 4D (1, L, H, D) to match fla output format - return output.unsqueeze(0), final_state +# g = g.squeeze(0).contiguous() +# beta = beta.squeeze(0).contiguous() +# fi_state = initial_state.to(torch.float32) +# fi_g = g.to(torch.float32) +# fi_beta = beta.to(torch.float32) +# output, final_state = chunk_gated_delta_rule_fi( +# q=q, +# k=k, +# v=v, +# g=torch.exp(fi_g), +# beta=fi_beta, +# initial_state=fi_state, +# output_final_state=output_final_state, +# cu_seqlens=cu_seqlens, +# ) +# # Unsqueeze back to 4D (1, L, H, D) to match fla output format +# return output.unsqueeze(0), final_state -@CustomOp.register("chunk_gated_delta_rule") +@CustomOp.register("chunk_gated_delta_rule_txda") class ChunkGatedDeltaRuleOp(CustomOp): def __init__(self) -> None: super().__init__() @@ -62,31 +62,31 @@ def __init__(self) -> None: else: self._forward_method = self.forward_native - def forward_cuda( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: torch.LongTensor | None = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = True, - ): - return fi_chunk_gated_delta_rule( - q=q, - k=k, - v=v, - g=g, - beta=beta, - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - head_first=head_first, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) + # def forward_cuda( + # self, + # q: torch.Tensor, + # k: torch.Tensor, + # v: torch.Tensor, + # g: torch.Tensor, + # beta: torch.Tensor, + # initial_state: torch.Tensor, + # output_final_state: bool, + # cu_seqlens: torch.LongTensor | None = None, + # head_first: bool = False, + # use_qk_l2norm_in_kernel: bool = True, + # ): + # return fi_chunk_gated_delta_rule( + # q=q, + # k=k, + # v=v, + # g=g, + # beta=beta, + # initial_state=initial_state, + # output_final_state=output_final_state, + # cu_seqlens=cu_seqlens, + # head_first=head_first, + # use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + # ) def forward_native( self, diff --git a/vllm_fl/models/qwen3_5.py b/vllm_fl/models/qwen3_5.py index c4117d18..f9582650 100644 --- a/vllm_fl/models/qwen3_5.py +++ b/vllm_fl/models/qwen3_5.py @@ -60,7 +60,16 @@ ) # from vllm_fl.models.qwen3_next import ( -from vllm.model_executor.models.qwen3_next import ( +# from vllm.model_executor.models.qwen3_next import ( +# Qwen3NextAttention, +# Qwen3NextDecoderLayer, +# Qwen3NextGatedDeltaNet, +# Qwen3NextModel, +# Qwen3NextSparseMoeBlock, +# QwenNextMixtureOfExperts, +# ) +# rcy fix +from vllm_fl.models.qwen3_next import ( Qwen3NextAttention, Qwen3NextDecoderLayer, Qwen3NextGatedDeltaNet, @@ -122,7 +131,10 @@ def __init__(self, config, model_config=None, cache_config=None, quant_config=None, speculative_config=None, prefix=""): # Call grandparent init to skip Qwen3NextGatedDeltaNet.__init__ # but set up the same attributes - nn.Module.__init__(self) + # rcy fix + super().__init__(config, model_config, cache_config, + quant_config, speculative_config, prefix) + #nn.Module.__init__(self) from vllm.distributed import ( divide, get_tensor_model_parallel_rank, @@ -245,8 +257,9 @@ def __init__(self, config, model_config=None, cache_config=None, ) compilation_config = get_current_vllm_config().compilation_config - if prefix in compilation_config.static_forward_context: - raise ValueError(f"Duplicate layer name: {prefix}") + + # if prefix in compilation_config.static_forward_context: + # raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp() diff --git a/vllm_fl/models/qwen3_next.py b/vllm_fl/models/qwen3_next.py index e1695a1f..500906a1 100644 --- a/vllm_fl/models/qwen3_next.py +++ b/vllm_fl/models/qwen3_next.py @@ -89,8 +89,9 @@ make_layers, maybe_prefix, ) - -from vllm_fl.ops.fla import ChunkGatedDeltaRuleOp, FusedRecurrentGatedDeltaRuleOp +# +from vllm_fl.models.fla_ops import ChunkGatedDeltaRuleOp +from vllm_fl.ops.fla import FusedRecurrentGatedDeltaRuleOp logger = init_logger(__name__) @@ -356,10 +357,12 @@ def __init__( if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self - self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp( - output_final_state = True, - use_qk_l2norm_in_kernel=True, - ) + # self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp( + # output_final_state = True, + # use_qk_l2norm_in_kernel=True, + # ) rcy fix + self.chunk_gated_delta_rule = ChunkGatedDeltaRuleOp() + self.fused_recurrent_gated_delta_rule_multi_query = FusedRecurrentGatedDeltaRuleOp( inplace_final_state=True, @@ -655,7 +658,10 @@ def _forward_core( g=g_non_spec, beta=beta_non_spec, initial_state=initial_state, + output_final_state=True, cu_seqlens=non_spec_query_start_loc, + head_first=False, + use_qk_l2norm_in_kernel=True ) # Init cache ssm_state[non_spec_state_indices_tensor] = last_recurrent_state.to( diff --git a/vllm_fl/ops/fla/chunk.py b/vllm_fl/ops/fla/chunk.py index 403ba297..de02168f 100644 --- a/vllm_fl/ops/fla/chunk.py +++ b/vllm_fl/ops/fla/chunk.py @@ -6,11 +6,13 @@ from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.fla.ops.l2norm import l2norm_fwd -from vllm.model_executor.layers.fla.ops.utils import input_guard +# from vllm.model_executor.layers.fla.ops.utils import input_guard +from flag_gems.fused.FLA.utils import input_guard from vllm_fl.utils import use_flaggems_op -if use_flaggems_op("chunk_gated_delta_rule_fwd"): +# if use_flaggems_op("chunk_gated_delta_rule_fwd"): +if True: from flag_gems.fused.FLA import chunk_gated_delta_rule_fwd else: from vllm.model_executor.layers.fla.ops.chunk import chunk_gated_delta_rule_fwd @@ -19,7 +21,7 @@ class ChunkGatedDeltaRuleFunction(torch.autograd.Function): @staticmethod @input_guard - @torch.amp.custom_fwd(device_type="cuda") + @torch.amp.custom_fwd(device_type="txda") def forward( ctx, q: torch.Tensor, diff --git a/vllm_fl/ops/fla/fused_recurrent.py b/vllm_fl/ops/fla/fused_recurrent.py index 9519b807..ebd2a763 100644 --- a/vllm_fl/ops/fla/fused_recurrent.py +++ b/vllm_fl/ops/fla/fused_recurrent.py @@ -6,13 +6,13 @@ from vllm_fl.utils import use_flaggems_op -if use_flaggems_op("fused_recurrent_gated_delta_rule_fwd"): +if True: from flag_gems.fused.FLA import fused_recurrent_gated_delta_rule_fwd else: from vllm.model_executor.layers.fla.ops.fused_recurrent import ( fused_recurrent_gated_delta_rule_fwd, ) - +from flag_gems.fused.FLA.utils import input_guard class FusedRecurrentFunction(torch.autograd.Function): @staticmethod diff --git a/vllm_fl/platform.py b/vllm_fl/platform.py index cc2579fa..a28211da 100644 --- a/vllm_fl/platform.py +++ b/vllm_fl/platform.py @@ -46,21 +46,18 @@ class PlatformFL(Platform): dispatch_key = device_info.dispatch_key torch_device_fn = device_info.torch_device_fn vendor_name = device_info.vendor_name - ray_device_key: str = "GPU" + ray_device_key: str = "flagos" dist_backend: str = "flagcx" if "FLAGCX_PATH" in os.environ else "nccl" + device_control_env_var: str = "TXDA_VISIBLE_DEVICES" ### TODO(lms): dispatch device_control_env_var # device_control_env_var: str = "CUDA_VISIBLE_DEVICES" def is_cuda_alike(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" - if self.vendor_name == "iluvatar": - return False return self.device_type == "cuda" def is_cuda(self) -> bool: """Stateless version of [torch.cuda.is_available][].""" - if self.vendor_name == "iluvatar": - return False return self.device_type == "cuda" @property @@ -78,9 +75,10 @@ def check_if_supports_dtype(cls, torch_dtype: torch.dtype): def get_current_memory_usage( cls, device: Optional[torch.types.Device] = None ) -> float: - cls.torch_device_fn.empty_cache() - cls.torch_device_fn.reset_peak_memory_stats(device) - return cls.torch_device_fn.max_memory_allocated(device) + # cls.torch_device_fn.empty_cache() + # cls.torch_device_fn.reset_peak_memory_stats(device) + # return cls.torch_device_fn.max_memory_allocated(device) + return 1.0 @classmethod def set_device(cls, device: torch.device) -> None: @@ -91,7 +89,8 @@ def set_device(cls, device: torch.device) -> None: @classmethod def empty_cache(cls) -> None: - cls.torch_device_fn.empty_cache() + # cls.torch_device_fn.empty_cache() + pass @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -236,7 +235,7 @@ def get_static_graph_wrapper_cls(cls) -> str: @classmethod def support_static_graph_mode(cls) -> bool: - if cls.device_name in ["cuda", "npu"]: + if cls.device_name in ["cuda", "npu", "txda"]: return True return False @@ -284,6 +283,8 @@ def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: # TODO(yxa): For NPU/Ascend devices, return None (no capability version like CUDA) if cls.device_type == "npu": return None + elif cls.device_type == "txda": + return DeviceCapability(major=8, minor=1) # For CUDA devices major, minor = torch.cuda.get_device_capability(device_id) return DeviceCapability(major=major, minor=minor) diff --git a/vllm_fl/worker/model_runner.py b/vllm_fl/worker/model_runner.py index ae2ee0ef..6b2b64c0 100644 --- a/vllm_fl/worker/model_runner.py +++ b/vllm_fl/worker/model_runner.py @@ -646,7 +646,7 @@ def reset_mm_cache(self) -> None: if self.mm_budget: self.mm_budget.reset_cache() - @torch.inference_mode() + # @torch.inference_mode() def init_fp8_kv_scales(self) -> None: """ Re-initialize the KV cache and FP8 scales after waking from sleep. @@ -2934,7 +2934,7 @@ def _register_layerwise_nvtx_hooks(self) -> None: pyt_hooks.register_hooks(self.model, self.model.__class__.__name__) self.layerwise_nvtx_hooks_registered = True - @torch.inference_mode() + # @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -2986,7 +2986,8 @@ def execute_model( # returns True. before returning early here we call # dummy run to ensure coordinate_batch_across_dp # is called into to avoid out of sync issues. - self._dummy_run(1) + # self._dummy_run(1) + pass if not has_kv_transfer_group(): # Return empty ModelRunnerOutput if no work to do. return EMPTY_MODEL_RUNNER_OUTPUT @@ -3193,7 +3194,7 @@ def execute_model( self.kv_connector_output = kv_connector_output return None - @torch.inference_mode + # @torch.inference_mode def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput | IntermediateTensors: @@ -4000,7 +4001,7 @@ def _get_mm_dummy_batch( ) ) - @torch.inference_mode() + # @torch.inference_mode() def _dummy_run( self, num_tokens: int, @@ -4298,7 +4299,7 @@ def _dummy_run( ) return hidden_states, hidden_states[logit_indices_device] - @torch.inference_mode() + # @torch.inference_mode() def _dummy_sampler_run( self, hidden_states: torch.Tensor, @@ -4429,7 +4430,7 @@ def _dummy_pooler_run_task( else: raise e - @torch.inference_mode() + # @torch.inference_mode() def _dummy_pooler_run( self, hidden_states: torch.Tensor, @@ -4505,20 +4506,20 @@ def profile_run(self) -> None: self.encoder_cache[f"tmp_{i}"] = output # Add `is_profile` here to pre-allocate communication buffers - hidden_states, last_hidden_states = self._dummy_run( - self.max_num_tokens, is_profile=True - ) - if get_pp_group().is_last_rank: - if self.is_pooling_model: - output = self._dummy_pooler_run(hidden_states) - else: - output = self._dummy_sampler_run(last_hidden_states) - else: - output = None - self._sync_device() - del hidden_states, output - self.encoder_cache.clear() - gc.collect() + # hidden_states, last_hidden_states = self._dummy_run( + # self.max_num_tokens, is_profile=True + # ) + #if get_pp_group().is_last_rank: + #if self.is_pooling_model: + #output = self._dummy_pooler_run(hidden_states) + #else: + #output = self._dummy_sampler_run(last_hidden_states) + #else: + #output = None + #self._sync_device() + #del hidden_states, output + #self.encoder_cache.clear() + #gc.collect() def capture_model(self) -> int: if self.compilation_config.cudagraph_mode == CUDAGraphMode.NONE: diff --git a/vllm_fl/worker/worker.py b/vllm_fl/worker/worker.py index 48755cb9..e752db6e 100644 --- a/vllm_fl/worker/worker.py +++ b/vllm_fl/worker/worker.py @@ -29,18 +29,7 @@ get_kv_transfer_group, has_kv_transfer_group, ) -# from vllm.model_executor.warmup.kernel_warmup import kernel_warmup -try: - from vllm.model_executor.warmup.kernel_warmup import kernel_warmup -except ImportError: - # deep_gemm may be broken in some environments; provide a fallback - import logging as _logging - _logging.getLogger(__name__).warning( - "kernel_warmup import failed (likely deep_gemm issue), " - "using no-op kernel_warmup" - ) - def kernel_warmup(worker): - pass +from vllm.model_executor.warmup.kernel_warmup import kernel_warmup from vllm.distributed.parallel_state import ( get_pcp_group, get_pp_group, @@ -391,21 +380,21 @@ def init_device(self): ### TODO(lms): patch MemorySnapshot in other platform # take current memory snapshot - self.init_snapshot = MemorySnapshot() - self.requested_memory = ( - self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization - ) - if self.init_snapshot.free_memory < self.requested_memory: - GiB = lambda b: round(b / GiB_bytes, 2) - raise ValueError( - f"Free memory on device " - f"({GiB(self.init_snapshot.free_memory)}/" - f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " - f"is less than desired GPU memory utilization " - f"({self.cache_config.gpu_memory_utilization}, " - f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " - f"utilization or reduce GPU memory used by other processes." - ) + # self.init_snapshot = MemorySnapshot() + # self.requested_memory = ( + # self.init_snapshot.total_memory * self.cache_config.gpu_memory_utilization + # ) + # if self.init_snapshot.free_memory < self.requested_memory: + # GiB = lambda b: round(b / GiB_bytes, 2) + # raise ValueError( + # f"Free memory on device " + # f"({GiB(self.init_snapshot.free_memory)}/" + # f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " + # f"is less than desired GPU memory utilization " + # f"({self.cache_config.gpu_memory_utilization}, " + # f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " + # f"utilization or reduce GPU memory used by other processes." + # ) # Initialize workspace manager num_ubatches = 2 if self.vllm_config.parallel_config.enable_dbo else 1 init_workspace_manager(self.device, num_ubatches) @@ -434,7 +423,7 @@ def update_config(self, overrides: dict[str, Any]) -> None: def reload_weights(self) -> None: self.model_runner.reload_weights() - @torch.inference_mode() + # @torch.inference_mode() def determine_available_memory(self) -> int: """Profiles the peak memory usage of the model to determine how much memory can be used for KV cache without OOMs. @@ -447,79 +436,80 @@ def determine_available_memory(self) -> int: You may limit the usage of GPU memory by adjusting the `gpu_memory_utilization` parameter. """ - GiB = lambda b: b / GiB_bytes - if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: - # still need a profile run which compiles the model for - # max_num_batched_tokens - self.model_runner.profile_run() - - msg = ( - f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} " - f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for " - "KV Cache as specified by kv_cache_memory_bytes config and " - "skipped memory profiling. This does not respect the " - "gpu_memory_utilization config. Only use kv_cache_memory_bytes " - "config when you want manual control of KV cache memory " - "size. If OOM'ed, check the difference of initial free " - "memory between the current run and the previous run " - "where kv_cache_memory_bytes is suggested and update it " - "correspondingly." - ) - logger.info(msg) - return kv_cache_memory_bytes - - current_platform.empty_cache() - current_platform.torch_device_fn.reset_peak_memory_stats() - - # Execute a forward pass with dummy inputs to profile the memory usage - # of the model. - with memory_profiling_fl( - self.init_snapshot, - weights_memory=int(self.model_runner.model_memory_usage), - ) as profile_result: - self.model_runner.profile_run() - - self.non_torch_memory = profile_result.non_torch_increase - self.peak_activation_memory = profile_result.torch_peak_increase - - free_gpu_memory = profile_result.after_profile.free_memory - # NOTE(woosuk): Here we assume that the other processes using the same - # GPU did not change their memory usage during the profiling. - assert self.init_snapshot.free_memory > free_gpu_memory, ( - "Error in memory profiling. " - f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " - f"current free memory {GiB(free_gpu_memory)} GiB. " - "This happens when other processes sharing the same container " - "release GPU memory while vLLM is profiling during initialization. " - "To fix this, ensure consistent GPU memory allocation or " - "isolate vLLM in its own container." - ) - self.available_kv_cache_memory_bytes = ( - self.requested_memory - profile_result.non_kv_cache_memory - ) - - unrequested_memory = self.init_snapshot.free_memory - self.requested_memory - logger.debug( - "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", - GiB(self.init_snapshot.free_memory), - self.cache_config.gpu_memory_utilization, - GiB(self.requested_memory), - ) - logger.debug( - "Free memory after profiling: %.2f GiB (total), " - "%.2f GiB (within requested)", - GiB(free_gpu_memory), - GiB(free_gpu_memory - unrequested_memory), - ) - logger.debug(profile_result) - logger.info_once( - "Available KV cache memory: %.2f GiB", - GiB(self.available_kv_cache_memory_bytes), - scope="local", - ) - gc.collect() - - return int(self.available_kv_cache_memory_bytes) + # GiB = lambda b: b / GiB_bytes + # if kv_cache_memory_bytes := self.cache_config.kv_cache_memory_bytes: + # # still need a profile run which compiles the model for + # # max_num_batched_tokens + # self.model_runner.profile_run() + + # msg = ( + # f"Initial free memory {GiB(self.init_snapshot.free_memory):.2f} " + # f"GiB, reserved {GiB(kv_cache_memory_bytes):.2f} GiB memory for " + # "KV Cache as specified by kv_cache_memory_bytes config and " + # "skipped memory profiling. This does not respect the " + # "gpu_memory_utilization config. Only use kv_cache_memory_bytes " + # "config when you want manual control of KV cache memory " + # "size. If OOM'ed, check the difference of initial free " + # "memory between the current run and the previous run " + # "where kv_cache_memory_bytes is suggested and update it " + # "correspondingly." + # ) + # logger.info(msg) + # return kv_cache_memory_bytes + + # current_platform.empty_cache() + # current_platform.torch_device_fn.reset_peak_memory_stats() + + # # Execute a forward pass with dummy inputs to profile the memory usage + # # of the model. + # with memory_profiling_fl( + # self.init_snapshot, + # weights_memory=int(self.model_runner.model_memory_usage), + # ) as profile_result: + # self.model_runner.profile_run() + + # self.non_torch_memory = profile_result.non_torch_increase + # self.peak_activation_memory = profile_result.torch_peak_increase + + # free_gpu_memory = profile_result.after_profile.free_memory + # # NOTE(woosuk): Here we assume that the other processes using the same + # # GPU did not change their memory usage during the profiling. + # assert self.init_snapshot.free_memory > free_gpu_memory, ( + # "Error in memory profiling. " + # f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " + # f"current free memory {GiB(free_gpu_memory)} GiB. " + # "This happens when other processes sharing the same container " + # "release GPU memory while vLLM is profiling during initialization. " + # "To fix this, ensure consistent GPU memory allocation or " + # "isolate vLLM in its own container." + # ) + # self.available_kv_cache_memory_bytes = ( + # self.requested_memory - profile_result.non_kv_cache_memory + # ) + + # unrequested_memory = self.init_snapshot.free_memory - self.requested_memory + # logger.debug( + # "Initial free memory: %.2f GiB; Requested memory: %.2f (util), %.2f GiB", + # GiB(self.init_snapshot.free_memory), + # self.cache_config.gpu_memory_utilization, + # GiB(self.requested_memory), + # ) + # logger.debug( + # "Free memory after profiling: %.2f GiB (total), " + # "%.2f GiB (within requested)", + # GiB(free_gpu_memory), + # GiB(free_gpu_memory - unrequested_memory), + # ) + # logger.debug(profile_result) + # logger.info_once( + # "Available KV cache memory: %.2f GiB", + # GiB(self.available_kv_cache_memory_bytes), + # scope="local", + # ) + # gc.collect() + + # return int(self.available_kv_cache_memory_bytes) + return 20*1024*1024*1024 # def get_kv_connector_handshake_metadata(self) -> dict | None: """Get KV connector metadata from this worker if available.""" @@ -667,15 +657,15 @@ def compile_or_warm_up_model(self) -> None: ) # We skip EPLB here since we don't want to record dummy metrics - hidden_states, last_hidden_states = self.model_runner._dummy_run( - num_tokens=max_num_reqs, - skip_eplb=True, - cudagraph_runtime_mode=CUDAGraphMode.NONE, - ) - if self.model_runner.is_pooling_model: - self.model_runner._dummy_pooler_run(hidden_states) - else: - self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) + # hidden_states, last_hidden_states = self.model_runner._dummy_run( + # num_tokens=max_num_reqs, + # skip_eplb=True, + # cudagraph_runtime_mode=CUDAGraphMode.NONE, + # ) + # if self.model_runner.is_pooling_model: + # self.model_runner._dummy_pooler_run(hidden_states) + # else: + # self.model_runner._dummy_sampler_run(hidden_states=last_hidden_states) # Reset the seed to ensure that the random state is not affected by # the model initialization and profiling. @@ -705,13 +695,13 @@ def annotate_profile(self, scheduler_output): f"execute_new_{num_new}_cached_{num_cached}" ) - @torch.inference_mode() + # @torch.inference_mode() def sample_tokens( self, grammar_output: "GrammarOutput | None" ) -> ModelRunnerOutput | AsyncModelRunnerOutput: return self.model_runner.sample_tokens(grammar_output) - @torch.inference_mode() + # @torch.inference_mode() def execute_model( self, scheduler_output: "SchedulerOutput", @@ -1064,10 +1054,10 @@ def init_worker_distributed_environment( local_rank: int = -1, backend: str = "nccl", ) -> None: + backend = "flagcx" #txda """Initialize the distributed environment.""" attention_config = vllm_config.attention_config parallel_config = vllm_config.parallel_config - set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) from vllm.model_executor.layers.batch_invariant import init_batch_invariance init_batch_invariance(attention_config.backend)