diff --git a/openseek/competition/LongContext-ICL-Annotation/READMD_zh_cn.md b/openseek/competition/LongContext-ICL-Annotation/READMD_zh_cn.md index 735787e..82c61fb 100644 --- a/openseek/competition/LongContext-ICL-Annotation/READMD_zh_cn.md +++ b/openseek/competition/LongContext-ICL-Annotation/READMD_zh_cn.md @@ -1 +1,90 @@ -# LongContext-ICL-Annotation (长上下文场景中LLM自动数据标注挑战赛) +# 超长长上下文场景中LLM自动数据标注挑战赛 + +--- + +## 消息 + +- **[2026-01-20] `发布`:** 赛事信息已在 **Kaggle** 正式上线。详情见:[FlagOS Open Computing Global Challenge](https://www.kaggle.com/competitions/flag-os-open-computing-global-challenge). +- **[2026-01-06] `发布`:** 由 **众智 FlagOS 社区**、**北京智源人工智能研究院(BAAI)** 与 **CCF ODTC** 联合主办的综合性大赛 **FlagOS 开放计算全球挑战赛** 正式发布。详情见: + [FlagOS开放计算全球挑战赛- AI赛事通 | 数据算法赛](https://www.competehub.dev/zh/competitions/modelscope180) + + +--- + +## 简介 +长上下文场景中LLM自动数据标注挑战赛基于Qwen3-4B大语言模型,采用上下文(In-context Learning, ICL)范式开展自动化数据标注任务研究。参赛团队必须使用组委会统一提供的数据集,围绕超长上下文场景设计有效的 ICL 标注方案,并在统一评测集上完成推理与评测。组委会将依据标准化评测结果,对参赛方案进行综合评估并确定最终排名。 + +### 赛题目标 +本赛题以大语言模型(Large Language Models,LLMs)为核心驱动力,面向超长上下文条件下的自动化数据标注问题,探索兼具效率与精度的新型技术范式。赛题重点聚焦以下三个关键科学与工程问题: +1. 在超长上下文场景下,如何设计有效的模型指令与提示策略,引导 LLM 稳定、高质量地完成数据标注任务? +2. 当可用标注示例数量显著超过模型上下文容量时,如何为待标注数据构造信息密集、结构合理的超长上下文输入? +3. 在自动多轮对话或持续交互场景中,如何高效利用超长上下文,实现一致性与可扩展性兼顾的数据标注? + +### 赛题详情 +- 参赛团队自主设计一套完整的LLM自动数据标注方案,并在统一的数据集与评测设置下进行实验验证。比赛将以标准化榜单形式公布各参赛方案的评测分数及排名。 +- 除评测结果外,参赛团队还需按照赛事要求提交技术方案文档与可复现源代码。组委会将对提交方案进行复现验证,并对技术方案本身进行评审。最终成绩将由预测结果成绩与技术方案成绩加权计算得出,具体规则如下。 +- 参赛队伍需按照赛题和赛制要求,提交技术方案和完整代码至Github OpenSeek官方开源项目下。 +- 更多具体细节请参考[赛事平台](https://flagos.io/RaceDetail?id=296fmsd8&lang=cn)。 +- 关于赛事信息,一切以赛事平台公布信息为准。 + +--- + +## 快速开始 +### 1. 环境 + +```bash +openai +torch +flagScale +``` + +推荐在NVIDIA平台使用 `cd src && bash create_env_nvidia.sh` 创建环境。 + +### 2. 下载模型权重 +```bash +hf download Qwen/Qwen3-4B --local-dir Qwen3-4B +# or +modelscope download --model Qwen/Qwen3-4B +``` +### 3. 长文本配置 +在`Qwen3-4B/config.json`将原有配置替换为: +```json +"rope_scaling": { + "rope_type": "yarn", + "factor": 4.0, + "original_max_position_embeddings": 32768 +} +``` +### 4. 模型部署 + +请根据实际需求,配置 `llm_config.yaml` 文件。启动配置 + +```bash +cd FlagScale +python run.py --config-path .. --config-name llm_config action=run +``` + +在模型服务启动后,可通过以下方式测试本地 API: + +```bash +python api_test.py +``` + +如需停止服务,请执行: + +```bash +python run.py --config-path .. --config-name llm_config action=stop +``` + +### 5. 运行/改进基线方法(Baseline) + +启动如下命令开始模型标注 +```bash +python main.py +``` + +实现新的标注方法,请修改`method.py`文件。你可以在该文件中: +* 定义新的指令模板、 +* 定义新的上下文示例选择策略 +* 定义新的模型推理、标注方案 +* 添加自定义后处理逻辑 diff --git a/openseek/competition/LongContext-ICL-Annotation/README.md b/openseek/competition/LongContext-ICL-Annotation/README.md index 4eccf6b..25d1082 100644 --- a/openseek/competition/LongContext-ICL-Annotation/README.md +++ b/openseek/competition/LongContext-ICL-Annotation/README.md @@ -1,2 +1,108 @@ # LongContext-ICL-Annotation +Large Language Models Automatic Data Annotation under Long-Context Scenarios. + +## News + +- **[2026-01-20] `Release`:** The competition is now officially live on **Kaggle**. See details: [FlagOS Open Computing Global Challenge](https://www.kaggle.com/competitions/flag-os-open-computing-global-challenge). +- **[2026-01-06] `Release`:** The comprehensive competition **FlagOS Open Computing Global Challenge** was officially announced, co-hosted by the **FlagOS Community**, the **Beijing Academy of Artificial Intelligence (BAAI)**, and **CCF ODTC**. See details: + [FlagOS开放计算全球挑战赛- AI赛事通 | 数据算法赛](https://www.competehub.dev/zh/competitions/modelscope180) + + +## Introduction + +The LongContext-ICL-Annotation Challenge focuses on automatic data annotation under long-context settings using Large Language Models (LLMs). The competition is built upon the Qwen3-4B model and adopts the In-context Learning (ICL) paradigm to investigate scalable and high-quality automated annotation methods. + +Participating teams are required to use the officially provided datasets and design effective ICL-based annotation solutions tailored for ultra-long context scenarios. All submissions will be evaluated on a unified benchmark dataset. The Organizing Committee will conduct standardized evaluations and determine the final rankings based on the official evaluation results. + +## Objectives + +This challenge takes Large Language Models (LLMs) as the core technical foundation and targets automated data annotation under ultra-long context constraints, aiming to explore novel paradigms that balance annotation efficiency and annotation accuracy. The competition focuses on the following key scientific and engineering challenges: + +- 1. Instruction and Prompt Design: + + How can effective model instructions and prompt strategies be designed in ultra-long context scenarios to guide LLMs toward stable and high-quality data annotation? +- 2. Ultra-Long Context Construction: + + When the number of available annotation examples significantly exceeds the model’s context capacity, how can information-dense and structurally coherent ultra-long context inputs be constructed for target data annotation? +- 3. Multi-Turn and Continuous Annotation: + + In automated multi-round dialogue or continuous interaction settings, how can ultra-long contexts be efficiently leveraged to achieve both consistency and scalability in data annotation? + +## Challenge Details + +- Participating teams are expected to independently design a complete LLM-based automatic data annotation pipeline and validate their approach under a unified dataset and evaluation protocol. Evaluation scores and rankings will be published on a standardized leaderboard. + +- In addition to prediction results, teams must submit a technical report and fully reproducible source code in accordance with the competition requirements. The Organizing Committee will reproduce submitted solutions and review the technical design. The final score will be calculated as a weighted combination of prediction performance and technical solution evaluation, with detailed rules specified by the competition. + +- Teams are required to submit their technical reports and complete source code to the official OpenSeek GitHub repository designated by the competition. + +- For additional details, please refer to [FlagOS platform](https://flagos.io/RaceDetail?id=296fmsd8&lang=en). All competition-related information is subject to the announcements published on the official platform. + +## Quick Start + +### 1. Environment Setup + +```bash +openai +torch +flagScale +``` + +On NVIDIA platforms, it is recommended to create the environment using: `cd src && bash create_env_nvidia.sh` + +### 2. Download Model Weights + +```bash +hf download Qwen/Qwen3-4B --local-dir Qwen3-4B +# or +modelscope download --model Qwen/Qwen3-4B +``` + +### 3. Long-Context Configuration + +In `Qwen3-4B/config.json`, replace the original configuration with the following settings: + +```json +"rope_scaling": { + "rope_type": "yarn", + "factor": 4.0, + "original_max_position_embeddings": 32768 +} +``` + +### 4. Model Deployment + +Configure the `llm_config.yaml` file according to your actual requirements. Then start the service with: + +```bash +cd FlagScale +python run.py --config-path .. --config-name llm_config action=run +``` + +After the model service is launched, you can test the local API using: + +```bash +python api_test.py +``` + +To stop the service, run: + +```bash +python run.py --config-path .. --config-name llm_config action=stop +``` + +### 5. Run or Extend the Baseline Method + +Start the baseline annotation pipeline with: + +```bash +python main.py +``` + +To implement a new annotation method, modify the `method.py` file. Within this file, you may: + +- Define new instruction or prompt templates +- Design new context example selection strategies +- Implement alternative model inference and annotation pipelines +- Add custom post-processing logic diff --git a/openseek/competition/LongContext-ICL-Annotation/data/README.md b/openseek/competition/LongContext-ICL-Annotation/data/README.md index 7623f1b..f30f56b 100644 --- a/openseek/competition/LongContext-ICL-Annotation/data/README.md +++ b/openseek/competition/LongContext-ICL-Annotation/data/README.md @@ -1,5 +1,18 @@ +# Datasets -| 任务ID | 任务名称 | 上下文长度最短 | 测试样本数 | +This repository provides the official datasets for the **LLM Automatic Data Annotation**. +The datasets are specifically designed to evaluate the capability of Large Language Models (LLMs) to perform **automatic data annotation under ultra-long context settings** using the In-context Learning (ICL) paradigm. + + +## Overview + +- Most tasks require a **minimum ICL context length of 30K tokens**, deliberately exceeding standard context limits to evaluate long-context understanding, prompt engineering, and example selection strategies. +- Task **openseek-8** is configured with a **shorter minimum context length (15K tokens)** and a **smaller test set**, reflecting the unique challenges of **kernel generation**. +- All datasets are released with **fixed and standardized test splits** to ensure fair comparison and reproducibility across submissions. +- The task suite covers a **diverse range of domains and reasoning types**, including symbolic reasoning, linguistic analysis, natural language inference, code-related tasks, and open-ended generation. + + +| Task ID | task name | Minimum ICL context | Test sample number | | --- | --- | --- | --- | | openseek-1 | closest_integers | 30K | 500 | | openseek-2 | count_nouns_verbs | 30K | 500 | @@ -8,4 +21,22 @@ | openseek-5 | semeval_2018_task1_tweet_sadness_detection | 30K | 500 | | openseek-6 | mnli_same_genre_classification | 30K | 500 | | openseek-7 | jeopardy_answer_generation_all | 30K | 500 | -| openseek-8 | kernel_genernation | 30K | 500 | \ No newline at end of file +| openseek-8 | kernel_genernation | 15K | 166 | + + +## Data Structure +The datasets are organized in JSON format, with each task having its own json file. Here's a brief overview of the data structure: + +- `task_id`: A unique identifier for the task. +- "task_name": A short human-readable name of the task. +- `Definition`: A detailed description of what the model should do. +- `examples`: Demonstration samples intended for understanding the task format (not necessarily used for scoring). Each example typically includes: `id`, `input` and `output`. +- `test_samples`: The samples to be predicted by participants. Labels/ground truth is hidden. Each test sample typically includes: `id` and `input`. +- `License`: The dataset license name and/or a URL to the license text, describing allowed use and redistribution. + + +## Usage Notes + +- Participants must use the **official datasets as provided**, without altering test splits or labels, for leaderboard evaluation. +- Any preprocessing steps, context construction strategies, or example selection mechanisms should be clearly described in the accompanying technical report. +- All experimental results must be **fully reproducible** using the datasets in this repository. \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/data/README_zh.md b/openseek/competition/LongContext-ICL-Annotation/data/README_zh.md new file mode 100644 index 0000000..cc565fe --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/data/README_zh.md @@ -0,0 +1,46 @@ +# 数据集说明 + +本仓库提供 **LLM Automatic Data Annotation** 的官方数据集。 + +这些数据集专门用于评估大语言模型(LLMs)在 超长上下文设置 下,使用 In-context Learning(ICL)范式进行 自动数据标注 的能力。 + +--- + +## 概览 + +- 大多数任务要求 **最小 ICL 上下文长度为 30K tokens**,接近 **Qwen3-4B** 的标准上下文限制,以评估长上下文理解、提示工程以及示例选择策略。 +- 任务 **openseek-8** 配置了 **更短的最小上下文长度(15K tokens)** 和 **更小的测试集**,以反映**算子生成**的独特挑战。 +- 所有数据集均以 **固定且标准化的测试划分** 发布,以确保提交之间的公平比较与可复现性。 +- 任务集合覆盖 **多样的领域与推理类型**,包括符号推理、语言学分析、自然语言推断、代码相关任务以及开放式生成。 + +| Task ID | task name | Minimum ICL context | Test sample number | +| --- | --- | --- | --- | +| openseek-1 | closest_integers | 30K | 500 | +| openseek-2 | count_nouns_verbs | 30K | 500 | +| openseek-3 | collatz_conjecture | 30K | 500 | +| openseek-4 | conala_concat_strings | 30K | 500 | +| openseek-5 | semeval_2018_task1_tweet_sadness_detection | 30K | 500 | +| openseek-6 | mnli_same_genre_classification | 30K | 500 | +| openseek-7 | jeopardy_answer_generation_all | 30K | 500 | +| openseek-8 | kernel_genernation | 15K | 166 | + +--- + +## 数据结构 + +数据集以 `JSON` 格式组织,每个任务对应一个独立的 `.json` 文件。以下是数据结构的简要说明: + +- `task_id`: 任务的唯一标识符。 +- `task_name`: 任务的简短、便于理解的人类可读名称。 +- `Definition`: 对模型应执行内容的详细描述。 +- `examples`: 用于理解任务格式的演示样本(不一定用于计分)。每个示例包含:`id`、`input` 和 `output`。 +- `test_samples`: 参赛者需要预测的样本。标签/真实值被隐藏。每个测试样本包含:`id` 和 `input`。 +- `License`: 数据集许可证名称和/或许可证文本的 URL,用于说明允许的使用方式与再分发规则。 + +--- + +## 使用说明 + +- 参赛者必须使用**按原样提供的官方数据集**,不得更改测试划分或标签,以用于排行榜评测。 +- 任何预处理步骤、上下文构建策略或示例选择机制,都应在随附的技术报告中清晰描述。 +- 所有实验结果必须能够使用本仓库中的数据集**完全复现**。 \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/data/openseek-8_kernel_generation.json b/openseek/competition/LongContext-ICL-Annotation/data/openseek-8_kernel_generation.json new file mode 100644 index 0000000..29a3780 --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/data/openseek-8_kernel_generation.json @@ -0,0 +1,1964 @@ +{ + "task_id": "openseek-8", + "task_name": "kernel generation", + "Definition": [ + "Implementing custom algorithms or functions using Triton, and ensuring correct block masking and stride handling for memory safety." + ], + "examples": [ + { + "id": "openseek-8-5ef60b9ee4584b80abe4b409616a2129", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code defines a custom attention mechanism in PyTorch using the Triton library. This attention mechanism is implemented as a custom autograd function `LightningAttention2NoDecay` with `forward` and `backward` methods. The forward method computes the attention output given input tensors Q (queries), K (keys), and V (values), while the backward method computes gradients for Q, K, and V given the gradient of the output.\n\n The `_fwd_kernel` is responsible for the forward pass computation. It calculates the attention output by processing Q, K, and V in blocks of size `BLOCK` (64). It uses `NUM_BLOCK` to determine how many such blocks exist along the sequence dimension. The kernel loads segments of Q, K, and V, computes their dot product, and uses the result to calculate the output by combining intra-block (within the block) and inter-block (between blocks) interactions.\n\n The `_bwd_intra_kernel` is used in the backward pass to compute gradients within each block. It processes the gradient of the output (`DO`) and calculates the gradients `DQ`, `DK`, and `DV` for each of the input tensors. It uses a block size of `CBLOCK` (32) for sub-block computations, iterating over `NUM_BLOCK` blocks.\n\n The `_bwd_inter_kernel` computes gradients involving interactions between blocks. It iteratively updates the accumulated gradients for the entire input sequence. It uses the computed values from the `_bwd_intra_kernel` to adjust gradients for keys (K) and values (V).\n\n The code uses a grid launch strategy for parallel computation across batches and heads, defined by `b * h`, and sequence dimension divided into blocks.\n\n Important parameters and settings include:\n - `BLOCK`: Main block size (64) used in computations.\n - `NUM_BLOCK`: Number of blocks along the sequence dimension.\n - `CBLOCK`: Sub-block size (32) used for intra-block gradient calculations.\n - `NUM_CBLOCK`: Number of sub-blocks within each block for intra operations.\n\n These kernels are called using a grid defined by `(b * h, cdiv(e, BLOCK_MODEL))` for the forward pass and intra-block backward pass, and `(b * h,)` for the inter-block backward pass. The context saves Q, K, and V during the forward pass to facilitate efficient gradient computation during the backward pass.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n Out,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n BLOCK_MODEL: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_bh % h\n off_e = tl.program_id(1)\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n # channel offset\n e_offset = off_e * BLOCK_MODEL\n\n ##### get block ptr\n Q_block_ptr = Q + qk_offset + tl.arange(0, d)[None, :]\n K_trans_block_ptr = K + qk_offset + tl.arange(0, d)[:, None]\n V_block_ptr = V + v_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]\n O_block_ptr = Out + o_offset + e_offset + tl.arange(0, BLOCK_MODEL)[None, :]\n\n ##### init diag decay(Lambda); q, k decay; kv\n # q, k decay\n off_block = tl.arange(\n 0, BLOCK\n ) # Not bug, this is a bit different from algorithm 1, but is mathematically equivalent\n # diag decay\n index = off_block[:, None] - off_block[None, :]\n kv = tl.zeros([d, BLOCK_MODEL], dtype=tl.float32)\n\n ##### compute\n for i in range(NUM_BLOCK):\n # load\n q = tl.load(\n Q_block_ptr + off_block[:, None] * d, mask=off_block[:, None] < n, other=0.0\n ).to(tl.float32)\n k_trans = tl.load(\n K_trans_block_ptr + off_block[None, :] * d,\n mask=off_block[None, :] < n,\n other=0.0,\n ).to(tl.float32)\n v = tl.load(\n V_block_ptr + off_block[:, None] * e, mask=off_block[:, None] < n, other=0.0\n ).to(tl.float32)\n\n # compute\n qk = tl.dot(q, k_trans)\n qk = tl.where(index >= 0, qk, 0)\n o_intra = tl.dot(qk, v)\n o_inter = tl.dot(q, kv)\n o = o_intra + o_inter\n\n # save and update\n tl.store(\n O_block_ptr + off_block[:, None] * e,\n o.to(O_block_ptr.dtype.element_ty),\n mask=off_block[:, None] < n,\n )\n kv += tl.dot(k_trans, v)\n off_block += BLOCK\n\n\n@triton.jit\ndef _bwd_intra_kernel(\n Q,\n K,\n V,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_block = tl.program_id(1)\n off_bh % h\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n block_offset = off_block * BLOCK + tl.arange(0, BLOCK)\n\n ##### get block ptr\n Q_trans_block_ptr = (\n Q + qk_offset + block_offset[None, :] * d + tl.arange(0, d)[:, None]\n )\n K_block_ptr = K + qk_offset + block_offset[:, None] * d + tl.arange(0, d)[None, :]\n V_trans_block_ptr = (\n V + v_offset + block_offset[None, :] * e + tl.arange(0, e)[:, None]\n )\n\n DQ_block_ptr = DQ + qk_offset + block_offset[:, None] * d + tl.arange(0, d)[None, :]\n DK_trans_block_ptr = (\n DK + qk_offset + block_offset[None, :] * d + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = DV + v_offset + block_offset[:, None] * e + tl.arange(0, e)[None, :]\n DO_block_ptr = DO + o_offset + block_offset[:, None] * e + tl.arange(0, e)[None, :]\n\n ##### init diag decay(Lambda)\n array = tl.arange(0, BLOCK).to(tl.float32)\n # diag\n index = array[:, None] - array[None, :]\n\n ##### load block\n k = tl.load(K_block_ptr, mask=block_offset[:, None] < n, other=0.0).to(tl.float32)\n v_trans = tl.load(V_trans_block_ptr, mask=block_offset[None, :] < n, other=0.0).to(\n tl.float32\n )\n do = tl.load(DO_block_ptr, mask=block_offset[:, None] < n, other=0.0).to(tl.float32)\n q_trans = tl.load(Q_trans_block_ptr, mask=block_offset[None, :] < n, other=0.0).to(\n tl.float32\n )\n\n ##### compute\n dqk = tl.dot(do, v_trans)\n dqk = tl.where(index >= 0, dqk, 0)\n dq_intra = tl.dot(dqk, k)\n\n dk_intra_trans = tl.dot(q_trans, dqk)\n\n qk_trans = tl.dot(k, q_trans)\n qk_trans = tl.where(index <= 0, qk_trans, 0)\n dv_intra = tl.dot(qk_trans, do)\n\n dq = dq_intra\n dk_trans = dk_intra_trans\n dv = dv_intra\n\n # save\n tl.store(\n DQ_block_ptr,\n dq.to(DQ_block_ptr.dtype.element_ty),\n mask=block_offset[:, None] < n,\n )\n tl.store(\n DK_trans_block_ptr,\n dk_trans.to(DK_trans_block_ptr.dtype.element_ty),\n mask=block_offset[None, :] < n,\n )\n tl.store(\n DV_block_ptr,\n dv.to(DV_block_ptr.dtype.element_ty),\n mask=block_offset[:, None] < n,\n )\n\n\n@triton.jit\ndef _bwd_inter_kernel(\n Q,\n K,\n V,\n DO,\n DQ,\n DK,\n DV,\n b: tl.constexpr,\n h: tl.constexpr,\n n: tl.constexpr,\n d: tl.constexpr,\n e: tl.constexpr,\n BLOCK: tl.constexpr,\n NUM_BLOCK: tl.constexpr,\n CBLOCK: tl.constexpr,\n NUM_CBLOCK: tl.constexpr,\n):\n ##### get offset\n off_bh = tl.program_id(0)\n off_bh % h\n\n qk_offset = off_bh * n * d\n v_offset = off_bh * n * e\n o_offset = off_bh * n * e\n\n ##### get block ptr\n DQ_block_ptr = (\n DQ + qk_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n K_block_ptr = (\n K + qk_offset + tl.arange(0, CBLOCK)[:, None] * d + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V + v_offset + tl.arange(0, CBLOCK)[None, :] * e + tl.arange(0, e)[:, None]\n )\n DO_block_ptr = (\n DO + o_offset + tl.arange(0, CBLOCK)[:, None] * e + tl.arange(0, e)[None, :]\n )\n # mask\n off_block1 = tl.arange(0, CBLOCK)\n off_block2 = tl.arange(0, CBLOCK)\n\n ##### init lambda; kv\n kv_trans = tl.zeros([e, d], dtype=tl.float32)\n\n ##### compute dq inter\n for i in range(NUM_BLOCK):\n # compute in subblock\n for j in range(NUM_CBLOCK):\n if i > 0: # if not add this, may have bug\n do = tl.load(DO_block_ptr, mask=off_block1[:, None] < n, other=0.0).to(\n tl.float32\n )\n dq_inter = tl.dot(do, kv_trans)\n dq = dq_inter + tl.load(\n DQ_block_ptr, mask=off_block1[:, None] < n, other=0.0\n )\n tl.store(\n DQ_block_ptr,\n dq.to(DQ_block_ptr.dtype.element_ty),\n mask=off_block1[:, None] < n,\n )\n\n DQ_block_ptr += CBLOCK * d\n DO_block_ptr += CBLOCK * e\n off_block1 += CBLOCK\n\n # update kv in subblock\n kv_trans_current = tl.zeros([e, d], dtype=tl.float32)\n for j in range(NUM_CBLOCK):\n v_trans = tl.load(\n V_trans_block_ptr, mask=off_block2[None, :] < n, other=0.0\n ).to(tl.float32)\n k = tl.load(K_block_ptr, mask=off_block2[:, None] < n, other=0.0).to(\n tl.float32\n )\n kv_trans_current += tl.dot(v_trans, k)\n\n K_block_ptr += CBLOCK * d\n V_trans_block_ptr += CBLOCK * e\n off_block2 += CBLOCK\n\n kv_trans += kv_trans_current\n\n ##### get block ptr\n m = NUM_BLOCK * BLOCK\n off_block1 = m + tl.arange(0, CBLOCK)\n off_block2 = m + tl.arange(0, CBLOCK)\n\n Q_trans_block_ptr = (\n Q\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n K_block_ptr = (\n K\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[:, None] * d\n + tl.arange(0, d)[None, :]\n )\n V_trans_block_ptr = (\n V\n + v_offset\n + m * e\n + tl.arange(0, CBLOCK)[None, :] * e\n + tl.arange(0, e)[:, None]\n )\n\n DK_trans_block_ptr = (\n DK\n + qk_offset\n + m * d\n + tl.arange(0, CBLOCK)[None, :] * d\n + tl.arange(0, d)[:, None]\n )\n DV_block_ptr = (\n DV\n + v_offset\n + m * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n DO_block_ptr = (\n DO\n + o_offset\n + m * e\n + tl.arange(0, CBLOCK)[:, None] * e\n + tl.arange(0, e)[None, :]\n )\n\n ##### init dkv\n dkv = tl.zeros([d, e], dtype=tl.float32)\n\n ##### compute dk, dv inter\n for i in range(NUM_BLOCK - 1, -1, -1):\n # compute in subblock\n for j in range(NUM_CBLOCK - 1, -1, -1):\n K_block_ptr -= CBLOCK * d\n V_trans_block_ptr -= CBLOCK * e\n DK_trans_block_ptr -= CBLOCK * d\n DV_block_ptr -= CBLOCK * e\n off_block1 -= CBLOCK\n\n if i < NUM_BLOCK - 1: # if not add this, may have bug\n k = tl.load(K_block_ptr, mask=off_block1[:, None] < n, other=0.0).to(\n tl.float32\n )\n v_trans = tl.load(\n V_trans_block_ptr, mask=off_block1[None, :] < n, other=0.0\n ).to(tl.float32)\n\n dk_inter_trans = tl.dot(dkv, v_trans)\n dv_inter = tl.dot(k, dkv)\n\n dk_trans = dk_inter_trans + tl.load(\n DK_trans_block_ptr, mask=off_block1[None, :] < n, other=0.0\n )\n dv = dv_inter + tl.load(\n DV_block_ptr, mask=off_block1[:, None] < n, other=0.0\n )\n\n tl.store(\n DK_trans_block_ptr,\n dk_trans.to(DK_trans_block_ptr.dtype.element_ty),\n mask=off_block1[None, :] < n,\n )\n tl.store(\n DV_block_ptr,\n dv.to(DV_block_ptr.dtype.element_ty),\n mask=off_block1[:, None] < n,\n )\n\n # update dkv in subblock\n dkv_current = tl.zeros([d, e], dtype=tl.float32)\n for j in range(NUM_CBLOCK - 1, -1, -1):\n DO_block_ptr -= CBLOCK * e\n Q_trans_block_ptr -= CBLOCK * d\n off_block2 -= CBLOCK\n\n do = tl.load(DO_block_ptr, mask=off_block2[:, None] < n, other=0.0).to(\n tl.float32\n )\n q_trans = tl.load(\n Q_trans_block_ptr, mask=off_block2[None, :] < n, other=0.0\n ).to(tl.float32)\n dkv_current += tl.dot(q_trans, do)\n\n dkv += dkv_current\n\n\nclass LightningAttention2NoDecay(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v):\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n o = torch.empty((b, h, n, e), dtype=q.dtype, device=q.device)\n\n BLOCK = 64\n NUM_BLOCK = triton.cdiv(q.shape[2], BLOCK)\n # parallel over channel\n BLOCK_MODEL = min(triton.next_power_of_2(e), 32)\n grid = (b * h, triton.cdiv(e, BLOCK_MODEL))\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n o,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n BLOCK_MODEL=BLOCK_MODEL,\n )\n\n ctx.save_for_backward(q, k, v)\n\n return o\n\n @staticmethod\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n\n q = q.contiguous()\n k = k.contiguous()\n v = v.contiguous()\n do = do.contiguous()\n\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n\n b, h, n, d = q.shape\n e = v.shape[-1]\n\n # block size\n BLOCK = 64\n NUM_BLOCK = triton.cdiv(n, BLOCK)\n # compute block size\n CBLOCK = 32\n NUM_CBLOCK = BLOCK // CBLOCK\n\n # for intra part, compute in parallel\n grid = (b * h, NUM_BLOCK)\n _bwd_intra_kernel[grid](\n q,\n k,\n v,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n # for inter part, compute in sequencial\n grid = (b * h,)\n _bwd_inter_kernel[grid](\n q,\n k,\n v,\n do,\n dq,\n dk,\n dv,\n b,\n h,\n n,\n d,\n e,\n BLOCK=BLOCK,\n NUM_BLOCK=NUM_BLOCK,\n CBLOCK=CBLOCK,\n NUM_CBLOCK=NUM_CBLOCK,\n )\n\n return dq, dk, dv\n\n\nlightning_attn2_no_decay = LightningAttention2NoDecay.apply\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1e4c7d4bc335410e9dfaf6efee92b213", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel '_dequantize_rowwise' is designed for row-wise dequantization of the input tensor 'x'. The kernel takes pointers to the input tensor 'x_ptr', the state tensor 'state_x', the output tensor 'output_ptr', a precomputed inverse of 127, the total number of elements 'n_elements', a block size 'BLOCK_SIZE', and 'P2', which is the nearest power of two of the number of columns. Each kernel instance processes a single row identified by 'pid', the program ID in axis 0. The 'block_start' variable calculates the starting index for the current block, and 'arange' creates an index range up to 'P2'. 'offsets' computes the global memory offsets for loading the input tensor. The input values are loaded with 'tl.load' using these offsets and a mask to ensure out-of-bound accesses are ignored. The maximum value for the row is loaded from 'state_x', and each element is dequantized by multiplying with the maximum value and 'inv_127'. The results are stored back to 'output_ptr'. The Python function 'dequantize_rowwise' is a wrapper that prepares the output tensor, computes 'P2' for efficient memory alignment, and sets up the execution grid to invoke the Triton kernel. It assumes the input tensor is on CUDA, ensuring compatibility with Triton's GPU-based execution.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for rowwise dequantization\n@triton.jit\ndef _dequantize_rowwise(\n x_ptr,\n state_x,\n output_ptr,\n inv_127,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n max_val = tl.load(state_x + pid)\n output = max_val * x * inv_127\n tl.store(output_ptr + offsets, output, mask=row_mask)\n\n# Function to call the Triton kernel\ndef dequantize_rowwise(x: torch.Tensor, state_x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _dequantize_rowwise[grid](x, state_x, output, 1.0 / 127, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-63c2f84f19784171b5f4d7c7edee9418", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel function `kldivergence_kernel` performs element-wise computation of Kullback-Leibler (KL) divergence between two input vectors, `x_ptr` and `y_ptr`, storing results in `output_ptr`. It utilizes a block-wise processing approach, specified by `BLOCK_SIZE`, where each program (or kernel instance) computes over a distinct block of elements. The kernel starts by obtaining a unique program identifier (`pid`) for each execution using `tl.program_id(axis=0)`, which allows computation over a 1D grid. For each block, indices are determined by calculating `block_start = pid * BLOCK_SIZE` and creating an `offsets` tensor as `block_start + tl.arange(0, BLOCK_SIZE)`. A `mask` is used to ensure memory accesses are within bounds, protecting against out-of-bounds errors when loading from `x_ptr` and `y_ptr`. Elements of `x` and `y` within the current block are loaded using `tl.load()`, and KL divergence is computed with `output = x * tl.log(x / y)`. This result is conditionally stored in `output_ptr` using `tl.store(output_ptr + offsets, output, mask=mask)`. The function `kldivergence` serves as the Python API, enforcing GPU presence of input tensors and setting up the grid for kernel execution. It defines a grid function, `grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )`, which calculates the number of blocks needed to cover all elements. The kernel is then launched with specified block size, processing inputs `x` and `y` and returning the resulting `output`.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kldivergence_kernel(x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n # NOTE: `constexpr` so it can be used as a shape value.\n ):\n\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n mask = offsets < n_elements\n\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x * tl.log(x / y)\n\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef kldivergence(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n\n kldivergence_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-49b28b6a3b3c469db033c186caad88fe", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel is designed to perform the forward pass of a context attention mechanism, primarily used in Transformer architectures. This kernel is tailored for execution on a GPU, utilizing Triton's capabilities for efficient parallel processing.\n\n The key Triton function, '_fwd_kernel', is executed over a grid defined by batch, head, and input length dimensions. It processes multiple queries, keys, and values, computing attention scores and subsequently deriving an output tensor. This function supports batched and multi-headed attention, allowing flexibility in model architectures.\n\n Within '_fwd_kernel', queries (Q) are loaded for each block, and their dot product with keys (K) is calculated. This product is scaled by a factor derived from the head dimension, followed by the application of the softmax function to produce attention weights. These weights are then multiplied with values (V) to accumulate the output, which represents the weighted sum based on attention scores.\n\n Special attention is given to handling different sequence lengths, batching, and the multi-head structure, with each kernel instance operating independently across these dimensions. Stride parameters ensure that memory is accessed correctly based on input tensor shapes.\n\n The 'context_attention_fwd' function orchestrates this process by setting up the necessary kernel arguments, computing grid dimensions based on the input size, and selecting appropriate block sizes for optimal performance. It accounts for hardware specifics, such as different configurations for Tesla GPUs, ensuring that the kernel runs efficiently across various setups.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\nimport math\nimport torch.nn.functional as F\n\nTESLA = \"Tesla\" in torch.cuda.get_device_name(0)\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度\n Out,\n Req_to_tokens,\n B_req_idx,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n b_prompt_cache_len,\n head_dim: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n\n q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n k = tl.load(\n K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0\n )\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float(\"-100000000.0\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0)\n acc = acc * acc_scale[:, None]\n # update acc\n off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v = tl.load(\n V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0\n )\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim))\n return\n\n\n@torch.no_grad()\ndef context_attention_fwd(\n q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs\n):\n BLOCK = 128 if not TESLA else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n head_dim = Lq\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n\n sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n req_to_token_indexs,\n b_req_idx,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n kv_group_num=kv_group_num,\n b_prompt_cache_len=b_prompt_cache_len,\n head_dim=head_dim,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-cbe4869e42de4b358c617e0d7d53b70d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_swiglu_fwd_kernel` is designed to perform the Swiglu operation for the forward pass, where each element of the input `X` is multiplied by its sigmoid activation and then multiplied by the corresponding element in `Y`. The kernel is implemented with `@triton.jit`, allowing it to run on the GPU efficiently. It is autotuned over various configurations to handle different block sizes, optimizing performance based on the `ncols` parameter. The kernel uses Triton's `program_id` to determine which row and block of columns to process. The `load` function fetches data from the input tensors, and the `store` function writes the result to the `OUT` tensor. The function `_swiglu_fwd` prepares input tensors `x` and `y`, ensuring they are contiguous and splitting them from the input `xy`. It also reshapes the tensors as needed and initializes the output tensor. The grid is defined based on the number of rows (`M`) and columns (`N`) using `triton.cdiv` to calculate blocks per column. The kernel is then launched with the specified grid configuration.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Forward Triton kernel for Swish-Gated Linear Units (Swiglu)\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.jit\ndef _swiglu_fwd_kernel(\n X, Y, OUT, stride_x_row, stride_y_row, stride_out_row, ncols, BLOCK_N: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n OUT += row * stride_out_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n out = x * tl.sigmoid(x) * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n# Function to invoke the forward kernel\ndef _swiglu_fwd(xy, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_fwd_kernel[grid](x, y, out, x.stride(0), y.stride(0), out.stride(0), N)\n return out.reshape(*batch_shape, out.shape[-1])\n\n\n\n" + ] + }, + { + "id": "openseek-8-c28358f2ed45466197100c7916cc09f1", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe provided code is a Triton-based implementation for efficient matrix multiplication of two matrices `a` and `b`. It utilizes `triton.jit` to compile the `matmul_kernel`, which computes the product using block-wise operations for parallel execution on GPUs. The function `matmul` serves as a wrapper around this kernel to facilitate its usage.\n\n**matmul_kernel Function:**\n- **Inputs:** \n - Pointers to matrices `a_ptr`, `b_ptr`, and `c_ptr`.\n - Dimensions `M`, `N`, and `K`.\n - Strides for each matrix to navigate through memory.\n - Constants for blocking: `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K`, `GROUP_SIZE_M`.\n- **Outputs:** \n - Writes the result to the matrix pointed by `c_ptr`.\n- **Logic:**\n - Calculates thread and block IDs for distribution of work among threads.\n - Computes the offsets for accessing blocks of `a` and `b`.\n - Accumulates dot products of int8 elements using four stages of inner loop unrolling, ensuring correctness by shifting and masking operations.\n - The final result is stored in `c`.\n\n**matmul Function:**\n- **Inputs:** \n - `a`: A 2D tensor with shape `(M, 4*K)` in `int8` type.\n - `b`: A 2D tensor with shape `(K, N)` in `uint8` type, packed for efficiency.\n- **Outputs:** \n - Returns the resulting matrix `c` of shape `(M, N)` with `int32` type.\n- **Functionality:**\n - Checks for dimensional compatibility and matrix continuity.\n - Initializes an empty tensor `c` for output.\n - Defines a lambda `grid` for determining the execution grid size based on block configurations.\n - Launches the `matmul_kernel` with the calculated grid size and required metadata.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\ndef get_autotune_config():\n return [\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 128,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 256,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 128,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 256,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 128,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 128,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 128,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 64,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 64,\n \"BLOCK_SIZE_N\": 128,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=4,\n num_warps=4,\n ),\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 32,\n \"BLOCK_SIZE_N\": 32,\n \"BLOCK_SIZE_K\": 32,\n \"GROUP_SIZE_M\": 4,\n },\n num_stages=4,\n num_warps=4,\n ),\n ]\n\n@triton.autotune(\n configs=get_autotune_config(),\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr,\n b_ptr,\n c_ptr,\n M,\n N,\n K: tl.constexpr,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_SIZE_M: tl.constexpr,\n BLOCK_SIZE_N: tl.constexpr,\n BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n tl.static_assert(\n K % (4 * BLOCK_SIZE_K) == 0,\n \"K / 4 must be divisible by BLOCK_SIZE_K => K divisible by 4*BLOCK_SIZE_K\",\n )\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)\n for i in range(4):\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n for j in range(0, tl.cdiv(K // 4, BLOCK_SIZE_K)):\n k = i * tl.cdiv(K // 4, BLOCK_SIZE_K) + j\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0).to(tl.int8) # 转换为 int8 类型\n b_uint8 = tl.load(b_ptrs, mask=offs_k[:, None] < K, other=0)\n mask = 3 << (2 * i)\n b = ((b_uint8 & mask) >> (2 * i)).to(tl.int8) # 转换为 int8 类型\n tensor_full = tl.full((1,), 1, dtype=tl.int8) # 使用 int8 类型\n accumulator += tl.dot(a, (b - tensor_full), out_dtype=tl.int32) # 保持 a 和 b 为 int8\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n c = accumulator\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\ndef matmul(a, b):\n assert (\n a.shape[1] == b.shape[0] * 4\n ), \"Incompatible dimensions, the weight matrix need to be packed\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n _, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.int32)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n matmul_kernel[grid](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n )\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-39503ef966594151afbe885a6ed8f660", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_aligned` function is a Triton JIT-compiled kernel designed to perform attention operations incorporating relative positional embeddings between query (Q), key (K), and value (V) tensors. It uses block-wise loading and processing for efficiency, leveraging blocks defined by dimensions `BLOCK_M`, `BLOCK_N`, and `BLOCK_DMODEL`. The kernel applies a scaled dot-product mechanism with a temperature scale `sm_scale`, accounting for additional bias from the B0 tensor. The forward pass of the attention mechanism iterates over the sequence dimension (N_CTX + P_SEQ), calculating softmax scores and accumulating weighted values into `acc`. The outer function `_attention_rel_h_rel_w_kernel_aligned_device` configures the execution environment, specifying block dimensions, warps, and stages. It verifies shape and type consistency of inputs and invokes the kernel using a 3D grid setup that parallels the batch, head, and sequence dimensions. The resulting output is computed in `OUT_DTYPE` format, ensuring the data type compatibility based on input precision (float16 or bfloat16).\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_aligned(\n Q, K, V, B0, sm_scale,\n Out,\n stride_qh, stride_qm, stride_qk,\n stride_kh, stride_kn, stride_kk,\n stride_vh, stride_vk, stride_vn,\n stride_oh, stride_om, stride_on,\n stride_b0h, stride_b0m,\n Z,\n H,\n N_CTX,\n P_SEQ,\n OUT_DTYPE: tl.constexpr,\n BIAS_LAST_SIZE: tl.constexpr,\n B0_NUMEL: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n q_offset = off_hz * stride_qh\n kv_offset = off_hz * stride_kh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + kv_offset,\n shape=(BLOCK_DMODEL, N_CTX + P_SEQ),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + kv_offset,\n shape=(N_CTX + P_SEQ, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(OUT_DTYPE)\n lo = 0\n hi = N_CTX + P_SEQ\n\n b_ptr_offsets_m = tl.arange(0, BLOCK_M)\n\n b_offset = off_hz * stride_b0h\n b_ptr_offsets_n_1 = (tl.arange(0, BLOCK_N) %\n BIAS_LAST_SIZE) + BIAS_LAST_SIZE\n b1 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)\n * stride_b0m)[:, None] + b_ptr_offsets_n_1[None, :])\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE)\n qk += tl.dot(q, k, out_dtype=OUT_DTYPE)\n\n b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)\n * stride_b0m)[:, None] + start_n // BLOCK_N)\n qk += (b0 + b1)\n\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc *= alpha[:, None]\n acc += tl.dot(p.to(OUT_DTYPE), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n\n acc = acc / l_i[:, None]\n\n O_block_ptr = tl.make_block_ptr(\n base=Out + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(OUT_DTYPE))\n\n\ndef _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,\n BLOCK_M,\n BLOCK_N,\n num_warps,\n num_stages):\n _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]\n assert q.size() == k.size()\n assert q.size() == v.size()\n assert q.size(-2) == rel_h_w.size(-2)\n assert (q.dtype == torch.bfloat16 or q.dtype == torch.float16)\n assert k.dtype == q.dtype\n assert v.dtype == k.dtype\n assert o.dtype == v.dtype\n assert rel_h_w.dtype == q.dtype\n assert rel_h_w.size(-1) == 128\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]\n assert P_SEQ == 0\n assert rel_h_w.is_contiguous(), str(rel_h_w.stride())\n _fwd_kernel_aligned[grid](\n q, k, v,\n rel_h_w,\n sm_scale,\n o,\n q.stride(1), q.stride(2), q.stride(3),\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n o.stride(1), o.stride(2), o.stride(3),\n rel_h_w.stride(1), rel_h_w.stride(2),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n P_SEQ,\n OUT_DTYPE=tl.float16 if q.dtype == torch.float16 else tl.bfloat16,\n BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2),\n B0_NUMEL=rel_h_w.size(-1),\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=num_stages)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-cadeb667637749ff98659d925c25ad58", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe 'ff_llama' Triton kernel performs a complex fused operation involving matrix multiplications and element-wise activation functions. The operation specifically computes F.silu(w1(x)) * w3(x), where 'w1' and 'w3' are linear layer weights. The kernel integrates root mean square (RMS) scaling using an auxiliary weight 'rms_w', ensuring numerical stability with an epsilon parameter 'EPS' during normalization.\n\nThe kernel's loop iteratively loads chunks of the input matrix 'x', weights 'w1', 'w3', and RMS weights. It computes the accumulated sums for two different matrix multiplications ('acc1' and 'acc2') within tiled regions defined by 'BLOCK_SIZE_M', 'BLOCK_SIZE_N', and 'BLOCK_SIZE_K'. The kernel considers whether to use FP8 precision based on weight data types.\n\nNormalization is applied using L2-norm to stabilize the results, followed by combining both accumulators with a scaled sigmoid activation to produce the final output. The output is stored conditionally, ensuring it respects the tensor dimensions.\n\nThe 'kernel_ff' function wraps this kernel, handling the reshaping and preparation of inputs. It asserts the correct types and shapes, transposes weight matrices, and sets up grid dimensions for parallel execution before invoking the Triton kernel. The grid lambda function calculates the grid size required based on input sizes and block sizes, ensuring efficient tiling for GPU execution.\n", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef ff_llama(\n a_ptr, w1_ptr, w3_ptr, out_ptr, rms_w_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_w1k, stride_w1n,\n stride_w3k, stride_w3n,\n stride_outm, stride_outn,\n stride_rms_w,\n USE_FP8: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"\n w1 and w3 are weights (linear layers)\n F.silu(w1(x)) * w3(x)\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)\n pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n w1_ptrs = w1_ptr + (offs_k[:, None] * stride_w1k + offs_bn[None, :] * stride_w1n)\n w3_ptrs = w3_ptr + (offs_k[:, None] * stride_w3k + offs_bn[None, :] * stride_w3n)\n acc1 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n acc2 = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w\n a_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs)\n a_sum += tl.extra.cuda.libdevice.pow(a.to(tl.float32), 2)\n rms_w = tl.load(rms_w_ptrs)\n if USE_FP8:\n rms_w = rms_w.to(tl.float8e5, bitcast=True)\n rms_w = rms_w.to(tl.float16)\n a = a * rms_w\n b = tl.load(w1_ptrs)\n if USE_FP8:\n b = b.to(tl.float8e5, bitcast=True)\n b = b.to(tl.float32)\n b = b.to(tl.float16)\n acc1 += tl.dot(a, b)\n c = tl.load(w3_ptrs)\n if USE_FP8:\n c = c.to(tl.float8e5, bitcast=True)\n c = c.to(tl.float32)\n c = c.to(tl.float16)\n acc2 += tl.dot(a, c)\n\n a_ptrs += BLOCK_SIZE_K * stride_ak\n w1_ptrs += BLOCK_SIZE_K * stride_w1k\n w3_ptrs += BLOCK_SIZE_K * stride_w3k\n\n rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w\n\n a_mean = tl.sum(a_sum, axis=1) / K + EPS\n a_norm = tl.math.rsqrt(a_mean)\n acc1 = acc1 * a_norm[:, None]\n acc2 = acc2 * a_norm[:, None]\n accumulator = (acc1 * tl.sigmoid(acc1)) * acc2\n\n offs_outm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_outn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n out_ptrs = out_ptr + (stride_outm * offs_outm[:, None] + stride_outn * offs_outn[None, :])\n out_mask = (offs_outm[:, None] < M) & (offs_outn[None, :] < N)\n tl.store(out_ptrs, accumulator, mask=out_mask)\n\n\ndef kernel_ff(x: torch.Tensor, w1: torch.Tensor, w3: torch.Tensor, rms_w: torch.Tensor) -> torch.Tensor:\n assert x.dtype == torch.float16\n assert w1.dtype == w3.dtype == rms_w.dtype\n assert w1.dtype in [torch.int8, torch.float16]\n assert w1.shape == w3.shape\n\n w1_t = w1.t()\n w3_t = w3.t()\n\n batch, seq_len, dim = x.shape\n M, K = batch * seq_len, dim\n\n N = w1_t.shape[1]\n assert K == w1_t.shape[0]\n assert w1_t.shape == w3_t.shape\n x_reshape = x.reshape(M, K)\n out = torch.empty((M, N), dtype=x.dtype, device=x.device)\n grid = lambda META: (triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"N\"], META[\"BLOCK_SIZE_N\"]),)\n ff_llama[grid](\n x_reshape, w1_t, w3_t, out, rms_w,\n M, N, K,\n *x_reshape.stride(),\n *w1_t.stride(),\n *w3_t.stride(),\n *out.stride(),\n *rms_w.stride(),\n USE_FP8=w1_t.dtype != torch.float16,\n EPS=1e-6,\n BLOCK_SIZE_M=16, BLOCK_SIZE_N=16, BLOCK_SIZE_K=64,\n num_stages=2, num_warps=4\n )\n out = out.view(batch, seq_len, -1)\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-7eade69cbe6b46209f7c21a02fee324a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_fwd_kernel_destindex_copy_quantize_kv` is designed for quantizing and copying key-value (KV) tensor data using specified destination indices. It processes a tensor `K` where each element is accessed using calculated offsets and destination indices from `Dest_loc`. The main operations include:\n 1. Loading source data from tensor `K` for a specific group and dimension.\n 2. Computing the absolute maximum values within each group to determine scaling factors needed to normalize the data to fit into int8 format.\n 3. Quantizing the data by dividing each element by its group's scaling factor and casting it to int8.\n 4. Storing quantized data to the `Out` tensor at positions specified by `dest_index` derived from `Dest_loc`.\n 5. Saving the scaling factors into the `Out_scale` tensor for later dequantization.\n\n The `destindex_copy_quantize_kv` function acts as a higher-level interface for this kernel, handling input reshaping and invoking the kernel. It sets up the grid size based on sequence length and head count, ensures that the head dimension is divisible by the quantization group dimension, and calls the kernel with the correct parameters. The head dimension is divided into groups for more efficient processing.\n\n Key parameters and functions used:\n - `K`: Source tensor of shape (batch, head, head_dim).\n - `DestLoc`: Destination indices tensor indicating where each sequence element's output should be written.\n - `Out`: Output tensor to store quantized data.\n - `Out_scale`: Tensor to store scaling factors for each group.\n - `group_size`, `BLOCK_GROUP_NUM`, `BLOCK_GROUP_DIM`: Parameters defining the size and dimensions of groups used in quantization.\n - `tl.load`, `tl.store`: Triton operations to load from and store into GPU memory.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K,\n Dest_loc,\n Out,\n Out_scale,\n stride_k_bs,\n stride_k_h,\n stride_k_g,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_g,\n stride_o_d,\n stride_os_bs,\n stride_os_h,\n stride_os_g,\n group_size,\n BLOCK_GROUP_NUM: tl.constexpr,\n BLOCK_GROUP_DIM: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_g = tl.arange(0, BLOCK_GROUP_NUM)\n offs_d = tl.arange(0, BLOCK_GROUP_DIM)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n src_data = tl.load(\n K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],\n mask=offs_g[:, None] < group_size,\n other=0.0,\n )\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)\n q_src_data = (src_data / data_scale[:, None]).to(tl.int8)\n\n o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g\n tl.store(o_ptrs, q_src_data, mask=offs_g[:, None] < group_size)\n tl.store(os_ptrs, data_scale, mask=offs_g < group_size)\n return\n\n\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n quant_group_dim = 8\n\n assert head_dim % quant_group_dim == 0, \"error head dim, can not been supported to copy quant kv\"\n grid = (seq_len, head_num)\n num_warps = 1\n\n group_size = head_dim // quant_group_dim\n group_dim = quant_group_dim\n\n K = K.view((K.shape[0], K.shape[1], group_size, group_dim))\n Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim)\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K,\n DestLoc,\n Out,\n Out_scale,\n K.stride(0),\n K.stride(1),\n K.stride(2),\n K.stride(3),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n Out.stride(3),\n Out_scale.stride(0),\n Out_scale.stride(1),\n Out_scale.stride(2),\n group_size,\n BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),\n BLOCK_GROUP_DIM=group_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c1506f634e13471f8225abc222b14943", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThis Triton code provides an implementation of the RoPE (Rotary Position Embeddings) using a combination of Python and a Triton JIT-compiled kernel. It is intended to perform this computation efficiently on a GPU. The key components include:\n\n1. `calculate_settings(n)` function:\n - Purpose: Computes optimal block size and number of warps for executing the Triton kernel based on the input dimension `n`.\n - It calculates the next power of two for `n` to determine `BLOCK_SIZE` and decides the appropriate number of `num_warps` based on this block size. It raises a runtime error if the block size exceeds the maximum allowed size `MAX_FUSED_SIZE`.\n\n2. `_rope_embedding` Triton kernel:\n - Signature: `_rope_embedding(Q, Q_row_stride, cos, cos_row_stride, sin, sin_row_stride, seqlen, head_dim, n_heads, BACKWARD_PASS, BLOCK_SIZE, ROPE_GROUP_SIZE)`\n - Role: Executes the computation of the RoPE embeddings. This involves multiplying the input `Q` with `cos` and `sin` matrices, applying a rotation to half of `Q`, and adjusting based on the `BACKWARD_PASS` flag.\n - The kernel leverages the Triton language to parallelize computations across rows and groups of heads. It uses `tl.load` and `tl.store` for memory operations with specified masks to ensure only valid operations are performed.\n\n3. `_rope_embedding_forward_impl(Q, cos, sin)`:\n - Functionality: Handles the forward pass for RoPE embedding. This involves preparing data, calculating necessary settings using `calculate_settings`, and launching the Triton kernel `_rope_embedding`.\n - Prepares the data by reshaping and transposing `Q` for compatibility with the kernel's expectations. Determines the number of groups (`n_groups`) by dividing the number of heads by `ROPE_GROUP_SIZE`.\n\n4. `_rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps)`:\n - Purpose: Facilitates the backward pass (gradient computation) for the RoPE operation, useful in training scenarios.\n - Similarly reshapes and transposes `dY` to call the Triton kernel with appropriate parameters, setting `BACKWARD_PASS` to `True` to reverse the transformation logic.\n\nCommon variables:\n- `Q`: The input tensor for the forward pass.\n- `cos`, `sin`: Tensors holding the cosine and sine values used in the computation.\n- `seqlen`, `head_dim`, `n_heads`: Dimensions describing sequence length, head dimension, and number of attention heads.\n- `BLOCK_SIZE`, `num_warps`: Parameters determined through heuristics to optimize GPU execution.\n", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\nMAX_FUSED_SIZE = 65536 # 2**16\nnext_power_of_2 = triton.next_power_of_2\n\ndef calculate_settings(n):\n BLOCK_SIZE = next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n num_warps = 4\n if BLOCK_SIZE >= 32768: num_warps = 32\n elif BLOCK_SIZE >= 8192: num_warps = 16\n elif BLOCK_SIZE >= 2048: num_warps = 8\n return BLOCK_SIZE, num_warps\n\n\nROPE_GROUP_SIZE = 4\n\n@triton.heuristics({\"BACKWARD_PASS\": lambda args: args[\"BACKWARD_PASS\"],})\n@triton.jit\ndef _rope_embedding(\n Q, Q_row_stride,\n cos, cos_row_stride,\n sin, sin_row_stride,\n seqlen,\n head_dim : tl.constexpr,\n n_heads : tl.constexpr,\n BACKWARD_PASS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n ROPE_GROUP_SIZE : tl.constexpr = 4,\n):\n \"\"\"\n Calculates the RoPE Embedding quickly\n RoPE is Q * cos + rotate_half(Q) * sin\n See our blog post for more info\n \"\"\"\n row_position = tl.program_id(0)\n group_head_position = tl.program_id(1)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n half_head_dim = head_dim // 2\n mask = col_offsets < half_head_dim\n\n sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n\n if BACKWARD_PASS:\n # See our blog post for more info.\n sin1 = -sin1\n \n # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8\n head_start = group_head_position * ROPE_GROUP_SIZE\n head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)\n for k in range(head_start, head_end):\n offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n\n # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n\n tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n \n\n\ndef _rope_embedding_forward_impl(Q, cos, sin):\n Q = Q.transpose(1, 2).clone()\n cos, sin = cos.squeeze(), sin.squeeze()\n batch, seq_len, n_heads, head_dim = Q.shape\n Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = Q.shape\n assert(seq_len <= cos.shape[0])\n\n # [TODO] Changing blocksize to head_dim//2 seems to have\n # some concurrency / un-deterministic issues.\n BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n\n # group_size = 4 # 4 or 8, too large group_size can hurt performance.\n div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n n_groups = div + (mod != 0)\n\n _rope_embedding[(n_rows, n_groups, )](\n Q, Q.stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len,\n head_dim, n_heads,\n BACKWARD_PASS = False,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n Q = Q.view(batch, seq_len, n_heads, head_dim)\n Q = Q.transpose(1, 2)\n return Q, cos, sin, n_groups, BLOCK_SIZE, num_warps\n\n\ndef _rope_embedding_backward_impl(dY, cos, sin, n_groups, BLOCK_SIZE, num_warps):\n dY = dY.transpose(1, 2)\n batch, seq_len, n_heads, head_dim = dY.shape\n dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n # Must be reshape not view\n n_rows, n_cols = dY.shape\n\n _rope_embedding[(n_rows, n_groups, )](\n dY, dY .stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len, head_dim, n_heads,\n BACKWARD_PASS = True,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n dY = dY.view(batch, seq_len, n_heads, head_dim)\n dY = dY.transpose(1, 2)\n return dY\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-56744023c5884d7bab7930b62a448874", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code implements a matrix-vector multiplication kernel, `mv_kernel`, designed for efficient execution on NVIDIA GPUs. It leverages Triton's Just-In-Time (JIT) compilation and auto-tuning features. The kernel function multiplies a matrix `A` of size N x M with a vector `B` of size M, resulting in an output vector `C` of size N. The kernel is parameterized by constants `BLOCK_N` and `BLOCK_M` which determine the tile sizes for the computation, allowing for optimal parallel processing. Within `mv_kernel`, a loop iterates over the matrix A in blocks of `BLOCK_M` to load sub-matrices and the corresponding block from vector B, performing element-wise multiplication and accumulation in shared memory. The partial results are then reduced along the rows and stored in the output vector `C`. The `mv` function wraps this kernel execution, asserting dimension compatibility and managing CUDA resources for launching the kernel with calculated grid dimensions.\n ", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": m, \"BLOCK_N\": n}, num_stages=s, num_warps=w)\n for m in [32, 64, 128]\n for n in [1, 2, 4, 8]\n for s in [3, 4]\n for w in [4, 8]\n ],\n key=[\"M\", \"N\"],\n)\n@triton.jit\ndef mv_kernel(\n A,\n B,\n C,\n N,\n M,\n stride_an,\n stride_am,\n stride_bm,\n stride_cn,\n BLOCK_N: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset_n = pid * BLOCK_N + tl.arange(0, BLOCK_N)[:, None]\n offset_m = tl.arange(0, BLOCK_M)[None, :]\n n_mask = offset_n < N\n A_ptrs = A + offset_n * stride_an + offset_m * stride_am\n B_ptrs = B + offset_m * stride_bm\n acc = tl.zeros((BLOCK_N, BLOCK_M), dtype=tl.float32)\n for m in range(0, M, BLOCK_M):\n m_mask = m + offset_m < M\n a = tl.load(A_ptrs, mask=n_mask & m_mask, other=0.0).to(tl.float32)\n b = tl.load(B_ptrs, mask=m_mask, other=0.0).to(tl.float32)\n acc += a * b\n A_ptrs += BLOCK_M * stride_am\n B_ptrs += BLOCK_M * stride_bm\n\n acc = tl.sum(acc, axis=1)\n C_ptrs = C + offset_n * stride_cn\n tl.store(C_ptrs, acc[:, None], mask=n_mask)\n\n\ndef mv(inp, vec):\n logging.debug(\"GEMS MV\")\n assert inp.shape[1] == vec.shape[0], \"incompatible dimensions\"\n N, M = inp.shape\n out = torch.empty((N,), device=inp.device, dtype=inp.dtype)\n grid = lambda META: (triton.cdiv(N, META[\"BLOCK_N\"]),)\n with torch.cuda.device(inp.device):\n mv_kernel[grid](\n inp,\n vec,\n out,\n N,\n M,\n inp.stride(0),\n inp.stride(1),\n vec.stride(0),\n out.stride(0),\n )\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-340610f899f9449aa30153302236fd63", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `rms_matmul_rbe` executes a specialized matrix multiplication that incorporates RMS normalization and optionally applies rotary embeddings. The kernel takes in pointers to the input tensor `x`, the transposed weight matrix `w`, and an auxiliary RMS weight `rms_w`. The operation proceeds in a block-wise manner, defined by `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `BLOCK_SIZE_K`, iterating over batches and computing each element of the output matrix. RMS normalization is applied to input rows before matrix multiplication, by computing the sum of squares, deriving the mean, and applying the reciprocal square root to normalize. If specified, rotary embeddings adjust results based on token position and angle `THETA`. The wrapper function `rms_matmul_rbe_wrapper` sets up parameters, manages data types, and launches the Triton kernel, ensuring the input and output tensors are correctly formatted and aligned for transformer layers, considering the model's dimensions like batch size, number of attention heads, and head size.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\ntorch.manual_seed(1234)\n\n\n@triton.jit\ndef rbe_triton(x_ptr, out_ptr,\n M, K,\n stride_x_batch, stride_x_m, stride_x_n,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n THETA: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)\n pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])\n x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)\n x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)\n tl.debug_barrier()\n start_block = start_token_position + pid_m * BLOCK_SIZE_M\n cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)\n\n out_real = real * cos - imag * sin\n out_imag = real * sin + imag * cos\n tl.debug_barrier()\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :])\n out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n tl.store(out_ptrs, out_real, mask=out_real_mask)\n out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)\n\n\n@triton.jit\ndef rms_matmul_rbe(\n x_ptr, w_ptr, rms_w_ptr, out_ptr,\n M, N, K,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_w_k, stride_w_n,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n USE_FP8: tl.constexpr,\n RBE_EPILOGUE: tl.constexpr,\n THETA: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"\n Prologue: RMS\n Epilogue: nothing or Rotary embeddings\n c = ROBE((rms(a) * rms_w) @ b)\n \"\"\"\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)\n pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)\n\n offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k)\n w_ptrs = w_ptr + (offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w\n x_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n x = tl.load(x_ptrs)\n x_sum += tl.extra.cuda.libdevice.pow(x.to(tl.float32), 2)\n rms_w = tl.load(rms_w_ptrs) # TODO add an assert that rms_w is a multiple of BLOCK SIZE K\n if USE_FP8:\n rms_w = rms_w.to(tl.float8e5, bitcast=True)\n rms_w = rms_w.to(tl.float16)\n x = x * rms_w\n w = tl.load(w_ptrs) # TODO add an assert that w is a multiple of BLOCK SIZE K\n if USE_FP8:\n w = w.to(tl.float8e5, bitcast=True)\n w = w.to(tl.float32)\n w = w.to(tl.float16)\n accumulator += tl.dot(x, w)\n x_ptrs += BLOCK_SIZE_K * stride_x_k\n w_ptrs += BLOCK_SIZE_K * stride_w_k\n rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w\n x_mean = tl.sum(x_sum, axis=1) / K + EPS\n x_norm = tl.math.rsqrt(x_mean)\n accumulator = accumulator * x_norm[:, None]\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n)\n out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n\n tl.store(out_ptrs, accumulator, mask=out_mask)\n\n\ndef rms_matmul_rbe_wrapper(x: torch.Tensor, weight: torch.Tensor, rms_w: torch.Tensor, use_rbe: bool, start_pos: int,\n n_heads: int, head_dim: int):\n # 确保 weight 和 rms_w 的数据类型一致\n assert weight.dtype in [torch.float16, torch.int8], \"Only torch.float16 or torch.int8 are supported for weight\"\n \n # 确保 rms_w 和 weight 的 dtype 一致\n if rms_w.dtype != weight.dtype:\n # print(f\"rms_w dtype: {rms_w.dtype}, weight dtype: {weight.dtype}\")\n rms_w = rms_w.to(weight.dtype) # 如果类型不一致,将 rms_w 转换为与 weight 一致的类型\n\n # 创建输出张量\n batch, M, K = x.shape\n weight_t = weight.t()\n K_W, N = weight_t.shape\n assert K == K_W\n\n out = torch.empty((batch, M, N), dtype=weight_t.dtype, device=weight_t.device)\n out_ptr = triton.reinterpret(out, tl.float8e5 if out.dtype == torch.int8 else tl.float16)\n\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"N\"], META[\"BLOCK_SIZE_N\"]))\n\n rms_matmul_rbe[grid](\n x_ptr=x,\n w_ptr=weight_t, rms_w_ptr=rms_w, out_ptr=out_ptr,\n M=M, N=N, K=K,\n stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2),\n stride_w_k=weight_t.stride(0), stride_w_n=weight_t.stride(1),\n stride_rms_w=rms_w.stride(0),\n stride_out_batch=out.stride(0), stride_out_m=out.stride(1), stride_out_n=out.stride(2),\n start_token_position=start_pos,\n USE_FP8=weight_t.dtype == torch.int8,\n RBE_EPILOGUE=use_rbe,\n THETA=10000.,\n EPS=1e-6,\n BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64,\n num_stages=4, num_warps=4\n )\n out = out.view(batch, M, n_heads, head_dim)\n return out\n\n\n" + ] + }, + { + "id": "openseek-8-8def6542013448a6812c2fb195d0f7aa", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The given code implements layer normalization using Triton's just-in-time (JIT) compiler, allowing efficient GPU execution. The layer normalization operation is split into two phases: forward and backward passes, each implemented with its Triton kernel.\n\n `_layer_norm_forward_kernel` calculates layer normalization for each row of the input matrix `X`. It loads input, weight, and bias tensors, computes the mean and variance, and stores the results in `Y` after applying normalization using `W` and `B`. This kernel is called within `layer_norm_forward`, which manages input reshaping, output initialization, and kernel configuration.\n\n `_layer_norm_backward_kernel` computes gradients with respect to inputs, weights, and biases. It processes a block of rows per program instance, computes gradients iteratively, and stores results in `DX`, `DW`, and `DB`. The kernel considers the number of streaming multiprocessors available (`sm_count`) for efficient distribution of workload.\n\n `layer_norm_forward` initializes the forward process by reshaping inputs and calling the forward kernel. It checks input dimensions and creates output placeholders for `Y`, `Mean`, and `RSTD`, setting up kernel launch parameters like `BLOCK_SIZE` and `num_warps`.\n\n `layer_norm_backward` prepares and launches the backward kernel, consolidating partial derivatives computed per GPU program into `DW` and `DB` to obtain gradients w.r.t weights and biases.\n\n `LigerLayerNormFunction` integrates these operations within a PyTorch custom autograd function. It wraps `layer_norm_forward` and `layer_norm_backward` to support end-to-end differentiation. The `forward` method executes `layer_norm_forward`, saving inputs and results. The `backward` method uses stored values to compute and return gradients via `layer_norm_backward`.\n\n Relevant details:\n - `calculate_settings`: Determines appropriate `BLOCK_SIZE` and warp count for kernel execution.\n - `tl.constexpr` arguments: Used to pass constants like block size to kernels.\n - Kernel launch: Parameters like strides and dimensions are crucial for indexing and data handling within kernels.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\nimport functools\n\ndef is_hip() -> bool:\n return torch.version.hip is not None\n\ndef ensure_contiguous(fn):\n @functools.wraps(fn)\n def wrapper(ctx, *args, **kwargs):\n def maybe_to_contiguous(x):\n return x.contiguous() if isinstance(x, torch.Tensor) else x\n\n args = [maybe_to_contiguous(arg) for arg in args]\n kwargs = {k: maybe_to_contiguous(v) for k, v in kwargs.items()}\n return fn(ctx, *args, **kwargs)\n\n return wrapper\n\n\ndef calculate_settings(n):\n MAX_FUSED_SIZE = 65536\n BLOCK_SIZE = triton.next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(\n f\"Cannot launch Triton kernel since n = {n} exceeds \"\n f\"the recommended Triton blocksize = {MAX_FUSED_SIZE}.\"\n )\n\n num_warps = 4\n if BLOCK_SIZE >= 32768:\n num_warps = 32 if not is_hip() else 16\n elif BLOCK_SIZE >= 8192:\n num_warps = 16\n elif BLOCK_SIZE >= 2048:\n num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef _layer_norm_forward_kernel(\n Y_ptr, # pointer to output, shape (n_rows, n_cols)\n Y_row_stride, # stride of each row in output\n X_ptr, # pointer to input, shape (n_rows, n_cols)\n X_row_stride, # stride of each row in input\n W_ptr, # pointer to weights, shape (n_cols,)\n W_row_stride, # stride of each row in weights\n B_ptr, # pointer to bias, shape (n_cols,)\n B_row_stride, # stride of each row in bias\n Mean_ptr, # pointer to mean, shape (n_rows,)\n Mean_row_stride, # stride of each row in mean\n RSTD_ptr, # pointer to rstd, shape (n_rows,)\n RSTD_row_stride, # stride of each row in rstd\n n_cols,\n eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y_ptr += row_idx * Y_row_stride\n X_ptr += row_idx * X_row_stride\n Mean_ptr += row_idx * Mean_row_stride\n RSTD_ptr += row_idx * RSTD_row_stride\n\n X_row = tl.load(X_ptr + col_offsets, mask=mask, other=0)\n W_row = tl.load(W_ptr + col_offsets, mask=mask, other=0)\n B_row = tl.load(B_ptr + col_offsets, mask=mask, other=0)\n\n mean = tl.sum(X_row, axis=0) / n_cols\n var = tl.sum((X_row - mean) * (X_row - mean), axis=0) / n_cols\n rstd = tl.rsqrt(var + eps)\n\n tl.store(Mean_ptr, mean)\n tl.store(RSTD_ptr, rstd)\n\n Y_row = (X_row - mean) * rstd * W_row + B_row\n\n tl.store(Y_ptr + col_offsets, Y_row, mask=mask)\n\n@triton.jit\ndef _layer_norm_backward_kernel(\n X_ptr, # pointer to input, shape (n_rows, n_cols)\n W_ptr, # pointer to weights, shape (n_cols,)\n Mean_ptr, # pointer to mean, shape (n_rows,)\n RSTD_ptr, # pointer to rstd, shape (n_rows,)\n DX_ptr, # pointer to input grad, shape (n_rows, n_cols)\n DW_ptr, # pointer to weights grad, shape (n_cols,)\n DB_ptr, # pointer to bias grad, shape (n_cols,)\n DY_ptr, # pointer to output grad, shape (n_rows, n_cols)\n stride_x, # stride of each row in input\n stride_dx, # stride of each row in input grad\n stride_dw, # stride of each row in weights grad\n stride_db, # stride of each row in bias grad\n stride_dy, # stride of each row in output grad\n n_rows,\n n_cols,\n rows_per_program: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n dtype: tl.constexpr,\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n row_end = min((row_block_id + 1) * rows_per_program, n_rows)\n cols = tl.arange(0, BLOCK_SIZE)\n mask = cols < n_cols\n\n dw_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n db_row = tl.zeros((BLOCK_SIZE,), dtype=tl.float32)\n\n X_ptr += row_start * stride_x\n Mean_ptr += row_start\n RSTD_ptr += row_start\n DX_ptr += row_start * stride_dx\n DY_ptr += row_start * stride_dy\n\n for _ in range(row_start, row_end):\n x = tl.load(X_ptr + cols, mask=mask, other=0.0)\n w = tl.load(W_ptr + cols, mask=mask, other=0.0)\n dy = tl.load(DY_ptr + cols, mask=mask, other=0.0)\n mean = tl.load(Mean_ptr)\n rstd = tl.load(RSTD_ptr)\n\n x_hat = (x - mean) * rstd\n wdy = w * dy\n c1 = tl.sum(x_hat * wdy, axis=0) / n_cols\n c2 = tl.sum(wdy, axis=0) / n_cols\n dx = (wdy - (x_hat * c1 + c2)) * rstd\n tl.store(DX_ptr + cols, dx.to(dtype), mask=mask)\n\n dw_row += dy * x_hat\n db_row += dy\n\n X_ptr += stride_x\n Mean_ptr += 1\n RSTD_ptr += 1\n DX_ptr += stride_dx\n DY_ptr += stride_dy\n\n tl.store(DW_ptr + row_block_id * stride_dw + cols, dw_row.to(dtype), mask=mask)\n tl.store(DB_ptr + row_block_id * stride_db + cols, db_row.to(dtype), mask=mask)\n\ndef layer_norm_forward(X, W, B, eps):\n shape = X.shape\n dim = shape[-1]\n X = X.view(-1, dim)\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)\n Mean = torch.empty(n_rows, dtype=X.dtype, device=X.device)\n RSTD = torch.empty(n_rows, dtype=X.dtype, device=X.device)\n assert (\n X.shape[1] == W.shape[0]\n ), f\"Incompatible hidden size dimension between input tensor with shape[1] = {X.shape[1]} and weight tensor with shape[0] = {W.shape[0]}\"\n\n _layer_norm_forward_kernel[(n_rows,)](\n Y,\n Y.stride(0),\n X,\n X.stride(0),\n W,\n W.stride(0),\n B,\n B.stride(0),\n Mean,\n Mean.stride(0),\n RSTD,\n RSTD.stride(0),\n n_cols,\n eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return Y.view(*shape), X, Mean, RSTD, BLOCK_SIZE, num_warps\n\ndef layer_norm_backward(dY, X, W, B, Mean, RSTD):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.view(-1, dim)\n n_rows, n_cols = dY.shape\n\n DX = torch.empty((n_rows, n_cols), dtype=X.dtype, device=X.device)\n sm_count = torch.cuda.get_device_properties(X.device).multi_processor_count\n _DW = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)\n _DB = torch.empty((sm_count, n_cols), dtype=W.dtype, device=W.device)\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n if n_cols > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n\n rows_per_program = math.ceil(n_rows / sm_count)\n grid = (sm_count,)\n triton_dtype = tl.float32 if X.dtype == torch.float32 else tl.bfloat16\n _layer_norm_backward_kernel[grid](\n X,\n W,\n Mean,\n RSTD,\n DX,\n _DW,\n _DB,\n dY,\n X.stride(0),\n DX.stride(0),\n _DW.stride(0),\n _DB.stride(0),\n dY.stride(0),\n n_rows,\n n_cols,\n rows_per_program,\n BLOCK_SIZE=BLOCK_SIZE,\n dtype=triton_dtype,\n )\n\n DW = _DW.sum(dim=0).to(W.dtype)\n DB = _DB.sum(dim=0).to(W.dtype)\n\n DX = DX.view(*shape)\n return DX, DW, DB\n\nclass LigerLayerNormFunction(torch.autograd.Function):\n @staticmethod\n @ensure_contiguous\n def forward(ctx, X, W, B, eps):\n Y, X, Mean, RSTD, BLOCK_SIZE, num_warps = layer_norm_forward(X, W, B, eps)\n ctx.save_for_backward(X, W, B, Mean, RSTD)\n return Y\n\n @staticmethod\n @ensure_contiguous\n def backward(ctx, dY):\n X, W, B, Mean, RSTD = ctx.saved_tensors\n DX, DW, DB = layer_norm_backward(dY, X, W, B, Mean, RSTD)\n return DX, DW, DB, None\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-a32285e44de1481ea5bf7d717cdd9d3d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_fwd_kernel` is designed to execute a context attention mechanism, essential in transformer architectures. \n It accepts tensors Q, K, and V representing queries, keys, and values, respectively, alongside several metadata and stride parameters to account for complex batching and sequence operations. \n The kernel handles operations for attention computation: it initially calculates dot products between Q and K, scales these products by `sm_scale`, and applies an exponential function to obtain softmax values. \n The calculated attention scores are then utilized to weight the values in V, and the results are accumulated and stored in the `Out` tensor. \n This process involves masking techniques to handle variable sequence lengths and prompt caching, ensuring causal ordering. \n The kernel leverages Triton's parallel execution capabilities through grid and block configurations determined by inputs such as `BLOCK_M`, `BLOCK_N`, and `kv_group_num`.\n\n The `context_attention_fwd` function orchestrates the execution of the Triton kernel by setting up the computational environment. \n It calculates the appropriate block size `BLOCK_M` based on whether the Tesla architecture is in use, establishes shape constraints, and determines the scaling factor `sm_scale` using the inverse square root of the query dimension `Lq`, adjusted by a constant for compatibility with Triton's exponential function. \n The function configures execution parameters, including batch and head dimensions, through a grid lambda function, and invokes the Triton kernel with all required inputs and meta parameters, such as tensor strides, kv head grouping, and block dimensionality. \n This setup ensures efficient parallel attention computation across potentially large input sequences.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\nTESLA = \"Tesla\" in torch.cuda.get_device_name(0)\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n Out,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens,\n B_req_idx,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n kv_group_num,\n b_prompt_cache_len,\n H: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n cur_bh = tl.program_id(1)\n cur_batch = cur_bh // H\n cur_head = cur_bh % H\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = block_start_loc + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n\n # causal mask\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n kv_loc = tl.load(\n Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n),\n mask=(start_n + offs_n) < block_end_loc,\n other=0,\n )\n off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n qk = tl.dot(q, k)\n\n mask = offs_m[:, None] + prompt_cache_len >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk * sm_scale, -1.0e8)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n\n # -- update m_i and l_i\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0)\n p = p.to(v.dtype)\n acc = tl.dot(p, v, acc)\n # update m_i and l_i\n m_i = m_ij\n\n acc = acc / l_i[:, None]\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n\n@torch.no_grad()\ndef context_attention_fwd(\n q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs\n):\n BLOCK_M = 128 if not TESLA else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128, 256}\n\n # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634,\n # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。\n sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = lambda meta: (triton.cdiv(max_input_len, meta[\"BLOCK_M\"]), batch * head, 1)\n\n BLOCK_N = BLOCK_M\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n b_start_loc,\n b_seq_len,\n req_to_token_indexs,\n b_req_idx,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n req_to_token_indexs.stride(0),\n req_to_token_indexs.stride(1),\n kv_group_num=kv_group_num,\n b_prompt_cache_len=b_prompt_cache_len,\n H=head,\n BLOCK_DMODEL=Lk,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0e68e01a209e4adc92ccfab00d418fb5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This script provides a highly optimized cross-entropy loss function using Triton for efficient GPU execution, particularly suitable for large-scale distributed machine learning tasks. It introduces two main kernels, `cross_entropy_fwd_kernel` and `cross_entropy_bwd_kernel`, written in Triton language that enable high-performance computation with fine-grained control over execution.\n\n The forward kernel, `cross_entropy_fwd_kernel`, calculates the LSE and the smoothed cross-entropy loss. It uses block-level parallelism and supports label smoothing, where a proportion of the true label probability is distributed among all classes. The kernel uses Triton’s `program_id` to identify the executing thread block and manage access to specific data partitions. It includes heuristics to determine whether label smoothing is applied and whether the dataset is split across multiple devices (useful for tensor parallelism).\n\n The backward kernel, `cross_entropy_bwd_kernel`, computes the gradient of the cross-entropy loss with respect to the logits. This involves calculating the gradient of LSE and adjusting probabilities based on the presence of label smoothing. It uses similar block-level indexing to parallelize this computation efficiently across GPU cores.\n\n The `CrossEntropyLoss` class encapsulates these kernels within a PyTorch autograd function. Its static `forward` method sets up the context, computes losses, manages optional distributed operations, and saves necessary tensors for the backward pass. The `backward` method applies the backward kernel to compute gradients.\n\n The auxiliary function `cross_entropy_loss` is a user-friendly wrapper around the `CrossEntropyLoss.apply` method. It handles parameter passing for common usage patterns, such as enabling/disabling label smoothing and configuring tensor parallelism.\n\n Parameters for these functions include:\n - `logits`: a 2D tensor containing model predictions before softmax.\n - `labels`: a 1D tensor with the actual class labels.\n - `smoothing`: a float controlling label smoothing intensity.\n - `lse_square_scale`: controls LSE regularization.\n - `ignored_index`: specifies label indices to ignore in loss computation.\n - `process_group`: defines the communication group for distributed settings.\n\n This module ensures efficient computation and gradient propagation in neural networks, especially when working with large vocabularies or extensive multi-GPU setups.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n # if SPLIT (e.g. tensor parallel), don't include the LSE in the loss since it's not the final LSE\n SPLIT: tl.constexpr,\n):\n # Triton kernel implementation for the forward pass of cross-entropy with label smoothing.\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(tl.float32)\n max_logits = tl.max(logits, 0)\n if HAS_SMOOTHING:\n sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)\n lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits\n tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)\n if label_idx == ignored_index:\n loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(n_cols, (col_block_idx + 1) * BLOCK_SIZE):\n logits_label = tl.load(logits_ptr + label_idx)\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n loss += lse_square_scale * lse * lse\n tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n # Triton kernel implementation for the backward pass of cross-entropy with label smoothing.\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignored_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(tl.float32)\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, dloss * probs, mask=col_offsets < n_cols)\n\nclass CrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n logits,\n labels,\n smoothing,\n lse_square_scale=0.0,\n ignored_index=-100,\n inplace_backward=False,\n process_group=None,\n ):\n # CrossEntropyLoss forward function leveraging the Triton kernel.\n n_rows, n_cols = logits.shape\n assert labels.shape == (n_rows,)\n world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n total_classes = world_size * n_cols\n rank = 0 if process_group is None else torch.distributed.get_rank(process_group)\n class_start_idx = rank * n_cols\n\n if logits.stride(-1) != 1:\n logits = logits.contiguous()\n MAX_BLOCK_SIZE = 64 * 1024\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))\n split = world_size > 1 or n_cols > MAX_BLOCK_SIZE\n n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE\n loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)\n losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n with torch.cuda.device(logits.device.index):\n cross_entropy_fwd_kernel[(n_rows, n_splits)](\n losses, # data ptrs\n lse,\n logits,\n labels,\n smoothing,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx,\n n_cols, # shapes\n n_rows,\n logits.stride(0), # strides\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n SPLIT=split,\n )\n\n if split:\n if world_size > 1:\n lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)\n torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)\n handle_losses = torch.distributed.all_reduce(\n losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n )\n lse = torch.logsumexp(lse_allgather, dim=0)\n handle_losses.wait()\n else:\n lse = torch.logsumexp(lse, dim=0)\n losses = losses.sum(dim=0)\n losses += lse\n if lse_square_scale != 0.0:\n losses += lse_square_scale * lse.square()\n losses.masked_fill_(labels == ignored_index, 0.0)\n\n ctx.save_for_backward(logits, lse, labels)\n ctx.smoothing = smoothing\n ctx.lse_square_scale = lse_square_scale\n ctx.ignored_index = ignored_index\n ctx.total_classes = total_classes\n ctx.class_start_idx = class_start_idx\n ctx.inplace_backward = inplace_backward\n return losses\n\n @staticmethod\n def backward(ctx, grad_losses):\n logits, lse, labels = ctx.saved_tensors\n dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n n_rows, n_cols = logits.shape\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n grid = lambda META: (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"])) # noqa\n with torch.cuda.device(logits.device.index):\n cross_entropy_bwd_kernel[grid](\n dlogits, # data ptrs\n grad_losses,\n logits,\n lse,\n labels,\n ctx.smoothing,\n ctx.lse_square_scale,\n ctx.ignored_index,\n ctx.total_classes,\n ctx.class_start_idx,\n n_cols, # shapes\n logits.stride(0), # strides\n dlogits.stride(0),\n grad_losses.stride(0),\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n )\n return dlogits, None, None, None, None, None, None, None\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n label_smoothing: float = 0.0,\n lse_square_scale: float = 0.0,\n ignored_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n \"\"\"\n CrossEntropyLoss wrapper function for calling the custom autograd Function.\n \"\"\"\n return CrossEntropyLoss.apply(\n logits,\n labels,\n label_smoothing,\n lse_square_scale,\n ignored_index,\n inplace_backward,\n process_group,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-8225b27a77f54908ba3f22cc3ed807d5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code snippet implements a masked select operation using Triton to enable high-performance parallel computations on GPUs. The core logic is encapsulated in a Triton JIT-compiled kernel `masked_select_kernel`, which processes data in parallel across many threads organized in blocks. The kernel takes pointers to input data (`inp_ptr`), a selection mask (`select_mask_ptr`), a prefix sum array (`prefix_sum_ptr`), and an output buffer (`out_ptr`). Each thread computes its global ID using `tl.program_id` and `tl.arange`, checking if it's within the bounds of `n_elements`. It loads the input and mask values, converts the mask to a Boolean type `tl.int1`, and calculates output positions using a prefix sum. The selected input values are stored in the output buffer where both the mask and bounds conditions are met. The `masked_select` function wraps the kernel invocation, performing necessary pre-processing steps such as broadcasting tensors to compatible shapes, flattening the mask, and computing prefix sums for determining output positions. The kernel's grid size is dynamically computed based on the number of elements and the block size (`BLOCK_SIZE`). The `cfggen` function generates various Triton configurations to optimize execution by testing different block sizes and warp numbers, enhancing performance through autotuning. This ensures efficient parallel computation by exploring multiple execution configurations. The `broadcastable` function provides a utility to ensure that two tensor shapes can be broadcasted, adhering to broadcasting rules essential for element-wise operations in multi-dimensional arrays.\n ", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n\nShape = Tuple[int]\n\ndef broadcastable(s1: Shape, s2: Shape) -> bool:\n r1 = len(s1)\n if r1 == 0:\n return True\n r2 = len(s2)\n if r2 == 0:\n return True\n\n s1, s2 = (s1, s2) if r1 >= r2 else (s2, s1)\n r1, r2 = (r1, r2) if r1 >= r2 else (r2, r1)\n\n d = r1 - r2\n for i in range(r2):\n if s1[d + i] == 1 or s2[i] == 1 or s1[d + i] == s2[i]:\n continue\n return False\n return True\n\n\ndef cfggen():\n configs = [\n triton.Config({\"BLOCK_SIZE\": bs}, num_warps=w)\n for w in [4, 8, 16, 32]\n for bs in [256, 512, 1024, 2048, 4096]\n ]\n return configs\n\n\n\n@triton.autotune(configs=cfggen(), key=[\"n_elements\"])\n@triton.jit\ndef masked_select_kernel(\n inp_ptr,\n select_mask_ptr,\n prefix_sum_ptr,\n out_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n\n inp = tl.load(inp_ptr + offsets, mask=mask, other=0.0)\n select_mask = tl.load(select_mask_ptr + offsets, mask=mask, other=0.0).to(tl.int1)\n out_offset = tl.load(prefix_sum_ptr + offsets, mask=mask, other=0.0) - 1\n\n tl.store(out_ptr + out_offset, inp, mask=(select_mask and mask))\n\n\ndef masked_select(inp, mask):\n logging.debug(\"GEMS MASKED SELECT\")\n\n inp_shape = tuple(inp.shape)\n mask_shape = tuple(mask.shape)\n\n assert broadcastable(\n inp_shape, mask_shape\n ), \"The shapes of the `mask` and the `input` tensor must be broadcastable\"\n inp, mask = torch.broadcast_tensors(inp, mask)\n\n inp = inp.contiguous()\n mask = mask.contiguous()\n\n mask_flattened = mask.ravel()\n\n prefix_sum = mask_flattened.cumsum(axis=0)\n out = torch.empty(prefix_sum[-1].item(), dtype=inp.dtype, device=inp.device)\n\n n_elements = inp.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n with torch.cuda.device(inp.device):\n masked_select_kernel[grid](inp, mask_flattened, prefix_sum, out, n_elements)\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-2cc0099f8270412983a42f8dd3762748", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel `_fwd_kernel` implements a scalable sliding window mechanism for performing attention operations in a Transformer architecture. It takes tensors `Q`, `K`, and `V` as inputs along with other parameters like `sm_scale` (to scale query-key scores), and additional metadata (`B_Start_Loc`, `B_Seqlen`) for handling batch sequences efficiently.\n\nKey aspects of `_fwd_kernel`:\n- It operates on a grid defined by `(batch, head, sequence blocks)`, processed in parallel using Triton.\n- `BLOCK_M`, `BLOCK_DMODEL`, and `BLOCK_N` are compile-time constants defining the dimensions of processing blocks.\n- The algorithm iteratively processes `K` and `V` blocks, computes attention scores (`qk`) with a sliding window constraint to focus attention within a localized sequence window.\n- Uses dynamic memory access patterns with offsets to load and store data efficiently.\n- Employs numerically stable operations for computing exponential weights, preventing underflows/overflows.\n- Constructs final output via scaling and accumulating partial results, stored to the output tensor `Out`.\n\nThe `context_attention_fwd` function serves as the interface, defining the grid dimensions, ensuring correct input data preparation, and invoking the Triton kernel with proper execution parameters for optimized performance.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\nimport math\nimport torch.nn.functional as F\n\n\n@triton.jit\ndef _fwd_kernel(\n Q,\n K,\n V,\n sm_scale,\n B_Start_Loc,\n B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度\n Out,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n sliding_window,\n BLOCK_M: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_kv_head = cur_head // kv_group_num\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(\n k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len,\n other=0.0,\n )\n \n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n # [SYM] mask outside of windows,使用大负数代替 -inf\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, -1e9)\n qk = tl.where((start_n + offs_n[None, :]) > (offs_m[:, None] - sliding_window), qk, -1e9)\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n # 防止 m_ij 为 -1e9 导致的数值问题\n m_ij = tl.where(m_ij == -1e9, 0.0, m_ij)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n \n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n l_i_new = tl.where(l_i_new == 0.0, 1e-9, l_i_new) # 防止除零\n \n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(\n v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len,\n other=0.0,\n )\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n\n@torch.no_grad()\ndef context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len, sliding_window):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q,\n k,\n v,\n sm_scale,\n b_start_loc,\n b_seq_len,\n o,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n kv_group_num=kv_group_num,\n sliding_window=sliding_window,\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-071752c1682c40839cae5398ee7c6bbc", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-based kernel performs a forward pass of a blockwise attention mechanism, specifically designed to handle the query (Q), key (K), and value (V) matrices in blocks for efficient computation on GPUs.\n\n - Function `_attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, start_m, BLOCK_M, HEAD_DIM, BLOCK_N, STAGE, offs_m, offs_n, N_CTX)`:\n This function executes the core attention mechanism in two main stages depending on the STAGE parameter. In STAGE 1, it initializes by loading key vectors (k) and their scaling factors (k_scale). It computes the dot product of Q and K, applies scaling, and manages masking for handling causal attention. In both stages, it uses exponential operations to perform the softmax calculation in a numerically stable manner by leveraging maximum value adjustment (m_ij). It accumulates the weighted value vectors to compute the attention output.\n\n - Function `_attn_fwd(Q, K, V, Q_scale, K_scale, Out, stride_qz, stride_qh, stride_qm, stride_qk, stride_kz, stride_kh, stride_kn, stride_kk, stride_vz, stride_vh, stride_vk, stride_vn, stride_oz, stride_oh, stride_om, stride_on, Z, H, N_CTX, HEAD_DIM, BLOCK_M, BLOCK_N, STAGE)`:\n This function sets up the parameters for block processing and handles the execution logic of the forward pass by allocating memory pointers, computing required indices, and setting up the grid for parallel execution using the Triton programming model. It computes linear indices for each block based on offsets calculated from tensor strides, invokes `_attn_fwd_inner` to perform blockwise computation of attention scores, and stores the results back into the output tensor.\n\n - Function `forward(q, k, v, q_scale, k_scale)`:\n This acts as a wrapper function to manage inputs and invoke the Triton kernel `_attn_fwd`. It determines the dimensions of the input tensors, configures the grid for kernel launch, and calls `_attn_fwd` with the appropriate settings. It ensures that inputs are appropriately batched and shaped to match the expected dimensions, manages output allocation, and performs assertions to guarantee consistency across input feature dimensions.\n\n The kernel is optimized for parallel execution by leveraging Triton’s capabilities to handle complex memory access patterns and perform efficient tensor computations in a blockwise manner. Key parameters such as BLOCK_M and BLOCK_N determine the block size, while HEAD_DIM specifies the feature dimensionality in attention computation. The entire setup enables efficient processing of sequences within attention mechanisms typically used in transformer models.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, q_scale, \n K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, \n N_CTX: tl.constexpr):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n K_scale_ptr += lo // BLOCK_N\n K_ptrs += HEAD_DIM * lo\n V_ptrs += HEAD_DIM * lo\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k_mask = (offs_n[None, :] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[:, None])\n k = tl.load(K_ptrs, mask = k_mask)\n k_scale = tl.load(K_scale_ptr)\n qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_ptrs, mask = (offs_n[:, None] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[None, :]))\n p = p.to(tl.float16)\n acc += tl.dot(p, v, out_dtype=tl.float16) \n m_i = m_ij\n K_ptrs += BLOCK_N * HEAD_DIM\n K_scale_ptr += 1\n V_ptrs += BLOCK_N * HEAD_DIM\n return acc, l_i, m_i\n\n@triton.jit\ndef _attn_fwd(Q, K, V, Q_scale, K_scale, Out, \n stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vk, stride_vn, \n stride_oz, stride_oh, stride_om, stride_on, \n Z, H, N_CTX, \n HEAD_DIM: tl.constexpr, \n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr \n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n \n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n vk_offset = qvk_offset // stride_qm\n q_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_M)\n k_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_N) \n \n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, 128)\n Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n Q_scale_ptr = Q_scale + q_scale_offset + start_m\n K_ptrs = K + qvk_offset + offs_k[:, None] + offs_n[None, :] * stride_kn\n K_scale_ptr = K_scale + k_scale_offset\n V_ptrs = V + qvk_offset + offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk\n O_block_ptr = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, 128], dtype=tl.float32)\n q = tl.load(Q_ptrs, mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n q_scale = tl.load(Q_scale_ptr)\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 4 - STAGE, offs_m, offs_n, N_CTX \n )\n acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 2, offs_m, offs_n, N_CTX \n )\n acc = acc / l_i[:, None]\n tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n\ndef forward(q, k, v, q_scale, k_scale):\n BLOCK_M = 128\n BLOCK_N = 64\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n o = torch.empty_like(q, dtype=torch.float16)\n stage = 3\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, q_scale, k_scale, o, \n q.stride(0), q.stride(1), q.stride(2), q.stride(3), \n k.stride(0), k.stride(1), k.stride(2), k.stride(3), \n v.stride(0), v.stride(1), v.stride(2), v.stride(3), \n o.stride(0), o.stride(1), o.stride(2), o.stride(3), \n q.shape[0], q.shape[1], \n N_CTX=q.shape[2], \n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, \n STAGE=stage, \n num_warps=8, \n num_stages=4)\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-a7208a205882433db6034019b7ceef97", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_bgmv_shrink_kernel` is designed for executing a batched generalized matrix-vector multiplication (BGMV) using low-rank adaptation (LORA). It is optimized for GPU execution using Triton's just-in-time (JIT) compilation.\n\n The kernel operates as follows:\n - It processes multiple batches of input data, with each batch associated with a possibly different LORA matrix indicated by `lora_indices`. If a batch's index in `lora_indices` is -1, it is skipped.\n - The computation involves splitting the `K` dimension into manageable blocks of size `BLOCK_K`. `SPLIT_K` determines how many such blocks are processed by each kernel instance.\n - `input_ptr` and `lora_ptr` provide the starting addresses of input data and LORA matrices, respectively. Strides define how elements in these tensors are laid out in memory, facilitating efficient data loading.\n - The kernel initializes a zeroed accumulator for each block and iteratively loads blocks of data from the input tensor and LORA matrix. It performs element-wise multiplication and reduction, summing the results along the `K` dimension.\n - The `scaling` factor is applied to the accumulated result before storing it back in the `out_ptr` location, with atomic addition used if reductions across multiple kernel instances are required (`SPLIT_K` > 1).\n\n The wrapper function `_bgmv_shrink` is responsible for configuring and launching the kernel:\n - It ensures the input, LORA weight matrix, and output tensor are contiguous in memory to enhance access speed and reduces risk of memory-bound operations.\n - It extracts the batch count and the dimensions `N` and `K` from the LORA weight tensor.\n - `BLOCK_N` is computed as the next power of two greater than or equal to `N` to ensure optimal performance on the hardware.\n - A grid of threads is configured to handle the number of batches and the split of the `K` dimension, ensuring efficient parallel execution.\n - Finally, the kernel is invoked with the prepared arguments and configuration.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n# from .utils import get_lora_op_configs\n\n@triton.jit\ndef _bgmv_shrink_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n scaling,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_K: tl.constexpr,\n):\n pid_sk = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n\n offset_n = tl.arange(0, BLOCK_N)\n offset_k = tl.arange(0, BLOCK_K) + pid_sk * BLOCK_K\n a_ptr = input_ptr + cur_batch * xm_stride\n b_ptr = lora_ptr + l0_stride * lora_index\n accumulator = tl.zeros((BLOCK_N, ), dtype=tl.float32)\n for k in range(0, K, BLOCK_K * SPLIT_K):\n current_k = k + offset_k\n current_k_c = tl.max_contiguous(current_k, BLOCK_K)\n tiled_a = tl.load(\n a_ptr + current_k_c,\n mask=current_k < K,\n other=0.0,\n )\n b_ptr_mask = (offset_n[:, None] < N) & (current_k[None, :] < K)\n\n tiled_b = tl.load(\n b_ptr + offset_n[:, None] * lora_k_stride +\n current_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n )\n\n accumulator += tl.sum(tiled_a * tiled_b, 1)\n accumulator *= scaling\n offset_cn = tl.arange(0, BLOCK_N)\n c_ptr = out_ptr + cur_batch * cm_stride + offset_cn * cn_stride\n c_mask = offset_cn < N\n if SPLIT_K == 1:\n tl.store(c_ptr, accumulator, mask=c_mask)\n else:\n tl.atomic_add(c_ptr, accumulator, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _bgmv_shrink(\n inputs: torch.Tensor,\n lora_a_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n scaling: float = 1.0,\n) -> None:\n assert inputs.dtype == lora_a_weights.dtype\n assert inputs.dtype in [torch.float16, torch.bfloat16]\n assert lora_a_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_a_weights.size(-1)\n assert inputs.is_contiguous()\n\n if lora_a_weights.ndim == 4:\n assert lora_a_weights.size(1) == 1\n lora_a_weights = lora_a_weights.squeeze(dim=1)\n else:\n assert lora_a_weights.ndim == 3\n assert lora_a_weights.is_contiguous()\n assert output_tensor.is_contiguous()\n\n batches = lora_indices_tensor.size(0)\n N, K = lora_a_weights.shape[-2:]\n BLOCK_N = triton.next_power_of_2(N)\n # config = get_lora_op_configs(\"bgmv_shrink\", batches, K)\n\n grid = lambda META: (\n META[\"SPLIT_K\"],\n batches,\n )\n _bgmv_shrink_kernel[grid](\n inputs,\n lora_a_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n scaling,\n inputs.stride(0),\n inputs.stride(1),\n lora_a_weights.stride(0),\n lora_a_weights.stride(1),\n lora_a_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_N=BLOCK_N,\n BLOCK_K=256,\n SPLIT_K=64,\n\n # **config,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-46a469ca02a6461f982c2db2cba0b711", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code implements a custom Triton kernel for calculating softmax probabilities from logits in token sequences. The kernel function '_fwd_kernel_token_softmax' is designed to process each batch and head independently, leveraging Triton's just-in-time (JIT) compilation capabilities. Inputs include 'Logics', a tensor representing the logits; 'B_Start_Loc', indicating the starting indices of each sequence in the batch; 'B_Seqlen', giving the length of each sequence; and 'Prob_Out', an output tensor where the computed probabilities are stored.\n\n The kernel works by first determining the current batch and head using 'tl.program_id', then loading the relevant portion of the logits using these indices. It uses the 'BLOCK_SIZE' to handle data in chunks, ensuring it respects sequence boundaries with a mask based on 'B_Seqlen'. Softmax computation is done in a numerically stable way by subtracting the max value from each element in the row before exponentiation, followed by a normalization step with the sum of the exponentials.\n\n The wrapper function 'token_softmax_fwd' is responsible for setting up the kernel launch. It calculates an appropriate 'BLOCK_SIZE' based on 'max_input_len', adapts the number of warps used based on 'BLOCK_SIZE' to optimize performance, and orchestrates the batch and head dimension handling by launching the Triton kernel with these parameters. This function does not compute gradients, indicated by '@torch.no_grad()', making it suitable for inference tasks where gradients are not needed.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for forward token softmax\n@triton.jit\ndef _fwd_kernel_token_softmax(\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n stride_logic_h, stride_logic_bs,\n stride_prob_h, stride_prob_bs,\n BLOCK_SIZE: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,\n mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32)\n\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n\n tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets)\n * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len)\n return\n\n# Function to launch the Triton kernel\n@torch.no_grad()\ndef token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):\n BLOCK_SIZE = triton.next_power_of_2(max_input_len)\n batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]\n\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n\n _fwd_kernel_token_softmax[(batch, head_num)](\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n Logics.stride(0), Logics.stride(1),\n Prob_Out.stride(0), Prob_Out.stride(1),\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-9814bac70f50489a81ad75f65e52a2ee", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `matmul_kernel`, is a specialized GPU matrix multiplication operation. \n It employs a blocked tiling strategy for efficient computation of the result matrix `c` from input matrices `a` and `b`. \n Within this kernel, operations are parallelized across blocks defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. \n These blocks allow the kernel to load sub-matrices, perform computations, and manage memory more efficiently.\n\n The kernel begins by computing indices for thread execution, segmenting the operation across various program IDs derived from the grid dimensions. \n For each thread block, it computes offsets `offs_am`, `offs_bn`, and `offs_k` to read data from the input matrices.\n\n In a loop iterating over slices of the K dimension, sub-matrices are loaded using `tl.load` with masks to handle boundary conditions. \n These matrices are then multiplied using `tl.dot`, accumulating results in a local accumulator. \n Memory access patterns are optimized using `tl.max_contiguous` and `tl.multiple_of` to align data in cache-friendly ways.\n\n The function finally writes the accumulated results to the output matrix `c`, with care taken to respect bounds and using conditional storage via `tl.store`.\n\n The `matmul` function wraps this kernel, preparing inputs and meta-parameters based on the matrix data types and dimensions. \n It enforces input compatibility, establishes execution grid dimensions, and sets device memory for output. \n Configuration parameters such as BLOCK_SIZE_M, num_stages, and num_warps are determined per data type, \n ensuring optimal kernel execution tailored for either float16 or Triton's experimental float8 types.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef _matmul_launch_metadata(grid, kernel, args):\n ret = {}\n M, N, K = args[\"M\"], args[\"N\"], args[\"K\"]\n ret[\"name\"] = f\"{kernel.name} [M={M}, N={N}, K={K}]\"\n if \"c_ptr\" in args:\n bytes_per_elem = args[\"c_ptr\"].element_size()\n else:\n bytes_per_elem = 1 if args[\"FP8_OUTPUT\"] else 2\n ret[f\"flops{bytes_per_elem * 8}\"] = 2. * M * N * K\n ret[\"bytes\"] = bytes_per_elem * (M * K + N * K + M * N)\n return ret\n\n\n@triton.jit(launch_metadata=_matmul_launch_metadata)\ndef matmul_kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, #\n BLOCK_SIZE_N: tl.constexpr, #\n BLOCK_SIZE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, #\n ):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n start_m = pid_m * BLOCK_SIZE_M\n start_n = pid_n * BLOCK_SIZE_N\n\n offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n offs_am = tl.where(offs_am < M, offs_am, 0)\n offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n\n offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if (c_ptr.dtype.element_ty == tl.float8e4nv):\n c = accumulator.to(tl.float8e4nv)\n else:\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef matmul(a, b):\n configs = {\n torch.float8_e4m3fn: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 128, \"GROUP_SIZE_M\": 8, \"num_stages\": 4,\n \"num_warps\": 8\n }, torch.float16: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64, \"GROUP_SIZE_M\": 8, \"num_stages\": 3,\n \"num_warps\": 8\n }\n }\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.dtype == b.dtype, \"Incompatible dtypes\"\n M, K = a.shape\n K, N = b.shape\n dtype = a.dtype\n\n c = torch.empty((M, N), device=a.device, dtype=dtype)\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]), )\n matmul_kernel[grid](\n a, b, c, #\n M, N, K, #\n a.stride(0), a.stride(1), #\n b.stride(0), b.stride(1), #\n c.stride(0), c.stride(1), #\n BLOCK_SIZE_M=configs[dtype][\"BLOCK_SIZE_M\"], #\n BLOCK_SIZE_N=configs[dtype][\"BLOCK_SIZE_N\"], #\n BLOCK_SIZE_K=configs[dtype][\"BLOCK_SIZE_K\"], #\n GROUP_SIZE_M=configs[dtype][\"GROUP_SIZE_M\"], #\n num_stages=configs[dtype][\"num_stages\"], #\n num_warps=configs[dtype][\"num_warps\"], #\n )\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0a96c628c587431284338d9ffd76deed", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel `bmm_kernel` is designed to perform batched matrix multiplication on input tensors A and B, producing an output tensor O. This kernel is optimized for GPU execution, utilizing tiling and parallel processing.\n The key parameters include:\n - `A`, `B`: Input tensors with shapes (batch, M, K) and (batch, K, N). These tensors are assumed to be contiguous in memory.\n - `O`: Output tensor with shape (batch, M, N) where the result of the multiplication is stored.\n - `M`, `N`, `K`: Dimensions representing the sizes of the matrices to be multiplied.\n - `TILE_M`, `TILE_N`, `TILE_K`: Compile-time constants defining the tile sizes for the matrix multiplication. This tiling strategy divides the matrices into smaller blocks that can be processed in parallel.\n - `GROUP_M`: Influences the order of computation across different tiles.\n - `DIVISIBLE_M`, `DIVISIBLE_N`, `DIVISIBLE_K`: Compile-time booleans determining if the dimensions M, N, and K are perfectly divisible by the respective tile sizes, optimizing boundary handling.\n \n The kernel calculates grid indices and offsets for processing each tile. Masking is used to handle partial tiles when dimensions aren't perfectly divisible by tile sizes. The main computation involves loading tiles of A and B, performing a matrix multiplication using `tl.dot`, and accumulating the results into the output tile.\n \n The `bmm` function is a wrapper that initializes the output tensor, determines grid dimensions based on input sizes, and launches the `bmm_kernel` using Triton's `autotune` to find optimal configurations for TILE_M, TILE_N, TILE_K, and other parameters. It ensures the correct execution context on the GPU device and calls the kernel with appropriate arguments.\n ", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef heur_divisible_m(args):\n return args[\"M\"] % args[\"TILE_M\"] == 0\n\n\ndef heur_divisible_n(args):\n return args[\"N\"] % args[\"TILE_N\"] == 0\n\n\ndef heur_divisible_k(args):\n return args[\"K\"] % args[\"TILE_K\"] == 0\n\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\"TILE_M\": 32, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 1},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 64, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 64, \"TILE_N\": 64, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 64, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 128, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=2,\n ),\n triton.Config(\n {\"TILE_M\": 32, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 1},\n num_warps=4,\n num_stages=3,\n ),\n triton.Config(\n {\"TILE_M\": 64, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=3,\n ),\n triton.Config(\n {\"TILE_M\": 64, \"TILE_N\": 64, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=3,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 32, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=3,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 64, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=3,\n ),\n triton.Config(\n {\"TILE_M\": 128, \"TILE_N\": 128, \"TILE_K\": 32, \"GROUP_M\": 2},\n num_warps=4,\n num_stages=3,\n ),\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.heuristics(\n {\n \"DIVISIBLE_M\": heur_divisible_m,\n \"DIVISIBLE_N\": heur_divisible_n,\n \"DIVISIBLE_K\": heur_divisible_k,\n }\n)\n@triton.jit\ndef bmm_kernel(\n A,\n B,\n O,\n M,\n N,\n K,\n TILE_M: tl.constexpr,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n DIVISIBLE_M: tl.constexpr,\n DIVISIBLE_N: tl.constexpr,\n DIVISIBLE_K: tl.constexpr,\n):\n # batch offsets\n pid_b = tl.program_id(2)\n A += pid_b * M * K\n B += pid_b * K * N\n O += pid_b * M * N\n\n pidx = tl.program_id(0)\n pidy = tl.program_id(1)\n\n if GROUP_M == 1:\n pid_m, pid_n = pidx, pidy\n else:\n # reorder CTAs\n gridx = tl.num_programs(0)\n gridy = tl.num_programs(1)\n pid = pidx + pidy * gridx\n\n num_CTA_per_group = gridy * GROUP_M\n\n group_id = pid // num_CTA_per_group\n inner_group_id = pid % num_CTA_per_group\n if (group_id * GROUP_M + GROUP_M) > gridx:\n GROUP_SIZE = gridx % GROUP_M\n else:\n GROUP_SIZE = GROUP_M\n pid_m = group_id * GROUP_M + inner_group_id % GROUP_SIZE\n pid_n = inner_group_id // GROUP_SIZE\n\n offs_m = pid_m * TILE_M + tl.arange(0, TILE_M)\n offs_n = pid_n * TILE_N + tl.arange(0, TILE_N)\n offs_k = tl.arange(0, TILE_K)\n\n if not DIVISIBLE_M:\n mask_m = offs_m < M\n if not DIVISIBLE_N:\n mask_n = offs_n < N\n\n a_ptrs = A + offs_m[:, None] * K + offs_k[None, :]\n b_ptrs = B + offs_k[:, None] * N + offs_n[None, :]\n o_ptrs = O + offs_m[:, None] * N + offs_n[None, :]\n\n num_iters = tl.cdiv(K, TILE_K)\n o = tl.zeros((TILE_M, TILE_N), dtype=tl.float32)\n for _ in range(num_iters):\n if DIVISIBLE_K:\n if DIVISIBLE_M:\n mask_a = None\n else:\n mask_a = mask_m[:, None]\n if DIVISIBLE_N:\n mask_b = None\n else:\n mask_b = mask_n[None, :]\n else:\n mask_k = offs_k < K\n if DIVISIBLE_M:\n mask_a = mask_k[None, :]\n else:\n mask_a = mask_m[:, None] & mask_k[None, :]\n if DIVISIBLE_N:\n mask_b = mask_k[:, None]\n else:\n mask_b = mask_k[:, None] & mask_n[None, :]\n\n a = tl.load(a_ptrs, mask_a)\n b = tl.load(b_ptrs, mask_b)\n\n offs_k += TILE_K\n a_ptrs += TILE_K\n b_ptrs += TILE_K * N\n\n o += tl.dot(a, b, allow_tf32=False)\n\n if DIVISIBLE_M and DIVISIBLE_N:\n mask_c = None\n elif DIVISIBLE_M and not DIVISIBLE_N:\n mask_c = mask_n[None, :]\n elif not DIVISIBLE_M and DIVISIBLE_N:\n mask_c = mask_m[:, None]\n else:\n mask_c = mask_m[:, None] & mask_n[None, :]\n tl.store(o_ptrs, o, mask_c)\n\n\ndef bmm(A, B):\n logging.debug(\"GEMS BMM\")\n batch, M, K = A.shape\n _, _, N = B.shape\n A = A.contiguous()\n B = B.contiguous()\n out = torch.empty((batch, M, N), dtype=A.dtype, device=A.device)\n\n grid_fn = lambda meta: (\n triton.cdiv(meta[\"M\"], meta[\"TILE_M\"]),\n triton.cdiv(meta[\"N\"], meta[\"TILE_N\"]),\n batch,\n )\n with torch.cuda.device(A.device):\n bmm_kernel[grid_fn](A, B, out, M, N, K)\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-225bc7ae357e4955bd447331a6034c3a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton implementation consists of functions designed to optimize the scaled dot-product attention calculation found in transformer models. The core components are `_fwd_kernel`, `_bwd_preprocess`, `_bwd_kernel`, and the wrapper class `_attention`.\n\n `_fwd_kernel` is the forward computation kernel that performs matrix multiplication between query `Q` and key `K` to produce attention scores, which are then scaled by `sm_scale` and masked where necessary. Softmax is applied to these scores to obtain the attention weights. These weights are used to compute the weighted sum of the value `V` tensor, yielding the final output stored in `Out`. The function manages data using block pointers (`q_tile_ptr`, `k_tile_ptr`, `v_tile_ptr`) which efficiently load and store data in smaller manageable blocks. Block shapes (`BLOCK_M`, `BLOCK_N`, `BLOCK_DMODEL`) define the granularity of these operations.\n\n `_bwd_preprocess` prepares the gradient of the output (`DO`) for backpropagation, ensuring numerical stability by scaling it with the normalization constant `L`, and computes a delta factor for later use in gradient computation.\n\n `_bwd_kernel` computes the gradients for the input tensors `Q`, `K`, `V` using the chain rule. It processes data in tiles, recalculates necessary softmax probabilities, and uses them to compute gradients with respect to the inputs. The processed gradients are accumulated and stored back in their respective tensors.\n\n The `_attention` class interfaces with PyTorch's autograd to manage forward and backward passes. It configures the Triton grid to split the data for parallel processing, ensures appropriate memory allocation for intermediate variables like `L`, `m`, and sets parameters such as `BLOCK`, `num_warps`, and `num_stages` to balance the workload across available resources.\n\n Constants like `BLOCK` (often set to 128) define the tile size for the kernels, impacting both performance and memory efficiency. `ctx` is used to pass necessary information between forward and backward functions, enabling efficient gradient calculation in PyTorch's automatic differentiation system.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L, M,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n Z, H, N_CTX, D0,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n\n # initialize offsets for store\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n stride_qh_2d = stride_qh // stride_qm // stride_qk\n\n q_tile_ptr = tl.make_block_ptr(base=Q,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_hz * stride_qh_2d + start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n k_tile_ptr = tl.make_block_ptr(base=K,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(off_hz * stride_qh_2d, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0))\n v_tile_ptr = tl.make_block_ptr(base=V,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(off_hz * stride_qh_2d, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0))\n out_tile_ptr = tl.make_block_ptr(base=Out,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(off_hz * stride_qh_2d + start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n # load q: it will stay in SRAM throughout\n q = tl.load(q_tile_ptr)\n\n # loop over k, v and update accumulators\n for start_n in range(0, (start_m + 1) * BLOCK_M, BLOCK_N):\n # -- compute qk ----\n k = tl.load(k_tile_ptr, boundary_check=(0, 1))\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (\n start_n + offs_n[None, :]), qk, float(\"-inf\"))\n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(tl.float16)\n v = tl.load(v_tile_ptr, boundary_check=(0, 1))\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_N, 0])\n v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_N, 0])\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # write back l and m\n l_ptrs = L + off_hz * N_CTX + offs_m\n m_ptrs = M + off_hz * N_CTX + offs_m\n tl.store(l_ptrs, l_prev)\n tl.store(m_ptrs, m_prev)\n\n acc = acc.to(tl.float16)\n tl.store(out_tile_ptr, acc, boundary_check=(0, 1))\n\n\n@triton.jit\ndef _bwd_preprocess(\n Out, DO, L,\n NewDO, Delta,\n BLOCK_M: tl.constexpr, D_HEAD: tl.constexpr,\n):\n off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, D_HEAD)\n # load\n o = tl.load(Out + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n do = tl.load(DO + off_m[:, None] * D_HEAD + off_n[None, :]).to(tl.float32)\n denom = tl.load(L + off_m).to(tl.float32)\n # compute\n do = do / denom[:, None]\n delta = tl.sum(o * do, axis=1)\n # write-back\n tl.store(NewDO + off_m[:, None] * D_HEAD + off_n[None, :], do)\n tl.store(Delta + off_m, delta)\n\n\n@triton.jit\ndef _bwd_kernel(\n Q, K, V, sm_scale, Out, DO,\n DQ, DK, DV,\n L, M,\n D,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n Z, H, N_CTX, D0,\n num_block,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n off_hz = tl.program_id(0)\n off_z = off_hz // H\n off_h = off_hz % H\n # init tile_ptr\n stride_qz_2d = stride_qz // stride_qm // stride_qk\n stride_qh_2d = stride_qh // stride_qm // stride_qk\n\n q_tile_ptr = tl.make_block_ptr(base=Q,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n k_tile_ptr = tl.make_block_ptr(base=K,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_kn, stride_kk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n v_tile_ptr = tl.make_block_ptr(base=V,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n do_tile_ptr = tl.make_block_ptr(base=DO,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n dq_tile_ptr = tl.make_block_ptr(base=DQ,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n dk_tile_ptr = tl.make_block_ptr(base=DK,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n dv_tile_ptr = tl.make_block_ptr(base=DV,\n shape=(D0, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(\n off_z * stride_qz_2d + off_h * stride_qh_2d, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0))\n # offset pointers for batch/head\n DQ += off_z * stride_qz + off_h * stride_qh\n for start_n in range(0, num_block):\n lo = start_n * BLOCK_M\n # initialize row/col offsets\n offs_qm = lo + tl.arange(0, BLOCK_M)\n offs_n = start_n * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_m = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, BLOCK_DMODEL)\n # initialize pointers to value-like data\n dq_ptrs = DQ + (offs_qm[:, None] * stride_qm + offs_k[None, :] * stride_qk)\n # pointer to row-wise quantities in value-like data\n D_ptrs = D + off_hz * N_CTX\n m_ptrs = M + off_hz * N_CTX\n # initialize dv amd dk\n dv = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n dk = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # k and v stay in SRAM throughout\n k = tl.load(k_tile_ptr, boundary_check=(0, 1))\n v = tl.load(v_tile_ptr, boundary_check=(0, 1))\n # loop over rows\n for start_m in range(lo, num_block * BLOCK_M, BLOCK_M):\n offs_m_curr = start_m + offs_m\n # load q, k, v, do on-chip\n q = tl.load(q_tile_ptr, boundary_check=(0, 1))\n # recompute p = softmax(qk, dim=-1).T\n # NOTE: `do` is pre-divided by `l`; no normalization here\n qk = tl.dot(q, tl.trans(k))\n qk = tl.where(offs_m_curr[:, None] >= (\n offs_n[None, :]), qk, float(\"-inf\"))\n m = tl.load(m_ptrs + offs_m_curr)\n p = tl.exp(qk * sm_scale - m[:, None])\n # compute dv\n do = tl.load(do_tile_ptr, boundary_check=(0, 1))\n dv += tl.dot(tl.trans(p.to(tl.float16)), do)\n # compute dp = dot(v, do)\n Di = tl.load(D_ptrs + offs_m_curr)\n dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]\n dp += tl.dot(do, tl.trans(v))\n # compute ds = p * (dp - delta[:, None])\n ds = p * dp * sm_scale\n # compute dk = dot(ds.T, q)\n dk += tl.dot(tl.trans(ds.to(tl.float16)), q)\n # compute dq\n dq = tl.load(dq_tile_ptr)\n dq += tl.dot(ds.to(tl.float16), k)\n tl.store(dq_tile_ptr, dq)\n # increment pointers\n dq_ptrs += BLOCK_M * stride_qm\n q_tile_ptr = tl.advance(q_tile_ptr, [BLOCK_M, 0])\n do_tile_ptr = tl.advance(do_tile_ptr, [BLOCK_M, 0])\n dq_tile_ptr = tl.advance(dq_tile_ptr, [BLOCK_M, 0])\n q_tile_ptr = tl.advance(q_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n do_tile_ptr = tl.advance(do_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n dq_tile_ptr = tl.advance(dq_tile_ptr, [lo + (1 - num_block) * BLOCK_M, 0])\n # increment tile pointers\n k_tile_ptr = tl.advance(k_tile_ptr, [BLOCK_M, 0])\n v_tile_ptr = tl.advance(v_tile_ptr, [BLOCK_M, 0])\n # write-back\n tl.store(dv_tile_ptr, dv.to(tl.float16), boundary_check=(0, 1))\n tl.store(dk_tile_ptr, dk.to(tl.float16), boundary_check=(0, 1))\n dv_tile_ptr = tl.advance(dv_tile_ptr, [BLOCK_M, 0])\n dk_tile_ptr = tl.advance(dk_tile_ptr, [BLOCK_M, 0])\n\n\nclass _attention(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, sm_scale):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n grid = (triton.cdiv(q.shape[2], BLOCK), q.shape[0] * q.shape[1], 1)\n L = torch.empty(\n (q.shape[0] * q.shape[1], q.shape[2]),\n device=q.device,\n dtype=torch.float32)\n m = torch.empty(\n (q.shape[0] * q.shape[1], q.shape[2]),\n device=q.device,\n dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n D0 = q.shape[0] * q.shape[1] * q.shape[2]\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n L, m,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2], D0,\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk, num_warps=num_warps,\n num_stages=2,\n )\n\n ctx.save_for_backward(q, k, v, o, L, m)\n ctx.grid = grid\n ctx.sm_scale = sm_scale\n ctx.BLOCK_DMODEL = Lk\n return o\n\n @staticmethod\n def backward(ctx, do):\n BLOCK = 128\n q, k, v, o, l, m = ctx.saved_tensors\n do = do.contiguous()\n dq = torch.zeros_like(q, dtype=torch.float32)\n dk = torch.empty_like(k)\n dv = torch.empty_like(v)\n do_scaled = torch.empty_like(do)\n delta = torch.empty_like(l)\n D0 = q.shape[0] * q.shape[1] * q.shape[2]\n _bwd_preprocess[(ctx.grid[0] * ctx.grid[1], )](\n o, do, l,\n do_scaled, delta,\n BLOCK_M=BLOCK, D_HEAD=ctx.BLOCK_DMODEL,\n )\n _bwd_kernel[(ctx.grid[1],)](\n q, k, v, ctx.sm_scale,\n o, do_scaled,\n dq, dk, dv,\n l, m,\n delta,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n q.shape[0], q.shape[1], q.shape[2], D0,\n ctx.grid[0],\n BLOCK_M=BLOCK, BLOCK_N=BLOCK,\n BLOCK_DMODEL=ctx.BLOCK_DMODEL, num_warps=8,\n num_stages=1,\n )\n return dq, dk, dv, None\n\n\nattention = _attention.apply\n\n\n\n" + ] + }, + { + "id": "openseek-8-89491986d0af49568b39e6d587526a6e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel code implements the `rotary_embedding` function, which provides rotary positional embeddings for transformer architectures. It operates in two modes based on the presence of `k_cache`. The `rotary_embedding_kernel` is invoked when `k_cache` is None and processes the query (`q`) and key (`k`) tensors by loading precomputed cosine (`cos`) and sine (`sin`) values according to their strides and positions. It then applies these trigonometric values to the tensors to perform a rotation operation defined as: out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin and out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos for the queries and similarly for the keys if applicable. The `fused_rotary_embedding_kernel_v2` handles the case when `k_cache` is provided. This involves using block-wise operations for cache management. For this, it utilizes additional input tensors like `block_tables` to determine storage positions within the cache and `kv_lengths` for past sequence length information. Both kernels leverage trigonometric identities to rotate the embeddings and conditionally store results based on the compute strategy defined by `grid` settings which scale based on input tensor dimensions and number of threads (`num_warps`). Essential inputs include `q`, `k`, `cos`, `sin`, and optionally `k_cache`, `block_tables`, and `kv_lengths`. Important constants like `Q_HEAD_NUM` and `HEAD_DIM` are used to define structural properties of the input tensors.\n ", + "output": [ + "from typing import Optional\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef rotary_embedding_kernel(\n q,\n k,\n cos,\n sin,\n q_token_stride,\n q_head_stride,\n k_token_stride,\n k_head_stride,\n head_dim_stride,\n cos_token_stride,\n cos_stride,\n q_total_tokens,\n Q_HEAD_NUM: tl.constexpr,\n KV_GROUP_NUM: tl.constexpr,\n HEAD_DIM: tl.constexpr,\n BLOCK_TOKENS: tl.constexpr, # token range length\n):\n cur_head_idx = tl.program_id(0)\n cur_token_block_idx = tl.program_id(1)\n\n tokens_range = cur_token_block_idx * BLOCK_TOKENS + tl.arange(0, BLOCK_TOKENS)\n dim_range0 = tl.arange(0, HEAD_DIM // 2)\n dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)\n\n off_cos_sin = tokens_range[:, None] * cos_token_stride + dim_range0[None, :] * cos_stride\n loaded_cos = tl.load(cos + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)\n loaded_sin = tl.load(sin + off_cos_sin, mask=(tokens_range[:, None] < q_total_tokens), other=0.0)\n\n off_q0 = (\n tokens_range[:, None, None] * q_token_stride\n + cur_head_idx * q_head_stride\n + dim_range0[None, None, :] * head_dim_stride\n )\n off_q1 = (\n tokens_range[:, None, None] * q_token_stride\n + cur_head_idx * q_head_stride\n + dim_range1[None, None, :] * head_dim_stride\n )\n loaded_q0 = tl.load(\n q + off_q0,\n mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),\n other=0.0,\n )\n loaded_q1 = tl.load(\n q + off_q1,\n mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),\n other=0.0,\n )\n out_q0 = loaded_q0 * loaded_cos[:, None, :] - loaded_q1 * loaded_sin[:, None, :]\n out_q1 = loaded_q0 * loaded_sin[:, None, :] + loaded_q1 * loaded_cos[:, None, :]\n\n tl.store(\n q + off_q0,\n out_q0,\n mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),\n )\n tl.store(\n q + off_q1,\n out_q1,\n mask=((cur_head_idx < Q_HEAD_NUM) & (tokens_range[:, None, None] < q_total_tokens)),\n )\n\n handle_kv = cur_head_idx % KV_GROUP_NUM == 0\n if handle_kv:\n k_head_idx = cur_head_idx // KV_GROUP_NUM\n off_k0 = (\n tokens_range[:, None, None] * k_token_stride\n + k_head_idx * k_head_stride\n + dim_range0[None, None, :] * head_dim_stride\n )\n off_k1 = (\n tokens_range[:, None, None] * k_token_stride\n + k_head_idx * k_head_stride\n + dim_range1[None, None, :] * head_dim_stride\n )\n loaded_k0 = tl.load(\n k + off_k0,\n mask=(tokens_range[:, None, None] < q_total_tokens),\n other=0.0,\n )\n loaded_k1 = tl.load(\n k + off_k1,\n mask=(tokens_range[:, None, None] < q_total_tokens),\n other=0.0,\n )\n out_k0 = loaded_k0 * loaded_cos[:, None, :] - loaded_k1 * loaded_sin[:, None, :]\n out_k1 = loaded_k0 * loaded_sin[:, None, :] + loaded_k1 * loaded_cos[:, None, :]\n tl.store(\n k + off_k0,\n out_k0,\n mask=(tokens_range[:, None, None] < q_total_tokens),\n )\n tl.store(\n k + off_k1,\n out_k1,\n mask=(tokens_range[:, None, None] < q_total_tokens),\n )\n\n\n@triton.jit\ndef fused_rotary_embedding_kernel_v2(\n q,\n k,\n cos,\n sin,\n kv_cache,\n BLOCK_TABLES,\n context_lengths,\n q_token_stride,\n q_head_stride,\n k_token_stride,\n k_head_stride,\n head_dim_stride,\n cos_token_stride,\n cos_stride,\n cacheb_stride,\n cacheh_stride,\n cachebs_stride,\n cached_stride,\n bts_stride,\n btb_stride,\n block_size,\n q_total_tokens,\n Q_HEAD_NUM: tl.constexpr,\n HEAD_DIM: tl.constexpr,\n):\n block_head_index = tl.program_id(0)\n if block_head_index >= Q_HEAD_NUM:\n return\n block_token_index = tl.program_id(1)\n\n dim_range0 = tl.arange(0, HEAD_DIM // 2)\n dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)\n\n off_q0 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range0 * head_dim_stride\n off_q1 = block_token_index * q_token_stride + block_head_index * q_head_stride + dim_range1 * head_dim_stride\n off_k0 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range0 * head_dim_stride\n off_k1 = block_token_index * k_token_stride + block_head_index * k_head_stride + dim_range1 * head_dim_stride\n\n loaded_q0 = tl.load(\n q + off_q0,\n )\n loaded_q1 = tl.load(\n q + off_q1,\n )\n\n loaded_k0 = tl.load(\n k + off_k0,\n )\n\n loaded_k1 = tl.load(\n k + off_k1,\n )\n\n off_cos_sin = block_token_index * cos_token_stride + dim_range0 * cos_stride\n\n loaded_cos = tl.load(cos + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)\n loaded_sin = tl.load(sin + off_cos_sin, mask=(block_token_index < q_total_tokens), other=0.0)\n\n out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin\n out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos\n\n out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin\n out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos # total_tokens, head_num, head_dim\n\n past_kv_seq_len = tl.load(context_lengths + block_token_index) - 1\n\n last_block_idx = past_kv_seq_len // block_size\n block_table_ptr = BLOCK_TABLES + block_token_index * bts_stride\n block_ids = tl.load(block_table_ptr + last_block_idx * btb_stride, mask=(block_token_index < q_total_tokens))\n offsets_in_last_block = (past_kv_seq_len % block_size) * cachebs_stride\n\n kv_range0 = (\n block_ids * cacheb_stride\n + block_head_index * cacheh_stride\n + offsets_in_last_block\n + dim_range0 * cached_stride\n )\n kv_range1 = (\n block_ids * cacheb_stride\n + block_head_index * cacheh_stride\n + offsets_in_last_block\n + dim_range1 * cached_stride\n )\n\n tl.store(\n kv_cache + kv_range0,\n out_k0,\n )\n tl.store(\n kv_cache + kv_range1,\n out_k1,\n )\n\n # concat\n tl.store(\n q + off_q0,\n out_q0,\n )\n tl.store(\n q + off_q1,\n out_q1,\n )\n\n\ndef rotary_embedding(\n q: torch.Tensor,\n k: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n k_cache: Optional[torch.Tensor] = None,\n block_tables: Optional[torch.Tensor] = None,\n kv_lengths: Optional[torch.Tensor] = None,\n):\n \"\"\"\n Args:\n q: query tensor, [total_tokens, head_num, head_dim]\n k: key tensor, [total_tokens, kv_head_num, head_dim]\n cos: cosine for rotary embedding, [max_position_len, head_dim]\n sin: sine for rotary embedding, [max_position_len, head_dim]\n k_cache (torch.Tensor): Blocked key cache. [num_blocks, num_kv_heads, block_size, head_dim]\n kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]\n block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]\n \"\"\"\n q_total_tokens, q_head_num, head_dim = q.shape\n assert q.size(0) == k.size(0)\n BLOCK_TOKENS = 4\n\n if head_dim >= 512:\n num_warps = 16\n elif head_dim >= 256:\n num_warps = 8\n else:\n num_warps = 4\n\n k_head_num = k.size(1)\n q_token_stride, q_head_stride, head_dim_stride = q.stride()\n k_token_stride, k_head_stride, _ = k.stride()\n cos_token_stride, cos_stride = cos.stride()\n\n assert q_head_num % k_head_num == 0\n kv_group_num = q_head_num // k_head_num\n\n if k_cache == None:\n grid = lambda META: (\n q_head_num,\n triton.cdiv(q_total_tokens, META[\"BLOCK_TOKENS\"]),\n )\n rotary_embedding_kernel[grid](\n q,\n k,\n cos,\n sin,\n q_token_stride,\n q_head_stride,\n k_token_stride,\n k_head_stride,\n head_dim_stride,\n cos_token_stride,\n cos_stride,\n q_total_tokens,\n Q_HEAD_NUM=q_head_num,\n KV_GROUP_NUM=kv_group_num,\n HEAD_DIM=head_dim,\n BLOCK_TOKENS=BLOCK_TOKENS,\n num_warps=num_warps,\n )\n else:\n grid = (triton.next_power_of_2(q_head_num), q_total_tokens)\n fused_rotary_embedding_kernel_v2[grid](\n q,\n k,\n cos,\n sin,\n k_cache,\n block_tables,\n kv_lengths,\n q_token_stride,\n q_head_stride,\n k_token_stride,\n k_head_stride,\n head_dim_stride,\n cos_token_stride,\n cos_stride,\n k_cache.stride(0),\n k_cache.stride(1),\n k_cache.stride(2),\n k_cache.stride(3),\n block_tables.stride(0),\n block_tables.stride(1),\n k_cache.size(-2),\n q_total_tokens,\n Q_HEAD_NUM=q_head_num,\n HEAD_DIM=head_dim,\n num_warps=num_warps,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-6f557325ac024cc0b434cb90fc0eef2b", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `chunk_global_reversed_cumsum_scalar_kernel` is a Triton kernel that computes a reversed cumulative sum for a given 3D tensor `s` of shape (B, H, T) where B is the batch dimension, H is the number of heads, and T is the sequence length. The kernel executes for each (B, H) pair independently, by utilizing a grid where the grid size is (B * H). For each block, it initializes an accumulation variable `b_z` to zero. It then iterates over the dimension T in blocks of size `BT`, moving backwards from the end to the start. In each iteration, it computes the sum of the current block, updates the accumulation variable, calculates the cumulative sum by subtracting it from the sum, and stores the result in the output tensor `o`. The function `chunk_global_reversed_cumsum_scalar` serves as a wrapper to handle PyTorch tensor input and output, and it defines the kernel configuration parameters such as grid size and tensor data types.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=4),\n ],\n key=[]\n)\n@triton.jit\ndef chunk_global_reversed_cumsum_scalar_kernel(\n s,\n o,\n T: tl.constexpr,\n BT: tl.constexpr,\n):\n i_bh = tl.program_id(0)\n b_z = tl.zeros([], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)\n b_zz = tl.sum(b_s, axis=0)\n b_z += b_zz\n b_o = b_s - tl.cumsum(b_s, axis=0) + b_z[None]\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))\n\n\ndef chunk_global_reversed_cumsum_scalar(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T = s.shape\n dtype = dtype or s.dtype\n grid = (B * H,)\n z = torch.empty_like(s, dtype=dtype)\n chunk_global_reversed_cumsum_scalar_kernel[grid](\n s, z,\n T=T\n )\n return z\n\n\n\n" + ] + }, + { + "id": "openseek-8-22a61cb2e60444bf811a6e051d8e9ff5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `flash_attn_triton` function is designed to compute the flash attention mechanism efficiently using the Triton library. It accepts input tensors `q`, `k`, and `v` representing the query, key, and value matrices with shape `[batch_size, heads, sequence_length, dimensions]`. The function handles optional causal masking and scales the dot products by `sm_scale`. It calculates the output using the `_fwd_kernel`, a specialized Triton kernel.\n\n The Triton kernel `_fwd_kernel` performs operations in block-sized chunks defined by `BLOCK_M` and `BLOCK_N`, iterating over the sequence length in these fixed-size blocks. For each block, it loads segments of `Q`, `K`, and `V`, computes the dot products `qk`, applies softmax scaling with corrections for numerical stability using exp2 and logarithmic transformations. If `IS_CAUSAL` is enabled, it applies causal masking ensuring future information is not used.\n\n The maximum values are tracked for each block to maintain precision during the softmax computation, followed by accumulation of scaled values into an output buffer. Once all blocks are processed, the output is normalized and stored in tensor `o`. The intermediate maximum and denominator values are also stored for each query's position.\n\n Stride parameters for each tensor define memory layout, facilitating batched operations. `num_warps` controls the number of Triton warps used, determined based on the `Lk` dimension size. The function is specialized for `Lk` sizes of 16, 32, 64, or 128, ensuring optimal performance with these typical transformer sizes. Each Triton kernel launch is controlled by a 3D grid where dimensions correspond to blocks in the sequence length, the batch and head dimensions, and a single block in depth, effectively parallelizing the attention computation over the input dimensions.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\ndef flash_attn_triton(q, k, v, causal=True, sm_scale=1):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n o = torch.empty_like(q)\n\n BLOCK_M = 128\n BLOCK_N = 64\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n L,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, DIM=Lk,\n IS_CAUSAL=causal,\n num_warps=num_warps,\n num_stages=4)\n\n return o\n\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n L,\n O,\n stride_q_bs, stride_q_head, stride_q_seqlen, stride_q_dim,\n stride_k_bs, stride_k_head, stride_k_seqlen, stride_k_dim,\n stride_v_bs, stride_v_head, stride_v_seqlen, stride_v_dim,\n stride_o_bs, stride_o_head, stride_o_seqlen, stride_o_dim,\n BS, HEAD, SEQLEN,\n BLOCK_M: tl.constexpr,\n DIM: tl.constexpr,\n BLOCK_N: tl.constexpr,\n IS_CAUSAL: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_bs_head = tl.program_id(1)\n\n qkv_base_offset = off_bs_head * stride_q_head\n Q_block_ptr = tl.make_block_ptr(\n base=Q + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_q_seqlen, stride_q_dim),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, DIM),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + qkv_base_offset,\n shape=(DIM, SEQLEN),\n strides=(stride_k_dim, stride_k_seqlen),\n offsets=(0, 0),\n block_shape=(DIM, BLOCK_N),\n order=(0, 1),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_k_seqlen, stride_v_dim),\n offsets=(0, 0),\n block_shape=(BLOCK_N, DIM),\n order=(1, 0),\n )\n off_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_n = tl.arange(0, BLOCK_N)\n max = tl.zeros([BLOCK_M], dtype=tl.float32) - float('inf')\n denom = tl.zeros([BLOCK_M], dtype=tl.float32)\n out_buffer = tl.zeros([BLOCK_M, DIM], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(Q_block_ptr)\n q = (q * qk_scale).to(tl.float16)\n lo = 0\n hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else SEQLEN\n for start_n in range(lo, hi, BLOCK_N):\n k = tl.load(K_block_ptr)\n v = tl.load(V_block_ptr)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n if IS_CAUSAL:\n qk = tl.where(off_m[:, None] >= (start_n + off_n[None, :]), qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n\n max_new = tl.maximum(max, tl.max(qk, 1))\n alpha = tl.math.exp2(max - max_new)\n nume = tl.math.exp2(qk - max_new[:, None])\n out_scale = denom * 0 + alpha\n out_buffer *= out_scale[:, None]\n out_buffer += tl.dot(nume.to(tl.float16), v)\n denom = denom * alpha + tl.sum(nume, 1)\n max = max_new\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n\n out_buffer = out_buffer / denom[:, None]\n l_ptr = L + off_bs_head * SEQLEN + off_m\n tl.store(l_ptr, max + tl.math.log2(denom))\n O_block_ptr = tl.make_block_ptr(\n base=O + qkv_base_offset,\n shape=(SEQLEN, DIM),\n strides=(stride_o_seqlen, stride_o_dim),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, DIM),\n order=(1, 0),\n )\n tl.store(O_block_ptr, out_buffer.to(tl.float16))\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c8895663eb9746448709fbc5319a00e4", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code includes Triton kernels designed to convert query and key matrices into int8 format, suitable for applications requiring low precision and optimized memory usage. The process is encapsulated in two specialized kernels, q_kernel_per_block_int8 and k_kernel_per_block_int8, each responsible for quantizing a block of the respective matrices.\n\n The q_kernel_per_block_int8 function loads a block of the query matrix, normalizes it by a scaling factor derived from its max absolute value, quantizes to int8 while rounding, and stores both the quantized data and the scaling factor for later reconstruction. The factor ensures precision is maintained post quantization.\n\n Similarly, the k_kernel_per_block_int8 processes key matrices, performing an identical operation tailored for the key matrix dimensions and blocks.\n\n Both functions are invoked by per_block_int8, which handles the necessary reshaping of input matrices based on dimensionality. It initializes empty int8 and scaling tensors, calculates the grid configuration for kernel execution, and launches the respective Triton kernels with appropriate parameters.\n\n Inputs and Outputs:\n - q (torch.Tensor): Input query matrix.\n - k (torch.Tensor): Input key matrix.\n - BLKQ, BLKK (int): Block sizes for the kernels' operation.\n - q_int8, k_int8 (torch.Tensor): Outputs holding the int8 representations.\n - q_scale, k_scale (torch.Tensor): Output scaling factors for precision recovery.\n\n Key Parameters and Code Flow:\n - The Triton kernels utilize program_id to distinguish thread execution contexts, calculate offsets for memory access, and utilize block-wise processing via arithmetic with block sizes BLKQ and BLKK.\n - Input tensors are reshaped for uniform handling irrespective of the original dimensions (3D or 4D tensors).\n - The scaling ensures that the largest absolute value in each block maps to the range of representable int8 values, with the quantization mimicking nearest integer rounding.\n - Overall, the implementation serves to efficiently convert matrices to a space-efficient int8 format while storing necessary metadata (scales) for accurate de-quantization.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef q_kernel_per_block_int8(X, X_int8, BLK: tl.constexpr, Scale, L, C: tl.constexpr, scale_stride):\n # Triton kernel for processing block of query matrix in int8\n off_b = tl.program_id(1) \n off_blk = tl.program_id(0)\n x_offset = off_b * L * C \n offs_m = off_blk*BLK + tl.arange(0, BLK)\n offs_k = tl.arange(0, C)\n\n x_ptrs = X + x_offset + offs_m[:, None] * C + offs_k[None, :]\n x_int8_ptrs = X_int8 + x_offset + offs_m[:, None] * C + offs_k[None, :]\n scale_ptrs = Scale + off_b * scale_stride + off_blk \n\n x = tl.load(x_ptrs, mask=offs_m[:, None] < L)\n x *= (C**-0.5 * 1.44269504)\n scale = tl.max(tl.abs(x)) / 127.\n x_int8 = x / scale\n x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)\n x_int8 = x_int8.to(tl.int8)\n tl.store(x_int8_ptrs, x_int8, mask=offs_m[:, None] < L)\n tl.store(scale_ptrs, scale)\n\n@triton.jit\ndef k_kernel_per_block_int8(X, X_int8, BLK: tl.constexpr, Scale, L, C: tl.constexpr, scale_stride):\n # Triton kernel for processing block of key matrix in int8\n off_b = tl.program_id(1) \n off_blk = tl.program_id(0)\n x_offset = off_b * L * C \n offs_m = off_blk*BLK + tl.arange(0, BLK)\n offs_k = tl.arange(0, C)\n\n x_ptrs = X + x_offset + offs_m[:, None] * C + offs_k[None, :]\n x_int8_ptrs = X_int8 + x_offset + offs_m[:, None] * C + offs_k[None, :]\n scale_ptrs = Scale + off_b * scale_stride + off_blk \n\n x = tl.load(x_ptrs, mask=offs_m[:, None] < L)\n scale = tl.max(tl.abs(x)) / 127.\n x_int8 = x / scale\n x_int8 += 0.5 * tl.where(x_int8 >= 0, 1, -1)\n x_int8 = x_int8.to(tl.int8)\n tl.store(x_int8_ptrs, x_int8, mask=offs_m[:, None] < L)\n tl.store(scale_ptrs, scale)\n\n\ndef per_block_int8(q, k, BLKQ=128, BLKK=64):\n # Function to initialize and launch Triton kernels for processing q and k\n q_int8 = torch.empty_like(q, dtype=torch.int8)\n k_int8 = q_int8.clone()\n\n if q.dim() == 3:\n q_scale = torch.empty((q.shape[-3], (q.shape[-2] + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32)\n k_scale = torch.empty((k.shape[-3], (k.shape[-2] + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32)\n elif q.dim() == 4:\n q_scale = torch.empty((q.shape[-4], q.shape[-3], (q.shape[-2] + BLKQ - 1) // BLKQ, 1), device=q.device, dtype=torch.float32)\n k_scale = torch.empty((k.shape[-4], k.shape[-3], (k.shape[-2] + BLKK - 1) // BLKK, 1), device=q.device, dtype=torch.float32)\n\n q = q.view(-1, q.shape[-2], q.shape[-1])\n k = k.view(-1, k.shape[-2], k.shape[-1])\n\n B, L, C = q.shape\n grid = ((L+BLKQ-1)//BLKQ, B, )\n q_kernel_per_block_int8[grid](\n q, \n q_int8,\n BLKQ,\n q_scale,\n L, C, q_scale.stride(0) if q_scale.dim() == 3 else q_scale.stride(1),\n )\n\n grid = ((L+BLKK-1)//BLKK, B, )\n k_kernel_per_block_int8[grid](\n k, \n k_int8,\n BLKK,\n k_scale,\n L, C, k_scale.stride(0) if k_scale.dim() == 3 else k_scale.stride(1),\n )\n\n return q_int8, q_scale, k_int8, k_scale\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0eb0c4c49a4f4f4a9d534064abbe9b5b", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code implements a Triton-based kernel to efficiently compute the forward pass of the attention mechanism, a core component of Transformer models. The central routine `forward(q, k, v, q_scale, k_scale)` utilizes this kernel to apply attention across query (`q`), key (`k`), and value (`v`) matrices with respective scaling factors (`q_scale`, `k_scale`). It creates an empty output tensor `o` to store the results.\n\n The kernel `_attn_fwd` is invoked with a specific grid configuration that distributes computation across blocks. It processes segments of the input matrices defined by `BLOCK_M` and `BLOCK_N`, effectively handling different stages of the attention mechanism. The kernel reads slices of the input matrices into shared memory, computes scaled dot-products between queries and keys, applies exponential scaling to derive attention weights, and finally aggregates these weights to produce the attention output using the loaded value matrix.\n\n The `_attn_fwd_inner` function operates in multiple stages based on the `STAGE` parameter, designed to progressively refine the attention computation by adjusting scores, applying softmax transformations, and accumulating results into the `acc` tensor. Each iteration dynamically updates scaling offsets and memory pointers, ensuring proper alignment and data coherence for efficient processing.\n\n In summary, this implementation harnesses the capabilities of Triton to parallelize the attention computation, crucial for optimizing the performance of models reliant on self-attention mechanisms.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, q_scale, \n K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, \n N_CTX: tl.constexpr):\n if STAGE == 1:\n lo, hi = 0, start_m * BLOCK_M\n elif STAGE == 2:\n lo, hi = start_m * BLOCK_M, (start_m + 1) * BLOCK_M\n lo = tl.multiple_of(lo, BLOCK_M)\n K_scale_ptr += lo // BLOCK_N\n K_ptrs += HEAD_DIM * lo\n V_ptrs += HEAD_DIM * lo\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k_mask = (offs_n[None, :] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[:, None])\n k = tl.load(K_ptrs, mask=k_mask)\n k_scale = tl.load(K_scale_ptr)\n qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale\n if STAGE == 2:\n mask = offs_m[:, None] >= (start_n + offs_n[None, :])\n qk = qk + tl.where(mask, 0, -1.0e6)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n else:\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_ptrs, mask=(offs_n[:, None] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[None, :]))\n p = p.to(tl.float16)\n \n acc += tl.dot(p, v.to(tl.float16), out_dtype=tl.float16) \n m_i = m_ij\n K_ptrs += BLOCK_N * HEAD_DIM\n K_scale_ptr += 1\n V_ptrs += BLOCK_N * HEAD_DIM\n return acc, l_i, m_i\n\n@triton.jit\ndef _attn_fwd(Q, K, V, Q_scale, K_scale, Out, \n stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vk, stride_vn, \n stride_oz, stride_oh, stride_om, stride_on, \n Z, H, N_CTX, \n HEAD_DIM: tl.constexpr, \n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n \n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n vk_offset = qvk_offset // stride_qm\n q_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_M)\n k_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_N)\n \n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, 128)\n Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n Q_scale_ptr = Q_scale + q_scale_offset + start_m\n K_ptrs = K + qvk_offset + offs_k[:, None] + offs_n[None, :] * stride_kn\n K_scale_ptr = K_scale + k_scale_offset\n V_ptrs = V + qvk_offset + offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk\n O_block_ptr = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, 128], dtype=tl.float32)\n q = tl.load(Q_ptrs, mask=(offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n q_scale = tl.load(Q_scale_ptr)\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 4 - STAGE, offs_m, offs_n, N_CTX)\n acc, l_i, _ = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 2, offs_m, offs_n, N_CTX)\n acc = acc / l_i[:, None]\n tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask=(offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n\ndef forward(q, k, v, q_scale, k_scale):\n BLOCK_M = 128\n BLOCK_N = 64\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n o = torch.empty_like(q, dtype=torch.bfloat16)\n stage = 3\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, q_scale, k_scale, o, \n q.stride(0), q.stride(1), q.stride(2), q.stride(3), \n k.stride(0), k.stride(1), k.stride(2), k.stride(3), \n v.stride(0), v.stride(1), v.stride(2), v.stride(3), \n o.stride(0), o.stride(1), o.stride(2), o.stride(3), \n q.shape[0], q.shape[1], \n N_CTX=q.shape[2], \n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, \n STAGE=stage, \n num_warps=8, \n num_stages=3)\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-bd2770523a0d4451afe800ef9115835e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel 'rotary_kernel' applies rotary positional encoding on the input tensor X using cosine (COS) and sine (SIN) tensors. The kernel supports both variable and fixed sequence lengths, controlled by IS_VARLEN. It processes data in blocks, with dimensions controlled by BLOCK_M and BLOCK_K, determining the workload per thread. The kernel adjusts input and output pointers based on batch, head, and sequence index calculations. The interleaved data layout is supported via INTERLEAVED, affecting how input data is accessed and processed. The optional conjugate transformation is controlled by CONJUGATE, affecting the sine component. The 'apply_rotary' function is a high-level Python wrapper preparing the input tensors and invoking the kernel with appropriate grid and block configurations. It handles data type checks, ensures tensor contiguity, and manages inplace operations. It calculates grid sizes based on sequence lengths, and sets up strides for various tensor dimensions (batch, sequence, heads, headdim). The function ensures the kernel operates correctly, given the input tensor shapes and specified operation parameters.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Union, Optional\n\n\n@triton.jit\ndef rotary_kernel(\n OUT, # Pointers to matrices\n X,\n COS,\n SIN,\n CU_SEQLENS,\n SEQLEN_OFFSETS, # this could be int or a pointer\n # Matrix dimensions\n seqlen,\n rotary_dim,\n seqlen_ro,\n # strides\n stride_out_batch,\n stride_out_seqlen,\n stride_out_nheads,\n stride_out_headdim,\n stride_x_batch,\n stride_x_seqlen,\n stride_x_nheads,\n stride_x_headdim,\n # Meta-parameters\n BLOCK_K: tl.constexpr,\n IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr,\n INTERLEAVED: tl.constexpr,\n CONJUGATE: tl.constexpr,\n BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(\n COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0\n ).to(tl.float32)\n sin = tl.load(\n SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x0 = tl.load(\n X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0\n ).to(tl.float32)\n x1 = tl.load(\n X + rotary_dim_half * stride_x_headdim,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n tl.store(\n OUT + rotary_dim_half * stride_out_headdim,\n o1,\n mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half),\n )\n else:\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1 # 1, 0, 3, 2, 5, 4, ...\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(\n COS,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=1.0,\n ).to(tl.float32)\n sin = tl.load(\n SIN,\n mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half),\n other=0.0,\n ).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(\n tl.float32\n )\n x1 = tl.load(\n X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0\n ).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\ndef apply_rotary(\n x: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None,\n max_seqlen: Optional[int] = None,\n interleaved=False,\n inplace=False,\n conjugate=False,\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n assert max_seqlen is not None, \"If cu_seqlens is passed in, then max_seqlen must be passed\"\n total_seqlen, nheads, headdim = x.shape\n batch_p_1 = cu_seqlens.shape[0]\n batch = batch_p_1 - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n assert sin.shape == cos.shape\n rotary_dim *= 2\n assert rotary_dim <= headdim, \"rotary_dim must be <= headdim\"\n assert headdim <= 256, \"Only support headdim <= 256\"\n assert seqlen_ro >= seqlen, \"seqlen_ro must be >= seqlen\"\n\n assert (\n cos.dtype == sin.dtype\n ), f\"cos and sin must have the same dtype, got {cos.dtype} and {sin.dtype}\"\n assert (\n x.dtype == cos.dtype\n ), f\"Input and cos/sin must have the same dtype, got {x.dtype} and {cos.dtype}\"\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n assert seqlen_offsets.shape == (batch,)\n assert seqlen_offsets.dtype in [torch.int32, torch.int64]\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n assert seqlen_offsets + seqlen <= seqlen_ro\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = (\n 32\n if rotary_dim <= 32\n else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n )\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads)\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 128 else 4)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, # data ptrs\n x,\n cos,\n sin,\n cu_seqlens,\n seqlen_offsets,\n seqlen, # shapes\n rotary_dim,\n seqlen_ro,\n output.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0\n output.stride(-3), # seqlen_stride or total_seqlen_stride\n output.stride(-2), # nheads_stride\n output.stride(-1), # headdim_stride\n x.stride(0) if not is_varlen else 0, # batch_strides if not varlen else 0\n x.stride(-3), # seqlen stride or total_seqlen_stride\n x.stride(-2), # nheads stride\n x.stride(-1), # headdim stride\n BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor),\n is_varlen,\n interleaved,\n conjugate,\n BLOCK_M,\n )\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-78899ea17fdc4c20b19bfdb4e692b4bd", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton-based implementation of the softmax function is optimized for GPU execution by utilizing parallelization across matrix rows. The kernel, `softmax_kernel`, is compiled just-in-time using Triton's `@triton.jit`, making it highly efficient for large-scale data processing.\n\n `softmax_kernel` is defined to perform the softmax computation for each row independently. The function parameters include pointers to the input and output matrices (`input_ptr`, `output_ptr`), the strides for row advancement (`input_row_stride`, `output_row_stride`), the number of columns (`n_cols`), and a compile-time constant `BLOCK_SIZE` to manage block-wise operations.\n\n The kernel execution logic involves:\n - Identifying the current row using `tl.program_id(0)`.\n - Calculating the starting pointer for this row.\n - Utilizing a block of threads defined by `BLOCK_SIZE` to handle possible overreach beyond the actual column count.\n - Loading the row into on-chip memory with masking to handle cases where the block size exceeds column count.\n - Performing numerical stabilization by subtracting the maximum value from the row elements to ensure stable computation.\n - Computing the exponentials, followed by normalization to derive the softmax probabilities.\n - Storing the result back in the output matrix.\n\n The `softmax` function facilitates kernel execution. It calculates the optimal `BLOCK_SIZE` by finding the smallest power of two greater than the column count and adjusts `num_warps` according to `BLOCK_SIZE` to ensure efficient parallel execution. The function then initializes an output tensor, launches the Triton kernel with one block per input matrix row, and returns the computed softmax result.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Subtract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n softmax_kernel[(n_rows, )](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y\n\n\n\n" + ] + }, + { + "id": "openseek-8-a6148578973b4f89ab5525843e7c6321", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The function `fused_native_layer_norm_no_welford` implements a layer normalization operation without Welford's method using Triton to accelerate the computation on a GPU. The Triton kernel `triton_red_fused_native_layer_norm_no_welford` is used to compute the layer normalization in parallel. The kernel is defined with two autotuning configurations that adjust the block sizes (`XBLOCK` and `RBLOCK`) for performance tuning. The kernel operates by first loading input data and computing the mean across a dimension specified by `xnumel` and `rnumel`. It accumulates data in blocks and computes the mean using Triton's `tl.sum` function. After computing the mean, it calculates the variance and then the inverse standard deviation using the reciprocal square root function `libdevice.rsqrt`. These intermediate results (mean and inverse standard deviation) are stored in output buffers `in_out_ptr0` and `in_out_ptr1`. Finally, the kernel normalizes the input data by subtracting the mean, multiplying by the inverse standard deviation, and applying scale and shift using additional input parameters `in_ptr1` and `in_ptr2`. This processed data is stored in `out_ptr0`, which is part of the function's return values. The function manages CUDA devices and streams explicitly and utilizes Triton's grid function to define the launch configuration for the kernel.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.runtime.triton_helpers import libdevice\n\nempty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda\nreinterpret_tensor = torch.ops.inductor._reinterpret_tensor\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"XBLOCK\": 1,\n \"RBLOCK\": 1024,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"XBLOCK\": 1,\n \"RBLOCK\": 2048,\n },\n num_stages=1,\n num_warps=8,\n ),\n ],\n key=[\"xnumel\", \"rnumel\"],\n)\n@triton.jit\ndef triton_red_fused_native_layer_norm_no_welford(\n in_out_ptr0,\n in_out_ptr1,\n in_ptr0,\n in_ptr1,\n in_ptr2,\n out_ptr0,\n xnumel,\n rnumel,\n XBLOCK: tl.constexpr,\n RBLOCK: tl.constexpr,\n):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n _tmp3 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(\n in_ptr0 + (r1 + (rnumel * x0)), rmask, eviction_policy=\"evict_last\"\n ).to(tl.float32)\n tmp1 = tmp0.to(tl.float32)\n tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])\n tmp4 = _tmp3 + tmp2\n _tmp3 = tmp4\n tmp3 = tl.sum(_tmp3, 1)[:, None]\n tmp5 = rnumel # 4096.0\n tmp6 = tmp3 / tmp5\n tl.debug_barrier()\n tl.store(in_out_ptr0 + (x0), tmp6, None)\n _tmp12 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp7 = tl.load(\n in_ptr0 + (r1 + (rnumel * x0)), rmask, eviction_policy=\"evict_last\"\n ).to(tl.float32)\n tmp8 = tmp7.to(tl.float32)\n tmp9 = tmp8 - tmp6\n tmp10 = tmp9 * tmp9\n tmp11 = tl.broadcast_to(tmp10, [XBLOCK, RBLOCK])\n tmp13 = _tmp12 + tmp11\n _tmp12 = tmp13\n tmp12 = tl.sum(_tmp12, 1)[:, None]\n tmp14 = rnumel # 4096.0\n tmp15 = tmp12 / tmp14\n tmp16 = 1e-05\n tmp17 = tmp15 + tmp16\n tmp18 = libdevice.rsqrt(tmp17)\n tl.debug_barrier()\n tl.store(in_out_ptr1 + (x0), tmp18, None)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp19 = tl.load(\n in_ptr0 + (r1 + (rnumel * x0)), rmask, eviction_policy=\"evict_first\"\n ).to(tl.float32)\n tmp23 = tl.load(in_ptr1 + (r1), rmask, eviction_policy=\"evict_last\").to(\n tl.float32\n )\n tmp26 = tl.load(in_ptr2 + (r1), rmask, eviction_policy=\"evict_last\").to(\n tl.float32\n )\n tmp20 = tmp19.to(tl.float32)\n tmp21 = tmp20 - tmp6\n tmp22 = tmp21 * tmp18\n tmp24 = tmp23.to(tl.float32)\n tmp25 = tmp22 * tmp24\n tmp27 = tmp26.to(tl.float32)\n tmp28 = tmp25 + tmp27\n tmp29 = tmp28.to(tl.float32)\n tl.store(out_ptr0 + (r1 + (rnumel * x0)), tmp29, rmask)\n\ndef fused_native_layer_norm_no_welford(primals_1, primals_2, primals_3):\n S, D = primals_3.shape\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided_cuda((S, 1), (1, S), torch.float32)\n buf1 = buf0\n del buf0 # reuse\n buf2 = empty_strided_cuda((S, 1), (1, S), torch.float32)\n buf3 = reinterpret_tensor(buf2, (S, 1), (1, 1), 0)\n del buf2 # reuse\n buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16)\n stream0 = get_raw_stream(0)\n grid = lambda META: (triton.cdiv(S, META[\"XBLOCK\"]),)\n triton_red_fused_native_layer_norm_no_welford[grid](\n buf1, buf3, primals_3, primals_1, primals_2, buf4, S, D\n )\n return (\n buf4,\n primals_3,\n buf1,\n buf3,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-813d9bf606a448b7a706d5483a96c79a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton kernel `mul_kernel` that multiplies each element of a given source tensor `src` with a constant value referred to as the exponent compensator and stores the result in the destination tensor `dst`. The kernel is designed to work on a portion of the tensor at a time, where the portion size is controlled by a parameter `BLOCK_SIZE`, which is a compile-time constant. The constant exponent compensator is defined as `2.0 ** (127 - 15)` to adjust for exponent bias in floating-point representation. The kernel computes indices for accessing elements in the source tensor using `tl.program_id(0)` to identify the current block and `tl.arange(0, BLOCK_SIZE)` for offsets within the block. Using these indices, it loads the corresponding elements from `src`, multiplies them by the compensator, and stores the results in `dst`.\n\n The wrapper function `launch_mul_kernel` is provided to set up and invoke this Triton kernel. It takes a PyTorch tensor `src` and an optional block size `BLOCK_SIZE` as arguments. It creates an empty destination tensor `dst` on the CUDA device with the same shape as `src`. It then launches the kernel with the appropriate grid size `(src.shape[0] // BLOCK_SIZE,)`, passing `src`, `dst`, and `BLOCK_SIZE` as parameters. Finally, the function returns the filled destination tensor `dst` containing the results of the multiplication operation.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel to multiply each element in the source tensor by a constant exponent compensator\n@triton.jit\ndef mul_kernel(src, dst, BLOCK_SIZE: tl.constexpr):\n # Define a constant exponent compensator\n exponent_compensator: tl.constexpr = 2.0 ** (127 - 15)\n # Calculate the indices for the current program ID\n idxs = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Load elements from the source tensor\n x = tl.load(src + idxs)\n # Multiply each element by the exponent compensator\n y = x * exponent_compensator\n # Store the result in the destination tensor\n tl.store(dst + idxs, y)\n\n# Function to launch the Triton kernel\ndef launch_mul_kernel(src, BLOCK_SIZE=1):\n # Create an empty tensor for the result\n dst = torch.empty(src.shape, dtype=torch.float32, device='cuda')\n # Launch the Triton kernel\n mul_kernel[(src.shape[0] // BLOCK_SIZE,)](src, dst, BLOCK_SIZE)\n return dst\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-6da7d83140b3478286ae117e960b0264", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `dropout` function is implemented using a Triton kernel named `_dropout`. The purpose is to apply dropout to an input tensor `x`, controlled by a mask tensor `x_keep`, and store the result in `output`. The kernel operates on blocks of elements (`BLOCK_SIZE=1024`) and for each block, calculates an offset based on the program ID. The elements are loaded from `x_ptr` and `x_keep_ptr` using these offsets and a mask to ensure within bounds. The core operation modifies the elements of `x` based on the mask `x_keep`, using the formula `tl.where(x_keep, x / (1 - p), 0.0)` to scale retained elements and zero-out others. Results are stored back to `output_ptr`. The host function `dropout` ensures input tensor `x` is contiguous, computes the grid size for dispatching the Triton kernel, and calls `_dropout` with necessary pointers and parameters, enabling parallel execution over the elements of `x`.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _dropout(\n x_ptr, # pointer to the input\n x_keep_ptr, # pointer to a mask of 0s and 1s\n output_ptr, # pointer to the output\n n_elements, # number of elements in the `x` tensor\n p, # probability that an element of `x` is changed to zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n # Load data\n x = tl.load(x_ptr + offsets, mask=mask)\n x_keep = tl.load(x_keep_ptr + offsets, mask=mask)\n # The line below is the crucial part, described in the paragraph above!\n output = tl.where(x_keep, x / (1 - p), 0.0)\n # Write-back output\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef dropout(x, x_keep, p):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)\n return output\n\n\n\n" + ] + }, + { + "id": "openseek-8-21bc2972286640ef85823f94ef7e9e22", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe code defines a Triton kernel `_fwd_kernel_destindex_copy_quantize_kv` and a Python wrapper function `destindex_copy_quantize_kv`. The kernel processes the input tensor `K` by loading it according to the block size specified by `BLOCK_DMODEL` and `BLOCK_HEAD`. The function operates across multiple heads as specified by `head_num`. The kernel uses the `DestLoc` array to determine the destination index where each sequence element should be copied to in the output tensor `Out`. It calculates the absolute maximum for each head slice to determine a scaling factor, which is then used to quantize the data to int8. This quantized data is stored in `Out`, and the scaling factors are stored in `Out_scale`. The kernel uses 1 warp, and the grid size is determined by the sequence length of `DestLoc`. The function assumes `K`, `Out`, and `Out_scale` have consistent dimensions for heads and the model. The wrapper function sets up this operation by defining grid and block sizes, strides, and invoking the kernel.\n", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K, Dest_loc, Out, Out_scale,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n stride_os_bs, stride_os_h, stride_os_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], \n mask=offs_h[:, None] < head_num, other=0.0)\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=offs_h[:, None] < head_num)\n tl.store(os_ptrs, data_scale, mask=offs_h[:, None] < head_num)\n\n\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K, DestLoc, Out, Out_scale,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n" + ] + }, + { + "id": "openseek-8-8b73cbaf8079483daf37c0606ea78e15", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_fwd_kernel_token_att2` is designed to compute attention over tokens in a parallelized fashion. It takes several inputs: probability matrix `Prob`, value tensor `V`, and an output tensor `Out` among other auxiliary inputs such as `Req_to_tokens` that maps requests to tokens, and several batch-related indices and strides. \n\n - The kernel computes attention by iterating over token sequences (`BLOCK_N` size blocks) within a batch and head context (`cur_batch` and `cur_head`). The indices are adjusted based on the current position, head, and sequence length to ensure that only valid indices are processed.\n \n - For each block of tokens, it loads relevant slices of the `Prob` and `V` tensors into `p_value` and `v_value`. These are multiplied and accumulated into `acc`, the accumulator for the final result.\n\n - Once all blocks are processed, `acc` is cast to the output's data type and stored in the output tensor `Out` at the corresponding position calculated via strides.\n\n The function `token_att_fwd2` serves as a Python interface for this Triton kernel. It initializes parameters like block size (`BLOCK`), grid dimensions, and computes `kv_group_num` to determine the head groups for key-value pairs. It then calls the Triton kernel with these parameters. No gradients are computed as the function is decorated with `@torch.no_grad()`.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob,\n V,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n B_Att_Start_Loc,\n B_Att_Seqlen,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_ph,\n stride_pbs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n sliding_window,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # Triton kernel for computing token attention\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n cur_kv_head = cur_head // kv_group_num\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0)\n cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch)\n\n v_loc_off = (\n cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s\n )\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_att_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_att_seq_len, other=0.0)\n v_loc = tl.load(\n Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s,\n mask=(start_n + offs_n + cur_batch_start_index) < cur_batch_seq_len,\n other=0.0,\n )\n v_value = tl.load(\n V + v_offs + v_loc[:, None] * stride_vbs,\n mask=(start_n + offs_n[:, None] + cur_batch_start_index) < cur_batch_seq_len,\n other=0.0,\n )\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(Out.dtype.element_ty)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd2(\n prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window\n):\n # Launch the Triton kernel for token attention\n BLOCK = 128\n batch, head = B_req_idx.shape[0], prob.shape[0]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n kv_group_num = prob.shape[0] // v.shape[1]\n\n _fwd_kernel_token_att2[grid](\n prob,\n v,\n out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n B_Att_Start_Loc,\n B_Att_Seqlen,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n prob.stride(0),\n prob.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n out.stride(0),\n out.stride(1),\n out.stride(2),\n kv_group_num=kv_group_num,\n sliding_window=sliding_window,\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-ac2812a82dcd40d79f3d2be2fc877115", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `var_len_copy_kernel_triton` is a kernel designed to perform variable-length copy operations from a source array to a destination array. The kernel function accepts pointers to arrays `old_a_start`, `old_a_len`, `old_a_location`, `new_a_start`, `new_a_location`, and a constant `BLOCK_SIZE`. The `a_id` is the unique identifier for the current program instance, determining which segment of data to copy. For each segment, defined by `a_id`, the function reads the starting index and length from `old_a_start` and `old_a_len`, respectively, and performs a looped copy using Triton's `tl.load` and `tl.store` commands. The `old_offset` and `new_offset` are used to iterate over elements within the block, while masks ensure that out-of-bound accesses are prevented. The function `launch_var_len_copy_triton` prepares and launches the kernel with the grid size equal to the number of segments to copy, and specifies `BLOCK_SIZE` as 256.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef var_len_copy_kernel_triton(old_a_start, old_a_len, old_a_location, new_a_start, new_a_location,\n BLOCK_SIZE: tl.constexpr):\n a_id = tl.program_id(0)\n length = tl.load(old_a_len + a_id)\n old_start = tl.load(old_a_start + a_id)\n new_start = tl.load(new_a_start + a_id)\n old_offset = tl.arange(0, BLOCK_SIZE)\n new_offset = tl.arange(0, BLOCK_SIZE)\n for i in range(0, length, BLOCK_SIZE):\n v = tl.load(old_a_location + old_start + i + old_offset, mask=old_offset < length)\n tl.store(new_a_location + new_start + i + new_offset, v, mask=new_offset < length)\n\ndef launch_var_len_copy_triton(old_a_start, old_a_len, old_location, new_a_start, new_a_location):\n BLOCK_SIZE = 256\n grid_size = (len(old_a_start),)\n\n var_len_copy_kernel_triton[grid_size](\n old_a_start, old_a_len, old_location, new_a_start, new_a_location, BLOCK_SIZE)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-d690e6abe2fc4ba68205c658697c208e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_score_kernel` is a Triton-based GPU kernel designed to compute attention scores in transformer models. It uses a block-wise approach to leverage GPU parallelism and optimize memory access. The main inputs are the matrices `Q` (query), `K` (key), and `M` (mask), with the result being stored in `Out`. The kernel iterates over blocks defined by `BLOCK_M` and `BLOCK_N`, performing operations in these fixed-size submatrices.\n\n Key computations involve loading query blocks `Q` and key blocks `K`, calculating the dot product to get the score matrix `qk`, and scaling this matrix by a factor `sm_scale` which accounts for the softmax operation. It includes conditions for sliding window attention, checking bounds to handle non-even dimensions, and reducing results into the output vector `o`.\n\n The function `get_score` is a Python wrapper function that prepares inputs and executes the kernel. It determines grid size based on the dimensions of K and Q. If there is a resource constraint error due to large block sizes, it reduces `BLOCK_M` and `BLOCK_N` by half and retries execution. This function also calculates the scale factor for attention (`sm_scale`) and manages additional parameters like sliding window configurations.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n_BLOCK_N = 64\n_BLOCK_M = 64\n\n@triton.heuristics(\n {\n \"IS_EVEN_M\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_M\"] == 0,\n \"IS_EVEN_N\": lambda args: args[\"NKV_CTX\"] % args[\"BLOCK_N\"] == 0,\n }\n)\n@triton.jit\ndef _score_kernel(\n Q, K, M, sm_scale, Out,\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_oz, stride_oh, stride_on,\n Z, H, H_KV, #\n N_CTX, #\n ROUND_CTX,\n NKV_CTX,\n sliding_window_offset,\n sliding_window_size,\n SLIDING_WINDOW: tl.constexpr,\n COMPLEMENT_SLIDING_WINDOW: tl.constexpr,\n IS_EVEN_M: tl.constexpr,\n IS_EVEN_N: tl.constexpr,\n BLOCK_M: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n):\n start_n = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n off_hkv = off_h // (H//H_KV)\n q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh\n m_ptrs = M + off_hz * ROUND_CTX + tl.arange(0, BLOCK_M)\n o = tl.zeros([BLOCK_M], dtype=tl.float32)\n\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, NKV_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, start_n * BLOCK_N),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n\n if IS_EVEN_N:\n k = tl.load(K_block_ptr)\n else:\n k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n\n\n lo = 0\n hi = ROUND_CTX\n qk_scale = sm_scale\n qk_scale *= 1.4426950408889634 # 1/log(2)\n\n for start_m in range(lo, hi, BLOCK_M):\n start_m = tl.multiple_of(start_m, BLOCK_M)\n if IS_EVEN_M:\n q = tl.load(Q_block_ptr)\n else:\n q = tl.load(Q_block_ptr, boundary_check=(0,1), padding_option=\"zero\")\n\n m = tl.load(m_ptrs)\n\n # calc qk\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = qk * qk_scale\n\n if SLIDING_WINDOW:\n dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \\\n + start_m - start_n * BLOCK_N + sliding_window_offset\n\n if COMPLEMENT_SLIDING_WINDOW:\n mask = (dist >= sliding_window_size)\n else:\n mask = (dist >= 0) & (dist < sliding_window_size)\n\n qk = qk - m[:, None]\n p = tl.math.exp2(qk) # (BLOCK_M, BLOCK_N)\n\n if SLIDING_WINDOW:\n p = tl.where(mask, p, 0)\n\n if not IS_EVEN_N:\n p = tl.where(\n ((tl.arange(0, BLOCK_M) + start_m) < N_CTX)[:, None],\n p, 0\n )\n\n o += tl.sum(p, axis=0)\n\n\n Q_block_ptr = tl.advance(Q_block_ptr, offsets=(BLOCK_M, 0))\n m_ptrs = m_ptrs + BLOCK_M\n\n o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh\n o_range = tl.arange(0, BLOCK_N) + start_n * BLOCK_N # orange\n o_ptrs = Out + o_offset + o_range\n tl.store(o_ptrs, o.to(Out.type.element_ty), mask = o_range < NKV_CTX)\n\ndef get_score(q, k, m, sliding_window, complement_sliding_window):\n N_CTX = q.size(-2)\n NKV_CTX = k.size(-2)\n ROUND_CTX = m.size(-1)\n ret = torch.zeros(\n (q.size(0), q.size(1), k.size(2)),\n dtype=k.dtype, device=k.device\n )\n if sliding_window is not None:\n sliding_window_offset, sliding_window_size = sliding_window\n else:\n sliding_window_offset, sliding_window_size = None, None\n\n grid = lambda META: (\n triton.cdiv(k.shape[2], META[\"BLOCK_N\"]),\n q.shape[0] * q.shape[1]\n )\n sm_scale = 1 / math.sqrt(q.size(-1))\n\n global _BLOCK_N\n global _BLOCK_M\n\n try:\n _score_kernel[grid](\n q, k, m, sm_scale, ret,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n ret.stride(0), ret.stride(1), ret.stride(2),\n q.size(0), q.size(1), k.size(1),\n N_CTX, ROUND_CTX, NKV_CTX,\n sliding_window_offset,\n sliding_window_size,\n SLIDING_WINDOW=(sliding_window is not None),\n COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,\n BLOCK_M=_BLOCK_M,\n BLOCK_N=_BLOCK_N,\n BLOCK_DMODEL=q.size(-1)\n )\n except triton.OutOfResources as E:\n from warnings import warn\n _BLOCK_N = _BLOCK_N // 2\n _BLOCK_M = _BLOCK_M // 2\n warn(f\"Triton Attention Output Resources. {E}\\nUse smaller block size {_BLOCK_N}.\")\n _score_kernel[grid](\n q, k, m, sm_scale, ret,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n ret.stride(0), ret.stride(1), ret.stride(2),\n q.size(0), q.size(1), k.size(1),\n N_CTX, ROUND_CTX, NKV_CTX,\n sliding_window_offset,\n sliding_window_size,\n SLIDING_WINDOW=(sliding_window is not None),\n COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,\n BLOCK_M=_BLOCK_M,\n BLOCK_N=_BLOCK_N,\n BLOCK_DMODEL=q.size(-1)\n )\n\n return ret\n\n\n\n" + ] + }, + { + "id": "openseek-8-96ac98716f4d443ea073a9d1cc16cd40", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a batched vector-matrix multiplication using Triton. The primary kernel function is `batched_vecmat_kernel`, which performs the core operations. This kernel function is designed to handle inputs `A` (shape [dim_m, dim_k]) and `B` (shape [dim_m, dim_n, dim_k]), where the goal is to compute the vector-matrix product for each vector-matrix pair defined by corresponding slices in A and B. The kernel operates in blocks, defined by `block_m`, `block_n`, and `block_k`, which partition the matrices for parallel processing.\n\n Inside the kernel, `m_index` and `n_index` are the block indices for the current program instance, determined by Triton's grid structure. `output_tile` computes the output matrix indices that this block will modify. The `vecmat` variable accumulates the results for the block. The loop iterates over `k_blocks`, partitioning the K dimension, loading slices of `A` and `B`, broadcasting `a` to match `b`'s shape, and computing the dot product for the block using `tl.sum`. The result for each block is stored using `tl.store`.\n\n The `batched_vecmat` function initializes tensors A, B, and the output on the GPU using PyTorch. It checks that the dimensions are divisible by the respective block sizes and computes the grid dimensions. The Triton kernel is then launched over this grid, passing the initialized tensors and configuration parameters, ensuring that the vector-matrix multiplication is performed for all specified blocks in the input dimensions, and the result is stored in `output`.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef batched_vecmat_kernel(\n A, # shape: [dim_m, dim_k]\n B, # shape: [dim_m, dim_n, dim_k]\n dim_m, dim_n, dim_k,\n output,\n block_m: tl.constexpr, block_n: tl.constexpr, block_k: tl.constexpr):\n m_index = tl.program_id(0)\n n_index = tl.program_id(1)\n output_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_n \\\n + (n_index * block_n + tl.arange(0, block_n))[None, :]\n\n vecmat = tl.zeros([block_m, block_n], dtype=A.dtype.element_ty)\n k_blocks = dim_k // block_k\n for k_index in range(k_blocks):\n a_tile = (m_index * block_m + tl.arange(0, block_m))[:, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, :]\n a = tl.load(A + a_tile)\n\n b_tile = (m_index * block_m + tl.arange(0, block_m))[None, :, None] * dim_n * dim_k \\\n + (n_index * block_n + tl.arange(0, block_n))[:, None, None] * dim_k \\\n + (k_index * block_k + tl.arange(0, block_k))[None, None, :]\n b = tl.load(B + b_tile)\n\n expanded_a, _ = tl.broadcast(a, b)\n vecmat += tl.trans(tl.sum(expanded_a * b, axis=2))\n\n tl.store(output + output_tile, vecmat)\n\n\ndef batched_vecmat(\n M, N, K, block_m, block_n, block_k, num_warps=4, num_stages=1\n):\n\n A = torch.randn(M, K, device='cuda', dtype=torch.float32) # shape: [M, K]\n B = torch.randn(M, N, K, device='cuda', dtype=torch.float32) # shape: [M, N, K]\n output = torch.zeros(M, N, device='cuda', dtype=torch.float32) # 输出张量,shape: [M, N]\n\n assert K % block_k == 0, \"\"\n assert M % block_m == 0, \"\"\n assert N % block_n == 0, \"\"\n\n grid = (M // block_m, N // block_n)\n\n # 调用 Triton Kernel\n batched_vecmat_kernel[grid](\n A,\n B,\n M, N, K,\n output,\n block_m=block_m,\n block_n=block_n,\n block_k=block_k,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return output\n\n\n\n" + ] + }, + { + "id": "openseek-8-b31e29a78632426c85b9255e288fff1d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton implementation provides a mechanism to calculate cross-entropy loss and its gradients efficiently using GPU parallelism. It involves two main kernels: `cross_entropy_fwd_kernel` for the forward pass and `cross_entropy_bwd_kernel` for the backward pass.\n\n The `cross_entropy_fwd_kernel` is executed for each row (determined by `tl.program_id(0)`) and for each block of columns within that row (determined by `tl.program_id(1)`). The kernel computes the logits for each block, applies logit scaling, and calculates the log-sum-exp (lse). It also handles scenarios with label smoothing, optionally computes a scaled squared value of the lse (z_loss), and stores both the loss and z_loss in the provided pointers.\n\n Specific checks are included for whether the label at the current index matches an `ignored_index`, which sets the loss and z_loss to zero if true. Moreover, conditional compilation flags like `HAS_SMOOTHING` and `SPLIT` enable or disable features such as smoothing and splitting the loss computation.\n\n The `cross_entropy_bwd_kernel` computes the gradient of the loss with respect to the logits (dlogits). It calculates probabilities from the logits and adjusts them based on the lse. Depending on whether label smoothing is enabled, it modifies the gradients accordingly and stores them. The grid configuration uses row index and block index to partition the computation efficiently.\n\n The wrapper functions, `cross_entropy_fwd` and `cross_entropy_bwd`, prepare and dispatch these kernels with the specified grid dimensions based on the input tensor shapes. They also print intermediate results such as `loss`, `lse`, `z_loss`, and `dlogits` for debugging purposes.\n\n Parameters:\n - logits: Input tensor containing the logit values.\n - labels: Tensor with the target class labels.\n - smoothing, logit_scale, lse_square_scale: Scalars adjusting the loss computation.\n - ignored_index: Specifies any label to ignore.\n - total_classes, class_start_idx: Help manage subsetting classes in parallel scenarios.\n - BLOCK_SIZE: Constant specifying the number of columns processed in parallel.\n - HAS_SMOOTHING, SPLIT: Boolean flags for feature toggling.\n\n Outputs:\n - Forward pass: Returns tensors for `loss`, `lse`, and `z_loss`.\n - Backward pass: Returns tensor `dlogits` containing gradients of the logits.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n z_loss_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n SPLIT: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n max_logits = tl.max(logits, 0)\n if HAS_SMOOTHING:\n sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)\n lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits\n tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)\n if label_idx == ignored_index:\n loss = 0.0\n z_loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(\n n_cols, (col_block_idx + 1) * BLOCK_SIZE\n ):\n logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n z_loss = lse_square_scale * lse * lse\n loss += z_loss\n else:\n z_loss = 0.0\n tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)\n if not SPLIT:\n tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)\n\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignored_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\ndef cross_entropy_fwd(\n logits, labels, smoothing, logit_scale, lse_square_scale, ignored_index, total_classes, class_start_idx, BLOCK_SIZE, HAS_SMOOTHING, SPLIT\n):\n n_rows, n_cols = logits.shape\n loss = torch.empty((n_rows, n_cols), dtype=torch.float32, device=logits.device)\n lse = torch.empty((n_rows, n_cols), dtype=torch.float32, device=logits.device)\n z_loss = torch.empty((n_rows, n_cols), dtype=torch.float32, device=logits.device)\n \n grid = (n_rows, (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE)\n \n # 调用前向内核,传递相关参数\n cross_entropy_fwd_kernel[grid](\n loss, lse, z_loss, logits, labels, smoothing, logit_scale, lse_square_scale, ignored_index, total_classes, class_start_idx, n_cols, n_rows, logits.stride(0), BLOCK_SIZE, HAS_SMOOTHING, SPLIT\n )\n \n # 打印损失、LSE和z_loss,帮助调试\n print(f\"Forward loss: {loss}\")\n print(f\"Forward LSE: {lse}\")\n print(f\"Forward z_loss: {z_loss}\")\n \n return loss, lse, z_loss\n\ndef cross_entropy_bwd(\n dloss, logits, lse, labels, smoothing, logit_scale, lse_square_scale, ignored_index, total_classes, class_start_idx, BLOCK_SIZE, HAS_SMOOTHING\n):\n n_rows, n_cols = logits.shape\n dlogits = torch.empty_like(logits)\n \n grid = (n_rows, (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE)\n \n # 调用反向内核,传递相关参数\n cross_entropy_bwd_kernel[grid](\n dlogits, dloss, logits, lse, labels, smoothing, logit_scale, lse_square_scale, ignored_index, total_classes, class_start_idx, n_cols, logits.stride(0), dlogits.stride(0), dloss.stride(0), BLOCK_SIZE, HAS_SMOOTHING\n )\n \n # 打印反向梯度,帮助调试\n print(f\"Backward dlogits: {dlogits}\")\n \n return dlogits\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4d1e0dc67eba4ecd97100475f761157f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel function `_fwd_kernel_int8kv` is designed to perform a scaled dot-product attention operation on inputs Q, K, and V, where Q represents the query tensor, K represents the key tensor, and V represents the value tensor. The function computes attention scores by taking the dot product of Q and K, scales them using a softmax scaling factor (derived from the inverse square root of the model dimension and a logarithmic factor), and applies a causal mask to enforce autoregressive behavior. The resulting probabilities are then used to weigh the values in V, producing the final attention output stored in Out. The function uses constant expression parameters (`H`, `BLOCK_DMODEL`, `BLOCK_M`, `BLOCK_N`) to define block sizes and tensor dimensions for efficient parallel execution. It accommodates prompt caching and supports multiple head and batch sizes. The wrapper function `context_attention_fwd_ppl_int8kv` configures the execution grid, sets parameters based on input tensor dimensions, and invokes the kernel. The wrapper ensures the block size is adjusted based on the GPU architecture (e.g., Tesla GPUs). This implementation is tailored for high-performance operations on compatible hardware, particularly optimized for workloads requiring efficient memory and computational throughput.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\nTESLA = \"Tesla\" in torch.cuda.get_device_name(0)\n\n@triton.jit\ndef _fwd_kernel_int8kv(\n Q,\n K,\n V,\n sm_scale,\n Out,\n B_Start_Loc,\n B_Seqlen,\n b_prompt_cache_len,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kb,\n stride_kh,\n stride_ks,\n stride_kd,\n stride_vb,\n stride_vh,\n stride_vs,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n H: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n cur_bh = tl.program_id(1)\n cur_batch = cur_bh // H\n cur_head = cur_bh % H\n\n cur_kv_head = cur_head // kv_group_num\n prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = block_start_loc + tl.arange(0, BLOCK_M)\n off_q = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs\n + cur_head * stride_qh\n + offs_d[None, :] * stride_qd\n )\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n block_end_loc = tl.minimum(block_start_loc + BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len)\n # causal mask\n for start_n in range(0, block_mask * block_end_loc, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n # k = tl.load(\n # k_ptrs + (start_n + offs_n[None, :]) * stride_ks,\n # mask=(start_n + offs_n[None, :]) < block_end_loc,\n # other=0,\n # )\n off_k = (\n cur_batch * stride_kb\n + (start_n + offs_n[None, :]) * stride_ks\n + cur_kv_head * stride_kh\n + offs_d[:, None] * stride_kd\n )\n k = tl.load(K + off_k, mask=(start_n + offs_n[None, :]) < block_end_loc, other=0.0)\n\n qk = tl.dot(q, k)\n mask = (offs_m[:, None] + prompt_cache_len) >= (start_n + offs_n[None, :])\n qk = tl.where(mask, qk * sm_scale, -1.0e8)\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk -= m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n\n # -- update m_i and l_i\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n # v = tl.load(\n # v_ptrs + (start_n + offs_n[:, None]) * stride_vs,\n # mask=(start_n + offs_n[:, None]) < block_end_loc,\n # other=0.0,\n # )\n off_v = (\n cur_batch * stride_vb\n + (start_n + offs_n[:, None]) * stride_vs\n + cur_kv_head * stride_vh\n + offs_d[None, :] * stride_vd\n )\n v = tl.load(V + off_v, mask=(start_n + offs_n[:, None]) < block_end_loc, other=0.0)\n\n p = p.to(v.dtype)\n acc = tl.dot(p, v, acc)\n # update m_i\n m_i = m_ij\n\n acc = acc / l_i[:, None]\n off_o = (\n (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs\n + cur_head * stride_oh\n + offs_d[None, :] * stride_od\n )\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n\n\n@torch.no_grad()\ndef context_attention_fwd_ppl_int8kv(q, k, v, o, b_start_loc, b_seq_len, max_input_len, b_prompt_cache_len):\n BLOCK_M = 128 if not TESLA else 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128, 256}\n\n # 计算scale系数, 并乘以 1/log(2) = 1.4426950408889634,\n # 算子内部使用 tl.math.exp2 来使计算与标准attention等价。\n sm_scale = 1.0 / (Lq ** 0.5) * 1.4426950408889634\n batch, head = b_seq_len.shape[0], q.shape[1]\n kv_group_num = q.shape[1] // k.shape[1]\n\n grid = lambda meta: (triton.cdiv(max_input_len, meta[\"BLOCK_M\"]), batch * head, 1)\n BLOCK_N = BLOCK_M\n num_warps = 4 if Lk <= 64 else 8\n num_stages = 1\n\n _fwd_kernel_int8kv[grid](\n q,\n k,\n v,\n sm_scale,\n o,\n b_start_loc,\n b_seq_len,\n b_prompt_cache_len,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k.stride(3),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n v.stride(3),\n o.stride(0),\n o.stride(1),\n o.stride(2),\n kv_group_num=kv_group_num,\n H=head,\n BLOCK_DMODEL=Lk,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-5c240bd91c81494a958dc07d1405cf91", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements a Triton kernel called `masked_add_kernel` and a function `masked_add` to invoke the kernel. The kernel operates on CUDA tensors `grad_ptr`, `p_ptr`, and `p_mask_ptr` over blocks of size `BLOCK_SIZE`. It calculates indices to work on using `tl.program_id` and checks bounds using a `mask`. It then loads elements from `p_mask_ptr` and converts them to boolean masks using `tl.int1`. The kernel performs an element-wise addition on `grad_ptr` with `p_ptr` multiplied by `alpha`, masked by `p_mask`. The result is stored back in `grad_ptr`. The wrapper function `masked_add` asserts that all tensors are on CUDA, with the same layout and strides. It determines the grid size based on `n_elements` and launches the kernel with the specified block size, passing tensor data and parameters for computation.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef masked_add_kernel(grad_ptr,\n p_ptr,\n p_mask_ptr,\n n_elements,\n alpha,\n BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n p_mask = tl.load(p_mask_ptr + offsets, mask=mask).to(tl.int1)\n mask = mask & ~p_mask\n p = tl.load(p_ptr + offsets, mask=mask)\n grad = tl.load(grad_ptr + offsets, mask=mask)\n grad += p * alpha\n tl.store(grad_ptr + offsets, grad, mask=mask)\n\ndef masked_add(grad: torch.Tensor, p_data: torch.Tensor, p_mask: torch.Tensor, alpha: float = 0):\n '''\n equivalent to\n grad.add_(p.data * (1 - p.mask), alpha=decay)\n '''\n assert grad.is_cuda and p_data.is_cuda and p_mask.is_cuda\n assert (grad.layout, p_data.layout, p_mask.layout) == (torch.strided, torch.strided, torch.strided)\n assert grad.stride() == p_data.stride() == p_mask.stride()\n n_elements = grad.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n masked_add_kernel[grid](grad, p_data, p_mask, n_elements, alpha, BLOCK_SIZE=1024)\n\n\n\n" + ] + }, + { + "id": "openseek-8-593da267576a492e8ec3dbd15cb38023", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel '_triton_rope' is a JIT-compiled function designed for efficient application of rotary position embeddings to input matrices q (queries) and k (keys). The function takes pointers to these matrices, their row strides, cosine and sine rotation vectors, and several configuration parameters as inputs. Inside the kernel, program_id (pid) is used to distribute work across elements in a batch-sequence space.\n\nThe kernel reads slices of the q and k matrices, applies rotary transformations using cosine and sine components, and stores the transformed slices back. The transformation is effectively a 2D rotation in each head's dimensional space, achieved through vectorized operations. The choice of forward or backward transformation is determined by the BACKWARD_PASS constant, allowing flexible operations depending on the training phase.\n\nThe function 'rope_forward' serves as a wrapper around '_triton_rope'. It first transposes the query and key matrices to the appropriate format, computes the necessary paddings using Triton's utility functions to ensure efficient processing, and calls the kernel on each batch-sequence pair. It configures the execution grid based on the number of rows (batch_size * seq_len) and ensures the inputs are contiguous, optimizing memory access patterns. Finally, it returns the matrices to their original shapes and provides the updated q, k, cos, and sin matrices.\n", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_rope(\n q_ptr,\n q_row_stride,\n k_ptr,\n k_row_stride,\n cos,\n cos_row_stride,\n sin,\n sin_row_stride,\n sl,\n bs: tl.constexpr,\n n_qh: tl.constexpr,\n n_kh: tl.constexpr,\n hd: tl.constexpr,\n pad_n_qh: tl.constexpr,\n pad_n_kh: tl.constexpr,\n pad_hd: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n BACKWARD_PASS: tl.constexpr = False,\n):\n pid = tl.program_id(0)\n\n q_ptr = q_ptr + pid * q_row_stride\n k_ptr = k_ptr + pid * k_row_stride\n\n cos_row_idx = pid % (sl)\n cos = cos + cos_row_idx * cos_row_stride\n sin = sin + cos_row_idx * sin_row_stride\n cos_offsets = tl.arange(0, pad_hd // 2)\n cos_mask = cos_offsets < hd // 2\n cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)\n sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)\n\n first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)\n k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)\n\n second_half_q_offsets = first_half_q_offsets + (hd // 2)\n second_half_k_offsets = first_half_k_offsets + (hd // 2)\n second_q_mask = first_q_mask\n second_k_mask = first_k_mask\n q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)\n k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)\n\n if not BACKWARD_PASS:\n new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n else:\n new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n\n\ndef rope_forward(q, k, cos, sin):\n q = q.transpose(1, 2)\n k = k.transpose(1, 2)\n\n batch_size, seq_len, n_q_head, head_dim = q.shape\n n_kv_head = k.shape[2]\n pad_hd = triton.next_power_of_2(head_dim)\n pad_n_q_head = triton.next_power_of_2(n_q_head)\n pad_n_kv_head = triton.next_power_of_2(n_kv_head)\n BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)\n\n n_row = batch_size * seq_len\n\n q = q.contiguous()\n k = k.contiguous()\n cos = cos.contiguous()\n sin = sin.contiguous()\n\n _triton_rope[(n_row,)](\n q,\n q.stride(1),\n k,\n k.stride(1),\n cos,\n cos.stride(-2),\n sin,\n sin.stride(-2),\n seq_len,\n batch_size,\n n_q_head,\n n_kv_head,\n head_dim,\n pad_n_q_head,\n pad_n_kv_head,\n pad_hd,\n BLOCK_SIZE=BLOCK_SIZE,\n BACKWARD_PASS=False,\n )\n return q.transpose(1, 2), k.transpose(1, 2), cos, sin\n\n\n\n" + ] + }, + { + "id": "openseek-8-708fc685b45e42bda23b94eeb2a8101b", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel _bmm_chunk_bwd_kernel is designed to compute the backward pass for a batched matrix multiplication operation, particularly in scenarios involving chunking of one of the matrices. The kernel is decorated with several configurations for autotuning, each specifying the number of stages, warps, and block sizes (BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_CS) for optimal execution.\n\n Inputs include pointers to matrices (a_ptr, dout_ptr) and their respective strides and dimensions, such as stride_a_batch, stride_dout_csize_m, etc. The kernel calculates the partial dot product accumulation using tl.dot and can add a residual from res_ptr if HAS_RESIDUAL is set to True. The result is stored in db_ptr.\n\n The auxiliary function _bmm_chunk_bwd prepares tensors, ensuring they are contiguous and appropriately shaped, while determining the computation grid size based on the meta-parameters. It invokes the Triton kernel using triton.jit for execution on the CUDA device associated with the input tensor. This function ensures alignment of tensor dimensions, strides, and types, converting them as necessary for the underlying Triton operations.\n ", + "output": [ + "import math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_CS': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K'],\n)\n@triton.jit\ndef _bmm_chunk_bwd_kernel(\n a_ptr, dout_ptr, db_ptr, res_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_dout_batch, stride_dout_chunk, stride_dout_head, stride_dout_csize_m, stride_dout_csize_n,\n stride_db_batch, stride_db_seqlen, stride_db_head, stride_db_k,\n stride_res_batch, stride_res_seqlen, stride_res_head, stride_res_k,\n dot_dtype: tl.constexpr,\n HAS_RESIDUAL: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_CS: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(K, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n dout_ptr += pid_b * stride_dout_batch + pid_c * stride_dout_chunk + pid_h * stride_dout_head\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_cs = tl.arange(0, BLOCK_SIZE_CS)\n dout_ptrs = dout_ptr + (offs_m[:, None] * stride_dout_csize_n + offs_cs[None, :] * stride_dout_csize_m)\n a_ptrs = a_ptr + (offs_cs[:, None] * stride_a_seqlen + offs_n[None, :] * stride_ak)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for cs in range(0, tl.cdiv(chunk_size_limit, BLOCK_SIZE_CS)):\n dout = tl.load(dout_ptrs, mask=(offs_m[:, None] < chunk_size) & (offs_cs[None, :] < chunk_size_limit - cs * BLOCK_SIZE_CS), other=0.0).to(dot_dtype)\n a = tl.load(a_ptrs, mask=(offs_cs[:, None] < chunk_size_limit - cs * BLOCK_SIZE_CS) & (offs_n[None, :] < K), other=0.0).to(dot_dtype)\n acc += tl.dot(dout, a)\n dout_ptrs += BLOCK_SIZE_CS * stride_dout_csize_m\n a_ptrs += BLOCK_SIZE_CS * stride_a_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_RESIDUAL:\n res_ptr += pid_b * stride_res_batch + pid_c * chunk_size * stride_res_seqlen + pid_h * stride_res_head\n res_ptrs = res_ptr + (offs_m[:, None] * stride_res_seqlen + offs_n[None, :] * stride_res_k)\n res = tl.load(res_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K)).to(tl.float32)\n acc += res\n db = acc.to(db_ptr.dtype.element_ty)\n\n db_ptr += pid_b * stride_db_batch + pid_c * chunk_size * stride_db_seqlen + pid_h * stride_db_head\n db_ptrs = db_ptr + (offs_m[:, None] * stride_db_seqlen + offs_n[None, :] * stride_db_k)\n tl.store(db_ptrs, db, mask=(offs_m[:, None] < chunk_size_limit) & (offs_n[None, :] < K))\n\ndef _bmm_chunk_bwd(a, dout, residual=None, out=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n nchunks, chunk_size = dout.shape[1], dout.shape[-1]\n if a.stride(-1) != 1 and a.stride(-2) != 1:\n a = a.contiguous()\n if dout.stride(-1) != 1 and dout.stride(-2) != 1:\n dout = dout.contiguous()\n if residual is not None:\n assert residual.shape == (batch, seqlen, k) if not has_groups else (batch, seqlen, ngroups, k)\n if residual.stride(-1) != 1 and residual.stride(1) != 1:\n residual = residual.contiguous()\n if out is not None:\n assert out.shape == a.shape\n assert out.stride(-1) == 1 or out.stride(1) == 1\n else:\n out = torch.empty_like(a)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or dout.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or dout.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(k, META['BLOCK_SIZE_N']), batch,\n nchunks if not has_groups else nchunks * ngroups)\n residual_strides = ((residual.stride(0), residual.stride(1), 0 if not has_groups else residual.stride(2),\n residual.stride(-1))\n if residual is not None else (0, 0, 0, 0))\n with torch.cuda.device(a.device.index):\n _bmm_chunk_bwd_kernel[grid](\n a, dout, out, residual,\n int(seqlen), int(chunk_size), int(k), int(ngroups if has_groups else 1),\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n dout.stride(0), dout.stride(1), 0 if not has_groups else dout.stride(2), dout.stride(-2), dout.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-1),\n residual_strides[0], residual_strides[1], residual_strides[2], residual_strides[3],\n dot_dtype,\n HAS_RESIDUAL=residual is not None,\n )\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-11686975255945f798bdbe23b1d632ef", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton implementation provides a specialized operator for computing the Kullback-Leibler divergence between predicted and target distributions, with support for various reduction modes and optimized gradient computation.\n\n The `kldiv_forward_triton` function orchestrates the forward computation of the KL divergence. Inputs to this function include `y_pred` (a tensor of predicted log-probabilities with shape [BT, V]), `y_true` (ground truth tensor of the same shape), `log_target` (a boolean flag indicating whether `y_true` is in log-space), `reduction` (a string determining how the results are aggregated: \"none\", \"sum\", \"mean\", or \"batchmean\"), and `eps` (a small value added for numerical stability in logarithmic operations).\n\n The core computation is executed by the `_kldiv_kernel_forward` kernel. This kernel processes each batch independently. It calculates the KL divergence using the formula `KL(y_true || y) = y_true * (log(y_true) - log(y))` when `log_target` is False, and uses `loss = exp(y_true) * (y_true - y)` otherwise. The kernel iterates over columns in blocks of size `BLOCK_SIZE`. If `reduction` is \"none\", the computed losses for each position are stored directly in `loss_ptr`. For other reduction modes, the losses are summed over the batch or entire tensor as specified.\n\n In the backward pass, handled by `kldiv_backward_triton`, the gradients with respect to the inputs are calculated. The inputs are `target` (ground truth values), `grad_output` (incoming gradient from the network), `new_grads` (a tensor to hold the computed gradients), and `log_target`. The `_kldiv_kernel_backward` kernel performs the differentiation step, outputting gradients adjusted according to whether `log_target` is true (where the exponential form is used).\n\n Parameters such as `BLOCK_SIZE` and `num_warps` are crucial for tuning the computation's performance, dictating the block of data processed per iteration and parallelism level, respectively.\n ", + "output": [ + "from typing import Literal\nimport torch\nimport triton\nimport triton.language as tl\n\n\nMAX_FUSED_SIZE = 65536 // 4 # 65536 // 4 or 8 works the best\nREDUCTION_LITERAL = Literal[\"none\", \"sum\", \"mean\", \"batchmean\"]\n_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)\n\n@triton.jit\ndef _kldiv_kernel_forward(\n y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space\n y_stride, # int, prediction stride\n gt_ptr, # [B, S], ground truth ptr\n gt_stride, # int, ground truth stride\n loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr\n loss_stride, # int, output stride\n n_cols, # int, number of columns in the input tensor\n eps,\n BLOCK_SIZE: tl.constexpr,\n log_target: tl.constexpr = False,\n reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,\n):\n pid = tl.program_id(0).to(tl.int64)\n y_ptr += pid * y_stride\n gt_ptr += pid * gt_stride\n loss_ptr += pid * loss_stride\n\n base_offsets = tl.arange(0, BLOCK_SIZE)\n\n loss_sum = 0.0\n for i in range(0, n_cols, BLOCK_SIZE):\n offsets = i + base_offsets\n mask = offsets < n_cols\n y = tl.load(y_ptr + offsets, mask=mask, other=0.0)\n y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)\n\n # KL(y_true || y) = y_true * (log(y_true) - log(y))\n # We compute KL(y_true || y) with y in the log-space\n if not log_target:\n loss = y_true * (tl.log(tl.maximum(y_true, eps)) - y)\n else:\n loss = tl.exp(y_true) * (y_true - y)\n\n if reduction == 0: # _REDUCTION_MODE_NONE\n tl.store(loss_ptr + offsets, loss, mask=mask)\n else:\n loss_sum += tl.sum(loss, axis=0)\n\n if reduction != 0:\n tl.store(loss_ptr, loss_sum)\n\n\n@triton.jit\ndef _kldiv_kernel_backward(\n target_ptr,\n target_stride,\n new_grads_ptr,\n new_grads_stride,\n n_cols,\n BLOCK_SIZE: tl.constexpr,\n log_target: tl.constexpr = False,\n):\n pid = tl.program_id(0).to(tl.int64)\n\n target_ptr += pid * target_stride\n new_grads_ptr += pid * new_grads_stride\n\n offsets = tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_cols\n\n for i in range(0, n_cols, BLOCK_SIZE):\n offsets = i + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_cols\n\n target = tl.load(target_ptr + offsets, mask=mask, other=0.0)\n\n if not log_target:\n res = target * -1\n else:\n res = -tl.exp(target)\n\n tl.store(new_grads_ptr + offsets, res, mask=mask)\n\n\ndef kldiv_forward_triton(y_pred, y_true, log_target, reduction, eps): # [BT, V]\n BT, V = y_pred.shape\n\n BLOCK_SIZE = min(16384, triton.next_power_of_2(V))\n num_warps = 4 if BLOCK_SIZE < 2048 else 8 if BLOCK_SIZE < 8192 else 16 if BLOCK_SIZE < 32768 else 32\n\n grid = (BT,)\n reduction = {\"none\": 0, \"sum\": 1, \"mean\": 2, \"batchmean\": 3}[reduction]\n\n out_size = (BT, V) if reduction == 0 else (BT,)\n output_tensor = torch.zeros(out_size, device=y_pred.device, dtype=torch.float32)\n\n _kldiv_kernel_forward[grid](\n y_pred,\n y_pred.stride(0),\n y_true,\n y_true.stride(0),\n output_tensor,\n output_tensor.stride(0),\n V,\n eps=eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n log_target=log_target,\n reduction=reduction,\n )\n\n if reduction == 3: # _REDUCTION_MODE_BATCHMEAN\n return output_tensor.sum() / BT\n elif reduction == 1: # _REDUCTION_MODE_SUM\n return output_tensor.sum(dim=0)\n elif reduction == 2: # _REDUCTION_MODE_MEAN\n return output_tensor.sum() / (BT * V)\n else:\n return output_tensor\n\n\ndef kldiv_backward_triton(target, grad_output, new_grads, log_target):\n BT, V = target.shape\n\n BLOCK_SIZE = min(16384, triton.next_power_of_2(V))\n num_warps = 4 if BLOCK_SIZE < 2048 else 8 if BLOCK_SIZE < 8192 else 16 if BLOCK_SIZE < 32768 else 32\n\n grid = (BT,)\n\n _kldiv_kernel_backward[grid](\n target,\n target.stride(0),\n new_grads,\n new_grads.stride(0),\n V,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n log_target=log_target,\n )\n\n if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)):\n return new_grads\n\n return new_grads * grad_output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-ab9912dc3e0f4deeb6fdd996a1d71e8a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-based softmax operator consists of two main functions: `softmax_kernel` and `softmax`.\n The `softmax_kernel` is decorated with `@triton.jit`, indicating it is compiled for execution on Triton-compatible GPUs. It takes pointers to input and output data, stride information for row access, and a constant `BLOCK_SIZE`.\n The kernel function computes the softmax transformation for each row independently. It uses `tl.program_id(0)` to get the current row index for the 1D grid. It calculates the input pointers for the row, loads data into SRAM with `tl.load`, subtracts the maximum value to ensure numerical stability, and applies the exponential function to compute the numerator of the softmax.\n The denominator is calculated by summing the exponentiated values. The output is computed as the element-wise division of the numerator by the denominator and stored back using `tl.store`.\n The `softmax` function configures execution parameters such as grid size and block size based on input dimensions. It creates an empty output tensor and invokes the `softmax_kernel` using Triton’s grid execution syntax, where each row of the input tensor is processed by an individual Triton kernel instance.\n It computes the `BLOCK_SIZE` as the next power of two of the number of columns to optimize for memory alignment and performance, adjusting the number of warps for larger blocks. \n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n # Subtract maximum for numerical stability\n row_minus_max = row - tl.max(row, axis=0)\n # Note that exponentiation in Triton is fast but approximate (i.e., think __expf in CUDA)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row of the input matrix\n softmax_kernel[(n_rows, )](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c2d868d8b5a646b79275ff3acd648064", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton operator is designed to compute the power of each element in a tensor by a scalar value, using a specialized kernel function `pow_func_scalar_tensor_kernel_rank_1`. The wrapper function `pow_func_scalar_tensor_wrapper_rank_1` sets up the computation by ensuring input-output shape alignment and using heuristics to determine optimal computation parameters, such as tile size, number of warps, and grid configuration. It handles both torch `Tensor` and `StridedBuffer` inputs. The kernel function processes data in parallel, managing memory offsets and data loading using block pointers, and performs boundary checks to ensure safe memory operations. The operation is divided into tiles, and depending on the configuration, it either uses a monolithic or a grid-stride-loop style kernel execution. The use of Triton's `pow` function ensures compatibility and efficiency in computing the exponentiation on the GPU.\n ", + "output": [ + "import math\nfrom typing import Union\nimport torch\nimport triton\nfrom triton import language as tl\n\ntry:\n from triton.language.extra.cuda.libdevice import pow as _pow\nexcept ImportError:\n try:\n from triton.language.math import pow as _pow\n except ImportError:\n from triton.language.libdevice import pow as _pow\n\ndef heuristics_for_tile_size(max_tile_size, *sizes):\n ndim = len(sizes)\n tile_sizes = [0 for _ in range(ndim)]\n for i in range(ndim):\n size = sizes[ndim - 1 - i]\n tile_size = min(max_tile_size, triton.next_power_of_2(size))\n tile_sizes[ndim - 1 - i] = tile_size\n max_tile_size = max(1, max_tile_size // tile_size)\n return tuple(tile_sizes)\n\ndef heuristics_for_num_warps(tile_size):\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n \n\nclass StridedBuffer:\n \"\"\"A drop-in replacement of torch.Tensor that can be used in wrapper generated by\n PointwiseDynamicFunction. It allows us to use a different shape, stride, data\n pointer that that of the base tensor.\n\n It is a kind of reinterpretation of the base tensor. We make this class since we\n cannot get a Tensor view with negative strides via torch APIs, while we need this\n to implement flip op.\n\n Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer\n may not have all the methods as torch.Tensors do. We add some attributes & methods\n with the same name as torch.Tensor, which are used in the generated code. But we\n may not cover all the methods, add one if what you need is missing here.\n\n And can also be used in triton kernels since it also has dtype & data_ptr().\n \"\"\"\n\n def __init__(\n self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0\n ):\n self._base = base\n self.dtype = dtype or base.dtype\n if offset == 0:\n self._data_ptr = self._base.data_ptr()\n else:\n offset = self.dtype.itemsize * offset\n self._data_ptr = self._base.data_ptr() + offset\n self.shape = tuple(shape if shape is not None else self._base.shape)\n self._strides = tuple(strides if strides is not None else self._base.stride())\n self.device = self._base.device\n self.ndim = len(self.shape)\n\n def stride(self):\n return self._strides\n\n def size(self):\n return self.shape\n\n def element_size(self):\n return self.dtype.itemsize\n\n def numel(self):\n return math.prod(self.shape)\n\n def dim(self):\n return self.ndim\n\n def unwrap(self):\n return self._base\n\n def data_ptr(self):\n return self._data_ptr\n\n\ndef pow_func_scalar_tensor_wrapper_rank_1(val0, in0: Union[torch.Tensor, StridedBuffer], /, *, out0: Union[torch.Tensor, StridedBuffer]): \n \"\"\"Generated wrapper function with Pointwise: scalar, StridedBuffer, StridedBuffer(a1!) -> StridedBuffer(a1!)\"\"\"\n assert in0.shape == out0.shape, 'operand shapes mismatch'\n # task partitioning\n shape = out0.shape\n num_tasks = out0.numel()\n tile_sizes = heuristics_for_tile_size(512, *shape)\n tile_size = math.prod(tile_sizes)\n num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))\n num_ctas = min(65536, num_tiles)\n tiles_per_cta = triton.cdiv(num_tiles, num_ctas)\n num_warps = heuristics_for_num_warps(tile_size)\n one_tile_per_cta = tiles_per_cta==1\n grid = (num_ctas, 1, 1)\n # kernel launch\n in0_strides = in0.stride()\n in0_stride_order = (0,)\n out0_strides = out0.stride()\n out0_stride_order = (0,)\n with torch.cuda._DeviceGuard(in0.device.index):\n pow_func_scalar_tensor_kernel_rank_1[grid](\n val0, in0, out0,\n in0_strides[0], # stride for in0\n in0_stride_order[0], # stride order for in0\n out0_strides[0], # stride for out0\n out0_stride_order[0], # stride orderfor out0\n shape[0], # task indexing space\n num_tasks, # num tasks\n tiles_per_cta=tiles_per_cta, # tiles_per_cta\n tile_size0=tile_sizes[0],\n one_tile_per_cta=one_tile_per_cta,\n num_warps=num_warps,\n )\n return out0\n\n@triton.jit\ndef pow_func_scalar_tensor(x, exponent):\n return _pow(x.to(tl.float32), exponent)\n\n@triton.jit(do_not_specialize=['val0'])\ndef pow_func_scalar_tensor_kernel_rank_1(\n val0,\n in0_ptr: tl.tensor, # of tl.pointer_type\n out0_ptr: tl.tensor, # of tl.pointer_type\n in0_stride0: int, # strides for in0\n in0_stride_order0: tl.constexpr, # stride order for in0\n out0_stride0: int, # strides for out0\n out0_stride_order0: tl.constexpr, # stride order for out0\n s0: int, # task_space\n num_tasks: int,\n tiles_per_cta: int,\n tile_size0: tl.constexpr,\n one_tile_per_cta: tl.constexpr,\n):\n pid = tl.program_id(0)\n num_tiles0 = tl.cdiv(s0, tile_size0)\n if one_tile_per_cta: # monolitic kernel style\n tile_id = pid\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = pow_func_scalar_tensor(val0, in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n else: # grid-stride-loop style kernel\n num_ctas = tl.num_programs(0)\n for j in range(0, tiles_per_cta):\n tile_id = pid + j * num_ctas\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = pow_func_scalar_tensor(val0, in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-9eab71e5a4ea4fe7b0d785392ba1e902", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code defines a Triton-accelerated L2 normalization routine, consisting of two parts: a forward pass `_l2_norm_fwd` and a backward pass `_l2_norm_bwd`. Both parts utilize their respective kernels `_l2_norm_fwd_1pass_kernel` and `_l2_norm_bwd_kernel` for GPU-accelerated computations.\n\n - `_l2_norm_fwd_1pass_kernel`: This function is a Triton kernel that performs the forward L2 normalization on a per-row basis for a 2D input tensor `X`. The input tensor is expected to have its last dimension stride as 1. The kernel computes the L2 norm by:\n * Loading a block of data from the current row specified by `row = tl.program_id(0)`.\n * Calculating the sum of squares to obtain the variance.\n * Computing the reciprocal square root of the variance (rstd) to normalize `x`.\n * Storing the normalized result into the output tensor `Y`.\n\n - `_l2_norm_bwd_kernel`: This kernel performs the backward pass necessary for gradient computation during L2 normalization. It takes the input tensor `X`, its gradients `DY`, and outputs the gradient with respect to `X` in `DX`. The computations include:\n * Loading the inputs and gradients for the row defined by `row = tl.program_id(0)`.\n * Using the precomputed variance and rstd from the forward pass to determine gradients.\n * Applying normalization logic to compute the gradient of the loss with respect to the input `X`.\n * Storing the computed gradient into `DX`.\n\n Both `_l2_norm_fwd` and `_l2_norm_bwd` handle tensor reshaping and preparation before kernel invocation. They also enforce a constraint on the feature dimension size, ensuring it remains below 64KB to fit Triton's kernel memory limits. These routines are designed to work in environments with PyTorch and Triton, aiming at efficient parallel execution on modern GPUs.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n # Compute mean and variance\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) \n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n mask = cols < N\n y = x * rstd\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x) \n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\ndef _l2_norm_fwd(\n x, eps=1e-6\n):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n assert x.stride(-1) == 1 \n # allocate output\n y = torch.empty_like(x)\n assert y.stride(-1) == 1\n N = x.shape[-1]\n M = x.shape[0]\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\ndef _l2_norm_bwd(\n x, dy, eps=1e-5,\n):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n assert dy.shape == x.shape\n # allocate output\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-aff8fcb956d8465a89f1901b17b7399f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code defines a custom layer normalization operation using Triton, a library designed for writing high-performance custom GPU kernels. The `calculate_settings` function computes the optimal block size and number of warps for a given number of columns `n`. It ensures the block size does not exceed the maximum allowed size (`MAX_FUSED_SIZE`). The `layernorm_forward` function is a JIT-compiled kernel for computing the forward pass of layer normalization. It calculates the mean and variance of each row of the input `X`, applies normalization using these statistics, and scales/shifts the result with learned parameters `W` and `b`. The `layernorm_backward` function computes the gradients of the input with respect to the output loss gradient `dY`, using stored intermediate results (`r`, `mu`) from the forward pass. `Fast_Layernorm` class is a PyTorch `autograd.Function` implementing this custom layer normalization. The `forward` method sets up the computation and calls `layernorm_forward`, while `backward` calls `layernorm_backward` to compute the gradient of `X` with respect to `dY`. The `fast_layernorm` function applies this custom layer normalization by interfacing with PyTorch's `nn.LayerNorm` parameters, ensuring that the operation is done with `W` and `b` derived from the PyTorch module.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\nnext_power_of_2 = triton.next_power_of_2\nMAX_FUSED_SIZE : int = 65536\n\ndef calculate_settings(n : int) -> (int, int,):\n BLOCK_SIZE : int = next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n num_warps : int = 4\n if BLOCK_SIZE >= 32768: num_warps = 32\n elif BLOCK_SIZE >= 8192: num_warps = 16\n elif BLOCK_SIZE >= 2048: num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W,\n b,\n r,\n mu,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr\n):\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx\n mu += row_idx\n\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n mean_X = tl.sum(X_row, axis = 0) / n_cols\n XX = X_row - mean_X\n row_var = tl.sum(XX * XX, axis = 0) / n_cols\n inv_var = tl.math.rsqrt(row_var + eps)\n tl.store (r, inv_var)\n tl.store (mu, mean_X)\n output = (XX * inv_var) * W_row + b_row\n tl.store(Y + col_offsets, output, mask = mask)\n\n@triton.jit\ndef layernorm_backward(\n dY, dY_row_stride,\n X, X_row_stride,\n W,\n b,\n r,\n mu,\n n_cols, eps,\n BLOCK_SIZE : tl.constexpr\n):\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dY += row_idx * dY_row_stride\n X += row_idx * X_row_stride\n r += row_idx\n mu += row_idx\n\n dY_row = tl.load(dY + col_offsets, mask = mask, other = 0).to(tl.float32)\n X_row = tl.load(X + col_offsets, mask = mask, other = 0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask = mask, other = 0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask = mask, other = 0).to(tl.float32)\n\n inv_var = tl.load(r) .to(tl.float32)\n mean = tl.load(mu).to(tl.float32)\n normed = (X_row - mean) * inv_var\n dY_W = dY_row * W_row\n dX_row = dY_W - tl.sum(dY_W, axis = 0) / n_cols - normed * tl.sum(dY_W * normed, axis = 0) / n_cols\n dX_row = dX_row * inv_var\n tl.store(dY + col_offsets, dX_row, mask = mask)\n\nclass Fast_Layernorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, X, W, b, eps):\n shape = X.shape\n dim = shape[-1]\n X = X.view(-1, dim)\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n Y = torch.empty((n_rows, n_cols), dtype = X.dtype, device = \"cuda:0\")\n r = torch.empty(n_rows, dtype = torch.float32, device = \"cuda:0\")\n mu = torch.empty(n_rows, dtype = torch.float32, device = \"cuda:0\")\n\n layernorm_forward[(n_rows,)](\n Y, Y.stride(0),\n X, X.stride(0),\n W,\n b,\n r,\n mu,\n n_cols, eps,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n ctx.eps = eps\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.save_for_backward(X, W, b, r, mu)\n return Y.view(*shape)\n \n @staticmethod\n def backward(ctx, dY):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.view(-1, dim)\n X, W, b, r, mu = ctx.saved_tensors\n n_rows, n_cols = dY.shape\n\n layernorm_backward[(n_rows,)](\n dY, dY.stride(0),\n X, X .stride(0),\n W,\n b,\n r,\n mu,\n n_cols, ctx.eps,\n BLOCK_SIZE = ctx.BLOCK_SIZE,\n num_warps = ctx.num_warps,\n )\n dX = dY.view(*shape)\n return dX, None, None, None, None\n \ndef fast_layernorm(layernorm, X):\n assert(layernorm.elementwise_affine is True)\n W = layernorm.weight\n bias = layernorm.bias\n eps = layernorm.variance_epsilon if \\\n hasattr(layernorm, \"variance_epsilon\") \\\n else layernorm.eps\n out = Fast_Layernorm.apply(X, W, bias, eps)\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-ed04847d69274f4ba6b32a57371ee565", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton kernel implementation calculates the mean of a tensor across specified dimensions. The kernel, `mean_dim_kernel`, receives tensor `X`, a tensor `Mean` to store the output, the dimensions `M` and `N` indicating the size, and `BLOCK_M` and `BLOCK_N` defining the block sizes for processing. The kernel uses a program ID `pid` to map each computation block to the correct row in `X`. It calculates a mask `row_mask` to ensure operations only occur within valid tensor bounds.\n\n The kernel computes the mean across a specified dimension by iteratively loading elements from `X` in blocks and accumulating their sum in `_mean`. After summing, it divides by `N` to find the mean, storing the result in `Mean`.\n\n The function `dim_compress` is used to rearrange the input tensor dimensions for efficient memory access patterns. This function takes a tensor `inp` and a list of dimensions `dims` to reduce, returning a permuted tensor with these dimensions at the end.\n\n The `mean_dim` function is a wrapper that prepares the inputs for the kernel. It converts the input tensor `x` to a format compatible with the kernel by reducing dimensions specified in `dim`. It calculates `M` as the product of remaining dimensions after reduction and `N` as the product of reduction dimensions. It creates an empty output tensor `out` with the required shape. The kernel is then launched using a grid configuration determined by the `cdiv` function, dividing `M` by `BLOCK_M`. After execution, if `keepdim` is `False`, the output tensor's reduced dimensions are squeezed out. The function returns the resulting tensor.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef mean_dim_kernel(X, Mean, M, N, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr):\n # Map the program id to the row of X it should compute.\n pid = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)[:, None] \n X = X + pid * N\n Mean = Mean + pid\n row_mask = pid < M\n\n # Compute mean\n _mean = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n for off in range(0, N, BLOCK_N):\n cols = off + tl.arange(0, BLOCK_N)[None, :]\n col_mask = cols < N\n mask = row_mask and col_mask\n\n a = tl.load(X + cols, mask, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=1) / N\n mean = mean[:, None]\n tl.store(Mean, mean, row_mask)\n\ndef dim_compress(inp: torch.Tensor, dims):\n if isinstance(dims, int):\n dims = [dims]\n dim = inp.ndim\n stride = inp.stride()\n batch_dim = [i for i in range(dim) if i not in dims]\n sorted_reduction_dim = sorted(dims, key=lambda x: stride[x], reverse=True)\n order = batch_dim + sorted_reduction_dim\n return inp.permute(order).contiguous()\n\ndef mean_dim(x, dim, keepdim=False, *, dtype=None):\n if dtype is None:\n dtype = x.dtype\n \n shape = list(x.shape)\n if isinstance(dim, int):\n dim = [dim]\n dim = [d % x.ndim for d in dim]\n x = dim_compress(x, dim)\n N = 1\n for i in dim:\n N *= shape[i]\n shape[i] = 1\n M = x.numel() // N\n out = torch.empty(shape, dtype=dtype, device=x.device)\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]),)\n\n with torch.cuda.device(x.device):\n mean_dim_kernel[grid](x, out, M, N, BLOCK_M=8, BLOCK_N=8)\n if not keepdim:\n out = out.squeeze(dim)\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-62ec7ebdf48347c28877793a6a0ef40d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton code is designed to efficiently fill key and value states into a cache for attention mechanisms using two different kernels, based on whether quantization is applied. The primary function, `fill_kv_cache`, handles this operation. It requires input tensors for keys and values (`k_states`, `v_states`), the destination caches (`k_caches`, `v_caches`), and metadata about query start locations and sequence lengths. Optional tensors `k_scales_zeros` and `v_scales_zeros` are used when quantization is enabled, defining scale and zero-point adjustments for int4 or int8 representation. The logic branches depending on `quant_policy`: if it is 0, the `_fill_kv_cache_kernel` is invoked, which directly copies states to caches. Otherwise, `_fill_kv_cache_quant_kernel` quantizes inputs using helper functions `_quant_int4` and `_quant_int8`, which compute quantization parameters and pack data accordingly. Each kernel is executed over a multi-dimensional grid, determined by `batch_size` and `max_num_blocks`, using Triton's grid-stride loop pattern for parallelism. The code ensures that only valid data is processed and stored by computing sequence lengths and using boolean masks for conditional memory operations. Additionally, the kernels are optimized to handle different head and block dimensions by adjusting block sizes and grid configuration.\n ", + "output": [ + "from typing import Literal\nimport torch\nimport triton\nimport triton.language as tl\n\n\nKERNEL_META = dict()\n\ndef get_kernel_meta(tensor: torch.Tensor):\n \"\"\"kernel meta.\"\"\"\n return KERNEL_META\n \n@triton.jit\ndef _div_up(val, other):\n return (val + other - 1) // other\n\n\n@triton.jit\ndef _quant_int8(val):\n val_min = tl.min(val, 1)\n val_max = tl.max(val, 1)\n scales = (val_max - val_min) / 255\n zeros = -val_min / scales\n q_val = (val / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n return q_val, scales, zeros\n\n\n@triton.jit\ndef _quant_int4(val1, val2):\n val1 = val1.to(tl.float32)\n val2 = val2.to(tl.float32)\n val_min = tl.min(tl.minimum(val1, val2), 1)\n val_max = tl.max(tl.maximum(val1, val2), 1)\n scales = (val_max - val_min) / 15\n zeros = -val_min / scales\n q_val1 = (val1 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n q_val2 = (val2 / scales[:, None] + zeros[:, None] + 0.5).to(tl.uint8)\n q_val = q_val1 + q_val2 * 16\n return q_val, scales, zeros\n\n\n@triton.jit\ndef _fill_kv_cache_kernel(\n KStates,\n VStates,\n KCaches,\n VCaches,\n QStartLoc,\n QSeqLens,\n KVSeqLens,\n BlockOffsets,\n num_heads: tl.constexpr,\n head_dim: tl.constexpr,\n head_dim_v: tl.constexpr,\n stride_kss,\n stride_ksh,\n stride_ksd,\n stride_vss,\n stride_vsh,\n stride_vsd,\n stride_kcn: tl.constexpr,\n stride_kcb: tl.constexpr,\n stride_kch: tl.constexpr,\n stride_kcd: tl.constexpr,\n stride_vcn: tl.constexpr,\n stride_vcb: tl.constexpr,\n stride_vch: tl.constexpr,\n stride_vcd: tl.constexpr,\n stride_boff,\n BLOCK: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_H: tl.constexpr,\n):\n \"\"\"fill kv cache kernel.\"\"\"\n batch_id = tl.program_id(0)\n block_id = tl.program_id(1)\n\n # initialize\n h_off = tl.arange(0, BLOCK_H)\n d_off = tl.arange(0, BLOCK_D)\n\n q_startloc = tl.load(QStartLoc + batch_id)\n q_seqlen = tl.load(QSeqLens + batch_id)\n kv_seqlen = tl.load(KVSeqLens + batch_id)\n history_seqlen = kv_seqlen - q_seqlen\n\n block0_first_tokenloc = history_seqlen % BLOCK\n\n state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc,\n 0)\n kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id\n kv_block_id = min(kv_block_id, stride_boff - 1)\n block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)\n\n cur_startloc = q_startloc + state_token_offset\n ks_ptr = KStates + cur_startloc * stride_kss\n vs_ptr = VStates + cur_startloc * stride_vss\n\n kc_ptr = KCaches + block_off * stride_kcn\n vc_ptr = VCaches + block_off * stride_vcn\n\n c_first_tokenloc = block0_first_tokenloc\n if block_id != 0:\n c_first_tokenloc *= 0\n c_last_tokenloc = tl.minimum(\n BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK)\n\n for bidx in range(c_first_tokenloc, c_last_tokenloc):\n sidx = bidx - c_first_tokenloc\n mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim)\n k = tl.load(ks_ptr + sidx * stride_kss + h_off[:, None] * stride_ksh +\n d_off[None, :] * stride_ksd,\n mask=mask)\n tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch +\n d_off[None, :] * stride_kcd,\n k,\n mask=mask)\n\n if BLOCK_DV > 0:\n dv_off = tl.arange(0, BLOCK_DV)\n maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] <\n head_dim_v)\n v = tl.load(vs_ptr + sidx * stride_vss +\n h_off[:, None] * stride_vsh +\n dv_off[None, :] * stride_vsd,\n mask=maskv)\n tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch +\n dv_off[None, :] * stride_vcd,\n v,\n mask=maskv)\n\n\n@triton.jit\ndef _fill_kv_cache_quant_kernel(\n KStates,\n VStates,\n KCaches,\n VCaches,\n KScalesZeros,\n VScalesZeros,\n QStartLoc,\n QSeqLens,\n KVSeqLens,\n BlockOffsets,\n num_heads: tl.constexpr,\n head_dim: tl.constexpr,\n head_dim_v: tl.constexpr,\n stride_kss,\n stride_ksh,\n stride_ksd,\n stride_vss,\n stride_vsh,\n stride_vsd,\n stride_kcn: tl.constexpr,\n stride_kcb: tl.constexpr,\n stride_kch: tl.constexpr,\n stride_kcd: tl.constexpr,\n stride_vcn: tl.constexpr,\n stride_vcb: tl.constexpr,\n stride_vch: tl.constexpr,\n stride_vcd: tl.constexpr,\n stride_kszn: tl.constexpr,\n stride_kszb: tl.constexpr,\n stride_kszh: tl.constexpr,\n stride_kszd: tl.constexpr,\n stride_vszn: tl.constexpr,\n stride_vszb: tl.constexpr,\n stride_vszh: tl.constexpr,\n stride_vszd: tl.constexpr,\n quant_policy: tl.constexpr,\n stride_boff,\n BLOCK: tl.constexpr,\n BLOCK_D: tl.constexpr,\n BLOCK_DV: tl.constexpr,\n BLOCK_H: tl.constexpr,\n):\n \"\"\"fill kv cache kernel with int4 and int8 quant fused.\n\n Args:\n stride_xss: stride of sequence length dim of key or value states\n stride_xsh: stride of head_num dim of key or value states\n stride_xsh: stride of head_size dim of key or value states\n stride_xn: stride of page num dim\n stride_xb: stride of block size dim\n stride_xh: stride of head_num dim\n stride_xd: stride of head_size dim\n \"\"\"\n batch_id = tl.program_id(0)\n block_id = tl.program_id(1)\n d_off = tl.arange(0, BLOCK_D)\n\n # initialize\n h_off = tl.arange(0, BLOCK_H)\n szd_off = tl.arange(0, 2)\n\n q_startloc = tl.load(QStartLoc + batch_id)\n q_seqlen = tl.load(QSeqLens + batch_id)\n kv_seqlen = tl.load(KVSeqLens + batch_id)\n history_seqlen = kv_seqlen - q_seqlen\n\n block0_first_tokenloc = history_seqlen % BLOCK\n\n state_token_offset = tl.maximum(block_id * BLOCK - block0_first_tokenloc,\n 0)\n kv_block_id = _div_up(history_seqlen + 1, BLOCK) - 1 + block_id\n kv_block_id = min(kv_block_id, stride_boff - 1)\n block_off = tl.load(BlockOffsets + batch_id * stride_boff + kv_block_id)\n\n cur_startloc = q_startloc + state_token_offset\n ks_ptr = KStates + cur_startloc * stride_kss\n vs_ptr = VStates + cur_startloc * stride_vss\n\n kc_ptr = KCaches + block_off * stride_kcn\n vc_ptr = VCaches + block_off * stride_vcn\n\n ksz_ptr = KScalesZeros + block_off * stride_kszn\n vsz_ptr = VScalesZeros + block_off * stride_vszn\n\n c_first_tokenloc = block0_first_tokenloc\n if block_id != 0:\n c_first_tokenloc *= 0\n c_last_tokenloc = tl.minimum(\n BLOCK, q_seqlen + block0_first_tokenloc - block_id * BLOCK)\n\n for bidx in range(c_first_tokenloc, c_last_tokenloc):\n sidx = bidx - c_first_tokenloc\n mask = (h_off[:, None] < num_heads) & (d_off[None, :] < head_dim)\n if quant_policy == 4:\n k1 = tl.load(ks_ptr + sidx * stride_kss +\n h_off[:, None] * stride_ksh +\n d_off[None, :] * stride_ksd,\n mask=mask)\n k2 = tl.load(ks_ptr + sidx * stride_kss +\n h_off[:, None] * stride_ksh +\n d_off[None, :] * stride_ksd + head_dim * stride_ksd,\n mask=mask)\n q_k, k_scales, k_zeros = _quant_int4(k1, k2)\n else:\n k = tl.load(ks_ptr + sidx * stride_kss +\n h_off[:, None] * stride_ksh +\n d_off[None, :] * stride_ksd,\n mask=mask)\n q_k, k_scales, k_zeros = _quant_int8(k)\n tl.store(kc_ptr + bidx * stride_kcb + h_off[:, None] * stride_kch +\n d_off[None, :] * stride_kcd,\n q_k,\n mask=mask)\n tl.store(ksz_ptr + bidx * stride_kszb + h_off[:, None] * stride_kszh +\n szd_off[None, :] * stride_kszd,\n k_scales[:, None],\n mask=(h_off[:, None] < num_heads) & (szd_off[None, :] < 1))\n tl.store(ksz_ptr + bidx * stride_kszb + h_off[:, None] * stride_kszh +\n szd_off[None, :] * stride_kszd,\n k_zeros[:, None],\n mask=(h_off[:, None] < num_heads) & (szd_off[None, :] == 1))\n\n if BLOCK_DV > 0:\n if quant_policy == 4:\n dv_off = tl.arange(0, BLOCK_DV //\n 2) # int4 pack, half the head_dim\n maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] <\n head_dim_v // 2)\n v1 = tl.load(vs_ptr + sidx * stride_vss +\n h_off[:, None] * stride_vsh +\n dv_off[None, :] * stride_vsd,\n mask=maskv)\n v2 = tl.load(vs_ptr + sidx * stride_vss +\n h_off[:, None] * stride_vsh +\n dv_off[None, :] * stride_vsd +\n head_dim_v // 2 * stride_vsd,\n mask=maskv)\n q_v, v_scales, v_zeros = _quant_int4(v1, v2)\n else:\n dv_off = tl.arange(0, BLOCK_DV)\n maskv = (h_off[:, None] < num_heads) & (dv_off[None, :] <\n head_dim_v)\n v = tl.load(vs_ptr + sidx * stride_vss +\n h_off[:, None] * stride_vsh +\n dv_off[None, :] * stride_vsd,\n mask=maskv)\n q_v, v_scales, v_zeros = _quant_int8(v)\n tl.store(vc_ptr + bidx * stride_vcb + h_off[:, None] * stride_vch +\n dv_off[None, :] * stride_vcd,\n q_v,\n mask=maskv)\n tl.store(\n vsz_ptr + bidx * stride_vszb + h_off[:, None] * stride_vszh +\n szd_off[None, :] * stride_vszd,\n v_scales[:, None],\n mask=(h_off[:, None] < num_heads) & (szd_off[None, :] < 1))\n tl.store(\n vsz_ptr + bidx * stride_vszb + h_off[:, None] * stride_vszh +\n szd_off[None, :] * stride_vszd,\n v_zeros[:, None],\n mask=(h_off[:, None] < num_heads) & (szd_off[None, :] == 1))\n\n\ndef fill_kv_cache(k_states: torch.Tensor,\n v_states: torch.Tensor,\n k_caches: torch.Tensor,\n v_caches: torch.Tensor,\n q_start_loc: torch.Tensor,\n q_seq_length: torch.Tensor,\n kv_seq_length: torch.Tensor,\n max_q_seq_length: int,\n block_offsets: torch.Tensor,\n k_scales_zeros: torch.Tensor = None,\n v_scales_zeros: torch.Tensor = None,\n quant_policy: Literal[0, 4, 8] = 0):\n \"\"\"fill key/value state to cache for paged attention.\"\"\"\n\n block_offsets = block_offsets.contiguous()\n batch_size = block_offsets.size(0)\n block_size, num_heads, head_dim = k_caches.size()[1:]\n head_dim_v = v_states.size(-1)\n max_num_blocks = triton.cdiv(max_q_seq_length, block_size) + 1\n\n BLOCK = block_size\n BLOCK_H = triton.next_power_of_2(num_heads)\n BLOCK_D = triton.next_power_of_2(head_dim)\n BLOCK_DV = triton.next_power_of_2(head_dim_v)\n grid = [batch_size, max_num_blocks]\n kernel_meta = get_kernel_meta(k_states)\n if quant_policy == 0:\n _fill_kv_cache_kernel[grid](\n k_states,\n v_states,\n k_caches,\n v_caches,\n q_start_loc,\n q_seq_length,\n kv_seq_length,\n block_offsets,\n num_heads=num_heads,\n head_dim=head_dim,\n head_dim_v=head_dim_v,\n stride_kss=k_states.stride(-3),\n stride_ksh=k_states.stride(-2),\n stride_ksd=k_states.stride(-1),\n stride_vss=v_states.stride(-3),\n stride_vsh=v_states.stride(-2),\n stride_vsd=v_states.stride(-1),\n stride_kcn=k_caches.stride(0),\n stride_kcb=k_caches.stride(1),\n stride_kch=k_caches.stride(2),\n stride_kcd=k_caches.stride(3),\n stride_vcn=v_caches.stride(0),\n stride_vcb=v_caches.stride(1),\n stride_vch=v_caches.stride(2),\n stride_vcd=v_caches.stride(3),\n stride_boff=block_offsets.stride(0),\n BLOCK=BLOCK,\n BLOCK_D=BLOCK_D,\n BLOCK_DV=BLOCK_DV,\n BLOCK_H=BLOCK_H,\n num_warps=4,\n num_stages=3,\n **kernel_meta,\n )\n else:\n _fill_kv_cache_quant_kernel[grid](\n k_states,\n v_states,\n k_caches,\n v_caches,\n k_scales_zeros,\n v_scales_zeros,\n q_start_loc,\n q_seq_length,\n kv_seq_length,\n block_offsets,\n num_heads=num_heads,\n head_dim=head_dim,\n head_dim_v=head_dim_v,\n stride_kss=k_states.stride(-3),\n stride_ksh=k_states.stride(-2),\n stride_ksd=k_states.stride(-1),\n stride_vss=v_states.stride(-3),\n stride_vsh=v_states.stride(-2),\n stride_vsd=v_states.stride(-1),\n stride_kcn=k_caches.stride(0),\n stride_kcb=k_caches.stride(1),\n stride_kch=k_caches.stride(2),\n stride_kcd=k_caches.stride(3),\n stride_vcn=v_caches.stride(0),\n stride_vcb=v_caches.stride(1),\n stride_vch=v_caches.stride(2),\n stride_vcd=v_caches.stride(3),\n stride_kszn=k_scales_zeros.stride(0),\n stride_kszb=k_scales_zeros.stride(1),\n stride_kszh=k_scales_zeros.stride(2),\n stride_kszd=k_scales_zeros.stride(3),\n stride_vszn=v_scales_zeros.stride(0),\n stride_vszb=v_scales_zeros.stride(1),\n stride_vszh=v_scales_zeros.stride(2),\n stride_vszd=v_scales_zeros.stride(3),\n quant_policy=quant_policy,\n stride_boff=block_offsets.stride(0),\n BLOCK=BLOCK,\n BLOCK_D=BLOCK_D,\n BLOCK_DV=BLOCK_DV,\n BLOCK_H=BLOCK_H,\n num_warps=4,\n num_stages=3,\n **kernel_meta,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-7f7fd786515c46dd96ddfef768934f56", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel code defines a function, `relu_kernel`, which performs the Rectified Linear Unit (ReLU) activation function on an input tensor. The function is compiled using the `triton.jit` decorator, allowing it to run on a GPU. The kernel processes the input tensor in parallel using multiple threads, each identified by `pid`, the program ID. The block of data each thread processes is determined by `block_start` and `offsets`, where `block_size` specifies the number of elements each block handles. The ReLU operation is executed by the `tl.where` function, setting negative values to zero and retaining positive values. The result is conditionally stored back to global memory via `tl.store`, only writing elements within the valid range of the input size `N`. The `relu` function, which serves as a wrapper, prepares the output tensor, calculates the required number of blocks, and launches the Triton kernel on a grid defined by these blocks. It ultimately returns the ReLU-activated tensor.\n ", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef relu_kernel(x_ptr, out_ptr, N: tl.constexpr, block_size: tl.constexpr):\n # Get the index of the current thread\n pid = tl.program_id(0)\n block_start = pid * block_size\n offsets = block_start + tl.arange(0, block_size)\n mask = offsets < N\n\n # Load elements from global memory\n x = tl.load(x_ptr + offsets, mask=mask)\n\n # Compute linear layer\n result = tl.where(x >= 0, x, 0.0)\n\n # Write result to global memory\n if pid == 0:\n tl.store(out_ptr + offsets, result, mask=mask)\n\ndef relu(x):\n # Prepare output tensor\n out = torch.empty_like(x, dtype=torch.float32, device=x.device)\n N = out.numel()\n\n BLOCK_SIZE = 1024\n num_blocks = (N + BLOCK_SIZE - 1) // BLOCK_SIZE # Calculate the number of blocks needed\n\n # Launch Triton kernel\n grid = lambda meta: (triton.cdiv(N, BLOCK_SIZE), )\n\n relu_kernel[grid](x, out, N, BLOCK_SIZE)\n\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-3e087da82ec147e795c923c2d95c4055", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code comprises multiple GPU-accelerated kernels using Triton, facilitating efficient max value computation in tensors.\n - `max_kernel_1`: This Triton JIT-compiled kernel computes the maximum values within a large 1D input tensor across predefined blocks (BLOCK_SIZE). The program index (pid) determines the starting offset for each block, and values are conditionally loaded using a mask to handle out-of-bound reads. The maximum value in each block is calculated using `tl.max()` and stored in the `mid` intermediate tensor at the corresponding position.\n - `max_kernel_2`: A second kernel that consolidates results from `max_kernel_1`. It calculates the overall maximum from the intermediate results stored in `mid` and writes the final maximum value to the `out` tensor.\n - `max_kernel`: An advanced kernel for multi-dimensional max computation, which calculates the maximum values along a specified dimension. It employs a dual-program index strategy: `pid_m` for blocks along the selected dimension and `pid_k` for other dimensions. The `tl.load` function pulls data with masking to ensure correct indexing, and `tl.max` derives the maximum value and corresponding index across dimensions.\n - `max`: A wrapper function that executes the sequential execution of `max_kernel_1` and `max_kernel_2`. It prepares necessary intermediate tensors (`mid` and `out`) and determines optimal block sizes (BLOCK_SIZE) through Triton's utility functions.\n - `max_dim`: Extends max computation to operate on any dimension specified by `dim`, which is checked for validity against the tensor's rank. It reshapes outputs accordingly and executes `max_kernel` using calculated grid dimensions, determined by tensor dimensions M (product of all sizes before `dim`) and K (sizes after `dim`).\n Each function is meticulously annotated for Triton's JIT compilation, facilitating optimal resource allocation for various GPU architectures. This design aims at reducing computational time by maximizing parallel execution potential of Triton kernels.\n ", + "output": [ + "import logging\nimport math\nfrom collections import namedtuple\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef max_kernel_1(\n inp,\n mid,\n M,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n max_val = tl.max(inp_val)\n mid_ptr = mid + pid\n tl.store(mid_ptr, max_val)\n\n\n@triton.jit\ndef max_kernel_2(mid, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=-float(\"inf\"))\n max_val = tl.max(mid_val)\n tl.store(out, max_val)\n\n\ndef heur_block_n(args):\n return triton.next_power_of_2(args[\"N\"])\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 8}, num_warps=8),\n triton.Config({\"BLOCK_M\": 16}, num_warps=8),\n triton.Config({\"BLOCK_M\": 32}, num_warps=8),\n ],\n key=[\n \"M\",\n \"N\",\n ],\n)\n@triton.heuristics(\n {\n \"BLOCK_N\": heur_block_n,\n }\n)\n@triton.jit\ndef max_kernel(\n inp,\n out_value,\n out_index,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n # set offset\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n offset_index = m_offset * K + pid_k\n # set mask\n mask1 = m_offset < M\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n result_value, result_index = tl.max(inp_vals, axis=1, return_indices=True)\n\n out_value_ptrs = out_value + offset_index\n out_index_ptrs = out_index + offset_index\n\n tl.store(out_value_ptrs, result_value, mask=mask1)\n tl.store(out_index_ptrs, result_index, mask=mask1)\n\n\ndef max(inp):\n logging.debug(\"GEMS MAX\")\n M = inp.numel()\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n\n dtype = inp.dtype\n mid = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n out = torch.empty([], dtype=dtype, device=inp.device)\n\n with torch.cuda.device(inp.device):\n max_kernel_1[(mid_size, 1, 1)](inp, mid, M, block_size)\n max_kernel_2[(1, 1, 1)](mid, out, mid_size, block_mid)\n return out\n\n\ndef max_dim(inp, dim=None, keepdim=False):\n logging.debug(\"GEMS MAX DIM\")\n assert dim >= -inp.ndim and dim < inp.ndim, \"Invalid dim\"\n shape = inp.shape\n dim = dim % inp.ndim\n N = shape[dim]\n M = math.prod(shape[:dim])\n K = inp.numel() // M // N\n\n inp = inp.contiguous()\n\n shape_list = list(shape)\n shape_list[dim] = 1\n out_value = torch.empty(shape_list, dtype=inp.dtype, device=inp.device)\n out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)\n\n if not keepdim:\n out_value = torch.squeeze(out_value, dim)\n out_index = torch.squeeze(out_index, dim)\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n max_kernel[grid](inp, out_value, out_index, M, N, K)\n Max_out = namedtuple(\"max\", [\"values\", \"indices\"])\n out = Max_out(values=out_value, indices=out_index)\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e9bcc9c3212e4907891ec12108663723", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton-based operator is a matrix multiplication kernel designed to optimize computational efficiency using techniques like memory hierarchy optimization and tile swizzling. It consists of several key functions:\n\n - `swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)`: This function calculates 2D tile coordinates from a given linear tile ID using a swizzling pattern. It improves L2 cache performance by changing the order in which tiles are accessed.\n - `linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)`: Converts a linear tile ID into 2D tile coordinates without reordering.\n - `mac_loop(A, B, C, M, N, K, locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, iters_per_tile, start_iter, end_iter, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M)`: Computes a portion of the matrix multiplication for the given range of iterations (start_iter to end_iter). It accumulates results in a local accumulator and handles synchronization using locks.\n - `first_wave(A, B, C, M, N, K, locks, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M)`: Manages the first set of work-items executed on the hardware, handling a batch of tiles efficiently by leveraging Stream-K techniques.\n - `full_tiles(A, B, C, M, N, K, stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn, total_tiles_streamk, BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE, GROUP_M)`: Computes tiles left after the initial \"first wave,\" managing the remaining work through classical blocking.\n\n The `matmul` class orchestrates the execution, setting up parameters and calling the necessary functions. The `_call` method manages grid setup, memory allocation, and kernel execution, while `forward` exposes the operation as a PyTorch-compatible function. Key input parameters include the dimensions of matrices A and B (M, N, K), block sizes (BLK_M, BLK_N, BLK_K), and parallelization configuration (number of stages/warps). Output is the product matrix C, calculated from the input matrices A and B.\n ", + "output": [ + "import torch\nimport triton\nfrom triton import language as tl\n\n@triton.jit()\ndef swizzle_tile(tile_id,\n M, N, K,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr\n ):\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = tile_id // width\n group_size = tl.minimum(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (tile_id % group_size)\n pid_n = (tile_id % width) // group_size\n return pid_m, pid_n\n\n\n@triton.jit()\ndef linear_tile(tile_id,\n M, N, K,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr\n ):\n pid_m = tile_id // tl.cdiv(N, BLOCK_N)\n pid_n = tile_id % tl.cdiv(N, BLOCK_N)\n return pid_m, pid_n\n\n\n@triton.jit()\ndef mac_loop(A, B, C,\n M, N, K,\n locks,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n iters_per_tile,\n start_iter, end_iter,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n ACC_TYPE: tl.constexpr, GROUP_M: tl.constexpr):\n\n # where are we in the grid\n tile_id = start_iter // iters_per_tile\n if GROUP_M > 0:\n pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)\n else:\n pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak) + BLOCK_K * stride_ak * (start_iter % iters_per_tile)\n B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn) + BLOCK_K * stride_bk * (start_iter % iters_per_tile)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n\n for current_iter in range(start_iter, end_iter):\n a = tl.load(A)\n b = tl.load(B)\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n if end_iter % iters_per_tile == 0: # last iteration of the tile always happens before its start on another SM\n C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) # compute inside the if/else to avoid spilling!\n tl.store(C_, acc)\n if start_iter % iters_per_tile != 0: # only if tile has been partially processed\n tl.atomic_xchg(locks + tile_id, 1)\n else:\n while tl.atomic_cas(locks + tile_id, 1, 1) != 1:\n pass\n C_ = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn) # compute inside the if/else to avoid spilling!\n tl.atomic_add(C_, acc)\n\n\n@triton.jit()\ndef first_wave(\n A, B, C,\n M, N, K,\n locks,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n total_full_tiles_streamk, total_partial_tiles_streamk, iters_per_tile,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,\n GROUP_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n start_iter = pid * total_full_tiles_streamk + tl.minimum(pid, total_partial_tiles_streamk)\n last_iter = (pid + 1) * total_full_tiles_streamk + tl.minimum(pid + 1, total_partial_tiles_streamk)\n\n while start_iter < last_iter:\n end_iter = tl.minimum(start_iter + (iters_per_tile - start_iter % iters_per_tile), last_iter)\n mac_loop(A, B, C,\n M, N, K,\n locks,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n iters_per_tile,\n start_iter, end_iter,\n BLOCK_M, BLOCK_N, BLOCK_K, ACC_TYPE,\n GROUP_M,\n )\n\n start_iter = end_iter\n\n\n@triton.jit()\ndef full_tiles(\n A, B, C,\n M, N, K,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n total_tiles_streamk,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, ACC_TYPE: tl.constexpr,\n GROUP_M: tl.constexpr,\n):\n # first wave has done more tiles than there are SMs, we adjust pid\n tile_id = tl.program_id(0) + total_tiles_streamk\n if GROUP_M > 0:\n pid_m, pid_n = swizzle_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)\n else:\n pid_m, pid_n = linear_tile(tile_id, M, N, K, BLOCK_M, BLOCK_N, BLOCK_K, GROUP_M)\n\n # do matrix multiplication\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n # pointers\n A = A + (rm[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rn[None, :] * stride_bn)\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(0, tl.cdiv(K, BLOCK_K)):\n a = tl.load(A)\n b = tl.load(B)\n acc += tl.dot(a, b)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n acc = acc.to(tl.float16) # restore C.dtype.element_ty\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n tl.store(C, acc)\n\n\nclass matmul(torch.autograd.Function):\n\n _debug = False\n\n @staticmethod\n def set_debug(debug: bool):\n matmul._debug = debug\n\n @staticmethod\n def _call(a: torch.Tensor, b: torch.Tensor, total_programs_streamk: int, BLK_M: int, BLK_N: int, BLK_K: int, two_tiles: bool, num_stages: int, num_warps: int):\n device = a.device\n\n assert a.is_contiguous() and b.is_contiguous(), \"non-contiguous inputs are not supported\"\n # checks constraints\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n # accumulator types\n ACC_TYPE = tl.float32 if a.dtype in [torch.float16, torch.bfloat16, torch.float32] else tl.int32\n # compute grid (work to do per SM on the first wave)\n total_blocks_M = triton.cdiv(M, BLK_M)\n total_blocks_N = triton.cdiv(N, BLK_N)\n iters_per_tile = triton.cdiv(K, BLK_K)\n GROUP_M = 8 # 0 to disable swizzling\n total_tiles = total_blocks_M * total_blocks_N\n\n if total_programs_streamk > 0: # Stream-K\n # last wave may occupy less than total_programs_streamk SMs\n total_tiles_streamk = total_tiles % total_programs_streamk\n # for two-tile Stream-K + data-parallel from original paper\n if two_tiles and total_tiles - total_tiles_streamk > total_programs_streamk:\n total_tiles_streamk += total_programs_streamk\n # remaining tiles are computed using classical blocking\n total_blocking_tiles = total_tiles - total_tiles_streamk\n total_iters_streamk = total_tiles_streamk * iters_per_tile\n # iterations related to full waves\n total_full_tiles_streamk = total_iters_streamk // total_programs_streamk\n # iterations related to last (partial) wave\n total_partial_tiles_streamk = total_iters_streamk % total_programs_streamk\n\n else: # all tiles are computed using classical blocking\n total_blocking_tiles = total_tiles\n total_tiles_streamk = 0\n total_full_tiles_streamk = 0\n total_partial_tiles_streamk = 0\n total_iters_streamk = 0\n\n if matmul._debug:\n print(f\"M,N,K={M},{N},{K} ; BLK_M,N,K={BLK_M},{BLK_N},{BLK_K}\")\n print(f\"{total_blocks_M=} x {total_blocks_N=} = {total_tiles=}\")\n print(f\"{total_tiles_streamk=} + {total_blocking_tiles=} = {total_tiles=}\")\n print(f\"{total_programs_streamk=}\")\n print(f\"{total_blocking_tiles=}\")\n print(f\"{iters_per_tile=}\")\n print(f\"{total_iters_streamk=}\")\n\n # allocates output\n c = torch.empty((M, N), device=device, dtype=a.dtype)\n # allocates locks to sync work accross SMs\n locks = torch.zeros((total_tiles_streamk,), device=device, dtype=torch.int32)\n k1 = first_wave[(total_programs_streamk,)](\n a,\n b,\n c,\n M,\n N,\n K,\n locks,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n total_full_tiles_streamk=total_full_tiles_streamk,\n total_partial_tiles_streamk=total_partial_tiles_streamk,\n iters_per_tile=iters_per_tile,\n BLOCK_M=BLK_M,\n BLOCK_N=BLK_N,\n BLOCK_K=BLK_K,\n ACC_TYPE=ACC_TYPE,\n GROUP_M=GROUP_M,\n num_stages=num_stages,\n num_warps=num_warps,\n )\n if matmul._debug:\n print(f\"{k1.n_regs} registers used, {k1.n_spills} spills\")\n k2 = full_tiles[(total_blocking_tiles,)](\n a,\n b,\n c,\n M,\n N,\n K,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n total_tiles_streamk=total_tiles_streamk,\n BLOCK_M=BLK_M,\n BLOCK_N=BLK_N,\n BLOCK_K=BLK_K,\n ACC_TYPE=ACC_TYPE,\n GROUP_M=GROUP_M,\n num_stages=num_stages,\n num_warps=num_warps,\n )\n if matmul._debug:\n print(f\"{k2.n_regs} registers used, {k2.n_spills} spills\")\n return c\n\n @staticmethod\n def forward(ctx, a: torch.Tensor, b: torch.Tensor, grid: int, BLK_M=128, BLK_N=128, BLK_K=32, two_tiles=True, num_stages=3, num_warps=4):\n return matmul._call(a=a, b=b, total_programs_streamk=grid, BLK_M=BLK_M, BLK_N=BLK_N, BLK_K=BLK_K, two_tiles=two_tiles, num_warps=num_warps, num_stages=num_stages)\n\n\n\n" + ] + }, + { + "id": "openseek-8-b69c0558ac2c4c4abad66363f5796c67", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton kernel module is focused on enabling efficient recurrent computation with chunked gate processing. The module consists of two main kernels, `_fwd_recurrence` for forward propagation and `_bwd_recurrence` for backward propagation, and a high-level interface class, `ChunkGateRecurrent`.\n\n `_fwd_recurrence`:\n - Purpose: Executes forward pass with recurrent updates across multiple blocks of input data.\n - Parameters:\n - `S` and `d`: Input data and decay factors for each block.\n - `O`: Output tensor for storing results.\n - `NUM_HEAD`, `NUM_BLOCK`: Define the number of heads and blocks.\n - `D_MODEL_K`, `D_MODEL_V`: Overall dimensions of key and value models.\n - `BLOCK_MODEL_K`, `BLOCK_MODEL_V`: Dimensions of each block.\n - `last_kv`: Tensor storing the previous iteration's key/value state, if applicable.\n - Execution: Initializes accumulators, optionally using `last_kv`, then iterates over blocks to apply recurrent transformation and store results in `O`.\n\n `_bwd_recurrence`:\n - Purpose: Computes gradients by reversing the operations performed in the forward pass.\n - Parameters:\n - `S`, `d`: Similar roles as in the forward pass.\n - `DI`, `DG`, `DL`, `DS`: Tensors for storing computed gradients.\n - `NUM_HEAD`, `NUM_BLOCK`, `D_MODEL_K`, `D_MODEL_V`, `BLOCK_MODEL_K`, `BLOCK_MODEL_V`: Dimensional parameters matching those in the forward pass.\n - Execution: Iteratively computes gradients by reversing block operations, accumulating results into gradient tensors `DI`, `DG`, and `DL`.\n\n `ChunkGateRecurrent`:\n - A custom autograd function that manages data layout and execution of forward and backward kernels on GPUs.\n - `forward(ctx, kv, cross_decay, last_kv=None)`: Prepares and launches the forward kernel with given inputs. Saves output for backward pass.\n - `backward(ctx, DO)`: Prepares and launches the backward kernel, calculates gradients using previously saved outputs.\n - Uses blocks of size `BLOCK_MODEL_K` and `BLOCK_MODEL_V` to divide work efficiently across the GPU, ensuring proper tiling and parallelism.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\ntorch.backends.cudnn.allow_tf32 = True\n\n@triton.jit\ndef _fwd_recurrence(\n S, d, \n O,\n NUM_HEAD, NUM_BLOCK, \n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr,\n last_kv: tl.tensor # 不再使用 Optional\n):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]\n O = O + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]\n\n if last_kv is not None:\n last_kv = last_kv + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]\n acc = tl.load(last_kv).to(tl.float32)\n else:\n acc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32)\n\n tl.store(O, acc.to(O.dtype.element_ty))\n O += D_MODEL_K * D_MODEL_V\n d = d + offset_bh * NUM_BLOCK\n for i in range(NUM_BLOCK-1):\n d_i = tl.load(d)\n S_i = tl.load(S) \n acc = acc * d_i + S_i\n tl.store(O, acc.to(O.dtype.element_ty))\n d += 1\n S += D_MODEL_K * D_MODEL_V\n O += D_MODEL_K * D_MODEL_V\n \n\n## NUM_SPLIT_K/V. K/V dimension split into NUM_SPLIT_K/V parts with equal size BLOCK_MODEL\n@triton.jit\ndef _bwd_recurrence(\n S, d, \n DI, DG, DL, DS, \n NUM_HEAD, NUM_BLOCK,\n D_MODEL_K: tl.constexpr, D_MODEL_V: tl.constexpr,\n BLOCK_MODEL_K: tl.constexpr, BLOCK_MODEL_V: tl.constexpr,\n \n ):\n offset_bh = tl.program_id(0)\n offset_d = tl.program_id(1)\n offset_s = tl.program_id(2) \n\n # offset_h = offset_bh % NUM_HEAD\n NUM_K = D_MODEL_K // BLOCK_MODEL_K\n NUM_V = D_MODEL_V // BLOCK_MODEL_V\n # skip the last chunk because it is never used\n S = S + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n DI = DI + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 2) * D_MODEL_K * D_MODEL_V\n\n # start from the last chunk \n DS = DS + offset_bh * NUM_BLOCK * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :] + (NUM_BLOCK - 1) * D_MODEL_K * D_MODEL_V\n\n DG = DG + offset_bh * NUM_BLOCK * NUM_K * NUM_V + offset_d * NUM_V + offset_s + (NUM_BLOCK - 2) * NUM_K * NUM_V\n\n d = d + offset_bh * NUM_BLOCK + (NUM_BLOCK - 1)\n\n Dacc = tl.zeros([BLOCK_MODEL_K, BLOCK_MODEL_V], dtype=tl.float32) \n\n # ignore the first chunk\n for i in range(NUM_BLOCK - 1):\n S_i = tl.load(S)\n DS_i = tl.load(DS)\n d_i = tl.load(d)\n Dacc = Dacc * d_i + DS_i\n DG_i = tl.sum(Dacc * S_i.to(tl.float32))\n\n tl.store(DG, DG_i.to(DG.dtype.element_ty))\n tl.store(DI, Dacc.to(DI.dtype.element_ty)) \n\n S -= D_MODEL_K * D_MODEL_V\n DI -= D_MODEL_K * D_MODEL_V \n DS -= D_MODEL_K * D_MODEL_V\n DG -= NUM_K * NUM_V\n d -= 1\n \n DL = DL + offset_bh * D_MODEL_K * D_MODEL_V + offset_d * D_MODEL_V * BLOCK_MODEL_K + tl.arange(0, BLOCK_MODEL_K)[:, None] * D_MODEL_V + offset_s * BLOCK_MODEL_V + tl.arange(0, BLOCK_MODEL_V)[None, :]\n DS_i = tl.load(DS)\n d_i = tl.load(d)\n Dacc = Dacc * d_i + DS_i\n tl.store(DL, Dacc.to(DL.dtype.element_ty)) \n\nclass ChunkGateRecurrent(torch.autograd.Function):\n @staticmethod\n def forward(ctx, kv, cross_decay, last_kv=None):\n cross_decay = cross_decay.contiguous()\n kv = kv.contiguous()\n\n B, H, N, D_k, D_v = kv.shape \n output = torch.empty_like(kv) \n BLOCK_MODEL_K = 64\n BLOCK_MODEL_V = 16\n \n assert D_k % BLOCK_MODEL_K == 0\n assert D_v % BLOCK_MODEL_V == 0\n\n grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V)\n ctx.grid = grid\n ctx.have_last_kv = last_kv is not None\n ctx.BLOCK_MODEL_K = BLOCK_MODEL_K\n ctx.BLOCK_MODEL_V = BLOCK_MODEL_V\n\n _fwd_recurrence[grid](\n kv,\n cross_decay,\n output,\n D_MODEL_K=D_k, D_MODEL_V=D_v,\n NUM_BLOCK=N, NUM_HEAD=H,\n BLOCK_MODEL_K=BLOCK_MODEL_K,\n BLOCK_MODEL_V=BLOCK_MODEL_V,\n last_kv=last_kv\n )\n\n ctx.save_for_backward(output, cross_decay) \n return output\n\n @staticmethod\n def backward(ctx, DO):\n DO = DO.contiguous()\n\n output, cross_decay = ctx.saved_tensors \n\n B, H, N, D_k, D_v = output.shape \n \n BLOCK_MODEL_K = 64\n BLOCK_MODEL_V = 16\n\n grid = (B*H, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V)\n\n DI = torch.empty_like(DO)\n DG = torch.empty(B*H, N, D_k//BLOCK_MODEL_K, D_v//BLOCK_MODEL_V, device=cross_decay.device, dtype=cross_decay.dtype)\n DL = torch.empty(B, H, D_k, D_v, device=output.device, dtype=output.dtype)\n _bwd_recurrence[grid](\n output, cross_decay,\n DI, DG, DL, DO, \n NUM_HEAD=H, NUM_BLOCK = N, \n D_MODEL_K = D_k,\n D_MODEL_V = D_v, \n BLOCK_MODEL_K=BLOCK_MODEL_K,\n BLOCK_MODEL_V=BLOCK_MODEL_V,\n )\n\n DI[:, :, -1] = 0\n DG[:, -1] = 0\n DG = DG.view(B, H, N, -1).sum(dim=-1)\n return DI, DG, DL if ctx.have_last_kv else None\n\nchunk_gate_recurrent = ChunkGateRecurrent.apply\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-6308ec25657a4fe297a076985f597d47", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel code defines two main functions: `chunk_gated_abc_fwd_kernel_cum` and `chunk_gated_abc_fwd_kernel_h`, \n both leveraging parallel computing capabilities to perform operations on tensor data.\n\n `chunk_gated_abc_fwd_kernel_cum` takes an input tensor `s` and computes a cumulative result stored in `o`. \n It is parameterized by tensor strides and block sizes such as `T` (total number of rows), `S` (total number of columns), \n `BT` (block row size), and `BS` (block column size). The kernel constructs a mask `m_s` to apply cumulative operations \n only on relevant data points, ensuring efficient computation through block pointer manipulation and boundary checking.\n\n `chunk_gated_abc_fwd_kernel_h` is tailored for applying a gated cumulative sum on input tensors: `k` (keys), \n `v` (values), and `g` (gating factors). It considers initial and final state tensors `h0` and `ht` if needed. \n The kernel processes data in chunks defined by `BT` (block size for T), `BK` (block size for K), and `BV` (block size for V), \n updating the output tensor `h` via iterative transformations of `b_h` based on `b_k` and `b_v`, influenced by gating variables `b_g`.\n\n The `fwd_pre` function serves as a preparatory step, setting up the execution grid for `chunk_gated_abc_fwd_kernel_cum`. \n It reshapes and processes the input tensor `g`, effectively mimicking a cumulative sum operation.\n\n `fwd_inner` configures and invokes `chunk_gated_abc_fwd_kernel_h`. It defines the computational grid and parameters \n for efficient execution, with capabilities for adjusting based on warp sizes and stages. This function realizes \n a gated accumulation logic often used in attention mechanisms, adapting with or without state continuity.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BS': 16}, num_warps=2),\n triton.Config({'BS': 16}, num_warps=4),\n triton.Config({'BS': 16}, num_warps=8),\n triton.Config({'BS': 32}, num_warps=2),\n triton.Config({'BS': 32}, num_warps=4),\n triton.Config({'BS': 32}, num_warps=8),\n triton.Config({'BS': 64}, num_warps=2),\n triton.Config({'BS': 64}, num_warps=4),\n triton.Config({'BS': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_cum(\n s,\n o,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr,\n):\n i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.).to(tl.float32)\n\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_o = tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_gated_abc_fwd_kernel_h(\n k,\n v,\n g,\n h,\n h0,\n ht,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n s_h_h,\n s_h_t,\n s_h_d,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n GATEK: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h += tl.load(p_h, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, s_h_d), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if GATEK:\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_gn = tl.make_block_ptr(g + i_bh * s_k_h, (T * K,), (s_k_d,), ((i_t * BT + BT - 1) * K + i_k * BK,), (BK,), (0,))\n # [BK,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[:, None]\n # [BK, BT]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_k = (b_k * tl.exp(b_gn[:, None] - b_g)).to(b_k.dtype)\n else:\n p_g = tl.make_block_ptr(g + i_bh * s_v_h, (T, V), (s_v_t, s_v_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_gn = tl.make_block_ptr(g + i_bh * s_v_h, (T * V,), (s_v_d,), ((i_t * BT + BT - 1) * V + i_v * BV,), (BV,), (0,))\n # [BV,]\n b_gn = tl.load(p_gn, boundary_check=(0,))\n # [BK, BV]\n b_h *= tl.exp(b_gn)[None, :]\n # [BT, BV]\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_v = (b_v * tl.exp(b_gn[None, :] - b_g)).to(b_v.dtype)\n # [BK, BV]\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_h = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef fwd_pre(g, B, H, T, S, BT):\n NT = triton.cdiv(T, BT)\n g_org, g = g, torch.empty_like(g, dtype=torch.float)\n def grid(meta): return (triton.cdiv(meta['S'], meta['BS']), NT, B * H)\n # keep cummulative normalizer in fp32\n # this kernel is equivalent to\n # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1)\n chunk_gated_abc_fwd_kernel_cum[grid](\n g_org, g,\n g.stride(1), g.stride(2), g.stride(3),\n T=T, S=S, BT=BT\n )\n return g\n\n\ndef fwd_inner(q, k, v, g, B, H, T, K, V, BT, BK, BV, gatek=False, h0=None, ht=None):\n NT = triton.cdiv(T, BT)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_warps = 4 if BK == 64 else 2\n num_stages = 1\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NV, NK, B * H)\n chunk_gated_abc_fwd_kernel_h[grid](\n k, v, g, h, h0, ht,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2), h.stride(3),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n GATEK=gatek,\n USE_INITIAL_STATE=h0 is not None,\n STORE_FINAL_STATE=ht is not None,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return h\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-8588debc4ee2421092d54c7e78566f17", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code defines custom forward and backward operations for the SWiGLU (SwiGLU: Swish-Gated Linear Units) function using Triton kernels. The main functions are `swiglu_forward` and `swiglu_backward`. \n `swiglu_forward` takes input tensors `a` and `b`, reshapes them, and prepares an output tensor `c`. It computes the result using the `_swiglu_forward_kernel`, which calculates the element-wise product of `b` and the SiLU-activated `a`. The kernel uses `tl.load` to load data in blocks, applies the SiLU function defined by `silu`, and stores the result using `tl.store`. The kernel launches with a configurable grid size defined by the number of rows in the reshaped input tensor and a calculated block size, `BLOCK_SIZE`, which is a power of two based on the number of columns and limited by `MAX_FUSED_SIZE`.\n `swiglu_backward` computes gradients for `a` and `b` using the `_swiglu_backward_kernel`, which utilizes recomputation to save memory. It recalculates the SiLU activation and uses it to compute the gradients with respect to the input tensors. The backward kernel also uses `tl.load` and `tl.store` for handling data and follows a similar block configuration as the forward kernel.\n The helper function `calculate_settings` assists in determining the optimal `BLOCK_SIZE` and `num_warps` based on the input size, with adjustments depending on the hardware being used (HIP or otherwise). The SiLU activation is implemented as a Triton JIT-compiled function, `silu`, using the Triton `tl.sigmoid` operation to define the SiLU function.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nimport functools\n\ndef is_hip() -> bool:\n return torch.version.hip is not None\n\ndef calculate_settings(n):\n MAX_FUSED_SIZE = 65536\n BLOCK_SIZE = triton.next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(\n f\"Cannot launch Triton kernel since n = {n} exceeds \"\n f\"the recommended Triton blocksize = {MAX_FUSED_SIZE}.\"\n )\n\n num_warps = 4\n if BLOCK_SIZE >= 32768:\n num_warps = 32 if not is_hip() else 16\n elif BLOCK_SIZE >= 8192:\n num_warps = 16\n elif BLOCK_SIZE >= 2048:\n num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef silu(x):\n return x * tl.sigmoid(x)\n\n@triton.jit\ndef _swiglu_forward_kernel(\n a_ptr, b_ptr, c_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr\n):\n program_id = tl.program_id(0).to(tl.int64)\n\n # locate start index\n a_ptr += program_id * stride\n b_ptr += program_id * stride\n c_ptr += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n # sigmoid requires type float32\n a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)\n c_row = silu(a_row) * b_row\n tl.store(c_ptr + col_offsets, c_row, mask=mask)\n\n@triton.jit\ndef _swiglu_backward_kernel(\n dc_ptr, a_ptr, b_ptr, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr\n):\n program_id = tl.program_id(0).to(tl.int64)\n\n # locate start index\n dc_ptr += program_id * stride\n a_ptr += program_id * stride\n b_ptr += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dc_row = tl.load(dc_ptr + col_offsets, mask=mask, other=0)\n # sigmoid requires type float32\n a_row = tl.load(a_ptr + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b_ptr + col_offsets, mask=mask, other=0)\n\n # recomputation to save memory\n sig_a = tl.sigmoid(a_row)\n silu_a = a_row * sig_a\n db_row = dc_row * silu_a\n da_row = dc_row * (silu_a * (1 - sig_a) + sig_a) * b_row\n\n tl.store(a_ptr + col_offsets, da_row, mask=mask)\n tl.store(b_ptr + col_offsets, db_row, mask=mask)\n\ndef swiglu_forward(a, b):\n ori_shape = a.shape\n\n n_cols = ori_shape[-1]\n a = a.view(-1, n_cols)\n b = b.view(-1, n_cols)\n c = torch.empty_like(a)\n n_rows = a.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _swiglu_forward_kernel[(n_rows,)](\n a,\n b,\n c,\n c.stride(-2),\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a, b, c.view(*ori_shape)\n\ndef swiglu_backward(a, b, dc):\n ori_shape = dc.shape\n n_cols = ori_shape[-1]\n dc = dc.view(-1, n_cols)\n n_rows = dc.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _swiglu_backward_kernel[(n_rows,)](\n dc,\n a,\n b,\n dc.stride(-2),\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a.view(*ori_shape), b.view(*ori_shape)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-d598c479928d4d67bccb95afd6375be0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe code involves a custom addition operation utilizing the Triton programming model to perform element-wise addition on two input PyTorch tensors, `a` and `b`. The key components are the Triton kernel `_add_kernel` and its wrapper function `custom_add`.\n\nFunction `_add_kernel(A, B, C, size, BLOCK)`:\n- This Triton kernel function, decorated with `@triton.jit`, signifies its use of just-in-time compilation.\n- Parameters:\n - `A`, `B`: Pointers to input tensors containing the data to be added.\n - `C`: Pointer to the output tensor where the result will be stored.\n - `size`: Scalar specifying the number of elements in the tensors.\n - `BLOCK`: A compile-time constant defining the number of elements each program instance processes concurrently.\n- The kernel computes the addition of `A` and `B` and stores the result in `C`.\n- Program index for parallel execution is fetched via `tl.program_id(0)`.\n- `offs` calculates the offsets for block-wise data processing, determined by `prog_id * BLOCK + tl.arange(0, BLOCK)`.\n- Loads data from `A` and `B` using `tl.load` with masks to prevent out-of-bound access using `offs < size`.\n- The resulting sum is stored in `C` with `tl.store` under the same mask condition to handle edge cases.\n\nFunction `custom_add(a, b)`:\n- Acts as a wrapper to invoke `_add_kernel` and handle tensor operations within the PyTorch framework.\n- Initializes `c`, an empty tensor with the same shape as `a`, to store the result of the addition.\n- Computes `size` of input tensors which dictates the operation range.\n- Sets `BLOCK` size to 16, defining the granularity of each program instance's data processing.\n- Calculates `grid` using `triton.cdiv(size, BLOCK)` to determine the number of program instances required, ensuring complete coverage of the tensor.\n- Launches `_add_kernel` with calculated `grid`, and provides tensor pointers and necessary arguments for execution.\n- The resultant tensor `c` is returned after kernel execution.\n\nThis implementation splits the workload into blocks, processed in parallel, handling edge cases with masks. The consistent BLOCK size ensures efficient memory access patterns.\n", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _add_kernel(A, B, C, size, BLOCK: tl.constexpr):\n \"\"\"add kernel.\"\"\"\n prog_id = tl.program_id(0)\n offs = prog_id * BLOCK + tl.arange(0, BLOCK)\n a = tl.load(A + offs, mask=offs < size)\n b = tl.load(B + offs, mask=offs < size)\n tl.store(C + offs, a + b, mask=offs < size)\n\ndef custom_add(a, b):\n \"\"\"custom add one.\"\"\"\n c = torch.empty_like(a)\n size = c.size(0)\n BLOCK = 16\n\n grid = (triton.cdiv(size, BLOCK), )\n _add_kernel[grid](a, b, c, size, BLOCK=BLOCK)\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0ed60e91a0454a6cbdcaaca95590ed99", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton function `conv2d_forward_kernel` executes a 2D convolution, computing an output feature map by sliding a filter across the input data.\n The kernel uses parallel processing to divide the workload over multiple threads, facilitated by Triton's block-based computation.\n The kernel requires pointers to input, weight, and output arrays, along with dimensions and strides for these arrays. \n Additional parameters include the kernel size, stride, padding, groups, and flags for computation precision like FP16 and TF32.\n The block size constants for batch, input features, and output features control how the computation is partitioned among threads.\n The `conv2d_forward` function wraps this kernel, preparing the input and output tensors and calculating dimensions and strides.\n The function computes the output height and width using the convolution formula and initializes an empty output tensor.\n It then computes the necessary block and grid sizes for launching the Triton kernel, which performs the main computation.\n After execution, the function returns the computed output tensor.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef conv2d_forward_kernel(\n input_pointer, weight_pointer, output_pointer,\n batch_dim, in_feat_dim, in_height, in_width,\n out_feat_dim, out_height, out_width,\n input_batch_stride, input_in_feat_stride, input_height_stride, input_width_stride,\n weight_out_feat_stride, weight_in_feat_stride, weight_height_stride, weight_width_stride,\n output_batch_stride, output_out_feat_stride, output_height_stride, output_width_stride,\n kernel_height: tl.constexpr, kernel_width: tl.constexpr,\n stride_height: tl.constexpr, stride_width: tl.constexpr,\n padding_height: tl.constexpr, padding_width: tl.constexpr,\n groups: tl.constexpr, fp16: tl.constexpr, tf32: tl.constexpr,\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH: tl.constexpr, BLOCK_SIZE_IN_FEAT: tl.constexpr,\n BLOCK_SIZE_OUT_FEAT: tl.constexpr,\n ):\n \"\"\"\n 2D-convolves over the input using weights.\n\n Args:\n input_pointer: Pointer to the input to convolve over.\n The input must be of shape [batch_dim, in_feat_dim, in_height, in_width].\n weight_pointer: Pointer to the weights input is convolved over by.\n The weights must be of shape [out_feat_dim, in_feat_dim, kernel_height, kernel_width].\n output_pointer: Pointer to a container the result is written to.\n The container must be of shape [batch_dim, out_feat_dim, out_height, out_width].\n batch_dim: Batch dimension of the input and output.\n in_feat_dim: Dimensionality of the input features.\n in_height: Input height.\n in_width: Input width.\n out_feat_dim: Dimensionality of the output features.\n out_height: Output height.\n out_width: Output width.\n input_batch_stride: Stride necessary to jump one element along the\n input's batch dimension.\n input_in_feat_stride: Stride necessary to jump one element along the\n input's feature dimension.\n input_height_stride: Stride necessary to jump one element along the\n input's height dimension.\n input_width_stride: Stride necessary to jump one element along the\n input's width dimension.\n weight_out_feat_stride: Stride necessary to jump one element along the\n weights' output feature dimension.\n weight_in_feat_stride: Stride necessary to jump one element along the\n weights' input feature dimension.\n weight_height_stride: Stride necessary to jump one element along the\n weights' height dimension.\n weight_width_stride: Stride necessary to jump one element along the\n weights' width dimension.\n output_batch_stride: Stride necessary to jump one element along the\n output's batch dimension.\n output_out_feat_stride: Stride necessary to jump one element along the\n output's feature dimension.\n output_height_stride: Stride necessary to jump one element along the\n output's height dimension.\n output_width_stride: Stride necessary to jump one element along the\n output's width dimension.\n kernel_height: Kernel height.\n kernel_width: Kernel width.\n stride_height: Stride of kernel across the height dimension.\n stride_width: Stride of kernel across the width dimension.\n padding_height: Padding applied to the input across the height dimension.\n padding_width: Padding applied to the input across the width dimension.\n groups: Number of groups for the convolution.\n fp16: Flag for loading the input and weights in FP16.\n tf32: Flag for performing matrix products in TF32.\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH: Block size across the batch, height, and\n width dimensions.\n BLOCK_SIZE_IN_FEAT: Block size across the input feature dimension.\n BLOCK_SIZE_OUT_FEAT: Block size across the output feature dimension.\n \"\"\"\n batch_height_width_pid = tl.program_id(0)\n out_feat_pid = tl.program_id(1)\n group_pid = tl.program_id(2)\n\n in_group_dim = in_feat_dim // groups\n out_group_dim = out_feat_dim // groups\n\n batch_height_width_offset = (batch_height_width_pid * BLOCK_SIZE_BATCH_HEIGHT_WIDTH +\n tl.arange(0, BLOCK_SIZE_BATCH_HEIGHT_WIDTH))\n batch_height_offset = batch_height_width_offset // out_width\n batch_offset = batch_height_offset // out_height\n\n output_feat_offset = (out_feat_pid * BLOCK_SIZE_OUT_FEAT +\n tl.arange(0, BLOCK_SIZE_OUT_FEAT))\n output_height_offset = batch_height_offset % out_height\n output_width_offset = batch_height_width_offset % out_width\n\n input_pointer += (input_batch_stride * batch_offset +\n input_in_feat_stride * group_pid * in_group_dim)[:, None]\n weight_pointer += (weight_out_feat_stride * output_feat_offset +\n weight_out_feat_stride * group_pid * out_group_dim)[None, :]\n\n accum = tl.zeros((BLOCK_SIZE_BATCH_HEIGHT_WIDTH, BLOCK_SIZE_OUT_FEAT),\n dtype=tl.float32)\n\n for h in range(kernel_height):\n for w in range(kernel_width):\n for c in range(0, in_group_dim, BLOCK_SIZE_IN_FEAT):\n input_feat_offset = c + tl.arange(0, BLOCK_SIZE_IN_FEAT)\n input_height_offset = (h - padding_height +\n stride_height * output_height_offset)\n input_width_offset = (w - padding_width +\n stride_width * output_width_offset)\n\n curr_input_pointer = (input_pointer +\n (input_in_feat_stride * input_feat_offset)[None, :] +\n (input_height_stride * input_height_offset)[:, None] +\n (input_width_stride * input_width_offset)[:, None])\n curr_weight_pointer = (weight_pointer +\n (weight_in_feat_stride * input_feat_offset)[:, None] +\n (weight_height_stride * h) +\n (weight_width_stride * w))\n\n input_mask = ((batch_offset < batch_dim)[:, None] &\n (input_feat_offset < in_group_dim)[None, :] &\n (0 <= input_height_offset)[:, None] &\n (input_height_offset < in_height)[:, None] &\n (0 <= input_width_offset)[:, None] &\n (input_width_offset < in_width)[:, None])\n weight_mask = ((input_feat_offset < in_group_dim)[:, None] &\n (output_feat_offset < out_group_dim)[None, :])\n\n input_block = tl.load(curr_input_pointer, mask=input_mask)\n weight_block = tl.load(curr_weight_pointer, mask=weight_mask)\n\n if fp16:\n input_block = input_block.to(tl.float16)\n weight_block = weight_block.to(tl.float16)\n\n accum += tl.dot(input_block, weight_block, allow_tf32=tf32)\n\n output_pointer += ((output_batch_stride * batch_offset)[:, None] +\n (output_out_feat_stride * (group_pid * out_group_dim + output_feat_offset))[None, :] +\n (output_height_stride * output_height_offset)[:, None] +\n (output_width_stride * output_width_offset)[:, None])\n output_mask = ((batch_offset < batch_dim)[:, None] &\n (output_feat_offset < out_group_dim)[None, :] &\n (output_height_offset < out_height)[:, None] &\n (output_width_offset < out_width)[:, None])\n\n tl.store(output_pointer, accum, mask=output_mask)\n\n\ndef conv2d_forward(input_tensor, weight_tensor, kernel_height, kernel_width, stride_height, stride_width, padding_height, padding_width, groups, fp16=False, tf32=False):\n # Extract dimensions\n batch_dim, in_feat_dim, in_height, in_width = input_tensor.shape\n out_feat_dim, _, _, _ = weight_tensor.shape\n out_height = (in_height + 2 * padding_height - kernel_height) // stride_height + 1\n out_width = (in_width + 2 * padding_width - kernel_width) // stride_width + 1\n\n # Calculate strides\n input_batch_stride, input_in_feat_stride, input_height_stride, input_width_stride = input_tensor.stride()\n weight_out_feat_stride, weight_in_feat_stride, weight_height_stride, weight_width_stride = weight_tensor.stride()\n output_tensor = torch.empty((batch_dim, out_feat_dim, out_height, out_width), device=input_tensor.device, dtype=torch.float32)\n output_batch_stride, output_out_feat_stride, output_height_stride, output_width_stride = output_tensor.stride()\n\n # Define block sizes\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH = 128\n BLOCK_SIZE_IN_FEAT = 32\n BLOCK_SIZE_OUT_FEAT = 32\n\n # Launch the Triton kernel\n grid = (triton.cdiv(batch_dim * out_height * out_width, BLOCK_SIZE_BATCH_HEIGHT_WIDTH),\n triton.cdiv(out_feat_dim, BLOCK_SIZE_OUT_FEAT),\n groups)\n \n conv2d_forward_kernel[grid](\n input_tensor, weight_tensor, output_tensor,\n batch_dim, in_feat_dim, in_height, in_width,\n out_feat_dim, out_height, out_width,\n input_batch_stride, input_in_feat_stride, input_height_stride, input_width_stride,\n weight_out_feat_stride, weight_in_feat_stride, weight_height_stride, weight_width_stride,\n output_batch_stride, output_out_feat_stride, output_height_stride, output_width_stride,\n kernel_height, kernel_width,\n stride_height, stride_width,\n padding_height, padding_width,\n groups, fp16, tf32,\n BLOCK_SIZE_BATCH_HEIGHT_WIDTH, BLOCK_SIZE_IN_FEAT, BLOCK_SIZE_OUT_FEAT\n )\n \n return output_tensor\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4a75c661cc2240be84767427890c7479", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton-based operator for checking the finiteness of tensor elements. The `isfinite_func_wrapper_rank_1` function is a wrapper that processes input and output tensors, ensuring they have matching shapes. It uses heuristic functions to determine optimal tile sizes (`heuristics_for_tile_size`) and the number of warps (`heuristics_for_num_warps`) for efficient execution. The function calculates task parameters like `num_ctas` and `grid`, and then calls the Triton kernel `isfinite_func_kernel_rank_1`. \n\n The `isfinite_func_kernel_rank_1` kernel operates on rank-1 tensors and employs either a monolithic or grid-stride-loop approach based on `one_tile_per_cta`. The kernel leverages Triton's pointer and block pointer types to load data elements from the input tensor into a Triton tensor (`in0`). It uses boundary checking to prevent out-of-bounds accesses. The kernel applies the `isfinite_func` function, which determines finiteness by calling `_isfinited` for fp64 or `_finitef` for other types, casting appropriately. Results are stored back into the output tensor using the Triton `store` function. The kernel dynamically adjusts its execution strategy based on the input size and structure, ensuring both safety and performance.\n ", + "output": [ + "import math\nfrom typing import Union\nimport torch\nimport triton\nfrom triton import language as tl\n\ntry:\n from triton.language.extra.cuda.libdevice import isfinited as _isfinited\nexcept ImportError:\n try:\n from triton.language.math import isfinited as _isfinited\n except ImportError:\n from triton.language.libdevice import isfinited as _isfinited\n\ntry:\n from triton.language.extra.cuda.libdevice import finitef as _finitef\nexcept ImportError:\n try:\n from triton.language.math import finitef as _finitef\n except ImportError:\n from triton.language.libdevice import finitef as _finitef\n\ndef heuristics_for_tile_size(max_tile_size, *sizes):\n ndim = len(sizes)\n tile_sizes = [0 for _ in range(ndim)]\n for i in range(ndim):\n size = sizes[ndim - 1 - i]\n tile_size = min(max_tile_size, triton.next_power_of_2(size))\n tile_sizes[ndim - 1 - i] = tile_size\n max_tile_size = max(1, max_tile_size // tile_size)\n return tuple(tile_sizes)\n\ndef heuristics_for_num_warps(tile_size):\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n \n\nclass StridedBuffer:\n \"\"\"A drop-in replacement of torch.Tensor that can be used in wrapper generated by\n PointwiseDynamicFunction. It allows us to use a different shape, stride, data\n pointer that that of the base tensor.\n\n It is a kind of reinterpretation of the base tensor. We make this class since we\n cannot get a Tensor view with negative strides via torch APIs, while we need this\n to implement flip op.\n\n Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer\n may not have all the methods as torch.Tensors do. We add some attributes & methods\n with the same name as torch.Tensor, which are used in the generated code. But we\n may not cover all the methods, add one if what you need is missing here.\n\n And can also be used in triton kernels since it also has dtype & data_ptr().\n \"\"\"\n\n def __init__(\n self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0\n ):\n self._base = base\n self.dtype = dtype or base.dtype\n if offset == 0:\n self._data_ptr = self._base.data_ptr()\n else:\n offset = self.dtype.itemsize * offset\n self._data_ptr = self._base.data_ptr() + offset\n self.shape = tuple(shape if shape is not None else self._base.shape)\n self._strides = tuple(strides if strides is not None else self._base.stride())\n self.device = self._base.device\n self.ndim = len(self.shape)\n\n def stride(self):\n return self._strides\n\n def size(self):\n return self.shape\n\n def element_size(self):\n return self.dtype.itemsize\n\n def numel(self):\n return math.prod(self.shape)\n\n def dim(self):\n return self.ndim\n\n def unwrap(self):\n return self._base\n\n def data_ptr(self):\n return self._data_ptr\n\n\ndef isfinite_func_wrapper_rank_1(in0: Union[torch.Tensor, StridedBuffer], /, *, out0: Union[torch.Tensor, StridedBuffer]): \n \"\"\"Generated wrapper function with Pointwise: StridedBuffer, StridedBuffer(a1!) -> StridedBuffer(a1!)\"\"\"\n assert in0.shape == out0.shape, 'operand shapes mismatch'\n # task partitioning\n shape = out0.shape\n num_tasks = out0.numel()\n tile_sizes = heuristics_for_tile_size(512, *shape)\n tile_size = math.prod(tile_sizes)\n num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))\n num_ctas = min(65536, num_tiles)\n tiles_per_cta = triton.cdiv(num_tiles, num_ctas)\n num_warps = heuristics_for_num_warps(tile_size)\n one_tile_per_cta = tiles_per_cta==1\n grid = (num_ctas, 1, 1)\n # kernel launch\n in0_strides = in0.stride()\n in0_stride_order = (0,)\n out0_strides = out0.stride()\n out0_stride_order = (0,)\n with torch.cuda._DeviceGuard(in0.device.index):\n isfinite_func_kernel_rank_1[grid](\n in0, out0,\n in0_strides[0], # stride for in0\n in0_stride_order[0], # stride order for in0\n out0_strides[0], # stride for out0\n out0_stride_order[0], # stride orderfor out0\n shape[0], # task indexing space\n num_tasks, # num tasks\n tiles_per_cta=tiles_per_cta, # tiles_per_cta\n tile_size0=tile_sizes[0],\n one_tile_per_cta=one_tile_per_cta,\n num_warps=num_warps,\n )\n return out0\n\n@triton.jit\ndef isfinite_func(x):\n return _isfinited(x) if x.dtype.is_fp64() else _finitef(x.to(tl.float32))\n\n@triton.jit\ndef isfinite_func_kernel_rank_1(\n in0_ptr: tl.tensor, # of tl.pointer_type\n out0_ptr: tl.tensor, # of tl.pointer_type\n in0_stride0: int, # strides for in0\n in0_stride_order0: tl.constexpr, # stride order for in0\n out0_stride0: int, # strides for out0\n out0_stride_order0: tl.constexpr, # stride order for out0\n s0: int, # task_space\n num_tasks: int,\n tiles_per_cta: int,\n tile_size0: tl.constexpr,\n one_tile_per_cta: tl.constexpr,\n):\n pid = tl.program_id(0)\n num_tiles0 = tl.cdiv(s0, tile_size0)\n if one_tile_per_cta: # monolitic kernel style\n tile_id = pid\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = isfinite_func(in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n else: # grid-stride-loop style kernel\n num_ctas = tl.num_programs(0)\n for j in range(0, tiles_per_cta):\n tile_id = pid + j * num_ctas\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = isfinite_func(in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-118494e845a745e8804528028c156149", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The given Triton code provides an implementation to convert tensors containing packed FP4 (4-bit floating point) values into BF16 (bfloat16) values on a CUDA-capable GPU using Triton. The implementation is composed of two main functions: `_fp4_packed_to_bf16` and `triton_f4_to_bf16_kernel`.\n\n 1. Function: `_fp4_packed_to_bf16`\n This function accepts packed FP4 values and produces BF16 values. The packed FP4 values consist of low and high bits, which are separated using bitwise operations and then interleaved. Each FP4 value consists of sign, exponent, and mantissa components that are extracted using provided bit masks and constants.\n \n Key operations include:\n - Separating sign, exponent, and mantissa.\n - Handling special cases, such as zero and denormal (0.5) values.\n - Adjusting exponents from FP4 to BF16, accounting for differences in exponent bias.\n - Reconstructing the value in FP32 format first, then converting to BF16.\n \n 2. Function: `triton_f4_to_bf16_kernel`\n This Triton kernel handles parallel processing of the input tensor. It processes data in blocks and employs `_fp4_packed_to_bf16` for conversion. Input includes pointers to packed FP4 data and the intended BF16 output location. The kernel uses a block size to iterate over chunks of data, efficiently converting each.\n\n Essential constants and parameters involved in conversion include:\n - `SIGN_MASK_F4`, `MANTISSA_MASK_F4`: Used for extracting parts of FP4 numbers.\n - `ZERO_BITS_F32`, `ZERO_POINT_FIVE_BITS_F32`: Bit patterns for special FP32 values (zero and 0.5).\n - `EBITS_F4_E2M1`, `MBITS_F4_E2M1`, `EBITS_F32`, `MBITS_F32`: Exponent and mantissa specifications for FP4 and FP32.\n - Bias constants to correct exponent value differences between formats.\n\n 3. Wrapper: `triton_f4_to_bf16`\n This Python function wraps around the kernel call, ensuring the input tensor is contiguous and CUDA-allocated. It calculates the grid dimensions for the kernel launch, corresponding to the number of elements and block size. The output tensor is allocated with the appropriate shape and type (bfloat16), and the kernel is executed.\n\n In conclusion, this code demonstrates how to effectively convert FP4 encoded data into BF16 using Triton's CUDA parallelism to achieve high performance, with attention to numerical precision and edge-case handling.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fp4_packed_to_bf16(\n x_packed,\n sign_mask_f4,\n mantissa_mask_f4,\n mbits_f4_e2m1,\n ebits_f4_e2m1,\n f4_e2m1_exp_bias,\n mbits_f32,\n ebits_f32,\n f32_exp_bias,\n zero_bits_f32,\n zero_point_five_bits_f32,\n):\n \"\"\"\n Input: a tensor of packed fp4 values\n Output: a tensor of bfloat16 values\n \"\"\"\n\n # low-bits: original location 0:3\n # high-bits: original location 4:7\n x_low_bits = x_packed >> 4\n x_high_bits = x_packed & 0xF\n x = tl.interleave(x_low_bits, x_high_bits)\n\n # cast logic below\n # output = x_unpacked.to(tl.float32)\n\n # save the sign\n sign_f4 = x & sign_mask_f4\n\n # set everything to positive, will add sign back at the end\n x_pos = x ^ sign_f4\n\n # Special case zero\n zero_mask = x_pos == 0\n\n # There is only one denormal value in fp4: s001, which is 0.5 in f32\n # Special case it.\n # TODO(later): will it be faster to repeat this for all 8 positive\n # values instead of the bit manipulations?\n denormal_mask = x_pos == 1\n\n # calculate the new exponent and shift it to bits 2:9 of the result\n exp_biased_f4 = x_pos >> mbits_f4_e2m1\n exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias\n exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32\n\n # shift the mantissa to bits 10:32 of the result\n mantissa_f4 = x_pos & mantissa_mask_f4\n mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)\n output = mantissa_f32\n\n # combine the pieces\n result = exp_biased_f32 | mantissa_f32\n # result[zero_mask] = ZERO_BITS_F32\n result = tl.where(zero_mask, zero_bits_f32, result)\n # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32\n result = tl.where(denormal_mask, zero_point_five_bits_f32, result)\n\n # add sign back\n sign_f32 = sign_f4.to(tl.int32) << (\n mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1\n )\n result = result | sign_f32\n\n # The bit shifting above is for float32, so for now we\n # bitcast to float32 and then regular cast to bfloat16\n # TODO(later): it should be pretty easy to cast directly to bf16, just\n # need to adjust the mbits/ebits/special values. Perf impact is likely\n # to be small as we would not be chaning memory access patterns.\n output = result.to(tl.float32, bitcast=True)\n output = output.to(tl.bfloat16)\n return output\n\n@triton.jit\ndef triton_f4_to_bf16_kernel(\n x_ptr,\n output_ptr,\n n_elements_in,\n sign_mask_f4: tl.constexpr,\n mantissa_mask_f4: tl.constexpr,\n mbits_f4_e2m1: tl.constexpr,\n ebits_f4_e2m1: tl.constexpr,\n f4_e2m1_exp_bias: tl.constexpr,\n mbits_f32: tl.constexpr,\n ebits_f32: tl.constexpr,\n f32_exp_bias: tl.constexpr,\n zero_bits_f32: tl.constexpr,\n zero_point_five_bits_f32: tl.constexpr,\n BLOCK_SIZE_IN: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n n_elements_out = n_elements_in * 2\n BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2\n\n block_start_in = pid * BLOCK_SIZE_IN\n offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN)\n\n mask_in = offsets_in < n_elements_in\n\n # packed uint8\n x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)\n output = _fp4_packed_to_bf16(\n x_packed,\n sign_mask_f4,\n mantissa_mask_f4,\n mbits_f4_e2m1,\n ebits_f4_e2m1,\n f4_e2m1_exp_bias,\n mbits_f32,\n ebits_f32,\n f32_exp_bias,\n zero_bits_f32,\n zero_point_five_bits_f32,\n )\n\n # set up output offsets\n block_start_out = pid * BLOCK_SIZE_OUT\n offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT)\n mask_out = offsets_out < n_elements_out\n\n tl.store(output_ptr + offsets_out, output, mask=mask_out)\n\nEBITS_F32, MBITS_F32 = 8, 23\nEBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1\nEBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3\nEBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2\n\nSIGN_MASK_F4 = 0x8 # 1000\nMANTISSA_MASK_F4 = 0x1 # 0001\n\nZERO_BITS_F32 = 0x0\nZERO_POINT_FIVE_BITS_F32 = 0x3F000000\nF4_E2M1_EXP_BIAS = 1\nF32_EXP_BIAS = 127\n\ndef triton_f4_to_bf16(x: torch.Tensor):\n \"\"\"\n Input: a tensor of packed fp4 values\n Output: a tensor of bfloat16 values\n\n Note: this function is only used in testing, so we can test\n the numerical correctness of the cast without the scaling.\n \"\"\"\n new_shape = (*x.shape[:-1], x.shape[-1] * 2)\n output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)\n assert x.is_contiguous()\n assert x.is_cuda and output.is_cuda\n n_elements_in = x.numel()\n grid = lambda meta: ( # noqa: E731\n triton.cdiv(n_elements_in, meta[\"BLOCK_SIZE_IN\"]),\n ) # noqa: E731,E501\n triton_f4_to_bf16_kernel[grid](\n x,\n output,\n n_elements_in,\n sign_mask_f4=SIGN_MASK_F4,\n mantissa_mask_f4=MANTISSA_MASK_F4,\n mbits_f4_e2m1=MBITS_F4_E2M1,\n ebits_f4_e2m1=EBITS_F4_E2M1,\n f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,\n mbits_f32=MBITS_F32,\n ebits_f32=EBITS_F32,\n f32_exp_bias=F32_EXP_BIAS,\n zero_bits_f32=ZERO_BITS_F32,\n zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,\n BLOCK_SIZE_IN=512,\n )\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-407f4a4a1ed34da2847b5935c7a9e4c1", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_triton_rope` kernel is a Triton-annotated function (`@triton.jit`) designed to apply rotary position embeddings to query (`q_ptr`) and key (`k_ptr`) matrices by using precomputed cosine and sine arrays (`cos`, `sin`). The kernel uses a unique identifier `pid` for each program instance, indexing into matrices for batch processing. The embeddings split the head dimensions into two halves, applying a rotation formula to each. The `cos` and `sin` arrays determine the phase shift for the embedding. Two configurations are supported: forward and backward, determined by `BACKWARD_PASS`. During the forward pass, the rotation uses standard trigonometric relationships, while the backward pass applies the inverse rotation. The `rope_backward` function is a wrapper that ensures gradients (`dq`, `dk`) are processed correctly by transposing and configuring the kernel call. Inputs are padded to power-of-two dimensions to align with efficient Triton memory access patterns. This function primarily sets up the operation and handles the memory alignment, then triggers the kernel for execution.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_rope(\n q_ptr,\n q_row_stride,\n k_ptr,\n k_row_stride,\n cos,\n cos_row_stride,\n sin,\n sin_row_stride,\n sl,\n bs: tl.constexpr,\n n_qh: tl.constexpr,\n n_kh: tl.constexpr,\n hd: tl.constexpr,\n pad_n_qh: tl.constexpr,\n pad_n_kh: tl.constexpr,\n pad_hd: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n BACKWARD_PASS: tl.constexpr = False,\n):\n pid = tl.program_id(0)\n\n q_ptr = q_ptr + pid * q_row_stride\n k_ptr = k_ptr + pid * k_row_stride\n\n cos_row_idx = pid % (sl)\n cos = cos + cos_row_idx * cos_row_stride\n sin = sin + cos_row_idx * sin_row_stride\n cos_offsets = tl.arange(0, pad_hd // 2)\n cos_mask = cos_offsets < hd // 2\n cos_row = tl.load(cos + cos_offsets, mask=cos_mask, other=0)\n sin_row = tl.load(sin + cos_offsets, mask=cos_mask, other=0)\n\n first_half_q_offsets = tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_half_k_offsets = tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]\n first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (tl.arange(0, pad_hd // 2)[None, :] < hd // 2)\n q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(sin_row.dtype)\n k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(sin_row.dtype)\n\n second_half_q_offsets = first_half_q_offsets + (hd // 2)\n second_half_k_offsets = first_half_k_offsets + (hd // 2)\n second_q_mask = first_q_mask\n second_k_mask = first_k_mask\n q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(sin_row.dtype)\n k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(sin_row.dtype)\n\n if not BACKWARD_PASS:\n new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n else:\n new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row\n tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)\n new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row\n tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)\n\n new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row\n tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)\n new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row\n tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)\n\ndef rope_backward(dq, dk, cos, sin):\n dq = dq.transpose(1, 2)\n dk = dk.transpose(1, 2)\n\n batch_size, seq_len, n_q_head, head_dim = dq.shape\n n_kv_head = dk.shape[2]\n pad_hd = triton.next_power_of_2(head_dim)\n pad_n_q_head = triton.next_power_of_2(n_q_head)\n pad_n_kv_head = triton.next_power_of_2(n_kv_head)\n BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)\n\n n_row = batch_size * seq_len\n\n dq = dq.contiguous()\n dk = dk.contiguous()\n\n _triton_rope[(n_row,)](\n dq,\n dq.stride(1),\n dk,\n dk.stride(1),\n cos,\n cos.stride(-2),\n sin,\n sin.stride(-2),\n seq_len,\n batch_size,\n n_q_head,\n n_kv_head,\n head_dim,\n pad_n_q_head,\n pad_n_kv_head,\n pad_hd,\n BLOCK_SIZE=BLOCK_SIZE,\n BACKWARD_PASS=True,\n )\n return dq.transpose(1, 2), dk.transpose(1, 2)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-862fad300db84c21a27cb6d61ee3b6d0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code consists of a Triton kernel, `_sgmv_expand_slice_kernel`, and a wrapper function, `_sgmv_expand_slice`. \n The kernel performs a specialized form of matrix multiplication involving sparse generalized matrix-vector multiplication (SGMV). \n It operates on blocks of data defined by the parameters BLOCK_M, BLOCK_N, and BLOCK_K. \n The kernel checks for boundary conditions and handles data casting between different precisions. \n It processes sequences of a batch in parallel and uses lora_indices to access specific slices of LoRA weights for computations. \n The wrapper function sets up the input parameters, ensuring contiguity and dimension correctness, and configures the Triton launch grid to execute the kernel. \n It validates tensor shapes and data types, adjusts weights if needed, and launches the kernel with a grid configured to cover the input data based on the maximum sequence length and batch size.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _sgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n b_seq_start_loc,\n seq_lens,\n lora_indices,\n xm_stride,\n xk_stride, # 1\n l0_stride, # hidden_size*max_rank\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n cta_n_num = tl.cdiv(N, BLOCK_N)\n pid_m = pid // cta_n_num\n pid_n = pid % cta_n_num\n M = tl.load(seq_lens + cur_batch)\n if pid_m * BLOCK_M > M:\n return\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n cur_seq_start = tl.load(b_seq_start_loc + cur_batch)\n offset_m = tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_n = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N\n offset_k = tl.arange(0, BLOCK_K)\n ram = tl.max_contiguous(tl.multiple_of(offset_m % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(offset_n % N, BLOCK_N), BLOCK_N)\n\n a_ptr = (input_ptr + cur_seq_start * xm_stride + ram[:, None] * xm_stride +\n offset_k[None, :] * xk_stride, )\n b_ptr = (lora_ptr + l0_stride * lora_index +\n offset_k[:, None] * lora_n_stride + rbn[None, :] * lora_k_stride)\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n for k in range(tl.cdiv(K, BLOCK_K)):\n if EVEN_K:\n tiled_a = tl.load(a_ptr)\n tiled_b = tl.load(b_ptr)\n else:\n tiled_a = tl.load(a_ptr,\n mask=offset_k[None, :] < K - k * BLOCK_K,\n other=0)\n tiled_b = tl.load(b_ptr,\n mask=offset_k[:, None] < K - k * BLOCK_K,\n other=0)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n accumulator += tl.dot(\n tiled_a,\n tiled_b,\n )\n a_ptr += BLOCK_K * xk_stride\n b_ptr += BLOCK_K * lora_n_stride\n tiled_c = accumulator.to(lora_ptr.dtype.element_ty)\n offset_cm = cur_seq_start + tl.arange(0, BLOCK_M) + pid_m * BLOCK_M\n offset_cn = tl.arange(0, BLOCK_N) + pid_n * BLOCK_N + slice_offset\n c_ptr = (out_ptr + offset_cm[:, None] * cm_stride +\n offset_cn[None, :] * cn_stride)\n M = tl.load(seq_lens + cur_batch)\n c_mask = (offset_cm[:, None] < (cur_seq_start + M)) & (offset_cn[None, :] <\n (slice_offset + N))\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr, mask=c_mask)\n tiled_c += tiled_out\n tl.store(c_ptr, tiled_c, mask=c_mask)\n\n\n@torch.inference_mode()\ndef _sgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n b_seq_start_loc: torch.Tensor,\n seq_len_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n batches: int,\n max_seq_length: int,\n token_nums: int,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = False,\n) -> None:\n\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(0) == token_nums\n assert inputs.size(1) == lora_b_weights.size(-1)\n assert b_seq_start_loc.size(0) == batches\n assert lora_indices_tensor.size(0) == batches\n assert slice_size == lora_b_weights.size(-2)\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n\n assert lora_b_weights.is_contiguous()\n\n N, K = lora_b_weights.shape[-2:]\n\n BLOCK_M = 32\n BLOCK_N = 32\n BLOCK_K = 16\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n grid = (\n triton.cdiv(max_seq_length, BLOCK_M) * triton.cdiv(N, BLOCK_N),\n batches,\n )\n _sgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n b_seq_start_loc,\n seq_len_tensor,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_M,\n BLOCK_N,\n BLOCK_K,\n EVEN_K,\n ADD_INPUTS,\n CAST_TYPE,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c98b87096fd94ed48a3259c2fb079f19", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-accelerated function embedding_kernel is specialized for extracting and storing embedding vectors from a weight matrix for a sequence of token IDs. It uses program IDs to determine processing offsets and handles iteration over sequences with BLOCK_N and BLOCK_NN stride sizes. For each sequence, it computes token IDs and uses masks to ensure only valid data is loaded and processed. The weight matrix is addressed using a combination of token IDs and dimension offsets, facilitated by the stride of the weight tensor. The processed vectors are then stored into the 'out' tensor using calculated strides and masks, ensuring each output sequence position receives the correct embedding vector. The wrapping function, embedding, configures and invokes the kernel with appropriate grid settings, aligning BLOCK_DMODEL to the next power of two based on weight dimensions and leveraging constant memory settings to optimize the embedding extraction process.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef embedding_kernel(\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n stride_weight_seq,\n stride_out_seq,\n n_ctx,\n hiden_size: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_NN: tl.constexpr,\n):\n start_n = tl.program_id(0) * BLOCK_N\n\n offs_nn = start_n + tl.arange(0, BLOCK_NN)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n for start_nn in range(0, BLOCK_N, BLOCK_NN):\n start_nn = tl.multiple_of(start_nn, BLOCK_NN)\n offs_seq = start_nn + offs_nn\n n_ctx_mask = offs_seq < n_ctx\n token_ids = tl.load(input_ids + offs_seq, mask=n_ctx_mask, other=vob_end_id)\n id_mask = (token_ids >= vob_start_id) & (token_ids < vob_end_id)\n token_ids = token_ids - vob_start_id\n dim_mask = offs_d < hiden_size\n load_mask = id_mask[:, None] & dim_mask[None, :]\n store_mask = n_ctx_mask[:, None] & dim_mask[None, :]\n vecs = tl.load(weight + token_ids[:, None] * stride_weight_seq + offs_d[None, :], mask=load_mask, other=0.0)\n tl.store(out + offs_seq[:, None] * stride_out_seq + offs_d[None, :], vecs, mask=store_mask)\n\n@torch.no_grad()\ndef embedding(input_ids, weight: torch.Tensor, vob_start_id, vob_end_id, out: torch.Tensor):\n BLOCK_N = 64\n BLOCK_NN = 1\n BLOCK_DMODEL = triton.next_power_of_2(weight.shape[1])\n n_ctx = input_ids.shape[0]\n\n grid = (triton.cdiv(n_ctx, BLOCK_N), 1, 1)\n\n embedding_kernel[grid](\n weight,\n input_ids,\n out,\n vob_start_id,\n vob_end_id,\n weight.stride(0),\n out.stride(0),\n n_ctx=n_ctx,\n hiden_size=weight.shape[1],\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_N=BLOCK_N,\n BLOCK_NN=BLOCK_NN,\n num_warps=1,\n num_stages=1,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-2135932c1db940729cb7a2c692dc83aa", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `uniform_kernel` function in Triton is designed to generate uniform random numbers on the GPU within a specified range `[from_, to)`. It accepts several parameters: `out_ptr` is a pointer to where the result is stored, `N` is the total number of random numbers to generate, `philox_seed` and `philox_offset` are used for initializing the Philox random number generator, and `from_` and `to` define the range of random numbers. The block size `BLOCK` determines the number of threads per block.\n\n A key part of the kernel is utilizing Triton's Philox function, which takes the seed and offset to generate four random numbers (`r0`, `r1`, `r2`, `r3`). These random uints are then converted to floats in the `[0, 1)` range using `uint_to_uniform_float`, and scaled to the specified range `[from_, to)`.\n\n To efficiently manage memory, the kernel uses a loop to calculate offsets (`off_0`, `off_1`, `off_2`, `off_3`) for storing the four random numbers generated per loop iteration. Each number is conditionally stored based on its offset relative to `N`, using Triton's `tl.store` function with an eviction policy set to \"evict_first\". The kernel also includes heuristics to optimize `BLOCK` size and `num_warps` based on `N`.\n\n The `uniform_` function is a high-level wrapper that initializes the Philox generator state using the `philox_cuda_seed_offset` function, calculates the grid size, and invokes the kernel with the appropriate parameters. It ensures the kernel is executed within the context of the current CUDA device using `torch.cuda.device`.\n ", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\nimport functools\nimport operator\nfrom typing import Tuple\n\nShape = Tuple[int]\n\n\ntry:\n uint_to_uniform_float = tl.uint_to_uniform_float\nexcept AttributeError:\n # Copied from triton.language package for compatibility\n @triton.jit\n def uint_to_uniform_float(x):\n \"\"\"\n Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).\n \"\"\"\n # TODO: fix frontend issues and cleanup\n # conditions can be simplified\n # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)\n if tl.constexpr(x.dtype == tl.uint32) or tl.constexpr(x.dtype == tl.int32):\n # maximum value such that `MAX_INT * scale < 1.0` (with float rounding)\n x = x.to(tl.int32, bitcast=True)\n scale = 4.6566127342e-10\n else:\n tl.static_assert(\n tl.constexpr(x.dtype == tl.uint64) or tl.constexpr(x.dtype == tl.int64)\n )\n x = x.to(tl.int64, bitcast=True)\n scale = 1.0842020432385337e-19\n x = tl.where(x < 0, -x - 1, x)\n return x * scale\n\n\ndef philox_cuda_seed_offset(increment, device=None):\n device = device or torch.cuda.current_device()\n gen = torch.cuda.default_generators[device]\n state_copy = gen.get_state()\n c0, c1 = state_copy.view(torch.int64)\n seed, offset = int(c0), int(c1)\n increment = (increment + 3) // 4 * 4\n c1 += increment\n # get_state returns a new tensor, so it needs set_state to update the actual generator state.\n gen.set_state(state_copy)\n return seed, offset\n\n\ndef heur_block(args):\n if args[\"N\"] <= 512:\n return 512\n else:\n return 1024\n\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 512:\n return 4\n elif args[\"N\"] <= 1024:\n return 8\n else:\n return 16\n\n\n@triton.heuristics(\n {\n \"BLOCK\": heur_block,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit(do_not_specialize=[\"philox_seed\", \"philox_offset\"])\ndef uniform_kernel(\n out_ptr,\n N,\n philox_seed,\n philox_offset,\n from_,\n to,\n BLOCK: tl.constexpr,\n):\n philox_seed = philox_seed.to(tl.int64)\n philox_offset = philox_offset.to(tl.int64)\n c0 = (philox_offset & 0xFFFFFFFF).to(tl.uint32)\n c1 = ((philox_offset >> 32) & 0xFFFFFFFF).to(tl.uint32)\n i4 = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)\n c0 += i4\n _O = c0 * 0\n r0, r1, r2, r3 = tl.philox(philox_seed, c0, c1, _O, _O)\n r0 = uint_to_uniform_float(r0) * (to - from_) + from_\n r1 = uint_to_uniform_float(r1) * (to - from_) + from_\n r2 = uint_to_uniform_float(r2) * (to - from_) + from_\n r3 = uint_to_uniform_float(r3) * (to - from_) + from_\n off_0 = tl.program_id(0) * BLOCK * 4 + tl.arange(0, BLOCK)\n off_1 = off_0 + BLOCK\n off_2 = off_1 + BLOCK\n off_3 = off_2 + BLOCK\n tl.store(out_ptr + off_0, r0, mask=off_0 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_1, r1, mask=off_1 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_2, r2, mask=off_2 < N, eviction_policy=\"evict_first\")\n tl.store(out_ptr + off_3, r3, mask=off_3 < N, eviction_policy=\"evict_first\")\n\n\n\ndef volume(shape: Shape) -> int:\n return functools.reduce(operator.mul, shape, 1)\n\n\nUNROLL = 4\n\n\ndef uniform_(self, from_=0.0, to=1.0, *, generator=None):\n logging.debug(\"GEMS UNIFORM\")\n N = volume(self.shape)\n grid_fn = lambda meta: (triton.cdiv(N, meta[\"BLOCK\"] * UNROLL),)\n\n increment = triton.cdiv(N, UNROLL)\n philox_seed, philox_offset = philox_cuda_seed_offset(increment)\n with torch.cuda.device(self.device):\n uniform_kernel[grid_fn](self, N, philox_seed, philox_offset, from_, to)\n return self\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c4eabe5bf09b453a9fc4b85826d12a59", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_bgmv_expand_slice_kernel` is designed to optimize the matrix-vector multiplication by splitting the N dimension into multiple blocks, improving performance for large hidden sizes. It operates on inputs from `input_ptr` (input matrix), `lora_ptr` (LoRA weight matrix), and `out_ptr` (output matrix), considering the batch and index information provided by `lora_indices`. The kernel computes matrix-vector products for each block and optionally accumulates these products into the output tensor. Key constants like `BLOCK_N`, `BLOCK_K`, `SPLIT_N`, `EVEN_K`, `ADD_INPUTS`, and `CAST_TYPE` control the block dimensions, input data handling, and type casting behavior. The `_bgmv_expand_slice` function, marked with `@torch.inference_mode()`, initializes tensor properties, validates shapes, and prepares the grid configuration before invoking the Triton kernel, ensuring data and operational integrity.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _bgmv_expand_slice_kernel(\n input_ptr,\n lora_ptr,\n out_ptr,\n N,\n K,\n lora_indices,\n xm_stride,\n xk_stride,\n l0_stride,\n lora_k_stride,\n lora_n_stride,\n cm_stride,\n cn_stride,\n slice_offset,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n SPLIT_N: tl.constexpr,\n EVEN_K: tl.constexpr,\n ADD_INPUTS: tl.constexpr,\n CAST_TYPE: tl.constexpr,\n):\n \"\"\"\n GroupGEMV, introducing SPLIT_N can improve large hidden_size's performance\n \"\"\"\n pid_sn = tl.program_id(axis=0)\n cur_batch = tl.program_id(axis=1)\n lora_index = tl.load(lora_indices + cur_batch)\n if lora_index == -1:\n return\n offset_k = tl.arange(0, BLOCK_K)\n offset_n = tl.arange(0, BLOCK_N)\n if EVEN_K:\n tiled_a = tl.load(input_ptr + cur_batch * xm_stride +\n offset_k * xk_stride)\n else:\n tiled_a = tl.load(\n input_ptr + cur_batch * xm_stride + offset_k * xk_stride,\n mask=offset_k < K,\n other=0,\n )\n split_n_length = tl.cdiv(N, SPLIT_N)\n if CAST_TYPE:\n tiled_a = tiled_a.to(lora_ptr.dtype.element_ty)\n b_ptr = (lora_ptr + l0_stride * lora_index +\n pid_sn * split_n_length * lora_k_stride)\n c_ptr = (out_ptr + cur_batch * cm_stride + pid_sn * split_n_length +\n slice_offset * cn_stride)\n\n for n in range(0, split_n_length, BLOCK_N):\n current_n = n + offset_n\n b_ptr_mask = (current_n[:, None] < split_n_length) & (offset_k[None, :]\n < K)\n c_mask = current_n < split_n_length\n tiled_b = tl.load(\n b_ptr + current_n[:, None] * lora_k_stride +\n offset_k[None, :] * lora_n_stride,\n mask=b_ptr_mask,\n other=0.0,\n )\n if ADD_INPUTS:\n tiled_out = tl.load(c_ptr + current_n * cn_stride, mask=c_mask)\n accumulator = tl.sum(tiled_a * tiled_b, 1) + tiled_out\n else:\n accumulator = tl.sum(tiled_a * tiled_b, 1)\n tl.store(c_ptr + current_n * cn_stride, accumulator, mask=c_mask)\n\n@torch.inference_mode()\ndef _bgmv_expand_slice(\n inputs: torch.Tensor,\n lora_b_weights: torch.Tensor,\n output_tensor: torch.Tensor,\n lora_indices_tensor: torch.Tensor,\n slice_offset: int,\n slice_size: int,\n add_inputs: bool = True,\n) -> None:\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n\n assert slice_size == lora_b_weights.size(-2)\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4:\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3\n\n assert lora_b_weights.is_contiguous()\n\n N, K = lora_b_weights.shape[-2:]\n BLOCK_K = triton.next_power_of_2(K)\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n\n batches = lora_indices_tensor.size(0)\n\n # config = get_lora_op_configs(\"expand\", batches, N)\n\n grid = lambda META: (\n META[\"SPLIT_N\"],\n batches,\n )\n _bgmv_expand_slice_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n slice_offset,\n BLOCK_K=BLOCK_K,\n BLOCK_N=256,\n SPLIT_N=lora_b_weights.shape[-2:][0],\n EVEN_K=EVEN_K,\n ADD_INPUTS=ADD_INPUTS,\n CAST_TYPE=CAST_TYPE,\n # **config,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-26ae8834593a4982b0142b1f4e9f5b6f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton-based operation encapsulates fused recurrent computation which is split into two main kernels: 'fused_recurrent_fwd_kernel' and 'fused_recurrent_bwd_kernel'. The forward kernel is designed to perform element-wise multiplication and accumulation operations to process sequences of data in a parallelized manner using Triton's primitives. Key inputs are tensors q, k, and v, which stand for query, key, and value respectively, with additional parameters such as beta for scaling adjustments and initial_state for setting the initial hidden state. The operation also scales input queries using a specified or default scale factor. The output can optionally include the final state of the recurrent computation. The backward kernel computes gradients for each of the input tensors, leveraging Triton's efficient parallel computing capabilities to perform operations such as tensor load, store, and arithmetic computations in reverse order to update weights during backpropagation. The FusedRecurrentFunction class integrates these kernels into PyTorch's autograd mechanism, enabling them to be used as custom backward functions, with the 'fused_recurrent_delta_rule' acting as the interface for user-friendly application. Key parameters include B (batch size), H (number of heads), T (sequence length), K and V (dimensions of the key and value tensors). BK and BV are block sizes for partitioning the computation, and grid specifies the execution configuration for Triton. The operations also account for different cases such as whether initial states are used or if beta has head-wise scaling, adapting the kernel execution path accordingly.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_fwd_kernel(\n q, k, v, beta, o, h0, ht, s_qk_h, s_vo_h, scale, B, H, T, K: tl.constexpr, V: tl.constexpr, \n BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, \n IS_HEADWISE_BETA: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n else:\n p_beta = beta + i_bh * T\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _v_minus = tl.sum(h * b_k[None, :], axis=1)\n b_v -= _v_minus\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n tl.store(p_v, b_v.to(p_v.dtype.element_ty), mask=mask_bv)\n b_v *= b_beta\n h += b_k[None, :] * b_v[:, None]\n _o = h * b_q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += K\n p_k += K\n p_o += V\n p_v += V\n p_beta += V if IS_HEADWISE_BETA else 1\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_bwd_kernel(\n q, k, v, beta, dht, dh0, do, dq, dk, dv, dbeta, h0, s_qk_h, s_vo_h, NK, scale, B, H, T, \n K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, \n IS_HEADWISE_BETA: tl.constexpr, USE_DH0: tl.constexpr, USE_DHT: tl.constexpr\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n mask_bk = i_k * BK + tl.arange(0, BK) < K\n mask_bv = i_v * BV + tl.arange(0, BV) < V\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n else:\n p_beta = beta + i_bh * T + T - 1\n\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * K\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * V\n if IS_HEADWISE_BETA:\n p_dbeta = dbeta + (i_bh + i_k * B * H + i_v * B * H * NK) * s_vo_h + tl.arange(0, BV) + (T - 1) * V\n else:\n p_dbeta = dbeta + (i_bh + i_v * B * H) * T + T - 1\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_DHT:\n p_ht = dht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n d_h += tl.load(p_ht, mask=mask_bk[:, None] & mask_bv[None, :], other=0).to(tl.float32)\n\n for _ in range(T):\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n d_h += b_q[:, None] * b_do[None, :]\n d_k = tl.sum(d_h * (b_v * b_beta)[None, :], axis=1)\n d_v = tl.sum(d_h * b_k[:, None], axis=0)\n\n d_beta = d_v * b_v if IS_HEADWISE_BETA else tl.sum(d_v * b_v)\n d_v = d_v * b_beta\n\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n if IS_HEADWISE_BETA:\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty), mask=mask_bv)\n else:\n tl.store(p_dbeta, d_beta.to(p_dbeta.dtype.element_ty))\n\n d_h -= b_k[:, None] * d_v[None, :]\n\n p_do -= V\n p_q -= K\n p_k -= K\n p_v -= V\n p_dk -= K\n p_dv -= V\n p_dbeta -= V if IS_HEADWISE_BETA else 1\n p_beta -= V if IS_HEADWISE_BETA else 1\n\n if USE_DH0:\n p_dh0 = dh0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n tl.store(p_dh0, d_h.to(p_dh0.dtype.element_ty), mask=mask_bk[:, None] & mask_bv[None, :])\n\n tl.debug_barrier()\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n if IS_HEADWISE_BETA:\n p_beta = beta + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n else:\n p_beta = beta + i_bh * T\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[:, None]) * V + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n d_k = tl.load(p_dk, mask=mask_bk, other=0).to(tl.float32)\n d_v = tl.load(p_dv, mask=mask_bv, other=0).to(tl.float32)\n d_k -= tl.sum(d_v[None, :] * h, axis=1)\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n if IS_HEADWISE_BETA:\n b_beta = tl.load(p_beta, mask=mask_bv, other=0).to(tl.float32)\n else:\n b_beta = tl.load(p_beta).to(tl.float32)\n b_v *= b_beta\n\n h += b_k[:, None] * b_v[None, :]\n _d_q = h * b_do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += K\n p_do += V\n p_v += V\n p_dk += K\n p_dv += V\n p_dq += K\n p_beta += V if IS_HEADWISE_BETA else 1\n\nclass FusedRecurrentFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, beta, scale=None, initial_state=None, output_final_state=False):\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n assert NK == 1, \"NK > 1 is not supported yet\"\n o = q.new_empty(NK, B, H, T, V)\n\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32)\n else:\n final_state = None\n\n grid = (NV, NK, B * H)\n fused_recurrent_fwd_kernel[grid](\n q, k, v, beta, o, initial_state, final_state,\n q.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n IS_HEADWISE_BETA=beta.ndim == v.ndim,\n num_warps=num_warps,\n num_stages=num_stages,\n )\n o = o.squeeze(0)\n ctx.save_for_backward(q, k, v, beta, initial_state)\n ctx.scale = scale\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht):\n q, k, v, beta, initial_state = ctx.saved_tensors\n B, H, T, K, V = *q.shape, v.shape[-1]\n scale = ctx.scale\n BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n assert NK == 1, \"NK > 1 is not supported yet\"\n num_stages = 1\n num_warps = 2\n\n beta_vector = beta.ndim == v.ndim\n\n dq = q.new_empty(NV, B, H, T, K)\n dk = q.new_empty(NV, B, H, T, K)\n dv = q.new_empty(NK, B, H, T, V)\n if beta_vector:\n dbeta = q.new_empty(NV, NK, B, H, T, V)\n else:\n dbeta = q.new_empty(NV, B, H, T)\n grid = (NV, NK, B * H)\n\n if initial_state is not None and initial_state.requires_grad:\n dh0 = torch.empty_like(initial_state, dtype=torch.float32)\n else:\n dh0 = None\n\n fused_recurrent_bwd_kernel[grid](\n q, k, v, beta, dht, dh0, do, dq, dk, dv, dbeta, initial_state,\n q.stride(1),\n v.stride(1),\n NK, scale,\n B=B, H=H, T=T, K=K, V=V,\n BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n USE_DH0=dh0 is not None,\n USE_DHT=dht is not None,\n IS_HEADWISE_BETA=beta_vector,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n dbeta = dbeta.sum((0, 1)) if beta_vector else dbeta.sum(0)\n return dq.to(q), dk.to(k), dv.to(v), dbeta.to(beta), None, dh0, None\n\ndef fused_recurrent_delta_rule(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n beta: torch.Tensor = None,\n scale: float = None,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale is None:\n scale = q.shape[-1] ** -0.5\n else:\n assert scale > 0, \"scale must be positive\"\n if beta is None:\n beta = torch.ones_like(q[..., 0])\n o, final_state = FusedRecurrentFunction.apply(q, k, v, beta, scale, initial_state, output_final_state)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-875ce755aae344f7bb34ded8a8420640", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code defines a Triton kernel called `add_kernel` and a wrapper function `add` to perform element-wise addition of two CUDA tensors. `add_kernel` is a Triton JIT-compiled function that processes vectors in blocks of a specified size (`BLOCK_SIZE`). The kernel function takes pointers to the input vectors (`x_ptr` and `y_ptr`), an output vector pointer (`output_ptr`), and the number of elements (`n_elements`). It computes the sum of the elements from the two input vectors and stores the result in the output vector. \n\n The kernel uses `tl.program_id(axis=0)` to determine the block index for a 1D launch grid. `block_start` is calculated using the block index and `BLOCK_SIZE`, and `offsets` are computed for accessing elements within this block. A mask is created to handle out-of-bounds accesses, ensuring operations only occur on valid elements. `tl.load` and `tl.store` functions are utilized with the mask to safely load input elements and store the computed output.\n\n The `add` function serves as a wrapper, ensuring all tensors are on the CUDA device, calculating the total number of elements, and defining a grid based on the number of elements and `BLOCK_SIZE`. It then launches the kernel with the computed grid. The function returns the output tensor containing the summed elements.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(x_ptr, # *Pointer* to first input vector.\n y_ptr, # *Pointer* to second input vector.\n output_ptr, # *Pointer* to output vector.\n n_elements, # Size of the vector.\n BLOCK_SIZE: tl.constexpr, # Number of elements each program should process.\n ):\n pid = tl.program_id(axis=0) # We use a 1D launch grid so axis is 0.\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n y = tl.load(y_ptr + offsets, mask=mask)\n output = x + y\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef add(x: torch.Tensor, y: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and y.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n\n\n" + ] + }, + { + "id": "openseek-8-1ae8edcd84084ca08a648daee6ab7594", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a custom operation for RMS normalization using the Triton library, which is optimized for GPU execution. The kernel `rms_norm_kernel` is responsible for computing the RMS normalized version of input `X` with weights `W`, storing the results in `Y`. It operates on a per-row basis, controlled by the program ID `pid`, which assigns each row of `X` to a separate block. The mask ensures computations respect the dimensions of `X` even if it's smaller than the block size. Variance is computed as `var = tl.sum(x * x, axis=0) / N`, where `x` is the loaded data of the current row. `rrms` is calculated as the reciprocal of the root of `var` plus a small `eps` to ensure numerical stability. The computation `y = (x * rrms).to(Y.dtype.element_ty) * w` applies normalization and scales by weights. The result is stored in `Y`. The `RmsNorm` class encapsulates this in a PyTorch `Function`, providing a `forward` method which allocates outputs and calls the kernel, passing dimensions and strides. The method `rms_norm` wraps this `Function` for user-friendly access, taking the input tensor `x`, its normalized shape, the weight tensor, and an optional `eps` parameter, and returns the RMS normalized tensor.\n ", + "output": [ + "import torch\nimport triton\nimport math\nimport triton.language as tl\n\n@triton.jit(do_not_specialize=[\"eps\"])\ndef rms_norm_kernel(\n Y, # pointer to the output\n X, # pointer to the input\n W, # pointer to the weights\n y_stride_r,\n y_stride_c,\n x_stride_r, # how much to increase the pointer when moving by 1 row\n x_stride_c, # how much to increase the pointer when moving by 1 col\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(0)\n Y += pid * y_stride_r\n X += pid * x_stride_r\n\n mask = tl.arange(0, BLOCK_SIZE) < N\n cols = tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols * x_stride_c, mask, other=0.0).to(tl.float32)\n\n var = tl.sum(x * x, axis=0) / N\n rrms = 1 / tl.sqrt(var + eps)\n\n w = tl.load(W + tl.arange(0, BLOCK_SIZE), mask=mask, other=0.0)\n y = (x * rrms).to(Y.dtype.element_ty) * w\n tl.store(Y + cols * y_stride_c, y, mask=mask)\n\n\nclass RmsNorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, eps=1e-5):\n dim = x.ndim - len(normalized_shape)\n M = math.prod(x.shape[:dim])\n N = math.prod(normalized_shape)\n\n BLOCK_SIZE = triton.next_power_of_2(N)\n x = x.contiguous()\n weight = weight.contiguous()\n y = torch.empty_like(x)\n\n with torch.cuda.device(x.device):\n rms_norm_kernel[M,](y, x, weight, N, 1, N, 1, N, eps, BLOCK_SIZE)\n return y\n\n\ndef rms_norm(x, normalized_shape, weight, eps=1e-5):\n return RmsNorm.apply(x, normalized_shape, weight, eps)\n\n\n\n" + ] + }, + { + "id": "openseek-8-c8734c76552f40b1aeb4567f392f2063", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe code defines a Triton-based multinomial sampling procedure for sampling indices from a probability distribution described by the `scores` tensor. This distribution is organized as a 2D tensor with dimensions [batch_size, num_tokens], where each row represents a batch and each column represents a token's score. The `multinomial_sampling` function prepares for kernel execution by setting block sizes (BLOCK = 8, BLOCK_N = 128), creating a grid for parallel execution, and computing necessary strides for accessing tensor elements.\n\nThe Triton kernel `_multinomial_sampling_kernel` executes the core logic: for each batch, it initializes sampling seeds and offsets, uses them to generate random samples, and iteratively computes cumulative scores for the token scores block-wise. The kernel then determines the token indices where each random sample falls within the cumulative probability range and stores the result. This approach ensures efficient parallel computation and supports large batch sizes with many tokens.\n", + "output": [ + "from typing import Dict, Sequence, Union\nimport torch\nimport triton\nimport triton.language as tl\n\n\nKERNEL_META = dict()\n\ndef get_kernel_meta(tensor: torch.Tensor):\n \"\"\"kernel meta.\"\"\"\n return KERNEL_META\n\nTypeHintType = Union[Dict[str, type], Sequence[type], None]\n\n\n@triton.jit\ndef _multinomial_sampling_kernel(Scores, Seeds, Offsets, Indices, Outputs,\n stride_sb, stride_st, stride_ib, stride_it,\n num_batchs, num_tokens, BLOCK: tl.constexpr,\n BLOCK_N: tl.constexpr):\n \"\"\"Kernel.\"\"\"\n batch_block_id = tl.program_id(0)\n\n off = batch_block_id * BLOCK + tl.arange(0, BLOCK)\n n_off = tl.arange(0, BLOCK_N)\n\n off_mask = off < num_batchs\n seed = tl.load(Seeds + off, mask=off_mask)\n offset = tl.load(Offsets + off, mask=off_mask).to(tl.int32)\n\n samp = tl.rand(seed, offset)[:, None]\n acc = tl.zeros((BLOCK, ), dtype=tl.float32)\n output = tl.load(Indices + off * stride_ib, mask=off_mask)\n\n for b_idx in range(0, num_tokens, BLOCK_N):\n s_off = b_idx + n_off\n s_mask = off_mask[:, None] & (s_off[None, :] < num_tokens)\n scores = tl.load(Scores + off[:, None] * stride_sb +\n s_off[None, :] * stride_st,\n mask=s_mask,\n other=0.0).to(tl.float32)\n c_scores = tl.cumsum(scores, 1)\n cum_scores = acc[:, None] + c_scores\n acc += tl.max(c_scores, 1)\n\n pre_cum_scores = cum_scores - scores\n valid_mask = (samp > pre_cum_scores) & (samp <= cum_scores)\n found_mask = tl.sum(valid_mask, 1) > 0\n\n valid_pos = b_idx + tl.argmax(valid_mask.to(tl.int32), 1)\n indices = tl.load(Indices + off * stride_ib + valid_pos * stride_it,\n mask=found_mask & off_mask,\n other=-1)\n output = tl.where(found_mask, indices, output)\n\n tl.store(Outputs + off, output, mask=off_mask)\n\n\ndef multinomial_sampling(scores: torch.Tensor,\n seeds: torch.LongTensor,\n offsets: torch.LongTensor,\n indices: torch.Tensor = None):\n \"\"\"multinomial sampling.\"\"\"\n\n assert scores.dim() == 2\n batch_size, num_tokens = scores.size()\n device = scores.device\n\n if num_tokens == 1:\n return torch.zeros_like(scores, dtype=torch.long)\n\n if indices is None:\n indices = torch.arange(num_tokens, device=device)\n indices = indices.expand_as(scores)\n\n assert indices.dim() == 2\n assert indices.size() == scores.size()\n\n outputs = indices[:, 0].clone()\n\n BLOCK = 8\n BLOCK_N = 128\n\n grid = [triton.cdiv(batch_size, BLOCK)]\n kernel_meta = get_kernel_meta(scores)\n _multinomial_sampling_kernel[grid](scores,\n seeds,\n offsets,\n indices,\n outputs,\n stride_sb=scores.stride(0),\n stride_st=scores.stride(1),\n stride_ib=indices.stride(0),\n stride_it=indices.stride(1),\n num_batchs=batch_size,\n num_tokens=num_tokens,\n BLOCK=BLOCK,\n BLOCK_N=BLOCK_N,\n num_warps=8,\n **kernel_meta)\n\n return outputs\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-f1b0c6c8ba9b49cca7b1cb77c3cce205", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton operator implements a mixed sparse attention mechanism. The primary kernel function, `_triton_mixed_sparse_attn_fwd_kernel`, processes query (Q), key (K), and value (V) tensors, generating an output tensor (Out) by applying block-wise sparse attention computations. It uses block indices and counts specified by `block_count`, `block_offset`, `column_count`, and `column_index` tensors to selectively attend to certain regions of the inputs. The inputs also include `seqlens` for sequence lengths and `sm_scale` for scaling dot products. Within the kernel, the algorithm computes scaled dot-products, applies a causal mask to prevent attending to future tokens, and accumulates weighted sums to produce the output. Blocks of the input tensors are processed iteratively, using a maximum block and column count approach to manage computations. The `qk_scale` factor is applied to scale the query tensor before computing attention scores. The wrapper function `_triton_mixed_sparse_attention` prepares the input tensors and dimensions, manages the execution grid configuration, and invokes the Triton kernel.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _triton_mixed_sparse_attn_fwd_kernel(\n Q, K, V, seqlens, sm_scale,\n block_count, block_offset, column_count, column_index,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vn, stride_vk,\n stride_oz, stride_oh, stride_om, stride_ok,\n Z, H, N_CTX,\n NUM_ROWS, NNZ_S, NNZ_V,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n dtype: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n\n seqlen = tl.load(seqlens + off_hz // H)\n if start_m * BLOCK_M >= seqlen:\n return\n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n qo_offset = (off_hz // H) * stride_qz + (off_hz % H) * stride_qh\n kv_offset = (off_hz // H) * stride_kz + (off_hz % H) * stride_kh\n\n q_ptrs = Q + qo_offset + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk\n k_ptrs = K + kv_offset + offs_d[:, None] * stride_kk\n v_ptrs = V + kv_offset + offs_d[None, :] * stride_vk\n o_ptrs = Out + qo_offset + offs_m[:, None] * stride_om + offs_d[None, :] * stride_ok\n\n num_blks = tl.load(block_count + off_hz * NUM_ROWS + start_m)\n blks_ptr = block_offset + (off_hz * NUM_ROWS + start_m) * NNZ_S\n num_cols = tl.load(column_count + off_hz * NUM_ROWS + start_m)\n cols_ptr = column_index + (off_hz * NUM_ROWS + start_m) * NNZ_V\n\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n qk_scale = sm_scale * 1.44269504\n q = tl.load(q_ptrs)\n q = (q * qk_scale).to(dtype)\n\n m_mask = offs_m[:, None] < seqlen\n\n # 设置最大块数\n max_num_blks = 8 # 根据实际需求调整\n for block_index in range(max_num_blks):\n # 使用条件掩码代替 break\n cond = block_index < num_blks\n # 仅在 cond 为 True 时执行以下操作\n start_n = tl.load(blks_ptr + block_index, mask=cond)\n cols = start_n + offs_n\n n_mask = (cols < seqlen) & cond[:, None]\n k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)\n v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n causal_mask = (cols[None, :] <= offs_m[:, None])\n qk = tl.where(m_mask & causal_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(dtype), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n # 设置最大列数\n max_num_cols = 16 # 根据实际需求调整\n for start_n in range(0, max_num_cols, BLOCK_N):\n # 使用条件掩码代替 break\n cond = start_n < num_cols\n n_mask = (start_n + offs_n < num_cols) & cond\n cols = tl.load(cols_ptr + start_n + offs_n, mask=cond[:, None], other=0)\n k = tl.load(k_ptrs + cols[None, :] * stride_kn, mask=n_mask[None, :], other=0.0)\n v = tl.load(v_ptrs + cols[:, None] * stride_vn, mask=n_mask[:, None], other=0.0)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk = tl.where(m_mask & n_mask, qk, float(\"-inf\"))\n qk += tl.dot(q, k)\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n acc_scale = l_i * 0 + alpha\n acc *= acc_scale[:, None]\n acc += tl.dot(p.to(dtype), v)\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n\n acc /= l_i[:, None]\n tl.store(o_ptrs, acc.to(dtype), mask=m_mask)\n\n\ndef _triton_mixed_sparse_attention(\n q: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]\n k: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]\n v: torch.Tensor, # [BATCH, N_HEADS, N_CTX, D_HEAD]\n seqlens: torch.Tensor, # [BATCH, ]\n block_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]\n block_offset: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_S]\n column_count: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M)]\n column_index: torch.Tensor, # [BATCH, N_HEADS, cdiv(N_CTX, BLOCK_SIZE_M), NNZ_V]\n sm_scale: float,\n block_size_M: int = 64,\n block_size_N: int = 64,\n) -> torch.Tensor:\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.zeros_like(q)\n grid = (triton.cdiv(q.shape[2], block_size_M), q.shape[0] * q.shape[1], 1)\n dtype = tl.bfloat16 if q.dtype == torch.bfloat16 else tl.float16\n _triton_mixed_sparse_attn_fwd_kernel[grid](\n q, k, v, seqlens, sm_scale,\n block_count, block_offset, column_count, column_index,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n q.shape[0], q.shape[1], q.shape[2],\n block_count.shape[-1], block_offset.shape[-1], column_index.shape[-1],\n BLOCK_M=block_size_M, BLOCK_N=block_size_N,\n BLOCK_DMODEL=Lk,\n dtype=dtype,\n num_warps=4, num_stages=2,\n )\n\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-5927333b3f17454783fdf2f6b9e043f3", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements the forward and backward pass of a chunk retention operation using the Triton framework to optimize tensor computations. \n\n It consists of multiple Triton kernel definitions and their Python wrappers for executing these kernels with specified parameters. \n\n - `chunk_retention_fwd_kernel_h` is used in the forward pass to compute hidden states. It iteratively processes input tensors `k` and `v` across specified chunks (size `BT`) and updates the hidden state tensor `h`, optionally using an initial state and storing the final state. It takes into account block indices and handles boundaries with `boundary_check`.\n\n - `chunk_retention_fwd_kernel_o` computes the final output tensor by combining `q`, `k`, and `v`, applying scaling, and handling chunk-wise processing with decay factors for block-wise computation.\n\n - `chunk_retention_bwd_kernel_dh` calculates gradients for hidden states during the backward pass, processing `q`, `do`, and updating the gradient tensor `dh`.\n\n - `chunk_retention_bwd_kernel_dqkv` computes gradients for the input tensors `q`, `k`, and `v` by processing `do`, `dh`, and hidden state `h`.\n\n These kernels are executed by the Python functions:\n - `chunk_fwd_h_fn`: Sets up and calls `chunk_retention_fwd_kernel_h`.\n - `chunk_fwd_o_fn`: Sets up and calls `chunk_retention_fwd_kernel_o`.\n - `chunk_bwd_dh_fn`: Sets up and calls `chunk_retention_bwd_kernel_dh`.\n - `chunk_bwd_dqkv_fn`: Sets up and calls `chunk_retention_bwd_kernel_dqkv`.\n\n `ChunkRetentionFunction` is a custom PyTorch `autograd.Function` that utilizes these kernels for differentiable operations. It manages context saving and restores for gradient computations. The `forward` method initializes computations while the `backward` method orchestrates the gradient computations.\n\n The `chunk_retention` function offers a user-friendly interface, allowing the operation to be performed on input tensors with options for initial and final states, scaling, and checkpointing, facilitating both forward and backward passes efficiently.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4)\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_retention_fwd_kernel_h(\n k, v, h, h0, ht, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n # Triton kernel code for forward pass of chunk retention with initial and final state handling\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n if i_t == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False)\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4)\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_retention_fwd_kernel_o(\n q, k, v, h, o, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n scale, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n # Triton kernel code for forward pass of chunk retention with output scaling\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n o_i = tl.arange(0, BT)\n d_i = tl.math.exp2((o_i + 1) * b_b)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n b_s *= d_s\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4)\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_retention_bwd_kernel_dh(\n q, do, dh, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n scale, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # Triton kernel code for backward pass of chunk retention, computing gradients for hidden state\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b)\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh = d_b * b_dh + tl.dot(b_q, (b_do * d_i[:, None]).to(b_q.dtype), allow_tf32=False)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4)\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_retention_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h, s_vo_t, s_vo_d, s_h_h, s_h_t,\n scale, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n # Triton kernel code for backward pass of chunk retention, computing gradients for q, k, v\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n n_bh = tl.num_programs(2)\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n o_i = tl.arange(0, BT)\n d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n d_q = (d_q * scale).to(d_q.dtype)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s)\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False)\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n b_ds = (b_ds * d_s).to(b_q.dtype)\n b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\ndef chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state):\n B, H, T, K, V = *k.shape, v.shape[-1]\n final_state = None\n if output_final_state:\n final_state = k.new_empty(B, H, K, V, dtype=torch.float32)\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n h = k.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_fwd_kernel_h[grid](\n k, v, h, initial_state, final_state,\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state\n )\n return h, final_state\n\n\ndef chunk_fwd_o_fn(h, q, k, v, BT, scale):\n B, H, T, K, V = *k.shape, v.shape[-1]\n o = torch.empty_like(v)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NV = triton.cdiv(V, BV)\n NT = triton.cdiv(T, BT)\n grid = (NV, NT, B * H)\n chunk_retention_fwd_kernel_o[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV\n )\n return o\n\n\ndef chunk_bwd_dh_fn(do, q, k, v, BT, scale):\n B, H, T, K, V = *k.shape, v.shape[-1]\n BT = 64\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n dh = k.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_bwd_kernel_dh[grid](\n q, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT\n )\n return dh\n\n\ndef chunk_bwd_dqkv_fn(do, q, k, v, h, dh, scale):\n B, H, T, K, V = *k.shape, v.shape[-1]\n BT = 64\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT, NK = triton.cdiv(T, BT), triton.cdiv(K, BK)\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n chunk_retention_bwd_kernel_dqkv[grid](\n q, k, v, h, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT\n )\n dv = dv.sum(0)\n return dq, dk, dv\n\n\nclass ChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, initial_state, output_final_state, scale, checkpoint_level):\n BT = 64\n h, final_state = chunk_fwd_h_fn(k, v, BT, initial_state, output_final_state)\n o = chunk_fwd_o_fn(h, q, k, v, BT, scale)\n if checkpoint_level == 1:\n h = None\n ctx.save_for_backward(q, k, v, h, initial_state)\n ctx.BT, ctx.scale = BT, scale\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, d_ht=None):\n BT, scale = ctx.BT, ctx.scale\n q, k, v, h, initial_state = ctx.saved_tensors\n if h is None:\n h, _ = chunk_fwd_h_fn(k, v, BT, initial_state, False)\n dh = chunk_bwd_dh_fn(do, q, k, v, BT, scale)\n dq, dk, dv = chunk_bwd_dqkv_fn(do, q, k, v, h, dh, scale)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None, None\n\n\ndef chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False,\n scale: float = None,\n checkpoint_level: int = 1\n) -> Tuple[torch.Tensor, torch.Tensor]:\n assert checkpoint_level in [0, 1], \"checkpoint_level must be 0, 1\"\n assert q.dim() == k.dim() == v.dim() == 4, \"q, k, v must have 4 dimensions (b, h, l, d)\"\n assert q.dtype == k.dtype == v.dtype, \"q, k, v must have the same dtype\"\n if scale is None:\n scale = q.size(-1) ** -0.5\n o, final_state = ChunkRetentionFunction.apply(\n q, k, v, initial_state, output_final_state, scale, checkpoint_level)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-b6aabe7c1f14454ebb2e26d3386351d5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton kernel `square_kernel` computes the element-wise square of each element in a 2D input tensor. The kernel logic works as follows:\n - It launches one instance per row of the input, making rows independent for parallel execution. \n - The function parameters include pointers to the input and output, the row strides for input and output, the number of columns, and the `BLOCK_SIZE` representing the power of two greater than the number of columns to manage the memory layout.\n - For each row, it calculates starting pointers and loads the row into SRAM using `tl.load` with a masking operation to handle cases where `BLOCK_SIZE` exceeds the number of columns.\n - The computation of squaring each element in the row is performed simply by multiplying the row with itself, stored in `square_output`.\n - Finally, it stores the squared values back to global memory using `tl.store`.\n\n The wrapper function `square`:\n - Accepts a 2D tensor `x` and computes the number of rows and columns.\n - Determines `BLOCK_SIZE` as the smallest power of two greater than `n_cols` and sets `num_warps` for parallel execution; it adjusts this based on `BLOCK_SIZE` to optimize performance.\n - Allocates an output tensor `y` using PyTorch's `empty_like` to match the input tensor's shape.\n - Launches the `square_kernel` with a 1D grid configuration corresponding to the number of rows, ensuring each row is processed by one kernel instance.\n - Returns the squared tensor `y`.\n ", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef square_kernel(output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols, BLOCK_SIZE: tl.constexpr):\n # The rows of the softmax are independent, so we parallelize across those\n row_idx = tl.program_id(0)\n # The stride represents how much we need to increase the pointer to advance 1 row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n # The block size is the next power of two greater than n_cols, so we can fit each\n # row in a single block\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n # Load the row into SRAM, using a mask since BLOCK_SIZE may be > than n_cols\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float('inf'))\n\n square_output = row * row\n \n # Write back output to DRAM\n output_row_start_ptr = output_ptr + row_idx * output_row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n tl.store(output_ptrs, square_output, mask=col_offsets < n_cols)\n\n\ndef square(x):\n n_rows, n_cols = x.shape\n # The block size is the smallest power of two greater than the number of columns in `x`\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n # Another trick we can use is to ask the compiler to use more threads per row by\n # increasing the number of warps (`num_warps`) over which each row is distributed.\n # You will see in the next tutorial how to auto-tune this value in a more natural\n # way so you don't have to come up with manual heuristics yourself.\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n # Allocate output\n y = torch.empty_like(x)\n # Enqueue kernel. The 1D launch grid is simple: we have one kernel instance per row o\n # f the input matrix\n square_kernel[(n_rows, )](\n y,\n x,\n x.stride(0),\n y.stride(0),\n n_cols,\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return y\n\n\n\n" + ] + }, + { + "id": "openseek-8-636d07da5ea0422896c9c21ddecc5bff", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code is designed to handle complex matrix operations, primarily involving computation of attention matrices and outputs using block-based processing. Here's an intricate breakdown:\n\n The code consists of multiple Triton kernels, each designed for different segments of the operation:\n\n 1. **chunk_gla_fwd_A_kernel_intra_sub_inter**:\n - Input Parameters: `q` (query), `k` (key), `g` (gradient or gating), `A` (matrix for storing results), and several other configuration scalars.\n - Functionality: This kernel computes a sub-block of the matrix `A` using input matrices `q`, `k`, and `g`. It ensures that computations are done only for valid blocks (`i_i <= i_j` is skipped to ensure upper triangle operations). It loads blocks, scales and modifies them using exponential transformations, and accumulates results in `b_A`. The computed block is then stored back.\n\n 2. **chunk_gla_fwd_A_kernel_intra_sub_intra**:\n - Similar in setup to the first kernel but focuses on processing sub-blocks intra-thread to cover diagonal segments of `A`.\n - Implements logic to sum over the product of matrix blocks while considering exponentials and scaling for attention-like computations.\n\n 3. **chunk_gla_fwd_A_kernel_intra_sub_intra_split**:\n - Introduces splitting of computation along the `K` dimension when `K` is large. This kernel writes intermediate results into `A_intra`, allowing for memory-efficient partial computation across the `K` dimension.\n\n 4. **chunk_gla_fwd_A_kernel_intra_sub_intra_merge**:\n - Takes partially computed blocks from `A_intra` and merges them into the main matrix `A`. This step is necessary to consolidate the split results into a coherent output.\n\n 5. **chunk_gla_fwd_kernel_o**:\n - Manages the construction of output `o`, crucial for attention mechanisms.\n - It processes inputs using gated mechanisms and cumulative sums, combining them with the previous results to form the final output `o`.\n\n Wrapper Functions:\n - `chunk_fwd_intra_gated_gk_fn`: Executes the sequence of kernels responsible for calculating `A` based on the dimensions and splits specified.\n - `chunk_fwd_o_gated_gk_fn`: Employs the previously computed `A` along with other inputs to determine the final result `o`.\n\n The kernels utilize Triton's efficient block operations, leveraging parallelism in GPU architectures. This includes advanced techniques like `tl.dot` for fast matrix multiplication and `tl.load`/`tl.store` for seamless data movement across the GPU.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"BC\", \"BK\"],\n)\n@triton.jit\ndef chunk_gla_fwd_A_kernel_intra_sub_inter(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr\n):\n i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_i, i_j = i_c // NC, i_c % NC\n if i_t * BT + i_i * BC >= T:\n return\n if i_i <= i_j:\n return\n\n b_A = tl.zeros([BC, BC], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n o_k = i_k * BK + tl.arange(0, BK)\n m_k = o_k < K\n\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gk = tl.make_block_ptr(g + i_bh * s_k_h, (K, T), (1, s_k_t), (i_k * BK, i_t * BT + i_j * BC), (BK, BC), (0, 1))\n p_gn = tl.max_contiguous(tl.multiple_of(g + i_bh * s_k_h + (i_t * BT + i_i * BC) * K + o_k, BK), BK)\n b_gn = tl.load(p_gn, mask=m_k, other=0)\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = b_q * tl.exp(b_g - b_gn[None, :]) * scale\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_gk = tl.load(p_gk, boundary_check=(0, 1))\n b_kg = b_k * tl.exp(b_gn[:, None] - b_gk)\n b_A += tl.dot(b_qg, b_kg)\n\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_i * BC, i_j * BC), (BC, BC), (1, 0))\n tl.store(p_A, b_A.to(A.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"BK\", \"BT\"],\n)\n@triton.jit\ndef chunk_gla_fwd_A_kernel_intra_sub_intra(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr\n):\n i_t, i_i, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_j = i_i\n if i_t * BT + i_i * BC >= T:\n return\n\n o_i = tl.arange(0, BC)\n o_k = tl.arange(0, BK)\n o_A = i_bh * T * BT + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BT + i_j * BC\n m_k = o_k < K\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, 0), (BC, BK), (1, 0))\n\n p_k = tl.max_contiguous(tl.multiple_of(k + i_bh * s_k_h + (i_t * BT + i_j * BC) * K + o_k, BK), BK)\n p_gk = tl.max_contiguous(tl.multiple_of(g + i_bh * s_k_h + (i_t * BT + i_j * BC) * K + o_k, BK), BK)\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n for j in range(0, min(BC, T-i_t*BT-i_i*BC)):\n b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)\n b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)\n b_A = tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]), 1)\n b_A = tl.where(o_i >= j, b_A * scale, 0.)\n tl.store(A + o_A + j, b_A, mask=m_A)\n p_k += K\n p_gk += K\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"BC\", \"BK\"],\n)\n@triton.jit\ndef chunk_gla_fwd_A_kernel_intra_sub_intra_split(\n q,\n k,\n g,\n A,\n s_k_h,\n s_k_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n BK: tl.constexpr,\n NC: tl.constexpr\n):\n i_k, i_tc, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_t, i_i = i_tc // NC, i_tc % NC\n i_j = i_i\n n_bh = tl.num_programs(2)\n if i_t * BT + i_i * BC >= T:\n return\n\n o_i = tl.arange(0, BC)\n o_k = i_k * BK + tl.arange(0, BK)\n o_A = (i_bh + i_k * n_bh) * T * BC + (i_t * BT + i_i * BC + tl.arange(0, BC)) * BC\n m_k = o_k < K\n m_A = (i_t * BT + i_i * BC + tl.arange(0, BC)) < T\n\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT + i_i * BC, i_k * BK), (BC, BK), (1, 0))\n p_k = tl.max_contiguous(tl.multiple_of(k + i_bh * s_k_h + (i_t * BT + i_j * BC) * K + o_k, BK), BK)\n p_gk = tl.max_contiguous(tl.multiple_of(g + i_bh * s_k_h + (i_t * BT + i_j * BC) * K + o_k, BK), BK)\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_g = tl.load(p_g, boundary_check=(0, 1))\n for j in range(0, min(BC, T-i_t*BT-i_i*BC)):\n b_A = tl.zeros([BC], dtype=tl.float32)\n b_k = tl.load(p_k, mask=m_k, other=0).to(tl.float32)\n b_gk = tl.load(p_gk, mask=m_k, other=0).to(tl.float32)\n b_A += tl.sum(b_q * b_k[None, :] * tl.exp(b_g - b_gk[None, :]), 1)\n b_A = tl.where(o_i >= j, b_A * scale, 0.)\n tl.store(A + o_A + j, b_A, mask=m_A)\n p_k += K\n p_gk += K\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"BC\"],\n)\n@triton.jit\ndef chunk_gla_fwd_A_kernel_intra_sub_intra_merge(\n A,\n A2,\n T: tl.constexpr,\n BT: tl.constexpr,\n BC: tl.constexpr,\n NK: tl.constexpr\n):\n i_t, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n if i_t * BT + i_c * BC >= T:\n return\n n_bh = tl.num_programs(2)\n b_A = tl.zeros([BC, BC], dtype=tl.float32)\n for i_k in range(0, NK):\n p_A = tl.make_block_ptr(A + (i_bh + i_k*n_bh) * T * BC, (T, BC), (BC, 1), (i_t * BT + i_c * BC, 0), (BC, BC), (1, 0))\n b_A += tl.load(p_A, boundary_check=(0, 1))\n p_A2 = tl.make_block_ptr(A2 + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT + i_c * BC, i_c * BC), (BC, BC), (1, 0))\n tl.store(p_A2, b_A.to(A2.dtype.element_ty), boundary_check=(0, 1))\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"BK\", \"BV\", \"BT\"],\n)\n@triton.jit\ndef chunk_gla_fwd_kernel_o(\n q,\n v,\n g,\n h,\n o,\n A,\n s_k_h,\n s_k_t,\n s_v_h,\n s_v_t,\n s_h_h,\n s_h_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_g = tl.make_block_ptr(g + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_g = tl.load(p_g, boundary_check=(0, 1))\n b_qg = (b_q * tl.exp(b_g)).to(b_q.dtype)\n b_h = tl.load(p_h, boundary_check=(0, 1))\n if i_k >= 0:\n b_o += tl.dot(b_qg, b_h.to(b_qg.dtype))\n\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_A = tl.make_block_ptr(A + i_bh * T * BT, (T, BT), (BT, 1), (i_t * BT, 0), (BT, BT), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_A = tl.load(p_A, boundary_check=(0, 1))\n b_A = tl.where(m_s, b_A, 0.).to(b_v.dtype)\n b_o += tl.dot(b_A, b_v, allow_tf32=False)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\ndef chunk_fwd_intra_gated_gk_fn(q, k, g, scale, BT):\n B, H, T, K = q.shape\n BC = 16\n NC = triton.cdiv(BT, BC)\n NT = triton.cdiv(T, BT)\n\n BK = min(64, triton.next_power_of_2(K))\n A = q.new_empty(B, H, T, BT, dtype=torch.float32)\n grid = (NT, NC * NC, B * H)\n chunk_gla_fwd_A_kernel_intra_sub_inter[grid](\n q, k, g, A,\n k.stride(1), k.stride(2),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC\n )\n grid = (NT, NC, B * H)\n if K <= 256:\n BK = triton.next_power_of_2(K)\n chunk_gla_fwd_A_kernel_intra_sub_intra[grid](\n q, k, g, A,\n k.stride(1), k.stride(2),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK\n )\n else:\n BK = 128\n NK = triton.cdiv(K, BK)\n A_intra = q.new_empty(NK, B, H, BT, BC, dtype=torch.float32)\n grid = (NK, NT * NC, B * H)\n chunk_gla_fwd_A_kernel_intra_sub_intra_split[grid](\n q, k, g, A_intra,\n k.stride(1), k.stride(2),\n scale,\n T=T, K=K, BT=BT, BC=BC, BK=BK, NC=NC\n )\n grid = (NT, NC, B * H)\n chunk_gla_fwd_A_kernel_intra_sub_intra_merge[grid](\n A_intra, A,\n T=T, BT=BT, BC=BC, NK=NK\n )\n return A\n\ndef chunk_fwd_o_gated_gk_fn(q, v, g_cumsum, A, h, BT, scale):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BK = min(32, triton.next_power_of_2(K))\n BV = min(32, triton.next_power_of_2(V))\n NV = triton.cdiv(V, BV)\n NT = triton.cdiv(T, BT)\n\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_gla_fwd_kernel_o[grid](\n q, v, g_cumsum, h, o, A,\n q.stride(1), q.stride(2),\n v.stride(1), v.stride(2),\n h.stride(1), h.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV\n )\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-fa527a1ca75c4c04b43efc8eacad70fb", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The forward function `_layer_norm_fwd_1pass_kernel` in Triton performs layer normalization on a row of a matrix with configurable settings for using residual connections, storing residual outputs, RMS normalization, and applying a bias. Inputs include tensors for data, weights, and optionally biases and residuals. Each row computes the mean and variance, normalizes the data, applies weights and biases, and stores the result. The kernel utilizes Triton's `autotune` to optimize for various configurations by selecting the number of warps. \n\n The backward function `_layer_norm_bwd_kernel` computes the gradients of input data, weights, and biases, considering residuals if used. It can also recompute the forward pass output when necessary for gradient accuracy. The function supports different configurations, such as whether to store additional residual gradients or use RMS normalization, and optimizes execution over multiple streaming multiprocessors (SMs) by distributing rows evenly. The gradients are accumulated across SMs and returned for use in parameter updates. \n\n Both functions rely on a BLOCK_N size determined by input dimensions and memory constraints, ensuring the features do not exceed a set threshold. The forward and backward functions handle edge cases like differing data types for outputs and residuals. Additionally, the functions use autotuning to find the best execution configuration for different input sizes and feature settings.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, RESIDUAL_OUT, Mean, Rstd, \n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, \n N, eps, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, \n HAS_RESIDUAL: tl.constexpr, STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, out_dtype=None, \n residual_dtype=None, is_rms_norm=False\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if residual is not None or (residual_dtype is not None and residual_dtype != x.dtype):\n residual_out = torch.empty(M, N, device=x.device, dtype=residual_dtype)\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=\"cuda\") if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=\"cuda\")\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, residual_out, \n mean, rstd, x.stride(0), y.stride(0), \n residual.stride(0) if residual is not None else 0, \n residual_out.stride(0) if residual_out is not None else 0, \n N, eps, is_rms_norm, BLOCK_N, residual is not None, \n residual_out is not None, bias is not None\n )\n return y, mean, rstd, residual_out if residual_out is not None else x\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_DRESIDUAL\", \"STORE_DRESIDUAL\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"Y\"] is not None})\n@triton.jit\ndef _layer_norm_bwd_kernel(\n X, W, B, Y, DY, DX, DW, DB, DRESIDUAL, DRESIDUAL_IN, \n Mean, Rstd, stride_x_row, stride_y_row, stride_dy_row, \n stride_dx_row, stride_dres_row, stride_dres_in_row, M, \n N, eps, rows_per_program, IS_RMS_NORM: tl.constexpr, \n BLOCK_N: tl.constexpr, HAS_DRESIDUAL: tl.constexpr, \n STORE_DRESIDUAL: tl.constexpr, HAS_BIAS: tl.constexpr, \n RECOMPUTE_OUTPUT: tl.constexpr\n):\n row_block_id = tl.program_id(0)\n row_start = row_block_id * rows_per_program\n cols = tl.arange(0, BLOCK_N)\n mask = cols < N\n X += row_start * stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += row_start * stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += row_start * stride_dres_in_row\n DY += row_start * stride_dy_row\n DX += row_start * stride_dx_row\n if RECOMPUTE_OUTPUT:\n Y += row_start * stride_y_row\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if RECOMPUTE_OUTPUT and HAS_BIAS:\n b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)\n dw = tl.zeros((BLOCK_N,), dtype=tl.float32)\n if HAS_BIAS:\n db = tl.zeros((BLOCK_N,), dtype=tl.float32)\n row_end = min((row_block_id + 1) * rows_per_program, M)\n for row in range(row_start, row_end):\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n if not IS_RMS_NORM:\n mean = tl.load(Mean + row) if Mean is not None else 0.0 # 修改此行\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n xhat = tl.where(mask, xhat, 0.0)\n if RECOMPUTE_OUTPUT:\n y = xhat * w + b if HAS_BIAS else xhat * w\n tl.store(Y + cols, y, mask=mask)\n wdy = w * dy\n dw += dy * xhat\n if HAS_BIAS:\n db += dy\n if not IS_RMS_NORM:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n else:\n c1 = tl.sum(xhat * wdy, axis=0) / N\n dx = (wdy - xhat * c1) * rstd\n if HAS_DRESIDUAL:\n dres = tl.load(DRESIDUAL + cols, mask=mask, other=0).to(tl.float32)\n dx += dres\n if STORE_DRESIDUAL:\n tl.store(DRESIDUAL_IN + cols, dx, mask=mask)\n tl.store(DX + cols, dx, mask=mask)\n X += stride_x_row\n if HAS_DRESIDUAL:\n DRESIDUAL += stride_dres_row\n if STORE_DRESIDUAL:\n DRESIDUAL_IN += stride_dres_in_row\n if RECOMPUTE_OUTPUT:\n Y += stride_y_row\n DY += stride_dy_row\n DX += stride_dx_row\n tl.store(DW + row_block_id * N + cols, dw, mask=mask)\n if HAS_BIAS:\n tl.store(DB + row_block_id * N + cols, db, mask=mask)\n\n\ndef _layer_norm_bwd(\n dy, x, weight, bias, eps, mean, rstd, dresidual=None, \n has_residual=False, is_rms_norm=False, x_dtype=None, recompute_output=False\n):\n M, N = x.shape\n assert x.stride(-1) == 1\n assert dy.stride(-1) == 1\n assert dy.shape == (M, N)\n if dresidual is not None:\n assert dresidual.stride(-1) == 1\n assert dresidual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n dx = (\n torch.empty_like(x)\n if x_dtype is None\n else torch.empty(M, N, dtype=x_dtype, device=x.device)\n )\n dresidual_in = torch.empty_like(x) if has_residual and dx.dtype != x.dtype else None\n y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count\n _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)\n _db = (\n torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)\n if bias is not None\n else None\n )\n rows_per_program = math.ceil(M / sm_count)\n grid = (sm_count,)\n with torch.cuda.device(x.device.index):\n _layer_norm_bwd_kernel[grid](\n x, weight, bias, y, dy, dx, _dw, _db, dresidual, \n dresidual_in, mean, rstd, x.stride(0), \n 0 if not recompute_output else y.stride(0), dy.stride(0), \n dx.stride(0), dresidual.stride(0) if dresidual is not None else 0, \n dresidual_in.stride(0) if dresidual_in is not None else 0, \n M, N, eps, rows_per_program, is_rms_norm, BLOCK_N, \n dresidual is not None, dresidual_in is not None, bias is not None\n )\n dw = _dw.sum(0).to(weight.dtype)\n db = _db.sum(0).to(bias.dtype) if bias is not None else None\n if has_residual and dx.dtype == x.dtype:\n dresidual_in = dx\n return (dx, dw, db, dresidual_in) if not recompute_output else (dx, dw, db, dresidual_in, y)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-54b29d0eec8f49909a260166ae32cf8f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_copy_to_kvcache_seqlen1_kernel` is defined to efficiently copy elements from the input tensors `K` (keys) and `V` (values) to the respective cache tensors `KCache` and `VCache`. It supports two types of cache layouts: one with a four-dimensional layout `[num_blocks, num_kv_heads, block_size, head_dim]`, and another with a five-dimensional layout `[num_blocks, num_kv_heads, head_dim // x, block_size, x]`.\n The kernel is parameterized by the block size, head dimensions, and stride lengths for efficient memory indexing. It uses Triton's `program_id` to obtain the current sequence and head index, determining where to store the data in the cache. The `block_tables` tensor stores mapping information of blocks for each sequence, and `context_lengths` provides the lengths of past sequences.\n The wrapper function `copy_kv_to_blocked_cache` orchestrates the setup for kernel invocation, asserting the compatibility of input shapes and deciding on the cache layout. It calculates strides and offsets based on whether the new cache layout is used, then launches the kernel with a calculated grid size based on batch size and number of heads.\n This setup ensures that during model decoding, only necessary data is efficiently copied to cache for further operations, thereby optimizing the memory bandwidth and computation required in subsequent steps.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n# supports two types of cache layouts\n# 1. [num_blocks, num_kv_heads, block_size, head_dim]\n# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x]\n@triton.jit\ndef _copy_to_kvcache_seqlen1_kernel(\n K,\n V,\n KCache,\n VCache,\n BLOCK_TABLES,\n context_lengths,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_vt,\n stride_vh,\n stride_vd,\n stride_kcb,\n stride_kch,\n stride_kcsplit_x,\n stride_kcs,\n stride_kcd,\n stride_vcb,\n stride_vch,\n stride_vcs,\n stride_vcd,\n stride_bts,\n stride_btb,\n block_size,\n HEAD_DIM: tl.constexpr,\n KCACHE_X: tl.constexpr,\n):\n cur_seq_idx = tl.program_id(0)\n cur_kv_head_idx = tl.program_id(1)\n\n past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) - 1\n last_bt_block_idx = past_kv_seq_len // block_size\n block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts\n block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)\n offsets_in_last_block = past_kv_seq_len % block_size\n\n range_x = tl.arange(0, KCACHE_X)\n offsets_dmodel_x_partition = tl.arange(0, KCACHE_X)\n\n for split_x in tl.static_range(HEAD_DIM // KCACHE_X):\n offsets_dmodel_x_partition = tl.arange(split_x * KCACHE_X, (split_x + 1) * KCACHE_X)\n offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel_x_partition * stride_kd\n k = tl.load(K + offsets_k)\n offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel_x_partition * stride_vd\n v = tl.load(V + offsets_v)\n\n offsets_kcache = (\n block_id * stride_kcb\n + cur_kv_head_idx * stride_kch\n + split_x * stride_kcsplit_x\n + offsets_in_last_block * stride_kcs\n + range_x\n )\n tl.store(KCache + offsets_kcache, k)\n offsets_vcache = (\n block_id * stride_vcb\n + cur_kv_head_idx * stride_vch\n + offsets_in_last_block * stride_vcs\n + offsets_dmodel_x_partition * stride_vcd\n )\n tl.store(VCache + offsets_vcache, v)\n return\n\n\ndef copy_kv_to_blocked_cache(\n k: torch.Tensor,\n v: torch.Tensor,\n k_cache: torch.Tensor,\n v_cache: torch.Tensor,\n kv_lengths: torch.Tensor,\n block_tables: torch.Tensor,\n use_new_kcache_layout: bool = False,\n):\n \"\"\"\n Copy keys or values to the blocked key/value cache during decoding stage.\n\n Args:\n k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys during decoding with seq len 1.\n v (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Values during decoding with seq len 1.\n k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key cache.\n v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked value cache.\n kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.\n block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.\n use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.\n \"\"\"\n k_cache_shape = k_cache.shape\n v_cache_shape = v_cache.shape\n\n if use_new_kcache_layout:\n assert (\n len(k_cache_shape) == 5\n and k_cache_shape[1] == v_cache_shape[1]\n and k_cache_shape[2] * k_cache_shape[4] == v_cache_shape[3]\n ), f\"Invalid KCache shape {k_cache_shape} and VCache shape {v_cache_shape}\"\n else:\n assert k.size(-1) == k_cache_shape[-1], \"Incompatible head dim\"\n assert (\n k_cache_shape == v_cache_shape\n ), f\"Incompatible KCache shape {k_cache_shape} and VCache shape {v_cache_shape}\"\n assert v.size(-1) == v_cache_shape[-1], \"Incompatible head dim\"\n\n k = k.squeeze(1) if k.dim() == 4 else k\n assert k.dim() == 3, f\"Incompatible k dim {k.dim()}\"\n v = v.squeeze(1) if v.dim() == 4 else v\n assert v.dim() == 3, f\"Incompatible v dim {v.dim()}\"\n\n bsz, num_kv_heads, head_dim = k.shape\n assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (\n f\"Got incompatible batch size (number of seqs):\\n\"\n f\" Past kv sequence lengths bsz {kv_lengths.shape[0]}; \"\n f\" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}\"\n )\n\n # Modify if the shape of kv cahce is changed.\n block_size = k_cache.size(-2)\n\n x = head_dim\n stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)\n if use_new_kcache_layout:\n x = k_cache.size(-1)\n stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]\n\n num_warps = 8 if head_dim > 128 else 4\n grid = (bsz, num_kv_heads)\n _copy_to_kvcache_seqlen1_kernel[grid](\n k,\n v,\n k_cache,\n v_cache,\n block_tables,\n kv_lengths,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n stride_kcsplit_x,\n stride_kcs,\n stride_kcd,\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n block_tables.stride(0),\n block_tables.stride(1),\n block_size,\n HEAD_DIM=head_dim,\n KCACHE_X=x,\n num_warps=num_warps,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-27d040fbdba54413b49ba6de2cc559a5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `chunk_simple_gla_fwd_kernel_o` function, decorated with `@triton.autotune`, implements a Triton kernel for efficient computation of a forward pass in a chunked matrix-multiplication setting. It receives tensors `q`, `k`, `v`, `h`, and `g`, alongside tensor strides (`s_k_h`, `s_k_t`, etc.), scaling factor `scale`, and chunk dimensions (`BT`, `BK`, `BV`). The kernel operates using Triton's parallel execution, where indices are computed using `tl.program_id` for efficient matrix partitioning. Block pointers (e.g., `p_q`, `p_k`) load sub-blocks of `q`, `k`, `h` into registers. It computes partial outputs `b_o` and `b_s` using dot products, adjusts them with exponentials, and conditions based on the mask `m_s`. The result is stored in the output tensor `o` using `tl.store`. The `chunk_fwd_o_fn` function handles higher-level orchestration, preparing grid dimensions, calculating chunk sizes (`BK`, `BV`), and calling the kernel with the pre-computed grid and problem parameters.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=4),\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_simple_gla_fwd_kernel_o(\n q,\n k,\n v,\n h,\n g,\n o,\n s_k_h,\n s_k_t,\n s_v_h,\n s_v_t,\n s_h_h,\n s_h_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n # Kernel implementation\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (K, T), (1, s_k_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_g = tl.load(p_g, boundary_check=(0,))\n b_o = b_o * tl.exp(b_g)[:, None]\n b_s = b_s * tl.exp(b_g[:, None] - b_g[None, :])\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\ndef chunk_fwd_o_fn(h, q, k, v, g, BT, scale):\n # Kernel call\n B, H, T, K, V = *k.shape, v.shape[-1]\n o = torch.empty_like(v)\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NV = triton.cdiv(V, BV)\n NT = triton.cdiv(T, BT)\n grid = (NV, NT, B * H)\n chunk_simple_gla_fwd_kernel_o[grid](\n q, k, v, h, g, o,\n q.stride(1), q.stride(2),\n v.stride(1), v.stride(2),\n h.stride(1), h.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV\n )\n return o\n\n\n\n" + ] + }, + { + "id": "openseek-8-211463d6ac074a679c19c3fb16e6499e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton implementation provides a specialized routine for performing matrix multiplication with quantized int8 matrices to leverage efficient GPU computation. \n\n The `quantize_int8_perrow_kernel` is a Triton JIT-compiled kernel, decorated with multiple `triton.autotune` configurations to optimize execution parameters such as `num_stages` and `num_warps`. It is designed to quantize a floating-point matrix to int8 per row, by first calculating the maximum absolute value per row and using it to scale the values to fit within the int8 range [-127, 127]. The kernel operates on a matrix `fpa` of dimensions MxK, loading blocks of data to compute maximums and perform quantization, and stores the resulting int8 values in matrix `a` while outputting the scale per row to `as`.\n\n The `quantize_int8_perrow` function serves as a Python wrapper for this kernel. It allocates output tensors for the quantized result and the scale factors, determines grid size based on input matrix dimensions, and launches the Triton kernel.\n\n The `matmul_kernel` is another Triton JIT kernel that handles matrix multiplication C = A x B where A and B are quantized matrices, accompanied by scaling factors `as_ptr` and `bs_ptr`. The kernel iteratively loads sub-blocks of A and B, accumulates the results in `accumulator`, and applies scaling factors before storing the final result in matrix C. The kernel supports various configurations, including SPLIT_K for advanced memory management and optimization.\n\n The `matmul_quantize_int8` function first applies row-wise quantization on a floating-point matrix `fpa` and then performs multiplication using `matmul_int8`.\n\n The `matmul_int8` method calls `matmul_kernel` with appropriate parameters, allocating and preparing output storage if not provided. It calculates grid size using configurable meta-parameters `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `SPLIT_K`.\n\n Finally, the `quantize_int8` function handles the quantization of matrices, computing scale factors based on maximum absolute values along the specified axis, and transforming the weight matrix to int8, preparing it for efficient matrix operations. Outputs from these functions include the quantized matrix and corresponding scaling factors, essential for accurate result reconstruction after computation.", + "output": [ + "import time\nimport torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=4),\n triton.Config({}, num_stages=2, num_warps=2),\n triton.Config({}, num_stages=2, num_warps=1),\n ],\n key=['K'],\n)\n@triton.jit\ndef quantize_int8_perrow_kernel(\n fpa_ptr, a_ptr, as_ptr,\n M, K, \n stride_fpam, stride_fpak,\n stride_am, stride_ak,\n stride_asm,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n\n fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n a_max = tl.zeros((BLOCK_SIZE_M,), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n a_max = tl.maximum(a_max, tl.max(tl.abs(fpa), axis=1))\n fpa_ptrs += BLOCK_SIZE_K * stride_fpak\n a_scale = (a_max / 127.)\n fpa_ptrs = fpa_ptr + offs_am[:, None] * stride_fpam + offs_k[None, :] * stride_fpak\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n fpa = tl.load(fpa_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n inta = (fpa / a_scale[:, None]).to(tl.int8)\n tl.store(a_ptrs, inta, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K)\n fpa_ptrs += BLOCK_SIZE_K * stride_fpak\n a_ptrs += BLOCK_SIZE_K * stride_ak\n as_offs = pid_m * BLOCK_SIZE_M * stride_asm + tl.arange(0, BLOCK_SIZE_M)\n tl.store(as_ptr + as_offs, a_scale)\n\n\ndef quantize_int8_perrow(fpa):\n a = torch.empty(fpa.shape, device=fpa.device, dtype=torch.int8)\n a_scale = torch.empty(fpa.shape[0], device=fpa.device, dtype=torch.float16)\n M, K = fpa.shape\n BLOCK_SIZE_M = 1\n BLOCK_SIZE_K = triton.next_power_of_2(K)\n grid = (M // BLOCK_SIZE_M,)\n quantize_int8_perrow_kernel[grid](\n fpa, a, a_scale,\n M, K,\n fpa.stride(0), fpa.stride(1),\n a.stride(0), a.stride(1),\n a_scale.stride(0),\n BLOCK_SIZE_M, BLOCK_SIZE_K,\n )\n return a, a_scale\n\n\n@triton.autotune(\n configs=[\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8),\n\t triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 16}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 16}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n ],\n key=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n # Pointers to matrices\n a_ptr, as_ptr, b_ptr, bs_ptr, c_ptr,\n # Matrix dimensions\n M, N, K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n # by to get the element one row down (A has M rows).\n stride_am, stride_ak,\n stride_asm,\n stride_bk, stride_bn,\n stride_bsn,\n stride_cm, stride_cn,\n # Meta-parameters\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr, \n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n # -----------------------------------------------------------\n # Map program ids `pid` to the block of C it should compute.\n # This is done in a grouped ordering to promote L2 data reuse.\n # See above `L2 Cache Optimizations` section for details.\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n # ----------------------------------------------------------\n # Create pointers for the first blocks of A and B.\n # We will advance this pointer as we move in the K direction\n # and accumulate\n # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers\n # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers\n # See above `Pointer Arithmetics` section for details\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n as_ptrs = as_ptr + offs_am * stride_asm\n bs_ptrs = bs_ptr + offs_bn * stride_bsn\n a_scale = tl.load(as_ptrs, mask=offs_am < M, other=0.0)\n b_scale = tl.load(bs_ptrs, mask=offs_bn < N, other=0.0)\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n # Load the next block of A and B, generate a mask by checking the K dimension.\n # If it is out of bounds, set it to 0.\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K * SPLIT_K, other=0.0)\n # We accumulate along the K dimension.\n accumulator += tl.dot(a, b)\n # Advance the ptrs to the next K block.\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_bk\n # You can fuse arbitrary activation functions here\n # while the accumulator is still in FP32!\n c = (accumulator.to(tl.float32) * a_scale[:, None] * b_scale[None, :]).to(tl.float16)\n # -----------------------------------------------------------\n # Write back the block of the output matrix C with masks.\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_quantize_int8(fpa, b, b_scale, out=None):\n a, a_scale = quantize_int8_perrow(fpa)\n # a, a_scale = quantize_int8(fpa, axis=1)\n return matmul_int8(a, a_scale, b, b_scale, out)\n\n\ndef matmul_int8(a, a_scale, b, b_scale, out=None):\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n # Allocates output.\n if out == None:\n c = torch.zeros((M, N), device=a.device, dtype=torch.float16)\n else:\n c = out.fill_(0.)\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n a, a_scale, b, b_scale, c,\n M, N, K,\n a.stride(0), a.stride(1),\n a_scale.stride(0),\n b.stride(0), b.stride(1),\n b_scale.stride(0),\n c.stride(0), c.stride(1),\n )\n return c\n\n\ndef quantize_int8(weight, axis=0):\n # Weight shape: [H1, H2]\n # Scale shape: [H2]\n scale = weight.abs().amax(axis, keepdim=True) / 127.\n weight = (weight / scale).to(torch.int8)\n # col major will accelerate i8xi8 kernel.\n if axis == 0:\n weight = weight.t().contiguous().t()\n scale = scale.squeeze(axis)\n return weight, scale, None\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-090c682c62bf41189aa243b5e51758c7", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel _seeded_dropout and the Python function seeded_dropout implement a seeded dropout operation on a given tensor x. The purpose of this implementation is to simulate the dropout operation, which is typically used to prevent overfitting in neural networks by randomly setting elements of the input tensor to zero based on a probability p. \n\n The function _seeded_dropout is a Triton kernel, meaning it's executed in parallel on a GPU. It takes several arguments: x_ptr and output_ptr, which are pointers to the input and output data; n_elements, the number of elements to process; p, the probability of setting a value to zero; seed, a fixed seed for random number generation to ensure reproducibility; and BLOCK_SIZE, a constant expression defining the size of blocks of data processed by each kernel instance. The kernel calculates offsets for each block based on the program ID, loads data from the input pointer using these offsets, and generates a random number per element using tl.rand. A mask is created where random numbers greater than the probability p indicate elements that are kept. These elements are scaled by 1/(1-p) to keep the output mean consistent with the input. The result is stored in the output pointer, ensuring values outside the bounds of n_elements are masked.\n\n The function seeded_dropout is a wrapper around the Triton kernel. It initializes an output tensor of the same shape and type as the input tensor x and ensures the input tensor is contiguous. It calculates the grid size needed to cover all input elements given the specified BLOCK_SIZE, which determines how many elements each kernel instance will process. The Triton kernel _seeded_dropout is then launched with this grid configuration. Finally, the output tensor is returned, containing the dropout-applied version of the input tensor.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _seeded_dropout(\n x_ptr,\n output_ptr,\n n_elements,\n p,\n seed,\n BLOCK_SIZE: tl.constexpr,\n):\n # compute memory offsets of elements handled by this instance\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n # load data from x\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n # randomly prune it\n random = tl.rand(seed, offsets)\n x_keep = random > p\n # write-back\n output = tl.where(x_keep, x / (1 - p), 0.0)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n\ndef seeded_dropout(x, p, seed):\n output = torch.empty_like(x)\n assert x.is_contiguous()\n n_elements = x.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )\n _seeded_dropout[grid](x, output, n_elements, p, seed, BLOCK_SIZE=1024)\n return output\n\n\n\n" + ] + }, + { + "id": "openseek-8-1daa839a2ecb49699f23e2efc6c64592", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_int8_matmul_rowwise_dequantize` kernel in Triton performs matrix multiplication with int8 inputs, applying row-wise dequantization, and supports optional bias addition. The kernel uses a block-wise computation strategy defined by the parameters BLOCK_M, BLOCK_N, and BLOCK_K. SPLIT_K enables further parallelization across the K dimension by dividing the work into segments. The grid size for launching the kernel is determined by the lambda function `grid`, which calculates the number of blocks required based on input dimensions and block sizes. The kernel first computes a per-block matrix product using the `tl.dot` operation. It accumulates results in an `acc` tensor of type tl.int32, which is then scaled by factors loaded from `state_x_ptr` and `state_w_ptr`. If bias is present, it's added post-dequantization. The final result is stored or atomically added to matrix C based on the SPLIT_K value. The outer function `int8_matmul_rowwise_dequantize` sets up the parameters and input tensors, handling strides and data contiguity, and invokes the Triton kernel with optimized configurations.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n\ndef get_configs_io_bound():\n def init_to_zero(name):\n return lambda nargs: nargs[name].zero_()\n\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_k in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n triton.Config(\n {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n, \"BLOCK_K\": block_k, \"SPLIT_K\": 1},\n num_stages=num_stages,\n num_warps=num_warps,\n ),\n )\n # split_k\n for split_k in [2, 4, 8, 16]:\n configs.append(\n triton.Config(\n {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n, \"BLOCK_K\": block_k, \"SPLIT_K\": split_k},\n num_stages=num_stages,\n num_warps=num_warps,\n pre_hook=init_to_zero(\"C\"),\n ),\n )\n return configs\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n *get_configs_io_bound(),\n ],\n key=[\"M\", \"N\", \"K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"EVEN_K\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n },\n)\n@triton.jit\ndef _int8_matmul_rowwise_dequantize(\n A,\n B,\n C,\n bias,\n state_x_ptr,\n state_w_ptr,\n M,\n N,\n K,\n divfactor,\n has_bias: tl.constexpr,\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n SPLIT_K: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr,\n):\n pid = tl.program_id(0)\n pid_z = tl.program_id(1)\n grid_m = tl.cdiv(M, BLOCK_M)\n grid_n = tl.cdiv(N, BLOCK_N)\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = pid_z * BLOCK_K + tl.arange(0, BLOCK_K)\n A = A + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = B + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n w_factor = tl.load(state_w_ptr + rbn)[None, :]\n x_factor = tl.load(state_x_ptr + ram)[:, None]\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, tl.cdiv(K, BLOCK_K * SPLIT_K)):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n k_remaining = K - k * (BLOCK_K * SPLIT_K)\n a = tl.load(A, mask=rk[None, :] < k_remaining, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k_remaining, other=0.0)\n acc += tl.dot(a, b)\n A += BLOCK_K * SPLIT_K * stride_ak\n B += BLOCK_K * SPLIT_K * stride_bk\n acc = w_factor * (x_factor * (acc * divfactor))\n acc = acc.to(C.dtype.element_ty)\n if has_bias:\n bias = tl.load(bias + rn).to(C.dtype.element_ty)\n acc = acc + bias[None, :]\n C = C + (rm[:, None] * stride_cm + rn[None, :] * stride_cn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n if SPLIT_K == 1:\n tl.store(C, acc, mask=mask)\n else:\n tl.atomic_add(C, acc, mask=mask)\n\ndef int8_matmul_rowwise_dequantize(a, b, state_x, state_w, bias):\n divfactor = 1.0 / (127.0 * 127.0)\n has_bias = 0 if bias is None else 1\n device = a.device\n if a.stride(0) > 1 and a.stride(1) > 1:\n a = a.contiguous()\n if b.stride(0) > 1 and b.stride(1) > 1:\n b = b.contiguous()\n assert a.shape[1] == b.shape[0], \"incompatible dimensions\"\n M, K = a.shape\n _, N = b.shape\n c = torch.empty((M, N), device=device, dtype=torch.float16)\n ACC_TYPE = tl.float32\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]), META[\"SPLIT_K\"])\n _int8_matmul_rowwise_dequantize[grid](\n a,\n b,\n c,\n bias,\n state_x,\n state_w,\n M,\n N,\n K,\n divfactor,\n has_bias,\n a.stride(0),\n a.stride(1),\n b.stride(0),\n b.stride(1),\n c.stride(0),\n c.stride(1),\n GROUP_M=8,\n ACC_TYPE=ACC_TYPE,\n )\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-169966033e284f21b14eeb9da064f5c0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton function `_quantize_global_transpose` performs a global quantization and transposition of a matrix `A`. It operates on a block of size `BLOCK_M` by `BLOCK_N` and partitions the matrix into groups as determined by `GROUP_M`. The kernel computes the grid dimensions needed to cover the matrix based on `M` and `N`, which represent the dimensions of `A`. The `program_id` is split into group indices `pid_m` and `pid_n` to navigate over blocks. `A` is accessed using pointers adjusted with `stride_am` and `stride_an`, and the quantization is achieved by multiplying each element by `absmax_inv` and scaling to int8 range. The result is stored in matrix `B`, transposing the indices using `stride_bm` and `stride_bn`. The `quantize_global_transpose` function facilitates the kernel launch, preparing the input by calculating `absmax` and its reciprocal, then initializes the output matrix. It defines a lambda to configure the launch grid based on the matrix dimensions and block sizes.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# global quantize and transpose\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"GROUP_M\": 8}, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"GROUP_M\": 8}, num_warps=4),\n # ...\n ],\n key=[\"M\", \"N\"],\n)\n@triton.jit\ndef _quantize_global_transpose(\n A,\n absmax_inv_ptr,\n B,\n stride_am,\n stride_an,\n stride_bn,\n stride_bm,\n M,\n N,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n GROUP_M: tl.constexpr,\n):\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // group_size\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n A = A + (rm[:, None] * stride_am + rn[None, :] * stride_an)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n a = tl.load(A, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n\n # rematerialize to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n B = B + (rm[:, None] * stride_bm + rn[None, :] * stride_bn)\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n\n output = tl.extra.cuda.libdevice.llrint(127.0 * (a * absmax_inv))\n\n tl.store(B, output, mask=mask)\n\ndef quantize_global_transpose(input):\n absmax = input.abs().max().unsqueeze(0)\n absmax_inv = 1.0 / absmax\n M, N = input.shape\n out = torch.empty(N, M, device=\"cuda\", dtype=torch.int8)\n\n assert out.size(0) == N and out.size(1) == M\n assert input.stride(0) == 1 or input.stride(1) == 1\n assert out.stride(0) == 1 or out.stride(1) == 1\n\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),)\n _quantize_global_transpose[grid](\n input,\n absmax_inv,\n out,\n input.stride(0),\n input.stride(1),\n out.stride(0),\n out.stride(1),\n M,\n N,\n )\n return out, absmax\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-d5f12cab10a548ea8e9db82a5fc60cfa", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel function `_fwd_kernel_token_att2` is designed for performing forward attention operations, similar to those found in Transformer models. It processes input tensors `Prob`, `V`, and `Req_to_tokens`, and computes an output tensor `Out`. \n\n- `Prob` represents the attention probabilities for each token.\n- `V` is the tensor of value vectors.\n- `Out` is the resulting tensor where the weighted sum of `V` is stored.\n- `Req_to_tokens` maps requests to token indices.\n- `B_req_idx`, `B_Start_Loc`, and `B_Seqlen` are batch-specific indices and sequence length data.\n\nThe kernel operates on blocks of data specified by `BLOCK_DMODEL` and `BLOCK_N`, which define the block size for model dimensions and tokens, respectively. Within the kernel:\n\n1. The `cur_batch` and `cur_head` are determined by the program's grid ID.\n2. It calculates offsets for accessing portions of the `Prob`, `V`, and `Req_to_tokens` tensors.\n3. For each block of tokens, it fetches the corresponding probabilities and values, computes their weighted sum, and accumulates it in `acc`.\n4. Finally, `acc` is stored back into the `Out` tensor at the appropriate location.\n\nThe function `token_att_fwd2` sets up the grid size based on batch and head dimensions and computes the `kv_group_num` which defines the number of key-value groups. It then launches the `_fwd_kernel_token_att2` with all the necessary tensor strides and configurations, such as `num_warps` and `num_stages`, to control the parallel execution environment.\n", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_token_att2(\n Prob,\n V,\n Out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n stride_req_to_tokens_b,\n stride_req_to_tokens_s,\n stride_ph,\n stride_pbs,\n stride_vbs,\n stride_vh,\n stride_vd,\n stride_obs,\n stride_oh,\n stride_od,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_kv_head = cur_head // kv_group_num\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_index = 0\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n cur_batch_req_idx = tl.load(B_req_idx + cur_batch)\n\n v_loc_off = cur_batch_req_idx * stride_req_to_tokens_b + (cur_batch_start_index + offs_n) * stride_req_to_tokens_s\n p_offs = cur_head * stride_ph + (cur_batch_in_all_start_index + offs_n) * stride_pbs\n v_offs = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd\n\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n p_value = tl.load(Prob + p_offs + start_n, mask=(start_n + offs_n) < cur_batch_seq_len, other=0.0)\n v_loc = tl.load(\n Req_to_tokens + v_loc_off + start_n * stride_req_to_tokens_s,\n mask=(start_n + offs_n) < cur_batch_seq_len,\n other=0.0,\n )\n v_value = tl.load(\n V + v_offs + v_loc[:, None] * stride_vbs, mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0\n )\n acc += tl.sum(p_value[:, None] * v_value, 0)\n\n acc = acc.to(Out.dtype.element_ty)\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_att_fwd2(prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen):\n BLOCK = 128\n batch, head = B_req_idx.shape[0], prob.shape[0]\n grid = (batch, head)\n num_warps = 4\n dim = v.shape[-1]\n\n kv_group_num = prob.shape[0] // v.shape[1]\n\n _fwd_kernel_token_att2[grid](\n prob,\n v,\n out,\n Req_to_tokens,\n B_req_idx,\n B_Start_Loc,\n B_Seqlen,\n Req_to_tokens.stride(0),\n Req_to_tokens.stride(1),\n prob.stride(0),\n prob.stride(1),\n v.stride(0),\n v.stride(1),\n v.stride(2),\n out.stride(0),\n out.stride(1),\n out.stride(2),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=dim,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n" + ] + }, + { + "id": "openseek-8-a974c6b18f3c4328bf75c9fedcff8c9f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-based implementation provides a high-performance LayerNorm operation optimized for execution on GPUs. It involves three specialized kernels: \n `_layer_norm_fwd_fused` performs the forward pass by normalizing input tensor `X` over its last dimension using batch statistics (mean and variance). It applies learned scale (`W`) and shift (`B`) parameters, storing the normalized result in output `Y`. The kernel uses block-wise reduction for efficient computation, storing per-row mean and inverse standard deviation (`Rstd`), which are later used for the backward pass.\n `_layer_norm_bwd_dx_fused` computes the gradient of inputs (`DX`) using the output gradients (`DY`). It adjusts for learned parameters (`W`) and uses saved batch statistics (`Mean`, `Rstd`). The kernel handles partial reduction of gradients for weights (`DW`) and biases (`DB`) across a workgroup, ensuring synchronization via locks during accumulation.\n `_layer_norm_bwd_dwdb` finalizes the backward pass by aggregating partial gradients for weights and biases collected across workgroups into final gradients (`FINAL_DW` and `FINAL_DB`). This kernel completes the distributed reduction using efficient memory operations.\n The `LayerNorm` class leverages these kernels, defining a custom PyTorch function. During the forward pass, it reshapes inputs and prepares necessary buffers, executing `_layer_norm_fwd_fused`. In the backward pass, `_layer_norm_bwd_dx_fused` and `_layer_norm_bwd_dwdb` are called sequentially, computing the necessary gradients with respect to inputs, weights, and biases. Parameters like `BLOCK_SIZE`, `GROUP_SIZE_M`, and kernel launch configurations (number of warps) are dynamically determined based on input dimensions to optimize performance.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _layer_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n B, # pointer to the biases\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n mean = 0\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Mean + row, mean)\n tl.store(Rstd + row, rstd)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n b = tl.load(B + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w + b\n tl.store(Y + cols, y, mask=mask)\n\n@triton.jit\ndef _layer_norm_bwd_dx_fused(DX, # pointer to the input gradient\n DY, # pointer to the output gradient\n DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n X, # pointer to the input\n W, # pointer to the weights\n Mean, # pointer to the mean\n Rstd, # pointer to the 1/std\n Lock, # pointer to the lock\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n GROUP_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n row = tl.program_id(0)\n cols = tl.arange(0, BLOCK_SIZE_N)\n mask = cols < N\n X += row * stride\n DY += row * stride\n DX += row * stride\n lock_id = row % GROUP_SIZE_M\n Lock += lock_id\n Count = Lock + GROUP_SIZE_M\n DW = DW + lock_id * N + cols\n DB = DB + lock_id * N + cols\n x = tl.load(X + cols, mask=mask, other=0).to(tl.float32)\n dy = tl.load(DY + cols, mask=mask, other=0).to(tl.float32)\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n mean = tl.load(Mean + row)\n rstd = tl.load(Rstd + row)\n xhat = (x - mean) * rstd\n wdy = w * dy\n xhat = tl.where(mask, xhat, 0.)\n wdy = tl.where(mask, wdy, 0.)\n c1 = tl.sum(xhat * wdy, axis=0) / N\n c2 = tl.sum(wdy, axis=0) / N\n dx = (wdy - (xhat * c1 + c2)) * rstd\n tl.store(DX + cols, dx, mask=mask)\n partial_dw = (dy * xhat).to(w.dtype)\n partial_db = (dy).to(w.dtype)\n while tl.atomic_cas(Lock, 0, 1) == 1:\n pass\n count = tl.load(Count)\n if count == 0:\n tl.atomic_xchg(Count, 1)\n else:\n partial_dw += tl.load(DW, mask=mask)\n partial_db += tl.load(DB, mask=mask)\n tl.store(DW, partial_dw, mask=mask)\n tl.store(DB, partial_db, mask=mask)\n tl.atomic_xchg(Lock, 0)\n\n@triton.jit\ndef _layer_norm_bwd_dwdb(DW, # pointer to the partial sum of weights gradient\n DB, # pointer to the partial sum of biases gradient\n FINAL_DW, # pointer to the weights gradient\n FINAL_DB, # pointer to the biases gradient\n M, # GROUP_SIZE_M\n N, # number of columns\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n pid = tl.program_id(0)\n cols = pid * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n dw = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n db = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for i in range(0, M, BLOCK_SIZE_M):\n rows = i + tl.arange(0, BLOCK_SIZE_M)\n mask = (rows[:, None] < M) & (cols[None, :] < N)\n offs = rows[:, None] * N + cols[None, :]\n dw += tl.load(DW + offs, mask=mask, other=0.)\n db += tl.load(DB + offs, mask=mask, other=0.)\n sum_dw = tl.sum(dw, axis=0)\n sum_db = tl.sum(db, axis=0)\n tl.store(FINAL_DW + cols, sum_dw, mask=cols < N)\n tl.store(FINAL_DB + cols, sum_db, mask=cols < N)\n\nclass LayerNorm(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, normalized_shape, weight, bias, eps):\n y = torch.empty_like(x)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n mean = torch.empty((M, ), dtype=torch.float32, device=x.device)\n rstd = torch.empty((M, ), dtype=torch.float32, device=x.device)\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n _layer_norm_fwd_fused[(M, )]( #\n x_arg, y, weight, bias, mean, rstd, #\n x_arg.stride(0), N, eps, #\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps, num_ctas=1)\n ctx.save_for_backward(x, weight, bias, mean, rstd)\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.eps = eps\n return y\n\n @staticmethod\n def backward(ctx, dy):\n x, w, b, m, v = ctx.saved_tensors\n N = w.shape[0]\n GROUP_SIZE_M = 64\n if N <= 8192: GROUP_SIZE_M = 96\n if N <= 4096: GROUP_SIZE_M = 128\n if N <= 1024: GROUP_SIZE_M = 256\n locks = torch.zeros(2 * GROUP_SIZE_M, dtype=torch.int32, device=w.device)\n _dw = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)\n _db = torch.zeros((GROUP_SIZE_M, N), dtype=x.dtype, device=w.device)\n dw = torch.empty((N, ), dtype=w.dtype, device=w.device)\n db = torch.empty((N, ), dtype=w.dtype, device=w.device)\n dx = torch.empty_like(dy)\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n _layer_norm_bwd_dx_fused[(M, )]( #\n dx, dy, _dw, _db, x, w, m, v, locks, #\n x_arg.stride(0), N, #\n BLOCK_SIZE_N=ctx.BLOCK_SIZE, #\n GROUP_SIZE_M=GROUP_SIZE_M, #\n num_warps=ctx.num_warps)\n grid = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])]\n _layer_norm_bwd_dwdb[grid](\n _dw, _db, dw, db, min(GROUP_SIZE_M, M), N, #\n BLOCK_SIZE_M=32, #\n BLOCK_SIZE_N=128, num_ctas=1)\n return dx, None, dw, db, None\n\nlayer_norm = LayerNorm.apply\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-06c47a5284194a6b87f2b5b5c06bf638", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The kernel 'load_reduce_kernel' is implemented using the Triton library to perform an element-wise reduction operation. It reads from an input matrix 'x_ptr', computes the maximum value across each row (axis=1), and writes the result to an output vector 'y_ptr'. The function accepts several parameters: 'x_ptr' and 'y_ptr' are pointers to the input matrix and output vector, respectively; 'stride_xm' and 'stride_xn' define the leading and secondary dimension strides for accessing the matrix; 'stride_y' is the stride for the output vector. 'BLOCK_M' and 'BLOCK_N' are compile-time constants defining the block dimensions. The kernel creates a block pointer to manage memory access efficiently, loads a block of data, computes the row-wise maxima using tl.max, and stores the output using tl.store. The 'load_reduce' function tests this kernel by generating a random matrix 'x' and an empty output vector 'y', invoking the kernel with these parameters, and comparing the result against PyTorch's max function for correctness using 'assert_close'.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nfrom torch.testing import assert_close\n\n\n@triton.jit\ndef load_reduce_kernel(\n x_ptr, # pointer to the input matrix\n y_ptr, # pointer to the output vector\n stride_xm, # stride of matrix x in leading dimension\n stride_xn, # stride of matrix x in the second dimension\n stride_y, # stride of output vector y\n BLOCK_M: tl.constexpr, # block size in leading dimension\n BLOCK_N: tl.constexpr, # block size in second dimension\n):\n x_ptr = tl.make_block_ptr(\n base=x_ptr, shape=(BLOCK_M, BLOCK_N), strides=(stride_xm, stride_xn),\n offsets=(0, 0), block_shape=(BLOCK_M, BLOCK_N), order=(1, 0)\n )\n x = tl.load(x_ptr)\n y = tl.max(x, axis=1)\n tl.store(y_ptr + tl.arange(0, BLOCK_M), y)\n\n# Test function for load_reduce_kernel\ndef load_reduce(BLOCK_M, BLOCK_N, dtype_str):\n dtype_mapping = {\n 'float16': torch.float16,\n 'float32': torch.float32,\n }\n dtype = dtype_mapping[dtype_str]\n x = torch.randn((BLOCK_M, BLOCK_N), device='cuda', dtype=dtype)\n y = torch.empty((BLOCK_M,), device='cuda', dtype=dtype)\n\n load_reduce_kernel[(1,)](x, y, x.stride(0), x.stride(1), y.stride(0), BLOCK_M, BLOCK_N)\n\n golden = x.max(dim=1)[0]\n torch.set_printoptions(profile='full')\n assert_close(y, golden, rtol=1e-2, atol=1e-3, check_dtype=False)\n\n\n\n" + ] + }, + { + "id": "openseek-8-b1d06197aa744137a68431f2bfba64c3", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_layer_norm_fwd_kernel` function in Triton computes the forward pass of a layer normalization operation on a 3D input tensor `X`. The kernel normalizes the elements of `X` along the last dimension for each 2D slice, using a specified weight matrix `W` for scaling. \n\n The kernel operates with blocks of size `BLOCK_SIZE` and uses loop iterations to handle each segment along the normalization axis. The mean and variance are computed first, using floating point precision for stability, and then applied to obtain the normalized output `Y`.\n\n Input `X` is a pointer to a 3D tensor with shape `(D1, D2, N)`, and `W` is a pointer to weights with shape `(D2, N)`. `Y` is the output tensor of the same shape as `X`. The kernel uses provided strides to traverse the input tensor efficiently.\n\n The function `layernorm_forward` prepares the grid dimensions for kernel execution, computes the memory strides, and ensures the dimensions are compatible between `X` and `W`. It then invokes the Triton kernel `_layer_norm_fwd_kernel` with computed parameters such as grid size, and block size.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# LayerNorm adapted from triton tutorial\n@triton.jit\ndef _layer_norm_fwd_kernel(\n X, # pointer to the input\n W, # pointer to the weights\n Y, # output pointer\n stride_x_N,\n stride_x_hn,\n stride_x_hd,\n stride_y_N,\n stride_y_hn,\n stride_y_hd,\n stride_w_hn,\n stride_w_hd,\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n Seq = tl.program_id(0)\n H = tl.program_id(1)\n\n X += Seq * stride_x_N + H * stride_x_hn\n Y += Seq * stride_y_N + H * stride_y_hn\n W += H * stride_w_hn\n\n _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n _mean += a\n mean = tl.sum(_mean, axis=0) / N\n\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x - mean, 0.0)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = (x - mean) * rstd\n y = x_hat * w\n\n tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask)\n\n\ndef layernorm_forward(\n X, # pointer to the input\n W, # pointer to the weights\n eps, # epsilon to avoid division by zero\n):\n assert len(X.shape) == 3\n assert len(W.shape) == 2\n assert X.shape[-1] == W.shape[-1]\n assert X.shape[-2] == W.shape[-2]\n\n y = torch.empty_like(X)\n\n stride_x_N = X.stride(0)\n stride_x_hn = X.stride(1)\n stride_x_hd = X.stride(2)\n\n stride_y_N = y.stride(0)\n stride_y_hn = y.stride(1)\n stride_y_hd = y.stride(2)\n\n stride_w_hn = W.stride(0)\n stride_w_hd = W.stride(1)\n\n N = X.shape[-1]\n BLOCK_SIZE = 128\n\n grid = (X.shape[0], X.shape[1])\n _layer_norm_fwd_kernel[grid](\n X,\n W,\n y,\n stride_x_N,\n stride_x_hn,\n stride_x_hd,\n stride_y_N,\n stride_y_hn,\n stride_y_hd,\n stride_w_hn,\n stride_w_hd,\n N,\n eps,\n BLOCK_SIZE,\n )\n\n return y\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c393fd6730cb46fa989edee07161468b", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel _fwd_kernel_aligned performs scaled dot-product attention using queries (Q), keys (K), values (V), and a bias matrix (B0) to compute the output (Out). The kernel operates using blocks of size defined by BLOCK_M, BLOCK_N, and BLOCK_DMODEL, enabling efficient memory usage and computation. The kernel processes data by dividing the input matrices Q, K, and V into smaller blocks, loading these into shared memory, and iteratively processing them using dot products and bias additions. The computation involves scaling the dot products by sm_scale, adding biases from B0, and applying a numerically stable softmax operation using tl.math.exp2. Accumulated results are normalized before being written back to the output matrix Out. The kernel is parallelized using Triton's grid, spreading tasks across the Q's head dimension and batch size.\n\n The function _attention_rel_h_rel_w_kernel_aligned_device sets up and launches the Triton kernel. It first validates that input tensor shapes and types are consistent, particularly ensuring q, k, v, and rel_h_w are compatible and have expected dimensions. It configures a computation grid based on the third dimension of Q and its batch size. The function calculates necessary constants like OUT_DTYPE for the precision and BIAS_LAST_SIZE for bias processing. It then invokes the _fwd_kernel_aligned kernel with specified parameters such as tensor pointers, strides, block sizes, and tuning parameters (num_warps, num_stages) to optimize the execution.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_aligned(\n Q, K, V, B0, sm_scale,\n Out,\n stride_qh, stride_qm, stride_qk,\n stride_kh, stride_kn, stride_kk,\n stride_vh, stride_vk, stride_vn,\n stride_oh, stride_om, stride_on,\n stride_b0h, stride_b0m,\n Z,\n H,\n N_CTX,\n P_SEQ,\n OUT_DTYPE: tl.constexpr,\n BIAS_LAST_SIZE: tl.constexpr,\n B0_NUMEL: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n q_offset = off_hz * stride_qh\n kv_offset = off_hz * stride_kh\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + kv_offset,\n shape=(BLOCK_DMODEL, N_CTX + P_SEQ),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1)\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + kv_offset,\n shape=(N_CTX + P_SEQ, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0)\n )\n\n # initialize offsets\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # scale sm_scale by log_2(e) and use\n # 2^x instead of exp in the loop because CSE and LICM\n # don't work as expected with `exp` in the loop\n qk_scale = sm_scale * 1.44269504\n # load q: it will stay in SRAM throughout\n q = tl.load(Q_block_ptr) # , boundary_check=(1, 0), padding_option=\"zero\")\n q = (q * qk_scale).to(OUT_DTYPE)\n # loop over k, v and update accumulator\n lo = 0\n hi = N_CTX + P_SEQ\n\n b_ptr_offsets_m = tl.arange(0, BLOCK_M)\n\n b_offset = off_hz * stride_b0h\n b_ptr_offsets_n_1 = (tl.arange(0, BLOCK_N) %\n BIAS_LAST_SIZE) + BIAS_LAST_SIZE\n b1 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)\n * stride_b0m)[:, None] + b_ptr_offsets_n_1[None, :])\n for start_n in range(lo, hi, BLOCK_N):\n # -- load k, v --\n # , boundary_check=(0, 1), padding_option=\"zero\")\n k = tl.load(K_block_ptr)\n # , boundary_check=(1, 0), padding_option=\"zero\")\n v = tl.load(V_block_ptr)\n # -- compute qk ---\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=OUT_DTYPE)\n qk += tl.dot(q, k) #, out_dtype=OUT_DTYPE)\n\n # -- compute rel_h[:, None] + rel_w[None, :] bias ---\n\n # Bias\n b0 = tl.load(B0 + b_offset + ((start_m * BLOCK_M + b_ptr_offsets_m)\n * stride_b0m)[:, None] + start_n // BLOCK_N)\n qk += ((b0 + b1) * 1.44269504)\n\n # -- compute scaling constant ---\n m_i_new = tl.maximum(m_i, tl.max(qk, 1))\n alpha = tl.math.exp2(m_i - m_i_new)\n p = tl.math.exp2(qk - m_i_new[:, None])\n # -- scale and update acc --\n acc *= alpha[:, None]\n acc += tl.dot(p.to(OUT_DTYPE), v)\n # -- update m_i and l_i --\n l_i = l_i * alpha + tl.sum(p, 1)\n m_i = m_i_new\n # update pointers\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n\n # write back l and m\n acc = acc / l_i[:, None]\n\n # write back O\n O_block_ptr = tl.make_block_ptr(\n base=Out + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0)\n )\n tl.store(O_block_ptr, acc.to(OUT_DTYPE))\n\n\ndef _attention_rel_h_rel_w_kernel_aligned_device(q, k, v, rel_h_w, sm_scale, o,\n BLOCK_M,\n BLOCK_N,\n num_warps,\n num_stages):\n _, Lk, _ = q.shape[-1], k.shape[-1], v.shape[-1]\n assert q.size() == k.size()\n assert q.size() == v.size()\n assert q.size(-2) == rel_h_w.size(-2)\n assert (q.dtype == torch.bfloat16 or q.dtype == torch.float16)\n assert k.dtype == q.dtype\n assert v.dtype == k.dtype\n assert o.dtype == v.dtype\n assert rel_h_w.dtype == q.dtype\n assert rel_h_w.size(-1) == 128\n # assert rel_h_w.size(-1) == 2 * BLOCK_N\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n # print(\"q.shape[0] * q.shape[1]: \", q.shape[0] * q.shape[1])\n P_SEQ = 0 if q.shape[-2] == k.shape[-2] else k.shape[-2] - q.shape[-2]\n assert P_SEQ == 0\n assert rel_h_w.is_contiguous(), str(rel_h_w.stride())\n OUT_DTYPE = tl.float16 if q.dtype == torch.float16 else tl.bfloat16\n _fwd_kernel_aligned[grid](\n q, k, v,\n rel_h_w,\n sm_scale,\n o,\n q.stride(1), q.stride(2), q.stride(3),\n k.stride(1), k.stride(2), k.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n o.stride(1), o.stride(2), o.stride(3),\n rel_h_w.stride(1), rel_h_w.stride(2),\n q.shape[0],\n q.shape[1],\n q.shape[2],\n P_SEQ,\n OUT_DTYPE=OUT_DTYPE,\n BIAS_LAST_SIZE=(rel_h_w.size(-1) // 2),\n B0_NUMEL=rel_h_w.size(-1),\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n BLOCK_DMODEL=Lk,\n num_warps=num_warps,\n num_stages=num_stages)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-25d7a745119a48d1ac079a68317cb403", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` is a Triton JIT-compiled kernel function designed to perform efficient tensor calculations for each batch and head within a defined grid. It processes two main input tensors, `Mid_O` and `Mid_O_LogExpSum`, using their strides and other inputs to compute and store results in the output tensor `O`. Key parameters include:\n - `B_Seqlen`: Provides sequence lengths for the batches.\n - `Mid_O`: A tensor of shape [batch, head, seq_block_num, head_dim] used in weighted sum calculations.\n - `Mid_O_LogExpSum`: A tensor [batch, head, seq_block_num] storing log-exp sum values for each sequence block.\n - `O`: A tensor [batch, head, head_dim] for storing the final accumulated and normalized output.\n - Stride parameters for tensors ensure correct indexing: `stride_mid_ob`, `stride_mid_oh`, `stride_mid_os`, `stride_mid_od` for `Mid_O`, and similar for others.\n - Constants `BLOCK_SEQ` and `BLOCK_DMODEL` define the block sizes for sequences and model dimensions, which guide iteration within the kernel.\n The kernel uses a double loop: over sequence blocks to accumulate weighted values and a vectorized computation for efficiency. Inside the loop, it computes new logic values, scales existing accumulations, and adjusts for numerical stability. Finally, it normalizes the accumulation by the sum of exponentials and writes the result back to `O`. \n The wrapper function `flash_decode_stage2` is a PyTorch-compatible function that sets up the computation grid and launches this kernel. It checks the compatibility of model dimensions (`Lk`) and passes the appropriate arguments to the kernel.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n O, # [batch, head, head_dim]\n stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od,\n stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es,\n stride_obs, stride_oh, stride_od,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n \n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n \n tl.store(O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp)\n return\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq):\n Lk = mid_out.shape[-1]\n assert Lk in {16, 32, 64, 128}\n batch, head_num = mid_out.shape[0], mid_out.shape[1]\n grid = (batch, head_num)\n \n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen, mid_out, mid_out_logexpsum, O,\n mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3),\n mid_out_logexpsum.stride(0), mid_out_logexpsum.stride(1), mid_out_logexpsum.stride(2),\n O.stride(0), O.stride(1), O.stride(2),\n BLOCK_SEQ=block_seq,\n BLOCK_DMODEL=Lk,\n num_warps=4,\n num_stages=2,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-bfe100a11ce74b90aef06cdb1b7b1edf", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code implements a Triton kernel `index_select_cat_bwd_kernel` and a Python wrapper function `index_select_cat_bwd` to handle the backward operation of index selection followed by concatenation. The function `index_select_cat_bwd` is used in a deep learning context where the forward operation involves selecting certain rows from a source tensor (potentially using concatenation of results from multiple selections) based on the indices provided. The backward operation, therefore, needs to compute the gradient of the source tensor by redistributing the gradient received from the output (`grad_output`) back into the positions specified by `index` within `grad_source`. The Triton kernel processes the data in a grid layout, with the axes of the grid defined by `BLOCK_SIZE_INDEX` and `BLOCK_SIZE_COL`, ensuring parallel computation. The function also includes extensive checks for tensor compatibility, ensuring that the inputs are 2D, CUDA-based, and have appropriate matching strides. The grid configuration is dynamically determined to cover all indices and columns of the inputs.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef index_select_cat_bwd_kernel(\n grad_source_ptr, # *Pointer* to grad_source tensor.\n index_ptr, # *Pointer* to index tensor.\n grad_output_ptr, # *Pointer* to grad_output tensor.\n num_rows,\n num_indices,\n num_cols,\n stride0, # Stride information of input and source tensor.\n stride1,\n BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.\n BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.\n):\n pid0 = tl.program_id(axis=0) # We use 3D launch grid\n pid1 = tl.program_id(axis=1)\n\n cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)\n\n # load grad_output\n grad_output_indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)\n grad_output_offsets = (\n grad_output_ptr\n + grad_output_indices[:, None] * stride0\n + cols[None, :] * stride1\n )\n grad_output_mask = (grad_output_indices[:, None] < num_indices) & (\n cols[None, :] < num_cols\n )\n grad_output = tl.load(grad_output_offsets, mask=grad_output_mask).to(tl.float32)\n\n # select indices from grad_source\n grad_source_indices = tl.load(\n index_ptr + grad_output_indices, mask=(grad_output_indices < num_indices)\n )\n grad_source_offsets = (\n grad_source_ptr\n + grad_source_indices[:, None] * stride0\n + cols[None, :] * stride1\n )\n\n # compute scaled index add and save\n tl.store(grad_source_offsets, grad_output, mask=grad_output_mask)\n\n\ndef index_select_cat_bwd(\n grad_source: torch.Tensor,\n index: torch.Tensor,\n grad_output: torch.Tensor,\n):\n if not (grad_source.is_cuda and grad_output.is_cuda):\n raise ValueError(\"The grad_source and grad_output tensor must be of type CUDA!\")\n\n if not (grad_source.ndim == 2 and grad_output.ndim == 2):\n raise ValueError(\n f\"The grad_source and grad_output must be three-dimensional \"\n f\"(got {grad_source.ndim} and {grad_output.ndim})!\"\n )\n if not grad_source.shape[1] == grad_output.shape[1]:\n raise ValueError(\n f\"The number of elements along dimension 1 of grad_source and grad_output must be the same \"\n f\"(got {grad_source.shape[1]} and {grad_output.shape[1]})\"\n )\n\n num_rows, num_cols = grad_source.shape\n num_indices, num_cols = grad_output.shape\n if not num_rows >= num_indices:\n raise ValueError(\n f\"The number of elements along dimension 0 of grad_source must be larger than that of grad_output \"\n f\"(got {num_rows} and {num_indices})!\"\n )\n if not index.shape[0] == num_indices:\n raise ValueError(\n f\"The number of indices and the number of elements along dimension 0 of grad_output must match \"\n f\"(got {index.shape[0]} and {num_indices})!\"\n )\n\n stride0, stride1 = grad_source.stride(0), grad_source.stride(1)\n if not (grad_output.stride(0) == stride0 and grad_output.stride(1) == stride1):\n raise ValueError(\n f\"The strides of the grad_source and grad_output tensors must match \"\n f\"(got {stride0} vs. {grad_output.stride(0)}, {stride1} vs. {grad_output.stride(1)})!\"\n )\n\n def grid(meta):\n return (\n triton.cdiv(num_indices, meta[\"BLOCK_SIZE_INDEX\"]),\n triton.cdiv(num_cols, meta[\"BLOCK_SIZE_COL\"]),\n )\n\n index_select_cat_bwd_kernel[grid](\n grad_source,\n index,\n grad_output,\n num_rows,\n num_indices,\n num_cols,\n grad_source.stride(0),\n grad_source.stride(1),\n BLOCK_SIZE_INDEX=1,\n BLOCK_SIZE_COL=512,\n )\n\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-acdb0971d34c4320960dfae58d282e9c", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe provided Triton operator computes the forward and backward passes of the Kullback-Leibler divergence (KL divergence). It includes the kernel definitions `_kldiv_kernel_forward` and `_kldiv_kernel_backward` and their corresponding wrapper functions `kldiv_forward_triton` and `kldiv_backward_triton`. Here is a detailed breakdown:\n\n1. Constants:\n - `MAX_FUSED_SIZE`: Set to 65536 divided by 4, determining the maximum block size for tensor fusion.\n - Reduction mode constants: `_REDUCTION_MODE_NONE`, `_REDUCTION_MODE_SUM`, `_REDUCTION_MODE_MEAN`, and `_REDUCTION_MODE_BATCHMEAN` are defined to manage different reduction strategies.\n - `_str_to_reduction_mode`: A dictionary mapping string representations of reduction modes to their constant values.\n\n2. Function `get_num_warps(BLOCK_SIZE)`:\n - Determines the optimal number of warps based on the provided `BLOCK_SIZE`. It returns 4 by default but increases for larger block sizes, up to 32 for `BLOCK_SIZE` >= 32768.\n\n3. `_kldiv_kernel_forward` function:\n - Takes pointers to the prediction tensor `y_ptr`, ground truth tensor `gt_ptr`, and a loss tensor `loss_ptr`, along with their strides.\n - Iterates over blocks of `BLOCK_SIZE` columns, computing the KL divergence using the formula `KL(y_true || y) = y_true * (log(y_true) - log(y))`.\n - Supports different reduction modes, handling each mode within the loop and storing results accordingly.\n\n4. `_kldiv_kernel_backward` function:\n - Similar in structure to its forward counterpart, this kernel computes gradients for the backward pass.\n - For non-log targets, it computes the negative of the target. For log targets, it multiplies the negative of the exponential of the target by the target.\n\n5. `kldiv_forward_triton` function:\n - Configures and launches the `_kldiv_kernel_forward` with appropriate grid, block size, and warp count based on the input tensor shape.\n - Outputs a tensor reduced based on the specified `reduction` mode, returning summed or mean values if applicable.\n\n6. `kldiv_backward_triton` function:\n - Configures and launches the `_kldiv_kernel_backward` for computing gradients.\n - If the `grad_output` is a scalar tensor equal to one, it returns the computed gradient directly; otherwise, it scales the gradient by `grad_output`.\n\nThe operations leverage Triton's capabilities for efficient parallel execution, particularly suited for large-scale tensor computations often found in deep learning tasks.\n", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nimport numpy as np\n\ndef get_num_warps(BLOCK_SIZE):\n num_warps = 4\n if BLOCK_SIZE >= 32768:\n num_warps = 32\n elif BLOCK_SIZE >= 8192:\n num_warps = 16\n elif BLOCK_SIZE >= 2048:\n num_warps = 8\n\n return num_warps\n\n\nMAX_FUSED_SIZE = 65536 // 4\n\n_REDUCTION_MODE_NONE = tl.constexpr(0)\n_REDUCTION_MODE_SUM = tl.constexpr(1)\n_REDUCTION_MODE_MEAN = tl.constexpr(2)\n_REDUCTION_MODE_BATCHMEAN = tl.constexpr(3)\n\n_str_to_reduction_mode = {\n \"none\": _REDUCTION_MODE_NONE.value,\n \"sum\": _REDUCTION_MODE_SUM.value,\n \"mean\": _REDUCTION_MODE_MEAN.value,\n \"batchmean\": _REDUCTION_MODE_BATCHMEAN.value,\n}\n\n@triton.jit\ndef _kldiv_kernel_forward(\n y_ptr, # [B, S], prediction ptr, the kernel expects the prediction in log-space\n y_stride, # int, prediction stride\n gt_ptr, # [B, S], ground truth ptr\n gt_stride, # int, ground truth stride\n loss_ptr, # [B] or [B, S] if reduction == _REDUCTION_MODE_NONE, output ptr\n loss_stride, # int, output stride\n n_cols, # int, number of columns in the input tensor\n BLOCK_SIZE: tl.constexpr,\n log_target: tl.constexpr = False,\n reduction: tl.constexpr = _REDUCTION_MODE_BATCHMEAN,\n):\n pid = tl.program_id(0).to(tl.int64)\n y_ptr += pid * y_stride\n gt_ptr += pid * gt_stride\n loss_ptr += pid * loss_stride\n\n base_offsets = tl.arange(0, BLOCK_SIZE)\n\n for i in range(0, n_cols, BLOCK_SIZE):\n offsets = i + base_offsets\n mask = offsets < n_cols\n y = tl.load(y_ptr + offsets, mask=mask, other=0.0)\n y_true = tl.load(gt_ptr + offsets, mask=mask, other=0.0)\n\n # KL(y_true || y) = y_true * (log(y_true) - log(y))\n # We compute KL(y_true || y) with y in the log-space\n if not log_target:\n loss = y_true * (tl.log(y_true) - y)\n else:\n loss = tl.exp(y_true) * (y_true - y)\n\n if reduction == _REDUCTION_MODE_NONE:\n tl.store(loss_ptr + offsets, loss, mask=mask)\n else:\n loss = tl.sum(loss, axis=0)\n tl.store(loss_ptr, loss)\n loss_ptr += 1\n\n@triton.jit\ndef _kldiv_kernel_backward(\n input_ptr,\n input_stride,\n target_ptr,\n target_stride,\n n_cols,\n BLOCK_SIZE: tl.constexpr,\n log_target: tl.constexpr = False,\n):\n pid = tl.program_id(0).to(tl.int64)\n\n input_ptr += pid * input_stride\n target_ptr += pid * target_stride\n\n offsets = tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_cols\n\n for i in range(0, n_cols, BLOCK_SIZE):\n offsets = i + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_cols\n\n target = tl.load(target_ptr + offsets, mask=mask, other=0.0)\n\n if not log_target:\n res = target * -1\n else:\n res = -tl.exp(target)\n\n tl.store(input_ptr + offsets, res, mask=mask)\n\ndef kldiv_forward_triton(y_pred, y_true, log_target, reduction): # [B, S] # [B, S]\n B, S = y_pred.shape\n\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))\n num_warps = get_num_warps(BLOCK_SIZE)\n\n grid = (B,)\n reduction = _str_to_reduction_mode[reduction]\n\n out_size = (B, S) if reduction == _REDUCTION_MODE_NONE.value else (B,)\n output_tensor = torch.zeros(\n out_size,\n dtype=torch.float32,\n device=\"cuda\",\n )\n\n _kldiv_kernel_forward[grid](\n y_pred,\n y_pred.stride(0),\n y_true,\n y_true.stride(0),\n output_tensor,\n output_tensor.stride(0),\n S,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n log_target=log_target,\n reduction=reduction,\n )\n\n if reduction == _REDUCTION_MODE_BATCHMEAN.value:\n return output_tensor.sum() / B\n elif reduction == _REDUCTION_MODE_SUM.value:\n return output_tensor.sum(axis=0)\n elif reduction == _REDUCTION_MODE_MEAN.value:\n return output_tensor.mean(axis=0)\n else:\n return output_tensor\n\ndef kldiv_backward_triton(input, target, grad_output, log_target):\n B, S = input.shape\n\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(S))\n num_warps = get_num_warps(BLOCK_SIZE)\n\n grid = (B,)\n\n _kldiv_kernel_backward[grid](\n input,\n input.stride(0),\n target,\n target.stride(0),\n S,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n log_target=log_target,\n )\n\n if torch.equal(\n grad_output,\n torch.tensor(\n 1.0,\n dtype=grad_output.dtype,\n device=\"cuda\",\n ),\n ):\n return input\n\n return input * grad_output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-ed62d9ffc8fe41fbb22d915bf3e98188", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton code includes multiple kernels and wrapper functions for handling 4-bit quantized matrix multiplications and their dequantization. The code is structured as follows:\n\n 1. **matmul4_kernel**: This Triton kernel handles the matrix multiplication `C = A x B` where `A` is a float16 matrix and `B` is a 4-bit quantized matrix stored in int32 format. The kernel dequantizes `B` using provided `scales` and `zeros`. The dequantization logic involves shifting and masking bitwise operations to extract the 4-bit values from `B`, applying scales, and subtracting zeros. The computations are done in blocks determined by `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `BLOCK_SIZE_K`. The result is accumulated in float32 precision and then stored in float16 format.\n\n 2. **matmul_dequantize_int4_gptq**: This function serves as a wrapper for `matmul4_kernel`. It prepares the inputs and sets up the Triton grid based on the dimensions of matrices `x`, `qweight`, `scales`, and `qzeros`. The grid size is determined by dividing `M` and `N` by the corresponding block sizes. If no output tensor is provided, it initializes a new tensor for the result. The kernel is invoked with all necessary parameters, including strides for addressing the elements of input matrices.\n\n 3. **matmul_kernel**: This Triton kernel performs a similar operation as `matmul4_kernel` but with an added `SPLIT_K` parameter. This allows splitting the `K` dimension into multiple parts, which can be processed separately, making it suitable for handling larger matrices. The kernel ensures that the result is accumulated correctly even when `SPLIT_K` is greater than 1 using atomic addition operations.\n\n 4. **matmul_dequantize_int4_s2**: This function wraps around `matmul_kernel`, organizing the execution of matrix multiplication with dequantization. The grid configuration takes into account both the `M` and `N` dimensions and the `SPLIT_K` factor. The function initializes the output matrix and calls the kernel with parameters set up for efficient parallel computation.\n\n 5. **dequantize_kernel**: This kernel is focused on converting a 4-bit quantized matrix `b` into a full precision float16 matrix. It uses block sizes `BLOCK_SIZE_K` and `BLOCK_SIZE_N` to manage parallel processing. Dequantization is achieved by extracting 4-bit values, adjusting for scales, and zero points. The kernel writes the dequantized values to the output matrix.\n\n 6. **dequantize_int4**: This function uses `dequantize_kernel` to convert quantized matrices into full precision. It sets up the necessary grid parameters and computes the dequantized matrix by invoking the kernel with appropriate memory strides and dimensions.\n\n 7. **matmul_dequantize_int4_s1**: This function dequantizes the weight matrix first and then performs a matrix multiplication using PyTorch's `torch.mm`. This approach is beneficial in scenarios where the dequantization overhead is compensated by the reuse of the dequantized weights over multiple computations.\n\n Overall, the code is designed to efficiently handle matrix operations on quantized data with Triton's parallel computation capabilities, providing both flexibility and performance.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n # Additional configurations are omitted for brevity...\n ],\n key=['M', 'N', 'K', 'NO_GROUPS'],\n)\n@triton.jit\ndef matmul4_kernel(\n a_ptr, b_ptr, c_ptr,\n scales_ptr, zeros_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_scales_g, stride_scales_n,\n stride_zeros_g, stride_zeros_n,\n groupsize, NO_GROUPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N//8) int32\n groupsize is an int specifying the size of groups for scales and zeros.\n G is K // groupsize.\n Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient.\n WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K.\n WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N.\n WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K.\n \"\"\"\n bits = 4\n infearure_per_bits = 8\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,)\n zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,)\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n if NO_GROUPS:\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 \n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n zeros = zeros * scales\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated \n if not NO_GROUPS:\n g_id = k // (groupsize // BLOCK_SIZE_K)\n ptr = scales_ptrs + g_id * stride_scales_g\n scales = tl.load(ptr) # (BLOCK_SIZE_N,)\n ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,)\n zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32 \n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16 \n b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values\n b = b * scales[None, :] - zeros[None, :] # Scale and shift\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk \n c = accumulator.to(tl.float16) \n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor:\n \"\"\"\n Compute the matrix multiplication C = A x B + bias.\n Where B is quantized using GPTQ and groupsize = -1 into 4-bit values.\n\n A is of shape (..., K) float16\n qweight is of shape (K//8, N) int32\n scales is of shape (G, N) float16\n qzeros is of shape (G, N//8) int32\n bias is of shape (1, N) float16\n\n groupsize is the number of infeatures in each group.\n G = K // groupsize\n\n Returns C of shape (..., N) float16\n \"\"\"\n assert x.shape[-1] == (qweight.shape[0] * 8), \"A must be a multiple of 8 in the last dimension\"\n assert x.is_contiguous(), \"A must be contiguous\"\n\n M, K = x.shape\n N = qweight.shape[1]\n\n if output is None:\n inplace = False\n output = torch.empty((M, N), device=x.device, dtype=torch.float16)\n else:\n inplace = True\n\n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n matmul4_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size, group_size == K,\n )\n if not inplace:\n return output\n\n\n@triton.autotune(\n configs=[\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n # Additional configurations are omitted for brevity...\n ],\n key=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n bs_ptr, bzp_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n group_size,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr\n ):\n \"\"\"\n Matrix multiplication kernel with dequantization for 4-bit integers\n A: [M, K], B: [K//8, N], C: [M, N]\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \\\n + offs_bn[None, :] * stride_bsn\n bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \\\n + (offs_bn[None, :] // 8) * stride_bzpn\n b_shift_bits = (offs_k[:, None] % 8) * 4 \n bzp_shift_bits = (offs_bn[None, :] % 8) * 4\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n bs = tl.load(bs_ptrs)\n bzp = tl.load(bzp_ptrs)\n\n int_b = (b >> b_shift_bits) & 0xF\n int_bzp = (bzp >> bzp_shift_bits) & 0xF\n b = ((int_b - int_bzp) * bs).to(tl.float16)\n accumulator += tl.dot(a.to(tl.float16), b.to(tl.float16))\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8)\n\n c = accumulator.to(tl.float16)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor:\n \"\"\"\n Perform matrix multiplication and dequantize the result from 4-bit integers.\n \"\"\"\n assert x.is_contiguous(), \"A must be contiguous\"\n assert qweight.is_contiguous(), \"B must be contiguous\" \n M, K = x.shape\n N = scales.shape[1]\n if output is None:\n output = torch.zeros((M, N), device=x.device, dtype=torch.float16) \n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n )\n return output\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n # Additional configurations are omitted for brevity...\n ],\n key=['K', 'N'],\n)\n@triton.jit\ndef dequantize_kernel(\n # Pointers to matrices\n b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr,\n # Matrix dimensions\n K, N, group_size,\n stride_bk, stride_bn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n stride_fpbk, stride_fpbn,\n # Meta-parameters\n BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"Dequantize tile [BLOCK_SIZE_K, BLOCK_SIZE_N] in full precision.\n We should assert BLOCK_SIZE_N % 8 == 0.\n weight[K // 8, N], scale[K // group_size, N], zp[K // group_size, N // group_size]\n \"\"\"\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n fpb_offs = offs_k[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn\n b_offs = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n bzp_offs = (offs_k[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn\n bs_offs = (offs_k[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn\n n_mask = offs_n[None, :] < N\n k_mask = offs_k[:, None] < K\n mask = n_mask & k_mask\n int32_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n zp_b = tl.load(b_zp_ptr + bzp_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=mask, other=0.0)\n b_shift = (offs_k[:, None] % 8) * 4\n bzp_shift = (offs_n[None, :] % 8) * 4\n fp_weight = (((int32_b >> b_shift) & 0xF) - ((zp_b >> bzp_shift) & 0xF)) * scale_b\n tl.store(fpb_ptr + fpb_offs, fp_weight, mask=mask)\n\n\ndef dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size):\n Kw, N = b.shape\n K = Kw * 8\n fp_b = torch.ones((K, N), device=device, dtype=dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']),\n triton.cdiv(N, META['BLOCK_SIZE_N']), \n )\n dequantize_kernel[grid](\n b, b_scale, b_zero_point, fp_b,\n K, N, group_size,\n b.stride(0), b.stride(1),\n b_scale.stride(0), b_scale.stride(1),\n b_zero_point.stride(0), b_zero_point.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n return fp_b\n\n\ndef matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, group_size=128, out=None):\n \"\"\"\n Matmul dequantize int4 s1 dequantize weight to `fp_b` and do fp16 torch.mm,\n this is for `prefill` stage, since weight size is fixed so is dequantize overhead,\n perfill stage have more tokens to amortize dequant cost.\n \"\"\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n Kw, N = b.shape\n if out is None:\n out = torch.empty((M, N), device=a.device, dtype=a.dtype)\n fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size)\n torch.mm(a, fp_b, out=out)\n fp_b = None\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-156cabf8bd614b2a8a4743a9b90e4d37", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The _quantize_rowwise kernel in Triton is designed to perform row-wise quantization on a 2D input tensor, which is provided as a CUDA tensor. The kernel is executed per row, with each row processed independently by leveraging Triton's parallel computing capabilities. It takes several parameters: x_ptr (pointer to the input tensor), output_ptr (pointer to the output tensor where quantized results are stored), output_maxs (pointer to an array where max values per row are stored), and n_elements, which is the total number of elements in the output tensor. Two compile-time constants, BLOCK_SIZE and P2, define the block size for processing and the power of 2 ceiling of the row size, respectively. The kernel computes the absolute values of elements within a row, finds the maximum value among them, and uses this max value to scale each element such that it fits into the int8 range via the function tl.extra.cuda.libdevice.llrint. The quantized values and max values are then stored in their respective output locations. The quantize_rowwise function acts as a wrapper to set up the necessary data structures, ensure the input is a CUDA tensor, and launch the kernel with appropriate grid size, which corresponds to the number of rows. It returns the quantized output tensor and a tensor containing the max values for each row.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n\n# TODO: autotune this better.\n@triton.autotune(\n configs=[\n triton.Config({}, num_stages=1, num_warps=8),\n triton.Config({}, num_stages=2, num_warps=8),\n triton.Config({}, num_stages=4, num_warps=8),\n triton.Config({}, num_stages=8, num_warps=8),\n triton.Config({}, num_stages=1),\n triton.Config({}, num_stages=2),\n triton.Config({}, num_stages=4),\n triton.Config({}, num_stages=8),\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n ],\n key=[\"n_elements\"],\n)\n@triton.jit\ndef _quantize_rowwise(\n x_ptr,\n output_ptr,\n output_maxs,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n P2: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n arange = tl.arange(0, P2)\n offsets = block_start + arange\n row_mask = arange < BLOCK_SIZE\n x = tl.load(x_ptr + offsets, mask=row_mask)\n\n abs_x = tl.abs(x)\n max_val = tl.max(tl.where(row_mask, abs_x, 0), axis=0)\n output = tl.extra.cuda.libdevice.llrint(127.0 * (x / max_val))\n tl.store(output_ptr + offsets, output, mask=row_mask)\n tl.store(output_maxs + pid, max_val)\n\ndef quantize_rowwise(x: torch.Tensor):\n output = torch.empty(*x.shape, device=x.device, dtype=torch.int8)\n output_maxs = torch.empty(x.shape[0], device=x.device, dtype=torch.float16)\n\n P2 = int(2 ** (math.ceil(math.log2(x.shape[1]))))\n\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (x.shape[0],)\n _quantize_rowwise[grid](x, output, output_maxs, n_elements, BLOCK_SIZE=x.shape[1], P2=P2)\n return output, output_maxs\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-b90a5bccd8e84c538cd4edc6aa9bd03a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The '_fwd_kernel' Triton kernel performs a forward computation for a token-wise softmax reduction involving several key steps. The kernel operates on three-dimensional data structured by 'Logics', 'V', and 'Out', together with indices and sequence metadata ('B_Loc', 'B_Start_Loc', 'B_Seqlen') which are crucial for handling variable-length sequences within a batch. The primary purpose is to apply a scaled dot-product attention mechanism, which includes a softmax computation on logits and a weighted summation. Within the function, the grid is defined over batch and head dimensions, indicated by 'program_id(0)' and 'program_id(1)'. The kernel iteratively processes blocks of the sequence, as determined by 'BLOCK_N', calculating maximum exponentials 'e_max', probabilities 'p', and an accumulated weighted sum 'acc' over the sequence. This is achieved by carefully loading indices with 'tl.load', performing exponentiation and normalization, and ensuring numerical stability through max-shifted exponentials. Post-loop, results in 'acc' are normalized by the sum of exponentials and stored into 'Out'. The helper function 'token_softmax_reducev_fwd' is responsible for launching '_fwd_kernel' across all combinations of batch and head via the grid, setting up strides and block sizes using input tensor strides and the model dimension 'BLOCK_DMODEL'. Execution parameters include a configurable number of warps and stages for performance tuning.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n\n@triton.jit\ndef _fwd_kernel(\n Logics, V, Out,\n B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n stride_logic_h, stride_logic_bs,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n stride_b_loc_b, stride_b_loc_s,\n other_kv_index, # Avoid reading NaN data\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch)\n\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n off_v = cur_head * stride_vh + offs_d[None, :] * stride_vd\n off_b_loc = cur_batch * stride_b_loc_b + (max_input_len - cur_batch_seq_len) * stride_b_loc_s\n\n v_ptrs = V + off_v\n\n e_max = float(\"-inf\")\n e_sum = 0.0\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n for start_n in range(0, cur_batch_seq_len, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n v_index = tl.load(B_Loc + off_b_loc + (start_n + offs_n) * stride_b_loc_s, mask=(start_n + offs_n) < cur_batch_seq_len, other=other_kv_index)\n\n qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, \n mask=start_n + offs_n < cur_batch_seq_len, other=float(\"-inf\"))\n \n n_e_max = tl.maximum(tl.max(qk, 0), e_max)\n old_scale = tl.exp(e_max - n_e_max)\n p = tl.exp(qk - n_e_max)\n e_sum = e_sum * old_scale + tl.sum(p, 0)\n v = tl.load(v_ptrs + v_index[:, None] * stride_vbs)\n acc = acc * old_scale + tl.sum(p[:, None] * v, 0)\n e_max = n_e_max\n\n acc = acc / e_sum\n off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc)\n return\n\n\n@torch.no_grad()\ndef token_softmax_reducev_fwd(logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len, other_kv_index):\n BLOCK = 64\n batch, head = b_seq_len.shape[0], logics.shape[0]\n grid = (batch, head)\n num_warps = 1\n _fwd_kernel[grid](\n logics, v, o, b_loc, b_start_loc, b_seq_len, max_input_len,\n logics.stride(0), logics.stride(1),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n b_loc.stride(0), b_loc.stride(1),\n other_kv_index,\n BLOCK_DMODEL=v.shape[-1],\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=3\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-a5cbcbfd65a54a659d44b70104f8cc65", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based kernel for matrix multiplication of INT4 quantized weights and provides Python functions to handle the quantization and dequantization processes. \n The 'matmul_kernel' function is a Triton kernel using @triton.jit which performs matrix multiplication. It processes the input matrices in tiles defined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K, and uses a loop to iterate over the K dimension to accumulate results in fp32 precision, subsequently stored in the output buffer, potentially employing atomic add for SPLIT_K > 1.\n The kernel is set up with a wide variety of configurations, allowing Triton's autotuning capabilities to select optimal parameters based on matrix dimensions M, N, and K. It utilizes quantized INT4 weights, reconstructing them using scales and zero points to compute matrix products accurately. \n The Python function 'matmul_dequantize_int4_s2' serves as an interface to this kernel, preparing data and launching the computation on specified grid dimensions.\n The function 'quantize_int4' converts a floating-point weight matrix into INT4 format, organizing the data by packing 8 INT4 values into one INT32 and calculating scaling factors and zero points for each group of elements. \n The helper function 'unpack_int4' is intended for testing; it unpacks the INT4 matrix back into a floating-point format to verify the quantization process.\n ", + "output": [ + "import time\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n\tconfigs=[\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t \n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 1, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n\t\ttriton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 256, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n triton.Config({'SPLIT_K': 2, 'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 16, 'BLOCK_SIZE_K': 512, 'GROUP_SIZE_M': 16}, num_stages=2, num_warps=4),\n\t\t\n ],\n\tkey=['M', 'N', 'K'],\n reset_to_zero=['c_ptr']\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n bs_ptr, bzp_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n group_size,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, SPLIT_K: tl.constexpr\n ):\n \"\"\"\n assert K % (BLOCK_SIZE_K * SPLIT_K) == 0\n \"\"\"\n pid = tl.program_id(axis=0)\n pid_sp_k = tl.program_id(axis=1)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = pid_sp_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n\n # [BLOCK_M, BLOCK_K]\n a_ptrs = a_ptr + offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak\n # [BLOCK_K, BLOCK_N] but repeated 8 times in N\n b_ptrs = b_ptr + (offs_k[:, None] // 8) * stride_bk + offs_bn[None, :] * stride_bn\n # tl.static_print(\"shape\", a_ptrs, b_ptrs, bs_ptrs, bzp_ptrs)\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):\n # Load the next block of A and B.\n # [BLOCK_K, BLOCK_N] but repeated group_size times in K \n bs_ptrs = bs_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bsk \\\n + offs_bn[None, :] * stride_bsn\n # [BLOCK_K, BLOCK_N] but repeated in K and N\n bzp_ptrs = bzp_ptr + ((offs_k[:, None] + k * BLOCK_SIZE_K * SPLIT_K) // group_size) * stride_bzpk \\\n + (offs_bn[None, :] // 8) * stride_bzpn\n b_shift_bits = (offs_k[:, None] % 8) * 4 # assert BLOCK_SIZE_K % 8 == 0\n bzp_shift_bits = (offs_bn[None, :] % 8) * 4\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n bs = tl.load(bs_ptrs)\n bzp = tl.load(bzp_ptrs)\n # We accumulate along the K dimension.\n int_b = (b >> b_shift_bits) & 0xF\n int_bzp = (bzp >> bzp_shift_bits) & 0xF\n b = ((int_b - int_bzp) * bs).to(a.dtype)\n accumulator += tl.dot(a, b.to(a.dtype))\n # Advance the ptrs to the next K block.\n a_ptrs += BLOCK_SIZE_K * SPLIT_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K * SPLIT_K * stride_bk // 8) # assert BLOCK_SIZE_K % 8 == 0\n # You can fuse arbitrary activation functions here\n # while the accumulator is still in FP32!\n c = accumulator.to(c_ptr.dtype.element_ty)\n # -----------------------------------------------------------\n # Write back the block of the output matrix C with masks.\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if SPLIT_K == 1:\n tl.store(c_ptrs, c, mask=c_mask)\n else:\n tl.atomic_add(c_ptrs, c, mask=c_mask)\n\n\ndef matmul_dequantize_int4_s2(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size: int = 128, output=None) -> torch.FloatTensor:\n \"\"\"\n \"\"\"\n assert x.is_contiguous(), \"A must be contiguous\"\n assert qweight.is_contiguous(), \"B must be contiguous\" \n M, K = x.shape\n N = scales.shape[1]\n if output is None:\n output = torch.zeros((M, N), device=x.device, dtype=x.dtype) \n grid = lambda META: (\n triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n META['SPLIT_K'],\n )\n matmul_kernel[grid](\n x, qweight, output,\n scales, qzeros,\n M, N, K,\n x.stride(0), x.stride(1),\n qweight.stride(0), qweight.stride(1),\n output.stride(0), output.stride(1),\n scales.stride(0), scales.stride(1),\n qzeros.stride(0), qzeros.stride(1),\n group_size,\n )\n return output\n\ndef quantize_int4(weight, group_size=128, tp_rank=0):\n # Weight shape: [H1 // 8, H2]\n # Scale shape: [H1 // group_size, H2]\n # zero_pint shape: [H1 // group_size, H2 // 8]\n\n weight = weight.transpose(1, 0)\n h1, h2 = weight.shape\n assert h1 % 8 == 0 and h2 % 8 == 0, \"H1 {} H2 {}\".format(h1, h2)\n assert h2 % group_size == 0, \"H1 {} H2 {}\".format(h1, h2)\n weight = weight.contiguous().view(-1, group_size).cuda(tp_rank)\n weight_max = weight.amax(-1, keepdim=True)\n weight_max = torch.where(weight_max < 0, 0, weight_max)\n weight_min = weight.amin(-1, keepdim=True)\n weight_min = torch.where(weight_min > 0, 0, weight_min)\n weight_range = weight_max - weight_min \n scale = weight_range / (2 ** 4 - 1)\n zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32)\n weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2)\n int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device)\n int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device)\n zero_point = zero_point.view(h1, -1)\n scale = scale.view(h1, -1)\n # pack 8 int4 in an int32 number.\n # Weight pack in row.\n for pack in range(0, h2, 8):\n for i in range(8):\n int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4)\n # zero point pack in col.\n for pack in range(0, h1, 8):\n for i in range(8):\n int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4)\n '''\n fp_weight = torch.zeros(h1, h2).half().to(weight.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_weight[pack * 8 + i, :] = \\\n ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16\n print((fp_weight - weight).abs().sum())\n\n fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zp[pack * 8 + i, :] = \\\n (int_zero_point[pack, :] >> (i * 4)) & 15\n\n print((fp_zp - zero_point).abs().sum())\n '''\n weight = None\n return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size\n\n\ndef unpack_int4(weight, scale, zp):\n \"\"\"\n Test function to verify quantize int4 is correct.\n Will not be used in model inference.\n \"\"\"\n weight = weight.transpose(1, 0)\n scale = scale.transpose(1, 0)\n zp = zp.transpose(1, 0)\n h1, h2 = weight.shape\n group_size = h2 * 8 // scale.shape[1]\n group_num = scale.shape[1]\n fp_weight = torch.zeros(h1, h2 * 8).half().to(weight.device)\n fp_zero_point = torch.zeros(h1, group_num).to(weight.device)\n for pack in range(0, h2):\n for i in range(8):\n fp_weight[:, pack * 8 + i] = (weight[:, pack] >> (i * 4)) & 0xF\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zero_point[pack * 8 + i, :] = (zp[pack, :] >> (i * 4)) & 0xF\n for g in range(group_num):\n fp_weight[:, g * group_size:(g + 1) * group_size] = (fp_weight[:, g * group_size:(g + 1) * group_size] - \\\n fp_zero_point[:, g].unsqueeze(1)) * scale[:, g].unsqueeze(1)\n return fp_weight.transpose(1, 0)\n\n\n\n" + ] + }, + { + "id": "openseek-8-0d1269dfa9ab4586b06c0051314e209e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton code defines a kernel named 'puzzle1_kernel'. This kernel performs an element-wise addition of a constant value to each element in the input tensor 'x'. The kernel expects the following parameters: 'x_ptr' (a pointer to the input tensor), 'output_ptr' (a pointer to the output tensor), 'N' (the total number of elements), 'BLOCK_SIZE' (a compile-time constant defining the block size), and 'value' (the constant to add). Within the kernel, 'pid' identifies the program's unique ID along the specified axis (axis=0). Each block computes a starting index 'block_start', and 'offsets' determines the positions within this block. A 'mask' ensures that computations only occur within valid indices (offsets < N). The kernel loads data from 'x_ptr', adds 'value', and stores the result in 'output_ptr'. The wrapper function 'puzzle1' initializes an output tensor with the same shape and type as 'x'. It verifies CUDA compatibility of input and output tensors, calculates the total number of elements 'N', defines the grid size using the function 'grid', and invokes 'puzzle1_kernel' with specified block size and constant value (10). It returns the output tensor that contains the modified data.\n ", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\n# Triton kernel\n@triton.jit\ndef puzzle1_kernel(x_ptr, output_ptr, N, BLOCK_SIZE: tl.constexpr, value):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < N\n x = tl.load(x_ptr + offsets, mask=mask)\n output = x + value\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Wrapper function to call the kernel\ndef puzzle1(x: torch.Tensor):\n output = torch.empty_like(x)\n assert x.is_cuda and output.is_cuda\n N = output.numel()\n grid = lambda meta: (triton.cdiv(N, meta['BLOCK_SIZE']),)\n puzzle1_kernel[grid](x, output, N, BLOCK_SIZE=1024, value=10)\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c2ee88553e7048338788d5bcf967a2ad", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines two conversion operations using Triton: one for converting float8 data stored as int8 to float16 (`kernel_f8_to_f16` and `f8_to_f16`), and another for converting float16 or float32 data to float8 stored as int8 (`kernel_f16_to_f8` and `f16_to_f8`). \n The kernel `kernel_f8_to_f16` identifies a program ID `pid` and computes offsets `offs` based on `BLOCK_SIZE`. It loads data from input `X` with masking for valid indices, then writes the same data to output `Y`. The conversion is performed by interpreting the input int8 tensor as float8.\n The `f8_to_f16` function ensures the input tensor is of type `torch.int8` and resides on a CUDA device. It creates an output tensor of `torch.float16` type. It calculates the grid size based on the number of elements, and calls the Triton kernel.\n The `kernel_f16_to_f8` operates similarly, processing float16 or float32 inputs and outputting int8 results, with appropriate reinterpretation of data types. The `f16_to_f8` function follows similar assertions and grid calculations.\n Both functions rely on Triton’s grid and masking functionalities for efficient computation.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef kernel_f8_to_f16(Y, X, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offs < N\n x = tl.load(X + offs, mask=mask)\n tl.store(Y + offs, x, mask=mask)\n tl.store(Y + offs, x, mask=mask)\n\ndef f8_to_f16(x, dtypes=tl.float8e5) -> torch.Tensor:\n assert x.dtype == torch.int8, f\"torch.int8 expected but got {x.dtype}\"\n assert \"cuda\" in str(x.device), f\"CUDA tensors only but got {x.device}\"\n ret = torch.empty_like(x, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)\n numel = ret.untyped_storage().size() // ret.element_size()\n kernel_f8_to_f16[grid](ret, triton.reinterpret(x, dtypes), numel, BLOCK_SIZE=1024)\n return ret\n\n@triton.jit\ndef kernel_f16_to_f8(Y, X, N, BLOCK_SIZE: tl.constexpr):\n pid = tl.program_id(0)\n offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = offs < N\n x = tl.load(X + offs, mask=mask)\n tl.store(Y + offs, x, mask=mask)\n\ndef f16_to_f8(x: torch.Tensor, dtypes=tl.float8e5) -> torch.Tensor:\n assert x.dtype in [torch.float16, torch.float32]\n assert \"cuda\" in str(x.device), f\"CUDA tensors only but got {x.device}\"\n ret = torch.empty_like(x, dtype=torch.int8)\n grid = lambda META: (triton.cdiv(x.numel(), META['BLOCK_SIZE']),)\n numel = x.untyped_storage().size() // x.element_size()\n kernel_f16_to_f8[grid](triton.reinterpret(ret, dtypes), x, numel, BLOCK_SIZE=1024)\n return ret\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-2db59e58609e4ada8ee2446defd448db", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton kernel `matmul_kernel` is designed for matrix multiplication of two large matrices, `a` and `b`, storing the result in `c`. The kernel operates using a blocked approach, where `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `BLOCK_SIZE_K` dictate the dimensions of sub-matrices each thread block handles. The logical process begins by calculating each thread block's starting position in the output matrix via `pid_m` and `pid_n`, which are program IDs along the two major axes. Matrix strides (`stride_am`, `stride_ak`, `stride_bk`, `stride_bn`, `stride_cm`, `stride_cn`) are used for accessing elements efficiently. Offsets `offs_am`, `offs_bn`, and `offs_k` are calculated to locate the correct sub-blocks of `a` and `b` for loading. The computation iterates over the K dimension in blocks, performing a dot product of sub-blocks `a` and `b`, accumulating results in `accumulator`. After processing all blocks along K, the result in `accumulator` is cast to `float16` and stored in `c`. The outer `matmul` function configures the kernel launch, calculating grid sizes as a function of the matrix dimensions `M`, `N`, and `K` divided by their respective block sizes.\n ", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef matmul_kernel(c_ptr, a_ptr, b_ptr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid_m = tl.program_id(axis=0)\n pid_n = tl.program_id(axis=1)\n\n M, N, K = 4096, 4096, 4096\n stride_am = 4096\n stride_ak = 1\n stride_bk = 4096\n stride_bn = 1\n stride_cm = 4096\n stride_cn = 1\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs)\n b = tl.load(b_ptrs)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n c = tl.cast(accumulator, tl.float16)\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n tl.store(c_ptrs, c)\n\ndef matmul(c, a, b, M, N, K, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K):\n matmul_kernel[triton.cdiv(M, BLOCK_SIZE_M), triton.cdiv(N, BLOCK_SIZE_N)](\n c, a, b, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K\n )\n\n\n\n" + ] + }, + { + "id": "openseek-8-8a71364b91da4cfca85d866fc864b9f1", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a fused recurrent retention mechanism using Triton kernels for efficient computation. The primary function is 'fused_recurrent_retention', which takes query 'q', key 'k', value 'v', and an optional 'initial_state'. It performs a series of matrix operations across a temporal dimension 'T' to produce an output tensor 'o' and optionally a 'final_state' tensor.\n The operation is carried out by two Triton kernels: 'fused_recurrent_retention_fwd_kernel' for the forward pass and 'fused_recurrent_retention_bwd_kernel' for the backward pass. The forward kernel loads and processes blocks of data from 'q', 'k', 'v', and iteratively updates an accumulator 'h' using a decay factor based on the current head index 'i_h'. If 'USE_INITIAL_STATE' is true, the kernel begins with a preloaded initial state. The result for each timestep is stored in 'o', and if 'STORE_FINAL_STATE' is set, the final state is stored. The backward kernel reverses the process, starting from 'do' to compute gradients for 'q', 'k', 'v' while handling the updates in a reverse order.\n Important constants like 'BK', 'BV' define block sizes, while 'scale' scales the query tensors to normalize outputs. The computation grid is defined by dimensions related to input tensor shapes and Triton's efficient layout strategies.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_retention_fwd_kernel(\n q, k, v, o, initial_state, final_state, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n B, H, T, scale, \n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n # Kernel logic\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = (1 - tl.math.exp2(-5 - i_h * 1.0))\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_o = o + (i_bh + i_k * B * H) * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < DK\n mask_bv = (i_v * BV + tl.arange(0, BV)) < DV\n mask_kv = mask_bk[None, :] & mask_bv[:, None]\n\n h = tl.zeros([BV, BK], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for _ in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n\n h = b_b * h + _k[None, :] * _v[:, None]\n _o = h * _q[None, :]\n _o = tl.sum(_o, axis=1)\n tl.store(p_o, _o.to(p_o.dtype.element_ty), mask=mask_bv)\n\n p_q += DK\n p_k += DK\n p_o += DV\n p_v += DV\n\n if STORE_FINAL_STATE:\n p_final_s = final_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[None, :]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_final_s, h.to(p_final_s.dtype.element_ty), mask=mask_kv)\n\n@triton.jit\ndef fused_recurrent_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, initial_state, \n s_qk_h, s_qk_t, s_qk_d, \n s_vo_h, s_vo_t, s_vo_d, \n B, H, T, scale, \n BK: tl.constexpr, BV: tl.constexpr, DK: tl.constexpr, DV: tl.constexpr, \n USE_INITIAL_STATE: tl.constexpr\n):\n # Kernel logic\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = 1 - tl.math.exp2(-5 - i_h * 1.0)\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK)\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV)\n\n p_dq = dq + (i_bh + i_v * B * H) * s_qk_h + i_k * BK + tl.arange(0, BK)\n mask_bk = i_k * BK + tl.arange(0, BK) < DK\n mask_bv = i_v * BV + tl.arange(0, BV) < DV\n\n h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n mask_kv = mask_bk[:, None] & mask_bv[None, :]\n p_init_s = initial_state + i_bh * DK * DV + \\\n (i_k * BK + tl.arange(0, BK)[:, None]) * \\\n DV + (i_v * BV + tl.arange(0, BV)[None, :])\n h += tl.load(p_init_s, mask=mask_kv, other=0).to(tl.float32)\n\n for i in range(0, T):\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n\n h = b_b * h + _k[:, None] * _v[None, :]\n _d_q = h * _do[None, :]\n d_q = tl.sum(_d_q, axis=1) * scale\n tl.store(p_dq, d_q.to(p_dq.dtype.element_ty), mask=mask_bk)\n\n p_k += DK\n p_do += DV\n p_v += DV\n p_dq += DK\n\n tl.debug_barrier()\n\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (T - 1) * DK\n p_do = do + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_v = v + i_bh * s_vo_h + i_v * BV + tl.arange(0, BV) + (T - 1) * DV\n p_dk = dk + (i_bh + i_v * B * H) * s_qk_h + i_k * \\\n BK + tl.arange(0, BK) + (T - 1) * DK\n p_dv = dv + (i_bh + i_k * B * H) * s_vo_h + i_v * \\\n BV + tl.arange(0, BV) + (T - 1) * DV\n d_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n for _ in range(T):\n _do = tl.load(p_do, mask=mask_bv, other=0).to(tl.float32)\n _q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n _k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n _v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n d_h += _q[:, None] * _do[None, :]\n d_k = tl.sum(d_h * _v[None, :], axis=1)\n d_v = tl.sum(d_h * _k[:, None], axis=0)\n\n d_h *= b_b\n tl.store(p_dk, d_k.to(p_dk.dtype.element_ty), mask=mask_bk)\n tl.store(p_dv, d_v.to(p_dv.dtype.element_ty), mask=mask_bv)\n\n p_do -= DV\n p_q -= DK\n p_k -= DK\n p_v -= DV\n p_dk -= DK\n p_dv -= DV\n\nclass FusedRecurrentRetentionFunction(torch.autograd.Function):\n @staticmethod\n def forward(ctx, q, k, v, initial_state=None, output_final_state=False):\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n\n scale = d_head_qk ** -0.5\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n o = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n\n if output_final_state:\n final_state = q.new_empty(batch_size, n_heads, d_head_qk, d_head_v)\n else:\n final_state = None\n\n grid = (NV, NK, batch_size * n_heads)\n fused_recurrent_retention_fwd_kernel[grid](\n q, k, v, o, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, d_final_state=None):\n q, k, v, initial_state = ctx.saved_tensors\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n scale = d_head_qk ** -0.5\n\n BK, BV = min(d_head_qk, 32), min(d_head_v, 32)\n NK, NV = triton.cdiv(d_head_qk, BK), triton.cdiv(d_head_v, BV)\n num_stages = 1\n num_warps = 1\n\n dq = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dk = q.new_empty(NV, batch_size, n_heads, seq_len, d_head_qk)\n dv = q.new_empty(NK, batch_size, n_heads, seq_len, d_head_v)\n grid = (NV, NK, batch_size * n_heads)\n\n fused_recurrent_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv, initial_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n DK=d_head_qk, DV=d_head_v, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages,\n USE_INITIAL_STATE=initial_state is not None\n )\n dq = dq.sum(0)\n dk = dk.sum(0)\n dv = dv.sum(0)\n return dq, dk, dv, None, None\n\ndef fused_recurrent_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e2b318addf2d4725a561656bcee20ec4", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The 'rmsnorm_forward' function performs Root Mean Square (RMS) normalization on an input tensor using a custom Triton kernel, '_rms_norm_fwd_fused'. The inputs are 'x', a PyTorch tensor with arbitrary dimensions, and 'weight', a tensor of the same size as the last dimension of 'x' for element-wise multiplication after normalization. It outputs 'y', which is a tensor of the same shape as 'x'. Inside the Triton kernel, '_rms_norm_fwd_fused', the function is mapped to each row of the input 'X'. It calculates the variance by loading input elements in blocks of size 'BLOCK_SIZE', summing their squares, and averaging over the number of columns 'N'. The reciprocal standard deviation (rstd) is computed to normalize each element of the input tensor. The normalized values are then multiplied by the corresponding elements of 'weight' and stored in the output tensor 'Y'. The operation is conditioned to handle edge cases when the number of columns is not a perfect multiple of 'BLOCK_SIZE' using masking. Additionally, the 'rmsnorm_forward' function includes heuristics for deciding the number of Triton warps, managing feature size constraints, and arranging the data for efficient kernel execution.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.).to(tl.float32)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n x = tl.load(X + cols, mask=mask, other=0.).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y.to(tl.float16), mask=mask)\n\ndef rmsnorm_forward(x, weight, eps):\n # allocate output\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.view(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n BLOCK_SIZE = 128 * 2 * 2 * 2 * 2 * 2 * 2 * 2\n num_warps = 8\n # enqueue kernel\n _rms_norm_fwd_fused[(M,)](x_arg, y, weight,\n x_arg.stride(0), N, eps,\n BLOCK_SIZE=BLOCK_SIZE, num_warps=num_warps)\n return y\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1e2661c9f8c543a990905e3d122e2ea2", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code provides a Triton kernel named _fwd_kernel_destindex_copy_kv, designed to facilitate the copying of tensor slices based on specified destination indices. This kernel is invoked within the destindex_copy_kv function, a wrapper function designed for PyTorch. The core functionality of the kernel involves processing a 3D input tensor K, representing data across batch size, head number, and head dimension, and transferring selected slices to the output tensor Out using indices from the DestLoc tensor. Each element in DestLoc maps to a position in the batch dimension of Out. The kernel logic iterates over each sequence element, indexed by cur_index, and determines the corresponding slice to copy using DestLoc. Offsets offs_h and offs_d calculate strides for accessing head and depth dimensions. The kernel dynamically adapts to input sizes, employing BLOCK_HEAD and BLOCK_DMODEL as block size parameters, computed as powers of two for optimal performance. The wrapper function orchestrates kernel execution by defining the execution grid as the sequence length, confirming dimensional integrity between K and Out, and invoking the kernel with appropriate stride and dimension parameters. By operating within Triton's just-in-time compilation environment, this code exemplifies an efficient mechanism for parallel data manipulation in deep learning models.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K,\n Dest_loc,\n Out,\n stride_k_bs,\n stride_k_h,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_d,\n head_num,\n head_dim,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0)\n tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim))\n return\n\n\n@torch.no_grad()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K,\n DestLoc,\n Out,\n K.stride(0),\n K.stride(1),\n K.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n head_num,\n head_dim,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n" + ] + }, + { + "id": "openseek-8-c98ff8963d624662a4c340c2d5af4210", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_rotary_kernel` is a Triton JIT-compiled kernel aimed at applying a rotary positional encoding to tensors `Q` and `K`. This is achieved through the use of precomputed cosine and sine values contained in tensors `Cos` and `Sin`. The kernel operates on blocks of heads and sequences defined by compile-time constants `BLOCK_HEAD`, `BLOCK_SEQ`, and `BLOCK_DMODEL`.\n\n The kernel leverages Triton's parallel processing model by executing on a 2D grid. Each block is tasked with processing specific portions of the input based on the current head and sequence indices, which are derived from the program's IDs (`tl.program_id(0)` and `tl.program_id(1)`). Strides and offsets are computed to efficiently access and manipulate the relevant sections of `Q`, `K`, `Cos`, and `Sin`.\n\n The main computational steps involve:\n - Calculating offsets for `Q` and `K` based on their strides and the current block indices.\n - Loading segments of `Q` and `K` using these offsets and applying boundary masks to avoid out-of-bounds accesses.\n - Applying the rotary transformation by combining the loaded values with `Cos` and `Sin` using the rotation formulas: `out0 = q0 * cos0 - q1 * sin0` and `out1 = q0 * sin1 + q1 * cos1`.\n - Storing the transformed segments back to the original locations in `Q` and `K`.\n\n The `rotary_emb_fwd` function acts as a high-level interface for invoking the kernel. It validates input shapes, calculates execution grid dimensions (`grid`), and determines the number of warps to use based on the head dimension size. This function encapsulates all preparation logic required for the kernel execution, ensuring that `Q`, `K`, `Cos`, and `Sin` are correctly aligned with the expected input format of `_rotary_kernel`.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _rotary_kernel(\n Q,\n K,\n Cos,\n Sin,\n stride_qbs,\n stride_qh,\n stride_qd,\n stride_kbs,\n stride_kh,\n stride_kd,\n stride_cosbs,\n stride_cosd,\n stride_sinbs,\n stride_sind,\n max_total_len,\n HEAD_Q,\n HEAD_K,\n BLOCK_HEAD: tl.constexpr,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_head_index = tl.program_id(0)\n cur_seq_index = tl.program_id(1)\n\n cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD)\n cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ)\n\n dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2\n dim_range1 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + 1\n\n off_q0 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range0[None, None, :] * stride_qd\n )\n off_q1 = (\n cur_seq_range[:, None, None] * stride_qbs\n + cur_head_range[None, :, None] * stride_qh\n + dim_range1[None, None, :] * stride_qd\n )\n\n off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd\n\n q0 = tl.load(\n Q + off_q0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n q1 = tl.load(\n Q + off_q1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q),\n other=0.0,\n )\n\n cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out0 = q0 * cos0 - q1 * sin0\n out1 = q0 * sin1 + q1 * cos1\n\n tl.store(\n Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n tl.store(\n Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q)\n )\n\n off_k0 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range0[None, None, :] * stride_kd\n )\n off_k1 = (\n cur_seq_range[:, None, None] * stride_kbs\n + cur_head_range[None, :, None] * stride_kh\n + dim_range1[None, None, :] * stride_kd\n )\n\n off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd\n off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd\n\n k0 = tl.load(\n K + off_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n k1 = tl.load(\n K + off_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n other=0.0,\n )\n\n cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0)\n\n out_k0 = k0 * cos0 - k1 * sin0\n out_k1 = k0 * sin1 + k1 * cos1\n\n tl.store(\n K + off_k0,\n out_k0,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n tl.store(\n K + off_k1,\n out_k1,\n mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K),\n )\n return\n\n@torch.no_grad()\ndef rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0):\n total_len = q.shape[0]\n head_num_q, head_num_k = q.shape[1], k.shape[1]\n head_dim = int(q.shape[2] * partial_rotary_factor)\n assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f\"q shape {q.shape} cos shape {cos.shape}\"\n assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f\"k shape {k.shape} cos shape {cos.shape}\"\n\n BLOCK_SEQ = 16\n BLOCK_HEAD = 4\n if head_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ))\n _rotary_kernel[grid](\n q,\n k,\n cos,\n sin,\n q.stride(0),\n q.stride(1),\n q.stride(2),\n k.stride(0),\n k.stride(1),\n k.stride(2),\n cos.stride(0),\n cos.stride(1),\n sin.stride(0),\n sin.stride(1),\n total_len,\n head_num_q,\n head_num_k,\n BLOCK_HEAD=BLOCK_HEAD,\n BLOCK_SEQ=BLOCK_SEQ,\n BLOCK_DMODEL=head_dim,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1f1054d88f2b4f27b2a68166f75d9bb8", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_fwd_kernel_flash_decode_stage2` Triton kernel is a parallel computation designed for processing sequences in a neural network context, specifically dealing with batches, heads, and sequence blocks. This kernel receives several inputs: `B_Seqlen`, `Mid_O`, `Mid_O_LogExpSum`, and `Out`, along with strides for indexing. `B_Seqlen` contains sequence lengths per batch, `Mid_O` contains intermediate outputs, `Mid_O_LogExpSum` holds log-exp sum values, and `Out` will store the final output. The kernel operates over a 2D grid defined by batch size and head count (`grid = (batch, head_num)`), with constants `BLOCK_SEQ` and `BLOCK_DMODEL` indicating sequence block size and dimension alignment respectively.\n\n The kernel function operates as follows:\n - Identifies the current batch and head using `tl.program_id`.\n - Initializes accumulators: `sum_exp`, `max_logic`, and `acc` to accumulate exponential logic and values.\n - Loads the current sequence length and calculates the number of sequence blocks (`block_n_size`).\n - Iterates over each block, where:\n - It loads values (`tv`) from `Mid_O` and logic sums (`tlogic`) from `Mid_O_LogExpSum`.\n - Computes the maximum logic value across blocks and scales previous accumulations.\n - Updates the accumulators by computing the exponential of adjusted logic values and scaling/accumulating.\n - Stores the final normalized result into `Out`, scaling accumulated values by the sum of exponentials.\n\n The `flash_decode_stage2` function sets up and invokes this kernel, determining dimensions and grid setup based on input tensor shapes. It ensures efficient computation by using Triton's parallel execution framework, specifying warp and stage numbers.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_flash_decode_stage2(\n B_Seqlen,\n Mid_O, # [batch, head, seq_block_num, head_dim]\n Mid_O_LogExpSum, # [batch, head, seq_block_num]\n Out, # [batch, head, head_dim]\n stride_mid_ob,\n stride_mid_oh,\n stride_mid_os,\n stride_mid_od,\n stride_mid_o_eb,\n stride_mid_o_eh,\n stride_mid_o_es,\n stride_obs,\n stride_oh,\n stride_od,\n head_dim,\n BLOCK_SEQ: tl.constexpr,\n BLOCK_DMODEL: tl.constexpr,\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n\n block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ\n\n sum_exp = 0.0\n max_logic = -float(\"inf\")\n acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32)\n\n offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d\n offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh\n for block_seq_n in range(0, block_n_size, 1):\n tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0)\n tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n)\n new_max_logic = tl.maximum(tlogic, max_logic)\n\n old_scale = tl.exp(max_logic - new_max_logic)\n acc *= old_scale\n exp_logic = tl.exp(tlogic - new_max_logic)\n acc += exp_logic * tv\n sum_exp = sum_exp * old_scale + exp_logic\n max_logic = new_max_logic\n\n tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim)\n return\n\n@torch.no_grad()\ndef flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq):\n Lk = mid_out.shape[-1]\n head_dim = Lk\n batch, head_num = mid_out.shape[0], mid_out.shape[1]\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (batch, head_num)\n\n _fwd_kernel_flash_decode_stage2[grid](\n B_Seqlen,\n mid_out,\n mid_out_logexpsum,\n Out,\n mid_out.stride(0),\n mid_out.stride(1),\n mid_out.stride(2),\n mid_out.stride(3),\n mid_out_logexpsum.stride(0),\n mid_out_logexpsum.stride(1),\n mid_out_logexpsum.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n head_dim,\n BLOCK_SEQ=block_seq,\n BLOCK_DMODEL=BLOCK_DMODEL,\n num_warps=4,\n num_stages=2,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-afa1281f85334d4aab4dc28ffdd0aff2", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code outlines an implementation for matrix multiplication on GPU, enhanced by an autotuning mechanism which optimizes the kernel for various GPU configurations. The key computational routine is encapsulated in `matmul_kernel`, which multiplies matrices A (MxK) and B (KxN) to produce C (MxN). The kernel is decorated with `@triton.autotune` to select optimal parameters dynamically based on input dimensions (M, N, K).\n\n The `matmul` function acts as a convenient interface, validating input tensor compatibility, ensuring contiguity, and calling the kernel with the right execution grid computed using the `grid` lambda function. The result is stored in a new tensor `c` of shape (M, N).\n\n Core logic inside `matmul_kernel`:\n - Computes grid and block indices to iterate over matrix chunks, using block size constants like `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `BLOCK_SIZE_K`.\n - Pointers (`a_ptrs`, `b_ptrs`) are computed for accessing matrix blocks, leveraging input strides.\n - An accumulator is initialized to zero and filled via dot-product operations across the K dimension.\n - If specified, a leaky ReLU activation is applied, modifying the result where needed.\n - The result matrix C is stored conditionally to handle edge cases on matrix boundaries.\n\n The `leaky_relu` function is conditionally applied, performing element-wise operations to modify negative values in the accumulator. Triton's use of `tl.where` facilitates this transformation efficiently.\n\n Overall, this implementation exemplifies efficient, configurable GPU-based matrix multiplication suitable for diverse GPU architectures.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Autotune configuration functions\ndef is_cuda():\n return triton.runtime.driver.active.get_current_target().backend == \"cuda\"\n\ndef get_cuda_autotune_config():\n return [\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n # Other configurations omitted for brevity\n ]\n\ndef get_hip_autotune_config():\n return [\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 16, 'GROUP_SIZE_M': 1, 'waves_per_eu': 2}, num_warps=4, num_stages=2),\n # Other configurations omitted for brevity\n ]\n\ndef get_autotune_config():\n if is_cuda():\n return get_cuda_autotune_config()\n else:\n return get_hip_autotune_config()\n\n# Triton kernel for matrix multiplication with autotuning\n@triton.autotune(\n configs=get_autotune_config(),\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak, \n stride_bk, stride_bn, \n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, \n GROUP_SIZE_M: tl.constexpr, \n ACTIVATION: tl.constexpr \n):\n \"\"\"Kernel for computing the matrix multiplication C = A x B.\n A has shape (M, K), B has shape (K, N), and C has shape (M, N).\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n \"\"\"Applies the Leaky ReLU activation function.\"\"\"\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n \"\"\"Wrapper function for matrix multiplication using Triton.\n a: Tensor with shape (M, K)\n b: Tensor with shape (K, N)\n activation: Activation function to be applied (optional)\n Returns the result of the matrix multiplication.\n \"\"\"\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c, \n M, N, K, \n a.stride(0), a.stride(1), \n b.stride(0), b.stride(1), \n c.stride(0), c.stride(1), \n ACTIVATION=activation \n )\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-bd877bac67a246f5a2d192ea5f9ae5ee", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code defines two GPU kernels using Triton's JIT compilation to facilitate the manipulation of cosine and sine caches often employed in sequence-based models.\n\n - `prefill_cache_kernel`: This kernel is optimized for pre-filling caches when given sequences (prompts) are involved. It receives parameters like `cos_cache` and `sin_cache`, which are the source caches, and target parameters like `cos_output` and `sin_output` for storing the processed cache sections. The kernel calculates the original sequence index for each target index using a cumulative sum of sequence lengths (`cumsum_lengths`) and extracts respective parts from the caches to fill `cos_output` and `sin_output`. It uses logical indexing based on the length of total sequences and caches specified by `HIDDEN_DIM`, `N_ELEMENTS`, and `BLOCK_SIZE`.\n\n - `decoding_cache_kernel`: Targeted at decoding cache data based on sequence lengths, this kernel takes in parameters like `cos_cache`, `sin_cache`, and `lengths`, alongside output buffers `cos_output` and `sin_output`. It retrieves the previous cache entries for a given sequence using its length-1 value to index into the cache, ensuring data continuity. Important constants here include `HIDDEN_DIM`, `NUM_SEQS`, and `BLOCK_SIZE`.\n\n The utility function `get_xine_cache` determines which kernel to execute based on the `is_prompts` flag, setting up computational grids and preparing tensors for output storage. For prompt processing, it aggregates sequence lengths to derive `total_length` and employs the `prefill_cache_kernel`. In non-prompt scenarios, it computes indices directly from `lengths` and engages the `decoding_cache_kernel`. This setup ensures flexible handling of cache operations across varied sequence lengths and types.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef prefill_cache_kernel(\n cos_cache,\n sin_cache,\n cumsum_lengths,\n cos_output,\n sin_output,\n cache_stride,\n hidden_stride,\n total_length,\n HIDDEN_DIM: tl.constexpr,\n N_ELEMENTS: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n idx0 = tl.program_id(axis=0)\n idx1 = tl.program_id(axis=1)\n idx = idx0 * BLOCK_SIZE + idx1\n\n # original seq_idx and pos\n cumsum_lens = tl.load(cumsum_lengths + tl.arange(0, N_ELEMENTS))\n ori_seq_idx = idx - tl.max(tl.where(cumsum_lens <= idx, cumsum_lens, 0))\n cos_cache_part = tl.load(\n cos_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length\n )\n sin_cache_part = tl.load(\n sin_cache + ori_seq_idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride, mask=idx < total_length\n )\n tl.store(\n cos_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,\n cos_cache_part,\n mask=idx < total_length,\n )\n tl.store(\n sin_output + idx * cache_stride + tl.arange(0, HIDDEN_DIM) * hidden_stride,\n sin_cache_part,\n mask=idx < total_length,\n )\n\n\n@triton.jit\ndef decoding_cache_kernel(\n cos_cache,\n sin_cache,\n lengths,\n cos_output,\n sin_output,\n cache_stride,\n hidden_stride,\n HIDDEN_DIM: tl.constexpr,\n NUM_SEQS: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n idx = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n ori_seq_idx = tl.load(lengths + idx, mask=(idx < NUM_SEQS), other=None) # [BLOCK_SIZE,]\n cos_cache_part = tl.load(\n cos_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,\n mask=idx[:, None] < NUM_SEQS,\n )\n sin_cache_part = tl.load(\n sin_cache + ori_seq_idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride,\n mask=idx[:, None] < NUM_SEQS,\n )\n tl.store(\n cos_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),\n cos_cache_part,\n mask=idx[:, None] < NUM_SEQS,\n )\n tl.store(\n sin_output + (idx[:, None] * cache_stride + tl.arange(0, HIDDEN_DIM)[None, :] * hidden_stride),\n sin_cache_part,\n mask=idx[:, None] < NUM_SEQS,\n )\n\n\ndef get_xine_cache(lengths: torch.Tensor, cos_cache: torch.Tensor, sin_cache: torch.Tensor, is_prompts: bool = False):\n assert cos_cache.shape[1] == sin_cache.shape[1]\n _, hidden_dim = cos_cache.shape\n num_seqs = lengths.numel()\n\n if hidden_dim >= 256:\n num_warps = 16\n elif hidden_dim >= 128:\n num_warps = 8\n else:\n num_warps = 4\n\n cache_stride = cos_cache.stride(0)\n hidden_stride = cos_cache.stride(1)\n\n if is_prompts:\n BLOCK_SIZE = 16\n total_length = lengths.sum().item()\n cumsum_lens = torch.cumsum(lengths, dim=0)\n cos_output = torch.empty((total_length, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)\n sin_output = torch.empty((total_length, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)\n grid = (triton.cdiv(total_length, BLOCK_SIZE), BLOCK_SIZE)\n prefill_cache_kernel[grid](\n cos_cache,\n sin_cache,\n cumsum_lens,\n cos_output,\n sin_output,\n cache_stride,\n hidden_stride,\n total_length,\n HIDDEN_DIM=hidden_dim,\n N_ELEMENTS=triton.next_power_of_2(num_seqs),\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n else:\n BLOCK_SIZE = 4\n nlengths = torch.as_tensor(lengths) - 1\n cos_output = torch.empty((num_seqs, hidden_dim), dtype=cos_cache.dtype, device=cos_cache.device)\n sin_output = torch.empty((num_seqs, hidden_dim), dtype=sin_cache.dtype, device=sin_cache.device)\n grid = (triton.cdiv(num_seqs, BLOCK_SIZE),)\n decoding_cache_kernel[grid](\n cos_cache,\n sin_cache,\n nlengths,\n cos_output,\n sin_output,\n cache_stride,\n hidden_stride,\n HIDDEN_DIM=hidden_dim,\n NUM_SEQS=num_seqs,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n\n return cos_output, sin_output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e950988001b1417498ac2e6dd2280b96", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code utilizes Triton to accelerate parallel attention mechanisms, commonly used in transformer models, on GPUs. It introduces `parallel_retention_fwd_kernel` and `parallel_retention_bwd_kernel` as the core computational kernels. These kernels employ Triton block pointers for optimized data access patterns and execute parallel computations over the sequence length and attention heads. The forward kernel calculates the attention outputs by performing scaled dot-product operations between the query (`q`), key (`k`), and value (`v`) tensors. It incorporates decay factors computed from the head index to adjust the attention weights, ensuring numerical stability and enhancing performance. The output (`o`) tensor is computed through accumulation and rescaling steps involving iterative block-level operations. The `ParallelRetentionFunction` class defines custom autograd functions for the forward and backward passes using the kernels. The `forward` method precomputes constants, including block sizes and scaling factors, and launches the forward kernel, while preserving the inputs for gradient computation. In contrast, the `backward` method leverages the stored inputs and the `do` (gradient of the output) tensor to compute gradients of the inputs (`dq`, `dk`, `dv`). It does so by invoking two sub-kernels: `_parallel_retention_bwd_dq` for the query gradient and `_parallel_retention_bwd_dkv` for the key and value gradients. The implementation carefully manages tensor dimensions and strides, ensuring contiguous memory accesses for efficient GPU execution. Overall, the design showcases advanced techniques in GPU programming, balancing performance with clarity through modular function decomposition.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nfrom torch.amp import custom_bwd, custom_fwd\nimport functools\n\ndef contiguous(fn):\n @functools.wraps(fn)\n def wrapper(ctx, *args, **kwargs):\n return fn(ctx,\n *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),\n **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})\n return wrapper\n\n@triton.jit\ndef parallel_retention_fwd_kernel(\n # B: batch_size, H: n_heads, T: seq_len, D: d_head\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n B, # batch size\n H, # n_heads\n T, # seq_len\n scale, # D_head_K ** -0.5\n BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q\n BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n DK: tl.constexpr, # D_head_K\n DV: tl.constexpr, # D_head_V\n):\n # i_c: chunk index. used for sequence parallelism\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n # decay rate given the head index\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n # cumulative decay from the end of the chunk\n o_k = tl.arange(0, BTS)\n d_h = tl.math.exp2((BTS - o_k) * b_b)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n # [BQ, BD] block Q, in the shared memory throughout the whole kernel\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n\n # Q block and K block have no overlap\n # no need for mask, thereby saving flops\n for _ in range(0, i_c * BTL, BTS):\n # [BK, BTS]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BTS, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n b_s = tl.dot(b_q, (b_k), allow_tf32=False) * d_h[None, :]\n # [BQ, BD]\n b_o = b_o * tl.math.exp2(b_b * BTS)\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n # # rescale interchunk output\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n d_q = tl.math.exp2(tl.arange(0, BTL) * b_b)\n b_o *= d_q[:, None]\n # # sync threads, easy for compiler to optimize\n # tl.debug_barrier()\n\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (DK, T),\n (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n # Q block and K block have overlap. masks required\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n # [BK, BTS]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BTS, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_s = tl.dot(b_q, b_k, allow_tf32=False) * d_s\n # [BTL, BV]\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, DV),\n (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n # decay rate given the head index\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n # overall decay rate for an entire block\n d_b = tl.math.exp2(b_b * BTS)\n # cumulative decay from the end of the chunk\n d_h = tl.math.exp2((BTS - tl.arange(0, BTS)) * b_b)\n for _ in range(0, i_c * BTL, BTS):\n # [BTS, BK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BV, BTS]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_h[None, :]\n # [BQ, BD]\n b_dq *= d_b\n b_dq += tl.dot(b_ds.to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n b_dq *= tl.math.exp2(tl.arange(0, BTL) * b_b)[:, None] * scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (DV, T),\n (s_vo_d, s_vo_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n # Q block and K block have overlap. masks required\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n # [BTS, BK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BV, BTS]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n m_s = o_q[:, None] >= o_k[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (o_q[:, None] - o_k[None, :]) * b_b), 0)\n b_ds = tl.dot(b_do, b_v, allow_tf32=False) * d_s * scale\n # [BTL, BK]\n b_dq += tl.dot(b_ds.to(b_k.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_qk_h, (T, DK),\n (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n # no overlap. no need for mask.\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n # overall decay rate for an entire block\n d_b = tl.math.exp2(b_b * BTS)\n # compute dk dv\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, DK), (s_qk_t, s_qk_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, DV), (s_vo_t, s_vo_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n d_h = tl.math.exp2((BTL - tl.arange(0, BTL)) * b_b)\n b_kd = (b_k * d_h[:, None]).to(b_k.dtype)\n d_q = tl.math.exp2(tl.arange(0, BTS) * b_b)\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]\n b_do = tl.load(p_do, boundary_check=(0, 1)) # [BV, BTS]\n b_do = (b_do * d_q[None, :]).to(b_do.dtype)\n\n b_dv *= d_b\n b_s = tl.dot(b_kd.to(b_q.dtype), b_q, allow_tf32=False) # [BTL, BTS]\n b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n\n b_dk *= d_b\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n b_dk *= d_h[:, None] * scale\n b_dv *= scale\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_qk_h, (DK, T), (s_qk_d, s_qk_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_vo_h, (DV, T), (s_vo_d, s_vo_t), (i_v * BV, i), (BV, BTS), (0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BK, BQ]\n m_s = o_k[:, None] <= o_q[None, :]\n d_s = tl.where(m_s, tl.math.exp2(\n (-o_k[:, None] + o_q[None, :]) * b_b.to(tl.float32)), 0) * scale\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * d_s\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * d_s\n # [BK, BD]\n b_dk += tl.dot(b_ds.to(b_q.dtype), tl.trans(b_q), allow_tf32=False)\n b_dv += tl.dot(b_s.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n o_q += BTS\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_qk_h,\n (T, DK), (s_qk_t, s_qk_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_vo_h,\n (T, DV), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n\n@triton.jit\ndef parallel_retention_bwd_kernel(\n q, k, v, do, dq, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale,\n BTL: tl.constexpr, BTS: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n DK: tl.constexpr, DV: tl.constexpr,\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(DV, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_retention_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n k, v, do, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=DK, DV=DV\n )\n tl.debug_barrier()\n _parallel_retention_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, B, H, T, scale, BTL, BTS, BK, BV, DK, DV\n )\n\n\nclass ParallelRetentionFunction(torch.autograd.Function):\n @staticmethod\n @contiguous\n @custom_fwd(device_type='cuda')\n def forward(ctx, q, k, v):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n o = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n parallel_retention_fwd_kernel[grid](\n q, k, v, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n return o.sum(0).to(q.dtype)\n\n @staticmethod\n @contiguous\n @custom_bwd(device_type='cuda')\n def backward(ctx, do):\n q, k, v = ctx.saved_tensors\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n batch_size, n_heads, seq_len, d_head_qk = q.shape\n d_head_v = v.shape[-1]\n num_stages = 3 if d_head_qk <= 64 else 2\n num_warps = 4\n NK = triton.cdiv(d_head_qk, BK)\n NV = triton.cdiv(d_head_v, BV)\n grid = (NK * NV, triton.cdiv(seq_len, BTL), batch_size * n_heads)\n scale = d_head_qk ** -0.5\n\n dq = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, batch_size, n_heads, seq_len,\n d_head_qk, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, batch_size, n_heads, seq_len,\n d_head_v, dtype=q.dtype, device=q.device)\n\n parallel_retention_bwd_kernel[grid](\n q, k, v, do, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n batch_size, n_heads, seq_len, scale,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV, DK=d_head_qk, DV=d_head_v,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype)\n\n\nparallel_retention = ParallelRetentionFunction.apply\n\n\n\n" + ] + }, + { + "id": "openseek-8-c0aa2a9ca6d341d49843fb143e4827ea", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a fast RMS Layernorm operation using the Triton language for GPU acceleration. It contains three main kernels: `_rms_layernorm_forward`, `_rms_layernorm_backward`, and `_gemma_rms_layernorm_forward`. The `_rms_layernorm_forward` kernel computes the row-wise variance, calculates the inverse square root of the variance, and normalizes the input data using this inverse variance. The result is scaled by a weight vector to produce the output. The `_rms_layernorm_backward` kernel computes gradients with respect to the input and the weight based on the chain rule. The `_gemma_rms_layernorm_forward` kernel is similar to `_rms_layernorm_forward` but includes an additional constant `1.0` added to the weight during the scaling operation. These kernels are used within the `Fast_RMS_Layernorm` PyTorch autograd function. The forward method calculates the output of the layer normalization and saves the relevant variables for the backward pass, while the backward method computes the gradients with respect to the inputs. Common variables include `BLOCK_SIZE`, which determines the computational block size, and `num_warps`, indicating the number of GPU warps used for parallel computation. The function `calculate_settings` decides these values based on input dimensions. The `fast_rms_layernorm` function is a helper that applies the `Fast_RMS_Layernorm` on the given input and layernorm configuration. This setup allows for efficient layer normalization computation on GPUs using custom Triton kernels, aimed at providing significant speedup over standard implementations.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\nimport torch.nn as nn\n\n\nnext_power_of_2 = triton.next_power_of_2\nMAX_FUSED_SIZE : int = 65536\n\ndef calculate_settings(n : int) -> (int, int,):\n BLOCK_SIZE : int = next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n num_warps : int = 4\n if BLOCK_SIZE >= 32768: num_warps = 32\n elif BLOCK_SIZE >= 8192: num_warps = 16\n elif BLOCK_SIZE >= 2048: num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef _rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE: tl.constexpr\n):\n \"\"\"\n Fast RMS Layernorm kernel\n Inspiration from a Triton tutorial:\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask=mask, other=0)\n\n row_var = tl.sum(X_row * X_row, axis=0) / n_cols\n inv_var = tl.math.rsqrt(row_var + eps)\n tl.store(r, inv_var)\n normed = X_row * inv_var\n normed = normed.to(W_row.dtype)\n output = normed * W_row\n tl.store(Y + col_offsets, output, mask=mask)\n\n@triton.heuristics({\"GEMMA\": lambda args: args[\"GEMMA\"],})\n@triton.jit\ndef _rms_layernorm_backward(\n dY, dY_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n dW, dW_row_stride,\n n_cols, eps,\n GEMMA: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n):\n \"\"\"\n Fast RMS Layernorm kernel for the backward pass\n Inspiration from a Triton tutorial:\n https://triton-lang.org/main/getting-started/tutorials/05-layer-norm.html\n \"\"\"\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dY += row_idx * dY_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n dY_row = tl.load(dY + col_offsets, mask=mask, other=0).to(tl.float32)\n X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)\n\n inv_var = tl.load(r).to(tl.float32)\n normed = X_row * inv_var\n\n if GEMMA: dY_W = dY_row * (W_row + 1.0)\n else: dY_W = dY_row * W_row\n\n rowsum_dY_normed = tl.sum(dY_W * normed, axis=0)\n output = inv_var/n_cols * (n_cols*dY_W - normed*rowsum_dY_normed)\n tl.store(dY + col_offsets, output, mask=mask)\n\n@triton.jit\ndef _gemma_rms_layernorm_forward(\n Y, Y_row_stride,\n X, X_row_stride,\n W, W_row_stride,\n r, r_row_stride,\n n_cols, eps,\n BLOCK_SIZE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n Y += row_idx * Y_row_stride\n X += row_idx * X_row_stride\n r += row_idx * r_row_stride\n\n X_row = tl.load(X + col_offsets, mask=mask, other=0).to(tl.float32)\n W_row = tl.load(W + col_offsets, mask=mask, other=0).to(tl.float32)\n\n row_var = tl.sum(X_row * X_row, axis=0) / n_cols\n inv_var = tl.math.rsqrt(row_var + eps)\n tl.store(r, inv_var)\n normed = X_row * inv_var\n output = normed * (W_row + 1.0)\n\n tl.store(Y + col_offsets, output, mask=mask)\n\nclass Fast_RMS_Layernorm(torch.autograd.Function):\n @staticmethod\n def forward(ctx, X, W, eps, gemma=False):\n shape = X.shape\n dim = shape[-1]\n X = X.view(-1, dim)\n n_rows, n_cols = X.shape\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n Y = torch.empty((n_rows, n_cols), dtype=X.dtype, device=\"cuda:0\")\n r = torch.empty(n_rows, dtype=torch.float32, device=\"cuda:0\")\n\n fx = _gemma_rms_layernorm_forward if gemma else _rms_layernorm_forward\n fx[(n_rows,)](\n Y, Y.stride(0),\n X, X.stride(0),\n W, W.stride(0),\n r, r.stride(0),\n n_cols, eps,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n ctx.eps = eps\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.GEMMA = gemma\n ctx.save_for_backward(X, W, r)\n return Y.view(*shape)\n\n @staticmethod\n def backward(ctx, dY):\n shape = dY.shape\n dim = shape[-1]\n dY = dY.view(-1, dim)\n X, W, r = ctx.saved_tensors\n n_rows, n_cols = dY.shape\n dW = X\n\n _rms_layernorm_backward[(n_rows,)](\n dY, dY.stride(0),\n X, X.stride(0),\n W, W.stride(0),\n r, r.stride(0),\n dW, dW.stride(0),\n n_cols, ctx.eps,\n GEMMA=ctx.GEMMA,\n BLOCK_SIZE=ctx.BLOCK_SIZE,\n num_warps=ctx.num_warps,\n )\n dX = dY.view(*shape)\n return dX, None, None, None\n\ndef fast_rms_layernorm(layernorm, X, gemma=False):\n W = layernorm.weight\n eps = layernorm.variance_epsilon if \\\n hasattr(layernorm, \"variance_epsilon\") \\\n else layernorm.eps\n out = Fast_RMS_Layernorm.apply(X, W, eps, gemma)\n return out\n\n\n# Define a simple layernorm module for testing\nclass SimpleLayerNorm(nn.Module):\n def __init__(self, normalized_shape, eps=1e-5):\n super(SimpleLayerNorm, self).__init__()\n self.weight = nn.Parameter(torch.ones(normalized_shape).cuda())\n self.eps = eps\n\n\n\n" + ] + }, + { + "id": "openseek-8-f6b2a229b984401698e42742237794c8", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements a Triton kernel for the forward pass of an attention operation, typically used in transformers. The `forward` function receives the query (`q`), key (`k`), and value (`v`) tensors, as well as their scaling factors (`q_scale`, `k_scale`). It calculates the attention output `o` through block-wise parallel processing.\n\n The primary Triton kernel `_attn_fwd` orchestrates the attention mechanism computation. It executes in parallel over multiple blocks defined by `BLOCK_M` and `BLOCK_N`. The kernel sets up index offsets and pointers for `Q`, `K`, and `V` tensors, used to load data efficiently into shared memory. Each thread block handles a subset of the input tensor to compute its contribution to the attention mechanism.\n\n Inside `_attn_fwd`, the auxiliary kernel `_attn_fwd_inner` performs the main computational loop. It processes blocks of keys and values, calculating their dot products with the queries, scaling them according to `q_scale` and `k_scale`, and applying a maximum reduction for numerical stability. It converts the result into exponentials and normalizes them by a sum, which resembles a softmax operation.\n\n During each iteration, `acc` stores the accumulated weighted sums of the values, while `l_i` stores the normalization factors. `m_i` is used to track the maximum values for numerical stability. The kernel accumulates the results in `acc`, divides by `l_i` for normalization, and writes back to the output tensor `o`.\n\n The kernel is launched with a grid configuration based on the dimensions of the query tensor, and the function ultimately returns the output tensor `o`.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Triton kernel for the inner loop of attention forward pass\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, q_scale, \n K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, \n N_CTX: tl.constexpr):\n lo, hi = 0, N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k_mask = (offs_n[None, :] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[:, None])\n k = tl.load(K_ptrs, mask = k_mask)\n k_scale = tl.load(K_scale_ptr)\n qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_ptrs, mask = (offs_n[:, None] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[None, :]))\n p = p.to(tl.float16)\n \n acc += tl.dot(p, v.to(tl.float16), out_dtype=tl.float16) \n m_i = m_ij\n K_ptrs += BLOCK_N * HEAD_DIM\n K_scale_ptr += 1\n V_ptrs += BLOCK_N * HEAD_DIM\n return acc, l_i\n\n# Triton kernel for the attention forward pass\n@triton.jit\ndef _attn_fwd(Q, K, V, Q_scale, K_scale, Out, \n stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vk, stride_vn, \n stride_oz, stride_oh, stride_om, stride_on, \n Z, H, N_CTX, \n HEAD_DIM: tl.constexpr, \n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr \n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n \n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n vk_offset = qvk_offset // stride_qm\n q_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_M)\n k_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_N) \n \n\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, 128)\n Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n Q_scale_ptr = Q_scale + q_scale_offset + start_m\n K_ptrs = K + qvk_offset + offs_k[:, None] + offs_n[None, :] * stride_kn\n K_scale_ptr = K_scale + k_scale_offset\n V_ptrs = V + qvk_offset + offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk\n O_block_ptr = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, 128], dtype=tl.float32)\n q = tl.load(Q_ptrs, mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n q_scale = tl.load(Q_scale_ptr)\n acc, l_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 4 - STAGE, offs_m, offs_n, N_CTX \n )\n acc = acc / l_i[:, None]\n tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n\n# Function to call the Triton kernel\ndef forward(q, k, v, q_scale, k_scale):\n BLOCK_M = 128\n BLOCK_N = 64\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n o = torch.empty_like(q, dtype=torch.bfloat16)\n stage = 1\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, q_scale, k_scale, o, \n q.stride(0), q.stride(1), q.stride(2), q.stride(3), \n k.stride(0), k.stride(1), k.stride(2), k.stride(3), \n v.stride(0), v.stride(1), v.stride(2), v.stride(3), \n o.stride(0), o.stride(1), o.stride(2), o.stride(3), \n q.shape[0], q.shape[1], \n N_CTX=q.shape[2], \n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, \n STAGE=stage, \n num_warps=8, \n num_stages=3)\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-a95e6f64940e4e06a3a9d1f48262df82", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rbe_triton` kernel applies a position-dependent transformation to a 3D input tensor `x` of shape [batch, M, K], performing operations on real and imaginary parts separately. Each thread block processes a rectangular section of the tensor determined by the block size, `BLOCK_SIZE_M` (2) for the M dimension and `BLOCK_SIZE_K` (1024) for the K dimension. The kernel uses `program_id` to compute the offsets `offs_m` and `offs_n` for each block, loading data from `x` into local variables `real` and `imag`, considering only even-numbered indices for real components. Sine and cosine values for position-dependent transformations are precomputed via `get_freq_multi_tokens`, with `theta=10000`. These frequency values are used to transform `real` and `imag` into `out_real` and `out_imag`, which are then written back to the output tensor `out` using calculated offsets and appropriate masks. The `rbe_triton_wrapper` function prepares the parameters, defines the execution grid based on input dimensions, and launches the kernel, managing data flow and maintaining tensor consistency.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\ntorch.manual_seed(1234)\n\n\n@triton.jit\ndef get_freq_multi_tokens(offs_cn, starting_idx, theta: tl.constexpr, NB_TOKENS: tl.constexpr):\n DIM: tl.constexpr = 128 # in model, dim = self.params.dim // self.params.n_heads\n freqs = offs_cn % DIM\n freqs = freqs.to(tl.float32) / DIM\n freqs = tl.extra.cuda.libdevice.pow(theta, freqs)\n freqs = (tl.arange(0, NB_TOKENS) + starting_idx)[:, None] / freqs[None, :]\n return tl.cos(freqs), tl.sin(freqs)\n\n\n@triton.jit\ndef rbe_triton(x_ptr, out_ptr,\n M, K,\n stride_x_batch, stride_x_m, stride_x_n,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n THETA: tl.constexpr, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(K, BLOCK_SIZE_K)\n pid_n = pid % tl.cdiv(K, BLOCK_SIZE_K)\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K // 2) * 2 # take only even numbers\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + stride_x_m * offs_m[:, None] + stride_x_n * offs_n[None, :])\n x_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n real = tl.load(x_ptrs, mask=x_real_mask, other=0.0)\n x_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n imag = tl.load(x_ptrs + 1, mask=x_imag_mask, other=0.0)\n tl.debug_barrier()\n start_block = start_token_position + pid_m * BLOCK_SIZE_M\n cos, sin = get_freq_multi_tokens(offs_cn=offs_n, starting_idx=start_block, theta=THETA, NB_TOKENS=BLOCK_SIZE_M)\n\n out_real = real * cos - imag * sin\n out_imag = real * sin + imag * cos\n tl.debug_barrier()\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + stride_out_m * offs_m[:, None] + stride_out_n * offs_n[None, :])\n out_real_mask = (offs_m[:, None] < M) & (offs_n[None, :] < K)\n tl.store(out_ptrs, out_real, mask=out_real_mask)\n out_imag_mask = (offs_m[:, None] < M) & (1 + offs_n[None, :] < K)\n tl.store(out_ptrs + 1, out_imag, mask=out_imag_mask)\n\n\ndef rbe_triton_wrapper(x: torch.Tensor, pos: int) -> torch.Tensor:\n batch, M, K = x.shape\n out = torch.empty_like(x)\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"K\"], META[\"BLOCK_SIZE_K\"]),)\n\n rbe_triton[grid](x, out,\n M, K,\n *x.stride(),\n *out.stride(),\n start_token_position=pos, THETA=10000., BLOCK_SIZE_M=2, BLOCK_SIZE_K=1024)\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-bc17a717211e438d96e0c54e6847bdea", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator defines a matrix multiplication function `iv_dependent_matmul_kernel`, which is executed on a GPU using the Triton framework. The kernel function takes pointers to matrices `a`, `b`, and `c`, along with their dimensions `M`, `N`, `K`, and strides for memory layout. It also accepts block sizes `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K` and a string `type` to define different loading strategies. The kernel uses a double loop strategy to load and compute block matrices from `a` and `b`, storing results in `c`, based on specified offsets calculated from program ID `pid`. The function supports conditional logic for various data preloading strategies to optimize memory access patterns. The wrapper function `iv_dependent_matmul_wrapper` handles setup, including device selection, random matrix generation, and grid configuration, to invoke the kernel with specific parameters. It uses the Triton grid configuration to determine execution dimensions and sets the number of pipeline stages depending on the scheduling type. The result is stored in `triton_output`.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef iv_dependent_matmul_kernel(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr, #\n type: tl.constexpr):\n pid = tl.program_id(axis=0)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n pid_m = pid // num_pid_n\n pid_n = pid % num_pid_n\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptr = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptr = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n a_ptrs = a_ptr\n b_ptrs = b_ptr\n if type == \"post_load_two_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs_next = a_ptr + BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + BLOCK_SIZE_K * stride_bk\n a_ptrs_next_next = a_ptr + 2 * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + 2 * BLOCK_SIZE_K * stride_bk\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n if type == \"pre_load\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + k * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n a_ptrs = a_ptr + k * BLOCK_SIZE_K * stride_ak\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n if type == \"post_load\":\n a_ptrs = a_ptr + (k + 1) * BLOCK_SIZE_K * stride_ak\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_pre_mixed\":\n b_ptrs = b_ptr + (k + 1) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_two_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptr + (k + 2) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next = b_ptr + (k + 2) * BLOCK_SIZE_K * stride_bk\n elif type == \"post_load_three_iters\":\n a_ptrs = a_ptrs_next\n b_ptrs = b_ptrs_next\n a_ptrs_next = a_ptrs_next_next\n b_ptrs_next = b_ptrs_next_next\n a_ptrs_next_next = a_ptr + (k + 3) * BLOCK_SIZE_K * stride_ak\n b_ptrs_next_next = b_ptr + (k + 3) * BLOCK_SIZE_K * stride_bk\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n\ndef iv_dependent_matmul_wrapper(\n M: int, \n K: int, \n N: int, \n BLOCK_SIZE_M: int, \n BLOCK_SIZE_N: int, \n BLOCK_SIZE_K: int, \n type: str = \"pre_load\", # Kernel type for scheduling (\"pre_load\", \"post_load\", etc.)\n device: torch.device = \"cuda\" # Device to run the test (defaults to \"cuda\")\n):\n # Ensure the device is correct\n device = torch.device(device)\n\n # Generate random input matrices a and b on the specified device\n a = torch.rand((M, K), device=device)\n b = torch.rand((K, N), device=device)\n\n # Create an empty tensor to store the Triton result\n triton_output = torch.empty((M, N), device=device)\n\n # Define Triton grid configuration\n def grid(META):\n return (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n\n # Set the number of stages based on the kernel type\n num_stages = 4 if type == \"post_load_three_iters\" else 3\n\n # Run the Triton kernel\n iv_dependent_matmul_kernel[grid](\n a, b, triton_output, M, N, K, #\n a.stride(0), a.stride(1), b.stride(0), b.stride(1), #\n triton_output.stride(0), triton_output.stride(1), #\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K, type=type, #\n num_stages=num_stages\n )\n\n # Optionally print the result for inspection\n # print(triton_output)\n\n return triton_output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4cd02a9342be487da920efacf2156ae2", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel `_swiglu_bwd_kernel` is an efficient implementation of the backward pass for Swish-Gated Linear Units (Swiglu). It operates on input tensors `X` and `Y`, their respective gradients `DX` and `DY`, a derivative tensor `DOUT`, and optionally the output tensor `OUT` for recomputation. The kernel maps program IDs to row indices and calculates column indices for block processing, defined by the `BLOCK_N` parameter. It loads slices of the input tensors using masks to handle boundary conditions. Core computations use the sigmoid function to derive the gradients, applying the Swish derivative formula for `dx` and straightforward multiplication for `dy`. These results are stored back to the gradients' memory locations. The kernel utilizes `RECOMPUTE_OUTPUT` to optionally recompute and store the output tensor. The wrapper function `_swiglu_bwd` is responsible for input preprocessing, ensuring tensors are contiguous, reshaping for batch dimensions, splitting `xy` into `x` and `y`, and setting up the execution grid that determines the number of Triton programs launched. It manages the lifecycle of intermediate tensors and adapts outputs based on the `recompute_output` flag.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Backward Triton kernel for Swish-Gated Linear Units (Swiglu)\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_N': 32}),\n triton.Config({'BLOCK_N': 64}),\n triton.Config({'BLOCK_N': 128}),\n triton.Config({'BLOCK_N': 256}),\n triton.Config({'BLOCK_N': 512}),\n triton.Config({'BLOCK_N': 1024}),\n ],\n key=['ncols'],\n)\n@triton.heuristics({\"RECOMPUTE_OUTPUT\": lambda args: args[\"OUT\"] is not None})\n@triton.jit\ndef _swiglu_bwd_kernel(\n X, Y, DOUT, OUT, DX, DY, stride_x_row, stride_y_row, stride_dout_row,\n stride_out_row, stride_dx_row, stride_dy_row, ncols, BLOCK_N: tl.constexpr,\n RECOMPUTE_OUTPUT: tl.constexpr\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n start_col = tl.program_id(1) * BLOCK_N\n X += row * stride_x_row\n Y += row * stride_y_row\n DOUT += row * stride_dout_row\n if RECOMPUTE_OUTPUT:\n OUT += row * stride_out_row\n DX += row * stride_dx_row\n DY += row * stride_dy_row\n cols = start_col + tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < ncols, other=0.).to(tl.float32)\n y = tl.load(Y + cols, mask=cols < ncols, other=0.).to(tl.float32)\n dout = tl.load(DOUT + cols, mask=cols < ncols, other=0.).to(tl.float32)\n x_sigmoid = tl.sigmoid(x)\n dx = x_sigmoid * (1 + x * (1 - x_sigmoid)) * y * dout\n dy = x * x_sigmoid * dout\n tl.store(DX + cols, dx, mask=cols < ncols)\n tl.store(DY + cols, dy, mask=cols < ncols)\n if RECOMPUTE_OUTPUT:\n out = x * x_sigmoid * y\n tl.store(OUT + cols, out, mask=cols < ncols)\n\n# Function to invoke the backward kernel\ndef _swiglu_bwd(xy, dout, dxy=None, recompute_output=False, out=None):\n if xy.stride(-1) != 1:\n xy = xy.contiguous()\n if dout.stride(-1) != 1:\n dout = dout.contiguous()\n batch_shape = xy.shape[:-1]\n xy = xy.reshape(-1, xy.shape[-1])\n x, y = xy.chunk(2, dim=-1)\n dout = dout.reshape(-1, dout.shape[-1])\n assert dout.shape == x.shape\n if dxy is None:\n dxy = torch.empty_like(xy)\n else:\n dxy = dxy.reshape(-1, dxy.shape[-1])\n assert dxy.shape == xy.shape\n dx, dy = dxy.chunk(2, dim=-1)\n assert dx.stride(-1) == 1\n assert dy.stride(-1) == 1\n if recompute_output:\n if out is None:\n out = torch.empty_like(x)\n else:\n out = out.reshape(-1, out.shape[-1])\n assert out.shape == x.shape\n assert out.stride(-1) == 1\n M, N = x.shape\n grid = lambda META: (M, triton.cdiv(N, META['BLOCK_N']))\n with torch.cuda.device(x.device.index):\n _swiglu_bwd_kernel[grid](\n x, y, dout, out if recompute_output else None, dx, dy, x.stride(0), y.stride(0),\n dout.stride(0), out.stride(0) if recompute_output else 0, dx.stride(0),\n dy.stride(0), N\n )\n if not recompute_output:\n return dxy.reshape(*batch_shape, dxy.shape[-1])\n else:\n return dxy.reshape(*batch_shape, dxy.shape[-1]), out.reshape(*batch_shape, out.shape[-1])\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-f44c967eedd24797b04e2624e91869b7", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton operator is defined to transpose a matrix using a kernel function and a wrapper function. The kernel function named 'kernel' is decorated with '@triton.jit' for just-in-time compilation and performs matrix transposition by directly manipulating pointers based on the given strides and dimensions. It accepts input parameters such as a matrix 'M', an output buffer 'Out', the strides of 'M' and 'Out', and the dimensions 'SIZE_M' and 'D_HEAD'. The kernel computes the pointers for elements of 'M' using 'matrix_stridex' and 'matrix_stridey', and for 'Out' using 'out_stridex' and 'out_stridey'. The transposition is achieved by loading elements from 'M' and storing them into 'Out' in a transposed layout. The wrapper function named 'wrapper' initializes 'matrix' with random float16 values and 'out' with zeros, both on CUDA. It defines the grid configuration as a tuple with a single element, then calls the kernel with these matrices and their properties. Finally, it returns the transposed matrix 'out'.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef kernel(\n M,\n Out,\n matrix_stridex,\n matrix_stridey,\n out_stridex,\n out_stridey,\n SIZE_M: tl.constexpr,\n D_HEAD: tl.constexpr,\n):\n size_m_arange = tl.arange(0, SIZE_M)\n d_head_arange = tl.arange(0, D_HEAD)\n # transpose\n matrix_ptr = M + d_head_arange[None, :] * matrix_stridey + size_m_arange[:, None] * matrix_stridex\n out_ptr = Out + d_head_arange[None, :] * out_stridex + size_m_arange[:, None] * out_stridey\n matrix = tl.load(matrix_ptr)\n tl.store(out_ptr, matrix)\n\ndef wrapper(size_m, d_head):\n matrix = torch.randn((size_m, d_head), dtype=torch.float16, device=\"cuda\")\n out = torch.zeros((d_head, size_m), dtype=torch.float16, device=\"cuda\")\n\n grid = (1,)\n kernel[grid](\n matrix,\n out,\n *matrix.stride(),\n *out.stride(),\n size_m,\n d_head,\n )\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-642c5feef8924d6ba01e87d4e1805f0e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `softmax_kernel_online_v2` function is a Triton-optimized kernel designed to compute the softmax function over a 2D input tensor. It uses tiling to manage computation over potentially large matrices by splitting the data into manageable chunks, or tiles, of size `TILE_N`. This function accepts pointers to the input and output data, dimensions of the input matrix (M by N), and the tile size `TILE_N`. The computation occurs in two main phases: a reduction phase to compute necessary statistics (maximum value and sum of exponentials for stable softmax) and a final output phase where the input data is exponentiated, normalized by the sum, and stored in the output location. The kernel efficiently handles matrix boundaries by leveraging the `prev_multiple_of` helper to compute the largest possible tile-aligned boundary within matrix dimensions. This boundary is used to handle potential edge cases in matrix dimensions that are not perfectly divisible by the tile size. The function's design ensures numerical stability and efficiency by computing softmax in a numerically stable manner using a two-pass algorithm. The wrapper `softmax` function prepares the necessary parameters, invokes this kernel, and returns the softmax-normalized output.\n ", + "output": [ + "import triton\nfrom triton import language as tl\nimport torch\n\n@triton.jit\ndef prev_multiple_of(a, b):\n # the largest x None:\n \"\"\"\n Args:\n inputs (torch.Tensor): input tensor\n lora_b_weights (torch.Tensor): lora'a weight\n output_tensor (torch.Tensor): output tensor\n lora_indices_tensor (torch.Tensor): (batch_size,). The LoRA index\n corresponding to each batch, An index of -1 means no lora should be\n applied.\n batches (int): batch size\n add_inputs (bool, optional): Defaults to False, adds the final lora \n results to the output.\n \"\"\"\n assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]\n assert lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]\n assert inputs.size(1) == lora_b_weights.size(-1)\n\n assert inputs.is_contiguous()\n assert output_tensor.is_contiguous()\n\n if lora_b_weights.ndim == 4: # shape:(lora_num,1,size,rank)\n assert lora_b_weights.size(1) == 1\n lora_b_weights = lora_b_weights.squeeze(dim=1)\n else:\n assert lora_b_weights.ndim == 3 # shape:(lora_num,size,rank)\n assert lora_b_weights.is_contiguous()\n\n # TODO tuning this config\n N, K = lora_b_weights.shape[-2:] # K= rank,N=hidden_size\n BLOCK_K = triton.next_power_of_2(K)\n EVEN_K = K % BLOCK_K == 0\n ADD_INPUTS = add_inputs\n CAST_TYPE = False\n if inputs.dtype == torch.float32 and lora_b_weights.dtype in [\n torch.float16,\n torch.bfloat16,\n ]:\n CAST_TYPE = True\n batches = lora_indices_tensor.size(0)\n # config = get_lora_op_configs(\"expand\", batches, N)\n grid = lambda META: (\n META[\"SPLIT_N\"],\n batches,\n )\n _bgmv_expand_kernel[grid](\n inputs,\n lora_b_weights,\n output_tensor,\n N,\n K,\n lora_indices_tensor,\n inputs.stride(0),\n inputs.stride(1),\n lora_b_weights.stride(0),\n lora_b_weights.stride(1),\n lora_b_weights.stride(2),\n output_tensor.stride(0),\n output_tensor.stride(1),\n BLOCK_K=BLOCK_K,\n BLOCK_N=256,\n SPLIT_N=lora_b_weights.shape[-2:][0],\n EVEN_K=EVEN_K,\n ADD_INPUTS=ADD_INPUTS,\n CAST_TYPE=CAST_TYPE,\n # **config,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-aa21609446184a628be28814a42b1143", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code is a Triton-accelerated implementation of the forward pass for an attention mechanism, typically used in transformer architectures. The main kernel function `_attn_fwd` is JIT-compiled using Triton, which allows for optimized, parallel execution on GPUs. The kernel takes inputs Q, K, and V matrices, which are the query, key, and value matrices respectively, and computes attention outputs using block matrix operations for efficiency. \n\n The `_attn_fwd_inner` function is central to this operation. It calculates attention scores through the matrix product of Q and K (QK), scales the results using `qk_scale`, applies a mask if a sliding window attention is specified, and uses exponential scaling to stabilize the softmax computation. The function updates running maxima `m_i` and likelihoods `l_i` for each block. These are crucial for numerical stability, avoiding overflow/underflow when computing softmax.\n\n The outer `_attn_fwd` function initializes memory pointers for the Q, K, V, and output matrices. It calculates block pointers which dictate data layout in shared memory, enabling efficient memory access patterns. Depending on the context sizes, it configures whether boundary checks are necessary for matrix loads to ensure out-of-bound accesses are avoided. The use of `tl.math.exp2` instead of the traditional exponential function hints at specific numerical optimizations in logarithm base 2 for GPU-friendly calculations.\n\n Parameters such as `BLOCK_M`, `BLOCK_N`, and `BLOCK_DMODEL` determine the size of each processing block, crucial for performance tuning, as they define shared memory usage and the degree of parallelism. The `_forward` function acts as an interface, preparing the input data, managing kernel execution, handling exceptions like `triton.OutOfResources`, and adjusting block sizes if resource limitations are hit. The function supports additional functionality such as handling end-of-sequence operations, initializing accumulators, and managing sliding window constraints.\n\n Overall, this code is structured to ensure high performance on GPU architectures by leveraging Triton's ability to automatically parallelize workloads, manage synchronization, and exploit fast memory hierarchies through efficient tiling and batching strategies.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n_BLOCK_N = 64\n_BLOCK_M = 64\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q,\n K_block_ptr, V_block_ptr,\n start_m, qk_scale, N_CTX,\n sliding_window_offset, sliding_window_size,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, SLIDING_WINDOW: tl.constexpr,\n IS_EVEN_M: tl.constexpr, IS_EVEN_N: tl.constexpr, COMPLEMENT_SLIDING_WINDOW: tl.constexpr\n ):\n # range of values handled by this stage\n if SLIDING_WINDOW and not COMPLEMENT_SLIDING_WINDOW:\n if COMPLEMENT_SLIDING_WINDOW:\n lo = 0\n hi = (((start_m + 1) * BLOCK_M + sliding_window_offset - sliding_window_size + BLOCK_N - 1) // BLOCK_N) * BLOCK_N\n else:\n lo = ((start_m * BLOCK_M + sliding_window_offset - sliding_window_size + 1) // BLOCK_N) * BLOCK_N\n hi = ((((start_m + 1) * BLOCK_M - 1) + sliding_window_offset + BLOCK_N) // BLOCK_N) * BLOCK_N\n if lo < 0:\n lo = 0\n if hi > N_CTX:\n hi = N_CTX\n\n lo = tl.multiple_of(lo, BLOCK_N)\n K_block_ptr = tl.advance(K_block_ptr, (0, lo))\n V_block_ptr = tl.advance(V_block_ptr, (lo, 0))\n else:\n lo, hi = 0, N_CTX\n\n # loop over k, v and update accumulator\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n if IS_EVEN_N:\n k = tl.load(K_block_ptr)\n else:\n k = tl.load(K_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk = qk * qk_scale\n\n if SLIDING_WINDOW:\n dist = tl.arange(0, BLOCK_M)[:, None] - tl.arange(0, BLOCK_N)[None, :] \\\n + start_m * BLOCK_M - start_n + sliding_window_offset\n\n if COMPLEMENT_SLIDING_WINDOW:\n mask = (dist >= sliding_window_size)\n else:\n mask = (dist >= 0) & (dist < sliding_window_size)\n\n qk = tl.where(mask, qk, float(\"-inf\"))\n\n if not IS_EVEN_N:\n qk = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], qk, float(\"-inf\"))\n\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n\n if SLIDING_WINDOW:\n p = tl.where(mask, p, 0)\n\n if not IS_EVEN_N:\n p = tl.where(((tl.arange(0, BLOCK_N) + start_n) < N_CTX)[None, :], p, 0)\n\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n tmp = m_i - m_ij\n alpha_mask = (tmp != tmp) # check nan\n alpha = tl.math.exp2(tmp)\n alpha = tl.where(alpha_mask, 1., alpha)\n l_i = l_i * alpha + l_ij\n # -- update output accumulator --\n acc = acc * alpha[:, None]\n # update acc\n if IS_EVEN_N:\n v = tl.load(V_block_ptr)\n else:\n v = tl.load(V_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n\n acc += tl.dot(p.to(v.dtype), v)\n # update m_i and l_i\n m_i = m_ij\n V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0))\n K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N))\n\n return acc, l_i, m_i\n\n\n@triton.heuristics(\n {\n \"IS_EVEN_M\": lambda args: args[\"N_CTX\"] % args[\"BLOCK_M\"] == 0,\n \"IS_EVEN_N\": lambda args: args[\"NKV_CTX\"] % args[\"BLOCK_N\"] == 0,\n }\n)\n@triton.jit\ndef _attn_fwd(Q, K, V, sm_scale, M, Out, L,#\n stride_qz, stride_qh, stride_qm, stride_qk, #\n stride_kz, stride_kh, stride_kn, stride_kk, #\n stride_vz, stride_vh, stride_vk, stride_vn, #\n stride_oz, stride_oh, stride_om, stride_on, #\n Z, H, H_KV, #\n N_CTX, #\n ROUND_CTX,\n NKV_CTX,\n sliding_window_offset,\n sliding_window_size,\n IS_EVEN_M: tl.constexpr,\n IS_EVEN_N: tl.constexpr,\n BLOCK_M: tl.constexpr, #\n BLOCK_DMODEL: tl.constexpr, #\n BLOCK_N: tl.constexpr, #\n END: tl.constexpr,\n INIT: tl.constexpr,\n SLIDING_WINDOW: tl.constexpr,\n COMPLEMENT_SLIDING_WINDOW: tl.constexpr\n ):\n\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n off_z = off_hz // H\n off_h = off_hz % H\n off_hkv = off_h // (H//H_KV)\n q_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n k_offset = off_z.to(tl.int64) * stride_kz + off_hkv.to(tl.int64) * stride_kh\n v_offset = off_z.to(tl.int64) * stride_vz + off_hkv.to(tl.int64) * stride_vh\n o_offset = off_z.to(tl.int64) * stride_oz + off_h.to(tl.int64) * stride_oh\n\n # block pointers\n Q_block_ptr = tl.make_block_ptr(\n base=Q + q_offset,\n shape=(N_CTX, BLOCK_DMODEL),\n strides=(stride_qm, stride_qk),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n V_block_ptr = tl.make_block_ptr(\n base=V + v_offset,\n shape=(NKV_CTX, BLOCK_DMODEL),\n strides=(stride_vk, stride_vn),\n offsets=(0, 0),\n block_shape=(BLOCK_N, BLOCK_DMODEL),\n order=(1, 0),\n )\n K_block_ptr = tl.make_block_ptr(\n base=K + k_offset,\n shape=(BLOCK_DMODEL, NKV_CTX),\n strides=(stride_kk, stride_kn),\n offsets=(0, 0),\n block_shape=(BLOCK_DMODEL, BLOCK_N),\n order=(0, 1),\n )\n O_block_ptr = tl.make_block_ptr(\n base=Out + o_offset,\n shape=(ROUND_CTX, BLOCK_DMODEL),\n strides=(stride_om, stride_on),\n offsets=(start_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_DMODEL),\n order=(1, 0),\n )\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n # initialize pointer to m and l\n m_ptrs = M + off_hz * ROUND_CTX + offs_m\n l_ptrs = L + off_hz * ROUND_CTX + offs_m\n if INIT:\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n else:\n # don't have to check boundary for q len\n m_i = tl.load(m_ptrs).to(tl.float32)\n l_i = tl.load(l_ptrs).to(tl.float32)\n acc = tl.load(O_block_ptr).to(tl.float32)\n\n qk_scale = sm_scale\n qk_scale *= 1.4426950408889634 # 1/log(2)\n # load q: it will stay in SRAM throughout\n if IS_EVEN_M:\n q = tl.load(Q_block_ptr)\n else:\n q = tl.load(Q_block_ptr, boundary_check=(0, 1), padding_option=\"zero\")\n\n acc, l_i, m_i = _attn_fwd_inner(acc, l_i, m_i, q, K_block_ptr, V_block_ptr, #\n start_m, qk_scale, NKV_CTX, #\n sliding_window_offset, sliding_window_size,\n BLOCK_M, BLOCK_DMODEL, BLOCK_N, SLIDING_WINDOW, IS_EVEN_M, IS_EVEN_N,\n COMPLEMENT_SLIDING_WINDOW)\n # epilogue\n if (END):\n m_i += tl.math.log2(l_i)\n acc = acc / l_i[:, None]\n else:\n tl.store(l_ptrs, l_i)\n\n tl.store(m_ptrs, m_i)\n tl.store(O_block_ptr, acc.to(Out.type.element_ty))\n\ndef _forward(\n q, k, v, sm_scale,\n o=None, m=None, l=None, end=False,\n sliding_window=None, init=False,\n complement_sliding_window=False\n):\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n q_round_len = math.ceil(q.shape[2] / 64) * 64\n\n if sliding_window is not None:\n sliding_window_offset, sliding_window_size = sliding_window\n else:\n sliding_window_offset, sliding_window_size = None, None\n\n grid = lambda META: (\n triton.cdiv(q.shape[2], META[\"BLOCK_M\"]),\n q.shape[0] * q.shape[1],\n )\n\n global _BLOCK_N\n global _BLOCK_M\n\n try:\n with torch.cuda.device(q.device):\n _attn_fwd[grid](\n q, k, v, sm_scale, m, o, l, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], k.shape[1], #\n q.shape[2], #\n q_round_len,\n k.shape[2],\n sliding_window_offset,\n sliding_window_size,\n BLOCK_DMODEL=Lk, #\n END=end,\n INIT=init,\n BLOCK_M=_BLOCK_M,\n BLOCK_N=_BLOCK_N,\n SLIDING_WINDOW=(sliding_window is not None),\n COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,\n num_warps=4,\n num_stages=4\n )\n except triton.OutOfResources as E:\n _BLOCK_N = _BLOCK_N // 2\n _BLOCK_M = _BLOCK_M // 2\n from warnings import warn\n warn(f\"Triton Attention Output Resources. {E}\\nUse smaller block size {_BLOCK_N}.\")\n with torch.cuda.device(q.device):\n _attn_fwd[grid](\n q, k, v, sm_scale, m, o, l, #\n q.stride(0), q.stride(1), q.stride(2), q.stride(3), #\n k.stride(0), k.stride(1), k.stride(2), k.stride(3), #\n v.stride(0), v.stride(1), v.stride(2), v.stride(3), #\n o.stride(0), o.stride(1), o.stride(2), o.stride(3), #\n q.shape[0], q.shape[1], k.shape[1], #\n q.shape[2], #\n q_round_len,\n k.shape[2],\n sliding_window_offset,\n sliding_window_size,\n BLOCK_DMODEL=Lk, #\n END=end,\n INIT=init,\n BLOCK_M=_BLOCK_M,\n BLOCK_N=_BLOCK_N,\n SLIDING_WINDOW=(sliding_window is not None),\n COMPLEMENT_SLIDING_WINDOW=complement_sliding_window,\n num_warps=4,\n num_stages=4\n )\n\n if end:\n o = o[:, :, :q.shape[2], :].contiguous().to(q.dtype)\n\n return o, m, l\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-09fdd96791014919af8c4b2eda440e3e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code defines a Triton kernel, `_fwd_kernel_destindex_copy_kv`, which is intended for copying elements from the source tensors `KV_nope` and `KV_rope` to the destination tensors `O_nope` and `O_rope` using indices specified in the tensor `DestLoc`. This kernel iterates over each index in the sequence, as specified by `DestLoc`, to load elements from the `KV_nope` and `KV_rope` tensors and subsequently store them in the corresponding position in the `O_nope` and `O_rope` tensors based on `DestLoc`. \n\n The kernel uses Triton's parallel programming constructs, with `tl.load` and `tl.store` operations, ensuring efficient memory transactions. Specifically, `tl.arange` is used to compute offsets for batch processing of dimensions defined by `BLOCK_DMODEL_NOPE` and `BLOCK_DMODEL_ROPE`, which represent the aligned head dimensions (to the next power of two for optimal performance). The kernel is invoked via the `destindex_copy_kv` function, a no-gradient context wrapper that configures grid size and execution parameters such as `num_warps` and `num_stages` for optimal performance on the Triton platform. \n\n The inputs to `destindex_copy_kv` include:\n - `KV_nope`, `KV_rope`: Source tensors of shape `(batch_size, num_heads, head_dim)` which contain the data to be copied.\n - `DestLoc`: A tensor indicating the destination indices in the output tensors.\n - `O_nope`, `O_rope`: Output tensors which receive the copied data.\n\n The function checks that the shapes of corresponding source and destination tensors match, prepares the aligned dimensions for kernel execution, calculates strides, and invokes the kernel with these parameters. This setup supports parallel data transfer leveraging Triton's capabilities for handling multi-dimensional tensor operations efficiently.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n KV_nope,\n KV_rope,\n Dest_loc,\n O_nope,\n O_rope,\n stride_kv_nope_bs,\n stride_kv_nope_h,\n stride_kv_nope_d,\n stride_kv_rope_bs,\n stride_kv_rope_h,\n stride_kv_rope_d,\n stride_o_nope_bs,\n stride_o_nope_h,\n stride_o_nope_d,\n stride_o_rope_bs,\n stride_o_rope_h,\n stride_o_rope_d,\n kv_nope_head_num,\n kv_rope_head_num,\n BLOCK_DMODEL_NOPE: tl.constexpr,\n BLOCK_DMODEL_ROPE: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_d_nope = tl.arange(0, BLOCK_DMODEL_NOPE)\n offs_d_rope = tl.arange(0, BLOCK_DMODEL_ROPE)\n dest_index = tl.load(Dest_loc + cur_index)\n\n kv_nope_ptrs = KV_nope + cur_index * stride_kv_nope_bs + stride_kv_nope_d * offs_d_nope[None, :]\n kv_rope_ptrs = KV_rope + cur_index * stride_kv_rope_bs + stride_kv_rope_d * offs_d_rope[None, :]\n\n o_nope_ptrs = O_nope + dest_index * stride_o_nope_bs + stride_o_nope_d * offs_d_nope[None, :]\n o_rope_ptrs = O_rope + dest_index * stride_o_rope_bs + stride_o_rope_d * offs_d_rope[None, :]\n\n kv_nope = tl.load(kv_nope_ptrs)\n kv_rope = tl.load(kv_rope_ptrs)\n\n tl.store(o_nope_ptrs, kv_nope)\n tl.store(o_rope_ptrs, kv_rope)\n return\n\n\n@torch.no_grad()\ndef destindex_copy_kv(KV_nope, KV_rope, DestLoc, O_nope, O_rope):\n seq_len = DestLoc.shape[0]\n kv_nope_head_num = KV_nope.shape[1]\n kv_rope_head_num = KV_rope.shape[1]\n\n kv_nope_head_dim = KV_nope.shape[2]\n kv_rope_head_dim = KV_rope.shape[2]\n\n aligned_d_nope = triton.next_power_of_2(kv_nope_head_dim) # 调整为2的幂次方\n aligned_d_rope = triton.next_power_of_2(kv_rope_head_dim) # 调整为2的幂次方\n\n assert KV_nope.shape[1] == O_nope.shape[1]\n assert KV_nope.shape[2] == O_nope.shape[2]\n assert KV_rope.shape[1] == O_rope.shape[1]\n assert KV_rope.shape[2] == O_rope.shape[2]\n grid = (seq_len,)\n num_warps = 2\n\n _fwd_kernel_destindex_copy_kv[grid](\n KV_nope,\n KV_rope,\n DestLoc,\n O_nope,\n O_rope,\n KV_nope.stride(0),\n KV_nope.stride(1),\n KV_nope.stride(2),\n KV_rope.stride(0),\n KV_rope.stride(1),\n KV_rope.stride(2),\n O_nope.stride(0),\n O_nope.stride(1),\n O_nope.stride(2),\n O_rope.stride(0),\n O_rope.stride(1),\n O_rope.stride(2),\n kv_nope_head_num,\n kv_rope_head_num,\n # BLOCK_DMODEL_NOPE=kv_nope_head_dim,\n # BLOCK_DMODEL_ROPE=kv_rope_head_dim,\n BLOCK_DMODEL_NOPE=aligned_d_nope, # 传递对齐后的值\n BLOCK_DMODEL_ROPE=aligned_d_rope, # 传递对齐后的值\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-eaa0d5385bc4416eba6701ab7acaea04", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `triton_red_fused_native_layer_norm_0` within `fused_native_layer_norm` performs layer normalization on the input tensor `primals_3` (shape `(S, D)`). The process involves calculating the mean and variance across the rows of the input tensor using a parallel reduction technique known as the Welford algorithm, facilitated by Triton's helper functions. Each thread block processes a segment of the input (`RBLOCK` elements) and updates shared buffers for mean, variance, and count (`tmp3_mean`, `tmp3_m2`, `tmp3_weight`). After the reduction, mean (`tmp3`), variance (`tmp4`), and count (`tmp5`) are stored in buffer `out_ptr0` and used to compute the normalization factor. The main normalization operation involves centering the input by subtracting the mean and scaling with the reciprocal of the variance's square root (`tmp10`). The kernel also applies optional affine transformations using `primals_1` (scale) and `primals_2` (bias), storing the final normalized outputs in `out_ptr1`. The function sets up the kernel execution using Triton's grid and block configuration and manages CUDA streams for device synchronization. The `fused_native_layer_norm` Python wrapper sets up these inputs and invokes the Triton kernel, preparing auxiliary output tensors like `buf0`, `buf3`, and the main result in `buf4`.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch._C import _cuda_getCurrentRawStream as get_raw_stream\nfrom torch._inductor.runtime import triton_helpers\nfrom torch._inductor.runtime.triton_helpers import libdevice\n\nempty_strided_cuda = torch._C._dynamo.guards._empty_strided_cuda\nreinterpret_tensor = torch.ops.inductor._reinterpret_tensor\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"XBLOCK\": 1,\n \"RBLOCK\": 1024,\n },\n num_stages=1,\n num_warps=8,\n ),\n triton.Config(\n {\n \"XBLOCK\": 1,\n \"RBLOCK\": 2048,\n },\n num_stages=1,\n num_warps=8,\n ),\n ],\n key=[\"xnumel\", \"rnumel\"],\n)\n@triton.jit\ndef triton_red_fused_native_layer_norm_0(\n in_out_ptr0,\n in_ptr0,\n in_ptr1,\n in_ptr2,\n out_ptr0,\n out_ptr1,\n xnumel,\n rnumel,\n XBLOCK: tl.constexpr,\n RBLOCK: tl.constexpr,\n):\n xoffset = tl.program_id(0) * XBLOCK\n xindex = xoffset + tl.arange(0, XBLOCK)[:, None]\n xmask = xindex < xnumel\n rbase = tl.arange(0, RBLOCK)[None, :]\n x0 = xindex\n tmp3_mean = tl.zeros([XBLOCK, RBLOCK], tl.float32)\n tmp3_m2 = tl.zeros([XBLOCK, RBLOCK], tl.float32)\n tmp3_weight = tl.zeros([XBLOCK, RBLOCK], tl.float32)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp0 = tl.load(\n in_ptr0 + (r1 + (rnumel * x0)), rmask, eviction_policy=\"evict_last\"\n ).to(tl.float32)\n tmp1 = tmp0.to(tl.float32)\n tmp2 = tl.broadcast_to(tmp1, [XBLOCK, RBLOCK])\n tmp3_mean_next, tmp3_m2_next, tmp3_weight_next = triton_helpers.welford_reduce(\n tmp2, tmp3_mean, tmp3_m2, tmp3_weight, roffset == 0\n )\n tmp3_mean = tl.where(rmask, tmp3_mean_next, tmp3_mean)\n tmp3_m2 = tl.where(rmask, tmp3_m2_next, tmp3_m2)\n tmp3_weight = tl.where(rmask, tmp3_weight_next, tmp3_weight)\n tmp3_tmp, tmp4_tmp, tmp5_tmp = triton_helpers.welford(\n tmp3_mean, tmp3_m2, tmp3_weight, 1\n )\n tmp3 = tmp3_tmp[:, None]\n tmp4 = tmp4_tmp[:, None]\n tmp5 = tmp5_tmp[:, None]\n tl.store(out_ptr0 + (x0), tmp3, None)\n tmp6 = rnumel\n tmp7 = tmp4 / tmp6\n tmp8 = 1e-05\n tmp9 = tmp7 + tmp8\n tmp10 = libdevice.rsqrt(tmp9)\n tl.debug_barrier()\n tl.store(in_out_ptr0 + (x0), tmp10, None)\n for roffset in range(0, rnumel, RBLOCK):\n rindex = roffset + rbase\n rmask = rindex < rnumel\n r1 = rindex\n tmp11 = tl.load(\n in_ptr0 + (r1 + (rnumel * x0)), rmask, eviction_policy=\"evict_first\"\n ).to(tl.float32)\n tmp15 = tl.load(in_ptr1 + (r1), rmask, eviction_policy=\"evict_last\").to(\n tl.float32\n )\n tmp18 = tl.load(in_ptr2 + (r1), rmask, eviction_policy=\"evict_last\").to(\n tl.float32\n )\n tmp12 = tmp11.to(tl.float32)\n tmp13 = tmp12 - tmp3\n tmp14 = tmp13 * tmp10\n tmp16 = tmp15.to(tl.float32)\n tmp17 = tmp14 * tmp16\n tmp19 = tmp18.to(tl.float32)\n tmp20 = tmp17 + tmp19\n tmp21 = tmp20.to(tl.float32)\n tl.store(out_ptr1 + (r1 + (rnumel * x0)), tmp21, rmask)\n\ndef fused_native_layer_norm(primals_1, primals_2, primals_3):\n S, D = primals_3.shape\n with torch.cuda._DeviceGuard(0):\n torch.cuda.set_device(0)\n buf0 = empty_strided_cuda((S, 1), (1, 1), torch.float32)\n buf1 = empty_strided_cuda((S, 1), (1, S), torch.float32)\n buf3 = reinterpret_tensor(buf1, (S, 1), (1, 1), 0)\n del buf1 # reuse\n buf4 = empty_strided_cuda((S, D), (D, 1), torch.bfloat16)\n stream0 = get_raw_stream(0)\n grid = lambda META: (triton.cdiv(S, META[\"XBLOCK\"]),)\n triton_red_fused_native_layer_norm_0[grid](\n buf3, primals_3, primals_1, primals_2, buf0, buf4, S, D\n )\n return (\n buf4,\n primals_3,\n buf0,\n buf3,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-d001a6df1a1348e891f77139e69eab2e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The function `spinning_lock_kernel` is a Triton kernel that performs a reduction operation using a spinning lock \n mechanism for synchronization. This kernel takes input pointers `P`, `C`, and `locks`, with integer parameters \n `num_sms`, `k`, `M`, `N`, `stride_cm`, and `stride_cn`, and two compile-time constants `BLOCK_SIZE_M` and \n `BLOCK_SIZE_N`. Each block is identified by a unique program id `pid`, which determines its position in a virtual \n grid through `pid_m` and `pid_n`. The kernel initializes an accumulator `acc` to zeros. It then loops up to 9 times, \n performing accumulation only for the first thread (when `pid % k == 0`) by iterating over other participating threads \n and accumulating their data if successful in acquiring the lock via `tl.atomic_cas`. Threads that don't participate \n directly in reduction (i.e., where `pid % k != 0`) store the accumulated result in `P` and release the lock using \n `tl.atomic_xchg`. Finally, the results are written to `C` using computed offsets, and a mask ensures that only valid \n memory locations are accessed. The `spinning_lock` function orchestrates the kernel launch, preparing the execution \n grid according to `num_sms` and invoking the kernel with the specified block sizes.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n# Define the Triton kernel\n@triton.jit\ndef spinning_lock_kernel(P, C, locks, num_sms, k, M, N, stride_cm, stride_cn, BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr):\n pid = tl.program_id(0)\n pid_m = pid // num_sms\n pid_n = pid % num_sms\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) # Assuming acc initialization\n\n # Perform reduction for every kth pid\n for iters in range(1, 10):\n if (pid % k == 0):\n next_pid = pid + 1\n\n while next_pid < pid + k and next_pid < num_sms:\n while tl.atomic_cas(locks + next_pid, 1, 1) != 1:\n pass\n\n rm1 = tl.arange(0, BLOCK_SIZE_M)\n rn1 = tl.arange(0, BLOCK_SIZE_N)\n P_ = P + next_pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]\n acc1 = tl.load(P_)\n acc += acc1\n\n next_pid += 1\n \n # Store results using temporary storage P for every k-1 pids\n else:\n rm1 = tl.arange(0, BLOCK_SIZE_M)\n rn1 = tl.arange(0, BLOCK_SIZE_N)\n P_ = P + pid * BLOCK_SIZE_M * BLOCK_SIZE_N + rm1[:, None] * BLOCK_SIZE_N + rn1[None, :]\n tl.store(P_, acc)\n tl.atomic_xchg(locks + pid, 1)\n\n # Store final results in C\n rm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n rn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n C_ = C + rm[:, None] * stride_cm + rn[None, :] * stride_cn\n mask = (rm < M)[:, None] & (rn < N)[None, :]\n tl.store(C_, acc, mask=mask)\n\n\ndef spinning_lock(P, C, locks, num_sms, k, M, N, stride_cm, stride_cn, BLOCK_SIZE_M, BLOCK_SIZE_N):\n grid = (num_sms,)\n spinning_lock_kernel[grid](\n P, C, locks, num_sms, k, M, N, stride_cm, stride_cn, BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N,)\n\n\n\n" + ] + }, + { + "id": "openseek-8-961907808a044714980a5a3055b6100e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `rmsnorm_triton` is designed to perform RMS normalization on a 3D tensor `x` with dimensions `[batch, M, K]`. The input pointers `x_ptr`, `rms_w_ptr`, and `out_ptr` correspond to the input tensor, the RMS weights, and the output tensor, respectively. The function is parallelized across the `batch` and `M` dimensions using Triton's `program_id`. Inside the kernel, the RMS value is computed by summing the squares of each element along the `K` dimension, normalized by `N_SIZE` (size of `K`), and then adjusted with a small constant `eps` to prevent division by zero. The computed RMS is then used to normalize the elements, which are subsequently scaled by `rms_weights`. The `rmsnorm_wrapper` function facilitates the execution of this kernel by organizing the input, output, strides, and block sizes before invoking the kernel with specified grid dimensions (`batch, M`). The kernel utilizes a `BLOCK_N_SIZE` to process chunks of the `K` dimension at a time, enhancing parallel execution, and uses `num_warps` to control the level of parallelism within a Triton block.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rmsnorm_triton(x_ptr, rms_w_ptr, out_ptr,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_k,\n N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):\n pid_batch = tl.program_id(0)\n pid_m = tl.program_id(1)\n\n # parallel at m dimension\n offset_m = pid_batch * stride_x_batch + pid_m * stride_x_m\n block_n_size = tl.arange(0, BLOCK_N_SIZE)\n var = tl.zeros((BLOCK_N_SIZE,), tl.float32)\n # parallel between blocks\n for block_n_strart_ptr in range(0, N_SIZE, BLOCK_N_SIZE):\n offset_n = block_n_strart_ptr + block_n_size\n x_ptr_mask = offset_n < N_SIZE\n x = tl.load(x_ptr + offset_m + offset_n * stride_x_k, mask=x_ptr_mask, other=0.) # careful stride_x_k\n xf = x.to(tl.float32)\n var += xf*xf\n var = tl.sum(var, axis=0) / N_SIZE # reduce between wrap\n std = tl.sqrt(var + eps)\n\n for block_n_strart_ptr in range(0, N_SIZE, BLOCK_N_SIZE):\n offset_n = block_n_strart_ptr + block_n_size\n x_ptr_mask = offset_n < N_SIZE\n\n rms_w_offset = tl.load(rms_w_ptr + offset_n * stride_rms_w, mask=x_ptr_mask)\n x = tl.load(x_ptr + offset_m + offset_n * stride_x_k, mask=x_ptr_mask, other=0.)\n\n x_new = x / std\n out = x_new * rms_w_offset\n out_offset = pid_batch * stride_out_batch + pid_m * stride_out_m + offset_n * stride_out_k\n tl.store(out_ptr + out_offset, out, mask=x_ptr_mask)\n\n\ndef rmsnorm_wrapper(x, rms_weights, eps=1e-6):\n batch, M, K = x.shape\n out = torch.empty_like(x)\n rmsnorm_triton[(batch, M,)](x, rms_weights, out,\n *x.stride(),\n *rms_weights.stride(), # 1\n *out.stride(),\n N_SIZE=K, eps=eps, BLOCK_N_SIZE=4096,\n num_warps=16\n )\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-56bb106a6ce14888b9de6ccdd9d4b15f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `chunk_simple_gla_bwd_kernel_dqkg` calculates the backward gradients for query (`dq`), key (`dk`), and a gating mechanism (`dg`) used in attention mechanisms of neural networks, specifically transformers. The kernel operates on the tensors: `q` (query), `k` (key), `v` (value), `h` (hidden states), `g` (gating), `do` (output differential), and `dh` (hidden differential). It leverages Triton's `autotune` decorator, which optimizes for performance by experimenting with 4 and 8 warp configurations based on input dimensions `BT`, `BK`, and `BV`. The primary function executes in parallel over grid dimensions `(NK, NT, B*H)` with indices set by Triton's program ID functions. The kernel first initializes zero tensors for gradients and performs calculations such as matrix multiplications and exponential scalings to fill `b_dq`, `b_dk`, and `b_dg` during the loop over `V`. Boundary checks are used during tensor loading to prevent out-of-bounds memory access. The kernel stores results back to the output tensors (`dq`, `dk`, `dg`) using block pointers to access appropriate memory locations. The host function `chunk_bwd_dqkg_fn` defines grid settings and tensor shapes, initializes output tensors, and launches the Triton kernel, facilitating gradient calculations in attention models.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8)\n ],\n key=[\"BT\", \"BK\", \"BV\"],\n)\n@triton.jit\ndef chunk_simple_gla_bwd_kernel_dqkg(\n q,\n k,\n v,\n h,\n g,\n do,\n dh,\n dq,\n dk,\n dg,\n s_k_h,\n s_k_t,\n s_v_h,\n s_v_t,\n s_h_h,\n s_h_t,\n scale,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n # Kernel implementation\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_g = tl.make_block_ptr(g + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_g = tl.load(p_g, boundary_check=(0,))\n last_idx = min(i_t * BT + BT, T) - 1\n b_g_last = tl.load(g + i_bh * T + last_idx)\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n b_dg_last = tl.zeros([1,], dtype=tl.float32)\n b_dg = tl.zeros([BT,], dtype=tl.float32)\n\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BV, BK]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n\n b_dg_last += (tl.sum(b_h * b_dh))\n b_ds += tl.dot(b_do, tl.trans(b_v))\n b_dq += tl.dot(b_do, b_h.to(b_do.dtype))\n b_dk += tl.dot(b_v, b_dh.to(b_v.dtype))\n\n p_q = tl.make_block_ptr(q + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_dg_last *= tl.exp(b_g_last)\n b_dq = b_dq * tl.exp(b_g)[:, None] * scale\n b_dk = b_dk * tl.exp(-b_g + b_g_last)[:, None]\n b_dg_last += tl.sum(b_dk * b_k)\n b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale * tl.exp(b_g[:, None] - b_g[None, :]), 0)\n b_ds = b_ds.to(b_k.dtype)\n # [BT, BK]\n b_dq += tl.dot(b_ds, b_k)\n b_dk += tl.dot(tl.trans(b_ds), b_q)\n b_dg += tl.sum(b_q * b_dq - b_k * b_dk, axis=1)\n # (SY 09/21) revcumsum in a separate kernel due to strange triton compiler issue\n # b_dg = tl.dot(tl.where(o_i[:, None] <= o_i[None, :], 1., 0.), b_dg, allow_tf32=False) + b_dg_last)\n b_dg = tl.where(o_i < min(BT, T-i_t*BT) - 1, b_dg, b_dg + b_dg_last)\n p_dq = tl.make_block_ptr(dq + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_k_h, (T, K), (s_k_t, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dg = tl.make_block_ptr(dg + (i_k*n_bh + i_bh) * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0,))\n\ndef chunk_bwd_dqkg_fn(do, q, k, v, g, h, dh, scale):\n # Kernel call\n B, H, T, K, V = *k.shape, v.shape[-1]\n BT = 64\n BK = min(triton.next_power_of_2(K), 64)\n BV = min(triton.next_power_of_2(V), 64)\n NT, NK = triton.cdiv(T, BT), triton.cdiv(K, BK)\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dg = torch.empty(NK, B, H, T, dtype=torch.float32, device=g.device).fill_(-1e9)\n chunk_simple_gla_bwd_kernel_dqkg[grid](\n q, k, v, h, g, do, dh, dq, dk, dg,\n q.stride(1), q.stride(2),\n v.stride(1), v.stride(2),\n dh.stride(1), dh.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT\n )\n return dq, dk, dg\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c401b4ab0ddd469796ebbbd9cb191c36", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code defines Triton kernels and a PyTorch autograd function for computing fifth-order spherical harmonics. The `fifth_order_fwd` kernel computes the spherical harmonics projections using x, y, z coordinates, defined with various constants for the harmonic orders. Inputs include a coordinate tensor, output buffer, block size, number of elements, and other configuration parameters. The output consists of harmonics in different orders computed using mathematical operations on x, y, z. The `fifth_order_bwd` kernel calculates gradients for these harmonics, using pre-calculated derivatives and loading gradients from the output space. Each kernel handles blocks of work determined by block size, calculated via a helper function. The `FifthOrderSphericalHarmonic` class is a PyTorch function managing these kernels' forward and backward execution. It initializes outputs, saves context for backpropagation, and allocates memory as needed. The code applies Triton's parallel capabilities to process tensor computations efficiently, aimed at GPU utilization. \n ", + "output": [ + "import triton\nimport torch\nfrom triton import language as tl\nimport math\n\ndef calculate_lastdim_num_blocks(input_tensor: torch.Tensor, block_size: int) -> int:\n \"\"\"\n Calculate the number of blocks for a tensor, assuming we\n stride along the last dimension, and a given block size.\n\n The corresponding pointer arithmetic looks like this:\n\n ```python\n block_id = tl.program_id(0)\n striding = tl.arange(0, block_size) * stride\n offset = (striding + (block_size * stride * block_id))\n ```\n\n This function is used to work out the amount of parallel\n work that needs to be done, given as the total number of\n elements divided by the last dimension stride, and a specified\n block size that will then divvy up the work.\n\n Parameters\n ----------\n input_tensor : torch.Tensor\n Torch N-d tensor to operate over.\n\n Returns\n -------\n int\n Number of blocks of work, given a block size.\n \"\"\"\n # get the stride of the last dimension\n stride = input_tensor.stride(-2)\n numel = input_tensor.numel()\n total_blocks = math.ceil(numel / stride)\n return total_blocks\n\n\n# This kernel computes the fifth order spherical harmonics projections.\n@triton.jit\ndef fifth_order_fwd(\n coord_ptr: tl.tensor, # Pointer to input coordinates tensor.\n output_ptr: tl.tensor, # Pointer to output tensor.\n block_size: tl.constexpr, # Number of elements in each block.\n coord_numel: tl.constexpr, # Total number of elements in the coordinates tensor.\n output_numel: tl.constexpr, # Total number of elements in the output tensor.\n col_offset: tl.constexpr, # Offset for the output tensor.\n output_stride: tl.constexpr, # Stride of the output tensor.\n):\n coord_stride = 3\n block_id = tl.program_id(0)\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n # -------------------- variable and constant definitions\n CONST000 = 1.73430461568895\n CONST001 = 2.32681380862329\n CONST002 = 1.60565407233314\n CONST003 = 3.21130814466628\n CONST004 = 3.31662479035540\n CONST005 = 6.21867148191637\n CONST006 = 6.21867148191637\n CONST007 = 1.60565407233314\n CONST009 = 11.6340690431164\n CONST010 = 12.8452325786651\n CONST011 = 12.4373429638327\n CONST012 = 12.8452325786651\n CONST013 = 13.8744369255116\n CONST017 = 33.9852909359329\n CONST018 = 7.35803132638072\n CONST020 = -44.1481879582843\n CONST021 = -41.6233107765348\n CONST022 = -29.4321253055229\n CONST023 = -23.2681380862329\n CONST024 = -19.2678488679977\n CONST025 = -19.2678488679977\n CONST026 = -16.9926454679664\n CONST027 = -16.9926454679664\n CONST028 = -13.8744369255116\n CONST029 = -16.5831239517770\n CONST030 = 3.46860923137790\n CONST031 = -8.49632273398321\n CONST032 = -5.20291384706685\n CONST033 = -3.46860923137790\n CONST034 = -1.73430461568895\n VAR05 = x * x * x * x * x\n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR14 = y * y * y * y * y\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR23 = z * z * z * z * z\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n # -------------------- kernel implementations\n Y00 = CONST001 * VAR05 + CONST009 * VAR24 * x + CONST023 * VAR07 * VAR26\n Y01 = y * (CONST022 * VAR07 * z - CONST022 * VAR25 * x)\n Y02 = (\n CONST000 * VAR05\n + VAR07 * (CONST028 * VAR17 + CONST033 * VAR26)\n + x * (-CONST021 * VAR17 * VAR26 + CONST032 * VAR24)\n )\n Y03 = CONST027 * VAR07 * y * z + x * (CONST017 * VAR16 * z + CONST026 * VAR25 * y)\n Y04 = (\n CONST002 * VAR05\n + VAR07 * (CONST003 * VAR26 + CONST025 * VAR17)\n + x * (CONST002 * VAR24 + CONST010 * VAR15 + CONST024 * VAR17 * VAR26)\n )\n Y05 = (\n CONST004 * VAR14\n + VAR16 * (CONST029 * VAR08 + CONST029 * VAR26)\n + y * (CONST005 * VAR06 + CONST006 * VAR24 + CONST011 * VAR08 * VAR26)\n )\n Y06 = (\n CONST002 * VAR23\n + VAR25 * (CONST003 * VAR08 + CONST024 * VAR17)\n + z * (CONST007 * VAR06 + CONST012 * VAR15 + CONST024 * VAR08 * VAR17)\n )\n Y07 = VAR16 * (CONST026 * VAR08 - CONST026 * VAR26) + y * (\n -CONST031 * VAR06 + CONST031 * VAR24\n )\n Y08 = (\n CONST034 * VAR23\n + VAR25 * (CONST013 * VAR17 + CONST030 * VAR08)\n + z * (CONST021 * VAR08 * VAR17 - CONST032 * VAR06)\n )\n Y09 = y * (CONST018 * VAR06 + CONST018 * VAR24 + CONST020 * VAR08 * VAR26)\n Y10 = CONST001 * VAR23 + CONST009 * VAR06 * z + CONST023 * VAR08 * VAR25\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n tl.store(output_ptr + output_row_offset, Y00, mask=output_row_offset < output_numel)\n tl.store(\n output_ptr + output_row_offset + 1,\n Y01,\n mask=output_row_offset + 1 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 2,\n Y02,\n mask=output_row_offset + 2 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 3,\n Y03,\n mask=output_row_offset + 3 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 4,\n Y04,\n mask=output_row_offset + 4 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 5,\n Y05,\n mask=output_row_offset + 5 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 6,\n Y06,\n mask=output_row_offset + 6 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 7,\n Y07,\n mask=output_row_offset + 7 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 8,\n Y08,\n mask=output_row_offset + 8 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 9,\n Y09,\n mask=output_row_offset + 9 < output_numel,\n )\n tl.store(\n output_ptr + output_row_offset + 10,\n Y10,\n mask=output_row_offset + 10 < output_numel,\n )\n\n\n# This kernel computes the gradients of the fifth order spherical harmonics projections.\n@triton.jit\ndef fifth_order_bwd(\n coord_ptr: tl.tensor, # Pointer to input coordinates tensor.\n coord_grad_ptr: tl.tensor, # Pointer to the gradient output tensor.\n sph_grad_ptr: tl.tensor, # Pointer to the gradient of the spherical harmonic.\n block_size: tl.constexpr, # Number of elements in each block.\n coord_numel: tl.constexpr, # Total number of elements in the coordinates tensor.\n output_numel: tl.constexpr, # Total number of elements in the output tensor.\n col_offset: tl.constexpr, # Offset for the output tensor.\n output_stride: tl.constexpr, # Stride of the output tensor.\n):\n block_id = tl.program_id(0)\n coord_stride = 3\n coord_striding = tl.arange(0, block_size) * coord_stride\n coord_row_offset = coord_striding + (block_size * coord_stride * block_id)\n x = tl.load(coord_ptr + coord_row_offset, mask=coord_row_offset < coord_numel)\n y = tl.load(\n coord_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n z = tl.load(\n coord_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n output_striding = tl.arange(0, block_size) * output_stride\n output_row_offset = (\n output_striding + (block_size * output_stride * block_id) + col_offset\n )\n g_0 = tl.load(\n sph_grad_ptr + output_row_offset, mask=output_row_offset < output_numel\n )\n g_1 = tl.load(\n sph_grad_ptr + output_row_offset + 1, mask=output_row_offset + 1 < output_numel\n )\n g_2 = tl.load(\n sph_grad_ptr + output_row_offset + 2, mask=output_row_offset + 2 < output_numel\n )\n g_3 = tl.load(\n sph_grad_ptr + output_row_offset + 3, mask=output_row_offset + 3 < output_numel\n )\n g_4 = tl.load(\n sph_grad_ptr + output_row_offset + 4, mask=output_row_offset + 4 < output_numel\n )\n g_5 = tl.load(\n sph_grad_ptr + output_row_offset + 5, mask=output_row_offset + 5 < output_numel\n )\n g_6 = tl.load(\n sph_grad_ptr + output_row_offset + 6, mask=output_row_offset + 6 < output_numel\n )\n g_7 = tl.load(\n sph_grad_ptr + output_row_offset + 7, mask=output_row_offset + 7 < output_numel\n )\n g_8 = tl.load(\n sph_grad_ptr + output_row_offset + 8, mask=output_row_offset + 8 < output_numel\n )\n g_9 = tl.load(\n sph_grad_ptr + output_row_offset + 9, mask=output_row_offset + 9 < output_numel\n )\n g_10 = tl.load(\n sph_grad_ptr + output_row_offset + 10,\n mask=output_row_offset + 10 < output_numel,\n )\n # -------------------- variable and constant definitions\n CONST000 = 1.60565407233314\n CONST001 = 3.00000000000000\n CONST002 = 3.21130814466628\n CONST003 = 1.60565407233314\n CONST004 = 6.42261628933256\n CONST005 = 6.42261628933256\n CONST006 = 8.67152307844476\n CONST007 = 8.02827036166571\n CONST008 = 6.93721846275580\n CONST009 = 11.6340690431164\n CONST010 = 12.8452325786651\n CONST011 = 6.21867148191637\n CONST012 = 6.21867148191637\n CONST014 = 12.4373429638327\n CONST017 = 12.8452325786651\n CONST018 = 13.8744369255116\n CONST019 = 24.8746859276655\n CONST020 = 24.8746859276655\n CONST021 = 27.7488738510232\n CONST024 = 29.4321253055229\n CONST027 = 7.35803132638072\n CONST029 = 46.5362761724657\n CONST030 = 51.3809303146605\n CONST031 = 51.3809303146605\n CONST034 = 101.955872807799\n CONST036 = -8.67152307844475\n CONST037 = 3.46860923137790\n CONST038 = -88.2963759165686\n CONST039 = -83.2466215530696\n CONST040 = -69.8044142586986\n CONST041 = -50.9779364038993\n CONST042 = -50.9779364038993\n CONST043 = -46.5362761724657\n CONST044 = -44.1481879582843\n CONST045 = -41.6233107765348\n CONST046 = -38.5356977359954\n CONST047 = -38.5356977359954\n CONST048 = -33.1662479035540\n CONST049 = -33.9852909359329\n CONST050 = 6.42261628933257\n CONST051 = -33.9852909359329\n CONST052 = -29.4321253055229\n CONST053 = -27.7488738510232\n CONST054 = -20.8116553882674\n CONST055 = -19.2678488679977\n CONST056 = -19.2678488679977\n CONST057 = -16.9926454679664\n CONST058 = -16.9926454679664\n CONST059 = -13.8744369255116\n CONST060 = -16.5831239517770\n CONST061 = -8.49632273398321\n CONST062 = -6.93721846275580\n CONST063 = -5.20291384706685\n CONST064 = -3.46860923137790\n VAR06 = x * x * x * x\n VAR07 = x * x * x\n VAR08 = x * x\n VAR15 = y * y * y * y\n VAR16 = y * y * y\n VAR17 = y * y\n VAR24 = z * z * z * z\n VAR25 = z * z * z\n VAR26 = z * z\n # -------------------- kernel implementations\n g_x = tl.load(\n coord_grad_ptr + coord_row_offset, mask=coord_row_offset < coord_numel\n )\n g_y = tl.load(\n coord_grad_ptr + coord_row_offset + 1, mask=coord_row_offset + 1 < coord_numel\n )\n g_z = tl.load(\n coord_grad_ptr + coord_row_offset + 2, mask=coord_row_offset + 2 < coord_numel\n )\n g_x += (\n g_0 * (CONST009 * VAR06 + CONST009 * VAR24 + CONST040 * VAR08 * VAR26)\n + g_1 * y * (CONST038 * VAR08 * z - CONST052 * VAR25)\n + g_10 * (CONST029 * VAR07 * z + CONST043 * VAR25 * x)\n + g_2\n * (\n CONST001 * VAR08 * (CONST059 * VAR17 + CONST064 * VAR26)\n + CONST006 * VAR06\n - CONST045 * VAR17 * VAR26\n + CONST063 * VAR24\n )\n + g_3 * (CONST041 * VAR08 * y * z - CONST049 * VAR16 * z + CONST057 * VAR25 * y)\n + g_4\n * (\n CONST000 * VAR24\n + CONST001 * VAR08 * (CONST002 * VAR26 + CONST055 * VAR17)\n + CONST007 * VAR06\n + CONST010 * VAR15\n + CONST056 * VAR17 * VAR26\n )\n + g_5 * (CONST048 * VAR16 * x + y * (CONST019 * VAR07 + CONST019 * VAR26 * x))\n + g_6 * (CONST005 * VAR25 * x + z * (CONST004 * VAR07 + CONST046 * VAR17 * x))\n + g_7 * (CONST049 * VAR16 * x - CONST051 * VAR07 * y)\n + g_8 * (CONST008 * VAR25 * x + z * (CONST039 * VAR17 * x - CONST054 * VAR07))\n + g_9 * y * (CONST024 * VAR07 + CONST038 * VAR26 * x)\n )\n g_y += (\n g_1 * (CONST052 * VAR07 * z - CONST052 * VAR25 * x)\n + g_2 * (-CONST039 * VAR26 * x * y + CONST053 * VAR07 * y)\n + g_3 * (CONST058 * VAR07 * z + x * (CONST034 * VAR17 * z + CONST057 * VAR25))\n + g_4 * (CONST047 * VAR07 * y + x * (CONST030 * VAR16 + CONST046 * VAR26 * y))\n + g_5\n * (\n CONST001 * VAR17 * (CONST060 * VAR08 + CONST060 * VAR26)\n + CONST011 * VAR06\n + CONST012 * VAR24\n + CONST014 * VAR08 * VAR26\n - CONST060 * VAR15\n )\n + g_6 * (CONST046 * VAR25 * y + z * (CONST031 * VAR16 + CONST046 * VAR08 * y))\n + g_7\n * (\n CONST001 * VAR17 * (CONST057 * VAR08 - CONST057 * VAR26)\n - CONST061 * VAR06\n + CONST061 * VAR24\n )\n + g_8 * (CONST021 * VAR25 * y + CONST039 * VAR08 * y * z)\n + g_9 * (CONST027 * VAR06 + CONST027 * VAR24 + CONST044 * VAR08 * VAR26)\n )\n g_z += (\n g_0 * (CONST029 * VAR25 * x + CONST043 * VAR07 * z)\n + g_1 * y * (-CONST038 * VAR26 * x + CONST052 * VAR07)\n + g_10 * (CONST009 * VAR06 + CONST009 * VAR24 + CONST040 * VAR08 * VAR26)\n + g_2 * (CONST062 * VAR07 * z + x * (-CONST039 * VAR17 * z + CONST054 * VAR25))\n + g_3 * (CONST058 * VAR07 * y + x * (CONST042 * VAR26 * y - CONST049 * VAR16))\n + g_4 * (CONST005 * VAR07 * z + x * (CONST046 * VAR17 * z + CONST050 * VAR25))\n + g_5 * (CONST048 * VAR16 * z + y * (CONST019 * VAR08 * z + CONST020 * VAR25))\n + g_6\n * (\n CONST001 * VAR26 * (CONST002 * VAR08 + CONST056 * VAR17)\n + CONST003 * VAR06\n + CONST007 * VAR24\n + CONST017 * VAR15\n + CONST056 * VAR08 * VAR17\n )\n + g_7 * (-CONST049 * VAR16 * z + CONST051 * VAR25 * y)\n + g_8\n * (\n CONST001 * VAR26 * (CONST018 * VAR17 + CONST037 * VAR08)\n + CONST036 * VAR24\n + CONST045 * VAR08 * VAR17\n - CONST063 * VAR06\n )\n + g_9 * y * (CONST024 * VAR25 + CONST038 * VAR08 * z)\n )\n tl.store(\n coord_grad_ptr + coord_row_offset, g_x, mask=coord_row_offset < coord_numel\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 1,\n g_y,\n mask=coord_row_offset + 1 < coord_numel,\n )\n tl.store(\n coord_grad_ptr + coord_row_offset + 2,\n g_z,\n mask=coord_row_offset + 2 < coord_numel,\n )\n\n\nclass FifthOrderSphericalHarmonic(torch.autograd.Function):\n @staticmethod\n def forward(\n ctx,\n coords: torch.Tensor, # Input coordinates tensor.\n output_tensor: torch.Tensor | None = None, # Output tensor placeholder.\n mask: torch.Tensor | None = None, # Mask tensor (not used in this implementation).\n block_size: int = 64, # Block size for Triton kernels.\n col_offset: int = 0, # Column offset for the output tensor.\n ):\n if not isinstance(output_tensor, torch.Tensor):\n output_tensor = torch.empty(\n (*coords.shape[:-1], 11), dtype=coords.dtype, device=coords.device\n )\n coord_numel = coords.numel()\n output_numel = output_tensor.numel()\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n fifth_order_fwd[num_blocks,](\n coords,\n output_tensor,\n block_size,\n coord_numel,\n output_numel,\n col_offset,\n output_tensor.stride(-2),\n )\n ctx.save_for_backward(coords)\n return output_tensor\n\n @staticmethod\n def backward(\n ctx,\n sph_grad_tensor: torch.Tensor, # Gradient of the output tensor.\n block_size: int = 64, # Block size for Triton kernels.\n col_offset: int = 0, # Column offset for the output tensor.\n ) -> torch.Tensor:\n (coords,) = ctx.saved_tensors\n coord_grad_output = torch.zeros_like(coords)\n num_blocks = calculate_lastdim_num_blocks(coords, block_size)\n fifth_order_bwd[num_blocks,](\n coords,\n coord_grad_output,\n sph_grad_tensor,\n block_size,\n coords.numel(),\n sph_grad_tensor.numel(),\n col_offset,\n sph_grad_tensor.stride(-2),\n )\n return coord_grad_output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-20cf39e7cadd42c880064a82736bcd94", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `fused_recurrent_rwkv6_fwd_kernel` is designed for efficient execution of a specific recurrent neural network operation used in the RWKV-6 model. This kernel processes the input tensors `q`, `k`, `v`, `w`, and `u` to perform fused operations that involve recurrent processing over time `T`. The kernel can optionally use an initial hidden state (`USE_INITIAL_STATE`) and store the final hidden state (`STORE_FINAL_STATE`). It supports reversing the input sequence (`REVERSE`). The main compute loop iterates over the sequence length `T` to perform recurrent updates by loading slices of `k`, `v`, and applying operations with `b_w` and `b_u` to update the hidden state and compute the output `b_o`. The function `FusedRecurrentRWKV6Function` is an autograd function that encapsulates the forward pass using this Triton kernel, setting up grid parameters (`BK`, `BV`, `NK`, `NV`) for parallel execution. It manages memory allocation for outputs and optionally for final hidden states, while also saving necessary tensors for backward pass in PyTorch's autograd. The `fused_recurrent_rwkv6` function acts as a user-facing API to perform this computation, with the ability to scale `q` by `scale` and return both the output tensor and optionally the final hidden state.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.jit\ndef fused_recurrent_rwkv6_fwd_kernel(\n q, k, v, w, u, o, h0, ht, s_k_h, s_v_h, scale, B: tl.constexpr, H: tl.constexpr, T: tl.constexpr, K: tl.constexpr, V: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, REVERSE: tl.constexpr,\n):\n i_v, i_k, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n\n p_q = q + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_k = k + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_v = v + i_bh * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_o = o + (i_bh + i_k * B * H) * s_v_h + i_v * BV + tl.arange(0, BV) + ((T-1) * V if REVERSE else 0)\n p_w = w + i_bh * s_k_h + i_k * BK + tl.arange(0, BK) + ((T-1) * K if REVERSE else 0)\n p_u = u + i_h * K + tl.arange(0, BK) + i_k * BK\n\n mask_bk = (i_k * BK + tl.arange(0, BK)) < K\n mask_bv = (i_v * BV + tl.arange(0, BV)) < V\n mask_kv = mask_bv[:, None] & mask_bk[None, :]\n\n b_h = tl.zeros([BV, BK], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n b_h += tl.load(p_h0, mask=mask_kv, other=0).to(tl.float32)\n\n b_u = tl.load(p_u, mask=mask_bk, other=0).to(tl.float32)\n for _ in range(0, T):\n b_k = tl.load(p_k, mask=mask_bk, other=0).to(tl.float32)\n b_v = tl.load(p_v, mask=mask_bv, other=0).to(tl.float32)\n b_q = tl.load(p_q, mask=mask_bk, other=0).to(tl.float32) * scale\n b_w = tl.load(p_w, mask=mask_bk, other=0).to(tl.float32)\n b_w = tl.exp(b_w)\n b_kv = b_k[None, :] * b_v[:, None]\n b_o = (b_h + b_kv * b_u[None, :]) * b_q[None, :]\n b_o = tl.sum(b_o, axis=1)\n b_h = b_h * b_w[None, :]\n b_h += b_kv\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_bv)\n p_q += -K if REVERSE else K\n p_k += -K if REVERSE else K\n p_o += -V if REVERSE else V\n p_v += -V if REVERSE else V\n p_w += -K if REVERSE else K\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * K * V + (i_k * BK + tl.arange(0, BK)[None, :]) * V + (i_v * BV + tl.arange(0, BV)[:, None])\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_kv)\n\nclass FusedRecurrentRWKV6Function(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, r, k, v, w, u, scale=None, initial_state=None, output_final_state=False, reverse=False):\n q = r\n B, H, T, K, V = *q.shape, v.shape[-1]\n\n BK, BV = min(triton.next_power_of_2(K), 32), min(triton.next_power_of_2(V), 32)\n NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 1\n\n final_state = q.new_empty(B, H, K, V) if output_final_state else None\n\n o = q.new_empty(NK, B, H, T, V, dtype=torch.float32)\n grid = (NV, NK, B * H)\n fused_recurrent_rwkv6_fwd_kernel[grid](\n q, k, v, w, u, o, initial_state, final_state,\n k.stride(1),\n v.stride(1),\n scale,\n B=B, H=H, T=T, K=K, V=V, BK=BK, BV=BV,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n REVERSE=reverse,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n o = o.sum(0)\n ctx.save_for_backward(q, k, v, w, u, initial_state)\n ctx.scale = scale\n ctx.reverse = reverse\n return o.to(q.dtype), final_state\n\ndef fused_recurrent_rwkv6(\n r: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n w: torch.Tensor,\n u: torch.Tensor,\n scale: float = -1,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale == -1:\n scale = r.shape[-1] ** -0.5\n o, final_state = FusedRecurrentRWKV6Function.apply(r, k, v, w, u, scale, initial_state, output_final_state)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-d9e4a55aa463494dbd28db59a28fd268", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel function 'softmax_kernel' is designed to compute the softmax operation for each row of a 2D input tensor, indicated by 'input_ptr'. It writes the result to the location pointed by 'output_ptr'. This kernel takes several parameters: 'input_row_stride' and 'output_row_stride' which denote the memory strides for input and output matrices respectively, 'n_cols' which denotes the number of columns in the row being processed, and a compile-time constant 'BLOCK_SIZE' which determines the size of the block of memory being processed in parallel.\n\n For each row, identified by 'tl.program_id(axis=0)', the kernel computes memory offsets for both input and output using the strides. It loads a row of data from global memory into on-chip SRAM using 'tl.load' with masking to handle cases where 'n_cols' is less than 'BLOCK_SIZE', filling absent values with '-inf'.\n\n The kernel computes the maximum value of the loaded row using 'tl.max' for numerical stability. This maximum value is subtracted from each element in the row, followed by exponentiation through 'tl.exp' to form the numerator of the softmax calculation.\n\n The sum of the numerators is computed using 'tl.sum' to serve as the denominator for normalization. Each numerator is then divided by this denominator to obtain the softmax result for the row. The final normalized values are stored back to the output tensor using 'tl.store', employing a mask to ensure that only valid column entries (less than 'n_cols') are written.\n\n The function 'triton_softmax' acts as a wrapper to set up and execute the Triton kernel. It extracts dimensions of the input tensor 'x' to determine 'n_rows' and 'n_cols'. An output tensor of the same shape is initialized. The block size for parallel processing is chosen as the next power of two greater than or equal to 'n_cols', but capped at 1024, to optimize the hardware resources usage. It configures the execution grid to cover all rows and launches the kernel to perform the operation across the entire input tensor. This setup allows efficient parallel execution of the softmax operation across multiple rows.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef softmax_kernel(\n output_ptr, input_ptr, input_row_stride, output_row_stride, n_cols,\n BLOCK_SIZE: tl.constexpr,\n):\n # Get the program ID\n row_idx = tl.program_id(axis=0)\n\n # Compute the memory offsets for this row\n row_start_ptr = input_ptr + row_idx * input_row_stride\n out_row_start_ptr = output_ptr + row_idx * output_row_stride\n\n # Load the row into SRAM\n row = tl.load(row_start_ptr + tl.arange(0, BLOCK_SIZE), mask=tl.arange(0, BLOCK_SIZE) < n_cols, other=-float('inf'))\n\n # Compute max for numerical stability\n row_max = tl.max(row, axis=0)\n \n # Subtract max from row and exponentiate\n numerator = tl.exp(row - row_max)\n \n # Compute sum for normalization\n denominator = tl.sum(numerator, axis=0)\n \n # Normalize\n softmax_output = numerator / denominator\n \n # Store the output\n tl.store(out_row_start_ptr + tl.arange(0, BLOCK_SIZE), softmax_output, mask=tl.arange(0, BLOCK_SIZE) < n_cols)\n\ndef triton_softmax(x):\n n_rows, n_cols = x.shape\n output = torch.empty_like(x)\n \n # Determine the block size\n BLOCK_SIZE = triton.next_power_of_2(n_cols)\n BLOCK_SIZE = min(BLOCK_SIZE, 1024) \n \n # Launch the Triton kernel\n grid = (n_rows,)\n softmax_kernel[grid](\n output, x,\n x.stride(0), output.stride(0),\n n_cols, BLOCK_SIZE=BLOCK_SIZE\n )\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0a06980a74f44c7abd2a24cd33e06959", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `chunk_global_cumsum_scalar_kernel` is designed to perform a cumulative sum operation on a 3D tensor, where the dimensions are Batch (B), Head (H), and Time (T). This kernel processes the tensor in chunks defined by block size `BT`, and each block's result is stored in the output tensor `o`. The kernel uses a loop to iterate over these chunks, loading a block of data using `tl.make_block_ptr` and `tl.load`, performing cumulative sum with `tl.cumsum`, and updating a running sum with `tl.sum` to carry over between chunks. `chunk_global_cumsum_scalar` acts as the kernel's wrapper function, preparing and launching it with the appropriate grid configuration based on the input tensor's dimensions. It accepts a PyTorch tensor `s` and an optional `dtype`, defaulting to the dtype of `s`. It initializes an empty output tensor `z` with the same shape and dtype as `s` to hold the computation results. The grid configuration is determined by the batch size and head count, ensuring each program instance processes one sequence.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=4),\n ],\n key=[]\n)\n@triton.jit\ndef chunk_global_cumsum_scalar_kernel(\n s,\n o,\n T: tl.constexpr,\n BT: tl.constexpr,\n):\n i_bh = tl.program_id(0)\n b_z = tl.zeros([], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n p_o = tl.make_block_ptr(o + i_bh * T, (T,), (1,), (i_t * BT,), (BT,), (0,))\n b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32)\n b_o = tl.cumsum(b_s, axis=0) + b_z[None]\n b_zz = tl.sum(b_s, axis=0)\n b_z += b_zz\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,))\n\ndef chunk_global_cumsum_scalar(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T = s.shape\n dtype = dtype or s.dtype\n grid = (B * H,)\n z = torch.empty_like(s, dtype=dtype)\n chunk_global_cumsum_scalar_kernel[grid](\n s, z,\n T=T\n )\n return z\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-9273ceeae6d4466e8a13d62af5486773", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements an efficient ReLU operation using Triton, a library for writing high-performance GPU kernels. Here is a detailed breakdown:\n - The `heuristics_for_tile_size` function calculates appropriate tile sizes to partition the computational work. It does so by considering the maximum allowed tile size and the dimensions of the input data.\n - The `heuristics_for_num_warps` function determines the number of warps needed for efficient execution based on the calculated tile size. Smaller tile sizes require fewer warps, ensuring optimal resource use.\n - The `StridedBuffer` class is designed to work with tensors having custom strides. This is crucial for supporting operations that need non-standard memory layouts, like tensors with negative strides (e.g., flipping).\n - `relu_forward_wrapper_rank_1` acts as a dispatcher for the kernel execution. It calculates grid and block dimensions, sets up the kernel's parameters like strides, sizes, and launches the `relu_forward_kernel_rank_1`.\n - `relu_forward_kernel_rank_1` is the core Triton kernel that performs the ReLU computation on 1D tensors. It efficiently handles tasks using either a monolithic or grid-stride-loop approach, loading data, applying the ReLU function using `tl.where`, and storing the results back.\n Each function and class are crafted to integrate with Triton's execution model, maximizing GPU performance for the ReLU operation.\n ", + "output": [ + "import math\nfrom typing import Union\nimport torch\nimport triton\nfrom triton import language as tl\n\ndef heuristics_for_tile_size(max_tile_size, *sizes):\n ndim = len(sizes)\n tile_sizes = [0 for _ in range(ndim)]\n for i in range(ndim):\n size = sizes[ndim - 1 - i]\n tile_size = min(max_tile_size, triton.next_power_of_2(size))\n tile_sizes[ndim - 1 - i] = tile_size\n max_tile_size = max(1, max_tile_size // tile_size)\n return tuple(tile_sizes)\n\ndef heuristics_for_num_warps(tile_size):\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n \n\nclass StridedBuffer:\n \"\"\"A drop-in replacement of torch.Tensor that can be used in wrapper generated by\n PointwiseDynamicFunction. It allows us to use a different shape, stride, data\n pointer that that of the base tensor.\n\n It is a kind of reinterpretation of the base tensor. We make this class since we\n cannot get a Tensor view with negative strides via torch APIs, while we need this\n to implement flip op.\n\n Although generated code can accept torch.Tensor & StridedBuffer, but StridedBuffer\n may not have all the methods as torch.Tensors do. We add some attributes & methods\n with the same name as torch.Tensor, which are used in the generated code. But we\n may not cover all the methods, add one if what you need is missing here.\n\n And can also be used in triton kernels since it also has dtype & data_ptr().\n \"\"\"\n\n def __init__(\n self, base: torch.Tensor, shape=None, strides=None, dtype=None, offset=0\n ):\n self._base = base\n self.dtype = dtype or base.dtype\n if offset == 0:\n self._data_ptr = self._base.data_ptr()\n else:\n offset = self.dtype.itemsize * offset\n self._data_ptr = self._base.data_ptr() + offset\n self.shape = tuple(shape if shape is not None else self._base.shape)\n self._strides = tuple(strides if strides is not None else self._base.stride())\n self.device = self._base.device\n self.ndim = len(self.shape)\n\n def stride(self):\n return self._strides\n\n def size(self):\n return self.shape\n\n def element_size(self):\n return self.dtype.itemsize\n\n def numel(self):\n return math.prod(self.shape)\n\n def dim(self):\n return self.ndim\n\n def unwrap(self):\n return self._base\n\n def data_ptr(self):\n return self._data_ptr\n\n\n\ndef relu_forward_wrapper_rank_1(in0: Union[torch.Tensor, StridedBuffer], /, *, out0: Union[torch.Tensor, StridedBuffer]): \n \"\"\"Generated wrapper function with Pointwise: StridedBuffer, StridedBuffer(a1!) -> StridedBuffer(a1!)\"\"\"\n assert in0.shape == out0.shape, 'operand shapes mismatch'\n # task partitioning\n shape = out0.shape\n num_tasks = out0.numel()\n tile_sizes = heuristics_for_tile_size(512, *shape)\n tile_size = math.prod(tile_sizes)\n num_tiles = math.prod(triton.cdiv(size, tile_size) for size, tile_size in zip(shape, tile_sizes))\n num_ctas = min(65536, num_tiles)\n tiles_per_cta = triton.cdiv(num_tiles, num_ctas)\n num_warps = heuristics_for_num_warps(tile_size)\n one_tile_per_cta = tiles_per_cta==1\n grid = (num_ctas, 1, 1)\n # kernel launch\n in0_strides = in0.stride()\n in0_stride_order = (0,)\n out0_strides = out0.stride()\n out0_stride_order = (0,)\n with torch.cuda._DeviceGuard(in0.device.index):\n relu_forward_kernel_rank_1[grid](\n in0, out0,\n in0_strides[0], # stride for in0\n in0_stride_order[0], # stride order for in0\n out0_strides[0], # stride for out0\n out0_stride_order[0], # stride orderfor out0\n shape[0], # task indexing space\n num_tasks, # num tasks\n tiles_per_cta=tiles_per_cta, # tiles_per_cta\n tile_size0=tile_sizes[0],\n one_tile_per_cta=one_tile_per_cta,\n num_warps=num_warps,\n )\n return out0\n\n@triton.jit\ndef relu_forward(x):\n return tl.where(x > 0, x, 0)\n\n@triton.jit\ndef relu_forward_kernel_rank_1(\n in0_ptr: tl.tensor, # of tl.pointer_type\n out0_ptr: tl.tensor, # of tl.pointer_type\n in0_stride0: int, # strides for in0\n in0_stride_order0: tl.constexpr, # stride order for in0\n out0_stride0: int, # strides for out0\n out0_stride_order0: tl.constexpr, # stride order for out0\n s0: int, # task_space\n num_tasks: int,\n tiles_per_cta: int,\n tile_size0: tl.constexpr,\n one_tile_per_cta: tl.constexpr,\n):\n pid = tl.program_id(0)\n num_tiles0 = tl.cdiv(s0, tile_size0)\n if one_tile_per_cta: # monolitic kernel style\n tile_id = pid\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = relu_forward(in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n else: # grid-stride-loop style kernel\n num_ctas = tl.num_programs(0)\n for j in range(0, tiles_per_cta):\n tile_id = pid + j * num_ctas\n # pid multi index recontruction: we use c ordering, right axes changes fastest\n tile_id0 = tile_id\n\n # tile offsets\n offset0 = tile_id0 * tile_size0\n # loads\n in0_bptr = tl.make_block_ptr(in0_ptr, (s0,), (in0_stride0,), (offset0,), (tile_size0,), order=(in0_stride_order0,))\n in0 = tl.load(in0_bptr, boundary_check=(in0_stride_order0,)).to(in0_ptr.type.element_ty) # workaround the bug on bool, we should use the original pointer's dtype(instead of block pointer's)\n\n # compute\n out0 = relu_forward(in0)\n\n # stores, note that store to block pointer does not automatically cast the value to the pointer's dtype\n out0_bptr = tl.make_block_ptr(out0_ptr, (s0,), (out0_stride0,), (offset0,), (tile_size0,), order=(out0_stride_order0,))\n tl.store(out0_bptr, out0.to(out0_bptr.type.element_ty), boundary_check=(out0_stride_order0,))\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e039f950d23e4124990be4c14f23ea92", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel, `_rope_embedding`, and its associated Python wrapper function, `Fast_RoPE_Embedding`, are designed to compute the Rotary Positional Encoding (RoPE) on input matrices, particularly for query (`Q`) and key (`K`) matrices in transformer architectures. The `_rope_embedding` kernel is written in Triton, a language for writing efficient GPU compute kernels, and it leverages parallel execution to accelerate the RoPE computation. \n\n The kernel takes multiple parameters, including:\n - `Q`: the query tensor to be transformed.\n - `Q_row_stride`: the stride between rows in the `Q` tensor.\n - `cos`, `sin`: tensors containing precomputed cosine and sine values for rotation.\n - `cos_row_stride`, `sin_row_stride`: respective strides for cosine and sine tensors.\n - `seqlen`: sequence length constraint.\n - `head_dim`: dimension of each head in the multi-head attention mechanism.\n - `n_heads`: number of attention heads.\n - `BACKWARD_PASS`: a constant indicating if the backward pass is computed.\n - `BLOCK_SIZE`: the block size for Triton kernel execution, determining the parallel execution granularity.\n\n Within the kernel, each thread computes parts of the RoPE transformation for segments of the input tensor. The operation involves computing `Q * cos + rotate_half(Q) * sin`, where `rotate_half(Q)` denotes a transformation involving half-head dimension offsets.\n\n The `Fast_RoPE_Embedding` class manages the forward and backward passes for autograd in PyTorch. In the `forward` method, input tensors are reshaped, processed by `_rope_embedding`, and reshaped back to their original form. The backward method processes gradients, passing them through the RoPE transformation for backpropagation.\n\n Finally, the `fast_rope_embedding` function applies this transformation to both query and key matrices, ensuring compatibility with transformer-based models. It utilizes the `transpose` operation to ensure correct dimensions for the RoPE computation.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\nROPE_GROUP_SIZE = 4\nMAX_FUSED_SIZE : int = 65536\n\ndef calculate_settings(n : int) -> (int, int,):\n BLOCK_SIZE : int = triton.next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n num_warps : int = 4\n if BLOCK_SIZE >= 32768: num_warps = 32\n elif BLOCK_SIZE >= 8192: num_warps = 16\n elif BLOCK_SIZE >= 2048: num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef _rope_embedding(\n Q, Q_row_stride,\n cos, cos_row_stride,\n sin, sin_row_stride,\n seqlen,\n head_dim : tl.constexpr,\n n_heads : tl.constexpr,\n BACKWARD_PASS : tl.constexpr,\n BLOCK_SIZE : tl.constexpr,\n):\n \"\"\"\n Calculates the RoPE Embedding quickly\n RoPE is Q * cos + rotate_half(Q) * sin\n See our blog post for more info\n \"\"\"\n ROPE_GROUP_SIZE = 4\n row_position = tl.program_id(0)\n group_head_position = tl.program_id(1)\n col_offsets = tl.arange(0, BLOCK_SIZE)\n half_head_dim = head_dim // 2\n mask = col_offsets < half_head_dim\n\n sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \\\n half_head_dim*0 + col_offsets, mask = mask, other = 0)\n\n if BACKWARD_PASS:\n # See our blog post for more info.\n sin1 = -sin1\n \n # [TODO] Autotune ROPE_GROUP_SIZE to be 1, 2, 4, 8\n head_start = group_head_position * ROPE_GROUP_SIZE\n head_end = min((head_start + ROPE_GROUP_SIZE), n_heads)\n\n # 10% Faster kernel from [HuyNguyen-hust](https://github.com/unslothai/unsloth/pull/238)\n for k in range(head_start, head_end):\n offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets\n offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim\n\n # For Gemma - sometimes RoPE must be done in float32 and not bfloat16\n Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)\n Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)\n\n tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)\n tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)\n \n\nclass Fast_RoPE_Embedding(torch.autograd.Function):\n @staticmethod\n def forward(ctx, Q, cos, sin):\n cos, sin = cos.squeeze(), sin.squeeze()\n batch, seq_len, n_heads, head_dim = Q.shape\n Q = Q.reshape(batch*seq_len, n_heads*head_dim)\n n_rows, n_cols = Q.shape\n assert(seq_len <= cos.shape[0])\n\n # [TODO] Changing blocksize to head_dim//2 seems to have\n # some concurrency / un-deterministic issues.\n BLOCK_SIZE, num_warps = calculate_settings(head_dim//2) # (head_dim//2)\n \n # group_size = 4 # 4 or 8, too large group_size can hurt performance.\n div, mod = divmod(n_heads, ROPE_GROUP_SIZE)\n n_groups = div + (mod != 0)\n\n _rope_embedding[(n_rows, n_groups, )](\n Q, Q.stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len,\n head_dim, n_heads,\n BACKWARD_PASS = False,\n BLOCK_SIZE = BLOCK_SIZE,\n num_warps = num_warps,\n )\n ctx.BLOCK_SIZE = BLOCK_SIZE\n ctx.num_warps = num_warps\n ctx.n_groups = n_groups\n ctx.cos = cos\n ctx.sin = sin\n return Q.view(batch, seq_len, n_heads, head_dim)\n \n @staticmethod\n def backward(ctx, dY):\n batch, seq_len, n_heads, head_dim = dY.shape\n dY = dY.reshape(batch*seq_len, n_heads*head_dim)\n # Must be reshape not view\n n_rows, n_cols = dY.shape\n\n cos = ctx.cos\n sin = ctx.sin\n\n _rope_embedding[(n_rows, ctx.n_groups, )](\n dY, dY .stride(0),\n cos, cos.stride(0),\n sin, sin.stride(0),\n seq_len, head_dim, n_heads,\n BACKWARD_PASS = True,\n BLOCK_SIZE = ctx.BLOCK_SIZE,\n num_warps = ctx.num_warps,\n )\n dY = dY.view(batch, seq_len, n_heads, head_dim)\n return dY, None, None,\n \n\ndef fast_rope_embedding(Q, K, cos, sin):\n Q = Fast_RoPE_Embedding.apply(Q.transpose(1, 2), cos, sin).transpose(1, 2)\n K = Fast_RoPE_Embedding.apply(K.transpose(1, 2), cos, sin).transpose(1, 2)\n return Q, K\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e718e11063714bd1afedddf558cdafbe", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe `chunk_global_reversed_cumsum_vector` function performs a reversed cumulative sum on a 4D tensor `s` using a custom Triton kernel, suitable for batched operations with attention-like structures. The function begins by extracting the dimensions of the input tensor [B, H, T, S] and sets a spatial block size `BS` of 32. The output tensor `z` is initialized with the same shape as `s`, but with the specified `dtype`. The kernel `chunk_global_reversed_cumsum_vector_kernel` is configured to optimize over various block sizes (`BT`) and numbers of warps to efficiently handle different input sizes. Within the kernel, a loop iterates backwards over time blocks determined by `BT`, starting from the end (`tl.cdiv(T, BT) - 1`) to the beginning. A block pointer is created for both the input and output tensors. The kernel loads the input block `b_s`, computes a masked dot product with a lower triangular mask `m_s` to ensure the cumulative sum is reversed, and stores the result in the corresponding block of the output tensor `z`. This operation ensures only the elements at or after the current time step are included. The cumulative sum `b_z` is updated after processing each block, adding the sum of the current block `b_s` along the feature dimension. The kernel ensures boundary checks during load and store operations to manage out-of-bounds memory accesses. The result is returned in `z` after the kernel execution, representing the reversed cumulative sum across the specified dimension.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_global_reversed_cumsum_vector_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] <= o_i[None, :], 1., 0.)\n\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT) - 1, -1, -1):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\ndef chunk_global_reversed_cumsum_vector(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_global_reversed_cumsum_vector_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\n\n\n" + ] + }, + { + "id": "openseek-8-0ad7d8e8cf6e43cbbe395404701e785b", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton operator `logsumexp_fwd_kernel` calculates the log-sum-exp over the last dimension of an input tensor. This is often used to stabilize the computation of the exponential function on input elements, especially with large values. The kernel operates in parallel using Triton, leveraging the GPU architecture for efficiency.\n\n The `logsumexp_fwd_kernel` kernel is decorated with `triton.autotune` to optimize its execution across different hardware configurations, particularly varying the number of warps between 1 and 32. It uses a heuristic `HAS_SCALE` to determine whether the input tensor should be scaled.\n\n Within the kernel, we calculate the indices `i_n` and `i_d` using Triton's `program_id`, which represent the current program's block IDs. The indices `o_d` and `m_d` determine the positions to operate on and apply masking to avoid out-of-bounds memory accesses. The input tensor `x` is loaded into a block `b_x`, potentially scaled, and the maximum value `b_m` is computed along the block dimension. The kernel then calculates the log of the sum of exponentials, adds the maximum value `b_m`, and stores the result in `z`.\n\n The `logsumexp_fwd` function prepares the input tensor by reshaping it and calculates `N`, `D`, and `B`. It determines `ND`, the division of `D` by `B`, creating an empty tensor `z` to hold results. The kernel is invoked with `(N, ND)` blocks. After computation, `z` is reduced along the last dimension to produce the final result. If a different output data type is desired, it casts `z` to this type before returning it.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=['D']\n)\n@triton.heuristics({\n 'HAS_SCALE': lambda args: args['scale'] is not None\n})\n@triton.jit\ndef logsumexp_fwd_kernel(\n x,\n z,\n scale,\n D: tl.constexpr,\n B: tl.constexpr,\n HAS_SCALE: tl.constexpr\n):\n i_n, i_d = tl.program_id(0).to(tl.int64), tl.program_id(1).to(tl.int64)\n o_d = i_d * B + tl.arange(0, B)\n m_d = o_d < D\n\n b_x = tl.load(x + i_n * D + o_d, mask=m_d, other=-float('inf'))\n if HAS_SCALE:\n b_x = b_x * scale\n b_m = tl.max(b_x, 0)\n b_z = tl.log(tl.sum(tl.exp(b_x - b_m), 0)) + b_m\n tl.store(z + i_n * tl.cdiv(D, B) + i_d, b_z)\n\ndef logsumexp_fwd(\n x,\n scale: Optional[float] = None,\n dtype: Optional[torch.dtype] = None\n):\n r\"\"\"\n Compute the logsumexp of the input tensor over the last dimension.\n\n Args:\n x (Tensor):\n The input tensor of any shape.\n scale (Optional[float]):\n The scale applied to the input tensor. Default: `None`.\n dtype (Optional[torch.dtype]):\n The data type of the output tensor. Default: `None`.\n Returns:\n Tensor: The logsumexp of the input tensor.\n \"\"\"\n\n shape = x.shape\n x = x.view(-1, shape[-1])\n N, D = x.shape\n B = min(triton.next_power_of_2(D), 64 * 1024)\n ND = triton.cdiv(D, B)\n\n z = x.new_empty(N, ND, dtype=torch.float)\n logsumexp_fwd_kernel[(N, ND)](\n x=x,\n z=z,\n scale=scale,\n D=D,\n B=B\n )\n z = z.logsumexp(-1).view(*shape[:-1])\n if dtype is not None and dtype != torch.float:\n z = z.to(dtype)\n return z\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1ce5399a6c2240e99f7e6103e58b91b2", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-based custom operation consists of several key components designed to efficiently perform parallel computations on batch tensor data, often seen in transformer models. The main kernel functions are `parallel_rebased_fwd_kernel` and `parallel_rebased_bwd_kernel`. The forward kernel (`parallel_rebased_fwd_kernel`) calculates attention scores by performing matrix multiplications between query (`q`), key (`k`), and value (`v`) tensors, with scaling and optional normalization, and outputs the result into tensor `o`. An additional tensor `z` holds a normalization factor computed from the scores. This kernel leverages Triton's block-wise parallelism by dividing the computations into blocks of size defined by `BTL`, `BTS`, `BK`, and `BV`.\n\n The backward pass is handled by `parallel_rebased_bwd_kernel`, which calculates gradients for input tensors using previously computed scores and the derivative outputs (`do` and `dz`). This is further divided into two helper functions `_parallel_rebased_bwd_dq` and `_parallel_rebased_bwd_dkv` to compute gradients for queries (`dq`) and key-value pairs (`dk`, `dv`) respectively. These functions utilize the block pointer API to efficiently handle memory access patterns for large tensor dimensions.\n\n The Triton kernel operations are encapsulated within the `ParallelBasedFunction` class that inherits from `torch.autograd.Function`, allowing for automatic differentiation. The `forward` method initializes the output tensors and launches the forward kernel. The `backward` method prepares the gradient tensors and invokes the backward kernel, returning gradients with respect to the inputs.\n\n The `parallel_rebased` function is a user-facing API that sets up and applies the custom operation. It supports optional scaling (`use_scale`) and normalization (`use_normalize`) of outputs, with configurable options such as `return_both` to return both output and normalization factors. It asserts that the feature dimension does not exceed 128, ensuring compatibility with the defined Triton grid and block sizes.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef parallel_rebased_fwd_kernel(\n q, # query [B, H, L, D_head_K]\n k, # key [B, H, L, D_head_V]\n v, # value [B, H, L, D_head_V]\n o, # output [B, H, L, D_head_V]\n z, # normalizer [B, H, L]\n s_qk_h, # stride size: L * D_head_K\n s_qk_t, # stride size: D_head_K\n s_qk_d, # stride size: 1\n s_vo_h, # stride size: L * D_head_V\n s_vo_t, # stride size: D_head_V\n s_vo_d, # stride size: 1\n scale, # D_head_K ** -0.5\n B, # batch size\n H, # H\n T, # T\n K: tl.constexpr, # D_head_K\n V: tl.constexpr, # D_head_V\n BTL: tl.constexpr, # BLOCK SIZE along the sequence dimension for Q\n BTS: tl.constexpr, # BLOCK SIZE along the sequence dimension for K/V\n BK: tl.constexpr, # BLOCK SIZE along the K dimension\n BV: tl.constexpr, # BLOCK SIZE along the V dimension\n):\n # i_c: chunk index. used for sequence parallelism\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(V, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, 0), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (0, i_v * BV), (BTS, BV), (1, 0))\n\n # [BQ, BD] block Q, in the shared memory throughout the whole kernel\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_o = tl.zeros([BTL, BV], dtype=tl.float32)\n b_z = tl.zeros([BTL], dtype=tl.float32)\n\n # Q block and K block have no overlap\n # no need for mask, thereby saving flops\n for _ in range(0, i_c * BTL, BTS):\n # [BK, BTS]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n\n # [BTS, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n b_s = tl.dot(b_q, (b_k), allow_tf32=False)\n b_s = b_s * b_s\n b_z += tl.sum(b_s, axis=1)\n\n # [BQ, BD]\n b_o = b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n\n # # rescale interchunk output\n tl.debug_barrier()\n o_q = tl.arange(0, BTL)\n # # sync threads, easy for compiler to optimize\n # tl.debug_barrier()\n\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_c * BTL), (BK, BTS), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c * BTL, i_v * BV), (BTS, BV), (1, 0))\n # Q block and K block have overlap. masks required\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n # [BK, BTS]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BTS, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n m_s = o_q[:, None] >= o_k[None, :]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n b_s = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_z += tl.sum(b_s, axis=1)\n # [BTL, BV]\n b_o += tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n p_k = tl.advance(p_k, (0, BTS))\n p_v = tl.advance(p_v, (BTS, 0))\n o_k += BTS\n\n p_o = tl.make_block_ptr(o + (i_bh + B * H * i_k) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n p_z = z + (i_bh + B * H * i_k) * T + i_c * BTL + tl.arange(0, BTL)\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_z, b_z.to(p_z.dtype.element_ty),\n mask=((i_c * BTL + tl.arange(0, BTL)) < T))\n\n@triton.jit\ndef _parallel_rebased_bwd_dq(\n i_bh,\n i_c,\n i_k,\n i_v,\n i_h,\n q,\n k,\n v,\n do,\n dz,\n dq,\n s_k_h,\n s_k_t,\n s_k_d,\n s_v_h,\n s_v_t,\n s_v_d,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BTL: tl.constexpr,\n BTS: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n p_do = tl.make_block_ptr(do + i_bh * s_v_h, (T, V), (s_v_t, s_v_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n p_q = tl.make_block_ptr(q + (i_bh) * s_k_h, (T, K),\n (s_k_t, s_k_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_q = (b_q * scale).to(b_q.dtype)\n b_dq = tl.zeros([BTL, BK], dtype=tl.float32)\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K),\n (s_k_t, s_k_d), (0, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T),\n (s_v_d, s_v_t), (i_v * BV, 0), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i_c * BTL + tl.arange(0, BTL)\n b_dz = tl.load(p_dz, mask=(i_c * BTL + tl.arange(0, BTL)) < T)\n\n for _ in range(0, i_c * BTL, BTS):\n # [BTS, BK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BV, BTS]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n # [BQ, BD]\n b_dq += tl.dot((2 * b_ds * b_s).to(b_v.dtype), b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n\n b_dq *= scale\n o_q = tl.arange(0, BTL)\n o_k = tl.arange(0, BTS)\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K),\n (s_k_t, s_k_d), (i_c * BTL, i_k * BK), (BTS, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (V, T),\n (s_v_d, s_v_t), (i_v * BV, i_c * BTL), (BV, BTS), (0, 1))\n # Q block and K block have overlap. masks required\n for _ in range(i_c * BTL, (i_c + 1) * BTL, BTS):\n # [BTS, BK]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BV, BTS]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BTL, BTS]\n m_s = o_q[:, None] >= o_k[None, :]\n b_ds = tl.dot(b_do, b_v, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[:, None]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n b_s = tl.dot(b_q, tl.trans(b_k), allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n # [BTL, BK]\n b_dq += tl.dot((2 * b_ds * b_s).to(b_k.dtype),\n b_k, allow_tf32=False)\n p_k = tl.advance(p_k, (BTS, 0))\n p_v = tl.advance(p_v, (0, BTS))\n o_k += BTS\n p_dq = tl.make_block_ptr(dq + (i_bh + B * H * i_v) * s_k_h, (T, K),\n (s_k_t, s_k_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_k_h, s_k_t, s_k_d, s_v_h,\n s_v_t, s_v_d,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BTL: tl.constexpr,\n BTS: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n):\n # compute dk dv\n p_k = tl.make_block_ptr(k + i_bh * s_k_h, (T, K), (s_k_t, s_k_d),\n (i_c * BTL, i_k * BK), (BTL, BK), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_v_h, (T, V), (s_v_t, s_v_d),\n (i_c * BTL, i_v * BV), (BTL, BV), (1, 0))\n b_k, b_v = tl.load(p_k, boundary_check=(0, 1)), tl.load(\n p_v, boundary_check=(0, 1))\n b_dk, b_dv = tl.zeros([BTL, BK], dtype=tl.float32), tl.zeros(\n [BTL, BV], dtype=tl.float32)\n\n for i in range((tl.cdiv(T, BTS) * BTS)-BTS, (i_c + 1) * BTL - BTS, -BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1)) # [BK, BTS]\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype) # [BV, BTS]\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n b_s = tl.dot(b_k.to(b_q.dtype), b_q, allow_tf32=False) * \\\n scale # [BTL, BTS]\n b_s2 = b_s * b_s\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_ds = tl.dot(b_v, b_do, allow_tf32=False) * scale\n if i_v == 0:\n b_ds += b_dz[None, :] * scale\n else:\n b_ds = b_ds\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n\n tl.debug_barrier()\n o_q, o_k = tl.arange(0, BTS), tl.arange(0, BTL)\n for i in range(i_c*BTL, (i_c+1)*BTL, BTS):\n p_q = tl.make_block_ptr(\n q + i_bh * s_k_h, (K, T), (s_k_d, s_k_t), (i_k * BK, i), (BK, BTS), (0, 1))\n p_do = tl.make_block_ptr(\n do + i_bh * s_v_h, (V, T), (s_v_d, s_v_t), (i_v * BV, i), (BV, BTS), (0, 1))\n p_dz = dz + i_bh * T + i + tl.arange(0, BTS)\n b_q = tl.load(p_q, boundary_check=(0, 1)) # [BD, BQ]\n b_do = tl.load(p_do, boundary_check=(0, 1)).to(b_q.dtype)\n b_dz = tl.load(p_dz, mask=(i + tl.arange(0, BTS)) < T)\n # [BK, BQ]\n m_s = o_k[:, None] <= o_q[None, :]\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s2 = b_s * b_s\n b_s = tl.where(m_s, b_s, 0)\n b_s2 = tl.where(m_s, b_s2, 0)\n\n b_ds = tl.dot(b_v, b_do, allow_tf32=False)\n if i_v == 0:\n b_ds += b_dz[None, :]\n else:\n b_ds = b_ds\n b_ds = tl.where(m_s, b_ds, 0) * scale\n # [BK, BD]\n b_dv += tl.dot(b_s2.to(b_q.dtype), tl.trans(b_do), allow_tf32=False)\n b_dk += tl.dot((2 * b_ds * b_s).to(b_q.dtype),\n tl.trans(b_q), allow_tf32=False)\n o_q += BTS\n\n p_dk = tl.make_block_ptr(dk + (i_bh + B * H * i_v) * s_k_h,\n (T, K), (s_k_t, s_k_d), (i_c*BTL, i_k*BK), (BTL, BK), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_bh + B * H * i_k) * s_v_h,\n (T, V), (s_v_t, s_v_d), (i_c*BTL, i_v*BV), (BTL, BV), (1, 0))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n return\n\n@triton.jit\ndef parallel_rebased_bwd_kernel(\n q,\n k,\n v,\n do,\n dz,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n scale,\n B: tl.constexpr,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BTL: tl.constexpr,\n BTS: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_kv, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n NV = tl.cdiv(V, BV)\n i_k = i_kv // (NV)\n i_v = i_kv % (NV)\n i_h = i_bh % H\n _parallel_rebased_bwd_dq(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dq, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d, scale,\n B=B, H=H, T=T, K=K, V=V, BTL=BTL, BTS=BTS, BK=BK, BV=BV\n )\n tl.debug_barrier()\n _parallel_rebased_bwd_dkv(\n i_bh, i_c, i_k, i_v, i_h,\n q, k, v, do, dz, dk, dv, s_qk_h, s_qk_t, s_qk_d, s_vo_h,\n s_vo_t, s_vo_d,\n scale,\n B=B, H=H, T=T, K=K, V=V, BTL=BTL, BTS=BTS, BK=BK, BV=BV\n )\n\nclass ParallelBasedFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, scale):\n BTL, BTS = 128, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n B, H, T, K, V = *k.shape, v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n grid = (NK * NV, triton.cdiv(T, BTL), B * H)\n\n assert NK == 1, \"will encounter some synchronization issue if not.\"\n\n o = torch.empty(NK, B, H, T, V, device=q.device)\n z = torch.empty(NK, B, H, T, device=q.device)\n parallel_rebased_fwd_kernel[grid](\n q, k, v, o, z,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v)\n ctx.scale = scale\n return o.sum(0).to(q.dtype), z.sum(0).to(q.dtype)\n\n @staticmethod\n def backward(ctx, do, dz):\n q, k, v = ctx.saved_tensors\n scale = ctx.scale\n BTL, BTS = 64, 32\n assert BTL % BTS == 0\n BK = min(128, triton.next_power_of_2(k.shape[-1]))\n BV = min(128, triton.next_power_of_2(v.shape[-1]))\n BK, BV = max(BK, 16), max(BV, 16)\n B, H, T, K, V = *k.shape, v.shape[-1]\n num_stages = 2\n num_warps = 4\n NK = triton.cdiv(K, BK)\n NV = triton.cdiv(V, BV)\n grid = (NK * NV, triton.cdiv(T, BTL), B * H)\n\n assert NK == 1, \"will encounter some synchronization issue if not\"\n\n dq = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)\n dk = torch.empty(NV, B, H, T, K, dtype=q.dtype, device=q.device)\n dv = torch.empty(NK, B, H, T, V, dtype=q.dtype, device=q.device)\n\n parallel_rebased_bwd_kernel[grid](\n q, k, v, do, dz, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n scale,\n B=B, H=H, T=T, K=K, V=V,\n BTL=BTL, BTS=BTS, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n return dq.sum(0).to(q.dtype), dk.sum(0).to(k.dtype), dv.sum(0).to(v.dtype), None\n\ntriton_parallel_based = ParallelBasedFunction.apply\n\ndef parallel_rebased(q, k, v, eps=1e-5, use_scale=True, use_normalize=True, return_both=False):\n assert q.shape[-1] <= 128, \"only support feature dim up to 128\"\n if use_scale:\n scale = q.shape[-1] ** -0.5\n else:\n scale = 1\n o, z = triton_parallel_based(q, k, v, scale)\n if return_both:\n return o, z\n if use_normalize:\n o = o / (z[..., None] + eps)\n else:\n o = o\n return o.to(q.dtype)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-24e7277ab16c4c85852750c5253e2780", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code implements two optimized GPU kernels using Triton: `matmul4_kernel` and `dequantize_kernel`. Both kernels leverage Triton's JIT compilation and autotuning capabilities to efficiently handle specific matrix operations commonly used in neural network computations, especially those involving quantized weights.\n\n 1. `matmul4_kernel`:\n - Purpose: Perform a quantized matrix multiplication where matrix `B` is in a 4-bit integer format, while `A` and `C` are in `float16`.\n - Inputs: \n - `a_ptr`, `b_ptr`, `c_ptr`: Pointers to matrices `A`, `B`, and `C`.\n - `scales_ptr`, `zeros_ptr`: Pointers to scale and zero-point matrices for quantization.\n - `M`, `N`, `K`: Dimensions of the matrices.\n - `stride_am`, `stride_ak`, `stride_bk`, `stride_bn`, `stride_cm`, `stride_cn`: Strides for memory access.\n - `stride_scales_g`, `stride_scales_n`, `stride_zeros_g`, `stride_zeros_n`: Strides for scales and zeros.\n - `groupsize`, `NO_GROUPS`: Meta-parameters for group handling.\n - Operation: Multiplies `A` with a dequantized version of `B`, using provided scales and zeros to adjust the values. It operates on blocks of size `BLOCK_SIZE_M x BLOCK_SIZE_N x BLOCK_SIZE_K`.\n - Characteristics: Handles grouped scaling efficiently by preloading scales and zero-points, and decomposes the operation into several small matrix multiplies over the specified block sizes.\n\n 2. `dequantize_kernel`:\n - Purpose: Unpacks and dequantizes a 4-bit quantized matrix `B` to full precision.\n - Inputs:\n - `b_ptr`: Pointer to the int4 quantized weights.\n - `b_scale_ptr`, `b_zp_ptr`: Pointers to scale and zero-point matrices.\n - `fpb_ptr`: Pointer to store the dequantized full precision result.\n - `K`, `N`, `group_size`: Dimensions and grouping size.\n - `stride_bk`, `stride_bn`, `stride_bsk`, `stride_bsn`, `stride_bzpk`, `stride_bzpn`, `stride_fpbk`, `stride_fpbn`: Strides for respective matrices.\n - Operation: Converts packed 4-bit integers into full-precision floats using group-wise scales and zero-points, processing each tile of size `BLOCK_SIZE_K x BLOCK_SIZE_N`.\n\n Additional Functions:\n - `dequantize_int4`: A helper function using `dequantize_kernel` to convert an int4 matrix to float.\n - `matmul_dequantize_int4_s1`: Dequantizes matrix `B` using `dequantize_int4` and performs a matrix multiplication with `A`.\n - `quantize_int4`: Quantizes a given weight matrix into a 4-bit format, preparing it for use in the above kernels. It calculates scales and zero-points based on groups and packs values into `int32`.\n\n The code is structured to facilitate the handling of quantized matrices, optimizing for modern hardware's parallel execution capabilities. The kernels are auto-tuned for varying hardware specifications, ensuring efficient execution.\n ", + "output": [ + "import time\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n\tconfigs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), \n ],\n\tkey=['M', 'N', 'K', 'NO_GROUPS'],\n)\n@triton.jit\ndef matmul4_kernel(\n\ta_ptr, b_ptr, c_ptr,\n\tscales_ptr, zeros_ptr,\n\tM, N, K,\n\tstride_am, stride_ak,\n\tstride_bk, stride_bn,\n\tstride_cm, stride_cn,\n\tstride_scales_g, stride_scales_n,\n\tstride_zeros_g, stride_zeros_n,\n\tgroupsize, NO_GROUPS: tl.constexpr,\n\tBLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n\tGROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N//8) int32\n groupsize is an int specifying the size of groups for scales and zeros.\n G is K // groupsize.\n Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient.\n WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K.\n WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N.\n WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K.\n \"\"\"\n bits = 4\n infearure_per_bits = 8\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,)\n # zeros_ptrs is set up such that it repeats elements along the N axis 8 times\n zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,)\n # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n # If G == 1, scales and zeros are the same for all K, so we can load them once\n if NO_GROUPS:\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32\t\n # Unpack zeros\n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n # zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16\n zeros = zeros * scales\n # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N)\n # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension\n # So this loop is along the infeatures dimension (K)\n # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n if not NO_GROUPS:\n g_id = k // (groupsize // BLOCK_SIZE_K)\n ptr = scales_ptrs + g_id * stride_scales_g\n scales = tl.load(ptr) # (BLOCK_SIZE_N,)\n ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,)\n zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32\t\n # Unpack zeros\n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16\t\n # Now we need to unpack b (which is 4-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values\n b = b * scales[None, :] - zeros[None, :] # Scale and shift\n # print(\"data type\", a, b)\n accumulator += tl.dot(a, b.to(a.dtype))\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk \n c = accumulator.to(c_ptr.dtype.element_ty) \n # Store the result\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n ],\n key=['K', 'N'],\n)\n@triton.jit\ndef dequantize_kernel(\n # Pointers to matrices\n b_ptr, b_scale_ptr, b_zp_ptr, fpb_ptr,\n # Matrix dimensions\n K, N, group_size,\n stride_bk, stride_bn,\n stride_bsk, stride_bsn,\n stride_bzpk, stride_bzpn,\n stride_fpbk, stride_fpbn,\n # Meta-parameters\n BLOCK_SIZE_K: tl.constexpr, BLOCK_SIZE_N: tl.constexpr,\n):\n \"\"\"Dequantize tile [BLOCK_SIZE_K, BLOCK_SIZE_N] in full precision.\n We should assert BLOCK_SIZE_N % 8 == 0.\n weight[K // 8, N], scale[K // group_size, N], zp[K // group_size, N // group_size]\n \"\"\"\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = k_block_idx * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n offs_n = n_block_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n fpb_offs = offs_k[:, None] * stride_fpbk + offs_n[None, :] * stride_fpbn\n b_offs = (offs_k[:, None] // 8) * stride_bk + offs_n[None, :] * stride_bn\n bzp_offs = (offs_k[:, None] // group_size) * stride_bzpk + (offs_n[None, :] // 8) * stride_bzpn\n bs_offs = (offs_k[:, None] // group_size) * stride_bsk + offs_n[None, :] * stride_bsn\n n_mask = offs_n[None, :] < N\n k_mask = offs_k[:, None] < K\n mask = n_mask & k_mask\n int32_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n zp_b = tl.load(b_zp_ptr + bzp_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=mask, other=0.0)\n b_shift = (offs_k[:, None] % 8) * 4\n bzp_shift = (offs_n[None, :] % 8) * 4\n fp_weight = (((int32_b >> b_shift) & 0xF) - ((zp_b >> bzp_shift) & 0xF)) * scale_b\n tl.store(fpb_ptr + fpb_offs, fp_weight, mask=mask)\n\n\ndef dequantize_int4(b, b_scale, b_zero_point, device, dtype, group_size):\n Kw, N = b.shape\n K = Kw * 8\n fp_b = torch.ones((K, N), device=device, dtype=dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']),\n triton.cdiv(N, META['BLOCK_SIZE_N']), \n )\n dequantize_kernel[grid](\n b, b_scale, b_zero_point, fp_b,\n K, N, group_size,\n b.stride(0), b.stride(1),\n b_scale.stride(0), b_scale.stride(1),\n b_zero_point.stride(0), b_zero_point.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n return fp_b\n\n\ndef matmul_dequantize_int4_s1(a, b, b_scale, b_zero_point, group_size=128, out=None):\n \"\"\"\n Matmul dequantize int4 s1 dequantize weight to `fp_b` and do fp16 torch.mm,\n this is for `prefill` stage, since weight size is fixed so is dequantize overhead,\n perfill stage have more tokens to amortize dequant cost.\n \"\"\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n # assert b.is_contiguous(), \"Matrix B must be contiguous\"\n M, K = a.shape\n Kw, N = b.shape\n if out is None:\n # Allocates output.\n out = torch.empty((M, N), device=a.device, dtype=a.dtype)\n fp_b = dequantize_int4(b, b_scale, b_zero_point, a.device, a.dtype, group_size)\n torch.mm(a, fp_b, out=out)\n fp_b = None\n return out\n\n\ndef quantize_int4(weight, group_size=128, tp_rank=0):\n # Weight shape: [H1 // 8, H2]\n # Scale shape: [H1 // group_size, H2]\n # zero_pint shape: [H1 // group_size, H2 // 8]\n\n weight = weight.transpose(1, 0)\n h1, h2 = weight.shape\n assert h1 % 8 == 0 and h2 % 8 == 0, \"H1 {} H2 {}\".format(h1, h2)\n assert h2 % group_size == 0, \"H1 {} H2 {}\".format(h1, h2)\n weight = weight.contiguous().view(-1, group_size).cuda(tp_rank)\n weight_max = weight.amax(-1, keepdim=True)\n weight_max = torch.where(weight_max < 0, 0, weight_max)\n weight_min = weight.amin(-1, keepdim=True)\n weight_min = torch.where(weight_min > 0, 0, weight_min)\n weight_range = weight_max - weight_min \n scale = weight_range / (2 ** 4 - 1)\n zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32)\n weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2)\n int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device)\n int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device)\n zero_point = zero_point.view(h1, -1)\n scale = scale.view(h1, -1)\n # pack 8 int4 in an int32 number.\n # Weight pack in row.\n for pack in range(0, h2, 8):\n for i in range(8):\n int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4)\n # zero point pack in col.\n for pack in range(0, h1, 8):\n for i in range(8):\n int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4)\n '''\n fp_weight = torch.zeros(h1, h2).half().to(weight.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_weight[pack * 8 + i, :] = \\\n ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16\n print((fp_weight - weight).abs().sum())\n\n fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zp[pack * 8 + i, :] = \\\n (int_zero_point[pack, :] >> (i * 4)) & 15\n\n print((fp_zp - zero_point).abs().sum())\n '''\n weight = None\n return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size\n\n\n\n" + ] + }, + { + "id": "openseek-8-34bd0e3b8b7b41efb6c21f3b1b591b5d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel '_fwd_kernel_destindex_copy_quantize_kv' takes input tensors 'K' and 'Dest_loc', along with output tensors 'Out' and 'Out_scale'. It processes elements in blocks defined by 'BLOCK_HEAD' and 'BLOCK_DMODEL' which are powers of two based on dimensions of 'K'. The kernel computes per-block maximum absolute values to determine scales for quantization, scales the input data, converts it to int8, and stores both the quantized data and corresponding scales to the output tensors. Indices from 'Dest_loc' determine positions in the output where processed blocks are stored. The 'destindex_copy_quantize_kv' function is a no-gradient PyTorch wrapper that calls the Triton kernel, initializing execution parameters such as grid size from 'seq_len', determining block sizes, and specifying constraints on warps and stages.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_quantize_kv(\n K,\n Dest_loc,\n Out,\n Out_scale,\n stride_k_bs,\n stride_k_h,\n stride_k_d,\n stride_o_bs,\n stride_o_h,\n stride_o_d,\n stride_os_bs,\n stride_os_h,\n stride_os_d,\n head_num,\n head_dim,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr,\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n src_data = tl.load(\n K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],\n mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim),\n other=0.0,\n )\n abs_data = tl.abs(src_data)\n data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None]\n q_src_data = (src_data / data_scale).to(tl.int8)\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]\n tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim))\n tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num))\n\n\n@torch.no_grad()\ndef destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n BLOCK_DMODEL = triton.next_power_of_2(head_dim)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_quantize_kv[grid](\n K,\n DestLoc,\n Out,\n Out_scale,\n K.stride(0),\n K.stride(1),\n K.stride(2),\n Out.stride(0),\n Out.stride(1),\n Out.stride(2),\n Out_scale.stride(0),\n Out_scale.stride(1),\n Out_scale.stride(2),\n head_num,\n head_dim,\n BLOCK_DMODEL=BLOCK_DMODEL,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-625b0d4fbe574ff69998317578a3afa6", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `rotary_kernel` function is a Triton kernel that performs rotary position encoding on a tensor `X` using precomputed cosine (`COS`) and sine (`SIN`) matrices. It modifies or populates the output tensor `OUT` with the transformed data. The kernel accommodates both fixed and variable sequence lengths, controlled by the presence of `CU_SEQLENS`. The kernel handles interleaved and non-interleaved formats and allows for in-place transformations and conjugate computations if specified.\n\n The kernel operates in a three-dimensional grid, processing batches (`pid_batch`), heads (`pid_head`), and sequences (`pid_m`). It calculates transformations by loading blocks of data and applying rotary transformations based on cosine and sine values. The key operations are tailored based on whether the data is interleaved or not, with conditional handling for conjugation using `CONJUGATE`.\n\n The `apply_rotary` function acts as a high-level interface to the Triton kernel. It accepts the input tensor `x`, cosine and sine matrices, sequence length offsets, and optional cumulative sequence lengths (`cu_seqlens`). The function determines the execution grid and block sizes, aligning them with the input data shape and configuration. It initializes an output tensor, copying non-rotary parts of `x` if required. The function ensures that the kernel is called with appropriate arguments, matching the shape and type expectations set within the kernel logic. This design allows for efficient rotary transformations in transformer architectures.\n ", + "output": [ + "from typing import Optional, Union\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef rotary_kernel(\n OUT, X, COS, SIN, CU_SEQLENS, SEQLEN_OFFSETS, seqlen, nheads, rotary_dim, seqlen_ro,\n CACHE_KEY_SEQLEN, stride_out_batch, stride_out_seqlen, stride_out_nheads,\n stride_out_headdim, stride_x_batch, stride_x_seqlen, stride_x_nheads,\n stride_x_headdim, BLOCK_K: tl.constexpr, IS_SEQLEN_OFFSETS_TENSOR: tl.constexpr,\n IS_VARLEN: tl.constexpr, INTERLEAVED: tl.constexpr, CONJUGATE: tl.constexpr, BLOCK_M: tl.constexpr,\n):\n pid_m = tl.program_id(axis=0)\n pid_batch = tl.program_id(axis=1)\n pid_head = tl.program_id(axis=2)\n rotary_dim_half = rotary_dim // 2\n\n if not IS_VARLEN:\n X = X + pid_batch * stride_x_batch + pid_head * stride_x_nheads\n OUT = OUT + pid_batch * stride_out_batch + pid_head * stride_out_nheads\n else:\n start_idx = tl.load(CU_SEQLENS + pid_batch)\n seqlen = tl.load(CU_SEQLENS + pid_batch + 1) - start_idx\n X = X + start_idx * stride_x_seqlen + pid_head * stride_x_nheads\n OUT = OUT + start_idx * stride_out_seqlen + pid_head * stride_out_nheads\n\n if pid_m * BLOCK_M >= seqlen:\n return\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n if not IS_SEQLEN_OFFSETS_TENSOR:\n rm_cs = rm + SEQLEN_OFFSETS\n else:\n rm_cs = rm + tl.load(SEQLEN_OFFSETS + pid_batch)\n rk = tl.arange(0, BLOCK_K)\n rk_half = tl.arange(0, BLOCK_K // 2)\n\n if not INTERLEAVED:\n X = X + (rm[:, None] * stride_x_seqlen + rk_half[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_half[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x1 = tl.load(X + rotary_dim_half * stride_x_headdim, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n o0 = x0 * cos - x1 * sin\n o1 = x0 * sin + x1 * cos\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk_half[None, :] * stride_out_headdim)\n tl.store(OUT, o0, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n tl.store(OUT + rotary_dim_half * stride_out_headdim, o1, mask=(rm[:, None] < seqlen) & (rk_half[None, :] < rotary_dim_half))\n else:\n rk_swap = rk + ((rk + 1) % 2) * 2 - 1\n rk_repeat = tl.arange(0, BLOCK_K) // 2\n X0 = X + (rm[:, None] * stride_x_seqlen + rk[None, :] * stride_x_headdim)\n X1 = X + (rm[:, None] * stride_x_seqlen + rk_swap[None, :] * stride_x_headdim)\n COS = COS + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n SIN = SIN + (rm_cs[:, None] * rotary_dim_half + rk_repeat[None, :])\n cos = tl.load(COS, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=1.0).to(tl.float32)\n sin = tl.load(SIN, mask=(rm_cs[:, None] < seqlen_ro) & (rk_repeat[None, :] < rotary_dim_half), other=0.0).to(tl.float32)\n x0 = tl.load(X0, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim), other=0.0).to(tl.float32)\n x1 = tl.load(X1, mask=(rm[:, None] < seqlen) & (rk_swap[None, :] < rotary_dim), other=0.0).to(tl.float32)\n if CONJUGATE:\n sin = -sin\n x0_cos = x0 * cos\n x1_sin = x1 * sin\n out = tl.where(rk[None, :] % 2 == 0, x0_cos - x1_sin, x0_cos + x1_sin)\n OUT = OUT + (rm[:, None] * stride_out_seqlen + rk[None, :] * stride_out_headdim)\n tl.store(OUT, out, mask=(rm[:, None] < seqlen) & (rk[None, :] < rotary_dim))\n\ndef apply_rotary(\n x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, seqlen_offsets: Union[int, torch.Tensor] = 0,\n cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None,\n interleaved=False, inplace=False, conjugate=False\n) -> torch.Tensor:\n is_varlen = cu_seqlens is not None\n if not is_varlen:\n batch, seqlen, nheads, headdim = x.shape\n else:\n total_seqlen, nheads, headdim = x.shape\n batch = cu_seqlens.shape[0] - 1\n seqlen = max_seqlen\n seqlen_ro, rotary_dim = cos.shape\n rotary_dim *= 2\n\n cos, sin = cos.contiguous(), sin.contiguous()\n if isinstance(seqlen_offsets, torch.Tensor):\n seqlen_offsets = seqlen_offsets.contiguous()\n else:\n seqlen_offsets += seqlen\n\n output = torch.empty_like(x) if not inplace else x\n if rotary_dim < headdim and not inplace:\n output[..., rotary_dim:].copy_(x[..., rotary_dim:])\n\n BLOCK_K = 32 if rotary_dim <= 32 else (64 if rotary_dim <= 64 else (128 if rotary_dim <= 128 else 256))\n grid = lambda META: (triton.cdiv(seqlen, META[\"BLOCK_M\"]), batch, nheads)\n BLOCK_M = 4 if interleaved else (8 if rotary_dim <= 64 else 4)\n\n with torch.cuda.device(x.device.index):\n rotary_kernel[grid](\n output, x, cos, sin, cu_seqlens, seqlen_offsets, seqlen, nheads, rotary_dim,\n seqlen_ro, seqlen // 128, output.stride(0) if not is_varlen else 0, output.stride(-3),\n output.stride(-2), output.stride(-1), x.stride(0) if not is_varlen else 0,\n x.stride(-3), x.stride(-2), x.stride(-1), BLOCK_K,\n isinstance(seqlen_offsets, torch.Tensor), is_varlen, interleaved, conjugate, BLOCK_M\n )\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-3b4061622f5e48c4b8111703e07eaebb", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton JIT kernel `_fwd_kernel_apply_penalty` is designed to adjust the logits of a language model based on three types of penalties: presence, frequency, and repetition. These adjustments help in controlling the model's output behavior concerning repetitive and previously seen tokens. The kernel operates per batch (`cur_batch`) and involves the following steps: it first loads the penalty coefficients for the current batch. It then determines the range of token indices for the batch using `p_cumsum_seq_len`. Using these indices, it loads token IDs (`batch_ids`) and their corresponding counts (`batch_ids_count`) from `p_token_ids` and `p_token_counts`, respectively. For each token, it adjusts the logit based on repetition, then frequency, and finally presence penalties. The adjusted logits are stored back into the `Logits` tensor. The function `apply_penalty` is a Python wrapper that ensures `Logits` is contiguous and calculates the appropriate block size (`BLOCK`) using Triton's `next_power_of_2` function, ensuring it adheres to a minimum size for efficiency. It then invokes the kernel for each batch dimension (`Logits.shape[0]`), passing necessary parameters like penalties, token IDs, counts, cumulative sequence lengths, and strides of the logits tensor. The number of warps is fixed at 8 to balance performance and resource usage.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_apply_penalty(\n Logits, presence_penalty, freqency_penalty, repetition_penalty,\n p_token_ids, p_token_counts, p_cumsum_seq_len, \n stride_logit_b, stride_logit_s,\n BLOCK_P: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_freqency = tl.load(freqency_penalty + cur_batch)\n cur_presence = tl.load(presence_penalty + cur_batch)\n cur_repetition = tl.load(repetition_penalty + cur_batch)\n\n cur_batch_start_index = tl.load(p_cumsum_seq_len + cur_batch)\n cur_batch_end_index = tl.load(p_cumsum_seq_len + cur_batch + 1)\n\n cur_batch_id_offset = cur_batch_start_index + tl.arange(0, BLOCK_P)\n batch_ids = tl.load(p_token_ids + cur_batch_id_offset, mask=cur_batch_id_offset 0, cur_logits / cur_repetition, cur_logits * cur_repetition)\n freq_logits = rep_logits - batch_ids_count * cur_freqency\n pre_logits = freq_logits - cur_presence\n output_ptr = Logits + cur_batch * stride_logit_b + batch_ids\n tl.store(output_ptr, pre_logits, mask=cur_batch_id_offset 128 else 32 \n BV = 64 if BK <= 64 else BV\n BC = 16 if BK > 128 else 32 \n BC = 64 if BK <= 64 else BC\n BC = min(BT, BC)\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n assert NK == 1, 'NK > 1 is not supported because it involves time-consuming synchronization'\n\n h = k.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n v_new = torch.empty_like(u)\n chunk_delta_rule_fwd_kernel_h[grid](\n k, u, w, v_new, h, initial_state, final_state,\n k.stride(1), k.stride(2), k.stride(3),\n u.stride(1), u.stride(2), u.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BC=BC, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None,\n )\n return h, v_new\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e7e7e1eaed1a49e3a5f4442901f8dc02", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel `matmul_kernel` to perform matrix multiplication C = A x B, where A has shape (M, K) and B has shape (K, N), resulting in matrix C with shape (M, N). The kernel processes blocks of A and B defined by `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, and `BLOCK_SIZE_K`, with parallel execution managed using a grid of program IDs. The kernel computes a block of C by iteratively loading blocks of A and B, performing dot products, and accumulating the results in the `accumulator`. Optionally, a leaky ReLU activation is applied to the `accumulator`. The final block result is stored in C using Triton's memory operations, with handling for matrix boundaries via masks. The `matmul` function serves as a high-level interface, ensuring input validity, preparing execution parameters, and invoking the kernel. It accepts optional activation, setting `ACTIVATION` based on user input.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr,\n ACTIVATION: tl.constexpr\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n # Create pointers for the first blocks of A and B\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n # Compute a block of the C matrix\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n \n # Apply activation function if specified\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n \n c = accumulator.to(tl.float16)\n\n # Write back the block of the output matrix C with masks\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']), )\n matmul_kernel[grid](\n a, b, c,\n M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n BLOCK_SIZE_M=32, BLOCK_SIZE_N=32, BLOCK_SIZE_K=32,\n GROUP_SIZE_M=4,\n ACTIVATION=activation\n )\n return c\n\n\n\n" + ] + }, + { + "id": "openseek-8-127b6e23c995470f9ff213462e8c98f5", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton JIT kernel `_fwd_kernel` is designed to compute the forward pass of a self-attention mechanism with support for both regular and causal attention (controlled by the `IS_CAUSAL` parameter). It takes in matrices Q, K, and V, which represent the query, key, and value tensors of the attention mechanism. The kernel implements a block-wise approach to handle large input sizes efficiently, processing data in chunks defined by constants `BLOCK_M`, `BLOCK_N`, and `BLOCK_DMODEL`.\n\n The kernel computes a scaled dot-product between the query and key matrices, applying a softmax operation to generate attention scores. It then applies these scores to the value matrix to produce the output. The `sm_scale` parameter is used to scale the logits before applying the softmax, which is a common practice to stabilize gradients. \n\n If `IS_CAUSAL` is true, the kernel masks future positions to maintain causality in autoregressive models, ensuring that each position only attends to previous positions.\n\n The kernel supports mixed-precision computations, leveraging 8-bit floats (FP8) when keys and values are stored as `torch.int8`, controlled by `USE_FP8`. It rematerializes offsets to manage Triton's register pressure effectively.\n\n The wrapper function `triton_fa` is responsible for invoking `_fwd_kernel`, ensuring the correct configurations and input data formats. It checks data type consistency and adjusts grid and block sizes based on the input dimensions, utilizing a dynamic grid size to cover the full input. The grid is determined by `triton.cdiv(m_size, BLOCK)` for the rows, and `head_size * batch` for the second dimension, to handle all heads in parallel. The kernel uses either 4 or 8 warps based on the block size (`Lk`), and stages are managed with `num_stages=2` to balance register usage and computational latency.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nimport math\n\n@triton.jit\ndef _fwd_kernel(\n Q, K, V, sm_scale,\n Out,\n stride_qz, stride_qh, stride_qm, stride_qk,\n stride_kz, stride_kh, stride_kn, stride_kk,\n stride_vz, stride_vh, stride_vk, stride_vn,\n stride_oz, stride_oh, stride_om, stride_on,\n N_HEAD, H, N_CTX,\n start_position, # <- ADDED\n IS_CAUSAL: tl.constexpr, # <- ADDED\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n USE_FP8: tl.constexpr,\n):\n start_m = tl.program_id(0)\n\n head_idx = tl.program_id(1)\n batch_id = head_idx // N_HEAD\n off_hz = head_idx % N_HEAD\n\n # initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_q = batch_id * stride_qz + off_hz * stride_qh + offs_m[:, None] * stride_qm + offs_d[None, :] * stride_qk # <- stride fixed\n off_k = batch_id * stride_kz + off_hz * stride_kh + offs_n[:, None] * stride_kn + offs_d[None, :] * stride_kk # <- stride fixed\n off_v = batch_id * stride_vz + off_hz * stride_vh + offs_n[:, None] * stride_vk + offs_d[None, :] * stride_vn # <- stride fixed\n # Initialize pointers to Q, K, V\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n # initialize pointer to m and l\n m_prev = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_prev = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n # load q: it will stay in SRAM throughout\n q = tl.load(q_ptrs, offs_m[:, None] < H, other=0.0)\n # loop over k, v and update accumulator\n block_n_end = N_CTX # <- ADDED (including the IF)\n if IS_CAUSAL:\n # in causal mode, we expect that BLOCK_M_SIZE == BLOCK_N_SIZE\n # autotune will prune shapes not matching this rule\n block_n_end = (start_m + 1) * BLOCK_N + start_position\n for start_n in range(0, block_n_end, BLOCK_N):\n block_n_offs = start_n + offs_n # <- ADDED\n # -- compute qk ----\n k = tl.load(k_ptrs, block_n_offs[:, None] < N_CTX, 0.)\n if USE_FP8:\n k = k.to(tl.float8e5, bitcast=True)\n k = k.to(tl.float16)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, tl.trans(k))\n qk = tl.where(offs_n[None, :] < N_CTX, qk, float(\"-inf\")) # <- ADDED\n qk *= sm_scale\n if IS_CAUSAL: # <- ADDED\n qk = tl.where(offs_m[:, None] >= (block_n_offs[None, :] + start_position), qk, float(\"-inf\"))\n \n # compute new m\n m_curr = tl.maximum(tl.max(qk, 1), m_prev)\n # correct old l\n l_prev *= tl.exp(m_prev - m_curr)\n # attention weights\n p = tl.exp(qk - m_curr[:, None])\n l_curr = tl.sum(p, 1) + l_prev\n # rescale operands of matmuls\n l_rcp = 1. / l_curr\n p *= l_rcp[:, None]\n acc *= (l_prev * l_rcp)[:, None]\n # update acc\n p = p.to(Q.dtype.element_ty)\n v = tl.load(v_ptrs, block_n_offs[:, None] < N_CTX, 0.0)\n if USE_FP8:\n v = v.to(tl.float8e5, bitcast=True)\n v = v.to(tl.float16)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_prev = l_curr\n m_prev = m_curr\n # update pointers\n k_ptrs += BLOCK_N * stride_kn\n v_ptrs += BLOCK_N * stride_vk\n # rematerialize offsets to save registers\n start_m = tl.program_id(0)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n # initialize pointers to output\n offs_d = tl.arange(0, BLOCK_DMODEL)\n off_o = batch_id * stride_oz + off_hz * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :] * stride_on\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, offs_m[:, None] < H)\n\n\ndef triton_fa(q, k, v, sm_scale, is_causal, start_position):\n assert q.dtype == torch.float16\n assert k.dtype == v.dtype and k.dtype in [torch.float16, torch.int8]\n\n BLOCK = 64\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n o = torch.empty_like(q)\n num_warps = 4 if Lk <= 64 else 8\n batch, head_size, m_size, dhead = q.size()\n grid = (triton.cdiv(m_size, BLOCK), head_size * batch)\n n_size = k.size(2)\n _fwd_kernel[grid](\n q, k, v, sm_scale,\n o,\n q.stride(0), q.stride(1), q.stride(2), q.stride(3),\n k.stride(0), k.stride(1), k.stride(2), k.stride(3),\n v.stride(0), v.stride(1), v.stride(2), v.stride(3),\n o.stride(0), o.stride(1), o.stride(2), o.stride(3),\n head_size, m_size, n_size,\n start_position=start_position,\n IS_CAUSAL=is_causal,\n BLOCK_M=BLOCK,\n BLOCK_N=BLOCK,\n BLOCK_DMODEL=Lk,\n USE_FP8=k.dtype == torch.int8, # USE_FP8\n num_warps=num_warps,\n num_stages=2,\n )\n\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-ad1cb694a1f4489da9f77194250d67fd", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel function `_fwd_kernel_destindex_copy_kv` is a GPU-accelerated operation designed to perform indexed copying of tensor data from `K` to `Out`, directed by indices specified in `Dest_loc`. The function is invoked over a grid of size equal to `seq_len`, derived from the size of `Dest_loc`. The main computation leverages Triton's parallel execution model by determining the current index using `tl.program_id(0)`. It constructs pointers for both source (`k_ptrs`) and destination (`o_ptrs`) using tensor strides for batch size, head, and depth dimensions. The load operation is masked by `offs_h` to ensure it stays within the valid `head_num` range. The wrapper function `destindex_copy_kv` sets kernel parameters such as strides and grid configuration, validating that the tensor shapes of `K` and `Out` align. It computes `BLOCK_HEAD` to be the next power of 2 of `head_num` for performance optimization, using Triton's utility function. The kernel is executed with a single warp, suitable for the problem size, and encapsulates the entire logic for transferring indexed data based on head and depth dimensions.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_destindex_copy_kv(\n K, Dest_loc,\n Out,\n stride_k_bs, stride_k_h, stride_k_d,\n stride_o_bs, stride_o_h, stride_o_d,\n head_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_HEAD: tl.constexpr\n):\n cur_index = tl.program_id(0)\n offs_h = tl.arange(0, BLOCK_HEAD)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n\n dest_index = tl.load(Dest_loc + cur_index)\n\n k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]\n o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]\n\n k = tl.load(k_ptrs, mask=offs_h[:, None] < head_num, other=0.0)\n tl.store(o_ptrs, k, mask=offs_h[:, None] < head_num)\n return\n\n\n@torch.no_grad()\ndef destindex_copy_kv(K, DestLoc, Out):\n seq_len = DestLoc.shape[0]\n head_num = K.shape[1]\n head_dim = K.shape[2]\n assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2]\n BLOCK_HEAD = triton.next_power_of_2(head_num)\n grid = (seq_len,)\n num_warps = 1\n\n _fwd_kernel_destindex_copy_kv[grid](\n K, DestLoc, Out,\n K.stride(0), K.stride(1), K.stride(2),\n Out.stride(0), Out.stride(1), Out.stride(2),\n head_num,\n BLOCK_DMODEL=head_dim,\n BLOCK_HEAD=BLOCK_HEAD,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n" + ] + }, + { + "id": "openseek-8-b0b6e723724d439db31f73244d345371", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The kernel `_fwd_kernel` is part of a Triton operator for computing scaled dot-product attention. It accepts tensors Q (queries), K (keys), V (values), and additional parameters like `sm_scale`, `B_Start_Loc`, and `B_Seqlen` to handle batches of varying sequence lengths. The `context_attention_fwd` function initializes and dispatches this kernel, with arguments including strides for Q, K, V, and Out, which help index these multidimensional tensors correctly within the kernel. Inside `_fwd_kernel`, a blocking strategy defined by constants BLOCK_M, BLOCK_DMODEL, and BLOCK_N is used to partition the computation, enhancing parallelism and memory efficiency. The kernel calculates attention scores by dot-product of queries and keys, scales them, and applies a softmax operation. It accumulates results to form the output by weighting value vectors V accordingly. The `context_attention_fwd` computes the softmax scaling factor `sm_scale` as the inverse square root of query depth, ensuring numeric stability. The grid for launching the kernel is determined by the batch size, number of heads, and input length, and it adapts the number of warps based on key length Lk.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\nif triton.__version__ >= \"2.1.0\":\n @triton.jit\n def _fwd_kernel(\n Q, K, V, sm_scale, B_Start_Loc, B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度\n Out,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n stride_vbs, stride_vh, stride_vd,\n stride_obs, stride_oh, stride_od,\n BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr,\n ):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_m = tl.program_id(2)\n\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n block_start_loc = BLOCK_M * start_m\n\n # initialize offsets\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_DMODEL)\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n off_q = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] * stride_qd\n off_k = offs_n[None, :] * stride_kbs + cur_head * stride_kh + offs_d[:, None] * stride_kd\n off_v = offs_n[:, None] * stride_vbs + cur_head * stride_vh + offs_d[None, :] * stride_vd\n\n q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0)\n\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32)\n\n block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0)\n\n for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n # -- compute qk ----\n k = tl.load(k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs,\n mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, other=0.0)\n # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0)\n\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n qk *= sm_scale\n qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float(\"-inf\"))\n\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n # update acc\n v = tl.load(v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs,\n mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, other=0.0)\n\n p = p.to(v.dtype)\n acc += tl.dot(p, v)\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n # initialize pointers to output\n off_o = (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_d[None, :] * stride_od\n out_ptrs = Out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len)\n return\n\n @torch.no_grad()\n def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):\n BLOCK = 128\n # shape constraints\n Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]\n assert Lq == Lk and Lk == Lv\n assert Lk in {16, 32, 64, 128}\n\n sm_scale = 1.0 / (Lq**0.5) # 计算scale系数\n batch, head = b_seq_len.shape[0], q.shape[1]\n\n grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,\n\n num_warps = 4 if Lk <= 64 else 8\n _fwd_kernel[grid](\n q, k, v, sm_scale, b_start_loc, b_seq_len,\n o,\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n v.stride(0), v.stride(1), v.stride(2),\n o.stride(0), o.stride(1), o.stride(2),\n BLOCK_M=BLOCK,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\nelse:\n raise Exception(\"error triton version!\")\n\n\n\n" + ] + }, + { + "id": "openseek-8-94c726f3275446ec93e38981e7382d54", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a Triton kernel named `kernel_function`, which processes input data using block-wise operations. \n The kernel takes pointers to input and output data (`x_ptr` and `output_ptr`), the total number of elements to process (`n_elements`), and a constant block size (`BLOCK_SIZE`). \n Inside the kernel, each program instance calculates its starting point (`block_start`) and creates an `offsets` tensor for element indexing. \n A mask ensures operations only occur on valid indices within the input bounds. The kernel loads data from `x_ptr`, computes the sine using `tl.math.sin`, and stores the result in `output_ptr`. \n The `call_kernel` function prepares to execute the kernel by calculating the total number of elements (`n_elements`) and creates an output tensor. \n It defines a grid configuration function using lambda to handle thread block calculations based on `BLOCK_SIZE`, ensuring the entire input is processed. \n The kernel is then launched with the grid configuration, input, output, and element count.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# Kernel function using Triton\n@triton.jit\ndef kernel_function(x_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):\n # x_ptr: pointer to input data\n # output_ptr: pointer to output data\n # n_elements: number of elements to process\n # BLOCK_SIZE: block size for Triton kernel\n pid = tl.program_id(0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n output = tl.math.sin(x)\n tl.store(output_ptr + offsets, output, mask=mask)\n\n# Function to call the Triton kernel\ndef call_kernel(x):\n # x: input tensor\n n_elements = x.numel()\n output = torch.empty_like(x)\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n kernel_function[grid](x, output, n_elements, BLOCK_SIZE=1024)\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-8ef2fb4a1beb4604a7f3063d6c2d8350", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code defines a Triton-based cross-entropy loss function with optional label smoothing and scaled logits handling. The main components are two kernel functions: `cross_entropy_fwd_kernel` and `cross_entropy_bwd_kernel`, which are executed via Triton to perform computations efficiently on GPU.\n\n The `cross_entropy_fwd_kernel` computes the forward pass of the cross-entropy loss. It processes each row of logits and calculates the log-sum-exp (lse) and loss for each label, considering smoothing, logit scaling, and specific class handling for distributed tensor parallelism. The kernel handles cases where labels are ignored and adjusts loss calculation if the label index does not match the current block of classes. It uses triton's utilities to manage block operations and loop unrolling efficiently.\n\n The `cross_entropy_bwd_kernel` calculates the backward pass, deriving the gradient of the loss with respect to logits. It applies the chain rule considering the effect of label smoothing and optional scaling factors. The computed probabilities are adjusted based on whether smoothing is applied or not, and then used to compute the gradient contribution for each logit.\n\n The `CrossEntropyLoss` class encapsulates the forward and backward pass logic. The `forward` method prepares data structures for loss and lse, manages distributed computation conditions, and executes the forward kernel. It adjusts results for cases where the workload is split across processes or blocks. The computed losses are also adjusted to zero out contributions from ignored labels.\n\n The `backward` method in `CrossEntropyLoss` class computes gradients for the logits by calling the backward kernel. It supports an optional in-place backward for memory efficiency.\n\n Finally, the `cross_entropy_loss` function is a high-level entry point that applies the `CrossEntropyLoss` operation. It takes parameters such as logits, labels, and smoothing factors, and returns the computed losses and z-losses. The function is designed to handle both single and multi-process scenarios seamlessly. \n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_fwd_kernel(\n loss_ptr, # data ptrs\n lse_ptr,\n z_loss_ptr,\n logits_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n n_rows,\n logits_row_stride, # strides\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n SPLIT: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n max_logits = tl.max(logits, 0)\n if HAS_SMOOTHING:\n sum_logits = tl.sum(tl.where(col_offsets < n_cols, logits, 0.0), 0)\n lse = tl.log(tl.sum(tl.exp(logits - max_logits), 0)) + max_logits\n tl.store(lse_ptr + col_block_idx * n_rows + row_idx, lse)\n if label_idx == ignored_index:\n loss = 0.0\n z_loss = 0.0\n else:\n label_idx -= class_start_idx\n if label_idx >= col_block_idx * BLOCK_SIZE and label_idx < min(\n n_cols, (col_block_idx + 1) * BLOCK_SIZE\n ):\n logits_label = tl.load(logits_ptr + label_idx) * logit_scale\n if HAS_SMOOTHING:\n loss = (\n (lse if not SPLIT else 0.0)\n - smoothing * sum_logits / total_classes\n - (1 - smoothing) * logits_label\n )\n else:\n loss = (lse if not SPLIT else 0.0) - logits_label\n else:\n if HAS_SMOOTHING:\n loss = smoothing * ((lse if not SPLIT else 0.0) - sum_logits / total_classes)\n else:\n loss = 0.0\n if not SPLIT:\n z_loss = lse_square_scale * lse * lse\n loss += z_loss\n else:\n z_loss = 0.0\n tl.store(loss_ptr + col_block_idx * n_rows + row_idx, loss)\n if not SPLIT:\n tl.store(z_loss_ptr + col_block_idx * n_rows + row_idx, z_loss)\n\n@triton.heuristics(\n {\n \"HAS_SMOOTHING\": lambda args: args[\"smoothing\"] > 0.0,\n }\n)\n@triton.jit\ndef cross_entropy_bwd_kernel(\n dlogits_ptr, # data ptrs\n dloss_ptr,\n logits_ptr,\n lse_ptr,\n labels_ptr,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx, # Useful for tensor parallel when each rank only has a subset of classes\n n_cols, # shapes\n logits_row_stride, # strides\n dlogits_row_stride,\n dloss_row_stride,\n BLOCK_SIZE: tl.constexpr,\n HAS_SMOOTHING: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n col_block_idx = tl.program_id(1)\n logits_ptr = logits_ptr + row_idx * logits_row_stride.to(tl.int64)\n dlogits_ptr = dlogits_ptr + row_idx * dlogits_row_stride.to(tl.int64)\n col_offsets = col_block_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n label_idx = tl.load(labels_ptr + row_idx)\n if label_idx != ignored_index:\n dloss = tl.load(dloss_ptr + row_idx * dloss_row_stride)\n else:\n dloss = 0.0\n logits = tl.load(logits_ptr + col_offsets, mask=col_offsets < n_cols, other=-float(\"inf\")).to(\n tl.float32\n ) * logit_scale\n lse = tl.load(lse_ptr + row_idx)\n probs = tl.exp(logits - lse)\n probs += 2.0 * lse_square_scale * lse * probs\n label_idx -= class_start_idx\n if HAS_SMOOTHING:\n smooth_positive = 1.0 - smoothing\n smooth_negative = smoothing / total_classes\n probs = tl.where(col_offsets == label_idx, probs - (1 - smoothing), probs) - smooth_negative\n else:\n probs = tl.where(col_offsets == label_idx, probs - 1.0, probs)\n tl.store(dlogits_ptr + col_offsets, (dloss * logit_scale) * probs, mask=col_offsets < n_cols)\n\nclass CrossEntropyLoss(torch.autograd.Function):\n\n @staticmethod\n def forward(\n ctx,\n logits,\n labels,\n smoothing=0.0,\n logit_scale=1.0,\n lse_square_scale=0.0,\n ignored_index=-100,\n inplace_backward=False,\n process_group=None,\n ):\n n_rows, n_cols = logits.shape\n assert labels.shape == (n_rows,)\n world_size = 1 if process_group is None else torch.distributed.get_world_size(process_group)\n total_classes = world_size * n_cols\n rank = 0 if process_group is None else torch.distributed.get_rank(process_group)\n class_start_idx = rank * n_cols\n\n if logits.stride(-1) != 1:\n logits = logits.contiguous()\n MAX_BLOCK_SIZE = 64 * 1024\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), MAX_BLOCK_SIZE)\n num_warps = (\n 4\n if BLOCK_SIZE < 2048\n else (8 if BLOCK_SIZE < 8192 else (16 if BLOCK_SIZE < 128 * 1024 else 32))\n )\n split = world_size > 1 or n_cols > MAX_BLOCK_SIZE\n n_splits = (n_cols + BLOCK_SIZE - 1) // BLOCK_SIZE\n loss_shape = (n_splits, n_rows) if n_splits > 1 else (n_rows,)\n losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n lse = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n z_losses = torch.empty(*loss_shape, dtype=torch.float, device=logits.device)\n with torch.cuda.device(logits.device.index):\n cross_entropy_fwd_kernel[(n_rows, n_splits)](\n losses, # data ptrs\n lse,\n z_losses,\n logits,\n labels,\n smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n total_classes,\n class_start_idx,\n n_cols, # shapes\n n_rows,\n logits.stride(0), # strides\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n SPLIT=split,\n )\n\n if split:\n if n_splits > 1:\n lse = torch.logsumexp(lse, dim=0)\n losses = losses.sum(dim=0)\n if world_size > 1:\n lse_allgather = torch.empty(world_size, n_rows, dtype=lse.dtype, device=lse.device)\n torch.distributed.all_gather_into_tensor(lse_allgather, lse, group=process_group)\n handle_losses = torch.distributed.all_reduce(\n losses, op=torch.distributed.ReduceOp.SUM, group=process_group, async_op=True\n )\n lse = torch.logsumexp(lse_allgather, dim=0)\n handle_losses.wait()\n losses += lse\n if lse_square_scale != 0.0:\n z_losses = lse_square_scale * lse.square()\n z_losses.masked_fill_(labels == ignored_index, 0.0)\n losses += z_losses\n else:\n z_losses = torch.zeros_like(losses)\n losses.masked_fill_(labels == ignored_index, 0.0)\n\n ctx.save_for_backward(logits, lse, labels)\n ctx.mark_non_differentiable(z_losses)\n ctx.smoothing = smoothing\n ctx.logit_scale = logit_scale\n ctx.lse_square_scale = lse_square_scale\n ctx.ignored_index = ignored_index\n ctx.total_classes = total_classes\n ctx.class_start_idx = class_start_idx\n ctx.inplace_backward = inplace_backward\n\n return losses, z_losses\n\n @staticmethod\n def backward(ctx, grad_losses, grad_z_losses):\n del grad_z_losses # z_losses are only for logging.\n\n logits, lse, labels = ctx.saved_tensors\n dlogits = logits if ctx.inplace_backward else torch.empty_like(logits)\n n_rows, n_cols = logits.shape\n BLOCK_SIZE = min(triton.next_power_of_2(n_cols), 4 * 1024)\n num_warps = 4 if BLOCK_SIZE < 2048 else (8 if BLOCK_SIZE < 8192 else 16)\n grid = lambda META: (n_rows, triton.cdiv(n_cols, META[\"BLOCK_SIZE\"])) # noqa\n with torch.cuda.device(logits.device.index):\n cross_entropy_bwd_kernel[grid](\n dlogits, # data ptrs\n grad_losses,\n logits,\n lse,\n labels,\n ctx.smoothing,\n ctx.logit_scale,\n ctx.lse_square_scale,\n ctx.ignored_index,\n ctx.total_classes,\n ctx.class_start_idx,\n n_cols, # shapes\n logits.stride(0), # strides\n dlogits.stride(0),\n grad_losses.stride(0),\n BLOCK_SIZE=BLOCK_SIZE, # constants\n num_warps=num_warps,\n )\n return dlogits, None, None, None, None, None, None, None, None\n\ndef cross_entropy_loss(\n logits: torch.Tensor,\n labels: torch.Tensor,\n label_smoothing: float = 0.0,\n logit_scale: float = 1.0,\n lse_square_scale: float = 0.0,\n ignored_index=-100,\n inplace_backward: bool = False,\n process_group=None,\n) -> Tuple[torch.Tensor, torch.Tensor]:\n return CrossEntropyLoss.apply(\n logits,\n labels,\n label_smoothing,\n logit_scale,\n lse_square_scale,\n ignored_index,\n inplace_backward,\n process_group,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-6d593bdc68714b0db046e2653c220f24", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code provides an element-wise addition kernel named `add_kernel` and a corresponding wrapper function `add_wrapper`. The `add_kernel` is a decorated function with `@triton.jit`, enabling Just-In-Time compilation for optimized GPU execution. It receives pointers to two input tensors (`in_ptr0`, `in_ptr1`), an output tensor (`out_ptr`), the number of elements to process (`n_elements`), and a `BLOCK_SIZE`. Within the kernel, `pid` identifies the program's ID for the block dimension. The computation starts by calculating `block_start` as the base index for each block using `pid * BLOCK_SIZE`, and `offsets` as a range from `block_start` up to `BLOCK_SIZE`. A mask is applied using `offsets < n_elements` to prevent out-of-bound memory access, loading tensor values with `tl.load` for both input pointers. The values are summed and stored back in `out_ptr` using `tl.store`, still respecting the mask.\n\n The `add_wrapper` function orchestrates the setup and execution of the kernel. It initializes the output tensor `out` using `torch.zeros_like(x)` for similar shape and dtype. The function determines the total number of elements using `x.numel()` and computes the number of blocks needed with `num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE`. Finally, `add_kernel[(num_blocks,)](x, y, out, n_elements, BLOCK_SIZE)` launches the kernel, passing the input and output pointers, number of elements, and block size. The final result is returned in `out`.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef add_kernel(\n in_ptr0,\n in_ptr1,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n y = tl.load(in_ptr1 + offsets, mask=mask)\n output = x + y\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef add_wrapper(x, y):\n out = torch.zeros_like(x)\n \n BLOCK_SIZE = 4\n n_elements = x.numel()\n\n # Calculate the number of blocks needed\n num_blocks = (n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE\n\n # Launch the kernel\n add_kernel[(num_blocks,)](x, y, out, n_elements, BLOCK_SIZE)\n\n return out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1c4718cfbf7d4d8c99a4ea7f3eab3ce2", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The function `block_sparse_attention_kernel` is a Triton GPU kernel designed to perform block-sparse attention for transformer models in the context of prompt processing. It assumes a past sequence length of zero and right-padding only. The kernel takes as inputs the query tensor Q, key tensor K, value tensor V, and layout data in CSR (Compressed Sparse Row) format, specifically `layout_csr_row_indices` and `layout_csr_col_indices`. The kernel executes in a grid determined by the number of query blocks and the product of batch size and number of heads. It utilizes offsets to correctly index into these tensors, leveraging the block sizes BLOCK_M, BLOCK_N, and BLOCK_D. Each block's indices in the query (Q), key (K), and value (V) tensors are calculated, and memory is managed using these indices to perform the necessary loads. The kernel performs key calculations including computing the dot product between query and key blocks (qk), scaling by `softmax_scale`, and adjusting using softmax to prevent overflow. The results are accumulated over the required dimensions to produce the final attention-weighted values which are stored in the output tensor `out`. Support for multi-head attention is built-in, with special handling for head grouping based on `num_heads` and `num_kv_heads`. If multiple data blocks are present (NUM_D_BLOCKS > 1), additional accumulation is performed to account for the segmented data.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\n\n# This kernel is for prompt only and assume that past sequence length is 0. It only supports right padding.\n@triton.jit\ndef block_sparse_attention_kernel(\n out, # output [B, H, M, D]. Note that B is batch_size, H is num_heads, M is q_seq_len, and D is head_size\n Q, # query [B, H, M, D]\n K, # key [B, H_kv, N, D]. Note that N is max_seq_len for kv cache, H_kv is num_kv_heads\n V, # value [B, H_kv, N, D]\n layout_csr_row_indices, # block mask CSR format. Shape is [L, num_rows + 1] where num_rows = max_seq_len / BLOCK_M\n layout_csr_col_indices, # block mask CSR format. Shape is [L, num_rows * num_cols] where num_cols = max_seq_len / BLOCK_N\n layout_csr_row_stride_h, # stride per head for csr_row_indices, i.e. num_rows + 1\n layout_csr_col_stride_h, # stride per head for csr_col_indices, i.e. num_rows * num_cols\n num_layout, # number of sparse layout (L)\n softmax_scale,\n stride_qb,\n stride_qh,\n stride_qm,\n stride_kb,\n stride_kh,\n stride_kn,\n stride_vb,\n stride_vh,\n stride_vn,\n stride_ob,\n stride_oh,\n stride_om,\n num_heads,\n num_kv_heads,\n total_seq_len, # Total sequence length including past sequence length and query sequence length.\n BLOCK_M: tl.constexpr, # block size for q_seq_len\n EVEN_M: tl.constexpr, # whether q_seq_len % BLOCK_M == 0\n BLOCK_N: tl.constexpr, # block size for k_seq_len\n EVEN_N: tl.constexpr, # whether k_seq_len % BLOCK_N == 0\n BLOCK_D: tl.constexpr, # block size for D\n NUM_D_BLOCKS: tl.constexpr, # number of data blocks = D / BLOCK_D\n):\n tl.static_print(f\"{BLOCK_M=} {BLOCK_N=} {BLOCK_D=} {EVEN_M=} {EVEN_N=} {NUM_D_BLOCKS=}\")\n\n # Past sequence length is 0 since this kernel is for prompt only.\n q_seq_len = total_seq_len\n\n # Grid is [CDiv(q_seq_len, BLOCK_M), batch_size * num_heads]\n start_m = tl.program_id(0)\n off_bh = tl.program_id(1)\n\n off_h = off_bh % num_heads\n off_b = off_bh // num_heads\n\n # For group query attention, map the query head index to the corresponding one for key and value.\n head_groups = num_heads // num_kv_heads\n off_h_kv = off_h // head_groups\n\n Q += off_b * stride_qb + off_h * stride_qh\n K += off_b * stride_kb + off_h_kv * stride_kh\n V += off_b * stride_vb + off_h_kv * stride_vh\n\n # Initialize offsets\n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_d = tl.arange(0, BLOCK_D)\n off_q = offs_m[:, None] * stride_qm + offs_d[None, :] # [BLOCK_M, BLOCK_D]\n off_k = offs_n[None, :] * stride_kn + offs_d[:, None] # [BLOCK_D, BLOCK_N]\n off_v = offs_n[:, None] * stride_vn + offs_d[None, :] # [BLOCK_N, BLOCK_D]\n\n # Initialize pointers to query, key, value\n q_ptrs = Q + off_q\n k_ptrs = K + off_k\n v_ptrs = V + off_v\n\n # Initialize pointer to m and l\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32)\n acc = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)\n if NUM_D_BLOCKS >= 2:\n acc2 = tl.zeros([BLOCK_M, BLOCK_D], dtype=tl.float32)\n\n # Load q: it will stay in SRAM throughout\n if EVEN_M:\n q = tl.load(q_ptrs)\n if NUM_D_BLOCKS >= 2:\n q2 = tl.load(q_ptrs + BLOCK_D)\n else:\n q = tl.load(q_ptrs, mask=offs_m[:, None] < q_seq_len)\n if NUM_D_BLOCKS >= 2:\n q2 = tl.load(q_ptrs + BLOCK_D, mask=offs_m[:, None] < q_seq_len)\n\n layout_h = off_h % num_layout\n\n # This assumes that past sequence length is 0, otherwise need + (past_seq_len + 1) // BLOCK_M.\n layout_ptr = layout_csr_row_indices + layout_h * layout_csr_row_stride_h + start_m\n start_l = tl.load(layout_ptr).to(tl.int32)\n end_l = tl.load(layout_ptr + 1).to(tl.int32)\n\n # Loop over k, v and update accumulator\n for col_idx_idx in range(start_l, end_l):\n col_idx = tl.load(layout_csr_col_indices + layout_h * layout_csr_col_stride_h + col_idx_idx).to(tl.int32)\n start_n = col_idx * BLOCK_N\n # -- compute qk ----\n if EVEN_N:\n k = tl.load(k_ptrs + start_n * stride_kn)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn, mask=offs_n[None, :] + start_n < total_seq_len)\n qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)\n qk += tl.dot(q, k)\n\n if NUM_D_BLOCKS >= 2:\n if EVEN_N:\n k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D)\n else:\n k = tl.load(k_ptrs + start_n * stride_kn + BLOCK_D, mask=offs_n[None, :] + start_n < total_seq_len)\n qk += tl.dot(q2, k)\n\n qk *= softmax_scale\n\n # This assumes that past sequence length is 0, otherwise need offs_m[:, None] + past_seq_len >= ...\n qk += tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), 0, float(\"-inf\"))\n # -- compute m_ij, p, l_ij\n m_ij = tl.max(qk, 1)\n p = tl.exp(qk - m_ij[:, None])\n l_ij = tl.sum(p, 1)\n # -- update m_i and l_i\n m_i_new = tl.maximum(m_i, m_ij)\n alpha = tl.exp(m_i - m_i_new)\n beta = tl.exp(m_ij - m_i_new)\n l_i_new = alpha * l_i + beta * l_ij\n # -- update output accumulator --\n # scale p\n p_scale = beta / l_i_new\n p = p * p_scale[:, None]\n # scale acc\n acc_scale = l_i / l_i_new * alpha\n acc = acc * acc_scale[:, None]\n if NUM_D_BLOCKS >= 2:\n acc2 = acc2 * acc_scale[:, None]\n p = p.to(Q.dtype.element_ty)\n # update acc\n if EVEN_N:\n v = tl.load(v_ptrs + start_n * stride_vn)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn, mask=offs_n[:, None] + start_n < total_seq_len)\n acc += tl.dot(p, v)\n\n if NUM_D_BLOCKS >= 2:\n if EVEN_N:\n v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D)\n else:\n v = tl.load(v_ptrs + start_n * stride_vn + BLOCK_D, mask=offs_n[:, None] + start_n < total_seq_len)\n acc2 += tl.dot(p, v)\n\n # update m_i and l_i\n l_i = l_i_new\n m_i = m_i_new\n\n off_o = off_b * stride_ob + off_h * stride_oh + offs_m[:, None] * stride_om + offs_d[None, :]\n out_ptrs = out + off_o\n tl.store(out_ptrs, acc, mask=offs_m[:, None] < q_seq_len)\n if NUM_D_BLOCKS >= 2:\n tl.store(out_ptrs + BLOCK_D, acc2, mask=offs_m[:, None] < q_seq_len)\n\n# Define the wrapper function to call the Triton kernel\ndef block_sparse_attention(\n Q, K, V, layout_csr_row_indices, layout_csr_col_indices, layout_csr_row_stride_h, layout_csr_col_stride_h,\n num_layout, softmax_scale, num_heads, num_kv_heads, total_seq_len, BLOCK_M, EVEN_M, BLOCK_N, EVEN_N, BLOCK_D, NUM_D_BLOCKS\n):\n # Determine the grid size\n q_seq_len = total_seq_len\n grid = (triton.cdiv(q_seq_len, BLOCK_M), Q.shape[0] * num_heads)\n\n # Allocate output tensor\n out = torch.empty((Q.shape[0], num_heads, q_seq_len, Q.shape[-1]), device=Q.device, dtype=Q.dtype)\n\n # Call the Triton kernel\n block_sparse_attention_kernel[grid](\n out, Q, K, V, layout_csr_row_indices, layout_csr_col_indices, layout_csr_row_stride_h, layout_csr_col_stride_h,\n num_layout, softmax_scale, Q.stride(0), Q.stride(1), Q.stride(2), K.stride(0), K.stride(1), K.stride(2),\n V.stride(0), V.stride(1), V.stride(2), out.stride(0), out.stride(1), out.stride(2), num_heads, num_kv_heads,\n total_seq_len, BLOCK_M=BLOCK_M, EVEN_M=EVEN_M, BLOCK_N=BLOCK_N, EVEN_N=EVEN_N, BLOCK_D=BLOCK_D, NUM_D_BLOCKS=NUM_D_BLOCKS\n )\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-79bb069e46f448a995c3f73f769fa751", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator code showcases the implementation of efficient GPU kernels for performing integer matrix multiplication and scaled integer matrix multiplication using the Triton framework. The main constructs in the code are two kernel functions and two corresponding launch functions.\n\n 1. `@triton.jit def matmul_kernel_with_block_pointers(...)`: This kernel computes the matrix product C = A x B for matrices A (MxK) and B (KxN) with result C (MxN). It uses block pointers to access matrix elements in chunks defined by `BLOCK_M`, `BLOCK_N`, and `BLOCK_K`, promoting L2 cache reuse and minimizing global memory access latency. The kernel organizes the computation through program IDs, which are mapped to blocks of matrix C in a grouped ordering, enhancing data locality. The accumulation of results is done in an `accumulator` of type `tl.int32` for precision.\n\n 2. `@triton.jit def scaled_matmul_kernel_with_block_pointers(...)`: Extending the basic matrix multiplication, this kernel incorporates scaling of the output matrix by a scale matrix `scales1`. The approach is similar to the previous kernel, with the matrix multiplication results being scaled before being stored. The `EVEN_K` parameter allows optimization if K is a multiple of the block size.\n\n 3. `def int_matmul_kernel(a, b, c, config)`: This is a host function that prepares and launches the `matmul_kernel_with_block_pointers` kernel. It calculates the grid size needed for the kernel execution based on the dimensions of the matrices A, B, and C, then invokes the kernel with this configuration, passing the matrix data pointers and stride information alongside execution parameters from the `config` object.\n\n 4. `def int_scaled_matmul_kernel(a, b, scales1, c, config)`: Similar to `int_matmul_kernel`, this function sets up and launches the `scaled_matmul_kernel_with_block_pointers` kernel. It includes handling for the scale matrix `scales1`, adjusting kernel parameters to account for matrix sizes and execution configuration specified in the `config` object.\n\n 5. `class Config`: This helper class encapsulates configuration settings for the kernel launches, such as `num_warps`, `num_stages`, and `num_ctas`, which dictate how the kernel is executed on the GPU. These settings can significantly impact the performance and efficiency of the computation on different GPU architectures.\n\n Overall, the code leverages Triton's ability to generate efficient GPU kernels with minimal user intervention by handling boundary checks, block pointer manipulations, and execution configurations automatically, aiming to achieve high-performance matrix operations suitable for integer arithmetic on GPUs.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef matmul_kernel_with_block_pointers(\n # Pointers to matrices\n a_ptr,\n b_ptr,\n c_ptr,\n # Matrix dimensions\n M,\n N,\n K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n # by to get the element one row down (A has M rows).\n stride_am,\n stride_ak, #\n stride_bk,\n stride_bn, #\n stride_cm,\n stride_cn,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n # -----------------------------------------------------------\n # Map program ids `pid` to the block of C it should compute.\n # This is done in a grouped ordering to promote L2 data reuse.\n # See the matrix multiplication tutorial for details.\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_M)\n num_pid_n = tl.cdiv(N, BLOCK_N)\n num_pid_in_group = GROUP_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_M\n GROUP_M = min(num_pid_m - first_pid_m, GROUP_M)\n pid_m = first_pid_m + (pid % GROUP_M)\n pid_n = (pid % num_pid_in_group) // GROUP_M\n\n # ----------------------------------------------------------\n # Create block pointers for the first blocks of A and B.\n # We will advance this pointer as we move in the K direction and accumulate.\n # See above `Make a Block Pointer` section for details.\n a_block_ptr = tl.make_block_ptr(\n base=a_ptr,\n shape=(M, K),\n strides=(stride_am, stride_ak),\n offsets=(pid_m * BLOCK_M, 0),\n block_shape=(BLOCK_M, BLOCK_K),\n order=(1, 0),\n )\n b_block_ptr = tl.make_block_ptr(\n base=b_ptr,\n shape=(K, N),\n strides=(stride_bk, stride_bn),\n offsets=(0, pid_n * BLOCK_N),\n block_shape=(BLOCK_K, BLOCK_N),\n order=(1, 0),\n )\n\n # -----------------------------------------------------------\n # Iterate to compute a block of the C matrix.\n # We accumulate into a `[BLOCK_M, BLOCK_N]` block.\n # of fp32 values for higher accuracy.\n # `accumulator` will be converted back to fp16 after the loop.\n accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.int32)\n for k in range(0, K, BLOCK_K):\n # Load with boundary checks, no need to calculate the mask manually.\n # For better performance, you may remove some axis from the boundary\n # check, if you can guarantee that the access is always in-bound in\n # that axis.\n # See above `Load/Store a Block Pointer` section for details.\n a = tl.load(a_block_ptr, boundary_check=(0, 1))\n b = tl.load(b_block_ptr, boundary_check=(0, 1))\n # We accumulate along the K dimension.\n accumulator += tl.dot(a, b)\n # Advance the block pointer to the next K block.\n # See above `Advance a Block Pointer` section for details.\n a_block_ptr = tl.advance(a_block_ptr, (0, BLOCK_K))\n b_block_ptr = tl.advance(b_block_ptr, (BLOCK_K, 0))\n c = accumulator # .to(tl.float16)\n\n # ----------------------------------------------------------------\n # Write back the block of the output matrix C with boundary checks.\n # See above `Load/Store a Block Pointer` section for details.\n c_block_ptr = tl.make_block_ptr(\n base=c_ptr,\n shape=(M, N),\n strides=(stride_cm, stride_cn),\n offsets=(pid_m * BLOCK_M, pid_n * BLOCK_N),\n block_shape=(BLOCK_M, BLOCK_N),\n order=(1, 0),\n )\n tl.store(c_block_ptr, c, boundary_check=(0, 1))\n\n\n@triton.jit\ndef scaled_matmul_kernel_with_block_pointers(\n # Pointers to matrices\n a_ptr,\n b_ptr,\n c_ptr,\n s1_ptr,\n # Matrix dimensions\n M,\n N,\n K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`\n # by to get the element one row down (A has M rows).\n stride_am,\n stride_ak,\n stride_bk,\n stride_bn,\n stride_cm,\n stride_cn,\n stride_s1m,\n stride_s1n,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n GROUP_M: tl.constexpr,\n EVEN_K: tl.constexpr,\n ACC_TYPE: tl.constexpr = tl.int32,\n):\n # based on triton.ops.matmul\n pid = tl.program_id(0)\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_id = pid // width\n group_size = min(grid_m - group_id * GROUP_M, GROUP_M)\n pid_m = group_id * GROUP_M + (pid % group_size)\n pid_n = (pid % width) // (group_size)\n\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n ram = tl.max_contiguous(tl.multiple_of(rm % M, BLOCK_M), BLOCK_M)\n rbn = tl.max_contiguous(tl.multiple_of(rn % N, BLOCK_N), BLOCK_N)\n rk = tl.arange(0, BLOCK_K)\n A = a_ptr + (ram[:, None] * stride_am + rk[None, :] * stride_ak)\n B = b_ptr + (rk[:, None] * stride_bk + rbn[None, :] * stride_bn)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=ACC_TYPE)\n for k in range(K, 0, -BLOCK_K):\n if EVEN_K:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=rk[None, :] < k, other=0.0)\n b = tl.load(B, mask=rk[:, None] < k, other=0.0)\n acc += tl.dot(a, b) # , allow_tf32=ALLOW_TF32)\n A += BLOCK_K * stride_ak\n B += BLOCK_K * stride_bk\n\n # rematerialize rm and rn to save registers\n rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)\n idx_m = rm[:, None]\n idx_n = rn[None, :]\n mask = (idx_m < M) & (idx_n < N)\n\n # inductor generates a suffix\n xindex = idx_n + (N * idx_m)\n tmp0 = tl.load(\n s1_ptr + (tl.broadcast_to(idx_m, mask.shape)),\n mask,\n eviction_policy=\"evict_last\",\n )\n tl.store(c_ptr + (tl.broadcast_to(xindex, mask.shape)), acc * tmp0, mask)\n\n\ndef int_matmul_kernel(a, b, c, config):\n M, K = a.shape\n K, N = b.shape\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n )\n matmul_kernel_with_block_pointers[grid](\n a,\n b,\n c, #\n M,\n N,\n K, #\n a.stride(0),\n a.stride(1), #\n b.stride(0),\n b.stride(1), #\n c.stride(0),\n c.stride(1),\n num_warps=config.num_warps,\n num_stages=config.num_stages,\n num_ctas=config.num_ctas,\n **config.kwargs,\n )\n return c\n\n\ndef int_scaled_matmul_kernel(a, b, scales1, c, config):\n M, K = a.shape\n K, N = b.shape\n # print(\"a.sizes(): \", a.size(), \"a.strides(): \", a.stride(), \"a.dtype: \", a.dtype)\n # print(\"b.sizes(): \", b.size(), \"b.strides(): \", b.stride(), \"b.dtype: \", b.dtype)\n # print(\"c.sizes(): \", c.size(), \"c.strides(): \", c.stride(), \"c.dtype: \", c.dtype)\n # print(\"scales1.sizes(): \", scales1.size(), \"scales1.strides(): \", scales1.stride(), \"scales1.dtype\", scales1.dtype)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),\n )\n scaled_matmul_kernel_with_block_pointers[grid](\n a,\n b,\n c,\n scales1,\n M,\n N,\n K, #\n a.stride(0),\n a.stride(1), #\n b.stride(0),\n b.stride(1), #\n c.stride(0),\n c.stride(1),\n scales1.stride(0),\n scales1.stride(1),\n num_warps=config.num_warps,\n num_stages=config.num_stages,\n num_ctas=config.num_ctas,\n EVEN_K=(K % 2 == 0),\n **config.kwargs,\n )\n return c\n\n\nclass Config:\n def __init__(self, num_warps=4, num_stages=2, num_ctas=1, **kwargs):\n self.num_warps = num_warps\n self.num_stages = num_stages\n self.num_ctas = num_ctas\n self.kwargs = kwargs\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-8d3d5f125c504845bc19556bbd523d32", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code comprises three main Triton kernels for specific tensor operations, often utilized in transformer-like architectures:\n\n 1. **fwd_decay_cumsum**:\n - **Function**: This kernel performs a cumulative sum with decay for each element in a tensor `g`, writing the results to `g_o`. It iterates over a defined number of block rows (`BT`).\n - **Inputs**: Includes the input tensor `g`, output tensor `g_o`, and a decay constant, with parameters defining block and grid sizes.\n - **Logic**: It initializes a zero vector `cum_decay`, iteratively loads segments from `g`, scales them by `inv_ln2`, accumulates into `cum_decay`, and stores results in `g_o`. The kernel uses a mask to handle boundary conditions where the iteration range exceeds tensor dimensions.\n\n 2. **prepare_qg_kg**:\n - **Function**: Prepares new tensors `qg` and `kg` from input tensors `q`, `k`, and `g`. The preparation involves applying exponential decay and scaling transformations.\n - **Inputs**: Tensors `q`, `k`, and `g` with outputs `qg`, `kg`, and block/grid size parameters.\n - **Logic**: For each block row (`BT`), it retrieves values from `q`, `k`, and `g`, computes exponential transformations based on cumulative decay from `g`, and applies a scaling factor. Transformed values are stored into `qg` and `kg`, leveraging mask operations to handle dimension bounds.\n\n 3. **bwd_decay_global_cumsum**:\n - **Function**: Implements the backward operation to compute gradients for the cumulative decay operation, updating `dg` based on provided input gradients.\n - **Inputs**: Includes `dq_inner`, `dq_inter`, `dk_inner`, `dk_inter`, `q`, `k`, `g`, and the gradient tensor `dg`, along with configuration parameters.\n - **Logic**: Backtracks through each block row (`BT`), loads input gradients and tensors, computes the gradient of the decay operation by differentiating through the sum and product operations of `q` and `k`, and accumulates results into `dg`. Employs exponential decay based on the gradient load position and handles edge cases with a mask.\n\n **Launch Functions**: For each kernel, a corresponding launch function computes the necessary strides based on tensor dimensions, prepares grid dimensions (`DK // BK`, `T // BT`, `B * H`), and invokes the respective kernel with computed parameters.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\n\ninv_ln2 = tl.constexpr(1.44269504)\n\n# Forward decay cumulative sum kernel\n@triton.jit\ndef fwd_decay_cumsum(\n g,\n g_o, \n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_go = g_o + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n cum_decay = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n for i in range(BT):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n cum_decay += _g * inv_ln2\n tl.store(p_go, cum_decay.to(p_go.dtype.element_ty), mask=mask)\n p_g += DK\n p_go += DK\n\n# Prepare qg and kg kernel\n@triton.jit\ndef prepare_qg_kg(\n q,\n k,\n g,\n qg,\n kg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_g = g + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_k = k + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_qg = qg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n p_kg = kg + i_bh * s_qk_h + i_c * BT * DK + i_k * BK + tl.arange(0, BK)\n \n mask = (i_k * BK + tl.arange(0, BK)) < DK\n\n last_decay = tl.load(g + i_bh * s_qk_h + (i_c * BT + BT - 1) * DK + i_k * BK + tl.arange(0, BK))\n\n for i in range(BT):\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n _q *= tl.math.exp2(_g) * scale\n _k *= tl.math.exp2(last_decay - _g)\n tl.store(p_kg, _k.to(p_kg.dtype.element_ty), mask=mask)\n tl.store(p_qg, _q.to(p_qg.dtype.element_ty), mask=mask)\n p_q += DK\n p_g += DK\n p_k += DK\n p_kg += DK\n p_qg += DK\n\n# Backward decay global cumulative sum kernel\n@triton.jit\ndef bwd_decay_global_cumsum(\n dq_inner,\n dq_inter,\n dk_inner,\n dk_inter,\n q, k, g, dg,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n B,\n H,\n T,\n scale,\n BT: tl.constexpr,\n BK: tl.constexpr,\n DK: tl.constexpr\n):\n i_k, i_c, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n p_q = q + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_k = k + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_g = g + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dg = dg + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inner = dq_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inner = dk_inner + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dq_inter = dq_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n p_dk_inter = dk_inter + i_bh * s_qk_h + i_k * BK + tl.arange(0, BK) + (i_c * BT + BT - 1) * DK\n cum_grad_dg = tl.zeros([BK], dtype=tl.float32)\n mask = (i_k * BK + tl.arange(0, BK)) < DK\n last_g = tl.zeros([BK], dtype=tl.float32)\n for j in range(BT-1, -1, -1):\n _g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n if j == (BT-1):\n last_g = _g\n _dq1 = tl.load(p_dq_inner, mask=mask, other=0)\n _dq2 = tl.load(p_dq_inter, mask=mask, other=0)\n _dq2 *= tl.math.exp2(_g)\n _dq = _dq1 + _dq2\n tl.store(p_dq_inter, _dq, mask=mask)\n _dk1 = tl.load(p_dk_inner, mask=mask, other=0)\n _dk2 = tl.load(p_dk_inter, mask=mask, other=0)\n _dk2 *= tl.math.exp2(last_g - _g)\n _dk = _dk1 + _dk2\n tl.store(p_dk_inter, _dk, mask=mask)\n _q = tl.load(p_q, mask=mask, other=0)\n _k = tl.load(p_k, mask=mask, other=0)\n _dg = _dq * _q - _dk * _k\n cum_grad_dg += _dg\n tl.store(p_dg, cum_grad_dg.to(p_dg.dtype.element_ty), mask=mask)\n p_g -= DK\n p_k -= DK\n p_q -= DK\n p_dq_inner -= DK\n p_dk_inner -= DK\n p_dq_inter -= DK\n p_dk_inter -= DK\n p_dg -= DK\n\n# Define the kernel launch functions\ndef launch_fwd_decay_cumsum(g, g_o, B, H, T, scale, BT, BK, DK):\n # Calculate strides\n s_qk_h = H * T * DK\n s_qk_t = T * DK\n s_qk_d = DK\n\n # Launch the kernel\n grid = (DK // BK, T // BT, B * H)\n fwd_decay_cumsum[grid](\n g, g_o, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT=BT, BK=BK, DK=DK\n )\n\ndef launch_prepare_qg_kg(q, k, g, qg, kg, B, H, T, scale, BT, BK, DK):\n # Calculate strides\n s_qk_h = H * T * DK\n s_qk_t = T * DK\n s_qk_d = DK\n\n # Launch the kernel\n grid = (DK // BK, T // BT, B * H)\n prepare_qg_kg[grid](\n q, k, g, qg, kg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT=BT, BK=BK, DK=DK\n )\n\ndef launch_bwd_decay_global_cumsum(dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg, B, H, T, scale, BT, BK, DK):\n # Calculate strides\n s_qk_h = H * T * DK\n s_qk_t = T * DK\n s_qk_d = DK\n\n # Launch the kernel\n grid = (DK // BK, T // BT, B * H)\n bwd_decay_global_cumsum[grid](\n dq_inner, dq_inter, dk_inner, dk_inter, q, k, g, dg, s_qk_h, s_qk_t, s_qk_d, B, H, T, scale, BT=BT, BK=BK, DK=DK\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-2c53b56ba7c743feb983a5028141fb7d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `matmul_kernel_persistent` function is a Triton JIT-compiled kernel designed for matrix multiplication with persistent thread execution across different tiles of the matrices. The function takes pointers to input matrices `a_ptr`, `b_ptr`, and an output matrix `c_ptr`, along with the dimensions `M`, `N`, `K`, and strides for each matrix. It uses configurable block sizes for partitioning the matrices and utilizes GPU resources effectively by leveraging multiple streaming multiprocessors (NUM_SMS).\n\n The kernel divides the matrices into blocks of size `BLOCK_SIZE_M x BLOCK_SIZE_N` and iterates over `BLOCK_SIZE_K` tiles of the shared dimension `K`. Each tile is processed independently, and results are accumulated in the `accumulator` array. The kernel employs mask handling to manage edge cases where block indices exceed matrix dimensions and uses `tl.load` and `tl.store` operations with masks to safely access global memory. The accumulated results are converted to the appropriate data type before storing them in the output matrix.\n\n The `matmul_persistent` function acts as a wrapper to initialize and launch the Triton kernel. It validates input tensor shapes and data types and computes the appropriate grid size based on the input dimensions and hardware capabilities. The function configures kernel parameters such as block sizes, number of stages, and warps based on the input data type, and executes the `matmul_kernel_persistent` with these parameters. The output is an allocated tensor `c` which contains the resulting product of matrices `a` and `b`.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\ndef _matmul_launch_metadata(grid, kernel, args):\n ret = {}\n M, N, K = args[\"M\"], args[\"N\"], args[\"K\"]\n ret[\"name\"] = f\"{kernel.name} [M={M}, N={N}, K={K}]\"\n if \"c_ptr\" in args:\n bytes_per_elem = args[\"c_ptr\"].element_size()\n else:\n bytes_per_elem = 1 if args[\"FP8_OUTPUT\"] else 2\n ret[f\"flops{bytes_per_elem * 8}\"] = 2. * M * N * K\n ret[\"bytes\"] = bytes_per_elem * (M * K + N * K + M * N)\n return ret\n\n\n@triton.jit(launch_metadata=_matmul_launch_metadata)\ndef matmul_kernel_persistent(a_ptr, b_ptr, c_ptr, #\n M, N, K, #\n stride_am, stride_ak, #\n stride_bk, stride_bn, #\n stride_cm, stride_cn, #\n BLOCK_SIZE_M: tl.constexpr, #\n BLOCK_SIZE_N: tl.constexpr, #\n BLOCK_SIZE_K: tl.constexpr, #\n GROUP_SIZE_M: tl.constexpr, #\n NUM_SMS: tl.constexpr, #\n ):\n start_pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n k_tiles = tl.cdiv(K, BLOCK_SIZE_K)\n num_tiles = num_pid_m * num_pid_n\n\n tiles_per_SM = num_tiles // NUM_SMS\n if start_pid < num_tiles % NUM_SMS:\n tiles_per_SM += 1\n\n tile_id = start_pid - NUM_SMS\n ki = -1\n\n offs_k_for_mask = tl.arange(0, BLOCK_SIZE_K)\n\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n\n pid_m = 0\n pid_n = 0\n offs_am = tl.arange(0, BLOCK_SIZE_M)\n offs_bn = tl.arange(0, BLOCK_SIZE_N)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n for _ in range(0, k_tiles * tiles_per_SM):\n ki = tl.where(ki == k_tiles - 1, 0, ki + 1)\n if ki == 0:\n tile_id += NUM_SMS\n group_id = tile_id // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (tile_id % group_size_m)\n pid_n = (tile_id % num_pid_in_group) // group_size_m\n\n start_m = pid_m * BLOCK_SIZE_M\n start_n = pid_n * BLOCK_SIZE_N\n offs_am = start_m + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = start_n + tl.arange(0, BLOCK_SIZE_N)\n offs_am = tl.where(offs_am < M, offs_am, 0)\n offs_bn = tl.where(offs_bn < N, offs_bn, 0)\n offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)\n offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)\n offs_k = ki * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n a = tl.load(a_ptrs, mask=offs_k_for_mask[None, :] < K - ki * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k_for_mask[:, None] < K - ki * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n\n if ki == k_tiles - 1:\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n if (c_ptr.dtype.element_ty == tl.float8e4nv):\n c = accumulator.to(tl.float8e4nv)\n else:\n c = accumulator.to(tl.float16)\n tl.store(c_ptrs, c, mask=c_mask)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n\n\ndef matmul_persistent(a, b):\n configs = {\n torch.float8_e4m3fn: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 128, \"GROUP_SIZE_M\": 8, \"num_stages\": 4,\n \"num_warps\": 8\n }, torch.float16: {\n \"BLOCK_SIZE_M\": 128, \"BLOCK_SIZE_N\": 256, \"BLOCK_SIZE_K\": 64, \"GROUP_SIZE_M\": 8, \"num_stages\": 3,\n \"num_warps\": 8\n }\n }\n # Check constraints.\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.dtype == b.dtype, \"Incompatible dtypes\"\n NUM_SMS = torch.cuda.get_device_properties(\"cuda\").multi_processor_count\n M, K = a.shape\n K, N = b.shape\n dtype = a.dtype\n # Allocates output.\n c = torch.empty((M, N), device=a.device, dtype=dtype)\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (min(NUM_SMS, triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"])), )\n matmul_kernel_persistent[grid](\n a, b, c, #\n M, N, K, #\n a.stride(0), a.stride(1), #\n b.stride(0), b.stride(1), #\n c.stride(0), c.stride(1), #\n BLOCK_SIZE_M=configs[dtype][\"BLOCK_SIZE_M\"], #\n BLOCK_SIZE_N=configs[dtype][\"BLOCK_SIZE_N\"], #\n BLOCK_SIZE_K=configs[dtype][\"BLOCK_SIZE_K\"], #\n GROUP_SIZE_M=configs[dtype][\"GROUP_SIZE_M\"], #\n NUM_SMS=NUM_SMS, #\n num_stages=configs[dtype][\"num_stages\"], #\n num_warps=configs[dtype][\"num_warps\"], #\n )\n return c\n\n\n" + ] + }, + { + "id": "openseek-8-787ad8e991264bca9887a105e4aecc02", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This implementation leverages Triton to perform a GEGLU activation function using a tanh-based approximation for efficient computation on GPUs. It includes two main kernels: `_geglu_tanh_forward_kernel` for the forward pass and `_geglu_tanh_backward_kernel` for the backward pass.\n\n The `geglu_forward` function serves as the entry point for the forward computation. It takes two input tensors, `a` and `b`, representing linear transformations of some input data. These tensors are reshaped to a 2D format with the last dimension corresponding to the number of columns `n_cols`. The function then initializes an empty tensor `c` to store the output and calculates `BLOCK_SIZE` and `num_warps` using `calculate_settings`, assumed to be constant at 128 and 4 respectively. The `_geglu_tanh_forward_kernel` is launched across the number of rows (`n_rows`) to perform element-wise operations defined within the kernel: load `a` and `b`, compute the GEGLU activation using an approximation based on `tanh`, and store the result in `c`.\n\n Similarly, the `geglu_backward` function calculates gradients with respect to inputs `a` and `b` from the upstream gradient `dc`. It invokes `_geglu_tanh_backward_kernel` with the appropriate parameters. The backward kernel recomputes necessary intermediates to avoid storing them, calculates gradients for `a` and `b` using the derivative of the GEGLU function, and stores them back to the respective tensors.\n\n Common computations include a tanh approximation formula for GELU: `0.5 * a * (1 + tanh(sqrt(2/pi) * (a + 0.044715 * a^3)))`, and the corresponding backward pass calculations leveraging derivative rules for tanh and GELU. The code reflects efficient memory and computation usage, highlighting the suitability of Triton for such tasks.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom triton.language.extra.libdevice import tanh\n\n\ndef calculate_settings(n_cols):\n # This function calculates the BLOCK_SIZE and num_warps based on n_cols.\n # For simplicity, let's assume BLOCK_SIZE is 128 and num_warps is 4.\n BLOCK_SIZE = 128\n num_warps = 4\n return BLOCK_SIZE, num_warps\n\n@triton.jit\ndef _geglu_tanh_forward_kernel(\n a, b, c, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr\n):\n program_id = tl.program_id(0).to(tl.int64)\n\n # locate start index\n a += program_id * stride\n b += program_id * stride\n c += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask=mask, other=0)\n\n # tanh approximation form of GELU is computed with:\n # 0.5 * a * (1 + tanh(sqrt(2 / pi) * (a + 0.044715 * a^3)))\n sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)\n a_cubed = a_row * a_row * a_row\n tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)\n tanh_result = tanh(tanh_arg)\n geglu_a = 0.5 * a_row * (1 + tanh_result)\n c_row = geglu_a * b_row\n tl.store(c + col_offsets, c_row, mask=mask)\n\n@triton.jit\ndef _geglu_tanh_backward_kernel(\n dc, a, b, stride, n_cols: tl.constexpr, BLOCK_SIZE: tl.constexpr\n):\n program_id = tl.program_id(0).to(tl.int64)\n\n # locate start index\n dc += program_id * stride\n a += program_id * stride\n b += program_id * stride\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < n_cols\n\n dc_row = tl.load(dc + col_offsets, mask=mask, other=0)\n a_row = tl.load(a + col_offsets, mask=mask, other=0).to(tl.float32)\n b_row = tl.load(b + col_offsets, mask=mask, other=0)\n\n # recomputation to save memory\n sqrt_2_over_pi = 0.7978845608028654 # sqrt(2 / pi)\n a_cubed = a_row * a_row * a_row\n tanh_arg = sqrt_2_over_pi * (a_row + 0.044715 * a_cubed)\n tanh_result = tanh(tanh_arg)\n geglu_a = 0.5 * a_row * (1 + tanh_result)\n\n db_row = dc_row * geglu_a\n\n # Gradient w.r.t. a can be computed with:\n # b * (0.5 * (1 + tanh(z)) + 0.5 * a * (1 - tanh(z)^2) * (sqrt(2/pi) * (1 + 3 * 0.044715 * a^2)))\n # where z = sqrt(2/pi) * (a + 0.044715 * a^3)\n term1 = 0.5 * (1 + tanh_result)\n tanh_sq = tanh_result * tanh_result\n term2 = (\n 0.5\n * a_row\n * (1 - tanh_sq)\n * (sqrt_2_over_pi * (1 + 3 * 0.044715 * a_row * a_row))\n )\n da_row = dc_row * b_row * (term1 + term2)\n\n tl.store(a + col_offsets, da_row, mask=mask)\n tl.store(b + col_offsets, db_row, mask=mask)\n\ndef geglu_forward(a, b):\n ori_shape = a.shape\n\n n_cols = ori_shape[-1]\n a = a.view(-1, n_cols)\n b = b.view(-1, n_cols)\n c = torch.empty_like(a)\n n_rows = a.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _geglu_tanh_forward_kernel[(n_rows,)](\n a,\n b,\n c,\n c.stride(-2),\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return a, b, c.view(*ori_shape)\n\ndef geglu_backward(a, b, dc):\n ori_shape = dc.shape\n n_cols = ori_shape[-1]\n dc = dc.view(-1, n_cols)\n n_rows = dc.shape[0]\n\n BLOCK_SIZE, num_warps = calculate_settings(n_cols)\n\n _geglu_tanh_backward_kernel[(n_rows,)](\n dc,\n a,\n b,\n dc.stride(-2),\n n_cols=n_cols,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n\n return a.view(*ori_shape), b.view(*ori_shape)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-de454ccc491a4023963e0e8e54717855", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_bmm_chunk_fwd_kernel` is responsible for computing a batched matrix multiplication (BMM) with support for chunking, causal masking, and sequence indexing. The function takes pointers to input tensors `a` and `b`, an output tensor `out`, and several parameters dictating the shapes, strides, and other characteristics needed for efficient computation on GPUs.\n\n Key parameters include:\n - `BLOCK_SIZE_M`, `BLOCK_SIZE_N`, `BLOCK_SIZE_K`: constants defining the size of each processing block for the M, N, and K dimensions.\n - `IS_CAUSAL`: a compile-time constant indicating whether causal masking is applied.\n - `HAS_SEQ_IDX`: a compile-time constant indicating whether sequence indexing is applied.\n\n The kernel uses program IDs (`pid`) to determine the specific data block each thread processes, defined by combinations of batch, chunk, group, and head indices. The kernel then computes the dot product of corresponding sub-matrices from `a` and `b`, accumulating the result in `acc`, a local accumulation tensor.\n\n Once the accumulation is complete, if `HAS_SEQ_IDX` is enabled, the sequence indices are loaded and used to zero out contributions from mismatched indices. The resulting matrix product is stored back into the global memory.\n\n The `_bmm_chunk_fwd` function is a Python wrapper that sets up kernel arguments, determines execution grid size, and launches the kernel. It handles cases where input tensors are non-contiguous, ensures output tensor allocation based on the computed dimensions, and adjusts for optional parameters like sequence indexing and causality. This function is designed to facilitate BMM operations in various configurations, making it versatile for different input tensor shapes and application needs.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=2),\n ],\n key=['chunk_size', 'K', 'IS_CAUSAL'],\n)\n@triton.jit\ndef _bmm_chunk_fwd_kernel(\n a_ptr, b_ptr, out_ptr, seq_idx_ptr,\n seqlen, chunk_size, K, ngroups,\n stride_a_batch, stride_a_seqlen, stride_a_head, stride_ak,\n stride_b_batch, stride_b_seqlen, stride_b_head, stride_bk,\n stride_out_batch, stride_out_chunk, stride_out_head, stride_outm, stride_outn,\n stride_seq_idx_batch, stride_seq_idx_seqlen,\n IS_CAUSAL: tl.constexpr,\n dot_dtype: tl.constexpr,\n HAS_SEQ_IDX: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n pid_b = tl.program_id(axis=1)\n pid_ch = tl.program_id(axis=2)\n pid_c = pid_ch // ngroups\n pid_h = pid_ch - pid_c * ngroups\n num_pid_n = tl.cdiv(chunk_size, BLOCK_SIZE_N)\n pid_m = tl.program_id(axis=0) // num_pid_n\n pid_n = tl.program_id(axis=0) % num_pid_n\n if IS_CAUSAL:\n if pid_n * BLOCK_SIZE_N >= (pid_m + 1) * BLOCK_SIZE_M:\n return\n a_ptr += pid_b * stride_a_batch + pid_c * chunk_size * stride_a_seqlen + pid_h * stride_a_head\n b_ptr += pid_b * stride_b_batch + pid_c * chunk_size * stride_b_seqlen + pid_h * stride_b_head\n if HAS_SEQ_IDX:\n seq_idx_ptr += pid_b * stride_seq_idx_batch + pid_c * chunk_size * stride_seq_idx_seqlen\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_m[:, None] * stride_a_seqlen + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_n[None, :] * stride_b_seqlen)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=(offs_m[:, None] < chunk_size_limit) & (offs_k[None, :] < K - k * BLOCK_SIZE_K), other=0.0).to(dot_dtype)\n b = tl.load(b_ptrs, mask=(offs_k[:, None] < K - k * BLOCK_SIZE_K) & (offs_n[None, :] < chunk_size_limit), other=0.0).to(dot_dtype)\n acc += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n if HAS_SEQ_IDX:\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n seq_idx_m = tl.load(seq_idx_ptr + offs_m * stride_seq_idx_seqlen, mask=offs_m < chunk_size_limit, other=-1)\n seq_idx_n = tl.load(seq_idx_ptr + offs_n * stride_seq_idx_seqlen, mask=offs_n < chunk_size_limit, other=-2)\n acc = tl.where(seq_idx_m[:, None] == seq_idx_n[None, :], acc, 0.0)\n out = acc.to(out_ptr.dtype.element_ty)\n\n out_ptr += pid_b * stride_out_batch + pid_c * stride_out_chunk + pid_h * stride_out_head\n out_ptrs = out_ptr + (stride_outm * offs_m[:, None] + offs_n[None, :] * stride_outn)\n tl.store(out_ptrs, out, mask=(offs_m[:, None] < chunk_size) & (offs_n[None, :] < chunk_size))\n\ndef _bmm_chunk_fwd(a, b, chunk_size, seq_idx=None, causal=False, output_dtype=None):\n has_groups = a.dim() == 4\n if not has_groups:\n batch, seqlen, k = a.shape\n else:\n batch, seqlen, ngroups, k = a.shape\n assert b.shape == a.shape\n if seq_idx is not None:\n assert seq_idx.shape == (batch, seqlen)\n if a.stride(-1) != 1 and a.stride(1) != 1:\n a = a.contiguous()\n if b.stride(-1) != 1 and b.stride(1) != 1:\n b = b.contiguous()\n nchunks = math.ceil(seqlen / chunk_size)\n out_dtype = a.dtype if output_dtype is None else output_dtype\n out = torch.empty((batch, nchunks, chunk_size, chunk_size) if not has_groups else (batch, nchunks, ngroups, chunk_size, chunk_size),\n device=a.device, dtype=out_dtype)\n dot_dtype = (tl.bfloat16 if a.dtype == torch.bfloat16 or b.dtype == torch.bfloat16 else\n (tl.float16 if a.dtype == torch.float16 or b.dtype == torch.float16 else tl.float32))\n grid = lambda META: (triton.cdiv(chunk_size, META['BLOCK_SIZE_M']) * triton.cdiv(chunk_size, META['BLOCK_SIZE_N']),\n batch, nchunks if not has_groups else nchunks * ngroups)\n with torch.cuda.device(a.device.index):\n _bmm_chunk_fwd_kernel[grid](\n a, b, out, seq_idx,\n int(seqlen), int(chunk_size), int(k), int(ngroups if has_groups else 1),\n a.stride(0), a.stride(1), 0 if not has_groups else a.stride(2), a.stride(-1),\n b.stride(0), b.stride(1), 0 if not has_groups else b.stride(2), b.stride(-1),\n out.stride(0), out.stride(1), 0 if not has_groups else out.stride(2), out.stride(-2), out.stride(-1),\n *((seq_idx.stride(0), seq_idx.stride(1)) if seq_idx is not None else (0, 0)),\n causal,\n dot_dtype,\n HAS_SEQ_IDX=seq_idx is not None,\n )\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-05ca7984ef324b08809d6e7ebfb94a05", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel '_chunk_cumsum_fwd_kernel' performs a chunked cumulative sum on a 3D tensor 'dt' by iterating over batches, chunks, and heads. The data pointers are adjusted based on thread identifiers to load appropriate sections of the tensor. Optional bias and softplus transformations are applied to the data. The cumulative sum is computed for each chunk using the scaling factors 'A', and results are stored in 'dA_cumsum'. The output tensor 'dt_out' retains the modified values of 'dt' after transformation and clamping. The kernel uses block configurations to optimize performance for different problem sizes.\n The '_chunk_cumsum_fwd' function initializes output tensors, computes grid dimensions, and invokes the Triton kernel on the GPU. It checks shapes of inputs and optionally includes biases in the computation. Key configurations are determined by input dimensions, and blocks are selected for efficient parallel computation.\n Inputs:\n - dt: 3D tensor (batch, seqlen, nheads), source data.\n - A: 1D tensor (nheads,), scaling factors for cumulative sum.\n - chunk_size: Integer, size of data chunks.\n - dt_bias: Optional 1D tensor (nheads,), biases for dt.\n - dt_softplus: Boolean, softplus transformation flag.\n - dt_limit: Tuple, min and max clamping values for dt.\n Outputs:\n - dA_cumsum: Tensor (batch, nheads, nchunks, chunk_size), cumulative sum result.\n - dt_out: Tensor (batch, nheads, nchunks, chunk_size), transformed dt.\n ", + "output": [ + "\nimport math\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_H': 1}),\n triton.Config({'BLOCK_SIZE_H': 2}),\n triton.Config({'BLOCK_SIZE_H': 4}),\n triton.Config({'BLOCK_SIZE_H': 8}),\n triton.Config({'BLOCK_SIZE_H': 16}),\n triton.Config({'BLOCK_SIZE_H': 32}),\n triton.Config({'BLOCK_SIZE_H': 64}),\n ],\n key=['chunk_size', 'nheads'],\n)\n@triton.jit\ndef _chunk_cumsum_fwd_kernel(\n # Pointers to matrices\n dt_ptr, A_ptr, dt_bias_ptr, dt_out_ptr, dA_cumsum_ptr,\n # Matrix dimension\n batch, seqlen, nheads, chunk_size,\n dt_min, dt_max,\n # Strides\n stride_dt_batch, stride_dt_seqlen, stride_dt_head,\n stride_A_head,\n stride_dt_bias_head,\n stride_dt_out_batch, stride_dt_out_chunk, stride_dt_out_head, stride_dt_out_csize,\n stride_dA_cs_batch, stride_dA_cs_chunk, stride_dA_cs_head, stride_dA_cs_csize,\n # Meta-parameters\n DT_SOFTPLUS: tl.constexpr,\n HAS_DT_BIAS: tl.constexpr,\n BLOCK_SIZE_H: tl.constexpr, BLOCK_SIZE_CHUNK: tl.constexpr,\n):\n # Triton kernel implementation for chunked cumulative sum forward pass\n pid_b = tl.program_id(axis=0)\n pid_c = tl.program_id(axis=1)\n pid_h = tl.program_id(axis=2)\n dt_ptr += pid_b * stride_dt_batch + pid_c * chunk_size * stride_dt_seqlen\n dt_out_ptr += pid_b * stride_dt_out_batch + pid_c * stride_dt_out_chunk\n dA_cumsum_ptr += pid_b * stride_dA_cs_batch + pid_c * stride_dA_cs_chunk\n\n offs_h = pid_h * BLOCK_SIZE_H + tl.arange(0, BLOCK_SIZE_H)\n offs_c = tl.arange(0, BLOCK_SIZE_CHUNK)\n dt_ptrs = dt_ptr + (offs_h[:, None] * stride_dt_head + offs_c[None, :] * stride_dt_seqlen)\n A_ptrs = A_ptr + offs_h * stride_A_head\n dt_out_ptrs = dt_out_ptr + (offs_h[:, None] * stride_dt_out_head + offs_c[None, :] * stride_dt_out_csize)\n dA_cs_ptrs = dA_cumsum_ptr + (offs_h[:, None] * stride_dA_cs_head + offs_c[None, :] * stride_dA_cs_csize)\n chunk_size_limit = min(chunk_size, seqlen - pid_c * chunk_size)\n\n dt = tl.load(dt_ptrs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), other=0.0).to(tl.float32)\n if HAS_DT_BIAS:\n dt_bias = tl.load(dt_bias_ptr + offs_h * stride_dt_bias_head, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dt += dt_bias[:, None]\n if DT_SOFTPLUS:\n dt = tl.where(dt <= 20.0, tl.log(1 + tl.exp(dt)), dt)\n # As of Triton 2.2.0, tl.clamp is not available yet\n # dt = tl.clamp(dt, dt_min, dt_max)\n dt = tl.minimum(tl.maximum(dt, dt_min), dt_max)\n dt = tl.where((offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size_limit), dt, 0.0)\n tl.store(dt_out_ptrs, dt, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n A = tl.load(A_ptrs, mask=offs_h < nheads, other=0.0).to(tl.float32)\n dA = dt * A[:, None]\n dA_cs = tl.cumsum(dA, axis=1)\n tl.store(dA_cs_ptrs, dA_cs, mask=(offs_h[:, None] < nheads) & (offs_c[None, :] < chunk_size))\n\ndef _chunk_cumsum_fwd(dt, A, chunk_size, dt_bias=None, dt_softplus=False, dt_limit=(0.0, float(\"inf\"))):\n \"\"\"\n Function to perform the forward cumulative sum operation in chunks.\n\n Arguments:\n - dt: (batch, seqlen, nheads), the input tensor.\n - A: (nheads,), the scaling factors.\n - chunk_size: The size of each chunk to process at a time.\n - dt_bias: (nheads,), optional, biases for dt if applicable.\n - dt_softplus: Boolean, whether to apply the softplus operation to dt.\n - dt_limit: Tuple, (min, max) limits for clamping dt values.\n\n Returns:\n - dA_cumsum: Cumulative sum result.\n - dt_out: Modified dt after processing.\n \"\"\"\n batch, seqlen, nheads = dt.shape\n assert A.shape == (nheads,)\n if dt_bias is not None:\n assert dt_bias.shape == (nheads,)\n nchunks = math.ceil(seqlen / chunk_size)\n dt_out = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n dA_cumsum = torch.empty(batch, nheads, nchunks, chunk_size, device=dt.device, dtype=torch.float32)\n grid_chunk_cs = lambda META: (batch, nchunks, triton.cdiv(nheads, META['BLOCK_SIZE_H']))\n with torch.cuda.device(dt.device.index):\n _chunk_cumsum_fwd_kernel[grid_chunk_cs](\n dt, A, dt_bias, dt_out, dA_cumsum,\n int(batch), int(seqlen), int(nheads), int(chunk_size),\n dt_limit[0], dt_limit[1],\n dt.stride(0), dt.stride(1), dt.stride(2),\n A.stride(0),\n dt_bias.stride(0) if dt_bias is not None else 0,\n dt_out.stride(0), dt_out.stride(2), dt_out.stride(1), dt_out.stride(3),\n dA_cumsum.stride(0), dA_cumsum.stride(2), dA_cumsum.stride(1), dA_cumsum.stride(3),\n dt_softplus,\n HAS_DT_BIAS=dt_bias is not None,\n BLOCK_SIZE_CHUNK=triton.next_power_of_2(chunk_size),\n )\n return dA_cumsum, dt_out\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-04fd2eb3b7fe4c309407f44fce9c71f7", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code includes a Triton kernel called 'dequantize_kernel', which takes input int8 matrix 'b', scale factors 'b_scale', and outputs a float matrix 'fpb'. It uses block indices to load segments of 'b' and 'b_scale', multiplies these to dequantize, and stores the result in 'fpb'. Parameters 'K' and 'N' define matrix dimensions, while strides manage memory layout. The function 'matmul_dequantize_int8' ensures matrix dimension compatibility and performs the dequantization using 'dequantize_kernel', followed by a matrix multiplication with input matrix 'a' using PyTorch's 'torch.mm'. The result is stored in matrix 'c'. The kernel utilizes Triton's '@autotune' to optimize performance with different configurations specified by 'BLOCK_SIZE_N' and 'BLOCK_SIZE_K'.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 128}, num_stages=3, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 256}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 256}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32}, num_stages=5, num_warps=2),\n ],\n key=['K', 'N'],\n)\n\n\n@triton.jit\ndef dequantize_kernel(\n b_ptr, b_scale_ptr, fpb_ptr,\n K, N,\n stride_bk, stride_bn,\n stride_fpbk, stride_fpbn,\n BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n k_block_idx = tl.program_id(axis=0)\n n_block_idx = tl.program_id(axis=1)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n offs_n = tl.arange(0, BLOCK_SIZE_N)\n b_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_bk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_bn\n fpb_offs = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None]) * stride_fpbk + \\\n (n_block_idx * BLOCK_SIZE_N + offs_n[None, :]) * stride_fpbn\n bs_offs = n_block_idx * BLOCK_SIZE_N + offs_n[None, :]\n n_mask = n_block_idx * BLOCK_SIZE_N + offs_n[None, :] < N\n mask = (k_block_idx * BLOCK_SIZE_K + offs_k[:, None] < K) & n_mask\n int_b = tl.load(b_ptr + b_offs, mask=mask, other=0.0)\n scale_b = tl.load(b_scale_ptr + bs_offs, mask=n_mask, other=0.0)\n tl.store(fpb_ptr + fpb_offs, int_b * scale_b, mask=mask)\n\ndef matmul_dequantize_int8(a, b, b_scale, out=None):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n if out == None:\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n else:\n c = out\n fp_b = torch.empty((K, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (\n triton.cdiv(K, META['BLOCK_SIZE_K']), triton.cdiv(N, META['BLOCK_SIZE_N']),\n )\n dequantize_kernel[grid](\n b, b_scale, fp_b,\n K, N,\n b.stride(0), b.stride(1),\n fp_b.stride(0), fp_b.stride(1)\n )\n torch.mm(a, fp_b, out=c)\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-b42c3e9812af4f6f87020fca3fe7b117", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code provides a detailed example of a Triton kernel 'fused_add_mul_activation_kernel' and its corresponding PyTorch wrapper function 'fused_add_mul_activation_torch'. The kernel performs a fused operation of addition, multiplication, and activation on tensors. It operates on the input tensor 'x_ptr' by adding elements from 'bias_ptr' and a scaled version of 'in_ptr', then applies an activation function ('sigmoid' or 'relu'). Inputs are accessed in blocks, controlled by 'BLOCK_SIZE', which allows efficient memory handling and parallel execution. The function 'fused_add_mul_activation_torch' configures the grid for executing the Triton kernel based on the size of 'in_out_tensor' and invokes the kernel. The kernel processes data in chunks using block-wise indexing, applies the specified mathematical transformations, and stores the results back. It also manages memory through specific eviction policies to optimize cache usage. The wrapper sets a default multiplier and activation type, executes the kernel with these parameters, and returns the processed tensor.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef fused_add_mul_activation_kernel(x_ptr, bias_ptr, in_ptr,\n num_weights: tl.constexpr,\n xnumel: tl.constexpr,\n multiplier: tl.constexpr,\n activation: tl.constexpr,\n BLOCK_SIZE: tl.constexpr):\n xoffset = tl.program_id(0) * BLOCK_SIZE\n index = xoffset + tl.arange(0, BLOCK_SIZE)[:]\n mask = index < xnumel\n bias_index = index % num_weights\n tmp0 = tl.load(x_ptr + index, mask)\n tmp1 = tl.load(bias_ptr + bias_index, mask, eviction_policy='evict_last')\n tmp3 = tl.load(in_ptr + index, mask)\n activ_input = multiplier * tmp3 + tmp0 + tmp1\n if activation == \"sigmoid\":\n ma_result = tl.sigmoid(activ_input)\n # option 2 - calculate sigmoid using exp\n # ma_result = 1.0 / (1.0 + tl.exp(-sigmoid_input))\n # option 3: fast sigmoid - inaccurate but faster\n # ma_result = 1.0 / (1.0 + tl.abs(sigmoid_input))\n elif activation == \"relu\":\n ma_result = tl.maximum(0, activ_input)\n\n tl.store(x_ptr + index, ma_result, mask)\n\n\ndef fused_add_mul_activation_torch(in_out_tensor: torch.Tensor, bias: torch.Tensor,\n in_tensor: torch.Tensor) -> torch.Tensor:\n # print(\"calling fused_add_mul_relu_torch\")\n grid = lambda meta: (triton.cdiv(in_out_tensor.numel(), meta['BLOCK_SIZE']),)\n BLOCK_SIZE = min(2048, in_out_tensor.numel())\n fused_add_mul_activation_kernel[grid](in_out_tensor, bias, in_tensor,\n bias.numel(),\n in_out_tensor.numel(),\n multiplier=0.5,\n activation=\"sigmoid\",\n BLOCK_SIZE=BLOCK_SIZE)\n return in_out_tensor\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4c8f8b10044d47edb90ab45006fcd32a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `matmul_tma_load_store` is designed to perform block-based matrix multiplication. It uses the `tl.make_block_ptr` function to create block pointers for matrices A, B, and C, each with specific shapes and strides, which facilitate efficient loading of sub-matrices. The parameters `BLOCK_M`, `BLOCK_N`, and `BLOCK_K` define the shape of these sub-matrices. The kernel loads blocks of A and B using `tl.load` and computes the matrix product using `tl.dot`. If `OUTPUT_F16` is set, it converts the result to float16 before storing it with `tl.store`. The Python wrapper `warpper_tma_load_store` generates random matrices A and B, optionally transposes them, allocates matrix C, and calls the kernel. The wrapper accepts multiple parameters like matrix dimensions (M, N, K), number of warps and CTAs, transpose flags for A and B, and the output format for C.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_tma_load_store(\n a_ptr, b_ptr, c_ptr,\n M, N, K,\n stride_am, stride_ak,\n stride_bk, stride_bn,\n stride_cm, stride_cn,\n BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,\n OUTPUT_F16: tl.constexpr\n):\n # Create block pointers for A, B, and C matrices\n a_block_ptr = tl.make_block_ptr(base=a_ptr, shape=(M, K), strides=(stride_am, stride_ak), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_K), order=(1, 0))\n b_block_ptr = tl.make_block_ptr(base=b_ptr, shape=(K, N), strides=(stride_bk, stride_bn), offsets=(0, 0),\n block_shape=(BLOCK_K, BLOCK_N), order=(0, 1))\n c_block_ptr = tl.make_block_ptr(base=c_ptr, shape=(M, N), strides=(stride_cm, stride_cn), offsets=(0, 0),\n block_shape=(BLOCK_M, BLOCK_N), order=(1, 0))\n # Load A and B blocks\n a = tl.load(a_block_ptr)\n b = tl.load(b_block_ptr)\n\n # Compute matrix product\n c = tl.dot(a, b)\n # Optionally convert the result to float16\n if OUTPUT_F16:\n c = c.to(tl.float16)\n\n # Store the result\n tl.store(c_block_ptr, c)\n\n\ndef warpper_tma_load_store(M, N, K, NUM_CTAS, NUM_WARPS, TRANS_A, TRANS_B, OUTPUT_F16):\n # Prepare input matrices\n if (TRANS_A):\n a = torch.randn((K, M), device='cuda', dtype=torch.float16).T\n else:\n a = torch.randn((M, K), device='cuda', dtype=torch.float16)\n if (TRANS_B):\n b = torch.randn((N, K), device='cuda', dtype=torch.float16).T\n else:\n b = torch.randn((K, N), device='cuda', dtype=torch.float16)\n\n # Prepare output matrix\n c = torch.empty((M, N), device=a.device, dtype=torch.float32)\n if OUTPUT_F16:\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n\n # Execute Triton kernel\n matmul_tma_load_store[(1, 1)](\n a_ptr=a, b_ptr=b, c_ptr=c,\n M=M, N=N, K=K,\n stride_am=a.stride(0), stride_ak=a.stride(1),\n stride_bk=b.stride(0), stride_bn=b.stride(1),\n stride_cm=c.stride(0), stride_cn=c.stride(1),\n BLOCK_M=M, BLOCK_N=N, BLOCK_K=K,\n num_warps=NUM_WARPS, num_ctas=NUM_CTAS,\n OUTPUT_F16=OUTPUT_F16)\n return c\n \n\n\n\n" + ] + }, + { + "id": "openseek-8-aa1aa9c6f5dc483a99ca35aa0265bdb1", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton implementation contains two kernels, `mul2_kernel` and `mul2_inplace_kernel`, for multiplying elements of a tensor by 2 using parallel computation. Both kernels use a block size (`BLOCK_SIZE`) to split the computation into manageable chunks. They retrieve the program ID (`pid`) to identify the current block and calculate `block_start`, the starting index for this block. Offsets are computed as `block_start + tl.arange(0, BLOCK_SIZE)` to locate elements within the block. A mask is created to ensure operations stay within `n_elements`, preventing out-of-bounds access. The elements are loaded from memory, doubled, and written back to the specified location using `tl.load` and `tl.store`. The `triton_mul2` function initializes an output tensor, determines the grid size using `triton.cdiv(n_elements, BLOCK_SIZE)`, and calls `mul2_kernel` with appropriate arguments. The `triton_mul2_inplace` function directly modifies the input tensor, using a similar setup to call `mul2_inplace_kernel`. These functions leverage the efficiency of GPU parallelism provided by Triton to scale with the size of input data efficiently.\n ", + "output": [ + "import triton\nfrom triton import language as tl\nimport torch\n\n\n@triton.jit\ndef mul2_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = 2 * x\n tl.store(out_ptr + offsets, output, mask=mask)\n\n@triton.jit\ndef mul2_inplace_kernel(\n ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(ptr + offsets, mask=mask)\n output = 2 * x\n tl.store(ptr + offsets, output, mask=mask)\n\n\ndef triton_mul2(x, BLOCK_SIZE=16):\n output = torch.zeros_like(x)\n n_elements = x.numel()\n grid = (triton.cdiv(n_elements, BLOCK_SIZE),)\n mul2_kernel[grid](x, output, n_elements, BLOCK_SIZE)\n return output\n\ndef triton_mul2_inplace(x, BLOCK_SIZE=16):\n n_elements = x.numel()\n grid = (triton.cdiv(n_elements, BLOCK_SIZE),)\n mul2_inplace_kernel[grid](x, n_elements, BLOCK_SIZE)\n return x\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-68e9eb390c104c0a9c39dedd9de420bf", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The given Triton kernel named `update_fn_kernel` is designed to execute a parameter update using gradient and momentum for optimization purposes in a GPU environment. It operates over a 1D grid where each block handles a set of elements defined by `BLOCK_SIZE`. The kernel begins by computing the starting offset for each block using the program ID `pid` and `BLOCK_SIZE`. It masks the operation to ensure that only valid elements (within `n_elements`) are processed.\n\n The kernel reads values from memory for parameters (`p_ptr`), gradients (`grad_ptr`), and the exponential moving average of past gradients (`exp_avg_ptr`). For each parameter, it applies step weight decay to scale down the parameter by `(1 - lr * wd)`. It calculates the difference between the exponential average and the current gradient, and uses this to update the parameter with a momentum term determined by `beta1`. A sign-based adjustment is applied to the parameter only if there's a change, driven by a conditional mask, emulating a `torch.sign` operation.\n\n After the parameter update, the exponential moving average is decayed using `beta2` and the new gradient. The results for the updated parameter and exponential average are stored back to global memory. The wrapper function `update_fn` ensures all tensors involved are CUDA tensors and calculates the execution grid using the total number of elements divided by `BLOCK_SIZE`. This wrapper also sets up the execution environment, invoking the kernel with these tensors and the relevant hyperparameters like learning rate (`lr`), weight decay (`wd`), and momentum terms (`beta1`, `beta2`).\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(configs = [\n triton.Config({'BLOCK_SIZE': 128}, num_warps = 4),\n triton.Config({'BLOCK_SIZE': 1024}, num_warps = 8),\n], key = ['n_elements'], restore_value=['p_ptr', 'exp_avg_ptr'])\n\n# Triton CUDA kernel\n\n@triton.jit\ndef update_fn_kernel(\n p_ptr,\n grad_ptr,\n exp_avg_ptr,\n lr,\n wd,\n beta1,\n beta2,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n\n mask = offsets < n_elements\n\n # Offsetted pointers\n offset_p_ptr = p_ptr + offsets\n offset_grad_ptr = grad_ptr + offsets\n offset_exp_avg_ptr = exp_avg_ptr + offsets\n\n # Load\n p = tl.load(offset_p_ptr, mask=mask)\n grad = tl.load(offset_grad_ptr, mask=mask)\n exp_avg = tl.load(offset_exp_avg_ptr, mask=mask)\n\n # Stepweight decay\n p = p * (1 - lr * wd)\n\n # Diff between momentum running average and grad\n diff = exp_avg - grad\n\n # Weight update\n update = diff * beta1 + grad\n\n # torch.sign\n can_update = update != 0\n update_sign = tl.where(update > 0, -lr, lr)\n\n p = p + update_sign * can_update\n\n # Decay the momentum running average coefficient\n exp_avg = diff * beta2 + grad\n\n # Store new params and momentum running average coefficient\n tl.store(offset_p_ptr, p, mask=mask)\n tl.store(offset_exp_avg_ptr, exp_avg, mask=mask)\n\ndef update_fn(\n p: torch.Tensor,\n grad: torch.Tensor,\n exp_avg: torch.Tensor,\n lr: float,\n wd: float,\n beta1: float,\n beta2: float\n):\n assert all([t.is_cuda for t in (p, grad, exp_avg)])\n n_elements = p.numel()\n\n grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)\n\n update_fn_kernel[grid](\n p,\n grad,\n exp_avg,\n lr,\n wd,\n beta1,\n beta2,\n n_elements\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-be30aba757cc4547a8a9e279cde4ee82", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements a GPU-based argmax operation using Triton to optimize performance for different tensor shapes and dimensionalities. \n\n The primary function, `argmax`, executes one of two pathways based on whether a specific dimension is specified. Without a dimension (`dim=None`), it treats the input tensor as a flat array, applying a two-stage reduction:\n \n - `argmax_kernel_1`: This kernel iterates over blocks of the input tensor, calculating the maximum value and its corresponding index within each block. The kernel parameters include:\n - `inp`: The input tensor.\n - `mid_value`, `mid_index`: Buffers to store intermediate max values and their indices.\n - `M`: Total number of elements.\n - `BLOCK_SIZE`: Defines the block of elements each instance processes.\n - `INT64_INDEX`: Flags whether indices require 64-bit representation based on input size.\n \n - `argmax_kernel_2`: This kernel finalizes the maximum index from the intermediate results. It reads from `mid_value` and `mid_index` to determine the overall maximum index and store it in the output.\n - `mid_value`, `mid_index`: Intermediate results from the first kernel.\n - `out`: The tensor where the final maximum index is stored.\n - `mid_size`: The number of blocks processed.\n - `BLOCK_MID`: Defines the intermediate block size for reduction.\n \n When a dimension is specified, the `argmax_kernel` is employed, which operates as follows:\n - Iterates over the specified dimension to find the maximum value index across rows/columns.\n - `M`, `N`, `K`: Represent the product of dimensions before, along, and after the specified dimension.\n - `BLOCK_M`, `BLOCK_N`: Define the number of elements each block of threads should process in these dimensions.\n \n The function ensures correct memory use with either int32 or int64 indices based on the tensor's total number of elements by calling `can_use_int32_index`. It adjusts the execution grid dynamically to accommodate tensor size and block configuration.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n\ndef can_use_int32_index(tensor):\n # This function checks if the tensor can use int32 indices\n return tensor.numel() < 2**31\n\n# Kernel 1: argmax_kernel_1\n@triton.jit\ndef argmax_kernel_1(\n inp,\n mid_value,\n mid_index,\n M,\n BLOCK_SIZE: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n pid = tl.program_id(0)\n if INT64_INDEX:\n pid = pid.to(tl.int64)\n offset = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n inp_ptrs = inp + offset\n mask = offset < M\n inp_val = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n max_val, max_index = tl.max(inp_val, axis=0, return_indices=True)\n max_index = max_index + pid * BLOCK_SIZE\n mid_value_ptr = mid_value + pid\n max_index_ptr = mid_index + pid\n tl.store(mid_value_ptr, max_val)\n tl.store(max_index_ptr, max_index)\n\n# Kernel 2: argmax_kernel_2\n@triton.jit\ndef argmax_kernel_2(mid_value, mid_index, out, mid_size, BLOCK_MID: tl.constexpr):\n offset = tl.arange(0, BLOCK_MID)\n mid_ptrs = mid_value + offset\n mask = offset < mid_size\n mid_val = tl.load(mid_ptrs, mask=mask, other=-float(\"inf\"))\n index_val = tl.argmax(mid_val, axis=0)\n mid_index_ptrs = mid_index + index_val\n out_val = tl.load(mid_index_ptrs)\n tl.store(out, out_val)\n\n# Kernel 3: argmax_kernel\n@triton.jit\ndef argmax_kernel(\n inp,\n out_index,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n INT64_INDEX: tl.constexpr = False,\n):\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n if INT64_INDEX:\n pid_m = pid_m.to(tl.int64)\n pid_k = pid_k.to(tl.int64)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n\n max_values = tl.full([BLOCK_M], dtype=tl.float32, value=float(\"-inf\"))\n argmax_values = tl.full([BLOCK_M], dtype=tl.int64, value=0)\n for start_n in range(0, N, BLOCK_N):\n n_offset = start_n + tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n inp_ptrs = inp + offset\n inp_vals = tl.load(inp_ptrs, mask=mask, other=-float(\"inf\"))\n local_max, local_argmax = tl.max(\n inp_vals, 1, return_indices=True, return_indices_tie_break_left=True\n )\n update = local_max > max_values\n max_values = tl.where(update, local_max, max_values)\n argmax_values = tl.where(update, start_n + local_argmax, argmax_values)\n\n offset_index = m_offset * K + pid_k\n out_index_ptrs = out_index + offset_index\n mask1 = m_offset < M\n tl.store(out_index_ptrs, argmax_values, mask=mask1)\n\n# Function calling the kernels\ndef argmax(inp, dim=None, keepdim=False, *, dtype=None):\n if dim is None:\n M = inp.numel()\n if dtype is None:\n dtype = inp.dtype\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(M)))\n mid_size = triton.cdiv(M, block_size)\n block_mid = triton.next_power_of_2(mid_size)\n use_int64_index = not can_use_int32_index(inp)\n\n mid_value = torch.empty((mid_size,), dtype=dtype, device=inp.device)\n mid_index = torch.empty((mid_size,), dtype=torch.int64, device=inp.device)\n if keepdim:\n shape = list(inp.shape)\n for i in range(0, inp.dim()):\n shape[i] = 1\n out = torch.empty(shape, dtype=torch.int64, device=inp.device)\n else:\n out = torch.empty([], dtype=torch.int64, device=inp.device)\n\n with torch.cuda.device(inp.device):\n argmax_kernel_1[(mid_size, 1, 1)](\n inp,\n mid_value,\n mid_index,\n M,\n block_size,\n INT64_INDEX=use_int64_index,\n )\n argmax_kernel_2[(1, 1, 1)](mid_value, mid_index, out, mid_size, block_mid)\n return out\n else:\n assert dim >= -inp.ndim and dim < inp.ndim, \"Invalid dim\"\n shape = inp.shape\n dim = dim % inp.ndim\n N = shape[dim]\n M = math.prod(shape[:dim])\n K = inp.numel() // M // N\n\n inp = inp.contiguous()\n use_int64_index = not can_use_int32_index(inp)\n\n shape_list = list(shape)\n shape_list[dim] = 1\n out_index = torch.empty(shape_list, dtype=torch.int64, device=inp.device)\n if not keepdim:\n out_index = torch.squeeze(out_index, dim)\n\n BLOCK_M = 128 # Example, adjust as needed\n BLOCK_N = 128 # Example, adjust as needed\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n argmax_kernel[grid](\n inp,\n out_index,\n M,\n N,\n K,\n BLOCK_M=BLOCK_M,\n BLOCK_N=BLOCK_N,\n INT64_INDEX=use_int64_index,\n )\n\n return out_index\n\n\n\n" + ] + }, + { + "id": "openseek-8-9524190d9ec84d3bb1115da9de5c27b3", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `_layer_norm_fwd_1pass_kernel` function is a Triton-based CUDA kernel for computing layer normalization in a single pass. It is capable of handling several advanced operations: applying dropout, adding biases, and optionally handling residual connections or additional input tensors (`X1`, `W1`, `B1`). The function operates over 2D tensors `X` and `Y` with dimensions MxN. The main inputs are `X`, `Y`, `W` (weights), and `B` (biases). It supports optional components like `RESIDUAL`, `X1`, `W1`, `B1` and involves additional computations like scaling with `ROWSCALE`, random dropout determined by `SEEDS`, and storing results in `DROPOUT_MASK`. The computationally intensive parts are the normalization steps, which involve calculating the mean and variance (except in RMS norm mode), and applying weights and biases to produce the output. This is all highly parallelized for efficiency on GPUs. A collection of configurations and heuristics helps optimize the kernel's performance based on input size and features used. Outputs include normalized outputs stored in `Y` (and optionally `Y1`), computed mean and inverse standard deviation (for non-RMS normalization), adjusted residuals if `residual_out` is specified, as well as dropout seeds and masks.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"N\", \"HAS_RESIDUAL\", \"STORE_RESIDUAL_OUT\", \"IS_RMS_NORM\", \"HAS_BIAS\"],\n)\n@triton.heuristics({\"HAS_X1\": lambda args: args[\"X1\"] is not None})\n@triton.heuristics({\"HAS_W1\": lambda args: args[\"W1\"] is not None})\n@triton.heuristics({\"HAS_B1\": lambda args: args[\"B1\"] is not None})\n@triton.jit\ndef _layer_norm_fwd_1pass_kernel(\n X, Y, W, B, RESIDUAL, X1, W1, B1, Y1, RESIDUAL_OUT, ROWSCALE, SEEDS, DROPOUT_MASK, Mean, Rstd,\n stride_x_row, stride_y_row, stride_res_row, stride_res_out_row, stride_x1_row, stride_y1_row,\n M, N, eps, dropout_p, IS_RMS_NORM: tl.constexpr, BLOCK_N: tl.constexpr, HAS_RESIDUAL: tl.constexpr,\n STORE_RESIDUAL_OUT: tl.constexpr, HAS_BIAS: tl.constexpr, HAS_DROPOUT: tl.constexpr,\n STORE_DROPOUT_MASK: tl.constexpr, HAS_ROWSCALE: tl.constexpr, HAS_X1: tl.constexpr,\n HAS_W1: tl.constexpr, HAS_B1: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_y_row\n if HAS_RESIDUAL:\n RESIDUAL += row * stride_res_row\n if STORE_RESIDUAL_OUT:\n RESIDUAL_OUT += row * stride_res_out_row\n if HAS_X1:\n X1 += row * stride_x1_row\n if HAS_W1:\n Y1 += row * stride_y1_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + row).to(tl.float32)\n x *= rowscale\n if HAS_DROPOUT:\n keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)\n if HAS_X1:\n x1 = tl.load(X1 + cols, mask=cols < N, other=0.0).to(tl.float32)\n if HAS_ROWSCALE:\n rowscale = tl.load(ROWSCALE + M + row).to(tl.float32)\n x1 *= rowscale\n if HAS_DROPOUT:\n keep_mask = (\n tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p\n )\n x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)\n if STORE_DROPOUT_MASK:\n tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)\n x += x1\n if HAS_RESIDUAL:\n residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)\n x += residual\n if STORE_RESIDUAL_OUT:\n tl.store(RESIDUAL_OUT + cols, x, mask=cols < N)\n if not IS_RMS_NORM:\n mean = tl.sum(x, axis=0) / N\n tl.store(Mean + row, mean)\n xbar = tl.where(cols < N, x - mean, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n else:\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n tl.store(Rstd + row, rstd)\n mask = cols < N\n w = tl.load(W + cols, mask=mask).to(tl.float32)\n if HAS_BIAS:\n b = tl.load(B + cols, mask=mask).to(tl.float32)\n x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd\n y = x_hat * w + b if HAS_BIAS else x_hat * w\n tl.store(Y + cols, y, mask=mask)\n if HAS_W1:\n w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)\n if HAS_B1:\n b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)\n y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1\n tl.store(Y1 + cols, y1, mask=mask)\n\ndef _layer_norm_fwd(\n x, weight, bias, eps, residual=None, x1=None, weight1=None, bias1=None, dropout_p=0.0,\n rowscale=None, out_dtype=None, residual_dtype=None, is_rms_norm=False, return_dropout_mask=False,\n):\n if residual is not None:\n residual_dtype = residual.dtype\n M, N = x.shape\n assert x.stride(-1) == 1\n if residual is not None:\n assert residual.stride(-1) == 1\n assert residual.shape == (M, N)\n assert weight.shape == (N,)\n assert weight.stride(-1) == 1\n if bias is not None:\n assert bias.stride(-1) == 1\n assert bias.shape == (N,)\n if x1 is not None:\n assert x1.shape == x.shape\n assert rowscale is None\n assert x1.stride(-1) == 1\n if weight1 is not None:\n assert weight1.shape == (N,)\n assert weight1.stride(-1) == 1\n if bias1 is not None:\n assert bias1.shape == (N,)\n assert bias1.stride(-1) == 1\n if rowscale is not None:\n assert rowscale.is_contiguous()\n assert rowscale.shape == (M,)\n y = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)\n assert y.stride(-1) == 1\n if weight1 is not None:\n y1 = torch.empty_like(y)\n assert y1.stride(-1) == 1\n else:\n y1 = None\n if (\n residual is not None\n or (residual_dtype is not None and residual_dtype != x.dtype)\n or dropout_p > 0.0\n or rowscale is not None\n or x1 is not None\n ):\n residual_out = torch.empty(\n M, N, device=x.device, dtype=residual_dtype if residual_dtype is not None else x.dtype\n )\n assert residual_out.stride(-1) == 1\n else:\n residual_out = None\n mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None\n rstd = torch.empty((M,), dtype=torch.float32, device=x.device)\n if dropout_p > 0.0:\n seeds = torch.randint(\n 2**32, (M if x1 is None else 2 * M,), device=x.device, dtype=torch.int64\n )\n else:\n seeds = None\n if return_dropout_mask and dropout_p > 0.0:\n dropout_mask = torch.empty(M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool)\n else:\n dropout_mask = None\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _layer_norm_fwd_1pass_kernel[(M,)](\n x, y, weight, bias, residual, x1, weight1, bias1, y1, residual_out, rowscale, seeds,\n dropout_mask, mean, rstd, x.stride(0), y.stride(0), residual.stride(0) if residual is not None else 0,\n residual_out.stride(0) if residual_out is not None else 0, x1.stride(0) if x1 is not None else 0,\n y1.stride(0) if y1 is not None else 0, M, N, eps, dropout_p, is_rms_norm, BLOCK_N,\n residual is not None, residual_out is not None, bias is not None, dropout_p > 0.0,\n dropout_mask is not None, rowscale is not None,\n )\n if dropout_mask is not None and x1 is not None:\n dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)\n else:\n dropout_mask1 = None\n return (\n y, y1, mean, rstd, residual_out if residual_out is not None else x, seeds, dropout_mask, dropout_mask1,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-1e6f826df6e24b0aa1972b8bec59f061", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements chunk retention operations using Triton, targeting high efficiency for transformer-like operations. It includes several kernels, each focused on specific computation aspects, along with a PyTorch autograd function to facilitate integration into neural network training pipelines.\n\n - `chunk_retention_fwd_kernel_h`: This Triton kernel calculates an intermediate tensor 'h'. It accepts inputs 'k', 'v', and optional 'initial_state', manipulating a buffer 'b_h' initialized to zeros. If `USE_INITIAL_STATE` is true, it loads initial state data into 'b_h'. It performs several steps:\n 1. Computes decay factors 'd_b' and 'd_i' based on a custom decay function.\n 2. Iterates over time dimension 'NT', loading blocks of 'k' and 'v'.\n 3. Updates 'b_h' with a dot product computation, applying decay.\n 4. Optionally stores final state in 'final_state' if `STORE_FINAL_STATE` is set.\n\n - `chunk_retention_fwd_kernel_o`: This kernel computes the output tensor 'o'. It integrates:\n 1. Loading blocks of 'q', 'k', 'v', and precomputed 'h'.\n 2. Uses decay factors 'd_i' and precomputed values to modulate contributions.\n 3. Aggregates contributions in 'b_o' and 'b_s' from interactions between 'q', 'k', and 'v'.\n 4. Final 'o' is calculated by scaling and storing the aggregated results.\n\n - `chunk_retention_bwd_kernel_dh`: Part of the backward pass kernels, it calculates 'dh', accumulating gradient contributions while iterating backwards over time steps.\n\n - `chunk_retention_bwd_kernel_dqkv`: This kernel computes the gradients with respect to 'q', 'k', and 'v' (denoted as 'dq', 'dk', 'dv'). It uses:\n 1. Loading blocks and calculating gradients using Triton's matrix operations.\n 2. Applying decay and scaling to modulate contributions.\n 3. Storing calculated gradients for 'q', 'k', 'v' efficiently.\n\n - `ChunkRetentionFunction`: A custom PyTorch autograd function wrapping these kernels, providing an interface for forward and backward passes. It automatically handles device operations and manages kernel launch configuration.\n\n The `chunk_retention` function acts as a user interface, preparing inputs, invoking the autograd function, and handling optional initial state and final state management.\n ", + "output": [ + "from typing import Tuple\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch.amp import custom_bwd, custom_fwd\nimport functools\n\ndef contiguous(fn):\n @functools.wraps(fn)\n def wrapper(ctx, *args, **kwargs):\n return fn(ctx,\n *(i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args),\n **{k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()})\n return wrapper\n\n@triton.jit\ndef chunk_retention_fwd_kernel_h(\n k,\n v,\n h,\n initial_state, # initial state of the chunk [B, H, D_head_K, D_head_V]\n final_state, # final state of the chunk [B, H, D_head_K, D_head_V]\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n # [BK, BV]\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(initial_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n # [BK, BV]\n if i_t == NT - 1 and (T % BT) != 0:\n d_b = tl.math.exp2((T % BT) * b_b)\n d_i = tl.math.exp2(((T % BT) - o_i - 1) * b_b)\n b_h = d_b * b_h + tl.dot(b_k, (b_v * d_i[:, None]).to(b_k.dtype), allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(final_state + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_retention_fwd_kernel_o(\n q,\n k,\n v,\n h,\n o,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_i = tl.math.exp2((o_i + 1) * b_b)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0)\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n # [BT, BK]\n b_q = tl.load(p_q, boundary_check=(0, 1))\n # [BK, BT]\n b_k = tl.load(p_k, boundary_check=(0, 1))\n # [BK, BV]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot((b_q * d_i[:, None]).to(b_q.dtype), b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n\n b_s *= d_s\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dh(\n q,\n do,\n dh,\n v,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_b, d_i = tl.math.exp2(BT * b_b), tl.math.exp2((o_i + 1) * b_b)\n # [BK, BV]\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_o = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n # [BT, BV]\n b_o = tl.load(p_o, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n\n b_dh += tl.dot((b_o * d_i[:, None]).to(b_o.dtype), b_v, allow_tf32=False)\n\n b_dh *= d_b\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_k * K * V, (K, V), (s_h_t, 1), (i_v * BV, i_t * BT), (BK, BV), (1, 0))\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_retention_bwd_kernel_dqkv(\n q,\n k,\n v,\n h,\n do,\n dh,\n dq,\n dk,\n dv,\n s_qk_h,\n s_qk_t,\n s_qk_d,\n s_vo_h,\n s_vo_t,\n s_vo_d,\n s_h_h,\n s_h_t,\n scale,\n H: tl.constexpr,\n T: tl.constexpr,\n K: tl.constexpr,\n V: tl.constexpr,\n BT: tl.constexpr,\n BK: tl.constexpr,\n BV: tl.constexpr,\n NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n i_h = i_bh % H\n n_bh = tl.num_programs(2)\n b_b = tl.math.log2(1 - tl.math.exp2(-5 - i_h * 1.0))\n\n o_i = tl.arange(0, BT)\n d_q, d_k = tl.math.exp2((o_i + 1) * b_b), tl.math.exp2((BT - o_i - 1) * b_b)\n d_q = (d_q * scale).to(d_q.dtype)\n m_s = o_i[:, None] >= o_i[None, :]\n d_s = tl.where(m_s, tl.math.exp2((o_i[:, None] - o_i[None, :]) * b_b), 0) * scale\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * tl.trans(d_s)\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k * n_bh + i_bh) * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n # [BT, BV]\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n # [BV, BK]\n b_h = tl.load(p_h, boundary_check=(0, 1))\n # [BK, BV]\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n\n # [BT, BT]\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n # [BT, BK]\n b_dq += tl.dot(b_do, b_h, allow_tf32=False)\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n # [BT, BV]\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) * d_k[:, None] + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n # [BT, BT]\n b_ds = (b_ds * d_s).to(b_q.dtype)\n # [BT, BK]\n b_dq = b_dq * d_q[:, None] + tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk = b_dk * d_k[:, None] + tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkRetentionFunction(torch.autograd.Function):\n\n @staticmethod\n @custom_fwd(device_type='cuda')\n @contiguous\n def forward(ctx, q, k, v, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_fwd_kernel_h[grid](\n k, v, h, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_retention_fwd_kernel_o[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n ctx.save_for_backward(q, k, v, h)\n return o.to(q.dtype), final_state\n\n @staticmethod\n @custom_bwd(device_type='cuda')\n @contiguous\n def backward(ctx, do, d_ht=None):\n q, k, v, h = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = K ** -0.5\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_retention_bwd_kernel_dh[grid](\n q, do, dh, v,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_retention_bwd_kernel_dqkv[grid](\n q, k, v, h, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n H=H, T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None\n\ndef chunk_retention(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = ChunkRetentionFunction.apply(q, k, v, initial_state, output_final_state)\n return o, final_state\n\n\n\n" + ] + }, + { + "id": "openseek-8-ceb7394af7814196860d3cc2cc080501", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel sin_kernel computes the element-wise sine of a 1D input array. It takes the following parameters: in_ptr0, a pointer to the input data; out_ptr, a pointer to the output data; n_elements, the total number of elements to process; and BLOCK_SIZE, a compile-time constant determining the block of data processed per program instance. The kernel uses tl.program_id(axis=0) to identify which block it is processing and calculates offsets within the data based on the block ID and BLOCK_SIZE. It loads input data using tl.load with masking to prevent out-of-bounds access. After applying tl.sin to compute the sine of each loaded element, it stores the result back into memory using tl.store, again using a mask to ensure only valid memory locations are written. The wrapper function sin_triton initializes the computation by determining the number of elements in the input tensor x, and invokes the sin_kernel, specifying the grid size as the number of elements and setting the BLOCK_SIZE to 4.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef sin_kernel(\n in_ptr0,\n out_ptr,\n n_elements,\n BLOCK_SIZE: \"tl.constexpr\",\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(in_ptr0 + offsets, mask=mask)\n output = tl.sin(x)\n tl.store(out_ptr + offsets, output, mask=mask)\n\ndef sin_triton(x, out):\n n_elements = x.numel()\n sin_kernel[(n_elements,)](x, out, n_elements, BLOCK_SIZE=4)\n\n\n\n" + ] + }, + { + "id": "openseek-8-ea23f263a7f44d5394321aa13dc1abb1", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel function `decoding_fused_rotary_embedding_kernel` applies rotary embedding transformation to the input tensors `q`, `k`, and `v`. This is achieved by computing sine and cosine rotations separately for halves of the head dimension. The result is stored back into the original query tensor `q`. Additionally, it selectively updates `k_cache` and `v_cache` tensors, which store past key/value embeddings, based on the `KV_GROUP_NUM` and layout configurations.\n\n The kernel uses tensor slicing based on current program IDs to target specific heads and tokens within the tensors. The core of the computation involves two halves of the head dimension where sine and cosine transformations are applied to achieve rotary embeddings, which rotates the input embedding vectors to incorporate positional information.\n\n The function `decoding_fused_rotary_embedding` is responsible for setting up kernel execution, including determining the grid size which depends on the number of heads and tokens. It takes in additional parameters `block_tables` and `kv_lengths` for managing cache layouts. It supports flexible cache layouts controlled by `use_new_kcache_layout` and dynamically computes strides for tensor access. The function chooses the number of warps based on head dimension to balance performance and resource usage.\n\n Parameters for the kernel include various strides and offsets, indicating how to traverse and manipulate the memory layout efficiently. The grid defined by `(q_head_num, q_total_tokens)` enables parallel processing of each head and token.\n\n Together, these components facilitate efficient decoding in transformer models, enhancing operations like attention by embedding positional information directly in the input embeddings.\n ", + "output": [ + "import warnings\nfrom typing import Optional\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef decoding_fused_rotary_embedding_kernel(\n q,\n k,\n v,\n cos,\n sin,\n k_cache,\n v_cache,\n BLOCK_TABLES,\n context_lengths,\n x,\n q_token_stride,\n q_head_stride,\n k_token_stride,\n k_head_stride,\n head_dim_stride,\n cos_token_stride,\n cos_stride,\n kcb_stride,\n kch_stride,\n kcsplit_x_stride,\n kcs_stride,\n kcd_stride,\n vcb_stride,\n vch_stride,\n vcs_stride,\n vcd_stride,\n bts_stride,\n btb_stride,\n block_size,\n KV_GROUP_NUM: tl.constexpr,\n HEAD_DIM: tl.constexpr,\n):\n cur_head_idx = tl.program_id(0)\n cur_token_idx = tl.program_id(1)\n\n dim_range = tl.arange(0, HEAD_DIM)\n dim_range0 = tl.arange(0, HEAD_DIM // 2)\n dim_range1 = tl.arange(HEAD_DIM // 2, HEAD_DIM)\n\n off_q = cur_token_idx * q_token_stride + cur_head_idx * q_head_stride\n off_q0 = off_q + dim_range0 * head_dim_stride\n off_q1 = off_q + dim_range1 * head_dim_stride\n\n loaded_q0 = tl.load(q + off_q0)\n loaded_q1 = tl.load(q + off_q1)\n off_cos_sin = cur_token_idx * cos_token_stride + dim_range0 * cos_stride\n loaded_cos = tl.load(cos + off_cos_sin)\n loaded_sin = tl.load(sin + off_cos_sin)\n\n out_q0 = loaded_q0 * loaded_cos - loaded_q1 * loaded_sin\n out_q1 = loaded_q0 * loaded_sin + loaded_q1 * loaded_cos\n tl.store(q + off_q0, out_q0)\n tl.store(q + off_q1, out_q1)\n\n handle_kv = cur_head_idx % KV_GROUP_NUM == 0\n if handle_kv:\n cur_k_head_idx = cur_head_idx // KV_GROUP_NUM\n off_kv = cur_token_idx * k_token_stride + cur_k_head_idx * k_head_stride\n off_k0 = off_kv + dim_range0 * head_dim_stride\n off_k1 = off_kv + dim_range1 * head_dim_stride\n loaded_k0 = tl.load(k + off_k0)\n loaded_k1 = tl.load(k + off_k1)\n\n out_k0 = loaded_k0 * loaded_cos - loaded_k1 * loaded_sin\n out_k1 = loaded_k0 * loaded_sin + loaded_k1 * loaded_cos\n\n # NOTE The precondition here is that it's only for unpadded inputs during decoding stage,\n # and so that we could directly use the token index as the sequence index\n past_kv_seq_len = tl.load(context_lengths + cur_token_idx) - 1\n\n last_block_idx = past_kv_seq_len // block_size\n block_ids = tl.load(BLOCK_TABLES + cur_token_idx * bts_stride + last_block_idx * btb_stride)\n offsets_in_last_block = past_kv_seq_len % block_size\n offsets_cache_base = block_ids * kcb_stride + cur_k_head_idx * kch_stride\n k_range0 = (\n offsets_cache_base\n + offsets_in_last_block * kcs_stride\n + (dim_range0 // x) * kcsplit_x_stride\n + (dim_range0 % x) * kcd_stride\n )\n k_range1 = (\n offsets_cache_base\n + offsets_in_last_block * kcs_stride\n + (dim_range1 // x) * kcsplit_x_stride\n + (dim_range1 % x) * kcd_stride\n )\n tl.store(k_cache + k_range0, out_k0)\n tl.store(k_cache + k_range1, out_k1)\n\n off_v = off_kv + dim_range * head_dim_stride\n loaded_v = tl.load(v + off_v)\n v_range = (\n block_ids * vcb_stride\n + cur_k_head_idx * vch_stride\n + offsets_in_last_block * vcs_stride\n + dim_range * vcd_stride\n )\n tl.store(v_cache + v_range, loaded_v)\n\n\ndef decoding_fused_rotary_embedding(\n q: torch.Tensor,\n k: torch.Tensor,\n v: torch.Tensor,\n cos: torch.Tensor,\n sin: torch.Tensor,\n k_cache: Optional[torch.Tensor] = None,\n v_cache: Optional[torch.Tensor] = None,\n block_tables: Optional[torch.Tensor] = None,\n kv_lengths: Optional[torch.Tensor] = None,\n use_new_kcache_layout: bool = False,\n):\n \"\"\"\n Args:\n q: query tensor, [total_tokens, head_num, head_dim]\n k: key tensor, [total_tokens, kv_head_num, head_dim]\n v: value tensor, [total tokens, kv_head_num, head_dim]\n cos: cosine for rotary embedding, [max_position_len, head_dim]\n sin: sine for rotary embedding, [max_position_len, head_dim]\n k_cache (torch.Tensor): Blocked key cache. [num_blocks, kv_head_num, block_size, head_dim]\n v_cache (torch.Tensor): Blocked value cache. [num_blocks, kv_head_num, block_size, head_dim]\n kv_lengths, Past key/value sequence lengths plus current sequence length for each sequence. [bsz]\n block_tables: Block tables for each sequence. [bsz, max_blocks_per_sequence]\n \"\"\"\n q_total_tokens, q_head_num, head_dim = q.shape\n assert q.size(0) == k.size(0) == v.size(0)\n\n if head_dim >= 512:\n num_warps = 16\n elif head_dim >= 256:\n num_warps = 8\n else:\n num_warps = 4\n k_head_num = k.size(1)\n kv_group_num = q_head_num // k_head_num\n\n # For KCache and VCache with the same layout\n x = head_dim\n kcsplit_x_stride, kcs_stride, kcd_stride = 0, k_cache.stride(2), k_cache.stride(3)\n # For KCache layout [num_blocks, num_kv_heads, head_dim//x, block_size, x]\n if use_new_kcache_layout:\n assert (\n k_cache.dim() == 5\n and k_cache.shape[1] == v_cache.shape[1]\n and k_cache.shape[2] * k_cache.shape[4] == v_cache.shape[3]\n ), f\"Invalid KCache shape {k_cache.shape} and VCache shape {v_cache.shape}\"\n x = k_cache.size(-1)\n kcsplit_x_stride, kcs_stride, kcd_stride = k_cache.stride()[-3:]\n\n grid = (q_head_num, q_total_tokens)\n decoding_fused_rotary_embedding_kernel[grid](\n q,\n k,\n v,\n cos,\n sin,\n k_cache,\n v_cache,\n block_tables,\n kv_lengths,\n x,\n q.stride(0),\n q.stride(1),\n k.stride(0),\n k.stride(1),\n q.stride(2),\n cos.stride(0),\n cos.stride(1),\n k_cache.stride(0),\n k_cache.stride(1),\n kcsplit_x_stride,\n kcs_stride,\n kcd_stride,\n v_cache.stride(0),\n v_cache.stride(1),\n v_cache.stride(2),\n v_cache.stride(3),\n block_tables.stride(0),\n block_tables.stride(1),\n k_cache.size(-2),\n KV_GROUP_NUM=kv_group_num,\n HEAD_DIM=head_dim,\n num_warps=num_warps,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-277c43e062bd4ee3b32e34598121ddf7", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton implementation facilitates efficient linear attention computation by organizing calculations into specialized kernels. The `chunk_linear_attn_fwd_kernel_h` kernel initializes an intermediate tensor `h`, iteratively updating it by loading and computing the dot product of block pointers derived from key (`k`) and value (`v`) tensors across multiple time steps, incorporating an initial state if provided. The `chunk_linear_attn_fwd_kernel_o` kernel calculates the output tensor `o` by iterating over dimensions of the query (`q`), `k`, and `h` tensors, performing dot products to compute weighted sums and handling the softmax attention scoring with masked elements set to zero. The backward pass is divided into `chunk_linear_attn_bwd_kernel_dh`, which computes the gradient for `h` by propagating gradients from output to input space, and `chunk_linear_attn_bwd_kernel_dqkv`, which calculates the gradients for `q`, `k`, and `v` by reusing the intermediate computations in the forward pass. These kernels employ Triton's block pointer feature, boundary checks, and arithmetic capabilities to perform operations efficiently on GPUs. The `ChunkLinearAttentionFunction` class integrates these Triton kernels into PyTorch's autograd framework, allowing the forward and backward computations to be seamlessly included in PyTorch models. The function uses input dimensions (`T`, `K`, `V`), threading model parameters (`BT`, `BK`, `BV`), and Triton's grid and warp management to adaptively optimize for performance and accuracy. ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional, Tuple\n\n@triton.jit\ndef chunk_linear_attn_fwd_kernel_h(\n k, v, h, h0, ht,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr,\n NT: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_h = tl.zeros([BK, BV], dtype=tl.float32)\n\n if USE_INITIAL_STATE:\n p_h0 = tl.make_block_ptr(h0 + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n b_h = tl.load(p_h0, boundary_check=(0, 1)).to(tl.float32)\n\n for i_t in range(NT):\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_h, b_h.to(p_h.dtype.element_ty), boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n if STORE_FINAL_STATE:\n p_ht = tl.make_block_ptr(ht + i_bh * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_linear_attn_fwd_kernel_o(\n q, k, v, h, o,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr\n):\n i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n o_i = tl.arange(0, BT)\n m_s = o_i[:, None] >= o_i[None, :]\n\n b_o = tl.zeros([BT, BV], dtype=tl.float32)\n b_s = tl.zeros([BT, BT], dtype=tl.float32)\n for i_k in range(tl.cdiv(K, BK)):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_o += tl.dot(b_q, b_h, allow_tf32=False)\n b_s += tl.dot(b_q, b_k, allow_tf32=False)\n b_s = tl.where(m_s, b_s, 0)\n\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_o = (b_o + tl.dot(b_s.to(b_v.dtype), b_v, allow_tf32=False)) * scale\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1))\n\n\n@triton.jit\ndef chunk_linear_attn_bwd_kernel_dh(\n q, do, dh,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_v, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n\n b_dh = tl.zeros([BK, BV], dtype=tl.float32)\n for i_t in range(NT - 1, -1, -1):\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h + i_t * K * V, (K, V), (s_h_t, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0))\n\n tl.store(p_dh, b_dh.to(p_dh.dtype.element_ty), boundary_check=(0, 1))\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_q = (b_q * scale).to(b_q.dtype)\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_dh += tl.dot(b_q, b_do.to(b_q.dtype), allow_tf32=False)\n\n\n@triton.jit\ndef chunk_linear_attn_bwd_kernel_dqkv(\n q, k, v, h, do, dh, dq, dk, dv,\n s_qk_h, s_qk_t, s_qk_d,\n s_vo_h, s_vo_t, s_vo_d,\n s_h_h, s_h_t, scale,\n T: tl.constexpr, K: tl.constexpr, V: tl.constexpr,\n BT: tl.constexpr, BK: tl.constexpr, BV: tl.constexpr, NT: tl.constexpr\n):\n i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2)\n n_bh = tl.num_programs(2)\n o_i = tl.arange(0, BT)\n\n p_q = tl.make_block_ptr(q + i_bh * s_qk_h, (K, T), (s_qk_d, s_qk_t), (i_k * BK, i_t * BT), (BK, BT), (0, 1))\n p_k = tl.make_block_ptr(k + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n\n b_q = tl.load(p_q, boundary_check=(0, 1))\n b_k = tl.load(p_k, boundary_check=(0, 1))\n b_s = tl.dot(b_k, b_q, allow_tf32=False) * scale\n b_s = tl.where(o_i[:, None] <= o_i[None, :], b_s, 0)\n\n b_dq = tl.zeros([BT, BK], dtype=tl.float32)\n b_dk = tl.zeros([BT, BK], dtype=tl.float32)\n b_ds = tl.zeros([BT, BT], dtype=tl.float32)\n for i_v in range(tl.cdiv(V, BV)):\n p_v = tl.make_block_ptr(v + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_h_h, (V, NT * K), (1, s_h_t), (i_v * BV, i_t * K + i_k * BK), (BV, BK), (0, 1))\n p_do = tl.make_block_ptr(do + i_bh * s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n p_dh = tl.make_block_ptr(dh + i_bh * s_h_h, (NT * K, V), (s_h_t, 1), (i_t * K + i_k * BK, i_v * BV), (BK, BV), (1, 0))\n p_dv = tl.make_block_ptr(dv + (i_k*n_bh+i_bh)*s_vo_h, (T, V), (s_vo_t, s_vo_d), (i_t * BT, i_v * BV), (BT, BV), (1, 0))\n\n b_v = tl.load(p_v, boundary_check=(0, 1))\n b_do = tl.load(p_do, boundary_check=(0, 1))\n b_h = tl.load(p_h, boundary_check=(0, 1))\n b_dh = tl.load(p_dh, boundary_check=(0, 1))\n\n b_ds += tl.dot(b_do, tl.trans(b_v), allow_tf32=False)\n b_dq += tl.dot(b_do, b_h, allow_tf32=False) * scale\n b_dk += tl.dot(b_v, tl.trans(b_dh), allow_tf32=False)\n b_dv = tl.dot(b_k, b_dh, allow_tf32=False) + tl.dot(b_s.to(b_q.dtype), b_do, allow_tf32=False)\n tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1))\n b_ds = tl.where(o_i[:, None] >= o_i[None, :], b_ds * scale, 0).to(b_q.dtype)\n b_dq += tl.dot(b_ds, b_k, allow_tf32=False)\n b_dk += tl.trans(tl.dot(b_q, b_ds, allow_tf32=False))\n\n p_dq = tl.make_block_ptr(dq + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n p_dk = tl.make_block_ptr(dk + i_bh * s_qk_h, (T, K), (s_qk_t, s_qk_d), (i_t * BT, i_k * BK), (BT, BK), (1, 0))\n tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1))\n tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1))\n\n\nclass ChunkLinearAttentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, scale, initial_state, output_final_state):\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n ctx.scale = scale\n\n final_state = None\n if output_final_state:\n final_state = q.new_empty(B, H, K, V, dtype=torch.float32, requires_grad=False)\n\n h = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_linear_attn_fwd_kernel_h[grid](\n k, v, h, initial_state, final_state,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=output_final_state,\n num_warps=num_warps,\n num_stages=num_stages\n )\n grid = (NV, NT, B * H)\n o = torch.empty_like(v)\n chunk_linear_attn_fwd_kernel_o[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n h.stride(1), h.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV,\n num_warps=num_warps,\n num_stages=num_stages\n )\n ctx.save_for_backward(q, k, v, h)\n return o.to(q.dtype), final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n q, k, v, h = ctx.saved_tensors\n\n B, H, T, K, V = *q.shape, v.shape[-1]\n BT = 64\n BK, BV = min(64, triton.next_power_of_2(K)), min(32 if q.dtype == torch.float32 else 64, triton.next_power_of_2(V))\n NT, NK, NV = triton.cdiv(T, BT), triton.cdiv(K, BK), triton.cdiv(V, BV)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n scale = ctx.scale\n\n dh = q.new_empty(B, H, NT * K, V)\n grid = (NK, NV, B * H)\n chunk_linear_attn_bwd_kernel_dh[grid](\n q, do, dh,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n\n grid = (NK, NT, B * H)\n dq = torch.empty_like(q)\n dk = torch.empty_like(k)\n dv = v.new_empty(NK, *v.shape)\n num_stages = 1\n num_warps = 4 if BK == 64 else 2\n chunk_linear_attn_bwd_kernel_dqkv[grid](\n q, k, v, h, do, dh, dq, dk, dv,\n q.stride(1), q.stride(2), q.stride(3),\n v.stride(1), v.stride(2), v.stride(3),\n dh.stride(1), dh.stride(2),\n scale,\n T=T, K=K, V=V, BT=BT, BK=BK, BV=BV, NT=NT,\n num_warps=num_warps,\n num_stages=num_stages\n )\n dv = dv.sum(0)\n return dq.to(q.dtype), dk.to(k.dtype), dv.to(v.dtype), None, None, None\n\n\ndef chunk_linear_attn(\n q: torch.Tensor, k: torch.Tensor, v: torch.Tensor,\n scale: Optional[float] = None, initial_state: torch.Tensor = None,\n output_final_state: bool = False, normalize: bool = True\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if scale is None:\n scale = q.shape[-1] ** -0.5\n o, final_state = ChunkLinearAttentionFunction.apply(q, k, v, scale, initial_state, output_final_state)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-bb51893682e2403da3bfff6cbf51a6b6", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code defines two Triton kernels and their associated host functions for computing the softmax operation over the last dimension of a 3D tensor and its backward pass.\n\n The `_softmax` kernel applies a fused softmax operation, which can handle various configurations such as logarithmic softmax (`LOG`), causal masking (`CAUSAL`), and other types of masking (`MASK_TYPE`). The kernel uses Triton primitives like `tl.load` to read inputs and `tl.store` to write outputs. The kernel is designed to operate efficiently by determining an appropriate computational `DEPTH` using Triton's `next_power_of_2` function, ensuring optimized memory access patterns. The `IS_FP16` heuristic checks if the input tensor is half-precision and ensures higher precision calculations if necessary.\n\n The host function `softmax` handles input validations and tensor preparation, calling `_softmax` with a calculated grid that spans the first two dimensions of the input tensor. It manages strides for accessing elements correctly and sets up any additional parameters required by the kernel.\n\n The `_softmax_backward` kernel computes gradients for the softmax operation, again supporting options like logarithmic output and causal behavior. It handles the backward pass by adjusting gradients accordingly, using efficient arithmetic operations that leverage Triton's parallelism.\n\n The function `softmax_backward` validates gradients' shapes and prepares the backward pass's configuration, similarly invoking the JIT-compiled `_softmax_backward` kernel with appropriate parameters. It ensures strides and memory layout are correctly handled for efficient gradient computation.\n\n Overall, the use of Triton's decorators like `@triton.autotune`, `@triton.heuristics`, and `@triton.jit` highlights this code's focus on performance through compile-time and runtime optimizations. This approach ensures both the forward and backward softmax operations are executed efficiently on supported hardware architectures.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\ndef get_depth(K):\n return triton.next_power_of_2(K)\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics({'DEPTH': lambda nargs: get_depth(nargs['K'])})\n@triton.heuristics({'IS_FP16': lambda nargs: nargs['Y'].dtype == torch.float16})\n@triton.jit\ndef _softmax(\n Y, X, M,\n stride_ym, stride_yn,\n stride_xm, stride_xn,\n stride_m,\n K,\n LOG: tl.constexpr,\n MASK_TYPE: tl.constexpr,\n CAUSAL: tl.constexpr,\n DEPTH: tl.constexpr,\n IS_FP16: tl.constexpr,\n):\n \"\"\"\n Fused softmax kernel over a 3d tensor.\n The softmax is applied over the last dimension, equivalent to torch.softmax(tensor, dim=-1)\n \"\"\"\n m = tl.program_id(0)\n n = tl.program_id(1)\n k = tl.arange(0, DEPTH)\n x_ptrs = X + m * stride_xm + n * stride_xn + k\n io_mask = k < K\n if CAUSAL:\n io_mask = io_mask & (k <= n)\n x = tl.load(x_ptrs, mask=io_mask, other=float(\"-inf\"))\n if CAUSAL:\n off = float(\"-inf\")\n off = off.to(x.dtype)\n x = tl.where(k > n, off, x)\n if MASK_TYPE is not None:\n if MASK_TYPE == 'qk':\n mask_ptrs = M + n * stride_m + k\n elif MASK_TYPE == 'bk':\n mask_ptrs = M + m * stride_m + k\n add_mask = tl.load(mask_ptrs, io_mask, other=float(\"-inf\"))\n x += add_mask\n z = x - tl.max(x, axis=0)\n if IS_FP16:\n z = z.to(tl.float32)\n num = tl.exp(z)\n denom = tl.sum(num, axis=0)\n if LOG:\n y = z - tl.log(denom)\n else:\n y = num / denom\n y_ptrs = Y + m * stride_ym + n * stride_yn + k\n tl.store(y_ptrs, y, mask=k < K)\n\ndef softmax(Y, X, M=None, log=False, mask_type=None, causal=False):\n assert X.ndim == 3, \"Input tensor X must be 3D\"\n assert Y.shape == X.shape, \"Output tensor Y must have the same shape as X\"\n M = M if M is not None else torch.empty(0, device=X.device)\n \n K = X.shape[-1]\n stride_ym, stride_yn = Y.stride()[:-1]\n stride_xm, stride_xn = X.stride()[:-1]\n stride_m = M.stride(-1) if M.numel() > 0 else 0\n\n grid = (X.shape[0], X.shape[1])\n _softmax[grid](\n Y, X, M,\n stride_ym, stride_yn,\n stride_xm, stride_xn,\n stride_m,\n K,\n LOG=log,\n MASK_TYPE=mask_type,\n CAUSAL=causal\n )\n\ndef get_depth(K):\n return triton.next_power_of_2(K)\n\n\n@triton.autotune(\n configs=[\n triton.Config({}, num_warps=1),\n triton.Config({}, num_warps=2),\n triton.Config({}, num_warps=4),\n triton.Config({}, num_warps=8),\n triton.Config({}, num_warps=16),\n triton.Config({}, num_warps=32),\n ],\n key=[\"K\"],\n)\n@triton.heuristics({'DEPTH': lambda nargs: get_depth(nargs['K'])})\n@triton.heuristics({'IS_FP16': lambda nargs: nargs['GradIn'].dtype == torch.float16})\n@triton.jit\ndef _softmax_backward(\n GradIn, GradOut, Out,\n stride_bm, stride_bn,\n stride_gm, stride_gn,\n stride_om, stride_on,\n K,\n LOG: tl.constexpr,\n CAUSAL: tl.constexpr,\n DEPTH: tl.constexpr,\n IS_FP16: tl.constexpr,\n):\n \"\"\"\n Compute the softmax gradients.\n \"\"\"\n m = tl.program_id(0)\n n = tl.program_id(1)\n k = tl.arange(0, DEPTH)\n grad_out_ptrs = GradOut + m * stride_gm + n * stride_gn + k\n out_ptrs = Out + m * stride_om + n * stride_on + k\n io_mask = k < K\n if CAUSAL:\n io_mask = io_mask & (k <= n)\n g = tl.load(grad_out_ptrs, mask=io_mask, other=float(0))\n o = tl.load(out_ptrs, mask=io_mask, other=float(0))\n if CAUSAL:\n zero = float(0)\n zero = zero.to(g.dtype)\n g = tl.where(k > n, zero, g)\n o = tl.where(k > n, zero, o)\n if LOG:\n s = tl.sum(g, 0)\n if IS_FP16:\n o = o.to(tl.float32)\n grad_in = g - tl.exp(o) * s\n else:\n s = tl.sum(g * o, 0)\n grad_in = o * (g - s)\n grad_in_ptrs = GradIn + m * stride_bm + n * stride_bn + k\n tl.store(grad_in_ptrs, grad_in, mask=k < K)\n\n\ndef softmax_backward(GradIn, GradOut, Out, log=False, causal=False):\n assert GradOut.shape == Out.shape, \"GradOut and Out must have the same shape\"\n assert GradIn.shape == Out.shape, \"GradIn and Out must have the same shape\"\n \n K = Out.shape[-1]\n stride_bm, stride_bn = GradIn.stride()[:-1]\n stride_gm, stride_gn = GradOut.stride()[:-1]\n stride_om, stride_on = Out.stride()[:-1]\n\n grid = (Out.shape[0], Out.shape[1])\n _softmax_backward[grid](\n GradIn, GradOut, Out,\n stride_bm, stride_bn,\n stride_gm, stride_gn,\n stride_om, stride_on,\n K,\n LOG=log,\n CAUSAL=causal\n )\n\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-be31d7f3ac014b4db522c43238d5fa75", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code implements a custom RMS normalization using Triton to accelerate the computation process. The kernel function `rms_norm_fwd_fused` is defined with the `@triton.jit` decorator, which makes it suitable for just-in-time compilation and execution on GPUs. The function is responsible for calculating the root mean square (RMS) normalization of each row in the input matrix `X`. It does so by computing the variance of elements in each row, calculating the inverse of the standard deviation (rstd), and applying a learned weight matrix `W` to the normalized output.\n\n The function takes in pointers to input tensor `X`, output tensor `Y`, and weights `W`, along with some metadata such as `stride` for row access and `N` for the number of columns. It utilizes a block-based approach to iterate over input data, computing partial results in blocks of `BLOCK_SIZE` to efficiently leverage GPU capabilities.\n\n The `TritonLlamaRMSNorm` class is a PyTorch `nn.Module` that integrates this kernel into a model. The constructor receives a `weight` tensor and an `eps` value for numerical stability. The `forward` method reshapes the input tensor into a 2D format, calculates the optimal block size, and enqueues the Triton kernel for execution. The kernel computes the normalized output stored in tensor `y`, ensuring high performance for inputs where the feature dimension is less than 64KB.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nimport torch.nn as nn\n\n# Kernel function for fused RMSNorm\n@triton.jit\ndef rms_norm_fwd_fused(\n X, # pointer to the input\n Y, # pointer to the output\n W, # pointer to the weights\n stride, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_SIZE: tl.constexpr,\n):\n # Map the program id to the row of X and Y it should compute.\n row = tl.program_id(0)\n Y += row * stride\n X += row * stride\n # Compute variance\n _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n _var += x * x\n var = tl.sum(_var, axis=0) / N\n rstd = 1 / tl.sqrt(var + eps)\n # Normalize and apply linear transformation\n for off in range(0, N, BLOCK_SIZE):\n cols = off + tl.arange(0, BLOCK_SIZE)\n mask = cols < N\n w = tl.load(W + cols, mask=mask)\n x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)\n x_hat = x * rstd\n y = x_hat * w\n # Write output\n tl.store(Y + cols, y, mask=mask)\n\n# TritonLlamaRMSNorm class for integrating the kernel into a model\nclass TritonLlamaRMSNorm(nn.Module):\n def __init__(self, weight, eps=1e-6):\n \"\"\"\n LlamaRMSNorm is equivalent to T5LayerNorm\n \"\"\"\n super().__init__()\n self.weight = weight\n self.variance_epsilon = eps\n\n def forward(self, x):\n y = torch.empty_like(x)\n # reshape input data into 2D tensor\n x_arg = x.reshape(-1, x.shape[-1])\n M, N = x_arg.shape\n # Less than 64KB per feature: enqueue fused kernel\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_SIZE:\n raise RuntimeError(\"This layer norm doesn't support feature dim >= 64KB.\")\n # heuristics for number of warps\n num_warps = min(max(BLOCK_SIZE // 256, 1), 8)\n # enqueue kernel\n rms_norm_fwd_fused[(M,)](\n x_arg,\n y,\n self.weight,\n x_arg.stride(0),\n N,\n self.variance_epsilon,\n BLOCK_SIZE=BLOCK_SIZE,\n num_warps=num_warps,\n )\n return y\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-e0583024fdfc468da4bf741b3398929c", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel function `rmsnorm_triton` is designed to perform RMS normalization on a given 3D tensor `x`, which is stored in `x_ptr`. The normalization is applied over the last dimension of `x` with a specified size `N_SIZE`. The function takes pointers to the input tensor `x_ptr`, the RMS normalization weights `rms_w_ptr`, and the output tensor `output_ptr`. It also requires the strides for input tensor, weight, and output tensor, as well as the constants `N_SIZE`, `eps`, and `BLOCK_N_SIZE`.\n\n Within the kernel, `pid_batch` and `pid_m` identify the current batch and row of the matrix being processed. The kernel calculates the variance by iterating over chunks of size `BLOCK_N_SIZE` and accumulating the sum of squares of the elements. This is divided by `N_SIZE` to get the variance, from which the reciprocal of the standard deviation `rstd` is derived using an epsilon `eps` for stability.\n\n The function then proceeds to normalize the input, scale by the weights, and write the output back to `output_ptr`. The iteration over chunks ensures efficient memory access patterns. The wrapper function `rmsnorm_triton_wrapper` sets up the input data, initializes the output tensor, and launches the kernel with the appropriate grid dimensions based on batch size and the second dimension M of the input tensor.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\ntorch.manual_seed(1234)\n\n\n@triton.jit\ndef rmsnorm_triton(x_ptr, rms_w_ptr, output_ptr,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_k,\n N_SIZE: tl.constexpr, eps: tl.constexpr, BLOCK_N_SIZE: tl.constexpr):\n pid_batch = tl.program_id(0)\n pid_m = tl.program_id(1)\n\n offs_m = pid_batch * stride_x_batch + pid_m * stride_x_m\n block_N = tl.arange(0, BLOCK_N_SIZE)\n var = tl.zeros((BLOCK_N_SIZE,), tl.float32)\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0)\n var += tl.extra.cuda.libdevice.pow(x.to(tl.float32), 2)\n\n var = tl.sum(var, axis=0) / N_SIZE\n rstd = tl.math.rsqrt(var + eps)\n\n # multiply by weight and add bias\n for block_n_start_idx in range(0, N_SIZE, BLOCK_N_SIZE):\n offs_n = block_n_start_idx + block_N\n x_ptr_mask = offs_n < N_SIZE\n rms_w = tl.load(rms_w_ptr + offs_n * stride_rms_w, mask=x_ptr_mask)\n\n x = tl.load(x_ptr + offs_m + offs_n * stride_x_k, mask=x_ptr_mask, other=0.0).to(tl.float32)\n x_hat = x * rstd\n out = x_hat * rms_w\n out_off = pid_batch * stride_out_batch + pid_m * stride_out_m + offs_n * stride_out_k\n tl.store(output_ptr + out_off, out, mask=x_ptr_mask)\n\n\ndef rmsnorm_triton_wrapper(x, rms_w, eps=1e-6):\n batch, M, K = x.shape\n assert rms_w.shape[-1] == K\n out = torch.empty_like(x)\n rmsnorm_triton[(batch, M,)](x, rms_w, out,\n *x.stride(),\n *rms_w.stride(),\n *out.stride(),\n N_SIZE=K, eps=eps, BLOCK_N_SIZE=1024,\n )\n return out\n\n\n\n" + ] + }, + { + "id": "openseek-8-70f2017397a4487ca301f5f83f89624e", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The `nested3` function is a Triton kernel designed to perform complex memory operations in a nested loop manner, specifically within a 2x2 tile structure. Its main purpose is to load elements from an input array (`in_ptr`) and store them in an output array (`out_ptr`) using specific stride parameters (`stride_m` and `stride_n`) to calculate the correct memory offsets. The kernel employs three nested loops over indices i, j, and k. In each loop, it executes the following steps: (1) Calculate pointers `a_ptrs` and `c_ptrs` for the input and output using offsets derived from the loop indices and strides. (2) Load values from these calculated pointers using `tl.load`. (3) Store these loaded values to output pointers using `tl.store`. The pointer increment within nested loops ensures correct access patterns for intended operations. The `wrapper_nested3` function sets up the input tensor `x` and output tensor `output` on CUDA with specific dimensions derived from `n_rows` and `n_cols`. It defines a grid configuration for the Triton kernel execution, where grid size is determined by dividing `n_cols` by 4. Finally, it invokes the `nested3` kernel with the defined grid, passing in tensor data and computed strides, and prints the resultant output tensor.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef nested3(in_ptr, out_ptr, stride_m, stride_n):\n offs_am = tl.arange(0, 2)\n offs_an = tl.arange(0, 2)\n a_ptrs = in_ptr + (offs_am[:, None] * stride_m +\n offs_an[None, :] * stride_n)\n\n offs_cm = tl.arange(0, 2)\n offs_cn = tl.arange(0, 2)\n c_ptrs = out_ptr + stride_m * offs_cm[:, None] + stride_n * offs_cn[\n None, :]\n\n for i in range(0, 2):\n a1 = tl.load(a_ptrs)\n\n for j in range(0, 2):\n a_ptrs += 2 * stride_n\n a2 = tl.load(a_ptrs)\n\n for k in range(0, 2):\n a_ptrs += 2 * stride_n\n a3 = tl.load(a_ptrs)\n tl.store(c_ptrs, a1)\n c_ptrs += 2 * stride_n\n\n tl.store(c_ptrs, a2)\n c_ptrs += 2 * stride_n\n tl.store(c_ptrs, a3)\n c_ptrs += 2 * stride_n\n\n a_ptrs += 2 * stride_n\n\n\ndef wrapper_nested3(n_rows, n_cols):\n x = torch.arange(0, n_rows * n_cols, device=\"cuda\", dtype=torch.int32).reshape([n_rows, n_cols])\n output = torch.zeros([n_rows, n_cols], device=x.device, dtype=x.dtype)\n grid = lambda meta: (n_cols // 4,)\n nested3[grid](x, output, x.stride(0), x.stride(1))\n print(output)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-872351027dfe4dc087507b735f9a98f8", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton code snippet provides a custom operator tailored for performing matrix multiplications where the matrix B is stored in a quantized int4 format, utilizing GPTQ (General-Purpose Tensor Quantization). This efficient format reduces memory usage and increases computational speed for neural network operations.\n\n The main computational function, `matmul4_kernel`, is decorated with `@triton.jit`, allowing it to be compiled just-in-time for the specific hardware it runs on. It is also wrapped with `@triton.autotune`, providing different configurations to optimize performance based on the input matrix sizes and other parameters. This function handles the matrix multiplication C = A x B, where A is a float16 matrix, B is a quantized int4 matrix represented as int32, and the result C is a float16 matrix. Key inputs include pointers to the data arrays, stride information to navigate these arrays, and dimensional sizes (M, N, K). The kernel processes blocks of the matrices determined by BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K.\n\n The kernel ensures proper dequantization of B using stored scale and zero-point values to transform B back to float format during operations. It leverages Triton's ability to manage thread and warp execution, relying on identifiers like `program_id` for handling parallel processing across different blocks. The computations involve transforming B's packed int32 representation back to float values through bit manipulation (`>>` for shifting, `&` for masking) and scaling adjustments. Accumulation results in a float32 temporary matrix that is later converted to the desired dtype before storage.\n\n The wrapper function `matmul_dequantize_int4_gptq` orchestrates kernel execution. It validates input conditions, defines the execution grid size based on input dimensions, and calls the kernel with necessary parameters. Optionally, it allows in-place updates of the output matrix.\n\n The auxiliary function `quantize_int4` is essential for preparing the B matrix in its int4 format. The function quantizes matrix weights, computes appropriate scales and zero points for reconstruction, and efficiently packs these quantized values into the int32 format expected by the kernel. It processes input matrices by transposing, slicing into groups, and calculating min-max ranges to determine scale factors and zero points, ensuring that each group is accurately represented in its reduced bit-depth format.\n\n Together, these components facilitate high-performance computations on modern GPU architectures, enabling efficient processing for deep learning models, particularly in contexts where memory and bandwidth constraints are significant.\n ", + "output": [ + "import time\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.autotune(\n\tconfigs=[\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=2, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 128, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8), \n ],\n\tkey=['M', 'N', 'K', 'NO_GROUPS'],\n)\n@triton.jit\ndef matmul4_kernel(\n\ta_ptr, b_ptr, c_ptr,\n\tscales_ptr, zeros_ptr,\n\tM, N, K,\n\tstride_am, stride_ak,\n\tstride_bk, stride_bn,\n\tstride_cm, stride_cn,\n\tstride_scales_g, stride_scales_n,\n\tstride_zeros_g, stride_zeros_n,\n\tgroupsize, NO_GROUPS: tl.constexpr,\n\tBLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n\tGROUP_SIZE_M: tl.constexpr,\n):\n \"\"\"\n Compute the matrix multiplication C = A x B.\n A is of shape (M, K) float16\n B is of shape (K//8, N) int32\n C is of shape (M, N) float16\n scales is of shape (G, N) float16\n zeros is of shape (G, N//8) int32\n groupsize is an int specifying the size of groups for scales and zeros.\n G is K // groupsize.\n Set NO_GROUPS to groupsize == K, in which case G = 1 and the kernel is more efficient.\n WARNING: This kernel assumes that K is a multiple of BLOCK_SIZE_K.\n WARNING: This kernel assumes that N is a multiple of BLOCK_SIZE_N.\n WARNING: This kernel assumes that groupsize is a multiple of BLOCK_SIZE_K.\n \"\"\"\n bits = 4\n infearure_per_bits = 8\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m \n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n a_mask = (offs_am[:, None] < M)\n # b_ptrs is set up such that it repeats elements along the K axis 8 times\n b_ptrs = b_ptr + ((offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn) # (BLOCK_SIZE_K, BLOCK_SIZE_N)\n scales_ptrs = scales_ptr + offs_bn * stride_scales_n # (BLOCK_SIZE_N,)\n # zeros_ptrs is set up such that it repeats elements along the N axis 8 times\n zeros_ptrs = zeros_ptr + ((offs_bn // infearure_per_bits) * stride_zeros_n) # (BLOCK_SIZE_N,)\n # shifter is used to extract the 4 bits of each element in the 32-bit word from B and zeros\n shifter = (offs_k % infearure_per_bits) * bits\n zeros_shifter = (offs_bn % infearure_per_bits) * bits\n # If G == 1, scales and zeros are the same for all K, so we can load them once\n if NO_GROUPS:\n # Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop\n scales = tl.load(scales_ptrs) # (BLOCK_SIZE_N,)\n zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32\t\n # Unpack zeros\n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n # zeros = (zeros + 1) * scales # (BLOCK_SIZE_N,) float16\n zeros = zeros * scales\n # Now calculate a block of output of shape (BLOCK_SIZE_M, BLOCK_SIZE_N)\n # M is along the batch dimension, N is along the outfeatures dimension, K is along the infeatures dimension\n # So this loop is along the infeatures dimension (K)\n # It's calculating BLOCK_SIZE_M batches in parallel, and for each batch, BLOCK_SIZE_N outfeatures in parallel\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, num_pid_k):\n a = tl.load(a_ptrs, mask=a_mask, other=0.) # (BLOCK_SIZE_M, BLOCK_SIZE_K)\n b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated\n if not NO_GROUPS:\n g_id = k // (groupsize // BLOCK_SIZE_K)\n ptr = scales_ptrs + g_id * stride_scales_g\n scales = tl.load(ptr) # (BLOCK_SIZE_N,)\n ptr = zeros_ptrs + g_id * stride_zeros_g # (BLOCK_SIZE_N,)\n zeros = tl.load(ptr) # (BLOCK_SIZE_N,), each element is repeated 8 times, int32\t\n # Unpack zeros\n zeros = (zeros >> zeros_shifter) & 0xF # (BLOCK_SIZE_N,) int32\n zeros = (zeros) * scales # (BLOCK_SIZE_N,) float16\t\n # Now we need to unpack b (which is 4-bit values) into 32-bit values\n b = (b >> shifter[:, None]) & 0xF # Extract the 4-bit values\n b = b * scales[None, :] - zeros[None, :] # Scale and shift\n # print(\"data type\", a, b)\n accumulator += tl.dot(a, b.to(a.dtype))\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk \n c = accumulator.to(c_ptr.dtype.element_ty) \n # Store the result\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\n\ndef matmul_dequantize_int4_gptq(x: torch.FloatTensor, qweight: torch.IntTensor, scales: torch.FloatTensor, qzeros: torch.IntTensor, group_size, output=None) -> torch.FloatTensor:\n\t\"\"\"\n\tCompute the matrix multiplication C = A x B + bias.\n\tWhere B is quantized using GPTQ and groupsize = -1 into 4-bit values.\n\n\tA is of shape (..., K) float16\n\tqweight is of shape (K//8, N) int32\n\tscales is of shape (G, N) float16\n\tqzeros is of shape (G, N//8) int32\n\tbias is of shape (1, N) float16\n\n\tgroupsize is the number of infeatures in each group.\n\tG = K // groupsize\n\n\tReturns C of shape (..., N) float16\n\t\"\"\"\n\tassert x.shape[-1] == (qweight.shape[0] * 8), \"A must be a multiple of 8 in the last dimension\"\n\tassert x.is_contiguous(), \"A must be contiguous\"\n\n\tM, K = x.shape\n\tN = qweight.shape[1]\n\t# This is based on the possible BLOCK_SIZE_Ks\n\t# assert K % 16 == 0 and K % 32 == 0 and K % 64 == 0 and K % 128 == 0, \"K must be a multiple of 16, 32, 64, and 128\"\n\t# # This is based on the possible BLOCK_SIZE_Ns\n\t# assert N % 16 == 0 and N % 32 == 0 and N % 64 == 0 and N % 128 == 0 and N % 256 == 0, \"N must be a multiple of 16, 32, 64, 128, and 256\"\n\t# # This is based on the possible BLOCK_SIZE_Ks\n\t# assert groupsize % 32 == 0 and groupsize % 64 == 0 and groupsize % 128 == 0, \"groupsize must be a multiple of 32, 64, and 128\"\n\n\t# output = torch.empty((M, N), device='cuda', dtype=torch.float16)\n\tif output is None:\n\t\tinplace = False\n\t\toutput = torch.empty((M, N), device=x.device, dtype=x.dtype)\n\telse:\n\t\tinplace = True\n\n\tgrid = lambda META: (\n\t\ttriton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),\n\t)\n\tmatmul4_kernel[grid](\n\t\tx, qweight, output,\n\t\tscales, qzeros,\n\t\tM, N, K,\n\t\tx.stride(0), x.stride(1),\n\t\tqweight.stride(0), qweight.stride(1),\n\t\toutput.stride(0), output.stride(1),\n\t\tscales.stride(0), scales.stride(1),\n\t\tqzeros.stride(0), qzeros.stride(1),\n\t\tgroup_size, group_size == K,\n )\n\t# return output\n\tif not inplace:\n\t\treturn output\n\ndef quantize_int4(weight, group_size=128, tp_rank=0):\n # Weight shape: [H1 // 8, H2]\n # Scale shape: [H1 // group_size, H2]\n # zero_pint shape: [H1 // group_size, H2 // 8]\n\n weight = weight.transpose(1, 0)\n h1, h2 = weight.shape\n assert h1 % 8 == 0 and h2 % 8 == 0, \"H1 {} H2 {}\".format(h1, h2)\n assert h2 % group_size == 0, \"H1 {} H2 {}\".format(h1, h2)\n weight = weight.contiguous().view(-1, group_size).cuda(tp_rank)\n weight_max = weight.amax(-1, keepdim=True)\n weight_max = torch.where(weight_max < 0, 0, weight_max)\n weight_min = weight.amin(-1, keepdim=True)\n weight_min = torch.where(weight_min > 0, 0, weight_min)\n weight_range = weight_max - weight_min \n scale = weight_range / (2 ** 4 - 1)\n zero_point = (-weight_min / scale).round().clamp(0, 15).to(torch.int32)\n weight = (weight / scale + zero_point).round().clamp(0, 15).to(torch.int32).view(h1, h2)\n int_weight = torch.empty(h1, h2 // 8).to(torch.int32).to(weight.device)\n int_zero_point = torch.zeros(h1 // 8, h2 // group_size).to(torch.int32).to(weight.device)\n zero_point = zero_point.view(h1, -1)\n scale = scale.view(h1, -1)\n # pack 8 int4 in an int32 number.\n # Weight pack in row.\n for pack in range(0, h2, 8):\n for i in range(8):\n int_weight[:, pack // 8] += weight[:, pack + i] << (i * 4)\n # zero point pack in col.\n for pack in range(0, h1, 8):\n for i in range(8):\n int_zero_point[pack // 8, :] += zero_point[pack + i, :] << (i * 4)\n '''\n fp_weight = torch.zeros(h1, h2).half().to(weight.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_weight[pack * 8 + i, :] = \\\n ((int_weight[pack, :] << (28 - i * 4) >> 28) + 16) % 16\n print((fp_weight - weight).abs().sum())\n\n fp_zp = torch.zeros(zero_point.shape).half().to(zero_point.device)\n for pack in range(0, h1 // 8):\n for i in range(8):\n fp_zp[pack * 8 + i, :] = \\\n (int_zero_point[pack, :] >> (i * 4)) & 15\n\n print((fp_zp - zero_point).abs().sum())\n '''\n weight = None\n return int_weight.transpose(1, 0).contiguous(), scale.transpose(1, 0).contiguous(), int_zero_point.transpose(1, 0).contiguous(), group_size\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-046ba77c956b4f569823866cff3654a0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel and function provided are designed to efficiently convert packed fp4 values to bf16 format while applying a scaling factor represented in e8m0 format. \n\nKernel `triton_f4_to_scaled_bf16_kernel`:\n- Function: Convert 4-bit floating-point numbers to bfloat16 with scaling.\n- Inputs: \n - `x_ptr`: Pointer to packed 4-bit floats.\n - `s_ptr`: Pointer to scaling factors in e8m0 format.\n - `output_ptr`: Pointer where results are stored.\n - `n_elements_in`: Number of packed elements to process.\n - Constant expressions for masks and biases used in conversion logic (e.g., `SIGN_MASK_F4`).\n- Outputs: Scaled bf16 tensor.\n- Process: \n - Decomposes each packed byte into two 4-bit numbers.\n - Adjusts the sign and mantissa bits to convert to bf16, applying bitwise operations.\n - Corrects for special cases like zero and denormals with masks.\n - Uses Triton-specific constructs for parallel processing, calculating the correct output block and grid sizes.\n\nWrapper Function `triton_f4_to_scaled_bf16`:\n- Function: Setup and launch the Triton kernel with correct parameters.\n- Inputs:\n - `x`: The tensor of packed 4-bit floats.\n - `s_e8m0`: Scaling factor tensor.\n - `mx_block_size`: Block size for processing.\n- Outputs: Tensor with bf16 values, scaled.\n- Implementation:\n - Determines shape and size of output tensor.\n - Ensures input tensor is contiguous and on the CUDA device.\n - Computes grid configuration using the number of input elements.\n - Launches `triton_f4_to_scaled_bf16_kernel` with the computed grid and configuration parameters, ensuring efficient conversion and scaling of input data.\n", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fp4_packed_to_bf16(\n x_packed,\n sign_mask_f4,\n mantissa_mask_f4,\n mbits_f4_e2m1,\n ebits_f4_e2m1,\n f4_e2m1_exp_bias,\n mbits_f32,\n ebits_f32,\n f32_exp_bias,\n zero_bits_f32,\n zero_point_five_bits_f32,\n):\n \"\"\"\n Input: a tensor of packed fp4 values\n Output: a tensor of bfloat16 values\n \"\"\"\n\n # low-bits: original location 0:3\n # high-bits: original location 4:7\n x_low_bits = x_packed >> 4\n x_high_bits = x_packed & 0xF\n x = tl.interleave(x_low_bits, x_high_bits)\n\n # cast logic below\n # output = x_unpacked.to(tl.float32)\n\n # save the sign\n sign_f4 = x & sign_mask_f4\n\n # set everything to positive, will add sign back at the end\n x_pos = x ^ sign_f4\n\n # Special case zero\n zero_mask = x_pos == 0\n\n # There is only one denormal value in fp4: s001, which is 0.5 in f32\n # Special case it.\n # TODO(later): will it be faster to repeat this for all 8 positive\n # values instead of the bit manipulations?\n denormal_mask = x_pos == 1\n\n # calculate the new exponent and shift it to bits 2:9 of the result\n exp_biased_f4 = x_pos >> mbits_f4_e2m1\n exp_biased_f32 = exp_biased_f4 - f4_e2m1_exp_bias + f32_exp_bias\n exp_biased_f32 = exp_biased_f32.to(tl.int32) << mbits_f32\n\n # shift the mantissa to bits 10:32 of the result\n mantissa_f4 = x_pos & mantissa_mask_f4\n mantissa_f32 = mantissa_f4.to(tl.int32) << (mbits_f32 - mbits_f4_e2m1)\n output = mantissa_f32\n\n # combine the pieces\n result = exp_biased_f32 | mantissa_f32\n # result[zero_mask] = ZERO_BITS_F32\n result = tl.where(zero_mask, zero_bits_f32, result)\n # result[denormal_mask] = ZERO_POINT_FIVE_BITS_F32\n result = tl.where(denormal_mask, zero_point_five_bits_f32, result)\n\n # add sign back\n sign_f32 = sign_f4.to(tl.int32) << (\n mbits_f32 - mbits_f4_e2m1 + ebits_f32 - ebits_f4_e2m1\n )\n result = result | sign_f32\n\n # The bit shifting above is for float32, so for now we\n # bitcast to float32 and then regular cast to bfloat16\n # TODO(later): it should be pretty easy to cast directly to bf16, just\n # need to adjust the mbits/ebits/special values. Perf impact is likely\n # to be small as we would not be chaning memory access patterns.\n output = result.to(tl.float32, bitcast=True)\n output = output.to(tl.bfloat16)\n return output\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE_IN\": 128}),\n triton.Config({\"BLOCK_SIZE_IN\": 256}),\n triton.Config({\"BLOCK_SIZE_IN\": 512}),\n triton.Config({\"BLOCK_SIZE_IN\": 1024}),\n triton.Config({\"BLOCK_SIZE_IN\": 2048}),\n ],\n key=[\"n_elements_in\"],\n)\n@triton.jit\ndef triton_f4_to_scaled_bf16_kernel(\n x_ptr,\n s_ptr,\n output_ptr,\n n_elements_in,\n mx_block_size: tl.constexpr,\n sign_mask_f4: tl.constexpr,\n mantissa_mask_f4: tl.constexpr,\n mbits_f4_e2m1: tl.constexpr,\n ebits_f4_e2m1: tl.constexpr,\n f4_e2m1_exp_bias: tl.constexpr,\n mbits_f32: tl.constexpr,\n ebits_f32: tl.constexpr,\n f32_exp_bias: tl.constexpr,\n zero_bits_f32: tl.constexpr,\n zero_point_five_bits_f32: tl.constexpr,\n e8m0_exponent_bias: tl.constexpr,\n e8m0_exponent_nan_val: tl.constexpr,\n BLOCK_SIZE_IN: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n n_elements_out = n_elements_in * 2\n n_elements_s = n_elements_out // 32\n\n BLOCK_SIZE_S: tl.constexpr = BLOCK_SIZE_IN // 16\n BLOCK_SIZE_OUT: tl.constexpr = BLOCK_SIZE_IN * 2\n\n block_start_in = pid * BLOCK_SIZE_IN\n offsets_in = block_start_in + tl.arange(0, BLOCK_SIZE_IN)\n mask_in = offsets_in < n_elements_in\n # packed uint8\n x_packed = tl.load(x_ptr + offsets_in, mask=mask_in)\n output = _fp4_packed_to_bf16(\n x_packed,\n sign_mask_f4,\n mantissa_mask_f4,\n mbits_f4_e2m1,\n ebits_f4_e2m1,\n f4_e2m1_exp_bias,\n mbits_f32,\n ebits_f32,\n f32_exp_bias,\n zero_bits_f32,\n zero_point_five_bits_f32,\n )\n\n # load scale\n block_start_s = pid * BLOCK_SIZE_S\n offsets_s = block_start_s + tl.arange(0, BLOCK_SIZE_S)\n mask_s = offsets_s < n_elements_s\n s = tl.load(s_ptr + offsets_s, mask=mask_s)\n\n # create the scale in bf16\n s_offset = s.to(tl.int16) - e8m0_exponent_bias\n s_fp = tl.extra.cuda.libdevice.pow(2.0, s_offset).to(tl.bfloat16)\n s_fp = tl.where(s != e8m0_exponent_nan_val, s_fp, float(\"nan\"))\n\n # multiply output by scale\n # TODO(later): see if manipulating the exponent instead of fp\n # multiplication is going to give a significant speedup\n output = tl.reshape(\n output, (BLOCK_SIZE_OUT // mx_block_size, mx_block_size)\n ) # noqa: E501\n s_fp = tl.reshape(s_fp, (BLOCK_SIZE_S // 1, 1))\n output = output * s_fp\n output = tl.reshape(output, (BLOCK_SIZE_OUT,))\n\n # set up output offsets\n block_start_out = pid * BLOCK_SIZE_OUT\n offsets_out = block_start_out + tl.arange(0, BLOCK_SIZE_OUT)\n mask_out = offsets_out < n_elements_out\n\n tl.store(output_ptr + offsets_out, output, mask=mask_out)\n\n\nEBITS_F32, MBITS_F32 = 8, 23\nEBITS_F4_E2M1, MBITS_F4_E2M1 = 2, 1\nEBITS_F6_E2M3, MBITS_F6_E2M3 = 2, 3\nEBITS_F6_E3M2, MBITS_F6_E3M2 = 3, 2\n\nSIGN_MASK_F4 = 0x8 # 1000\nMANTISSA_MASK_F4 = 0x1 # 0001\n\nZERO_BITS_F32 = 0x0\nZERO_POINT_FIVE_BITS_F32 = 0x3F000000\nF4_E2M1_EXP_BIAS = 1\nF32_EXP_BIAS = 127\nE8M0_EXPONENT_BIAS = 127\nE8M0_EXPONENT_NAN_VAL = 255\n\n\ndef triton_f4_to_scaled_bf16(\n x: torch.Tensor,\n s_e8m0: torch.Tensor,\n mx_block_size: int,\n):\n \"\"\"\n Input: a tensor of packed fp4 values, and a scale in e8m0 format. The block\n size is currently assumed to be 32.\n Output: a tensor of bfloat16 values, multiplied by the encoded scale\n \"\"\"\n new_shape = (*x.shape[:-1], x.shape[-1] * 2)\n output = torch.empty(*new_shape, device=x.device, dtype=torch.bfloat16)\n assert x.is_contiguous()\n assert x.is_cuda and output.is_cuda\n n_elements_in = x.numel()\n grid = lambda meta: ( # noqa: E731\n triton.cdiv(n_elements_in, meta[\"BLOCK_SIZE_IN\"]),\n )\n triton_f4_to_scaled_bf16_kernel[grid](\n x,\n s_e8m0,\n output,\n n_elements_in,\n mx_block_size,\n sign_mask_f4=SIGN_MASK_F4,\n mantissa_mask_f4=MANTISSA_MASK_F4,\n mbits_f4_e2m1=MBITS_F4_E2M1,\n ebits_f4_e2m1=EBITS_F4_E2M1,\n f4_e2m1_exp_bias=F4_E2M1_EXP_BIAS,\n mbits_f32=MBITS_F32,\n ebits_f32=EBITS_F32,\n f32_exp_bias=F32_EXP_BIAS,\n zero_bits_f32=ZERO_BITS_F32,\n zero_point_five_bits_f32=ZERO_POINT_FIVE_BITS_F32,\n e8m0_exponent_bias=E8M0_EXPONENT_BIAS,\n e8m0_exponent_nan_val=E8M0_EXPONENT_NAN_VAL,\n )\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-12b947ccc11447428f1891b1913f2e87", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel named 'matmul_kernel' is designed for performing block-wise matrix multiplication. \n It operates on matrices x and y of dimensions (m_size, k_size) and (k_size, n_size), respectively, and outputs a result matrix z of dimensions (m_size, n_size).\n The kernel begins by determining which block of the output matrix to calculate based on the program's id (`pid`). \n It computes the indices for these blocks by dividing the output matrix into smaller blocks, with sizes determined by m_block_size, k_block_size, and n_block_size.\n Using these indices, the kernel computes memory offsets to fetch blocks of data from the input matrices x and y.\n These blocks are loaded into shared memory using `tl.load`, which leverages the high bandwidth of shared memory to accelerate computation.\n The core computation is performed using `tl.dot`, which computes the dot product of the fetched sub-matrices, accumulating results in a local variable `z`.\n This operation is repeated for all k blocks to fully compute the matrix product.\n Once the computation for the current block is complete, the resulting sub-matrix `z` is stored back to the output matrix z in global memory.\n The `matmul` function acts as a wrapper around this kernel, setting up the grid and blocks, and initializing the output matrix. \n It calculates the required grid size to cover all blocks of the output matrix and launches the kernel with appropriate arguments.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef matmul_kernel(\n x_ptr,\n y_ptr,\n z_ptr,\n m_size,\n k_size,\n n_size,\n m_block_size: tl.constexpr,\n k_block_size: tl.constexpr,\n n_block_size: tl.constexpr,\n):\n # Calculate the m and n blocks to compute using the program id\n pid = tl.program_id(0)\n num_n_blocks = tl.cdiv(n_size, n_block_size)\n m_block = pid // num_n_blocks\n n_block = pid % num_n_blocks\n\n m_offsets = tl.arange(0, m_block_size) + m_block * m_block_size\n n_offsets = tl.arange(0, n_block_size) + n_block * n_block_size\n k_offsets = tl.arange(0, k_block_size)\n\n x_ptrs = x_ptr + m_offsets[:, None] * k_size + k_offsets[None, :]\n y_ptrs = y_ptr + k_offsets[:, None] * n_size + n_offsets[None, :]\n z_ptrs = z_ptr + m_offsets[:, None] * n_size + n_offsets[None, :]\n\n z = tl.zeros((m_block_size, n_block_size), dtype=tl.float32)\n\n for _ in range(0, k_size, k_block_size):\n x_sub = tl.load(x_ptrs)\n y_sub = tl.load(y_ptrs)\n z += tl.dot(x_sub, y_sub, allow_tf32=False)\n x_ptrs += k_block_size\n y_ptrs += k_block_size * n_size\n\n tl.store(z_ptrs, z)\n\ndef matmul(x, y):\n m_size, k_size = x.shape\n _, n_size = y.shape\n z = torch.empty(m_size, n_size, device='cuda:0')\n def grid(meta):\n return (triton.cdiv(m_size, meta['m_block_size']) * triton.cdiv(n_size, meta['n_block_size']),)\n \n matmul_kernel[grid](\n x,\n y,\n z,\n m_size,\n k_size,\n n_size,\n m_size,\n k_size,\n n_size\n )\n return z\n\n\n\n" + ] + }, + { + "id": "openseek-8-376c1fbac02c40b980a9b02da149eddf", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton operator is composed of a kernel function `matmul_kernel` and a wrapper function `matmul`. The kernel performs matrix multiplication C = A x B, where A, B, and C have dimensions (M, K), (K, N), and (M, N) respectively. The kernel is optimized using Triton’s block mapping, with configurable block sizes: BLOCK_SIZE_M, BLOCK_SIZE_N, and BLOCK_SIZE_K. The operation supports activation functions, such as leaky ReLU, defined in `leaky_relu(x)`. The wrapper function `matmul` initializes the output matrix C, configures the execution grid based on input dimensions, and invokes the kernel. Key parameters include matrix dimensions and strides, alongside the activation function name, which determines if and how the activation is applied. The process involves splitting the computation into blocks and iterating over the K dimension to accumulate results. Finally, results are stored back into the appropriate locations in C using masked memory operations.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.autotune(\n configs=[\n triton.Config(\n {\n \"BLOCK_SIZE_M\": 128,\n \"BLOCK_SIZE_N\": 256,\n \"BLOCK_SIZE_K\": 64,\n \"GROUP_SIZE_M\": 8,\n },\n num_stages=3,\n num_warps=8,\n ),\n # Additional configurations omitted for brevity\n ],\n key=[\"M\", \"N\", \"K\"],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr, # Pointers to matrices\n M, N, K, # Matrix dimensions\n stride_am, stride_ak, # Strides for matrix A\n stride_bk, stride_bn, # Strides for matrix B\n stride_cm, stride_cn, # Strides for matrix C\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr, ACTIVATION: tl.constexpr,\n):\n \"\"\"Kernel for computing the matmul C = A x B.\n A has shape (M, K), B has shape (K, N) and C has shape (M, N)\n \"\"\"\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator = tl.dot(a, b, accumulator)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n if ACTIVATION == \"leaky_relu\":\n accumulator = leaky_relu(accumulator)\n c = accumulator.to(tl.float16)\n\n offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n c_ptrs = c_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]\n c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)\n tl.store(c_ptrs, c, mask=c_mask)\n\n@triton.jit\ndef leaky_relu(x):\n return tl.where(x >= 0, x, 0.01 * x)\n\ndef matmul(a, b, activation=\"\"):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n assert a.is_contiguous(), \"Matrix A must be contiguous\"\n M, K = a.shape\n K, N = b.shape\n c = torch.empty((M, N), device=a.device, dtype=torch.float16)\n grid = lambda META: (\n triton.cdiv(M, META[\"BLOCK_SIZE_M\"]) * triton.cdiv(N, META[\"BLOCK_SIZE_N\"]),\n )\n matmul_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1),\n ACTIVATION=activation,\n )\n return c\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-00001205a1d84eac9c8465384cc3e1ce", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This code defines a Triton-based implementation of a linear layer with optional activation functions. The implementation starts with several auxiliary functions for different activations, such as tanh, ReLU, GELU, and fast GELU. The core computation is done in the kernel_fma function, which is a Triton kernel. This function takes in matrix pointers A, B, and C, along with optional bias and activation inputs, and computes the matrix multiplication with optional additions and activations. Key parameters like matrix dimensions, strides, and meta-parameters are passed to control the operation. This kernel is designed to work efficiently on GPUs, employing Triton's autotuning and heuristics to choose optimal configurations based on the inputs' dimensions and characteristics. The LinearLayer class is a custom PyTorch Function that sets up and calls the kernel_fma, managing inputs and outputs between PyTorch tensors and Triton. It ensures the input dimensions match and applies the chosen activation function. The linear_layer function is a user-facing wrapper for applying the LinearLayer function. It handles activation options and the optional saving of pre-activation inputs for backpropagation.\n ", + "output": [ + "from typing import Optional\nimport math\n\nimport torch\nimport triton\nimport triton.language as tl\nfrom torch.autograd.function import FunctionCtx\nfrom torch.cuda.amp import custom_fwd\nfrom triton.ops.matmul_perf_model import early_config_prune, estimate_matmul_time\n\n\nsqrt2pi = math.sqrt(2.0 / math.pi)\nsqrt2 = tl.constexpr(math.sqrt(2.0))\n\n\n@triton.jit\ndef tanh(x):\n \"\"\"Tanh activation function\"\"\"\n return tl.extra.cuda.libdevice.tanh(x)\n\n\n@triton.jit\ndef relu(x):\n \"\"\"Relu activation function\"\"\"\n return tl.maximum(0, x)\n\n\n@triton.jit\ndef fast_gelu(x):\n \"\"\"Fast approximation of the gelu function. May slightly decrease accuracy.\"\"\"\n return 0.5 * x * (1 + tanh(sqrt2pi * (x + 0.044715 * x * x * x)))\n\n\n@triton.jit\ndef gelu(x):\n \"\"\"Gaussian Error Linear Unit (GELU)\"\"\"\n return x * 0.5 * (1.0 + tl.extra.cuda.libdevice.erf(x / sqrt2))\n\n\ndef init_to_zero(name):\n return lambda nargs: nargs[name].zero_()\n\n\ndef get_configs_io_bound():\n configs = []\n for num_stages in [2, 3, 4, 5, 6]:\n for block_m in [16, 32]:\n for block_k in [32, 64]:\n for block_n in [32, 64, 128, 256]:\n num_warps = 2 if block_n <= 64 else 4\n configs.append(\n triton.Config(\n {\"BLOCK_M\": block_m, \"BLOCK_N\": block_n, \"BLOCK_K\": block_k, \"SPLIT_K\": 1},\n num_stages=num_stages,\n num_warps=num_warps,\n )\n )\n # split_k not used\n # for split_k in [2, 4, 8, 16]:\n # configs.append(triton.Config(\n # {'BLOCK_M': block_m, 'BLOCK_N': block_n, 'BLOCK_K': block_k, 'SPLIT_K': split_k},\n # num_stages=num_stages, num_warps=num_warps, pre_hook=init_to_zero('C')))\n return configs\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 32, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n # good for int8\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=3, num_warps=8),\n triton.Config({\"BLOCK_M\": 256, \"BLOCK_N\": 64, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 256, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 128, \"BLOCK_K\": 128, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 64, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 128, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 128, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=4, num_warps=4),\n triton.Config({\"BLOCK_M\": 64, \"BLOCK_N\": 32, \"BLOCK_K\": 64, \"SPLIT_K\": 1}, num_stages=5, num_warps=2),\n ]\n + get_configs_io_bound(),\n key=[\"CACHE_KEY_M\", \"CACHE_KEY_N\", \"CACHE_KEY_K\"],\n prune_configs_by={\"early_config_prune\": early_config_prune, \"perf_model\": estimate_matmul_time, \"top_k\": 10},\n)\n@triton.heuristics(\n {\n \"K_LOAD_MASK_NEEDED\": lambda args: args[\"K\"] % (args[\"BLOCK_K\"] * args[\"SPLIT_K\"]) == 0,\n }\n)\n@triton.jit\ndef kernel_fma(\n C, # Pointers to matrices\n ACT_INPUTS,\n A,\n B,\n bias,\n # Matrix dimensions\n M,\n N,\n K,\n CACHE_KEY_M,\n CACHE_KEY_N,\n CACHE_KEY_K,\n # The stride variables represent how much to increase the ptr by when moving by 1\n # element in a particular dimension. E.g. stride_am is how much to increase a_ptr\n # by to get the element one row down (A has M rows)\n output_m_stride,\n output_n_stride,\n act_inputs_m_stride,\n act_inputs_n_stride,\n a_m_stride,\n a_k_stride,\n b_n_stride,\n b_k_stride,\n # Meta-parameters\n BLOCK_M: tl.constexpr,\n GROUP_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n BLOCK_K: tl.constexpr,\n # split k not used, not performant with activation, kept because early_config_prune is expecting it\n SPLIT_K: tl.constexpr,\n K_LOAD_MASK_NEEDED: tl.constexpr,\n HAS_BIAS: tl.constexpr,\n SHOULD_SAVE_ACT_INPUTS: tl.constexpr,\n ACTIVATION: tl.constexpr,\n):\n \"\"\"\n Kernel for computing Out = activation(A x W + C)\n\n - Input has shape (M, K)\n - Weight has shape (K, N)\n - Bias has shape (N,)\n - Output has shape (M, N)\n - ActInputs (optional) has shape (M, N)\n\n 'ActInputs' optionally saves the A x W + C intermediate for backward computations\n\n This kernel will consolidate over K\n \"\"\"\n program_idx = tl.program_id(axis=0)\n\n grid_m = (M + BLOCK_M - 1) // BLOCK_M\n grid_n = (N + BLOCK_N - 1) // BLOCK_N\n # re-order program ID for better L2 performance\n width = GROUP_M * grid_n\n group_idx = program_idx // width\n group_size = min(grid_m - group_idx * GROUP_M, GROUP_M)\n block_m_idx = group_idx * GROUP_M + (program_idx % group_size)\n block_n_idx = (program_idx % width) // group_size\n\n # now compute the block that each program will go through\n # m_offs (resp. n_offs) denotes a range of indices\n # for rows (resp. col) of C\n m_offs_untagged = block_m_idx * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offs_untagged = block_n_idx * BLOCK_N + tl.arange(0, BLOCK_N)\n\n # trick to avoid masking on M and N axis\n # m_offs_untagged and n_offs_untagged can contains addresses outside matrix boundaries\n # modulo operation is used to wrap around the indices that go beyond the matrix boundaries\n # The value loaded are not ok but at least we are not reading outside the A/B matrices\n # Then, during storing in C a mask is used and the results related to these wrong values is discarded!\n # Regarding max_contiguous and multiple_of, they are used to force the compiler to vectorize loads\n # multiple_of indicates that the first element of rm / rn is a multiple of BLOCK_M / BLOCK_N\n # max_contiguous indicates that the range is a block of BLOCK_M / BLOCK_N contiguous elements\n m_offs = tl.max_contiguous(tl.multiple_of(m_offs_untagged % M, BLOCK_M), BLOCK_M)\n n_offs = tl.max_contiguous(tl.multiple_of(n_offs_untagged % N, BLOCK_N), BLOCK_N)\n\n k_range_offs = tl.arange(0, BLOCK_K)\n\n A = A + (m_offs[:, None] * a_m_stride + k_range_offs[None, :] * a_k_stride)\n B = B + (k_range_offs[:, None] * b_k_stride + n_offs[None, :] * b_n_stride)\n\n acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)\n\n if HAS_BIAS:\n bias = tl.load(bias + n_offs, mask=n_offs < N, other=0.0).to(tl.float32)\n acc += bias[None, :]\n\n for k in range(K, 0, -BLOCK_K):\n if K_LOAD_MASK_NEEDED:\n a = tl.load(A)\n b = tl.load(B)\n else:\n a = tl.load(A, mask=k_range_offs[None, :] < k, other=0.0)\n b = tl.load(B, mask=k_range_offs[:, None] < k, other=0.0)\n acc += tl.dot(a, b)\n\n A += BLOCK_K * a_k_stride\n B += BLOCK_K * b_k_stride\n\n # optional: save the activation inputs\n if SHOULD_SAVE_ACT_INPUTS:\n act_in_ptrs = ACT_INPUTS + m_offs[:, None] * act_inputs_m_stride + n_offs[None, :] * act_inputs_n_stride\n tl.store(act_in_ptrs, acc)\n\n # optional: fused activation (while the data is in shared memory)\n if ACTIVATION == \"tanh\":\n acc = tanh(acc)\n if ACTIVATION == \"gelu\":\n acc = gelu(acc)\n if ACTIVATION == \"fast_gelu\":\n acc = fast_gelu(acc)\n if ACTIVATION == \"relu\":\n acc = relu(acc)\n\n # write back result\n C = C + m_offs[:, None] * output_m_stride + n_offs[None, :] * output_n_stride\n c_ptr_mask = (m_offs < M)[:, None] & (n_offs < N)[None, :]\n tl.store(C, acc, mask=c_ptr_mask)\n\n\nclass LinearLayer(torch.autograd.Function):\n @staticmethod\n @custom_fwd(cast_inputs=torch.float16)\n def forward(\n ctx: FunctionCtx,\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n activation: str,\n act_inputs: Optional[torch.Tensor],\n ) -> torch.Tensor:\n \"\"\"\n Compute e = activation(x @ weight + bias).\n This wrapper kicks the `kernel_fma` Triton kernel\n :param ctx: context for autograd\n :param x: input tensor\n :param weight: weight matrix\n :param bias: an optional bias tensor\n :param activation: Activation name. Needs to be a Triton kernel.\n :param act_inputs: an optional tensor to save the activation inputs (for backward)\n :return: result tensor\n \"\"\"\n x_ = x if x.ndim == 2 else x.flatten(0, 1)\n\n assert x.dtype == weight.dtype, f\"Input and weight must have the same dtype, got {x.dtype} and {weight.dtype}\"\n if bias is not None:\n assert x.dtype == bias.dtype, f\"Input and bias must have the same dtype, got {x.dtype} and {bias.dtype}\"\n assert x_.shape[1] == weight.shape[1], f\"Incompatible dimensions: {x_.shape} - {weight.shape}\"\n\n assert bias is None or bias.is_contiguous()\n assert bias is None or bias.shape[0] == weight.shape[0], \"Incompatible dimensions in between weight and bias\"\n assert weight.is_contiguous()\n\n M, K = x_.shape\n N, K = weight.shape\n\n outputs = torch.empty((M, N), device=x.device, dtype=x.dtype)\n\n # 1D launch kernel where each block gets its own program.\n grid = lambda META: (triton.cdiv(M, META[\"BLOCK_M\"]) * triton.cdiv(N, META[\"BLOCK_N\"]),) # noqa\n\n kernel_fma[grid](\n outputs,\n act_inputs,\n x_,\n weight, # data ptrs\n bias if bias is not None else x, # auto skip bias if not present\n M, # shapes\n N,\n K,\n M // 32, # key for triton cache (limit number of compilations)\n N // 32,\n K // 32,\n output_m_stride=outputs.stride(0), # strides\n output_n_stride=outputs.stride(1),\n act_inputs_m_stride=act_inputs.stride(0) if act_inputs is not None else 0,\n act_inputs_n_stride=act_inputs.stride(1) if act_inputs is not None else 0,\n a_m_stride=x_.stride(0),\n a_k_stride=x_.stride(1),\n b_n_stride=weight.stride(0),\n b_k_stride=weight.stride(1),\n HAS_BIAS=bias is not None, # optional fused bias\n SHOULD_SAVE_ACT_INPUTS=act_inputs is not None, # optional save activation inputs\n ACTIVATION=activation if not None else x, # optional fused activation\n GROUP_M=8, # speed optimization: group the programs\n )\n\n outputs = outputs if x.ndim == 2 else outputs.reshape(x.shape[0], -1, N)\n ctx.save_for_backward(weight, bias, x)\n return outputs\n\n\ndef linear_layer(\n x: torch.Tensor,\n weight: torch.Tensor,\n bias: Optional[torch.Tensor],\n activation=\"\",\n act_inputs: Optional[torch.Tensor] = None,\n) -> torch.Tensor:\n return LinearLayer.apply(x, weight, bias, activation, act_inputs)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-886efac9244044678ae4d37a449ce675", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The given Triton code defines a GPU-accelerated softmax operation through two functions: `softmax_kernel` and `softmax`.\n - `softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE)`: This is a Triton JIT-compiled function designed to calculate the softmax of rows in a 2D tensor. It first identifies the row index using `tl.program_id(0)` and computes the starting pointer for the row. `col_offsets` and `input_ptrs` define the block of columns to process. The function loads the input row into SRAM, adjusts it by subtracting the maximum value for numerical stability, and optionally adds a mask if `mask_ptr` is not `None`. It computes the exponentials, sums them to derive the denominator, and divides each element to produce the softmax output, which is stored back in the output tensor.\n - `softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor`: This is the wrapper function that sets up the Triton kernel. It first validates the input tensor and mask dimensions and ensures the operation is along the last dimension. The input tensor is reshaped into 2D if necessary. Depending on the size of the input, the function selects grid and block sizes, choosing an appropriate number of warps to optimize performance. For large row counts, it uses lambda functions to calculate grid dimensions dynamically. It then invokes `softmax_kernel` with the calculated parameters and returns the resulting tensor containing softmax values.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef softmax_kernel(output_ptr, input_ptr, row_stride, n_cols, mask_ptr, BLOCK_SIZE: tl.constexpr):\n r\"\"\"the kernel function for implementing softmax operator\n Args:\n output_ptr: the output after finishing softmax operation, (N, hidden_dim)\n input_ptr: the tensor of input, shape should be (N, hidden_dim)\n n_cols(tl.constexpr): the number of cols of input\n BLOCK_SIZE(tl.constexpr): the block_size of your hidden_dim dimension, typically BLOCK_SIZE >= hidden_dim\n \"\"\"\n row_idx = tl.program_id(0)\n row_start_ptr = input_ptr + row_idx * row_stride\n col_offsets = tl.arange(0, BLOCK_SIZE)\n input_ptrs = row_start_ptr + col_offsets\n row = tl.load(input_ptrs, mask=col_offsets < n_cols, other=-float(\"inf\")).to(tl.float32)\n row_minus_max = row - tl.max(row, axis=0)\n\n if mask_ptr is not None:\n # load mask into SRAM\n mask_ptrs = (mask_ptr + (row_idx * row_stride)) + col_offsets\n mask = tl.load(mask_ptrs, mask=col_offsets < n_cols, other=0).to(tl.float32)\n\n # update\n row_minus_max = row_minus_max + mask\n\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n output_row_start_ptr = output_ptr + row_idx * row_stride\n output_ptrs = output_row_start_ptr + col_offsets\n # Write back output to DRAM\n tl.store(output_ptrs, softmax_output, mask=col_offsets < n_cols)\n\ndef softmax(input: torch.Tensor, mask: torch.Tensor = None, dim=-1) -> torch.Tensor:\n if mask is not None:\n # assert input[-1] == mask[-1], \"the last dimentions should be the same for input and mask\"\n assert input.shape[-1] == mask.shape[-1], \"the last dimensions should be the same for input and mask\"\n assert dim == -1 or dim == len(input.shape) - 1, \"currently softmax layer only support last dimention\"\n\n hidden_dim = input.shape[-1]\n output = torch.empty_like(input)\n input = input.view(-1, hidden_dim)\n if mask is not None:\n mask = mask.view(-1, hidden_dim)\n assert input.shape[0] == mask.shape[0], \"the fist dimention of mask and input should be the same\"\n\n num_rows, num_cols = input.shape\n block_size = max(triton.next_power_of_2(num_cols), 2)\n num_warps = 16\n if block_size >= 4096:\n num_warps = 16\n elif block_size >= 2048:\n num_warps = 8\n else:\n num_warps = 4\n\n if num_rows <= 350000:\n grid = (num_rows,)\n softmax_kernel[grid](\n output, input, input.stride(0), num_cols, mask, BLOCK_SIZE=block_size, num_warps=num_warps\n )\n else:\n grid = lambda meta: (triton.cdiv(num_rows, meta[\"BLOCK_M\"]),)\n\n softmax_kernel[grid](\n output_ptr=output,\n input_ptr=input,\n row_stride=input.stride(0),\n n_rows=num_rows,\n n_cols=num_cols,\n mask_ptr=mask,\n # currently manually setting up size\n BLOCK_M=32,\n BLOCK_SIZE=block_size,\n )\n\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c3616cd178b049fcbcc617d94aefe03f", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton operator implements a forward-pass kernel for the attention mechanism, a central part of transformer models. The `attention_fwd_kernel` function is compiled using Triton’s `@triton.jit`, allowing the kernel to execute on the GPU. It computes the scaled dot-product attention, which involves taking queries (q), keys (k), and values (v) as inputs. The calculations are performed over batched multi-head inputs, where each batch and head are processed independently.\n\n Inside the kernel, `b_h` is initialized as a zero matrix and is used to accumulate intermediate results. The loop iterates over divided blocks of the sequence length (`tl.cdiv(T, BT)`). In each iteration, memory pointers `p_q`, `p_k`, `p_v`, `p_h`, and `p_o` are set up to point to the corresponding slices of the q, k, v, h, and o tensors. The dimensions and strides are set to efficiently load and store tensor blocks.\n\n The core computation begins with loading blocks of the query and key tensors, followed by a scaled dot-product to compute attention scores (`b_s`). The block of the value tensor is then weighted by these scores to produce the output block (`b_o`). Depending on the `IFCOND` condition, either a conditional update or standard update is applied to the intermediate `b_h` tensor, which may also be optionally stored when the `STORE` flag is set.\n\n The class `AttentionFunction` provides a user interface for invoking the Triton kernel. It initializes the output tensor `o` and an intermediate storage tensor `h`. The forward method sets the kernel launch configurations, like grid size, number of warps, and stages, tailored for different input dimensions. The method also computes scaling based on the head dimension and calls the Triton kernel with these parameters. The function returns the computed output tensor.\n\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef attention_fwd_kernel(\n q,\n k,\n v,\n h,\n o,\n s_qh,\n s_qt,\n s_qd,\n s_hh,\n s_ht,\n T,\n scale,\n BT: tl.constexpr,\n BD: tl.constexpr,\n NT: tl.constexpr,\n STORE: tl.constexpr,\n IFCOND: tl.constexpr\n):\n i_bh = tl.program_id(0)\n\n # [BD, BD]\n b_h = tl.zeros([BD, BD], dtype=tl.float32)\n for i in range(0, tl.cdiv(T, BT)):\n p_q = tl.make_block_ptr(q + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n p_k = tl.make_block_ptr(k + i_bh * s_qh, (BD, T), (s_qd, s_qt), (0, i * BT), (BD, BT), (0, 1))\n p_v = tl.make_block_ptr(v + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n p_h = tl.make_block_ptr(h + i_bh * s_hh, (NT * BD, BD), (s_ht, s_qd), (i * BD, 0), (BD, BD), (1, 0))\n p_o = tl.make_block_ptr(o + i_bh * s_qh, (T, BD), (s_qt, s_qd), (i * BT, 0), (BT, BD), (1, 0))\n\n if STORE:\n tl.store(p_h, b_h.to(p_h.dtype.element_ty))\n # [BT, BD]\n b_q = tl.load(p_q)\n b_q = (b_q * scale).to(b_q.dtype)\n # [BD, BT]\n b_k = tl.load(p_k)\n # [BT, BD]\n b_v = tl.load(p_v)\n\n # [BT, BT]\n b_s = tl.dot(b_q, b_k, allow_tf32=False)\n # [BT, BD]\n b_o = tl.dot(b_s.to(b_q.dtype), b_v, allow_tf32=False)\n if IFCOND:\n if i == 0:\n b_h = tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n else:\n b_o += tl.dot(b_q, b_h.to(b_q.dtype), allow_tf32=False)\n b_h += tl.dot(b_k, b_v, allow_tf32=False)\n\n tl.store(p_o, b_o.to(p_o.dtype.element_ty))\n\n\nclass AttentionFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, q, k, v, store=False, ifcond=False):\n batch_size, n_heads, seq_len, d_head = q.shape\n scale = d_head ** -0.5\n BD = q.shape[-1]\n BT = 32\n NT = triton.cdiv(seq_len, BT)\n num_stages = 3 if d_head <= 64 else 2\n num_warps = 4\n\n h = q.new_empty(batch_size, n_heads, NT * BD, BD)\n o = torch.empty_like(q)\n grid = (batch_size * n_heads,)\n attention_fwd_kernel[grid](\n q, k, v, h, o,\n q.stride(1), q.stride(2), q.stride(3), h.stride(1), h.stride(2),\n seq_len, scale,\n BT=BT, BD=BD, NT=NT, STORE=store, IFCOND=ifcond,\n num_warps=num_warps,\n num_stages=num_stages\n )\n return o\n\n\n\n" + ] + }, + { + "id": "openseek-8-b2326578003944b4907206c3b57051a0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton-based kernel performs a forward pass of the scaled dot-product attention mechanism, implemented through two main functions: `_attn_fwd` and `_attn_fwd_inner`. The kernel processes query (Q), key (K), value (V) matrices and corresponding scaling factors to compute attention scores in the output tensor `Out`. This is achieved by using a block-based approach, where computations are performed on blocks of size `BLOCK_M` (for queries) and `BLOCK_N` (for keys and values).\n\n The `_attn_fwd` function orchestrates the overall computation by initializing pointers and offsets needed for matrix access. It handles the iteration over context size `N_CTX` and invokes the helper function `_attn_fwd_inner` for each block.\n\n The `_attn_fwd_inner` function conducts the core computation by:\n - Loading blocks of K and V using pointer arithmetic, controlled via `K_ptrs`, `K_scale_ptr`, and `V_ptrs`.\n - Computing the scaled dot-product `qk` between a block of Q and K, using scaling factors `q_scale` and `k_scale`.\n - Applying the softmax operation over `qk` to obtain the probability matrix `p`.\n - Accumulating the result of `p` weighted by V into `acc`.\n - Updating normalization `l_i` and maximum score `m_i` for stability and normalization purposes.\n\n Function Inputs:\n - `Q`, `K`, `V`: 3D tensors representing the attention mechanism components.\n - `Q_scale`, `K_scale`: Tensors representing scale factors for query and key matrices.\n - `Out`: Output tensor to store the attention results.\n - `stride_qz`, `stride_qh`, `stride_qm`, `stride_qk`: Strides for navigating through the Q tensor.\n - `stride_kz`, `stride_kh`, `stride_kn`, `stride_kk`: Strides for navigating through the K tensor.\n - `stride_vz`, `stride_vh`, `stride_vk`, `stride_vn`: Strides for navigating through the V tensor.\n - `stride_oz`, `stride_oh`, `stride_om`, `stride_on`: Strides for navigating through the output tensor.\n\n Outputs:\n - `Out`: The result tensor holding the attention scores, normalized over the softmax.\n\n Key logic considerations:\n - Block-wise processing allows for efficient memory access and parallel computation.\n - Handling of numerical stability and normalization through `l_i` and `m_i`.\n - Triton's just-in-time (JIT) compilation is utilized to optimize performance across GPUs.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _attn_fwd_inner(acc, l_i, m_i, q, q_scale, \n K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr, BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr, offs_m: tl.constexpr, offs_n: tl.constexpr, \n N_CTX: tl.constexpr):\n lo, hi = 0, N_CTX\n for start_n in range(lo, hi, BLOCK_N):\n start_n = tl.multiple_of(start_n, BLOCK_N)\n k_mask = (offs_n[None, :] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[:, None])\n k = tl.load(K_ptrs, mask = k_mask)\n k_scale = tl.load(K_scale_ptr)\n qk = tl.dot(q, k).to(tl.float32) * q_scale * k_scale\n m_ij = tl.maximum(m_i, tl.max(qk, 1))\n qk = qk - m_ij[:, None]\n p = tl.math.exp2(qk)\n l_ij = tl.sum(p, 1)\n alpha = tl.math.exp2(m_i - m_ij)\n l_i = l_i * alpha + l_ij\n acc = acc * alpha[:, None]\n v = tl.load(V_ptrs, mask = (offs_n[:, None] < (N_CTX - start_n)) & ((tl.arange(0, 128) < 96)[None, :]))\n p = p.to(tl.float16)\n \n acc += tl.dot(p, v, out_dtype=tl.float16) \n m_i = m_ij\n K_ptrs += BLOCK_N * HEAD_DIM\n K_scale_ptr += 1\n V_ptrs += BLOCK_N * HEAD_DIM\n return acc, l_i\n\n@triton.jit\ndef _attn_fwd(Q, K, V, Q_scale, K_scale, Out, \n stride_qz, stride_qh, stride_qm, stride_qk, \n stride_kz, stride_kh, stride_kn, stride_kk, \n stride_vz, stride_vh, stride_vk, stride_vn, \n stride_oz, stride_oh, stride_om, stride_on, \n Z, H, N_CTX, \n HEAD_DIM: tl.constexpr, \n BLOCK_M: tl.constexpr, \n BLOCK_N: tl.constexpr, \n STAGE: tl.constexpr \n ):\n start_m = tl.program_id(0)\n off_hz = tl.program_id(1)\n \n off_z = off_hz // H\n off_h = off_hz % H\n qvk_offset = off_z.to(tl.int64) * stride_qz + off_h.to(tl.int64) * stride_qh\n vk_offset = qvk_offset // stride_qm\n q_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_M)\n k_scale_offset = off_hz * tl.cdiv(N_CTX, BLOCK_N) \n \n offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)\n offs_n = tl.arange(0, BLOCK_N)\n offs_k = tl.arange(0, 128)\n Q_ptrs = Q + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n Q_scale_ptr = Q_scale + q_scale_offset + start_m\n K_ptrs = K + qvk_offset + offs_k[:, None] + offs_n[None, :] * stride_kn\n K_scale_ptr = K_scale + k_scale_offset\n V_ptrs = V + qvk_offset + offs_n[:, None] * stride_qm + offs_k[None, :] * stride_qk\n O_block_ptr = Out + qvk_offset + offs_m[:, None] * stride_qm + offs_k[None, :] * stride_qk\n m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float(\"inf\")\n l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + 1.0\n acc = tl.zeros([BLOCK_M, 128], dtype=tl.float32)\n q = tl.load(Q_ptrs, mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n q_scale = tl.load(Q_scale_ptr)\n acc, l_i = _attn_fwd_inner(acc, l_i, m_i, q, q_scale, K_ptrs, K_scale_ptr, V_ptrs, \n start_m, \n BLOCK_M, HEAD_DIM, BLOCK_N, \n 4 - STAGE, offs_m, offs_n, N_CTX \n )\n acc = acc / l_i[:, None]\n tl.store(O_block_ptr, acc.to(Out.type.element_ty), mask = (offs_m[:, None] < N_CTX) & ((tl.arange(0, 128) < 96)[None, :]))\n\ndef forward(q, k, v, q_scale, k_scale):\n BLOCK_M = 128\n BLOCK_N = 64\n HEAD_DIM_Q, HEAD_DIM_K = q.shape[-1], k.shape[-1]\n HEAD_DIM_V = v.shape[-1]\n assert HEAD_DIM_Q == HEAD_DIM_K and HEAD_DIM_K == HEAD_DIM_V\n o = torch.empty_like(q, dtype=torch.float16)\n stage = 1\n\n grid = (triton.cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1)\n M = torch.empty((q.shape[0], q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32)\n _attn_fwd[grid](\n q, k, v, q_scale, k_scale, o, \n q.stride(0), q.stride(1), q.stride(2), q.stride(3), \n k.stride(0), k.stride(1), k.stride(2), k.stride(3), \n v.stride(0), v.stride(1), v.stride(2), v.stride(3), \n o.stride(0), o.stride(1), o.stride(2), o.stride(3), \n q.shape[0], q.shape[1], \n N_CTX=q.shape[2], \n BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, HEAD_DIM=HEAD_DIM_K, \n STAGE=stage, \n num_warps=8, \n num_stages=4)\n return o\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-0f4839ebf6bc44e684af45e54e00672d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton matrix multiplication code is organized around the `matmul_kernel` function, designed for GPU execution. \n This function is auto-tuned for efficiency, adapting to different matrix shapes with several block configurations specified via the `@triton.autotune` decorator. Each configuration varies in block sizes, stages, and warps, optimizing for a given MxNxK problem size. \n The kernel itself operates as follows:\n - `pid` determines the unique program ID within the grid, dividing the workload among GPU units.\n - The grid's structure is defined such that the matrix is divided into tiles of size BLOCK_SIZE_M x BLOCK_SIZE_N.\n - `num_pid_m` and `num_pid_n` calculate how many blocks fit in each matrix dimension.\n - Each program computes part of the result matrix C for a given (M, N) tile by iterating over the K dimension in blocks of BLOCK_SIZE_K.\n - Pointers `a_ptrs` and `b_ptrs` are established to access relevant submatrices.\n - The kernel loops over `K` in chunks of BLOCK_SIZE_K, loading submatrices from A and B, performing a dot product, and accumulating the result.\n - Finally, results are stored into matrix C with appropriate masking to handle edge cases. \n The `triton_matmul` function provides a Pythonic interface for calling the kernel, performing dimension checks, setting up the output matrix, and configuring the execution grid based on input dimensions and kernel metadata.\n ", + "output": [ + "import triton\nimport triton.language as tl\nimport torch\n\n@triton.autotune(\n configs=[\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 64, 'GROUP_SIZE_M': 8}, num_stages=3, num_warps=8),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 256, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 128, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 128, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=4, num_warps=4),\n triton.Config({'BLOCK_SIZE_M': 64, 'BLOCK_SIZE_N': 32, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n triton.Config({'BLOCK_SIZE_M': 32, 'BLOCK_SIZE_N': 64, 'BLOCK_SIZE_K': 32, 'GROUP_SIZE_M': 8}, num_stages=5, num_warps=2),\n ],\n key=['M', 'N', 'K'],\n)\n@triton.jit\ndef matmul_kernel(\n a_ptr, b_ptr, c_ptr, M, N, K,\n stride_am, stride_ak, stride_bk, stride_bn, stride_cm, stride_cn,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n GROUP_SIZE_M: tl.constexpr\n):\n pid = tl.program_id(axis=0)\n num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)\n num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)\n num_pid_in_group = GROUP_SIZE_M * num_pid_n\n group_id = pid // num_pid_in_group\n first_pid_m = group_id * GROUP_SIZE_M\n group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)\n pid_m = first_pid_m + (pid % group_size_m)\n pid_n = (pid % num_pid_in_group) // group_size_m\n\n offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)\n b_ptrs = b_ptr + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)\n\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)\n b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)\n accumulator += tl.dot(a, b)\n a_ptrs += BLOCK_SIZE_K * stride_ak\n b_ptrs += BLOCK_SIZE_K * stride_bk\n\n c_ptrs = c_ptr + (offs_am[:, None] * stride_cm + offs_bn[None, :] * stride_cn)\n c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)\n tl.store(c_ptrs, accumulator, mask=c_mask)\n\ndef triton_matmul(a, b):\n assert a.shape[1] == b.shape[0], \"Incompatible dimensions\"\n M, K = a.shape\n K, N = b.shape\n\n c = torch.empty((M, N), device=a.device, dtype=a.dtype)\n grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']) * triton.cdiv(N, META['BLOCK_SIZE_N']),)\n matmul_kernel[grid](\n a, b, c, M, N, K,\n a.stride(0), a.stride(1),\n b.stride(0), b.stride(1),\n c.stride(0), c.stride(1)\n )\n return c\n\n\n\n" + ] + }, + { + "id": "openseek-8-81fe0331bf714419b8076a4e78eea697", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton kernel '_copy_to_kcache_seqlen_n_kernel' and the function 'copy_k_to_blocked_cache' work together to handle the copying of key or value vectors into a cached structure used in sequence models, especially during the decoding stages. The kernel itself manages the computation of positions within a blocked cache layout, supporting both traditional and new formats as specified by a parameter. It determines current token and sequence indices, computes past sequence lengths, and figures out the exact location in a cache where the token data should be stored. This involves calculating offsets and leveraging Triton’s intrinsic memory operations such as 'tl.load' and 'tl.store'. The kernel is configured using grid dimensions based on batch size and head dimensions, while the function 'copy_k_to_blocked_cache' converts the input tensor 'k' to the appropriate shape and checks for compatibility between input and cache dimensions. It computes necessary strides and calls the kernel with calculated strides and block dimensions, passing constants like 'HEAD_DIM' and 'KCACHE_X' for memory layout control. This kernel and function collectively enable efficient copying of sequence tokens into a block-structured cache format, optimizing the process for various sequence lengths and layouts.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n\n# supports two types of cache layouts\n# 1. [num_blocks, num_kv_heads, block_size, head_dim]\n# 2. [num_blocks, num_kv_heads, head_dim // x, block_size, x]\n@triton.jit\ndef _copy_to_kcache_seqlen_n_kernel(\n K, # K or V\n KCache, # [num_blocks, num_kv_heads, head_dim // x, block_size, x]\n BLOCK_TABLES,\n seq_lengths,\n stride_kt,\n stride_kh,\n stride_kd,\n stride_kcb,\n stride_kch,\n stride_kcsplit_x,\n stride_kcs,\n stride_kcx,\n stride_bts,\n stride_btb,\n block_size,\n n_tokens,\n HEAD_DIM: tl.constexpr,\n KCACHE_X: tl.constexpr,\n):\n # `n_tokens` is used to specify the number of tokens to copy for each sequence\n # When n_tokens > 1, tokens from different sequences are packed into the first dimension of the grid,\n # `seq_lengths` must be the lengths of sequences counting the number of tokens to copy\n # E.g. if n_tokens = 5, seq_lengths = [12, 15], then the already-copied position ids are [0-6, 0-9]\n # for the two sequences, respectively. And the position ids to be copied are [7-11, 9-14].\n # When n_tokens = 1, consider token idx as the sequence idx, since it's only used during regular decoding stage\n cur_token_idx = tl.program_id(0)\n cur_seq_idx = cur_token_idx // n_tokens\n # `cur_token_shift` is only valid and functional when `n_tokens` > 1\n cur_token_shift = cur_token_idx - (n_tokens * (cur_seq_idx + 1))\n cur_kv_head_idx = tl.program_id(1)\n split_x_idx = tl.program_id(2)\n\n past_kv_seq_len = tl.load(seq_lengths + cur_seq_idx) + cur_token_shift\n last_bt_block_idx = past_kv_seq_len // block_size\n block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts\n block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)\n offset_last_block = past_kv_seq_len % block_size\n offsets_dmodel = split_x_idx * KCACHE_X + tl.arange(0, KCACHE_X)\n offsets_k = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd\n k = tl.load(K + offsets_k)\n offsets_kcache = (\n block_id * stride_kcb\n + cur_kv_head_idx * stride_kch\n + split_x_idx * stride_kcsplit_x\n + offset_last_block * stride_kcs\n + tl.arange(0, KCACHE_X)\n )\n tl.store(KCache + offsets_kcache, k)\n return\n\n\ndef copy_k_to_blocked_cache(\n k: torch.Tensor,\n k_cache: torch.Tensor,\n kv_lengths: torch.Tensor,\n block_tables: torch.Tensor,\n n: int = 1,\n use_new_kcache_layout: bool = False,\n):\n \"\"\"\n Copy keys or values to the blocked key/value cache during decoding stage.\n\n Args:\n k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.\n [bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n\n k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.\n new KCache Layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]\n kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.\n block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.\n n (int): Number of tokens to copy for each sequence. Default to 1.\n use_new_kcache_layout (bool): Whether to use the new layout for kcache. Default to False.\n \"\"\"\n assert k.dtype == k_cache.dtype, \"Expected consistent dtype for tensor and cache.\"\n if k.dim() == 4:\n k = k.reshape(-1, k.size(-2), k.size(-1))\n k_shape = k.shape\n bsz, num_kv_heads, head_dim = k_shape\n # NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]\n if n > 1:\n assert bsz % n == 0, \"Each sequence should have the same number of tokens to be copied\"\n bsz = bsz // n\n\n assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (\n f\"Got incompatible batch size (number of seqs):\\n\"\n f\" Past kv sequence lengths bsz {kv_lengths.shape[0]}; \"\n f\" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}\"\n )\n\n k_cache_shape = k_cache.shape\n # Modify if the shape of kv cahce is changed.\n block_size = k_cache_shape[-2]\n\n x = head_dim\n stride_kcsplit_x, stride_kcs, stride_kcd = 0, k_cache.stride(2), k_cache.stride(3)\n if use_new_kcache_layout:\n # when using kcache layout [num_blocks, num_kv_heads, head_dim // x, block_size, x]\n assert (\n len(k_cache_shape) == 5\n and k_cache_shape[1] == k_shape[1]\n and k_cache_shape[2] * k_cache_shape[4] == k_shape[2]\n ), f\"Incompatible k_cache shape {k_cache_shape} with k shape {k_shape}\"\n x = k_cache.size(-1)\n stride_kcsplit_x, stride_kcs, stride_kcd = k_cache.stride()[2:]\n\n num_warps = 8 if head_dim > 128 else 4\n grid = (bsz * n, num_kv_heads, head_dim // x)\n _copy_to_kcache_seqlen_n_kernel[grid](\n k,\n k_cache,\n block_tables,\n kv_lengths,\n k.stride(0),\n k.stride(1),\n k.stride(2),\n k_cache.stride(0),\n k_cache.stride(1),\n stride_kcsplit_x,\n stride_kcs,\n stride_kcd,\n block_tables.stride(0),\n block_tables.stride(1),\n block_size,\n n_tokens=n,\n HEAD_DIM=head_dim,\n KCACHE_X=x,\n num_warps=num_warps,\n )\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-f8f700ef89624ae5b2408f04c3b02dfc", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n This Triton kernel `cos_func` is designed to compute the cosine of each element in an input tensor `a` and write the results to an output tensor `b`. \n It uses a configurable constant `BLOCK_SIZE` which dictates the number of elements each parallel thread block processes. The `offset` variable is calculated to give each block its starting point by considering the global program id and the thread's local position in the block using `tl.arange`. \n The `mask` variable is created to ensure calculations are only performed on valid tensor elements within bounds. \n The elements from `a` are loaded into `a_value` using `tl.load`, respecting the bounds defined by `mask`. \n The cosine of each element is computed with `tl.cos` and stored in the `b_value` tensor. Finally, results are written to the output tensor `b` using `tl.store`, again with the `mask` ensuring correct boundaries. \n The `cos` function in Python acts as a driver for the Triton kernel, preparing the input and output tensors, calculating the block size as the nearest power of 2 greater than the square root of the number of elements, and determining the grid size to cover all elements. It then executes the kernel with these parameters.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\nimport math\n\n# Kernel function: Computes the cosine of each element in the input tensor.\n@triton.jit\ndef cos_func(a, b, n_elements, BLOCK_SIZE: tl.constexpr):\n # Calculate the offset for each block and thread\n offset = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n # Determine which elements are valid within the block\n mask = offset < n_elements\n # Load elements from tensor `a` into `a_value`\n a_value = tl.load(a + offset, mask=mask)\n # Compute the cosine of each element in `a_value`\n b_value = tl.cos(a_value.to(tl.float32))\n # Store the result back to tensor `b`\n tl.store(b + offset, b_value, mask=mask) \n\n# Function to invoke the Triton kernel and perform the computation\ndef cos(A):\n # Prepare output tensor `B` with the same shape and type as `A`\n B = torch.empty_like(A)\n # Determine the total number of elements in the input tensor `A`\n n_elements = A.numel()\n # Calculate the optimal block size\n block_size = triton.next_power_of_2(math.ceil(math.sqrt(n_elements)))\n # Determine the grid size needed to cover all elements\n grid_size = triton.cdiv(n_elements, block_size)\n # Launch the Triton kernel\n cos_func[(grid_size, 1, 1)](A, B, n_elements, block_size)\n return B\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-afb0ada9d57c4bc7ade586834db3428d", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton-based implementation provides optimized GPU kernels for cross-entropy loss calculations. It comprises several key functions:\n \n - `calculate_settings(n)`: Determines the optimal block size and number of warps for GPU execution based on the vocabulary size `n`. It ensures the block size does not exceed `MAX_FUSED_SIZE`.\n \n - `_cross_entropy_forward(...)`: This kernel calculates the cross-entropy loss for a single input row using a block of threads. It optionally applies softcapping and logit scaling to the logits. It performs the log-sum-exp calculation for numerical stability and stores the result. If the label is valid (not `-100`), it computes the loss by subtracting the logit corresponding to the true label from the log-sum-exp value.\n \n - `_chunked_cross_entropy_forward(...)`: This variant handles cases where the vocabulary size exceeds `MAX_FUSED_SIZE`. It breaks the computation into chunks. It uses similar logic as `_cross_entropy_forward` but accumulates partial log-sum-exp results across chunks.\n \n - `_cross_entropy_backward(...)`: Computes gradients with respect to logits. For each block of logits, it calculates the softmax derivative, applies optional transformations (softcapping, logit scaling), and stores the result multiplied by the incoming gradient `dlosses`.\n \n The `Fast_CrossEntropyLoss` class encapsulates these kernels. Its `forward` method checks if chunking is needed, sets up the data, and launches the appropriate kernel(s). The `backward` method launches `_cross_entropy_backward` to compute gradients.\n \n The `fast_cross_entropy_loss` function processes batched inputs and computes the mean loss, accounting for potentially masked labels (value `-100` indicates masking).\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n# from .utils import triton_tanh\nfrom triton.language.extra import libdevice\n\ntriton_tanh = libdevice.tanh\nnext_power_of_2 = triton.next_power_of_2\nMAX_FUSED_SIZE : int = 65536\n\ndef calculate_settings(n : int) -> (int, int):\n BLOCK_SIZE : int = next_power_of_2(n)\n if BLOCK_SIZE > MAX_FUSED_SIZE:\n raise RuntimeError(f\"Cannot launch Triton kernel since n = {n} exceeds \"\\\n f\"the maximum CUDA blocksize = {MAX_FUSED_SIZE}.\")\n num_warps : int = 4\n if BLOCK_SIZE >= 32768: num_warps = 32\n elif BLOCK_SIZE >= 8192: num_warps = 16\n elif BLOCK_SIZE >= 2048: num_warps = 8\n return BLOCK_SIZE, num_warps\n\n@triton.heuristics({\n \"DO_SOFTCAPPING\": lambda args: args[\"DO_SOFTCAPPING\"],\n \"DO_LOGIT_SCALING\": lambda args: args[\"DO_LOGIT_SCALING\"],\n})\n@triton.jit\ndef _cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n DO_SOFTCAPPING: tl.constexpr,\n SOFTCAP: tl.constexpr,\n DO_LOGIT_SCALING: tl.constexpr,\n LOGIT_SCALE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx\n labels_ptr += row_idx\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\"))\n\n if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits\n if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)\n\n logits = logits.to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx)\n if DO_LOGIT_SCALING: x = LOGIT_SCALE * x\n if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)\n loss = logsumexp - x.to(tl.float32)\n else:\n loss = 0.0\n tl.store(logsumexp_ptr, logsumexp)\n tl.store(loss_ptr, loss)\n\n@triton.heuristics({\n \"DO_SOFTCAPPING\": lambda args: args[\"DO_SOFTCAPPING\"],\n \"DO_LOGIT_SCALING\": lambda args: args[\"DO_LOGIT_SCALING\"],\n})\n@triton.jit\ndef _chunked_cross_entropy_forward(\n logits_ptr, logits_row_stride,\n loss_ptr,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE: tl.constexpr,\n N_CHUNKS: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n DO_SOFTCAPPING: tl.constexpr,\n SOFTCAP: tl.constexpr,\n DO_LOGIT_SCALING: tl.constexpr,\n LOGIT_SCALE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n chunk_idx = tl.program_id(1)\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n loss_ptr += row_idx\n logsumexp_ptr += row_idx * N_CHUNKS + chunk_idx\n labels_ptr += row_idx\n\n col_offsets = chunk_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n\n label_idx = tl.load(labels_ptr).to(tl.int32)\n logits = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\"))\n\n if DO_LOGIT_SCALING: logits = LOGIT_SCALE * logits\n if DO_SOFTCAPPING: logits = SOFTCAP * triton_tanh(logits / SOFTCAP)\n\n logits = logits.to(tl.float32)\n c = tl.max(logits, 0)\n logsumexp = c + tl.log(tl.sum(tl.exp(logits - c), 0))\n\n if chunk_idx == 0:\n if label_idx != -100:\n x = tl.load(logits_ptr + label_idx).to(tl.float32)\n if DO_LOGIT_SCALING: x = LOGIT_SCALE * x\n if DO_SOFTCAPPING: x = SOFTCAP * triton_tanh(x / SOFTCAP)\n loss = -1.0 * x.to(tl.float32)\n else:\n loss = 0.0\n tl.store(loss_ptr, loss)\n tl.store(logsumexp_ptr, logsumexp)\n\n\n@triton.heuristics({\n \"DO_SOFTCAPPING\": lambda args: args[\"DO_SOFTCAPPING\"],\n \"DO_LOGIT_SCALING\": lambda args: args[\"DO_LOGIT_SCALING\"],\n})\n@triton.jit\ndef _cross_entropy_backward(\n logits_ptr, logits_row_stride,\n dloss_ptr, dloss_row_stride,\n logsumexp_ptr,\n labels_ptr,\n VOCAB_SIZE: tl.constexpr,\n BLOCK_SIZE: tl.constexpr,\n DO_SOFTCAPPING: tl.constexpr,\n SOFTCAP: tl.constexpr,\n DO_LOGIT_SCALING: tl.constexpr,\n LOGIT_SCALE: tl.constexpr,\n):\n row_idx = tl.program_id(0)\n block_idx = tl.program_id(1)\n\n logits_ptr += row_idx * logits_row_stride.to(tl.int64)\n dloss_ptr += row_idx * dloss_row_stride\n col_offsets = block_idx*BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < VOCAB_SIZE\n label_idx = tl.load(labels_ptr + row_idx).to(tl.int32)\n\n if label_idx != -100:\n dloss = tl.load(dloss_ptr)\n else:\n dloss = 0.0\n\n x = tl.load(logits_ptr + col_offsets, mask=mask, other=-float(\"inf\"))\n\n if DO_LOGIT_SCALING:\n x = x * LOGIT_SCALE\n \n if DO_SOFTCAPPING:\n partial = triton_tanh(x / SOFTCAP)\n x = SOFTCAP * partial\n \n logsumexp = tl.load(logsumexp_ptr + row_idx)\n y = tl.exp(x.to(tl.float32) - logsumexp)\n y = tl.where(\n col_offsets == label_idx,\n y - 1.0,\n y,\n )\n\n if DO_LOGIT_SCALING:\n y = y * LOGIT_SCALE\n \n if DO_SOFTCAPPING:\n y = y * (1.0 - partial*partial)\n \n tl.store(logits_ptr + col_offsets, dloss * y, mask=mask)\n\nMAX_FUSED_SIZE = 65536\n\nclass Fast_CrossEntropyLoss(torch.autograd.Function):\n @staticmethod\n def forward(ctx, logits, labels, logit_softcapping=0, logit_scaling=0):\n n_rows, vocab_size = logits.shape\n\n div, mod = divmod(vocab_size, MAX_FUSED_SIZE)\n n_chunks = div + (mod != 0)\n losses = torch.empty(n_rows, dtype=torch.float32, device=\"cuda:0\")\n\n DO_SOFTCAPPING = (logit_softcapping != 0)\n DO_LOGIT_SCALING = (logit_scaling != 0)\n\n if n_chunks == 1:\n BLOCK_SIZE, num_warps = calculate_settings(vocab_size)\n logsumexp = torch.empty(n_rows, dtype=torch.float32, device=\"cuda:0\")\n\n _cross_entropy_forward[(n_rows,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE=vocab_size,\n BLOCK_SIZE=BLOCK_SIZE,\n DO_SOFTCAPPING=DO_SOFTCAPPING,\n SOFTCAP=logit_softcapping,\n DO_LOGIT_SCALING=DO_LOGIT_SCALING,\n LOGIT_SCALE=logit_scaling,\n num_warps=num_warps,\n )\n else:\n logsumexp = torch.empty((n_rows, n_chunks,), dtype=torch.float32, device=\"cuda:0\")\n\n _chunked_cross_entropy_forward[(n_rows, n_chunks,)](\n logits, logits.stride(0),\n losses,\n logsumexp,\n labels,\n VOCAB_SIZE=vocab_size,\n N_CHUNKS=n_chunks,\n BLOCK_SIZE=MAX_FUSED_SIZE,\n DO_SOFTCAPPING=DO_SOFTCAPPING,\n SOFTCAP=logit_softcapping,\n DO_LOGIT_SCALING=DO_LOGIT_SCALING,\n LOGIT_SCALE=logit_scaling,\n num_warps=32,\n )\n logsumexp = torch.logsumexp(logsumexp, dim=1)\n losses += logsumexp\n losses.masked_fill_(labels == -100, 0)\n \n ctx.save_for_backward(logits, logsumexp, labels)\n ctx.DO_SOFTCAPPING = DO_SOFTCAPPING\n ctx.logit_softcapping = logit_softcapping\n ctx.DO_LOGIT_SCALING = DO_LOGIT_SCALING\n ctx.logit_scaling = logit_scaling\n return losses\n \n @staticmethod\n def backward(ctx, dlosses):\n logits, logsumexp, labels = ctx.saved_tensors\n n_rows, vocab_size = logits.shape\n\n BLOCK_SIZE = 4096\n div, mod = divmod(vocab_size, BLOCK_SIZE)\n n_blocks = div + (mod != 0)\n\n _cross_entropy_backward[(n_rows, n_blocks,)](\n logits, logits.stride(0),\n dlosses, dlosses.stride(0),\n logsumexp,\n labels,\n VOCAB_SIZE=vocab_size,\n BLOCK_SIZE=BLOCK_SIZE,\n DO_SOFTCAPPING=ctx.DO_SOFTCAPPING,\n SOFTCAP=ctx.logit_softcapping,\n DO_LOGIT_SCALING=ctx.DO_LOGIT_SCALING,\n LOGIT_SCALE=ctx.logit_scaling,\n num_warps=8,\n )\n return logits, None, None, None,\n \ndef fast_cross_entropy_loss(\n logits,\n labels,\n logit_softcapping=0,\n logit_scaling=0,\n n_items=None,\n):\n batch, seq_len, d = logits.shape\n assert(labels.shape == (batch, seq_len))\n\n loss = Fast_CrossEntropyLoss.apply(\n logits.view(batch*seq_len, d),\n labels.view(-1),\n logit_softcapping,\n logit_scaling,\n )\n if n_items is None:\n n_items = torch.count_nonzero(labels != -100)\n return loss.sum() / n_items\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-dc527ecbaa8f48c6984e4c413efcbb54", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided code implements a custom log softmax operation using Triton, an efficient parallel programming library for GPUs, and integrates it as a PyTorch autograd function. The code comprises two main Triton kernels and a Python class to wrap them.\n\n The 'log_softmax_kernel' Triton kernel computes the log softmax along a specified dimension of a tensor. It does so by first calculating the row-wise maximum to stabilize the softmax computation, subtracting this maximum from each element, and then computing the exponentials. The sum of these exponentials for each row is used to normalize the outputs. The kernel uses triton.autotune and triton.heuristics for performance optimization based on input size, leveraging BLOCK_M and BLOCK_N constants to define block sizes and num_warps to optimize parallel execution.\n\n The 'log_softmax_backward_kernel' is the backward pass for the log softmax operation. It calculates the gradient of the input tensor by using the chain rule, considering the output and the gradient of the output from the forward pass. This involves subtracting the weighted exponentiated output from the gradient, scaled by the sum of the gradients along the specified dimension.\n\n The 'LogSoftmax' class serves as a PyTorch custom autograd function, which integrates the Triton kernels into PyTorch's autograd system. In the 'forward' method, it calculates necessary dimensions (M and N) and the grid size for kernel invocation, invoking 'log_softmax_kernel' to compute the result. The result is saved for backward computation. The 'backward' method retrieves the saved output and uses 'log_softmax_backward_kernel' to calculate the input gradient, correctly considering dimensionality for both forward and backward passes.\n\n Finally, the 'log_softmax' function acts as a convenient wrapper for users, accepting a tensor and a dimension along which to apply the log softmax, while handling contiguous memory and optional dtype specification. It applies the 'LogSoftmax' function, returning the result tensor with the same shape as the input.\n ", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n\n\ndef heur_block_n(args):\n return triton.next_power_of_2(args[\"N\"])\n\n\ndef heur_num_warps(args):\n if args[\"N\"] <= 1024:\n return 4\n elif args[\"N\"] <= 2048:\n return 8\n else:\n return 16\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 1}),\n triton.Config({\"BLOCK_M\": 2}),\n triton.Config({\"BLOCK_M\": 4}),\n triton.Config({\"BLOCK_M\": 8}),\n ],\n key=[\n \"M\",\n \"N\",\n ],\n)\n@triton.heuristics(\n {\n \"BLOCK_N\": heur_block_n,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit\ndef log_softmax_kernel(\n output_ptr,\n input_ptr,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n offset = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n input_ptrs = input_ptr + offset\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\")).to(tl.float32)\n row_minus_max = inp - tl.max(inp, axis=1)[:, None]\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=1)[:, None]\n softmax_output = tl.log(numerator / denominator)\n output_ptrs = output_ptr + offset\n tl.store(output_ptrs, softmax_output, mask=mask)\n\n\n\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_M\": 1}),\n triton.Config({\"BLOCK_M\": 2}),\n triton.Config({\"BLOCK_M\": 4}),\n triton.Config({\"BLOCK_M\": 8}),\n ],\n key=[\n \"M\",\n \"N\",\n ],\n)\n@triton.heuristics(\n {\n \"BLOCK_N\": heur_block_n,\n \"num_warps\": heur_num_warps,\n }\n)\n@triton.jit\ndef log_softmax_backward_kernel(\n out_ptr,\n out_grad_ptr,\n in_grad_ptr,\n M,\n N,\n K,\n BLOCK_M: tl.constexpr,\n BLOCK_N: tl.constexpr,\n):\n pid_m = tl.program_id(0)\n pid_k = tl.program_id(1)\n m_offset = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)\n n_offset = tl.arange(0, BLOCK_N)\n\n offsets = m_offset[:, None] * N * K + n_offset[None, :] * K + pid_k\n mask = m_offset[:, None] < M and n_offset[None, :] < N\n out_ptrs = out_ptr + offsets\n out = tl.load(out_ptrs, mask=mask).to(tl.float32)\n out_grad_ptrs = out_grad_ptr + offsets\n out_grad = tl.load(out_grad_ptrs, mask=mask).to(tl.float32)\n\n scale = tl.sum(out_grad, 1)\n in_grad = out_grad - tl.exp(out.to(tl.float32)) * scale[:, None]\n\n in_grad_ptrs = in_grad_ptr + offsets\n tl.store(in_grad_ptrs, in_grad, mask=mask)\n\n\nclass LogSoftmax(torch.autograd.Function):\n @staticmethod\n def forward(ctx, x, dim, dtype):\n logging.debug(\"GEMS LOG_SOFTMAX\")\n\n assert dim >= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i]\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n K = inp.numel() // M // N\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(inp.device):\n log_softmax_kernel[grid](\n out,\n inp,\n M,\n N,\n K,\n )\n ctx.save_for_backward(out)\n ctx.dim = dim\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n logging.debug(\"GEMS LOG_SOFTMAX VJP\")\n\n dim = ctx.dim\n (out,) = ctx.saved_tensors\n\n assert dim >= -out.ndim and dim < out.ndim, \"Invalid dim\"\n dim = dim % out.ndim\n M = 1\n N = out.shape[dim]\n for i in range(dim):\n M *= out.shape[i]\n\n out_grad = out_grad.contiguous()\n in_grad = torch.empty_like(out)\n K = out.numel() // M // N\n\n grid = lambda meta: (\n triton.cdiv(M, meta[\"BLOCK_M\"]),\n K,\n )\n with torch.cuda.device(in_grad.device):\n log_softmax_backward_kernel[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n K,\n )\n return in_grad, None, None\n\n\ndef log_softmax(x, dim=-1, dtype=None):\n return LogSoftmax.apply(x, dim, dtype)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-c1f9916869f04e54bd07ca46b9e26463", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.The `index_select_cat_fwd` function is implemented using Triton to accelerate the process of selecting and concatenating specific rows from a 2D source tensor based on a 1D index tensor, both residing on the GPU. This function requires that `source` and `index` be CUDA tensors, with `source` being 2-dimensional and `index` 1-dimensional. If the number of indices exceeds the number of rows in the source, a warning is printed, and the indices are truncated. The function extracts strides for source tensor dimension traversal. The `index_select_cat_fwd_kernel`, decorated with `@triton.jit`, uses a 2D grid to distribute workload across blocks, where each block handles a block of indices (`BLOCK_SIZE_INDEX`) and a block of columns (`BLOCK_SIZE_COL`). Within the kernel, program ids `pid0` and `pid1` determine the starting indices and columns for the block. It computes offsets for source and output tensors, applying masks to ensure bounds. `tl.load` is used to fetch data from the source, and `tl.store` writes the selected data into the output tensor. The `grid` function computes the dimensions of the launch grid based on the number of indices and columns, utilizing `triton.cdiv` for ceil division.", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef index_select_cat_fwd_kernel(\n output_ptr, # *Pointer* to output tensor.\n source_ptr, # *Pointer* to source tensor.\n index_ptr, # *Pointer* to index tensor.\n num_indices,\n num_cols,\n stride0, # Stride information of source tensor.\n stride1,\n BLOCK_SIZE_INDEX: tl.constexpr, # Number of indices each program should process.\n BLOCK_SIZE_COL: tl.constexpr, # Number of cols each program should process.\n):\n pid0 = tl.program_id(axis=0) # We use 2D launch grid\n pid1 = tl.program_id(axis=1)\n\n indices = pid0 * BLOCK_SIZE_INDEX + tl.arange(0, BLOCK_SIZE_INDEX)\n rows = tl.load(index_ptr + indices, mask=(indices < num_indices))\n cols = pid1 * BLOCK_SIZE_COL + tl.arange(0, BLOCK_SIZE_COL)\n\n source_offsets = source_ptr + rows[:, None] * stride0 + cols[None, :] * stride1\n mask = (indices[:, None] < num_indices) & (cols[None, :] < num_cols)\n output = tl.load(source_offsets, mask=mask)\n\n output_offsets = output_ptr + indices[:, None] * stride0 + cols[None, :] * stride1\n tl.store(output_offsets, output, mask=mask)\n\n\ndef index_select_cat_fwd(\n output: torch.Tensor,\n source: torch.Tensor,\n index: torch.Tensor,\n):\n if not (source.is_cuda and index.is_cuda):\n raise ValueError(\"The index tensor and the source tensor must be of type CUDA!\")\n\n if not source.ndim == 2:\n raise ValueError(f\"Expected 2-dimensional tensor, got {source.ndim}.\")\n if not index.ndim == 1:\n raise ValueError(f\"Expected 1-dimensional tensor, got {index.ndim}.\")\n\n num_rows, num_cols = source.shape\n num_indices = index.shape[0]\n\n if num_indices > num_rows:\n print(f\"Warning: The number of indices exceeds the number of rows in the source tensor. Truncating indices.\")\n num_indices = num_rows\n index = index[:num_rows]\n\n stride0, stride1 = source.stride(0), source.stride(1)\n\n def grid(meta):\n return (\n triton.cdiv(num_indices, meta[\"BLOCK_SIZE_INDEX\"]),\n triton.cdiv(num_cols, meta[\"BLOCK_SIZE_COL\"]),\n )\n\n index_select_cat_fwd_kernel[grid](\n output,\n source,\n index,\n num_indices,\n num_cols,\n stride0,\n stride1,\n BLOCK_SIZE_INDEX=1,\n BLOCK_SIZE_COL=512,\n )\n\n return output\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4bfb5bfd15644570a93f8ce592800cd0", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel function '_fwd_kernel_token_att1' performs a forward pass for a token attention mechanism, often used in transformer models for calculating the attention scores between query (Q) and key (K) tensors. This kernel accepts various input tensors, including 'Q' and 'K', which represent the query and key matrices, 'B_Loc', which provides positional information about the keys, 'B_Start_Loc' and 'B_Seqlen', which give sequence start indices and lengths, and 'max_input_len' as the maximum sequence length within a batch. The output 'Att_Out' stores the computed attention values. Internally, the kernel uses several predefined strides to navigate through the tensor dimensions. For each batch, head, and block of sequences defined by the block size 'BLOCK_N', it loads segments of the query and key tensors, computes their dot product, scales it by 'sm_scale', and stores the result back. The grid configuration ensures coverage across all required dimensions with specific 'num_warps' for parallel processing. The wrapper function 'token_att_fwd' initializes and launches this kernel, checking the compatibility of input dimensions and adjusting the scaling factor based on the key dimension.\n ", + "output": [ + "\nimport triton\nimport triton.language as tl\nimport torch\n\n@triton.jit\ndef _fwd_kernel_token_att1(\n Q, K, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n Att_Out,\n stride_b_loc_b, stride_b_loc_s,\n stride_qbs, stride_qh, stride_qd,\n stride_kbs, stride_kh, stride_kd,\n att_stride_h, att_stride_bs,\n kv_group_num,\n BLOCK_DMODEL: tl.constexpr,\n BLOCK_N: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n start_n = tl.program_id(2)\n \n cur_kv_head = cur_head // kv_group_num\n\n offs_d = tl.arange(0, BLOCK_DMODEL)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n cur_batch_start_index = max_input_len - cur_batch_seq_len\n cur_batch_end_index = max_input_len\n\n off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd\n\n offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N)\n\n block_stard_index = start_n * BLOCK_N\n block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0)\n\n for start_mark in range(0, block_mask, 1):\n q = tl.load(Q + off_q + start_mark)\n offs_n_new = cur_batch_start_index + offs_n\n k_loc = tl.load(B_Loc + stride_b_loc_b * cur_batch + stride_b_loc_s * offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)\n off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd\n k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)\n att_value = tl.sum(q[None, :] * k, 1)\n att_value *= sm_scale\n off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs\n tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index)\n return\n\n@torch.no_grad()\ndef token_att_fwd(q, k, att_out, B_Loc, B_Start_Loc, B_Seqlen, max_input_len):\n BLOCK = 32\n # shape constraints\n Lq, Lk = q.shape[-1], k.shape[-1]\n assert Lq == Lk\n assert Lk in {16, 32, 64, 128}\n sm_scale = 1.0 / (Lk ** 0.5)\n\n batch, head_num = B_Loc.shape[0], q.shape[1]\n\n grid = (batch, head_num, triton.cdiv(max_input_len, BLOCK))\n kv_group_num = q.shape[1] // k.shape[1]\n\n num_warps = 4 if Lk <= 64 else 8\n num_warps = 2\n\n _fwd_kernel_token_att1[grid](\n q, k, sm_scale, B_Loc, B_Start_Loc, B_Seqlen, max_input_len,\n att_out,\n B_Loc.stride(0), B_Loc.stride(1),\n q.stride(0), q.stride(1), q.stride(2),\n k.stride(0), k.stride(1), k.stride(2),\n att_out.stride(0), att_out.stride(1),\n kv_group_num=kv_group_num,\n BLOCK_DMODEL=Lk,\n BLOCK_N=BLOCK,\n num_warps=num_warps,\n num_stages=1,\n )\n return\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-722414b05558495cb37e835fe25a39aa", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_bwd_kernel` performs a backward pass operation for L2 normalization on a per-row basis. It receives pointers to input `X`, output gradient `DY`, and calculates the input gradient `DX`. Each row of the input is accessed using the `stride_x_row`. `BLOCK_N` determines the number of elements processed per block, set based on maximum allowable fused size and next power of 2 of `N`. Within the kernel, it computes the variance of the input slice, uses it to compute the reciprocal of the standard deviation (`rstd`), and then calculates `dx` using the formula `dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x`. The result is conditionally stored in `DX` using masks. The `_l2_norm_bwd` function orchestrates this process, ensuring input tensors `x` and `dy` are properly reshaped and their strides configured for contiguity if necessary. If `N` exceeds `BLOCK_N`, an error is raised to prevent excessive feature dimensions. Finally, the kernel is launched over `M` rows of the reshaped tensors, and the output `dx` is reshaped back to the original input shape.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_bwd_kernel(\n X, # pointer to the input\n DY, # pointer to the output gradient\n DX, # pointer to the input gradient\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n DX += row * stride_x_row\n DY += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n x = tl.where(cols < N, x, 0.0)\n var = tl.sum(x * x) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n dy = tl.load(DY + cols, mask=cols < N, other=0.0).to(tl.float32)\n dy = tl.where(cols < N, dy, 0.0)\n dx = dy * rstd - tl.sum(dy * x) * (1 / (var+eps)) * rstd * x\n tl.store(DX + cols, dx, mask=mask)\n\ndef _l2_norm_bwd(\n x, dy, eps=1e-5,\n):\n x_shape_og = x.shape\n x = x.reshape(-1, dy.shape[-1])\n dy = dy.reshape(-1, dy.shape[-1])\n if dy.stride(-1) != 1:\n dy = dy.contiguous()\n dx = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_bwd_kernel[(M,)](\n x,\n dy,\n dx,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return dx.reshape(x_shape_og)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4ada70c336aa4133a82373ffe8d8e6ac", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The provided Triton code defines a custom forward-pass kernel function `_fwd_kernel_token_softmax` designed for computing the softmax of token logits with variable sequence lengths within a batch and multiple attention heads. The kernel is launched through the `token_softmax_fwd` function, which orchestrates the setup of the kernel's execution parameters based on the input data. The main task of `_fwd_kernel_token_softmax` is to handle each batch and head independently using two-dimensional parallelism. It loads the appropriate segment of logits for each token sequence, applies a numerically stable softmax by subtracting the maximum logit value from each element, and normalizes by the total exponentiated sum. This operation efficiently computes softmax for padded sequences by masking invalid positions using Triton's `mask` mechanism, which replaces them with negative infinity. The kernel parameters such as `num_warps` are dynamically adjusted based on the block size, enhancing performance across various sequence lengths.\n ", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\n@triton.jit\ndef _fwd_kernel_token_softmax(\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n stride_logic_h, stride_logic_bs,\n stride_prob_h, stride_prob_bs,\n BLOCK_SIZE: tl.constexpr\n):\n cur_batch = tl.program_id(0)\n cur_head = tl.program_id(1)\n\n col_offsets = tl.arange(0, BLOCK_SIZE)\n cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)\n cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)\n\n row = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_in_all_start_index + col_offsets) * stride_logic_bs,\n mask=col_offsets < cur_batch_seq_len, other=-float('inf')).to(tl.float32)\n\n row_minus_max = row - tl.max(row, axis=0)\n numerator = tl.exp(row_minus_max)\n denominator = tl.sum(numerator, axis=0)\n softmax_output = numerator / denominator\n\n tl.store(Prob_Out + cur_head * stride_prob_h + (cur_batch_in_all_start_index + col_offsets)\n * stride_prob_bs, softmax_output, mask=col_offsets < cur_batch_seq_len)\n return\n\n\n@torch.no_grad()\ndef token_softmax_fwd(Logics, B_Start_Loc, B_Seqlen, Prob_Out, max_input_len):\n BLOCK_SIZE = triton.next_power_of_2(max_input_len)\n batch, head_num = B_Start_Loc.shape[0], Logics.shape[0]\n\n num_warps = 4\n if BLOCK_SIZE >= 2048:\n num_warps = 8\n if BLOCK_SIZE >= 4096:\n num_warps = 16\n\n _fwd_kernel_token_softmax[(batch, head_num)](\n Logics, B_Start_Loc, B_Seqlen,\n Prob_Out,\n Logics.stride(0), Logics.stride(1),\n Prob_Out.stride(0), Prob_Out.stride(1),\n num_warps=num_warps,\n BLOCK_SIZE=BLOCK_SIZE,\n )\n return\n\n\n\n" + ] + }, + { + "id": "openseek-8-1eaca87e449142678f3a42c7c5ee4196", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code defines a Triton-based GPU kernel called `chunk_global_cumsum_vector_kernel` and a corresponding Python function `chunk_global_cumsum_vector` for applying this kernel. The kernel computes a global cumulative sum over blocks in a 4D input tensor `s`. The tensor dimensions are interpreted as [Batch, Head, Time, Size]. The kernel uses block processing, defined by the `BT` (block time) and `BS` (block size) parameters. The kernel performs the following steps:\n\n 1. It determines the program's position using `tl.program_id`, which assigns work to the kernel across blocks of the tensor.\n 2. A lower triangular mask `m_s` is created to ensure each element's sum includes all previous elements in the current block row.\n 3. For each block in the time dimension (`T`), a pointer to the relevant data slice in `s` is created with `tl.make_block_ptr`.\n 4. Data from `s` is loaded into `b_s`, processed as `float32` for precision.\n 5. Matrix multiplication (`tl.dot`) with the mask `m_s` is used to compute the block-level cumulative sum, stored in `b_c`.\n 6. The result is stored back to the output tensor `z`.\n 7. The running sum `b_z` is updated after processing each block column.\n\n The `chunk_global_cumsum_vector` function prepares the input tensor `s`, and creates an output tensor `z` of the same shape. It sets the computation grid based on tensor dimensions and block size. The kernel is called with this grid, using dimensions like `stride` to correctly address tensor slices.\n\n The kernel uses Triton's `autotune` decorator to choose optimal execution configurations (combinations of block size and number of warps) based on input size `S`. This approach improves performance by adapting to specific hardware characteristics.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Optional\n\n@triton.autotune(\n configs=[\n triton.Config({'BT': 16}, num_warps=2),\n triton.Config({'BT': 16}, num_warps=4),\n triton.Config({'BT': 16}, num_warps=8),\n triton.Config({'BT': 32}, num_warps=2),\n triton.Config({'BT': 32}, num_warps=4),\n triton.Config({'BT': 32}, num_warps=8),\n triton.Config({'BT': 64}, num_warps=2),\n triton.Config({'BT': 64}, num_warps=4),\n triton.Config({'BT': 64}, num_warps=8),\n ],\n key=['S']\n)\n@triton.jit\ndef chunk_global_cumsum_vector_kernel(\n s,\n z,\n s_s_h,\n s_s_t,\n s_s_d,\n T: tl.constexpr,\n S: tl.constexpr,\n BT: tl.constexpr,\n BS: tl.constexpr\n):\n i_s, i_bh = tl.program_id(0), tl.program_id(1)\n o_i = tl.arange(0, BT)\n m_s = tl.where(o_i[:, None] >= o_i[None, :], 1., 0.)\n b_z = tl.zeros([BS], dtype=tl.float32)\n for i_t in range(tl.cdiv(T, BT)):\n p_s = tl.make_block_ptr(s + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n p_z = tl.make_block_ptr(z + i_bh * s_s_h, (T, S), (s_s_t, s_s_d), (i_t * BT, i_s * BS), (BT, BS), (1, 0))\n # [BT, BS]\n b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32)\n b_c = b_z[None, :] + tl.dot(m_s, b_s, allow_tf32=False)\n tl.store(p_z, b_c.to(p_z.dtype.element_ty), boundary_check=(0, 1))\n if i_t >= 0:\n b_z += tl.sum(b_s, 0)\n\ndef chunk_global_cumsum_vector(\n s: torch.Tensor,\n dtype: Optional[torch.dtype] = None,\n) -> torch.Tensor:\n B, H, T, S = s.shape\n BS = 32\n dtype = dtype or s.dtype\n grid = (triton.cdiv(S, BS), B * H)\n z = torch.empty_like(s, dtype=dtype)\n chunk_global_cumsum_vector_kernel[grid](\n s, z,\n s.stride(1), s.stride(2), s.stride(3),\n T=T, S=S, BS=BS\n )\n return z\n\n\n\n" + ] + }, + { + "id": "openseek-8-39b874b1f1494637a71a460f60bf2e72", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton SSM kernel functions define efficient GPU computations for state-space models by iterating through sequence data. \n\n `diag_ssm_forward_kernel` handles the forward pass for real-valued inputs, applying the transformation `s = s * Lambda + x` across a sequence of data stored in `x_ptr`. It reads initial states from `s_ptr` and writes outputs to `y_ptr`. The transformation is iteratively applied for each timestep `t` in the sequence `length`, over dimensions defined by `batch_size` and `dim`, using a grid that spans `(batch_size * dim)` divided into blocks of `BLOCK_SIZE`.\n\n `diag_ssm_forward_kernel_complex` extends this to complex numbers, handling real and imaginary parts separately, and performing complex multiplication and addition.\n\n For the backward pass, `diag_ssm_backward_kernel` and its complex counterpart compute gradients needed for learning, propagating them backwards from `grad_y_ptr` through the sequence, updating `grad_s_ptr`, `grad_x_ptr`, and `grad_lambda_ptr`. Gradients respect the chain rule, ensuring consistency with the complex autograd requirements in PyTorch.\n\n The enclosing class `_ssm_forward` uses these kernels within PyTorch's autograd framework, handling input validation and complex data type checks. It initiates GPU kernel execution by defining a grid configuration function to optimize computation.\n ", + "output": [ + "import torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef diag_ssm_forward_kernel(s_ptr, x_ptr, lambda_ptr, y_ptr, length,\n batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n 前向传播核函数(实数版本)\n\n 参数:\n s_ptr: [batch_size, dim]\n x_ptr: [length, batch_size, dim]\n lambda_ptr: [dim]\n y_ptr: [length, batch_size, dim]\n \"\"\"\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n s = tl.load(s_ptr + col_offsets, mask=mask, other=0)\n Lambda = tl.load(lambda_ptr + col_offsets % dim, mask=mask, other=0)\n for t in range(length):\n offsets = t * batch_size * dim + col_offsets\n x = tl.load(x_ptr + offsets, mask=mask, other=0)\n s = s * Lambda + x\n tl.store(y_ptr + offsets, s, mask=mask)\n\n@triton.jit\ndef diag_ssm_backward_kernel(\n s_ptr, lambda_ptr, y_ptr, grad_s_ptr, grad_x_ptr, grad_lambda_ptr,\n grad_y_ptr, length, batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n 反向传播核函数(实数版本)\n\n 参数:\n s_ptr: [batch_size, dim]\n lambda_ptr: [dim]\n y_ptr: [length, batch_size, dim]\n grad_s_ptr: [batch_size, dim]\n grad_x_ptr: [length, batch_size, dim]\n grad_lambda_ptr: [batch_size, dim]\n grad_y_ptr: [length, batch_size, dim]\n \"\"\"\n\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n Lambda = tl.load(lambda_ptr + col_offsets % dim, mask=mask, other=0)\n\n # 初始化梯度为零\n grad_s = tl.zeros_like(Lambda)\n grad_Lambda = tl.zeros_like(Lambda)\n\n for i in range(length):\n # Triton 不支持 range(length - 1, -1, -1)\n t = length - 1 - i\n offsets = t * batch_size * dim + col_offsets\n\n grad_y = tl.load(grad_y_ptr + offsets, mask=mask, other=0)\n if t > 0:\n s = tl.load(\n y_ptr + offsets - batch_size * dim, mask=mask, other=0)\n else:\n s = tl.load(s_ptr + col_offsets, mask=mask, other=0)\n\n grad_s = grad_y + grad_s\n grad_x = grad_s\n grad_Lambda += grad_s * s\n grad_s = grad_s * Lambda\n\n tl.store(grad_x_ptr + offsets, grad_x, mask=mask)\n\n tl.store(grad_s_ptr + col_offsets, grad_s, mask=mask)\n tl.store(grad_lambda_ptr + col_offsets, grad_Lambda, mask=mask)\n\n@triton.jit\ndef diag_ssm_forward_kernel_complex(s_ptr, x_ptr, y_ptr, lambda_ptr,\n length, batch_size, dim,\n BLOCK_SIZE: tl.constexpr):\n \"\"\"\n 前向传播核函数(复数版本)\n\n 参数:\n s_ptr: [batch_size, dim, 2]\n x_ptr: [length, batch_size, dim, 2]\n lambda_ptr: [dim, 2]\n y_ptr: [length, batch_size, dim, 2]\n \"\"\"\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n # 加载's'和'Lambda'的实部和虚部\n s_real = tl.load(s_ptr + col_offsets * 2, mask=mask, other=0)\n s_imag = tl.load(s_ptr + col_offsets * 2 + 1, mask=mask, other=0)\n lambda_real = tl.load(\n lambda_ptr + (col_offsets % dim) * 2, mask=mask, other=0)\n lambda_imag = tl.load(\n lambda_ptr + (col_offsets % dim) * 2 + 1, mask=mask, other=0)\n\n for t in range(length):\n offsets = (t * batch_size * dim + col_offsets) * 2\n # 加载'x'的实部和虚部\n x_real = tl.load(x_ptr + offsets, mask=mask, other=0)\n x_imag = tl.load(x_ptr + offsets + 1, mask=mask, other=0)\n\n # 复数的乘法和加法\n new_s_real = s_real * lambda_real - s_imag * lambda_imag + x_real\n new_s_imag = s_real * lambda_imag + s_imag * lambda_real + x_imag\n\n # 存储更新后的实部和虚部\n tl.store(y_ptr + offsets, new_s_real, mask=mask)\n tl.store(y_ptr + offsets + 1, new_s_imag, mask=mask)\n\n # 更新's'以进行下一次迭代\n s_real, s_imag = new_s_real, new_s_imag\n\n@triton.jit\ndef diag_ssm_backward_kernel_complex(\n s_ptr, lambda_ptr, y_ptr, grad_s_ptr, grad_x_ptr, grad_lambda_ptr,\n grad_y_ptr, length, batch_size, dim, BLOCK_SIZE: tl.constexpr):\n \"\"\"\n 反向传播核函数(复数版本)\n\n 参数:\n s_ptr: [batch_size, dim, 2]\n lambda_ptr: [dim, 2]\n y_ptr: [length, batch_size, dim, 2]\n grad_s_ptr: [batch_size, dim, 2]\n grad_x_ptr: [length, batch_size, dim, 2]\n grad_lambda_ptr: [batch_size, dim, 2]\n grad_y_ptr: [length, batch_size, dim, 2]\n \"\"\"\n\n # 复数自导数计算 \\partial f / \\partial z^*\n # 因此在计算过程中需要取共轭\n # 参考:https://pytorch.org/docs/stable/notes/autograd.html#autograd-for-complex-numbers\n # 所以在加载/存储梯度的虚部时,需要取反\n\n col_idx = tl.program_id(0) * BLOCK_SIZE\n col_offsets = col_idx + tl.arange(0, BLOCK_SIZE)\n mask = col_offsets < batch_size * dim\n\n # 加载'Lambda'的实部和虚部\n lambda_real = tl.load(\n lambda_ptr + (col_offsets % dim) * 2, mask=mask, other=0)\n lambda_imag = tl.load(\n lambda_ptr + (col_offsets % dim) * 2 + 1, mask=mask, other=0)\n\n # 初始化梯度为零\n grad_s_real = tl.zeros_like(lambda_real)\n grad_s_imag = tl.zeros_like(lambda_imag)\n grad_lambda_real = tl.zeros_like(lambda_real)\n grad_lambda_imag = tl.zeros_like(lambda_imag)\n\n for i in range(length):\n # Triton 不支持 range(length - 1, -1, -1)\n t = length - 1 - i\n offsets = (t * batch_size * dim + col_offsets) * 2\n\n grad_y_real = tl.load(grad_y_ptr + offsets, mask=mask, other=0)\n grad_y_imag = -tl.load(\n grad_y_ptr + offsets + 1, mask=mask, other=0)\n if t > 0:\n s_real = tl.load(\n y_ptr + offsets - 2 * batch_size * dim, mask=mask, other=0)\n s_imag = tl.load(\n y_ptr + offsets - 2 * batch_size * dim + 1,\n mask=mask,\n other=0)\n else:\n s_real = tl.load(s_ptr + 2 * col_offsets, mask=mask, other=0)\n s_imag = tl.load(\n s_ptr + 2 * col_offsets + 1, mask=mask, other=0)\n\n grad_s_real = grad_y_real + grad_s_real\n grad_s_imag = grad_y_imag + grad_s_imag\n grad_x_real = grad_s_real\n grad_x_imag = grad_s_imag\n grad_lambda_real += grad_s_real * s_real - grad_s_imag * s_imag\n grad_lambda_imag += grad_s_real * s_imag + grad_s_imag * s_real\n grad_s_real = grad_x_real * lambda_real - grad_x_imag * lambda_imag\n grad_s_imag = grad_x_real * lambda_imag + grad_x_imag * lambda_real\n\n tl.store(grad_x_ptr + offsets, grad_x_real, mask=mask)\n tl.store(grad_x_ptr + offsets + 1, -grad_x_imag, mask=mask)\n\n # 存储最终的梯度\n tl.store(grad_s_ptr + col_offsets * 2, grad_s_real, mask=mask)\n tl.store(grad_s_ptr + col_offsets * 2 + 1, -grad_s_imag, mask=mask)\n tl.store(\n grad_lambda_ptr + col_offsets * 2, grad_lambda_real, mask=mask)\n tl.store(\n grad_lambda_ptr + col_offsets * 2 + 1,\n -grad_lambda_imag,\n mask=mask)\n\nclass _ssm_forward(torch.autograd.Function):\n # TODO 使用 @triton.autotune 选择最佳的 BLOCK_SIZE\n # 对于3090,BLOCK_SIZE = 128似乎效果良好\n BLOCK_SIZE = 128\n\n @staticmethod\n def forward(ctx, s, x, Lambda):\n assert s.is_contiguous() and x.is_contiguous() and Lambda.is_contiguous()\n length, batch_size, dim = x.shape\n n = batch_size * dim\n y = torch.zeros_like(x)\n grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )\n\n if Lambda.dtype == torch.complex64:\n # 确保s和x是复数张量\n if not torch.is_complex(s):\n raise ValueError(\"当Lambda为复数时,s必须是复数张量\")\n if not torch.is_complex(x):\n raise ValueError(\"当Lambda为复数时,x必须是复数张量\")\n diag_ssm_forward_kernel_complex[grid](\n torch.view_as_real(s), torch.view_as_real(x),\n torch.view_as_real(y), torch.view_as_real(Lambda), length,\n batch_size, dim, _ssm_forward.BLOCK_SIZE)\n elif Lambda.dtype.is_floating_point:\n diag_ssm_forward_kernel[grid](s, x, Lambda, y, length,\n batch_size, dim,\n _ssm_forward.BLOCK_SIZE)\n else:\n raise ValueError(\"不支持的 dtype: %s\" % Lambda.dtype)\n ctx.save_for_backward(s, y, Lambda)\n return y\n\n @staticmethod\n def backward(ctx, grad_y):\n s, y, Lambda = ctx.saved_tensors\n length, batch_size, dim = y.shape\n grad_y = grad_y.contiguous()\n n = batch_size * dim\n grad_s = torch.empty_like(s)\n grad_x = torch.empty_like(grad_y)\n # grad_lambda 存储每个批次中 Lambda 的梯度\n # 我们将在内核完成后进行求和\n grad_lambda = torch.empty_like(s)\n grid = lambda meta: (triton.cdiv(n, meta['BLOCK_SIZE']), )\n if Lambda.dtype == torch.complex64:\n diag_ssm_backward_kernel_complex[grid](\n torch.view_as_real(s), torch.view_as_real(Lambda),\n torch.view_as_real(y), torch.view_as_real(grad_s),\n torch.view_as_real(grad_x),\n torch.view_as_real(grad_lambda),\n torch.view_as_real(grad_y), length, batch_size, dim,\n _ssm_forward.BLOCK_SIZE)\n else:\n diag_ssm_backward_kernel[grid](\n s, Lambda, y, grad_s, grad_x, grad_lambda, grad_y, length,\n batch_size, dim, _ssm_forward.BLOCK_SIZE)\n return grad_s, grad_x, grad_lambda.sum(dim=0)\n\ndiag_ssm_forward_triton = _ssm_forward.apply\n\n" + ] + }, + { + "id": "openseek-8-aeed98a776544b109375effa36c6a201", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThe Triton kernel `_quantize_global` performs global quantization on a tensor `x` by transforming its floating-point elements into an 8-bit integer representation. The kernel operates over CUDA using Triton's JIT compilation, and is equipped with autotuning to optimize the `BLOCK_SIZE` for different workloads. It takes `x_ptr`, a pointer to the input tensor `x`, `absmax_inv_ptr`, a pointer to the precomputed inverse of the maximum absolute value of `x`, and `output_ptr`, a pointer where the quantized results are stored. The `n_elements` parameter specifies the total number of elements to be processed. The kernel computes `pid` using the Triton `program_id` to identify the current program instance and calculates `block_start` and `offsets` to determine the specific segment of data it processes. It uses `tl.load` to read `BLOCK_SIZE` contiguous elements from `x`, applies the quantization operation using `tl.extra.cuda.libdevice.llrint`, which rounds the scaled values to the nearest integer, and writes the results into the output buffer. The surrounding Python function `quantize_global` first calculates `absmax` as the maximum absolute value of `x`, derives its reciprocal as `absmax_inv`, initializes an output tensor, and then dispatches the Triton kernel over a grid, with the grid size dynamically determined based on the input size and `BLOCK_SIZE`. After execution, it returns the quantized tensor and the original maximum absolute value, which might be used for dequantization.\n", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n# global quantize\n@triton.autotune(\n configs=[\n triton.Config({\"BLOCK_SIZE\": 1024}, num_warps=4),\n triton.Config({\"BLOCK_SIZE\": 2048}, num_stages=1),\n ],\n key=[\"n_elements\"],\n)\n@triton.jit\ndef _quantize_global(\n x_ptr,\n absmax_inv_ptr,\n output_ptr,\n n_elements,\n BLOCK_SIZE: tl.constexpr,\n):\n pid = tl.program_id(axis=0)\n block_start = pid * BLOCK_SIZE\n offsets = block_start + tl.arange(0, BLOCK_SIZE)\n mask = offsets < n_elements\n x = tl.load(x_ptr + offsets, mask=mask)\n absmax_inv = tl.load(absmax_inv_ptr)\n output = tl.extra.cuda.libdevice.llrint(127.0 * (x * absmax_inv))\n tl.store(output_ptr + offsets, output, mask=mask)\n\ndef quantize_global(x: torch.Tensor):\n absmax = x.abs().max().unsqueeze(0)\n absmax_inv = 1.0 / absmax\n output = torch.empty(*x.shape, device=\"cuda\", dtype=torch.int8)\n assert x.is_cuda and output.is_cuda\n n_elements = output.numel()\n grid = lambda meta: (triton.cdiv(n_elements, meta[\"BLOCK_SIZE\"]),)\n _quantize_global[grid](x, absmax_inv, output, n_elements)\n return output, absmax\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-4a74af86ec454ac789a6ac9e73c317dc", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The Triton kernel `_l2_norm_fwd_1pass_kernel` aims to perform L2 normalization on a 2D input tensor `X`. It processes each row separately using Triton's parallel execution model. The kernel expects pointers to `X` and `Y` along with the stride for rows (`stride_x_row`), number of columns in `X` (`N`), a small constant `eps` to prevent division by zero, and a compile-time constant `BLOCK_N`. The kernel computes L2 normalization by first loading a block of data from `X`, calculating the sum of squares for variance, and computing the reciprocal of the square root of the variance plus `eps` to get `rstd`. It then multiplies the input block by `rstd` to produce the normalized values, which are stored in `Y`.\n\n The Python function `_l2_norm_fwd` handles the setup and execution of the kernel. It first reshapes and possibly makes the input tensor `x` contiguous. It initializes an empty tensor `y` to store the output. The function calculates `BLOCK_N` based on `x`'s element size and ensures it doesn't exceed 64KB. If the feature dimension `N` is larger than `BLOCK_N`, it raises an error. The kernel is then launched with the total number of rows `M`, pointers to `x` and `y`, stride, number of columns, `eps`, and `BLOCK_N`. Finally, the function returns the normalized tensor reshaped to its original dimensions.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\n\n@triton.jit\ndef _l2_norm_fwd_1pass_kernel(\n X, # pointer to the input\n Y, # pointer to the output\n stride_x_row, # how much to increase the pointer when moving by 1 row\n N, # number of columns in X\n eps, # epsilon to avoid division by zero\n BLOCK_N: tl.constexpr,\n):\n row = tl.program_id(0)\n X += row * stride_x_row\n Y += row * stride_x_row\n cols = tl.arange(0, BLOCK_N)\n x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)\n xbar = tl.where(cols < N, x, 0.0)\n var = tl.sum(xbar * xbar, axis=0) \n rstd = 1 / tl.sqrt(var + eps)\n mask = cols < N\n y = x * rstd\n tl.store(Y + cols, y, mask=mask)\n\ndef _l2_norm_fwd(\n x, eps=1e-6\n):\n x_shape_og = x.shape\n x = x.reshape(-1, x.shape[-1])\n if x.stride(-1) != 1:\n x = x.contiguous()\n y = torch.empty_like(x)\n N = x.shape[-1]\n M = x.shape[0]\n MAX_FUSED_SIZE = 65536 // x.element_size()\n BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))\n if N > BLOCK_N:\n raise RuntimeError(\n \"This layer norm doesn't support feature dim >= 64KB.\")\n with torch.cuda.device(x.device.index):\n _l2_norm_fwd_1pass_kernel[(M,)](\n x,\n y,\n x.stride(0),\n N,\n eps,\n BLOCK_N,\n )\n return y.reshape(x_shape_og)\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-a0653344a6614430b51baacf54788f05", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThis implementation consists of Triton kernels designed to perform matrix multiplication enhanced by RMS normalization and optional rotary embeddings, serving particularly well in attention-based neural network layers. \n\n1. `rms_matmul_rbe` Kernel:\n - Function: Computes the expression c = (rms(x) * rms_w) @ w with optional rotary embeddings applied in the epilogue.\n - Inputs: \n - x_ptr, w_ptr, rms_w_ptr, out_ptr: Pointers to the input matrix, weight matrix, rms of the weight matrix, and output matrix, respectively.\n - M, N, K: Dimensions of the matrices involved.\n - Strides: Strides for batch, M, N, K dimensions for each matrix.\n - start_token_position, USE_FP8, RBE_EPILOGUE, THETA, EPS, BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K: Constants and configuration flags for numerical precision, RMS normalization, and rotary embedding application.\n - Logic: The kernel uses block-level computations to load, normalize using RMS, and multiply matrices. The rotary embeddings are optionally applied to output.\n\n2. `rms_matmul_rbe_qkv` Kernel:\n - Function: Acts as a wrapper to perform matrix multiplications separately for Q, K, and V matrices. It calls `rms_matmul_rbe` thrice, once for each type of matrix.\n - Inputs: Similar to `rms_matmul_rbe` but for Q, K, V matrices along with their respective weight matrices and strides.\n\n3. `rms_matmul_rbe_qkv_wrapper` Function:\n - Purpose: This is a high-level PyTorch interface for initializing the Triton kernel launch. It sets up the input/output matrices, strides, and prepares the grid configuration for the kernel execution.\n - Logic: It ensures type and shape checks for input matrices, manages data type transformations for kernel compatibility, and reshapes the output matrices (q, k, v) to expected dimensions for further processing in neural network models.\n\nThe kernels leverage Triton's ability to run optimized parallel computations on GPUs, allowing efficient implementation of operations common in Transformer architectures. The implementation checks data types and dimensions to support both FP16 and FP8, offering performance tuning via block size and grid configuration.\n", + "output": [ + "import torch\n\nimport triton\nimport triton.language as tl\n\n\ntorch.manual_seed(1234)\n\n\n@triton.jit\ndef rms_matmul_rbe(\n x_ptr, w_ptr, rms_w_ptr, out_ptr,\n M, N, K,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_w_k, stride_w_n,\n stride_rms_w,\n stride_out_batch, stride_out_m, stride_out_n,\n start_token_position,\n USE_FP8: tl.constexpr,\n RBE_EPILOGUE: tl.constexpr,\n THETA: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr,\n):\n \"\"\"\n Prologue: RMS\n Epilogue: nothing or Rotary embeddings\n c = ROBE((rms(a) * rms_w) @ b)\n \"\"\"\n pid_batch = tl.program_id(axis=0)\n pid = tl.program_id(axis=1)\n pid_m = pid // tl.cdiv(N, BLOCK_SIZE_N)\n pid_n = pid % tl.cdiv(N, BLOCK_SIZE_N)\n\n offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M\n offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N\n offs_k = tl.arange(0, BLOCK_SIZE_K)\n x_ptrs = x_ptr + (pid_batch * stride_x_batch + offs_m[:, None] * stride_x_m + offs_k[None, :] * stride_x_k)\n w_ptrs = w_ptr + (offs_k[:, None] * stride_w_k + offs_n[None, :] * stride_w_n)\n accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)\n rms_w_ptrs = rms_w_ptr + tl.arange(0, BLOCK_SIZE_K)[None, :] * stride_rms_w\n x_sum = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)\n for _ in range(0, tl.cdiv(K, BLOCK_SIZE_K)):\n x = tl.load(x_ptrs)\n x_sum += tl.extra.cuda.libdevice.pow(x.to(tl.float32), 2)\n rms_w = tl.load(rms_w_ptrs) # TODO add an assert that rms_w is a multiple of BLOCK SIZE K\n if USE_FP8:\n rms_w = rms_w.to(tl.float8e5, bitcast=True)\n rms_w = rms_w.to(tl.float16)\n x = x * rms_w\n w = tl.load(w_ptrs) # TODO add an assert that w is a multiple of BLOCK SIZE K\n if USE_FP8:\n w = w.to(tl.float8e5, bitcast=True)\n w = w.to(tl.float32)\n w = w.to(tl.float16)\n accumulator += tl.dot(x, w)\n x_ptrs += BLOCK_SIZE_K * stride_x_k\n w_ptrs += BLOCK_SIZE_K * stride_w_k\n rms_w_ptrs += BLOCK_SIZE_K * stride_rms_w\n x_mean = tl.sum(x_sum, axis=1) / K + EPS\n x_norm = tl.math.rsqrt(x_mean)\n accumulator = accumulator * x_norm[:, None]\n\n offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)\n offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)\n out_ptrs = out_ptr + (\n pid_batch * stride_out_batch + offs_m[:, None] * stride_out_m + offs_n[None, :] * stride_out_n)\n out_mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)\n\n tl.store(out_ptrs, accumulator, mask=out_mask)\n\n\n@triton.jit\ndef rms_matmul_rbe_qkv(x_ptr,\n q_weight_ptr, k_weight_ptr, v_weight_ptr,\n rms_w_ptr,\n q_ptr, k_ptr, v_ptr,\n M, N, K,\n stride_x_batch, stride_x_m, stride_x_k,\n stride_q_w_k, stride_q_w_n,\n stride_k_w_k, stride_k_w_n,\n stride_v_w_k, stride_v_w_n,\n stride_rms_w,\n stride_q_batch, stride_q_m, stride_q_n,\n stride_k_batch, stride_k_m, stride_k_n,\n stride_v_batch, stride_v_m, stride_v_n,\n start_token_position,\n USE_FP8: tl.constexpr,\n THETA: tl.constexpr,\n EPS: tl.constexpr,\n BLOCK_SIZE_M: tl.constexpr, BLOCK_SIZE_N: tl.constexpr, BLOCK_SIZE_K: tl.constexpr):\n # q\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=q_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=q_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_q_w_k, stride_w_n=stride_q_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_q_batch, stride_out_m=stride_q_m, stride_out_n=stride_q_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=True,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n # k\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=k_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=k_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_k_w_k, stride_w_n=stride_k_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_k_batch, stride_out_m=stride_k_m, stride_out_n=stride_k_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=True,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n # v\n rms_matmul_rbe(\n x_ptr=x_ptr,\n w_ptr=v_weight_ptr, rms_w_ptr=rms_w_ptr, out_ptr=v_ptr,\n M=M, N=N, K=K,\n stride_x_batch=stride_x_batch, stride_x_m=stride_x_m, stride_x_k=stride_x_k,\n stride_w_k=stride_v_w_k, stride_w_n=stride_v_w_n,\n stride_rms_w=stride_rms_w,\n stride_out_batch=stride_v_batch, stride_out_m=stride_v_m, stride_out_n=stride_v_n,\n start_token_position=start_token_position,\n USE_FP8=USE_FP8,\n RBE_EPILOGUE=False,\n THETA=THETA,\n EPS=EPS,\n BLOCK_SIZE_M=BLOCK_SIZE_M, BLOCK_SIZE_N=BLOCK_SIZE_N, BLOCK_SIZE_K=BLOCK_SIZE_K,\n )\n\n\ndef rms_matmul_rbe_qkv_wrapper(x: torch.Tensor,\n start_pos: int,\n q_weight: torch.Tensor, k_weight: torch.Tensor, v_weight: torch.Tensor,\n rms_w: torch.Tensor,\n n_heads: int, head_dim: int,\n k: torch.Tensor,\n v: torch.Tensor,\n eps: float = 1e-6, theta=10000.):\n assert q_weight.shape == k_weight.shape == v_weight.shape\n assert q_weight.dtype == k_weight.dtype == v_weight.dtype == rms_w.dtype\n assert q_weight.dtype in [torch.float16, torch.int8]\n batch, M, K = x.shape\n\n assert K == rms_w.shape[0]\n\n q_weight_t = q_weight.t()\n k_weight_t = k_weight.t()\n v_weight_t = v_weight.t()\n K_W, N = q_weight_t.shape\n assert K == K_W\n q = torch.empty((batch, M, N), dtype=torch.float16, device=q_weight_t.device)\n\n k = k.view((batch, M, N))\n v = v.view((batch, M, N))\n assert k.dtype == k_weight.dtype\n assert v.dtype == v_weight.dtype\n\n q_ptr = triton.reinterpret(q, tl.float16)\n k_ptr = triton.reinterpret(k, tl.float8e5 if k.dtype == torch.int8 else tl.float16)\n v_ptr = triton.reinterpret(v, tl.float8e5 if v.dtype == torch.int8 else tl.float16)\n\n grid = lambda META: (\n batch, triton.cdiv(META[\"M\"], META[\"BLOCK_SIZE_M\"]) * triton.cdiv(META[\"N\"], META[\"BLOCK_SIZE_N\"]))\n\n rms_matmul_rbe_qkv[grid](\n x_ptr=x,\n q_weight_ptr=q_weight_t, k_weight_ptr=k_weight_t, v_weight_ptr=v_weight_t,\n rms_w_ptr=rms_w,\n q_ptr=q_ptr, k_ptr=k_ptr, v_ptr=v_ptr,\n M=M, N=N, K=K,\n stride_x_batch=x.stride(0), stride_x_m=x.stride(1), stride_x_k=x.stride(2),\n stride_q_w_k=q_weight_t.stride(0), stride_q_w_n=q_weight_t.stride(1),\n stride_k_w_k=k_weight_t.stride(0), stride_k_w_n=k_weight_t.stride(1),\n stride_v_w_k=v_weight_t.stride(0), stride_v_w_n=v_weight_t.stride(1),\n stride_rms_w=rms_w.stride(0),\n stride_q_batch=q.stride(0), stride_q_m=q.stride(1), stride_q_n=q.stride(2),\n stride_k_batch=k.stride(0), stride_k_m=k.stride(1), stride_k_n=k.stride(2),\n stride_v_batch=v.stride(0), stride_v_m=v.stride(1), stride_v_n=v.stride(2),\n start_token_position=start_pos,\n USE_FP8=q_weight.dtype == torch.int8,\n THETA=theta,\n EPS=eps,\n BLOCK_SIZE_M=16, BLOCK_SIZE_N=64, BLOCK_SIZE_K=64,\n num_stages=4, num_warps=4\n )\n q = q.view(batch, M, n_heads, head_dim)\n k = k.view(batch, M, n_heads, head_dim)\n v = v.view(batch, M, n_heads, head_dim)\n return q, k, v\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-f415d2286bf949a7ac2144ccce054427", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\n The code implements a fused recurrent hierarchical gated recurrent network (HGRN) using Triton, which provides efficient GPU kernels. It includes two main kernels: `fused_recurrent_hgrn_fwd_kernel` for the forward pass and `fused_recurrent_hgrn_bwd_kernel` for the backward pass.\n\n The `fused_recurrent_hgrn_fwd_kernel` takes inputs `x` (input sequence), `g` (gates), `o` (output storage), `h0` (optional initial hidden state), `ht` (final state storage), and configuration constants `T`, `D`, `BD`, `USE_INITIAL_STATE`, `STORE_FINAL_STATE`. The kernel computes the output for each time step by iterating over the time dimension `T`, using the formula `b_h = b_g * b_h + b_x`, where `b_x` and `b_g` are loaded from `x` and `g`, respectively. The result is stored in `o`.\n\n The `fused_recurrent_hgrn_bwd_kernel` computes gradients `dx` and `dg` for the inputs `x` and `g`, given the gradient of the output `do`. It backpropagates through time from `T-1` to `0`, updating the gradients using `b_dh = b_dh + b_do`, `b_dx = b_dh`, `b_dg = b_dh * b_o`, and `b_dh = b_dh * b_g`.\n\n The `FusedRecurrentHGRNFunction` wraps these kernels in a PyTorch `autograd.Function` to allow for forward and backward passes during training. Its `forward` method sets up the necessary grid and calls the forward kernel. The `backward` method retrieves saved tensors and calls the backward kernel to compute gradients.\n\n Finally, `fused_recurrent_hgrn` is a user-friendly function that initializes the operation, handling input detachment and configuring the kernels based on whether an initial state is provided and whether the final state should be output. It returns both the computed sequence and optionally the final state of the RNN.\n ", + "output": [ + "\nimport torch\nimport triton\nimport triton.language as tl\nfrom typing import Tuple\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef fused_recurrent_hgrn_fwd_kernel(\n x,\n g,\n o,\n h0,\n ht,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr,\n STORE_FINAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_x = x + i_bh * T * D + o_d\n p_g = g + i_bh * T * D + o_d\n p_o = o + i_bh * T * D + o_d\n\n b_h = tl.zeros([BD], dtype=tl.float32)\n if USE_INITIAL_STATE:\n p_h0 = h0 + i_bh * D + o_d\n b_h += tl.load(p_h0, mask=mask, other=0).to(tl.float32)\n for _ in range(0, T):\n b_x = tl.load(p_x, mask=mask, other=0).to(tl.float32)\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_h = b_g * b_h + b_x\n tl.store(p_o, b_h.to(p_o.dtype.element_ty), mask=mask)\n\n p_x += D\n p_g += D\n p_o += D\n\n if STORE_FINAL_STATE:\n p_ht = ht + i_bh * D + o_d\n tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask)\n\n\n@triton.autotune(\n configs=[\n triton.Config({'BD': 32}, num_warps=1),\n triton.Config({'BD': 32}, num_warps=2),\n triton.Config({'BD': 32}, num_warps=4),\n triton.Config({'BD': 32}, num_warps=8),\n triton.Config({'BD': 64}, num_warps=1),\n triton.Config({'BD': 64}, num_warps=2),\n triton.Config({'BD': 64}, num_warps=4),\n triton.Config({'BD': 64}, num_warps=8),\n triton.Config({'BD': 128}, num_warps=1),\n triton.Config({'BD': 128}, num_warps=2),\n triton.Config({'BD': 128}, num_warps=4),\n triton.Config({'BD': 128}, num_warps=8),\n ],\n key=['D']\n)\n@triton.jit\ndef fused_recurrent_hgrn_bwd_kernel(\n g,\n o,\n dx,\n dg,\n do,\n h0,\n T: tl.constexpr,\n D: tl.constexpr,\n BD: tl.constexpr,\n USE_INITIAL_STATE: tl.constexpr\n):\n i_d, i_bh = tl.program_id(0), tl.program_id(1)\n o_d = i_d * BD + tl.arange(0, BD)\n mask = o_d < D\n\n p_g = g + (i_bh * T + T - 1) * D + o_d\n p_o = o + (i_bh * T + T - 2) * D + o_d\n p_dx = dx + (i_bh * T + T - 1) * D + o_d\n p_dg = dg + (i_bh * T + T - 1) * D + o_d\n p_do = do + (i_bh * T + T - 1) * D + o_d\n\n b_dh = tl.zeros([BD], dtype=tl.float32)\n for i in range(T - 1, -1, -1):\n b_g = tl.load(p_g, mask=mask, other=0).to(tl.float32)\n b_do = tl.load(p_do, mask=mask, other=0).to(tl.float32)\n if i > 0:\n b_o = tl.load(p_o, mask=mask, other=0).to(tl.float32)\n elif USE_INITIAL_STATE:\n b_o = tl.load(h0 + i_bh * D + o_d, mask=mask, other=0).to(tl.float32)\n else:\n b_o = tl.zeros([BD], dtype=tl.float32)\n\n b_dh = b_dh + b_do\n b_dx = b_dh\n b_dg = b_dh * b_o\n b_dh = b_dh * b_g\n tl.store(p_dx, b_dx.to(p_dx.dtype.element_ty), mask=mask)\n tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), mask=mask)\n\n p_g -= D\n p_o -= D\n p_dx -= D\n p_dg -= D\n p_do -= D\n\n\nclass FusedRecurrentHGRNFunction(torch.autograd.Function):\n\n @staticmethod\n def forward(ctx, x, g, initial_state=None, output_final_state=False):\n B, H, T, D = x.shape\n\n final_state = None\n if output_final_state:\n final_state = x.new_empty(B, H, D)\n\n o = torch.empty_like(x)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_fwd_kernel[grid](\n x, g, o, initial_state, final_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n STORE_FINAL_STATE=final_state is not None\n )\n ctx.save_for_backward(g, o, initial_state)\n return o, final_state\n\n @staticmethod\n def backward(ctx, do, dht=None):\n g, o, initial_state = ctx.saved_tensors\n B, H, T, D = do.shape\n\n dx = torch.empty_like(o)\n dg = torch.empty_like(g)\n def grid(meta): return (triton.cdiv(D, meta['BD']), B * H)\n fused_recurrent_hgrn_bwd_kernel[grid](\n g, o, dx, dg, do, initial_state,\n T, D,\n USE_INITIAL_STATE=initial_state is not None,\n )\n\n return dx, dg, None, None\n\n\ndef fused_recurrent_hgrn(\n x: torch.Tensor,\n g: torch.Tensor,\n initial_state: torch.Tensor = None,\n output_final_state: bool = False\n) -> Tuple[torch.Tensor, torch.Tensor]:\n if initial_state is not None:\n initial_state = initial_state.detach()\n o, final_state = FusedRecurrentHGRNFunction.apply(x, g, initial_state, output_final_state)\n return o, final_state\n\n\n\n\n" + ] + }, + { + "id": "openseek-8-8b6d990c0cf14c80888bfe82196da26a", + "input": "You are a expert in writing Triton operators for efficient GPU programming. Use triton language write a kernel and wrapper according following instruction.\nThis Triton script provides a specialized implementation of the softmax function, optimized for GPU computation. It includes both forward and backward passes for inner and non-inner dimensions, using Triton's kernel execution framework with tunable parameters to achieve optimal performance.\n\n### Key Components:\n\n- **Kernel Functions**:\n - `softmax_kernel_non_inner`: This function computes softmax over non-inner dimensions of a tensor. It uses the `TILE_K` and `TILE_N` parameters to define tile sizes, with `ONE_TILE_PER_CTA` indicating if a single tile is processed per CTA.\n - `softmax_kernel_inner`: Similar to the non-inner version, but optimized for inner dimension processing, focusing on maximizing the efficient use of shared memory.\n - `softmax_backward_kernel_non_inner` & `softmax_backward_kernel_inner`: These handle the backpropagation phase, computing the gradient based on the forward pass results. \n\n- **Heuristic Functions**:\n - `heur_tile_k`, `heur_tile_n_non_inner`, etc., determine optimal tile sizes (`TILE_K`, `TILE_N`) based on input dimensions M, N, K. These heuristics help balance the workload across available SMs (Streaming Multiprocessors) on the GPU.\n\n- **Softmax Class**:\n - Provides an autograd function with a custom forward and backward pass for PyTorch integration. The forward method executes the softmax operation using a Triton kernel, and the backward method computes the gradient.\n\n### Execution:\n- For forward pass, `softmax_kernel_non_inner` is invoked when K > 1, otherwise `softmax_kernel_inner` is used. Similar logic applies to the backward pass with the respective backward kernels.\n- Grid execution parameters are dynamically configured using the `grid` lambda functions.\n\n### Inputs/Outputs:\n- **Inputs**: `output_ptr`, `input_ptr`, and dimensions M, N, K. \n- **Outputs**: Computed softmax or gradient stored in `output_ptr` or `in_grad_ptr` respectively.\n\nBy splitting the softmax calculation into tiles and using Triton's parallel execution capabilities, this implementation leverages GPU resources effectively to provide scalable and high-performance softmax operations.\n", + "output": [ + "import logging\n\nimport torch\nimport triton\nimport triton.language as tl\n\n# from ..utils import libentry\n\nMAX_TILE_K = 8192\nNUM_SMS = torch.cuda.get_device_properties(\n torch.cuda.current_device()\n).multi_processor_count\n\n\ndef heur_tile_k(args):\n tile_k = 1\n upper_bound = min(args[\"K\"], MAX_TILE_K)\n while tile_k <= upper_bound:\n num_blocks = args[\"M\"] * triton.cdiv(args[\"K\"], tile_k)\n num_waves = num_blocks / NUM_SMS\n if (num_waves > 1) and (tile_k * 2 <= upper_bound):\n tile_k *= 2\n else:\n break\n return tile_k\n\n\ndef heur_tile_n_non_inner(args):\n return triton.cdiv(8192, args[\"TILE_K\"])\n\n\ndef heur_one_tile_per_cta(args):\n return args[\"TILE_N\"] >= args[\"N\"]\n\n\ndef heur_num_warps_non_inner(args):\n tile_size = args[\"TILE_N\"] * args[\"TILE_K\"]\n if tile_size < 2048:\n return 4\n elif tile_size < 4096:\n return 8\n else:\n return 16\n\n\n@triton.heuristics(\n {\n \"TILE_K\": heur_tile_k,\n \"TILE_N\": heur_tile_n_non_inner,\n \"ONE_TILE_PER_CTA\": heur_one_tile_per_cta,\n \"num_warps\": heur_num_warps_non_inner,\n }\n)\n@triton.jit\ndef softmax_kernel_non_inner(\n output_ptr,\n input_ptr,\n M,\n N,\n K,\n TILE_N: tl.constexpr,\n TILE_K: tl.constexpr,\n ONE_TILE_PER_CTA: tl.constexpr,\n):\n pid_k = tl.program_id(1)\n pid_m = tl.program_id(0)\n\n k_offsets = pid_k * TILE_K + tl.arange(0, TILE_K)\n\n if ONE_TILE_PER_CTA:\n n_offsets = tl.arange(0, TILE_N)\n offset = pid_m * N * K + n_offsets[:, None] * K + k_offsets\n mask = (n_offsets[:, None] < N) & (k_offsets < K)\n input_ptrs = input_ptr + offset\n inp = tl.load(input_ptrs, mask=mask, other=-float(\"inf\"))\n m = tl.max(inp, 0)\n e = tl.exp(inp - m[None, :])\n z = tl.sum(e, 0)\n out = e / z\n output_ptrs = output_ptr + offset\n tl.store(output_ptrs, out, mask=mask)\n else:\n m = tl.full([TILE_N, TILE_K], value=float(\"-inf\"), dtype=tl.float32)\n z = tl.full([TILE_N, TILE_K], value=0.0, dtype=tl.float32)\n\n # specialization does not improve performance inn this example, as tested\n for start_n in range(0, N, TILE_N):\n n_offsets = start_n + tl.arange(0, TILE_N)\n offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets\n mask = (n_offsets[:, None] < N) & (k_offsets < K)\n inp = tl.load(input_ptr + offsets, mask=mask, other=-float(\"inf\"))\n m_new = tl.maximum(m, inp)\n alpha = tl.exp(m - m_new)\n z = z * alpha + tl.exp(inp - m_new)\n m = m_new\n\n m_reduced = tl.max(m, 0) # (TILE_K,)\n z = tl.sum(z * tl.exp(m - m_reduced[None, :]), 0) # (TILE_K, )\n m = m_reduced\n\n # specialization does not improve performance inn this example, as tested\n previous_multiple = prev_multiple_of(N, TILE_N)\n for start_n in range(0, N, TILE_N):\n n_offsets = (previous_multiple - start_n) + tl.arange(0, TILE_N)\n offsets = pid_m * N * K + n_offsets[:, None] * K + k_offsets\n mask = (n_offsets[:, None] < N) & (k_offsets[None, :] < K)\n inp = tl.load(input_ptr + offsets, mask=mask, other=-float(\"inf\"))\n o = tl.exp(inp - m[None, :]) / z[None, :]\n tl.store(output_ptr + offsets, o, mask=mask)\n\n\n@triton.jit\ndef next_multiple_of(a, b):\n # the smallest x>=a that x%b ==0\n return tl.cidv(a, b) * b\n\n\n@triton.jit\ndef prev_multiple_of(a, b):\n # the largest x= -x.ndim and dim < x.ndim, \"Invalid dim\"\n dim = dim % x.ndim\n M = 1\n N = x.shape[dim]\n for i in range(dim):\n M *= x.shape[i] # pre_dim\n inp = x.contiguous()\n if dtype is None:\n dtype = x.dtype\n out = torch.empty_like(inp, dtype=dtype)\n K = inp.numel() // M // N # post_dim\n\n with torch.cuda.device(inp.device):\n if K > 1:\n grid = lambda meta: (M, triton.cdiv(K, meta[\"TILE_K\"]), 1)\n softmax_kernel_non_inner[grid](\n out,\n inp,\n M,\n N,\n K,\n )\n else:\n grid = (M, 1, 1)\n softmax_kernel_inner[grid](\n out,\n inp,\n M,\n N,\n )\n ctx.save_for_backward(out)\n ctx.dim = dim\n return out\n\n @staticmethod\n def backward(ctx, out_grad):\n logging.debug(\"GEMS SOFTMAX VJP\")\n dim = ctx.dim\n (out,) = ctx.saved_tensors\n\n assert dim >= -out.ndim and dim < out.ndim, \"Invalid dim\"\n dim = dim % out.ndim\n M = 1\n N = out.shape[dim]\n for i in range(dim):\n M *= out.shape[i]\n\n out_grad = out_grad.contiguous()\n in_grad = torch.empty_like(out)\n K = out.numel() // M // N\n\n with torch.cuda.device(in_grad.device):\n if K > 1:\n grid = lambda meta: (M, triton.cdiv(K, meta[\"TILE_K\"]), 1)\n softmax_backward_kernel_non_inner[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n K,\n )\n else:\n grid = lambda meta: (triton.cdiv(M, meta[\"TILE_M\"]), 1, 1)\n softmax_backward_kernel_inner[grid](\n out,\n out_grad,\n in_grad,\n M,\n N,\n )\n return in_grad, None, None\n\n\ndef softmax(x, dim=-1, dtype=None):\n return Softmax.apply(x, dim, dtype)\n\n\n\n\n" + ] + } + ], + "test_samples": [ + { + "id": "openseek-8-501f776ba20444458ac14dd7292cc913", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining batch matrix multiplication, RMS normalization, GELU activation, dropout, and subtraction. The function takes three input tensors, performs batch matrix multiplication on the first two, applies RMS normalization, GELU activation, and dropout, and finally subtracts the third tensor from the result.\nWrapper Entry Information: fused_bmm_rmsnorm_gelu_dropout_sub(input1, input2, other, normalized_shape, dropout_p=0.5, training=True, approximate='none', eps=1e-5, *, out=None) -> Tensor. Args: input1 (Tensor): First input tensor for batch matrix multiplication, of shape (B, N, M), where B is the batch size. input2 (Tensor): Second input tensor for batch matrix multiplication, of shape (B, M, P). other (Tensor): Tensor to subtract from the result after dropout, must be broadcastable to the shape of the output. normalized_shape (int or list or torch.Size): Shape over which RMS normalization is applied, typically the size of the last dimension P. dropout_p (float, optional): Probability of an element to be zeroed in the dropout layer. Default: 0.5. training (bool, optional): Apply dropout if True. Default: True. approximate (str, optional): Can be 'none' or 'tanh'. The approximation to use for GELU. Default: 'none'. eps (float, optional): A value added to the denominator for numerical stability in RMS normalization. Default: 1e-5. out (Tensor, optional): Output tensor. Ignored if None. Default: None. Shape: - Input1: (B, N, M), Input2: (B, M, P), Other: broadcastable to (B, N, P). Output: (B, N, P).\nMath: Given input tensors X, Y, and O, this function computes:\n\n\\[\n\\begin{align*}\nZ &= \\text{bmm}(X, Y) \\\\\nZ_{\\text{norm}} &= \\text{RMSNorm}(Z, \\epsilon) \\\\\nG &= \\text{GELU}(Z_{\\text{norm}}) \\\\\nD &= \\text{Dropout}(G, p) \\\\\nY &= D - O\n\\end{align*}\n\\]\n\nwhere:\n\n- \\text{bmm}(X, Y) performs batch matrix multiplication.\n- \\text{RMSNorm}(Z, \\epsilon) = \\frac{Z}{\\sqrt{\\text{mean}(Z^2, \\text{dim}=-1) + \\epsilon}} applies Root Mean Square Layer Normalization over the last dimension.\n- \\text{GELU}(Z_{\\text{norm}}) applies the Gaussian Error Linear Unit activation function element-wise.\n- \\text{Dropout}(G, p) randomly zeroes elements of G with probability p.\n- D - O subtracts tensor O from D, where O must be broadcastable to the shape of D.\nother: - The shapes of `input1` and `input2` must be compatible for batch matrix multiplication: `input1` of shape `(B, N, M)` and `input2` of shape `(B, M, P)` result in an output of shape `(B, N, P)`.\n- The `normalized_shape` argument for RMS normalization should match the dimensions over which to compute the RMS. For an output of shape `(B, N, P)`, setting `normalized_shape=P` applies normalization over the last dimension.\n- The `other` tensor must be broadcastable to the shape of the output tensor after dropout.\n- The `dropout` is applied during training when `training=True`. Set `training=False` to disable dropout during evaluation.\n- The `GELU` activation is applied element-wise to the normalized output.\n- All operations are differentiable and support autograd.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-82c29f05c3434437917c95f49fadff01", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Divides each element of the input tensor by the corresponding element of the other tensor, supporting broadcasting, type promotion, and handling integer, float, and complex inputs. Rounding behavior can be controlled with the rounding_mode parameter.\nWrapper Entry Information: div(input, other, *, rounding_mode=None, out=None) -> Tensor; input (Tensor): the dividend; other (Tensor or Number): the divisor; rounding_mode (str, optional): Type of rounding applied to the result; out (Tensor, optional): the output tensor\nMath: \\text{out}_i = \\frac{\\text{input}_i}{\\text{other}_i}\nother: By default, performs a 'true' division like Python 3. Supports broadcasting to a common shape, type promotion, and integer, float, and complex inputs. Always promotes integer types to the default scalar type.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-76a66f9a2bd5449fbb57b2b0a0bd7ec7", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over an input tensor with specified filters, followed by applying the sigmoid activation function element-wise to the result. This ensures that the convolutional output values are scaled between 0 and 1.\nWrapper Entry Information: sigmoid_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, out=None) -> Tensor\nArgs:\ninput (Tensor): The input tensor of shape `(minibatch, in_channels, iH, iW)`.\nweight (Tensor): The convolution filters of shape `(out_channels, in_channels / groups, kH, kW)`.\nbias (Tensor, optional): Optional bias tensor of shape `(out_channels)`. Default: None.\nstride (int or tuple, optional): The stride of the convolution kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1.\npadding (int, tuple, or string, optional): Padding on both sides of the input. Can be 'valid', 'same', single number, or tuple `(padH, padW)`. Default: 0.\ndilation (int or tuple, optional): The spacing between kernel elements. Default: 1.\ngroups (int, optional): Number of groups to split the input into. Default: 1.\nout (Tensor, optional): The output tensor.\nMath: \\text{out} = \\sigma(\\text{conv2d}(\\text{input}, \\text{weight})) where \\sigma(x) = \\frac{1}{1 + e^{-x}} is the sigmoid function.\nother: The function combines 2D convolution and sigmoid activation, ensuring output values are between 0 and 1.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-616ca19cad034c8ba1763cbd4420b620", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Solves multiple linear systems with the same coefficient matrix using LU decomposition. Given a square matrix A and multiple right-hand side vectors B, this function computes the solutions X to the linear systems A X = B by performing the LU decomposition of A and reusing it to solve for multiple right-hand sides efficiently. Supports batch dimensions.\nWrapper Entry Information: def solve_multiple_lu(A, Bs, *, pivot=True, out=None) -> Tensor\n\n - **A** (Tensor): Coefficient matrix of shape `(*, n, n)`, where `*` is zero or more batch dimensions.\n - **Bs** (Tensor): Right-hand side tensor of shape `(*, n, k)`, where `k` is the number of right-hand sides.\n - **pivot** (bool, optional): Controls whether to compute the LU decomposition with partial pivoting (`True`) or without pivoting (`False`). Default: `True`.\n - **out** (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: LU Decomposition:\nA = P L U\n- P is a permutation matrix.\n- L is a lower triangular matrix with unit diagonal elements.\n- U is an upper triangular matrix.\nSolving Linear Systems:\nFor each b_i in B, compute x_i by:\n- Apply permutation: b'_i = P^T b_i\n- Solve L y_i = b'_i\n- Solve U x_i = y_i\nother: This function efficiently reuses the LU decomposition of A to solve multiple linear systems with different right-hand sides. If `pivot=False`, no permutation is applied. Supports batch dimensions.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-7cc09b51ea774c9d9e44cd680d32435c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the hyperbolic tangent of the elements of the input tensor.\nWrapper Entry Information: tanh(input, *, out=None) -> Tensor Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\tanh(\\text{input}_{i})\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-a82af84ea5a14dc9b58d50f504ec8f5e", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the rectified linear unit (ReLU) function to each element in input, and then computes the square root of the result. This function ensures all negative values in input are set to zero before applying the square root.\nWrapper Entry Information: def relu_sqrt(input, inplace=False, out=None) -> Tensor: input (Tensor): The input tensor. inplace (bool, optional): If True, modifies input in-place (if possible). Default is False. out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = \\sqrt{\\max(0, \\text{input}_i)}\nother: The function modifies input in-place if inplace is set to True.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-215a58cbaf6d4e96a69284d61aeeaf3c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the square-root of the elements of the input tensor. It computes the square root element-wise.\nWrapper Entry Information: sqrt(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\sqrt{\\text{input}_{i}}\nother: The function can handle negative inputs, resulting in NaN for those elements.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-521e38ea57b1490a96e6dc76ff2f57b9", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the sigmoid (logistic) function to each element in the input and then computes the indices of the maximum values along the specified dimension or over all elements if no dimension is specified. If dim is not specified, it returns the index of the maximum value in the flattened tensor.\nWrapper Entry Information: sigmoid_argmax(input, dim=None, keepdim=False) -> LongTensor: input (Tensor): The input tensor. dim (int, optional): The dimension to reduce. Default is None, which computes the argmax over all elements. keepdim (bool, optional): Whether the output tensor has :attr:`dim` retained or not. Default is False.\nMath: sigmoid(x) = 1 / (1 + e^{-x})\nother: The function uses PyTorch tensor operations and returns a LongTensor containing indices.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b41b0e3a84e4430887282bc3faed8b81", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`. The operation is defined as: out_i = input_i - alpha * other_i. Supports broadcasting to a common shape, type promotion, and works with integer, float, and complex inputs.\nWrapper Entry Information: sub(input, other, *, alpha=1, out=None) -> Tensor; input (Tensor): the input tensor.; other (Tensor or Number): the tensor or number to subtract from input.; alpha (Number): the multiplier for other.; out (Tensor, optional): the output tensor.\nMath: out_i = input_i - alpha * other_i\nother: Supports broadcasting, type promotion, and works with integer, float, and complex inputs.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-ca9d997e5aef49cf8d0bbf48b8a22fbd", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes output using input values and pixel locations from grid, supporting spatial (4-D) and volumetric (5-D) input. Interpolates output value at specified grid positions using nearest or bilinear interpolation. Grid values are normalized within [-1, 1] range, and values outside are handled by padding_mode. Often used with affine_grid to build Spatial Transformer Networks.\nWrapper Entry Information: def grid_sample(input, grid, mode='bilinear', padding_mode='zeros', align_corners=False) -> Tensor\nother: Note: NaN values in grid are interpreted as -1. align_corners=True changes sampled grid positions with image resolution. Default for align_corners changed to False since version 1.2.0. bicubic mode implemented using cubic convolution algorithm with alpha=-0.75; other packages might use different alpha values.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-757113b30aed48eabfecadffd0aa1118", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the singular value decomposition (SVD) of a matrix. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions. The returned decomposition is a named tuple (U, S, Vh) which corresponds to U, S, V^{H} above. The singular values are returned in descending order. The parameter full_matrices chooses between the full (default) and reduced SVD. The driver kwarg may be used in CUDA with a cuSOLVER backend to choose the algorithm used to compute the SVD. The choice of a driver is a trade-off between accuracy and speed.\nWrapper Entry Information: def linalg.svd(A, full_matrices=True, *, driver=None, out=None) -> (Tensor, Tensor, Tensor)\n\nArgs:\n A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.\n full_matrices (bool, optional): controls whether to compute the full or reduced SVD, and consequently, the shape of the returned tensors `U` and `Vh`. Default: `True`.\n\nKeyword args:\n driver (str, optional): name of the cuSOLVER method to be used. This keyword argument only works on CUDA inputs. Available options are: `None`, `gesvd`, `gesvdj`, and `gesvda`. Default: `None`.\n out (tuple, optional): output tuple of three tensors. Ignored if `None`.\nMath: A = U \\operatorname{diag}(S) V^{\\text{H}} \\mathrlap{\\qquad U \\in \\mathbb{K}^{m \\times m}, S \\in \\mathbb{R}^k, V \\in \\mathbb{K}^{n \\times n}}\nother: Differences with numpy.linalg.svd: Unlike numpy.linalg.svd, this function always returns a tuple of three tensors and it doesn't support compute_uv argument. Please use torch.linalg.svdvals, which computes only the singular values, instead of compute_uv=False. When full_matrices=True, the gradients with respect to U[..., :, min(m, n):] and Vh[..., min(m, n):, :] will be ignored, as those vectors can be arbitrary bases of the corresponding subspaces. The returned tensors U and V are not unique, nor are they continuous with respect to A. Gradients computed using U or Vh will only be finite when A does not have repeated singular values.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-188273dd82f7465dafa78d1430aeb9ee", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the zeroth order modified Bessel function of the first kind for each element of the input tensor.\nWrapper Entry Information: i0(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor; Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = I_0(\\text{input}_{i}) = \\sum_{k=0}^{\\infty} \\frac{(\\text{input}_{i}^2/4)^k}{(k!)^2}\nother: The function calculates the zeroth order modified Bessel function of the first kind, which is a special mathematical function.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-e7f5dd02bec34352add8c7935b6d790f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the reciprocal of the square-root of each of the elements of the input tensor.\nWrapper Entry Information: rsqrt(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor.; Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\frac{1}{\\sqrt{\\text{input}_{i}}}\nother: Note: The function will return 'nan' for negative input values.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b36ca7e6da114f799ec8f9feaf26a769", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution followed by batch normalization, ReLU activation, and dropout. Sequentially applies conv2d, batch normalization for stabilizing training and reducing internal covariate shift, ReLU activation function, and dropout where some elements of the tensor are randomly zeroed with probability `p`.\nWrapper Entry Information: dropout_relu_batch_norm_conv2d(input: torch.Tensor, weight: torch.Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1, p=0.5, training=True, inplace=False) -> torch.Tensor; Args: input (Tensor): Input tensor of shape \\(N, C_{in}, H, W\\). weight (Tensor): Convolution filters of shape \\(C_{out}, C_{in} / \\text{groups}, kH, kW\\). bias (Tensor, optional): Bias tensor of shape \\(C_{out}\\). Default is None. stride (int or tuple, optional): Stride of the convolution. Default: 1 padding (int, tuple, or str, optional): Implicit padding on both sides of the input. Default: 0 dilation (int or tuple, optional): Spacing between kernel elements. Default: 1 groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1 p (float, optional): Probability of an element to be zeroed in dropout. Default: 0.5 training (bool, optional): If True, applies dropout during training. Default: True inplace (bool, optional): If True, performs the operation in-place. Default: False\nother: Output tensor is returned after applying conv2d, batch normalization, ReLU, and dropout.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-d4a55c1a2818498b9907bdaf461f3d0c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining matrix-vector multiplication, log-softmax activation, and dropout. The function first performs matrix-vector multiplication on the input matrix and vector. The result is then passed through a log-softmax activation function along the specified dimension. Finally, dropout is applied to the output of the log-softmax operation.\nWrapper Entry Information: fused_mv_logsoftmax_dropout(input, vec, p=0.5, training=True, inplace=False, dim=0, *, out=None) -> Tensor\nMath: Given an input matrix A ∈ ℝ^(n × m) and a vector v ∈ ℝ^m, the function computes:\n\nz = A * v\ns = log(exp(z) / ∑_j exp(z_j))\ny = Dropout(s, p)\n\nwhere log(exp(z) / ∑_j exp(z_j)) is the log-softmax function applied along dimension `dim`, and Dropout(s, p) randomly zeroes elements of s with probability p.\nother: - The shapes of `input` and `vec` must be compatible for matrix-vector multiplication: the number of columns in `input` must match the size of `vec`.\n- The `dim` argument in `log_softmax` specifies the dimension along which the log-softmax is computed. Since `z` is a 1-D tensor of shape `(n,)`, `dim` should be `0` or `-1`.\n- The `dropout` is applied during training when `training=True`. Set `training=False` to disable dropout during evaluation.\n- This function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-ddceed268b2546188db6e761c68b9522", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Adds the tensor or number 'other', scaled by 'alpha', to the 'input' tensor. Supports broadcasting to a common shape, type promotion, and accepts integer, float, and complex inputs.\nWrapper Entry Information: add(input, other, *, alpha=1, out=None) -> Tensor; input (Tensor): the input tensor.; other (Tensor or Number): the tensor or number to add to input.; alpha (Number): the multiplier for other.; out (Tensor, optional): the output tensor.\nMath: \\text{{out}}_i = \\text{{input}}_i + \\text{{alpha}} \\times \\text{{other}}_i\nother: Supports broadcasting and type promotion.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1a5b486bc85e4509a30bba465ae7a0f4", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies 2D Convolution, followed by Layer Normalization and SiLU activation to the input tensor `x`. Sequentially performs convolution on `x`, then applies layer normalization on the convolution output, followed by SiLU activation applied element-wise.\nWrapper Entry Information: fused_silu_layer_norm_conv2d(x: torch.Tensor, weight: torch.Tensor, conv_weight: torch.Tensor, conv_bias: torch.Tensor = None, conv_stride: int = 1, conv_padding: int = 0, conv_dilation: int = 1, conv_groups: int = 1, ln_eps: float = 1e-5) -> torch.Tensor\n\nArgs:\n x (Tensor): Input tensor for convolution, normalization, and activation.\n weight (Tensor): Learnable weight of size matching normalized output dimensions for LayerNorm.\n conv_weight (Tensor): Convolution kernel tensor of appropriate dimensions.\n conv_bias (Tensor, optional): Convolution bias tensor. Default: ``None``.\n conv_stride (int, optional): Stride of convolution. Default: 1.\n conv_padding (int, optional): Padding added to both sides of input. Default: 0.\n conv_dilation (int, optional): Dilation of convolution kernel. Default: 1.\n conv_groups (int, optional): Number of groups for convolution. Default: 1.\n ln_eps (float, optional): Epsilon value for Layer Normalization. Default: 1e-5.\n\nExample:\n >>> x = torch.randn(4, 3, 32, 32)\n >>> conv_weight = torch.randn(8, 3, 3, 3)\n >>> conv_bias = torch.zeros(8)\n >>> weight = torch.ones(8)\n >>> output = fused_silu_layer_norm_conv2d(x, weight, conv_weight, conv_bias, conv_stride=1, conv_padding=1)\n >>> print(output.shape)\n torch.Size([4, 8, 32, 32])\nother: Convolution operation parameters include stride, padding, dilation, and groups. Layer Normalization uses an epsilon value. Default values are provided for optional parameters.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1cc25388256c4207b53289a81921b12c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining index selection and element-wise equality comparison. It selects elements from the input tensor along a specified dimension using provided indices and then performs an element-wise equality comparison between the selected elements and another tensor or scalar. The result is a boolean tensor of the same shape as the selected elements, indicating where the comparisons are true.\nWrapper Entry Information: fused_index_select_eq(input, dim, index, other, *, out=None) -> Tensor. Args: input (Tensor): The input tensor X. dim (int): The dimension along which to index. index (IntTensor or LongTensor): The indices to select along dimension dim. other (Tensor or float): The tensor or value Y to compare with the selected tensor. out (Tensor, optional): Output tensor. Ignored if None. Default: None\nMath: Given an input tensor X, dimension \text{dim}, index tensor I, and another tensor or scalar Y, the function computes:\n\n1. **Index Selection:**\n\nSelect elements from X along dimension \text{dim} using indices I:\n\n\\[\nS = \\text{index\\_select}(X, \\text{dim}, I)\n\\]\n\n2. **Element-wise Equality Comparison:**\n\nCompare the selected tensor S with Y element-wise:\n\n\\[\nO = (S == Y)\n\\]\n\nThe output tensor O is a boolean tensor of the same shape as S.\nother: - The shapes of the selected tensor S and other must be broadcastable for the element-wise comparison.\n- If other is a scalar, it is broadcasted to the shape of S.\n- The function supports autograd for gradient computation, although the output is a boolean tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-bfb26a8289784475a2c5132dd723f6b3", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns the indices of the maximum values of a tensor across a specified dimension. If the dimension is None, it returns the index of the maximum value in the flattened input tensor. The output tensor can retain the reduced dimension if keepdim is set to True.\nWrapper Entry Information: argmax(input, dim, keepdim=False) -> LongTensor\nArgs:\n input (Tensor): the input tensor.\n dim (int): the dimension to reduce. If ``None``, the argmax of the flattened input is returned.\n keepdim (bool): whether the output tensor has :attr:`dim` retained or not.\nother: This is the second value returned by torch.max. See its documentation for the exact semantics of this method.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-d8c481c6232b4f68baba55a9f6fcfa8f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the solution `x` to the equation `Ax = b` using LU decomposition. Given matrix `A`, this function performs LU decomposition and then solves for `x` in `L @ U @ x = b`, where `P`, `L`, and `U` are derived from the LU decomposition.\nWrapper Entry Information: def fused_lu_solve(A: Tensor, b: Tensor) -> Tensor: A: The input matrix `A` of shape `(n, n)`. b: The right-hand side tensor `b` of shape `(n,)`.\nMath: Solves `Ax = b` using LU decomposition, where `A = P @ L @ U` and `L @ U @ x = b`.\nother: The function uses LU decomposition to solve linear equations.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-45c89b4e8cef4315bfaa885afd98c669", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the pairwise distance between `x1` and `x2` using the specified norm, then normalizes the resulting distances along the specified dimension. This combined operation is useful for obtaining normalized distance values between two sets of vectors.\nWrapper Entry Information: normalize_pairwise_distance(x1, x2, p_distance=2.0, eps_distance=1e-6, keepdim=False, p_norm=2, dim_norm=1, eps_norm=1e-12) -> Tensor; x1 (Tensor): The first input tensor; x2 (Tensor): The second input tensor, must have the same shape as `x1`; p_distance (float): The norm degree for computing the pairwise distance. Default: 2.0; eps_distance (float): Small value to avoid division by zero in pairwise distance calculation. Default: 1e-6; keepdim (bool): Whether to keep the reduced dimensions in the output. Default: False; p_norm (float): The exponent value in the norm formulation for normalization. Default: 2; dim_norm (int): The dimension along which normalization is applied. Default: 1; eps_norm (float): Small value to avoid division by zero in normalization. Default: 1e-12\nMath: \\text{distance} = \\frac{\\text{pairwise\\_distance}(x1, x2)}{\\max(\\lVert \\text{pairwise\\_distance}(x1, x2) \\rVert_p, \\epsilon)}\nother: The combined operation is useful for obtaining normalized distance values between two sets of vectors.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-eed93a70ff1546d8aa68ef247bc922bd", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a namedtuple (values, indices) where values is the maximum value of each row of the input tensor in the given dimension dim. Indices is the index location of each maximum value found (argmax). If keepdim is True, the output tensors are of the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed, resulting in the output tensors having 1 fewer dimension than input. If there are multiple maximal values in a reduced row, the indices of the first maximal value are returned.\nWrapper Entry Information: max(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) input (Tensor): the input tensor. dim (int): the dimension to reduce. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Default: ``False``. out (tuple, optional): the result tuple of two output tensors (max, max_indices).\nother: If there are multiple maximal values in a reduced row then the indices of the first maximal value are returned.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f7521a91f83c43da8791d9f29ea31535", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a linear transformation to the input tensor followed by the log_softmax activation function. This combined operation is optimized to be numerically stable and efficient, applying both a linear transformation and log-softmax in one step.\nWrapper Entry Information: log_softmax_linear(input, weight, bias=None, dim=-1, dtype=None) -> Tensor: input (Tensor): The input tensor of shape `(*, in_features)`, where `*` represents any number of additional dimensions. weight (Tensor): The weight matrix of shape `(out_features, in_features)`. bias (Tensor, optional): The optional bias tensor of shape `(out_features)`. Default: None. dim (int): The dimension along which log_softmax will be computed. Default: -1. dtype (:class:`torch.dtype`, optional): The desired data type of the returned tensor. If specified, the input tensor is cast to :attr:`dtype` before the operation. Default: None.\nMath: \\text{out} = \\log\\left(\\frac{\\exp(\\text{linear}(\\text{input}))}{\\sum_j \\exp(\\text{linear}(\\text{input})_j)}\\right) y = xA^T + b\nother: The values along the specified dimension represent log probabilities and sum to 1.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3a629a95117a4ca18edda2c3bc560fc0", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the rectified linear unit function element-wise. This operation compares each element in the input tensor to zero and returns the element itself if it is greater than zero or zero otherwise. The operation can be performed in-place, modifying the input tensor directly if inplace=True.\nWrapper Entry Information: relu(input, inplace=False) -> Tensor\n\nArgs:\n inplace: can optionally do the operation in-place. Default: False\n\nShape:\n - Input: (*), where * means any number of dimensions.\n - Output: (*), same shape as the input.\nMath: ReLU(x) = (x)^+ = max(0, x)\nother: See torch.nn.ReLU for more details.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-495afd5d58c84aff9fd26367f8c28f40", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Solves the least squares problem for an overdetermined system of linear equations using QR decomposition. It computes the least squares solution x that minimizes the Euclidean 2-norm |Ax - b|_2, where A is the coefficient matrix and b is the right-hand side vector or matrix.\nWrapper Entry Information: def least_squares_qr(A, b, *, mode='reduced', out=None) -> Tensor: A (Tensor): Coefficient matrix of shape (*, m, n), where * is zero or more batch dimensions. b (Tensor): Right-hand side vector or matrix of shape (*, m) or (*, m, k), where k is the number of right-hand sides. mode (str, optional): Determines the type of QR decomposition to use. One of 'reduced' (default) or 'complete'. See torch.linalg.qr for details. out (Tensor, optional): Output tensor. Ignored if None. Default: None.\nMath: The QR decomposition of A is given by A = QR, where Q is a matrix with orthonormal columns and R is an upper triangular matrix. The least squares solution is x = R^{-1} Q^H b.\nother: The function utilizes QR decomposition to efficiently solve overdetermined linear systems by finding the least squares solution.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-9e761bfd59194360afd2b637f96c63a7", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the determinant of a square matrix using QR decomposition. It performs QR decomposition of a square matrix A in \\mathbb{K}^{n \times n} (where \\mathbb{K} is either \\mathbb{R} or \\mathbb{C}) and computes the determinant by taking the product of the diagonal elements of R.\nWrapper Entry Information: determinant_via_qr(A, *, mode='reduced', out=None) -> Tensor\nMath: The QR decomposition of A is: A = Q R, where Q is an orthogonal/unitary matrix, R is an upper triangular matrix. The determinant is given by: \\det(A) = \\det(Q)\\cdot \\prod_{i=1}^{n} R_{ii}. For real matrices, \\det(Q) = \\pm 1. For complex matrices, |\\det(Q)| = 1.\nother: Numerical stability considerations are important, especially for ill-conditioned matrices. The function explicitly computes \\det(Q) to account for the sign. For complex matrices, the result may be complex.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-cc45f93ffa22476c8063c1ef6a8207d6", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining tiling (repeating elements) and the exponential function. The input tensor is first repeated along each dimension according to the specified `dims` using the tiling operation, then the exponential function is applied element-wise to the resulting tensor.\nWrapper Entry Information: fused_tile_exp(input, dims, *, out=None) -> Tensor; input (Tensor): The input tensor X whose elements are to be repeated and exponentiated.; dims (tuple of int): The number of repetitions for each dimension. If `dims` has fewer dimensions than `input`, ones are prepended to `dims` until all dimensions are specified.; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: Given an input tensor X and a tuple of dimensions \text{dims}, the function computes:\n1. **Tiling:**\nThe input tensor is repeated along each dimension according to the specified number of times in `dims`:\nY = tile(X, dims)\n2. **Exponential Function:**\nThe exponential function is applied element-wise to the tiled tensor:\nZ = exp(Y)\nother: The `dims` parameter controls how many times the input tensor is repeated along each dimension.\nIf `dims` specifies fewer dimensions than `input`, ones are prepended to `dims` until all dimensions are specified.\nThe function supports autograd for gradient computation.\nAll operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f08e0eca68224df896e459818a75b5e3", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the square root of each element in the input tensor, and then applies the hyperbolic tangent (tanh) function to the square-rooted values. The function returns a tensor where each element is the result of applying sqrt followed by tanh to each element of the input.\nWrapper Entry Information: def sqrt_tanh(input, out=None) -> Tensor: input (Tensor): The input tensor. out (Tensor, optional): The output tensor.\nMath: \\text{out}_{i} = \\tanh(\\sqrt{\\text{input}_{i}})\nother: Using a tensor with some negative values results in NaN for those elements.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-5b28a4d5afff45b1879d276494830c51", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies Batch Normalization over an input tensor across channels, followed by the Sigmoid Linear Unit (SiLU) activation function applied element-wise. This combined operation normalizes the input tensor and then applies a non-linear SiLU activation.\nWrapper Entry Information: silu_batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-5) -> Tensor; input (Tensor): The input tensor for Batch Normalization.; running_mean (Tensor): The running mean tensor (used during evaluation).; running_var (Tensor): The running variance tensor (used during evaluation).; weight (Tensor, optional): The weight tensor for Batch Normalization scaling. Default: None.; bias (Tensor, optional): The bias tensor for Batch Normalization. Default: None.; training (bool, optional): Whether the module is in training mode. Default: False.; momentum (float, optional): Value used for the running mean and variance computation. Default: 0.1.; eps (float, optional): A small value added to the denominator for numerical stability. Default: 1e-5.\nMath: The combined operation is defined as: \\text{out} = \\text{silu}(\\text{BatchNorm}(x)), where the SiLU function is defined as: \\text{silu}(x) = x * \\sigma(x), \\text{where } \\sigma(x) = \\frac{1}{1 + \\exp(-x)}\nother: Returns: A tensor that has undergone batch normalization and SiLU activation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-781a002b0f944a9989a836c8da9dec47", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Fills the elements of the self tensor with a specified value by selecting the indices in the order given in the index tensor. The operation is performed along a specified dimension.\nWrapper Entry Information: index_fill_(dim, index, value) -> Tensor\n\nArgs:\n dim (int): dimension along which to index\n index (LongTensor): indices of :attr:`self` tensor to fill in\n value (float): the value to fill with\n\nExample::\n >>> x = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]], dtype=torch.float)\n >>> index = torch.tensor([0, 2])\n >>> x.index_fill_(1, index, -1)\n tensor([[-1., 2., -1.],\n [-1., 5., -1.],\n [-1., 8., -1.]])\nother: The function modifies the tensor in-place.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6fc781d88f5743efaacedd57d654066f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining cross-entropy loss computation, softmax activation, and layer normalization. It computes the cross-entropy loss for given logits and targets, applies softmax activation to the logits, and then applies layer normalization to the resulting probabilities.\nWrapper Entry Information: fused_cross_entropy_softmax_layernorm(logits, targets, normalized_shape, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0, eps=1e-5, *, out=None) -> Tuple[Tensor, Tensor] - logits (Tensor): Input logits of shape (N, C) or (N, C, *), where N is the batch size and C is the number of classes. - targets (Tensor): Ground truth class indices or class probabilities. If containing class indices: shape (N) or (N, *) with values 0 <= targets_i < C. If containing class probabilities: same shape as logits and values in [0, 1]. - normalized_shape (int or list or torch.Size): Input shape over which layer normalization is applied. - weight (Tensor, optional): A manual rescaling weight given to each class. If provided, must be of size C. - ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Default: -100. - reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'. - label_smoothing (float, optional): A float in [0.0, 1.0] specifying the amount of smoothing when computing the loss. Default: 0.0. - eps (float, optional): A value added to the denominator for numerical stability in layer normalization. Default: 1e-5. - out (Tensor, optional): Output tensor for the normalized probabilities. Ignored if None. Default: None.\nMath: Given input logits \\mathbf{z} and target labels \\mathbf{y}, the function computes:\n\n1. **Cross-Entropy Loss:**\n\nFor target class indices:\n\n\\[\nL = -\\sum_{i=1}^N \\log\\left( \\frac{e^{z_{i, y_i}}}{\\sum_{j=1}^C e^{z_{i, j}}} \\right)\n\\]\n\nFor target class probabilities:\n\n\\[\nL = -\\sum_{i=1}^N \\sum_{j=1}^C y_{i, j} \\log\\left( \\frac{e^{z_{i, j}}}{\\sum_{k=1}^C e^{z_{i, k}}} \\right)\n\\]\n\n2. **Softmax Activation:**\n\n\\[\n\\mathbf{p}_{i, j} = \\text{Softmax}(z_{i, j}) = \\frac{e^{z_{i, j}}}{\\sum_{k=1}^C e^{z_{i, k}}}\n\\]\n\n3. **Layer Normalization:**\n\n\\[\n\\mathbf{o}_i = \\frac{\\mathbf{p}_i - \\mu_i}{\\sqrt{\\sigma_i^2 + \\epsilon}} \\times \\gamma + \\beta\n\\]\n\nwhere:\n\n- \\mu_i and \\sigma_i^2 are the mean and variance of \\mathbf{p}_i over the dimensions specified by `normalized_shape`.\n- \\gamma and \\beta are learnable parameters (omitted if `elementwise_affine=False`).\n- \\epsilon is a small value added for numerical stability.\nother: - The `logits` tensor should contain raw, unnormalized scores for each class.\n- The `targets` can be class indices or class probabilities matching the shape of `logits`.\n- The `normalized_shape` argument in `layer_norm` should correspond to the dimensions over which you want to apply normalization.\n- If `elementwise_affine` parameters (`weight` and `bias`) are needed in `layer_norm`, they can be defined and passed accordingly.\n- All operations support autograd for gradient computation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6895c51d60b848399f2d206e16b9c587", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns the mean value of each row of the input tensor in the given dimension dim. If dim is a list of dimensions, reduce over all of them. If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).\nWrapper Entry Information: input (Tensor): the input tensor. dim (int or tuple of ints): the dimension or dimensions to reduce. keepdim (bool): whether the output tensor has dim retained or not. dtype (torch.dtype, optional): the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None. out (Tensor, optional): the output tensor.\nother: See also torch.nanmean which computes the mean value of non-NaN elements.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-2a6dd4dd9e43480182989a6b5a5cac1d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the eigenvalue decomposition of a square matrix if it exists. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions. The returned eigenvalues are not guaranteed to be in any specific order. The eigenvalues and eigenvectors of a real matrix may be complex. When inputs are on a CUDA device, this function synchronizes that device with the CPU. Assumes that A is diagonalizable. The returned eigenvectors are normalized to have norm 1. The eigenvectors of a matrix are not unique, nor are they continuous with respect to A. Gradients computed using the eigenvectors tensor will only be finite when A has distinct eigenvalues.\nWrapper Entry Information: def linalg.eig(A, *, out=None) -> (Tensor, Tensor) Args: A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of diagonalizable matrices. Keyword args: out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`.\nMath: A = V \\operatorname{diag}(\\Lambda) V^{-1}\\mathrlap{\\qquad V \\in \\mathbb{C}^{n \\times n}, \\Lambda \\in \\mathbb{C}^n}\nother: The eigenvalues and eigenvectors of a real matrix may be complex. When inputs are on a CUDA device, this function synchronizes that device with the CPU. Assumes that A is diagonalizable. The returned eigenvectors are normalized to have norm 1. The eigenvectors of a matrix are not unique, nor are they continuous with respect to A. Gradients computed using the eigenvectors tensor will only be finite when A has distinct eigenvalues.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-71075dbf0e0a4c1592be4985ebd1ba1b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function computes the logarithm of the sum of exponentials of input elements along the specified dimension. It is useful for numerical stability when computing log probabilities.\nWrapper Entry Information: def logsumexp(input, dim, keepdim=False, *, out=None) -> Tensor\nMath: logsumexp(x) = log(sum(exp(x)))\nother: Alias for torch.logsumexp.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-61ae6f57db984c43b3e82cbaa5f6753f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining embedding lookup, element-wise addition, and tanh activation. The function retrieves embeddings from an embedding matrix using input indices, adds another tensor to these embeddings, and applies a tanh activation function to the result. It supports options for padding indices, max norm for embeddings, scaling gradients by frequency, and sparse gradients.\nWrapper Entry Information: fused_embedding_add_tanh(input_indices, weight, other, *, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, out=None) -> Tensor; input_indices (LongTensor): Tensor containing indices into the embedding matrix, of arbitrary shape (*); weight (Tensor): The embedding matrix of shape (V, D), where V is the number of embeddings (vocabulary size), and D is the embedding dimension; other (Tensor): Tensor to be added to the embeddings, must be broadcastable to the shape of E; padding_idx (int, optional): If specified, the entries at `padding_idx` do not contribute to the gradient; max_norm (float, optional): If given, each embedding vector with norm larger than `max_norm` is renormalized to have norm `max_norm`; norm_type (float, optional): The p-norm to compute for the `max_norm` option. Default: `2.0`; scale_grad_by_freq (bool, optional): If `True`, scale gradients by the inverse of frequency of the words in the mini-batch. Default: `False`; sparse (bool, optional): If `True`, gradient w.r.t. `weight` will be a sparse tensor. Default: `False`; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`\nMath: Given input indices \\mathbf{i}, embedding weight matrix W, and tensor O, the function computes:\n\\[\n\\begin{align*}\nE &= \\text{Embedding}(\\mathbf{i}, W) \\\\\nS &= E + O \\\\\nY &= \\tanh(S)\n\\end{align*}\n\\]\nother: - The `other` tensor must be broadcastable to the shape of the embeddings retrieved by `torch.nn.functional.embedding`.\n- All parameters related to `torch.nn.functional.embedding` are passed through to allow for options like `padding_idx`, `max_norm`, etc.\n- This function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b76244965a1e4ee78c7a2e17f1ab9fdb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining matrix-vector multiplication, sigmoid activation, and subtraction.\nWrapper Entry Information: fused_mv_sigmoid_sub(input, vec, other, alpha=1, *, out=None) -> Tensor; input (Tensor): Input matrix A of shape (n, m); vec (Tensor): Input vector \\mathbf{v} of shape (m); other (Tensor or Number): Tensor or scalar b to subtract from the sigmoid output, scaled by \\alpha; alpha (Number, optional): Scalar multiplier for other. Default: `1`; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`\nMath: Given an input matrix A, a vector \\mathbf{v}, and another tensor or scalar b, the function computes:\n\\[\n\\begin{align*}\n\\mathbf{z} &= A \\mathbf{v} \\\\\n\\mathbf{s} &= \\sigma(\\mathbf{z}) = \\frac{1}{1 + \\exp(-\\mathbf{z})} \\\\\n\\mathbf{y} &= \\mathbf{s} - \\alpha b\n\\end{align*}\n\\]\nother: - The shapes of `input` and `vec` must be compatible for matrix-vector multiplication.\n- The `other` tensor must be broadcastable to the shape of the output from the sigmoid function.\n- The function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-97a0096f362e4089a20674eaad0d6173", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Adds the tensor or number `other`, scaled by the multiplier `alpha`, to the input tensor `input`, and then applies the Gaussian Error Linear Units (GELU) activation function to the result.\nWrapper Entry Information: def add_gelu(input, other, alpha=1, approximate='none', out=None) -> Tensor: input (Tensor): The input tensor. other (Tensor or Number): The tensor or number to add to input. alpha (Number, optional): The multiplier for other. Default is 1. approximate (str, optional): The approximation method for GELU. Default is 'none'. out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = \\text{GELU}(\\text{input}_i + \\text{alpha} \\times \\text{other}_i) where GELU is defined as: - \\text{GELU}(x) = x * \\Phi(x) when approximate is 'none', - \\text{GELU}(x) = 0.5 * x * (1 + \\text{Tanh}(\\sqrt{2 / \\pi} * (x + 0.044715 * x^3))) when approximate is 'tanh'.\nother: The GELU function is defined with two methods: an exact method using the Cumulative Distribution Function for Gaussian Distribution, and an approximate method using a tanh-based formula.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3eb056fc79754edcb5e161265c57732f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes cosine embedding loss between two normalized tensors. This function first normalizes the inputs along the specified dimension using L2 normalization and then calculates the cosine embedding loss. The loss encourages similarity when the target is 1 and dissimilarity when the target is -1. It accepts optional parameters margin for dissimilarity control and reduction method for output aggregation.\nWrapper Entry Information: def fused_cosine_embedding_loss_with_normalization(input1: torch.Tensor, input2: torch.Tensor, target: torch.Tensor, margin: float = 0, reduction: str = 'mean') -> torch.Tensor: input1 (Tensor): First input tensor to be normalized and compared. input2 (Tensor): Second input tensor to be normalized and compared. target (Tensor): Tensor label with values 1 or -1, where 1 encourages similarity and -1 encourages dissimilarity. margin (float, optional): Margin for dissimilarity. Default: 0. reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'.\nother: The inputs are first L2 normalized along dimension 1 before loss calculation. The reduction parameter can be 'none', 'mean', or 'sum', with default as 'mean'.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-a6024651bc2a4554b4bf6898bdd0a33e", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a sequence of operations commonly used in transformer models, combining matrix multiplication, softmax, dropout, another matrix multiplication, layer normalization, and addition (residual connection).\nWrapper Entry Information: fused_transformer_block(input, weight1, weight2, residual, dropout_p=0.1, eps=1e-5, *, out=None) -> Tensor; input (Tensor): Input tensor of shape (*, N, D_in), where * denotes any number of batch dimensions.; weight1 (Tensor): Weight matrix of shape (D_in, D_k).; weight2 (Tensor): Weight matrix of shape (D_k, D_out).; residual (Tensor): Residual tensor to be added before layer normalization, must be broadcastable to the shape of Z_4.; dropout_p (float, optional): Probability of an element to be zeroed in the dropout layer. Default: 0.1.; eps (float, optional): A value added to the denominator for numerical stability in layer normalization. Default: 1e-5.; out (Tensor, optional): Output tensor. Ignored if None. Default: None.\nMath: Given an input tensor X, weight matrices W_1 and W_2, and a residual tensor R, the function computes:\n\n \\[\n \\begin{align*}\n Z_1 &= X W_1 \\\\\n Z_2 &= \\text{softmax}(Z_1) \\\\\n Z_3 &= \\text{dropout}(Z_2, p) \\\\\n Z_4 &= Z_3 W_2 \\\\\n Y &= \\text{LayerNorm}(Z_4 + R, \\gamma, \\beta, \\epsilon)\n \\end{align*}\n \\]\n\n where:\n - \\text{softmax}(Z) is applied along the last dimension.\n - \\text{dropout}(Z, p) randomly zeroes elements of Z with probability p.\n - \\text{LayerNorm} applies layer normalization with learnable parameters \\gamma and \\beta, and epsilon \\epsilon for numerical stability.\n - R is the residual tensor added to Z_4 before layer normalization.\nother: - The dimensions of `input` and `weight1` must be compatible for matrix multiplication: the last dimension of `input` must match the first dimension of `weight1`.\n - The output of the first matrix multiplication has shape `(*, N, D_k)`.\n - The `softmax` is applied along the last dimension (`dim=-1`).\n - The `dropout` is applied during training. Set `training=False` to disable dropout during evaluation.\n - The `layer_norm` is applied over the last dimension of the input tensor.\n - The `residual` tensor must be broadcastable to the shape of `z4`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-0b47029f5b3241c6ba2201b9babd9935", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the natural logarithm of (1 + input). This function is more accurate than torch.log for small values of input.\nWrapper Entry Information: log1p(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: y_i = \\log_{e} (x_i + 1)\nother: This function is more accurate than torch.log for small values of input.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f976b81fe08840119706d14983d18384", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies Batch Normalization over the input tensor across each channel, followed by applying the sigmoid activation function element-wise to the normalized result. This is useful for scaling the output to a range between 0 and 1 after normalization.\nWrapper Entry Information: def sigmoid_batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-5) -> Tensor\nArgs:\n input (Tensor): The input tensor of shape `(N, C)` or `(N, C, L)`, where `N` is batch size, `C` is the number of features or channels, and `L` is the sequence length.\n running_mean (Tensor): The running mean of the input channels.\n running_var (Tensor): The running variance of the input channels.\n weight (Tensor, optional): Learnable scaling factor for each channel, typically represented as `γ`. Default: None.\n bias (Tensor, optional): Learnable shift for each channel, typically represented as `β`. Default: None.\n training (bool, optional): If `True`, updates running statistics; if `False`, uses them for normalization. Default: False.\n momentum (float, optional): Value for updating the running mean and variance. Default: 0.1.\n eps (float, optional): A small value added for numerical stability. Default: 1e-5.\nMath: \\text{out} = \\sigma\\left(\\frac{\\text{input} - \\text{mean}}{\\sqrt{\\text{var} + \\epsilon}} * \\gamma + \\beta \\right) where \\sigma(x) = \\frac{1}{1 + \\exp(-x)} is the sigmoid function.\nother: The function normalizes the input tensor using batch normalization and then applies the sigmoid activation function to scale the output between 0 and 1.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-64b5d9d1883b4672a73b22b9040ef9e1", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies Batch Normalization followed by the Hardsigmoid activation function on the input tensor `x`. This function performs batch normalization on `x` using the specified parameters and then applies Hardsigmoid activation element-wise on the normalized output.\nWrapper Entry Information: fused_hardsigmoid_batch_norm(x: torch.Tensor, running_mean: torch.Tensor, running_var: torch.Tensor, weight: torch.Tensor = None, bias: torch.Tensor = None, training: bool = False, momentum: float = 0.1, eps: float = 1e-5, inplace: bool = False) -> torch.Tensor: Args: x (Tensor): Input tensor for batch normalization and activation. running_mean (Tensor): The running mean buffer (persistent). running_var (Tensor): The running variance buffer (persistent). weight (Tensor, optional): Learnable weight of size C for the normalized tensor. Default: ``None`` bias (Tensor, optional): Learnable bias of size C for the normalized tensor. Default: ``None`` training (bool, optional): Flag for training mode, used to update running estimates. Default: ``False`` momentum (float, optional): The value for the running mean and variance momentum. Default: ``0.1`` eps (float, optional): Small constant added to variance to improve numerical stability. Default: ``1e-5`` inplace (bool, optional): If ``True``, perform Hardsigmoid in-place. Default: ``False``\nother: The function includes optional parameters for learnable weight and bias, a training flag to update running estimates, momentum for running mean and variance, a small constant `eps` for numerical stability, and an `inplace` option for Hardsigmoid.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-289f73ed62e740c8b6cdf08f2ea929da", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Hurwitz zeta function, elementwise. The function calculates the sum of the series for each element in the input tensors, which represent the parameters x and q of the Hurwitz zeta function. The Riemann zeta function is a special case when q equals 1.\nWrapper Entry Information: zeta(input, other, *, out=None) -> Tensor; Args: input (Tensor): the input tensor corresponding to `x`. other (Tensor): the input tensor corresponding to `q`. Keyword args: out (Tensor, optional): the output tensor.\nMath: \\zeta(x, q) = \\sum_{k=0}^{\\infty} \\frac{1}{(k + q)^x}\nother: The Riemann zeta function corresponds to the case when `q = 1`\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-da7522a7c6924bf39ca8e4a82bb53c5a", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the matrix-vector product for a symmetric matrix `A` and a vector `x`, with scaling factors `alpha` and `beta`. Then calculates the norm of the resulting vector `y`. The operation performed is: 1. `y = alpha * torch.mv(A, x) + beta * y`, assuming `A` is symmetric. 2. `norm = torch.norm(y, p)`.\nWrapper Entry Information: def symmetric_matrix_vector_norm(A: torch.Tensor, x: torch.Tensor, alpha: float, beta: float, p: float = 2.0) -> torch.Tensor: A (Tensor): A symmetric matrix of shape `(n, n)`. x (Tensor): A vector of shape `(n,)`. alpha (float): Scalar multiplier for the matrix-vector product. beta (float): Scalar multiplier added to `y`. p (float, optional): Order of the norm. Default is 2.0 (Euclidean norm).\nMath: y = alpha * torch.mv(A, x) + beta * y\nnorm = torch.norm(y, p)\nother: Assumes `A` is symmetric.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-d4cabb13b39a499a9a6b23e853e99c8b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a linear transformation to the input tensor, followed by the Softplus activation function applied element-wise. This combined operation first performs a linear transformation and then introduces non-linearity with Softplus, which is smoother than ReLU and approximates it for large values. The function is particularly designed to improve numerical stability by reverting to a linear function for values above a specified threshold.\nWrapper Entry Information: softplus_linear(input, weight, bias=None, beta=1, threshold=20) -> Tensor\nMath: The combined operation is defined as: out = Softplus(Linear(x)), where the Softplus function is defined as: Softplus(x) = (1/β) * log(1 + exp(β * x))\nother: For values exceeding the threshold, the function helps maintain numerical stability by approximating a linear function, which enhances stability and prevents potential overflow.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-11320a88c62f40ce98b98080ba07ef61", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Reconstructs the input matrix `A` using its Singular Value Decomposition (SVD). This function combines the Singular Value Decomposition (SVD) with matrix reconstruction. Given a matrix `A`, it performs the following operations: 1. Compute the SVD of `A`: A = U Σ V^H, where `U` and `Vh` are unitary matrices and `S` contains the singular values of `A`. 2. Reconstruct `A` as A_reconstructed = U Σ V^H.\nWrapper Entry Information: fused_svd_reconstruct(A: Tensor) -> Tensor: The input matrix `A` of shape `(m, n)`.\nMath: A = U Σ V^H\nA_reconstructed = U diag(S) V^H\nother: The function returns the reconstructed matrix `A` of shape `(m, n)`, approximating the original matrix.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-50d2585ad4334780a630fe1bf041fb18", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining element-wise multiplication, addition, log-softmax activation, dropout, and batch matrix multiplication.\nWrapper Entry Information: fused_mul_add_logsoftmax_dropout_bmm(input1, input2, other, mat2, p=0.5, training=True, inplace=False, dim=-1, *, out=None) -> Tensor\nMath: Given input tensors X_1, X_2, O, and M, the function computes:\n\n\\[\n\\begin{align*}\nZ &= X_1 \\odot X_2 \\\\\nS &= Z + O \\\\\nL &= \\log\\left( \\frac{\\exp(S)}{\\sum_j \\exp(S_j)} \\right) \\\\\nD &= \\text{Dropout}(L, p) \\\\\nY &= \\text{bmm}(D, M)\n\\end{align*}\n\\]\n\nwhere:\n\n- X_1 and X_2 are input tensors for element-wise multiplication.\n- \\odot denotes element-wise multiplication.\n- O is a tensor or scalar to be added to Z, must be broadcastable to the shape of Z.\n- \\log\\left( \\frac{\\exp(S)}{\\sum_j \\exp(S_j)} \\right) is the log-softmax function applied along dimension `dim`.\n- \\text{Dropout}(L, p) randomly zeroes elements of L with probability p.\n- \\text{bmm}(D, M) performs batch matrix multiplication.\nother: - The shapes of `input1`, `input2`, and `other` must be broadcastable to each other.\n- The `mat2` tensor must have a shape compatible with the output of the dropout layer for batch matrix multiplication, i.e., `mat2` should have shape `(B, D_in, D_out)` if the dropout output has shape `(B, N, D_in)`.\n- The `log_softmax` function is applied along dimension `dim`, which should be the dimension of the features (typically `-1` for the last dimension).\n- The `dropout` is applied during training when `training=True`. Set `training=False` to disable dropout during evaluation.\n- All operations are differentiable and support autograd.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6c82bb8c97874fcda4c650e922fb65a1", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the element-wise SELU (Scaled Exponential Linear Unit) function to the input tensor. The SELU function is defined as scale * (max(0, x) + min(0, alpha * (exp(x) - 1))), where the constants alpha and scale are fixed values with alpha approximately 1.673 and scale approximately 1.051.\nWrapper Entry Information: selu(input, inplace=False) -> Tensor\nMath: SELU(x) = scale * (max(0,x) + min(0, alpha * (exp(x) - 1))), with alpha=1.6732632423543772848170429916717 and scale=1.0507009873554804934193349852946.\nother: See torch.nn.SELU for more details.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-192450a134ba4b8db5cf15f591eb74b5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes `y += alpha * x` and returns the 2-norm of the modified `y`. The function takes a target tensor `y`, a tensor `x` to be scaled by a scalar `alpha`, and adds the scaled `x` to `y`. It then calculates and returns the 2-norm of the updated `y`.\nWrapper Entry Information: scaled_add_norm(y: Tensor, x: Tensor, alpha: float) -> Tensor: y (Tensor): The target tensor to be modified, of shape `(n,)`. x (Tensor): The tensor to be scaled and added to `y`, of shape `(n,)`. alpha (float): The scalar multiplier for `x`.\nMath: y += alpha * x\nnorm = ||y||_2\nother: The function modifies the input tensor `y` in place and calculates the 2-norm using `torch.norm`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-78523a7fa58b495091e13f92efb7b7eb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over the input tensor, followed by applying the Leaky ReLU activation function element-wise to the result. This allows for both feature extraction and non-linear activation in one step.\nWrapper Entry Information: def leaky_relu_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, negative_slope=0.01, inplace=False) -> Tensor\nMath: The combined operation is defined as:\n\n.. math::\n \\text{out} = \\text{LeakyReLU}(\\text{conv2d}(\\text{input}))\n\nwhere the Leaky ReLU function is applied element-wise as:\n\n.. math::\n \\text{LeakyReLU}(x) = \\max(0, x) + \\text{negative\\_slope} \\times \\min(0, x)\nother: The function combines 2D convolution and Leaky ReLU activation in one step, allowing for efficient computation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-dc20e66156654cfc8d1a558548ccb016", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the square root of each element in :attr:`input`, and then applies the exponential function to the square-rooted values. The combined operation is defined as: out_i = e^(sqrt(input_i))\nWrapper Entry Information: def sqrt_exp(input, out=None) -> Tensor: input (Tensor): The input tensor. out (Tensor, optional): The output tensor.\nMath: out_i = e^(sqrt(input_i))\nother: N/A\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-125dc3fb47954939a752d1b89c38c022", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the cosine function element-wise to the input tensor, followed by a 1D average pooling. The function first computes the cosine of each element in the input tensor, then applies 1D average pooling over the resulting tensor with the specified kernel size, stride, padding, ceil mode, and padding inclusion.\nWrapper Entry Information: def cos_avg_pool1d(input: torch.Tensor, kernel_size: int, stride: int = None, padding: int = 0, ceil_mode: bool = False, count_include_pad: bool = True) -> torch.Tensor\n \n input (Tensor): The input tensor of shape (minibatch, in_channels, iW).\n kernel_size (int): Size of the pooling window.\n stride (int, optional): Stride of the pooling window. Defaults to `kernel_size`.\n padding (int, optional): Zero-padding added to both sides of the input. Default is 0.\n ceil_mode (bool, optional): If True, uses ceil instead of floor to compute the output shape. Default is False.\n count_include_pad (bool, optional): If True, includes the zero-padding in the averaging calculation. Default is True.\nMath: \\text{output} = \\text{avg\\_pool1d}(\\cos(\\text{input}))\nother: The function involves computing the cosine transformation followed by pooling, and handles parameters like stride, padding, and ceil mode.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-23e1c9547e2348be8fbf4e531d860e6e", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the sum of elements in the input tensor along the specified dimension(s), followed by calculating the standard deviation of the summed values.\nWrapper Entry Information: def sum_std(input, dim=None, keepdim=False, dtype=None, correction=1, out=None) -> Tensor: input (Tensor): The input tensor. dim (int or tuple of ints, optional): The dimension(s) to reduce. If None, all dimensions are reduced. keepdim (bool, optional): Whether the output tensor has dim retained or not. Default is False. dtype (torch.dtype, optional): The desired data type of the returned tensor. If specified, the input tensor is cast to dtype before the operation. Default: None. correction (int, optional): Difference between the sample size and sample degrees of freedom. Default is 1 (Bessel's correction). out (Tensor, optional): The output tensor.\nMath: \\text{sum} = \\sum_{i=0}^{N-1} x_i\n\n\\sigma = \\sqrt{\\frac{1}{\\max(0,~N - \\delta N)}\\sum_{i=0}^{N-1}(x_i-\\bar{x})^2}\nother: The function uses Bessel's correction by default with a correction value of 1.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3c2783058b9a49deba902ae94007f399", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function performs element-wise multiplication of two inputs, input and other, and then applies the Rectified Linear Unit (ReLU) function to the result, which replaces all negative values with zero.\nWrapper Entry Information: def mul_relu(input, other, inplace=False, out=None) -> Tensor: input (Tensor): The input tensor to be multiplied. other (Tensor or Number): The tensor or number to multiply with `input`. inplace (bool, optional): If True, modifies `input` in-place, if possible. Default is False. out (Tensor, optional): The output tensor.\nMath: ReLU(x) = max(0, x); out_i = ReLU(input_i * other_i)\nother: The function uses torch.mul for multiplication and F.relu for the ReLU operation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-27c7026034bc48de894b56c90b53633d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over an input tensor with specified filters, followed by applying the Gaussian Error Linear Units (GELU) activation function element-wise to the result. This helps introduce non-linearity after the convolution operation.\nWrapper Entry Information: def gelu_conv2d(input: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int], str] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, approximate: str = 'none', out: Optional[Tensor] = None) -> Tensor\nArgs:\n input (Tensor): The input tensor of shape `(minibatch, in_channels, iH, iW)`.\n weight (Tensor): The convolution filters of shape `(out_channels, in_channels / groups, kH, kW)`.\n bias (Tensor, optional): Optional bias tensor of shape `(out_channels)`. Default: ``None``.\n stride (int or tuple, optional): The stride of the convolution kernel. Can be a single number or a tuple `(sH, sW)`. Default: 1.\n padding (int, tuple, or string, optional): Padding on both sides of the input. Can be 'valid', 'same', single number, or tuple `(padH, padW)`. Default: 0.\n dilation (int or tuple, optional): The spacing between kernel elements. Default: 1.\n groups (int, optional): Number of groups to split the input into. Default: 1.\n approximate (str, optional): The approximation method for GELU. Default is 'none'.\n out (Tensor, optional): The output tensor.\nMath: The combined operation is defined as:\n\n.. math::\n \\text{out} = \\text{GELU}(\\text{conv2d}(\\text{input}, \\text{weight}))\n\nwhere GELU is computed as:\n\n- If :attr:`approximate` is 'none', GELU is computed as:\n\n .. math::\n \\text{GELU}(x) = x * \\Phi(x)\n\n where :math:`\\Phi(x)` is the Cumulative Distribution Function for Gaussian Distribution.\n\n- If :attr:`approximate` is 'tanh', GELU is approximated as:\n\n .. math::\n \\text{GELU}(x) = 0.5 * x * (1 + \\text{Tanh}(\\sqrt{2 / \\pi} * (x + 0.044715 * x^3)))\nother: The function combines 2D convolution and GELU activation, with options for approximation methods for GELU.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3580ec85929944749e9fa41e2e02bc65", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a fused operation consisting of a 2D convolution followed by SELU activation and instance normalization on the input tensor.\nWrapper Entry Information: fused_instance_norm_selu_conv2d(input: Tensor, weight: Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1, num_features=None, eps=1e-5, momentum=0.1, affine=False, track_running_stats=False) -> Tensor: input (Tensor): Input tensor of shape (minibatch, in_channels, iH, iW). weight (Tensor): Weights for the convolution, shape (out_channels, in_channels / groups, kH, kW). bias (Tensor, optional): Bias for the convolution layer, shape (out_channels). stride (int or tuple, optional): Stride of the convolution. Default is 1. padding (int or tuple, optional): Padding for the convolution. Default is 0. dilation (int or tuple, optional): Spacing between kernel elements. Default is 1. groups (int, optional): Number of blocked connections from input channels to output channels. Default is 1. num_features (int, optional): Number of features or channels in the input for instance normalization. eps (float, optional): A value added to the denominator for numerical stability in instance normalization. Default is 1e-5. momentum (float, optional): Momentum for updating running statistics in instance normalization. Default is 0.1. affine (bool, optional): If True, instance normalization has learnable affine parameters. Default is False. track_running_stats (bool, optional): If True, tracks running mean and variance for instance normalization. Default is False.\nother: The function combines convolution, SELU activation, and instance normalization in a single operation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-988e490635214576b037e77930deb604", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a ReLU activation followed by 2D fractional max pooling over an input signal composed of multiple planes. The input is first rectified (non-negative) and then pooled using fractional max pooling.\nWrapper Entry Information: def fused_fractional_max_pool2d_with_relu(input: torch.Tensor, kernel_size, output_size=None, output_ratio=None, return_indices=False) -> torch.Tensor: Input (Tensor): Input tensor. kernel_size (int or Tuple[int, int]): Size of the pooling window. output_size (Tuple[int, int], optional): Target output size (height, width). output_ratio (Tuple[float, float], optional): If set, output size is scaled as a ratio of the input size. return_indices (bool, optional): If `True`, return the max pooling indices along with the output.\nother: The function combines ReLU activation with fractional max pooling, allowing for optional output size or ratio specification and the option to return pooling indices.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-c16a172f3dc240c88cbafacaedfe6221", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Chebyshev polynomial of the first kind T_n(input). If n = 0, returns 1. If n = 1, returns input. For n < 6 or |input| > 1, uses a recursive formula. Otherwise, uses an explicit trigonometric formula.\nWrapper Entry Information: chebyshev_polynomial_t(input, n, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. n (Tensor): Degree of the polynomial. Keyword args: out (Tensor, optional): the output tensor.\nMath: T_{n + 1}(input) = 2 \\times input \\times T_{n}(input) - T_{n - 1}(input)\nT_{n}(input) = \\text{cos}(n \\times \\text{arccos}(x))\nother: If n = 0, returns 1. If n = 1, returns input. Uses recursion for n < 6 or |input| > 1, otherwise uses trigonometric formula.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-497431032a4749f08d2c2b4eb92faed1", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the logit of the elements of input. The input is clamped to [eps, 1 - eps] when eps is not None. When eps is None and input < 0 or input > 1, the function yields NaN.\nWrapper Entry Information: logit(input, eps=None, *, out=None) -> Tensor; input (Tensor): the input tensor.; eps (float, optional): the epsilon for input clamp bound. Default: None; out (Tensor, optional): the output tensor.\nMath: y_{i} = \\ln(\\frac{z_{i}}{1 - z_{i}}); z_{i} = \\begin{cases} x_{i} & \\text{if eps is None} \\\\ \\text{eps} & \\text{if } x_{i} < \\text{eps} \\\\ x_{i} & \\text{if } \\text{eps} \\leq x_{i} \\leq 1 - \\text{eps} \\\\ 1 - \\text{eps} & \\text{if } x_{i} > 1 - \\text{eps} \\end{cases}\nother: input is clamped to [eps, 1 - eps] when eps is not None. When eps is None and input < 0 or input > 1, the function yields NaN.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-d3eb7e27374b4ff881103691a389095d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Solves a symmetric (or Hermitian) linear system A x = b using LDL decomposition. The function first decomposes A into L and D through LDL decomposition, reconstructs matrix A, and then uses `torch.linalg.solve` to solve the linear system.\nWrapper Entry Information: solve_symmetric_ldl(A, b, *, hermitian=False, out=None) -> Tensor A (Tensor): 形状为 (*, n, n) 的对称(或 Hermitian)矩阵,其中 * 是零个或多个批次维度。 b (Tensor): 形状为 (*, n) 或 (*, n, k) 的右端项张量。 hermitian (bool, 可选): 是否将 A 视为 Hermitian 矩阵。默认值:False。 out (Tensor, 可选): 输出张量。如果为 None,则忽略。默认值:None。\nMath: Given a symmetric (or Hermitian) matrix A in \\mathbb{K}^{n \\times n} (where \\mathbb{K} is the real field \\mathbb{R} or complex field \\mathbb{C}), the LDL decomposition of A is represented as: A = L D L^{\\mathrm{T}} or A = L D L^{\\mathrm{H}}.\nother: This function supports batch processing; all computations are performed across batch dimensions.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-4f6f918252af440abb332a6351565838", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the exponential of each element in the input tensor, followed by calculating the square root of the result. Returns a tensor where each element is the result of applying exponential followed by square root to each element of input.\nWrapper Entry Information: def exp_sqrt(input, out=None) -> Tensor; input (Tensor): The input tensor.; out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = \\sqrt{e^{\\text{input}_i}}\nother: This function will return NaN for input elements that result in negative values after `exp` and `sqrt` due to overflow.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-171e992fdf344d2782f9673b7ed5a50d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a sequence of operations combining matrix multiplication, sigmoid, tanh, element-wise multiplication, and addition. It supports batches of inputs, where any leading batch dimensions in `input` will be preserved in the output. The function's operations are differentiable and support autograd. The function ensures the dimensions of `input` and `weight1` are compatible for matrix multiplication, and that `weight2` and `bias` are broadcastable to the shape of the output tensor.\nWrapper Entry Information: combined_activation(input, weight1, weight2, bias, *, out=None) -> Tensor; input (Tensor): Input tensor of shape (*, N, D_{in}), where * denotes any number of batch dimensions.; weight1 (Tensor): Weight matrix of shape (D_{in}, D_{out}).; weight2 (Tensor): Weight tensor for element-wise multiplication, must be broadcastable to the shape of the intermediate activation.; bias (Tensor): Bias tensor, must be broadcastable to the shape of the output.; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: Given an input tensor X, weight matrices W_1 and W_2, and a bias b, the function computes: Y = (tanh(sigmoid(X W_1)) ⊙ W_2) + b\n\n- σ(z) = 1 / (1 + exp(-z)) is the sigmoid function applied element-wise.\n- tanh(z) = (exp(z) - exp(-z)) / (exp(z) + exp(-z)) is the hyperbolic tangent function applied element-wise.\n- ⊙ denotes element-wise multiplication.\nother: The function supports differentiable operations and autograd. It requires compatibility in dimensions for matrix multiplication and broadcasting for element-wise operations.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6683d9566d4c4a498e388df775898fa2", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes `y += alpha * x` and returns the dot product of the modified `y` with itself. This fused function performs two operations: 1. Scales `x` by a factor of `alpha` and adds the result to `y`. 2. Computes the dot product of the modified `y` with itself.\nWrapper Entry Information: def scaled_add_dot(y: Tensor, x: Tensor, alpha: float) -> Tensor: y (Tensor): The target tensor to be modified, of shape (n,). x (Tensor): The tensor to be scaled and added to y, of shape (n,). alpha (float): The scalar multiplier for x.\nMath: y += alpha * x\ndot_product = torch.dot(y, y)\nother: The function modifies the input tensor `y` in place.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f52b79791fe8432bbd6a43023a96c586", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a contraction of a and b over multiple dimensions. It implements a generalized matrix product.\nWrapper Entry Information: def tensordot(a: Tensor, b: Tensor, dims: Union[int, Tuple[List[int], List[int]], List[List[int]]]) -> Tensor: \nArgs:\n a (Tensor): Left tensor to contract\n b (Tensor): Right tensor to contract\n dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to contract or explicit lists of dimensions for a and b respectively\nMath: r_{i_0,...,i_{m-d}, i_d,...,i_n} = \\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}.\nother: The sizes in the contracted dimensions must match, but broadcasted dimensions are handled.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-270812dbf8a249af947034627990704d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the QR decomposition of a matrix. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions. The parameter mode chooses between the full and reduced QR decomposition. It is always differentiable for 'reduced' mode, differentiable for 'complete' mode when m <= n, and never differentiable for 'r' mode.\nWrapper Entry Information: qr(A, mode='reduced', *, out=None) -> (Tensor, Tensor) A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. mode (str, optional): one of `'reduced'`, `'complete'`, `'r'`. Controls the shape of the returned tensors. Default: `'reduced'`. out (tuple, optional): output tuple of two tensors. Ignored if `None`. Default: `None`.\nMath: A = QR where Q is orthogonal in the real case and unitary in the complex case, and R is upper triangular with real diagonal. For tall matrices (m > n), the reduced QR decomposition is A = QR with Q in K^{m x n} and R in K^{n x n}.\nother: Differences with numpy.linalg.qr: mode='raw' is not implemented. Unlike numpy.linalg.qr, this function always returns a tuple of two tensors. When mode='r', the Q tensor is an empty tensor. The elements in the diagonal of R are not necessarily positive, making the QR decomposition unique only up to the sign of the diagonal of R. The QR decomposition is only well-defined if the first k = min(m, n) columns of every matrix in A are linearly independent.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-ac530c5c1f78450ba681b4b4195f1d79", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the arcsine of the elements of the input tensor. The function computes the inverse sine (arcsine) for each element in the input tensor.\nWrapper Entry Information: asin(input, *, out=None) -> Tensor: input (Tensor): the input tensor. out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\sin^{-1}(\\text{input}_{i})\nother: The function returns NaN for input values outside the range [-1, 1] as arcsine is not defined for those values.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-0f3a3783fd594587aac20e01eff12abc", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function performs a fused operation combining masked selection, addition, and GELU activation. It first selects elements from the input tensor based on a boolean mask, then adds a scalar or tensor (scaled by alpha) to the selected values, and finally applies the GELU (Gaussian Error Linear Unit) activation function element-wise to the result.\nWrapper Entry Information: fused_masked_select_add_gelu(input, mask, other, *, alpha=1, approximate='none', out=None) -> Tensor\nMath: Z = masked_select(X, M)\nS = Z + alpha * O\nY = GELU(S)\nother: The function is differentiable and supports autograd. The mask and other tensor must be broadcastable to the shape of the selected elements. The 'approximate' parameter can be set to 'tanh' for a faster, approximate GELU computation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-aa03d33964c946dc9c8062140df295d7", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function applies adaptive average pooling to the input tensors `x1` and `x2` to resize them to the specified `output_size`, and then computes the pairwise distance between the pooled outputs. The function first applies `adaptive_avg_pool2d` to each input tensor, and then calculates the pairwise distance using the specified norm `p`. A small value `eps` is added to avoid division by zero during distance calculation. The function can also retain the reduced dimension of the output via the `keepdim` parameter.\nWrapper Entry Information: def fused_pairwise_distance_adaptive_avg_pool2d(x1: torch.Tensor, x2: torch.Tensor, output_size: int or tuple, p: float = 2.0, eps: float = 1e-6, keepdim: bool = False) -> torch.Tensor: x1 (Tensor): First input tensor for adaptive average pooling and distance calculation. x2 (Tensor): Second input tensor for adaptive average pooling and distance calculation. output_size (int or tuple): The target output size for the adaptive average pooling. p (float, optional): The norm degree for pairwise distance calculation. Default: 2.0 eps (float, optional): Small value to avoid division by zero in pairwise distance. Default: 1e-6 keepdim (bool, optional): Whether to keep the reduced dimension. Default: False\nMath: No explicit formula provided. The function applies adaptive average pooling followed by pairwise distance calculation with norm p and epsilon to avoid division by zero.\nother: The function combines adaptive average pooling and pairwise distance calculation in a sequential manner.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-7ece5967c99242b4b406be8a2dc82ef5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Adds the `other` tensor, scaled by `alpha`, to the `input` tensor and computes the mean value along the specified dimension. If no dimension is specified, it computes the mean over all elements. Supports broadcasting, type promotion, and works with integer, float, and complex inputs.\nWrapper Entry Information: def add_mean(input, other, dim=None, alpha=1, keepdim=False, dtype=None, out=None) -> Tensor: input (Tensor): The input tensor. other (Tensor or Number): The tensor or number to add to input. dim (int or tuple of ints, optional): The dimension(s) to reduce. Default: None. alpha (Number, optional): The multiplier for other. Default: 1. keepdim (bool, optional): Whether the output tensor has dim retained or not. Default: False. dtype (torch.dtype, optional): The desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation. This is useful for preventing data type overflows. Default: None. out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = \\text{mean}(\\text{input}_i + \\text{alpha} \\times \\text{other}_i)\nother: Supports broadcasting to a common shape, type promotion, and integer, float, and complex inputs.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-4f188313a3474339ace1e868ae3ec2d5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a fused operation consisting of a linear transformation followed by ReLU activation and layer normalization on the input tensor.\nWrapper Entry Information: fused_layer_norm_relu_linear(input: Tensor, weight: Tensor, bias=None, normalized_shape=None, eps=1e-5, elementwise_affine=True) -> Tensor: Input (Tensor): Input tensor with shape (*, in_features). Weight (Tensor): Weights for the linear transformation, shape (out_features, in_features). Bias (Tensor, optional): Bias for the linear transformation, shape (out_features). Normalized_shape (int or list or torch.Size, optional): Shape of the dimensions to normalize. Eps (float, optional): A value added to the denominator for numerical stability. Default is 1e-5. Elementwise_affine (bool, optional): If True, layer normalization has learnable parameters. Default is True. Returns: Tensor: Result after applying the linear transformation, ReLU, and layer normalization. Example: >>> input = torch.randn(4, 5) # Example input tensor >>> weight = torch.randn(3, 5) # Linear transformation weights >>> bias = torch.randn(3) # Bias for linear layer >>> normalized_shape = 3 >>> # Apply fused operation >>> output = fused_layer_norm_relu_linear(input, weight, bias, normalized_shape) >>> print(output.shape) # Expected output shape: (4, 3)\nother: The function performs a sequence of operations: linear transformation, ReLU activation, and layer normalization. It supports optional bias and learnable parameters for layer normalization.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-dd5140f5d0104073bb68c57c7d334c88", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining element-wise addition, element-wise multiplication, and group normalization. It takes two input tensors, adds them element-wise, multiplies the result with the second tensor, and then applies group normalization using learnable parameters for scaling and shifting. The function supports autograd for gradient computation and all operations are differentiable.\nWrapper Entry Information: fused_add_mul_groupnorm(input1, input2, weight, bias, num_groups, eps=1e-5, *, out=None) -> Tensor; input1 (Tensor): The first input tensor X; input2 (Tensor): The second input tensor Y, must be broadcastable to the shape of X; weight (Tensor): Learnable weight parameter \\gamma of shape (C,), where C is the number of channels; bias (Tensor): Learnable bias parameter \\beta of shape (C,); num_groups (int): Number of groups to separate the channels into for group normalization; eps (float, optional): A value added to the denominator for numerical stability in group normalization. Default: `1e-5`; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`\nMath: Given two input tensors X and Y, and learnable parameters \\gamma and \\beta for group normalization, the function computes:\n\n\\[\n\\begin{align*}\nZ &= X + Y \\\\\nM &= Z \\odot Y \\\\\nO &= \\text{GroupNorm}(M, \\gamma, \\beta, \\text{num\\_groups}, \\epsilon)\n\\end{align*}\n\\]\n\nwhere:\n- X and Y are input tensors.\n- \\odot denotes element-wise multiplication.\n- \\text{GroupNorm}(M, \\gamma, \\beta, \\text{num\\_groups}, \\epsilon) applies group normalization to M with \\gamma and \\beta as affine parameters, over the specified number of groups, and \\epsilon for numerical stability.\nother: - The shapes of `input1` and `input2` must be broadcastable to each other.\n- The `weight` and `bias` parameters must have shape `(C,)`, where `C` is the number of channels in the input tensors.\n- The `num_groups` parameter must divide the number of channels `C` evenly.\n- This function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-837d48e849ab472cbd085950fb72c382", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Implements stochastic gradient descent, optionally with momentum, weight decay, dampening, and Nesterov momentum. It can maximize or minimize an objective function and supports different optimization algorithms for performance.\nWrapper Entry Information: def SGD(params, lr=1e-3, momentum=0, weight_decay=0, dampening=0, nesterov=False, maximize=False, foreach=None, differentiable=False, fused=None)\nMath: \\begin{aligned} &g_t \\leftarrow \\nabla_{\\theta} f_t (\\theta_{t-1}) \\\\\\ &\\text{if} \\: \\lambda \\neq 0 \\\\\\ &g_t \\leftarrow g_t + \\lambda \\theta_{t-1} \\\\\\ &\\text{if} \\: \\mu \\neq 0 \\\\\\ &\\text{if} \\: t > 1 \\\\\\ &\\textbf{b}_t \\leftarrow \\mu \\textbf{b}_{t-1} + (1-\\tau) g_t \\\\\\ &\\text{else} \\\\\\ &\\textbf{b}_t \\leftarrow g_t \\\\\\ &\\text{if} \\: \\textit{nesterov} \\\\\\ &g_t \\leftarrow g_{t} + \\mu \\textbf{b}_t \\\\\\ &\\text{else} \\\\\\ &g_t \\leftarrow \\textbf{b}_t \\\\\\ &\\text{if} \\: \\textit{maximize} \\\\\\ &\\theta_t \\leftarrow \\theta_{t-1} + \\gamma g_t \\\\\\ &\\text{else} \\\\\\ &\\theta_t \\leftarrow \\theta_{t-1} - \\gamma g_t \\end{aligned}\nother: Nesterov momentum is based on a research paper. The algorithm prioritizes different implementations based on performance. It differs from some traditional frameworks in its handling of momentum. The initial momentum buffer is set to the gradient value at the first step.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-83a297a664454392acae42b950384831", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over the input tensor, followed by batch normalization and then applies the ReLU activation function element-wise to the normalized result. This combined operation is useful for applying feature extraction, normalization, and non-linearity in one step, commonly used in convolutional neural networks (CNNs).\nWrapper Entry Information: def relu_batch_norm_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, running_mean=None, running_var=None, bn_weight=None, bn_bias=None, training=False, momentum=0.1, eps=1e-5, inplace=False) -> Tensor\n\nArgs:\n input (Tensor): The input tensor of shape (minibatch, in_channels, iH, iW).\n weight (Tensor): The convolution filters of shape (out_channels, in_channels / groups, kH, kW).\n bias (Tensor, optional): Optional bias tensor of shape (out_channels). Default: None.\n stride (int or tuple, optional): The stride of the convolution kernel. Default: 1.\n padding (int, tuple, or string, optional): Padding added to all sides of the input. Default: 0.\n dilation (int or tuple, optional): The spacing between kernel elements. Default: 1.\n groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1.\n running_mean (Tensor, optional): The running mean for batch normalization. Default: None.\n running_var (Tensor, optional): The running variance for batch normalization. Default: None.\n bn_weight (Tensor, optional): Learnable scaling factor for batch normalization (gamma). Default: None.\n bn_bias (Tensor, optional): Learnable shift factor for batch normalization (beta). Default: None.\n training (bool, optional): If True, updates running statistics for batch normalization. Default: False.\n momentum (float, optional): Value for updating the running mean and variance in batch normalization. Default: 0.1.\n eps (float, optional): A small value added for numerical stability in batch normalization. Default: 1e-5.\n inplace (bool, optional): If True, performs ReLU in-place. Default: False.\nMath: out = ReLU(BatchNorm(conv2d(input)))\nReLU(x) = max(0, x)\ny = \\frac{x - \\mathrm{E}[x]}{\\sqrt{\\mathrm{Var}[x] + \\epsilon}} * \\gamma + \\beta\nother: The function combines convolution, batch normalization, and ReLU activation in a single step, which is a common pattern in CNNs for efficient computation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-c4c5d71167f14729967dbe0df7067ee9", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over an input image composed of several input planes. Supports TensorFloat32. May select a nondeterministic algorithm on CUDA with CuDNN for performance. Supports complex data types.\nWrapper Entry Information: conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1) -> Tensor Args: input: input tensor of shape (minibatch , in_channels , iH , iW) weight: filters of shape (out_channels , in_channels/groups , kH , kW) bias: optional bias tensor of shape (out_channels). Default: None stride: the stride of the convolving kernel. Can be a single number or a tuple (sH, sW). Default: 1 padding: implicit paddings on both sides of the input. Can be a string {'valid', 'same'}, single number or a tuple (padH, padW). Default: 0 dilation: the spacing between kernel elements. Can be a single number or a tuple (dH, dW). Default: 1 groups: split input into groups, both in_channels and out_channels should be divisible by the number of groups. Default: 1\nother: Supports TensorFloat32. May select a nondeterministic algorithm on CUDA with CuDNN. Supports complex data types.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-d11b2ee92ecd46eda022a00907af3242", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the cosine similarity between two normalized input tensors `x1` and `x2`. This function normalizes `x1` and `x2` along a specified dimension using L_p normalization, and subsequently calculates the cosine similarity between these normalized tensors along the specified dimension. This involves ensuring vectors are scaled to avoid division by zero by introducing small epsilon values both during normalization and similarity computation.\nWrapper Entry Information: def normalized_cosine_similarity(x1: Tensor, x2: Tensor, dim: int = 1, eps_similarity: float = 1e-8, p_norm: float = 2, eps_norm: float = 1e-12) -> Tensor\nMath: The operation is defined as:\n\nsimilarity = \\frac{\\text{normalize}(x1) \\cdot \\text{normalize}(x2)}{\\max(\\lVert \\text{normalize}(x1) \\Vert _2, \\epsilon) \\cdot \\max(\\lVert \\text{normalize}(x2) \\Vert _2, \\epsilon)}\n\nwhere the `normalize` function is defined as:\n\nv = \\frac{v}{\\max(\\lVert v \\rVert_p, \\epsilon)}.\nother: The function allows broadcasting x2 to match x1's shape. Default values are provided for dimension, normalization, and similarity thresholds to enhance robustness against division by zero.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f3cef4ea85f0425a9bf4d76d144d22e5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the solution `x` to the equation `Ax = b` using the Cholesky decomposition. It first performs Cholesky decomposition on a symmetric positive-definite matrix `A` to obtain a lower triangular matrix `L` such that `A = L * L.T`, then solves for `x` in `Ax = b` using the Cholesky factorization.\nWrapper Entry Information: def fused_cholesky_solve(A: Tensor, b: Tensor) -> Tensor: A: The symmetric positive-definite matrix `A` of shape `(n, n)`. b: The right-hand side tensor `b` of shape `(n, k)`.\nMath: Cholesky decomposition: A = L * L.T, Solve: Ax = b\nother: The function assumes that the input matrix `A` is symmetric positive-definite.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-bd29ed3d82ce4259a74b4daff2fad6e4", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Matrix product of two tensors. The behavior depends on the dimensionality of the tensors: 1D tensors return a dot product; 2D tensors return a matrix-matrix product; 1D and 2D tensors return a matrix-vector product; N-dimensional tensors (N > 2) return a batched matrix multiply with broadcasting support. Sparse layouts are supported for 2D matrix-matrix products. TensorFloat32 is supported. On certain ROCm devices, float16 inputs use different precision for backward. The 1D dot product version does not support an out parameter.\nWrapper Entry Information: matmul(input, other, *, out=None) -> Tensor\nArguments:\n input (Tensor): the first tensor to be multiplied\n other (Tensor): the second tensor to be multiplied\nKeyword args:\n out (Tensor, optional): the output tensor.\nother: Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, or may not have autograd support. If you notice missing functionality please open a feature request.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-2de63b529046419bbdbe301183a782ca", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining torch.gather and torch.Tensor.masked_fill. It first gathers values from the input tensor along a specified dimension using provided indices, and then replaces the gathered elements with a specified value where the mask is True.\nWrapper Entry Information: fused_gather_masked_fill(input, dim, index, mask, value, *, sparse_grad=False, out=None) -> Tensor; input (Tensor): The input tensor X.; dim (int): The dimension along which to index.; index (LongTensor): The indices of elements to gather, of the same dimensionality as `input`.; mask (BoolTensor): A boolean mask tensor, broadcastable to the shape of the output tensor Y.; value (float): The value to fill in where `mask` is True.; sparse_grad (bool, optional): If True, gradient w.r.t. `input` will be a sparse tensor. Default: `False`.; out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: Y = \\text{gather}(X, \\text{dim}, I)\nY[M] = \\text{value}\nother: - The input and index tensors must have the same number of dimensions.\n- The size of index at each dimension d must not exceed the size of input at that dimension, except at dimension dim.\n- The mask tensor must be broadcastable to the shape of the gathered output.\n- The function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-2a6303ed19dd441a8b9406303601283b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function computes the cross entropy loss with log softmax applied to the input logits. It combines log softmax activation and cross entropy loss calculation in a numerically stable way. The log softmax is applied to the input logits, and the cross entropy loss is computed between the normalized logits and the target. The function allows customization with options such as which dimension to apply the log softmax, manual rescaling weights for each class, handling of ignored targets, reduction method for loss aggregation, and label smoothing to modify the target distribution.\nWrapper Entry Information: def fused_cross_entropy_log_softmax(input: torch.Tensor, target: torch.Tensor, dim: int = 1, weight: torch.Tensor = None, ignore_index: int = -100, reduction: str = 'mean', label_smoothing: float = 0.0) -> torch.Tensor\nArgs:\n input (Tensor): Input tensor of logits, where softmax will be computed along `dim`.\n target (Tensor): Ground truth class indices or probabilities.\n dim (int, optional): Dimension along which to compute log softmax. Default is 1.\n weight (Tensor, optional): Manual rescaling weight for each class.\n ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the input gradient. Default: -100.\n reduction (str, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. Default: 'mean'.\n label_smoothing (float, optional): Specifies the amount of smoothing to be applied, where 0.0 means no smoothing. Default: 0.0.\nMath: log_softmax(x_i) = log(exp(x_i) / sum(exp(x)))\nCE(y, p) = -sum(y * log(p))\nother: The function integrates the log softmax and cross entropy loss computation into a single operation for numerical stability. The input and target tensors must be of compatible shapes, where the input is expected to have logits of size (N, C) and target should have size (N,) for class indices.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b39378719da64c2d8300dc220b5a4232", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a matrix multiplication of the matrices mat1 and mat2. The matrix input is added to the final result. If mat1 is a (n x m) tensor, mat2 is a (m x p) tensor, then input must be broadcastable with a (n x p) tensor and out will be a (n x p) tensor. Alpha and beta are scaling factors on matrix-vector product between mat1 and mat2 and the added matrix input respectively. If beta is 0, then input will be ignored, and nan and inf in it will not be propagated. This operation supports sparse layouts. If input is sparse the result will have the same layout and if out is provided it must have the same layout as input. Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, or may not have autograd support. This operator supports TensorFloat32. On certain ROCm devices, when using float16 inputs this module will use different precision for backward.\nWrapper Entry Information: addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor; input (Tensor): matrix to be added; mat1 (Tensor): the first matrix to be matrix multiplied; mat2 (Tensor): the second matrix to be matrix multiplied; beta (Number, optional): multiplier for input (β); alpha (Number, optional): multiplier for mat1 @ mat2 (α); out (Tensor, optional): the output tensor.\nMath: out = β * input + α * (mat1 @ mat2)\nother: Sparse support is a beta feature and some layout(s)/dtype/device combinations may not be supported, or may not have autograd support. This operator supports TensorFloat32. On certain ROCm devices, when using float16 inputs this module will use different precision for backward.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-37a45ee5d6dc4415b23c9dd30e61fe61", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Solves the linear system `Ax = b` using QR decomposition. This function combines the QR decomposition with solving a linear system. Given a matrix `A` and a vector (or matrix) `b`, it performs the QR decomposition of `A` and computes the solution `x` using the formula `x = R^{-1} (Q^T b)`.\nWrapper Entry Information: def fused_qr_solve(A: Tensor, b: Tensor) -> Tensor: A: The matrix `A` of shape `(m, n)` where `m >= n`. b: The right-hand side tensor `b` of shape `(m, k)`.\nMath: x = R^{-1} Q^T b\nother: The function assumes `m >= n` for the matrix `A`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-91b698430a3e4048ab3cb178db442b7b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D adaptive average pooling over an input tensor, followed by the sigmoid activation function applied element-wise. This is used for downsampling a feature map to a specified output size and then normalizing the result with the sigmoid function.\nWrapper Entry Information: def sigmoid_adaptive_avg_pool2d(input: Tensor, output_size: Union[int, Tuple[int, int]]) -> Tensor\nMath: out = σ(AdaptiveAvgPool2D(input))\nSigmoid(x) = 1 / (1 + exp(-x))\nother: Each element in the resulting tensor is scaled to the range (0, 1) by the sigmoid activation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-dac33da1e6774a2191262f47ed0e75af", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the cosine of the elements of the input tensor.\nWrapper Entry Information: cos(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor.; Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\cos(\\text{input}_{i})\nother: The function computes the cosine of each element in the input tensor and returns a new tensor with these values.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-71b10a42ed2845888fc64dccdb2ee75c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining batch matrix multiplication, dropout, and GELU activation. It computes the batch matrix multiplication of two input tensors, applies dropout to the result, and then applies the GELU activation function.\nWrapper Entry Information: fused_bmm_dropout_gelu(input1, input2, p=0.5, training=True, inplace=False, approximate='none', *, out=None) -> Tensor\n- **input1** (Tensor): First input tensor for batch matrix multiplication, of shape (B, N, M), where B is the batch size.\n- **input2** (Tensor): Second input tensor for batch matrix multiplication, of shape (B, M, P).\n- **p** (float, optional): Probability of an element to be zeroed in the dropout layer. Default: `0.5`.\n- **training** (bool, optional): Apply dropout if `True`. Default: `True`.\n- **inplace** (bool, optional): If set to `True`, will perform the dropout operation in-place. Default: `False`.\n- **approximate** (str, optional): Can be `'none'` or `'tanh'`. The approximation to use for GELU. Default: `'none'`.\n- **out** (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: Given two input tensors X and Y, this function computes:\n\n\\[\n\\begin{align*}\nZ &= \\text{bmm}(X, Y) \\\\\nD &= \\text{Dropout}(Z, p) \\\\\nO &= \\text{GELU}(D)\n\\end{align*}\n\\]\n\nwhere:\n\n- \\text{bmm}(X, Y) performs batch matrix multiplication.\n- \\text{Dropout}(Z, p) randomly zeroes elements of Z with probability p.\n- \\text{GELU}(D) applies the Gaussian Error Linear Unit activation function element-wise to D.\nother: - The shapes of `input1` and `input2` must be compatible for batch matrix multiplication: `input1` of shape `(B, N, M)` and `input2` of shape `(B, M, P)` result in an output of shape `(B, N, P)`.\n- The `dropout` is applied during training when `training=True`. Set `training=False` to disable dropout during evaluation.\n- The `GELU` activation is applied element-wise to the output of dropout.\n- All operations are differentiable and support autograd.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-91f086bdba4744609803d8cf9b2ab3f3", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the truncated integer values of the elements of the input tensor. For integer inputs, it follows the array-api convention of returning a copy of the input tensor.\nWrapper Entry Information: trunc(input, *, out=None) -> Tensor\n\nArgs:\n input (Tensor): the input tensor.\n\nKeyword args:\n out (Tensor, optional): the output tensor.\nother: For integer inputs, follows the array-api convention of returning a copy of the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-eda30a668f9b445f864fa81b512aa3e3", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the matrix power A^k of a square matrix A using eigendecomposition. It relies on A being diagonalizable and computes the power through the equation A^k = V diag(Λ^k) V^(-1), where Λ and V are the eigenvalues and eigenvectors of A. It allows for fractional powers of matrices and supports real or complex exponents. If A is not diagonalizable, the result may not be accurate.\nWrapper Entry Information: def matrix_power_eig(A, k, *, out=None) -> Tensor\nArgs:\n A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of square matrices.\n k (float or complex): the exponent to which the matrix :attr:`A` is to be raised.\n\nKeyword args:\n out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.\nMath: A^k = V diag(Λ^k) V^{-1}, where A = V diag(Λ) V^{-1}, and Λ^k denotes the element-wise power of the eigenvalues.\nother: Supports input of float, double, cfloat, and cdouble dtypes. Also supports batches of matrices, output has the same batch dimensions. Note that the computed A^k may be complex even if A is real, due to complex eigenvalues. Warning: If A is not diagonalizable, the result may not be accurate. Gradients might be numerically unstable if the distance between any two eigenvalues is close to zero.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-7a073b2b19a54d098be7bbb0089c27cd", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the natural logarithm of each element in the input tensor, then applies the hyperbolic tangent (tanh) function to the result. This involves applying the logarithm first, which is only defined for positive numbers, and then applying tanh to transform the result between -1 and 1.\nWrapper Entry Information: def log_tanh(input, out=None) -> Tensor: input (Tensor): The input tensor. All elements must be positive for the log function. out (Tensor, optional): The output tensor.\nMath: \\text{out}_{i} = \\tanh(\\log(\\text{input}_{i}))\nother: All input elements must be positive for the logarithm function to be defined.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-88bb80e1e19f4e45974105bd5b4aa758", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the exponential of the elements of the input tensor.\nWrapper Entry Information: exp(input, *, out=None) -> Tensor\n input (Tensor): the input tensor.\n out (Tensor, optional): the output tensor.\nMath: y_{i} = e^{x_{i}}\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-63b594e894014f1cb17357d2ca37b053", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes two operations on matrix `C`: first, it performs the matrix-matrix product `C = alpha * torch.mm(A, B) + beta * C`, then updates `C` to be `C = alpha * torch.mm(C, C.T) + beta * C`. This function effectively performs two sequential matrix operations: a weighted sum of a matrix product and itself, followed by a weighted product of `C` and its transpose.\nWrapper Entry Information: matrix_multiply_symmetric(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, alpha: float, beta: float) -> torch.Tensor; Args: A (Tensor): The first input matrix of shape `(n, m)`. B (Tensor): The second input matrix of shape `(m, p)`. C (Tensor): The target matrix for the operations, shape `(n, p)`. alpha (float): Scalar multiplier for matrix products. beta (float): Scalar multiplier for adding to `C`. Example: A = torch.tensor([[1.0, 2.0], [3.0, 4.0]]), B = torch.tensor([[0.5, -1.0], [1.5, 2.0]]), C = torch.tensor([[1.0, 0.0], [0.0, 1.0]]), alpha, beta = 2.0, 0.5, result = matrix_multiply_symmetric(A, B, C, alpha, beta)\nMath: C = alpha * torch.mm(A, B) + beta * C\nC = alpha * torch.mm(C, C.T) + beta * C\nother: This function performs a fused operation of matrix multiplication and symmetric update.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-480648b79e3d4207ac10bf110b90f31f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the cosine similarity between `x1` and `x2` along a specified dimension, adds a singleton dimension, and applies 2D average pooling. It first computes cosine similarity along dim=1 using `cosine_similarity`, then adds a singleton dimension using `unsqueeze`, and finally applies 2D average pooling using `avg_pool2d`.\nWrapper Entry Information: fused_avg_pool2d_cosine_similarity(x1: torch.Tensor, x2: torch.Tensor, kernel_size: int, stride: int = None, padding: int = 0, eps: float = 1e-8) -> torch.Tensor\nother: The function provides an optional `stride` parameter which defaults to the value of `kernel_size` if not provided. The `eps` parameter is used to prevent division by zero in cosine similarity.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-67eda67e084f415db53beb4402320699", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a fused operation consisting of dropout followed by hard shrinkage on the input tensor. The function first applies dropout to the input tensor, where each element is zeroed with a probability of p if training is True. The dropout can be applied in-place if specified. After dropout, a hard shrinkage operation is applied, which shrinks values towards zero based on the lambda parameter.\nWrapper Entry Information: def fused_hardshrink_dropout(input: torch.Tensor, p: float = 0.5, training: bool = True, inplace: bool = False, lambd: float = 0.5) -> torch.Tensor\nArgs:\n input (Tensor): The input tensor.\n p (float, optional): Probability of an element to be zeroed in dropout. Default is 0.5.\n training (bool, optional): Apply dropout if True. Default is True.\n inplace (bool, optional): If set to True, dropout will be applied in-place. Default is False.\n lambd (float, optional): The lambda parameter for the hard shrinkage function. Default is 0.5.\nReturns:\n Tensor: Result after applying dropout and then hard shrinkage on the input.\nother: The function combines dropout and hard shrinkage operations, which are typically used in neural network training to prevent overfitting and to enforce sparsity, respectively.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6584c3ee8b14474983d820e65a4742a4", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the complementary error function (erfc) and the square root of each element in the input tensor.\nWrapper Entry Information: def erfc_sqrt(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: The input tensor for which the erfc and square root are computed.\nMath: \\text{erfc}(x) = 1 - \\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x} e^{-t^2} dt \\text{out}_{i} = \\sqrt{\\text{input}_{i}}\nother: Returns a tuple containing the erfc result and the square root result for each element in the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-438651ab55e5428daa39a47005a42e63", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns the reciprocal of the square root of the tensordot product of two tensors `a` and `b`. This function performs a tensor contraction of `a` and `b` over the specified dimensions using `torch.tensordot`, and then applies the element-wise reciprocal square root to the resulting tensor. The operation involves computing the tensordot product first and then applying the reciprocal of the square root element-wise to the result.\nWrapper Entry Information: def tensordot_rsqrt(a: torch.Tensor, b: torch.Tensor, dims) -> torch.Tensor: a (Tensor): Left tensor to contract. b (Tensor): Right tensor to contract. dims (int, Tuple[List[int], List[int]], or List[List[int]]): Dimensions for contraction, as per `torch.tensordot`.\nMath: \\text{output} = \\frac{1}{\\sqrt{\\sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \\times b_{k_0,...,k_{d-1}, i_d,...,i_n}}}\nother: The function applies the `torch.tensordot` and `torch.rsqrt` operations. The `dims` argument specifies the dimensions over which the contraction happens, similar to the `torch.tensordot` function.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-829de8149cf149d782ba0cbad32c09b5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the natural logarithm element-wise on the input tensor, followed by applying the softmax function along the specified dimension. This combined operation scales input values to a range between 0 and 1, summing to 1 after the logarithmic transformation. It allows transformation of the input tensor into a probability distribution.\nWrapper Entry Information: def softmax_log(input, dim=-1, dtype=None) -> Tensor: \nArgs:\n input (Tensor): The input tensor on which logarithm and softmax are applied.\n dim (int): The dimension along which softmax will be computed. Default: -1.\n dtype (:class:`torch.dtype`, optional): The desired data type of the returned tensor. If specified, the input tensor is cast to :attr:`dtype` before the operation is performed. Useful for preventing data type overflows. Default: None.\nExample:\n >>> import torch\n >>> import torch.nn.functional as F\n >>> # Define input tensor\n >>> input = torch.rand(3, 4) * 10\n >>> # Apply softmax_log along the last dimension\n >>> result = softmax_log(input, dim=1)\n >>> result\n tensor([[0.1829, 0.1782, 0.2783, 0.3606],\n [0.3119, 0.1724, 0.3256, 0.1900],\n [0.2057, 0.2166, 0.2991, 0.2786]])\n\n >>> # Apply softmax_log along a different dimension\n >>> result = softmax_log(input, dim=0)\n >>> result\n tensor([[0.3122, 0.4444, 0.2720, 0.2159],\n [0.3879, 0.2167, 0.4226, 0.2165],\n [0.2999, 0.3389, 0.3055, 0.5676]])\nMath: out = Softmax(log(input))\n\nwhere:\n\ny_{i} = \\frac{\\exp(\\log(x_{i}))}{\\sum_j \\exp(\\log(x_{j}))} = \\frac{x_i}{\\sum_j x_j}\nother: The function handles optional data type casting to prevent overflow and allows specifying the dimension for softmax application.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-a9aebe7cd5e741f9819610e210d594eb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a linear transformation followed by a sigmoid activation and dropout. This function sequentially applies a linear transformation to the input tensor, a sigmoid activation to scale the values between 0 and 1, and randomly zeroes some elements of the tensor with a specified probability during dropout.\nWrapper Entry Information: def dropout_sigmoid_linear(input: torch.Tensor, weight: torch.Tensor, bias=None, p=0.5, training=True, inplace=False) -> torch.Tensor: Input tensor of shape :math:`(*, \\text{in\\_features})`. Weight tensor of shape :math:`(\\text{out\\_features}, \\text{in\\_features})`. Bias tensor of shape :math:`(\\text{out\\_features})`. Default is `None`. Probability of an element to be zeroed in dropout. Default: 0.5 If `True`, applies dropout during training. Default: `True` If `True`, performs the operation in-place. Default: `False`\nother: The function applies dropout only if the `training` parameter is set to `True`. The `inplace` parameter allows for in-place operations to save memory.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6d7d7a1572de4ef19d1b20eeb4094268", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies Batch Normalization for each channel across a batch of data. Batch Normalization is a technique to improve the training of deep neural networks by ensuring that each layer receives whitened input, which helps to stabilize the learning process and reduce the number of training epochs needed to converge.\nWrapper Entry Information: def batch_norm(input, running_mean, running_var, weight=None, bias=None, training=False, momentum=0.1, eps=1e-05) -> Tensor\nother: This function is related to the BatchNorm classes like BatchNorm1d, BatchNorm2d, and BatchNorm3d, which are layers that handle this operation with additional features.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b211216562ce47218e4faefeb69a3284", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the natural logarithm of the absolute value of the gamma function on the input tensor.\nWrapper Entry Information: gammaln(input, *, out=None) -> Tensor\n\nArgs:\n input (Tensor): the input tensor.\n\nKeyword args:\n out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\ln \\Gamma(|\\text{input}_{i}|)\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-95acbc1a47824faaa34fb0d73a228b89", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the bitwise AND of input and other. The input tensor must be of integral or Boolean types. For bool tensors, it computes the logical AND.\nWrapper Entry Information: bitwise_and(input, other, *, out=None) -> Tensor; input: the first input tensor; other: the second input tensor; out (Tensor, optional): the output tensor.\nother: The input tensors must be of integral or Boolean types.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-e4a7846ad75646708b931b6639175bfd", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Subtracts 'other', scaled by 'alpha', from 'input', and then applies the Gaussian Error Linear Units (GELU) activation function to the result. The function supports two modes for GELU: exact and approximate using 'tanh'.\nWrapper Entry Information: def sub_gelu(input, other, alpha=1, approximate='none', out=None) -> Tensor: input (Tensor): The input tensor. other (Tensor or Number): The tensor or number to subtract from input. alpha (Number, optional): The multiplier for other. Default is 1. approximate (str, optional): The approximation method for GELU. Default is 'none'. out (Tensor, optional): The output tensor.\nMath: out_i = GELU(input_i - alpha * other_i)\n\nGELU(x) = x * Φ(x) when approximate is 'none'\nGELU(x) = 0.5 * x * (1 + Tanh(√(2/π) * (x + 0.044715 * x^3))) when approximate is 'tanh'\nother: The function allows for an optional output tensor and supports both exact and approximate GELU calculations.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-02cc469192bb4412938dede63a8eedda", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the Gaussian Error Linear Units (GELU) activation function to the elements of input, then computes the standard deviation along the specified dimension(s). The GELU function is applied element-wise to the input tensor, with an option to use an approximation method. After activation, the standard deviation of the result is calculated over specified dimensions, with options to keep reduced dimensions and apply a correction factor.\nWrapper Entry Information: def gelu_std(input, dim=None, keepdim=False, correction=1, approximate='none', out=None) -> Tensor: input (Tensor): The input tensor. dim (int or tuple of ints, optional): The dimension or dimensions to reduce. If None, computes over all dimensions. keepdim (bool, optional): Whether to retain the dimension(s) with size 1 after reduction. Default is False. correction (int, optional): The correction factor for standard deviation. Default is 1. approximate (str, optional): The approximation method for GELU. Default is 'none'. out (Tensor, optional): The output tensor.\nMath: GELU(x) = x * Φ(x) (when approximate is 'none')\nGELU(x) = 0.5 * x * (1 + Tanh(√(2/π) * (x + 0.044715 * x^3))) (when approximate is 'tanh')\nσ = √(1/(max(0, N - δN)) * Σ(x_i - x̄)^2)\nother: The function allows the use of a correction factor in the standard deviation calculation. It supports two methods for computing GELU: exact using CDF or approximate using a tanh-based formula.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-0ed62fee44d9485ea80491be353d9dc6", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs the same operation as torch.permute, which rearranges the dimensions of the input tensor according to the specified dims, but all output tensors are freshly created instead of aliasing the input.\nWrapper Entry Information: torch.permute_copy(input, dims) -> Tensor\nother: Freshly created output tensors mean that the function does not create views, so changes to the output will not affect the input.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1028736f1c1045d7ada072ce8e7b81a9", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the logarithmic derivative of the gamma function on input. This function is similar to SciPy's scipy.special.digamma. From PyTorch 1.8 onwards, the digamma function returns -Inf for 0, previously it returned NaN for 0.\nWrapper Entry Information: digamma(input, *, out=None) -> Tensor; Args: input (Tensor): the tensor to compute the digamma function on; Keyword args: out (Tensor, optional): the output tensor.\nMath: \\digamma(x) = \\frac{d}{dx} \\ln\\left(\\Gamma\\left(x\\right)\\right) = \\frac{\\Gamma'(x)}{\\Gamma(x)}\nother: This function is similar to SciPy's scipy.special.digamma. From PyTorch 1.8 onwards, the digamma function returns -Inf for 0, previously it returned NaN for 0.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-80ac379da8704c958ef03daed8d41b46", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the softmax function to the input tensor along the specified dimension, and then multiplies the softmaxed values by another tensor or number. The softmax function re-scales the elements so that they lie in the range [0, 1] and sum to 1 along the specified dimension.\nWrapper Entry Information: def softmax_mul(input, other, dim, dtype=None, out=None) -> Tensor: Applies the softmax function to the input tensor along the specified dimension, and then multiplies the softmaxed values by other. Args: input (Tensor): The input tensor to apply softmax on. other (Tensor or Number): The tensor or number to multiply with the softmaxed values. dim (int): The dimension along which softmax will be computed. dtype (torch.dtype, optional): The desired data type of returned tensor. If specified, the input tensor is cast to dtype before the operation. Default is None. out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = \\text{Softmax}(\\text{input}_i) \\times \\text{other}_i\n \\text{Softmax}(x_{i}) = \\frac{\\exp(x_i)}{\\sum_j \\exp(x_j)}\nother: Softmax re-scales the elements so that they lie in the range [0, 1] and sum to 1 along the specified dimension.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-fb3ffb7be7524d2494d8cc837084eb6a", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the bitwise AND operation between two tensors and then applies a Binomial distribution sampling based on the resulting tensor's values. First, it computes the bitwise AND of `input` and `other`. Then, the result is used as input for the Binomial distribution, with each element representing the number of trials with the probability specified in `probs` or `logits`.\nWrapper Entry Information: def bitwise_and_binomial(input: torch.Tensor, other: torch.Tensor, total_count: torch.Tensor, probs: torch.Tensor = None, logits: torch.Tensor = None) -> torch.Tensor: input (Tensor): The first input tensor of integral or Boolean type. other (Tensor): The second input tensor of integral or Boolean type. total_count (Tensor): Number of Bernoulli trials, must be broadcastable with `probs` or `logits`. probs (Tensor, optional): Event probabilities. Only one of `probs` or `logits` should be provided. logits (Tensor, optional): Event log-odds.\nMath: \\text{output} = \\text{Binomial}( \\text{bitwise\\_and}(\\text{input}, \\text{other}))\nother: The function performs a bitwise AND on the input tensors and uses the result to conduct Binomial sampling with specified probabilities or logits. The `total_count` indicates the number of Bernoulli trials, and either `probs` or `logits` must be provided for the Binomial distribution.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-352b77bb1ac149459fbcda6a1e61ec0c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function computes the conversion of angles from radians to degrees and calculates the square root for each element in the input tensor. It returns a tuple where the first element is the converted degrees and the second is the square root of the input tensor elements.\nWrapper Entry Information: def rad2deg_sqrt(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Args: input (Tensor): The input tensor with angles in radians.\nMath: \\text{out}_{i} = \\text{input}_{i} \\times (180.0 / \\pi) \\text{out}_{i} = \\sqrt{\\text{input}_{i}}\nother: The function uses torch's rad2deg and sqrt functions to perform the operations.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-42e97cc21acd464bb7a9ec6323a4fe8c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Bessel function of the first kind of order 1 for each element of the input tensor.\nWrapper Entry Information: bessel_j1(input, *, out=None) -> Tensor Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: Bessel function of the first kind of order :math:`1`.\nother: The function supports an optional output tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-42f721c18cde485fb32fbe1e29128328", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the LU decomposition with partial pivoting of a matrix. If pivot=True, returns a permutation matrix P, a lower triangular matrix L, and an upper triangular matrix U such that A = PLU. If pivot=False and A is on GPU, computes the LU decomposition without pivoting, returning empty P, L and U such that A = LU. Supports float, double, cfloat, and cdouble dtypes, as well as batches of matrices. Outputs have the same batch dimensions as input.\nWrapper Entry Information: lu(A, *, pivot=True, out=None) -> (Tensor, Tensor, Tensor) Args: A (Tensor): tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions. pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting or no pivoting. Default: `True`. Keyword args: out (tuple, optional): output tuple of three tensors. Ignored if `None`. Default: `None`.\nMath: A = PLU where P is a permutation matrix, L is lower triangular with ones on the diagonal, U is upper triangular. If pivot=False, A = LU.\nother: LU decomposition is not unique; different platforms may yield different decompositions. Gradient computations are supported only if the matrix is full-rank.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-da2be421bd4f4679a35aab46c5608101", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the Gaussian Error Linear Units (GELU) activation function to each element in the input tensor, followed by computing the minimum value along the specified dimension. If no dimension is specified, it computes the minimum over all elements. The function supports two methods for computing GELU: exact ('none') and an approximation using 'tanh'.\nWrapper Entry Information: gelu_min(input, approximate='none', dim=None, keepdim=False, out=None) -> Tensor or (Tensor, LongTensor)\nMath: When approximate is 'none': GELU(x) = x * Φ(x), where Φ(x) is the Cumulative Distribution Function for Gaussian Distribution.\nWhen approximate is 'tanh': GELU(x) = 0.5 * x * (1 + Tanh(√(2/π) * (x + 0.044715 * x^3)))\nother: Returns a namedtuple (values, indices) if dim is specified, otherwise returns the minimum value tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-5f0ba656d54941d1b319a195df05031a", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function applies an affine transformation to the input tensor followed by grid sampling. It first generates a 2D flow field (sampling grid) based on the input affine matrix `theta` using `affine_grid`. Then it uses the generated grid to sample from the input image using `grid_sample`. It supports multiple interpolation modes (such as 'bilinear', 'nearest', and 'bicubic'), different padding modes ('zeros', 'border', 'reflection'), and has an option to align corners for transformation consistency.\nWrapper Entry Information: def grid_sample_with_affine(input: torch.Tensor, theta: torch.Tensor, size: torch.Size, mode: str = 'bilinear', padding_mode: str = 'zeros', align_corners: bool = False) -> torch.Tensor: Input tensor of shape (N, C, H_{in}, W_{in}) (4D). Affine transformation matrix of shape (N, 2, 3) for 2D transformations. Target output image size as a 4D size (N, C, H_{out}, W_{out}). Interpolation mode to calculate output values, 'bilinear', 'nearest', or 'bicubic'. Default is 'bilinear'. Defines how to handle grid values outside the input range. Options: 'zeros', 'border', 'reflection'. Default is 'zeros'. If True, aligns the grid to corner pixels for transformation consistency. Default is False.\nother: The function generates an affine transformation grid and applies grid sampling to the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-177d413e25474275bfcd9471c75cb895", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Moore-Penrose pseudoinverse of a matrix using Singular Value Decomposition (SVD). It decomposes the input matrix A into its singular value components, inverts the non-zero singular values above a certain threshold to avoid numerical instability, and reconstructs the pseudoinverse using these components. Supports input of float, double, cfloat, and cdouble dtypes, and can handle batches of matrices.\nWrapper Entry Information: def pseudoinverse_svd(A, *, full_matrices=True, rcond=1e-15, out=None) -> Tensor\n\nArgs:\n A (Tensor): Input tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.\n\nKeyword args:\n full_matrices (bool, optional): If `True` (default), compute the full SVD. If `False`, compute the reduced SVD.\n rcond (float, optional): Relative condition number threshold. Singular values smaller than `rcond * largest_singular_value` are set to zero. Default: `1e-15`.\n out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: A^{+} = V^{\\mathrm{H}} \\Sigma^{+} U^{\\mathrm{H}}; \\sigma_i^{+} = \\begin{cases} \\dfrac{1}{\\sigma_i}, & \\text{if } \\sigma_i > \\text{rcond} \\times \\sigma_{\\max} \\\\ 0, & \\text{otherwise} \\end{cases}\nother: Supports input of float, double, cfloat, and cdouble dtypes; Handles batches of matrices\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-685c416260624574b55e451f2644af7d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the exponential function to each element in the input tensor and then computes the mean value of the result along the specified dimension or over all elements if no dimension is specified.\nWrapper Entry Information: def exp_mean(input, dim=None, keepdim=False, dtype=None, out=None) -> Tensor\nMath: The combined operation is defined as: out = mean(e^{input}) where the exponential function is defined as: y_{i} = e^{x_{i}}\nother: The function first applies the exponential function to each element of the input tensor and then computes the mean of these exponential values. The function allows specifying dimensions to reduce, whether to keep dimensions, and the data type of the output.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8a70c2f4fded4de79b5c5303cc5dc73c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes a rank-k approximation of a matrix using its Singular Value Decomposition (SVD). The function retains the top-k singular values and corresponding singular vectors from the SVD of A to form the approximation Ak. This low-rank approximation minimizes the Frobenius norm of the difference between A and Ak among all rank-k matrices. Supports input of float, double, cfloat, and cdouble dtypes, and batches of matrices.\nWrapper Entry Information: def low_rank_svd_approximation(A, k, *, full_matrices=True, out=None) -> Tensor\n\nArgs:\n A (Tensor): Tensor of shape `(*, m, n)` where `*` is zero or more batch dimensions.\n k (int): Rank of the approximation (must satisfy `1 <= k <= min(m, n)`).\n full_matrices (bool, optional): Controls whether to compute the full or reduced SVD. Default: `True`.\n\nKeyword args:\n out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: A \\approx A_k = U_k \\Sigma_k V_k^{\\text{H}}; U_k \\in \\mathbb{K}^{m \\times k}; \\Sigma_k \\in \\mathbb{R}^{k \\times k}; V_k^{\\text{H}} \\in \\mathbb{K}^{k \\times n}\nother: Supports input of float, double, cfloat, and cdouble dtypes; Batches of matrices are supported.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1bcc9abd9154461cb857951cc82f2789", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns the minimum value of each row of the input tensor in the given dimension dim, along with the index location of each minimum value found. If keepdim is True, the output tensors retain the same size as input except in the dimension dim where they are of size 1. Otherwise, dim is squeezed, resulting in the output tensors having 1 fewer dimension than input. If there are multiple minimal values in a reduced row, the indices of the first minimal value are returned. The function can also compare two tensors element-wise and return a tensor with the minimum values.\nWrapper Entry Information: min(input, dim, keepdim=False, *, out=None) -> (Tensor, LongTensor) Args: input (Tensor): the input tensor. dim (int): the dimension to reduce. keepdim (bool): whether the output tensor has :attr:`dim` retained or not. Keyword args: out (tuple, optional): the tuple of two output tensors (min, min_indices)\nother: If there are multiple minimal values in a reduced row, the indices of the first minimal value are returned.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1a96aa0c423349ba95e3564d6c9e8c3d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a symmetric matrix multiplication by multiplying matrix `A` with its transpose, scales the result by `alpha`, adds it to matrix `C` scaled by `beta`, and returns the sum of the absolute values of the resulting matrix.\nWrapper Entry Information: symmetric_mm_and_abs_sum(A: torch.Tensor, C: torch.Tensor, alpha: float, beta: float) -> torch.Tensor\nArgs:\n A (Tensor): Input matrix of shape `(n, m)` for which the symmetric product with its transpose is calculated.\n C (Tensor): Matrix of the same shape as `alpha * torch.mm(A, A.T)` to accumulate the scaled result.\n alpha (float): Scaling factor for the matrix product.\n beta (float): Scaling factor for matrix `C`.\nReturns:\n Tensor: Scalar tensor representing the sum of absolute values of the resulting matrix `C`.\nMath: 1. `C = alpha * torch.mm(A, A.T) + beta * C`; 2. `asum = torch.sum(torch.abs(C))`\nother: Returns a scalar tensor representing the sum of absolute values of the resulting matrix `C`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-659c185115c548589643df14f1c77a25", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the determinant of a square matrix using LU decomposition. The function performs LU decomposition on a given square matrix A and calculates its determinant. It supports matrices over real or complex numbers and can handle batch dimensions. The determinant is computed as the product of the diagonal elements of the upper triangular matrix U from the LU decomposition, adjusted by the sign of the permutation matrix P if pivoting is used. The function assumes A is invertible and supports float, double, cfloat, and cdouble dtypes.\nWrapper Entry Information: determinant_lu(A, *, pivot=True, out=None) -> Tensor; A (Tensor): Tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of square matrices. pivot (bool, optional): Controls whether to compute the LU decomposition with partial pivoting (`True`) or without pivoting (`False`). Default: `True`. out (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: \\det(A) = \\det(P) \\cdot \\prod_{i=1}^{n} U_{ii}; When pivot=False: \\det(A) = \\prod_{i=1}^{n} U_{ii}\nother: This method assumes that A is invertible. If A is singular, the determinant will be zero, and the function may return `inf` or `nan` due to division by zero or numerical instability.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f074ea9a5243428bac40a55e25ce18fa", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a linear transformation to the input tensor followed by a Tanh activation function. This combined operation is useful for introducing non-linearity after a linear transformation, helping to capture complex relationships in the data.\nWrapper Entry Information: def tanh_linear(input, weight, bias=None) -> Tensor: input (Tensor): The input tensor of shape `(*, in_features)`, where `*` represents any number of additional dimensions. weight (Tensor): The weight matrix of shape `(out_features, in_features)`. bias (Tensor, optional): The optional bias tensor of shape `(out_features)`. Default: None.\nMath: The combined operation is defined as: out = tanh(linear(input, weight, bias)) where the linear transformation is applied as y = xA^T + b and Tanh activation is applied element-wise as: Tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))\nother: A linear transformation followed by a Tanh activation helps capture complex relationships by introducing non-linearity.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-e5046812327840df84cf151a4a410978", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns the sum of each row of the input tensor in the given dimension dim. If dim is a list of dimensions, reduce over all of them. If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed, resulting in the output tensor having 1 (or len(dim)) fewer dimension(s).\nWrapper Entry Information: def sum(input, dim, keepdim=False, *, dtype=None) -> Tensor; input (Tensor): the input tensor.; dim (int or tuple of ints, optional): the dimension or dimensions to reduce.; keepdim (bool): whether the output tensor has :attr:`dim` retained or not.; dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.\nother: If dim is a list of dimensions, reduce over all of them. If keepdim is True, the output tensor is of the same size as input except in the dimension(s) dim where it is of size 1. Otherwise, dim is squeezed.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-ac47cee255454660b25d893807c4731d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Creates a one-dimensional tensor of size 'steps' whose values are evenly spaced from base^start to base^end, inclusive, on a logarithmic scale with a specified base. The tensor values are generated in a logarithmic progression from base^start to base^end using the specified number of steps.\nWrapper Entry Information: logspace(start, end, steps, base=10.0, *, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False) -> Tensor\n\nArgs:\n start (float or Tensor): the starting value for the set of points. If `Tensor`, it must be 0-dimensional\n end (float or Tensor): the ending value for the set of points. If `Tensor`, it must be 0-dimensional\n steps (int): size of the constructed tensor\n base (float, optional): base of the logarithm function. Default: 10.0.\n\nKeyword arguments:\n out (Tensor, optional): the output tensor.\n dtype (torch.dtype, optional): the data type to perform the computation in. Default: if None, uses the global default dtype (see torch.get_default_dtype()) when both start and end are real, and corresponding complex dtype when either is complex.\n layout (torch.layout, optional): the desired layout of returned Tensor. Default: torch.strided.\n device (torch.device, optional): the desired device of returned tensor. Default: if None, uses the current device for the default tensor type (see torch.set_default_device). device will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.\n requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.\nMath: (\text{base}^{\text{start}}, \text{base}^{(\text{start} + \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \\ldots, \text{base}^{(\text{start} + (\text{steps} - 2) * \frac{\text{end} - \text{start}}{ \text{steps} - 1})}, \text{base}^{\text{end}})\nother: From PyTorch 1.11, the 'steps' argument is required. Use steps=100 to restore the previous behavior. The function allows specifying various properties of the output tensor such as dtype, layout, and device.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-31089932de764b6a93545a1ca1f976e5", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Solves the triangular system of linear equations Ax = b, where A is a triangular matrix. Then, adds a scaled version of the vector y to the solution x. The operations performed are: 1. Solve the triangular system Ax = b using torch.linalg.solve_triangular with A as an upper triangular matrix. 2. Add the scaled vector alpha * y to the solution x.\nWrapper Entry Information: def solve_and_add_scaled_vector(A: torch.Tensor, b: torch.Tensor, y: torch.Tensor, alpha: float) -> torch.Tensor: A (Tensor): A triangular matrix of shape `(n, n)`. b (Tensor): Right-hand side vector or matrix of shape `(n,)` or `(n, k)`. y (Tensor): Vector to be scaled and added, must have shape `(n,)` or broadcastable to `(n,)`. alpha (float): Scaling factor for the vector y.\nMath: x = torch.linalg.solve_triangular(A, b, upper=True)\nx += alpha * y\nother: The function assumes A is an upper triangular matrix.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-fabf0f38be3c48e385547bb1eb32ae71", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution followed by pixel shuffle upscaling to rearrange the spatial dimensions. This function sequentially applies a 2D convolution operation and then rearranges the elements of the convolution output to increase the spatial resolution by the upscale_factor.\nWrapper Entry Information: def pixel_shuffle_conv2d(input: torch.Tensor, weight: torch.Tensor, bias=None, stride=1, padding=0, dilation=1, groups=1, upscale_factor=2) -> torch.Tensor: Input tensor of shape (minibatch, in_channels, iH, iW). Convolution filter tensor of shape (out_channels, in_channels/groups, kH, kW). Optional bias tensor of shape (out_channels). Stride of the convolving kernel. Padding added to all four sides of the input. Spacing between kernel elements. Number of blocked connections from input channels to output channels. Factor by which to increase spatial resolution.\nother: The function first applies a 2D convolution and then uses pixel shuffle to upscale the spatial dimensions by the given upscale_factor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3b11d4629e6e4254acc225208c9959bb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the matrix-vector product `y = alpha * torch.mv(A, x) + beta * y` and then returns the dot product `torch.dot(y, x)`. The function first computes a scaled matrix-vector product and updates `y`, then calculates the dot product of the updated `y` with `x`. It requires an input matrix `A` of shape `(n, m)`, an input vector `x` of shape `(m,)`, and a target vector `y` of shape `(n,)` that is modified in-place. The scalar `alpha` is a multiplier for `torch.mv(A, x)`, while `beta` is a multiplier for `y`.\nWrapper Entry Information: def matrix_vector_dot(A: Tensor, x: Tensor, y: Tensor, alpha: float, beta: float) -> Tensor: \nArgs:\n A (Tensor): The input matrix of shape `(n, m)`.\n x (Tensor): The input vector of shape `(m,)`.\n y (Tensor): The target vector to be modified, of shape `(n,)`.\n alpha (float): Scalar multiplier for `torch.mv(A, x)`.\n beta (float): Scalar multiplier for `y`.\nMath: y = alpha * torch.mv(A, x) + beta * y; result = torch.dot(y, x)\nother: The function modifies the `y` vector in-place and calculates a dot product after the update.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f59fa2c7622c45649d2ebd96a1c9eef2", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Gaussian Error Linear Units (GELU) activation on the input tensor, then returns the minimum value along the specified dimension(s) or over all elements if no dimension is specified. The function supports two methods for computing GELU: exact and approximate using 'tanh'.\nWrapper Entry Information: min_gelu(input, dim=None, keepdim=False, approximate='none', out=None) -> Tensor: input (Tensor): The input tensor. dim (int, optional): The dimension to reduce. If ``None``, returns the minimum of all elements. keepdim (bool, optional): Whether the output tensor retains :attr:`dim` as size 1. Default is ``False``. approximate (str, optional): The approximation method for GELU. Default is 'none'. out (Tensor, optional): The output tensor.\nMath: out = min(GELU(input))\n\nGELU(x) = x * Φ(x) if approximate is 'none'\nGELU(x) = 0.5 * x * (1 + Tanh(√(2/π) * (x + 0.044715 * x^3))) if approximate is 'tanh'\nother: Returns a namedtuple (values, indices) if dim is specified, otherwise returns the minimum value tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-aa17cdc9ea3b4b9692480d221ed2437b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Takes the power of each element in input with exponent and returns a tensor with the result. exponent can be either a single float number or a Tensor with the same number of elements as input. If exponent is a scalar value, the operation applied is out_i = x_i ^ exponent. If exponent is a tensor, the operation applied is out_i = x_i ^ exponent_i. When exponent is a tensor, the shapes of input and exponent must be broadcastable.\nWrapper Entry Information: pow(input, exponent, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. exponent (float or tensor): the exponent value; Keyword args: out (Tensor, optional): the output tensor.\nMath: out_i = x_i ^ exponent (for scalar exponent)\nout_i = x_i ^ exponent_i (for tensor exponent)\nother: The operation supports both scalar and tensor exponents. When exponent is a tensor, its shape must be broadcastable with the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-fa93e89275484a3aa3306469ffc19232", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over the input tensor, followed by max pooling and then applies the ReLU activation function element-wise to the pooled result. This combined operation is often used in convolutional neural networks (CNNs) for feature extraction, downsampling, and adding non-linearity.\nWrapper Entry Information: relu_max_pool2d_conv2d(input, weight, bias=None, conv_stride=1, conv_padding=0, conv_dilation=1, conv_groups=1, pool_kernel_size=2, pool_stride=None, pool_padding=0, pool_dilation=1, pool_ceil_mode=False, inplace=False) -> Tensor: input (Tensor): The input tensor of shape `(minibatch, in_channels, iH, iW)`. weight (Tensor): The convolution filters of shape `(out_channels, in_channels / groups, kH, kW)`. bias (Tensor, optional): Optional bias tensor of shape `(out_channels)`. Default: None. conv_stride (int or tuple, optional): The stride of the convolution kernel. Default: 1. conv_padding (int, tuple, or string, optional): Padding added to all sides of the input in convolution. Default: 0. conv_dilation (int or tuple, optional): The spacing between kernel elements in convolution. Default: 1. conv_groups (int, optional): Number of blocked connections from input channels to output channels in convolution. Default: 1. pool_kernel_size (int or tuple): The size of the pooling region in max pooling. pool_stride (int or tuple, optional): The stride of the pooling operation. Default: `pool_kernel_size`. pool_padding (int or tuple, optional): Padding added to all sides of the input in max pooling. Default: 0. pool_dilation (int or tuple, optional): The stride between elements within a sliding window in max pooling. Default: 1. pool_ceil_mode (bool, optional): If True, uses `ceil` instead of `floor` to compute output shape. Default: False. inplace (bool, optional): If True, performs ReLU in-place. Default: False.\nMath: \n \\text{out} = \\text{ReLU}(\\text{MaxPool2D}(\\text{conv2d}(\\text{input})))\n\nwhere the ReLU function is applied element-wise as:\n\n \\text{ReLU}(x) = \\max(0, x)\nother: The function is typically used in CNNs.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-ba5e0d7afa334a6c9a9fb928e3e3a67b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the error function of the input tensor. The error function is used in probability, statistics, and partial differential equations describing diffusion.\nWrapper Entry Information: erf(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: \\mathrm{erf}(x) = \\frac{2}{\\sqrt{\\pi}} \\int_{0}^{x} e^{-t^2} dt\nother: The function outputs a tensor with values representing the error function of each element in the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f30359dfb1514b54a0560bb570006024", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: This function computes the sigmoid of the input tensor element-wise. The sigmoid function is a common activation function used in neural networks, which maps any real-valued number into the range (0, 1).\nWrapper Entry Information: sigmoid(input, *, out=None) -> Tensor\nMath: The sigmoid function is defined as: sigmoid(x) = 1 / (1 + exp(-x))\nother: Alias for torch.special.expit.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-17a9f84522e74f8d980c07ebfc722a6b", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the Gaussian Error Linear Unit (GELU) activation function element-wise to the input tensor. The function can be computed exactly or approximately using a tanh-based formula depending on the 'approximate' argument.\nWrapper Entry Information: gelu(input, approximate='none') -> Tensor\nMath: When approximate is 'none': GELU(x) = x * Φ(x), where Φ(x) is the Cumulative Distribution Function for Gaussian Distribution. When approximate is 'tanh': GELU(x) = 0.5 * x * (1 + Tanh(√(2/π) * (x + 0.044715 * x^3)))\nother: See Gaussian Error Linear Units (GELUs) https://arxiv.org/abs/1606.08415\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-47ab9b2df14a4716a755b572550d005c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the determinant of a square matrix. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions.\nWrapper Entry Information: linalg.det(A, *, out=None) -> Tensor; A (Tensor): tensor of shape (*, n, n) where * is zero or more batch dimensions; out (Tensor, optional): output tensor. Ignored if None. Default: None.\nother: :func:`torch.linalg.slogdet` computes the sign and natural logarithm of the absolute value of the determinant of square matrices.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-536ad43d80e44453b64fc5d527e231a1", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining batch matrix multiplication, RMS normalization, GELU activation, and dropout.\nWrapper Entry Information: fused_bmm_rmsnorm_gelu_dropout(input1, input2, normalized_shape, dropout_p=0.1, eps=1e-5, training=True, approximate='none', *, out=None) -> Tensor; input1 (Tensor): First input tensor for bmm, of shape (B, N, M), where B is the batch size; input2 (Tensor): Second input tensor for bmm, of shape (B, M, P); normalized_shape (int or list or torch.Size): Input shape from an expected input of size (B, N, P). This is the shape over which RMS normalization is applied; dropout_p (float, optional): Probability of an element to be zeroed in the dropout layer. Default: 0.1; eps (float, optional): A value added to the denominator for numerical stability in RMS normalization. Default: 1e-5; training (bool, optional): Apply dropout if True. Default: True; approximate (str, optional): Can be 'none' or 'tanh'. The approximation to use for GELU. Default: 'none'; out (Tensor, optional): Output tensor. Ignored if None. Default: None.\nMath: Given two input tensors X and Y, this function computes: \\[ \\begin{align*} Z_1 &= \\text{bmm}(X, Y) \\\\ Z_2 &= \\text{RMSNorm}(Z_1, \\epsilon) \\\\ Z_3 &= \\text{GELU}(Z_2) \\\\ Z &= \\text{Dropout}(Z_3, p) \\end{align*} \\] where: \\- \\text{bmm}(X, Y) performs batch matrix multiplication. \\- \\text{RMSNorm}(Z_1, \\epsilon) = \\frac{Z_1}{\\sqrt{\\text{mean}(Z_1^2, \\text{dim}=\\text{last}) + \\epsilon}} \\times \\gamma, where \\gamma is a learnable parameter (if `elementwise_affine=True`). \\- \\text{GELU}(Z_2) applies the Gaussian Error Linear Unit activation function element-wise. \\- \\text{Dropout}(Z_3, p) randomly zeroes elements of Z_3 with probability p.\nother: - The shapes of `input1` and `input2` must be compatible for batch matrix multiplication: `input1` of shape `(B, N, M)` and `input2` of shape `(B, M, P)` result in an output of shape `(B, N, P)`.\n- The `normalized_shape` argument for RMS normalization should match the shape of the last dimension(s) of the output tensor over which to compute the RMS.\n- The `GELU` activation is applied element-wise to the normalized output.\n- The `dropout` is applied during training when `training=True`. Set `training=False` to disable dropout during evaluation.\n- All operations are differentiable and support autograd.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-580130884817408d8c27ea57df9d733a", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the floor of the elements of the input, the largest integer less than or equal to each element. For integer inputs, follows the array-api convention of returning a copy of the input tensor.\nWrapper Entry Information: floor(input, *, out=None) -> Tensor\nArgs:\n input (Tensor): the input tensor.\n\nKeyword args:\n out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\left\\lfloor \\text{input}_{i} \\right\\rfloor\nother: For integer inputs, the function returns a copy of the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-172f014718f34f869824a75fdb9b3094", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a tensor filled with random numbers from a uniform distribution on the interval [0, 1). The shape of the tensor is defined by the variable argument size.\nWrapper Entry Information: rand(*size, *, generator=None, out=None, dtype=None, layout=torch.strided, device=None, requires_grad=False, pin_memory=False) -> Tensor\n\nArgs:\n size (int...): a sequence of integers defining the shape of the output tensor. Can be a variable number of arguments or a collection like a list or tuple.\n\nKeyword args:\n generator (:class:`torch.Generator`, optional): a pseudorandom number generator for sampling\n out (Tensor, optional): the output tensor.\n dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor. Default: if ``None``, uses a global default (see :func:`torch.set_default_dtype`).\n layout (:class:`torch.layout`, optional): the desired layout of returned Tensor. Default: ``torch.strided``.\n device (:class:`torch.device`, optional): the desired device of returned tensor. Default: if ``None``, uses the current device for the default tensor type (see :func:`torch.set_default_device`). :attr:`device` will be the CPU for CPU tensor types and the current CUDA device for CUDA tensor types.\n requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: ``False``.\n pin_memory (bool, optional): If set, returned tensor would be allocated in the pinned memory. Works only for CPU tensors. Default: ``False``.\nother: The function can take a variable number of arguments to define the shape of the tensor. It supports optional parameters for generator, output tensor, data type, layout, device, autograd recording, and pinned memory.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-a6cb6970cb9c4598aa966bc6942f93d8", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the solution of a system of linear equations with complex Hermitian or real symmetric positive-definite lhs given its Cholesky decomposition. Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if :math:`A` or :math:`B` is a batch of matrices then the output has the same batch dimensions.\nWrapper Entry Information: cholesky_solve(B, L, upper=False, *, out=None) -> Tensor; B (Tensor): right-hand side tensor of shape (*, n, k) where * is zero or more batch dimensions; L (Tensor): tensor of shape (*, n, n) where * is zero or more batch dimensions consisting of lower or upper triangular Cholesky decompositions of symmetric or Hermitian positive-definite matrices; upper (bool, optional): flag that indicates whether L is lower triangular or upper triangular. Default: False; out (Tensor, optional): output tensor. Ignored if None. Default: None\nMath: A = LL^{\\text{H}}; AX = B\nother: Supports float, double, cfloat, cdouble dtypes; Handles batches of matrices; Uses Cholesky decomposition\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-23d22a2a0b5949789ba007e8fd8e5f93", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Multiplies the input tensor by another tensor or number, then subtracts another tensor or number from the result, scaled by a given alpha. This operation is performed element-wise.\nWrapper Entry Information: def mul_sub(input, other_mul, other_sub, alpha=1, out=None) -> Tensor: input (Tensor): The input tensor to be multiplied. other_mul (Tensor or Number): The tensor or number to multiply with `input`. other_sub (Tensor or Number): The tensor or number to subtract from the multiplication result. alpha (Number, optional): The multiplier for :attr:`other_sub`. Default is 1. out (Tensor, optional): The output tensor.\nMath: \\text{out}_i = (\\text{input}_i \\times \\text{other\\_mul}_i) - \\text{alpha} \\times \\text{other\\_sub}_i\nother: The function allows for element-wise operations and supports both tensor and scalar inputs for multiplication and subtraction. The output can be stored in a specified tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-3a972c1556d2460ea5090fd8e1c73be6", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes a compact representation of the LDL factorization of a Hermitian or symmetric (possibly indefinite) matrix. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions. When A is complex valued it can be Hermitian (hermitian=True) or symmetric (hermitian=False). The factorization is of the form A = L D L^T. If hermitian is True then transpose operation is the conjugate transpose. L (or U) and D are stored in compact form in LD. They follow the format specified by LAPACK's sytrf function. These tensors may be used in torch.linalg.ldl_solve to solve linear systems.\nWrapper Entry Information: linalg.ldl_factor(A, *, hermitian=False, out=None) -> (Tensor, Tensor)\n\nArgs:\n A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of symmetric or Hermitian matrices.\n\nKeyword args:\n hermitian (bool, optional): whether to consider the input to be Hermitian or symmetric. For real-valued matrices, this switch has no effect. Default: `False`.\n out (tuple, optional): tuple of two tensors to write the output to. Ignored if `None`. Default: `None`.\n\nReturns:\n A named tuple `(LD, pivots)`.\nMath: A = L D L^T\nother: When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see torch.linalg.ldl_factor_ex.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-b5d21b3de60a4ba4a05b2d523e1ecc8f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the absolute value of each element in the input tensor.\nWrapper Entry Information: abs(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = |\\text{input}_{i}|\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-71baef9db7104be1819ab8f0c31187da", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Multiplies the input tensor by another tensor or a number, supporting broadcasting to a common shape, type promotion, and integer, float, and complex inputs.\nWrapper Entry Information: mul(input, other, *, out=None) -> Tensor\n input (Tensor): the input tensor.\n other (Tensor or Number) - the tensor or number to multiply input by.\n out (Tensor, optional): the output tensor.\nMath: \\text{out}_i = \\text{input}_i \\times \\text{other}_i\nother: Supports broadcasting and type promotion.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f51bac3ed24e40beb2f8d5041a140c84", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Apply a softmax function to all slices along the specified dimension, re-scaling them so that the elements lie in the range [0, 1] and sum to 1.\nWrapper Entry Information: def softmax(input, dim, dtype=None) -> Tensor: input (Tensor): input; dim (int): A dimension along which softmax will be computed.; dtype (torch.dtype, optional): the desired data type of returned tensor. If specified, the input tensor is casted to dtype before the operation is performed. This is useful for preventing data type overflows. Default: None.\nMath: Softmax(x_i) = exp(x_i) / sum_j exp(x_j)\nother: This function doesn't work directly with NLLLoss, which expects the Log to be computed between the Softmax and itself. Use log_softmax instead (it's faster and has better numerical properties).\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1e86017637da48a7a9803d5bfda9c102", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies the Leaky ReLU activation function element-wise to the input tensor. The function is defined as LeakyReLU(x) = max(0, x) + negative_slope * min(0, x), where negative_slope is a small constant that allows a small, non-zero gradient when the unit is not active.\nWrapper Entry Information: leaky_relu(input, negative_slope=0.01, inplace=False) -> Tensor\nMath: LeakyReLU(x) = max(0, x) + negative_slope * min(0, x)\nother: See torch.nn.LeakyReLU for more details.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-5f2629245b0141738177fbea44858a10", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the inverse of a square matrix using LU decomposition. Given a square invertible matrix A, it computes the inverse A^{-1} by performing LU decomposition and solving linear systems involving triangular matrices. Supports inputs of 'float', 'double', 'cfloat', and 'cdouble' dtypes, as well as batches of matrices.\nWrapper Entry Information: invert_matrix_lu(A, *, pivot=True, out=None) -> Tensor\nMath: A = P L U\nA^{-1} = U^{-1} L^{-1} P\nY = L^{-1} P\nA^{-1} = U^{-1} Y\nother: The function allows computing the inverse with or without pivoting (partial pivoting by default). It can handle batches of matrices, and an output tensor can be specified which will be ignored if set to None.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-9e58a2371fd64537bd6d437d98e33fdb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Calculates the standard deviation over the specified dimensions of the input tensor. The dim argument can specify a single dimension, a list of dimensions, or None to reduce over all dimensions. If keepdim is set to True, the output tensor retains the reduced dimensions as size 1; otherwise, these dimensions are removed. The correction parameter adjusts the calculation for the difference between sample size and degrees of freedom, defaulting to Bessel's correction with correction=1.\nWrapper Entry Information: def std(input, dim=None, *, correction=1, keepdim=False, out=None) -> Tensor: input (Tensor): the input tensor. dim (int or tuple of ints): the dimension or dimensions to reduce. correction (int): difference between the sample size and sample degrees of freedom. Defaults to `Bessel's correction`, correction=1. keepdim (bool): whether the output tensor has dim retained or not. out (Tensor, optional): the output tensor.\nMath: \\sigma = \\sqrt{\\frac{1}{\\max(0,~N - \\delta N)}\\sum_{i=0}^{N-1}(x_i-\\bar{x})^2}\nother: The standard deviation function has undergone a change in version 2.0, where the argument previously called unbiased has been renamed to correction. Bessel's correction link: https://en.wikipedia.org/wiki/Bessel%27s_correction\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-f359b4c150724486982dbf2f7f7bfee8", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a matrix multiplication of the lower triangular part of matrix `A` with matrix `B`, scales the result by `alpha`, and then scales the final output by `beta`. The operations are as follows:\n1. Perform matrix multiplication between the lower triangular part of `A` (denoted as `torch.tril(A)`) and `B`, and scale the result by `alpha`.\n2. Scale the resulting matrix from step 1 by `beta` to obtain the final result.\nWrapper Entry Information: def tril_mm_and_scale(A: torch.Tensor, B: torch.Tensor, alpha: float, beta: float) -> torch.Tensor: A (Tensor): A 2D matrix to be multiplied, of shape (n, n). B (Tensor): A matrix to be multiplied with the lower triangular part of A, of shape (n, p). alpha (float): Scaling factor for the initial matrix multiplication result. beta (float): Scaling factor for the final result.\nMath: B = alpha * torch.mm(torch.tril(A), B)\nC = beta * B\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-6855d52e8dc4451dbdadc700c03a6746", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the solution of a square system of linear equations with a unique solution. Supports inputs of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if the inputs are batches of matrices then the output has the same batch dimensions. Assumes that matrix A is invertible.\nWrapper Entry Information: A (Tensor), B (Tensor), *, left (bool, optional), out (Tensor, optional)\nMath: AX = B; XA = B\nother: This function computes `X = A.inverse() @ B` in a faster and more numerically stable way than performing the computations separately. When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see `torch.linalg.solve_ex`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8ea8849df9b24a91809cd8738fa3a5c9", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Airy function Ai for each element of the input tensor.\nWrapper Entry Information: airy_ai(input, *, out=None) -> Tensor Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: Airy function :math:`\\text{Ai}\\left(\\text{input}\\right)`.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-aecb03abd3124ad49388b16605028005", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Tests if each element of the input tensor has its sign bit set or not. It handles signed zeros, so negative zero (-0) returns True.\nWrapper Entry Information: signbit(input, *, out=None) -> Tensor; Args: input (Tensor): the input tensor.; Keyword args: out (Tensor, optional): the output tensor.\nother: signbit handles signed zeros, so negative zero (-0) returns True.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-c0f63db0a8d84d1da24213d03b505974", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes a scaled matrix-matrix product, then calculates the dot product of the first two rows of the resulting matrix. First, it multiplies matrix A and B using the scalar alpha and then adds the scaled version of matrix C using scalar beta. Finally, it computes the dot product of the first two rows of the updated matrix C.\nWrapper Entry Information: def matrix_multiply_and_row_dot(A: torch.Tensor, B: torch.Tensor, alpha: float, beta: float, C: torch.Tensor) -> torch.Tensor: A (Tensor): First input matrix of shape `(n, m)`. B (Tensor): Second input matrix of shape `(m, p)`. alpha (float): Scalar multiplier for the matrix-matrix product. beta (float): Scalar multiplier for the input matrix `C`. C (Tensor): Output matrix of shape `(n, p)` where the results are added.\nMath: 1. `C = alpha * torch.mm(A, B) + beta * C`; 2. `result = torch.dot(C[0], C[1])`\nother: Assumes `C` has at least two rows for the dot product to be computed.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-02c5ca5a4ca444d6a8aac3278647e2be", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the n-th derivative of the digamma function on input. The function is implemented for nonnegative integers n >= 0.\nWrapper Entry Information: def polygamma(n, input, *, out=None) -> Tensor: n (int): the order of the polygamma function; input (Tensor): the input tensor.; out (Tensor, optional): the output tensor.\nMath: \\psi^{(n)}(x) = \\frac{d^{(n)}}{dx^{(n)}} \\psi(x)\nother: Implemented only for nonnegative integers n >= 0.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-60b9ddf2ac9a4a34b1a2ae077afdf8f4", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a linear transformation to the input tensor, followed by the Exponential Linear Unit (ELU) activation function applied element-wise. This combined operation first performs a linear transformation and then introduces non-linearity with ELU.\nWrapper Entry Information: def elu_linear(input, weight, bias=None, alpha=1.0, inplace=False) -> Tensor: input (Tensor): The input tensor for the linear layer. weight (Tensor): The weight tensor for the linear transformation. bias (Tensor, optional): The bias tensor for the linear transformation. Default: None. alpha (float, optional): The \\(\\alpha\\) parameter for the ELU function. Default: 1.0. inplace (bool, optional): Whether to apply ELU in-place. Default: False.\nMath: \\text{out} = \\text{ELU}(\\text{Linear}(x))\n\n\\text{ELU}(x) = \\begin{cases}\n x, & \\text{ if } x > 0\\\\\n \\alpha * (\\exp(x) - 1), & \\text{ if } x \\leq 0\n\\end{cases}\nother: The function integrates linear transformation and ELU activation. The ELU activation applies element-wise to incorporate non-linearity after linear mapping.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-514b7dabc27a48b097098f6986cfac12", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the pairwise distance between two input tensors `x1` and `x2` after normalizing both tensors. Normalization is performed along the specified dimension, followed by pairwise distance calculation.\nWrapper Entry Information: def fused_pairwise_distance_normalize(x1: torch.Tensor, x2: torch.Tensor, p_norm: float = 2.0, eps_norm: float = 1e-12, eps_distance: float = 1e-6, keepdim: bool = False) -> torch.Tensor\nArgs:\n x1 (Tensor): First input tensor.\n x2 (Tensor): Second input tensor.\n p_norm (float, optional): The exponent value in the norm for normalization. Default: 2.\n eps_norm (float, optional): Small value to avoid division by zero during normalization. Default: 1e-12.\n eps_distance (float, optional): Small value to avoid division by zero in distance calculation. Default: 1e-6.\n keepdim (bool, optional): If `True`, retains the last dimension in the output. Default: `False`.\nother: Normalization is performed along the specified dimension. Small values `eps_norm` and `eps_distance` are used to avoid division by zero during normalization and distance calculation, respectively.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-e1e036a7a3c547bd8d5311a08a5b5997", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Implements the Adam optimization algorithm, which is an adaptive learning rate optimization algorithm designed for training deep neural networks. It computes individual adaptive learning rates for different parameters from estimates of first and second moments of the gradients. The algorithm can optionally use the AMSGrad variant, apply weight decay, and maximize the objective function. It supports various implementation optimizations like foreach and fused implementations for performance improvements on CUDA.\nWrapper Entry Information: def Adam(params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, amsgrad=False, foreach=None, maximize=False, capturable=False, differentiable=False, fused=None) -> Optimizer\nMath: m_t = \\beta_1 m_{t-1} + (1 - \\beta_1) g_t; v_t = \\beta_2 v_{t-1} + (1-\\beta_2) g^2_t; \\widehat{m_t} = m_t/(1-\\beta_1^t); \\widehat{v_t} = v_t/(1-\\beta_2^t); \\theta_t = \\theta_{t-1} - \\gamma \\widehat{m_t}/(\\sqrt{\\widehat{v_t}} + \\epsilon)\nother: The foreach and fused implementations are typically faster than the for-loop, single-tensor implementation. The algorithm is based on the paper 'Adam: A Method for Stochastic Optimization'.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-03bc8db2c11a462db455c5f133949ebb", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining horizontal stacking (hstack) and element-wise division. The function first horizontally stacks a sequence of tensors and then divides each element of the resulting tensor by the corresponding element of a divisor tensor, with optional rounding modes.\nWrapper Entry Information: fused_hstack_div(tensors, divisor, *, rounding_mode=None, out=None) -> Tensor\n\n- **tensors** (sequence of Tensors): Sequence of tensors to be horizontally stacked. The tensors must have compatible shapes for stacking.\n- **divisor** (Tensor or Number): The tensor or number to divide the stacked tensor by. Must be broadcastable to the shape of the stacked tensor.\n- **rounding_mode** (str, optional): Type of rounding applied to the result:\n - `None`: Default behavior. Performs no rounding and, if both `input` and `divisor` are integer types, promotes the inputs to the default scalar type. Equivalent to true division in Python (`/` operator).\n - `'trunc'`: Rounds the results of the division towards zero.\n - `'floor'`: Rounds the results of the division down.\n\n Default: `None`.\n\n- **out** (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: Given a sequence of tensors [X_1, X_2, \\dots, X_n] and a divisor tensor D, the function computes:\n\n1. **Horizontal Stacking:**\n\n\\[\nX = \\text{hstack}(X_1, X_2, \\dots, X_n)\n\\]\n\n2. **Element-wise Division:**\n\n\\[\nY = \\frac{X}{D}\n\\]\nother: - The tensors in `tensors` must have shapes that are compatible for horizontal stacking, i.e., the dimensions except for the stacking dimension must be the same.\n- The `divisor` tensor must be broadcastable to the shape of the stacked tensor.\n- The function supports autograd for gradient computation.\n- All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-373c9833aa20491b8d0f164413265869", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Broadcasts the given tensors according to broadcasting semantics. This function takes multiple tensors as input and broadcasts them to have the same shape. Broadcasting refers to expanding the dimensions of tensors as necessary to make them compatible for element-wise operations. The broadcasted tensors share the same memory location for their elements, leading to potential issues with in-place operations.\nWrapper Entry Information: broadcast_tensors(*tensors) -> List of Tensors: *tensors (Args: any number of tensors of the same type) -> Example: x = torch.arange(3).view(1, 3), y = torch.arange(2).view(2, 1), a, b = torch.broadcast_tensors(x, y), a.size() == torch.Size([2, 3]), a == tensor([[0, 1, 2],[0, 1, 2]])\nother: More than one element of a broadcasted tensor may refer to a single memory location. In-place operations may result in incorrect behavior. If writing to tensors is needed, clone them first.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8050c0195af44fc39f4be117c5679de6", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over an input tensor, followed by applying the rectified linear unit (ReLU) activation function element-wise on the result. This operation first applies a 2D convolution over the input tensor using the specified filters, and then applies ReLU activation to the convolution result, setting all negative values to zero.\nWrapper Entry Information: relu_conv2d(input, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, inplace=False) -> Tensor: input (Tensor): The input tensor of shape (minibatch, in_channels, iH, iW). weight (Tensor): The convolution filters of shape (out_channels, in_channels / groups, kH, kW). bias (Tensor, optional): Optional bias tensor of shape (out_channels). Default: None. stride (int or tuple, optional): The stride of the convolution kernel. Default: 1. padding (int, tuple, or string, optional): Padding added to all sides of the input. Default: 0. dilation (int or tuple, optional): The spacing between kernel elements. Default: 1. groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1. inplace (bool, optional): If True, will perform ReLU operation in-place. Default: False.\nMath: The operation is defined as: \\text{out} = \\text{ReLU}(\\text{conv2d}(\\text{input})), where \\text{ReLU}(x) = \\max(0, x).\nother: Returns:\n Tensor: A tensor resulting from the 2D convolution followed by ReLU activation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-1ff3bd3bc01f4b7ea2acce581e682d0a", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the natural logarithm of the elements of the input tensor.\nWrapper Entry Information: log(input, *, out=None) -> Tensor Args: input (Tensor): the input tensor. Keyword args: out (Tensor, optional): the output tensor.\nMath: y_{i} = \\log_{e} (x_{i})\nother: The function computes the natural logarithm (base e) of each element in the input tensor.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-e27f3597fb8242328141587a95027f22", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Apply a 2D adaptive average pooling over an input signal composed of several input planes. The output is of size H x W, for any input size. The number of output features is equal to the number of input planes. The target output size of the image can be a tuple (H, W) or a single H for a square image H x H. H and W can be either an int, or None which means the size will be the same as that of the input.\nWrapper Entry Information: def adaptive_avg_pool2d(output_size) -> Tensor\nArgs:\n output_size: the target output size (single integer or double-integer tuple)\n\nShape:\n - Input: (N, C, H_in, W_in) or (C, H_in, W_in)\n - Output: (N, C, S_0, S_1) or (C, S_0, S_1), where S=output_size\nother: The target output size can be a single integer for square images or a tuple for rectangular dimensions. H and W can be None to retain input dimensions.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8afb8e554ecf4aff97a0e4253c79e69c", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Converts a float model to a dynamic quantized model by replacing specified modules with their dynamic weight-only quantized versions. Provides simple usage with a dtype argument (either float16 or qint8), and fine-grained control with qconfig and mapping parameters. The process is performed in-place if specified, transforming the original model.\nWrapper Entry Information: quantize_dynamic(model, qconfig_spec=None, inplace=False, mapping=None) -> Model\nArgs:\n model: input model\n qconfig_spec: Either a dictionary mapping submodule names/types to quantization configurations or a set of types/names for dynamic quantization\n inplace: carry out model transformations in-place, mutating the original module\n mapping: maps submodule types to dynamically quantized versions\nother: Dynamic quantization is typically performed on layers with large weight sizes such as Linear and RNN variants. The qconfig_spec can be a dictionary mapping submodule types or names to quantization configurations, or a set specifying which submodules to apply dynamic quantization to. If qconfig is provided, it overrides dtype.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-116e236fa2714fae998822e835c5c7a1", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Applies a 2D convolution over an input image using specified filters and an optional bias, then adds another tensor or scalar to the convolution result, scaled by alpha. The input tensor shape is (minibatch, in_channels, iH, iW), and the weight tensor shape is (out_channels, in_channels / groups, kH, kW). The function also allows for setting the stride, padding, dilation, groups, and an optional output tensor.\nWrapper Entry Information: conv2d_add(input, weight, bias=None, other=None, stride=1, padding=0, dilation=1, groups=1, alpha=1, out=None) -> Tensor: input (Tensor): The input tensor of shape (minibatch, in_channels, iH, iW). weight (Tensor): The convolution filters of shape (out_channels, in_channels / groups, kH, kW). bias (Tensor, optional): Optional bias tensor of shape (out_channels). Default: None. other (Tensor or Number, optional): The tensor or number to add to the convolution result. Default: None. stride (int or tuple, optional): The stride of the convolution kernel. Can be a single number or a tuple (sH, sW). Default: 1. padding (int, tuple, or string, optional): Padding on both sides of the input. Can be 'valid', 'same', single number, or tuple (padH, padW). Default: 0. dilation (int or tuple, optional): The spacing between kernel elements. Default: 1. groups (int, optional): Number of groups to split the input into, must divide in_channels and out_channels. Default: 1. alpha (Number, optional): The multiplier for other. Default: 1. out (Tensor, optional): The output tensor.\nMath: \\text{out} = \\text{conv2d}(\\text{input}, \\text{weight}) + \\alpha \\times \\text{other}\nother: The 'groups' argument must divide both in_channels and out_channels. Padding can be specified as 'valid', 'same', a single number, or a tuple. The output tensor shape depends on convolution parameters.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-7bf6eff22256447782da374f4fb4ecd2", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: The function torch.fft.ifftshift is the inverse of torch.fft.fftshift. It rearranges the elements of the input tensor, which is in FFT order, such that the zero-frequency component is moved back to the original position. This is useful for preparing data for inverse FFT operations. The function can rearrange specified dimensions or all dimensions by default.\nWrapper Entry Information: ifftshift(input, dim=None) -> Tensor\n\nArgs:\n input (Tensor): the tensor in FFT order\n dim (int, Tuple[int], optional): The dimensions to rearrange.\n Only dimensions specified here will be rearranged, any other dimensions\n will be left in their original order.\n Default: All dimensions of input.\nother: Inverse of torch.fft.fftshift.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8bf05d5dc7404f4ea036130533e3578d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the sign bit check and the bitwise AND operation on the input tensors. `signbit` checks if the sign bit of each element in `input` is set, returning True for negative values, including -0. `bitwise_and` computes the bitwise AND between `input` and `other`, with the tensors needing to be of integral or boolean types.\nWrapper Entry Information: def signbit_bitwise_and(input: torch.Tensor, other: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n Args:\n input (Tensor): The input tensor.\n other (Tensor): The second tensor for bitwise AND, should be of integral or boolean types.\n Example:\n >>> a = torch.tensor([0.7, -1.2, 0., 2.3])\n >>> b = torch.tensor([1, 0, 1, 1], dtype=torch.int8)\n >>> signbit_result, bitwise_and_result = signbit_bitwise_and(a, b)\n >>> signbit_result\n tensor([False, True, False, False])\n >>> bitwise_and_result\n tensor([0, 0, 0, 0], dtype=torch.int8)\nother: The tensors need to be of integral or boolean types for the bitwise AND operation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-5cd43fb7d32f415590ca4dfa123762b6", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Performs a fused operation combining element-wise repeat interleave and log-softmax activation. First, the input tensor is repeated along the specified dimension according to the values in 'repeats'. Then, a log-softmax activation is applied to the repeated tensor along the specified dimension. This function is differentiable and supports autograd for gradient computation, making it useful for backpropagation in neural networks.\nWrapper Entry Information: fused_repeat_interleave_log_softmax(input, repeats, dim=None, *, output_size=None, dtype=None, out=None) -> Tensor\nMath: Given an input tensor X and repeats r, the function computes: 1. Repeat Interleave: The input tensor is repeated along the specified dimension: Y = repeat_interleave(X, r, dim). 2. Log-Softmax Activation: The log-softmax function is applied to the repeated tensor along the specified dimension: Z_i = log( exp(Y_i) / sum_j exp(Y_j) ) where the summation is over the specified dimension.\nother: The 'repeats' parameter controls how many times each element is repeated along the specified dimension. The 'dim' parameter specifies the dimension along which to repeat and apply log-softmax. If 'dim' is None, the input is flattened before repeating. All operations are differentiable and support backpropagation.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-c83a278744b94bd1a0ee2fdcf199989f", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the Cholesky decomposition of a complex Hermitian or real symmetric positive-definite matrix. Supports input of float, double, cfloat and cdouble dtypes. Also supports batches of matrices, and if A is a batch of matrices then the output has the same batch dimensions.\nWrapper Entry Information: def linalg.cholesky(A, *, upper=False, out=None) -> Tensor\n\nArgs:\n A (Tensor): tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions\n consisting of symmetric or Hermitian positive-definite matrices.\n\nKeyword args:\n upper (bool, optional): whether to return an upper triangular matrix.\n The tensor returned with upper=True is the conjugate transpose of the tensor\n returned with upper=False.\n out (Tensor, optional): output tensor. Ignored if `None`. Default: `None`.\nMath: A = LL^{\\text{H}} where L is a lower triangular matrix with real positive diagonal and L^{\\text{H}} is the conjugate transpose when L is complex, and the transpose when L is real-valued.\nother: When inputs are on a CUDA device, this function synchronizes that device with the CPU. For a version of this function that does not synchronize, see torch.linalg.cholesky_ex. Raises RuntimeError if the A matrix or any matrix in a batched A is not Hermitian (resp. symmetric) positive-definite.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-aca437779ac84f62bd511eadb6202c94", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a tensor filled with the scalar value 1, with the same size as the input tensor. It mirrors the properties of the input in terms of dtype, layout, device, and memory format unless specified otherwise. The function does not support the 'out' keyword as of version 0.4, and equivalent operation needs an alternative approach.\nWrapper Entry Information: ones_like(input, *, dtype=None, layout=None, device=None, requires_grad=False, memory_format=torch.preserve_format) -> Tensor; input (Tensor): the size of :attr:`input` will determine size of the output tensor.; dtype (torch.dtype, optional): the desired data type of returned Tensor. Default: if None, defaults to the dtype of :attr:`input`.; layout (torch.layout, optional): the desired layout of returned tensor. Default: if None, defaults to the layout of :attr:`input`.; device (torch.device, optional): the desired device of returned tensor. Default: if None, defaults to the device of :attr:`input`.; requires_grad (bool, optional): If autograd should record operations on the returned tensor. Default: False.; memory_format (torch.memory_format, optional): the desired memory format of returned Tensor. Default: torch.preserve_format.\nother: Function does not support an 'out' keyword as of version 0.4. Use torch.ones for similar functionality if 'out' keyword is needed.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-155a864dd7fb44d58adb6a1c60aa7949", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: The function `torch.cuda.amp.autocast` is deprecated and replaced by `torch.amp.autocast(\"cuda\", args...)`. It allows scripts to run in mixed precision, improving performance while maintaining accuracy. `autocast` serves as a context manager or decorator, wrapping the forward pass(es) of a network and any related loss computations. Tensors can be any type when entering an autocast region, and it is not necessary to manually cast models or inputs to `half()` or `bfloat16()`. The function selects op-specific data types for operations within an autocast region. Backward operations should not be run under autocast, as they execute in the same data type chosen for the corresponding forward operations.\nWrapper Entry Information: autocast(device_type, enabled=True, dtype=None, cache_enabled=True) -> ContextManager\nother: Deprecated in favor of torch.amp.autocast(\"cuda\"). Recommended to use for forward pass and loss computation only. Avoid using for backward passes. State is thread-local. Can be nested with `autocast(enabled=False)` to force a subregion to run in a specific dtype. The use of autocast in a new thread requires invoking the context manager or decorator in that thread.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-8159829c87d0461bb1cbf060d61fe800", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Returns a new tensor with the reciprocal of the elements of the input. Unlike NumPy's reciprocal, this function supports integral inputs by promoting them to the default scalar type.\nWrapper Entry Information: reciprocal(input, *, out=None) -> Tensor; input (Tensor): the input tensor.; out (Tensor, optional): the output tensor.\nMath: \\text{out}_{i} = \\frac{1}{\\text{input}_{i}}\nother: Integral inputs to reciprocal are automatically promoted to the default scalar type.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-bc963f2985a4493c9f0e747073930e5d", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the cosine of each element in the input tensor, followed by determining the sign bit for each cosine result, indicating if it is positive or negative.\nWrapper Entry Information: def cos_signbit(input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: Args: input (Tensor): The input tensor for which the cosine and sign bit are computed.\nMath: \\text{cos\\_result} = \\cos(\\text{input}) \\text{sign\\_bit} = \\text{signbit}(\\text{cos\\_result})\nother: Returns a tuple containing the cosine of each element and a boolean tensor indicating the sign bit of each cosine result.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-7b43064c8d5e4260a988dddb31dcfa46", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the spectral norm (operator norm induced by the Euclidean vector norm) of a square matrix using its eigenvalues. The spectral norm is the largest absolute value among the eigenvalues of a matrix. It supports inputs of float, double, cfloat, and cdouble dtypes and handles batches of matrices.\nWrapper Entry Information: spectral_norm_eig(A, *, out=None) -> Tensor\nA (Tensor): Tensor of shape `(*, n, n)` where `*` is zero or more batch dimensions consisting of square matrices.\nout (Tensor, optional): Output tensor. Ignored if `None`. Default: `None`.\nMath: \\|A\\|_2 = \\max \\{ |\\lambda| : \\lambda \\text{ is an eigenvalue of } A \\}\nother: For normal matrices (where A A^{H} = A^{H} A), the spectral norm equals the largest absolute eigenvalue.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + }, + { + "id": "openseek-8-54e0885a7e7a4f00bd14e3a05e53c090", + "input": "You are an expert in Trion programming, capable of writing corresponding Triton kernels and wrapper functions based on functional descriptions and function parameters. Ensure that the wrapper function fully corresponds to the provided function information.\nFunctional Description: Computes the N dimensional discrete Fourier transform of the input tensor. It returns all positive and negative frequency terms, even though for real inputs, half of these values are redundant. Supports torch.half and torch.chalf on CUDA with GPU Architecture SM53 or greater, but only for powers of 2 signal length in every transformed dimension.\nWrapper Entry Information: fftn(input, s=None, dim=None, norm=None, *, out=None) -> Tensor; input (Tensor): the input tensor; s (Tuple[int], optional): Signal size in the transformed dimensions. If given, each dimension dim[i] will either be zero-padded or trimmed to the length s[i] before computing the FFT. If a length -1 is specified, no padding is done in that dimension. Default: s = [input.size(d) for d in dim]; dim (Tuple[int], optional): Dimensions to be transformed. Default: all dimensions, or the last len(s) dimensions if s is given.; norm (str, optional): Normalization mode. For the forward transform (fftn), these correspond to: 'forward' - normalize by 1/n; 'backward' - no normalization; 'ortho' - normalize by 1/sqrt(n) (making the FFT orthonormal) Where n = prod(s) is the logical FFT size. Calling the backward transform (ifftn) with the same normalization mode will apply an overall normalization of 1/n between the two transforms. This is required to make ifftn the exact inverse. Default is 'backward' (no normalization).; out (Tensor, optional): the output tensor.\nother: The Fourier domain representation of any real signal satisfies the Hermitian property. torch.fft.rfftn returns the more compact one-sided representation where only the positive frequencies of the last dimension are returned.\nAfter generation, verify if the Triton wrapper aligns with the provided func_inputs. If not, regenerate." + } + ], + "License": "Apache 2.0 License" +} \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/outputs/README.md b/openseek/competition/LongContext-ICL-Annotation/outputs/README.md new file mode 100644 index 0000000..27f7c2b --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/outputs/README.md @@ -0,0 +1,55 @@ +## Submission Format (JSONL + ZIP) + +Below is the standard format for submitting model predictions. Please save your predictions into **eight `.jsonl` files** (one per task), then **package them into a single `.zip` archive** and upload it to the FlagOS platform for automatic evaluation. + +--- + +### 1) JSONL File Content + +Each `.jsonl` file consists of multiple JSON objects (**one prediction per line**). +Each prediction must contain the following two fields: + +- `test_sample_id`: corresponds to the sample `id` in the competition dataset. +- `prediction`: the model’s predicted result for that sample. + +**Single-line example:** +```json +{"test_sample_id":"openseek-1-ed5ac69191204cd4bfb0ca41bc7f197f","prediction":"..."} +``` + +### 2) ZIP Archive Requirements (Mandatory) +Each submission must upload **one** `.zip` file, and the archive must contain **8** prediction files: + +- Each filename must start with `openseek-[id]` (e.g., `openseek-1*.jsonl`) +- All **8** tasks correspond to **8** `.jsonl` files for automated scoring. + +> Recommendation: Make sure the `.zip` archive contains these **8** `.jsonl` files directly (no nested folders), and avoid including any unrelated extra files to prevent evaluation parsing issues. + +--- + +## 提交格式说明(JSONL + ZIP) + +以下为标准的模型预测结果提交规范。请将模型预测结果分别保存为 **8 个 `.jsonl` 文件**,并将它们 **打包为一个 `.zip` 压缩包** 后上传至 FlagOS 平台进行自动评测。 + +--- + +### 1) JSONL 文件内容格式 + +每个 `.jsonl` 文件由多行 JSON 对象组成(**一行一个预测结果**)。 +每条预测必须包含以下两个字段: + +- `test_sample_id`:对应赛题数据中的样本 `id` +- `prediction`:模型对该样本的预测结果 + +**单行示例:** +```json +{"test_sample_id":"openseek-1-ed5ac69191204cd4bfb0ca41bc7f197f","prediction":"..."} +``` + +### 2) ZIP 压缩包要求(必须满足) +每次提交需上传 **一个** `.zip` 文件,且该压缩包必须同时包含 **8** 个预测文件: + +- 文件名需以 `openseek-[id]` 开头(例如 `openseek-1*.jsonl`) +- 共 **8** 个任务各对应 **8** 个 `.jsonl` 文件,用于自动化评分 + +> 建议:确保压缩包内直接包含这 **8** 个 `.jsonl` 文件(不嵌套文件夹),并避免额外无关文件,以免影响评测解析。 \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/outputs/result.zip b/openseek/competition/LongContext-ICL-Annotation/outputs/result.zip new file mode 100644 index 0000000..9c5de26 Binary files /dev/null and b/openseek/competition/LongContext-ICL-Annotation/outputs/result.zip differ diff --git a/openseek/competition/LongContext-ICL-Annotation/src/README.md b/openseek/competition/LongContext-ICL-Annotation/src/README.md deleted file mode 100644 index 8b13789..0000000 --- a/openseek/competition/LongContext-ICL-Annotation/src/README.md +++ /dev/null @@ -1 +0,0 @@ - diff --git a/openseek/competition/LongContext-ICL-Annotation/src/api_test.py b/openseek/competition/LongContext-ICL-Annotation/src/api_test.py new file mode 100644 index 0000000..3d30993 --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/src/api_test.py @@ -0,0 +1,21 @@ +import requests + +url = "http://0.0.0.0:2026/v1/completions" +prompts = [ + "Hello, FlagScale + vLLM!", + "Translate 'Hello World' to Chinese.", + "Write a short poem about autumn." + # '用中文写一首短诗,诗句开头用包裹起来' +] + +for prompt in prompts: + data = { + "model": "../Qwen3-4B", + "prompt": prompt, + "max_tokens": 1000 + } + resp = requests.post(url, json=data) + print(f"Prompt: {prompt}") + print("Response:", resp.json(), "\n") + + print("*"*50) diff --git a/openseek/competition/LongContext-ICL-Annotation/src/create_env_nvidia.sh b/openseek/competition/LongContext-ICL-Annotation/src/create_env_nvidia.sh new file mode 100644 index 0000000..8861624 --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/src/create_env_nvidia.sh @@ -0,0 +1,46 @@ + +git clone https://github.com/FlagOpen/FlagScale.git +cd FlagScale + +source ~/miniconda3/etc/profile.d/conda.sh +conda create -n flagscale python=3.11.11 -y +conda activate flagscale + +pip install --upgrade setuptools + +pip --trusted-host pypi.tuna.tsinghua.edu.cn install torch==2.6.0 torchvision==0.21.0 torchaudio==2.6.0 --index-url https://download.pytorch.org/whl/cu124 + +pip install -r ./requirements/requirements-base.txt +pip install -r ./requirements/requirements-common.txt + +pip install deepspeed +pip3 install --no-build-isolation transformer_engine[pytorch]==2.6.0.post1 +pip install nvidia-cudnn-frontend + +cu=$(nvcc --version | grep "Cuda compilation tools" | awk '{print $5}' | cut -d '.' -f 1) +torch=$(pip show torch | grep Version | awk '{print $2}' | cut -d '+' -f 1 | cut -d '.' -f 1,2) +cp=$(python3 --version | awk '{print $2}' | awk -F. '{print $1$2}') +flash_attn_version="2.8.3" +echo "https://github.com/Dao-AILab/flash-attention/releases/download/v${flash_attn_version}/flash_attn-${flash_attn_version}+cu${cu}torch${torch}-cp${cp}-cp${cp}-linux_x86_64.whl" +wget --continue --timeout=60 --no-check-certificate --tries=5 --waitretry=10 https://github.com/Dao-AILab/flash-attention/releases/download/v${flash_attn_version}/flash_attn-${flash_attn_version}+cu${cu}torch${torch}-cp${cp}-cp${cp}-linux_x86_64.whl +flash_attn-${flash_attn_version}+cu${cu}torch${torch}-cp${cp}-cp${cp}-linux_x86_64.whl +# Recommend to download the wheel handly, for example flash_attn-2.8.3+cu12torch2.6cxx11abiFALSE-cp311-cp311-linux_x86_64 +pip install flash_attn-2.8.3+cu124torch2.6-cp311-cp311-linux_x86_64.whl + +# maybe slow, be patient +pip install --no-build-isolation "git+https://github.com/Dao-AILab/flash-attention.git@v2.7.2#egg=flashattn-hopper&subdirectory=hopper" + + +# Maybe slow too, be patient +pip install -r ./requirements/inference/requirements.txt +pip install vllm==0.8.5 +python tools/patch/unpatch.py --backend llama.cpp +python tools/patch/unpatch.py --backend omniinfer +python tools/patch/unpatch.py --backend Megatron-LM + +pip install build +pip install setuptools-scm +pip install "git+https://github.com/state-spaces/mamba.git@v2.2.4" + +pip install -r ./requirements/serving/requirements.txt +pip install --no-build-isolation git+https://github.com/FlagOpen/FlagGems.git@release_v1.0.0 \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/src/llm_config.yaml b/openseek/competition/LongContext-ICL-Annotation/src/llm_config.yaml new file mode 100644 index 0000000..94e8a2c --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/src/llm_config.yaml @@ -0,0 +1,30 @@ +serve: +- serve_id: vllm_model + engine: vllm + engine_args: + model: ../Qwen3-4B + host: 0.0.0.0 + uvicorn_log_level: warning + port: 2026 + gpu_memory_utilization: 0.9 + trust_remote_code: true + no_enable_prefix_caching: true + +experiment: + exp_name: qwen3_4b + exp_dir: outputs/${experiment.exp_name} + task: + type: serve + runner: + hostfile: null + deploy: + use_fs_serve: false + envs: + CUDA_VISIBLE_DEVICES: 0 + CUDA_DEVICE_MAX_CONNECTIONS: 1 + +action: run + +hydra: + run: + dir: ${experiment.exp_dir}/hydra \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/src/main.py b/openseek/competition/LongContext-ICL-Annotation/src/main.py new file mode 100644 index 0000000..c294979 --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/src/main.py @@ -0,0 +1,92 @@ +import json, os, argparse +from tqdm import tqdm, trange +from transformers import AutoTokenizer + +# from method import build_prompt, select_examples, annotate + +from method import build_prompt, select_examples + +from method import annotate_nvidia as annotate # For Nvidia GPU +# from method import annotate_ascend as annotate # For Huawei Ascend + +TASK_FILES = { + 1: './data/openseek-1_closest_integers.json', + 2: './data/openseek-2_count_nouns_verbs.json', + 3: './data/openseek-3_collatz_conjecture.json', + 4: './data/openseek-4_conala_concat_strings.json', + 5: './data/openseek-5_semeval_2018_task1_tweet_sadness_detection.json', + 6: './data/openseek-6_mnli_same_genre_classification.json', + 7: './data/openseek-7_jeopardy_answer_generation_all.json', + 8: '../data/openseek-8_kernel_generation.json', +} + +def parser_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--task_id', type=int, required=True, + help='Task ID to evaluate, should be in [1, 7].') + parser.add_argument('--max_input_length', type=int, default=10_000, + help='Maximum input length for the model.') + parser.add_argument('--log_path_prefix', type=str, + default='../outputs/', + help='Prefix path to save the evaluation logs.') + parser.add_argument('--tokenizer_path', type=str, + default='/share/project/wuhaiming/spaces/data_agent/OpenSeek-main/openseek/competition/LongContext-ICL-Annotation/src/Qwen3-4B') + args = parser.parse_args() + return args + +def evaluate(task_id:int, + qwen_tokenizer:AutoTokenizer, + max_input_length:int=128_000, + log_path_prefix:str='./outputs/' + )->float: + assert task_id in [i for i in range(1, 9)],\ + f"task_id should be in [1, 8], but got {task_id}." + + task_file = TASK_FILES[task_id] + with open(task_file, 'r') as f: + task_dict = json.load(f) + + task_name = task_dict['task_name'] + task_description = task_dict['Definition'][0] + icl_examples = task_dict['examples'][:100] + test_samples = task_dict['test_samples'] + + version = 1 + output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' + output_path = os.path.dirname(output_file) + os.makedirs(output_path, exist_ok=True) + while os.path.exists(output_file): + version += 1 + output_file = f'{log_path_prefix}openseek-{task_id}-v{version}.jsonl' + with open(output_file, 'w') as f: + pass + + examples_str = None + for test_sample in tqdm(test_samples, desc=f'Evaluation on Task {task_id}: {task_name}'): + test_record = dict() + + test_sample_id = test_sample['id'] + test_record['test_sample_id'] = test_sample_id + + + text2annotate = test_sample['input'] + prompt = build_prompt(task_description, text2annotate) + if examples_str is None: + examples_str = select_examples(icl_examples, task_description, text2annotate) + input_prompt = prompt.replace("[[EXAMPLES]]\n\n", examples_str+'\n\n') + + # tokenized_input = qwen_tokenizer(input_prompt, return_tensors="pt") + # if tokenized_input['input_ids'].shape[1] > max_input_length: + # test_record['prediction'] = None + # else: + # prediction = annotate(input_prompt) + # test_record['prediction'] = prediction + prediction = annotate(input_prompt) + test_record['prediction'] = prediction + with open(output_file, 'a') as f: + f.write(json.dumps(test_record)+'\n') + +if __name__ == '__main__': + args = parser_args() + qwen_tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_path) + evaluate(args.task_id, qwen_tokenizer, args.max_input_length, args.log_path_prefix) \ No newline at end of file diff --git a/openseek/competition/LongContext-ICL-Annotation/src/method.py b/openseek/competition/LongContext-ICL-Annotation/src/method.py new file mode 100644 index 0000000..386daf2 --- /dev/null +++ b/openseek/competition/LongContext-ICL-Annotation/src/method.py @@ -0,0 +1,277 @@ + +import re +from collections import Counter +from transformers import AutoTokenizer + +""" Here is an example of implementation of Long-Context Data Annotation. """ + +def build_prompt____(task_description: str, text2annotate: str) -> str: + """ + Build a high-precision English prompt for long-context data annotation (optimized for Qwen3-4B). + Core requirement: Final answer MUST be wrapped in