diff --git a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py index 97167e0e8..82f5ce599 100644 --- a/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py +++ b/byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py @@ -133,7 +133,7 @@ def signal_handler(signum, frame): logger.info(f"rank {local_rank} received signal {signum}, exiting...") if hasattr(model, 'finalize_inference'): model.finalize_inference() - os._exit(0) + sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -195,4 +195,98 @@ def mp_forward(self, *args): output_dict = self._output_queues.get(block=True) return output_dict - + +# ROCM_HIPGRAPH modify +class GpuMpEngineWithGraph(GpuMpEngine): + def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None: + super().__init__(world_size, model_impl, xpu_cfg) + logger.info("@@@@@@@@@@ GpuMpEngineWithGraph") + + @torch.no_grad() + def mp_loop_worker( + self, + local_rank: int, + world_size: int, + input_queue: Queue, + output_queue: Queue, + model_impl, + xpu_config + ): + try: + torch.manual_seed(1) + + # set rank and world_size + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + + # create and init model based on model_impl and xpu_config + model = model_impl(xpu_config) + if hasattr(model, 'init_inference'): + model.init_inference() + + def signal_handler(signum, frame): + logger.info(f"rank {local_rank} received signal {signum}, exiting...") + if hasattr(model, 'finalize_inference'): + model.finalize_inference() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # current rank is ready + output_queue.put("ready", block=True) + logger.info(f"{local_rank}/{world_size} rank is ready") + + graph = torch.cuda.CUDAGraph() + + # model process loop + while True: + ( + forward_inputs, + ) = input_queue.get(block=True) + + # this is the capture phase of graph + if 'capture' in forward_inputs: + graph.reset() # reset cuda graph each time + inputs_dict = self.build_inputs(forward_inputs) + # model.forward(inputs_dict) + torch.cuda.synchronize() + with torch.cuda.graph(graph): + model.forward(inputs_dict) + torch.cuda.synchronize() + continue + + log = forward_inputs.get("log", False) + workspace = forward_inputs.get("workspace", None) + + forward_inputs["log_file"] = None + if log and workspace is not None: + workspace_dir = workspace / f"rank_{local_rank}" + workspace_dir.mkdir(exist_ok=True, parents=True) + forward_inputs["log_file"] = open(workspace_dir / "run.log", "w") + + + inputs_dict = self.build_inputs(forward_inputs) + start_time = time.perf_counter_ns() + + # output_dict = model.forward(inputs_dict) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter_ns() + duration_ms = round((end_time - start_time) / 1e6, 3) + output_dict = dict() + output_dict["duration_ms"] = duration_ms + + # TP realization: rank0 send result back to main process + if local_rank == 0: + output_queue.put(output_dict) + + if log and workspace is not None: + forward_inputs["log_file"].close() + + except Exception as e: + logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}") + output_queue.put(RuntimeError("[BUG] fatal exception in model subprocess")) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/gpu_ckpt_loader.py b/byte_infer_perf/llm_perf/backends/ROCM/gpu_ckpt_loader.py new file mode 100644 index 000000000..8bbc799d3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/gpu_ckpt_loader.py @@ -0,0 +1,51 @@ +import torch +import torch.distributed as dist + +from llm_perf.core.ckpt_loader import CoreCkptLoader + +class GpuCkptLoader(CoreCkptLoader): + def __init__( + self, + prefix, model, + mp_size=1, mp_rank=0, + ckpt_path: str="" + ): + super().__init__(prefix, model, mp_size, mp_rank, ckpt_path) + + + def weight_to_device(self, weight : torch.Tensor, non_blocking=False): + if self.mp_rank == 0: + weight = weight.cuda(non_blocking=non_blocking) + else: + cur_device = torch.cuda.current_device() + weight = torch.empty_like(weight, device=f"cuda:{cur_device}") + return weight + + + def broadcast_weight(self, key, device='cpu', non_blocking=False): + if self.mp_rank != 0: + tensor_shape = self.state_dict[key]["shape"] + tensor_dtype = self.state_dict[key]["dtype"] + tensor = torch.empty(tensor_shape, dtype=tensor_dtype) + else: + tensor = self.state_dict[key].cpu() + tensor_gpu = self.weight_to_device(tensor, non_blocking=non_blocking) + dist.broadcast(tensor_gpu, src=0) + self.state_dict[key] = tensor_gpu + + + def scatter_weight(self, key, dim, split_mode='default', outter=1, device='cpu', non_blocking=False): + self.broadcast_weight(key, non_blocking=non_blocking) + weight = self.state_dict[key] + + if split_mode == 'default': + weight_split = self.split(weight, dim) + elif split_mode == 'with_outter': + weight_split = self.with_outter_split(weight, dim, outter) + elif split_mode == 'split_outter': + weight_split = self.split(weight, dim, outter) + else: + assert False, f"unknown split mode {split_mode}" + + weight_split = [x.contiguous() for x in weight_split] + self.state_dict[key] = weight_split[self.mp_rank] diff --git a/byte_infer_perf/llm_perf/backends/ROCM/gpu_inferencer.py b/byte_infer_perf/llm_perf/backends/ROCM/gpu_inferencer.py new file mode 100644 index 000000000..a7220c452 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/gpu_inferencer.py @@ -0,0 +1,131 @@ +import os +from typing import Dict, List, Any +from dataclasses import dataclass + +from llm_perf.core.generation import GenerateRequest +from llm_perf.core.inferencer import CoreInferencer +from llm_perf.backends.ROCM.gpu_mp_engine import GpuMpEngine +from llm_perf.utils.logger import logger + +class GpuInferencer(CoreInferencer): + def __init__(self, model_impl, xpu_cfg): + super().__init__() + + self.tp_size = xpu_cfg["tp_size"] + self.pad_token_id = xpu_cfg["pad_token_id"] + self.max_batch_size = xpu_cfg["max_batch_size"] + self.mp_engine = GpuMpEngine(self.tp_size, model_impl, xpu_cfg) + + def prepare_inputs( + self, + tasks: List[CoreInferencer.Task], + **kwargs + ): + input_dict = { + "input_ids": None, + "position_ids": None, + "attention_mask": None, + "all_q_len": None, + "all_kv_len": None, + "is_context": None, + "valid_slot_ids": None + } + + is_context = kwargs.get("is_context") if "is_context" in kwargs.keys() else False + valid_slot_ids = kwargs.get("valid_slot_ids") if "valid_slot_ids" in kwargs.keys() else [i for i in range(self.max_batch_size)] + + + get_input_logits = False + for task in tasks: + if task.request.generate_config.get_input_logits: + get_input_logits = True + break + + input_dict["is_context"] = is_context + input_dict["valid_slot_ids"] = valid_slot_ids + input_dict["get_input_logits"] = get_input_logits + + if is_context: + q_len = len(tasks[0].request.input_ids) + kv_len = len(tasks[0].request.input_ids) + + input_dict["input_ids"] = [ + tasks[0].request.input_ids + ] + input_dict["position_ids"] = [ + [i for i in range(q_len)] + ] + input_dict["attention_mask"] = [ + [1 for _ in range(q_len)] + ] + input_dict["all_q_len"] = [ + q_len + ] + input_dict["all_kv_len"] = [ + kv_len + ] + else: + all_input_ids = [] + all_position_ids = [] + all_attention_mask = [] + all_q_len = [] + all_kv_len = [] + + for task in tasks: + q_len = 1 + kv_len = 0 + + if task is None: + kv_len = 1 + + input_ids = [ + self.pad_token_id + ] + position_ids = [ + 0 + ] + attention_mask = [ + 0 + ] + else: + kv_len = len(task.request.input_ids) + len(task.generate_ids) - 1 + + input_ids = [ + task.generate_ids[-1] + ] + position_ids = [ + kv_len + ] + attention_mask = [ + 1 + ] + all_input_ids.append(input_ids) + all_position_ids.append(position_ids) + all_attention_mask.append(attention_mask) + all_q_len.append(q_len) + all_kv_len.append(kv_len) + + input_dict["input_ids"] = all_input_ids + input_dict["position_ids"] = all_position_ids + input_dict["attention_mask"] = all_attention_mask + input_dict["all_q_len"] = all_q_len + input_dict["all_kv_len"] = all_kv_len + + return input_dict + + + def infer( + self, + tasks: List[CoreInferencer.Task], + **kwargs + ): + input_dict = self.prepare_inputs(tasks, **kwargs) + output_dict = self.mp_engine.mp_forward(input_dict) + + logits = output_dict["logits"] + next_token_logits = logits[:, -1, :].contiguous() + infer_outputs = { + "logits": logits, + "last_logits": next_token_logits + } + return infer_outputs diff --git a/byte_infer_perf/llm_perf/backends/ROCM/gpu_mp_engine.py b/byte_infer_perf/llm_perf/backends/ROCM/gpu_mp_engine.py new file mode 100644 index 000000000..9311fe961 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/gpu_mp_engine.py @@ -0,0 +1,298 @@ +import os +import sys +import time +import signal +import pathlib +from multiprocessing import Queue +from typing import List + +import torch +import torch.nn as nn +import torch.distributed as dist + +from llm_perf.core.mp_engine import CoreMpEngine +from llm_perf.utils.logger import logger + + +# context: +# input_ids: [1, s_q] +# attention_mask = [1, s_q] +# full_attention_mask = [1, 1, s_q, s_kv] (sq == s_kv) +def get_context_masks( + input_ids : torch.Tensor, + padding_mask : torch.Tensor +): + # input_ids: [1, q_len] + # padding_mask = [1, q_len] + _, q_len = input_ids.shape + + # [1, q_len, q_len] + full_attention_mask = torch.ones( + 1, q_len, q_len, + device=input_ids.device + ) + full_attention_mask.tril_() + full_attention_mask = full_attention_mask * padding_mask.unsqueeze(1) + full_attention_mask -= padding_mask.unsqueeze(-1) - 1 + full_attention_mask = (full_attention_mask < 0.5).bool() + full_attention_mask.unsqueeze_(1) + return full_attention_mask + + +# decode +# input_ids: [bs, 1] +# attention_mask = [bs, 1] +# full_attention_mask = [bs, 1, 1, s_kv] +def get_decode_masks( + input_ids : torch.Tensor, + all_kv_len: List[int] +): + # input_ids: [batch_size, 1] + # padding_mask: [batch_size, 1 + max_kv_len] + batch_size, q_len = input_ids.shape + max_qkv_len = q_len + max(all_kv_len) + + # # [batch_size, 1, max_qkv_len] + # padding_mask = [] + # for i in range(batch_size): + # cur_qkv_len = q_len + all_kv_len[i] + # mask_per_batch = [1] * cur_qkv_len + [0] * (max_qkv_len - cur_qkv_len) + # padding_mask.append(mask_per_batch) + # full_attention_mask = torch.tensor( + # padding_mask, + # device=input_ids.device + # ).unsqueeze_(1) + # full_attention_mask = (full_attention_mask < 0.5).bool() + # full_attention_mask.unsqueeze_(1) + seq_lens = torch.tensor([1 + y for y in all_kv_len], + dtype=torch.int, device=input_ids.device) + return seq_lens + + + +# basic TP realization mp engine +# 1. main process send all inputs to all subprocesses +# 2. subprocesses process inputs with same logic simultaneously and collaboratively using TP mechanism +# 3. suppose tp = 8, rank 0-7 receive same data, +# computing each part of data, using allreduce or allgather to gather data. +# then rank 0 sends data back to main process +class GpuMpEngine(CoreMpEngine): + def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None: + super().__init__(world_size, model_impl, xpu_cfg) + + + def build_inputs(self, forward_inputs): + # list --> torch.Tensor --> cuda + forward_inputs["input_ids"] = torch.tensor( + forward_inputs["input_ids"] + ).cuda() + forward_inputs["position_ids"] = torch.tensor( + forward_inputs["position_ids"] + ).cuda() + forward_inputs["attention_mask"] = torch.tensor( + forward_inputs["attention_mask"] + ).cuda() + + is_context = forward_inputs["is_context"] + if is_context: + forward_inputs["full_attention_mask"] = get_context_masks( + forward_inputs["input_ids"], + forward_inputs["attention_mask"] + ) + else: + forward_inputs["seq_lens_tensor"] = get_decode_masks( + forward_inputs["input_ids"], + forward_inputs["all_kv_len"] + ) + return forward_inputs + + + @torch.no_grad() + def mp_loop_worker( + self, + local_rank: int, + world_size: int, + input_queue: Queue, + output_queue: Queue, + model_impl, + xpu_config + ): + try: + torch.manual_seed(1) + + # set rank and world_size + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + + # create and init model based on model_impl and xpu_config + model = model_impl(xpu_config) + if hasattr(model, 'init_inference'): + model.init_inference() + + def signal_handler(signum, frame): + logger.info(f"rank {local_rank} received signal {signum}, exiting...") + if hasattr(model, 'finalize_inference'): + model.finalize_inference() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # current rank is ready + output_queue.put("ready", block=True) + logger.info(f"{local_rank}/{world_size} rank is ready") + + # model process loop + while True: + ( + forward_inputs, + ) = input_queue.get(block=True) + + log = forward_inputs.get("log", False) + workspace = forward_inputs.get("workspace", None) + + forward_inputs["log_file"] = None + if log and workspace is not None: + workspace_dir = workspace / f"rank_{local_rank}" + workspace_dir.mkdir(exist_ok=True, parents=True) + forward_inputs["log_file"] = open(workspace_dir / "run.log", "w") + + + inputs_dict = self.build_inputs(forward_inputs) + start_time = time.perf_counter_ns() + + output_dict = model.forward(inputs_dict) + + torch.cuda.synchronize() + end_time = time.perf_counter_ns() + duration_ms = round((end_time - start_time) / 1e6, 3) + output_dict["duration_ms"] = duration_ms + + # TP realization: rank0 send result back to main process + if local_rank == 0: + output_queue.put(output_dict) + + if log and workspace is not None: + forward_inputs["log_file"].close() + + except Exception as e: + logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}") + output_queue.put(RuntimeError("[BUG] fatal exception in model subprocess")) + + + def mp_forward(self, *args): + # extra args + # workspace: pathlib.Path, where to save files for each rank + # log: bool, whether to save logs to file + # override_hidden_states: bool, whether to override hidden_states + # random_seed: int, random seed for torch.manual_seed + + # send inputs to all subprocesses + for _ in range(self.world_size): + self._input_queues.put(args, block=True) + + # wait for one subprocess send result back to main process + output_dict = self._output_queues.get(block=True) + + return output_dict + +# ROCM_HIPGRAPH modify +class GpuMpEngineWithGraph(GpuMpEngine): + def __init__(self, world_size: int, model_impl: nn.Module, xpu_cfg) -> None: + super().__init__(world_size, model_impl, xpu_cfg) + logger.info("@@@@@@@@@@ GpuMpEngineWithGraph") + + @torch.no_grad() + def mp_loop_worker( + self, + local_rank: int, + world_size: int, + input_queue: Queue, + output_queue: Queue, + model_impl, + xpu_config + ): + try: + torch.manual_seed(1) + + # set rank and world_size + os.environ["RANK"] = str(local_rank) + os.environ["LOCAL_RANK"] = str(local_rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_WORLD_SIZE"] = str(world_size) + + # create and init model based on model_impl and xpu_config + model = model_impl(xpu_config) + if hasattr(model, 'init_inference'): + model.init_inference() + + def signal_handler(signum, frame): + logger.info(f"rank {local_rank} received signal {signum}, exiting...") + if hasattr(model, 'finalize_inference'): + model.finalize_inference() + sys.exit(0) + + signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGTERM, signal_handler) + + # current rank is ready + output_queue.put("ready", block=True) + logger.info(f"{local_rank}/{world_size} rank is ready") + + graph = torch.cuda.CUDAGraph() + + # model process loop + while True: + ( + forward_inputs, + ) = input_queue.get(block=True) + + # this is the capture phase of graph + if 'capture' in forward_inputs: + graph.reset() # reset cuda graph each time + inputs_dict = self.build_inputs(forward_inputs) + + random_seed = inputs_dict.pop("random_seed", 1) + torch.manual_seed(random_seed) + # for i in range(5): + # model.forward(inputs_dict) + torch.cuda.synchronize() + with torch.cuda.graph(graph): + model.forward(inputs_dict) + torch.cuda.synchronize() + continue + + log = forward_inputs.get("log", False) + workspace = forward_inputs.get("workspace", None) + + forward_inputs["log_file"] = None + if log and workspace is not None: + workspace_dir = workspace / f"rank_{local_rank}" + workspace_dir.mkdir(exist_ok=True, parents=True) + forward_inputs["log_file"] = open(workspace_dir / "run.log", "w") + + + inputs_dict = self.build_inputs(forward_inputs) + start_time = time.perf_counter_ns() + + # output_dict = model.forward(inputs_dict) + graph.replay() + + torch.cuda.synchronize() + end_time = time.perf_counter_ns() + duration_ms = round((end_time - start_time) / 1e6, 3) + output_dict = dict() + output_dict["duration_ms"] = duration_ms + + # TP realization: rank0 send result back to main process + if local_rank == 0: + output_queue.put(output_dict) + + if log and workspace is not None: + forward_inputs["log_file"].close() + + except Exception as e: + logger.exception(f"[BUG] engine _load_and_listen failed, no more requests will be handled. {e}") + output_queue.put(RuntimeError("[BUG] fatal exception in model subprocess")) \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/__init__.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/__init__.py new file mode 100644 index 000000000..94342ee5f --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/__init__.py @@ -0,0 +1,26 @@ +## __all__ is a dict: +## key is model_name in `model_zoo/chatglm-xx.json` +## value is vendor specify model impl +# __all__ = { +# "chatglm" : ChatGLMForConditionalGeneration, +# "chatglm2" : ChatGLM2ForConditionalGeneration +# } + +from typing import Dict, Tuple, Any + +import torch +import torch.nn as nn + +# from .gpu_chatglm2 import GPUChatGLM2 +# from .gpu_llama3 import GPULlama +# from .gpu_falcon import GPUFalcon +from .rocm_mixtral import GPUMixtral + +from llm_perf.utils.logger import logger + +__all__ = { + # "chatglm2": GPUChatGLM2, + # "llama3": GPULlama, + # "falcon": GPUFalcon, + "mixtral": GPUMixtral +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py new file mode 100644 index 000000000..7dc4fe29d --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py @@ -0,0 +1,1472 @@ +# coding=utf-8 +# Copyright 2023 Mistral AI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" PyTorch Mixtral model.""" +import os +import inspect +import math +import warnings +from typing import List, Optional, Tuple, Union + +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch import nn +from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss +import torch.distributed as dist + +from transformers.activations import ACT2FN +from transformers.cache_utils import Cache, DynamicCache +from transformers.modeling_attn_mask_utils import ( + _prepare_4d_causal_attention_mask, + _prepare_4d_causal_attention_mask_for_sdpa, +) +from transformers.modeling_outputs import ( + MoeCausalLMOutputWithPast, + MoeModelOutputWithPast, + SequenceClassifierOutputWithPast, +) +from transformers.modeling_utils import PreTrainedModel +from transformers.pytorch_utils import is_torch_greater_or_equal_than_1_13 +from transformers.utils import ( + add_start_docstrings, + add_start_docstrings_to_model_forward, + is_flash_attn_2_available, + is_flash_attn_greater_or_equal_2_10, + logging, + replace_return_docstrings, +) +from transformers.utils.import_utils import is_torch_fx_available +from transformers.models.mixtral.configuration_mixtral import MixtralConfig + +from ..rocm_kernels.fused_moe import fused_moe +from ..rocm_kernels import paged_attn +from ..rocm_kernels.paged_attn import PagedAttention +import rocmKernels as ops +from ..rocm_kernels.tuned_gemm import tgemm +from ..rocm_kernels.rotary_embedding import get_rope +from ..rocm_kernels.dist.communication_op import tensor_model_parallel_all_reduce + +if is_flash_attn_2_available(): + from flash_attn import flash_attn_func, flash_attn_varlen_func + from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa + + _flash_supports_window_size = "window_size" in list(inspect.signature(flash_attn_func).parameters) + +# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph. +# It means that the function will not be traced through and simply appear as a node in the graph. +if is_torch_fx_available(): + if not is_torch_greater_or_equal_than_1_13: + import torch.fx + + _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask) + + +logger = logging.get_logger(__name__) + +_CONFIG_FOR_DOC = "MixtralConfig" + + +def load_balancing_loss_func( + gate_logits: torch.Tensor, num_experts: torch.Tensor = None, top_k=2, attention_mask: Optional[torch.Tensor] = None +) -> float: + r""" + Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch. + + See Switch Transformer (https://arxiv.org/abs/2101.03961) for more details. This function implements the loss + function presented in equations (4) - (6) of the paper. It aims at penalizing cases where the routing between + experts is too unbalanced. + + Args: + gate_logits (Union[`torch.Tensor`, Tuple[torch.Tensor]): + Logits from the `gate`, should be a tuple of model.config.num_hidden_layers tensors of + shape [batch_size X sequence_length, num_experts]. + attention_mask (`torch.Tensor`, None): + The attention_mask used in forward function + shape [batch_size X sequence_length] if not None. + num_experts (`int`, *optional*): + Number of experts + + Returns: + The auxiliary loss. + """ + if gate_logits is None or not isinstance(gate_logits, tuple): + return 0 + + if isinstance(gate_logits, tuple): + compute_device = gate_logits[0].device + concatenated_gate_logits = torch.cat([layer_gate.to(compute_device) for layer_gate in gate_logits], dim=0) + + routing_weights = torch.nn.functional.softmax(concatenated_gate_logits, dim=-1) + + _, selected_experts = torch.topk(routing_weights, top_k, dim=-1) + + expert_mask = torch.nn.functional.one_hot(selected_experts, num_experts) + + if attention_mask is None: + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.mean(expert_mask.float(), dim=0) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.mean(routing_weights, dim=0) + else: + batch_size, sequence_length = attention_mask.shape + num_hidden_layers = concatenated_gate_logits.shape[0] // (batch_size * sequence_length) + + # Compute the mask that masks all padding tokens as 0 with the same shape of expert_mask + expert_attention_mask = ( + attention_mask[None, :, :, None, None] + .expand((num_hidden_layers, batch_size, sequence_length, top_k, num_experts)) + .reshape(-1, top_k, num_experts) + .to(compute_device) + ) + + # Compute the percentage of tokens routed to each experts + tokens_per_expert = torch.sum(expert_mask.float() * expert_attention_mask, dim=0) / torch.sum( + expert_attention_mask, dim=0 + ) + + # Compute the mask that masks all padding tokens as 0 with the same shape of tokens_per_expert + router_per_expert_attention_mask = ( + attention_mask[None, :, :, None] + .expand((num_hidden_layers, batch_size, sequence_length, num_experts)) + .reshape(-1, num_experts) + .to(compute_device) + ) + + # Compute the average probability of routing to these experts + router_prob_per_expert = torch.sum(routing_weights * router_per_expert_attention_mask, dim=0) / torch.sum( + router_per_expert_attention_mask, dim=0 + ) + + overall_loss = torch.sum(tokens_per_expert * router_prob_per_expert.unsqueeze(0)) + return overall_loss * num_experts + + +# Copied from transformers.models.llama.modeling_llama._get_unpad_data +def _get_unpad_data(attention_mask): + seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) + indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() + max_seqlen_in_batch = seqlens_in_batch.max().item() + cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) + return ( + indices, + cu_seqlens, + max_seqlen_in_batch, + ) + + +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Mixtral +class MixtralRMSNorm(nn.Module): + def __init__(self, hidden_size, eps=1e-6): + """ + MixtralRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.variance_epsilon = eps + + def forward(self, hidden_states, residual=None): + # # ori implementation + # input_dtype = hidden_states.dtype + # if residual is not None: + # hidden_states = hidden_states + residual + # residual = hidden_states + # hidden_states = hidden_states.to(torch.float32) + # variance = hidden_states.pow(2).mean(-1, keepdim=True) + # hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) + # hidden_states = self.weight * hidden_states.to(input_dtype) + # if residual is None: + # return hidden_states + # else: + # return hidden_states, residual + + # optimized implementation + if residual is not None: + ops.fused_add_rms_norm( + hidden_states, + residual, + self.weight.data, + self.variance_epsilon, + ) + return hidden_states, residual + out = torch.empty_like(hidden_states) + ops.rms_norm( + out, + hidden_states, + self.weight.data, + self.variance_epsilon, + ) + return out + + + + +# Copied from transformers.models.mistral.modeling_mistral.MistralRotaryEmbedding with Mistral->Mixtral +class MixtralRotaryEmbedding(nn.Module): + def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): + super().__init__() + + self.dim = dim + self.max_position_embeddings = max_position_embeddings + self.base = base + inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim)) + self.register_buffer("inv_freq", inv_freq, persistent=False) + + # Build here to make `torch.jit.trace` work. + self._set_cos_sin_cache( + seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype() + ) + + def _set_cos_sin_cache(self, seq_len, device, dtype): + self.max_seq_len_cached = seq_len + t = torch.arange(self.max_seq_len_cached, device=device, dtype=torch.int64).type_as(self.inv_freq) + + freqs = torch.outer(t, self.inv_freq) + # Different from paper, but it uses a different permutation in order to obtain the same calculation + emb = torch.cat((freqs, freqs), dim=-1) + self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False) + self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False) + + def forward(self, x, seq_len=None): + # x: [bs, num_attention_heads, seq_len, head_size] + if seq_len > self.max_seq_len_cached: + self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype) + + return ( + self.cos_cached[:seq_len].to(dtype=x.dtype), + self.sin_cached[:seq_len].to(dtype=x.dtype), + ) + + +# Copied from transformers.models.llama.modeling_llama.rotate_half +def rotate_half(x): + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +# Copied from transformers.models.mistral.modeling_mistral.apply_rotary_pos_emb +def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): + """Applies Rotary Position Embedding to the query and key tensors. + + Args: + q (`torch.Tensor`): The query tensor. + k (`torch.Tensor`): The key tensor. + cos (`torch.Tensor`): The cosine part of the rotary embedding. + sin (`torch.Tensor`): The sine part of the rotary embedding. + position_ids (`torch.Tensor`): + The position indices of the tokens corresponding to the query and key tensors. For example, this can be + used to pass offsetted position ids when working with a KV-cache. + unsqueeze_dim (`int`, *optional*, defaults to 1): + The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and + sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note + that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and + k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes + cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have + the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. + Returns: + `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding. + """ + cos = cos[position_ids].unsqueeze(unsqueeze_dim) + sin = sin[position_ids].unsqueeze(unsqueeze_dim) + q_embed = (q * cos) + (rotate_half(q) * sin) + k_embed = (k * cos) + (rotate_half(k) * sin) + return q_embed, k_embed + + +# Copied from transformers.models.llama.modeling_llama.repeat_kv +def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: + """ + This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, + num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) + """ + batch, num_key_value_heads, slen, head_dim = hidden_states.shape + if n_rep == 1: + return hidden_states + hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim) + return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) + + +class Linear(nn.Module): + def __init__(self, in_features, out_features, bias=True, dtype=None): + super().__init__() + self.in_features = in_features + self.out_features = out_features + self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype)) + if bias: + self.bias = nn.Parameter(torch.empty(out_features, dtype=dtype)) + else: + self.register_parameter("bias", None) + self.reset_parameters() + + def reset_parameters(self): + nn.init.xavier_uniform_(self.weight) + if self.bias is not None: + nn.init.zeros_(self.bias) + + def forward(self, x): + # return F.linear(x, self.weight, self.bias) + return tgemm.mm(x, self.weight, self.bias) + +# Copied from transformers.models.mistral.modeling_mistral.MistralAttention with Mistral->Mixtral +class MixtralAttention(nn.Module): + """ + Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer + and "Generating Long Sequences with Sparse Transformers". + """ + + def __init__(self, config: MixtralConfig, layer_idx: Optional[int] = None): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.config = config + self.layer_idx = layer_idx + if layer_idx is None: + logger.warning_once( + f"Instantiating {self.__class__.__name__} without passing a `layer_idx` is not recommended and will " + "lead to errors during the forward call if caching is used. Please make sure to provide a `layer_idx` " + "when creating this class." + ) + + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.num_key_value_heads = config.num_key_value_heads + self.num_key_value_groups = self.num_heads // self.num_key_value_heads + self.max_position_embeddings = config.max_position_embeddings + self.rope_theta = config.rope_theta + self.is_causal = True + self.attention_dropout = config.attention_dropout + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}" + f" and `num_heads`: {self.num_heads})." + ) + self.q_proj = Linear(self.hidden_size, self.num_heads * self.head_dim // self.mp_size, bias=False) + self.k_proj = Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=False) + self.v_proj = Linear(self.hidden_size, self.num_key_value_heads * self.head_dim // self.mp_size, bias=False) + self.o_proj = Linear(self.num_heads * self.head_dim // self.mp_size, self.hidden_size, bias=False) + + self.rotary_emb = MixtralRotaryEmbedding( + self.head_dim, + max_position_embeddings=self.max_position_embeddings, + base=self.rope_theta, + ) + self.rotary_emb_fused = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=self.max_position_embeddings, + base=int(self.rope_theta), + is_neox_style=True, + ) + + self.num_heads = self.num_heads // self.mp_size + self.num_key_value_heads = self.num_key_value_heads // self.mp_size + + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len): + raise ValueError( + f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is" + f" {attn_weights.size()}" + ) + + if attention_mask is not None: + if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + raise ValueError( + f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + ) + + attn_weights = attn_weights + attention_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) + + if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): + raise ValueError( + f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" + f" {attn_output.size()}" + ) + + attn_output = attn_output.transpose(1, 2).contiguous() + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Copied from transformers.models.mistral.modeling_mistral.MistralFlashAttention2 with Mistral->Mixtral +class MixtralFlashAttention2(MixtralAttention): + """ + Mixtral flash attention module. This module inherits from `MixtralAttention` as the weights of the module stays + untouched. The only required change would be on the forward pass where it needs to correctly call the public API of + flash attention and deal with padding tokens in case the input contains any of them. + """ + + # Copied from transformers.models.llama.modeling_llama.LlamaFlashAttention2.__init__ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + # TODO: Should be removed once Flash Attention for RoCm is bumped to 2.1. + # flash_attn<2.1 generates top-left aligned causal mask, while what is needed here is bottom-right alignement, that was made default for flash_attn>=2.1. This attribute is used to handle this difference. Reference: https://github.com/Dao-AILab/flash-attention/releases/tag/v2.1.0. + # Beware that with flash_attn<2.1, using q_seqlen != k_seqlen (except for the case q_seqlen == 1) produces a wrong mask (top-left). + self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs, + ): + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + + # overwrite attention_mask with padding_mask + attention_mask = kwargs.pop("padding_mask") + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + kv_seq_len = key_states.shape[-2] + if past_key_value is not None: + if self.layer_idx is None: + raise ValueError( + f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} " + "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class " + "with a layer index." + ) + kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx) + + # Because the input can be padded, the absolute sequence length depends on the max position id. + rotary_seq_len = max(kv_seq_len, position_ids[:, -1].max().item()) + 1 + cos, sin = self.rotary_emb(value_states, seq_len=rotary_seq_len) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + use_sliding_windows = ( + _flash_supports_window_size + and getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + ) + + if not _flash_supports_window_size: + logger.warning_once( + "The current flash attention version does not support sliding window attention, for a more memory efficient implementation" + " make sure to upgrade flash-attn library." + ) + + if past_key_value is not None: + # Activate slicing cache only if the config has a value `sliding_windows` attribute + cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0 + if ( + getattr(self.config, "sliding_window", None) is not None + and kv_seq_len > self.config.sliding_window + and cache_has_contents + ): + slicing_tokens = 1 - self.config.sliding_window + + past_key = past_key_value[self.layer_idx][0] + past_value = past_key_value[self.layer_idx][1] + + past_key = past_key[:, :, slicing_tokens:, :].contiguous() + past_value = past_value[:, :, slicing_tokens:, :].contiguous() + + if past_key.shape[-2] != self.config.sliding_window - 1: + raise ValueError( + f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got" + f" {past_key.shape}" + ) + + if attention_mask is not None: + attention_mask = attention_mask[:, slicing_tokens:] + attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1) + + cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + # repeat k/v heads if n_kv_heads < n_heads + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + dropout_rate = 0.0 if not self.training else self.attention_dropout + + # In PEFT, usually we cast the layer norms in float32 for training stability reasons + # therefore the input hidden states gets silently casted in float32. Hence, we need + # cast them back in float16 just to be sure everything works as expected. + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + # Handle the case where the model is quantized + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + logger.warning_once( + f"The input hidden states seems to be silently casted in float32, this might be related to" + f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" + f" {target_dtype}." + ) + + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + # Reashape to the expected shape for Flash Attention + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + attn_output = self._flash_attention_forward( + query_states, + key_states, + value_states, + attention_mask, + q_len, + dropout=dropout_rate, + use_sliding_windows=use_sliding_windows, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + def _flash_attention_forward( + self, + query_states, + key_states, + value_states, + attention_mask, + query_length, + dropout=0.0, + softmax_scale=None, + use_sliding_windows=False, + ): + """ + Calls the forward method of Flash Attention - if the input hidden states contain at least one padding token + first unpad the input, then computes the attention scores and pad the final attention scores. + + Args: + query_states (`torch.Tensor`): + Input query states to be passed to Flash Attention API + key_states (`torch.Tensor`): + Input key states to be passed to Flash Attention API + value_states (`torch.Tensor`): + Input value states to be passed to Flash Attention API + attention_mask (`torch.Tensor`): + The padding mask - corresponds to a tensor of size `(batch_size, seq_len)` where 0 stands for the + position of padding tokens and 1 for the position of non-padding tokens. + dropout (`float`): + Attention dropout + softmax_scale (`float`, *optional*): + The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim) + use_sliding_windows (`bool`, *optional*): + Whether to activate sliding window attention. + """ + if not self._flash_attn_uses_top_left_mask: + causal = self.is_causal + else: + # TODO: Remove the `query_length != 1` check once Flash Attention for RoCm is bumped to 2.1. For details, please see the comment in LlamaFlashAttention2 __init__. + causal = self.is_causal and query_length != 1 + + # Contains at least one padding token in the sequence + if attention_mask is not None: + batch_size = query_states.shape[0] + query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input( + query_states, key_states, value_states, attention_mask, query_length + ) + + cu_seqlens_q, cu_seqlens_k = cu_seq_lens + max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens + + if not use_sliding_windows: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output_unpad = flash_attn_varlen_func( + query_states, + key_states, + value_states, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, + max_seqlen_q=max_seqlen_in_batch_q, + max_seqlen_k=max_seqlen_in_batch_k, + dropout_p=dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + attn_output = pad_input(attn_output_unpad, indices_q, batch_size, query_length) + else: + if not use_sliding_windows: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + ) + else: + attn_output = flash_attn_func( + query_states, + key_states, + value_states, + dropout, + softmax_scale=softmax_scale, + causal=causal, + window_size=(self.config.sliding_window, self.config.sliding_window), + ) + + return attn_output + + def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length): + batch_size, kv_seq_len, num_heads, head_dim = key_layer.shape + + # On the first iteration we need to properly re-create the padding mask + # by slicing it on the proper place + if kv_seq_len != attention_mask.shape[-1]: + attention_mask_num_tokens = attention_mask.shape[-1] + attention_mask = attention_mask[:, attention_mask_num_tokens - kv_seq_len :] + + indices_k, cu_seqlens_k, max_seqlen_in_batch_k = _get_unpad_data(attention_mask) + + key_layer = index_first_axis(key_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + value_layer = index_first_axis(value_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k) + + if query_length == kv_seq_len: + query_layer = index_first_axis( + query_layer.reshape(batch_size * kv_seq_len, num_heads, head_dim), indices_k + ) + cu_seqlens_q = cu_seqlens_k + max_seqlen_in_batch_q = max_seqlen_in_batch_k + indices_q = indices_k + elif query_length == 1: + max_seqlen_in_batch_q = 1 + cu_seqlens_q = torch.arange( + batch_size + 1, dtype=torch.int32, device=query_layer.device + ) # There is a memcpy here, that is very bad. + indices_q = cu_seqlens_q[:-1] + query_layer = query_layer.squeeze(1) + else: + # The -q_len: slice assumes left padding. + attention_mask = attention_mask[:, -query_length:] + query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask) + + return ( + query_layer, + key_layer, + value_layer, + indices_q, + (cu_seqlens_q, cu_seqlens_k), + (max_seqlen_in_batch_q, max_seqlen_in_batch_k), + ) + + +# Copied from transformers.models.mistral.modeling_mistral.MistralSdpaAttention with Mistral->Mixtral +class MixtralSdpaAttention(MixtralAttention): + """ + Mixtral attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from + `MixtralAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to + SDPA API. + """ + # Adapted from MixtralAttention.forward + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Cache] = None, + output_attentions: bool = False, + use_cache: bool = False, + **kwargs + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented. + logger.warning_once( + "MixtralModel is using MixtralSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, " + 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.' + ) + return super().forward( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states = self.q_proj(hidden_states) + key_states = self.k_proj(hidden_states) + value_states = self.v_proj(hidden_states) + + + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + all_q_len = kwargs.get("all_q_len") + all_kv_len = kwargs.get("all_kv_len") + max_kv_len = max(all_kv_len) if is_context else max(all_q_len) + max(all_kv_len) + + # fused rope + query_states, key_states = self.rotary_emb_fused(position_ids, query_states, key_states) + # # old rope + # cos, sin = self.rotary_emb(value_states, seq_len=max_kv_len) + # query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids) + + block_tables, kv_caches = past_key_value + slot_mapping = kwargs.get("slot_mapping") + key_cache, value_cache = kv_caches[self.layer_idx] + key_states_cache = key_states.view(-1, self.num_key_value_heads, self.head_dim).contiguous() + value_states_cache = value_states.view(-1, self.num_key_value_heads, self.head_dim).contiguous() + PagedAttention.write_to_paged_cache(key_states_cache, + value_states_cache, + key_cache, + value_cache, + slot_mapping.view(-1), + "auto", 1.0, 1.0) + + if is_context: + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # if attention_mask is not None: + # if attention_mask.size() != (bsz, 1, q_len, kv_seq_len): + # raise ValueError( + # f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}" + # ) + + + if is_context: + attention_mask = kwargs.get("full_attention_mask") + + # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask, + # Reference: https://github.com/pytorch/pytorch/issues/112577. + if query_states.device.type == "cuda" and attention_mask is not None: + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=~attention_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + ).transpose(1, 2).contiguous() + else: + query_states = query_states.view(-1, self.num_heads, self.head_dim) + seq_lens = kwargs.get("seq_lens_tensor") + attn_output = PagedAttention.forward_decode( + query_states, + key_cache, + value_cache, + block_tables[0:bsz], + seq_lens, + int(self.max_position_embeddings), + "auto", + self.num_key_value_heads, + float(1.0 / (self.head_dim**0.5)), + None, + 1.0, + 1.0, + ).contiguous() + + attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +MIXTRAL_ATTENTION_CLASSES = { + "eager": MixtralAttention, + "flash_attention_2": MixtralFlashAttention2, + "sdpa": MixtralSdpaAttention, +} + + +class MixtralBlockSparseTop2MLP(nn.Module): + def __init__(self, config: MixtralConfig): + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.ffn_dim = config.intermediate_size + self.hidden_dim = config.hidden_size + + self.w1 = Linear(self.hidden_dim, self.ffn_dim // self.mp_size, bias=False) + self.w2 = Linear(self.ffn_dim // self.mp_size, self.hidden_dim, bias=False) + self.w3 = Linear(self.hidden_dim, self.ffn_dim // self.mp_size, bias=False) + + self.act_fn = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + current_hidden_states = self.act_fn(self.w1(hidden_states)) * self.w3(hidden_states) + current_hidden_states = self.w2(current_hidden_states) + return current_hidden_states + + +class MixtralBLockSparseTop2MLP(MixtralBlockSparseTop2MLP): + def __init__(self, *args, **kwargs): + logger.warning_once( + "MixtralBLockSparseTop2MLP is deprecated by MixtralBlockSparseTop2MLP and will be removed in v4.40." + ) + super().__init__(*args, **kwargs) + + +class MixtralSparseMoeBlock(nn.Module): + """ + This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__(self, config): + super().__init__() + self.hidden_dim = config.hidden_size + self.ffn_dim = config.intermediate_size + self.num_experts = config.num_local_experts + self.top_k = config.num_experts_per_tok + + # gating + self.gate = Linear(self.hidden_dim, self.num_experts, bias=False) + + # self.experts = nn.ModuleList([MixtralBlockSparseTop2MLP(config) for _ in range(self.num_experts)]) + + # Jitter parameters + self.jitter_noise = config.router_jitter_noise + + def forward( + self, + hidden_states: torch.Tensor, + **kwargs + ) -> torch.Tensor: + """ """ + batch_size, sequence_length, hidden_dim = hidden_states.shape + if self.training and self.jitter_noise > 0: + hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) + hidden_states = hidden_states.view(-1, hidden_dim) + + # router_logits: (batch * sequence_length, n_experts) + router_logits = self.gate(hidden_states) + log_file = kwargs.get("log_file", None) + + final_hidden_states = fused_moe(hidden_states, + self.w13_weight, + self.w2_weight, + router_logits, + self.top_k, + renormalize=True, + inplace=True) + # if log_file is not None: + # print(f"num_enabled_experts={non_zero_num}, tokens_distribution={tokens_list}", file=log_file, flush=True) + + final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) + # if int(os.environ.get("LOCAL_RANK", "0"))==0: + # print(f'{final_hidden_states=}') + return final_hidden_states, router_logits + + +class MixtralDecoderLayer(nn.Module): + def __init__(self, config: MixtralConfig, layer_idx: int): + self.layer_idx=layer_idx + # dist info + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__() + self.hidden_size = config.hidden_size + + self.self_attn = MIXTRAL_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx) + + self.block_sparse_moe = MixtralSparseMoeBlock(config) + self.input_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + output_router_logits: Optional[bool] = False, + use_cache: Optional[bool] = False, + **kwargs, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + if "padding_mask" in kwargs: + warnings.warn( + "Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`" + ) + """ + Args: + hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` + attention_mask (`torch.FloatTensor`, *optional*): attention mask of size + `(batch, sequence_length)` where padding elements are indicated by 0. + past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under + returned tensors for more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding + (see `past_key_values`). + """ + + # residual = hidden_states + # hidden_states = self.input_layernorm(hidden_states) + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm( + hidden_states, residual) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=False, + use_cache=False, + **kwargs, + ) + if self.mp_size > 1: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # dist.all_reduce(hidden_states) + + + # Fully Connected + # hidden_states = residual + hidden_states + # residual = hidden_states + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states, router_logits = self.block_sparse_moe(hidden_states, **kwargs) + if (os.environ.get("LOCAL_RANK", "0"), self.layer_idx) == (): + pass + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{self.layer_idx} {hidden_states.shape=}') + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{self.layer_idx} {hidden_states=}') + if self.mp_size > 1: + hidden_states = tensor_model_parallel_all_reduce(hidden_states) + # dist.all_reduce(hidden_states) + + # hidden_states = residual + hidden_states + # outputs = (hidden_states,) + outputs = (hidden_states, residual) + + return outputs + + +MIXTRAL_START_DOCSTRING = r""" + This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the + library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads + etc.) + + This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. + Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage + and behavior. + + Parameters: + config ([`MixtralConfig`]): + Model configuration class with all the parameters of the model. Initializing with a config file does not + load the weights associated with the model, only the configuration. Check out the + [`~PreTrainedModel.from_pretrained`] method to load the model weights. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralPreTrainedModel with Mistral->Mixtral +class MixtralPreTrainedModel(PreTrainedModel): + config_class = MixtralConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["MixtralDecoderLayer"] + _skip_keys_device_placement = "past_key_values" + _supports_flash_attn_2 = True + _supports_sdpa = True + _supports_cache_class = True + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + +MIXTRAL_INPUTS_DOCSTRING = r""" + Args: + input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`): + Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide + it. + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + [What are input IDs?](../glossary#input-ids) + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + + [What are attention masks?](../glossary#attention-mask) + + Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and + [`PreTrainedTokenizer.__call__`] for details. + + If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see + `past_key_values`). + + If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`] + and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more + information on the default strategy. + + - 1 indicates the head is **not masked**, + - 0 indicates the head is **masked**. + position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0, + config.n_positions - 1]`. + + [What are position IDs?](../glossary#position-ids) + past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape + `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape + `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`. + + Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention + blocks) that can be used (see `past_key_values` input) to speed up sequential decoding. + + If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that + don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all + `decoder_input_ids` of shape `(batch_size, sequence_length)`. + inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This + is useful if you want more control over how to convert `input_ids` indices into associated vectors than the + model's internal embedding lookup matrix. + use_cache (`bool`, *optional*): + If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see + `past_key_values`). + output_attentions (`bool`, *optional*): + Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned + tensors for more detail. + output_hidden_states (`bool`, *optional*): + Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for + more detail. + output_router_logits (`bool`, *optional*): + Whether or not to return the logits of all the routers. They are useful for computing the router loss, and + should not be returned during inference. + return_dict (`bool`, *optional*): + Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple. +""" + + +@add_start_docstrings( + "The bare Mixtral Model outputting raw hidden-states without any specific head on top.", + MIXTRAL_START_DOCSTRING, +) +# Copied from transformers.models.mistral.modeling_mistral.MistralModel with MISTRAL->MIXTRAL,Mistral->Mixtral +class MixtralModel(MixtralPreTrainedModel): + """ + Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`MixtralDecoderLayer`] + + Args: + config: MixtralConfig + """ + + def __init__(self, config: MixtralConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList( + [MixtralDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + self._attn_implementation = config._attn_implementation + self.norm = MixtralRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + self.gradient_checkpointing = False + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + # Ignore copy + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MoeModelOutputWithPast]: + residual = None + bsz = input_ids.shape[0] + is_context = kwargs.get("is_context") + valid_slot_ids = kwargs.get("valid_slot_ids") + batch_offset = kwargs.get("cache_batch_offset") + if is_context: + slot_offset = torch.tensor([valid_slot_ids[0] * batch_offset], + device = position_ids.device, + dtype = position_ids.dtype).unsqueeze(1) + else: + slot_offset = torch.arange(0, bsz * batch_offset, batch_offset, + device = position_ids.device, + dtype = position_ids.dtype).unsqueeze(1) + kwargs["slot_mapping"] = position_ids + slot_offset + if kwargs.pop("override_hidden_states", False): + random_seed = kwargs.pop("random_seed", None) + layer_index = kwargs.pop("fixed_layer_index", -1) + layer_index = layer_index % len(self.layers) + + # create random input ids on cpu and copy to device + if random_seed is not None: + # RuntimeError: Cannot call CUDAGeneratorImpl::set_current_seed during CUDA graph capture. + torch.manual_seed(random_seed) + random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device) + + hidden_states = self.embed_tokens(random_input_ids) + + for _ in self.layers: + layer_outputs, residual = self.layers[layer_index]( + hidden_states, + residual=residual, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + output_router_logits=False, + use_cache=False, + **kwargs, + ) + else: + hidden_states = self.embed_tokens(input_ids) + for decoder_layer in self.layers: + hidden_states, residual = decoder_layer( + hidden_states, + residual=residual, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_values, + output_attentions=False, + output_router_logits=False, + use_cache=False, + **kwargs, + ) + + hidden_states, _ = self.norm(hidden_states, residual) + + return MoeModelOutputWithPast( + last_hidden_state=hidden_states + ) + + +class MixtralForCausalLM(MixtralPreTrainedModel): + _tied_weights_keys = ["lm_head.weight"] + + def __init__(self, config): + super().__init__(config) + + self.model = MixtralModel(config) + self.vocab_size = config.vocab_size + self.lm_head = Linear(config.hidden_size, config.vocab_size, bias=False) + self.router_aux_loss_coef = config.router_aux_loss_coef + self.num_experts = config.num_local_experts + self.num_experts_per_tok = config.num_experts_per_tok + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + @add_start_docstrings_to_model_forward(MIXTRAL_INPUTS_DOCSTRING) + @replace_return_docstrings(output_type=MoeCausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC) + # Ignore copy + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_router_logits: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs, + ) -> Union[Tuple, MoeCausalLMOutputWithPast]: + r""" + Args: + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Returns: + + Example: + + ```python + >>> from transformers import AutoTokenizer, MixtralForCausalLM + + >>> model = MixtralForCausalLM.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Mixtral-8x7B-v0.1") + + >>> prompt = "Hey, are you conscious? Can you talk to me?" + >>> inputs = tokenizer(prompt, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(inputs.input_ids, max_length=30) + >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." + ```""" + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=None, + use_cache=False, + output_attentions=False, + output_hidden_states=False, + output_router_logits=False, + return_dict=True, + **kwargs, + ) + + # print(f'{os.environ.get("LOCAL_RANK", "0")} {outputs=}') + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + logits = logits.float() + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states.shape=}') + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states=}') + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{logits.shape=}') + # print(f'{os.environ.get("LOCAL_RANK", "0")}:{logits=}') + return MoeCausalLMOutputWithPast( + logits=logits + ) + + def prepare_inputs_for_generation( + self, + input_ids, + past_key_values=None, + attention_mask=None, + inputs_embeds=None, + output_router_logits=False, + **kwargs, + ): + # Omit tokens covered by past_key_values + if past_key_values is not None: + if isinstance(past_key_values, Cache): + cache_length = past_key_values.get_seq_length() + past_length = past_key_values.seen_tokens + max_cache_length = past_key_values.get_max_length() + else: + cache_length = past_length = past_key_values[0][0].shape[2] + max_cache_length = None + + # Keep only the unprocessed tokens: + # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where + # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as + # input) + if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: + input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :] + # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard + # input_ids based on the past_length. + elif past_length < input_ids.shape[1]: + input_ids = input_ids[:, past_length:] + # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens. + + # If we are about to go beyond the maximum cache length, we need to crop the input attention mask. + if ( + max_cache_length is not None + and attention_mask is not None + and cache_length + input_ids.shape[1] > max_cache_length + ): + attention_mask = attention_mask[:, -max_cache_length:] + + position_ids = kwargs.get("position_ids", None) + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -input_ids.shape[1] :] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "position_ids": position_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + "output_router_logits": output_router_logits, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + reordered_past = () + for layer_past in past_key_values: + reordered_past += ( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), + ) + return reordered_past + diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py new file mode 100644 index 000000000..48ad43bfb --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py @@ -0,0 +1,247 @@ +import os +import pathlib + +import torch +import torch.nn as nn +import torch.distributed as dist +import torch.nn.functional as F + +from typing import Dict, Any, List +from llm_perf.utils.logger import logger, setup_logger +from llm_perf.utils.ps_utils import check_memory_usage +from llm_perf.utils.dist_utils import check_dist + +from accelerate import init_empty_weights + +from llm_perf.backends.GPU.gpu_ckpt_loader import GpuCkptLoader +from llm_perf.core.ckpt_loader import Mixtral_ModelLoader +from transformers import MixtralConfig +from .modeling_mixtral import MixtralForCausalLM +from ..rocm_kernels.dist.parallel_state import (ensure_model_parallel_initialized, + init_distributed_environment, + set_custom_all_reduce, + destroy_model_parallel, + destroy_distributed_environment) +from ..rocm_kernels.dist.utils import (get_open_port, + get_distributed_init_method, + get_ip) +# setup_logger('info') + +class GPUMixtralLoader(GpuCkptLoader): + def __init__( + self, + model : MixtralForCausalLM, + model_config : MixtralConfig, + ckpt_path : str = "" + ): + mp_size = int(os.environ.get("WORLD_SIZE", "1")) + local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + super().__init__("", model, mp_size, local_rank, ckpt_path) + self.model_config = model_config + + def parallel_loader(self): + self.state_dict = {} + + model_dir = pathlib.Path(self.ckpt_path).absolute() + if not model_dir.exists() or not model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{model_dir} not exists or is not a directory") + return + + split_model_dir = model_dir.joinpath(f"TP{self.mp_size}") + if not split_model_dir.exists() or not split_model_dir.is_dir(): + if self.mp_rank == 0: + print(f"{split_model_dir} not exists or is not a directory, please split model first.") + return + + model_loader = Mixtral_ModelLoader(split_model_dir / f"device_{self.mp_rank}") + self.state_dict = model_loader.load_weight() + + def infusion_to_model(self): + self.model.model.embed_tokens.weight = self.to_parameter(self.state_dict["model.embed_tokens.weight"]) + for i in range(self.model_config.num_hidden_layers): + self.model.model.layers[i].input_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.input_layernorm.weight"]) + + self.model.model.layers[i].self_attn.q_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.q_proj.weight"]) + self.model.model.layers[i].self_attn.k_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.k_proj.weight"]) + self.model.model.layers[i].self_attn.v_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.v_proj.weight"]) + self.model.model.layers[i].self_attn.o_proj.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.self_attn.o_proj.weight"]) + + self.model.model.layers[i].post_attention_layernorm.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + + self.model.model.layers[i].block_sparse_moe.gate.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.gate.weight"]) + + tmpW=self.state_dict[f"model.layers.{0}.block_sparse_moe.experts.{0}.w1.weight"] + ffn_dim=tmpW.shape[0] + hidden_dim=tmpW.shape[1] + w13_weight = torch.empty(self.model_config.num_local_experts, + 2, ffn_dim, + hidden_dim, + dtype=tmpW.dtype) + w2_weight = torch.empty(self.model_config.num_local_experts, + hidden_dim, + ffn_dim, + dtype=tmpW.dtype) + for j in range(self.model_config.num_local_experts): + # self.model.model.layers[i].block_sparse_moe.experts[j].w1.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"]) + # self.model.model.layers[i].block_sparse_moe.experts[j].w2.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"]) + # self.model.model.layers[i].block_sparse_moe.experts[j].w3.weight = self.to_parameter(self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"]) + w13_weight[j, 0, :] = self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"] + w13_weight[j, 1, :] = self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"] + + w2_weight[j, :] = self.state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"] + w13_weight = w13_weight.view(self.model_config.num_local_experts, 2*ffn_dim, hidden_dim) + if bool(int(os.getenv("ENABLE_MOE_LDS_BYPASS", "1"))): + w13_weight = permute_weight(w13_weight) + w2_weight = permute_weight(w2_weight) + if bool(int(os.getenv("VLLM_MOE_PADDING", "1"))): + w13_weight = F.pad(w13_weight, (0, 128), "constant", 0) + torch.cuda.empty_cache() + w2_weight = F.pad(w2_weight, (0, 128), "constant", 0) + torch.cuda.empty_cache() + self.model.model.layers[i].block_sparse_moe.w13_weight = self.to_parameter(w13_weight) + self.model.model.layers[i].block_sparse_moe.w2_weight = self.to_parameter(w2_weight) + + self.model.model.norm.weight = self.to_parameter(self.state_dict["model.norm.weight"]) + self.model.lm_head.weight = self.to_parameter(self.state_dict["lm_head.weight"]) + + +def permute_weight(x: torch.Tensor) -> torch.Tensor: + # Hardcode BLOCK_K and BLOCK_N + BK = 128 + BN = 128 + x_ = x + x_ = x_.view(x.shape[0], + x.shape[1]//BN, BN//16, 16, + x.shape[2]//BK, BK//32, 4, 8) + x_ = x_.permute(0, 1, 5, 2, 6, 4, 3, 7) + x_ = x_.contiguous() + x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]) + return x_ + +class GPUMixtral(nn.Module): + def __init__(self, xpu_cfg: Dict[str, Any]) -> None: + super().__init__() + + self.xpu_cfg = xpu_cfg + self.model_config = xpu_cfg["model_config"] + + self.model_name = self.model_config["model_name"] + self.model_path = self.model_config["model_path"] + self.model_network = self.model_config["network"] + + self.mixtral_config : MixtralConfig = MixtralConfig(**self.model_network) + + # dist config + self.mp_size = int(os.environ.get("WORLD_SIZE", "1")) + self.local_rank = int(os.environ.get("LOCAL_RANK", "0")) + + self.transformer_model : MixtralForCausalLM = None + + + def init_inference(self): + torch.cuda.set_device(self.local_rank) + + if self.mp_size > 1: + set_custom_all_reduce(True) + + init_distributed_environment( + world_size=self.mp_size, + rank=self.local_rank, + distributed_init_method=get_distributed_init_method("127.0.0.1", get_open_port())) + # distributed_init_method=get_distributed_init_method(get_ip(), get_open_port())) + + ensure_model_parallel_initialized(self.mp_size, 1) + logger.info(f"RANK: {self.local_rank} {self.mp_size} init_process_group...") + # dist.init_process_group( + # backend="nccl", + # world_size=self.mp_size, + # rank=self.local_rank + # ) + check_dist() + + check_memory_usage("Begin") + + with init_empty_weights(): + self.transformer_model = MixtralForCausalLM(self.mixtral_config) + self.transformer_model.eval() + + check_memory_usage("After build model") + + self.load_weight(self.model_path) + + check_memory_usage("After load_weight") + + self.transformer_model.cuda() + + check_memory_usage("After model to device") + + self.block_tables, self.kv_cache = self.init_kvcache(self.mixtral_config.torch_dtype) + + if self.mp_size > 1: + dist.barrier() + + def finalize_inference(self): + if self.mp_size > 1 and dist.is_initialized(): + # dist.destroy_process_group() + destroy_model_parallel() + destroy_distributed_environment() + torch.cuda.empty_cache() + + def load_weight(self, ckpt_path): + p_loader = GPUMixtralLoader(self.transformer_model, self.mixtral_config, ckpt_path) + p_loader.parallel_loader() + p_loader.infusion_to_model() + + def init_kvcache(self, dtype): + max_batch_size = self.xpu_cfg["max_batch_size"] + num_layers = self.mixtral_config.num_hidden_layers + max_seq_len = self.mixtral_config.max_position_embeddings + hidden_size = self.mixtral_config.hidden_size + q_head_num = self.mixtral_config.num_attention_heads + kv_head_num = self.mixtral_config.num_key_value_heads + head_dim = hidden_size // q_head_num + + cur_device = self.transformer_model.device + + if self.xpu_cfg.get("perf_config", None) is not None: + max_seq_len = min(max_seq_len, + max(self.xpu_cfg["perf_config"]["seq_len_list"])*2) + self.block_size = 32 + max_num_blocks = 4096 + while max_num_blocks * self.block_size < max_seq_len * max_batch_size: + max_num_blocks += 4096 + self.max_num_blocks_per_seq = (max_seq_len + self.block_size - 1) // self.block_size + block_tables_lst: List[List[int]] = [] + for batch_idx in range(max_batch_size): + block_start = self.max_num_blocks_per_seq * batch_idx + block_table = [i + block_start for i in range(self.max_num_blocks_per_seq)] + block_tables_lst.append(block_table) + block_tables = torch.tensor(block_tables_lst, dtype=torch.int, device = cur_device) + x = 16 // torch.tensor([], dtype=dtype).element_size() + k_cache_shape = (max_num_blocks, kv_head_num // self.mp_size, head_dim // x, self.block_size, x) + v_cache_shape = (max_num_blocks, kv_head_num // self.mp_size, head_dim, self.block_size) + + past_key_values = () + for i in range(num_layers): + k_cache = torch.zeros(size=k_cache_shape, dtype=dtype, device=cur_device) + v_cache = torch.zeros(size=v_cache_shape, dtype=dtype, device=cur_device) + past_key_values += ((k_cache, v_cache),) + return block_tables, past_key_values + + def forward(self, inputs : Dict[str, torch.Tensor]): + inputs["cache_batch_offset"] = self.block_size * self.max_num_blocks_per_seq + model_outputs = self.transformer_model.forward( + **inputs, + past_key_values=(self.block_tables, self.kv_cache) + ) + + # context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size] + # decode: [max_batch_size, 1] + logits = model_outputs.logits + + output_dict = { + "logits": logits + } + return output_dict \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/model_impl/split_mixtral.py b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/split_mixtral.py new file mode 100644 index 000000000..a4eceabb0 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/model_impl/split_mixtral.py @@ -0,0 +1,148 @@ +import os +import sys +import pathlib +import argparse +from tqdm import tqdm + +import torch +import torch.nn as nn +from typing import List, Optional, Union, Tuple + +from accelerate import init_empty_weights +from transformers import MixtralConfig + +FILE_DIR = pathlib.Path(__file__).parent.absolute() + +sys.path.insert(0, str(FILE_DIR.parents[3])) +from llm_perf.backends.GPU.model_impl.modeling_mixtral import MixtralForCausalLM +from llm_perf.core.ckpt_loader import Mixtral_ModelLoader + + +def to_parameter( + data : torch.Tensor, + dtype : torch.dtype = None +): + if dtype is not None: + data = data.to(dtype) + return nn.Parameter(data, requires_grad=False) + + +def split( + src : torch.Tensor, + mp_size : int, + dim : int, + chunks : List [int]=[] +): + if len(chunks) == 0: + split_arg = src.shape[dim] // mp_size + output_tensors = torch.split(src, split_arg, dim=dim) + else: + # for example + # chunks = [32, 2, 2], sum_chunks = 36, src.shape[dim] = (32 + 2 + 2) * 128, other_dim = 128 + # mp_size = 8 + # new_chunks = [4, 1, 1] + sum_chunks = sum(chunks) + other_dim_size = src.shape[dim] // sum_chunks + + split_arg = [i * other_dim_size for i in chunks] + split_tensors = torch.split(src, split_arg, dim=dim) + + output_split = [] + for i, tensor in enumerate(split_tensors): + if mp_size > chunks[i]: + tensor_shape = tensor.size()[:dim] + (chunks[i], 1, other_dim_size) + tensor.size()[dim+1:] + new_tensor_shape = tensor.size()[:dim] + (chunks[i], mp_size // chunks[i], other_dim_size) + tensor.size()[dim+1:] + output_tensor_shape = tensor.size()[:dim] + (mp_size * other_dim_size,) + tensor.size()[dim+1:] + + tensor = tensor.view(tensor_shape) + tensor = tensor.expand(*new_tensor_shape) + tensor = tensor.contiguous() + tensor = tensor.view(output_tensor_shape) + + cur_split = torch.split(tensor, tensor.shape[dim] // mp_size, dim=dim) + output_split.append(cur_split) + + output_tensors = [] + for i in range(mp_size): + temp_tensors = [output_split[j][i] for j in range(len(chunks))] + tp_tensors = torch.concat(temp_tensors, dim=dim) + output_tensors.append(tp_tensors) + + output_tensors = [tensor.contiguous() for tensor in output_tensors] + + return output_tensors + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--model_path", type=str, required=True) + parser.add_argument("--mp_size", type=int, default=8, choices=[2, 4, 8]) + args = parser.parse_args() + + os.environ["LOCAL_RANK"] = "0" + os.environ["WORLD_SIZE"] = str(args.mp_size) + + model_path = pathlib.Path(args.model_path).absolute() + model_config : MixtralConfig = MixtralConfig.from_pretrained(str(model_path)) + print(model_config) + + model_loader = Mixtral_ModelLoader(model_path) + state_dict = model_loader.load_weight() + + # model_config.num_hidden_layers = 4 + + p_bar = tqdm(total=model_config.num_hidden_layers, desc="split model") + for i in range(model_config.num_hidden_layers): + q = f"model.layers.{i}.self_attn.q_proj.weight" + k = f"model.layers.{i}.self_attn.k_proj.weight" + v = f"model.layers.{i}.self_attn.v_proj.weight" + o = f"model.layers.{i}.self_attn.o_proj.weight" + + state_dict[q] = split(state_dict[q], args.mp_size, 0) + state_dict[k] = split(state_dict[k], args.mp_size, 0) + state_dict[v] = split(state_dict[v], args.mp_size, 0) + state_dict[o] = split(state_dict[o], args.mp_size, 1) + + for j in range(model_config.num_local_experts): + w1 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight" + w2 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight" + w3 = f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight" + + state_dict[w1] = split(state_dict[w1], args.mp_size, 0) + state_dict[w2] = split(state_dict[w2], args.mp_size, 1) + state_dict[w3] = split(state_dict[w3], args.mp_size, 0) + + p_bar.update(1) + p_bar.close() + + split_model_path = model_path / f"TP{args.mp_size}" + split_model_path.mkdir(parents=True, exist_ok=True) + + with init_empty_weights(): + model = MixtralForCausalLM(model_config) + model.eval() + + p_bar = tqdm(total=args.mp_size, desc="save model") + for rank in range(args.mp_size): + output_dir = split_model_path / f"device_{rank}" + output_dir.mkdir(parents=True, exist_ok=True) + + model.model.embed_tokens.weight = to_parameter(state_dict["model.embed_tokens.weight"]) + for i in range(model_config.num_hidden_layers): + model.model.layers[i].self_attn.q_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.q_proj.weight"][rank]) + model.model.layers[i].self_attn.k_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.k_proj.weight"][rank]) + model.model.layers[i].self_attn.v_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.v_proj.weight"][rank]) + model.model.layers[i].self_attn.o_proj.weight = to_parameter(state_dict[f"model.layers.{i}.self_attn.o_proj.weight"][rank]) + model.model.layers[i].block_sparse_moe.gate.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.gate.weight"]) + for j in range(model_config.num_local_experts): + model.model.layers[i].block_sparse_moe.experts[j].w1.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w1.weight"][rank]) + model.model.layers[i].block_sparse_moe.experts[j].w2.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w2.weight"][rank]) + model.model.layers[i].block_sparse_moe.experts[j].w3.weight = to_parameter(state_dict[f"model.layers.{i}.block_sparse_moe.experts.{j}.w3.weight"][rank]) + model.model.layers[i].input_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.input_layernorm.weight"]) + model.model.layers[i].post_attention_layernorm.weight = to_parameter(state_dict[f"model.layers.{i}.post_attention_layernorm.weight"]) + model.model.norm.weight = to_parameter(state_dict["model.norm.weight"]) + model.lm_head.weight = to_parameter(state_dict["lm_head.weight"]) + + model.save_pretrained(str(output_dir)) + p_bar.update(1) + p_bar.close() diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocmKernels.cpython-39-x86_64-linux-gnu.so b/byte_infer_perf/llm_perf/backends/ROCM/rocmKernels.cpython-39-x86_64-linux-gnu.so new file mode 100755 index 000000000..9ceed5fac Binary files /dev/null and b/byte_infer_perf/llm_perf/backends/ROCM/rocmKernels.cpython-39-x86_64-linux-gnu.so differ diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=1024,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=1024,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..67ac2bd47 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=1024,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,519 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "5": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "6": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "7": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "9": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "10": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "11": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "12": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "13": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "14": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "15": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "17": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "18": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "19": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "21": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "22": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "23": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "25": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "26": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "27": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "28": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "29": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "30": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "31": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "40": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "56": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 64, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 32, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "72": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "80": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 1 + }, + "88": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 4, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "104": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "112": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "120": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 64, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=2048,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=2048,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..a172153cb --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=2048,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=4096,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=4096,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..7eca8f657 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=4096,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=8192,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=8192,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..a172153cb --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=32,N=8192,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1024,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1024,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..52a58b46c --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1024,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=14336,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=14336,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..5945e09ed --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=14336,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1792,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1792,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..49413a63b --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=1792,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=2048,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=2048,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..42139fefd --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=2048,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=3584,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=3584,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..c237432d5 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=3584,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,244 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8192": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=4096,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=4096,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..5945e09ed --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=4096,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=7168,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=7168,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..5945e09ed --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=7168,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=8192,device_name=AMD_Instinct_MI308X_OAM.json b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=8192,device_name=AMD_Instinct_MI308X_OAM.json new file mode 100644 index 000000000..28b467d1a --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/configs/E=8,N=8192,device_name=AMD_Instinct_MI308X_OAM.json @@ -0,0 +1,233 @@ +{ + "1": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "8": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "24": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "32": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "48": { + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "64": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "96": { + "BLOCK_SIZE_M": 32, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "128": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "256": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "512": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1024": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "1536": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "2048": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "3072": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "4096": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "16384": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "18432": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + }, + "20480": { + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "BLOCK_SIZE_K": 128, + "GROUP_SIZE_M": 1, + "num_warps": 8, + "num_stages": 0, + "waves_per_eu": 0, + "matrix_instr_nonkdim": 16, + "kpack": 2 + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/activation_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/activation_kernels.cu new file mode 100644 index 000000000..f6798dbab --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/activation_kernels.cu @@ -0,0 +1,150 @@ +#include +#include +#include + +#include + +#include "hip_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +// Activation and gating kernel template. +template +__global__ void act_and_mul_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., 2, d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * 2 * d + idx]); + const scalar_t y = VLLM_LDG(&input[token_idx * 2 * d + d + idx]); + out[token_idx * d + idx] = ACT_FN(x) * y; + } +} + +template +__device__ __forceinline__ T silu_kernel(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T gelu_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'none' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L36-L38 + const float f = (float)x; + constexpr float ALPHA = M_SQRT1_2; + return (T)(f * 0.5f * (1.0f + ::erf(f * ALPHA))); +} + +template +__device__ __forceinline__ T gelu_tanh_kernel(const T& x) { + // Equivalent to PyTorch GELU with 'tanh' approximation. + // Refer to: + // https://github.com/pytorch/pytorch/blob/8ac9b20d4b090c213799e81acf48a55ea8d437d6/aten/src/ATen/native/cuda/ActivationGeluKernel.cu#L25-L30 + const float f = (float)x; + constexpr float BETA = M_SQRT2 * M_2_SQRTPI * 0.5f; + constexpr float KAPPA = 0.044715; + float x_cube = f * f * f; + float inner = BETA * (f + KAPPA * x_cube); + return (T)(0.5f * f * (1.0f + ::tanhf(inner))); +} + +} // namespace vllm + +// Launch activation and gating kernel. +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL) \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +void silu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::silu_kernel); +} + +void gelu_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_kernel); +} + +void gelu_tanh_and_mul(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., 2 * d] +{ + LAUNCH_ACTIVATION_GATE_KERNEL(vllm::gelu_tanh_kernel); +} + +namespace vllm { + +// Element-wise activation kernel template. +template +__global__ void activation_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., d] + const int d) { + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + const scalar_t x = VLLM_LDG(&input[token_idx * d + idx]); + out[token_idx * d + idx] = ACT_FN(x); + } +} + +} // namespace vllm + +// Launch element-wise activation kernel. +#define LAUNCH_ACTIVATION_KERNEL(KERNEL) \ + int d = input.size(-1); \ + int64_t num_tokens = input.numel() / d; \ + dim3 grid(num_tokens); \ + dim3 block(std::min(d, 1024)); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "activation_kernel", [&] { \ + vllm::activation_kernel> \ + <<>>(out.data_ptr(), \ + input.data_ptr(), d); \ + }); + +namespace vllm { + +template +__device__ __forceinline__ T gelu_new_kernel(const T& x) { + const float x3 = (float)(x * x * x); + const T t = (T)tanhf((T)(0.79788456f * (float)(x + (T)(0.044715f * x3)))); + return ((T)0.5) * x * (((T)1.0) + t); +} + +template +__device__ __forceinline__ T gelu_fast_kernel(const T& x) { + const float f = (float)x; + const T t = + (T)tanhf(((T)(f * 0.79788456f)) * (((T)1.0) + (T)(0.044715f * f) * x)); + return ((T)0.5) * x * (((T)1.0) + t); +} + +} // namespace vllm + +void gelu_new(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_new_kernel); +} + +void gelu_fast(torch::Tensor& out, // [..., d] + torch::Tensor& input) // [..., d] +{ + LAUNCH_ACTIVATION_KERNEL(vllm::gelu_fast_kernel); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention.cu new file mode 100644 index 000000000..b5d4465bc --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention.cu @@ -0,0 +1,1120 @@ +/* + * Copyright (c) 2024, The vLLM team. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include "hip_compat.h" + +#include +#include "dtype_fp8.cuh" +#include "quant_utils.cuh" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support + + #define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32 + #define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16 + +using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; +using float16x4 = + __attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16; +typedef float16x4 _Half4; +typedef struct _Half8 { + _Half4 xy[2]; +} _Half8; + +using bit16_t = uint16_t; +using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t; +typedef bit16x4 _B16x4; +typedef struct _B16x8 { + _B16x4 xy[2]; +} _B16x8; + +using _B8x8 = uint2; + +////// Non temporal load stores /////// + +template +__device__ __forceinline__ T load(T* addr) { + return addr[0]; +} + +template +__device__ __forceinline__ void store(T value, T* addr) { + addr[0] = value; +} + +template +__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA, + const _B16x4& inpB, + const floatx4& inpC) { + if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid, + blgp); + } else if constexpr (std::is_same::value) { + return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid, + blgp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ float to_float(const T& inp) { + if constexpr (std::is_same::value) { + return (float)inp; + } else if constexpr (std::is_same::value) { + return __bfloat162float(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ T from_float(const float& inp) { + if constexpr (std::is_same::value) { + return (_Float16)inp; + } else if constexpr (std::is_same::value) { + return __float2bfloat16(inp); + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t16; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.f = (_Float16)inp[i]; + ret[i] = t16.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t16.b = __float2bfloat16(inp[i]); + ret[i] = t16.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1, + const _B16x4& inp2) { + union tmpcvt { + uint16_t u; + _Float16 f; + __hip_bfloat16 b; + } t1, t2, res; + _B16x4 ret; + if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.f = t1.f + t2.f; + ret[i] = res.u; + } + return ret; + } else if constexpr (std::is_same::value) { + #pragma unroll + for (int i = 0; i < 4; i++) { + t1.u = inp1[i]; + t2.u = inp2[i]; + res.b = t1.b + t2.b; + ret[i] = res.u; + } + return ret; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +template +__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input, + const float scale) { + union alignas(16) { + uint4 u4; + _B16x8 u16x8; + vllm::bf16_8_t b16x8; + } tmp; + if constexpr (std::is_same::value) { + tmp.u4 = vllm::fp8::scaled_convert(input, scale); + return tmp.u16x8; + } else if constexpr (std::is_same::value) { + tmp.b16x8 = vllm::fp8::scaled_convert( + input, scale); + return tmp.u16x8; + } else { + static_assert(false, "unsupported 16b dtype"); + } +} + +/////////////////////////////////////// + +// grid (num_seqs, num_partitions,num_heads/gqa_ratio) +// block (partition size) +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale) { + constexpr int NWARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + const int lane4id = laneid % 4; + + const int seq_idx = blockIdx.x; + const int partition_idx = blockIdx.y; + const int partition_size = blockDim.x; + const int max_num_partitions = gridDim.y; + + const int context_len = context_lens[seq_idx]; + const int partition_start_token_idx = partition_idx * partition_size; + // exit if partition is out of context for seq + if (partition_start_token_idx >= context_len) { + return; + } + constexpr int QHLOOP = + DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads, + // total qheads =8, so qhloop is 2 + constexpr int GQA_RATIO4 = 4 * QHLOOP; + __shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1]; + __shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1]; + _B16x8 Qlocal[QHLOOP]; + constexpr int x = 16 / sizeof(scalar_t); + constexpr int KHELOOP = HEAD_SIZE / x; + _B16x8 Klocal[KHELOOP]; + _B8x8 Klocalb8[KHELOOP]; + constexpr int VHELOOP = + HEAD_SIZE / + WARP_SIZE; // v head_size dimension is distributed across lanes + constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2 + // 8xtokens + _B16x8 Vlocal[VHELOOP][VTLOOP]; + _B8x8 Vlocalb8[VHELOOP][VTLOOP]; + floatx4 dout[QHLOOP]; + float qk_max[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = {0}; + qk_max[h] = -FLT_MAX; + } + + const int wg_start_head_idx = blockIdx.z * GQA_RATIO; + const int wg_start_kv_head_idx = blockIdx.z; + + const int warp_start_token_idx = + partition_start_token_idx + warpid * WARP_SIZE; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int h = 0; h < GQA_RATIO4; h++) { + shared_qk_max[warpid][h] = -FLT_MAX; + shared_exp_sum[warpid][h] = 0.0f; + } + } else { // warp within context + + const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE); + const int last_ctx_block = num_context_blocks - 1; + + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + const int local_token_idx = threadIdx.x; + const int global_token_idx = partition_start_token_idx + local_token_idx; + + const int block_idx = (global_token_idx < context_len) + ? global_token_idx / BLOCK_SIZE + : last_ctx_block; + // fetch block number for q and k + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // fetch vphysical block numbers up front + constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE; + int vphysical_blocks[VBLOCKS]; + + const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE; + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + const int vblock_idx = warp_start_block_idx + b; + const int vblock_idx_ctx = + (vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block; + vphysical_blocks[b] = block_table[vblock_idx_ctx]; + } + + // each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems + const scalar_t* q_ptr = + q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE; + const _B16x8* q_ptrh8 = reinterpret_cast(q_ptr); + const int qhead_elemh8 = laneid / 4; + #pragma unroll + for (int h = 0; h < QHLOOP - 1; h++) { + const int qhead_idx = h * 4 + lane4id; + Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } + const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id; + if (final_qhead_idx < GQA_RATIO) { + Qlocal[QHLOOP - 1] = + q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8]; + } else { + Qlocal[QHLOOP - 1].xy[0] = {0}; + Qlocal[QHLOOP - 1].xy[1] = {0}; + } + + const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride + + wg_start_kv_head_idx * kv_head_stride; + + const int physical_block_offset = + local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset + // is already cast as _H8 + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* k_ptrh8 = reinterpret_cast(k_ptr); + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset]; + } + } else { + constexpr int X = 16 / sizeof(cache_t); + const cache_t* k_ptr2 = k_ptr + physical_block_offset * X; + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + const int head_elem = d * 8; + const int offset1 = head_elem / X; + const int offset2 = head_elem % X; + const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2; + Klocalb8[d] = *reinterpret_cast(k_ptr3); + } + } + + float alibi_slope[QHLOOP]; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int qhead_idx = h * 4 + lane4id; + alibi_slope[h] = (qhead_idx < GQA_RATIO) + ? alibi_slopes[wg_start_head_idx + qhead_idx] + : 0.f; + } + } + + const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride; + if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) { + const _B16x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B16x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + } + } + } + } else { + const _B8x8* v_ptrh8 = reinterpret_cast(v_ptr); + // iterate over each v block + #pragma unroll + for (int b = 0; b < VBLOCKS; b++) { + // int32 physical_block_number leads to overflow when multiplied with + // kv_block_stride + const int64_t vphysical_block_number = + static_cast(vphysical_blocks[b]); + const _B8x8* v_ptrh8b = + v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8; + // iterate over each head elem (within head_size) + #pragma unroll + for (int h = 0; h < VHELOOP; h++) { + const int head_size_elem = h * WARP_SIZE + laneid; + const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8; + // iterate over all velems within block + #pragma unroll + for (int d = 0; d < BLOCK_SIZE / 8; d++) { + // Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d]; + const _B8x8 Vlocalb8 = v_ptrh8be[d]; + Vlocal[h][b * BLOCK_SIZE / 8 + d] = + scaled_convert_b8x8(Vlocalb8, v_scale); + } + } + } + } + + if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) { + #pragma unroll + for (int d = 0; d < KHELOOP; d++) { + Klocal[d] = + scaled_convert_b8x8(Klocalb8[d], k_scale); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[0].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[0].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[1].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[1].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[2].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[2].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[3].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[3].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[4].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[4].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[5].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[5].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[6].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[6].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[7].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[7].xy[1], dout[h]); + if constexpr (KHELOOP > 8) { + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[8].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[8].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[9].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[9].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[10].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[10].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[11].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[11].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[12].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[12].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[13].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[13].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[14].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[14].xy[1], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[0], + Klocal[15].xy[0], dout[h]); + dout[h] = gcn_mfma_instr(Qlocal[h].xy[1], + Klocal[15].xy[1], dout[h]); + } // KHELOOP>8 + dout[h] *= scale; + } + // transpose dout so that 4 token ids are in each lane, and 4 heads are across + // 4 lanes + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + floatx4 tmp = {0}; + #pragma unroll + for (int i = 0; i < 4; i++) { + const float B = (lane4id == i) ? 1.0f : 0.0f; + // const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f; + tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0); + // tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0); + } + dout[h] = tmp; + } + + const int lane4_token_idx = 4 * (global_token_idx >> 2); + const int alibi_offset = lane4_token_idx - context_len + 1; + if (alibi_slopes != nullptr) { + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] += alibi_slope[h] * (alibi_offset + i); + } + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + qk_max[h] = -FLT_MAX; + #pragma unroll + for (int i = 0; i < 4; i++) { + qk_max[h] = (lane4_token_idx + i < context_len) + ? fmaxf(qk_max[h], dout[h][i]) + : qk_max[h]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask)); + } + } + + float exp_sum[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + exp_sum[h] = 0.0f; + #pragma unroll + for (int i = 0; i < 4; i++) { + dout[h][i] = (lane4_token_idx + i < context_len) + ? __expf(dout[h][i] - qk_max[h]) + : 0.0f; + exp_sum[h] += dout[h][i]; + } + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) { + exp_sum[h] += __shfl_xor(exp_sum[h], mask); + } + } + + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + const int head_idx = 4 * h + lane4id; + shared_qk_max[warpid][head_idx] = qk_max[h]; + shared_exp_sum[warpid][head_idx] = exp_sum[h]; + } + } // warp within context + + __syncthreads(); + + const int num_heads = gridDim.z * GQA_RATIO; + float* max_logits_ptr = + max_logits + seq_idx * num_heads * max_num_partitions + partition_idx; + float* exp_sums_ptr = + exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + float global_qk_max = -FLT_MAX; + float warp_qk_max[NWARPS]; + const int head_idx = 4 * h + lane4id; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + warp_qk_max[w] = shared_qk_max[w][head_idx]; + global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]); + } + float global_exp_sum = 0.0f; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + global_exp_sum += + shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max); + } + if (head_idx < GQA_RATIO) { + max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_qk_max; + exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] = + global_exp_sum; + } + const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) * + __expf(qk_max[h] - global_qk_max); + dout[h] *= global_inv_sum_scale; + } + // logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there + // are 4x16 tokens across warp + _B16x4 logits[QHLOOP]; + #pragma unroll + for (int h = 0; h < QHLOOP; h++) { + logits[h] = from_floatx4(dout[h]); + } + + __shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1]; + + if (warp_start_token_idx >= context_len) { // warp out of context + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout_shared[qh][vh][laneid][warpid] = {0}; + } + } + } else { // warp in context + // iterate across heads + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + floatx4 acc = {0}; + // iterate over tokens + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][0].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][1].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][2].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][3].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[0], + acc); + acc = gcn_mfma_instr(logits[qh], Vlocal[vh][4].xy[1], + acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][5].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][6].xy[1], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[0], acc); + acc = gcn_mfma_instr(logits[qh], + Vlocal[vh][7].xy[1], acc); + vout_shared[qh][vh][laneid][warpid] = from_floatx4(acc); + } + } + } // warp in context + + __syncthreads(); + + if (warpid == 0) { + _B16x4 vout[QHLOOP][VHELOOP]; + // iterate across heads + scalar_t* out_ptr; + int out_num_partitions; + if (context_len > partition_size) { + out_num_partitions = max_num_partitions; + out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + partition_idx * HEAD_SIZE; + } else { + out_num_partitions = 1; + out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE; + } + #pragma unroll + for (int qh = 0; qh < QHLOOP; qh++) { + // iterate over each v head elem (within head_size) + #pragma unroll + for (int vh = 0; vh < VHELOOP; vh++) { + vout[qh][vh] = {0}; + #pragma unroll + for (int w = 0; w < NWARPS; w++) { + vout[qh][vh] = + addx4(vout[qh][vh], vout_shared[qh][vh][laneid][w]); + } + const int head_size_elem = vh * WARP_SIZE + laneid; + bit16_t* out_ptr_b16 = reinterpret_cast(out_ptr); + #pragma unroll + for (int i = 0; i < 4; i++) { + const int head_idx = 4 * qh + i; + if (head_idx < GQA_RATIO) { + out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions * + HEAD_SIZE + + head_size_elem] = vout[qh][vh][i]; + } + } + } + } + } +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int context_len = context_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE); + if (num_partitions == 1) { + // if num_partitions==1, main kernel will write to out directly, no work in + // reduction kernel + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warpid = threadIdx.x / WARP_SIZE; + const int laneid = threadIdx.x % WARP_SIZE; + + __shared__ float shared_global_exp_sum; + __shared__ float shared_exp_sums[2 * WARP_SIZE]; + + if (warpid == 0) { + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + // valid partition is the last valid partition in case threadid > num + // partitions + const int valid_partition = + (threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1; + const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions) + ? WARP_SIZE + threadIdx.x + : num_partitions - 1; + float reg_max_logit = max_logits_ptr[valid_partition]; + float reg_max_logit2 = max_logits_ptr[valid_partition2]; + float max_logit = fmaxf(reg_max_logit, reg_max_logit2); + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask)); + } + + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + + float global_exp_sum = 0.0f; + float rescaled_exp_sum = exp_sums_ptr[valid_partition]; + float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2]; + rescaled_exp_sum *= + (threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f; + rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions) + ? expf(reg_max_logit2 - max_logit) + : 0.0f; + global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2; + shared_exp_sums[threadIdx.x] = rescaled_exp_sum; + shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2; + + #pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + global_exp_sum += __shfl_xor(global_exp_sum, mask); + } + if (threadIdx.x == 0) { + shared_global_exp_sum = global_exp_sum; + } + } // warpid == 0 + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x; + constexpr int MAX_NPAR = 64; + scalar_t tmps[MAX_NPAR]; + const float dzero = 0.0f; + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + tmps[j] = from_float(dzero); + } + const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE; + const int num_partition_offset = (num_partitions)*HEAD_SIZE; + int idx = 0; + + constexpr int JCHUNK = 16; + + #pragma unroll + for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + __syncthreads(); + + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + } + } // num_partitions > JCHUNK + + // Aggregate tmp_out to out. + float acc = 0.0f; + #pragma unroll + for (int j = 0; j < JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > JCHUNK) { + #pragma unroll + for (int j = JCHUNK; j < 2 * JCHUNK; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + if (num_partitions > 2 * JCHUNK) { + #pragma unroll + for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j]; + } + } + } + + if (num_partitions > MAX_NPAR) { + idx = 0; + #pragma unroll + for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE; + j += HEAD_SIZE) { + // lastj is last valid partition + const int lastj_offset = + (j < num_partition_offset) ? j : last_partition_offset; + tmps[idx] = tmp_out_ptr[lastj_offset]; + idx++; + } + + #pragma unroll + for (int j = 0; j < MAX_NPAR; j++) { + acc += to_float(tmps[j]) * shared_exp_sums[j + MAX_NPAR]; + } + } + + const float inv_global_exp_sum = + __fdividef(1.0f, shared_global_exp_sum + 1e-6f); + acc *= inv_global_exp_sum; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + out_ptr[threadIdx.x] = from_float(acc); +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +template +__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel( + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size] + int max_ctx_blocks, float k_scale, float v_scale) { + UNREACHABLE_CODE +} + +// Grid: (num_heads, num_seqs). +template +__global__ +__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ context_lens, // [num_seqs] + const int max_num_partitions){UNREACHABLE_CODE} + +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \ + paged_attention_ll4mi_QKV_kernel \ + <<>>( \ + query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \ + block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \ + k_scale, v_scale); + +template +void paged_attention_custom_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, const int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& context_lens, + int max_context_len, const c10::optional& alibi_slopes, + float k_scale, float v_scale) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + KVT* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + KVT* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* context_lens_ptr = context_lens.data_ptr(); + + const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE); + const int max_num_partitions = + DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE); + const int gqa_ratio = num_heads / num_kv_heads; + assert(num_heads % num_kv_heads == 0); + assert(head_size == HEAD_SIZE); + assert(max_num_partitions <= 128); + + constexpr int NTHR = PARTITION_SIZE; + dim3 grid(num_seqs, max_num_partitions, num_kv_heads); + dim3 block(NTHR); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (gqa_ratio) { + case 1: + LAUNCH_CUSTOM_ATTENTION(1); + break; + case 2: + LAUNCH_CUSTOM_ATTENTION(2); + break; + case 3: + LAUNCH_CUSTOM_ATTENTION(3); + break; + case 4: + LAUNCH_CUSTOM_ATTENTION(4); + break; + case 5: + LAUNCH_CUSTOM_ATTENTION(5); + break; + case 6: + LAUNCH_CUSTOM_ATTENTION(6); + break; + case 7: + LAUNCH_CUSTOM_ATTENTION(7); + break; + case 8: + LAUNCH_CUSTOM_ATTENTION(8); + break; + case 9: + LAUNCH_CUSTOM_ATTENTION(9); + break; + case 10: + LAUNCH_CUSTOM_ATTENTION(10); + break; + case 11: + LAUNCH_CUSTOM_ATTENTION(11); + break; + case 12: + LAUNCH_CUSTOM_ATTENTION(12); + break; + case 13: + LAUNCH_CUSTOM_ATTENTION(13); + break; + case 14: + LAUNCH_CUSTOM_ATTENTION(14); + break; + case 15: + LAUNCH_CUSTOM_ATTENTION(15); + break; + case 16: + LAUNCH_CUSTOM_ATTENTION(16); + break; + default: + TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio); + break; + } + // dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG); + // dim3 block2(1024); + // LAUNCH_CUSTOM_ATTENTION2; + + // reduction kernel is only required if max_context_len > partition size, + // otherwise main kernel writes directly to final output + // note there are cases with graphing where max_context_len is the max + // supported by graphing, not the actual max among all the sequences: in that + // case reduction kernel will still run but return immediately + if (max_context_len > PARTITION_SIZE) { + dim3 reduce_grid(num_heads, num_seqs); + dim3 reduce_block(head_size); + paged_attention_ll4mi_reduce_kernel + <<>>( + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, + context_lens_ptr, max_num_partitions); + } +} + +#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \ + paged_attention_custom_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, context_lens, max_context_len, \ + alibi_slopes, k_scale, v_scale); + +#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \ + switch (block_size) { \ + case 16: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \ + break; \ + case 32: \ + CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \ + switch (head_size) { \ + case 64: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \ + break; \ + case 128: \ + CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported head size: ", head_size); \ + break; \ + } + +void paged_attention_rocm( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& context_lens, // [num_seqs] + int64_t block_size, int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale) { + const int head_size = query.size(2); + if (kv_cache_dtype == "auto") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16, + vllm::Fp8KVCacheDataType::kAuto); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16, + vllm::Fp8KVCacheDataType::kAuto); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (query.dtype() == at::ScalarType::Half) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (query.dtype() == at::ScalarType::BFloat16) { + CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else { + TORCH_CHECK(false, "Unsupported data type: ", query.dtype()); + } + } else { + TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype); + } +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_dtypes.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_dtypes.h new file mode 100644 index 000000000..64f86381d --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_dtypes.h @@ -0,0 +1,7 @@ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float16.cuh" +#include "dtype_float32.cuh" +#include "dtype_bfloat16.cuh" +#include "dtype_fp8.cuh" diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_generic.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_generic.cuh new file mode 100644 index 000000000..62409c0cc --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_generic.cuh @@ -0,0 +1,65 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +namespace vllm { + +// A vector type to store Q, K, V elements. +template +struct Vec {}; + +// A vector type to store FP32 accumulators. +template +struct FloatVec {}; + +// Template vector operations. +template +inline __device__ Acc mul(A a, B b); + +template +inline __device__ float sum(T v); + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ float dot(T a, T b) { + return sum(mul(a, b)); +} + +template +inline __device__ void zero(T& dst) { + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; + +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_kernels.cu new file mode 100644 index 000000000..8afa088ee --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_kernels.cu @@ -0,0 +1,1010 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "attention_dtypes.h" +#include "attention_utils.cuh" + +#ifdef USE_ROCM + #include + #include "quant_utils.cuh" +typedef __hip_bfloat16 __nv_bfloat16; +#else + #include "../quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b)) + +namespace vllm { + +// Utility function for attention softmax. +template +inline __device__ float block_sum(float* red_smem, float sum) { + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + + // Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < NUM_WARPS) { + sum = red_smem[lane]; + } + + // Parallel reduction inside the warp. +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + sum += VLLM_SHFL_XOR_SYNC(sum, mask); + } + + // Broadcast to other threads. + return VLLM_SHFL_SYNC(sum, 0); +} + +// TODO(woosuk): Merge the last two dimensions of the grid. +// Grid: (num_heads, num_seqs, max_num_partitions). +template // Zero means no partitioning. +__device__ void paged_attention_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, + // head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + const int seq_idx = blockIdx.y; + const int partition_idx = blockIdx.z; + const int max_num_partitions = gridDim.z; + constexpr bool USE_PARTITIONING = PARTITION_SIZE > 0; + const int seq_len = seq_lens[seq_idx]; + if (USE_PARTITIONING && partition_idx * PARTITION_SIZE >= seq_len) { + // No work to do. Terminate the thread block. + return; + } + + const int num_seq_blocks = DIVIDE_ROUND_UP(seq_len, BLOCK_SIZE); + const int num_blocks_per_partition = + USE_PARTITIONING ? PARTITION_SIZE / BLOCK_SIZE : num_seq_blocks; + + // [start_block_idx, end_block_idx) is the range of blocks to process. + const int start_block_idx = + USE_PARTITIONING ? partition_idx * num_blocks_per_partition : 0; + const int end_block_idx = + MIN(start_block_idx + num_blocks_per_partition, num_seq_blocks); + const int num_blocks = end_block_idx - start_block_idx; + + // [start_token_idx, end_token_idx) is the range of tokens to process. + const int start_token_idx = start_block_idx * BLOCK_SIZE; + const int end_token_idx = + MIN(start_token_idx + num_blocks * BLOCK_SIZE, seq_len); + const int num_tokens = end_token_idx - start_token_idx; + + constexpr int THREAD_GROUP_SIZE = MAX(WARP_SIZE / BLOCK_SIZE, 1); + constexpr int NUM_THREAD_GROUPS = + NUM_THREADS / THREAD_GROUP_SIZE; // Note: This assumes THREAD_GROUP_SIZE + // divides NUM_THREADS + assert(NUM_THREADS % THREAD_GROUP_SIZE == 0); + constexpr int NUM_TOKENS_PER_THREAD_GROUP = + DIVIDE_ROUND_UP(BLOCK_SIZE, WARP_SIZE); + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int thread_idx = threadIdx.x; + const int warp_idx = thread_idx / WARP_SIZE; + const int lane = thread_idx % WARP_SIZE; + + const int head_idx = blockIdx.x; + const int num_heads = gridDim.x; + const int num_queries_per_kv = num_heads / num_kv_heads; + const int kv_head_idx = head_idx / num_queries_per_kv; + const float alibi_slope = + alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx]; + + // A vector type to store a part of a key or a query. + // The vector size is configured in such a way that the threads in a thread + // group fetch or compute 16 bytes at a time. For example, if the size of a + // thread group is 4 and the data type is half, then the vector size is 16 / + // (4 * sizeof(half)) == 2. + constexpr int VEC_SIZE = MAX(16 / (THREAD_GROUP_SIZE * sizeof(scalar_t)), 1); + using K_vec = typename Vec::Type; + using Q_vec = typename Vec::Type; + using Quant_vec = typename Vec::Type; + + constexpr int NUM_ELEMS_PER_THREAD = HEAD_SIZE / THREAD_GROUP_SIZE; + constexpr int NUM_VECS_PER_THREAD = NUM_ELEMS_PER_THREAD / VEC_SIZE; + + const int thread_group_idx = thread_idx / THREAD_GROUP_SIZE; + const int thread_group_offset = thread_idx % THREAD_GROUP_SIZE; + + // Load the query to registers. + // Each thread in a thread group has a different part of the query. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the query, and the second thread + // has 1, 5, 9, ... th vectors of the query, and so on. NOTE(woosuk): Because + // q is split from a qkv tensor, it may not be contiguous. + const scalar_t* q_ptr = q + seq_idx * q_stride + head_idx * HEAD_SIZE; + __shared__ Q_vec q_vecs[THREAD_GROUP_SIZE][NUM_VECS_PER_THREAD]; +#pragma unroll + for (int i = thread_group_idx; i < NUM_VECS_PER_THREAD; + i += NUM_THREAD_GROUPS) { + const int vec_idx = thread_group_offset + i * THREAD_GROUP_SIZE; + q_vecs[thread_group_offset][i] = + *reinterpret_cast(q_ptr + vec_idx * VEC_SIZE); + } + __syncthreads(); // TODO(naed90): possible speedup if this is replaced with a + // memory wall right before we use q_vecs + + // Memory planning. + extern __shared__ char shared_mem[]; + // NOTE(woosuk): We use FP32 for the softmax logits for better accuracy. + float* logits = reinterpret_cast(shared_mem); + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // x == THREAD_GROUP_SIZE * VEC_SIZE + // Each thread group fetches x elements from the key at a time. + constexpr int x = 16 / sizeof(cache_t); + float qk_max = -FLT_MAX; + + // Iterate over the key blocks. + // Each warp fetches a block of keys for each iteration. + // Each thread group in a warp fetches a key from the block, and computes + // dot product with the query. + const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq; + + // blocksparse specific vars + int bs_block_offset; + int q_bs_block_id; + if constexpr (IS_BLOCK_SPARSE) { + // const int num_blocksparse_blocks = DIVIDE_ROUND_UP(seq_len, + // blocksparse_block_size); + q_bs_block_id = (seq_len - 1) / blocksparse_block_size; + if (blocksparse_head_sliding_step >= 0) + // sliding on q heads + bs_block_offset = + (tp_rank * num_heads + head_idx) * blocksparse_head_sliding_step + 1; + else + // sliding on kv heads + bs_block_offset = (tp_rank * num_kv_heads + kv_head_idx) * + (-blocksparse_head_sliding_step) + + 1; + } + + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + const int k_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + const bool is_remote = + ((k_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0); + const bool is_local = + (k_bs_block_id > q_bs_block_id - blocksparse_local_blocks); + if (!is_remote && !is_local) { + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + + if (thread_group_offset == 0) { + // NOTE(linxihui): assign very large number to skipped tokens to + // avoid contribution to the sumexp softmax normalizer. This will + // not be used at computing sum(softmax*v) as the blocks will be + // skipped. + logits[token_idx - start_token_idx] = -FLT_MAX; + } + } + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + + // Load a key to registers. + // Each thread in a thread group has a different part of the key. + // For example, if the the thread group size is 4, then the first thread in + // the group has 0, 4, 8, ... th vectors of the key, and the second thread + // has 1, 5, 9, ... th vectors of the key, and so on. + for (int i = 0; i < NUM_TOKENS_PER_THREAD_GROUP; i++) { + const int physical_block_offset = + (thread_group_idx + i * WARP_SIZE) % BLOCK_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + K_vec k_vecs[NUM_VECS_PER_THREAD]; + +#pragma unroll + for (int j = 0; j < NUM_VECS_PER_THREAD; j++) { + const cache_t* k_ptr = + k_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride + physical_block_offset * x; + const int vec_idx = thread_group_offset + j * THREAD_GROUP_SIZE; + const int offset1 = (vec_idx * VEC_SIZE) / x; + const int offset2 = (vec_idx * VEC_SIZE) % x; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + k_vecs[j] = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + } else { + // Vector conversion from Quant_vec to K_vec. + Quant_vec k_vec_quant = *reinterpret_cast( + k_ptr + offset1 * BLOCK_SIZE * x + offset2); + k_vecs[j] = fp8::scaled_convert( + k_vec_quant, k_scale); + } + } + + // Compute dot product. + // This includes a reduction across the threads in the same thread group. + float qk = scale * Qk_dot::dot( + q_vecs[thread_group_offset], k_vecs); + // Add the ALiBi bias if slopes are given. + qk += (alibi_slope != 0) ? alibi_slope * (token_idx - seq_len + 1) : 0; + + if (thread_group_offset == 0) { + // Store the partial reductions to shared memory. + // NOTE(woosuk): It is required to zero out the masked logits. + const bool mask = token_idx >= seq_len; + logits[token_idx - start_token_idx] = mask ? 0.f : qk; + // Update the max value. + qk_max = mask ? qk_max : fmaxf(qk_max, qk); + } + } + } + + // Perform reduction across the threads in the same warp to get the + // max qk value for each "warp" (not across the thread block yet). + // The 0-th thread of each thread group already has its max qk value. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREAD_GROUP_SIZE; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = qk_max; + } + __syncthreads(); + + // TODO(woosuk): Refactor this part. + // Get the max qk value for the sequence. + qk_max = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, VLLM_SHFL_XOR_SYNC(qk_max, mask)); + } + // Broadcast the max qk value to all threads. + qk_max = VLLM_SHFL_SYNC(qk_max, 0); + + // Get the sum of the exp values. + float exp_sum = 0.f; + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + float val = __expf(logits[i] - qk_max); + logits[i] = val; + exp_sum += val; + } + exp_sum = block_sum(&red_smem[NUM_WARPS], exp_sum); + + // Compute softmax. + const float inv_sum = __fdividef(1.f, exp_sum + 1e-6f); + for (int i = thread_idx; i < num_tokens; i += NUM_THREADS) { + logits[i] *= inv_sum; + } + __syncthreads(); + + // If partitioning is enabled, store the max logit and exp_sum. + if (USE_PARTITIONING && thread_idx == 0) { + float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *max_logits_ptr = qk_max; + float* exp_sums_ptr = exp_sums + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions + partition_idx; + *exp_sums_ptr = exp_sum; + } + + // Each thread will fetch 16 bytes from the value cache at a time. + constexpr int V_VEC_SIZE = MIN(16 / sizeof(scalar_t), BLOCK_SIZE); + using V_vec = typename Vec::Type; + using L_vec = typename Vec::Type; + using V_quant_vec = typename Vec::Type; + using Float_L_vec = typename FloatVec::Type; + + constexpr int NUM_V_VECS_PER_ROW = BLOCK_SIZE / V_VEC_SIZE; + constexpr int NUM_ROWS_PER_ITER = WARP_SIZE / NUM_V_VECS_PER_ROW; + constexpr int NUM_ROWS_PER_THREAD = + DIVIDE_ROUND_UP(HEAD_SIZE, NUM_ROWS_PER_ITER); + + // NOTE(woosuk): We use FP32 for the accumulator for better accuracy. + float accs[NUM_ROWS_PER_THREAD]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + accs[i] = 0.f; + } + + scalar_t zero_value; + zero(zero_value); + for (int block_idx = start_block_idx + warp_idx; block_idx < end_block_idx; + block_idx += NUM_WARPS) { + // NOTE(woosuk): The block number is stored in int32. However, we cast it to + // int64 because int32 can lead to overflow when this variable is multiplied + // by large numbers (e.g., kv_block_stride). + // For blocksparse attention: skip computation on blocks that are not + // attended + if constexpr (IS_BLOCK_SPARSE) { + int v_bs_block_id = block_idx * BLOCK_SIZE / blocksparse_block_size; + if (!((v_bs_block_id + bs_block_offset) % blocksparse_vert_stride == 0) && + !((v_bs_block_id > q_bs_block_id - blocksparse_local_blocks))) { + continue; + } + } + const int64_t physical_block_number = + static_cast(block_table[block_idx]); + const int physical_block_offset = (lane % NUM_V_VECS_PER_ROW) * V_VEC_SIZE; + const int token_idx = block_idx * BLOCK_SIZE + physical_block_offset; + L_vec logits_vec; + from_float(logits_vec, *reinterpret_cast(logits + token_idx - + start_token_idx)); + + const cache_t* v_ptr = v_cache + physical_block_number * kv_block_stride + + kv_head_idx * kv_head_stride; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE) { + const int offset = row_idx * BLOCK_SIZE + physical_block_offset; + V_vec v_vec; + + if constexpr (KV_DTYPE == Fp8KVCacheDataType::kAuto) { + v_vec = *reinterpret_cast(v_ptr + offset); + } else { + V_quant_vec v_quant_vec = + *reinterpret_cast(v_ptr + offset); + // Vector conversion from V_quant_vec to V_vec. + v_vec = fp8::scaled_convert(v_quant_vec, + v_scale); + } + if (block_idx == num_seq_blocks - 1) { + // NOTE(woosuk): When v_vec contains the tokens that are out of the + // context, we should explicitly zero out the values since they may + // contain NaNs. See + // https://github.com/vllm-project/vllm/issues/641#issuecomment-1682544472 + scalar_t* v_vec_ptr = reinterpret_cast(&v_vec); +#pragma unroll + for (int j = 0; j < V_VEC_SIZE; j++) { + v_vec_ptr[j] = token_idx + j < seq_len ? v_vec_ptr[j] : zero_value; + } + } + accs[i] += dot(logits_vec, v_vec); + } + } + } + + // Perform reduction within each warp. +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + float acc = accs[i]; +#pragma unroll + for (int mask = NUM_V_VECS_PER_ROW / 2; mask >= 1; mask /= 2) { + acc += VLLM_SHFL_XOR_SYNC(acc, mask); + } + accs[i] = acc; + } + + // NOTE(woosuk): A barrier is required because the shared memory space for + // logits is reused for the output. + __syncthreads(); + + // Perform reduction across warps. + float* out_smem = reinterpret_cast(shared_mem); +#pragma unroll + for (int i = NUM_WARPS; i > 1; i /= 2) { + int mid = i / 2; + // Upper warps write to shared memory. + if (warp_idx >= mid && warp_idx < i) { + float* dst = &out_smem[(warp_idx - mid) * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + dst[row_idx] = accs[i]; + } + } + } + __syncthreads(); + + // Lower warps update the output. + if (warp_idx < mid) { + const float* src = &out_smem[warp_idx * HEAD_SIZE]; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + accs[i] += src[row_idx]; + } + } + } + __syncthreads(); + } + + // Write the final output. + if (warp_idx == 0) { + scalar_t* out_ptr = + out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE + partition_idx * HEAD_SIZE; +#pragma unroll + for (int i = 0; i < NUM_ROWS_PER_THREAD; i++) { + const int row_idx = lane / NUM_V_VECS_PER_ROW + i * NUM_ROWS_PER_ITER; + if (row_idx < HEAD_SIZE && lane % NUM_V_VECS_PER_ROW == 0) { + from_float(*(out_ptr + row_idx), accs[i]); + } + } + } +} + +// Grid: (num_heads, num_seqs, 1). +template +__global__ void paged_attention_v1_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, + v_cache, num_kv_heads, scale, block_tables, seq_lens, + max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, + kv_head_stride, k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs, max_num_partitions). +template +__global__ void paged_attention_v2_kernel( + float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions] + float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size] + const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads, + // head_size/x, block_size, x] + const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads, + // head_size, block_size] + const int num_kv_heads, // [num_heads] + const float scale, + const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_blocks_per_seq, + const float* __restrict__ alibi_slopes, // [num_heads] + const int q_stride, const int kv_block_stride, const int kv_head_stride, + const float k_scale, const float v_scale, const int tp_rank, + const int blocksparse_local_blocks, const int blocksparse_vert_stride, + const int blocksparse_block_size, const int blocksparse_head_sliding_step) { + paged_attention_kernel( + exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale, + block_tables, seq_lens, max_num_blocks_per_seq, alibi_slopes, q_stride, + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step); +} + +// Grid: (num_heads, num_seqs). +template +__global__ void paged_attention_v2_reduce_kernel( + scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size] + const float* __restrict__ exp_sums, // [num_seqs, num_heads, + // max_num_partitions] + const float* __restrict__ max_logits, // [num_seqs, num_heads, + // max_num_partitions] + const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads, + // max_num_partitions, head_size] + const int* __restrict__ seq_lens, // [num_seqs] + const int max_num_partitions) { + const int num_heads = gridDim.x; + const int head_idx = blockIdx.x; + const int seq_idx = blockIdx.y; + const int seq_len = seq_lens[seq_idx]; + const int num_partitions = DIVIDE_ROUND_UP(seq_len, PARTITION_SIZE); + if (num_partitions == 1) { + // No need to reduce. Only copy tmp_out to out. + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + for (int i = threadIdx.x; i < HEAD_SIZE; i += blockDim.x) { + out_ptr[i] = tmp_out_ptr[i]; + } + // Terminate the thread block. + return; + } + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + const int warp_idx = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + + // Size: 2 * num_partitions. + extern __shared__ char shared_mem[]; + // Workspace for reduction. + __shared__ float red_smem[2 * NUM_WARPS]; + + // Load max logits to shared memory. + float* shared_max_logits = reinterpret_cast(shared_mem); + const float* max_logits_ptr = max_logits + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float max_logit = -FLT_MAX; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + const float l = max_logits_ptr[i]; + shared_max_logits[i] = l; + max_logit = fmaxf(max_logit, l); + } + __syncthreads(); + + // Get the global max logit. + // Reduce within the warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + if (lane == 0) { + red_smem[warp_idx] = max_logit; + } + __syncthreads(); + // Reduce across warps. + max_logit = lane < NUM_WARPS ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = NUM_WARPS / 2; mask >= 1; mask /= 2) { + max_logit = fmaxf(max_logit, VLLM_SHFL_XOR_SYNC(max_logit, mask)); + } + // Broadcast the max value to all threads. + max_logit = VLLM_SHFL_SYNC(max_logit, 0); + + // Load rescaled exp sums to shared memory. + float* shared_exp_sums = + reinterpret_cast(shared_mem + sizeof(float) * num_partitions); + const float* exp_sums_ptr = exp_sums + + seq_idx * num_heads * max_num_partitions + + head_idx * max_num_partitions; + float global_exp_sum = 0.0f; + for (int i = threadIdx.x; i < num_partitions; i += blockDim.x) { + float l = shared_max_logits[i]; + float rescaled_exp_sum = exp_sums_ptr[i] * expf(l - max_logit); + global_exp_sum += rescaled_exp_sum; + shared_exp_sums[i] = rescaled_exp_sum; + } + __syncthreads(); + global_exp_sum = block_sum(&red_smem[NUM_WARPS], global_exp_sum); + const float inv_global_exp_sum = __fdividef(1.0f, global_exp_sum + 1e-6f); + + // Aggregate tmp_out to out. + const scalar_t* tmp_out_ptr = + tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE + + head_idx * max_num_partitions * HEAD_SIZE; + scalar_t* out_ptr = + out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE; +#pragma unroll + for (int i = threadIdx.x; i < HEAD_SIZE; i += NUM_THREADS) { + float acc = 0.0f; + for (int j = 0; j < num_partitions; ++j) { + acc += to_float(tmp_out_ptr[j * HEAD_SIZE + i]) * shared_exp_sums[j] * + inv_global_exp_sum; + } + from_float(out_ptr[i], acc); + } +} + +} // namespace vllm + +#define LAUNCH_PAGED_ATTENTION_V1(HEAD_SIZE) \ + VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( \ + ((void*)vllm::paged_attention_v1_kernel), \ + shared_mem_size); \ + vllm::paged_attention_v1_kernel \ + <<>>( \ + out_ptr, query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, \ + scale, block_tables_ptr, seq_lens_ptr, max_num_blocks_per_seq, \ + alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +// TODO(woosuk): Tune NUM_THREADS. +template +#else + int NUM_THREADS = 128> +#endif +void paged_attention_v1_launcher( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int padded_max_seq_len = + DIVIDE_ROUND_UP(max_seq_len, BLOCK_SIZE) * BLOCK_SIZE; + int logits_size = padded_max_seq_len * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + // Python-side check in vllm.worker.worker._check_if_can_support_max_seq_len + // Keep that in sync with the logic here! + int shared_mem_size = std::max(logits_size, outputs_size); + + dim3 grid(num_heads, num_seqs, 1); + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V1(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V1(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V1(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V1(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V1(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V1(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V1(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V1(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v1_launcher( \ + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, \ + seq_lens, max_seq_len, alibi_slopes, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); + +#define CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V1_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V1_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V1_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v1( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V1_LAUNCHER_BLOCK_SIZE) +} + +#define LAUNCH_PAGED_ATTENTION_V2(HEAD_SIZE) \ + vllm::paged_attention_v2_kernel \ + <<>>( \ + exp_sums_ptr, max_logits_ptr, tmp_out_ptr, query_ptr, key_cache_ptr, \ + value_cache_ptr, num_kv_heads, scale, block_tables_ptr, \ + seq_lens_ptr, max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, \ + kv_block_stride, kv_head_stride, k_scale, v_scale, tp_rank, \ + blocksparse_local_blocks, blocksparse_vert_stride, \ + blocksparse_block_size, blocksparse_head_sliding_step); \ + vllm::paged_attention_v2_reduce_kernel \ + <<>>( \ + out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, seq_lens_ptr, \ + max_num_partitions); + +template +#else + int NUM_THREADS = 128, int PARTITION_SIZE = 512> +#endif +void paged_attention_v2_launcher( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int num_kv_heads, float scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int max_seq_len, + const c10::optional& alibi_slopes, float k_scale, + float v_scale, const int tp_rank, const int blocksparse_local_blocks, + const int blocksparse_vert_stride, const int blocksparse_block_size, + const int blocksparse_head_sliding_step) { + int num_seqs = query.size(0); + int num_heads = query.size(1); + int head_size = query.size(2); + int max_num_blocks_per_seq = block_tables.size(1); + int q_stride = query.stride(0); + int kv_block_stride = key_cache.stride(0); + int kv_head_stride = key_cache.stride(1); + + [[maybe_unused]] int thread_group_size = MAX(WARP_SIZE / BLOCK_SIZE, 1); + assert(head_size % thread_group_size == 0); + + // NOTE: alibi_slopes is optional. + const float* alibi_slopes_ptr = + alibi_slopes + ? reinterpret_cast(alibi_slopes.value().data_ptr()) + : nullptr; + + T* out_ptr = reinterpret_cast(out.data_ptr()); + float* exp_sums_ptr = reinterpret_cast(exp_sums.data_ptr()); + float* max_logits_ptr = reinterpret_cast(max_logits.data_ptr()); + T* tmp_out_ptr = reinterpret_cast(tmp_out.data_ptr()); + T* query_ptr = reinterpret_cast(query.data_ptr()); + CACHE_T* key_cache_ptr = reinterpret_cast(key_cache.data_ptr()); + CACHE_T* value_cache_ptr = reinterpret_cast(value_cache.data_ptr()); + int* block_tables_ptr = block_tables.data_ptr(); + int* seq_lens_ptr = seq_lens.data_ptr(); + + constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE; + int max_num_partitions = DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); + int logits_size = PARTITION_SIZE * sizeof(float); + int outputs_size = (NUM_WARPS / 2) * head_size * sizeof(float); + + // For paged attention v2 kernel. + dim3 grid(num_heads, num_seqs, max_num_partitions); + int shared_mem_size = std::max(logits_size, outputs_size); + // For paged attention v2 reduce kernel. + dim3 reduce_grid(num_heads, num_seqs); + int reduce_shared_mem_size = 2 * max_num_partitions * sizeof(float); + + dim3 block(NUM_THREADS); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + switch (head_size) { + // NOTE(woosuk): To reduce the compilation time, we only compile for the + // head sizes that we use in the model. However, we can easily extend this + // to support any head size which is a multiple of 16. + case 64: + LAUNCH_PAGED_ATTENTION_V2(64); + break; + case 80: + LAUNCH_PAGED_ATTENTION_V2(80); + break; + case 96: + LAUNCH_PAGED_ATTENTION_V2(96); + break; + case 112: + LAUNCH_PAGED_ATTENTION_V2(112); + break; + case 120: + LAUNCH_PAGED_ATTENTION_V2(120); + break; + case 128: + LAUNCH_PAGED_ATTENTION_V2(128); + break; + case 192: + LAUNCH_PAGED_ATTENTION_V2(192); + break; + case 256: + LAUNCH_PAGED_ATTENTION_V2(256); + break; + default: + TORCH_CHECK(false, "Unsupported head size: ", head_size); + break; + } +} + +#define CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, KV_DTYPE, IS_BLOCK_SPARSE) \ + paged_attention_v2_launcher( \ + out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \ + num_kv_heads, scale, block_tables, seq_lens, max_seq_len, alibi_slopes, \ + k_scale, v_scale, tp_rank, blocksparse_local_blocks, \ + blocksparse_vert_stride, blocksparse_block_size, \ + blocksparse_head_sliding_step); + +#define CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE) \ + switch (is_block_sparse) { \ + case true: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, true); \ + break; \ + case false: \ + CALL_V2_LAUNCHER(T, CACHE_T, BLOCK_SIZE, IS_FP8_KV_CACHE, false); \ + break; \ + } + +// NOTE(woosuk): To reduce the compilation time, we omitted block sizes +// 1, 2, 4, 64, 128, 256. +#define CALL_V2_LAUNCHER_BLOCK_SIZE(T, CACHE_T, KV_DTYPE) \ + switch (block_size) { \ + case 8: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 8, KV_DTYPE); \ + break; \ + case 16: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 16, KV_DTYPE); \ + break; \ + case 32: \ + CALL_V2_LAUNCHER_SPARSITY(T, CACHE_T, 32, KV_DTYPE); \ + break; \ + default: \ + TORCH_CHECK(false, "Unsupported block size: ", block_size); \ + break; \ + } + +void paged_attention_v2( + torch::Tensor& out, // [num_seqs, num_heads, head_size] + torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions] + torch::Tensor& + tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size] + torch::Tensor& query, // [num_seqs, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + int64_t num_kv_heads, // [num_heads] + double scale, + torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq] + torch::Tensor& seq_lens, // [num_seqs] + int64_t block_size, int64_t max_seq_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step) { + const bool is_block_sparse = (blocksparse_vert_stride > 1); + DISPATCH_BY_KV_CACHE_DTYPE(query.dtype(), kv_cache_dtype, + CALL_V2_LAUNCHER_BLOCK_SIZE) +} + +#undef WARP_SIZE +#undef MAX +#undef MIN +#undef DIVIDE_ROUND_UP diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_utils.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_utils.cuh new file mode 100644 index 000000000..f5a3f7039 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/attention_utils.cuh @@ -0,0 +1,57 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "hip_compat.h" +#include "attention_dtypes.h" + +#include +#include + +namespace vllm { + +// Q*K^T operation. +template +inline __device__ float qk_dot_(const Vec (&q)[N], const Vec (&k)[N]) { + using A_vec = typename FloatVec::Type; + // Compute the parallel products for Q*K^T (treat vector lanes separately). + A_vec qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = vllm::fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREAD_GROUP_SIZE / 2; mask >= 1; mask /= 2) { + qk += VLLM_SHFL_XOR_SYNC(qk, mask); + } + return qk; +} + +template +struct Qk_dot { + template + static inline __device__ float dot(const Vec (&q)[N], const Vec (&k)[N]) { + return qk_dot_(q, k); + } +}; + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache.h new file mode 100644 index 000000000..11c4c5001 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include +#include + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping); + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping); + +void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale); + +void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, + torch::Tensor& key_cache, + torch::Tensor& value_cache, + torch::Tensor& slot_mapping, + const std::string& kv_cache_dtype, + const double k_scale, const double v_scale); + +// Just for unittest +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache_kernels.cu new file mode 100644 index 000000000..bfd054811 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/cache_kernels.cu @@ -0,0 +1,405 @@ +#include +#include +#include + +#include "hip_compat.h" +#include "dispatch_utils.h" + +#ifdef USE_ROCM + #include "quant_utils.cuh" +#else + #include "quantization/fp8/nvidia/quant_utils.cuh" +#endif + +#include +#include +#include +#include + +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +void swap_blocks(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + cudaMemcpyKind memcpy_type; + if (src_device.is_cuda() && dst_device.is_cuda()) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + memcpy_type = cudaMemcpyDeviceToDevice; + } else if (src_device.is_cuda() && dst_device.is_cpu()) { + memcpy_type = cudaMemcpyDeviceToHost; + } else if (src_device.is_cpu() && dst_device.is_cuda()) { + memcpy_type = cudaMemcpyHostToDevice; + } else { + TORCH_CHECK(false, "Invalid device combination"); + } + + // NOTE(youkaichao): keep in mind that `block_mapping` should be + // a cpu tensor, otherwise every `item` call will require a gpu-cpu + // synchronization. + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t block_size_in_bytes = src.element_size() * src[0].numel(); + const at::cuda::OptionalCUDAGuard device_guard( + src_device.is_cuda() ? src_device : dst_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + // NOTE(woosuk): This can be slow if the number of blocks is large. + const int64_t num_blocks = block_mapping.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + cudaMemcpyAsync(dst_ptr + dst_offset, src_ptr + src_offset, + block_size_in_bytes, memcpy_type, stream); + } +} + +namespace vllm { + +// Grid: (num_layers, num_pairs) +template +__global__ void copy_blocks_kernel(int64_t* key_cache_ptrs, + int64_t* value_cache_ptrs, + const int64_t* __restrict__ block_mapping, + const int numel_per_block) { + const int layer_idx = blockIdx.x; + const int pair_idx = blockIdx.y; + + scalar_t* key_cache = reinterpret_cast(key_cache_ptrs[layer_idx]); + scalar_t* value_cache = + reinterpret_cast(value_cache_ptrs[layer_idx]); + int64_t src_block_number = block_mapping[2 * pair_idx]; + int64_t dst_block_number = block_mapping[2 * pair_idx + 1]; + + const int64_t src_block_offset = src_block_number * numel_per_block; + const int64_t dst_block_offset = dst_block_number * numel_per_block; + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + key_cache[dst_offset] = key_cache[src_offset]; + } + for (int i = threadIdx.x; i < numel_per_block; i += blockDim.x) { + int64_t src_offset = src_block_offset + i; + int64_t dst_offset = dst_block_offset + i; + value_cache[dst_offset] = value_cache[src_offset]; + } +} + +} // namespace vllm + +// Note: the key_caches and value_caches vectors are constant but +// not the Tensors they contain. The vectors need to be const refs +// in order to satisfy pytorch's C++ operator registration code. +void copy_blocks(std::vector const& key_caches, + std::vector const& value_caches, + const torch::Tensor& block_mapping) { + int num_layers = key_caches.size(); + TORCH_CHECK(num_layers == value_caches.size()); + if (num_layers == 0) { + return; + } + torch::Device cache_device = key_caches[0].device(); + TORCH_CHECK(cache_device.is_cuda()); + + // Create data structures for the kernel. + // Create an array of pointers to the key and value caches. + int64_t key_cache_ptrs[num_layers]; + int64_t value_cache_ptrs[num_layers]; + for (int layer_idx = 0; layer_idx < num_layers; ++layer_idx) { + key_cache_ptrs[layer_idx] = + reinterpret_cast(key_caches[layer_idx].data_ptr()); + value_cache_ptrs[layer_idx] = + reinterpret_cast(value_caches[layer_idx].data_ptr()); + } + + // block_mapping is a 2D tensor with shape (num_pairs, 2). + int num_pairs = block_mapping.size(0); + + // Move the data structures to the GPU. + // NOTE: This synchronizes the CPU and GPU. + torch::Tensor key_cache_ptrs_tensor = + torch::from_blob(key_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + torch::Tensor value_cache_ptrs_tensor = + torch::from_blob(value_cache_ptrs, {num_layers}, torch::kInt64) + .to(cache_device); + + // Launch the kernel. + const int numel_per_block = key_caches[0][0].numel(); + dim3 grid(num_layers, num_pairs); + dim3 block(std::min(1024, numel_per_block)); + const at::cuda::OptionalCUDAGuard device_guard(cache_device); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES( + key_caches[0].scalar_type(), "copy_blocks_kernel", ([&] { + vllm::copy_blocks_kernel<<>>( + key_cache_ptrs_tensor.data_ptr(), + value_cache_ptrs_tensor.data_ptr(), + block_mapping.data_ptr(), numel_per_block); + })); +} + +namespace vllm { + +template +__global__ void reshape_and_cache_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, num_heads, head_size/x, + // block_size, x] + cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, + // block_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int key_stride, const int value_stride, const int num_heads, + const int head_size, const int block_size, const int x, const float k_scale, + const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + if (slot_idx < 0) { + // Padding token that should be ignored. + return; + } + + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int x_idx = head_offset / x; + const int x_offset = head_offset % x; + + const int64_t tgt_key_idx = + block_idx * num_heads * (head_size / x) * block_size * x + + head_idx * (head_size / x) * block_size * x + x_idx * block_size * x + + block_offset * x + x_offset; + const int64_t tgt_value_idx = + block_idx * num_heads * head_size * block_size + + head_idx * head_size * block_size + head_offset * block_size + + block_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_idx] = tgt_key; + value_cache[tgt_value_idx] = tgt_value; + } else { + key_cache[tgt_key_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } + } +} + +template +__global__ void reshape_and_cache_flash_kernel( + const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] + const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size] + cache_t* __restrict__ key_cache, // [num_blocks, block_size, num_heads, + // head_size] + cache_t* __restrict__ value_cache, // [num_blocks, block_size, num_heads, + // head_size] + const int64_t* __restrict__ slot_mapping, // [num_tokens] + const int block_stride, const int key_stride, const int value_stride, + const int num_heads, const int head_size, const int block_size, + const float k_scale, const float v_scale) { + const int64_t token_idx = blockIdx.x; + const int64_t slot_idx = slot_mapping[token_idx]; + // NOTE: slot_idx can be -1 if the token is padded + if (slot_idx < 0) { + return; + } + const int64_t block_idx = slot_idx / block_size; + const int64_t block_offset = slot_idx % block_size; + const int n = num_heads * head_size; + for (int i = threadIdx.x; i < n; i += blockDim.x) { + const int64_t src_key_idx = token_idx * key_stride + i; + const int64_t src_value_idx = token_idx * value_stride + i; + const int head_idx = i / head_size; + const int head_offset = i % head_size; + const int64_t tgt_key_value_idx = block_idx * block_stride + + block_offset * num_heads * head_size + + head_idx * head_size + head_offset; + scalar_t tgt_key = key[src_key_idx]; + scalar_t tgt_value = value[src_value_idx]; + if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) { + key_cache[tgt_key_value_idx] = tgt_key; + value_cache[tgt_key_value_idx] = tgt_value; + } else { + key_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_key, k_scale); + value_cache[tgt_key_value_idx] = + fp8::scaled_convert(tgt_value, v_scale); + } + } +} +} // namespace vllm + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), key_stride, value_stride, \ + num_heads, head_size, block_size, x, k_scale, v_scale); + +void reshape_and_cache( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& + key_cache, // [num_blocks, num_heads, head_size/x, block_size, x] + torch::Tensor& + value_cache, // [num_blocks, num_heads, head_size, block_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(3); + int x = key_cache.size(4); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE) +} + +// KV_T is the stored data type of kv-cache. +// CACHE_T is the data type of key and value tensors. +// KV_DTYPE is the real data type of kv-cache. +#define CALL_RESHAPE_AND_CACHE_FLASH(KV_T, CACHE_T, KV_DTYPE) \ + vllm::reshape_and_cache_flash_kernel \ + <<>>( \ + reinterpret_cast(key.data_ptr()), \ + reinterpret_cast(value.data_ptr()), \ + reinterpret_cast(key_cache.data_ptr()), \ + reinterpret_cast(value_cache.data_ptr()), \ + slot_mapping.data_ptr(), block_stride, key_stride, \ + value_stride, num_heads, head_size, block_size, k_scale, v_scale); + +void reshape_and_cache_flash( + torch::Tensor& key, // [num_tokens, num_heads, head_size] + torch::Tensor& value, // [num_tokens, num_heads, head_size] + torch::Tensor& key_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& + value_cache, // [num_blocks, block_size, num_heads, head_size] + torch::Tensor& slot_mapping, // [num_tokens] + const std::string& kv_cache_dtype, const double k_scale, + const double v_scale) { + int num_tokens = key.size(0); + int num_heads = key.size(1); + int head_size = key.size(2); + int block_size = key_cache.size(1); + + int key_stride = key.stride(0); + int value_stride = value.stride(0); + int block_stride = key_cache.stride(0); + TORCH_CHECK(key_cache.stride(0) == value_cache.stride(0)); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * head_size, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(key)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype, + CALL_RESHAPE_AND_CACHE_FLASH); +} + +namespace vllm { + +template +__global__ void convert_fp8_kernel(const Tin* __restrict__ src_cache, + Tout* __restrict__ dst_cache, + const float scale, + const int64_t block_stride) { + const int64_t block_idx = blockIdx.x; + for (int i = threadIdx.x; i < block_stride; i += blockDim.x) { + int64_t idx = block_idx * block_stride + i; + dst_cache[idx] = + fp8::scaled_convert(src_cache[idx], scale); + } +} + +} // namespace vllm + +#define CALL_CONVERT_FP8(Tout, Tin, KV_DTYPE) \ + vllm::convert_fp8_kernel<<>>( \ + reinterpret_cast(src_cache.data_ptr()), \ + reinterpret_cast(dst_cache.data_ptr()), scale, block_stride); + +// Only for testing. +void convert_fp8(torch::Tensor& dst_cache, torch::Tensor& src_cache, + const double scale, const std::string& kv_cache_dtype) { + torch::Device src_device = src_cache.device(); + torch::Device dst_device = dst_cache.device(); + TORCH_CHECK(src_device.is_cuda(), "src must be on a GPU") + TORCH_CHECK(dst_device.is_cuda(), "dst must be on a GPU") + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same GPU"); + at::cuda::OptionalCUDAGuard device_guard(src_device); + + int64_t num_blocks = src_cache.size(0); + int64_t block_stride = src_cache.stride(0); + + dim3 grid(num_blocks); + dim3 block(std::min(block_stride, int64_t(512))); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + if (kv_cache_dtype == "auto") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kAuto); + } + } else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") { + if (src_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(uint8_t, float, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint8_t, uint16_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (src_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(uint8_t, __nv_bfloat16, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Float) { + CALL_CONVERT_FP8(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::Half) { + CALL_CONVERT_FP8(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); + } else if (dst_cache.dtype() == at::ScalarType::BFloat16) { + CALL_CONVERT_FP8(__nv_bfloat16, uint8_t, + vllm::Fp8KVCacheDataType::kFp8E4M3); + } + } else { + TORCH_CHECK(false, "Unsupported data type: ", kv_cache_dtype); + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/ck_py_interface.cpp b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/ck_py_interface.cpp new file mode 100644 index 000000000..d49e73d2a --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/ck_py_interface.cpp @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2024 Advanced Micro Devices, Inc. All rights reserved. + * + * @Script: ck_py_interface.cpp + * @Author: valarLip + * @Email: lingpeng.jin@amd.com + * @Create At: 2024-10-24 12:11:58 + * @Last Modified By: valarLip + * @Last Modified At: 2024-10-24 14:39:16 + * @Description: This is description. + */ + +#include +#include + +#if defined(FIND_CK) +#include "layernorm2d_fwd.hpp" + +// common utility functions +#define FOREACH_BUFFER_TORCH_TYPE_MAP(F) \ + F("fp32", torch::kFloat) \ + F("fp16", torch::kHalf) \ + F("bf16", torch::kBFloat16) + +inline std::string torchDTypeToStr(caffe2::TypeMeta dtype) +{ +#define TYPE_CASE(type, torch_type) \ + case torch_type: \ + { \ + return type; \ + } + + switch (dtype.toScalarType()) + { + FOREACH_BUFFER_TORCH_TYPE_MAP(TYPE_CASE); + default: + throw std::runtime_error("Unsupported data type " + std::to_string((int8_t)(dtype.toScalarType()))); + } + +#undef TYPE_CASE +} + +void layernorm2d(torch::Tensor &out, // [hidden_size] + torch::Tensor &input, // [hidden_size] + torch::Tensor &weight, // [hidden_size] + torch::Tensor &bias, // [hidden_size] + double epsilon) +{ + // auto dtype = input.dtype(); + // TORCH_CHECK(dtype == torch::kFloat16 || dtype == torch::kBFloat16, + // "FlashAttention only support fp16 and bf16 data type"); + + std::string dtype_str = torchDTypeToStr(input.dtype()); + int n = input.size(-1); + int m = input.numel() / n; + int stride = n; + bool SaveMeanVar = false; + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + layernorm2d_fwd({dtype_str, SaveMeanVar}, + {input.data_ptr(), + weight.data_ptr(), + bias.data_ptr(), + out.data_ptr(), + nullptr, + nullptr, + static_cast(epsilon), + m, + n, + stride}, + {stream}); +} +#endif // diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom.cu new file mode 100644 index 000000000..a01d9a284 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom.cu @@ -0,0 +1,78 @@ +#include +#include +#include + +// declare templates for front (cpp) and back (cuda) sides of function: +// template + +// void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, +// cudaStream_t stream, const int rows_per_block); +// void LLMM_Silu(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, +// const int64_t rows_per_block) { +// auto M = in_a.size(0); +// auto K = in_a.size(1); +// LLGemm_Silu(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, +// at::cuda::getCurrentCUDAStream(), rows_per_block); +// } + +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block); + +// template +void LLMM1(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block) { + auto M = in_a.size(0); + auto K = in_a.size(1); + // if (N != in_b.numel()) + // throw std::invalid_argument("Size mismatch A.numel(): " + + // std::to_string(in_a.numel()) + // + ", B.numel(): " + + // std::to_string(in_b.numel())); + + // out_c.resize_({N}); + + // call the kernel function... + LLGemm1(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), rows_per_block); +} + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M, const int K, + const int N, cudaStream_t stream, const int CuCount); + +void wvSpltK(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t N_in, const int64_t CuCount) { + auto M = in_a.size(0); + auto K = in_a.size(1); + int N = N_in; + wvSpltK_(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, N, + at::cuda::getCurrentCUDAStream(), CuCount); +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx); + +void LLZZ(at::Tensor in_a, at::Tensor in_b, at::Tensor out_c, + const int64_t solidx = 0) { + auto M = in_a.size(0); + auto K = in_a.size(1); + + LLGemmZZ(in_a.data_ptr(), in_b.data_ptr(), out_c.data_ptr(), M, K, + at::cuda::getCurrentCUDAStream(), solidx); +} +// instantiate the CPP template for T=float: +// template void AddGPU(at::Tensor in_a, at::Tensor in_b, at::Tensor +// out_c); + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream); + +void MMCustomGPU(at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c) { + auto matA_sizes{in_a.sizes()}; + auto matB_sizes{in_b.sizes()}; + auto matO_sizes{out_c.sizes()}; + MMGPUKernel(in_a.data_ptr(), in_b.data_ptr(), + out_c.data_ptr(), matA_sizes[0], matA_sizes[1], + matB_sizes[0], matB_sizes[1], matO_sizes[0], matO_sizes[1], + at::cuda::getCurrentCUDAStream()); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cu new file mode 100644 index 000000000..8152cb3d3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cu @@ -0,0 +1,175 @@ +#include +#include +#include +#include + +#include "custom_all_reduce.cuh" + +// fake pointer type, must match fptr_t type in ops.h +using fptr_t = int64_t; +static_assert(sizeof(void*) == sizeof(fptr_t)); + +fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, + const std::vector& handles, + const std::vector& offsets, int64_t rank, + bool full_nvlink) { + int world_size = offsets.size(); + if (world_size > 8) + throw std::invalid_argument("world size > 8 is not supported"); + if (world_size % 2 != 0) + throw std::invalid_argument("Odd num gpus is not supported for now"); + if (world_size != handles.size()) + throw std::invalid_argument( + "handles length should equal to offsets length"); + if (rank < 0 || rank >= world_size) + throw std::invalid_argument("invalid rank passed in"); + + cudaIpcMemHandle_t ipc_handles[8]; + for (int i = 0; i < world_size; i++) { + std::memcpy(&ipc_handles[i], handles[i].data(), sizeof(cudaIpcMemHandle_t)); + } + return (fptr_t) new vllm::CustomAllreduce( + reinterpret_cast(meta.data_ptr()), rank_data.data_ptr(), + rank_data.numel(), ipc_handles, offsets, rank, full_nvlink); +} + +/** + * Make sure tensor t's data lies completely within ((char)t.data_ptr()) + + * t.numel() * t.element_size(). This is slightly weaker than t.is_contiguous() + * because it allows transpose of contiguous slice (i.e. slicing the first + * dimension). Currently, we require this because stride information is not + * passed into the kernels and we treat input tensors as flat. + * + * Examples + * A = torch.zeros(3, 3, 3) + * 1. A: OK + * 2. A[1:]: OK + * 3. A.permute(2, 0, 1): OK + * 4. A[1:].permute(2, 0, 1): OK + * 5. A[None].expand(2, -1, -1, -1): Not OK + * 6. A[:, 1:, 1:]: Not OK + */ +bool _is_weak_contiguous(torch::Tensor& t) { + return t.is_contiguous() || + (t.storage().nbytes() - t.storage_offset() * t.element_size() == + t.numel() * t.element_size()); +} + +void _all_reduce(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out, + cudaStream_t stream) { + auto fa = reinterpret_cast(_fa); + TORCH_CHECK(_is_weak_contiguous(out)); + switch (out.scalar_type()) { + case at::ScalarType::Float: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), + out.numel()); + break; + } + case at::ScalarType::Half: { + fa->allreduce(stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) + case at::ScalarType::BFloat16: { + fa->allreduce( + stream, reinterpret_cast(inp.data_ptr()), + reinterpret_cast(out.data_ptr()), out.numel()); + break; + } +#endif + default: + throw std::runtime_error( + "custom allreduce only supports float32, float16 and bfloat16"); + } +} + +void all_reduce_reg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + _all_reduce(_fa, inp, out, stream); +} + +void all_reduce_unreg(fptr_t _fa, torch::Tensor& inp, torch::Tensor& reg_buffer, + torch::Tensor& out) { + const at::cuda::OptionalCUDAGuard device_guard(device_of(inp)); + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + + auto input_size = inp.numel() * inp.element_size(); + TORCH_CHECK_EQ(inp.scalar_type(), out.scalar_type()); + TORCH_CHECK_EQ(inp.numel(), out.numel()); + TORCH_CHECK(input_size <= reg_buffer.numel() * reg_buffer.element_size(), + "registered buffer is too small to contain the input"); + AT_CUDA_CHECK(cudaMemcpyAsync(reg_buffer.data_ptr(), inp.data_ptr(), + input_size, cudaMemcpyDeviceToDevice, stream)); + _all_reduce(_fa, reg_buffer, out, stream); +} + +void dispose(fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + delete fa; +} + +int64_t meta_size() { return sizeof(vllm::Signal); } + +void register_buffer(fptr_t _fa, torch::Tensor& t, + const std::vector& handles, + const std::vector& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_buffer(handles, offsets, t.data_ptr()); +} + +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa) { + auto fa = reinterpret_cast(_fa); + auto [handle_bytes, offsets] = fa->get_graph_buffer_ipc_meta(); + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto handles = + torch::empty({static_cast(handle_bytes.size())}, options); + std::memcpy(handles.data_ptr(), handle_bytes.data(), handle_bytes.size()); + return {handles, std::move(offsets)}; +} + +void register_graph_buffers(fptr_t _fa, const std::vector& handles, + const std::vector>& offsets) { + auto fa = reinterpret_cast(_fa); + fa->register_graph_buffers(handles, offsets); +} + +#ifdef USE_ROCM + +void free_meta_buffer(void* buffer) { CUDACHECK(cudaFree(buffer)); } + +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor& inp) { + auto options = + torch::TensorOptions().dtype(torch::kUInt8).device(torch::kCPU); + auto data_handle = + torch::empty({static_cast(sizeof(cudaIpcMemHandle_t))}, options); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data_ptr(), + inp.data_ptr())); + return data_handle; +} + +torch::Tensor allocate_meta_buffer(int64_t size) { + auto device_index = c10::cuda::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + AT_CUDA_CHECK( + hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); + AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + auto options = torch::TensorOptions() + .dtype(torch::kI8) + .device(torch::kCUDA, device_index); + return torch::from_blob(buffer, {size}, free_meta_buffer, options); +} + +#endif diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cuh new file mode 100644 index 000000000..c640b15a2 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_all_reduce.cuh @@ -0,0 +1,534 @@ +#pragma once + +#include +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; +#else + #include +#endif +#include +#include + +#include +#include +#include +#include +#include + +#define CUDACHECK(cmd) \ + do { \ + cudaError_t e = cmd; \ + if (e != cudaSuccess) { \ + printf("Failed: Cuda error %s:%d '%s'\n", __FILE__, __LINE__, \ + cudaGetErrorString(e)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +namespace vllm { + +constexpr int kMaxBlocks = 64; +// note: we don't want to use atomics for signals because peer atomics are no +// supported on PCIe links +struct Signal { + alignas(128) uint32_t start[kMaxBlocks][8]; + alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank +}; + +#ifdef USE_ROCM +struct __align__(16) RankData { const void* ptrs[8]; }; +#else +struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +#endif + +struct __align__(16) RankSignals { volatile Signal* signals[8]; }; + +// like std::array, but aligned +template +struct __align__(alignof(T) * sz) array_t { + T data[sz]; + using type = T; + static constexpr int size = sz; +}; + +// use packed type to maximize memory efficiency +// goal: generate ld.128 and st.128 instructions +template +struct packed_t { + // the (P)acked type for load/store + using P = array_t; + // the (A)ccumulator type for reduction + using A = array_t; +}; + +#define DINLINE __device__ __forceinline__ + +// scalar cast functions +DINLINE float upcast_s(half val) { return __half2float(val); } + +template +DINLINE T downcast_s(float val); +template <> +DINLINE half downcast_s(float val) { + return __float2half(val); +} + +// scalar add functions +// for some reason when compiling with Pytorch, the + operator for half and +// bfloat is disabled so we call the intrinsics directly +DINLINE half& assign_add(half& a, half b) { + a = __hadd(a, b); + return a; +} +DINLINE float& assign_add(float& a, float b) { return a += b; } + +#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) +DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); } +template <> +DINLINE nv_bfloat16 downcast_s(float val) { + return __float2bfloat16(val); +} +DINLINE nv_bfloat16& assign_add(nv_bfloat16& a, nv_bfloat16 b) { + a = __hadd(a, b); + return a; +} +#endif + +template +DINLINE array_t& packed_assign_add(array_t& a, array_t b) { +#pragma unroll + for (int i = 0; i < N; i++) { + assign_add(a.data[i], b.data[i]); + } + return a; +} + +template +DINLINE array_t upcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + array_t out; +#pragma unroll + for (int i = 0; i < N; i++) { + out.data[i] = upcast_s(val.data[i]); + } + return out; + } +} + +template +DINLINE O downcast(array_t val) { + if constexpr (std::is_same::value) { + return val; + } else { + O out; +#pragma unroll + for (int i = 0; i < O::size; i++) { + out.data[i] = downcast_s(val.data[i]); + } + return out; + } +} + +// This function is meant to be used as the first synchronization in the all +// reduce kernel. Thus, it doesn't need to make any visibility guarantees for +// prior memory accesses. Note: volatile writes will not be reordered against +// other volatile writes. +template +DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { +#ifdef USE_ROCM + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, + __ATOMIC_RELAXED); + // wait until we got true from all ranks + while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED) < flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->end[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->start[blockIdx.x][threadIdx.x]); + } + __syncthreads(); +#endif +} + +// This function is meant to be used as the second or the final synchronization +// barrier in the all reduce kernel. If it's the final synchronization barrier, +// we don't need to make any visibility guarantees for prior memory accesses. +template +DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, + int rank) { +#ifdef USE_ROCM + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE); + // wait until we got true from all ranks + while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) < + flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + if constexpr (!final_sync) __threadfence_system(); + if (threadIdx.x < ngpus) { + // reset flag for next time + self_sg->start[blockIdx.x][threadIdx.x] = 0; + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + sg.signals[threadIdx.x]->end[blockIdx.x][rank] = 1; + // wait until we got true from all ranks + while (!self_sg->end[blockIdx.x][threadIdx.x]); + } + if constexpr (!final_sync) __syncthreads(); +#endif +} + +template +DINLINE P packed_reduce(const P* ptrs[], int idx) { + A tmp = upcast(ptrs[0][idx]); +#pragma unroll + for (int i = 1; i < ngpus; i++) { + packed_assign_add(tmp, upcast(ptrs[i][idx])); + } + return downcast

(tmp); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_1stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { + using P = typename packed_t::P; + using A = typename packed_t::A; + // note: we don't reorder the address so the accumulation order is the same + // for all ranks, ensuring bitwise identical results + auto dp = *_dp; + start_sync(sg, self_sg, rank); + // do the actual reduction + for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; + idx += gridDim.x * blockDim.x) { + ((P*)result)[idx] = packed_reduce((const P**)&dp.ptrs[0], idx); + } + end_sync(sg, self_sg, rank); +} + +template +DINLINE P* get_tmp_buf(volatile Signal* sg) { + return (P*)(((Signal*)sg) + 1); +} + +template +__global__ void __launch_bounds__(512, 1) + cross_device_reduce_2stage(RankData* _dp, RankSignals sg, + volatile Signal* self_sg, T* __restrict__ result, + int rank, int size) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = gridDim.x * blockDim.x; + using P = typename packed_t::P; + using A = typename packed_t::A; + int part = size / ngpus; + int start = rank * part; + int end = rank == ngpus - 1 ? size : start + part; + int largest_part = part + size % ngpus; + const P* ptrs[ngpus]; + P* tmps[ngpus]; +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int target = (rank + i) % ngpus; + ptrs[i] = (const P*)_dp->ptrs[target]; + tmps[i] = get_tmp_buf

(sg.signals[target]); + } + auto tmp_out = tmps[0]; + start_sync(sg, self_sg, rank); + // stage 1: reduce scatter + for (int idx = start + tid; idx < end; idx += stride) { + tmp_out[idx - start] = packed_reduce(ptrs, idx); + } + end_sync(sg, self_sg, rank); + + // stage 2: allgather. Note: it's important to match the tid between + // the two stages, because visibility across devices is only guaranteed + // between threads that have the same tid. If thread i computes the sum of + // start + i in the first stage, then thread i also gathers start + i from all + // ranks. + for (int idx = tid; idx < largest_part; idx += stride) { +#pragma unroll + for (int i = 0; i < ngpus; i++) { + int gather_from_rank = ((rank + i) % ngpus); + if (gather_from_rank == ngpus - 1 || idx < part) { + int dst_idx = gather_from_rank * part + idx; + ((P*)result)[dst_idx] = tmps[i][idx]; + } + } + } +} + +using IPC_KEY = std::array; +static_assert(sizeof(IPC_KEY) == sizeof(cudaIpcMemHandle_t)); +static_assert(alignof(IPC_KEY) == alignof(cudaIpcMemHandle_t)); + +class CustomAllreduce { + public: + int rank_; + int world_size_; + bool full_nvlink_; + + // below are device pointers + RankSignals sg_; + std::unordered_map buffers_; + Signal* self_sg_; + + // stores the registered device pointers from all ranks + RankData *d_rank_data_base_, *d_rank_data_end_; + std::vector graph_unreg_buffers_; + // a map from IPC handles to opened IPC pointers + std::map ipc_handles_; + + /** + * meta is a pointer to device metadata and temporary buffer for allreduce. + * + * There's a total of sizeof(Signal) of prefix before the actual data, + * so meta + 1 points to actual temporary buffer. + * + * note: this class does not own any device memory. Any required buffers + * are passed in from the constructor + */ + CustomAllreduce(Signal* meta, void* rank_data, size_t rank_data_sz, + const cudaIpcMemHandle_t* handles, + const std::vector& offsets, int rank, + bool full_nvlink = true) + : rank_(rank), + world_size_(offsets.size()), + full_nvlink_(full_nvlink), + self_sg_(meta), + d_rank_data_base_(reinterpret_cast(rank_data)), + d_rank_data_end_(d_rank_data_base_ + rank_data_sz / sizeof(RankData)) { + for (int i = 0; i < world_size_; i++) { + Signal* rank_sg; + if (i != rank_) { + char* handle = open_ipc_handle(&handles[i]); + handle += offsets[i]; + rank_sg = (Signal*)handle; + } else { + rank_sg = self_sg_; + } + sg_.signals[i] = rank_sg; + } + } + + char* open_ipc_handle(const void* ipc_handle) { + auto [it, new_handle] = + ipc_handles_.insert({*((IPC_KEY*)ipc_handle), nullptr}); + if (new_handle) { + char* ipc_ptr; + CUDACHECK(cudaIpcOpenMemHandle((void**)&ipc_ptr, + *((const cudaIpcMemHandle_t*)ipc_handle), + cudaIpcMemLazyEnablePeerAccess)); + it->second = ipc_ptr; + } + return it->second; + } + + std::pair, std::vector> + get_graph_buffer_ipc_meta() { + auto num_buffers = graph_unreg_buffers_.size(); + auto handle_sz = sizeof(cudaIpcMemHandle_t); + std::vector handles(handle_sz * num_buffers, 0); + std::vector offsets(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto ptr = graph_unreg_buffers_[i]; + void* base_ptr; + // note: must share the base address of each allocation, or we get wrong + // address + if (cuPointerGetAttribute(&base_ptr, +#ifdef USE_ROCM + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#else + CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#endif + (CUdeviceptr)ptr) != CUDA_SUCCESS) + throw std::runtime_error("failed to get pointer attr"); + CUDACHECK(cudaIpcGetMemHandle( + (cudaIpcMemHandle_t*)&handles[i * handle_sz], base_ptr)); + offsets[i] = ((char*)ptr) - ((char*)base_ptr); + } + return std::make_pair(handles, offsets); + } + + void check_rank_data_capacity(size_t num = 1) { + if (d_rank_data_base_ + num > d_rank_data_end_) + throw std::runtime_error( + "Rank data buffer is overflowed by " + + std::to_string(d_rank_data_base_ + num - d_rank_data_end_)); + } + + void register_buffer(const std::vector& handles, + const std::vector& offsets, void* self) { + check_rank_data_capacity(); + RankData data; + for (int i = 0; i < world_size_; i++) { + if (i != rank_) { + char* handle = open_ipc_handle(handles[i].data()); + handle += offsets[i]; + data.ptrs[i] = handle; + } else { + data.ptrs[i] = self; + } + } + auto d_data = d_rank_data_base_++; + CUDACHECK( + cudaMemcpy(d_data, &data, sizeof(RankData), cudaMemcpyHostToDevice)); + buffers_[self] = d_data; + } + + // note: when registering graph buffers, we intentionally choose to not + // deduplicate the addresses. That means if the allocator reuses some + // addresses, they will be registered again. This is to account for the remote + // possibility of different allocation patterns between ranks. For example, + // rank 1 may get the same input address for the second allreduce, but rank 2 + // got a different address. IPC handles have internal reference counting + // mechanism so overhead should be small. + void register_graph_buffers( + const std::vector& handles, + const std::vector>& offsets) { + auto num_buffers = graph_unreg_buffers_.size(); + check_rank_data_capacity(num_buffers); + std::vector rank_data(num_buffers); + for (int i = 0; i < num_buffers; i++) { + auto self_ptr = graph_unreg_buffers_[i]; + auto& rd = rank_data[i]; + for (int j = 0; j < world_size_; j++) { + if (j != rank_) { + char* handle = + open_ipc_handle(&handles[j][i * sizeof(cudaIpcMemHandle_t)]); + handle += offsets[j][i]; + rd.ptrs[j] = handle; + } else { + rd.ptrs[j] = self_ptr; + } + } + } + CUDACHECK(cudaMemcpy(d_rank_data_base_, rank_data.data(), + sizeof(RankData) * num_buffers, + cudaMemcpyHostToDevice)); + d_rank_data_base_ += num_buffers; + graph_unreg_buffers_.clear(); + } + + /** + * This is the result after careful grid search. Using 36 blocks give the best + * or close to the best runtime on the devices I tried: A100, A10, A30, T4, + * V100. You'll notice that NCCL kernels also only take a small amount of SMs. + * Not quite sure the underlying reason, but my guess is that too many SMs + * will cause contention on NVLink bus. + */ + template + void allreduce(cudaStream_t stream, T* input, T* output, int size, + int threads = 512, int block_limit = 36) { + auto d = packed_t::P::size; + if (size % d != 0) + throw std::runtime_error( + "custom allreduce currently requires input length to be multiple " + "of " + + std::to_string(d)); + if (block_limit > kMaxBlocks) + throw std::runtime_error("max supported block limit is " + + std::to_string(kMaxBlocks) + ". Got " + + std::to_string(block_limit)); + + RankData* ptrs; + cudaStreamCaptureStatus status; + CUDACHECK(cudaStreamIsCapturing(stream, &status)); + if (status == cudaStreamCaptureStatusActive) { + ptrs = d_rank_data_base_ + graph_unreg_buffers_.size(); + graph_unreg_buffers_.push_back(input); + } else { + auto it = buffers_.find(input); + if (it == buffers_.end()) + throw std::runtime_error( + "buffer address " + + std::to_string(reinterpret_cast(input)) + + " is not registered!"); + ptrs = it->second; + } + + size /= d; + auto bytes = size * sizeof(typename packed_t::P); + int blocks = std::min(block_limit, (size + threads - 1) / threads); +#define KL(ngpus, name) \ + name<<>>(ptrs, sg_, self_sg_, output, \ + rank_, size); +#define REDUCE_CASE(ngpus) \ + case ngpus: { \ + if (world_size_ == 2) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else if (full_nvlink_) { \ + if ((world_size_ <= 4 && bytes < 512 * 1024) || \ + (world_size_ <= 8 && bytes < 256 * 1024)) { \ + KL(ngpus, cross_device_reduce_1stage); \ + } else { \ + KL(ngpus, cross_device_reduce_2stage); \ + } \ + } \ + break; \ + } + + switch (world_size_) { + REDUCE_CASE(2) + REDUCE_CASE(4) + REDUCE_CASE(6) + REDUCE_CASE(8) + default: + throw std::runtime_error( + "custom allreduce only supports num gpus in (2,4,6,8). Actual num " + "gpus = " + + std::to_string(world_size_)); + } +#undef REDUCE_CASE +#undef KL + } + + ~CustomAllreduce() { + for (auto [_, ptr] : ipc_handles_) { + CUDACHECK(cudaIpcCloseMemHandle(ptr)); + } + } +}; +/** + * To inspect PTX/SASS, copy paste this header file to compiler explorer and add + a template instantiation: + * template void vllm::CustomAllreduce::allreduce(cudaStream_t, half *, + half *, int, int, int); +*/ +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_kernels.cu new file mode 100644 index 000000000..c9eb32c48 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/custom_kernels.cu @@ -0,0 +1,1309 @@ +#include +#include +#include +#include +#include "hip_compat.h" + +#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \ + defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300_MI250__ +#endif + +#if defined(NDEBUG) + #undef NDEBUG + #include + #define UNREACHABLE_CODE assert(false); + #define NDEBUG +#else + #define UNREACHABLE_CODE assert(false); +#endif + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm1_kernel(float4* af4, __half2* bf4, __half2* c, + const int K) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK * K / 8; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK] = {0.0}; + __half2 acch2; + __half2 oval; + + // As we later use warp shuffle operations, we may have more threads in the + // block than the actual available data, hence the if guard here. + if (threadid * 8 < K) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // rowA_elem4[i] holds 8 * half numbers seen as a single float4. + rowA_elem4[i] = load_ntmprl(&af4[row_addr + threadid + K / 8 * i]); + } + } + + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + __half2 Af2; + __half2 Bf2; + float2 S; + + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; + +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + // Multiply-add on 8 half. + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + + // See comment above concerning the if guard. + if (threadid * 8 < K) { + acc[i] = S.x + S.y; // accumulation on float + } + } + +// all reduce across warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + if (lane < NUM_A_ROWS_PER_BLOCK) { + red_smem[lane][warp] = acc[lane]; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + if (qwarpid < NUM_A_ROWS_PER_BLOCK) { + acc[qwarpid] = qthreadid < num_warps ? red_smem[qwarpid][qthreadid] : 0.f; +#pragma unroll + for (int mask = 16 / 2; mask >= 1; mask /= 2) { + acc[qwarpid] += __shfl_xor(acc[qwarpid], mask); + } + float oval2 = __shfl_xor(acc[qwarpid], 16); + + if (threadid % WARP_SIZE == 0 or threadid % WARP_SIZE == 32) { + oval = __float22half2_rn(make_float2(acc[qwarpid], oval2)); + c[blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 + qwarpid / 2] = oval; + } + } +} + +// define the kernel calling code: +// template +void LLGemm1(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<__half2*>(out_c); + + // NUM_TREADS need to be a multiple of WARP_SIZE, as we are using warp shuffle + // operations. + const int NUM_THREADS = + K * 2 / 16 % WARP_SIZE == 0 + ? K * 2 / 16 + : K * 2 / 16 + (WARP_SIZE - K * 2 / 16 % WARP_SIZE); + + int NUM_BLOCKS = M / rows_per_block; + + if (rows_per_block == 2) { + LLGemm1_kernel<2><<>>(af4, bf4, c, K); + } else if (rows_per_block == 4) { + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } else if (rows_per_block == 8) { + LLGemm1_kernel<8><<>>(af4, bf4, c, K); + } else if (rows_per_block == 16) { + LLGemm1_kernel<16><<>>(af4, bf4, c, K); + } else { + NUM_BLOCKS = M / 4; + LLGemm1_kernel<4><<>>(af4, bf4, c, K); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +// instantiate the kernel template for T=float: +// template void AddGPUKernel(float *in_a, float *in_b, float *out_c, +// const int M, const int K, cudaStream_t stream); + +const unsigned int TILE_WIDTH = 32; + +// Compute C = A * B +__global__ void matrixMultiplyShared(float* A, float* B, float* C, int numARows, + int numAColumns, int numBRows, + int numBColumns, int numCRows, + int numCColumns) { + __shared__ float sA[TILE_WIDTH][TILE_WIDTH]; // Tile size of 32x32 + __shared__ float sB[TILE_WIDTH][TILE_WIDTH]; + + int Row = blockDim.y * blockIdx.y + threadIdx.y; + int Col = blockDim.x * blockIdx.x + threadIdx.x; + float Cvalue = 0.0; + sA[threadIdx.y][threadIdx.x] = 0.0; + sB[threadIdx.y][threadIdx.x] = 0.0; + + for (int ph = 0; ph < (((numAColumns - 1) / TILE_WIDTH) + 1); ph++) { + if ((Row < numARows) && (threadIdx.x + (ph * TILE_WIDTH)) < numAColumns) { + sA[threadIdx.y][threadIdx.x] = + A[(Row * numAColumns) + threadIdx.x + (ph * TILE_WIDTH)]; + } else { + sA[threadIdx.y][threadIdx.x] = 0.0; + } + if (Col < numBColumns && (threadIdx.y + ph * TILE_WIDTH) < numBRows) { + sB[threadIdx.y][threadIdx.x] = + B[(threadIdx.y + ph * TILE_WIDTH) * numBColumns + Col]; + } else { + sB[threadIdx.y][threadIdx.x] = 0.0; + } + __syncthreads(); + for (int j = 0; j < TILE_WIDTH; ++j) { + Cvalue += sA[threadIdx.y][j] * sB[j][threadIdx.x]; + } + } + if (Row < numCRows && Col < numCColumns) { + C[Row * numCColumns + Col] = Cvalue; + } +} + +void MMGPUKernel(float* in_a, float* in_b, float* out_c, int numARows, + int numAColumns, int numBRows, int numBColumns, int numCRows, + int numCColumns, cudaStream_t stream) { + // Initialize the grid and block dimensions + dim3 dimBlock(TILE_WIDTH, TILE_WIDTH, 1); + dim3 dimGrid((numCColumns / TILE_WIDTH) + 1, (numCRows / TILE_WIDTH) + 1, 1); + //@@ Launch the GPU Kernel here + matrixMultiplyShared<<>>( + in_a, in_b, out_c, numARows, numAColumns, numBRows, numBColumns, numCRows, + numCColumns); + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +template +__global__ __launch_bounds__(512) void HGEMV_WFPerRow( + int m, int n, const _Float16* A, int lda, const _Float16* x, _Float16* y) { + int num_row_per_block = CTA / nThreads_per_row; + int row_id = (blockIdx.x * num_row_per_block + threadIdx.y) * MT0; + int inc = (gridDim.x * num_row_per_block) * MT0; + + while (row_id < m) { + float2 sum2[MT0]; + +#pragma unroll + for (int i = 0; i < MT0; ++i) { + sum2[i] = {0.0, 0.0}; + } + + for (int j = threadIdx.x; j < n; j += (nThreads_per_row * MT1)) { + bool is_active = j < n; + if (is_active) { + float2 x2[MT1 >> 1]; +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + x2[offset >> 1] = {x[j + nThreads_per_row * offset], + x[j + nThreads_per_row * (offset + 1)]}; + } + float2 a2[MT0][MT1 >> 1]; +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < MT1; offset += 2) { + a2[i][offset >> 1] = { + A[(row_id + i) * n + j + nThreads_per_row * offset], + A[(row_id + i) * n + j + nThreads_per_row * (offset + 1)]}; + } + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = 0; offset < (MT1 >> 1); offset++) { + sum2[i] += a2[i][offset] * x2[offset]; + } + } + } + } + float sum[MT0]; +#pragma unroll + for (int i = 0; i < MT0; i++) { + sum[i] = sum2[i].x + sum2[i].y; + } + +#pragma unroll + for (int i = 0; i < MT0; i++) { +#pragma unroll + for (int offset = nThreads_per_row >> 1; offset >= 1; + offset = offset >> 1) { + sum[i] += __shfl_down(sum[i], offset, nThreads_per_row); + } + } + if (threadIdx.x == 0) { +#pragma unroll + for (int i = 0; i < MT0; i++) { + y[row_id + i] = sum[i]; + } + } + row_id += inc; + } +} + +void LLGemmZZ(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int solidx = 0) { + // m -> M, n-> K + dim3 grid(1024); + dim3 block(64, 8); + if (solidx == 0) { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 1) { + HGEMV_WFPerRow<64, 512, 2, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else if (solidx == 2) { + HGEMV_WFPerRow<64, 512, 1, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } else { + HGEMV_WFPerRow<64, 512, 4, 8><<>>( + M, K, reinterpret_cast(in_a), K, + reinterpret_cast(in_b), + reinterpret_cast<_Float16*>(out_c)); + } + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} + +///////////////////////////////////////////// + +#define DTYPE half + +__device__ __forceinline__ int mindiv(int N, int div1, int div2) { + int nPrRnd = div1 * div2; + int rnds0 = N / nPrRnd; + nPrRnd -= div1 * 3; + int rnds3 = N / nPrRnd; + nPrRnd -= div1; + int rnds4 = N / nPrRnd; + nPrRnd -= div1; + int rnds5 = N / nPrRnd; + nPrRnd -= div1; + int rnds6 = N / nPrRnd; + nPrRnd -= div1; + int rnds7 = N / nPrRnd; + nPrRnd -= div1; + int rnds8 = N / nPrRnd; + nPrRnd -= div1; + int rnds9 = N / nPrRnd; + nPrRnd -= div1; + int rtn = div2; + if (rnds0 == rnds3) rtn = div2 - 3; + if (rnds0 == rnds4) rtn = div2 - 4; + if (rnds0 == rnds5) rtn = div2 - 5; + if (rnds0 == rnds6) rtn = div2 - 6; + if (rnds0 == rnds7) rtn = div2 - 7; + if (rnds0 == rnds8) rtn = div2 - 8; + if (rnds0 == rnds9) rtn = div2 - 9; + return rtn; +} + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] fits LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + // uint32_t commitColumn[YTILE]; + // for (uint32_t i = 0; i < YTILE; i++) { + // commitColumn[i] = 1; + //} + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + // if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + // else + // bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + // if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + // if (n < N && (n + YTILE) >= N) { + // uint32_t startColumn = N - YTILE; + // for (uint32_t i = 0; i < (n - startColumn); i++) { + // commitColumn[i] = 0; + // } + // n = startColumn; + //} + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_sml_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets cases where A[] marginally exceeds LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + + if (threadIdx.y >= _WvPrGrp) return; + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + while (n < N) { + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} + +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support +// This version targets big A[] cases, where it is much larger than LDS capacity +template +__global__ void __launch_bounds__(WvPrGrp* THRDS) + wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, const int CuCount) { + using half8 = + __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; + + union bigType { + DTYPE h[A_CHUNK]; + float f[A_CHUNK / 2]; + float2 f2[A_CHUNK / 4]; + double d[A_CHUNK / 4]; + half8 h8; + }; + + //---------------------------------------------------- + // Reserving 64 KB of LDS to have 1 WG / CU + // Goal is to bring the activation matrix A to the LDS + // and use it across the lifetime of the work group + // TODO: When activation matrix is larger than 64 KB + // then this is not goint to work! + //---------------------------------------------------- + __shared__ half s[1024 * 32]; + + //---------------------------------------------------- + // Computation of columns that need to be committed to memory! + //---------------------------------------------------- + uint32_t commitColumn[YTILE]; + for (uint32_t i = 0; i < YTILE; i++) { + commitColumn[i] = 1; + } + + // It's worth trying to load-balance... + int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); + if (threadIdx.y >= _WvPrGrp) return; + + //---------------------------------------------------- + // Indexing function into the column of weight matrix B + // Algorithm does 64 lane k-splitting / wave and uses + // WG ID and Thread ID to find the index. + //---------------------------------------------------- + uint32_t n = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + + //---------------------------------------------------- + // Fetch the activation matrix to LDS + // Loop iteration: + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements + // - Each WG will fetch 512 * 16 => 8K elements + // - Then the WG will move to another 8 K elements + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + #define PCML + #ifndef PCML + for (uint32_t k = 0; k < min(K * M, 32 * 1024); + k += THRDS * WvPrGrp * A_CHUNK) { + uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + + // Transpose of A implementation + // uint32_t k_ot = (k_in / M) + (k_in % M) * K; // transopse for + // bank-conflict-free readback + + if (k_in >= min(K * M, 32 * 1024)) break; + + //((bigType*)(&s[k_in]))->b128 = ((bigType*)(&A[k_in]))->b128; + *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + //((bigType*)(&s[k_ot]))->b128 = ((bigType*)(&A[k_in]))->b128; + } + __syncthreads(); + #endif + + #define TUC (THRDS * UNRL * A_CHUNK) + uint32_t kBase = 0; + // find biggest k size that fits in LDS + uint32_t kFit = (32 * 1024) / M; + // kFit = (kFit%TWC==0) ? kFit : (kFit-kFit%TWC+TWC); //round up to multiple + // of TUC + kFit = (kFit % TUC == 0) + ? kFit + : (kFit - kFit % TUC); // round up to multiple of TUC + // if (kFit == 0) kFit = TUC; + kFit = min(kFit, K); + + float sum[M][YTILE]; + + //---------------------------------------------------- + // Each wave works on a single column of weight matrix. + // There are 16 waves per WG, and hence, each WG is + // working on 16 columns of weight matrix. Moreover, + // we tile in column direction by YTILE, so when YTILE=1 + // the above math is right, however, when YTILE=2 then + // each wave will be working on 2 columns and WG will + // be working on 32 columns. + // + // Top level loop that makes WGs persistent! + // - WGs iterates across columns of weight matrix + // - Each wave within WG works on a given column(s) + // - After completing first set of columns, WGs start + // working on the next set of available columns + //---------------------------------------------------- + #ifdef PCML + int YW = (YTILE * _WvPrGrp); + uint32_t Nrndp = (N % YW == 0) ? N : (N - N % YW + YW); + while (n < Nrndp) { + #else + while (n < N) { + #endif + //---------------------------------------------------- + // 'sum' accumulates the matrix A x B computation + // split across 64 lanes. + // + // YTILE represents how many column of weight matrix + // are being worked on by each wave. + //---------------------------------------------------- + for (int i = 0; i < YTILE; i++) + for (int m = 0; m < M; m++) sum[m][i] = 0; + + bigType bigA[M][UNRL]; + bigType bigB0[UNRL]; + bigType bigB1[UNRL]; + bigType bigB2[UNRL]; + bigType bigB3[UNRL]; + bigType bigB4[UNRL]; + bigType bigB5[UNRL]; + bigType bigB6[UNRL]; + bigType bigB7[UNRL]; + bigType bigB8[UNRL]; + bigType bigB9[UNRL]; + bigType bigB10[UNRL]; + //---------------------------------------------------- + // Fetch weight matrix B in interleaved K-split! + // - Each thread (lane) is fetching 8 elements (A_Chunk) + // - Each wave will fetch 64*8=> 512 elements (1024B) + // - YTILE represents the number of column being serviced + // by wave + // - Loop for fetching weight matrix (B) are unrolled + // + // Fetch activation matrix A from LDS + // - Loop for fetching activation matrix (A) are unrolled + // + // Finally, do the matrix multiplication in an unrolled + // fashion. This provides lot of food for compiler + // scheduling. + // + // TODO: Logic below will only work when K is multiple of 8 + //---------------------------------------------------- + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + #ifdef PCML + if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS + if (k1 != 0) kBase += kFit; + __syncthreads(); + for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { + uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); + if (kBase + kOff >= K) break; + if (kOff >= kFit) break; + for (uint32_t m = 0; m < M; m++) { + uint32_t k_in = kBase + m * K + kOff; + uint32_t k_ot = m * kFit + kOff; + *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + } + } + __syncthreads(); + } + if (n >= N) continue; + #endif + + // Fetch the weight matrix from memory! + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + const half* B_ = &B[(n + 0) * K + k_]; + bigB0[k2].h8 = (loadnt((half8*)(&B_[0 * K]))); + //---------------------------------------------------- + // The following code with YTILE > 1 has to be deleted + //---------------------------------------------------- + if (YTILE >= 2) bigB1[k2].h8 = (loadnt((half8*)(&B_[1 * K]))); + if (YTILE >= 3) bigB2[k2].h8 = (loadnt((half8*)(&B_[2 * K]))); + if (YTILE >= 4) bigB3[k2].h8 = (loadnt((half8*)(&B_[3 * K]))); + if (YTILE >= 5) bigB4[k2].h8 = (loadnt((half8*)(&B_[4 * K]))); + if (YTILE >= 6) bigB5[k2].h8 = (loadnt((half8*)(&B_[5 * K]))); + if (YTILE >= 7) bigB6[k2].h8 = (loadnt((half8*)(&B_[6 * K]))); + if (YTILE >= 8) bigB7[k2].h8 = (loadnt((half8*)(&B_[7 * K]))); + } + + // Fetch activation matrix from either just LDS or from both LDS / memory + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + + // Fetch A activation matrix in interleaved fashion from LDS or memory + + for (int m = 0; m < M; m++) { + #ifdef PCML + bigA[m][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * m]))); + #else + if (k_ + K * m < 32 * 1024) + bigA[m][k2] = *((const bigType*)(&(s[k_ + K * m]))); + else + bigA[m][k2] = *((const bigType*)(&(A[k_ + K * m]))); + #endif + } + } + + // Do the matrix multiplication in interleaved manner + #pragma unroll + for (uint32_t k2 = 0; k2 < UNRL; k2++) { + uint32_t k = k1 + k2 * THRDS * A_CHUNK; + uint32_t k_ = k + threadIdx.x * A_CHUNK; + if (k_ >= K) break; + #pragma unroll + for (uint32_t m = 0; m < M; m++) { + // Do the matrix multiplication of activation and weight matrix + // - Remember the accumulation is happening for K-split of 64! + #pragma unroll + for (uint32_t b = 0; b < A_CHUNK / 2; b++) { + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][0]) + : "0"(sum[m][0]), "v"(bigA[m][k2].f[b]), "v"(bigB0[k2].f[b])); + + //---------------------------------------------------- + // The following code with YTILE > 1 + //---------------------------------------------------- + if (YTILE >= 2) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][1]) + : "0"(sum[m][1]), "v"(bigA[m][k2].f[b]), "v"(bigB1[k2].f[b])); + if (YTILE >= 3) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][2]) + : "0"(sum[m][2]), "v"(bigA[m][k2].f[b]), "v"(bigB2[k2].f[b])); + if (YTILE >= 4) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][3]) + : "0"(sum[m][3]), "v"(bigA[m][k2].f[b]), "v"(bigB3[k2].f[b])); + if (YTILE >= 5) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][4]) + : "0"(sum[m][4]), "v"(bigA[m][k2].f[b]), "v"(bigB4[k2].f[b])); + if (YTILE >= 6) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][5]) + : "0"(sum[m][5]), "v"(bigA[m][k2].f[b]), "v"(bigB5[k2].f[b])); + if (YTILE >= 7) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][6]) + : "0"(sum[m][6]), "v"(bigA[m][k2].f[b]), "v"(bigB6[k2].f[b])); + if (YTILE >= 8) + asm("v_dot2c_f32_f16 %0, %2, %3" + : "=v"(sum[m][7]) + : "0"(sum[m][7]), "v"(bigA[m][k2].f[b]), "v"(bigB7[k2].f[b])); + } + } + } + } + + #ifdef PCML + if (n >= N) { + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + continue; + } + #endif + + //---------------------------------------------------- + // Final reduction step using shuffle + //---------------------------------------------------- + for (int m = 0; m < M; m++) { + for (int y = 0; y < YTILE; y++) { + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" + : "=v"(sum[m][y]) + : "0"(sum[m][y]), "v"(sum[m][y]), "v"(sum[m][y])); + } + } + + if (threadIdx.x == 63) { + for (int m = 0; m < M; m++) { + for (int i = 0; i < YTILE; i++) { + if (commitColumn[i]) C[n + i + m * N] = __float2half(sum[m][i]); + } + } + } + + n += CuCount * _WvPrGrp * YTILE; + kBase = 0; + + // Check whether there will be fragmenation! + // This will happen only for the last wave! + if (n < N && (n + YTILE) >= N) { + uint32_t startColumn = N - YTILE; + for (uint32_t i = 0; i < (n - startColumn); i++) { + commitColumn[i] = 0; + } + n = startColumn; + } + } +} +#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support +template +__global__ void wvSpltK_hf_big_(const int K, const int N, const DTYPE* B, + const DTYPE* __restrict__ A, DTYPE* C, + const int CuCount) { + UNREACHABLE_CODE +} +#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support + +void wvSpltK_(void* in_a, void* in_b, void* out_c, const int M_in, + const int K_in, const int N_in, cudaStream_t stream, + const int CuCount = 0) { + dim3 grid(CuCount); + half* af4 = reinterpret_cast(in_a); + const half* bf4 = reinterpret_cast(in_b); + auto* c = reinterpret_cast(out_c); + +#define WVSPLTK(_WvPrGrp, _YTILEs, _YTILEm, _YTILEb, _UNRLs, _UNRLm, _UNRLb, \ + _N) \ + { \ + dim3 block(64, _WvPrGrp); \ + /*wvSpltK_hf:*/ \ + if ((K_in * N_in <= 32 * 1024) && (M_in % _YTILEs == 0)) { \ + wvSpltK_hf_sml_<64, _YTILEs, _WvPrGrp, 8, _UNRLs, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else if (K_in * N_in <= 32 * 1024 * 1.2) { \ + wvSpltK_hf_<64, _YTILEm, _WvPrGrp, 8, _UNRLm, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } else { \ + wvSpltK_hf_big_<64, _YTILEb, _WvPrGrp, 8, _UNRLb, _N> \ + <<>>(K_in, M_in, af4, bf4, c, CuCount); \ + } \ + } + + switch (N_in) { + case 1: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 1) // MI308 + break; + case 2: + WVSPLTK(16, 2, 2, 2, 2, 2, 2, 2) // MI308 + break; + case 3: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 3) // MI308 + break; + case 4: + WVSPLTK(16, 4, 7, 7, 1, 1, 1, 4) // MI308 + break; + default: + throw std::runtime_error("Unsupported N value: " + std::to_string(M_in) + + "," + std::to_string(K_in) + "," + + std::to_string(N_in)); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) { + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); + } +} \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dispatch_utils.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dispatch_utils.h new file mode 100644 index 000000000..3ecea0324 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dispatch_utils.h @@ -0,0 +1,35 @@ +/* + * Adapted from + * https://github.com/pytorch/pytorch/blob/v2.0.1/aten/src/ATen/Dispatch.h + */ +#pragma once + +#include + +#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) + +#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, \ + VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__)) + +#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \ + AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \ + AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) + +#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__)) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_bfloat16.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_bfloat16.cuh new file mode 100644 index 000000000..97a25baa1 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_bfloat16.cuh @@ -0,0 +1,463 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include + +typedef __hip_bfloat162 __nv_bfloat162; +typedef __hip_bfloat16 __nv_bfloat16; +#endif + +#include + +namespace vllm { + +// Define custom BF16 vector data types. +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; + +// BF16 vector types for Q, K, V. +template <> +struct Vec<__nv_bfloat16, 1> { + using Type = __nv_bfloat16; +}; +template <> +struct Vec<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template <> +struct Vec<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template <> +struct Vec<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec<__nv_bfloat16> { + using Type = float; +}; +template <> +struct FloatVec<__nv_bfloat162> { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat1622float2(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __bfloat162bfloat162(val); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +// Vector addition. +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + #ifndef USE_ROCM + return a + b; + #else + return __hadd(a, b); + #endif +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hadd2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) { + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hmul2(a, b); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +template <> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +template <> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) { + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +template <> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +template <> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) { + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +template <> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +template <> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) { + float fa = __bfloat162float(a); + float fb = __bfloat162float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) { + return mul(bf162bf162(a), b); +} + +template <> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(a, b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, + __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + return __hfma2(bf162bf162(a), b, c); +#endif + __builtin_unreachable(); // Suppress missing return statement warning +} + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) { + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) { + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) { + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) { + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) { + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) { + return fma(bf162bf162(a), b, fc); +} + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) { + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(__nv_bfloat16 v) { + return __bfloat162float(v); +} + +template <> +inline __device__ float sum(__nv_bfloat162 v) { + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +template <> +inline __device__ float sum(bf16_4_t v) { + return sum(v.x) + sum(v.y); +} + +template <> +inline __device__ float sum(bf16_8_t v) { + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} + +// From float32 to bfloat16. +inline __device__ void from_float(__nv_bfloat16& dst, float src) { + dst = __float2bfloat16(src); +} + +inline __device__ void from_float(__nv_bfloat162& dst, float2 src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst = __float22bfloat162_rn(src); +#endif +} + +inline __device__ void from_float(bf16_4_t& dst, Float4_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#endif +} + +inline __device__ void from_float(bf16_8_t& dst, Float8_ src) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#endif +} + +// From bfloat16 to float32. +inline __device__ float to_float(__nv_bfloat16 u) { + return __bfloat162float(u); +} + +// Zero-out a variable. +inline __device__ void zero(__nv_bfloat16& dst) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + assert(false); +#else + // Same as CUDART_ZERO_BF16 introduced in CUDA 12.2. + dst = __ushort_as_bfloat16((unsigned short)0x0000U); +#endif +} + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float16.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float16.cuh new file mode 100644 index 000000000..3a1815f0e --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float16.cuh @@ -0,0 +1,504 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" +#include "dtype_float32.cuh" + +#ifdef USE_ROCM + #include +#endif + +#include + +namespace vllm { + +// FP16 vector types for Q, K, V. +template <> +struct Vec { + using Type = uint16_t; +}; +template <> +struct Vec { + using Type = uint32_t; +}; +template <> +struct Vec { + using Type = uint2; +}; +template <> +struct Vec { + using Type = uint4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = Float4_; +}; +template <> +struct FloatVec { + using Type = Float8_; +}; + +// Utility functions for type conversions. +inline __device__ uint32_t h0_h0(uint16_t a) { +#ifndef USE_ROCM + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = a; + tmp.u16[1] = a; + return tmp.u32; +#endif +} + +inline __device__ float half_to_float(uint16_t h) { + float f; +#ifndef USE_ROCM + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); +#else + asm volatile("v_cvt_f32_f16 %0, %1;" : "=v"(f) : "v"(h)); +#endif + return f; +} + +inline __device__ float2 half2_to_float2(uint32_t v) { +#ifndef USE_ROCM + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +#else + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = v; + float2 ret; + ret.x = half_to_float(tmp.u16[0]); + ret.y = half_to_float(tmp.u16[1]); + return ret; +#endif +} + +inline __device__ uint16_t float_to_half(float f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#else + asm volatile("v_cvt_f16_f32 %0, %1;\n" : "=v"(tmp.u32) : "v"(f)); +#endif + return tmp.u16[0]; +} + +inline __device__ uint32_t float2_to_half2(float2 f) { + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#ifndef USE_ROCM + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" + : "=r"(tmp.u32) + : "f"(f.y), "f"(f.x)); + #else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); + #endif +#else + tmp.u16[0] = float_to_half(f.x); + tmp.u16[1] = float_to_half(f.y); +#endif + return tmp.u32; +} + +// Vector addition. +inline __device__ uint16_t add(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint32_t add(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_add_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +inline __device__ uint2 add(uint2 a, uint2 b) { + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ uint4 add(uint4 a, uint4 b) { + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +inline __device__ float2 add(uint32_t a, float2 fb) { + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +inline __device__ Float4_ add(uint2 a, Float4_ fb) { + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +inline __device__ Float8_ add(uint4 a, Float8_ fb) { + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +// Vector multiplication. +template <> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) { + uint16_t c; +#ifndef USE_ROCM + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); +#else + asm volatile("v_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) { + uint32_t c; +#ifndef USE_ROCM + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); +#else + asm volatile("v_pk_mul_f16 %0, %1, %2;\n" : "=v"(c) : "v"(a), "v"(b)); +#endif + return c; +} + +template <> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ uint2 mul(uint2 a, uint2 b) { + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +template <> +inline __device__ uint2 mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +template <> +inline __device__ uint4 mul(uint4 a, uint4 b) { + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +template <> +inline __device__ uint4 mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +template <> +inline __device__ float mul(uint16_t a, uint16_t b) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +template <> +inline __device__ float2 mul(uint32_t a, uint32_t b) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +template <> +inline __device__ float2 mul(uint16_t a, uint32_t b) { + return mul(h0_h0(a), b); +} + +template <> +inline __device__ Float4_ mul(uint2 a, uint2 b) { + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +template <> +inline __device__ Float4_ mul(uint16_t a, uint2 b) { + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint4 a, uint4 b) { + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +template <> +inline __device__ Float8_ mul(uint16_t a, uint4 b) { + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +// Vector fused multiply-add. +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) { + uint32_t d; +#ifndef USE_ROCM + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + asm volatile("v_pk_fma_f16 %0, %1, %2, %3;\n" + : "=v"(d) + : "v"(a), "v"(b), "v"(c)); +#endif + return d; +} + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) { + return fma(h0_h0(a), b, c); +} + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) { + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) { + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) { + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) { + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) { + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) { + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) { + return fma(h0_h0(a), b, fc); +} + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) { + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) { + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) { + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) { + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +// Vector sum. +template <> +inline __device__ float sum(uint16_t v) { + return half_to_float(v); +} + +template <> +inline __device__ float sum(uint32_t v) { + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +template <> +inline __device__ float sum(uint2 v) { + uint32_t c = add(v.x, v.y); + return sum(c); +} + +template <> +inline __device__ float sum(uint4 v) { + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); + return sum(c); +} + +// From float32 to float16. +inline __device__ void from_float(uint16_t& dst, float src) { + dst = float_to_half(src); +} + +inline __device__ void from_float(uint32_t& dst, float2 src) { + dst = float2_to_half2(src); +} + +inline __device__ void from_float(uint2& dst, Float4_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +inline __device__ void from_float(uint4& dst, Float8_ src) { + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +// From float16 to float32. +inline __device__ float to_float(uint16_t u) { return half_to_float(u); } + +inline __device__ float2 to_float(uint32_t u) { return half2_to_float2(u); } + +inline __device__ Float4_ to_float(uint2 u) { + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +inline __device__ Float8_ to_float(uint4 u) { + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +// Zero-out a variable. +inline __device__ void zero(uint16_t& dst) { dst = uint16_t(0); } + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float32.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float32.cuh new file mode 100644 index 000000000..7c6a686db --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_float32.cuh @@ -0,0 +1,251 @@ +/* + * Adapted from + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp + * and + * https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h + * Copyright (c) 2023, The vLLM team. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "attention_generic.cuh" + +#include + +namespace vllm { + +// Define custom FP32 vector data types. +struct Float4_ { + float2 x; + float2 y; +}; + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +// FP32 vector types for Q, K, V. +template <> +struct Vec { + using Type = float; +}; +template <> +struct Vec { + using Type = float2; +}; +template <> +struct Vec { + using Type = float4; +}; + +// FP32 accumulator vector types corresponding to Vec. +template <> +struct FloatVec { + using Type = float; +}; +template <> +struct FloatVec { + using Type = float2; +}; +template <> +struct FloatVec { + using Type = float4; +}; + +// Vector addition. +inline __device__ float add(float a, float b) { return a + b; } + +inline __device__ float2 add(float2 a, float2 b) { + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +inline __device__ float4 add(float4 a, float4 b) { + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +// Vector multiplication. +template <> +inline __device__ float mul(float a, float b) { + return a * b; +} + +template <> +inline __device__ float2 mul(float2 a, float2 b) { + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +template <> +inline __device__ float2 mul(float a, float2 b) { + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +template <> +inline __device__ float4 mul(float4 a, float4 b) { + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +template <> +inline __device__ float4 mul(float a, float4 b) { + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +// Vector fused multiply-add. +inline __device__ float fma(float a, float b, float c) { return a * b + c; } + +inline __device__ float2 fma(float2 a, float2 b, float2 c) { + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +inline __device__ float2 fma(float a, float2 b, float2 c) { + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ float4 fma(float4 a, float4 b, float4 c) { + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +inline __device__ float4 fma(float a, float4 b, float4 c) { + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) { + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) { + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +// Vector sum. +template <> +inline __device__ float sum(float v) { + return v; +} + +template <> +inline __device__ float sum(float2 v) { + return v.x + v.y; +} + +template <> +inline __device__ float sum(float4 v) { + return v.x + v.y + v.z + v.w; +} + +template <> +inline __device__ float sum(Float4_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +template <> +inline __device__ float sum(Float8_ v) { + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +// Vector dot product. +inline __device__ float dot(float a, float b) { return a * b; } + +inline __device__ float dot(float2 a, float2 b) { + float2 c = mul(a, b); + return c.x + c.y; +} + +inline __device__ float dot(Float4_ a, Float4_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + return acc.x + acc.y; +} + +inline __device__ float dot(Float8_ a, Float8_ b) { + float2 acc = mul(a.x, b.x); + acc = fma(a.y, b.y, acc); + acc = fma(a.z, b.z, acc); + acc = fma(a.w, b.w, acc); + return acc.x + acc.y; +} + +// From float to float. +inline __device__ void from_float(float& dst, float src) { dst = src; } + +inline __device__ void from_float(float2& dst, float2 src) { dst = src; } + +inline __device__ void from_float(float4& dst, float4 src) { dst = src; } + +// From float to float. +inline __device__ float to_float(float u) { return u; } + +inline __device__ float2 to_float(float2 u) { return u; } + +inline __device__ float4 to_float(float4 u) { return u; } + +inline __device__ Float4_ to_float(Float4_ u) { return u; } + +inline __device__ Float8_ to_float(Float8_ u) { return u; } + +// Zero-out a variable. +inline __device__ void zero(float& dst) { dst = 0.f; } + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_fp8.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_fp8.cuh new file mode 100644 index 000000000..e714e321b --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/dtype_fp8.cuh @@ -0,0 +1,41 @@ +#pragma once + +#include "attention_generic.cuh" + +#include +#ifdef ENABLE_FP8 + #ifndef USE_ROCM + #include + #endif // USE_ROCM +#endif // ENABLE_FP8 + +namespace vllm { + +enum class Fp8KVCacheDataType { + kAuto = 0, + kFp8E4M3 = 1, + kFp8E5M2 = 2, +}; + +// fp8 vector types for quantization of kv cache +template <> +struct Vec { + using Type = uint8_t; +}; + +template <> +struct Vec { + using Type = uint16_t; +}; + +template <> +struct Vec { + using Type = uint32_t; +}; + +template <> +struct Vec { + using Type = uint2; +}; + +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/fused_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/fused_kernels.cu new file mode 100644 index 000000000..4f3eea456 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/fused_kernels.cu @@ -0,0 +1,195 @@ +#include +#include +#include +#include + +constexpr int WARP_SIZE = 64; + +template +__device__ __forceinline__ T silu(const T& x) { + // x * sigmoid(x) + return (T)(((float)x) / (1.0f + expf((float)-x))); +} + +template +__device__ __forceinline__ T loadnt(T* addr) { + return __builtin_nontemporal_load(addr); +} + +__device__ __forceinline__ float4 load_ntmprl(const float4* addr) { + auto addr_alias = reinterpret_cast(addr); + auto dat0 = loadnt(addr_alias); + auto dat1 = loadnt(addr_alias + 1); + auto dat2 = loadnt(addr_alias + 2); + auto dat3 = loadnt(addr_alias + 3); + // auto dat0 = *(addr_alias); + // auto dat1 = *(addr_alias+1); + // auto dat2 = *(addr_alias+2); + // auto dat3 = *(addr_alias+3); + return make_float4(dat0, dat1, dat2, dat3); +} + +// TBlock fetches entire rows of A, and entire col of B (K dimension); assume +// N=1 for time being grid is M/A_NUM_ROWS blocks +template +__global__ void LLGemm_Silu_kernel(float4* af4, __half2* bf4, _Float16* c, + const int d) { + __shared__ float red_smem[NUM_A_ROWS_PER_BLOCK][WARP_SIZE]; + const int row_addr = blockIdx.x * NUM_A_ROWS_PER_BLOCK / 2 * blockDim.x; + const int row_addr_d = row_addr + d * blockDim.x; + // int row_addr_1 = row_addr + CUDA_NUM_THREADS; + // int row_addr_2 = row_addr_1 + CUDA_NUM_THREADS; + // int row_addr_3 = row_addr_2 + CUDA_NUM_THREADS; + const int threadid = threadIdx.x; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int num_warps = blockDim.x / WARP_SIZE; + const int qwarpid = threadid / 16; + const int qthreadid = threadid % 16; + float4 rowA_elem4[NUM_A_ROWS_PER_BLOCK]; + // float4 colB_elem4; + __half2 colB_elem4x, colB_elem4y, colB_elem4z, colB_elem4w; + float4 sum4; //[NUM_A_ROWS_PER_BLOCK]; + float acc[NUM_A_ROWS_PER_BLOCK]; //= 0.0; + __half2 acch2; + __half2 oval; + + // rowA_elem4 = af4[row_addr + threadid]; + //__syncthreads(); + // rowA_elem4_1 = af4[row_addr_1 + threadid]; + // rowA_elem4_2 = af4[row_addr_2 + threadid]; + // rowA_elem4_3 = af4[row_addr_3 + threadid]; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK / 2; i++) { + rowA_elem4[2 * i] = load_ntmprl(&af4[row_addr + i * blockDim.x + threadid]); + rowA_elem4[2 * i + 1] = + load_ntmprl(&af4[row_addr_d + i * blockDim.x + threadid]); + // rowA_elem4[i] = af4[row_addr + i*blockDim.x + threadid]; + //__syncthreads(); + } + colB_elem4x = bf4[threadid * 4 + 0]; + colB_elem4y = bf4[threadid * 4 + 1]; + colB_elem4z = bf4[threadid * 4 + 2]; + colB_elem4w = bf4[threadid * 4 + 3]; + + // __syncthreads(); + __half2 Af2; + __half2 Bf2; + float2 S; + // auto Bh2ptr = reinterpret_cast<__half2 *>(&colB_elem4); + // auto Bf2x = *Bh2ptr; + // auto Bf2y = *(Bh2ptr+1); + // auto Bf2z = *(Bh2ptr+2); + // auto Bf2w = *(Bh2ptr+3); + auto Ah2ptr = reinterpret_cast<__half2*>(&rowA_elem4); + __half2* ah2lptr; +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + ah2lptr = Ah2ptr + i * 4; + Af2 = *(ah2lptr); + acch2 = __hmul2(Af2, colB_elem4x); + Af2 = *(ah2lptr + 1); + acch2 = __hfma2(Af2, colB_elem4y, acch2); + Af2 = *(ah2lptr + 2); + acch2 = __hfma2(Af2, colB_elem4z, acch2); + Af2 = *(ah2lptr + 3); + acch2 = __hfma2(Af2, colB_elem4w, acch2); + S = __half22float2(acch2); + acc[i] = S.x + S.y; + } + +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { +#pragma unroll + for (int i = 0; i < NUM_A_ROWS_PER_BLOCK; i++) { + acc[i] += __shfl_xor(acc[i], mask); + } + } + + // Warp leaders store the data to shared memory. + // if (lane == 0) { + // #pragma unroll + // for (int i=0; i= 1; mask /= 2) { + // #pragma unroll + // for (int i=0; i +void LLGemm_Silu(void* in_a, void* in_b, void* out_c, const int M, const int K, + cudaStream_t stream, const int rows_per_block = 4) { + float4* af4 = reinterpret_cast(in_a); + auto* bf4 = reinterpret_cast<__half2*>(in_b); + auto* c = reinterpret_cast<_Float16*>(out_c); + const int d = M / 2; + const int NUM_THREADS = K * 2 / 16; + int NUM_BLOCKS = M / rows_per_block; + if (rows_per_block == 2) { + LLGemm_Silu_kernel<2> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 4) { + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 8) { + LLGemm_Silu_kernel<8> + <<>>(af4, bf4, c, d); + } else if (rows_per_block == 16) { + LLGemm_Silu_kernel<16> + <<>>(af4, bf4, c, d); + } else { + NUM_BLOCKS = M / 4; + LLGemm_Silu_kernel<4> + <<>>(af4, bf4, c, d); + } + + cudaError_t err = cudaGetLastError(); + if (cudaSuccess != err) + throw std::runtime_error("CUDA kernel failed : " + std::to_string(err)); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.cu new file mode 100644 index 000000000..445315e61 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.cu @@ -0,0 +1,165 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" +#include "gemm_a8w8_manifest.h" +#include "gemm_a8w8_lookup.h" + +using RowwiseKernel = std::function< + torch::Tensor(torch::Tensor &, torch::Tensor &, + torch::Tensor &, torch::Tensor &, torch::Tensor &)>; + +// Define a custom hash function for std::tuple +struct IntTupleHash +{ + size_t operator()(const std::tuple &t) const + { + auto hash1 = std::hash{}(std::get<0>(t)); + auto hash2 = std::hash{}(std::get<1>(t)); + auto hash3 = std::hash{}(std::get<2>(t)); + return hash1 ^ hash2 ^ hash3; + } +}; + +// For certain high priority shapes, we directly use the best kernel rather +// than use heuristics. +using RowwiseKernelMap = std::unordered_map< + std::tuple, + RowwiseKernel, + IntTupleHash>; + +template +RowwiseKernel rowwise_heuristic_dispatch(int M, int N, int K) +{ + // Apply shape heuristics to find a suitable kernel implementation. + if (M < 64 && N < 2048 && K < 2048) + { + // Kernel that generally works well on small shapes. + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + } + else if (M < 64 && K < 2048) + { + // Kernel that works well for small batch size and small K. + return a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2; + } + else if (M < 64 && N < 2048) + { + // Kernel that works well for small batch size and small N. + return a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2; + } + else if (M < 64 && N > 2048 && K > 2048) + { + // Kernel that works well for small M but larger N and K. + return a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1; + } + else if (M < 64) + { + // Fallback to generic small batch kernel if we cant find a good match. + return a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2; + /* } else if (((M < 512 && K < 8192) || (N <= 2048 && K <= 8192) || (K <= 2048 && N <= 8192)) && K >= 1024) { + // Kernel that is optimized for larger batch sizes but otherwise small + // tensors. + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5; */ + } + else if (K < 1024) + { + // Special case for small K. + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1; + } + else if (M < 1024) + { + // Kernel for generic medium batch sizes. + return a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3; + } + else if (M >= 1024 && N >= 1024 && K >= 1024) + { + // Kernel for very large gemm + // return a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; + return a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1; + } + else + { + // Fallback large kernel. + return a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3; + } +} + +// Helper function to return the next largest power of 2 +static constexpr int nextPow2(unsigned int num) +{ + if (num <= 1) + return 1; + return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1)); +} + +template +RowwiseKernel rowwise_dispatch(int M, int N, int K) +{ + // For a given shape, either find the best kernel via lookup or heuristic. + // For many small M shapes, we bucket them to the next largest kernel. + // This is fine since kernels are padded anyway. + int padded_m = M; + if (M > 1 && M <= 16) + { + padded_m = 16; + } + else if (M <= 16384) + { + padded_m = nextPow2(M); + } + else if (M <= 20480) + { + padded_m = 20480; + } + // First check if this shape is available in the direct lookup. + static const auto lookup = [] + { + if constexpr (std::is_same_v) { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(F16)}; + } else if constexpr (std::is_same_v) { + return RowwiseKernelMap{GENERATE_LOOKUP_TABLE(B16)}; + } else { + static_assert(false, "rowwise_dispatch used with unsupported dtype!"); + } }(); + + auto it = lookup.find({padded_m, N, K}); + // If we found an optimal kernel, use it. + if (it != lookup.end()) + { + return it->second; + } + // Otherwise, use heuristics. + return rowwise_heuristic_dispatch(M, N, K); +} + +//torch::Tensor gemm_a8w8( +void gemm_a8w8( + torch::Tensor &XQ, + torch::Tensor &WQ, + torch::Tensor &x_scale, + torch::Tensor &w_scale, + torch::Tensor &Y) +{ + TORCH_CHECK(XQ.dtype() == at::ScalarType::Char && XQ.dtype() == WQ.dtype(), + "Weights and activations should both be int8!"); + TORCH_CHECK(x_scale.dtype() == Y.dtype() && w_scale.dtype() == Y.dtype(), + "Scales and output should have the same dtype!"); + + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + if (Y.dtype() == at::ScalarType::Half) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else if (Y.dtype() == at::ScalarType::BFloat16) + { + rowwise_dispatch(M, N, K)(XQ, WQ, x_scale, w_scale, Y); + } + else + { + TORCH_CHECK(false, "Unsupported scales/output dtype!"); + } + //return Y; +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.h new file mode 100644 index 000000000..903aa4476 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8.h @@ -0,0 +1,9 @@ +#pragma once +#include +//torch::Tensor gemm_a8w8( +void gemm_a8w8( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_common.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_common.cuh new file mode 100644 index 000000000..3e00ed247 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_common.cuh @@ -0,0 +1,233 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#undef __HIP_NO_HALF_OPERATORS__ +#undef __HIP_NO_HALF_CONVERSIONS__ + +#include +#include +#include +#include + +#include +#include +#include +#include + +#include "ck/ck.hpp" +#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp" +#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp" +#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp" +#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp" + +#include "ck/library/utility/device_memory.hpp" +#include "ck/library/utility/host_tensor.hpp" +#include "ck/library/utility/host_tensor_generator.hpp" +#include "ck/library/utility/literals.hpp" +#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp" +#include "ck/library/utility/check_err.hpp" + +#include "ck/utility/blkgemmpipe_scheduler.hpp" + +template +using S = ck::Sequence; + +using I8 = int8_t; +using I32 = int; +using F16 = ck::half_t; +using B16 = ck::bhalf_t; +using FP8 = ck::f8_t; +using F32 = float; + +using Row = ck::tensor_layout::gemm::RowMajor; +using Col = ck::tensor_layout::gemm::ColumnMajor; + +using ADataType = I8; +using BDataType = I8; +using AccDataType = I32; +using CShuffleDataType = I32; +using ComputeDataType = I8; + +using ALayout = Row; +using BLayout = Col; +using D0Layout = Row; +using D1Layout = Col; +using DsLayout = ck::Tuple; +using ELayout = Row; + +using PassThrough = ck::tensor_operation::element_wise::PassThrough; + +using AElementOp = PassThrough; +using BElementOp = PassThrough; + +struct RowwiseScale +{ + template + __host__ __device__ constexpr void + operator()(E& e, const C& c, const D0& d0, const D1& d1) const; + + template <> + __host__ __device__ constexpr void operator()( + F16& e, const AccDataType& c, const F16& d0, const F16& d1) const + { + const F32 x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } + + template <> + __host__ __device__ constexpr void operator()( + B16& e, const AccDataType& c, const B16& d0, const B16& d1) const + { + const F32 x0_f = + ck::type_convert(c) * ck::type_convert(d0) * ck::type_convert(d1); + + e = ck::type_convert(x0_f); + } +}; + +using CDEElementOp = RowwiseScale; + +template +using DsDataType = ck::Tuple; + +#if 0 +template +using DeviceOpInstance = ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3 +// clang-format off +///######| ALayout| BLayout| DsLayout| ELayout| AData| BData| DsData| EData| AccData| CShuffle| A| B| CDE| GEMM| Block| MPer| NPer| KPer| AK1| BK1| MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| +///######| | | | | Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Spacialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| +///######| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| +///######| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | S| +///###### RRR +/// < Row, Row, DsLayout, ELayout, ADataType, BDataType, DsDataType, DEDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 4, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; +///###### RCR + < Row, Col, DsLayout, ELayout, ADataType, BDataType, DsDataType, DEDataType, AccDataType, CShuffleDataType, AElementOp, BElementOp, CDEElementOp, GemmSpec, 256, 256, 128, 64, 16, 16, 32, 32, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, S<8, 8, 1>, ck::BlockGemmPipelineScheduler::Interwave, ck::BlockGemmPipelineVersion::v1, I8>; +// clang-format on +#endif + +template < + typename DEDataType, + int BLOCK_SIZE, + int MBLOCK, + int NBLOCK, + int KBLOCK, + int WAVE_TILE_M, + int WAVE_TILE_N, + int WAVE_MAP_M, + int WAVE_MAP_N, + typename ABLOCK_TRANSFER, + typename BBLOCK_TRANSFER, + typename CBLOCK_TRANSFER, + typename CBLOCK_SPV, + int CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + int CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + ck::BlockGemmPipelineScheduler LOOP_SCHED, + ck::BlockGemmPipelineVersion PIPELINE_VERSION, + auto GEMM_SPEC = + ck::tensor_operation::device::GemmSpecialization::MNPadding> + using DeviceGemmHelper = + ck::tensor_operation::device::DeviceGemmMultiD_Xdl_CShuffle_V3< + ALayout, + BLayout, + DsLayout, + ELayout, + ADataType, + BDataType, + DsDataType, + DEDataType, + AccDataType, + CShuffleDataType, + AElementOp, + BElementOp, + CDEElementOp, + GEMM_SPEC, + BLOCK_SIZE, // Block Size + MBLOCK, // M per Block + NBLOCK, // N per Block + KBLOCK, // K per Block + 16, // AK1 + 16, // BK1 + WAVE_TILE_M, // M per Xdl + WAVE_TILE_N, // N per Xdl + WAVE_MAP_M, // Mxdl per Wave + WAVE_MAP_N, // Nxdl per Wave + ABLOCK_TRANSFER, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 0, + BBLOCK_TRANSFER, + S<1, 0, 2>, + S<1, 0, 2>, + 2, + 16, + 16, + 0, + CSHUFFLE_MX_PER_WAVE_PERSHUFFLE, + CSHUFFLE_NX_PER_WAVE_PERSHUFFLE, + CBLOCK_TRANSFER, + CBLOCK_SPV, + LOOP_SCHED, + PIPELINE_VERSION, + ComputeDataType>; + + +template +__forceinline__ torch::Tensor gemm_a8w8_rowwise_impl( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) +{ + int M = XQ.size(0); + int N = WQ.size(0); + int K = XQ.size(1); + + int StrideA = K; + int StrideB = K; + int StrideE = N; + + auto device_gemm = DeviceGemmInstance{}; + auto invoker = device_gemm.MakeInvoker(); + + auto a_element_op = AElementOp{}; + auto b_element_op = BElementOp{}; + auto cde_element_op = CDEElementOp{}; + + constexpr ck::index_t NumDTensor = DeviceGemmInstance::NumDTensor; + + auto argument = device_gemm.MakeArgument( + reinterpret_cast(XQ.data_ptr()), + reinterpret_cast(WQ.data_ptr()), + std::array{ + reinterpret_cast(w_scale.data_ptr()), + reinterpret_cast(x_scale.data_ptr())}, + reinterpret_cast(Y.data_ptr()), + M, + N, + K, + StrideA, + StrideB, + std::array{0, 0}, + StrideE, + 1, + a_element_op, + b_element_op, + cde_element_op + ); + TORCH_CHECK(device_gemm.IsSupportedArgument(argument), "This GEMM is not supported!"); + + invoker.Run(argument, StreamConfig{at::cuda::getCurrentCUDAStream().stream()}); + return Y; +} + +#endif // USE_ROCM diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_lookup.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_lookup.h new file mode 100644 index 000000000..dbc678065 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_lookup.h @@ -0,0 +1,115 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#define GENERATE_LOOKUP_TABLE(DTYPE) \ + { \ + /* QWen-57B \ + NK= 4608, 3584 */ \ + {{1, 4608, 3584}, \ + a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2}, \ + {{32, 4608, 3584}, \ + a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2}, \ + {{64, 4608, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{128, 4608, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{256, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{512, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{1024, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{2048, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{4096, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{8192, 4608, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{16384, 4608, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{20480, 4608, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + /* QWen-57B \ + NK= 3584, 3584 */ \ + {{1, 3584, 3584}, \ + a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2}, \ + {{32, 3584, 3584}, \ + a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2}, \ + {{64, 3584, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{128, 3584, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{256, 3584, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{512, 3584, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{1024, 3584, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{2048, 3584, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{4096, 3584, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{8192, 3584, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{16384, 3584, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{20480, 3584, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + /* QWen-57B \ + NK= 3584, 20480 */ \ + {{1, 3584, 20480}, \ + a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2}, \ + {{32, 3584, 20480}, \ + a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2}, \ + {{64, 3584, 20480}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{128, 3584, 20480}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{256, 3584, 20480}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{512, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{1024, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{2048, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{4096, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{8192, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{16384, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{20480, 3584, 20480}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + /* QWen-57B \ + NK= 40960, 3584 */ \ + {{1, 40960, 3584}, \ + a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2}, \ + {{32, 40960, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{64, 40960, 3584}, \ + a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{128, 40960, 3584}, \ + a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3}, \ + {{256, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{512, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{1024, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{2048, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{4096, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{8192, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{16384, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + {{20480, 40960, 3584}, \ + a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4}, \ + } + +#endif // USE_ROCM diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_manifest.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_manifest.h new file mode 100644 index 000000000..840722507 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/gemm_a8w8_manifest.h @@ -0,0 +1,254 @@ +#pragma once +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#ifdef USE_ROCM + +#include + +#include + +template +torch::Tensor +a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template +torch::Tensor +a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +#endif // USE_ROCM \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_compat.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_compat.h new file mode 100644 index 000000000..bfa8c0071 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_compat.h @@ -0,0 +1,49 @@ +#pragma once + +#ifdef USE_ROCM + #include +#endif + +#ifndef USE_ROCM + #define WARP_SIZE 32 +#else + #define WARP_SIZE warpSize +#endif + +#ifndef USE_ROCM + #define VLLM_LDG(arg) __ldg(arg) +#else + #define VLLM_LDG(arg) *(arg) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor_sync(uint32_t(-1), var, lane_mask, width) +#else + #define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask) + #define VLLM_SHFL_XOR_SYNC_WIDTH(var, lane_mask, width) \ + __shfl_xor(var, lane_mask, width) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_SYNC(var, src_lane) __shfl_sync(uint32_t(-1), var, src_lane) +#else + #define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane) +#endif + +#ifndef USE_ROCM + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \ + __shfl_down_sync(uint32_t(-1), var, lane_delta) +#else + #define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta) +#endif + +#ifndef USE_ROCM + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#else + #define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \ + hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL) +#endif diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8.h new file mode 100644 index 000000000..f9c80fcde --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8.h @@ -0,0 +1,137 @@ +#pragma once + +#ifdef __HIPCC__ + #include +#else + #include + #include + #include + #include +#endif + +#include "hip_float8_impl.h" + +struct alignas(1) hip_fp8 { + struct from_bits_t {}; + HIP_FP8_HOST_DEVICE static constexpr from_bits_t from_bits() { + return from_bits_t(); + } + uint8_t data; + + hip_fp8() = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(const hip_fp8&) = default; + HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v) = delete; + explicit HIP_FP8_HOST_DEVICE constexpr hip_fp8(uint8_t v, from_bits_t) + : data(v) {} + +#ifdef __HIP__MI300__ + // NOTE: ON-DEVICE... always optimal bias + explicit HIP_FP8_DEVICE hip_fp8(float v) + : data(hip_fp8_impl::to_fp8_from_fp32(v)) {} + + explicit HIP_FP8_DEVICE hip_fp8(_Float16 v) + : hip_fp8(static_cast(v)) {} + + // Host only implementation using s/w simulation + explicit HIP_FP8_HOST +#else // __HIP__MI300__ + // both Host and DEVICE for non-MI300 using s/w simulation + explicit HIP_FP8_HOST_DEVICE +#endif // __HIP__MI300__ + hip_fp8(float v) { + data = hip_fp8_impl::to_float8<4, 3, float, true /*negative_zero_nan*/, + true /*clip*/>(v); + } + + explicit HIP_FP8_HOST_DEVICE hip_fp8(double v) + : hip_fp8(static_cast(v)) {} + +#ifdef __HIP__MI300__ + // upcast using device specific intrinsic + explicit inline HIP_FP8_DEVICE operator float() const { + float fval; + uint32_t i32val = static_cast(data); + + // upcast + asm volatile("v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0" + : "=v"(fval) + : "v"(i32val)); + + return fval; + } + + explicit inline HIP_FP8_HOST operator float() const +#else // __HIP__MI300__ + explicit inline HIP_FP8_HOST_DEVICE operator float() const +#endif // __HIP__MI300__ + { + return hip_fp8_impl::from_float8<4, 3, float, true /*negative_zero_nan*/>( + data); + } +}; + +namespace std { +inline hip_fp8 sin(hip_fp8 a) { return hip_fp8(sinf(float(a))); } +inline hip_fp8 cos(hip_fp8 a) { return hip_fp8(cosf(float(a))); } +HIP_FP8_HOST_DEVICE constexpr hip_fp8 real(const hip_fp8& a) { return a; } +} // namespace std + +// Special operator overloading +inline std::ostream& operator<<(std::ostream& os, const hip_fp8& f8) { + return os << float(f8); +} + +// all + operator overloading with mixed types +// mixed types, always converts to f32, does computation in f32, and returns +// float +inline HIP_FP8_HOST_DEVICE float operator+(const float fa, hip_fp8 b) { + return (fa + float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator+(hip_fp8 a, const float fb) { + return (float(a) + fb); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8 operator+(hip_fp8 a, hip_fp8 b) { + return hip_fp8(float(a) + float(b)); +} + +inline HIP_FP8_HOST_DEVICE hip_fp8& operator+=(hip_fp8& a, hip_fp8 b) { + return a = hip_fp8(float(a) + float(b)); +} + +// overloading multiplication, always returns float, +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, hip_fp8 b) { + return float(a) * float(b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(float a, hip_fp8 b) { + return (a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(hip_fp8 a, float b) { + return (float(a) * b); +} + +inline HIP_FP8_HOST_DEVICE float operator*(int32_t a, hip_fp8 b) { + return ((float)a * float(b)); +} + +inline HIP_FP8_HOST_DEVICE float operator*(double a, hip_fp8 b) { + return ((float)a * float(b)); +} + +// overloading for compare +inline HIP_FP8_HOST_DEVICE bool operator==(hip_fp8 a, hip_fp8 b) { + return (a.data == b.data); +} +inline HIP_FP8_HOST_DEVICE bool operator!=(hip_fp8 a, hip_fp8 b) { + return (a.data != b.data); +} + +inline HIP_FP8_HOST_DEVICE bool operator>=(hip_fp8 a, hip_fp8 b) { + return static_cast(a) >= static_cast(b); +} +inline HIP_FP8_HOST_DEVICE bool operator>(hip_fp8 a, hip_fp8 b) { + return static_cast(a) > static_cast(b); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8_impl.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8_impl.h new file mode 100644 index 000000000..90251c353 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/hip_float8_impl.h @@ -0,0 +1,316 @@ +#pragma once + +#if defined(__HIPCC__) && \ + (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)) + #define __HIP__MI300__ +#endif + +#ifdef __HIPCC__ + #define HIP_FP8_HOST_DEVICE __host__ __device__ + #define HIP_FP8_HOST __host__ + #define HIP_FP8_DEVICE __device__ +#else + #define HIP_FP8_HOST_DEVICE + #define HIP_FP8_HOST + #define HIP_FP8_DEVICE +#endif + +namespace hip_fp8_impl { + +#ifdef __HIP__MI300__ +HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v) { + uint8_t i8data; + union { + float fval; + uint32_t i32val; + uint8_t i8val[4]; // NOTE: not endian independent + } val; + + uint32_t ival = 0; + val.fval = v; + + if ((val.i32val & 0x7F800000) != + 0x7F800000) { /// propagate NAN/INF, no clipping + val.fval = __builtin_amdgcn_fmed3f(val.fval, 240.0, -240.0); + } + + ival = __builtin_amdgcn_cvt_pk_fp8_f32(val.fval, val.fval, ival, + false); // false -> WORD0 + val.i32val = ival; + i8data = val.i8val[0]; + + return i8data; +} +#endif // __HIP__MI300__ + +HIP_FP8_HOST inline int clz(uint32_t x) { return __builtin_clz(x); } +#if defined(__HIPCC__) || defined(__CUDA_ARCH__) +HIP_FP8_DEVICE inline int clz(uint32_t x) { return __clz(x); } +#endif + +template +HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, + uint32_t rng = 0) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(wm + we == 7, "wm+we==7"); + static_assert(is_half || is_float, "Only half and float can be cast to f8"); + + const int mfmt = (sizeof(T) == 4) ? 23 : 10; + uint32_t x; + if (sizeof(T) == 4) { + x = reinterpret_cast(_x); + } else { + x = reinterpret_cast(_x); + } + + uint32_t head, mantissa; + int exponent, bias; + uint32_t sign; + + if (sizeof(T) == 4) { + head = x & 0xFF800000; + mantissa = x & 0x7FFFFF; + exponent = (head >> 23) & 0xFF; + sign = head >> 31; + bias = 127; + } else { + head = x & 0xFC00; + mantissa = x & 0x3FF; + exponent = (head >> 10) & 0x1F; + sign = head >> 15; + bias = 15; + } + + uint32_t signed_inf = (sign << 7) + (((1 << we) - 1) << wm); + + // Deal with inf and NaNs + if (negative_zero_nan) { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return 0x80; + } + } else { + // if(__hisinf(x) || __hisnan(x)) + if ((x & 0x7C00) == 0x7C00) { + return 0x80; + } + } + } else { + if (sizeof(T) == 4) { + if ((x & 0x7F800000) == 0x7F800000) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } else { + if ((x & 0x7C00) == 0x7C00) { + return signed_inf + (mantissa != 0 ? 1 : 0); + } + } + } + if (x == 0) { + return 0; + } + + // First need to check if it is normal or denorm as there is a difference of + // implicit 1 Then need to adjust the exponent to align with the F8 exponent, + // in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng + // to mantissa and truncate. And for RNE, no need to add rng. Then probably + // need to check whether there is carry and adjust exponent and mantissa again + + // For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent + // bits + const int f8_bias = (1 << (we - 1)) - 1 + (negative_zero_nan ? 1 : 0); + const int f8_denormal_act_exponent = + 1 - f8_bias; // actual exponent of f8 denormal + // act_exponent is the actual exponent of fp32/fp16 (after subtracting bias) + // f8_exponent is the converted f8 exponent with bias encoding + // exponent_diff is the diff between fp32/fp16 exponent and f8 exponent, + // the difference needs to be adjusted and mantissa shifted + int act_exponent, f8_exponent, exponent_diff; + + if (exponent == 0) { // fp32/fp16 is in denormal. + /* fp32 denormal is below 2^-127 so it is usually not a concern here, we +mostly concern fp16 here. In this case, f8 is usually in denormal. But there +could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has +exponent bias 16. It means that there are some numbers in fp16 denormal but they +are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers +where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8 +(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */ + act_exponent = exponent - bias + 1; + exponent_diff = + f8_denormal_act_exponent - + act_exponent; // actual exponent is exponent-bias+1 as it is denormal + } else { // fp32/fp16 is normal with implicit 1 + act_exponent = exponent - bias; + if (act_exponent <= f8_denormal_act_exponent) { + /* This is the case where fp32/fp16 is normal but it is in f8 denormal +range. For example fp8 nanoo mode, denormal exponent is -7, but if the +fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1, +Therefore it needs to be adjust to -6 and mantissa shift right by 1. +So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */ + exponent_diff = f8_denormal_act_exponent - act_exponent; + } else { // both fp32/fp16 and f8 are in normal range + exponent_diff = 0; // exponent_diff=0 does not mean there is no + // difference for this case, act_exponent could be + // larger. Just that it does not need shift mantissa + } + mantissa += (1 << mfmt); // Add the implicit 1 into mantissa + } + + bool midpoint = (mantissa & ((1 << (mfmt - wm + exponent_diff)) - 1)) == + static_cast(1 << (mfmt - wm + exponent_diff - 1)); + /* This part is a bit tricky. The judgment of whether it is a tie needs to be + done before we shift right as shift right could rip off some residual part + and make something not midpoint look like midpoint. For example, the fp16 + number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after + shift right by 4 bits, it would look like midpoint. +*/ + + if (exponent_diff > 0) { + mantissa >>= exponent_diff; + } else if (exponent_diff == -1) { + mantissa <<= -exponent_diff; + } + bool implicit_one = mantissa & (1 << mfmt); + // if there is no implicit 1, it means the f8 is denormal and need to adjust + // to denorm exponent + f8_exponent = (act_exponent + exponent_diff) /*actual f8 exponent*/ + + f8_bias - (implicit_one ? 0 : 1); + + // Now we have the exponent and mantissa adjusted + uint32_t drop_mask = (1 << (mfmt - wm)) - 1; + bool odd = mantissa & (1 << (mfmt - wm)); // if the least significant bit + // that is not truncated is 1 + mantissa += + (stoch ? rng : (midpoint ? (odd ? mantissa : mantissa - 1) : mantissa)) & + drop_mask; + + // Now we deal with overflow + if (f8_exponent == 0) { + if ((1 << mfmt) & mantissa) { + f8_exponent = 1; // denormal overflow to become normal, promote exponent + } + } else { + if ((1 << (mfmt + 1)) & mantissa) { + mantissa >>= 1; + f8_exponent++; + } + } + + mantissa >>= (mfmt - wm); + + // above range: quantize to maximum possible float of the same sign + const int max_exp = (1 << we) - (negative_zero_nan ? 1 : 2); + if (f8_exponent > max_exp) { + if (clip) { + mantissa = (1 << wm) - 1; + f8_exponent = max_exp; + } else { + return signed_inf; + } + } + + if (f8_exponent == 0 && mantissa == 0) { + return negative_zero_nan ? 0 : (sign << 7); + } + mantissa &= (1 << wm) - 1; + return (sign << 7) | (f8_exponent << wm) | mantissa; +} + +template +inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x) { +#ifdef __HIPCC__ + constexpr bool is_half = std::is_same::value; +#else + constexpr bool is_half = false; +#endif + constexpr bool is_float = std::is_same::value; + static_assert(is_half || is_float, "only half and float are supported"); + + constexpr int weo = is_half ? 5 : 8; + constexpr int wmo = is_half ? 10 : (is_float ? 23 : 7); + + T fInf, fNegInf, fNaN, fNeg0; + +#ifdef __HIPCC__ + if (is_half) { + const uint16_t ihInf = 0x7C00; + const uint16_t ihNegInf = 0xFC00; + const uint16_t ihNaN = 0x7C01; + const uint16_t ihNeg0 = 0x8000; + fInf = reinterpret_cast(ihInf); + fNegInf = reinterpret_cast(ihNegInf); + fNaN = reinterpret_cast(ihNaN); + fNeg0 = reinterpret_cast(ihNeg0); + } else +#endif + if (is_float) { + const uint32_t ifInf = 0x7F800000; + const uint32_t ifNegInf = 0xFF800000; + const uint32_t ifNaN = 0x7F800001; + const uint32_t ifNeg0 = 0x80000000; + fInf = reinterpret_cast(ifInf); + fNegInf = reinterpret_cast(ifNegInf); + fNaN = reinterpret_cast(ifNaN); + fNeg0 = reinterpret_cast(ifNeg0); + } + + if (x == 0) { + return 0; + } + + uint32_t sign = x >> 7; + uint32_t mantissa = x & ((1 << wm) - 1); + int exponent = (x & 0x7F) >> wm; + if (negative_zero_nan) { + if (x == 0x80) { + return fNaN; + } + } else { + if (x == 0x80) { + return fNeg0; + } + if (exponent == ((1 << we) - 1)) { + return (mantissa == 0) ? (sign ? fNegInf : fInf) : fNaN; + } + } + typename std::conditional::type retval; + if (we == 5 && is_half && !negative_zero_nan) { + retval = x << 8; + return reinterpret_cast(retval); + } + + const int exp_low_cutoff = + (1 << (weo - 1)) - (1 << (we - 1)) + 1 - (negative_zero_nan ? 1 : 0); + + // subnormal input + if (exponent == 0) { + // guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above + int sh = 1 + clz(mantissa) - (32 - wm); + mantissa <<= sh; + exponent += 1 - sh; + mantissa &= ((1 << wm) - 1); + } + exponent += exp_low_cutoff - 1; + mantissa <<= wmo - wm; + + // subnormal output (occurs when T=half, we=5, negative_zero_nan=true) + if (exponent <= 0) { + mantissa |= 1 << wmo; + mantissa >>= 1 - exponent; + exponent = 0; + } + + if (sizeof(T) == 2) { + retval = (sign << 15) | (exponent << 10) | mantissa; + } else { + retval = (sign << 31) | (exponent << 23) | mantissa; + } + return reinterpret_cast(retval); +} + +} // namespace hip_fp8_impl diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..1472a7279 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that works well on small but not super tiny shapes. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 128, + 16, + 128, + 16, + 16, + 4, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x128x16x128_16x16_4x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu new file mode 100644 index 000000000..c1c61eb45 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 128 != 0) || (N % 32 != 0) || (K % 128 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 128, + 32, + 128, + 32, + 32, + 2, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x128x32x128_32x32_2x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..df5494f46 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + + } +} + +template torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu new file mode 100644 index 000000000..ac9089541 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu @@ -0,0 +1,81 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 16 != 0) || (N % 32 != 0) || (K % 128 != 0); + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else{ + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 16, + 32, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x16x32x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..f18ed38f8 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 16, + 32, + 512, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x16x32x512_16x16_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..1dbcb0514 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 128, + 128, + 32, + 32, + 1, + 2, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x32x128x128_32x32_1x2_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..d3e728914 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A small kernel for small but not tiny shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 32 != 0) || (N % 16 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<2, 2, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x32x16x128_16x16_1x1_8x16x1_8x16x1_1x16x1x8_2x2x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..6c267cba1 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that works well on small but not super tiny shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 32 != 0) || (N % 64 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.cu new file mode 100644 index 000000000..9ebcb72a3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that works well on small but not super tiny shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 32 != 0) || (N % 64 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 32, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x32x64x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_8x8x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu new file mode 100644 index 000000000..8bd4ae239 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A small kernel for small but not tiny shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 64 != 0) || (N % 32 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 128, + 64, + 32, + 128, + 32, + 32, + 1, + 1, + S<8, 16, 1>, + S<8, 16, 1>, + S<1, 16, 1, 8>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_128x64x32x128_32x32_1x1_8x16x1_8x16x1_1x16x1x8_4x4x1_1x1_intrawave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu new file mode 100644 index 000000000..3a3a49427 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu new file mode 100644 index 000000000..7d261cf55 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // V5 kernel that works well on some medium shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.cu new file mode 100644 index 000000000..9bb59a793 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5.cu @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // V5 kernel that works well on some medium shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v5, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 128, + 32, + 32, + 2, + 2, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v5, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x128x128x128_32x32_2x2_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v5( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu new file mode 100644 index 000000000..9734c09dc --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 64 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 128, + 64, + 32, + 32, + 2, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x128x128x64_32x32_2x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu new file mode 100644 index 000000000..ab7aa472b --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 128 != 0) || (N % 64 != 0) || (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 128, + 64, + 128, + 32, + 32, + 2, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x128x64x128_32x32_2x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu new file mode 100644 index 000000000..d5f8c8c2c --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu @@ -0,0 +1,85 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // This kernel works well for many medium to large shapes. + + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + + bool kpad = K % 128 != 0; + + if (kpad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 224, + 256, + 128, + 16, + 16, + 7, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x224x256x128_16x16_7x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu new file mode 100644 index 000000000..85c419538 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A small kernel for small but not tiny shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 128 != 0) || (K % 64 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 128, + 64, + 32, + 32, + 4, + 2, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x256x128x64_32x32_4x2_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_interwave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.cu new file mode 100644 index 000000000..4af680193 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 224 != 0) || (K % 128 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 224, + 128, + 16, + 16, + 8, + 7, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 64, 1, 4>, + S<8, 8, 1>, + 2, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x256x224x128_16x16_8x7_8x32x1_8x32x1_1x64x1x4_8x8x1_2x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu new file mode 100644 index 000000000..d6eb900e0 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 128, + 16, + 16, + 8, + 8, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x256x256x128_16x16_8x8_8x32x1_8x32x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu new file mode 100644 index 000000000..ec0bf5333 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 256 != 0) || (N % 256 != 0) || (K % 64 != 0); + + // This kernel seems optimal in the most purely compute bound tasks. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 64, + 16, + 16, + 8, + 8, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 2, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x256x256x64_16x16_8x8_4x64x1_4x64x1_1x32x1x8_8x8x1_1x2_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu new file mode 100644 index 000000000..74008a4e3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4.cu @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (K % 64 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::KPadding>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 256, + 256, + 64, + 32, + 32, + 4, + 4, + S<4, 64, 1>, + S<4, 64, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v4, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl( + XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x256x256x64_32x32_4x4_4x64x1_4x64x1_1x32x1x8_8x8x1_1x1_intrawave_v4( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu new file mode 100644 index 000000000..61b1ddbc9 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3.cu @@ -0,0 +1,83 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // A kernel that seems to work well on mid sized tensors. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 64 != 0) || (N % 64 != 0) || (K % 128 != 0); + + // Dispatch based on whether padding is needed or not. + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 256, + 64, + 64, + 128, + 32, + 32, + 1, + 1, + S<8, 32, 1>, + S<8, 32, 1>, + S<1, 32, 1, 8>, + S<8, 8, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v3, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_256x64x64x128_32x32_1x1_8x32x1_8x32x1_1x32x1x8_8x8x1_1x1_intrawave_v3( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..1aab00230 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,82 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + + // Check if this input needs to be padded. + int M = size_to_dim_(XQ.dim() - 1, XQ.sizes()); + int N = WQ.size(0); + int K = WQ.size(1); + bool pad = (M % 16 != 0) || (N % 16 != 0) || (K % 128 != 0); + + if (pad) { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } else { + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 128, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); + } +} + +template torch::Tensor +a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_64x16x16x128_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.cu new file mode 100644 index 000000000..49d09a615 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1.cu @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // Secret kernel that seems good with small M but large N and K. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 256, + 16, + 16, + 1, + 1, + S<16, 4, 1>, + S<16, 4, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Intrawave, + ck::BlockGemmPipelineVersion::v1, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_64x16x16x256_16x16_1x1_16x4x1_16x4x1_1x4x1x16_4x4x1_1x1_intrawave_v1( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..997eeb8c6 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<32, 2, 1>, + S<32, 2, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2, + ck::tensor_operation::device::GemmSpecialization::Default>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_32x2x1_32x2x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..caa06fd3c --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 512, + 16, + 16, + 1, + 1, + S<8, 8, 1>, + S<8, 8, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_64x16x16x512_16x16_1x1_8x8x1_8x8x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu new file mode 100644 index 000000000..38d6de058 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/impl/a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2.cu @@ -0,0 +1,51 @@ +// SPDX-License-Identifier: MIT +// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. + +#include "gemm_a8w8_common.cuh" + +template +torch::Tensor +a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y) { + // The smallest kernel we have available. Works well for memory bound shapes. + using DeviceGemmInstance = DeviceGemmHelper< + DEDataType, + 64, + 16, + 16, + 64, + 16, + 16, + 1, + 1, + S<4, 16, 1>, + S<4, 16, 1>, + S<1, 16, 1, 4>, + S<4, 4, 1>, + 1, + 1, + ck::BlockGemmPipelineScheduler::Interwave, + ck::BlockGemmPipelineVersion::v2>; + // Run kernel instance. + return gemm_a8w8_rowwise_impl(XQ, WQ, x_scale, w_scale, Y); +} + +template torch::Tensor +a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); + +template torch::Tensor +a8w8_rowwise_64x16x16x64_16x16_1x1_4x16x1_4x16x1_1x16x1x4_4x4x1_1x1_interwave_v2( + torch::Tensor& XQ, + torch::Tensor& WQ, + torch::Tensor& x_scale, + torch::Tensor& w_scale, + torch::Tensor& Y); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/layernorm_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/layernorm_kernels.cu new file mode 100644 index 000000000..681f2bc59 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/layernorm_kernels.cu @@ -0,0 +1,627 @@ +#include +#include +#include + +#include "dispatch_utils.h" +// #include "attention/attention_dtypes.h" +#ifndef USE_ROCM + #include + #include + #include + #include +#else + #include + #include + #include + #include + // #include "quantization/fp8/amd/hip_float8.h" + // #include "quantization/fp8/amd/quant_utils.cuh" + +using __nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat162 = __hip_bfloat162; +#endif + + +namespace vllm { + +template +struct __align__(16) vec8_t { + scalar_t x, y, z, w, u, v, s, t; + + __device__ vec8_t() : x(0), y(0), z(0), w(0), u(0), v(0), s(0), t(0) {} + __device__ vec8_t(scalar_t x, scalar_t y, scalar_t z, scalar_t w, scalar_t u, + scalar_t v, scalar_t s, scalar_t t) + : x(x), y(y), z(z), w(w), u(u), v(v), s(s), t(t) {} + + __device__ vec8_t operator*(const vec8_t& other) const { + return vec8_t(x * other.x, y * other.y, z * other.z, w * other.w, + u * other.u, v * other.v, s * other.s, t * other.t); + } + + __device__ vec8_t operator*(const float& scale) const { + return vec8_t(x * scale, y * scale, z * scale, w * scale, u * scale, + v * scale, s * scale, t * scale); + } + + __device__ vec8_t operator+(const vec8_t& other) const { + return vec8_t(x + other.x, y + other.y, z + other.z, w + other.w, + u + other.u, v + other.v, s + other.s, t + other.t); + } + + __device__ void operator+=(const vec8_t& other) { + x += other.x; + y += other.y; + z += other.z; + w += other.w; + u += other.u; + v += other.v; + s += other.s; + t += other.t; + } + + __device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; } +}; + +// TODO(woosuk): Further optimize this kernel. +template +__global__ void rms_norm_kernel( + scalar_t* __restrict__ out, // [..., hidden_size] + const scalar_t* __restrict__ input, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + + vec8_t v8_variance = {0, 0, 0, 0, 0, 0, 0, 0}; + + vec8_t* vectorized_out = reinterpret_cast*>(out); + vec8_t const* vectorized_in = + reinterpret_cast const*>(input); + vec8_t const* vectorized_weight = + reinterpret_cast const*>(weight); + const int vec_hidden_size = hidden_size >> 3; + + // Compute variance. Be careful, hidden_size should multiple of 4. + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t x = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + v8_variance += x * x; + } + float v8_variance_sum = v8_variance.sum(); + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + float variance = + BlockReduce(reduceStore).Reduce(v8_variance_sum, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + vec8_t v8_in = vectorized_in[blockIdx.x * vec_hidden_size + idx]; + vec8_t v8_w = vectorized_weight[idx]; + vectorized_out[blockIdx.x * vec_hidden_size + idx] = + v8_in * s_variance * v8_w; + } +} + +// template +// __global__ void scaled_rms_norm_kernel( +// c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] +// const scalar_t* __restrict__ input, // [..., hidden_size] +// const scalar_t* __restrict__ weight, // [hidden_size] +// const float scale, const float epsilon, const int num_tokens, +// const int hidden_size) { +// __shared__ float s_variance; +// float variance = 0.0f; + +// for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +// const float x = (float)input[blockIdx.x * hidden_size + idx]; +// variance += x * x; +// } + +// using BlockReduce = cub::BlockReduce; +// __shared__ typename BlockReduce::TempStorage reduceStore; +// variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + +// if (threadIdx.x == 0) { +// s_variance = rsqrtf(variance / hidden_size + epsilon); +// } +// __syncthreads(); + +// for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +// float x = (float)input[blockIdx.x * hidden_size + idx]; +// float r = (x * s_variance) * weight[idx] * scale; +// out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( +// hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); +// } +// } + +/* Converter structs for the conversion from torch types to HIP/CUDA types, + and the associated type conversions within HIP/CUDA. These helpers need + to be implemented for now because the relevant type conversion + operators/constructors are not consistently implemented by HIP/CUDA, so + a generic conversion via type casts cannot be implemented. + + Each struct should have the member static constexpr bool `exists`: + If false, the optimized kernel is not used for the corresponding torch type. + If true, the struct should be fully defined as shown in the examples below. + */ +template +struct _typeConvert { + static constexpr bool exists = false; +}; + +#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)) +// CUDA < 12.0 runs into issues with packed type conversion +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __half; + using packed_hip_type = __half2; + + __device__ static inline float convert(hip_type x) { return __half2float(x); } + __device__ static inline float2 convert(packed_hip_type x) { + return __half22float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2half_rn(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22half2_rn(x); + } +}; + + #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +// CUDA_ARCH < 800 does not have BF16 support +// TODO: Add in ROCm support once public headers handle bf16 maturely +template <> +struct _typeConvert { + static constexpr bool exists = true; + using hip_type = __nv_bfloat16; + using packed_hip_type = __nv_bfloat162; + + __device__ static inline float convert(hip_type x) { + return __bfloat162float(x); + } + __device__ static inline float2 convert(packed_hip_type x) { + return __bfloat1622float2(x); + } + __device__ static inline hip_type convert(float x) { + return __float2bfloat16(x); + } + __device__ static inline packed_hip_type convert(float2 x) { + return __float22bfloat162_rn(x); + } +}; + #endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 +#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= + // 12000)) + +/* Vector POD struct to generate vectorized and packed FP16/BF16 ops + for appropriate specializations of fused_add_rms_norm_kernel. + Only functions that are necessary in that kernel are implemented. + Alignment to 16 bytes is required to use 128-bit global memory ops. + */ +template +struct alignas(16) _f16Vec { + /* Not theoretically necessary that width is a power of 2 but should + almost always be the case for optimization purposes */ + static_assert(width > 0 && (width & (width - 1)) == 0, + "Width is not a positive power of 2!"); + using Converter = _typeConvert; + using T1 = typename Converter::hip_type; + using T2 = typename Converter::packed_hip_type; + T1 data[width]; + + __device__ _f16Vec& operator+=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp += T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] += other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const _f16Vec& other) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + T2 temp{data[i], data[i + 1]}; + temp *= T2{other.data[i], other.data[i + 1]}; + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) data[i] *= other.data[i]; + } + return *this; + } + + __device__ _f16Vec& operator*=(const float scale) { + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 temp_f = Converter::convert(T2{data[i], data[i + 1]}); + temp_f.x *= scale; + temp_f.y *= scale; + T2 temp = Converter::convert(temp_f); + data[i] = temp.x; + data[i + 1] = temp.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float temp = Converter::convert(data[i]) * scale; + data[i] = Converter::convert(temp); + } + } + return *this; + } + + __device__ float sum_squares() const { + float result = 0.0f; + if constexpr (width % 2 == 0) { +#pragma unroll + for (int i = 0; i < width; i += 2) { + float2 z = Converter::convert(T2{data[i], data[i + 1]}); + result += z.x * z.x + z.y * z.y; + } + } else { +#pragma unroll + for (int i = 0; i < width; ++i) { + float x = Converter::convert(data[i]); + result += x * x; + } + } + return result; + } +}; + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ +template +__global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + // Sanity checks on our vector struct and type-punned pointer arithmetic + static_assert(std::is_pod_v<_f16Vec>); + static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + + const int vec_hidden_size = hidden_size / width; + __shared__ float s_variance; + float variance = 0.0f; + /* These and the argument pointers are all declared `restrict` as they are + not aliased in practice. Argument pointers should not be dereferenced + in this kernel as that would be undefined behavior */ + auto* __restrict__ input_v = + reinterpret_cast<_f16Vec*>(input); + auto* __restrict__ residual_v = + reinterpret_cast<_f16Vec*>(residual); + auto* __restrict__ weight_v = + reinterpret_cast*>(weight); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = input_v[id]; + temp += residual_v[id]; + variance += temp.sum_squares(); + residual_v[id] = temp; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { + int id = blockIdx.x * vec_hidden_size + idx; + _f16Vec temp = residual_v[id]; + temp *= s_variance; + temp *= weight_v[idx]; + input_v[id] = temp; + } +} + +/* Generic fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +template +__global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +fused_add_rms_norm_kernel( + scalar_t* __restrict__ input, // [..., hidden_size] + scalar_t* __restrict__ residual, // [..., hidden_size] + const scalar_t* __restrict__ weight, // [hidden_size] + const float epsilon, const int num_tokens, const int hidden_size) { + __shared__ float s_variance; + float variance = 0.0f; + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + scalar_t z = input[blockIdx.x * hidden_size + idx]; + z += residual[blockIdx.x * hidden_size + idx]; + float x = (float)z; + variance += x * x; + residual[blockIdx.x * hidden_size + idx] = z; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStore; + variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + + if (threadIdx.x == 0) { + s_variance = rsqrtf(variance / hidden_size + epsilon); + } + __syncthreads(); + + for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { + float x = (float)residual[blockIdx.x * hidden_size + idx]; + input[blockIdx.x * hidden_size + idx] = + ((scalar_t)(x * s_variance)) * weight[idx]; + } +} + +/* Function specialization in the case of FP16/BF16 tensors. + Additional optimizations we can make in this case are + packed and vectorized operations, which help with the + memory latency bottleneck. */ + +// template <> +// struct Vec { +// using Type = uint2; +// }; + +// template <> +// struct Vec { +// using Type = uint4; +// }; + +// template <> +// struct Vec { +// using Type = bf16_8_t; +// }; + +// template +// __global__ std::enable_if_t<(width > 0) && _typeConvert::exists> +// scaled_fused_add_rms_norm_kernel( +// c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] +// scalar_t* __restrict__ input, // [..., hidden_size] +// scalar_t* __restrict__ residual, // [..., hidden_size] +// const scalar_t* __restrict__ weight, // [hidden_size] +// const float epsilon, const float scale, const int num_tokens, +// const int hidden_size) { +// using in_v_t = typename Vec::Type; +// using out_v_t = typename Vec::Type; +// // Sanity checks on our vector struct and type-punned pointer arithmetic +// static_assert(std::is_pod_v<_f16Vec>); +// static_assert(sizeof(_f16Vec) == sizeof(scalar_t) * width); + +// const int vec_hidden_size = hidden_size / width; +// __shared__ float s_variance; +// float variance = 0.0f; +// /* These and the argument pointers are all declared `restrict` as they are +// not aliased in practice. Argument pointers should not be dereferenced +// in this kernel as that would be undefined behavior */ +// auto* __restrict__ out_v = reinterpret_cast(out); +// auto* __restrict__ input_v = +// reinterpret_cast<_f16Vec*>(input); +// auto* __restrict__ residual_v = +// reinterpret_cast<_f16Vec*>(residual); +// auto* __restrict__ weight_v = +// reinterpret_cast*>(weight); + +// for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { +// int id = blockIdx.x * vec_hidden_size + idx; +// _f16Vec temp = input_v[id]; +// temp += residual_v[id]; +// variance += temp.sum_squares(); +// residual_v[id] = temp; +// } + +// using BlockReduce = cub::BlockReduce; +// __shared__ typename BlockReduce::TempStorage reduceStore; +// variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + +// if (threadIdx.x == 0) { +// s_variance = rsqrtf(variance / hidden_size + epsilon); +// } +// __syncthreads(); + +// for (int idx = threadIdx.x; idx < vec_hidden_size; idx += blockDim.x) { +// int id = blockIdx.x * vec_hidden_size + idx; +// _f16Vec temp = residual_v[id]; +// temp *= s_variance; +// temp *= weight_v[idx]; +// out_v_t temp_quant = fp8::scaled_vec_conversion( +// *reinterpret_cast(&temp), scale); +// out_v[id] = temp_quant; +// } +// } + +/* Generic scaled_fused_add_rms_norm_kernel + The width field is not used here but necessary for other specializations. + */ +// template +// __global__ std::enable_if_t<(width == 0) || !_typeConvert::exists> +// scaled_fused_add_rms_norm_kernel( +// c10::Float8_e4m3fnuz* __restrict__ out, // [..., hidden_size] +// scalar_t* __restrict__ input, // [..., hidden_size] +// scalar_t* __restrict__ residual, // [..., hidden_size] +// const scalar_t* __restrict__ weight, // [hidden_size] +// const float epsilon, const float scale, const int num_tokens, +// const int hidden_size) { +// __shared__ float s_variance; +// float variance = 0.0f; + +// for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +// scalar_t z = input[blockIdx.x * hidden_size + idx]; +// z += residual[blockIdx.x * hidden_size + idx]; +// float x = (float)z; +// variance += x * x; +// residual[blockIdx.x * hidden_size + idx] = z; +// } + +// using BlockReduce = cub::BlockReduce; +// __shared__ typename BlockReduce::TempStorage reduceStore; +// variance = BlockReduce(reduceStore).Reduce(variance, cub::Sum{}, blockDim.x); + +// if (threadIdx.x == 0) { +// s_variance = rsqrtf(variance / hidden_size + epsilon); +// } +// __syncthreads(); + +// for (int idx = threadIdx.x; idx < hidden_size; idx += blockDim.x) { +// float x = (float)residual[blockIdx.x * hidden_size + idx]; +// float r = (x * s_variance) * (float)weight[idx] / scale; +// out[blockIdx.x * hidden_size + idx] = c10::Float8_e4m3fnuz( +// hip_fp8(r).data, c10::Float8_e4m3fnuz::from_bits()); +// } +// } + +} // namespace vllm + +void rms_norm(torch::Tensor& out, // [..., hidden_size] + torch::Tensor& input, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "rms_norm_kernel", [&] { + vllm::rms_norm_kernel<<>>( + out.data_ptr(), input.data_ptr(), + weight.data_ptr(), epsilon, num_tokens, hidden_size); + }); +} + +// void scaled_rms_norm(torch::Tensor& out, // [..., hidden_size] +// torch::Tensor& input, // [..., hidden_size] +// torch::Tensor& weight, // [hidden_size] +// torch::Tensor& scale, double epsilon) { +// int hidden_size = input.size(-1); +// int num_tokens = input.numel() / hidden_size; + +// dim3 grid(num_tokens); +// dim3 block(std::min(hidden_size, 1024)); +// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); +// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// VLLM_DISPATCH_FLOATING_TYPES( +// input.scalar_type(), "scaled_rms_norm_kernel", [&] { +// vllm::scaled_rms_norm_kernel<<>>( +// out.data_ptr(), input.data_ptr(), +// weight.data_ptr(), 1.0 / (*scale.data_ptr()), +// epsilon, num_tokens, hidden_size); +// }); +// } + +#define LAUNCH_FUSED_ADD_RMS_NORM(width) \ + VLLM_DISPATCH_FLOATING_TYPES( \ + input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \ + vllm::fused_add_rms_norm_kernel \ + <<>>(input.data_ptr(), \ + residual.data_ptr(), \ + weight.data_ptr(), epsilon, \ + num_tokens, hidden_size); \ + }); + +void fused_add_rms_norm(torch::Tensor& input, // [..., hidden_size] + torch::Tensor& residual, // [..., hidden_size] + torch::Tensor& weight, // [hidden_size] + double epsilon) { + int hidden_size = input.size(-1); + int num_tokens = input.numel() / hidden_size; + + dim3 grid(num_tokens); + /* This kernel is memory-latency bound in many scenarios. + When num_tokens is large, a smaller block size allows + for increased block occupancy on CUs and better latency + hiding on global mem ops. */ + const int max_block_size = (num_tokens < 256) ? 1024 : 256; + dim3 block(std::min(hidden_size, max_block_size)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + /*If the tensor types are FP16/BF16, try to use the optimized kernel + with packed + vectorized ops. + Max optimization is achieved with a width-8 vector of FP16/BF16s + since we can load at most 128 bits at once in a global memory op. + However, this requires each tensor's data to be aligned to 16 + bytes. + */ + auto inp_ptr = reinterpret_cast(input.data_ptr()); + auto res_ptr = reinterpret_cast(residual.data_ptr()); + auto wt_ptr = reinterpret_cast(weight.data_ptr()); + bool ptrs_are_aligned = + inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; + if (ptrs_are_aligned && hidden_size % 8 == 0) { + LAUNCH_FUSED_ADD_RMS_NORM(8); + } else { + LAUNCH_FUSED_ADD_RMS_NORM(0); + } +} + +// #define LAUNCH_SCALED_FUSED_ADD_RMS_NORM(width) \ +// VLLM_DISPATCH_FLOATING_TYPES( \ +// input.scalar_type(), "scaled_fused_add_rms_norm_kernel", [&] { \ +// vllm::scaled_fused_add_rms_norm_kernel \ +// <<>>( \ +// out.data_ptr(), \ +// input.data_ptr(), residual.data_ptr(), \ +// weight.data_ptr(), epsilon, \ +// *scale.data_ptr(), num_tokens, hidden_size); \ +// }); + +// void scaled_fused_add_rms_norm(torch::Tensor& out, // [..., hidden_size] +// torch::Tensor& input, // [..., hidden_size] +// torch::Tensor& residual, // [..., hidden_size] +// torch::Tensor& weight, // [hidden_size] +// torch::Tensor& scale, double epsilon) { +// int hidden_size = input.size(-1); +// int num_tokens = input.numel() / hidden_size; + +// dim3 grid(num_tokens); +// /* This kernel is memory-latency bound in many scenarios. +// When num_tokens is large, a smaller block size allows +// for increased block occupancy on CUs and better latency +// hiding on global mem ops. */ +// const int max_block_size = (num_tokens < 256) ? 1024 : 256; +// dim3 block(std::min(hidden_size, max_block_size)); +// const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); +// const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); +// /*If the tensor types are FP16/BF16, try to use the optimized kernel +// with packed + vectorized ops. +// Max optimization is achieved with a width-8 vector of FP16/BF16s +// since we can load at most 128 bits at once in a global memory op. +// However, this requires each tensor's data to be aligned to 16 +// bytes. +// */ +// auto inp_ptr = reinterpret_cast(input.data_ptr()); +// auto res_ptr = reinterpret_cast(residual.data_ptr()); +// auto wt_ptr = reinterpret_cast(weight.data_ptr()); +// bool ptrs_are_aligned = +// inp_ptr % 16 == 0 && res_ptr % 16 == 0 && wt_ptr % 16 == 0; +// if (ptrs_are_aligned && hidden_size % 8 == 0) { +// LAUNCH_SCALED_FUSED_ADD_RMS_NORM(8); +// } else { +// LAUNCH_SCALED_FUSED_ADD_RMS_NORM(0); +// } +// } diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_align_block_size_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_align_block_size_kernels.cu new file mode 100644 index 000000000..94aec6f3b --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_align_block_size_kernels.cu @@ -0,0 +1,140 @@ +#include +#include + +#include +#include + +#include "hip_compat.h" +#include "dispatch_utils.h" + +#define CEILDIV(x, y) (((x) + (y) - 1) / (y)) + +namespace vllm { + +namespace { +__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row, + int32_t col) { + // don't worry about overflow because num_experts is relatively small + return row * total_col + col; +} +} // namespace + +template +__global__ void moe_align_block_size_kernel(scalar_t* __restrict__ topk_ids, + int32_t* sorted_token_ids, + int32_t* expert_ids, + int32_t* token_nums, + int32_t* total_tokens_post_pad, + int32_t num_experts, + int32_t block_size, size_t numel) { + const size_t tokens_per_thread = CEILDIV(numel, blockDim.x); + const size_t start_idx = threadIdx.x * tokens_per_thread; + + extern __shared__ int32_t shared_mem[]; + + int32_t* tokens_cnts = + shared_mem; // 2d tensor with shape (num_experts + 1, num_experts) + int32_t* cumsum = + shared_mem + (num_experts + 1) * + num_experts; // 1d tensor with shape (num_experts + 1) + + for (int i = 0; i < num_experts; ++i) { + tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0; + } + + /** + * In the first step we compute token_cnts[thread_index + 1][expert_index], + * which counts how many tokens in the token shard of thread_index are + * assigned to expert expert_index. + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + ++tokens_cnts[index(num_experts, threadIdx.x + 1, topk_ids[i])]; + } + + __syncthreads(); + + // For each expert we accumulate the token counts from the different threads. + tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0; + for (int i = 1; i <= blockDim.x; ++i) { + tokens_cnts[index(num_experts, i, threadIdx.x)] += + tokens_cnts[index(num_experts, i - 1, threadIdx.x)]; + } + + __syncthreads(); + + // We accumulate the token counts of all experts in thread 0. + if (threadIdx.x == 0) { + cumsum[0] = 0; + for (int i = 1; i <= num_experts; ++i) { + cumsum[i] = cumsum[i - 1] + + CEILDIV(tokens_cnts[index(num_experts, blockDim.x, i - 1)], + block_size); + } + *total_tokens_post_pad = cumsum[num_experts] * block_size; + } + + __syncthreads(); + + /** + * For each expert, each thread processes the tokens of the corresponding + * blocks and stores the corresponding expert_id for each block. + */ + auto num = tokens_cnts[index(num_experts, blockDim.x, threadIdx.x)]; + for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1]; i++) + { + expert_ids[i] = threadIdx.x; + token_nums[i] = num; + num-=block_size; + } + + /** + * Each thread processes a token shard, calculating the index of each token + * after sorting by expert number. Given the example topk_ids = + * [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *, + * *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a + * padding value(preset in python). + */ + for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) { + int32_t expert_id = topk_ids[i]; + /** The cumsum[expert_id] stores the starting index of the tokens that the + * expert with expert_id needs to process, and + * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens + * processed by the expert with expert_id within the current thread's token + * shard. + */ + int32_t rank_post_pad = + tokens_cnts[index(num_experts, threadIdx.x, expert_id)] + + cumsum[expert_id] * block_size; + sorted_token_ids[rank_post_pad] = i; + ++tokens_cnts[index(num_experts, threadIdx.x, expert_id)]; + } +} +} // namespace vllm + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor token_nums, + torch::Tensor num_tokens_post_pad) +{ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_INTEGRAL_TYPES( + topk_ids.scalar_type(), "moe_align_block_size_kernel", [&] { + // calc needed amount of shared mem for `tokens_cnts` and `cumsum` + // tensors + const int32_t shared_mem = + ((num_experts + 1) * num_experts + (num_experts + 1)) * + sizeof(int32_t); + + // set dynamic shared mem + auto kernel = vllm::moe_align_block_size_kernel; + AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize( + (void*)kernel, shared_mem)); + kernel<<<1, num_experts, shared_mem, stream>>>( + topk_ids.data_ptr(), sorted_token_ids.data_ptr(), + experts_ids.data_ptr(), + token_nums.data_ptr(), + num_tokens_post_pad.data_ptr(), num_experts, block_size, + topk_ids.numel()); + }); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_ops.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_ops.h new file mode 100644 index 000000000..99f9395eb --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/moe_ops.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +void topk_softmax(torch::Tensor &topk_weights, torch::Tensor &topk_indices, + torch::Tensor &token_expert_indices, + torch::Tensor &gating_output, + bool need_renorm); + +void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, + int64_t block_size, torch::Tensor sorted_token_ids, + torch::Tensor experts_ids, + torch::Tensor token_nums, + torch::Tensor num_tokens_post_pad); + +void silu_and_mul(torch::Tensor &out, torch::Tensor &input); + +void rms_norm(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, + double epsilon); + +void fused_add_rms_norm(torch::Tensor &input, torch::Tensor &residual, + torch::Tensor &weight, double epsilon); + +// ck kernel +void layernorm2d(torch::Tensor &out, torch::Tensor &input, torch::Tensor &weight, torch::Tensor &bias, + double epsilon); + +void wvSpltK(at::Tensor &in_a, at::Tensor &in_b, at::Tensor &out_c, + const int64_t N_in, const int64_t CuCount); + +void rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int64_t head_size, + torch::Tensor &cos_sin_cache, bool is_neox); + +void batched_rotary_embedding(torch::Tensor &positions, torch::Tensor &query, + torch::Tensor &key, int64_t head_size, + torch::Tensor &cos_sin_cache, bool is_neox, + int64_t rot_dim, + torch::Tensor &cos_sin_cache_offsets); + +void moe_sum(torch::Tensor &input, torch::Tensor &output); + +// all reduce +using fptr_t = int64_t; +fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data, + const std::vector &handles, + const std::vector &offsets, int64_t rank, + bool full_nvlink); +void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out); +void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer, + torch::Tensor &out); +void dispose(fptr_t _fa); +int64_t meta_size(); +void register_buffer(fptr_t _fa, torch::Tensor &t, + const std::vector &handles, + const std::vector &offsets); +std::tuple> get_graph_buffer_ipc_meta( + fptr_t _fa); +void register_graph_buffers(fptr_t _fa, const std::vector &handles, + const std::vector> &offsets); +#ifdef USE_ROCM +torch::Tensor allocate_meta_buffer(int64_t size); +torch::Tensor get_meta_buffer_ipc_handle(torch::Tensor &inp); +#endif \ No newline at end of file diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/paged_attn_ops.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/paged_attn_ops.h new file mode 100644 index 000000000..376c02ecc --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/paged_attn_ops.h @@ -0,0 +1,41 @@ +#pragma once + +#include + +void LLMM1( + at::Tensor& in_a, at::Tensor& in_b, at::Tensor& out_c, + const int64_t rows_per_block); + + +void paged_attention_rocm( + torch::Tensor& out, torch::Tensor& exp_sums, + torch::Tensor& max_logits, torch::Tensor& tmp_out, + torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, + double scale, torch::Tensor& block_tables, + torch::Tensor& context_lens, int64_t block_size, + int64_t max_context_len, + const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, + double v_scale); + +void paged_attention_v1( + torch::Tensor& out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); + +void paged_attention_v2( + torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits, + torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache, + torch::Tensor& value_cache, int64_t num_kv_heads, double scale, + torch::Tensor& block_tables, torch::Tensor& seq_lens, int64_t block_size, + int64_t max_seq_len, const c10::optional& alibi_slopes, + const std::string& kv_cache_dtype, double k_scale, double v_scale, + const int64_t tp_rank, const int64_t blocksparse_local_blocks, + const int64_t blocksparse_vert_stride, const int64_t blocksparse_block_size, + const int64_t blocksparse_head_sliding_step); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/pos_encoding_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/pos_encoding_kernels.cu new file mode 100644 index 000000000..4df07af7e --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/pos_encoding_kernels.cu @@ -0,0 +1,203 @@ +#include +#include +#include + +#include "hip_compat.h" +#include "dispatch_utils.h" + +namespace vllm { + +template +inline __device__ void apply_token_rotary_embedding( + scalar_t* __restrict__ arr, const scalar_t* __restrict__ cos_ptr, + const scalar_t* __restrict__ sin_ptr, int rot_offset, int embed_dim) { + int x_index, y_index; + scalar_t cos, sin; + if (IS_NEOX) { + // GPT-NeoX style rotary embedding. + x_index = rot_offset; + y_index = embed_dim + rot_offset; + cos = VLLM_LDG(cos_ptr + x_index); + sin = VLLM_LDG(sin_ptr + x_index); + } else { + // GPT-J style rotary embedding. + x_index = 2 * rot_offset; + y_index = 2 * rot_offset + 1; + cos = VLLM_LDG(cos_ptr + x_index / 2); + sin = VLLM_LDG(sin_ptr + x_index / 2); + } + + const scalar_t x = arr[x_index]; + const scalar_t y = arr[y_index]; + arr[x_index] = x * cos - y * sin; + arr[y_index] = y * cos + x * sin; +} + +template +inline __device__ void apply_rotary_embedding( + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* cache_ptr, const int head_size, const int num_heads, + const int num_kv_heads, const int rot_dim, const int token_idx, + const int64_t query_stride, const int64_t key_stride) { + const int embed_dim = rot_dim / 2; + const scalar_t* cos_ptr = cache_ptr; + const scalar_t* sin_ptr = cache_ptr + embed_dim; + + const int nq = num_heads * embed_dim; + for (int i = threadIdx.x; i < nq; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * query_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + query + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } + + const int nk = num_kv_heads * embed_dim; + for (int i = threadIdx.x; i < nk; i += blockDim.x) { + const int head_idx = i / embed_dim; + const int64_t token_head = token_idx * key_stride + head_idx * head_size; + const int rot_offset = i % embed_dim; + apply_token_rotary_embedding( + key + token_head, cos_ptr, sin_ptr, rot_offset, embed_dim); + } +} + +template +__global__ void rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + const scalar_t* cache_ptr = cos_sin_cache + pos * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +template +__global__ void batched_rotary_embedding_kernel( + const int64_t* __restrict__ positions, // [batch_size, seq_len] or + // [num_tokens] + scalar_t* __restrict__ query, // [batch_size, seq_len, num_heads, + // head_size] or [num_tokens, num_heads, + // head_size] + scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, + // head_size] or [num_tokens, num_kv_heads, + // head_size] + const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // + // 2] + const int64_t* __restrict__ cos_sin_cache_offsets, // [batch_size, seq_len] + // or [num_tokens] + const int rot_dim, const int64_t query_stride, const int64_t key_stride, + const int num_heads, const int num_kv_heads, const int head_size) { + // Each thread block is responsible for one token. + const int token_idx = blockIdx.x; + int64_t pos = positions[token_idx]; + int64_t cos_sin_cache_offset = cos_sin_cache_offsets[token_idx]; + const scalar_t* cache_ptr = + cos_sin_cache + (cos_sin_cache_offset + pos) * rot_dim; + + apply_rotary_embedding( + query, key, cache_ptr, head_size, num_heads, num_kv_heads, rot_dim, + token_idx, query_stride, key_stride); +} + +} // namespace vllm + +void rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox) { + int64_t num_tokens = query.numel() / query.size(-1); + int rot_dim = cos_sin_cache.size(1); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::rotary_embedding_kernel<<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), rot_dim, + query_stride, key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + rot_dim, query_stride, key_stride, num_heads, num_kv_heads, + head_size); + } + }); +} + +/* +Batched version of rotary embedding, pack multiple LoRAs together +and process in batched manner. +*/ +void batched_rotary_embedding( + torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens] + torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or + // [num_tokens, num_heads * head_size] + torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or + // [num_tokens, num_kv_heads * head_size] + int64_t head_size, + torch::Tensor& cos_sin_cache, // [max_position, rot_dim] + bool is_neox, int64_t rot_dim, + torch::Tensor& cos_sin_cache_offsets // [num_tokens] +) { + int64_t num_tokens = cos_sin_cache_offsets.size(0); + int num_heads = query.size(-1) / head_size; + int num_kv_heads = key.size(-1) / head_size; + int64_t query_stride = query.stride(-2); + int64_t key_stride = key.stride(-2); + + dim3 grid(num_tokens); + dim3 block(std::min(num_heads * rot_dim / 2, 512)); + const at::cuda::OptionalCUDAGuard device_guard(device_of(query)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES(query.scalar_type(), "rotary_embedding", [&] { + if (is_neox) { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } else { + vllm::batched_rotary_embedding_kernel + <<>>( + positions.data_ptr(), query.data_ptr(), + key.data_ptr(), cos_sin_cache.data_ptr(), + cos_sin_cache_offsets.data_ptr(), rot_dim, query_stride, + key_stride, num_heads, num_kv_heads, head_size); + } + }); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/quant_utils.cuh b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/quant_utils.cuh new file mode 100644 index 000000000..25cab5a91 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/quant_utils.cuh @@ -0,0 +1,710 @@ +#pragma once +#include "hip_float8.h" + +#include +#include +#include + +#include "attention_dtypes.h" + +namespace vllm { +#ifdef USE_ROCM + +namespace fp8 { + #ifdef ENABLE_FP8 + +template +__inline__ __device__ Tout vec_conversion(const Tin& x) { + return x; +} + +template +__inline__ __device__ Tout scaled_vec_conversion(const Tin& x, + const float scale) { + return x; +} + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +vec_conversion(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8); + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0]; + tmp.h2r.y.data = f2[1]; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = vec_conversion(static_cast(a)); + tmp.u16[1] = vec_conversion(static_cast(a >> 8U)); + return tmp.u32; + #endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const uint32_t& a) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = vec_conversion((uint16_t)a); + tmp.u32[1] = vec_conversion((uint16_t)(a >> 16U)); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const uint2& a) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = vec_conversion(a.x); + tmp.u64[1] = vec_conversion(a.y); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f); +} + +using __nv_bfloat162 = __hip_bfloat162; + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a) { + __nv_bfloat162 res; + res.x = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a); + res.y = vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U)); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const uint32_t& a) { + bf16_4_t res; + res.x = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a); + res.y = vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t vec_conversion(const uint2& a) { + bf16_4_t tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float vec_conversion(const uint8_t& a) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8); +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +vec_conversion(const uint16_t& a) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0]; + res.y = f2[1]; + return res; + #else + float2 res; + res.x = vec_conversion(static_cast(a)); + res.y = vec_conversion(static_cast(a >> 8U)); + return res; + #endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +vec_conversion(const uint32_t& a) { + Float4_ res; + res.x = vec_conversion((uint16_t)a); + res.y = vec_conversion((uint16_t)(a >> 16U)); + return res; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ vec_conversion(const uint2& a) { + Float4_ tmp1, tmp2; + tmp1 = vec_conversion(a.x); + tmp2 = vec_conversion(a.y); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const uint16_t& a) { + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data)}; + return f8.data; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t +vec_conversion(const __nv_bfloat16& a) { + hip_fp8 res{__bfloat162float(a)}; + return res.data; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t vec_conversion(const float& a) { + hip_fp8 f8(a); + return f8.data; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +vec_conversion(const uint32_t& a) { + Float4_ tmp = vec_conversion(a); + float4 res = make_float4(tmp.x.x, tmp.x.y, tmp.y.x, tmp.y.y); + return res; +} + +// float2 -> half2 +template <> +__inline__ __device__ uint32_t +vec_conversion(const float2& a) { + union { + half2 float16; + uint32_t uint32; + }; + + float16 = __float22half2_rn(a); + return uint32; +} + +// Float4 -> half2x2 +template <> +__inline__ __device__ uint2 vec_conversion(const Float4_& a) { + uint2 b; + float2 val; + val.x = a.x.x; + val.y = a.x.y; + b.x = vec_conversion(val); + + val.x = a.y.x; + val.y = a.y.y; + b.y = vec_conversion(val); + return b; +} + +// Float4 -> float4 +template <> +__inline__ __device__ float4 vec_conversion(const Float4_& a) { + float4 b; + b.x = a.x.x; + b.y = a.x.y; + b.z = a.y.x; + b.w = a.y.y; + return b; +} + +// Float8 -> half2x4 +template <> +__inline__ __device__ uint4 vec_conversion(const Float8_& a) { + uint4 b; + b.x = vec_conversion(a.x); + b.y = vec_conversion(a.y); + b.z = vec_conversion(a.z); + b.w = vec_conversion(a.w); + return b; +} + +// float2 -> bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +vec_conversion<__nv_bfloat162, float2>(const float2& a) { + __nv_bfloat162 b = __float22bfloat162_rn(a); + return b; +} + +// Float4 -> bfloat162x2 +template <> +__inline__ __device__ bf16_4_t +vec_conversion(const Float4_& a) { + bf16_4_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + return b; +} + +// Float8 -> bfloat162x4 +template <> +__inline__ __device__ bf16_8_t +vec_conversion(const Float8_& a) { + bf16_8_t b; + b.x = __float22bfloat162_rn(a.x); + b.y = __float22bfloat162_rn(a.y); + b.z = __float22bfloat162_rn(a.z); + b.w = __float22bfloat162_rn(a.w); + return b; +} + +/* Scaled and vectorized conversions, for data exchange between high and low + precision domains + + Convention of the scale in API, e.g: FP8_data = Quantization( + High_Precision_data / scale ) s.t. Quantize(HP / scale) => FP8 Dequant(FP8) * + scale => HP + + */ + +// fp8 -> half +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + __half_raw res; + res.data = static_cast(f8) * scale; + return res.x; +} + +// fp8x2 -> half2 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + union { + __half2_raw h2r; + uint32_t ui32; + } tmp; + tmp.h2r.x.data = f2[0] * scale; + tmp.h2r.y.data = f2[1] * scale; + return tmp.ui32; + #else + union { + uint16_t u16[2]; + uint32_t u32; + } tmp; + + tmp.u16[0] = + scaled_vec_conversion(static_cast(a), scale); + tmp.u16[1] = scaled_vec_conversion( + static_cast(a >> 8U), scale); + return tmp.u32; + #endif +} + +// fp8x4 -> half2x2 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const uint32_t& a, float scale) { + union { + uint2 u32x2; + uint32_t u32[2]; + } tmp; + tmp.u32[0] = scaled_vec_conversion((uint16_t)a, scale); + tmp.u32[1] = + scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return tmp.u32x2; +} + +// fp8x8 -> half2x4 +template <> +__inline__ __device__ uint4 scaled_vec_conversion(const uint2& a, + float scale) { + union { + uint4 u64x2; + uint2 u64[2]; + } tmp; + tmp.u64[0] = scaled_vec_conversion(a.x, scale); + tmp.u64[1] = scaled_vec_conversion(a.y, scale); + return tmp.u64x2; +} + +using __nv_bfloat16 = __hip_bfloat16; + +// fp8 -> __nv_bfloat16 +template <> +__inline__ __device__ __nv_bfloat16 +scaled_vec_conversion<__nv_bfloat16, uint8_t>(const uint8_t& a, float scale) { + hip_fp8 f8{a, hip_fp8::from_bits()}; + float f{f8}; + return __float2bfloat16(f * scale); +} + +// fp8x2 -> __nv_bfloat162 +template <> +__inline__ __device__ __nv_bfloat162 +scaled_vec_conversion<__nv_bfloat162, uint16_t>(const uint16_t& a, + float scale) { + __nv_bfloat162 res; + res.x = scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)a, scale); + res.y = + scaled_vec_conversion<__nv_bfloat16, uint8_t>((uint8_t)(a >> 8U), scale); + return res; +} + +// fp8x4 -> bf16_4_t +template <> +__inline__ __device__ bf16_4_t +scaled_vec_conversion(const uint32_t& a, float scale) { + bf16_4_t res; + res.x = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)a, scale); + res.y = scaled_vec_conversion<__nv_bfloat162, uint16_t>((uint16_t)(a >> 16U), + scale); + return res; +} + +// fp8x8 -> bf16_8_t +template <> +__inline__ __device__ bf16_8_t +scaled_vec_conversion(const uint2& a, float scale) { + bf16_4_t tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + bf16_8_t res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// fp8 -> float +template <> +__inline__ __device__ float scaled_vec_conversion( + const uint8_t& a, float scale) { + hip_fp8 fp8{a, hip_fp8::from_bits()}; + return static_cast(fp8) * scale; +} + +// fp8x2 -> float2 +template <> +__inline__ __device__ float2 +scaled_vec_conversion(const uint16_t& a, float scale) { + #if defined(__HIP__MI300__) + float2 res; + const auto& f2 = __builtin_amdgcn_cvt_pk_f32_fp8(a, 0); + res.x = f2[0] * scale; + res.y = f2[1] * scale; + return res; + #else + float2 res; + res.x = scaled_vec_conversion(static_cast(a), scale); + res.y = scaled_vec_conversion(static_cast(a >> 8U), + scale); + return res; + #endif +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ Float4_ +scaled_vec_conversion(const uint32_t& a, const float scale) { + Float4_ res; + res.x = scaled_vec_conversion((uint16_t)a, scale); + res.y = scaled_vec_conversion((uint16_t)(a >> 16U), scale); + return res; +} + +// fp8x4 -> float4 +template <> +__inline__ __device__ float4 +scaled_vec_conversion(const uint32_t& a, float scale) { + Float4_ res = scaled_vec_conversion(a, scale); + return {res.x.x, res.x.y, res.y.x, res.y.y}; +} + +// fp8x8 -> float8 +template <> +__inline__ __device__ Float8_ +scaled_vec_conversion(const uint2& a, float scale) { + Float4_ tmp1, tmp2; + tmp1 = scaled_vec_conversion(a.x, scale); + tmp2 = scaled_vec_conversion(a.y, scale); + Float8_ res; + res.x = tmp1.x; + res.y = tmp1.y; + res.z = tmp2.x; + res.w = tmp2.y; + return res; +} + +// half -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const uint16_t& a, float scale) { + __half_raw tmp; + tmp.x = a; + + hip_fp8 f8{static_cast(tmp.data / scale)}; + return f8.data; +} + +// halfx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const uint32_t& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = tmp.h2r.x.data / scale; + f2.f = tmp.h2r.y.data / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint32_t ui32; + __half2_raw h2r; + } tmp; + tmp.ui32 = a; + + union { + uint8_t ui8[2]; + uint16_t ui16; + } res; + res.ui8[0] = scaled_vec_conversion(tmp.h2r.x.x, scale); + res.ui8[1] = scaled_vec_conversion(tmp.h2r.y.x, scale); + return res.ui16; + #endif +} + +// half2x2 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const uint2& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// half2x4 -> fp8x8 +template <> +__inline__ __device__ uint2 scaled_vec_conversion(const uint4& a, + float scale) { + union { + uint2 ui2[2]; + uint4 ui4; + } tmp; + tmp.ui4 = a; + uint2 res; + res.x = scaled_vec_conversion(tmp.ui2[0], scale); + res.y = scaled_vec_conversion(tmp.ui2[1], scale); + return res; +} + +// bf16 -> fp8 +template <> +__inline__ __device__ uint8_t scaled_vec_conversion( + const __nv_bfloat16& a, float scale) { + hip_fp8 res{__bfloat162float(a) / scale}; + return res.data; +} + +// bf16x2 -> fp8x2 +template <> +__inline__ __device__ uint16_t scaled_vec_conversion( + const __nv_bfloat162& a, float scale) { + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; +} + +// bf16x4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const bf16_4_t& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion(a.x, scale); + tmp.ui16[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui32; +} + +// bf16x8 -> fp8x8 +template <> +__inline__ __device__ uint2 +scaled_vec_conversion(const bf16_8_t& a, float scale) { + uint2 res; + res.x = scaled_vec_conversion({a.x, a.y}, scale); + res.y = scaled_vec_conversion({a.z, a.w}, scale); + return res; +} + +// float -> fp8 +template <> +__inline__ __device__ uint8_t +scaled_vec_conversion(const float& a, float scale) { + hip_fp8 f8(a); + return f8.data; +} + +// floatx2 -> fp8x2 +template <> +__inline__ __device__ uint16_t +scaled_vec_conversion(const float2& a, float scale) { + #ifdef __HIP__MI300__ + union { + uint32_t ui32; + float f; + } f1, f2; + f1.f = a.x / scale; + f2.f = a.y / scale; + if ((f1.ui32 & 0x7F800000) != 0x7F800000) { + f1.f = __builtin_amdgcn_fmed3f(f1.f, 240.0, -240.0); + } + if ((f2.ui32 & 0x7F800000) != 0x7F800000) { + f2.f = __builtin_amdgcn_fmed3f(f2.f, 240.0, -240.0); + } + return __builtin_amdgcn_cvt_pk_fp8_f32(f1.f, f2.f, 0, 0); + #else + union { + uint8_t ui8[2]; + uint16_t ui16; + } tmp; + tmp.ui8[0] = scaled_vec_conversion(a.x, scale); + tmp.ui8[1] = scaled_vec_conversion(a.y, scale); + return tmp.ui16; + #endif +} + +// floatx4 -> fp8x4 +template <> +__inline__ __device__ uint32_t +scaled_vec_conversion(const float4& a, float scale) { + union { + uint16_t ui16[2]; + uint32_t ui32; + } tmp; + tmp.ui16[0] = scaled_vec_conversion({a.x, a.y}, scale); + tmp.ui16[1] = scaled_vec_conversion({a.z, a.w}, scale); + return tmp.ui32; +} + #endif // ENABLE_FP8 + +template +__inline__ __device__ Tout convert(const Tin& x) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return vec_conversion(x); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + +template +__inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) { + #ifdef ENABLE_FP8 + if constexpr (kv_dt == Fp8KVCacheDataType::kFp8E4M3) { + return scaled_vec_conversion(x, scale); + } + #endif + assert(false); + return {}; // Squash missing return statement warning +} + + // The following macro is used to dispatch the conversion function based on + // the data type of the key and value cache. The FN is a macro that calls a + // function with template. + #define DISPATCH_BY_KV_CACHE_DTYPE(SRC_DTYPE, KV_DTYPE, FN) \ + if (KV_DTYPE == "auto") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, float, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint16_t, vllm::Fp8KVCacheDataType::kAuto); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, __nv_bfloat16, vllm::Fp8KVCacheDataType::kAuto); \ + } else { \ + TORCH_CHECK(false, "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + if (KV_DTYPE == "fp8" || KV_DTYPE == "fp8_e4m3") { \ + if (SRC_DTYPE == at::ScalarType::Float) { \ + FN(float, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::Half) { \ + FN(uint16_t, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else if (SRC_DTYPE == at::ScalarType::BFloat16) { \ + FN(__nv_bfloat16, uint8_t, vllm::Fp8KVCacheDataType::kFp8E4M3); \ + } else { \ + TORCH_CHECK(false, \ + "Unsupported input type of kv cache: ", SRC_DTYPE); \ + } \ + } else { \ + TORCH_CHECK(false, "Unsupported data type of kv cache: ", KV_DTYPE); \ + } \ + } + +} // namespace fp8 +#endif // USE_ROCM +} // namespace vllm diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/rocm_ops.cpp b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/rocm_ops.cpp new file mode 100644 index 000000000..974ceefe7 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/rocm_ops.cpp @@ -0,0 +1,112 @@ +#include "moe_ops.h" +#include "paged_attn_ops.h" +#include "gemm_a8w8.h" +#include "cache.h" +#include + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) +{ + m.def("topk_softmax", &topk_softmax, + "Apply topk softmax to the gating outputs."); + m.def("moe_align_block_size", &moe_align_block_size, + "Aligning the number of tokens to be processed by each expert such " + "that it is divisible by the block size."); + m.def("silu_and_mul", &silu_and_mul, "Activation function used in SwiGLU."); + m.def("rms_norm", &rms_norm, "Apply Root Mean Square (RMS) Normalization to the input tensor."); + m.def("fused_add_rms_norm", &fused_add_rms_norm, "In-place fused Add and RMS Normalization"); + m.def("wvSpltK", &wvSpltK, "wvSpltK(Tensor in_a, Tensor in_b, Tensor! out_c, int N_in," + " int CuCount) -> ()"); + m.def("LLMM1", &LLMM1, "LLMM1(Tensor in_a, Tensor in_b, Tensor! out_c, int rows_per_block) -> " + "()"); + m.def("rotary_embedding", &rotary_embedding, "rotary_embedding"); + m.def("batched_rotary_embedding", &batched_rotary_embedding, "batched_rotary_embedding"); + m.def("moe_sum", &moe_sum, "moe_sum(Tensor! input, Tensor output) -> ()"); + m.def("paged_attention_rocm", &paged_attention_rocm, + "paged_attention_rocm(Tensor! out, Tensor exp_sums," + " Tensor max_logits, Tensor tmp_out," + " Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads," + " float scale, Tensor block_tables," + " Tensor context_lens, int block_size," + " int max_context_len," + " Tensor? alibi_slopes," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + m.def("paged_attention_v1", &paged_attention_v1, + "paged_attention_v1(" + " Tensor! out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + m.def("paged_attention_v2", &paged_attention_v2, + "paged_attention_v2(" + " Tensor! out, Tensor! exp_sums, Tensor! max_logits," + " Tensor! tmp_out, Tensor query, Tensor key_cache," + " Tensor value_cache, int num_kv_heads, float scale," + " Tensor block_tables, Tensor seq_lens, int block_size," + " int max_seq_len, Tensor? alibi_slopes," + " str kv_cache_dtype, float k_scale, float v_scale," + " int tp_rank, int blocksparse_local_blocks," + " int blocksparse_vert_stride, int blocksparse_block_size," + " int blocksparse_head_sliding_step) -> ()"); + + m.def("gemm_a8w8", &gemm_a8w8, "gemm_a8w8"); + m.def("swap_blocks", &swap_blocks, + "swap_blocks(Tensor src, Tensor! dst, Tensor block_mapping) -> ()"); + m.def("copy_blocks", ©_blocks, + "copy_blocks(Tensor(a!)[] key_caches, Tensor[](b!) value_caches, " + "Tensor block_mapping) -> ()"); + + m.def("reshape_and_cache", &reshape_and_cache, + "reshape_and_cache(Tensor key, Tensor value," + " Tensor! key_cache, Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + m.def("reshape_and_cache_flash", &reshape_and_cache_flash, + "reshape_and_cache_flash(Tensor key, Tensor value," + " Tensor! key_cache," + " Tensor! value_cache," + " Tensor slot_mapping," + " str kv_cache_dtype," + " float k_scale, float v_scale) -> ()"); + m.def("convert_fp8", &convert_fp8, + "convert_fp8(Tensor! dst_cache, Tensor src_cache, float scale, " + "str kv_cache_dtype) -> ()"); + + // Custom all-reduce kernels + m.def("init_custom_ar", &init_custom_ar, + "init_custom_ar(Tensor meta, Tensor rank_data, " + "str[] handles, int[] offsets, int rank, " + "bool full_nvlink) -> int"); + + m.def("all_reduce_reg", &all_reduce_reg, "all_reduce_reg(int fa, Tensor inp, Tensor! out) -> ()"); + + m.def("all_reduce_unreg", &all_reduce_unreg, + "all_reduce_unreg(int fa, Tensor inp, Tensor reg_buffer, Tensor! out) -> " + "()"); + + m.def("dispose", &dispose); + m.def("meta_size", &meta_size); + + m.def("register_buffer", ®ister_buffer, + "register_buffer(int fa, Tensor t, str[] handles, " + "int[] offsets) -> ()"); + + m.def("get_graph_buffer_ipc_meta", &get_graph_buffer_ipc_meta); + m.def("register_graph_buffers", ®ister_graph_buffers); +#ifdef USE_ROCM + m.def("allocate_meta_buffer", &allocate_meta_buffer); + m.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle); +#endif + +#if defined(FIND_CK) + // ck staff start + m.def("layernorm2d_fwd", &layernorm2d); + // ck staff end +#endif +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/topk_softmax_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/topk_softmax_kernels.cu new file mode 100644 index 000000000..29844a1aa --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/csrc/topk_softmax_kernels.cu @@ -0,0 +1,600 @@ +/* + * Adapted from https://github.com/NVIDIA/TensorRT-LLM/blob/v0.7.1/cpp/tensorrt_llm/kernels/mixtureOfExperts/moe_kernels.cu + * Copyright (c) 2024, The vLLM team. + * SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: Apache-2.0 + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include "hip_compat.h" +#include "dispatch_utils.h" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) + +namespace vllm { +namespace moe { + +/// Aligned array type +template < + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = sizeof(T) * N +> +class alignas(Alignment) AlignedArray { + float data[N]; +}; + +// ====================== Softmax things =============================== +// We have our own implementation of softmax here so we can support transposing the output +// in the softmax kernel when we extend this module to support expert-choice routing. +template +__launch_bounds__(TPB) __global__ + void moeSoftmax(const float* input, const bool* finished, float* output, const int num_cols) +{ + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + __shared__ float normalizing_factor; + __shared__ float float_max; + + const int thread_row_offset = blockIdx.x * num_cols; + + cub::Sum sum; + float threadData(-FLT_MAX); + + // Don't touch finished rows. + if ((finished != nullptr) && finished[blockIdx.x]) + { + return; + } + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData = max(static_cast(input[idx]), threadData); + } + + const float maxElem = BlockReduce(tmpStorage).Reduce(threadData, cub::Max()); + if (threadIdx.x == 0) + { + float_max = maxElem; + } + __syncthreads(); + + threadData = 0; + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + threadData += exp((static_cast(input[idx]) - float_max)); + } + + const auto Z = BlockReduce(tmpStorage).Reduce(threadData, sum); + + if (threadIdx.x == 0) + { + normalizing_factor = 1.f / Z; + } + __syncthreads(); + + for (int ii = threadIdx.x; ii < num_cols; ii += TPB) + { + const int idx = thread_row_offset + ii; + const float val = exp((static_cast(input[idx]) - float_max)) * normalizing_factor; + output[idx] = val; + } +} + +template +__launch_bounds__(TPB) __global__ void moeTopK(const float* inputs_after_softmax, const bool* finished, float* output, + int* indices, int* source_rows, const int num_experts, const int k, const int start_expert, const int end_expert, const bool need_renorm) +{ + + using cub_kvp = cub::KeyValuePair; + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmpStorage; + + cub_kvp thread_kvp; + cub::ArgMax arg_max; + + const int num_rows = gridDim.x; + const int block_row = blockIdx.x; + + float renorm_value = 0.0f; + const bool row_is_active = finished ? !finished[block_row] : true; + const int thread_read_offset = blockIdx.x * num_experts; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + thread_kvp.key = 0; + thread_kvp.value = -1.f; // This is OK because inputs are probabilities + + cub_kvp inp_kvp; + for (int expert = threadIdx.x; expert < num_experts; expert += TPB) + { + const int idx = thread_read_offset + expert; + inp_kvp.key = expert; + inp_kvp.value = inputs_after_softmax[idx]; + + for (int prior_k = 0; prior_k < k_idx; ++prior_k) + { + const int prior_winning_expert = indices[k * block_row + prior_k]; + + if (prior_winning_expert == expert) + { + inp_kvp = thread_kvp; + } + } + + thread_kvp = arg_max(inp_kvp, thread_kvp); + } + + const cub_kvp result_kvp = BlockReduce(tmpStorage).Reduce(thread_kvp, arg_max); + if (threadIdx.x == 0) + { + // Ignore experts the node isn't responsible for with expert parallelism + const int expert = result_kvp.key; + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + const int idx = k * block_row + k_idx; + output[idx] = result_kvp.value; + indices[idx] = should_process_row ? (expert - start_expert) : num_experts; + assert(indices[idx] >= 0); + source_rows[idx] = k_idx * num_rows + block_row; + + if (need_renorm) + { + renorm_value += result_kvp.value; + } + } + __syncthreads(); + } + + if (need_renorm && threadIdx.x == 0 && renorm_value != 0.f) + { + renorm_value = 1 / renorm_value; + for (int k_idx = 0; k_idx < k; k_idx++) + { + int64_t const idx = k * block_row + k_idx; + output[idx] *= renorm_value; + } + } +} + +// ====================== TopK softmax things =============================== + +/* + A Top-K gating softmax written to exploit when the number of experts in the MoE layers + are a small power of 2. This allows us to cleanly share the rows among the threads in + a single warp and eliminate communication between warps (so no need to use shared mem). + + It fuses the softmax, max and argmax into a single kernel. + + Limitations: + 1) This implementation is intended for when the number of experts is a small power of 2. + 2) This implementation assumes k is small, but will work for any k. +*/ + +template +__launch_bounds__(WARPS_PER_CTA *WARP_SIZE) __global__ + void topkGatingSoftmax(const float *input, const bool *finished, float *output, const int num_rows, int *indices, + int *source_rows, const int k, const int start_expert, const int end_expert, const bool need_renorm) +{ + // We begin by enforcing compile time assertions and setting up compile time constants. + static_assert(VPT == (VPT & -VPT), "VPT must be power of 2"); + static_assert(NUM_EXPERTS == (NUM_EXPERTS & -NUM_EXPERTS), "NUM_EXPERTS must be power of 2"); + static_assert(BYTES_PER_LDG == (BYTES_PER_LDG & -BYTES_PER_LDG), "BYTES_PER_LDG must be power of 2"); + static_assert(BYTES_PER_LDG <= 16, "BYTES_PER_LDG must be leq 16"); + + // Number of bytes each thread pulls in per load + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static constexpr int ELTS_PER_ROW = NUM_EXPERTS; + static constexpr int THREADS_PER_ROW = ELTS_PER_ROW / VPT; + static constexpr int LDG_PER_THREAD = VPT / ELTS_PER_LDG; + + // Restrictions based on previous section. + static_assert(VPT % ELTS_PER_LDG == 0, "The elements per thread must be a multiple of the elements per ldg"); + static_assert(WARP_SIZE % THREADS_PER_ROW == 0, "The threads per row must cleanly divide the threads per warp"); + static_assert(THREADS_PER_ROW == (THREADS_PER_ROW & -THREADS_PER_ROW), "THREADS_PER_ROW must be power of 2"); + static_assert(THREADS_PER_ROW <= WARP_SIZE, "THREADS_PER_ROW can be at most warp size"); + + // We have NUM_EXPERTS elements per row. We specialize for small #experts + static constexpr int ELTS_PER_WARP = WARP_SIZE * VPT; + static constexpr int ROWS_PER_WARP = ELTS_PER_WARP / ELTS_PER_ROW; + static constexpr int ROWS_PER_CTA = WARPS_PER_CTA * ROWS_PER_WARP; + + // Restrictions for previous section. + static_assert(ELTS_PER_WARP % ELTS_PER_ROW == 0, "The elts per row must cleanly divide the total elt per warp"); + + // ===================== From this point, we finally start computing run-time variables. ======================== + + // Compute CTA and warp rows. We pack multiple rows into a single warp, and a block contains WARPS_PER_CTA warps. + // This, each block processes a chunk of rows. We start by computing the start row for each block. + const int cta_base_row = blockIdx.x * ROWS_PER_CTA; + + // Now, using the base row per thread block, we compute the base row per warp. + const int warp_base_row = cta_base_row + threadIdx.y * ROWS_PER_WARP; + + // The threads in a warp are split into sub-groups that will work on a row. + // We compute row offset for each thread sub-group + const int thread_row_in_warp = threadIdx.x / THREADS_PER_ROW; + const int thread_row = warp_base_row + thread_row_in_warp; + + // Threads with indices out of bounds should early exit here. + if (thread_row >= num_rows) + { + return; + } + const bool row_is_active = finished ? !finished[thread_row] : true; + + // We finally start setting up the read pointers for each thread. First, each thread jumps to the start of the + // row it will read. + const float* thread_row_ptr = input + thread_row * ELTS_PER_ROW; + + // Now, we compute the group each thread belong to in order to determine the first column to start loads. + const int thread_group_idx = threadIdx.x % THREADS_PER_ROW; + const int first_elt_read_by_thread = thread_group_idx * ELTS_PER_LDG; + const float* thread_read_ptr = thread_row_ptr + first_elt_read_by_thread; + + // Determine the pointer type to use to read in the data depending on the BYTES_PER_LDG template param. In theory, + // this can support all powers of 2 up to 16. + // NOTE(woosuk): The original implementation uses CUTLASS aligned array here. + // We defined our own aligned array and use it here to avoid the dependency on CUTLASS. + using AccessType = AlignedArray; + + // Finally, we pull in the data from global mem + float row_chunk[VPT]; + AccessType* row_chunk_vec_ptr = reinterpret_cast(&row_chunk); + const AccessType* vec_thread_read_ptr = reinterpret_cast(thread_read_ptr); +#pragma unroll + for (int ii = 0; ii < LDG_PER_THREAD; ++ii) + { + row_chunk_vec_ptr[ii] = vec_thread_read_ptr[ii * THREADS_PER_ROW]; + } + + // First, we perform a max reduce within the thread. We can do the max in fp16 safely (I think) and just + // convert to float afterwards for the exp + sum reduction. + float thread_max = row_chunk[0]; +#pragma unroll + for (int ii = 1; ii < VPT; ++ii) + { + thread_max = max(thread_max, row_chunk[ii]); + } + +// Now, we find the max within the thread group and distribute among the threads. We use a butterfly reduce. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + thread_max = max(thread_max, VLLM_SHFL_XOR_SYNC_WIDTH(thread_max, mask, THREADS_PER_ROW)); + } + + // From this point, thread max in all the threads have the max within the row. + // Now, we subtract the max from each element in the thread and take the exp. We also compute the thread local sum. + float row_sum = 0; +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = expf(row_chunk[ii] - thread_max); + row_sum += row_chunk[ii]; + } + +// Now, we perform the sum reduce within each thread group. Similar to the max reduce, we use a bufferfly pattern. +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + row_sum += VLLM_SHFL_XOR_SYNC_WIDTH(row_sum, mask, THREADS_PER_ROW); + } + + // From this point, all threads have the max and the sum for their rows in the thread_max and thread_sum variables + // respectively. Finally, we can scale the rows for the softmax. Technically, for top-k gating we don't need to + // compute the entire softmax row. We can likely look at the maxes and only compute for the top-k values in the row. + // However, this kernel will likely not be a bottle neck and it seems better to closer match torch and find the + // argmax after computing the softmax. + const float reciprocal_row_sum = 1.f / row_sum; + +#pragma unroll + for (int ii = 0; ii < VPT; ++ii) + { + row_chunk[ii] = row_chunk[ii] * reciprocal_row_sum; + } + + // Now, softmax_res contains the softmax of the row chunk. Now, I want to find the topk elements in each row, along + // with the max index. + int start_col = first_elt_read_by_thread; + static constexpr int COLS_PER_GROUP_LDG = ELTS_PER_LDG * THREADS_PER_ROW; + + float renorm_value = 0.0f; + for (int k_idx = 0; k_idx < k; ++k_idx) + { + // First, each thread does the local argmax + float max_val = row_chunk[0]; + int expert = start_col; +#pragma unroll + for (int ldg = 0, col = start_col; ldg < LDG_PER_THREAD; ++ldg, col += COLS_PER_GROUP_LDG) + { +#pragma unroll + for (int ii = 0; ii < ELTS_PER_LDG; ++ii) + { + float val = row_chunk[ldg * ELTS_PER_LDG + ii]; + + // No check on the experts here since columns with the smallest index are processed first and only + // updated if > (not >=) + if (val > max_val) + { + max_val = val; + expert = col + ii; + } + } + } + +// Now, we perform the argmax reduce. We use the butterfly pattern so threads reach consensus about the max. +// This will be useful for K > 1 so that the threads can agree on "who" had the max value. That thread can +// then blank out their max with -inf and the warp can run more iterations... +#pragma unroll + for (int mask = THREADS_PER_ROW / 2; mask > 0; mask /= 2) + { + float other_max = VLLM_SHFL_XOR_SYNC_WIDTH(max_val, mask, THREADS_PER_ROW); + int other_expert = VLLM_SHFL_XOR_SYNC_WIDTH(expert, mask, THREADS_PER_ROW); + + // We want lower indices to "win" in every thread so we break ties this way + if (other_max > max_val || (other_max == max_val && other_expert < expert)) + { + max_val = other_max; + expert = other_expert; + } + } + + // Write the max for this k iteration to global memory. + if (thread_group_idx == 0) + { + // Add a guard to ignore experts not included by this node + const bool node_uses_expert = expert >= start_expert && expert < end_expert; + const bool should_process_row = row_is_active && node_uses_expert; + + // The lead thread from each sub-group will write out the final results to global memory. (This will be a + // single) thread per row of the input/output matrices. + const int idx = k * thread_row + k_idx; + output[idx] = max_val; + indices[idx] = should_process_row ? (expert - start_expert) : NUM_EXPERTS; + source_rows[idx] = k_idx * num_rows + thread_row; + + // Accumulate renorm scalar + if (need_renorm) + { + renorm_value += max_val; + } + } + + // Finally, we clear the value in the thread with the current max if there is another iteration to run. + if (k_idx + 1 < k) + { + const int ldg_group_for_expert = expert / COLS_PER_GROUP_LDG; + const int thread_to_clear_in_group = (expert / ELTS_PER_LDG) % THREADS_PER_ROW; + + // Only the thread in the group which produced the max will reset the "winning" value to -inf. + if (thread_group_idx == thread_to_clear_in_group) + { + const int offset_for_expert = expert % ELTS_PER_LDG; + // Safe to set to any negative value since row_chunk values must be between 0 and 1. + row_chunk[ldg_group_for_expert * ELTS_PER_LDG + offset_for_expert] = -10000.f; + } + } + } + + if (need_renorm && thread_group_idx == 0 && renorm_value != 0.f) + { + renorm_value = 1 / renorm_value; + for (int k_idx = 0; k_idx < k; k_idx++) + { + int64_t const idx = k * thread_row + k_idx; + output[idx] *= renorm_value; + } + } +} + +namespace detail +{ +// Constructs some constants needed to partition the work across threads at compile time. +template +struct TopkConstants +{ + static constexpr int ELTS_PER_LDG = BYTES_PER_LDG / sizeof(float); + static_assert(EXPERTS / (ELTS_PER_LDG * WARP_SIZE) == 0 || EXPERTS % (ELTS_PER_LDG * WARP_SIZE) == 0, ""); + static constexpr int VECs_PER_THREAD = MAX(1, EXPERTS / (ELTS_PER_LDG * WARP_SIZE)); + static constexpr int VPT = VECs_PER_THREAD * ELTS_PER_LDG; + static constexpr int THREADS_PER_ROW = EXPERTS / VPT; + static constexpr int ROWS_PER_WARP = WARP_SIZE / THREADS_PER_ROW; +}; +} // namespace detail + +template +void topkGatingSoftmaxLauncherHelper(const float *input, const bool *finished, float *output, int *indices, + int *source_row, const int num_rows, const int k, const int start_expert, const int end_expert, const bool need_renorm, cudaStream_t stream) +{ + static constexpr std::size_t MAX_BYTES_PER_LDG = 16; + + static constexpr int BYTES_PER_LDG = MIN(MAX_BYTES_PER_LDG, sizeof(float) * EXPERTS); + using Constants = detail::TopkConstants; + static constexpr int VPT = Constants::VPT; + static constexpr int ROWS_PER_WARP = Constants::ROWS_PER_WARP; + const int num_warps = (num_rows + ROWS_PER_WARP - 1) / ROWS_PER_WARP; + const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB; + + dim3 block_dim(WARP_SIZE, WARPS_PER_TB); + topkGatingSoftmax<<>>( + input, finished, output, num_rows, indices, source_row, k, start_expert, end_expert, need_renorm); +} + +#define LAUNCH_SOFTMAX(NUM_EXPERTS, WARPS_PER_TB) \ + topkGatingSoftmaxLauncherHelper( \ + gating_output, nullptr, topk_weights, topk_indicies, \ + token_expert_indices, num_tokens, topk, 0, num_experts, need_renorm, \ + stream); + +void topkGatingSoftmaxKernelLauncher( + const float* gating_output, + float* topk_weights, + int* topk_indicies, + int* token_expert_indices, + float* softmax_workspace, + const int num_tokens, + const int num_experts, + const int topk, + const bool need_renorm, + cudaStream_t stream) { + static constexpr int WARPS_PER_TB = 4; + switch (num_experts) { + case 1: + LAUNCH_SOFTMAX(1, WARPS_PER_TB); + break; + case 2: + LAUNCH_SOFTMAX(2, WARPS_PER_TB); + break; + case 4: + LAUNCH_SOFTMAX(4, WARPS_PER_TB); + break; + case 8: + LAUNCH_SOFTMAX(8, WARPS_PER_TB); + break; + case 16: + LAUNCH_SOFTMAX(16, WARPS_PER_TB); + break; + case 32: + LAUNCH_SOFTMAX(32, WARPS_PER_TB); + break; + case 64: + LAUNCH_SOFTMAX(64, WARPS_PER_TB); + break; + case 128: + LAUNCH_SOFTMAX(128, WARPS_PER_TB); + break; + case 256: + LAUNCH_SOFTMAX(256, WARPS_PER_TB); + break; + default: { + TORCH_CHECK(softmax_workspace != nullptr, + "softmax_workspace must be provided for num_experts that are not a power of 2."); + static constexpr int TPB = 256; + moeSoftmax<<>>( + gating_output, nullptr, softmax_workspace, num_experts); + moeTopK<<>>( + softmax_workspace, nullptr, topk_weights, topk_indicies, token_expert_indices, + num_experts, topk, 0, num_experts, need_renorm); + } + } +} + +template +__global__ void moe_sum_kernel( + scalar_t* __restrict__ out, // [..., d] + const scalar_t* __restrict__ input, // [..., topk, d] + const int d) +{ + const int64_t token_idx = blockIdx.x; + for (int64_t idx = threadIdx.x; idx < d; idx += blockDim.x) { + scalar_t x = 0.0; + #pragma unroll + for (int k = 0; k < TOPK; ++k) { + x += VLLM_LDG(&input[token_idx * TOPK * d + k * d + idx]); + } + out[token_idx * d + idx] = x; + } +} + +} // namespace moe +} // namespace vllm + +void topk_softmax( + torch::Tensor& topk_weights, // [num_tokens, topk] + torch::Tensor& topk_indices, // [num_tokens, topk] + torch::Tensor& token_expert_indices, // [num_tokens, topk] + torch::Tensor& gating_output, // [num_tokens, num_experts] + bool need_renorm) +{ + const int num_experts = gating_output.size(-1); + const int num_tokens = gating_output.numel() / num_experts; + const int topk = topk_weights.size(-1); + + const bool is_pow_2 = (num_experts != 0) && ((num_experts & (num_experts - 1)) == 0); + const bool needs_workspace = !is_pow_2 || num_experts > 256; + const int64_t workspace_size = needs_workspace ? num_tokens * num_experts : 0; + + const at::cuda::OptionalCUDAGuard device_guard(device_of(gating_output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + torch::Tensor softmax_workspace = torch::empty({workspace_size}, gating_output.options()); + vllm::moe::topkGatingSoftmaxKernelLauncher( + gating_output.data_ptr(), + topk_weights.data_ptr(), + topk_indices.data_ptr(), + token_expert_indices.data_ptr(), + softmax_workspace.data_ptr(), + num_tokens, + num_experts, + topk, + need_renorm, + stream); +} + + +void moe_sum( + torch::Tensor& input, // [num_tokens, topk, hidden_size] + torch::Tensor& output) // [num_tokens, hidden_size] +{ + const int hidden_size = input.size(-1); + const int num_tokens = output.numel() / hidden_size; + const int topk = input.size(1); + + dim3 grid(num_tokens); + dim3 block(std::min(hidden_size, 1024)); + + const at::cuda::OptionalCUDAGuard device_guard(device_of(output)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + switch (topk) { + case 2: + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel + <<>>(output.data_ptr(), + input.data_ptr(), hidden_size); + }); + break; + + case 4: + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "moe_sum_kernel", [&] { + vllm::moe::moe_sum_kernel + <<>>(output.data_ptr(), + input.data_ptr(), hidden_size); + }); + break; + + default: + at::sum_out(output, input, 1); + break; + } +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/communication_op.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/communication_op.py new file mode 100644 index 000000000..e13505dc3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/communication_op.py @@ -0,0 +1,32 @@ +from typing import Any, Dict, Optional, Union + +import torch +import torch.distributed + +from .parallel_state import get_tp_group + + +def tensor_model_parallel_all_reduce(input_: torch.Tensor) -> torch.Tensor: + """All-reduce the input tensor across model parallel group.""" + return get_tp_group().all_reduce(input_) + + +def tensor_model_parallel_all_gather(input_: torch.Tensor, + dim: int = -1) -> torch.Tensor: + """All-gather the input tensor across model parallel group.""" + return get_tp_group().all_gather(input_, dim) + + +def tensor_model_parallel_gather(input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """Gather the input tensor across model parallel group.""" + return get_tp_group().gather(input_, dst, dim) + + +def broadcast_tensor_dict(tensor_dict: Optional[Dict[Any, Union[torch.Tensor, + Any]]] = None, + src: int = 0): + if not torch.distributed.is_initialized(): + return tensor_dict + return get_tp_group().broadcast_tensor_dict(tensor_dict, src) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/cuda_wrapper.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/cuda_wrapper.py new file mode 100644 index 000000000..ce1e19a83 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/cuda_wrapper.py @@ -0,0 +1,172 @@ +"""This file is a pure Python wrapper for the cudart library. +It avoids the need to compile a separate shared library, and is +convenient for use when we just need to call a few functions. +""" + +import ctypes +from dataclasses import dataclass +from typing import Any, Dict, List, Optional + +# this line makes it possible to directly load `libcudart.so` using `ctypes` +import torch # noqa + +# from vllm.logger import init_logger + +# logger = init_logger(__name__) + +# === export types and functions from cudart to Python === +# for the original cudart definition, please check +# https://docs.nvidia.com/cuda/cuda-runtime-api/index.html + +cudaError_t = ctypes.c_int +cudaMemcpyKind = ctypes.c_int + + +class cudaIpcMemHandle_t(ctypes.Structure): + _fields_ = [("internal", ctypes.c_byte * 128)] + + +@dataclass +class Function: + name: str + restype: Any + argtypes: List[Any] + + +def find_loaded_library(lib_name) -> Optional[str]: + """ + According to according to https://man7.org/linux/man-pages/man5/proc_pid_maps.5.html, + the file `/proc/self/maps` contains the memory maps of the process, which includes the + shared libraries loaded by the process. We can use this file to find the path of the + a loaded library. + """ # noqa + found = False + with open("/proc/self/maps") as f: + for line in f: + if lib_name in line: + found = True + break + if not found: + # the library is not loaded in the current process + return None + # if lib_name is libcudart, we need to match a line with: + # address /path/to/libcudart-hash.so.11.0 + start = line.index("/") + path = line[start:].strip() + filename = path.split("/")[-1] + assert filename.rpartition(".so")[0].startswith(lib_name), \ + f"Unexpected filename: {filename} for library {lib_name}" + return path + + +class CudaRTLibrary: + exported_functions = [ + # ​cudaError_t cudaSetDevice ( int device ) + Function("cudaSetDevice", cudaError_t, [ctypes.c_int]), + # cudaError_t cudaDeviceSynchronize ( void ) + Function("cudaDeviceSynchronize", cudaError_t, []), + # ​cudaError_t cudaDeviceReset ( void ) + Function("cudaDeviceReset", cudaError_t, []), + + # const char* cudaGetErrorString ( cudaError_t error ) + Function("cudaGetErrorString", ctypes.c_char_p, [cudaError_t]), + + # ​cudaError_t cudaMalloc ( void** devPtr, size_t size ) + Function("cudaMalloc", cudaError_t, + [ctypes.POINTER(ctypes.c_void_p), ctypes.c_size_t]), + # ​cudaError_t cudaFree ( void* devPtr ) + Function("cudaFree", cudaError_t, [ctypes.c_void_p]), + # ​cudaError_t cudaMemset ( void* devPtr, int value, size_t count ) + Function("cudaMemset", cudaError_t, + [ctypes.c_void_p, ctypes.c_int, ctypes.c_size_t]), + # ​cudaError_t cudaMemcpy ( void* dst, const void* src, size_t count, cudaMemcpyKind kind ) # noqa + Function("cudaMemcpy", cudaError_t, [ + ctypes.c_void_p, ctypes.c_void_p, ctypes.c_size_t, cudaMemcpyKind + ]), + + # cudaError_t cudaIpcGetMemHandle ( cudaIpcMemHandle_t* handle, void* devPtr ) # noqa + Function("cudaIpcGetMemHandle", cudaError_t, + [ctypes.POINTER(cudaIpcMemHandle_t), ctypes.c_void_p]), + # ​cudaError_t cudaIpcOpenMemHandle ( void** devPtr, cudaIpcMemHandle_t handle, unsigned int flags ) # noqa + Function("cudaIpcOpenMemHandle", cudaError_t, [ + ctypes.POINTER(ctypes.c_void_p), cudaIpcMemHandle_t, ctypes.c_uint + ]), + ] + + # class attribute to store the mapping from the path to the library + # to avoid loading the same library multiple times + path_to_library_cache: Dict[str, Any] = {} + + # class attribute to store the mapping from library path + # to the corresponding dictionary + path_to_dict_mapping: Dict[str, Dict[str, Any]] = {} + + def __init__(self, so_file: Optional[str] = None): + if so_file is None: + so_file = find_loaded_library("libcudart") + assert so_file is not None, \ + "libcudart is not loaded in the current process" + if so_file not in CudaRTLibrary.path_to_library_cache: + lib = ctypes.CDLL(so_file) + CudaRTLibrary.path_to_library_cache[so_file] = lib + self.lib = CudaRTLibrary.path_to_library_cache[so_file] + + if so_file not in CudaRTLibrary.path_to_dict_mapping: + _funcs = {} + for func in CudaRTLibrary.exported_functions: + f = getattr(self.lib, func.name) + f.restype = func.restype + f.argtypes = func.argtypes + _funcs[func.name] = f + CudaRTLibrary.path_to_dict_mapping[so_file] = _funcs + self.funcs = CudaRTLibrary.path_to_dict_mapping[so_file] + + def CUDART_CHECK(self, result: cudaError_t) -> None: + if result != 0: + error_str = self.cudaGetErrorString(result) + raise RuntimeError(f"CUDART error: {error_str}") + + def cudaGetErrorString(self, error: cudaError_t) -> str: + return self.funcs["cudaGetErrorString"](error).decode("utf-8") + + def cudaSetDevice(self, device: int) -> None: + self.CUDART_CHECK(self.funcs["cudaSetDevice"](device)) + + def cudaDeviceSynchronize(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceSynchronize"]()) + + def cudaDeviceReset(self) -> None: + self.CUDART_CHECK(self.funcs["cudaDeviceReset"]()) + + def cudaMalloc(self, size: int) -> ctypes.c_void_p: + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaMalloc"](ctypes.byref(devPtr), size)) + return devPtr + + def cudaFree(self, devPtr: ctypes.c_void_p) -> None: + self.CUDART_CHECK(self.funcs["cudaFree"](devPtr)) + + def cudaMemset(self, devPtr: ctypes.c_void_p, value: int, + count: int) -> None: + self.CUDART_CHECK(self.funcs["cudaMemset"](devPtr, value, count)) + + def cudaMemcpy(self, dst: ctypes.c_void_p, src: ctypes.c_void_p, + count: int) -> None: + cudaMemcpyDefault = 4 + kind = cudaMemcpyDefault + self.CUDART_CHECK(self.funcs["cudaMemcpy"](dst, src, count, kind)) + + def cudaIpcGetMemHandle(self, + devPtr: ctypes.c_void_p) -> cudaIpcMemHandle_t: + handle = cudaIpcMemHandle_t() + self.CUDART_CHECK(self.funcs["cudaIpcGetMemHandle"]( + ctypes.byref(handle), devPtr)) + return handle + + def cudaIpcOpenMemHandle(self, + handle: cudaIpcMemHandle_t) -> ctypes.c_void_p: + cudaIpcMemLazyEnablePeerAccess = 1 + devPtr = ctypes.c_void_p() + self.CUDART_CHECK(self.funcs["cudaIpcOpenMemHandle"]( + ctypes.byref(devPtr), handle, cudaIpcMemLazyEnablePeerAccess)) + return devPtr diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce.py new file mode 100644 index 000000000..90e70aac0 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce.py @@ -0,0 +1,318 @@ +from contextlib import contextmanager +from typing import Any, List, Optional, Union + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup + +# import vllm.envs as envs +# from vllm import _custom_ops as ops +import rocmKernels as ops +import os +from .custom_all_reduce_utils import ( + gpu_p2p_access_check) +from .parallel_state import in_the_same_node_as +from llm_perf.utils.logger import logger +# from vllm.logger import init_logger +# from vllm.platforms import current_platform + +try: + ops.meta_size() + custom_ar = True +except Exception: + # For CPUs + custom_ar = False + +# logger = init_logger(__name__) + + +def _can_p2p(rank: int, world_size: int) -> bool: + for i in range(world_size): + if i == rank: + continue + if not gpu_p2p_access_check(rank, i): + return False + return True + + +def is_weak_contiguous(inp: torch.Tensor): + return inp.is_contiguous() or (inp.storage().nbytes() - + inp.storage_offset() * inp.element_size() + == inp.numel() * inp.element_size()) + + +class CustomAllreduce: + + _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8] + + # max_size: max supported allreduce size + def __init__(self, + group: ProcessGroup, + device: Union[int, str, torch.device], + max_size=8192 * 1024 * 2) -> None: + """ + Args: + group: the process group to work on. If None, it will use the + default process group. + device: the device to bind the CustomAllreduce to. If None, + it will be bind to f"cuda:{local_rank}". + It is the caller's responsibility to make sure each communicator + is bind to a unique device, and all communicators in this group + are in the same node. + """ + self._IS_CAPTURING = False + self.disabled = True + + if not custom_ar: + # disable because of missing custom allreduce library + # e.g. in a non-cuda environment + return + + self.group = group + + assert dist.get_backend(group) != dist.Backend.NCCL, ( + "CustomAllreduce should be attached to a non-NCCL group.") + + if not all(in_the_same_node_as(group, source_rank=0)): + # No need to initialize custom allreduce for multi-node case. + logger.warning( + "Custom allreduce is disabled because this process group" + " spans across nodes.") + return + + rank = dist.get_rank(group=self.group) + world_size = dist.get_world_size(group=self.group) + if world_size == 1: + # No need to initialize custom allreduce for single GPU case. + return + + if world_size not in CustomAllreduce._SUPPORTED_WORLD_SIZES: + logger.warning( + "Custom allreduce is disabled due to an unsupported world" + " size: %d. Supported world sizes: %s. To silence this " + "warning, specify disable_custom_all_reduce=True explicitly.", + world_size, str(CustomAllreduce._SUPPORTED_WORLD_SIZES)) + return + + if isinstance(device, int): + device = torch.device(f"cuda:{device}") + elif isinstance(device, str): + device = torch.device(device) + # now `device` is a `torch.device` object + assert isinstance(device, torch.device) + self.device = device + + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "-1") + # cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices: + device_ids = list(map(int, cuda_visible_devices.split(","))) + else: + from vllm.utils import cuda_device_count_stateless + device_ids = list(range(cuda_device_count_stateless())) + + physical_device_id = device_ids[device.index] + tensor = torch.tensor([physical_device_id], + dtype=torch.int, + device="cpu") + gather_list = [ + torch.tensor([0], dtype=torch.int, device="cpu") + for _ in range(world_size) + ] + dist.all_gather(gather_list, tensor, group=self.group) + physical_device_ids = [t.item() for t in gather_list] + + # test nvlink first, this will filter out most of the cases + # where custom allreduce is not supported + # this checks hardware and driver support for NVLink + # assert current_platform.is_cuda() or current_platform.is_rocm() + # full_nvlink = current_platform.is_full_nvlink(physical_device_ids) + full_nvlink=True + if world_size > 2 and not full_nvlink: + logger.warning( + "Custom allreduce is disabled because it's not supported on" + " more than two PCIe-only GPUs. To silence this warning, " + "specify disable_custom_all_reduce=True explicitly.") + return + # test P2P capability, this checks software/cudaruntime support + # this is expensive to compute at the first time + # then we cache the result + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + # if not current_platform.is_rocm() and not _can_p2p(rank, world_size): + # logger.warning( + # "Custom allreduce is disabled because your platform lacks " + # "GPU P2P capability or P2P test failed. To silence this " + # "warning, specify disable_custom_all_reduce=True explicitly.") + # return + + self.disabled = False + # buffers memory are owned by this Python class and passed to C++ + # meta data composes of two parts: meta data for synchronization + # (256 bytes) and a temporary buffer for storing intermediate + # allreduce results. + # if current_platform.is_rocm(): + if 1: + # meta data buffers need to be "uncached" for signal on MI200 + self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size) + else: + self.meta = torch.zeros(ops.meta_size() + max_size, + dtype=torch.uint8, + device=self.device) + # This is a pre-registered IPC buffer. In eager mode, input tensors + # are first copied into this buffer before allreduce is performed + self.buffer = torch.empty(max_size, + dtype=torch.uint8, + device=self.device) + # This is a buffer for storing the tuples of pointers pointing to + # IPC buffers from all ranks. Each registered tuple has size of + # 8*world_size bytes where world_size is at most 8. Allocating 8MB + # is enough for 131072 such tuples. The largest model I've seen only + # needs less than 10000 of registered tuples. + self.rank_data = torch.empty(8 * 1024 * 1024, + dtype=torch.uint8, + device=self.device) + self.max_size = max_size + self.rank = rank + self.world_size = world_size + # if current_platform.is_rocm(): + if 1: + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = ops.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + else: + handles, offsets = self._get_ipc_meta(self.meta) + self.full_nvlink = full_nvlink + self._ptr = ops.init_custom_ar(self.meta, self.rank_data, handles, + offsets, rank, self.full_nvlink) + self.register_buffer(self.buffer) + + @contextmanager + def capture(self): + """ + The main responsibility of this context manager is the + `register_graph_buffers` call at the end of the context. + It records all the buffer addresses used in the CUDA graph. + """ + try: + self._IS_CAPTURING = True + yield + finally: + self._IS_CAPTURING = False + if not self.disabled: + self.register_graph_buffers() + + def _get_ipc_meta(self, inp: torch.Tensor): + # if current_platform.is_rocm(): + if 1: + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = ops.get_meta_buffer_ipc_handle(inp) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + else: + data = inp.untyped_storage()._share_cuda_() + shard_data = ( + data[1], # ipc handle to base ptr + data[3], # offset of base ptr + ) + return self._gather_ipc_meta(shard_data) + + def _gather_ipc_meta(self, shard_data): + # Note: don't use `[[None]] * self.world_size` here + # because it will create a list of the same reference + all_data: List[Optional[Any]] = [[None] + for i in range(self.world_size)] + all_data[self.rank][0] = shard_data + + ranks = dist.get_process_group_ranks(group=self.group) + ranks.sort() + for i, rank in enumerate(ranks): + dist.broadcast_object_list(all_data[i], + src=rank, + group=self.group, + device="cpu") + + # we cannot directly use `dist.all_gather_object` here + # because it is incompatible with `gloo` backend under inference mode. + # see https://github.com/pytorch/pytorch/issues/126032 for details. + + handles = [] + offsets = [] + for i in range(len(all_data)): + handles.append(all_data[i][0][0]) # type: ignore + offsets.append(all_data[i][0][1]) # type: ignore + return handles, offsets + + def register_buffer(self, inp: torch.Tensor): + handles, offsets = self._get_ipc_meta(inp) + ops.register_buffer(self._ptr, inp, handles, offsets) + + def register_graph_buffers(self): + handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr) + handles, offsets = self._gather_ipc_meta((bytes(handle), offset)) + logger.info("Registering %d cuda graph addresses", len(offset)) + ops.register_graph_buffers(self._ptr, handles, offsets) + + def should_custom_ar(self, inp: torch.Tensor): + if self.disabled: + return False + inp_size = inp.numel() * inp.element_size() + # custom allreduce requires input byte size to be multiples of 16 + if inp_size % 16 != 0: + return False + if not is_weak_contiguous(inp): + return False + # for 4 or more non NVLink-capable GPUs, custom allreduce provides + # little performance improvement over NCCL. + if self.world_size == 2 or self.full_nvlink: + return inp_size < self.max_size + return False + + # all reduce, assuming inp tensor is IPC registered with register_buffer, + # or, in the context of cuda graphs, register_graph_buffers + def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_reg(self._ptr, inp, out) + return out + + # all reduce, assuming inp tensor is NOT IPC registered + def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None): + if out is None: + out = torch.empty_like(inp) + ops.all_reduce_unreg(self._ptr, inp, self.buffer, out) + return out + + def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]: + # when custom allreduce is disabled, this will be None + if self.disabled or not self.should_custom_ar(input): + return None + if self._IS_CAPTURING: + if torch.cuda.is_current_stream_capturing(): + return self.all_reduce_reg(input) + else: + # if warm up, mimic the allocation pattern + # since custom allreduce is out-of-place + return torch.empty_like(input) + else: + # note: outside of cuda graph context, + # custom allreduce incurs a cost of cudaMemcpy, which should + # be small(<=1% of overall latency) compared to the performance + # gains of using custom kernels + return self.all_reduce_unreg(input) + + return None + + def close(self): + if not self.disabled and self._ptr: + ops.dispose(self._ptr) + self._ptr = 0 + + def __del__(self): + self.close() diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce_utils.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce_utils.py new file mode 100644 index 000000000..80de95138 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/custom_all_reduce_utils.py @@ -0,0 +1,260 @@ +import ctypes +import json +import os +import pickle +import subprocess +import sys +import tempfile +from itertools import product +from typing import Dict, List, Optional, Sequence + +import torch.distributed as dist +import torch.multiprocessing as mp + +# import vllm.envs as envs +VLLM_CACHE_ROOT=os.path.expanduser("~/.cache/vllm") +from .cuda_wrapper import CudaRTLibrary +# from vllm.logger import init_logger +from llm_perf.utils.logger import logger +from .utils import (cuda_device_count_stateless, + update_environment_variables) +from functools import lru_cache, partial, wraps + +# logger = init_logger(__name__) + + +def producer(batch_src: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for i in batch_src: + lib.cudaSetDevice(i) + pointer = lib.cudaMalloc(1024) + lib.cudaMemset(pointer, 1, 1024) + lib.cudaDeviceSynchronize() + handle = lib.cudaIpcGetMemHandle(pointer) + producer_queue.put(handle) + open_success = consumer_queue.get() + if open_success: + # use two queues to simulate barrier + producer_queue.put(0) + consumer_queue.get() + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def consumer(batch_tgt: Sequence[int], + producer_queue, + consumer_queue, + result_queue, + cuda_visible_devices: Optional[str] = None): + if cuda_visible_devices is not None: + update_environment_variables( + {"CUDA_VISIBLE_DEVICES": cuda_visible_devices}) + + lib = CudaRTLibrary() + for j in batch_tgt: + lib.cudaSetDevice(j) + handle = producer_queue.get() + open_success = False + try: + pointer = lib.cudaIpcOpenMemHandle(handle) # type: ignore + open_success = True + except RuntimeError: + # cannot error out here, because the producer process + # is still waiting for the response. + pass + consumer_queue.put(open_success) + if open_success: + # modify the memory + lib.cudaMemset(pointer, 2, 1024) + lib.cudaDeviceSynchronize() + # use two queues to simulate barrier + producer_queue.get() + consumer_queue.put(0) + # check if the memory is modified + host_data = (ctypes.c_char * 1024)() + lib.cudaMemcpy(host_data, pointer, 1024) # type: ignore + for i in range(1024): + if ord(host_data[i]) != 2: + open_success = False + break + result_queue.put(open_success) + lib.cudaDeviceReset() + + +def can_actually_p2p( + batch_src: Sequence[int], + batch_tgt: Sequence[int], +) -> Sequence[bool]: + """ + Usually, checking if P2P access is enabled can be done by + `torch.cuda.can_device_access_peer(src, tgt)`. However, sometimes + the driver might be broken, and `torch.cuda.can_device_access_peer(src, tgt)` + returns `True` even if P2P access is not actually possible. + See https://github.com/vllm-project/vllm/issues/2728 and + https://forums.developer.nvidia.com/t/direct-gpu-gpu-communication-does-not-seem-to-work-properly/283264/10 + Therefore, we have to perform a real P2P access to check if it is actually + possible. + + Note on p2p and cuda IPC: + Usually, one process uses one GPU: + GPU src --> cuda context src --> tensor src --> process src + + We need to combine p2p and cuda IPC, so that: + GPU src --> cuda context src --> tensor src --> process src + |shared| + GPU tgt --> cuda context tgt --> tensor tgt --> process tgt + That is to say, process src creates a tensor in GPU src, passes IPC handle to + process tgt, and process tgt accesses the tensor in GPU tgt. Any operation on the + tensor in process tgt will be reflected in the tensor in process src, because + they are the same memory segment. + It is important to note that process tgt accesses the tensor in GPU tgt, not + GPU src. That's why we need p2p access. + + The most time-consuming part is the process creation. To avoid creating + processes for every pair of GPUs, we use batched testing. We create two + processes for testing all pairs of GPUs in batch. The trick is to reset + the device after each test (which is not available in PyTorch). + """ # noqa + cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", "-1") + # cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + # pass the CUDA_VISIBLE_DEVICES to the child process + # to make sure they see the same set of GPUs + + # make sure the processes are spawned + smp = mp.get_context("spawn") + producer_queue = smp.Queue() + consumer_queue = smp.Queue() + result_queue = smp.Queue() + p_src = smp.Process(target=producer, + args=(batch_src, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_tgt = smp.Process(target=consumer, + args=(batch_tgt, producer_queue, consumer_queue, + result_queue, cuda_visible_devices)) + p_src.start() + p_tgt.start() + p_src.join() + p_tgt.join() + assert p_src.exitcode == 0 and p_tgt.exitcode == 0 + result: List[bool] = [] + for src, tgt in zip(batch_src, batch_tgt): + a = result_queue.get() + b = result_queue.get() + if a != b: + logger.warning( + "Two processes do not agree on the P2P access" + " status on %d -> %d, treat as disabled.", src, tgt) + result.append(False) + else: + result.append(a) + return result + + +# why do we need this cache? +# we are testing peer-to-peer (p2p) access between GPUs,across processes. +# if we test it every time, it will be very slow, because we need to create +# N * N * 2 processes, where N is the world size. This is very slow. +# to reduce the time, we use a cache file to store the p2p access status. +# the cache file is generated by the master process if it does not exist. +# then all the processes can read the cache file to check the p2p access status. +# Note that the cache file is suffixed by the CUDA_VISIBLE_DEVICES, so that we +# can have different cache files for different CUDA_VISIBLE_DEVICES settings, +# e.g. used by different vllm engines. The device id in the cache file is a +# **local** device id, i.e. from 0 to num_dev-1, where num_dev is the number +# of visible devices in the vllm engine. +_gpu_p2p_access_cache: Optional[Dict[str, bool]] = None + + +def gpu_p2p_access_check(src: int, tgt: int) -> bool: + """Check if GPU src can access GPU tgt.""" + + # if the cache variable is already calculated, + # read from the cache instead of checking it again + global _gpu_p2p_access_cache + if _gpu_p2p_access_cache is not None: + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + is_distributed = dist.is_initialized() + + num_dev = cuda_device_count_stateless() + cuda_visible_devices = int(os.environ.get("CUDA_VISIBLE_DEVICES", "-1")) + # cuda_visible_devices = envs.CUDA_VISIBLE_DEVICES + if cuda_visible_devices is None: + cuda_visible_devices = ",".join(str(i) for i in range(num_dev)) + + path = os.path.join( + VLLM_CACHE_ROOT, + f"gpu_p2p_access_cache_for_{cuda_visible_devices}.json") + os.makedirs(os.path.dirname(path), exist_ok=True) + from vllm.distributed.parallel_state import get_world_group + if ((not is_distributed or get_world_group().local_rank == 0) + and (not os.path.exists(path))): + # only the local master process (with local_rank == 0) can + # enter this block to calculate the cache + logger.info("generating GPU P2P access cache in %s", path) + cache: Dict[str, bool] = {} + ids = list(range(num_dev)) + # batch of all pairs of GPUs + batch_src, batch_tgt = zip(*list(product(ids, ids))) + # NOTE: we use `subprocess` rather than `multiprocessing` here + # because the caller might not have `if __name__ == "__main__":`, + # in that case we cannot use spawn method in multiprocessing. + # However, `can_actually_p2p` requires spawn method. + # The fix is, we use `subprocess` to call the function, + # where we have `if __name__ == "__main__":` in this file. + + # use a temporary file to store the result + # we don't use the output of the subprocess directly, + # because the subprocess might produce logging output + with tempfile.NamedTemporaryFile() as output_file: + input_bytes = pickle.dumps( + (batch_src, batch_tgt, output_file.name)) + returned = subprocess.run([sys.executable, __file__], + input=input_bytes, + capture_output=True) + # check if the subprocess is successful + try: + returned.check_returncode() + except Exception as e: + # wrap raised exception to provide more information + raise RuntimeError( + f"Error happened when batch testing " + f"peer-to-peer access from {batch_src} to {batch_tgt}:\n" + f"{returned.stderr.decode()}") from e + with open(output_file.name, "rb") as f: + result = pickle.load(f) + for _i, _j, r in zip(batch_src, batch_tgt, result): + cache[f"{_i}->{_j}"] = r + with open(path, "w") as f: + json.dump(cache, f, indent=4) + if is_distributed: + get_world_group().barrier() + logger.info("reading GPU P2P access cache from %s", path) + with open(path, "r") as f: + cache = json.load(f) + _gpu_p2p_access_cache = cache + return _gpu_p2p_access_cache[f"{src}->{tgt}"] + + +__all__ = ["gpu_p2p_access_check"] + +if __name__ == "__main__": + batch_src, batch_tgt, output_file = pickle.loads(sys.stdin.buffer.read()) + result = can_actually_p2p(batch_src, batch_tgt) + with open(output_file, "wb") as f: + f.write(pickle.dumps(result)) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/parallel_state.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/parallel_state.py new file mode 100644 index 000000000..76b55e3f5 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/parallel_state.py @@ -0,0 +1,1203 @@ +# Copyright 2023 The vLLM team. +# Adapted from +# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +"""vLLM distributed state. +It takes over the control of the distributed environment from PyTorch. +The typical workflow is: + +- call `init_distributed_environment` to initialize the distributed environment. +- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to + initialize the model parallel groups. + +- any code dealing with the distributed stuff + +- call `destroy_model_parallel` to destroy the model parallel groups. +- call `destroy_distributed_environment` to destroy the distributed environment. + +If you only need to use the distributed environment without model/pipeline + parallelism, you can skip the model parallel initialization and destruction + steps. +""" +import contextlib +import pickle +import weakref +from collections import namedtuple +from contextlib import contextmanager, nullcontext +from dataclasses import dataclass +from multiprocessing import shared_memory +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +from unittest.mock import patch + +import torch +import torch.distributed +from torch.distributed import Backend, ProcessGroup + +import os +# import vllm.envs as envs +# from vllm.logger import init_logger +# import logging +# init_logger = logging.getLogger +from llm_perf.utils.logger import logger +# from vllm.platforms import current_platform +# from vllm.utils import supports_custom_op +supports_custom_op=lambda:True + + +@dataclass +class GraphCaptureContext: + stream: torch.cuda.Stream + + +TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"]) + + +def _split_tensor_dict( + tensor_dict: Dict[str, Union[torch.Tensor, Any]] +) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]: + """Split the tensor dictionary into two parts: + 1. A list of (key, value) pairs. If the value is a tensor, it is replaced + by its metadata. + 2. A list of tensors. + """ + metadata_list: List[Tuple[str, Any]] = [] + tensor_list: List[torch.Tensor] = [] + for key, value in tensor_dict.items(): + if isinstance(value, torch.Tensor): + # Note: we cannot use `value.device` here, + # because it contains not only the device type but also the device + # index (e.g. "cuda:0"). We only need the device type. + # receiving side will set the device index. + device = value.device.type + metadata_list.append( + (key, TensorMetadata(device, value.dtype, value.size()))) + tensor_list.append(value) + else: + metadata_list.append((key, value)) + return metadata_list, tensor_list + + +_group_name_counter: Dict[str, int] = {} + + +def _get_unique_name(name: str) -> str: + """Get a unique name for the group. + Example: + _get_unique_name("tp") -> "tp:0" + _get_unique_name("tp") -> "tp:1" + """ + if name not in _group_name_counter: + _group_name_counter[name] = 0 + newname = f"{name}:{_group_name_counter[name]}" + _group_name_counter[name] += 1 + return newname + + +_groups: Dict[str, Callable[[], "GroupCoordinator"]] = {} + + +def _register_group(group: "GroupCoordinator") -> None: + # looks like Python 3.8 does not understand `ReferenceType` + _groups[group.unique_name] = weakref.ref(group) # type: ignore + + +if supports_custom_op(): + # @torch.library.custom_op("vllm::inplace_all_reduce", + # mutates_args=["tensor"]) + def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + group._all_reduce_in_place(tensor) + + # @inplace_all_reduce.register_fake + # def _(tensor: torch.Tensor, group_name: str) -> None: + # return + + # @torch.library.custom_op("vllm::outplace_all_reduce", mutates_args=[]) + def outplace_all_reduce(tensor: torch.Tensor, + group_name: str) -> torch.Tensor: + assert group_name in _groups, f"Group {group_name} is not found." + group = _groups[group_name]() + if group is None: + raise ValueError(f"Group {group_name} is destroyed.") + return group._all_reduce_out_place(tensor) + + # @outplace_all_reduce.register_fake + # def _(tensor: torch.Tensor, group_name: str) -> torch.Tensor: + # return torch.empty_like(tensor) + + +class GroupCoordinator: + """ + PyTorch ProcessGroup wrapper for a group of processes. + PyTorch ProcessGroup is bound to one specific communication backend, + e.g. NCCL, Gloo, MPI, etc. + GroupCoordinator takes charge of all the communication operations among + the processes in the group. It can route the communication to + a specific implementation (e.g. switch allreduce implementation + based on the tensor size and cuda graph mode). + """ + + # available attributes: + rank: int # global rank + ranks: List[int] # global ranks in the group + world_size: int # size of the group + # difference between `local_rank` and `rank_in_group`: + # if we have a group of size 4 across two nodes: + # Process | Node | Rank | Local Rank | Rank in Group + # 0 | 0 | 0 | 0 | 0 + # 1 | 0 | 1 | 1 | 1 + # 2 | 1 | 2 | 0 | 2 + # 3 | 1 | 3 | 1 | 3 + local_rank: int # local rank used to assign devices + rank_in_group: int # rank inside the group + cpu_group: ProcessGroup # group for CPU communication + device_group: ProcessGroup # group for device communication + use_pynccl: bool # a hint of whether to use PyNccl + use_custom_allreduce: bool # a hint of whether to use CustomAllreduce + # communicators are only created for world size > 1 + pynccl_comm: Optional[Any] # PyNccl communicator + ca_comm: Optional[Any] # Custom allreduce communicator + mq_broadcaster: Optional[Any] # shared memory broadcaster + + def __init__( + self, + group_ranks: List[List[int]], + local_rank: int, + torch_distributed_backend: Union[str, Backend], + use_pynccl: bool, + use_custom_allreduce: bool, + use_tpu_communicator: bool, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, + ): + group_name = group_name or "anonymous" + self.unique_name = _get_unique_name(group_name) + _register_group(self) + + self.rank = torch.distributed.get_rank() + self.local_rank = local_rank + self.device_group = None + self.cpu_group = None + + for ranks in group_ranks: + device_group = torch.distributed.new_group( + ranks, backend=torch_distributed_backend) + # a group with `gloo` backend, to allow direct coordination between + # processes through the CPU. + cpu_group = torch.distributed.new_group(ranks, backend="gloo") + if self.rank in ranks: + self.ranks = ranks + self.world_size = len(ranks) + self.rank_in_group = ranks.index(self.rank) + self.device_group = device_group + self.cpu_group = cpu_group + + assert self.cpu_group is not None + assert self.device_group is not None + + self.device = torch.device(f"cuda:{local_rank}") + # if current_platform.is_cuda_alike(): + # self.device = torch.device(f"cuda:{local_rank}") + # else: + # self.device = torch.device("cpu") + + self.use_pynccl = use_pynccl + self.use_custom_allreduce = use_custom_allreduce + self.use_tpu_communicator = use_tpu_communicator + + # lazy import to avoid documentation build error + from .custom_all_reduce import ( + CustomAllreduce) + # from vllm.distributed.device_communicators.pynccl import ( + # PyNcclCommunicator) + + self.pynccl_comm = None + if use_pynccl and self.world_size > 1: + self.pynccl_comm = PyNcclCommunicator( + group=self.cpu_group, + device=self.device, + ) + + self.ca_comm: Optional[CustomAllreduce] = None + if use_custom_allreduce and self.world_size > 1: + # Initialize a custom fast all-reduce implementation. + self.ca_comm = CustomAllreduce( + group=self.cpu_group, + device=self.device, + ) + + # from vllm.distributed.device_communicators.tpu_communicator import ( + # TpuCommunicator) + self.tpu_communicator = None + if use_tpu_communicator and self.world_size > 1: + self.tpu_communicator = TpuCommunicator(group=self.cpu_group) + + from .shm_broadcast import (MessageQueue) + self.mq_broadcaster = None + if use_message_queue_broadcaster and self.world_size > 1: + self.mq_broadcaster = MessageQueue.create_from_process_group( + self.cpu_group, 1 << 22, 6) + + @property + def first_rank(self): + """Return the global rank of the first process in the group""" + return self.ranks[0] + + @property + def last_rank(self): + """Return the global rank of the last process in the group""" + return self.ranks[-1] + + @property + def is_first_rank(self): + """Return whether the caller is the first process in the group""" + return self.rank == self.first_rank + + @property + def is_last_rank(self): + """Return whether the caller is the last process in the group""" + return self.rank == self.last_rank + + @property + def next_rank(self): + """Return the global rank of the process that follows the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group + 1) % world_size] + + @property + def prev_rank(self): + """Return the global rank of the process that precedes the caller""" + rank_in_group = self.rank_in_group + world_size = self.world_size + return self.ranks[(rank_in_group - 1) % world_size] + + @contextmanager + def graph_capture( + self, graph_capture_context: Optional[GraphCaptureContext] = None): + if graph_capture_context is None: + stream = torch.cuda.Stream() + graph_capture_context = GraphCaptureContext(stream) + else: + stream = graph_capture_context.stream + + ca_comm = self.ca_comm + maybe_ca_context = nullcontext( + ) if ca_comm is None else ca_comm.capture() + + # ensure all initialization operations complete before attempting to + # capture the graph on another stream + curr_stream = torch.cuda.current_stream() + if curr_stream != stream: + stream.wait_stream(curr_stream) + + with torch.cuda.stream(stream), maybe_ca_context: + # In graph mode, we have to be very careful about the collective + # operations. The current status is: + # allreduce \ Mode | Eager | Graph | + # -------------------------------------------- + # custom allreduce | enabled | enabled | + # PyNccl | disabled| enabled | + # torch.distributed | enabled | disabled| + # + # Note that custom allreduce will have a runtime check, if the + # tensor size is too large, it will fallback to the next + # available option. + # In summary: When using CUDA graph, we use + # either custom all-reduce kernel or pynccl. When not using + # CUDA graph, we use either custom all-reduce kernel or + # PyTorch NCCL. We always prioritize using custom all-reduce + # kernel but fall back to PyTorch or pynccl if it is + # disabled or not supported. + pynccl_comm = self.pynccl_comm + maybe_pynccl_context: Any + if not pynccl_comm: + maybe_pynccl_context = nullcontext() + else: + maybe_pynccl_context = pynccl_comm.change_state( + enable=True, stream=torch.cuda.current_stream()) + with maybe_pynccl_context: + yield graph_capture_context + + def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: + """ + User-facing all-reduce function before we actually call the + all-reduce operation. + + We need this because Dynamo does not support passing an arbitrary + object (`self` in this case) to a custom op. We need to pass the + group name as a string, and then look up the group coordinator from + the group name, dispatch the all-reduce operation to the group + coordinator. + + In addition, PyTorch custom ops do not support mutation or returning + a new tensor in the same op. So we need to figure out if the op is + in-place or out-of-place ahead of time. + """ + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + + if not supports_custom_op(): + self._all_reduce_in_place(input_) + return input_ + + if self.tpu_communicator is not None and \ + not self.tpu_communicator.disabled: + # TPU handles Dynamo with its own logic. + return self.tpu_communicator.all_reduce(input_) + + if self.ca_comm is not None and \ + not self.ca_comm.disabled and \ + self.ca_comm.should_custom_ar(input_): + return outplace_all_reduce( + input_, group_name=self.unique_name) + else: + inplace_all_reduce(input_, + group_name=self.unique_name) + return input_ + + def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor: + ca_comm = self.ca_comm + assert ca_comm is not None + assert not ca_comm.disabled + out = ca_comm.custom_all_reduce(input_) + assert out is not None + return out + + def _all_reduce_in_place(self, input_: torch.Tensor) -> None: + pynccl_comm = self.pynccl_comm + if (pynccl_comm is not None and not pynccl_comm.disabled): + pynccl_comm.all_reduce(input_) + elif input_.is_cpu: + import intel_extension_for_pytorch as ipex + ipex.distributed.all_reduce(input_, group=self.device_group) + else: + torch.distributed.all_reduce(input_, group=self.device_group) + + def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + + # For TPUs, use TPU communicator. + tpu_comm = self.tpu_communicator + if tpu_comm is not None and not tpu_comm.disabled: + return tpu_comm.all_gather(input_, dim) + + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + input_size = input_.size() + # Allocate output tensor. + output_tensor = torch.empty((world_size, ) + input_size, + dtype=input_.dtype, + device=input_.device) + # All-gather. + torch.distributed.all_gather_into_tensor(output_tensor, + input_, + group=self.device_group) + # Reshape + output_tensor = output_tensor.movedim(0, dim) + output_tensor = output_tensor.reshape(input_size[:dim] + + (world_size * + input_size[dim], ) + + input_size[dim + 1:]) + return output_tensor + + def gather(self, + input_: torch.Tensor, + dst: int = 0, + dim: int = -1) -> Optional[torch.Tensor]: + """ + NOTE: We assume that the input tensor is on the same device across + all the ranks. + NOTE: `dst` is the local rank of the destination rank. + """ + world_size = self.world_size + # Bypass the function if we are using only 1 GPU. + if world_size == 1: + return input_ + assert -input_.dim() <= dim < input_.dim(), ( + f"Invalid dim ({dim}) for input tensor with shape {input_.size()}") + if dim < 0: + # Convert negative dim to positive. + dim += input_.dim() + # Allocate output tensor. + if self.rank_in_group == dst: + gather_list = [torch.empty_like(input_) for _ in range(world_size)] + else: + gather_list = None + # Gather. + torch.distributed.gather(input_, + gather_list, + dst=self.ranks[dst], + group=self.device_group) + if self.rank_in_group == dst: + output_tensor = torch.cat(gather_list, dim=dim) + else: + output_tensor = None + return output_tensor + + def broadcast(self, input_: torch.Tensor, src: int = 0): + """Broadcast the input tensor. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return input_ + # Broadcast. + torch.distributed.broadcast(input_, + src=self.ranks[src], + group=self.device_group) + return input_ + + def broadcast_object(self, obj: Optional[Any] = None, src: int = 0): + """Broadcast the input object. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj + if self.mq_broadcaster is not None: + assert src == 0, "Message queue broadcaster only supports src=0" + return self.mq_broadcaster.broadcast_object(obj) + if self.rank_in_group == src: + torch.distributed.broadcast_object_list([obj], + src=self.ranks[src], + group=self.cpu_group) + return obj + else: + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=self.ranks[src], + group=self.cpu_group) + return recv[0] + + def broadcast_object_list(self, + obj_list: List[Any], + src: int = 0, + group: Optional[ProcessGroup] = None): + """Broadcast the input object list. + NOTE: `src` is the local rank of the source rank. + """ + assert src < self.world_size, f"Invalid src rank ({src})" + + # Bypass the function if we are using only 1 GPU. + if self.world_size == 1: + return obj_list + # Broadcast. + torch.distributed.broadcast_object_list(obj_list, + src=self.ranks[src], + group=self.device_group) + return obj_list + + def send_object(self, obj: Any, dst: int) -> None: + """Send the input object list to the destination rank.""" + """NOTE: `dst` is the local rank of the destination rank.""" + + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + assert dst != self.rank_in_group, ( + "Invalid destination rank. Destination rank is the same " + "as the current rank.") + + # Serialize object to tensor and get the size as well + object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8) + + size_tensor = torch.tensor([object_tensor.numel()], + dtype=torch.long, + device="cpu") + + # Send object size + + torch.distributed.send(size_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + # Send object + torch.distributed.send(object_tensor, + dst=self.ranks[dst], + group=self.cpu_group) + + return None + + def recv_object(self, src: int) -> Any: + """Receive the input object list from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + + assert src < self.world_size, f"Invalid src rank ({src})" + + assert src != self.rank_in_group, ( + "Invalid source rank. Source rank is the same as the current rank." + ) + + size_tensor = torch.empty(1, dtype=torch.long, device="cpu") + + # Receive object size + rank_size = torch.distributed.recv(size_tensor, + src=self.ranks[src], + group=self.cpu_group) + + # Tensor to receive serialized objects into. + object_tensor = torch.empty( # type: ignore[call-overload] + size_tensor.item(), # type: ignore[arg-type] + dtype=torch.uint8, + device="cpu") + + rank_object = torch.distributed.recv(object_tensor, + src=self.ranks[src], + group=self.cpu_group) + + assert rank_object == rank_size, ( + "Received object sender rank does not match the size sender rank.") + + obj = pickle.loads(object_tensor.numpy().tobytes()) + + return obj + + def broadcast_tensor_dict( + self, + tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None, + src: int = 0, + group: Optional[ProcessGroup] = None, + metadata_group: Optional[ProcessGroup] = None + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Broadcast the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if (not torch.distributed.is_initialized() or self.world_size == 1): + return tensor_dict + + group = self.device_group + metadata_group = self.cpu_group + assert src < self.world_size, f"Invalid src rank ({src})" + + rank_in_group = self.rank_in_group + if rank_in_group == src: + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), (f"Expecting a dictionary, got {type(tensor_dict)}") + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `broadcast_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.broadcast_object(metadata_list, src=src) + async_handles = [] + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast(tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + for async_handle in async_handles: + async_handle.wait() + + else: + metadata_list = self.broadcast_object(None, src=src) + tensor_dict = {} + async_handles = [] + for key, value in metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + if tensor.is_cpu: + # use metadata_group for CPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=metadata_group, + async_op=True) + else: + # use group for GPU tensors + handle = torch.distributed.broadcast( + tensor, + src=self.ranks[src], + group=group, + async_op=True) + async_handles.append(handle) + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + for async_handle in async_handles: + async_handle.wait() + return tensor_dict + + def send_tensor_dict( + self, + tensor_dict: Dict[str, Union[torch.Tensor, Any]], + dst: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Send the input tensor dictionary. + NOTE: `dst` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return tensor_dict + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + assert dst < self.world_size, f"Invalid dst rank ({dst})" + + metadata_list: List[Tuple[Any, Any]] = [] + assert isinstance( + tensor_dict, + dict), f"Expecting a dictionary, got {type(tensor_dict)}" + metadata_list, tensor_list = _split_tensor_dict(tensor_dict) + # `metadata_list` lives in CPU memory. + # `send_object_list` has serialization & deserialization, + # all happening on CPU. Therefore, we can use the CPU group. + self.send_object(metadata_list, dst=dst) + for tensor in tensor_list: + if tensor.numel() == 0: + # Skip sending empty tensors. + continue + + # send-allgather: send only a slice, then do allgather. + if (all_gather_group is not None + and tensor.numel() % all_gather_size == 0): + tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.send(tensor, + dst=self.ranks[dst], + group=group) + return None + + def recv_tensor_dict( + self, + src: Optional[int] = None, + all_gather_group: Optional["GroupCoordinator"] = None, + ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]: + """Recv the input tensor dictionary. + NOTE: `src` is the local rank of the source rank. + """ + # Bypass the function if we are using only 1 GPU. + if not torch.distributed.is_initialized() or self.world_size == 1: + return None + + all_gather_size = (1 if all_gather_group is None else + all_gather_group.world_size) + all_gather_rank = (0 if all_gather_group is None else + all_gather_group.rank_in_group) + + group = self.device_group + metadata_group = self.cpu_group + + if src is None: + src = (self.rank_in_group - 1) % self.world_size + assert src < self.world_size, f"Invalid src rank ({src})" + + recv_metadata_list = self.recv_object(src=src) + tensor_dict: Dict[str, Any] = {} + for key, value in recv_metadata_list: + if isinstance(value, TensorMetadata): + tensor = torch.empty(value.size, + dtype=value.dtype, + device=value.device) + if tensor.numel() == 0: + # Skip broadcasting empty tensors. + tensor_dict[key] = tensor + continue + + # send-allgather: send only a slice, then do allgather. + use_all_gather = (all_gather_group is not None + and tensor.numel() % all_gather_size == 0) + + if use_all_gather: + orig_shape = tensor.shape + tensor = tensor.reshape(all_gather_size, + -1)[all_gather_rank] + + if tensor.is_cpu: + # use metadata_group for CPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=metadata_group) + else: + # use group for GPU tensors + torch.distributed.recv(tensor, + src=self.ranks[src], + group=group) + if use_all_gather: + # do the allgather + tensor = all_gather_group.all_gather( # type: ignore + tensor, dim=0) + tensor = tensor.reshape(orig_shape) + + tensor_dict[key] = tensor + else: + tensor_dict[key] = value + return tensor_dict + + def barrier(self): + """Barrier synchronization among the group. + NOTE: don't use `device_group` here! `barrier` in NCCL is + terrible because it is internally a broadcast operation with + secretly created GPU tensors. It is easy to mess up the current + device. Use the CPU group instead. + """ + torch.distributed.barrier(group=self.cpu_group) + + def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None: + """Sends a tensor to the destination rank in a non-blocking way""" + """NOTE: `dst` is the local rank of the destination rank.""" + if dst is None: + dst = (self.rank_in_group + 1) % self.world_size + + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.send(tensor, dst) + else: + torch.distributed.send(tensor, self.ranks[dst], self.device_group) + + def recv(self, + size: torch.Size, + dtype: torch.dtype, + src: Optional[int] = None) -> torch.Tensor: + """Receives a tensor from the source rank.""" + """NOTE: `src` is the local rank of the source rank.""" + if src is None: + src = (self.rank_in_group - 1) % self.world_size + + tensor = torch.empty(size, dtype=dtype, device=self.device) + pynccl_comm = self.pynccl_comm + if pynccl_comm is not None and not pynccl_comm.disabled: + pynccl_comm.recv(tensor, src) + else: + torch.distributed.recv(tensor, self.ranks[src], self.device_group) + return tensor + + def destroy(self): + if self.device_group is not None: + torch.distributed.destroy_process_group(self.device_group) + self.device_group = None + if self.cpu_group is not None: + torch.distributed.destroy_process_group(self.cpu_group) + self.cpu_group = None + if self.pynccl_comm is not None: + self.pynccl_comm = None + if self.ca_comm is not None: + self.ca_comm = None + if self.mq_broadcaster is not None: + self.mq_broadcaster = None + + +_WORLD: Optional[GroupCoordinator] = None + + +def get_world_group() -> GroupCoordinator: + assert _WORLD is not None, ("world group is not initialized") + return _WORLD + + +def init_world_group(ranks: List[int], local_rank: int, + backend: str) -> GroupCoordinator: + return GroupCoordinator( + group_ranks=[ranks], + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=False, + use_tpu_communicator=False, + group_name="world", + ) + + +def init_model_parallel_group( + group_ranks: List[List[int]], + local_rank: int, + backend: str, + use_custom_allreduce: Optional[bool] = None, + use_message_queue_broadcaster: bool = False, + group_name: Optional[str] = None, +) -> GroupCoordinator: + if use_custom_allreduce is None: + use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE + return GroupCoordinator( + group_ranks=group_ranks, + local_rank=local_rank, + torch_distributed_backend=backend, + use_pynccl=False, + use_custom_allreduce=use_custom_allreduce, + use_tpu_communicator=False, + use_message_queue_broadcaster=use_message_queue_broadcaster, + group_name=group_name, + ) + + +_TP: Optional[GroupCoordinator] = None + + +def get_tp_group() -> GroupCoordinator: + assert _TP is not None, ("tensor model parallel group is not initialized") + return _TP + + +# kept for backward compatibility +get_tensor_model_parallel_group = get_tp_group + +_PP: Optional[GroupCoordinator] = None + + +def get_pp_group() -> GroupCoordinator: + assert _PP is not None, ( + "pipeline model parallel group is not initialized") + return _PP + + +# kept for backward compatibility +get_pipeline_model_parallel_group = get_pp_group + + +@contextmanager +def graph_capture(): + """ + `graph_capture` is a context manager which should surround the code that + is capturing the CUDA graph. Its main purpose is to ensure that the + some operations will be run after the graph is captured, before the graph + is replayed. It returns a `GraphCaptureContext` object which contains the + necessary data for the graph capture. Currently, it only contains the + stream that the graph capture is running on. This stream is set to the + current CUDA stream when the context manager is entered and reset to the + default stream when the context manager is exited. This is to ensure that + the graph capture is running on a separate stream from the default stream, + in order to explicitly distinguish the kernels to capture + from other kernels possibly launched on background in the default stream. + """ + with get_tp_group().graph_capture() as context, get_pp_group( + ).graph_capture(context): + yield context + + +# logger = init_logger(__name__) + +_ENABLE_CUSTOM_ALL_REDUCE = True + + +def set_custom_all_reduce(enable: bool): + global _ENABLE_CUSTOM_ALL_REDUCE + _ENABLE_CUSTOM_ALL_REDUCE = enable + + +def init_distributed_environment( + world_size: int = -1, + rank: int = -1, + distributed_init_method: str = "env://", + local_rank: int = -1, + backend: str = "nccl", +): + logger.debug( + "world_size=%d rank=%d local_rank=%d " + "distributed_init_method=%s backend=%s", world_size, rank, local_rank, + distributed_init_method, backend) + if not torch.distributed.is_initialized(): + assert distributed_init_method is not None, ( + "distributed_init_method must be provided when initializing " + "distributed environment") + if "CUDA_VISIBLE_DEVICES" not in os.environ: + from .utils import update_environment_variables + update_environment_variables({ + "CUDA_VISIBLE_DEVICES": (",".join(map(str, range(world_size)))) + }) + # this backend is used for WORLD + torch.distributed.init_process_group( + backend=backend, + # init_method=distributed_init_method, + world_size=world_size, + rank=rank) + # set the local rank + # local_rank is not available in torch ProcessGroup, + # see https://github.com/pytorch/pytorch/issues/122816 + if local_rank == -1: + # local rank not set, this usually happens in single-node + # setting, where we can use rank as local rank + if distributed_init_method == "env://": + # local_rank = envs.LOCAL_RANK + local_rank = os.environ.get("LOCAL_RANK", "0") + else: + local_rank = rank + global _WORLD + if _WORLD is None: + ranks = list(range(torch.distributed.get_world_size())) + _WORLD = init_world_group(ranks, local_rank, backend) + else: + assert _WORLD.world_size == torch.distributed.get_world_size(), ( + "world group already initialized with a different world size") + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: Optional[str] = None, +) -> None: + """ + Initialize model parallel groups. + + Arguments: + tensor_model_parallel_size: number of GPUs used for tensor model + parallelism. + pipeline_model_parallel_size: number of GPUs used for pipeline model + parallelism. + + Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we + use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize + the model pipeline. The present function will + create 4 tensor model-parallel groups and 2 pipeline model-parallel groups: + 4 tensor model-parallel groups: + [g0, g1], [g2, g3], [g4, g5], [g6, g7] + 2 pipeline model-parallel groups: + [g0, g2, g4, g6], [g1, g3, g5, g7] + Note that for efficiency, the caller should make sure adjacent ranks + are on the same DGX box. For example if we are using 2 DGX-1 boxes + with a total of 16 GPUs, rank 0 to 7 belong to the first box and + ranks 8 to 15 belong to the second box. + """ + # Get world size and rank. Ensure some consistencies. + assert torch.distributed.is_initialized() + world_size: int = torch.distributed.get_world_size() + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + + if (world_size != + tensor_model_parallel_size * pipeline_model_parallel_size): + raise RuntimeError( + f"world_size ({world_size}) is not equal to " + f"tensor_model_parallel_size ({tensor_model_parallel_size}) x " + f"pipeline_model_parallel_size ({pipeline_model_parallel_size})") + + # Build the tensor model-parallel groups. + num_tensor_model_parallel_groups: int = (world_size // + tensor_model_parallel_size) + global _TP + assert _TP is None, ("tensor model parallel group is already initialized") + group_ranks = [] + for i in range(num_tensor_model_parallel_groups): + ranks = list( + range(i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size)) + group_ranks.append(ranks) + + # message queue broadcaster is only used in tensor model parallel group + _TP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_message_queue_broadcaster=True, + group_name="tp") + + # Build the pipeline model-parallel groups. + num_pipeline_model_parallel_groups: int = (world_size // + pipeline_model_parallel_size) + global _PP + assert _PP is None, ( + "pipeline model parallel group is already initialized") + group_ranks = [] + for i in range(num_pipeline_model_parallel_groups): + ranks = list(range(i, world_size, num_pipeline_model_parallel_groups)) + group_ranks.append(ranks) + # pipeline parallel does not need custom allreduce + _PP = init_model_parallel_group(group_ranks, + get_world_group().local_rank, + backend, + use_custom_allreduce=False, + group_name="pp") + + +def ensure_model_parallel_initialized( + tensor_model_parallel_size: int, + pipeline_model_parallel_size: int, + backend: Optional[str] = None, +) -> None: + """Helper to initialize model parallel groups if they are not initialized, + or ensure tensor-parallel and pipeline-parallel sizes are equal to expected + values if the model parallel groups are initialized. + """ + backend = backend or torch.distributed.get_backend( + get_world_group().device_group) + if not model_parallel_is_initialized(): + initialize_model_parallel(tensor_model_parallel_size, + pipeline_model_parallel_size, backend) + return + + assert ( + get_tensor_model_parallel_world_size() == tensor_model_parallel_size + ), ("tensor parallel group already initialized, but of unexpected size: " + f"{get_tensor_model_parallel_world_size()=} vs. " + f"{tensor_model_parallel_size=}") + pp_world_size = get_pp_group().world_size + assert (pp_world_size == pipeline_model_parallel_size), ( + "pipeline parallel group already initialized, but of unexpected size: " + f"{pp_world_size=} vs. " + f"{pipeline_model_parallel_size=}") + + +def model_parallel_is_initialized(): + """Check if tensor and pipeline parallel groups are initialized.""" + return (_TP is not None and _PP is not None) + + +_TP_STATE_PATCHED = False + + +@contextmanager +def patch_tensor_parallel_group(tp_group: GroupCoordinator): + """Patch the tp group temporarily until this function ends. + + This method is for draft workers of speculative decoding to run draft model + with different tp degree from that of target model workers. + + Args: + tp_group (GroupCoordinator): the tp group coordinator + """ + global _TP_STATE_PATCHED + assert not _TP_STATE_PATCHED, "Should not call when it's already patched" + + _TP_STATE_PATCHED = True + old_tp_group = get_tp_group() + global _TP + _TP = tp_group + try: + yield + finally: + # restore the original state + _TP_STATE_PATCHED = False + _TP = old_tp_group + + +def get_tensor_model_parallel_world_size(): + """Return world size for the tensor model parallel group.""" + return get_tp_group().world_size + + +def get_tensor_model_parallel_rank(): + """Return my rank for the tensor model parallel group.""" + return get_tp_group().rank_in_group + + +def destroy_model_parallel(): + """Set the groups to none and destroy them.""" + global _TP + if _TP: + _TP.destroy() + _TP = None + + global _PP + if _PP: + _PP.destroy() + _PP = None + + +def destroy_distributed_environment(): + global _WORLD + if _WORLD: + _WORLD.destroy() + _WORLD = None + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + +def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]: + """ + This is a collective operation that returns if each rank is in the same node + as the source rank. It tests if processes are attached to the same + memory system (shared access to shared memory). + """ + assert torch.distributed.get_backend( + pg) != torch.distributed.Backend.NCCL, ( + "in_the_same_node_as should be tested with a non-NCCL group.") + # local rank inside the group + rank = torch.distributed.get_rank(group=pg) + world_size = torch.distributed.get_world_size(group=pg) + + # local tensor in each process to store the result + is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32) + + # global ranks of the processes in the group + ranks = torch.distributed.get_process_group_ranks(pg) + + magic_message = b"magic_message" + shm = None + + try: + with contextlib.suppress(OSError): + if rank == source_rank: + # create a shared memory segment + shm = shared_memory.SharedMemory(create=True, size=128) + shm.buf[:len(magic_message)] = magic_message + torch.distributed.broadcast_object_list([shm.name], + src=ranks[source_rank], + group=pg) + is_in_the_same_node[rank] = 1 + else: + # try to open the shared memory segment + recv = [None] + torch.distributed.broadcast_object_list(recv, + src=ranks[source_rank], + group=pg) + name = recv[0] + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + shm = shared_memory.SharedMemory(name=name) + if shm.buf[:len(magic_message)] == magic_message: + is_in_the_same_node[rank] = 1 + except Exception as e: + logger.error("Error ignored in is_in_the_same_node: %s", e) + finally: + if shm: + shm.close() + + torch.distributed.barrier(group=pg) + + # clean up the shared memory segment + with contextlib.suppress(OSError): + if rank == source_rank and shm: + shm.unlink() + torch.distributed.all_reduce(is_in_the_same_node, group=pg) + + return [x == 1 for x in is_in_the_same_node.tolist()] diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/shm_broadcast.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/shm_broadcast.py new file mode 100644 index 000000000..34a23256c --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/shm_broadcast.py @@ -0,0 +1,492 @@ +import pickle +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from multiprocessing import shared_memory +from typing import List, Optional +from unittest.mock import patch + +import torch +import torch.distributed as dist +from torch.distributed import ProcessGroup +from zmq import IPV6 # type: ignore +from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore + +# import vllm.envs as envs +from llm_perf.utils.logger import logger +# from vllm.logger import init_logger +from .utils import get_ip, get_open_port, is_valid_ipv6_address + +VLLM_RINGBUFFER_WARNING_INTERVAL = 60 + +# time to wait if the queue is full or empty +# if we sleep for too short, it will consume too much CPU +# if we sleep for too long, it will slow down the writer/reader +# 0.1 us is a good balance +RINGBUFFER_SLEEP_INTERVAL = 1e-7 + +# logger = init_logger(__name__) + + +class ShmRingBuffer: + + def __init__(self, + n_reader: int, + max_chunk_bytes: int, + max_chunks: int, + name: Optional[str] = None): + """ + A shared memory ring buffer implementation for broadcast communication. + Essentially, it is a queue where only one will `enqueue` and multiple + will `dequeue`. The max size of each item, together with the max number + of items that can be stored in the buffer are known in advance. + In this case, we don't need to synchronize the access to + the buffer. + + Buffer memory layout: + data metadata + | | + | (current_idx) | (current_idx) + v v + +-------------------------------+----------------------------------------+ + | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata | + +-------------------------------+----------------------------------------+ + | max_chunks x max_chunk_bytes | max_chunks x (1 + n_reader) bytes | + + metadata memory layout: each byte is a flag, the first byte is the written + flag, and the rest are reader flags. The flags are set to 0 by default. + +--------------+--------------+--------------+-----+--------------+ + | written_flag | reader0_flag | reader1_flag | ... | readerN_flag | + +--------------+--------------+--------------+-----+--------------+ + + The state of metadata is as follows: + + (case 1) 0???...???: the block is not written yet, cannot read, can write + (case 2) 1000...000: the block is just written, can read, cannot write + (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write + (case 4) 1111...111: the block is written and read by all readers, cannot read, can write + + State transition for readers: + + When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read. + Only after the caller finishes reading the block, the reader can mark the block as read. + Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0). + + State transition for writer: + + When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case + to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer + can reset the reader flags to 0, and mark the block as written (from 0 to 1). + NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct. + + During creation, `name` is None and the buffer is created. We can pass the + created object to other processes by pickling it. The other processes will + get the name of the shared memory and open it, so that they can access the + same shared memory buffer. + """# noqa + self.n_reader = n_reader + self.metadata_size = 1 + n_reader + self.max_chunk_bytes = max_chunk_bytes + self.max_chunks = max_chunks + self.total_bytes_of_buffer = (self.max_chunk_bytes + + self.metadata_size) * self.max_chunks + self.data_offset = 0 + self.metadata_offset = self.max_chunk_bytes * self.max_chunks + + if name is None: + # we are creating a buffer + self.is_creator = True + self.shared_memory = shared_memory.SharedMemory( + create=True, size=self.total_bytes_of_buffer) + # initialize the metadata section to 0 + with memoryview(self.shared_memory.buf[self.metadata_offset:] + ) as metadata_buffer: + torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0) + else: + # we are opening an existing buffer + self.is_creator = False + # fix to https://stackoverflow.com/q/62748654/9191338 + # Python incorrectly tracks shared memory even if it is not + # created by the process. The following patch is a workaround. + with patch("multiprocessing.resource_tracker.register", + lambda *args, **kwargs: None): + try: + self.shared_memory = shared_memory.SharedMemory(name=name) + assert ( + self.shared_memory.size == self.total_bytes_of_buffer) + except FileNotFoundError: + # we might deserialize the object in a different node + # in this case, this object is not used, + # and we should suppress the error + pass + + def __reduce__(self): + return ( + self.__class__, + (self.n_reader, self.max_chunk_bytes, self.max_chunks, + self.shared_memory.name), + ) + + def __del__(self): + if hasattr(self, "shared_memory"): + self.shared_memory.close() + if self.is_creator: + self.shared_memory.unlink() + + @contextmanager + def get_data(self, current_idx: int): + start = self.data_offset + current_idx * self.max_chunk_bytes + end = start + self.max_chunk_bytes + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + @contextmanager + def get_metadata(self, current_idx: int): + start = self.metadata_offset + current_idx * self.metadata_size + end = start + self.metadata_size + with memoryview(self.shared_memory.buf[start:end]) as buf: + yield buf + + +@dataclass +class Handle: + connect_ip: str + local_reader_ranks: List[int] = field(default_factory=list) + + buffer: Optional[ShmRingBuffer] = None + local_subscribe_port: Optional[int] = None + remote_subscribe_port: Optional[int] = None + + +class MessageQueue: + + def __init__( + self, + n_reader, # number of all readers + n_local_reader, # number of local readers through shared memory + local_reader_ranks: Optional[List[int]] = None, + max_chunk_bytes: int = 1024 * 1024 * 10, + max_chunks: int = 10, + connect_ip: Optional[str] = None, + ): + if local_reader_ranks is None: + local_reader_ranks = list(range(n_local_reader)) + else: + assert len(local_reader_ranks) == n_local_reader + self.n_local_reader = n_local_reader + n_remote_reader = n_reader - n_local_reader + self.n_remote_reader = n_remote_reader + + if connect_ip is None: + connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1" + + context = Context() + + if n_local_reader > 0: + # for local readers, we will: + # 1. create a shared memory ring buffer to communicate small data + # 2. create a publish-subscribe socket to communicate large data + self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, + max_chunks) + + # XPUB is very similar to PUB, + # except that it can receive subscription messages + # to confirm the number of subscribers + self.local_socket = context.socket(XPUB) + # set the verbose option so that we can receive every subscription + # message. otherwise, we will only receive the first subscription + # see http://api.zeromq.org/3-3:zmq-setsockopt for more details + self.local_socket.setsockopt(XPUB_VERBOSE, True) + local_subscribe_port = get_open_port() + socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}" + logger.debug("Binding to %s", socket_addr) + self.local_socket.bind(socket_addr) + + self.current_idx = 0 + + else: + self.buffer = None # type: ignore + local_subscribe_port = None + self.local_socket = None + self.current_idx = -1 + + if n_remote_reader > 0: + # for remote readers, we will: + # create a publish-subscribe socket to communicate large data + self.remote_socket = context.socket(XPUB) + self.remote_socket.setsockopt(XPUB_VERBOSE, True) + remote_subscribe_port = get_open_port() + if is_valid_ipv6_address(connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://*:{remote_subscribe_port}" + self.remote_socket.bind(socket_addr) + + else: + remote_subscribe_port = None + self.remote_socket = None + + self._is_writer = True + self._is_local_reader = False + self.local_reader_rank = -1 + # rank does not matter for remote readers + self._is_remote_reader = False + + self.handle = Handle( + connect_ip=connect_ip, + local_reader_ranks=local_reader_ranks, + buffer=self.buffer, + local_subscribe_port=local_subscribe_port, + remote_subscribe_port=remote_subscribe_port, + ) + + logger.info("vLLM message queue communication handle: %s", self.handle) + + def export_handle(self) -> Handle: + return self.handle + + @staticmethod + def create_from_handle(handle: Handle, rank) -> "MessageQueue": + self = MessageQueue.__new__(MessageQueue) + self.handle = handle + self._is_writer = False + + context = Context() + + if rank in handle.local_reader_ranks: + assert handle.buffer is not None + self.buffer = handle.buffer + self.current_idx = 0 + self.local_reader_rank = handle.local_reader_ranks.index(rank) + self._is_local_reader = True + self._is_remote_reader = False + + self.local_socket = context.socket(SUB) + self.local_socket.setsockopt_string(SUBSCRIBE, "") + socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.local_socket.connect(socket_addr) + + self.remote_socket = None + else: + self.buffer = None # type: ignore + self.current_idx = -1 + self.local_reader_rank = -1 + self._is_local_reader = False + self._is_remote_reader = True + + self.local_socket = None + + self.remote_socket = context.socket(SUB) + self.remote_socket.setsockopt_string(SUBSCRIBE, "") + if is_valid_ipv6_address(handle.connect_ip): + self.remote_socket.setsockopt(IPV6, 1) + socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}" + logger.debug("Connecting to %s", socket_addr) + self.remote_socket.connect(socket_addr) + + return self + + def wait_until_ready(self): + """This is a collective operation. All processes (including the + readers and the writer) should call this function. + """ + if self._is_writer: + # wait for all readers to connect + + # local readers + for i in range(self.n_local_reader): + # wait for subscription messages from all local readers + self.local_socket.recv() + if self.n_local_reader > 0: + # send a message to all local readers + # to make sure the publish channel is working + self.local_socket.send(b"READY") + + # remote readers + for i in range(self.n_remote_reader): + # wait for subscription messages from all remote readers + self.remote_socket.recv() + if self.n_remote_reader > 0: + # send a message to all remote readers + # to make sure the publish channel is working + self.remote_socket.send(b"READY") + elif self._is_local_reader: + # wait for the writer to send a message + recv = self.local_socket.recv() + assert recv == b"READY" + elif self._is_remote_reader: + # wait for the writer to send a message + recv = self.remote_socket.recv() + assert recv == b"READY" + + @contextmanager + def acquire_write(self): + assert self._is_writer, "Only writers can acquire write" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_count = sum(metadata_buffer[1:]) + written_flag = metadata_buffer[0] + if written_flag and read_count != self.buffer.n_reader: + # this block is written and not read by all readers + # for writers, `self.current_idx` is the next block to write + # if this block is not ready to write, + # we need to wait until it is read by all readers + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if (time.monotonic() - start_time > + VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is either + # (1) not written + # (2) read by all readers + + # mark the block as not written + metadata_buffer[0] = 0 + # let caller write to the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has written to the buffer + # NOTE: order is important here + # first set the read flags to 0 + # then set the written flag to 1 + # otherwise, the readers may think they already read the block + for i in range(1, self.buffer.n_reader + 1): + # set read flag to 0, meaning it is not read yet + metadata_buffer[i] = 0 + # mark the block as written + metadata_buffer[0] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + @contextmanager + def acquire_read(self): + assert self._is_local_reader, "Only readers can acquire read" + start_time = time.monotonic() + n_warning = 1 + while True: + with self.buffer.get_metadata(self.current_idx) as metadata_buffer: + read_flag = metadata_buffer[self.local_reader_rank + 1] + written_flag = metadata_buffer[0] + if not written_flag or read_flag: + # this block is either + # (1) not written + # (2) already read by this reader + + # for readers, `self.current_idx` is the next block to read + # if this block is not ready, + # we need to wait until it is written + + # wait for a while + time.sleep(RINGBUFFER_SLEEP_INTERVAL) + + # if we wait for a long time, we should warn the user + if (time.monotonic() - start_time > + VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning): + logger.warning( + "No available block found in %s second. ", + VLLM_RINGBUFFER_WARNING_INTERVAL) + n_warning += 1 + + continue + # found a block that is not read by this reader + # let caller read from the buffer + with self.buffer.get_data(self.current_idx) as buf: + yield buf + + # caller has read from the buffer + # set the read flag + metadata_buffer[self.local_reader_rank + 1] = 1 + self.current_idx = (self.current_idx + + 1) % self.buffer.max_chunks + break + + def enqueue(self, obj): + assert self._is_writer, "Only writers can enqueue" + serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL) + if self.n_local_reader > 0: + if len(serialized_obj) >= self.buffer.max_chunk_bytes: + with self.acquire_write() as buf: + buf[0] = 1 # overflow + self.local_socket.send(serialized_obj) + else: + with self.acquire_write() as buf: + buf[0] = 0 # not overflow + buf[1:len(serialized_obj) + 1] = serialized_obj + if self.n_remote_reader > 0: + self.remote_socket.send(serialized_obj) + + def dequeue(self): + if self._is_local_reader: + with self.acquire_read() as buf: + overflow = buf[0] == 1 + if not overflow: + # no need to know the size of serialized object + # pickle format contains the size information internally + # see https://docs.python.org/3/library/pickle.html + obj = pickle.loads(buf[1:]) + if overflow: + recv = self.local_socket.recv() + obj = pickle.loads(recv) + elif self._is_remote_reader: + recv = self.remote_socket.recv() + obj = pickle.loads(recv) + else: + raise RuntimeError("Only readers can dequeue") + return obj + + def broadcast_object(self, obj=None): + if self._is_writer: + self.enqueue(obj) + return obj + else: + return self.dequeue() + + @staticmethod + def create_from_process_group(pg: ProcessGroup, + max_chunk_bytes, + max_chunks, + writer_rank=0) -> "MessageQueue": + group_rank = dist.get_rank(pg) + group_world_size = dist.get_world_size(pg) + global_ranks = dist.get_process_group_ranks(pg) + + from .parallel_state import in_the_same_node_as + status = in_the_same_node_as(pg, source_rank=writer_rank) + same_node_ranks = [i for i, s in enumerate(status) if s] + n_reader = group_world_size - 1 + n_local_reader = len(same_node_ranks) - 1 + local_reader_ranks = [i for i in same_node_ranks if i != writer_rank] + buffer_io: MessageQueue + if group_rank == writer_rank: + buffer_io = MessageQueue( + n_reader=n_reader, + n_local_reader=n_local_reader, + local_reader_ranks=local_reader_ranks, + max_chunk_bytes=max_chunk_bytes, + max_chunks=max_chunks, + ) + handle = buffer_io.export_handle() + dist.broadcast_object_list([handle], + src=global_ranks[writer_rank], + group=pg) + else: + recv = [None] + dist.broadcast_object_list(recv, + src=global_ranks[writer_rank], + group=pg) + handle = recv[0] # type: ignore + buffer_io = MessageQueue.create_from_handle(handle, group_rank) + buffer_io.wait_until_ready() + return buffer_io diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/utils.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/utils.py new file mode 100644 index 000000000..c062ceedd --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/dist/utils.py @@ -0,0 +1,1551 @@ +import argparse +import asyncio +import contextlib +import datetime +import enum +import gc +import inspect +import ipaddress +import os +import random +import socket +import subprocess +import sys +import tempfile +import threading +import uuid +import warnings +import weakref +from asyncio import FIRST_COMPLETED, ensure_future +from functools import lru_cache, partial, wraps +from platform import uname +from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, + Hashable, List, Literal, Optional, OrderedDict, Set, Tuple, + Type, TypeVar, Union, overload) +from uuid import uuid4 + +import numpy as np +import numpy.typing as npt +import psutil +import torch +import torch.types +import yaml +from packaging.version import Version +from typing_extensions import ParamSpec, TypeIs, assert_never + +# import vllm.envs as envs +# from vllm.logger import enable_trace_function_call, init_logger +# from vllm.platforms import current_platform +from llm_perf.utils.logger import logger + +# logger = init_logger(__name__) + +# Exception strings for non-implemented encoder/decoder scenarios + +STR_NOT_IMPL_ENC_DEC_SWA = \ + "Sliding window attention for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE = \ + "Prefix caching for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL = \ + "Chunked prefill for encoder/decoder models " + \ + "is not currently supported." + +STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP = ( + "Models with logits_soft_cap " + "require FlashInfer backend, which is " + "currently not supported for encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_LORA = ("LoRA is currently not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_PP = ("Pipeline parallelism is not " + "currently supported with " + "encoder/decoder models.") + +STR_NOT_IMPL_ENC_DEC_MM = ("Multimodal is not currently " + "supported with encoder/decoder " + "models.") + +STR_NOT_IMPL_ENC_DEC_SPEC_DEC = ("Speculative decoding is not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_BACKEND = ("XFormers is the only backend " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER = ("Prompt adapters are not " + "currently supported with encoder/" + "decoder models.") + +STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " + "encoder/decoder models.") + +# Efficiently import all enc/dec error strings +# rather than having to import all of the above +STR_NOT_IMPL_ENC_DEC_ERR_STRS = { + "STR_NOT_IMPL_ENC_DEC_SWA": STR_NOT_IMPL_ENC_DEC_SWA, + "STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE": STR_NOT_IMPL_ENC_DEC_PREFIX_CACHE, + "STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL": + STR_NOT_IMPL_ENC_DEC_CHUNKED_PREFILL, + "STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP": STR_NOT_IMPL_ENC_DEC_LOGIT_SOFTCAP, + "STR_NOT_IMPL_ENC_DEC_LORA": STR_NOT_IMPL_ENC_DEC_LORA, + "STR_NOT_IMPL_ENC_DEC_PP": STR_NOT_IMPL_ENC_DEC_PP, + "STR_NOT_IMPL_ENC_DEC_MM": STR_NOT_IMPL_ENC_DEC_MM, + "STR_NOT_IMPL_ENC_DEC_SPEC_DEC": STR_NOT_IMPL_ENC_DEC_SPEC_DEC, + "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, + "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, + "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU +} + +# Constants related to forcing the attention backend selection + +# String name of register which may be set in order to +# force auto-selection of attention backend by Attention +# wrapper +STR_BACKEND_ENV_VAR: str = "VLLM_ATTENTION_BACKEND" + +# Possible string values of STR_BACKEND_ENV_VAR +# register, corresponding to possible backends +STR_FLASHINFER_ATTN_VAL: str = "FLASHINFER" +STR_TORCH_SDPA_ATTN_VAL: str = "TORCH_SDPA" +STR_ROCM_FLASH_ATTN_VAL: str = "ROCM_FLASH" +STR_XFORMERS_ATTN_VAL: str = "XFORMERS" +STR_FLASH_ATTN_VAL: str = "FLASH_ATTN" +STR_INVALID_VAL: str = "INVALID" + +GB_bytes = 1_000_000_000 +"""The number of bytes in one gigabyte (GB).""" + +GiB_bytes = 1 << 30 +"""The number of bytes in one gibibyte (GiB).""" + +STR_DTYPE_TO_TORCH_DTYPE = { + "half": torch.half, + "bfloat16": torch.bfloat16, + "float": torch.float, + "fp8": torch.uint8, + "fp8_e4m3": torch.uint8, + "fp8_e5m2": torch.uint8, +} + +TORCH_DTYPE_TO_NUMPY_DTYPE = { + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.uint8: np.uint8, + torch.int32: np.int32, + torch.int64: np.int64, +} + +P = ParamSpec('P') +K = TypeVar("K") +T = TypeVar("T") +U = TypeVar("U") + + +class _Sentinel: + ... + + +ALL_PINNED_SENTINEL = _Sentinel() + + +class rpd_trace(): + + def __init__(self, + filename=None, + name=None, + nvtx=False, + args=None, + skip=False): + self.skip = skip + if not self.skip: + self.name = name + self.args = args if args else "" + self.rpd = self.initialize_rpd_tracer(filename, nvtx) + + def _recreate_cm(self): + return self + + def __call__(self, func): + if not self.skip: + if self.name: + self.name += f"{func.__name__}" + else: + self.name = f"{func.__qualname__}" + + @wraps(func) + def inner(*args, **kwds): + with self._recreate_cm(): + return func(*args, **kwds) + + return inner + return func + + def __enter__(self): + if not self.skip: + self.rpd.__enter__() + self.rpd.rangePush("python", f"{self.name}", f"{self.args}") + return self + + def __exit__(self, *exc): + if not self.skip: + self.rpd.rangePop() + self.rpd.__exit__(None, None, None) + return False + + @staticmethod + def setup_environment_variables(filename): + os.environ['RPDT_AUTOSTART'] = '0' + os.environ['RPDT_FILENAME'] = filename + + def initialize_rpd_tracer(self, filename, nvtx): + try: + from rpdTracerControl import rpdTracerControl + rpd_trace.setup_environment_variables(filename) + rpdTracerControl.setFilename(name=filename, append=True) + return rpdTracerControl(nvtx=nvtx) + except Exception as e: + print(f"Error initializing rpdTracerControl: {e}") + raise + + @staticmethod + def create_file(filename): + import sqlite3 + + from rocpd.schema import RocpdSchema + try: + print("Creating empty rpd schema file ...") + filename = str(filename) + with sqlite3.connect(filename) as connection: + schema = RocpdSchema() + schema.writeSchema(connection) + connection.commit() + except sqlite3.OperationalError as e: + print(f"SQLite operational error: {e}") + except Exception as e: + print(f"An error occurred while creating the filename: {e}") + + +@lru_cache(maxsize=None) +def is_hipScopedMarker_available(): + try: + from hipScopedMarker import hipScopedMarker + except ImportError: + hipScopedMarker = None + return hipScopedMarker is not None + + +class rpd_mark(): + + def __init__(self, name=None): + self.name = name + + def __call__(self, func): + + if is_hipScopedMarker_available(): + from hipScopedMarker import hipScopedMarker + + @wraps(func) + def inner(*args, **kwds): + marker_name = self.name if self.name else f"{func.__name__}" + with hipScopedMarker(f"{marker_name}"): + return func(*args, **kwds) + + return inner + + else: + return func + + +class Device(enum.Enum): + GPU = enum.auto() + CPU = enum.auto() + + +class Counter: + + def __init__(self, start: int = 0) -> None: + self.counter = start + + def __next__(self) -> int: + i = self.counter + self.counter += 1 + return i + + def reset(self) -> None: + self.counter = 0 + + +class LRUCache(Generic[T]): + + def __init__(self, capacity: int): + self.cache: OrderedDict[Hashable, T] = OrderedDict() + self.pinned_items: Set[Hashable] = set() + self.capacity = capacity + + def __contains__(self, key: Hashable) -> bool: + return key in self.cache + + def __len__(self) -> int: + return len(self.cache) + + def __getitem__(self, key: Hashable) -> T: + value = self.cache[key] # Raise KeyError if not exists + self.cache.move_to_end(key) + return value + + def __setitem__(self, key: Hashable, value: T) -> None: + self.put(key, value) + + def __delitem__(self, key: Hashable) -> None: + self.pop(key) + + def touch(self, key: Hashable) -> None: + self.cache.move_to_end(key) + + def get(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: + value: Optional[T] + if key in self.cache: + value = self.cache[key] + self.cache.move_to_end(key) + else: + value = default_value + return value + + def put(self, key: Hashable, value: T) -> None: + self.cache[key] = value + self.cache.move_to_end(key) + self._remove_old_if_needed() + + def pin(self, key: Hashable) -> None: + """ + Pins a key in the cache preventing it from being + evicted in the LRU order. + """ + if key not in self.cache: + raise ValueError(f"Cannot pin key: {key} not in cache.") + self.pinned_items.add(key) + + def _unpin(self, key: Hashable) -> None: + self.pinned_items.remove(key) + + def _on_remove(self, key: Hashable, value: Optional[T]): + pass + + def remove_oldest(self, remove_pinned=False): + if not self.cache: + return + + if not remove_pinned: + # pop the oldest item in the cache that is not pinned + lru_key = next( + (key for key in self.cache if key not in self.pinned_items), + ALL_PINNED_SENTINEL) + if lru_key is ALL_PINNED_SENTINEL: + raise RuntimeError("All items are pinned, " + "cannot remove oldest from the cache.") + else: + lru_key = next(iter(self.cache)) + self.pop(lru_key) + + def _remove_old_if_needed(self) -> None: + while len(self.cache) > self.capacity: + self.remove_oldest() + + def pop(self, + key: Hashable, + default_value: Optional[T] = None) -> Optional[T]: + run_on_remove = key in self.cache + value: Optional[T] = self.cache.pop(key, default_value) + # remove from pinned items + if key in self.pinned_items: + self._unpin(key) + if run_on_remove: + self._on_remove(key, value) + return value + + def clear(self): + while len(self.cache) > 0: + self.remove_oldest(remove_pinned=True) + self.cache.clear() + + +class PyObjectCache: + """Used to cache python objects to avoid object allocations + across scheduler iterations. + """ + + def __init__(self, obj_builder): + self._obj_builder = obj_builder + self._index = 0 + + self._obj_cache = [] + for _ in range(128): + self._obj_cache.append(self._obj_builder()) + + def _grow_cache(self): + # Double the size of the cache + num_objs = len(self._obj_cache) + for _ in range(num_objs): + self._obj_cache.append(self._obj_builder()) + + def get_object(self): + """Returns a pre-allocated cached object. If there is not enough + objects, then the cache size will double. + """ + if self._index >= len(self._obj_cache): + self._grow_cache() + assert self._index < len(self._obj_cache) + + obj = self._obj_cache[self._index] + self._index += 1 + + return obj + + def reset(self): + """Makes all cached-objects available for the next scheduler iteration. + """ + self._index = 0 + + +def is_hip() -> bool: + return torch.version.hip is not None + + +@lru_cache(maxsize=None) +def is_cpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "cpu" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_openvino() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + return "openvino" in version("vllm") + except PackageNotFoundError: + return False + + +@lru_cache(maxsize=None) +def is_neuron() -> bool: + try: + import transformers_neuronx + except ImportError: + transformers_neuronx = None + return transformers_neuronx is not None + + +@lru_cache(maxsize=None) +def is_xpu() -> bool: + from importlib.metadata import PackageNotFoundError, version + try: + is_xpu_flag = "xpu" in version("vllm") + except PackageNotFoundError: + return False + # vllm is not build with xpu + if not is_xpu_flag: + return False + try: + import intel_extension_for_pytorch as ipex # noqa: F401 + _import_ipex = True + except ImportError as e: + logger.warning("Import Error for IPEX: %s", e.msg) + _import_ipex = False + # ipex dependency is not ready + if not _import_ipex: + logger.warning("not found ipex lib") + return False + return hasattr(torch, "xpu") and torch.xpu.is_available() + + +@lru_cache(maxsize=None) +def get_max_shared_memory_bytes(gpu: int = 0) -> int: + """Returns the maximum shared memory per thread block in bytes.""" + from vllm import _custom_ops as ops + max_shared_mem = ( + ops.get_max_shared_memory_per_block_device_attribute(gpu)) + # value 0 will cause MAX_SEQ_LEN become negative and test_attention.py + # will fail + assert max_shared_mem > 0, "max_shared_mem can not be zero" + return int(max_shared_mem) + + +def get_cpu_memory() -> int: + """Returns the total CPU memory of the node in bytes.""" + return psutil.virtual_memory().total + + +def seed_everything(seed: int) -> None: + """ + Set the seed of each random module. + + Loosely based on: https://github.com/Lightning-AI/pytorch-lightning/blob/2.4.0/src/lightning/fabric/utilities/seed.py#L20 + """ + random.seed(seed) + np.random.seed(seed) + + if current_platform.is_cuda_alike(): + torch.cuda.manual_seed_all(seed) + + if is_xpu(): + torch.xpu.manual_seed_all(seed) + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +@lru_cache(maxsize=None) +def get_vllm_instance_id() -> str: + """ + If the environment variable VLLM_INSTANCE_ID is set, return it. + Otherwise, return a random UUID. + Instance id represents an instance of the VLLM. All processes in the same + instance should have the same instance id. + """ + return envs.VLLM_INSTANCE_ID or f"vllm-instance-{random_uuid()}" + + +@lru_cache(maxsize=None) +def in_wsl() -> bool: + # Reference: https://github.com/microsoft/WSL/issues/4071 + return "microsoft" in " ".join(uname()).lower() + + +def make_async(func: Callable[P, T]) -> Callable[P, Awaitable[T]]: + """Take a blocking function, and run it on in an executor thread. + + This function prevents the blocking function from blocking the + asyncio event loop. + The code in this function needs to be thread safe. + """ + + def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: + loop = asyncio.get_event_loop() + p_func = partial(func, *args, **kwargs) + return loop.run_in_executor(executor=None, func=p_func) + + return _async_wrapper + + +async def iterate_with_cancellation( + iterator: AsyncGenerator[T, None], + is_cancelled: Callable[[], Awaitable[bool]], +) -> AsyncGenerator[T, None]: + """Convert async iterator into one that polls the provided function + at least once per second to check for client cancellation. + """ + + # Can use anext() in python >= 3.10 + awaits = [ensure_future(iterator.__anext__())] + while True: + done, pending = await asyncio.wait(awaits, timeout=1) + if await is_cancelled(): + with contextlib.suppress(BaseException): + awaits[0].cancel() + await iterator.aclose() + raise asyncio.CancelledError("client cancelled") + if done: + try: + item = await awaits[0] + awaits[0] = ensure_future(iterator.__anext__()) + yield item + except StopAsyncIteration: + # we are done + return + + +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], + is_cancelled: Optional[Callable[[], Awaitable[bool]]] = None, +) -> AsyncGenerator[Tuple[int, T], None]: + """Merge multiple asynchronous iterators into a single iterator. + + This method handle the case where some iterators finish before others. + When it yields, it yields a tuple (i, item) where i is the index of the + iterator that yields the item. + + It also optionally polls a provided function at least once per second + to check for client cancellation. + """ + + # Can use anext() in python >= 3.10 + awaits = { + ensure_future(pair[1].__anext__()): pair + for pair in enumerate(iterators) + } + timeout = None if is_cancelled is None else 1 + try: + while awaits: + done, pending = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED, + timeout=timeout) + if is_cancelled is not None and await is_cancelled(): + raise asyncio.CancelledError("client cancelled") + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[ensure_future(it.__anext__())] = pair + yield i, item + except StopAsyncIteration: + pass + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + f.cancel() + await it.aclose() + + +async def collect_from_async_generator( + iterator: AsyncGenerator[T, None]) -> List[T]: + """Collect all items from an async generator into a list.""" + items = [] + async for item in iterator: + items.append(item) + return items + + +def get_ip() -> str: + # host_ip = envs.VLLM_HOST_IP + # if host_ip: + # return host_ip + + # IP is not set, try to get it from the network interface + + # try ipv4 + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("8.8.8.8", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + # try ipv6 + try: + s = socket.socket(socket.AF_INET6, socket.SOCK_DGRAM) + # Google's public DNS server, see + # https://developers.google.com/speed/public-dns/docs/using#addresses + s.connect(("2001:4860:4860::8888", 80)) # Doesn't need to be reachable + return s.getsockname()[0] + except Exception: + pass + + warnings.warn( + "Failed to get the IP address, using 0.0.0.0 by default." + "The value can be set by the environment variable" + " VLLM_HOST_IP or HOST_IP.", + stacklevel=2) + return "0.0.0.0" + + +def is_valid_ipv6_address(address: str) -> bool: + try: + ipaddress.IPv6Address(address) + return True + except ValueError: + return False + + +def get_distributed_init_method(ip: str, port: int) -> str: + # Brackets are not permitted in ipv4 addresses, + # see https://github.com/python/cpython/issues/103848 + return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}" + + +def get_open_zmq_ipc_path() -> str: + base_rpc_path = envs.VLLM_RPC_BASE_PATH + return f"ipc://{base_rpc_path}/{uuid4()}" + + +def get_open_port() -> int: + # port = envs.VLLM_PORT + port = None + if port is not None: + while True: + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", port)) + return port + except OSError: + port += 1 # Increment port number if already in use + logger.info("Port %d is already in use, trying port %d", + port - 1, port) + # try ipv4 + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + except OSError: + # try ipv6 + with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s: + s.bind(("", 0)) + return s.getsockname()[1] + + +def find_process_using_port(port: int) -> Optional[psutil.Process]: + for conn in psutil.net_connections(): + if conn.laddr.port == port: + try: + return psutil.Process(conn.pid) + except psutil.NoSuchProcess: + return None + return None + + +def update_environment_variables(envs: Dict[str, str]): + for k, v in envs.items(): + if k in os.environ and os.environ[k] != v: + logger.warning( + "Overwriting environment variable %s " + "from '%s' to '%s'", k, os.environ[k], v) + os.environ[k] = v + + +def chunk_list(lst: List[T], chunk_size: int): + """Yield successive chunk_size chunks from lst.""" + for i in range(0, len(lst), chunk_size): + yield lst[i:i + chunk_size] + + +def cdiv(a: int, b: int) -> int: + """Ceiling division.""" + return -(a // -b) + + +def _generate_random_fp8( + tensor: torch.Tensor, + low: float, + high: float, +) -> None: + # NOTE(zhaoyang): Due to NaN and Inf representation for fp8 data type, + # it may occur Inf or NaN if we directly use torch.randint + # to generate random data for fp8 data. + # For example, s.11111.00 in fp8e5m2 format represents Inf. + # | E4M3 | E5M2 + #-----|-------------|------------------- + # Inf | N/A | s.11111.00 + # NaN | s.1111.111 | s.11111.{01,10,11} + from vllm import _custom_ops as ops + tensor_tmp = torch.empty_like(tensor, dtype=torch.float16) + tensor_tmp.uniform_(low, high) + ops.convert_fp8(tensor, tensor_tmp) + del tensor_tmp + + +def get_kv_cache_torch_dtype( + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None) -> torch.dtype: + if isinstance(cache_dtype, str): + if cache_dtype == "auto": + if isinstance(model_dtype, str): + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[model_dtype] + elif isinstance(model_dtype, torch.dtype): + torch_dtype = model_dtype + else: + raise ValueError(f"Invalid model dtype: {model_dtype}") + elif cache_dtype in ["half", "bfloat16", "float"]: + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_dtype] + elif cache_dtype == "fp8": + torch_dtype = torch.uint8 + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + elif isinstance(cache_dtype, torch.dtype): + torch_dtype = cache_dtype + else: + raise ValueError(f"Invalid kv cache dtype: {cache_dtype}") + return torch_dtype + + +def create_kv_caches_with_random_flash( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + key_value_cache_shape = (num_blocks, 2, block_size, num_heads, head_size) + scale = head_size**-0.5 + + key_caches: List[torch.Tensor] = [] + value_caches: List[torch.Tensor] = [] + + for _ in range(num_layers): + key_value_cache = torch.empty(size=key_value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_value_cache[:, 0]) + value_caches.append(key_value_cache[:, 1]) + return key_caches, value_caches + + +def create_kv_caches_with_random( + num_blocks: int, + block_size: int, + num_layers: int, + num_heads: int, + head_size: int, + cache_dtype: Optional[Union[str, torch.dtype]], + model_dtype: Optional[Union[str, torch.dtype]] = None, + seed: int = 0, + device: Optional[str] = "cuda", +) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: + + if cache_dtype == "fp8" and head_size % 16: + raise ValueError( + f"Does not support key cache of type fp8 with head_size {head_size}" + ) + + seed_everything(seed) + + torch_dtype = get_kv_cache_torch_dtype(cache_dtype, model_dtype) + + scale = head_size**-0.5 + x = 16 // torch.tensor([], dtype=torch_dtype).element_size() + key_cache_shape = (num_blocks, num_heads, head_size // x, block_size, x) + key_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + key_cache = torch.empty(size=key_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + key_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(key_cache, -scale, scale) + else: + raise ValueError( + f"Does not support key cache of type {cache_dtype}") + key_caches.append(key_cache) + + value_cache_shape = (num_blocks, num_heads, head_size, block_size) + value_caches: List[torch.Tensor] = [] + for _ in range(num_layers): + value_cache = torch.empty(size=value_cache_shape, + dtype=torch_dtype, + device=device) + if cache_dtype in ["auto", "half", "bfloat16", "float"]: + value_cache.uniform_(-scale, scale) + elif cache_dtype == 'fp8': + _generate_random_fp8(value_cache, -scale, scale) + else: + raise ValueError( + f"Does not support value cache of type {cache_dtype}") + value_caches.append(value_cache) + return key_caches, value_caches + + +@lru_cache +def print_warning_once(msg: str) -> None: + # Set the stacklevel to 2 to print the caller's line info + logger.warning(msg, stacklevel=2) + + +@lru_cache(maxsize=None) +def is_pin_memory_available() -> bool: + + if in_wsl(): + # Pinning memory in WSL is not supported. + # https://docs.nvidia.com/cuda/wsl-user-guide/index.html#known-limitations-for-linux-cuda-applications + print_warning_once("Using 'pin_memory=False' as WSL is detected. " + "This may slow down the performance.") + return False + elif is_xpu(): + print_warning_once("Pin memory is not supported on XPU.") + return False + elif is_neuron(): + print_warning_once("Pin memory is not supported on Neuron.") + return False + elif is_cpu() or is_openvino(): + return False + return True + + +class DeviceMemoryProfiler: + + def __init__(self, device: Optional[torch.types.Device] = None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + if current_platform.is_cuda_alike(): + torch.cuda.reset_peak_memory_stats(self.device) + mem = torch.cuda.max_memory_allocated(self.device) + elif is_xpu(): + torch.xpu.reset_peak_memory_stats(self.device) # type: ignore + mem = torch.xpu.max_memory_allocated(self.device) # type: ignore + return mem + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() + + +def make_ndarray_with_pad( + x: List[List[T]], + pad: T, + dtype: npt.DTypeLike, + *, + max_len: Optional[int] = None, +) -> npt.NDArray: + """ + Make a padded array from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + if max_len is None: + # Unlike for most functions, map is faster than a genexpr over `len` + max_len = max(map(len, x), default=0) + + padded_x = np.full((len(x), max_len), pad, dtype=dtype) + for ind, blocktb in enumerate(x): + assert len(blocktb) <= max_len + padded_x[ind, :len(blocktb)] = blocktb + + return padded_x + + +def make_tensor_with_pad( + x: List[List[T]], + pad: T, + dtype: torch.dtype, + *, + max_len: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + pin_memory: bool = False, +) -> torch.Tensor: + """ + Make a padded tensor from 2D inputs. + + The padding is applied to the end of each inner list until it reaches + `max_len`. + """ + np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype] + padded_x = make_ndarray_with_pad(x, pad, np_dtype, max_len=max_len) + + tensor = torch.from_numpy(padded_x).to(device) + if pin_memory: + tensor = tensor.pin_memory() + + return tensor + + +def async_tensor_h2d( + data: list, + dtype: torch.dtype, + target_device: Union[str, torch.device], + pin_memory: bool, +) -> torch.Tensor: + """Asynchronously create a tensor and copy it from host to device.""" + t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu") + return t.to(device=target_device, non_blocking=True) + + +def get_dtype_size(dtype: torch.dtype) -> int: + """Get the size of the data type in bytes.""" + return torch.tensor([], dtype=dtype).element_size() + + +# `collections` helpers +def is_list_of( + value: object, + typ: Type[T], + *, + check: Literal["first", "all"] = "first", +) -> TypeIs[List[T]]: + if not isinstance(value, list): + return False + + if check == "first": + return len(value) == 0 or isinstance(value[0], typ) + elif check == "all": + return all(isinstance(v, typ) for v in value) + + assert_never(check) + + +JSONTree = Union[Dict[str, "JSONTree[T]"], List["JSONTree[T]"], + Tuple["JSONTree[T]", ...], T] +"""A nested JSON structure where the leaves need not be JSON-serializable.""" + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Dict[str, JSONTree[T]], +) -> Dict[str, JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: List[JSONTree[T]], +) -> List[JSONTree[U]]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: Tuple[JSONTree[T], ...], +) -> Tuple[JSONTree[U], ...]: + ... + + +@overload +def json_map_leaves( + func: Callable[[T], U], + value: JSONTree[T], +) -> JSONTree[U]: + ... + + +def json_map_leaves(func: Callable[[T], U], value: JSONTree[T]) -> JSONTree[U]: + if isinstance(value, dict): + return {k: json_map_leaves(func, v) for k, v in value.items()} + elif isinstance(value, list): + return [json_map_leaves(func, v) for v in value] + elif isinstance(value, tuple): + return tuple(json_map_leaves(func, v) for v in value) + else: + return func(value) + + +def flatten_2d_lists(lists: List[List[T]]) -> List[T]: + """Flatten a list of lists to a single list.""" + return [item for sublist in lists for item in sublist] + + +def init_cached_hf_modules() -> None: + """ + Lazy initialization of the Hugging Face modules. + """ + from transformers.dynamic_module_utils import init_hf_modules + init_hf_modules() + + +@lru_cache(maxsize=None) +def find_library(lib_name: str) -> str: + """ + Find the library file in the system. + `lib_name` is full filename, with both prefix and suffix. + This function resolves `lib_name` to the full path of the library. + """ + # Adapted from https://github.com/openai/triton/blob/main/third_party/nvidia/backend/driver.py#L19 # noqa + # According to https://en.wikipedia.org/wiki/Filesystem_Hierarchy_Standard + # `/sbin/ldconfig` should exist in all Linux systems. + # `/sbin/ldconfig` searches the library in the system + libs = subprocess.check_output(["/sbin/ldconfig", "-p"]).decode() + # each line looks like the following: + # libcuda.so.1 (libc6,x86-64) => /lib/x86_64-linux-gnu/libcuda.so.1 + locs = [line.split()[-1] for line in libs.splitlines() if lib_name in line] + # `LD_LIBRARY_PATH` searches the library in the user-defined paths + env_ld_library_path = envs.LD_LIBRARY_PATH + if not locs and env_ld_library_path: + locs = [ + os.path.join(dir, lib_name) + for dir in env_ld_library_path.split(":") + if os.path.exists(os.path.join(dir, lib_name)) + ] + if not locs: + raise ValueError(f"Cannot find {lib_name} in the system.") + return locs[0] + + +def find_nccl_library() -> str: + """ + We either use the library file specified by the `VLLM_NCCL_SO_PATH` + environment variable, or we find the library file brought by PyTorch. + After importing `torch`, `libnccl.so.2` or `librccl.so.1` can be + found by `ctypes` automatically. + """ + so_file = envs.VLLM_NCCL_SO_PATH + + # manually load the nccl library + if so_file: + logger.info( + "Found nccl from environment variable VLLM_NCCL_SO_PATH=%s", + so_file) + else: + if torch.version.cuda is not None: + so_file = "libnccl.so.2" + elif torch.version.hip is not None: + so_file = "librccl.so.1" + else: + raise ValueError("NCCL only supports CUDA and ROCm backends.") + logger.info("Found nccl from library %s", so_file) + return so_file + + +def enable_trace_function_call_for_thread() -> None: + """Set up function tracing for the current thread, + if enabled via the VLLM_TRACE_FUNCTION environment variable + """ + + if envs.VLLM_TRACE_FUNCTION: + tmp_dir = tempfile.gettempdir() + filename = (f"VLLM_TRACE_FUNCTION_for_process_{os.getpid()}" + f"_thread_{threading.get_ident()}_" + f"at_{datetime.datetime.now()}.log").replace(" ", "_") + log_path = os.path.join(tmp_dir, "vllm", get_vllm_instance_id(), + filename) + os.makedirs(os.path.dirname(log_path), exist_ok=True) + enable_trace_function_call(log_path) + + +# `functools` helpers +def identity(value: T) -> T: + return value + + +F = TypeVar('F', bound=Callable[..., Any]) + + +def deprecate_kwargs( + *kws: str, + is_deprecated: Union[bool, Callable[[], bool]] = True, + additional_message: Optional[str] = None) -> Callable[[F], F]: + deprecated_kws = set(kws) + + if not callable(is_deprecated): + is_deprecated = partial(identity, is_deprecated) + + def wrapper(fn: F) -> F: + + @wraps(fn) + def inner(*args, **kwargs): + if is_deprecated(): + deprecated_kwargs = kwargs.keys() & deprecated_kws + if deprecated_kwargs: + msg = ( + f"The keyword arguments {deprecated_kwargs} are " + "deprecated and will be removed in a future update.") + if additional_message is not None: + msg += f" {additional_message}" + + warnings.warn( + DeprecationWarning(msg), + stacklevel=3, # The inner function takes up one level + ) + + return fn(*args, **kwargs) + + return inner # type: ignore + + return wrapper + + +@lru_cache(maxsize=8) +def _cuda_device_count_stateless( + cuda_visible_devices: Optional[str] = None) -> int: + # Note: cuda_visible_devices is not used, but we keep it as an argument for + # LRU Cache purposes. + + # Code below is based on + # https://github.com/pytorch/pytorch/blob/ + # c1cd946818442aca8c7f812b16d187ce1586c3bc/ + # torch/cuda/__init__.py#L831C1-L831C17 + import torch.cuda + import torch.version + + if not torch.cuda._is_compiled(): + return 0 + if is_hip(): + # ROCm uses amdsmi instead of nvml for stateless device count + # This requires a sufficiently modern version of Torch 2.4.0 + raw_count = torch.cuda._device_count_amdsmi() if (hasattr( + torch.cuda, "_device_count_amdsmi")) else -1 + else: + raw_count = torch.cuda._device_count_nvml() + r = torch._C._cuda_getDeviceCount() if raw_count < 0 else raw_count + return r + + +def cuda_device_count_stateless() -> int: + """Get number of CUDA devices, caching based on the value of + CUDA_VISIBLE_DEVICES at the time of call. + + This should be used instead of torch.cuda.device_count() + unless CUDA_VISIBLE_DEVICES has already been set to the desired + value.""" + + # This can be removed and simply replaced with torch.cuda.get_device_count + # after https://github.com/pytorch/pytorch/pull/122815 is released. + return _cuda_device_count_stateless(envs.CUDA_VISIBLE_DEVICES) + + +def cuda_is_initialized() -> bool: + """Check if CUDA is initialized.""" + if not torch.cuda._is_compiled(): + return False + return torch.cuda.is_initialized() + + +def weak_bind(bound_method: Callable[..., Any], ) -> Callable[..., None]: + """Make an instance method that weakly references + its associated instance and no-ops once that + instance is collected.""" + ref = weakref.ref(bound_method.__self__) # type: ignore[attr-defined] + unbound = bound_method.__func__ # type: ignore[attr-defined] + + def weak_bound(*args, **kwargs) -> None: + if inst := ref(): + unbound(inst, *args, **kwargs) + + return weak_bound + + +#From: https://stackoverflow.com/a/4104188/2749989 +def run_once(f: Callable[P, None]) -> Callable[P, None]: + + def wrapper(*args: P.args, **kwargs: P.kwargs) -> None: + if not wrapper.has_run: # type: ignore[attr-defined] + wrapper.has_run = True # type: ignore[attr-defined] + return f(*args, **kwargs) + + wrapper.has_run = False # type: ignore[attr-defined] + return wrapper + + +class FlexibleArgumentParser(argparse.ArgumentParser): + """ArgumentParser that allows both underscore and dash in names.""" + + def parse_args(self, args=None, namespace=None): + if args is None: + args = sys.argv[1:] + + if '--config' in args: + args = FlexibleArgumentParser._pull_args_from_config(args) + + # Convert underscores to dashes and vice versa in argument names + processed_args = [] + for arg in args: + if arg.startswith('--'): + if '=' in arg: + key, value = arg.split('=', 1) + key = '--' + key[len('--'):].replace('_', '-') + processed_args.append(f'{key}={value}') + else: + processed_args.append('--' + + arg[len('--'):].replace('_', '-')) + else: + processed_args.append(arg) + + return super().parse_args(processed_args, namespace) + + @staticmethod + def _pull_args_from_config(args: List[str]) -> List[str]: + """Method to pull arguments specified in the config file + into the command-line args variable. + + The arguments in config file will be inserted between + the argument list. + + example: + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + ```python + $: vllm {serve,chat,complete} "facebook/opt-12B" \ + --config config.yaml -tp 2 + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--config', 'config.yaml', + '-tp', '2' + ] + $: args = [ + "serve,chat,complete", + "facebook/opt-12B", + '--port', '12323', + '--tensor-parallel-size', '4', + '-tp', '2' + ] + ``` + + Please note how the config args are inserted after the sub command. + this way the order of priorities is maintained when these are args + parsed by super(). + """ + assert args.count( + '--config') <= 1, "More than one config file specified!" + + index = args.index('--config') + if index == len(args) - 1: + raise ValueError("No config file specified! \ + Please check your command-line arguments.") + + file_path = args[index + 1] + + config_args = FlexibleArgumentParser._load_config_file(file_path) + + # 0th index is for {serve,chat,complete} + # followed by model_tag (only for serve) + # followed by config args + # followed by rest of cli args. + # maintaining this order will enforce the precedence + # of cli > config > defaults + if args[0] == "serve": + if index == 1: + raise ValueError( + "No model_tag specified! Please check your command-line" + " arguments.") + args = [args[0]] + [ + args[1] + ] + config_args + args[2:index] + args[index + 2:] + else: + args = [args[0]] + config_args + args[1:index] + args[index + 2:] + + return args + + @staticmethod + def _load_config_file(file_path: str) -> List[str]: + """Loads a yaml file and returns the key value pairs as a + flattened list with argparse like pattern + ```yaml + port: 12323 + tensor-parallel-size: 4 + ``` + returns: + processed_args: list[str] = [ + '--port': '12323', + '--tensor-parallel-size': '4' + ] + + """ + + extension: str = file_path.split('.')[-1] + if extension not in ('yaml', 'yml'): + raise ValueError( + "Config file must be of a yaml/yml type.\ + %s supplied", extension) + + # only expecting a flat dictionary of atomic types + processed_args: List[str] = [] + + config: Dict[str, Union[int, str]] = {} + try: + with open(file_path, 'r') as config_file: + config = yaml.safe_load(config_file) + except Exception as ex: + logger.error( + "Unable to read the config file at %s. \ + Make sure path is correct", file_path) + raise ex + + for key, value in config.items(): + processed_args.append('--' + key) + processed_args.append(str(value)) + + return processed_args + + +async def _run_task_with_lock(task: Callable, lock: asyncio.Lock, *args, + **kwargs): + """Utility function to run async task in a lock""" + async with lock: + return await task(*args, **kwargs) + + +def supports_kw( + callable: Callable[..., object], + kw_name: str, + requires_kw_only: bool = False, + allow_var_kwargs: bool = True, +) -> bool: + """Check if a keyword is a valid kwarg for a callable; if requires_kw_only + disallows kwargs names that can also be positional arguments. + """ + params = inspect.signature(callable).parameters + if not params: + return False + + param_val = params.get(kw_name) + + # Types where the it may be valid, i.e., explicitly defined & nonvariadic + passable_kw_types = set((inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY)) + + if param_val: + is_sig_param = param_val.kind in passable_kw_types + # We want kwargs only, but this is passable as a positional arg + if (requires_kw_only and is_sig_param + and param_val.kind != inspect.Parameter.KEYWORD_ONLY): + return False + if ((requires_kw_only + and param_val.kind == inspect.Parameter.KEYWORD_ONLY) + or (not requires_kw_only and is_sig_param)): + return True + + # If we're okay with var-kwargs, it's supported as long as + # the kw_name isn't something like *args, **kwargs + if allow_var_kwargs: + # Get the last param; type is ignored here because params is a proxy + # mapping, but it wraps an ordered dict, and they appear in order. + # Ref: https://docs.python.org/3/library/inspect.html#inspect.Signature.parameters + last_param = params[next(reversed(params))] # type: ignore + return (last_param.kind == inspect.Parameter.VAR_KEYWORD + and last_param.name != kw_name) + return False + + +def resolve_mm_processor_kwargs( + init_kwargs: Optional[Dict[str, Any]], + inference_kwargs: Optional[Dict[str, Any]], + callable: Callable[..., object], + allow_var_kwargs: bool = False, +) -> Dict[str, Any]: + """Applies filtering to eliminate invalid mm_processor_kwargs, i.e., + those who are not explicit keywords to the given callable (of one is + given; otherwise no filtering is done), then merges the kwarg dicts, + giving priority to inference_kwargs if there are any collisions. + + In the case that no kwarg overrides are provided, returns an empty + dict so that it can still be kwarg expanded into the callable later on. + + If allow_var_kwargs=True, allows for things that can be expanded into + kwargs as long as they aren't naming collision for var_kwargs or potential + positional arguments. + """ + # Filter inference time multimodal processor kwargs provided + runtime_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, + overrides=inference_kwargs, + allow_var_kwargs=allow_var_kwargs) + + # Filter init time multimodal processor kwargs provided + init_mm_kwargs = get_allowed_kwarg_only_overrides( + callable, overrides=init_kwargs, allow_var_kwargs=allow_var_kwargs) + + # Merge the final processor kwargs, prioritizing inference + # time values over the initialization time values. + mm_processor_kwargs = {**init_mm_kwargs, **runtime_mm_kwargs} + return mm_processor_kwargs + + +def get_allowed_kwarg_only_overrides( + callable: Callable[..., object], + overrides: Optional[Dict[str, Any]], + allow_var_kwargs: bool = False, +) -> Dict[str, Any]: + """ + Given a callable which has one or more keyword only params and a dict + mapping param names to values, drop values that can be not be kwarg + expanded to overwrite one or more keyword-only args. This is used in a + few places to handle custom processor overrides for multimodal models, + e.g., for profiling when processor options provided by the user + may affect the number of mm tokens per instance. + + Args: + callable: Callable which takes 0 or more keyword only arguments. + If None is provided, all overrides names are allowed. + overrides: Potential overrides to be used when invoking the callable. + allow_var_kwargs: Allows overrides that are expandable for var kwargs. + + Returns: + Dictionary containing the kwargs to be leveraged which may be used + to overwrite one or more keyword only arguments when invoking the + callable. + """ + if not overrides: + return {} + + # Drop any mm_processor_kwargs provided by the user that + # are not kwargs, unless it can fit it var_kwargs param + filtered_overrides = { + kwarg_name: val + for kwarg_name, val in overrides.items() + if supports_kw(callable, + kwarg_name, + requires_kw_only=True, + allow_var_kwargs=allow_var_kwargs) + } + + # If anything is dropped, log a warning + dropped_keys = overrides.keys() - filtered_overrides.keys() + if dropped_keys: + logger.warning( + "The following intended overrides are not keyword-only args " + "and and will be dropped: %s", dropped_keys) + + return filtered_overrides + + +# Using dynamo with vLLM doesn't really work well with PyTorch versions < 2.4.0. +# In particular, the FakeScalarType is not supported for earlier versions of +# PyTorch which breaks dynamo for any ops registered using ScalarType. +def supports_dynamo() -> bool: + base_torch_version = Version(Version(torch.__version__).base_version) + return base_torch_version >= Version("2.4.0") + + +# Some backends use pytorch version < 2.4.0 which doesn't +# support `torch.library.custom_op`. +def supports_custom_op() -> bool: + return hasattr(torch.library, "custom_op") + + +class AtomicCounter: + """An atomic, thread-safe counter""" + + def __init__(self, initial=0): + """Initialize a new atomic counter to given initial value""" + self._value = initial + self._lock = threading.Lock() + + def inc(self, num=1): + """Atomically increment the counter by num and return the new value""" + with self._lock: + self._value += num + return self._value + + def dec(self, num=1): + """Atomically decrement the counter by num and return the new value""" + with self._lock: + self._value -= num + return self._value + + @property + def value(self): + return self._value diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe.py new file mode 100644 index 000000000..470153b14 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe.py @@ -0,0 +1,914 @@ +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import rocmKernels as moe_kernels +# import vllm.envs as envs +# from vllm import _custom_ops as moe_kernels +# from vllm.logger import init_logger +# from vllm.platforms import current_platform + +import logging +logger = logging.getLogger(__name__) +# logger = init_logger(__name__) +VLLM_MOE_PADDING = bool(int(os.getenv("VLLM_MOE_PADDING", "1"))) +FUSED_MOE_PERSISTENT = bool(int(os.getenv("FUSED_MOE_PERSISTENT", "0"))) +ENABLE_MOE_LDS_BYPASS = bool(int(os.getenv("ENABLE_MOE_LDS_BYPASS", "1"))) +print(f'{FUSED_MOE_PERSISTENT=}, {ENABLE_MOE_LDS_BYPASS=}, {VLLM_MOE_PADDING=}') +VLLM_FUSED_MOE_CHUNK_SIZE = 65536 +padding_size = 128 if VLLM_MOE_PADDING else 0 + + +@triton.jit +def fused_moe_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + token_nums_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + block_token_num = tl.load(token_nums_ptr + pid_m) + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + blk_m_range = tl.arange(0, BLOCK_SIZE_M) + token_mask = blk_m_range < block_token_num + offs_token_id = pid_m * BLOCK_SIZE_M + blk_m_range + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id, mask=token_mask) + # token_mask = offs_token < num_valid_tokens + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + if use_int8_w8a16: + b_scale_ptrs = b_scale_ptr + off_experts * stride_bse + offs_bn[ + None, :] * stride_bsn + b_scale = tl.load(b_scale_ptrs) + + if use_fp8_w8a8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0) + # We accumulate along the K dimension. + if use_int8_w8a16: + accumulator = tl.dot(a, b.to(compute_type), acc=accumulator) + elif use_fp8_w8a8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator = tl.dot(a, b, acc=accumulator) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + if use_int8_w8a16: + accumulator = (accumulator * b_scale).to(compute_type) + elif use_fp8_w8a8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def fused_moe_persistent_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, + NUM_SMS: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + This is the persistent version of the fused_moe kernel. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Simply compute how many iterations each persistent block needs to do + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # num_tiles = num_pid_m * num_pid_n + tile_id = start_pid + + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + # token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) + + # Load tile-invariant runtime constant + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Compute how many tiles are outside the padding region + num_pid_in_group = GROUP_SIZE_M * num_pid_n + pid_m = 0 + tile_id2 = start_pid - NUM_SMS + num_valid_tiles = -1 + while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + num_valid_tiles += 1 + tile_id2 += NUM_SMS + group_id = tile_id2 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) + + for _ in range(0, num_valid_tiles): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + # Compute the mask + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + # Compute the A pointer + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + # Compute the B pointer + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + if EVEN_K: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0 + ) + # We accumulate along the K dimension. + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + # advance tile_id + tile_id += NUM_SMS + + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + # sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + token_nums = torch.empty((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + moe_kernels.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, token_nums, num_tokens_post_pad) + return sorted_ids, expert_ids, token_nums, num_tokens_post_pad + + +def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: Optional[torch.Tensor], + B_scale: Optional[torch.Tensor], + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + token_nums: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + if use_fp8_w8a8: + A, A_scale = moe_kernels.scaled_fp8_quant(A, A_scale) + assert B_scale is not None + elif use_int8_w8a16: + assert B_scale is not None + else: + assert A_scale is None + assert B_scale is None + + if not FUSED_MOE_PERSISTENT: + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + + fused_moe_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + token_nums, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0, + B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + enable_moe_lds_bypass=ENABLE_MOE_LDS_BYPASS + ) + else: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * 2 + grid = lambda META: (min( + NUM_SMS, + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * + triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) + ), ) + + fused_moe_persistent_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + NUM_SMS=NUM_SMS, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8_w8a8, + **config, + enable_moe_lds_bypass=ENABLE_MOE_LDS_BYPASS + ) + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + # device_name = current_platform.get_device_name().replace(" ", "_") + device_name = 'AMD_Instinct_MI308X_OAM' # TODO: need to update + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.info("---> MOE tuned file not found at %s",config_file_path) + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: + config = { + 'BLOCK_SIZE_M': 64, + 'BLOCK_SIZE_N': 128, # reqd. for MOE shuffle + 'BLOCK_SIZE_K': 128, # reqd. for MOE shuffle + 'GROUP_SIZE_M': 8 + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 128, # reqd. for MOE shuffle + 'BLOCK_SIZE_K': 128, # reqd. for MOE shuffle + 'GROUP_SIZE_M': 1 + } + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) + return config + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + M, _ = hidden_states.shape + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + renormalize + ) + del token_expert_indicies # Not used. Will be used in the future. + + # if renormalize: + # topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str(dtype: torch.dtype, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def fused_experts(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), + topk_ids.shape[1], + config_dtype, + override_config=override_config, + ) + + config = get_config_func(M) + + intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=hidden_states.dtype) + intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=hidden_states.dtype) + + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + # print("init config:", config) + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + # print("inside config:", config) + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, token_nums, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + + invoke_fused_moe_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + token_nums, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + + moe_kernels.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + invoke_fused_moe_kernel(intermediate_cache2, + w2, + intermediate_cache3, + a2_scale, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + token_nums, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + + moe_kernels.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), + out_hidden_states[begin_chunk_idx:end_chunk_idx]) + + return out_hidden_states + + +def fused_moe( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, + w1_scale: Optional[torch.Tensor] = None, + w2_scale: Optional[torch.Tensor] = None, + a1_scale: Optional[torch.Tensor] = None, + a2_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) + + return fused_experts(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + inplace=inplace, + override_config=override_config, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe_int8_a8w8.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe_int8_a8w8.py new file mode 100644 index 000000000..c044abdff --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/fused_moe_int8_a8w8.py @@ -0,0 +1,1014 @@ +"""Fused MoE kernel.""" +import functools +import json +import os +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import triton +import triton.language as tl + +import rocmKernels as moe_kernels + +import logging +logger = logging.getLogger(__name__) +# logger = init_logger(__name__) +VLLM_MOE_PADDING = bool(int(os.getenv("VLLM_MOE_PADDING", "1"))) +FUSED_MOE_PERSISTENT = bool(int(os.getenv("FUSED_MOE_PERSISTENT", "0"))) +ENABLE_MOE_LDS_BYPASS = bool(int(os.getenv("ENABLE_MOE_LDS_BYPASS", "1"))) +print(f'{FUSED_MOE_PERSISTENT=}, {ENABLE_MOE_LDS_BYPASS=}, {VLLM_MOE_PADDING=}') +VLLM_FUSED_MOE_CHUNK_SIZE = 65536 +#padding_size = 128 if VLLM_MOE_PADDING else 0 +padding_size = 0 + + +@triton.jit +def fused_moe_int8_a8w8_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + stride_bse, + stride_bsn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8_w8a8: tl.constexpr, + use_int8_w8a16: tl.constexpr): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Map program ids `pid` to the block of C it should compute. + # This is done in a grouped ordering to promote L2 data reuse. + pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + num_pid_in_group = GROUP_SIZE_M * num_pid_n + group_id = pid // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) + pid_n = (pid % num_pid_in_group) // group_size_m + + # ---------------------------------------------------------- + # Create pointers for the first blocks of A and B. + # We will advance this pointer as we move in the K direction + # and accumulate + # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers + # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: + return + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + off_experts = tl.load(expert_ids_ptr + pid_m) + a_scale_ptr = a_scale_ptr + offs_token // top_k + b_scale_ptr = b_scale_ptr + off_experts * stride_bse + offs_cn + _ALPHA0 = tl.zeros([1], dtype=a_scale_ptr.dtype.element_ty) + + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + + a_scale = tl.load(a_scale_ptr, mask=token_mask, other=_ALPHA0).to(tl.float32) + b_scale = tl.load(b_scale_ptr, mask=offs_cn < N, other=_ALPHA0).to(tl.float32) + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + + offs_bn[None, :] * stride_bn) + + # ----------------------------------------------------------- + # Iterate to compute a block of the C matrix. + # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block + # of fp32 values for higher accuracy. + # `accumulator` will be converted back to fp16 after the loop. + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32) + _A0 = tl.zeros([1, 1], dtype=a_ptrs.dtype.element_ty) + _B0 = tl.zeros([1, 1], dtype=b_ptrs.dtype.element_ty) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + a = tl.load(a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=_A0) + b = tl.load(b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=_B0) + # We accumulate along the K dimension. + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + accumulator = accumulator * a_scale[:,None] * b_scale[None,:] + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + accumulator = accumulator.to(tl.float16) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ + None, :] + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + +@triton.heuristics({ + 'EVEN_K': lambda args: args['K'] % args['BLOCK_SIZE_K'] == 0, +}) +@triton.jit +def fused_moe_persistent_kernel( + # Pointers to matrices + a_ptr, + b_ptr, + c_ptr, + a_scale_ptr, + b_scale_ptr, + topk_weights_ptr, + sorted_token_ids_ptr, + expert_ids_ptr, + num_tokens_post_padded_ptr, + # Matrix dimensions + N, + K, + EM, + num_valid_tokens, + # The stride variables represent how much to increase the ptr by when + # moving by 1 element in a particular dimension. E.g. `stride_am` is + # how much to increase `a_ptr` by to get the element one row down + # (A has M rows). + stride_am, + stride_ak, + stride_be, + stride_bk, + stride_bn, + stride_cm, + stride_cn, + # Meta-parameters + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, + GROUP_SIZE_M: tl.constexpr, + EVEN_K: tl.constexpr, + NUM_SMS: tl.constexpr, + MUL_ROUTED_WEIGHT: tl.constexpr, + top_k: tl.constexpr, + compute_type: tl.constexpr, + use_fp8: tl.constexpr, +): + """ + Implements the fused computation for a Mixture of Experts (MOE) using + token and expert matrices. + This is the persistent version of the fused_moe kernel. + + Key Parameters: + - A: The input tensor representing tokens with shape (*, K), where '*' can + be any shape representing batches and K is the feature dimension of + each token. + - B: The stacked MOE weight tensor with shape (E, N, K), where E is + the number of experts, K is the input feature dimension, and N is + the output feature dimension. + - C: The output cache tensor with shape (M, topk, N), where M is the + total number of tokens post padding, topk is the number of times + each token is repeated, and N is the output feature dimension. + - sorted_token_ids: A tensor containing the sorted indices of tokens, + repeated topk times and arranged by the expert index they are + assigned to. + - expert_ids: A tensor containing the indices of the expert for each + block. It determines which expert matrix from B should be used for + each block in A. + This kernel performs the multiplication of a token by its corresponding + expert matrix as determined by `expert_ids`. The sorting of + `sorted_token_ids` by expert index and padding ensures divisibility by + BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix + multiplication across different blocks processed by the same expert. + """ + # ----------------------------------------------------------- + # Simply compute how many iterations each persistent block needs to do + start_pid = tl.program_id(axis=0) + num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) + num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) + # num_tiles = num_pid_m * num_pid_n + tile_id = start_pid + + offs_k = tl.arange(0, BLOCK_SIZE_K) + # offs_token = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int32) + # token_mask = tl.zeros((BLOCK_SIZE_M,), dtype=tl.int1) + + # Load tile-invariant runtime constant + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr) + + # Compute how many tiles are outside the padding region + num_pid_in_group = GROUP_SIZE_M * num_pid_n + pid_m = 0 + tile_id2 = start_pid - NUM_SMS + num_valid_tiles = -1 + while pid_m * BLOCK_SIZE_M < num_tokens_post_padded: + num_valid_tiles += 1 + tile_id2 += NUM_SMS + group_id = tile_id2 // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id2 % num_pid_in_group) % group_size_m) + + for _ in range(0, num_valid_tiles): + if GROUP_SIZE_M == 1: + pid_m = tile_id // num_pid_n + pid_n = tile_id % num_pid_n + else: + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + # Compute the mask + offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) + token_mask = offs_token < num_valid_tokens + # Compute the A pointer + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + + offs_k[None, :] * stride_ak) + # Compute the B pointer + offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + off_experts = tl.load(expert_ids_ptr + pid_m) + b_ptrs = (b_ptr + off_experts * stride_be + + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)) + + if use_fp8: + a_scale = tl.load(a_scale_ptr) + b_scale = tl.load(b_scale_ptr + off_experts) + + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): + # Load the next block of A and B, generate a mask by checking the + # K dimension. + if EVEN_K: + a = tl.load(a_ptrs, mask=token_mask[:, None], other=0.0) + b = tl.load(b_ptrs) + else: + a = tl.load( + a_ptrs, + mask=token_mask[:, None] & + (offs_k[None, :] < K - k * BLOCK_SIZE_K), + other=0.0 + ) + b = tl.load( + b_ptrs, + mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, + other=0.0 + ) + # We accumulate along the K dimension. + if use_fp8: + accumulator = tl.dot(a, b, acc=accumulator) + else: + accumulator += tl.dot(a, b) + # Advance the ptrs to the next K block. + a_ptrs += BLOCK_SIZE_K * stride_ak + b_ptrs += BLOCK_SIZE_K * stride_bk + + if MUL_ROUTED_WEIGHT: + moe_weight = tl.load(topk_weights_ptr + offs_token, + mask=token_mask, + other=0) + accumulator = accumulator * moe_weight[:, None] + + if use_fp8: + accumulator = (accumulator * a_scale * b_scale).to(compute_type) + else: + accumulator = accumulator.to(compute_type) + # ----------------------------------------------------------- + # Write back the block of the output + offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = (c_ptr + stride_cm * offs_token[:, None] + + stride_cn * offs_cn[None, :]) + c_mask = token_mask[:, None] & (offs_cn[None, :] < N) + tl.store(c_ptrs, accumulator, mask=c_mask) + + # advance tile_id + tile_id += NUM_SMS + +@triton.jit +def _abs_max(val1, val2): + val1_abs = tl.abs(val1) + val2_abs = tl.abs(val2) + if val1_abs >= val2_abs: + return val1_abs + else: + return val2_abs + + +_dynamic_quant_configs = [ + triton.Config( + {}, + num_warps=warps, + ) + for warps in [2, 4, 8, 16] +] + + +@triton.jit +def _triton_dynamic_quantize_kernel( + output_ptr, + input_ptr, + scale_ptr, + stride_outputm, + stride_outputn, + stride_inputm, + stride_inputn, + n_elements, + N: tl.constexpr, +): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, N) + mask = offsets < n_elements + input_ptrs = input_ptr + pid * stride_inputm + offsets + input_vals = tl.load(input_ptrs, mask=mask, other=1e-6) + abs_max_f = tl.reduce(input_vals, 0, _abs_max) + dynamic_per_token_scale = 127.0 / abs_max_f + precison_mask = tl.where(input_vals > 0, 0.5, -0.5) + output_vals = (input_vals * dynamic_per_token_scale + precison_mask).to(tl.int8) + output_ptrs = output_ptr + pid * stride_outputm + offsets + tl.store(output_ptrs, output_vals, mask=mask) + tl.store(scale_ptr + pid, abs_max_f / 127.0) + +def triton_dynamic_quantize(out, input, scale): + assert input.is_contiguous(), "input must be contiguous" + num_tokens = input.size(0) + hidden_size = input.size(1) + # tl.reduce requires the number of elements + # must be power-of-two + hidden_size_padded = triton.next_power_of_2(int(hidden_size)) + kwargs = [ + out, + input, + scale, + out.stride(0), + out.stride(1), + input.stride(0), + input.stride(1), + input.size(1), + ] + grid = (num_tokens, 1, 1) + const_kwargs = {"N": hidden_size_padded} + method_name = "dynamic_quant_" + str(hidden_size_padded) + if 0: + dynamic_quant = triton.autotune(configs=_dynamic_quant_configs, key=['N'])(_triton_dynamic_quantize_kernel) + else: + dynamic_quant = _triton_dynamic_quantize_kernel + dynamic_quant[grid](*kwargs,**const_kwargs) + +def moe_align_block_size( + topk_ids: torch.Tensor, block_size: int, + num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Aligns the token distribution across experts to be compatible with block + size for matrix multiplication. + + Parameters: + - topk_ids: A tensor of shape [total_tokens, top_k] representing the + top-k expert indices for each token. + - block_size: The block size used in block matrix multiplication. + - num_experts: The total number of experts. + + Returns: + - sorted_token_ids: A tensor containing the sorted token indices according + to their allocated expert. + - expert_ids: A tensor indicating the assigned expert index for each block. + - num_tokens_post_padded: The total number of tokens after padding, + ensuring divisibility by block_size. + + This function pads the number of tokens that each expert needs to process + so that it is divisible by block_size. + Padding ensures that during block matrix multiplication, the dimensions + align correctly. + + Example: + Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], + block_size = 4, and num_experts = 4: + - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, + with each expert needing to process 3 tokens. + - As block_size is 4, we pad 1 token for each expert. + - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3]. + - Then append padding tokens [12, 12, 12, 12] for each block. + - After sorting by expert index, we obtain token_ids + [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12]. + Tokens 12 are non-existent (padding) and are ignored in + the subsequent matrix multiplication. + - The padding ensures that the total number of tokens is now divisible + by block_size for proper block matrix operations. + """ + #print("BLOCK_SIZE_M in aligh", block_size) + max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) + sorted_ids = torch.empty((max_num_tokens_padded, ), + dtype=torch.int32, + device=topk_ids.device) + sorted_ids.fill_(topk_ids.numel()) + max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) + expert_ids = torch.zeros((max_num_m_blocks, ), + dtype=torch.int32, + device=topk_ids.device) + num_tokens_post_pad = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + token_num = torch.empty((1), + dtype=torch.int32, + device=topk_ids.device) + moe_kernels.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, + expert_ids, token_num, num_tokens_post_pad) + return sorted_ids, expert_ids, num_tokens_post_pad + +# int8 +def scaled_int8_quant( + input: torch.Tensor, + scale: Optional[torch.Tensor] = None, + azp: Optional[torch.Tensor] = None, + symmetric: bool = True +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Quantize the input tensor to int8 and return the quantized tensor and scale, and maybe azp. + + Args: + input: The input tensor to be quantized to int8. + scale: Optional scaling factor for the int8 quantization. + When not provided, we invoke dynamic-per-token quantization. + azp: Optional zero-point for the int8 quantization. + Must be provided for asymmetric quantization if `scale` is provided. + symmetric: Whether to use symmetric quantization (scale only, azp ignored). + + Returns: + Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] : Output int8 tensor, scales, and optionally azp. + """ + output = torch.empty_like(input, dtype=torch.int8) +# if scale is not None: +# # static-per-tensor quantization. +# assert symmetric == ( +# azp is +# None), "azp must only be provided for asymmetric quantization." +# torch.ops._C.static_scaled_int8_quant(output, input, scale, azp) +# return output, scale, None + + # dynamic-per-token quantization. + input_scales = torch.empty((input.numel() // input.shape[-1], 1), + device=input.device, + dtype=torch.float32) + input_azp = None if symmetric else torch.empty_like(input_scales, + dtype=torch.int32) + moe_kernels.dynamic_scaled_int8_quant(output, input, input_scales, + input_azp) + return output, input_scales + +def invoke_fused_moe_int8_a8w8_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + topk_weights: torch.Tensor, topk_ids: torch.Tensor, + sorted_token_ids: torch.Tensor, + expert_ids: torch.Tensor, + num_tokens_post_padded: torch.Tensor, + mul_routed_weight: bool, top_k: int, + config: Dict[str, Any], compute_type: tl.dtype, + use_fp8_w8a8: bool, use_int8_w8a16: bool) -> None: + assert topk_weights.stride(1) == 1 + assert sorted_token_ids.stride(0) == 1 + + assert B_scale is not None + B_scale = B_scale.view(-1, B.shape[1]) + # print("llm Bscale shape:",A_scale) + + if not FUSED_MOE_PERSISTENT: + grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ + "BLOCK_SIZE_M"]) * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]), ) + + fused_moe_int8_a8w8_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + #B.shape[2] - padding_size, + B.shape[2] , + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + B_scale.stride(0), + B_scale.stride(1), + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + **config, + # enable_moe_lds_bypass=ENABLE_MOE_LDS_BYPASS + ) + else: + NUM_SMS = torch.cuda.get_device_properties("cuda").multi_processor_count * 2 + grid = lambda META: (min( + NUM_SMS, + triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"]) * + triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]) + ), ) + + fused_moe_persistent_kernel[grid]( + A, + B, + C, + A_scale, + B_scale, + topk_weights, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + B.shape[1], + B.shape[2] - padding_size, + sorted_token_ids.shape[0], + topk_ids.numel(), + A.stride(0), + A.stride(1), + B.stride(0), + B.stride(2), + B.stride(1), + C.stride(1), + C.stride(2), + NUM_SMS=NUM_SMS, + MUL_ROUTED_WEIGHT=mul_routed_weight, + top_k=top_k, + compute_type=compute_type, + use_fp8=use_fp8_w8a8, + **config, + # enable_moe_lds_bypass=ENABLE_MOE_LDS_BYPASS + ) + + +def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: + # device_name = current_platform.get_device_name().replace(" ", "_") + device_name = 'AMD_Instinct_MI308X_OAM' # TODO: need to update + dtype_selector = "" if not dtype else f",dtype={dtype}" + return f"E={E},N={N},device_name={device_name}{dtype_selector}.json" + + +@functools.lru_cache +def get_moe_configs(E: int, N: int, + dtype: Optional[str]) -> Optional[Dict[int, Any]]: + """ + Return optimized configurations for the fused MoE kernel. + + The return value will be a dictionary that maps an irregular grid of + batch sizes to configurations of the fused_moe kernel. To evaluate the + kernel on a given batch size bs, the closest batch size in the grid should + be picked and the associated configuration chosen to invoke the kernel. + """ + + # First look up if an optimized configuration is available in the configs + # directory + json_file_name = get_config_file_name(E, N, dtype) + + config_file_path = os.path.join( + os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) + if os.path.exists(config_file_path): + with open(config_file_path) as f: + logger.info("Using configuration from %s for MoE layer.", + config_file_path) + # If a configuration has been found, return it + return {int(key): val for key, val in json.load(f).items()} + + # If no optimized configuration is available, we will use the default + # configuration + logger.info("---> MOE tuned file not found at %s",config_file_path) + return None + + +def get_default_config( + M: int, + E: int, + N: int, + K: int, + topk: int, + dtype: Optional[str], + is_marlin: bool, +) -> Dict[str, int]: + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 128, # reqd. for MOE shuffle + 'BLOCK_SIZE_K': 128, # reqd. for MOE shuffle + 'GROUP_SIZE_M': 8 + } + # A heuristic: fused marlin works faster with this config for small M + if M <= E or (is_marlin and M <= 32): + config = { + 'BLOCK_SIZE_M': 16, + 'BLOCK_SIZE_N': 128, # reqd. for MOE shuffle + 'BLOCK_SIZE_K': 128, # reqd. for MOE shuffle + 'GROUP_SIZE_M': 1 + } + return config + + +def try_get_optimal_moe_config( + w1_shape: Tuple[int, ...], + w2_shape: Tuple[int, ...], + top_k: int, + dtype: Optional[str], + M: int, + override_config: Optional[Dict[str, Any]] = None, + is_marlin: bool = False, +): + if override_config: + config = override_config + else: + # First try to load optimal config from the file + E, _, N = w2_shape + configs = get_moe_configs(E, N, dtype) + + if configs: + # If an optimal configuration map has been found, look up the + # optimal config + config = configs[min(configs.keys(), key=lambda x: abs(x - M))] + else: + # Else use the default config + config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, + is_marlin) + return config + + +def fused_topk( + hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, +): + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + M, _ = hidden_states.shape + + topk_weights = torch.empty(M, + topk, + dtype=torch.float32, + device=hidden_states.device) + topk_ids = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + token_expert_indicies = torch.empty(M, + topk, + dtype=torch.int32, + device=hidden_states.device) + + moe_kernels.topk_softmax( + topk_weights, + topk_ids, + token_expert_indicies, + gating_output.float(), # TODO(woosuk): Optimize this. + renormalize + ) + del token_expert_indicies # Not used. Will be used in the future. + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights, topk_ids + + +# This is used by the Deepseek-V2 model +def grouped_topk(hidden_states: torch.Tensor, + gating_output: torch.Tensor, + topk: int, + renormalize: bool, + num_expert_group: int = 0, + topk_group: int = 0): + + assert hidden_states.shape[0] == gating_output.shape[0], ( + "Number of tokens mismatch") + + scores = torch.softmax(gating_output, dim=-1) + num_token = scores.shape[0] + group_scores = scores.view(num_token, num_expert_group, + -1).max(dim=-1).values # [n, n_group] + group_idx = torch.topk(group_scores, k=topk_group, dim=-1, + sorted=False)[1] # [n, top_k_group] + group_mask = torch.zeros_like(group_scores) # [n, n_group] + group_mask.scatter_(1, group_idx, 1) # [n, n_group] + score_mask = group_mask.unsqueeze(-1).expand( + num_token, num_expert_group, + scores.shape[-1] // num_expert_group).reshape(num_token, -1) # [n, e] + tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e] + topk_weights, topk_ids = torch.topk(tmp_scores, + k=topk, + dim=-1, + sorted=False) + + if renormalize: + topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) + + return topk_weights.to(torch.float32), topk_ids.to(torch.int32) + + +def get_config_dtype_str(dtype: torch.dtype, + use_int8_w8a16: Optional[bool] = False, + use_fp8_w8a8: Optional[bool] = False): + if use_fp8_w8a8: + return "fp8_w8a8" + elif use_int8_w8a16: + return "int8_w8a16" + elif dtype == torch.float: + # avoiding cases where kernel fails when float32 MoE + # use fp16/bfloat16 configs + return "float32" + return None + + +def fused_experts_int8_a8w8(hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False): + # Check constraints. + assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch" + assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + num_tokens, _ = hidden_states.shape + E, N, _ = w1.shape + # We execute the fused_moe kernel in chunks to circumvent this issue: + # https://github.com/vllm-project/vllm/issues/5938 + CHUNK_SIZE = VLLM_FUSED_MOE_CHUNK_SIZE + M = min(num_tokens, CHUNK_SIZE) + config_dtype = get_config_dtype_str(use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16, + dtype=hidden_states.dtype) + + get_config_func = functools.partial( + try_get_optimal_moe_config, + w1.shape, + (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size), + topk_ids.shape[1], + config_dtype, + override_config=override_config, + ) + + config = get_config_func(M) + + intermediate_cache1 = torch.zeros((M, topk_ids.shape[1], N), + device=hidden_states.device, + dtype=torch.half) + intermediate_cache2 = torch.zeros((M * topk_ids.shape[1], N // 2), + device=hidden_states.device, + dtype=torch.half) + intermediate_cache3 = torch.zeros((M, topk_ids.shape[1], w2.shape[1]), + device=hidden_states.device, + dtype=torch.half) + + compute_type = (tl.bfloat16 + if hidden_states.dtype == torch.bfloat16 else tl.float16) + + if inplace: + out_hidden_states = hidden_states + else: + out_hidden_states = torch.empty_like(hidden_states) + + # print("init config:", config) + for chunk in range((num_tokens // CHUNK_SIZE) + 1): + begin_chunk_idx, end_chunk_idx = (chunk * CHUNK_SIZE, + min((chunk + 1) * CHUNK_SIZE, + num_tokens)) + curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx] + tokens_in_chunk, _ = curr_hidden_states.shape + + if tokens_in_chunk == 0: + break + print("tokens_in_chunk,", tokens_in_chunk, "CHUNK_SIZE",CHUNK_SIZE, "num_tokens",num_tokens) + if tokens_in_chunk < CHUNK_SIZE and chunk > 0: + # Adjust the intermediate cache size and config for the last + # chunk. Note that in most cases we only have one chunk + # so the cache size and config are already set correctly and + # do not need to be adjusted. + intermediate_cache1 = intermediate_cache1[:tokens_in_chunk] + intermediate_cache2 = intermediate_cache2[:tokens_in_chunk] + intermediate_cache3 = intermediate_cache3[:tokens_in_chunk] + config = get_config_func(tokens_in_chunk) + print("inside branch:") + + + curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx] + curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx] + + sorted_token_ids, expert_ids, num_tokens_post_padded = ( + moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], E)) + # print("llm : hidden_state:", curr_hidden_states) + # print("llm: w1:", w1) + # print("llm w1+scale", w1_scale) + invoke_fused_moe_int8_a8w8_kernel(curr_hidden_states, + w1, + intermediate_cache1, + a1_scale, + w1_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + False, + topk_ids.shape[1], + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + # print("llm : intermediate_cache1", intermediate_cache1) + moe_kernels.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) + + intermediate_cache2_quant = torch.empty( + intermediate_cache2.shape, dtype=torch.int8, device=hidden_states.device + ) + intermediate_cache2_scales = torch.empty( + intermediate_cache2.shape[0], dtype=torch.half, device=hidden_states.device + ) + triton_dynamic_quantize( + intermediate_cache2_quant, intermediate_cache2, intermediate_cache2_scales + ) +# print("llm: cache2:", intermediate_cache2_quant) +# print("llm: cache2_sclae:", intermediate_cache2_scales) + + invoke_fused_moe_int8_a8w8_kernel(intermediate_cache2_quant, + w2, + intermediate_cache3, + intermediate_cache2_scales, + w2_scale, + curr_topk_weights, + curr_topk_ids, + sorted_token_ids, + expert_ids, + num_tokens_post_padded, + True, + 1, + config, + compute_type=compute_type, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) + #print("dtype:",intermediate_cache3.dtype) +# moe_kernels.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), +# out_hidden_states[begin_chunk_idx:end_chunk_idx]) +# + #return out_hidden_states + # return intermediate_cache3 + # print("before sum : intermediate_cache3",intermediate_cache3) + return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1) + + +def fused_moe_int8_a8w8( + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + gating_output: torch.Tensor, + w1_scale: torch.Tensor, + w2_scale: torch.Tensor, + a1_scale: torch.Tensor, + a2_scale: torch.Tensor, + topk: int, + renormalize: bool, + inplace: bool = False, + override_config: Optional[Dict[str, Any]] = None, + use_grouped_topk: bool = False, + num_expert_group: Optional[int] = None, + topk_group: Optional[int] = None, + custom_routing_function: Optional[Callable] = None, + use_fp8_w8a8: bool = False, + use_int8_w8a16: bool = False, +) -> torch.Tensor: + """ + This function computes a Mixture of Experts (MoE) layer using two sets of + weights, w1 and w2, and top-k gating mechanism. + + Parameters: + - hidden_states (torch.Tensor): The input tensor to the MoE layer. + - w1 (torch.Tensor): The first set of expert weights. + - w2 (torch.Tensor): The second set of expert weights. + - gating_output (torch.Tensor): The output of the gating operation + (before softmax). + - topk (int): The number of top-k experts to select. + - renormalize (bool): If True, renormalize the top-k weights to sum to 1. + - inplace (bool): If True, perform the operation in-place. + Defaults to False. + - override_config (Optional[Dict[str, Any]]): Optional override + for the kernel configuration. + - num_expert_group: Optional[int]: additional parameter for grouped_topk + - topk_group: Optional[int]: additional parameter for grouped_topk + - use_grouped_topk: If True, use grouped_topk instead of fused_topk + note: Deepseekv2 model uses grouped_topk + - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner + products for w1 and w2. Defaults to False. + - w1_scale (Optional[torch.Tensor]): Optional scale to be used for + w1. + - w2_scale (Optional[torch.Tensor]): Optional scale to be used for + w2. + + Returns: + - torch.Tensor: The output tensor after applying the MoE layer. + """ + # Check constraints. + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + + if use_grouped_topk: + assert num_expert_group is not None and topk_group is not None + topk_weights, topk_ids = grouped_topk(hidden_states, gating_output, + topk, renormalize, + num_expert_group, topk_group) + elif custom_routing_function is None: +# print("fused topk") + topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, + renormalize = False) + else: + topk_weights, topk_ids = custom_routing_function( + hidden_states, gating_output, topk, renormalize) +# print("moe topk_weights",topk_weights,"topk_ids",topk_ids) + return fused_experts_int8_a8w8(hidden_states, + w1, + w2, + topk_weights, + topk_ids, + w1_scale=w1_scale, + w2_scale=w2_scale, + a1_scale=a1_scale, + a2_scale=a2_scale, + inplace=inplace, + override_config=override_config, + use_fp8_w8a8=use_fp8_w8a8, + use_int8_w8a16=use_int8_w8a16) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_quant_kernels.cu b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_quant_kernels.cu new file mode 100644 index 000000000..2266dee46 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_quant_kernels.cu @@ -0,0 +1,272 @@ +#include +#include +#include + +#include "dispatch_utils.h" + +#ifndef USE_ROCM + #include + #include +#else + #include + #include +#endif + +static inline __device__ int8_t float_to_int8_rn(float x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate + dst = std::clamp(dst, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int32_t float_to_int32_rn(float x) { +#ifdef USE_ROCM + // int32_max is not exactly representable as float. + // Therefore, we need to be careful and manually return int32_max on overflow. + // For symmetry, we also do the same for int32_min, even though it is exactly + // representable as float and the conversion should be exact. + static constexpr auto i32_min = std::numeric_limits::min(); + static constexpr auto i32_min_f = static_cast(i32_min); + static constexpr auto i32_max = std::numeric_limits::max(); + static constexpr auto i32_max_f = static_cast(i32_max); + + // To match the rounding mode of CUDA, we use nearbyint. + // It uses the current rounding mode, which is always FE_TONEAREST on HIP. + // If that changes in the future, we may need to set the rounding mode + // explicitly, either at runtime or compile time. + float dst = std::nearbyint(x); + + // saturate on the higher end. + if (dst >= i32_max_f) { + return i32_max; + } + // saturate on the lower end. + if (dst <= i32_min_f) { + return i32_min; + } + + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.rni.sat.s32.f32 %0, %1;" : "=r"(dst) : "f"(x)); + return reinterpret_cast(dst); +#endif +} + +static inline __device__ int8_t int32_to_int8(int32_t x) { +#ifdef USE_ROCM + static constexpr auto i8_min = + static_cast(std::numeric_limits::min()); + static constexpr auto i8_max = + static_cast(std::numeric_limits::max()); + + // saturate + int32_t dst = std::clamp(x, i8_min, i8_max); + return static_cast(dst); +#else + // CUDA path + uint32_t dst; + asm volatile("cvt.sat.s8.s32 %0, %1;" : "=r"(dst) : "r"(x)); + return reinterpret_cast(dst); +#endif +} + +namespace vllm { + +template +__global__ void static_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) / scale); + } +} + +template +__global__ void static_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type const* scale_ptr, azp_type const* azp_ptr, + const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + scale_type const scale = *scale_ptr; + azp_type const azp = *azp_ptr; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = int32_to_int8(float_to_int32_rn(val / scale) + azp); + out[token_idx * hidden_size + i] = quant_val; + } +} + +template +__global__ void dynamic_scaled_int8_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, const int hidden_size) { + int const tid = threadIdx.x; + int const token_idx = blockIdx.x; + float absmax_val = 0.0f; + float const zero = 0.0f; + + for (int i = tid; i < hidden_size; i += blockDim.x) { + float val = static_cast(input[token_idx * hidden_size + i]); + val = val > zero ? val : -val; + absmax_val = val > absmax_val ? val : absmax_val; + } + + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + float const block_absmax_val_maybe = + BlockReduce(reduceStorage).Reduce(absmax_val, cub::Max{}, blockDim.x); + __shared__ float block_absmax_val; + if (tid == 0) { + block_absmax_val = block_absmax_val_maybe; + scale[token_idx] = block_absmax_val / 127.0f; + } + __syncthreads(); + + float const tmp_scale = 127.0f / block_absmax_val; + for (int i = tid; i < hidden_size; i += blockDim.x) { + out[token_idx * hidden_size + i] = float_to_int8_rn( + static_cast(input[token_idx * hidden_size + i]) * tmp_scale); + } +} + +template +__global__ void dynamic_scaled_int8_azp_quant_kernel( + scalar_t const* __restrict__ input, int8_t* __restrict__ out, + scale_type* scale, azp_type* azp, const int hidden_size) { + int const token_idx = blockIdx.x; + + // Scan for the min and max value for this token + float max_val = std::numeric_limits::min(); + float min_val = std::numeric_limits::max(); + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto val = static_cast(input[token_idx * hidden_size + i]); + max_val = std::max(max_val, val); + min_val = std::min(min_val, val); + } + + // Reduce the max and min values across the block + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage reduceStorage; + max_val = BlockReduce(reduceStorage).Reduce(max_val, cub::Max{}, blockDim.x); + __syncthreads(); // Make sure min doesn't mess with max shared memory + min_val = BlockReduce(reduceStorage).Reduce(min_val, cub::Min{}, blockDim.x); + + __shared__ scale_type scale_sh; + __shared__ azp_type azp_sh; + + // Compute the scale and zero point and store them, only on the first thread + if (threadIdx.x == 0) { + float const scale_val = (max_val - min_val) / 255.0f; + // Use rounding to even (same as torch.round) + auto const azp_float = std::nearbyint(-128.0f - min_val / scale_val); + auto const azp_val = static_cast(azp_float); + + // Store the scale and azp into shared and global + scale[token_idx] = scale_sh = scale_val; + azp[token_idx] = azp_sh = azp_val; + } + + // Wait for the scale and azp to be computed + __syncthreads(); + + float const scale_val = scale_sh; + azp_type const azp_val = azp_sh; + + // Quantize the values + for (int i = threadIdx.x; i < hidden_size; i += blockDim.x) { + auto const val = static_cast(input[token_idx * hidden_size + i]); + auto const quant_val = + int32_to_int8(float_to_int32_rn(val / scale_val) + azp_val); + out[token_idx * hidden_size + i] = quant_val; + } +} + +} // namespace vllm + +void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor const& scale, + c10::optional const& azp) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scale.numel() == 1); + TORCH_CHECK(!azp || azp->numel() == 1); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "static_scaled_int8_quant_kernel", [&] { + if (!azp) { + vllm::static_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), hidden_size); + } else { + vllm::static_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scale.data_ptr(), azp->data_ptr(), + hidden_size); + } + }); +} + +void dynamic_scaled_int8_quant( + torch::Tensor& out, // [..., hidden_size] + torch::Tensor const& input, // [..., hidden_size] + torch::Tensor& scales, c10::optional const& azp) { + TORCH_CHECK(input.is_contiguous()); + TORCH_CHECK(out.is_contiguous()); + TORCH_CHECK(scales.is_contiguous()); + TORCH_CHECK(!azp || azp->is_contiguous()); + + int const hidden_size = input.size(-1); + int const num_tokens = input.numel() / hidden_size; + dim3 const grid(num_tokens); + dim3 const block(std::min(hidden_size, 1024)); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + VLLM_DISPATCH_FLOATING_TYPES( + input.scalar_type(), "dynamic_scaled_int8_quant_kernel", [&] { + if (!azp) { + vllm::dynamic_scaled_int8_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), hidden_size); + } else { + vllm::dynamic_scaled_int8_azp_quant_kernel + <<>>( + input.data_ptr(), out.data_ptr(), + scales.data_ptr(), azp->data_ptr(), + hidden_size); + } + }); +} diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_scale_quant.h b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_scale_quant.h new file mode 100644 index 000000000..6c9367d5d --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/int8_scale_quant.h @@ -0,0 +1,5 @@ +#pragma once + +#include +void dynamic_scaled_int8_quant( + torch::Tensor& out, torch::Tensor const& input, torch::Tensor& scale,c10::optional const& azp); diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/op_tests/test_layernorm2d.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/op_tests/test_layernorm2d.py new file mode 100644 index 000000000..73f1af02c --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/op_tests/test_layernorm2d.py @@ -0,0 +1,66 @@ +import torch +import torch.nn.functional as F +import rocmKernels + +num_iters = 100 + + +def run_torch(input, normalized_shape, weight, bias, eps): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latencies = [] + for i in range(num_iters): + start_event.record() + output = F.layer_norm( + input=input, + normalized_shape=(input.shape[-1],), + weight=weight, + bias=bias, + eps=eps + ) + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + # print(f"run_torch avg time: {avg} us") + return output, avg + + +def run_ck(input, normalized_shape, weight, bias, eps): + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + latencies = [] + for i in range(num_iters): + start_event.record() + output = torch.empty_like(input) + rocmKernels.layernorm2d_fwd( + output, + input, + weight, + bias, + eps + ) + end_event.record() + end_event.synchronize() + latencies.append(start_event.elapsed_time(end_event)) + avg = sum(latencies) / (num_iters * 10) * 1000 # us + # print(f"run_ck avg time: {avg} us") + return output, avg + + +def checkAllclose(a, b, rtol, atol): + assert torch.allclose( + a, b, rtol, atol), f"torch and ck results are not close\n{a.shape}\n{a}\n{b.shape}\n{b}\nmax delta:{(a-b).max()}" + + +for dtype in [torch.float16, torch.bfloat16]: + # for dtype in [torch.float16]: + for dim in [4096, 8192, 16384, 32768, 65536]: + input = torch.randn(dim, dtype=dtype, device="cuda") + weight = torch.randn(dim, dtype=dtype, device="cuda") + bias = torch.randn(dim, dtype=dtype, device="cuda") + a, avg_a = run_torch(input, (dim,), weight, bias, 1e-5) + b, avg_b = run_ck(input, (dim,), weight, bias, 1e-5) + print( + f"dim: {dim}, dtype: {dtype}, torch avg: {avg_a:.2f} us, ck avg: {avg_b:.2f} us, uplift: {avg_a/avg_b-1:.1%}") + # checkAllclose(a, b, 1e-3, 1e-3) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/paged_attn.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/paged_attn.py new file mode 100644 index 000000000..8c6f23d5a --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/paged_attn.py @@ -0,0 +1,394 @@ +from typing import List, Optional, Tuple, Union +import torch +import rocmKernels as ops + +from dataclasses import dataclass +# from vllm.utils import is_hip +def is_hip(): + return True +# if HAS_TRITON: +# from vllm.attention.ops.prefix_prefill import context_attention_fwd + +# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`. +_PARTITION_SIZE = 512 if not is_hip() else 1024 +_PARTITION_SIZE_ROCM = 512 +_DEVICE_PROPERTIES = torch.cuda.get_device_properties("cuda") +_ON_NAVI = hasattr(_DEVICE_PROPERTIES, "gcnArchName") and \ + "gfx1" in torch.cuda.get_device_properties("cuda").gcnArchName + + +# page attention ops +def paged_attention_v1( + out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v1( + out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, + k_scale, v_scale, tp_rank, blocksparse_local_blocks, + blocksparse_vert_stride, blocksparse_block_size, + blocksparse_head_sliding_step) + + +def paged_attention_v2( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, +) -> None: + ops.paged_attention_v2( + out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, + num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, + alibi_slopes, kv_cache_dtype, k_scale, v_scale, tp_rank, + blocksparse_local_blocks, blocksparse_vert_stride, + blocksparse_block_size, blocksparse_head_sliding_step) + + +def paged_attention_rocm( + out: torch.Tensor, + exp_sum: torch.Tensor, + max_logits: torch.Tensor, + tmp_out: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + num_kv_heads: int, + scale: float, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + block_size: int, + max_seq_len: int, + alibi_slopes: Optional[torch.Tensor], + kv_cache_dtype: str, + k_scale: float, + v_scale: float, +) -> None: + ops.paged_attention_rocm(out, exp_sum, max_logits, tmp_out, query, + key_cache, value_cache, num_kv_heads, + scale, block_tables, seq_lens, + block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale) + +@dataclass +class PagedAttentionMetadata: + """Metadata for PagedAttention.""" + # (batch_size,). The length of sequences (entire tokens seen so far) per + # sequence. + seq_lens_tensor: Optional[torch.Tensor] + # Maximum sequence length in the batch. 0 if it is prefill-only batch. + max_decode_seq_len: int + # (batch_size, max_blocks_per_seq). + # Block addresses per sequence. (Seq id -> list of physical block) + # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks + # in the kv cache. Each block can contain up to block_size tokens. + # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph + # captured. + block_tables: Optional[torch.Tensor] + + +def _use_rocm_custom_paged_attention(qtype: torch.dtype, head_size: int, + block_size: int, gqa_ratio: int, + max_seq_len: int) -> bool: + # rocm custom page attention not support on navi (gfx1*) + return (not _ON_NAVI and (qtype == torch.half or qtype == torch.bfloat16) + and (head_size == 64 or head_size == 128) + and (block_size == 16 or block_size == 32) + and (gqa_ratio >= 1 and gqa_ratio <= 16) and max_seq_len <= 65536) + +class PagedAttention: + @staticmethod + def get_supported_head_sizes() -> List[int]: + return [64, 80, 96, 112, 120, 128, 192, 256] + + @staticmethod + def get_kv_cache_shape( + num_blocks: int, + block_size: int, + num_kv_heads: int, + head_size: int, + ) -> Tuple[int, ...]: + return (2, num_blocks, block_size * num_kv_heads * head_size) + + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + max_seq_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + tp_rank: int = 0, + blocksparse_local_blocks: int = 0, + blocksparse_vert_stride: int = 0, + blocksparse_block_size: int = 64, + blocksparse_head_sliding_step: int = 0, + ) -> torch.Tensor: + # Whether to use rocm custom paged attention or not + num_seqs, num_heads, head_size = query.shape + block_size = value_cache.shape[3] + gqa_ratio = num_heads // num_kv_heads + use_custom = _use_rocm_custom_paged_attention( + query.dtype, head_size, block_size, gqa_ratio, + max_seq_len) + output = torch.empty_like(query) + if use_custom: + max_num_partitions = ( + (max_seq_len + _PARTITION_SIZE_ROCM - 1) // + _PARTITION_SIZE_ROCM) + assert _PARTITION_SIZE_ROCM % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + out = output + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + else: + max_num_partitions = ((max_seq_len + _PARTITION_SIZE - 1) // + _PARTITION_SIZE) + if blocksparse_vert_stride is not None and blocksparse_vert_stride > 1: + # use blocksparse paged attention + block_size = value_cache.size(-1) + assert (blocksparse_block_size > 0 and + blocksparse_block_size % block_size == 0), \ + (f"{blocksparse_block_size=} needs to be a multiple of" + f"{block_size=} used in block_tables.") + + # NOTE(woosuk): We use a simple heuristic to decide whether to use + # PagedAttention V1 or V2. If the number of partitions is 1, we use + # V1 to avoid the overhead of reduction. Also, if the number of + # sequences or heads is large, we use V1 since there is enough work + # to parallelize. + # TODO(woosuk): Tune this heuristic. + # For context len > 8192, use V2 kernel to avoid shared memory shortage. + use_v1 = (max_seq_len <= 8192 + and (max_num_partitions == 1 or num_seqs * num_heads > 512)) + + if use_v1: + # Run PagedAttention V1. + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + else: + # Run PagedAttention V2. + assert _PARTITION_SIZE % block_size == 0 + tmp_output = torch.empty( + size=(num_seqs, num_heads, max_num_partitions, head_size), + dtype=output.dtype, + device=output.device, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, max_num_partitions), + dtype=torch.float32, + device=output.device, + ) + max_logits = torch.empty_like(exp_sums) + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + return output + + # @staticmethod + # def forward_prefix( + # query: torch.Tensor, + # key: torch.Tensor, + # value: torch.Tensor, + # kv_cache_dtype: str, + # key_cache: torch.Tensor, + # value_cache: torch.Tensor, + # block_tables: torch.Tensor, + # query_start_loc: torch.Tensor, + # seq_lens_tensor: torch.Tensor, + # context_lens: torch.Tensor, + # max_query_len: int, + # alibi_slopes: Optional[torch.Tensor], + # sliding_window: Optional[int], + # k_scale: float, + # v_scale: float, + # ) -> torch.Tensor: + # output = torch.empty_like(query) + # context_attention_fwd( + # query, + # key, + # value, + # output, + # kv_cache_dtype, + # key_cache, + # value_cache, + # block_tables, + # # query_start_loc is (batch_size + 1,) + # query_start_loc[:-1], + # seq_lens_tensor, + # context_lens, + # max_query_len, + # k_scale, + # v_scale, + # alibi_slopes, + # sliding_window, + # ) + # return output + + @staticmethod + def swap_blocks( + src_kv_cache: torch.Tensor, + dst_kv_cache: torch.Tensor, + src_to_dst: torch.Tensor, + ) -> None: + src_key_cache = src_kv_cache[0] + dst_key_cache = dst_kv_cache[0] + ops.swap_blocks(src_key_cache, dst_key_cache, src_to_dst) + + src_value_cache = src_kv_cache[1] + dst_value_cache = dst_kv_cache[1] + ops.swap_blocks(src_value_cache, dst_value_cache, src_to_dst) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: torch.Tensor, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/rotary_embedding.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/rotary_embedding.py new file mode 100644 index 000000000..06974a146 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/rotary_embedding.py @@ -0,0 +1,1008 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.33.2/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Rotary Positional Embeddings.""" +import math +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn + +# from custom_op import CustomOp + + +def _rotate_neox(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., :x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2:] + return torch.cat((-x2, x1), dim=-1) + + +def _rotate_gptj(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., ::2] + x2 = x[..., 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) + + +def _apply_rotary_emb( + x: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + is_neox_style: bool, +) -> torch.Tensor: + """ + Args: + x: [num_tokens, num_heads, head_size] + cos: [num_tokens, head_size // 2] + sin: [num_tokens, head_size // 2] + is_neox_style: Whether to use the Neox-style or GPT-J-style rotary + positional embeddings. + """ + cos = cos.unsqueeze(-2).to(x.dtype) + sin = sin.unsqueeze(-2).to(x.dtype) + if is_neox_style: + x1, x2 = torch.chunk(x, 2, dim=-1) + else: + x1 = x[..., ::2] + x2 = x[..., 1::2] + o1 = x1 * cos - x2 * sin + o2 = x2 * cos + x1 * sin + if is_neox_style: + return torch.cat((o1, o2), dim=-1) + else: + return torch.stack((o1, o2), dim=-1).flatten(-2) + + +# class RotaryEmbedding(CustomOp): +class RotaryEmbedding(nn.Module): + """Original rotary positional embedding.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + ) -> None: + super().__init__() + self.head_size = head_size + self.rotary_dim = rotary_dim + self.max_position_embeddings = max_position_embeddings + self.base = base + self.is_neox_style = is_neox_style + self.dtype = dtype + + cache = self._compute_cos_sin_cache() + cache = cache.to(dtype) + self.cos_sin_cache: torch.Tensor + self.register_buffer("cos_sin_cache", cache, persistent=False) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + """Compute the inverse frequency.""" + # NOTE(woosuk): To exactly match the HF implementation, we need to + # use CPU to compute the cache and then move it to GPU. However, we + # create the cache on GPU for faster initialization. This may cause + # a slight numerical difference between the HF implementation and ours. + inv_freq = 1.0 / (base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim)) + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + """Compute the cos and sin cache.""" + inv_freq = self._compute_inv_freq(self.base) + t = torch.arange(self.max_position_embeddings, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward_native( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """A PyTorch-native implementation of forward().""" + if offsets is not None: + positions = positions + offsets + positions = positions.flatten() + num_tokens = positions.shape[0] + cos_sin = self.cos_sin_cache.index_select(0, positions) + cos, sin = cos_sin.chunk(2, dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + # def forward_cuda( + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + # from vllm import _custom_ops as ops + import rocmKernels as ops + + self.cos_sin_cache = self.cos_sin_cache.to(query.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def forward_xpu( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + from vllm._ipex_ops import ipex_ops as ops + + self.cos_sin_cache = self.cos_sin_cache.to(positions.device, + dtype=query.dtype) + # ops.rotary_embedding()/batched_rotary_embedding() + # are in-place operations that update the query and key tensors. + if offsets is not None: + ops.batched_rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, + self.is_neox_style, self.rotary_dim, + offsets) + else: + ops.rotary_embedding(positions, query, key, self.head_size, + self.cos_sin_cache, self.is_neox_style) + return query, key + + def extra_repr(self) -> str: + s = f"head_size={self.head_size}, rotary_dim={self.rotary_dim}" + s += f", max_position_embeddings={self.max_position_embeddings}" + s += f", base={self.base}, is_neox_style={self.is_neox_style}" + return s + + +class LinearScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with linear scaling. + + It supports multiple scaling factors. Since multiple LoRA adapters may have + different scaling factors, we need multiple cos/sin caches. In this way, + instead of running rotary embedding kernel per lora, we can run multiple + lora in a batched way. + + In addition to that, we also keep the cos/sin cache for the scaling factor + of 1 (default) at all times. + + Exemplary for two scaling factors x=1, y and z with embeddings + [[x11, x12, ... x1m], ..., [xn1, xn2, ..., xnm]] and + [[y11, y12, ... y1o], ..., [yn1, yn2, ..., yno]], and + [[z11, z12, ... z1p], ..., [zn1, zn2, ..., znp]], + + we construct the cos/sin cache as follows: + [[x11, x12, ... x1m, y11, y12, ... y1o, z11, z12, ... z1p], + ... + [xn1, xn2, ... xnm, yn1, yn2, ... yno, zn1, zn2, ... znp]] + + We then use offsets to index into the cos/sin cache for + the respective scaling factors. + + The offset to cache can be accessed via `scaling_factor_to_offset` API. + + Credits to the Reddit user /u/kaiokendev + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factors: Union[List[float], float], + dtype: torch.dtype, + ) -> None: + if isinstance(scaling_factors, float): + scaling_factors = [scaling_factors] + self.scaling_factors: List[float] = scaling_factors # noqa + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + # Lazy initialized. + self._scaling_factor_to_offset: Dict[float, int] + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.base) + cache_list: List[torch.Tensor] = [] + # offsets to the next cache in a tensor. + # Each offset corresponds to the same index in scaling_factors. + offsets: List[int] = [] + for scaling_factor in self.scaling_factors: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * scaling_factor + t = torch.arange(max_len, dtype=torch.float) + t = t / scaling_factor + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + if not cache_list: + offset = 0 + else: + last_offset = offsets[-1] + next_max_len = cache_list[-1].shape[0] + offset = last_offset + next_max_len + offsets.append(offset) + cache_list.append(cache) + self._scaling_factor_to_offset = { + float(scaling_factor): offsets[i] + for i, scaling_factor in enumerate(self.scaling_factors) + } + assert len(self.scaling_factors) == len(offsets) + return torch.cat(cache_list, dim=0) + + @property + def scaling_factor_to_offset(self) -> Dict[float, int]: + return self._scaling_factor_to_offset + + +class DynamicNTKScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with Dynamic NTK scaling. + + Credits to the Reddit users /u/bloc97 and /u/emozilla + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + ) -> None: + self.scaling_factor = scaling_factor + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_cos_sin_cache(self) -> torch.Tensor: + # NOTE(woosuk): self.max_position_embeddings is the original + # maximum length before applying the rope scaling. + # Thus, the maximum length after applying the rope scaling is + # self.max_position_embeddings * self.scaling_factor. + max_len = self.max_position_embeddings * self.scaling_factor + base = self.base * ( + (self.scaling_factor * max_len / self.max_position_embeddings) - + (self.scaling_factor - 1))**(self.rotary_dim / + (self.rotary_dim - 2)) + inv_freq = self._compute_inv_freq(base) + t = torch.arange(max_len, dtype=torch.float) + + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() + sin = freqs.sin() + cache = torch.cat((cos, sin), dim=-1) + return cache + + +# Inverse dim formula to find dim based on number of rotations +def _yarn_find_correction_dim(num_rotations: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> float: + return (dim * math.log(max_position_embeddings / + (num_rotations * 2 * math.pi))) / (2 * + math.log(base)) + + +# Find dim range bounds based on rotations +def _yarn_find_correction_range( + low_rot: int, + high_rot: int, + dim: int, + base: float = 10000, + max_position_embeddings: int = 2048) -> Tuple[int, int]: + low = math.floor( + _yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)) + high = math.ceil( + _yarn_find_correction_dim(high_rot, dim, base, + max_position_embeddings)) + return max(low, 0), min(high, dim - 1) # Clamp values just in case + + +def _yarn_linear_ramp_mask(low: float, high: float, dim: int, + dtype: torch.dtype) -> torch.Tensor: + if low == high: + high += 0.001 # Prevent singularity + + linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low) + ramp_func = torch.clamp(linear_func, 0, 1) + return ramp_func + + +def _yarn_get_mscale(scale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * math.log(scale) + 1.0 + + +class YaRNScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation + self.mscale = float( + _yarn_get_mscale(self.scaling_factor) * attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**( + torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + return cache + + +class Phi3LongRoPEScaledRotaryEmbedding(nn.Module): + """Phi3 family of models scaled rotary embedding. + + Based on the original RotaryEmbedding implementation. + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + original_max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + short_factor: List[float], + long_factor: List[float], + short_mscale: Optional[float] = None, + long_mscale: Optional[float] = None, + ): + super().__init__() + + if rotary_dim != head_size: + raise ValueError( + f"`Phi3LongRoPEScaledRotaryEmbedding` does not support \ + rotary_dim != head_size ({rotary_dim}!={head_size}).") + if is_neox_style is False: + raise ValueError( + "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style." + ) + + self.head_size = head_size + self.max_position_embeddings = max_position_embeddings + self.original_max_position_embeddings = original_max_position_embeddings + self.base = base + self.short_factor = short_factor + self.long_factor = long_factor + + scale = self.max_position_embeddings / \ + self.original_max_position_embeddings + if scale <= 1.0: + scaling_factor = 1.0 + else: + scaling_factor = math.sqrt( + 1 + math.log(scale) / + math.log(self.original_max_position_embeddings)) + if short_mscale is None: + short_mscale = scaling_factor + if long_mscale is None: + long_mscale = scaling_factor + + self.short_mscale = short_mscale + self.long_mscale = long_mscale + + short_cache = self._compute_cos_sin_cache( + original_max_position_embeddings, short_factor, short_mscale) + short_cache = short_cache.to(dtype) + self.register_buffer("short_cos_sin_cache", + short_cache, + persistent=False) + + long_cache = self._compute_cos_sin_cache(max_position_embeddings, + long_factor, long_mscale) + long_cache = long_cache.to(dtype) + self.register_buffer("long_cos_sin_cache", + long_cache, + persistent=False) + + long_short_cache = torch.cat( + [self.short_cos_sin_cache, self.long_cos_sin_cache], dim=0) + self.register_buffer("long_short_cos_sin_cache", + long_short_cache, + persistent=False) + + def _compute_inv_freq(self, rescale_factors: List[float]) -> torch.Tensor: + rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32) + inv_freq = 1.0 / (rescale_factors * (self.base**(torch.arange( + 0, self.head_size, 2, dtype=torch.float) / self.head_size))) + return inv_freq + + def _compute_cos_sin_cache( + self, + max_position_embeddings: int, + rescale_factors: List[float], + mscale: float, + ) -> torch.Tensor: + inv_freq = self._compute_inv_freq(rescale_factors) + t = torch.arange(max_position_embeddings, dtype=torch.float) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = freqs.cos() * mscale + sin = freqs.sin() * mscale + cache = torch.cat((cos, sin), dim=-1) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + query = query.view(*query.shape[:-1], -1, self.head_size) + key = key.view(*key.shape[:-1], -1, self.head_size) + + k = self.original_max_position_embeddings + long_prompt_offset = (torch.any(positions > k).float() * + torch.full_like(positions, k)).long() + idx = (torch.add(positions, long_prompt_offset) + if long_prompt_offset is not None else positions) + self.long_short_cos_sin_cache: torch.Tensor = ( + self.long_short_cos_sin_cache.to(idx.device)) + idx = torch.add(idx, offsets) if offsets is not None else idx + cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx) + + cos, sin = cos_sin.chunk(2, dim=-1) + cos = cos.repeat(1, 2).unsqueeze(-2) + sin = sin.repeat(1, 2).unsqueeze(-2) + + query = query * cos + _rotate_neox(query) * sin + key = key * cos + _rotate_neox(key) * sin + + return query.flatten(-2), key.flatten(-2) + + +def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float: + if scale <= 1: + return 1.0 + return 0.1 * mscale * math.log(scale) + 1.0 + + +class DeepseekScalingRotaryEmbedding(RotaryEmbedding): + """RotaryEmbedding extended with YaRN method. + + Credits to Peng et al. github.com/jquesnelle/yarn + """ + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + scaling_factor: float, + dtype: torch.dtype, + *, + extrapolation_factor: float = 1, + attn_factor: float = 1, + beta_fast: int = 32, + beta_slow: int = 1, + mscale: float = 1, + mscale_all_dim: float = 0, + ) -> None: + self.scaling_factor = scaling_factor + self.extrapolation_factor = extrapolation_factor + self.attn_factor = attn_factor + self.beta_fast = beta_fast + self.beta_slow = beta_slow + # Get n-d magnitude scaling corrected for interpolation. + self.mscale = float( + yarn_get_mscale(self.scaling_factor, float(mscale)) / + yarn_get_mscale(self.scaling_factor, float(mscale_all_dim)) * + attn_factor) + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor: + pos_freqs = self.base**(torch.arange( + 0, self.rotary_dim, 2, dtype=torch.float, device="cuda") / + self.rotary_dim) + inv_freq_extrapolation = 1.0 / pos_freqs + inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs) + + low, high = _yarn_find_correction_range(self.beta_fast, self.beta_slow, + self.rotary_dim, self.base, + self.max_position_embeddings) + # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = (1 - _yarn_linear_ramp_mask( + low, high, self.rotary_dim // 2, + dtype=torch.float)) * self.extrapolation_factor + inv_freq = inv_freq_interpolation * ( + 1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask + return inv_freq + + def _compute_cos_sin_cache(self) -> torch.Tensor: + inv_freq = self._compute_inv_freq(self.scaling_factor) + t = torch.arange(self.max_position_embeddings * self.scaling_factor, + device="cuda", + dtype=torch.float32) + freqs = torch.einsum("i,j -> ij", t, inv_freq) + cos = (freqs.cos() * self.mscale) + sin = (freqs.sin() * self.mscale) + cache = torch.cat((cos, sin), dim=-1) + print("Cache shape", cache.shape) + return cache + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + offsets: Optional[torch.Tensor] = None, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward().""" + query_rot = query[..., :self.rotary_dim] + key_rot = key[..., :self.rotary_dim] + if self.rotary_dim < self.head_size: + query_pass = query[..., self.rotary_dim:] + key_pass = key[..., self.rotary_dim:] + + self.cos_sin_cache: torch.Tensor = self.cos_sin_cache.to( + positions.device) + cos_sin = self.cos_sin_cache[torch.add(positions, offsets) + if offsets is not None else positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if self.is_neox_style: + # NOTE(woosuk): Here we assume that the positions tensor has the + # shape [batch_size, seq_len]. + cos = cos.repeat(1, 1, 2).unsqueeze(-2) + sin = sin.repeat(1, 1, 2).unsqueeze(-2) + else: + cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2) + sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2) + + rotate_fn = _rotate_neox if self.is_neox_style else _rotate_gptj + query_rot = query_rot * cos + rotate_fn(query_rot) * sin + key_rot = key_rot * cos + rotate_fn(key_rot) * sin + + if self.rotary_dim < self.head_size: + query = torch.cat((query_rot, query_pass), dim=-1) + key = torch.cat((key_rot, key_pass), dim=-1) + else: + query = query_rot + key = key_rot + return query, key + + +class Llama3RotaryEmbedding(RotaryEmbedding): + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + scaling_factor: float, + low_freq_factor: float, + high_freq_factor: float, + orig_max_position: int, + ) -> None: + self.scaling_factor = scaling_factor + self.low_freq_factor = low_freq_factor + self.high_freq_factor = high_freq_factor + self.orig_max_position = orig_max_position + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: + inv_freqs = super()._compute_inv_freq(base) + low_freq_wavelen = self.orig_max_position / self.low_freq_factor + high_freq_wavelen = self.orig_max_position / self.high_freq_factor + + wave_len = 2 * math.pi / inv_freqs + if self.low_freq_factor != self.high_freq_factor: + smooth = (self.orig_max_position / wave_len - self.low_freq_factor + ) / (self.high_freq_factor - self.low_freq_factor) + else: + smooth = 0 + new_freqs = torch.where( + wave_len < high_freq_wavelen, + inv_freqs, + torch.where( + wave_len > low_freq_wavelen, + inv_freqs / self.scaling_factor, + (1 - smooth) * inv_freqs / self.scaling_factor + + smooth * inv_freqs, + ), + ) + return new_freqs + + +class MRotaryEmbedding(RotaryEmbedding): + """Rotary Embedding with Multimodal Sections.""" + + def __init__( + self, + head_size: int, + rotary_dim: int, + max_position_embeddings: int, + base: int, + is_neox_style: bool, + dtype: torch.dtype, + mrope_section: Optional[List[int]] = None, + ) -> None: + super().__init__(head_size, rotary_dim, max_position_embeddings, base, + is_neox_style, dtype) + + self.mrope_section = mrope_section + if self.mrope_section: + assert sum(self.mrope_section) == rotary_dim // 2 + + def forward( + self, + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """PyTorch-native implementation equivalent to forward(). + + Args: + positions: + [num_tokens,] (text only) or + [3, num_tokens] (T/H/W positions with multimodal inputs) + query: [num_tokens, num_heads * head_size] + key: [num_tokens, num_kv_heads * head_size] + """ + assert positions.ndim == 1 or positions.ndim == 2 + + num_tokens = positions.shape[-1] + cos_sin = self.cos_sin_cache[positions] + cos, sin = cos_sin.chunk(2, dim=-1) + if positions.ndim == 2: + assert self.mrope_section + + cos = torch.cat([ + m[i] + for i, m in enumerate(cos.split(self.mrope_section, dim=-1)) + ], + dim=-1) + sin = torch.cat([ + m[i] + for i, m in enumerate(sin.split(self.mrope_section, dim=-1)) + ], + dim=-1) + + query_shape = query.shape + query = query.view(num_tokens, -1, self.head_size) + query_rot = query[..., :self.rotary_dim] + query_pass = query[..., self.rotary_dim:] + query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style) + query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) + + key_shape = key.shape + key = key.view(num_tokens, -1, self.head_size) + key_rot = key[..., :self.rotary_dim] + key_pass = key[..., self.rotary_dim:] + key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style) + key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape) + return query, key + + @staticmethod + def get_input_positions( + input_tokens: List[int], + image_grid_thw: Union[List[List[int]], torch.Tensor], + video_grid_thw: Union[List[List[int]], torch.Tensor], + image_token_id: int, + video_token_id: int, + vision_start_token_id: int, + vision_end_token_id: int, + spatial_merge_size: int, + context_len: int = 0, + ) -> Tuple[List[List[int]], int]: + """Get mrope input positions and delta value.""" + + if isinstance(image_grid_thw, torch.Tensor): + image_grid_thw = image_grid_thw.tolist() + if isinstance(video_grid_thw, torch.Tensor): + video_grid_thw = video_grid_thw.tolist() + + input_tokens_tensor = torch.tensor(input_tokens) + vision_start_indices = torch.argwhere( + input_tokens_tensor == vision_start_token_id).squeeze(1) + vision_tokens = input_tokens_tensor[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + llm_pos_ids_list: list = [] + + st = 0 + remain_images, remain_videos = image_nums, video_nums + + image_index, video_index = 0, 0 + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = \ + t, h // spatial_merge_size, w // spatial_merge_size + text_len = ed - st + + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + t_index = torch.arange(llm_grid_t).view(-1, 1).expand( + -1, llm_grid_h * llm_grid_w).flatten() + h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand( + llm_grid_t, -1, llm_grid_w).flatten() + w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand( + llm_grid_t, llm_grid_h, -1).flatten() + llm_pos_ids_list.append( + torch.stack([t_index, h_index, w_index]) + text_len + st_idx) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = llm_pos_ids_list[-1].max() + 1 if len( + llm_pos_ids_list) > 0 else 0 + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) + + llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) + llm_positions = llm_positions[:, context_len:] + mrope_position_delta = (llm_positions.max() + 1 - + len(input_tokens)).item() + + return llm_positions.tolist(), mrope_position_delta + + @staticmethod + def get_next_input_positions( + mrope_position_delta: int, + context_len: int, + seq_len: int, + ) -> List[List[int]]: + return [ + list( + range(context_len + mrope_position_delta, + seq_len + mrope_position_delta)) for _ in range(3) + ] + + +_ROPE_DICT: Dict[Tuple, RotaryEmbedding] = {} + + +def get_rope( + head_size: int, + rotary_dim: int, + max_position: int, + base: int, + is_neox_style: bool = True, + rope_scaling: Optional[Dict[str, Any]] = None, + dtype: Optional[torch.dtype] = None, + partial_rotary_factor: float = 1.0, +) -> RotaryEmbedding: + if dtype is None: + dtype = torch.get_default_dtype() + if rope_scaling is not None: + # Transforms every value that is a list into a tuple for caching calls + rope_scaling_tuple = { + k: tuple(v) if isinstance(v, list) else v + for k, v in rope_scaling.items() + } + rope_scaling_args = tuple(rope_scaling_tuple.items()) + else: + rope_scaling_args = None + if partial_rotary_factor < 1.0: + rotary_dim = int(rotary_dim * partial_rotary_factor) + key = (head_size, rotary_dim, max_position, base, is_neox_style, + rope_scaling_args, dtype) + if key in _ROPE_DICT: + return _ROPE_DICT[key] + + if rope_scaling is None: + rotary_emb = RotaryEmbedding(head_size, rotary_dim, max_position, base, + is_neox_style, dtype) + else: + scaling_type = rope_scaling[ + "type"] if "type" in rope_scaling else rope_scaling["rope_type"] + # The correct one should be "longrope" but keep "su" here + # for backward compatible + if scaling_type not in {"su", "longrope"}: + scaling_factor = rope_scaling.get("factor", 1.0) + if scaling_type == "llama3": + low_freq_factor = rope_scaling["low_freq_factor"] + high_freq_factor = rope_scaling["high_freq_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + rotary_emb = Llama3RotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, dtype, + scaling_factor, low_freq_factor, + high_freq_factor, + original_max_position) + elif scaling_type == "linear": + rotary_emb = LinearScalingRotaryEmbedding(head_size, rotary_dim, + max_position, base, + is_neox_style, + scaling_factor, dtype) + elif scaling_type == "dynamic": + rotary_emb = DynamicNTKScalingRotaryEmbedding( + head_size, rotary_dim, max_position, base, is_neox_style, + scaling_factor, dtype) + elif scaling_type == "yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow") + } + rotary_emb = YaRNScalingRotaryEmbedding(head_size, rotary_dim, + original_max_position, + base, is_neox_style, + scaling_factor, dtype, + **extra_kwargs) + elif scaling_type == "deepseek_yarn": + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + # assert max_position == original_max_position * scaling_factor + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("extrapolation_factor", "attn_factor", "beta_fast", + "beta_slow", "mscale", "mscale_all_dim") + } + rotary_emb = DeepseekScalingRotaryEmbedding( + head_size, rotary_dim, original_max_position, base, + is_neox_style, scaling_factor, dtype, **extra_kwargs) + # The correct one should be "longrope" but keep "su" here + # for backward compatible + elif scaling_type == "su" or scaling_type == "longrope": + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position = rope_scaling[ + "original_max_position_embeddings"] + extra_kwargs = { + k: v + for k, v in rope_scaling.items() + if k in ("short_mscale", "long_mscale") + } + rotary_emb = Phi3LongRoPEScaledRotaryEmbedding( + head_size, rotary_dim, max_position, original_max_position, + base, is_neox_style, dtype, short_factor, long_factor, + **extra_kwargs) + elif scaling_type == "mrope": + rotary_emb = MRotaryEmbedding( + head_size, + rotary_dim, + max_position, + base, + is_neox_style, + dtype, + mrope_section=rope_scaling["mrope_section"], + ) + else: + raise ValueError(f"Unknown RoPE scaling type {scaling_type}") + _ROPE_DICT[key] = rotary_emb + return rotary_emb diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/setup.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/setup.py new file mode 100644 index 000000000..e38df9caf --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/setup.py @@ -0,0 +1,266 @@ +import sys +import warnings +import os +import re +import ast +import shutil +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import torch +from torch.utils.cpp_extension import ( + BuildExtension, + CppExtension, + CUDAExtension, + CUDA_HOME, + ROCM_HOME, + IS_HIP_EXTENSION, +) + + +ck_dir = os.environ.get("CK_DIR", "/mnt/raid0/shengnxu/composable_kernel") +this_dir = os.path.dirname(os.path.abspath(__file__)) +bd_dir = f"{this_dir}/build" +PACKAGE_NAME = 'rocmKernels' +BUILD_TARGET = os.environ.get("BUILD_TARGET", "auto") + +if BUILD_TARGET == "auto": + if IS_HIP_EXTENSION: + IS_ROCM = True + else: + IS_ROCM = False +else: + if BUILD_TARGET == "cuda": + IS_ROCM = False + elif BUILD_TARGET == "rocm": + IS_ROCM = True + +FORCE_CXX11_ABI = False + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f'linux_{platform.uname().machine}' + elif sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + elif sys.platform == "win32": + return "win_amd64" + else: + raise ValueError("Unsupported platform: {}".format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so nvcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def check_if_rocm_home_none(global_option: str) -> None: + if ROCM_HOME is not None: + return + # warn instead of error because user could be downloading prebuilt wheels, so hipcc won't be necessary + # in that case. + warnings.warn( + f"{global_option} was requested, but hipcc was not found." + ) + + +def append_nvcc_threads(nvcc_extra_args): + nvcc_threads = os.getenv("NVCC_THREADS") or "4" + return nvcc_extra_args + ["--threads", nvcc_threads] + + +def rename_cpp_to_cu(pths): + ret = [] + dst = bd_dir + for pth in pths: + if not os.path.exists(pth): + continue + for entry in os.listdir(pth): + if os.path.isdir(f'{pth}/{entry}'): + continue + newName = entry + if entry.endswith(".cpp") or entry.endswith(".cu"): + newName = entry.replace(".cpp", ".cu") + ret.append(f'{dst}/{newName}') + shutil.copy(f'{pth}/{entry}', f'{dst}/{newName}') + return ret + + +def validate_and_update_archs(archs): + # List of allowed architectures + allowed_archs = ["native", "gfx90a", + "gfx940", "gfx941", "gfx942", "gfx1100"] + + # Validate if each element in archs is in allowed_archs + assert all( + arch in allowed_archs for arch in archs + ), f"One of GPU archs of {archs} is invalid or not supported by Flash-Attention" + + +cmdclass = {} +ext_modules = [] + +if IS_ROCM: + # use codegen get code dispatch + if not os.path.exists(bd_dir): + os.makedirs(bd_dir) + + print(f"\n\ntorch.__version__ = {torch.__version__}\n\n") + TORCH_MAJOR = int(torch.__version__.split(".")[0]) + TORCH_MINOR = int(torch.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pytorch/pytorch/pull/70650 + generator_flag = [] + torch_dir = torch.__path__[0] + if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag.append("-DOLD_GENERATOR_PATH") + if os.path.exists(ck_dir): + generator_flag.append("-DFIND_CK") + + cc_flag = [] + + archs = os.getenv("GPU_ARCHS", "native").split(";") + validate_and_update_archs(archs) + + cc_flag = [f"--offload-arch={arch}" for arch in archs] + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # torch._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True + + renamed_sources = rename_cpp_to_cu([f"{this_dir}/csrc"]) + renamed_ck_srcs = rename_cpp_to_cu( + [f"{ck_dir}/example/ck_tile/02_layernorm2d/instances", + f"{this_dir}/csrc/impl/", + # f'for other kernels' + ]) + + extra_compile_args = { + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": + [ + "-O3", "-std=c++17", + "-mllvm", "-enable-post-misched=0", + "-DUSE_PROF_API=1", + "-D__HIP_PLATFORM_HCC__=1", + "-D__HIP_PLATFORM_AMD__=1", + # "-DLEGACY_HIPBLAS_DIRECT", + "-U__HIP_NO_HALF_CONVERSIONS__", + "-U__HIP_NO_HALF_OPERATORS__", + ] + + generator_flag + + cc_flag, + } + + include_dirs = [ + f"{this_dir}/build", + f"{ck_dir}/include", + f"{ck_dir}/library/include", + f"{ck_dir}/example/ck_tile/02_layernorm2d", + ] + + ext_modules.append( + CUDAExtension( + name=PACKAGE_NAME, + sources=renamed_sources+renamed_ck_srcs, + extra_compile_args=extra_compile_args, + include_dirs=include_dirs, + ) + ) +else: + raise NotImplementedError("Only ROCM is supported") + + +class NinjaBuildExtension(BuildExtension): + def __init__(self, *args, **kwargs) -> None: + # do not override env MAX_JOBS if already exists + if not os.environ.get("MAX_JOBS"): + import psutil + + # calculate the maximum allowed NUM_JOBS based on cores + max_num_jobs_cores = max(1, os.cpu_count() // 2) + + # calculate the maximum allowed NUM_JOBS based on free memory + free_memory_gb = psutil.virtual_memory().available / \ + (1024 ** 3) # free memory in GB + # each JOB peak memory cost is ~8-9GB when threads = 4 + max_num_jobs_memory = int(free_memory_gb / 9) + + # pick lower value of jobs based on cores vs memory metric to minimize oom and swap usage during compilation + max_jobs = max(1, min(max_num_jobs_cores, max_num_jobs_memory)) + os.environ["MAX_JOBS"] = str(max_jobs) + + super().__init__(*args, **kwargs) + + +setup( + name=PACKAGE_NAME, + version="0.1.0", + packages=find_packages( + exclude=( + "build", + "csrc", + "include", + "tests", + "dist", + "docs", + "benchmarks", + ) + ), + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={"build_ext": NinjaBuildExtension}, + python_requires=">=3.8", + install_requires=[ + "torch", + "einops", + ], + setup_requires=[ + "packaging", + "psutil", + "ninja", + ], +) +if os.path.exists(bd_dir): + shutil.rmtree(bd_dir) + shutil.rmtree(f"./.eggs") + shutil.rmtree(f"./{PACKAGE_NAME}.egg-info") + +if os.path.exists('./build'): + shutil.rmtree(f"./build") diff --git a/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/tuned_gemm.py b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/tuned_gemm.py new file mode 100644 index 000000000..94483b305 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/rocm_kernels/tuned_gemm.py @@ -0,0 +1,139 @@ +import os +from pathlib import Path + +import pandas as pd +import torch +import torch.nn.functional as F +# from hipbsolidxgemm import hipb_create_extension, hipb_mm +# from rocsolidxgemm import rocb_create_extension, rocb_mm + +import rocmKernels as ops +# from vllm import _custom_ops as ops +# from vllm.envs import VLLM_USE_ROCM_SKINNY_GEMM +# from vllm.utils import is_hip + + +class TunedGemm: + + def __init__(self): + #rocb_create_extension() + #hipb_create_extension() + self.extensions_created = False + self.save_gemm = int(os.environ.get('VLLM_TUNE_GEMM', 0)) + self.untune_path = os.environ.get('VLLM_UNTUNE_FILE', + "/tmp/vllm_untuned.csv") + self.tune_path = os.environ.get('VLLM_TUNE_FILE', "tuned.csv") + self.bestsols = {} + self.load_best_sols() + self.create_ds() + self.cu_count = torch.cuda.get_device_properties( + device='cuda').multi_processor_count + + # self.use_skinny = is_hip() and VLLM_USE_ROCM_SKINNY_GEMM and \ + # "gfx1" not in torch.cuda.get_device_properties('cuda').gcnArchName + self.use_skinny = True + + if (self.save_gemm == 1): + self.tuned_df = pd.DataFrame( + columns=['M', 'N', 'K', 'bias', 'dtype']) + else: + self.tuned_df = None + + def load_best_sols(self): + if self.tune_path is not None and Path(self.tune_path).is_file(): + self.bestsols = pd.read_csv(self.tune_path) + + def create_ds(self): + df: pd.DataFrame = self.bestsols + solds = {} + for i in range(len(df)): + ds = df.iloc[i] + key = (ds['M'], ds['N'], ds['K'], ds['bias'], ds['dtype']) + if ds['libtype'] == 'hipblaslt': + soltype = 1 + elif ds['libtype'] == 'rocblas': + soltype = 2 + solds[key] = (soltype, int(ds['solidx'])) + self.solids = solds + #print('>>>',solds) + def query_sol(self, m, n, k, bias, dtype): + return self.solids.get((m, n, k, bias, str(dtype)), (0, 0)) + + def apply_skinny(self, m, n, k, inp_view, weights): + if not self.use_skinny: + return None + if inp_view.dtype != torch.float16 or k % 8 != 0: + return None + if m > 8 and 0 < n <= 4: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.wvSpltK(weights, inp_view, out, n, self.cu_count) + return out + elif m % 4 == 0 and n == 1 and k <= 8192: + out = torch.empty(inp_view.shape[0], + weights.shape[0], + dtype=inp_view.dtype, + device='cuda') + ops.LLMM1(weights, inp_view, out, 4) + return out + else: + return None + + def mm(self, inp, weights, bias=None): + # F.Linear can take a 3 dimensional input. vllm + # uses this for linear units. However, sampler + # will use torch.matmul with 2 dimensions only + if inp.dim() == 3: + inp_view = inp.view(-1, inp.size(-1)) + batched = True + else: + inp_view = inp + batched = False + if self.extensions_created is False: + # rocb_create_extension() + # hipb_create_extension() + self.extensions_created = True + m = weights.shape[0] + n = inp_view.shape[0] + k = inp_view.shape[1] + use_bias = bias is not None + soltype, solidx = self.query_sol(m=m, + n=n, + k=k, + bias=use_bias, + dtype=inp.dtype) + out = self.apply_skinny(m, n, k, inp_view, weights) + if out is not None: + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + if bias is not None: + return out + bias + return out + elif soltype == 1: + out = hipb_mm(inp_view, weights.t(), solidx, bias=bias) + elif soltype == 2: + out = rocb_mm(inp_view, weights.t(), solidx) + if bias is not None: + out = out + bias + else: + if (self.save_gemm == 1): + self.tuned_df = pd.concat([ + self.tuned_df, + pd.DataFrame({ + 'M': [m], + 'N': [n], + 'K': [k], + 'bias': [bias is not None], + 'dtype': [inp.dtype], + }) + ]).drop_duplicates() + self.tuned_df.to_csv(self.untune_path, index=False) + return F.linear(inp, weights, bias) + if batched: + out = out.view(inp.shape[0], inp.shape[1], weights.shape[0]) + return out + + +tgemm = TunedGemm() diff --git a/byte_infer_perf/llm_perf/backends/ROCM/setup.py b/byte_infer_perf/llm_perf/backends/ROCM/setup.py new file mode 100644 index 000000000..58013b889 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/setup.py @@ -0,0 +1,74 @@ +import torch +import importlib +from typing import Any, Dict + +from llm_perf.core.scheduler import CoreScheduler +from llm_perf.backends.ROCM.gpu_inferencer import GpuInferencer +from llm_perf.backends.GPU.gpu_sampler import GpuSampler +from llm_perf.backends.GPU.gpu_scheduler import GpuScheduler +from llm_perf.backends.ROCM.gpu_mp_engine import GpuMpEngine +from llm_perf.backends.ROCM.gpu_mp_engine import GpuMpEngineWithGraph +from llm_perf.utils.logger import logger +import os + +def get_device_name(): + return torch.cuda.get_device_name(0) + + +def get_engine(xpu_cfg) -> CoreScheduler: + # get model impl + hardware_type = xpu_cfg["hardware_type"] + model_config = xpu_cfg["model_config"] + model_name = model_config["model_name"] + + vendor_model_path = f"llm_perf/backends/{hardware_type}/model_impl" + vendor_model_impl = importlib.import_module( + ".", package=vendor_model_path.replace("/", ".") + ) + vendor_model = vendor_model_impl.__all__[model_name] + + is_graph = int(os.environ.get("ENABLE_GRAPH", "0")) + + if is_graph: + mp_engine = GpuMpEngineWithGraph( + world_size=xpu_cfg["tp_size"], + model_impl=vendor_model, + xpu_cfg=xpu_cfg + ) + return mp_engine + else: + mp_engine = GpuMpEngine( + world_size=xpu_cfg["tp_size"], + model_impl=vendor_model, + xpu_cfg=xpu_cfg + ) + return mp_engine + + +def setup_scheduler(xpu_cfg) -> CoreScheduler: + + # get model impl + hardware_type = xpu_cfg["hardware_type"] + model_config = xpu_cfg["model_config"] + model_name = model_config["model_name"] + + vendor_model_path = f"llm_perf/backends/{hardware_type}/model_impl" + vendor_model_impl = importlib.import_module( + ".", package=vendor_model_path.replace("/", ".") + ) + vendor_model = vendor_model_impl.__all__[model_name] + + # create inferencer + inferencer = GpuInferencer(vendor_model, xpu_cfg) + + # create sampler + sampler = GpuSampler() + + # create scheduler + scheduler = GpuScheduler( + inferencer=inferencer, + sampler=sampler, + xpu_cfg=xpu_cfg + ) + + return scheduler diff --git a/byte_infer_perf/llm_perf/backends/ROCM/test_gemm_int8/benchmark_mixtral_gemm.py b/byte_infer_perf/llm_perf/backends/ROCM/test_gemm_int8/benchmark_mixtral_gemm.py new file mode 100644 index 000000000..a193b14f3 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/test_gemm_int8/benchmark_mixtral_gemm.py @@ -0,0 +1,144 @@ +import argparse +import json +import os +import sys +import unittest +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from tqdm import tqdm +_path = os.path.abspath(os.path.dirname(__file__)) +sys.path.insert(0, f'{_path}/../') +#import vllm._moe_C as moe_kernels +import rocmKernels as ops +# print(ops.__file__) +# exit() +from rocmKernels import gemm_a8w8 +def torch_gemma8w8(a, b, alpha_row, alpha_col): + b = b.transpose(0, 1) + x = torch.matmul(a.to(torch.float32), b.to(torch.float32)) + scale = torch.matmul(alpha_row, alpha_col) + out = torch.mul(x, scale) + return out.to(torch.half) +def get_MNK_shapes(): + MNK_SHAPES = [ + (1,4608,3584), + (32,4608,3584), + (64,4608,3584), + (128,4608,3584), + (256,4608,3584), + (512,4608,3584), + (1024,4608,3584), + (2048,4608,3584), + (4096,4608,3584), + (8192,4608,3584), + (16384,4608,3584), + (20480,4608,3584), + (1,3584,3584), + (32,3584,3584), + (64,3584,3584), + (128,3584,3584), + (256,3584,3584), + (512,3584,3584), + (1024,3584,3584), + (2048,3584,3584), + (4096,3584,3584), + (8192,3584,3584), + (16384,3584,3584), + (20480,3584,3584), + (1,3584,20480), + (32,3584,20480), + (64,3584,20480), + (128,3584,20480), + (256,3584,20480), + (512,3584,20480), + (1024,3584,20480), + (2048,3584,20480), + (4096,3584,20480), + (8192,3584,20480), + (16384,3584,20480), + (20480,3584,20480), + (1,40960,3584), + (32,40960,3584), + (64,40960,3584), + (128,40960,3584), + (256,40960,3584), + (512,40960,3584), + (1024,40960,3584), + (2048,40960,3584), + (4096,40960,3584), + (8192,40960,3584), + (16384,40960,3584), + (20480,40960,3584), + ] + return MNK_SHAPES + + +def get_M_shapes(): + # Start M from 1 and init gemm_a8w8 at very beginning will cause inf error + M_SHAPES = [2**i for i in range(4, 13)] + [1, 10, 20, 30, 40] + # M_SHAPES = [2**i for i in range(4, model_config.exponent_of_max_seq_len + 1)] + return M_SHAPES +def main(): + for m, n,k in get_MNK_shapes(): + # start_event = torch.cuda.Event(enable_timing=True) + # end_event = torch.cuda.Event(enable_timing=True) + # start_event.record() + a_s = [] + b_s = [] + num_calls = 200 + for i in range(num_calls): + a = torch.randint(-20, 20, (m, k), dtype=torch.int8).cuda() + b = torch.randint(-20, 20, (n, k), dtype=torch.int8).cuda() + a_s.append(a) + b_s.append(b) + + alpha_row = torch.rand([m, 1], dtype=torch.half).cuda() + alpha_col = torch.rand([1, n], dtype=torch.half).cuda() + out_gemm = torch.empty([m,n],dtype = torch.half).cuda() + out_ref = torch.empty([m,n],dtype = torch.half).cuda() + print("self m",m,"n",n,"k",k) + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA, ], on_trace_ready=torch.profiler.tensorboard_trace_handler("./"), with_modules=False, with_stack=False) as p: + for i in range(num_calls): + gemm_a8w8(a_s[i],b_s[i],alpha_row,alpha_col,out_gemm) + print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=20)) + #end_event.record() + #end_event.synchronize() + #dur_us = start_event.elapsed_time(end_event)*1000 / num_calls + #print("m",m,"n",n,"k",k, "dur_us",dur_us) + out_ref = torch_gemma8w8(a,b,alpha_row,alpha_col) + assert torch.allclose( + out_ref, out_gemm, 1e-03, 1000), f"torch and ck results are not close\n{out_ref.shape}\n{out_ref}\n{out_gemm.shape}\n{out_gemm}\nmax delta:{(out_gemm-out_ref).max()}" +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="benchmark_mixtral_gemm", + description="Tune the gemm kernel for mixtral.") + parser.add_argument( + "--TP", + type=int, + choices=[8, 4, 2, 1], + help="Specify the TP value that the actual model will run on", + default="0", + ) + parser.add_argument( + "--GPUID", + type=str, + help="This script uses single GPU. Specify the GPU to use for tuning", + default="0", + ) + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + default="8x7B", + ) + + args = parser.parse_args() + + print(f"Running tuning for {args.model} model") + print(f"TP is set to: {args.TP}") + print(f"GPU-ID being used for tuning: {args.GPUID}") + sys.exit(main()) + + + diff --git a/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_gemm.py b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_gemm.py new file mode 100755 index 000000000..d516886a4 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_gemm.py @@ -0,0 +1,408 @@ +import argparse +import json +import os +import sys +import unittest +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from tqdm import tqdm +_path = os.path.abspath(os.path.dirname(__file__)) +sys.path.insert(0, f'{_path}/../') +#import vllm._moe_C as moe_kernels +import rocmKernels as ops +# print(ops.__file__) +# exit() +from rocm_kernels.fused_moe_int8_a8w8 import (fused_moe_int8_a8w8, + get_config_file_name, + scaled_int8_quant) + + +def main(args): + os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID + os.environ["HIP_FORCE_DEV_KERNARG"] = "1" + os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" + os.environ["OPTIMIZE_EPILOGUE"] = "1" + + for bs in [ +# 1, +# 2, +# 4, +# 8, +# 16, +# 24, + 32, +# 48, +# 64, +# 96, +# 128, +# 256, +# 512, +# 1024, +# 1536, +# 2048, +# 3072, +# 4096, + ]: + run_grid(bs, model=args.model, TP=args.TP) + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [32, 64, 128, 256] + split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [ 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + #for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for (matrix_instr_nonkdim) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + +def torch_moe(hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros( + B * topk, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + silu_input = hidden_states[mask] @ (w1[i].transpose(0, 1)) + d = silu_input.shape[-1] // 2 + silu_output_shape = silu_input.shape[:-1] + (d,) + silu_out = torch.empty( + silu_output_shape, dtype=silu_input.dtype, device=silu_input.device + ) + ops.silu_and_mul(silu_out, silu_input) + out[mask] = silu_out @ (w2[i].transpose(0, 1)) + #out = out + 2.0 + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) +def dynamic_quan_torch_impl(input): + max_input = input.abs().max(-1, keepdim=True)[0] + scale = max_input / 127.0 + out = torch.round(input / scale) + return out.to(torch.int8), scale.half().squeeze(-1) +def run_grid(bs, model, TP): + if model == '8x7B': + d_model = 4096 + #d_model = 32 + model_intermediate_size = 14336 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + else: + raise ValueError(f'Unsupported Mixtral model {model}') + + num_total_experts = 8 + top_k = 2 + tp_size = TP + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + full_configs = get_full_tuning_space() + M1 = bs * 2 + N1 = model_intermediate_size * 2 // tp_size + K1 = d_model + prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + + M2 = bs * 2 + N2 = d_model + K2 = model_intermediate_size // tp_size + prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + + configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") + + best_config = None + best_time_us = 1e20 + + for config in tqdm(configs): + print("have config") + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # benchmark + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + # model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + filename = get_config_file_name(num_total_experts, + model_intermediate_size // tp_size, + dtype=None) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + config, +) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + print("run timing") + hidden_states = torch.rand( + (bs, d_model), + device="cuda", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + a2_scales = torch.rand((hidden_states.shape[1]), + device = hidden_states.device, + dtype=hidden_states.dtype) + gating_output = F.softmax( + torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) + + ###### Stuff from fused moe ###### + hidden_states_quant,hidden_states_scales = dynamic_quan_torch_impl(hidden_states) + w1_quant, w1_scales = dynamic_quan_torch_impl(w1) + w2_quant, w2_scales = dynamic_quan_torch_impl(w2) + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + output = fused_moe_int8_a8w8(hidden_states_quant, + w1_quant, + w2_quant, + gating_output, + w1_scales, + w2_scales, + hidden_states_scales, + a2_scales, + top_k, + renormalize=False, + inplace=False) + hidden_states_dequant = hidden_states_quant * hidden_states_scales[:, None] + w1_dequant = w1_quant * w1_scales[:, :, None] + w2_dequant = w2_quant * w2_scales[:, :, None] + out_ref = torch_moe(hidden_states_dequant, + w1_dequant, + w2_dequant, + gating_output, + top_k, + ) + diff = ~torch.isclose( + output.half().cpu(), out_ref.half().cpu(), rtol=1, atol=1 + ) + #print("output:",output) + #print("out_ref:",out_ref) + assert(diff.sum() < 10) + #print("diff sum :",diff.sum()) + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="benchmark_mixtral_moe_rocm", + description="Tune the fused_moe kernel for mixtral.") + parser.add_argument( + "--TP", + type=int, + choices=[8, 4, 2, 1], + help="Specify the TP value that the actual model will run on", + required=True, + ) + parser.add_argument( + "--GPUID", + type=str, + help="This script uses single GPU. Specify the GPU to use for tuning", + default="0", + ) + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + + args = parser.parse_args() + + print(f"Running tuning for {args.model} model") + print(f"TP is set to: {args.TP}") + print(f"GPU-ID being used for tuning: {args.GPUID}") + sys.exit(main(args)) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm.py b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm.py new file mode 100755 index 000000000..d516886a4 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm.py @@ -0,0 +1,408 @@ +import argparse +import json +import os +import sys +import unittest +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from tqdm import tqdm +_path = os.path.abspath(os.path.dirname(__file__)) +sys.path.insert(0, f'{_path}/../') +#import vllm._moe_C as moe_kernels +import rocmKernels as ops +# print(ops.__file__) +# exit() +from rocm_kernels.fused_moe_int8_a8w8 import (fused_moe_int8_a8w8, + get_config_file_name, + scaled_int8_quant) + + +def main(args): + os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID + os.environ["HIP_FORCE_DEV_KERNARG"] = "1" + os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" + os.environ["OPTIMIZE_EPILOGUE"] = "1" + + for bs in [ +# 1, +# 2, +# 4, +# 8, +# 16, +# 24, + 32, +# 48, +# 64, +# 96, +# 128, +# 256, +# 512, +# 1024, +# 1536, +# 2048, +# 3072, +# 4096, + ]: + run_grid(bs, model=args.model, TP=args.TP) + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + block_mn_range = [16, 32, 64, 128, 256] + block_k_range = [32, 64, 128, 256] + split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + num_warps_range = [1, 2, 4, 8] + group_m_range = [1, 4, 8, 16, 32] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16, 32] + kpack_range = [ 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + #for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for (matrix_instr_nonkdim) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + +def torch_moe(hidden_states, w1, w2, score, topk): + B, D = hidden_states.shape + hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros( + B * topk, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + silu_input = hidden_states[mask] @ (w1[i].transpose(0, 1)) + d = silu_input.shape[-1] // 2 + silu_output_shape = silu_input.shape[:-1] + (d,) + silu_out = torch.empty( + silu_output_shape, dtype=silu_input.dtype, device=silu_input.device + ) + ops.silu_and_mul(silu_out, silu_input) + out[mask] = silu_out @ (w2[i].transpose(0, 1)) + #out = out + 2.0 + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) +def dynamic_quan_torch_impl(input): + max_input = input.abs().max(-1, keepdim=True)[0] + scale = max_input / 127.0 + out = torch.round(input / scale) + return out.to(torch.int8), scale.half().squeeze(-1) +def run_grid(bs, model, TP): + if model == '8x7B': + d_model = 4096 + #d_model = 32 + model_intermediate_size = 14336 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + else: + raise ValueError(f'Unsupported Mixtral model {model}') + + num_total_experts = 8 + top_k = 2 + tp_size = TP + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + full_configs = get_full_tuning_space() + M1 = bs * 2 + N1 = model_intermediate_size * 2 // tp_size + K1 = d_model + prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + + M2 = bs * 2 + N2 = d_model + K2 = model_intermediate_size // tp_size + prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + + configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") + + best_config = None + best_time_us = 1e20 + + for config in tqdm(configs): + print("have config") + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # benchmark + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + # model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + filename = get_config_file_name(num_total_experts, + model_intermediate_size // tp_size, + dtype=None) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + config, +) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + print("run timing") + hidden_states = torch.rand( + (bs, d_model), + device="cuda", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + a2_scales = torch.rand((hidden_states.shape[1]), + device = hidden_states.device, + dtype=hidden_states.dtype) + gating_output = F.softmax( + torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) + + ###### Stuff from fused moe ###### + hidden_states_quant,hidden_states_scales = dynamic_quan_torch_impl(hidden_states) + w1_quant, w1_scales = dynamic_quan_torch_impl(w1) + w2_quant, w2_scales = dynamic_quan_torch_impl(w2) + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + output = fused_moe_int8_a8w8(hidden_states_quant, + w1_quant, + w2_quant, + gating_output, + w1_scales, + w2_scales, + hidden_states_scales, + a2_scales, + top_k, + renormalize=False, + inplace=False) + hidden_states_dequant = hidden_states_quant * hidden_states_scales[:, None] + w1_dequant = w1_quant * w1_scales[:, :, None] + w2_dequant = w2_quant * w2_scales[:, :, None] + out_ref = torch_moe(hidden_states_dequant, + w1_dequant, + w2_dequant, + gating_output, + top_k, + ) + diff = ~torch.isclose( + output.half().cpu(), out_ref.half().cpu(), rtol=1, atol=1 + ) + #print("output:",output) + #print("out_ref:",out_ref) + assert(diff.sum() < 10) + #print("diff sum :",diff.sum()) + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="benchmark_mixtral_moe_rocm", + description="Tune the fused_moe kernel for mixtral.") + parser.add_argument( + "--TP", + type=int, + choices=[8, 4, 2, 1], + help="Specify the TP value that the actual model will run on", + required=True, + ) + parser.add_argument( + "--GPUID", + type=str, + help="This script uses single GPU. Specify the GPU to use for tuning", + default="0", + ) + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + + args = parser.parse_args() + + print(f"Running tuning for {args.model} model") + print(f"TP is set to: {args.TP}") + print(f"GPU-ID being used for tuning: {args.GPUID}") + sys.exit(main(args)) diff --git a/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm_backup.py b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm_backup.py new file mode 100755 index 000000000..831aca849 --- /dev/null +++ b/byte_infer_perf/llm_perf/backends/ROCM/test_moe_int8/benchmark_mixtral_moe_rocm_backup.py @@ -0,0 +1,434 @@ +import argparse +import json +import os +import sys +import unittest +import torch +import torch.nn.functional as F +import triton +import triton.language as tl +from tqdm import tqdm +_path = os.path.abspath(os.path.dirname(__file__)) +sys.path.insert(0, f'{_path}/../') +#import vllm._moe_C as moe_kernels +import rocmKernels as ops +from rocm_kernels.fused_moe_int8_a8w8 import (fused_moe_int8_a8w8, + scaled_int8_quant) +from rocm_kernels.fused_moe_custom import (fused_moe_int8_a8w8_custom, + get_config_file_name, + triton_dynamic_quantize) + + +def main(args): + os.environ["HIP_VISIBLE_DEVICES"] = args.GPUID + os.environ["HIP_FORCE_DEV_KERNARG"] = "1" + os.environ["DEBUG_CLR_GRAPH_PACKET_CAPTURE"] = "1" + os.environ["OPTIMIZE_EPILOGUE"] = "1" + + for bs in [ +# 1, +# 2, +# 4, +# 8, + 16, +# 24, +# 32, +# 48, +# 64, +# 96, +# 128, +# 256, +# 512, +# 1024, +# 1536, +# 2048, +# 3072, +# 4096, + ]: + run_grid(bs, model=args.model, TP=args.TP) + + +## Utilize method from rocm/Triton tuning script +def get_full_tuning_space(): + configs = [] + + #block_mn_range = [16, 32, 64, 128, 256] + block_mn_range = [32] + #block_k_range = [32, 64, 128, 256] + block_k_range = [32] + block_mn_range = [16] + # block_k_range = [64] + #split_k_range = [1, 2, 4, 5, 6, 8, 10, 12, 16, 18, 24] + split_k_range = [4] + #num_warps_range = [1, 2, 4, 8] + num_warps_range = [2] + # num_warps_range = [1] + #group_m_range = [1, 4, 8, 16, 32] + group_m_range = [1] + # group_m_range = [1] + # For now we see better perf with num_stages=0 for all gemm configs we care + # But keep this explicit so that we do not forget we may need to set it to + # other values in the future + num_stage_range = [0] + waves_per_eu_range = [0] + matrix_instr_nonkdim_range = [16] + #matrix_instr_nonkdim_range = [16, 32] + kpack_range = [ 2] + + for block_m in block_mn_range: + for block_n in block_mn_range: + for block_k in block_k_range: + for num_warps in num_warps_range: + for group_m in group_m_range: + # for split_k in split_k_range: + for num_stages in num_stage_range: + for waves_per_eu in waves_per_eu_range: + for (matrix_instr_nonkdim + ) in matrix_instr_nonkdim_range: + for kpack in kpack_range: + configs.append({ + "BLOCK_SIZE_M": block_m, + "BLOCK_SIZE_N": block_n, + "BLOCK_SIZE_K": block_k, + "GROUP_SIZE_M": group_m, + "num_warps": num_warps, + "num_stages": num_stages, + }) + + return configs + + +## Utilize method from rocm/Triton tuning script +def prune_configs(M, N, K, configs): + pruned_configs = [] + elemBytes_a = 2 # [DV Note] Hard-coded for float16 (2 bytes) + elemBytes_b = 2 # [DV Note] Hard-coded for float16 (2 bytes) + + mfma = 16 if M < 32 or N < 32 else 32 + + # TODO (zhanglx): figure out the boundary between large and small gemms + large_gemm = False + if M >= 2048 and N >= 2048: + large_gemm = True + + for config in configs: + BLOCK_SIZE_M = config.get("BLOCK_SIZE_M") + BLOCK_SIZE_N = config.get("BLOCK_SIZE_N") + BLOCK_SIZE_K = config.get("BLOCK_SIZE_K") + num_warps = config.get("num_warps") + if mfma == 4 and BLOCK_SIZE_K < 64: + continue + if BLOCK_SIZE_M * BLOCK_SIZE_N < 64: + continue + SPLIT_K = 1 # config.get("SPLIT_K") + GROUP_M = config.get("GROUP_SIZE_M") + # Skip BLOCK_SIZE that is too large compare to M/N + # unless BLOCK_SIZE is already small enough + if M * 2 < BLOCK_SIZE_M and BLOCK_SIZE_M != 16: + continue + if N * 2 < BLOCK_SIZE_N and BLOCK_SIZE_N != 16: + continue + # skip large split_k when not necessary + if SPLIT_K != 1 and not need_split_k(M, N, K): + continue + # skip split_k that leads to EVEN_K = false + leap = SPLIT_K * BLOCK_SIZE_K + modv = K % leap + if modv != 0: + continue + # skip large GROUP_M + if GROUP_M * BLOCK_SIZE_M > M and GROUP_M != 1: + continue + # out of shared memory resource + # TODO (zhanglx): This does not consider the LDS usage in the epilogue + LDS = (BLOCK_SIZE_K * BLOCK_SIZE_M * elemBytes_a + + BLOCK_SIZE_K * BLOCK_SIZE_N * elemBytes_b) + if LDS > 65536: + continue + # Skip small block sizes and num_warps for large gemm + # For fp16 and f8, we want to only use BLOCK_SIZE >= 64 + if large_gemm: + if BLOCK_SIZE_M < 64 or BLOCK_SIZE_N < 64: + continue + if BLOCK_SIZE_K < 64: + continue + if num_warps < 4: + continue + + pruned_configs.append(config) + + return pruned_configs + + +def union_of_list_of_dicts(l1, l2): + result = [] + temp_list = l1.copy() + temp_list.extend(l2) + for myDict in temp_list: + if myDict not in result: + result.append(myDict) + + return result + + +def need_split_k(SIZE_M, SIZE_N, SIZE_K): + return (SIZE_M < 64 or SIZE_N < 64) and SIZE_K > 1024 + +def torch_moe(hidden_states, w1, w2, score, topk): + #print("in side torch moe w1, w2", w1, w2, hidden_states) + B, D = hidden_states.shape + hidden_states = hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros( + B * topk, + w2.shape[1], + dtype=hidden_states.dtype, + device=hidden_states.device, + ) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + #print("torch topk_weight",topk_weight,"topk_ids",topk_ids) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + silu_input = hidden_states[mask] @ (w1[i].transpose(0, 1)) + # + d = silu_input.shape[-1] // 2 + silu_output_shape = silu_input.shape[:-1] + (d,) + silu_out = torch.empty( + silu_output_shape, dtype=silu_input.dtype, device=silu_input.device + ) + ops.silu_and_mul(silu_out, silu_input) + # + out[mask] = silu_out @ (w2[i].transpose(0, 1)) + #out = out + 2.0 + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) +def dynamic_quan_torch_impl(input): + max_input = input.abs().max(-1, keepdim=True)[0] + scale = max_input / 127.0 + out = torch.round(input / scale) + return out.to(torch.int8), scale.half().squeeze(-1) +def run_grid(bs, model, TP): + if model == '8x7B': + d_model = 4096 + #d_model = 32 + model_intermediate_size = 14336 + elif model == '8x22B': + d_model = 6144 + model_intermediate_size = 16384 + else: + raise ValueError(f'Unsupported Mixtral model {model}') + + num_total_experts = 8 + top_k = 2 + tp_size = TP + num_calls = 100 + + num_warmup_trials = 1 + num_trials = 1 + + full_configs = get_full_tuning_space() + M1 = bs * 2 + N1 = model_intermediate_size * 2 // tp_size + K1 = d_model + prune_configs_1 = prune_configs(M1, N1, K1, full_configs) + + M2 = bs * 2 + N2 = d_model + K2 = model_intermediate_size // tp_size + prune_configs_2 = prune_configs(M2, N2, K2, full_configs) + + configs = union_of_list_of_dicts(prune_configs_1, prune_configs_2) + print(f"{bs=} || {len(full_configs)=} | {len(prune_configs_1)=} | \ + {len(prune_configs_2)=} | {len(configs)=}") + + best_config = None + best_time_us = 1e20 + + for config in tqdm(configs): + print("have config") + # warmup + try: + for _ in range(num_warmup_trials): + run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + except triton.runtime.autotuner.OutOfResources: + continue + + # benchmark + for _ in range(num_trials): + kernel_dur_ms = run_timing( + num_calls=num_calls, + bs=bs, + d_model=d_model, + num_total_experts=num_total_experts, + top_k=top_k, + tp_size=tp_size, + model_intermediate_size=model_intermediate_size, + config=config, + ) + + kernel_dur_us = 1000 * kernel_dur_ms + # model_dur_ms = kernel_dur_ms * num_layers + + if kernel_dur_us < best_time_us: + best_config = config + best_time_us = kernel_dur_us + + filename = get_config_file_name(num_total_experts, + model_intermediate_size // tp_size, + dtype=None) + print(f"writing config to file {filename}") + existing_content = {} + if os.path.exists(filename): + with open(filename, "r") as f: + existing_content = json.load(f) + existing_content[str(bs)] = best_config + with open(filename, "w") as f: + json.dump(existing_content, f, indent=4) + f.write("\n") + + +def run_timing( + num_calls: int, + bs: int, + d_model: int, + num_total_experts: int, + top_k: int, + tp_size: int, + model_intermediate_size: int, + config, +) -> float: + shard_intermediate_size = model_intermediate_size // tp_size + print("run timing") + hidden_states = torch.rand( + (bs, d_model), + device="cuda", + dtype=torch.float16, + ) + + w1 = torch.rand( + (num_total_experts, 2 * shard_intermediate_size, d_model), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + + w2 = torch.rand( + (num_total_experts, d_model, shard_intermediate_size), + device=hidden_states.device, + dtype=hidden_states.dtype, + )/100 + a2_scales = torch.rand((hidden_states.shape[1]), + device = hidden_states.device, + dtype=hidden_states.dtype) + gating_output = F.softmax( + torch.rand( + # (num_calls, bs, num_total_experts), # THIS + (bs, num_total_experts), + device=hidden_states.device, + dtype=torch.float32, + ), + dim=-1, + ) + + ###### Stuff from fused moe ###### + hidden_states_quant,hidden_states_scales = dynamic_quan_torch_impl(hidden_states) + w1_quant, w1_scales = dynamic_quan_torch_impl(w1) + w2_quant, w2_scales = dynamic_quan_torch_impl(w2) + assert (hidden_states.shape[0] == gating_output.shape[0] + ), "Number of tokens mismatch" + assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" + assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" + assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" + assert w1.is_contiguous(), "Expert weights1 must be contiguous" + assert w2.is_contiguous(), "Expert weights2 must be contiguous" + assert hidden_states.dtype in [ + torch.float32, torch.float16, torch.bfloat16 + ] + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + output = fused_moe_int8_a8w8(hidden_states_quant, + w1_quant, + w2_quant, + gating_output, + w1_scales, + w2_scales, + hidden_states_scales, + a2_scales, + top_k, + renormalize=False, + inplace=False) + output_custom = fused_moe_int8_a8w8_custom(hidden_states_quant, + hidden_states_scales, + w1_quant, + w1_scales, + w2_quant, + w2_scales, + gating_output, + top_k, + renormalize=False, + inplace=False) + hidden_states_dequant = hidden_states_quant * hidden_states_scales[:, None] + w1_dequant = w1_quant * w1_scales[:, :, None] + w2_dequant = w2_quant * w2_scales[:, :, None] + out_ref = torch_moe(hidden_states_dequant, + w1_dequant, + w2_dequant, + gating_output, + top_k, + ) + diff = ~torch.isclose( + output.half().cpu(), out_ref.half().cpu(), rtol=1, atol=1 + ) + print("output:",output) +# print("output custom:",output_custom) + print("out_ref:",out_ref) + assert(diff.sum() < 10) + print("diff sum :",diff.sum()) + end_event.record() + end_event.synchronize() + + dur_ms = start_event.elapsed_time(end_event) / num_calls + return dur_ms + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + prog="benchmark_mixtral_moe_rocm", + description="Tune the fused_moe kernel for mixtral.") + parser.add_argument( + "--TP", + type=int, + choices=[8, 4, 2, 1], + help="Specify the TP value that the actual model will run on", + required=True, + ) + parser.add_argument( + "--GPUID", + type=str, + help="This script uses single GPU. Specify the GPU to use for tuning", + default="0", + ) + parser.add_argument('--model', + type=str, + choices=['8x7B', '8x22B'], + help='The Mixtral model to benchmark') + + args = parser.parse_args() + + print(f"Running tuning for {args.model} model") + print(f"TP is set to: {args.TP}") + print(f"GPU-ID being used for tuning: {args.GPUID}") + sys.exit(main(args)) diff --git a/byte_infer_perf/llm_perf/bench_model.py b/byte_infer_perf/llm_perf/bench_model.py index b1247e893..ce3463443 100644 --- a/byte_infer_perf/llm_perf/bench_model.py +++ b/byte_infer_perf/llm_perf/bench_model.py @@ -129,6 +129,13 @@ def update_template(mode, batch_size, seq_len): # warmup num_warm_iter = 10 input_template = update_template("context", 1, 1024) + is_graph = int(os.environ.get("ENABLE_GRAPH", "0")) + + if is_graph: + #ROCM_HIPGRAPH modify + input_template['capture'] = 1 + engine.mp_forward(input_template) + input_template.pop('capture') start_time = time.perf_counter_ns() for _ in range(num_warm_iter): @@ -158,7 +165,6 @@ def results_to_csv(file_path, results): row.append(f"{results[batch_size][seq_len]}") csv_writer.writerow(row) - log_results = [] if xpu_config["perf_config"]["perf_context"]: batch_size_list = [1] @@ -169,11 +175,18 @@ def results_to_csv(file_path, results): context_results[batch_size] = {} for seq_len in seq_len_list: input_template = update_template("context", 1, seq_len) + if is_graph: + #ROCM_HIPGRAPH modify + input_template['capture'] = 1 + engine.mp_forward(input_template) + input_template.pop('capture') + + total_test_iter = 20 start_iters = 2 test_iter = 0 duration_ms = 0. - while duration_ms < 5000. and test_iter < 100: + while test_iter < total_test_iter: result = engine.mp_forward(input_template) if start_iters > 0: start_iters -= 1 @@ -205,11 +218,18 @@ def results_to_csv(file_path, results): decode_results[batch_size] = {} for seq_len in seq_len_list: input_template = update_template("decode", batch_size, seq_len) + if is_graph: + #ROCM_HIPGRAPH modify + input_template['capture'] = 1 + engine.mp_forward(input_template) + input_template.pop('capture') + total_test_iter = 20 start_iters = 2 test_iter = 0 + duration_ms = 0. - while duration_ms < 5000. and test_iter < 100: + while test_iter < total_test_iter: result = engine.mp_forward(input_template) if start_iters > 0: start_iters -= 1 diff --git a/byte_infer_perf/llm_perf/core/mp_engine.py b/byte_infer_perf/llm_perf/core/mp_engine.py index 24f0da17e..7a8176bb4 100644 --- a/byte_infer_perf/llm_perf/core/mp_engine.py +++ b/byte_infer_perf/llm_perf/core/mp_engine.py @@ -32,7 +32,7 @@ def __init__( def signal_handler(signum, frame): logger.info(f"Received signal {signum}, exiting...") self.clean_subprocess() - os._exit(0) + sys.exit(0) signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGTERM, signal_handler) @@ -105,4 +105,4 @@ def mp_loop_worker( @abstractmethod def mp_forward(self, *args): - raise NotImplementedError \ No newline at end of file + raise NotImplementedError