diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json
new file mode 100644
index 00000000..39bbd268
--- /dev/null
+++ b/.devcontainer/devcontainer.json
@@ -0,0 +1,4 @@
+{
+ "image": "mcr.microsoft.com/devcontainers/universal:2",
+ "features": {}
+}
diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS
new file mode 100644
index 00000000..d185e0a5
--- /dev/null
+++ b/.github/CODEOWNERS
@@ -0,0 +1,5 @@
+@openai/developer-experience
+dkundel-openai
+Maratyszcza
+scott-oai
+volsgd
diff --git a/LICENSE b/LICENSE
index 4ecba18e..d6456956 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,181 +1,182 @@
+
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
-TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
-
-1. Definitions.
-
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
-
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
-
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
-
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
-
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
-2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
-3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
-4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
-5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
-6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
-7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
-8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
-9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
-END OF TERMS AND CONDITIONS
-
-APPENDIX: How to apply the Apache License to your work.
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
@@ -186,16 +187,16 @@ APPENDIX: How to apply the Apache License to your work.
same "printed page" as the copyright notice for easier
identification within third-party archives.
-Copyright 2025 OpenAI
+ Copyright [yyyy] [name of copyright owner]
-Licensed under the Apache License, Version 2.0 (the "License");
-you may not use this file except in compliance with the License.
-You may obtain a copy of the License at
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
-Unless required by applicable law or agreed to in writing, software
-distributed under the License is distributed on an "AS IS" BASIS,
-WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-See the License for the specific language governing permissions and
-limitations under the License.
\ No newline at end of file
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 00000000..7bd37930
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1 @@
+recursive-include _build *
\ No newline at end of file
diff --git a/README.md b/README.md
index b50dbc4f..0104cec4 100644
--- a/README.md
+++ b/README.md
@@ -1,36 +1,54 @@
-
-
-Try gpt-oss | Guides | Model card
-
Learn more about OpenAI's open models
-Download gpt-oss-120b and gpt-oss-20b on Hugging Face
+
+ Try gpt-oss ·
+ Guides ·
+ Model card ·
+ OpenAI blog
+
+ Download gpt-oss-120b and gpt-oss-20b on Hugging Face
+
+
-
Welcome to the gpt-oss series, [OpenAI's open-weight models](https://openai.com/open-models/) designed for powerful reasoning, agentic tasks, and versatile developer use cases.
-We're releasing two flavors of the open models:
+We're releasing two flavors of these open models:
-- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fits into a single H100 GPU (117B parameters with 5.1B active parameters)
+- `gpt-oss-120b` — for production, general purpose, high reasoning use cases that fit into a single 80GB GPU (like NVIDIA H100 or AMD MI300X) (117B parameters with 5.1B active parameters)
- `gpt-oss-20b` — for lower latency, and local or specialized use cases (21B parameters with 3.6B active parameters)
-Both models were trained on our [harmony response format][harmony] and should only be used with the harmony format as it will not work correctly otherwise.
+Both models were trained using our [harmony response format][harmony] and should only be used with this format; otherwise, they will not work correctly.
+
+## Table of Contents
+- [Highlights](#highlights)
+- [Inference examples](#inference-examples)
+- [About this repository](#about-this-repository)
+- [Setup](#setup)
+- [Download the model](#download-the-model)
+- [Reference PyTorch implementation](#reference-pytorch-implementation)
+- [Reference Triton implementation (single GPU)](#reference-triton-implementation-single-gpu)
+- [Reference Metal implementation](#reference-metal-implementation)
+- [Harmony format & tools](#harmony-format--tools)
+- [Clients](#clients)
+- [Tools](#tools)
+- [Other details](#other-details)
+- [Contributing](#contributing)
### Highlights
- **Permissive Apache 2.0 license:** Build freely without copyleft restrictions or patent risk—ideal for experimentation, customization, and commercial deployment.
- **Configurable reasoning effort:** Easily adjust the reasoning effort (low, medium, high) based on your specific use case and latency needs.
-- **Full chain-of-thought:** Gain complete access to the model's reasoning process, facilitating easier debugging and increased trust in outputs. It's not intended to be shown to end users.
+- **Full chain-of-thought:** Provides complete access to the model's reasoning process, facilitating easier debugging and greater trust in outputs. This information is not intended to be shown to end users.
- **Fine-tunable:** Fully customize models to your specific use case through parameter fine-tuning.
- **Agentic capabilities:** Use the models' native capabilities for function calling, [web browsing](#browser), [Python code execution](#python), and Structured Outputs.
-- **Native MXFP4 quantization:** The models are trained with native MXFP4 precision for the MoE layer, making `gpt-oss-120b` run on a single H100 GPU and the `gpt-oss-20b` model run within 16GB of memory.
+- **MXFP4 quantization:** The models were post-trained with MXFP4 quantization of the MoE weights, making `gpt-oss-120b` run on a single 80GB GPU (like NVIDIA H100 or AMD MI300X) and the `gpt-oss-20b` model run within 16GB of memory. All evals were performed with the same MXFP4 quantization.
### Inference examples
#### Transformers
-You can use `gpt-oss-120b` and `gpt-oss-20b` with Transformers. If you use Transformers's chat template it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package.
+You can use `gpt-oss-120b` and `gpt-oss-20b` with the Transformers library. If you use Transformers' chat template, it will automatically apply the [harmony response format][harmony]. If you use `model.generate` directly, you need to apply the harmony format manually using the chat template or use our [`openai-harmony`][harmony] package.
```python
from transformers import pipeline
@@ -60,7 +78,7 @@ print(outputs[0]["generated_text"][-1])
#### vLLM
-vLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible webserver. The following command will automatically download the model and start the server.
+vLLM recommends using [`uv`](https://docs.astral.sh/uv/) for Python dependency management. You can use vLLM to spin up an OpenAI-compatible web server. The following command will automatically download the model and start the server.
```bash
uv pip install --pre vllm==0.10.1+gptoss \
@@ -73,7 +91,83 @@ vllm serve openai/gpt-oss-20b
[Learn more about how to use gpt-oss with vLLM.](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
-#### Pytorch / Triton / Metal
+Offline Serve Code:
+- run this code after installing proper libraries as described, while additionally installing this:
+- `uv pip install openai-harmony`
+```python
+# source .oss/bin/activate
+
+import os
+os.environ["VLLM_USE_FLASHINFER_SAMPLER"] = "0"
+
+import json
+from openai_harmony import (
+ HarmonyEncodingName,
+ load_harmony_encoding,
+ Conversation,
+ Message,
+ Role,
+ SystemContent,
+ DeveloperContent,
+)
+
+from vllm import LLM, SamplingParams
+import os
+
+# --- 1) Render the prefill with Harmony ---
+encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
+
+convo = Conversation.from_messages(
+ [
+ Message.from_role_and_content(Role.SYSTEM, SystemContent.new()),
+ Message.from_role_and_content(
+ Role.DEVELOPER,
+ DeveloperContent.new().with_instructions("Always respond in riddles"),
+ ),
+ Message.from_role_and_content(Role.USER, "What is the weather like in SF?"),
+ ]
+)
+
+prefill_ids = encoding.render_conversation_for_completion(convo, Role.ASSISTANT)
+
+# Harmony stop tokens (pass to sampler so they won't be included in output)
+stop_token_ids = encoding.stop_tokens_for_assistant_actions()
+
+# --- 2) Run vLLM with prefill ---
+llm = LLM(
+ model="openai/gpt-oss-20b",
+ trust_remote_code=True,
+ gpu_memory_utilization = 0.95,
+ max_num_batched_tokens=4096,
+ max_model_len=5000,
+ tensor_parallel_size=1
+)
+
+sampling = SamplingParams(
+ max_tokens=128,
+ temperature=1,
+ stop_token_ids=stop_token_ids,
+)
+
+outputs = llm.generate(
+ prompt_token_ids=[prefill_ids], # batch of size 1
+ sampling_params=sampling,
+)
+
+# vLLM gives you both text and token IDs
+gen = outputs[0].outputs[0]
+text = gen.text
+output_tokens = gen.token_ids # <-- these are the completion token IDs (no prefill)
+
+# --- 3) Parse the completion token IDs back into structured Harmony messages ---
+entries = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
+
+# 'entries' is a sequence of structured conversation entries (assistant messages, tool calls, etc.).
+for message in entries:
+ print(f"{json.dumps(message.to_dict())}")
+```
+
+#### PyTorch / Triton / Metal
These implementations are largely reference implementations for educational purposes and are not expected to be run in production.
@@ -113,20 +207,21 @@ Check out our [awesome list](./awesome-gpt-oss.md) for a broader collection of g
This repository provides a collection of reference implementations:
- **Inference:**
- - [`torch`](#reference-pytorch-implementation) — a non-optimized [Pytorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4x H100s because it's not optimized
- - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [Pytorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching
+ - [`torch`](#reference-pytorch-implementation) — a non-optimized [PyTorch](https://pytorch.org/) implementation for educational purposes only. Requires at least 4× H100 GPUs due to lack of optimization.
+ - [`triton`](#reference-triton-implementation-single-gpu) — a more optimized implementation using [PyTorch](https://pytorch.org/) & [Triton](https://github.com/triton-lang/triton) incl. using CUDA graphs and basic caching
- [`metal`](#reference-metal-implementation) — a Metal-specific implementation for running the models on Apple Silicon hardware
- **Tools:**
- [`browser`](#browser) — a reference implementation of the browser tool the models got trained on
- [`python`](#python) — a stateless reference implementation of the python tool the model got trained on
- **Client examples:**
- - [`chat`](#terminal-chat) — a basic terminal chat application that uses the Pytorch or Triton implementations for inference along with the python and browser tools
+ - [`chat`](#terminal-chat) — a basic terminal chat application that uses the PyTorch or Triton implementations for inference along with the python and browser tools
- [`responses_api`](#responses-api) — an example Responses API compatible server that implements the browser tool along with other Responses-compatible functionality
## Setup
### Requirements
+- Python 3.12
- On macOS: Install the Xcode CLI tools --> `xcode-select --install`
- On Linux: These reference implementations require CUDA
- On Windows: These reference implementations have not been tested on Windows. Try using solutions like Ollama if you are trying to run the model locally.
@@ -148,7 +243,7 @@ If you want to modify the code or try the metal implementation set the project u
```shell
git clone https://github.com/openai/gpt-oss.git
-pip install -e .[metal]
+GPTOSS_BUILD_METAL=1 pip install -e ".[metal]"
```
## Download the model
@@ -157,20 +252,20 @@ You can download the model weights from the [Hugging Face Hub](https://huggingfa
```shell
# gpt-oss-120b
-huggingface-cli download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/
+hf download openai/gpt-oss-120b --include "original/*" --local-dir gpt-oss-120b/
# gpt-oss-20b
-huggingface-cli download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/
+hf download openai/gpt-oss-20b --include "original/*" --local-dir gpt-oss-20b/
```
## Reference PyTorch implementation
We include an inefficient reference PyTorch implementation in [gpt_oss/torch/model.py](gpt_oss/torch/model.py). This code uses basic PyTorch operators to show the exact model architecture, with a small addition of supporting tensor parallelism in MoE so that the larger model can run with this code (e.g., on 4xH100 or 2xH200). In this implementation, we upcast all weights to BF16 and run the model in BF16.
-To run the reference implementation. Install dependencies:
+To run the reference implementation, install the dependencies:
```shell
-pip install -e .[torch]
+pip install -e ".[torch]"
```
And then run:
@@ -192,9 +287,10 @@ git clone https://github.com/triton-lang/triton
cd triton/
pip install -r python/requirements.txt
pip install -e . --verbose --no-build-isolation
+pip install -e python/triton_kernels
# Install the gpt-oss triton implementation
-pip install -e .[triton]
+pip install -e ".[triton]"
```
And then run:
@@ -205,16 +301,16 @@ export PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True
python -m gpt_oss.generate --backend triton gpt-oss-120b/original/
```
-If you encounter `torch.OutOfMemoryError` make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint.
+If you encounter `torch.OutOfMemoryError`, make sure to turn on the expandable allocator to avoid crashes when loading weights from the checkpoint.
## Reference Metal implementation
-Additionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production ready but is accurate to the Pytorch implementation.
+Additionally we are providing a reference implementation for Metal to run on Apple Silicon. This implementation is not production-ready but is accurate to the PyTorch implementation.
The implementation will get automatically compiled when running the `.[metal]` installation on an Apple Silicon device:
```shell
-pip install -e .[metal]
+GPTOSS_BUILD_METAL=1 pip install -e ".[metal]"
```
To perform inference you'll need to first convert the SafeTensor weights from Hugging Face into the right format using:
@@ -223,10 +319,11 @@ To perform inference you'll need to first convert the SafeTensor weights from Hu
python gpt_oss/metal/scripts/create-local-model.py -s -d
```
-Or downloaded the pre-converted weight:
+Or download the pre-converted weights:
+
```shell
-huggingface-cli download openai/gpt-oss-120b --include "metal/*" --local-dir gpt-oss-120b/metal/
-huggingface-cli download openai/gpt-oss-20b --include "metal/*" --local-dir gpt-oss-20b/metal/
+hf download openai/gpt-oss-120b --include "metal/*" --local-dir gpt-oss-120b/metal/
+hf download openai/gpt-oss-20b --include "metal/*" --local-dir gpt-oss-20b/metal/
```
To test it you can run:
@@ -245,7 +342,7 @@ We also include two system tools for the model: browsing and python container. C
### Terminal Chat
-The terminal chat application is a basic example on how to use the harmony format together with the Pytorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used.
+The terminal chat application is a basic example of how to use the harmony format together with the PyTorch, Triton, and vLLM implementations. It also exposes both the python and browser tool as optional tools that can be used.
```bash
usage: python -m gpt_oss.chat [-h] [-r REASONING_EFFORT] [-a] [-b] [--show-browser-results] [-p] [--developer-message DEVELOPER_MESSAGE] [-c CONTEXT] [--raw] [--backend {triton,torch,vllm}] FILE
@@ -274,7 +371,7 @@ options:
```
> [!NOTE]
-> The torch and triton implementation requires original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively.
+> The torch and triton implementations require original checkpoint under `gpt-oss-120b/original/` and `gpt-oss-20b/original/` respectively. While vLLM uses the Hugging Face converted checkpoint under `gpt-oss-120b/` and `gpt-oss-20b/` root directory respectively.
### Responses API
@@ -284,7 +381,7 @@ You can start this server with the following inference backends:
- `triton` — uses the triton implementation
- `metal` — uses the metal implementation on Apple Silicon only
-- `ollama` — uses the Ollama /api/generate API as a inference solution
+- `ollama` — uses the Ollama /api/generate API as an inference solution
- `vllm` — uses your installed vllm version to perform inference
- `transformers` — uses your installed transformers version to perform local inference
@@ -329,7 +426,7 @@ codex -p oss
### Browser
> [!WARNING]
-> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`ExaBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment.
+> This implementation is purely for educational purposes and should not be used in production. You should implement your own equivalent of the [`YouComBackend`](gpt_oss/tools/simple_browser/backend.py) class with your own browsing environment. Currently we have available `YouComBackend` and `ExaBackend`.
Both gpt-oss models were trained with the capability to browse using the `browser` tool that exposes the following three methods:
@@ -339,20 +436,25 @@ Both gpt-oss models were trained with the capability to browse using the `browse
#### Usage
-To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
+To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_browser_tool()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
```python
import datetime
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from openai_harmony import SystemContent, Message, Conversation, Role, load_harmony_encoding, HarmonyEncodingName
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
-# Exa backend requires you to have set the EXA_API_KEY environment variable
-backend = ExaBackend(
+# Depending on the choice of the browser backend you need corresponding env variables setup
+# In case you use You.com backend requires you to have set the YDC_API_KEY environment variable,
+# while for Exa you might need EXA_API_KEY environment variable set
+backend = YouComBackend(
source="web",
)
+# backend = ExaBackend(
+# source="web",
+# )
browser_tool = SimpleBrowserTool(backend=backend)
# create a basic system prompt
@@ -365,7 +467,7 @@ if use_browser_tool:
# enables the tool
system_message_content = system_message_content.with_tools(browser_tool.tool_config)
# alternatively you could use the following if your tool is not stateless
- system_message_content = system_message_content.with_browser()
+ system_message_content = system_message_content.with_browser_tool()
# construct the system message
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
@@ -393,20 +495,20 @@ if last_message.recipient.startswith("browser"):
#### Details
-To control the context window size this tool use a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers.
+To control the context window size this tool uses a scrollable window of text that the model can interact with. So it might fetch the first 50 lines of a page and then scroll to the next 20 lines after that. The model has also been trained to then use citations from this tool in its answers.
To improve performance the tool caches requests so that the model can revisit a different part of a page without having to reload the page. For that reason you should create a new browser instance for every request.
### Python
-The model got trained on using a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony].
+The model was trained to use a python tool to perform calculations and other actions as part of its chain-of-thought. During the training the model used a stateful tool which makes running tools between CoT loops easier. This reference implementation, however, uses a stateless mode. As a result the PythonTool defines its own tool description to override the definition in [`openai-harmony`][harmony].
> [!WARNING]
> This implementation runs in a permissive Docker container which could be problematic in cases like prompt injections. It's serving as an example and you should consider implementing your own container restrictions in production.
#### Usage
-To enable the browser tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
+To enable the python tool, you'll have to place the definition into the `system` message of your harmony formatted prompt. You can either use the `with_python()` method if your tool implements the full interface or modify the definition using `with_tools()`. For example:
```python
import datetime
@@ -433,7 +535,7 @@ if use_python_tool:
system_message = Message.from_role_and_content(Role.SYSTEM, system_message_content)
# create the overall prompt
-messages = [system_message, Message.from_role_and_content(Role.USER, "What's the squareroot of 9001?")]
+messages = [system_message, Message.from_role_and_content(Role.USER, "What's the square root of 9001?")]
conversation = Conversation.from_messages(messages)
# convert to tokens
@@ -443,7 +545,7 @@ token_ids = encoding.render_conversation_for_completion(conversation, Role.ASSIS
# ...
# parse the output
-messages = messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
+messages = encoding.parse_messages_from_completion_tokens(output_tokens, Role.ASSISTANT)
last_message = messages[-1]
if last_message.recipient == "python":
# perform python call
@@ -463,10 +565,10 @@ if last_message.recipient == "python":
We released the models with native quantization support. Specifically, we use [MXFP4](https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf) for the linear projection weights in the MoE layer. We store the MoE tensor in two parts:
-- `tensor.blocks` stores the actual fp4 values. We pack every two value in one `uint8` value.
+- `tensor.blocks` stores the actual fp4 values. We pack every two values in one `uint8` value.
- `tensor.scales` stores the block scale. The block scaling is done among the last dimension for all MXFP4 tensors.
-All other tensors will be in BF16. We also recommend use BF16 as the activation precision for the model.
+All other tensors will be in BF16. We also recommend using BF16 as the activation precision for the model.
### Recommended Sampling Parameters
@@ -477,3 +579,17 @@ We recommend sampling with `temperature=1.0` and `top_p=1.0`.
The reference implementations in this repository are meant as a starting point and inspiration. Outside of bug fixes we do not intend to accept new feature contributions. If you build implementations based on this code such as new tool implementations you are welcome to contribute them to the [`awesome-gpt-oss.md`](./awesome-gpt-oss.md) file.
[harmony]: https://github.com/openai/harmony
+
+## Citation
+
+```bibtex
+@misc{openai2025gptoss120bgptoss20bmodel,
+ title={gpt-oss-120b & gpt-oss-20b Model Card},
+ author={OpenAI},
+ year={2025},
+ eprint={2508.10925},
+ archivePrefix={arXiv},
+ primaryClass={cs.CL},
+ url={https://arxiv.org/abs/2508.10925},
+}
+```
diff --git a/_build/gpt_oss_build_backend/__init__.py b/_build/gpt_oss_build_backend/__init__.py
new file mode 100644
index 00000000..2f46b29d
--- /dev/null
+++ b/_build/gpt_oss_build_backend/__init__.py
@@ -0,0 +1 @@
+"""In-tree PEP 517 backend package for gpt-oss."""
\ No newline at end of file
diff --git a/_build/gpt_oss_build_backend/backend.py b/_build/gpt_oss_build_backend/backend.py
new file mode 100644
index 00000000..5cd76bdf
--- /dev/null
+++ b/_build/gpt_oss_build_backend/backend.py
@@ -0,0 +1,140 @@
+"""
+Build backend for gpt-oss that supports two modes:
+
+1) Default (pure wheel for PyPI)
+ - Delegates to setuptools.build_meta.
+ - Produces a py3-none-any wheel so PyPI accepts it (no linux_x86_64 tag).
+
+2) Optional Metal/C extension build (local only)
+ - If the environment variable GPTOSS_BUILD_METAL is set to a truthy value
+ (1/true/on/yes), delegates to scikit_build_core.build.
+ - Dynamically injects build requirements (scikit-build-core, cmake, ninja,
+ pybind11) only for this mode.
+
+Why this is needed
+- PyPI rejects Linux wheels tagged linux_x86_64; manylinux/musllinux is required
+ for binary wheels. We ship a pure wheel by default, but still allow developers
+ to build/install the native Metal backend locally when needed.
+
+Typical usage
+- Publish pure wheel: `python -m build` (do not set GPTOSS_BUILD_METAL).
+- Local Metal dev: `GPTOSS_BUILD_METAL=1 pip install -e ".[metal]"`.
+- CI: keep GPTOSS_BUILD_METAL unset for releases; set it in internal jobs that
+ exercise the extension.
+
+Notes
+- The base package remains importable without the extension. The Metal backend
+ is only used when `gpt_oss.metal` is explicitly imported.
+- This file is discovered via `backend-path = ["_build"]` and
+ `build-backend = "gpt_oss_build_backend.backend"` in pyproject.toml.
+"""
+import os
+from importlib import import_module
+from typing import Any, Mapping, Sequence
+
+
+TRUE_VALUES = {"1", "true", "TRUE", "on", "ON", "yes", "YES"}
+
+
+def _use_metal_backend() -> bool:
+ return str(os.environ.get("GPTOSS_BUILD_METAL", "")).strip() in TRUE_VALUES
+
+
+def _setuptools_backend():
+ from setuptools import build_meta as _bm # type: ignore
+
+ return _bm
+
+
+def _scikit_build_backend():
+ return import_module("scikit_build_core.build")
+
+
+def _backend():
+ return _scikit_build_backend() if _use_metal_backend() else _setuptools_backend()
+
+
+# Required PEP 517 hooks
+
+def build_wheel(
+ wheel_directory: str,
+ config_settings: Mapping[str, Any] | None = None,
+ metadata_directory: str | None = None,
+) -> str:
+ return _backend().build_wheel(wheel_directory, config_settings, metadata_directory)
+
+
+def build_sdist(
+ sdist_directory: str, config_settings: Mapping[str, Any] | None = None
+) -> str:
+ return _backend().build_sdist(sdist_directory, config_settings)
+
+
+def prepare_metadata_for_build_wheel(
+ metadata_directory: str, config_settings: Mapping[str, Any] | None = None
+) -> str:
+ # Fallback if backend doesn't implement it
+ be = _backend()
+ fn = getattr(be, "prepare_metadata_for_build_wheel", None)
+ if fn is None:
+ # setuptools exposes it; scikit-build-core may not. Defer to building a wheel for metadata.
+ return _setuptools_backend().prepare_metadata_for_build_wheel(
+ metadata_directory, config_settings
+ )
+ return fn(metadata_directory, config_settings)
+
+
+# Optional hooks
+
+def build_editable(
+ editable_directory: str, config_settings: Mapping[str, Any] | None = None, metadata_directory: str | None = None
+) -> str:
+ be = _backend()
+ fn = getattr(be, "build_editable", None)
+ if fn is None:
+ # setuptools implements build_editable; if not available, raise the standard error
+ raise RuntimeError("Editable installs not supported by the selected backend")
+ return fn(editable_directory, config_settings)
+
+
+def get_requires_for_build_wheel(
+ config_settings: Mapping[str, Any] | None = None,
+) -> Sequence[str]:
+ if _use_metal_backend():
+ # Add dynamic build requirements only when building the Metal backend
+ return [
+ "scikit-build-core>=0.10",
+ "pybind11>=2.12",
+ "cmake>=3.26",
+ "ninja",
+ ]
+ # setuptools usually returns []
+ return list(_setuptools_backend().get_requires_for_build_wheel(config_settings))
+
+
+def get_requires_for_build_sdist(
+ config_settings: Mapping[str, Any] | None = None,
+) -> Sequence[str]:
+ # No special requirements for SDist
+ be = _backend()
+ fn = getattr(be, "get_requires_for_build_sdist", None)
+ if fn is None:
+ return []
+ return list(fn(config_settings))
+
+
+def get_requires_for_build_editable(
+ config_settings: Mapping[str, Any] | None = None,
+) -> Sequence[str]:
+ if _use_metal_backend():
+ return [
+ "scikit-build-core>=0.10",
+ "pybind11>=2.12",
+ "cmake>=3.26",
+ "ninja",
+ ]
+ be = _setuptools_backend()
+ fn = getattr(be, "get_requires_for_build_editable", None)
+ if fn is None:
+ return []
+ return list(fn(config_settings))
\ No newline at end of file
diff --git a/awesome-gpt-oss.md b/awesome-gpt-oss.md
index 37befa2a..8b82ebf8 100644
--- a/awesome-gpt-oss.md
+++ b/awesome-gpt-oss.md
@@ -10,6 +10,7 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [Cloud](#cloud)
- [Examples / Tutorials](#examples--tutorials)
- [Tools](#tools)
+- [Training](#training)
## Inference
@@ -25,36 +26,48 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [Use gpt-oss-120b with LM Studio](https://lmstudio.ai/models/openai/gpt-oss-120b)
- Hugging Face & Transformers
- [How to run gpt-oss with Transformers](https://cookbook.openai.com/articles/gpt-oss/run-transformers)
- - [Hugging Face & gpt-oss launch blog](http://huggingface.co/blog/welcome-openai-gpt-oss)
+ - [Hugging Face & gpt-oss launch blog](https://huggingface.co/blog/welcome-openai-gpt-oss)
- [Collection of Hugging Face examples](https://github.com/huggingface/gpt-oss-recipes)
- NVIDIA
- [gpt-oss on RTX](https://blogs.nvidia.com/blog/rtx-ai-garage-openai-oss)
+- AMD
+ - [Running gpt-oss models on AMD Ryzen AI Processors and Radeon Graphics Cards](https://www.amd.com/en/blogs/2025/how-to-run-openai-gpt-oss-20b-120b-models-on-amd-ryzen-ai-radeon.html)
+ - [Running gpt-oss on STX Halo and Radeon dGPUs using Lemonade](https://lemonade-server.ai/news/gpt-oss.html)
+- llama.cpp
+ - [Running gpt-oss with llama.cpp](https://github.com/ggml-org/llama.cpp/discussions/15396)
### Server
- vLLM
- [How to run gpt-oss with vLLM](https://cookbook.openai.com/articles/gpt-oss/run-vllm)
+ - [vLLM & gpt-oss recipies](https://docs.vllm.ai/projects/recipes/en/latest/OpenAI/GPT-OSS.html)
- NVIDIA
- - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/gpt-oss/run-nvidia)
- - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog_9_Deploying_GPT_OSS_on_TRTLLM.md)
+ - [Optimizing gpt-oss with NVIDIA TensorRT-LLM](https://cookbook.openai.com/articles/run-nvidia)
+ - [Deploying gpt-oss on TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM/blob/main/docs/source/blogs/tech_blog/blog9_Deploying_GPT_OSS_on_TRTLLM.md)
+- AMD
+ - [Running the Latest Open Models from OpenAI on AMD AI Hardware](https://rocm.blogs.amd.com/ecosystems-and-partners/openai-day-0/README.html)
### Cloud
- Groq
- - [Groq & gpt-oss launch blog](http://groq.com/day-zero-support-for-openai-open-model)
+ - [Groq & gpt-oss launch blog](https://groq.com/blog/day-zero-support-for-openai-open-models)
- [gpt-oss-120b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-120b)
- [gpt-oss-20b model on the GroqCloud Playground](https://console.groq.com/playground?model=openai/gpt-oss-20b)
- [gpt-oss with built-in web search on GroqCloud](https://console.groq.com/docs/browser-search)
- - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution)
- - [Responses API on Groq](https://console.groq.com/docs/responses)
+ - [gpt-oss with built-in code execution on GroqCloud](https://console.groq.com/docs/code-execution)
+ - [Responses API on Groq](https://console.groq.com/docs/responses-api)
- NVIDIA
- [NVIDIA launch blog post](https://blogs.nvidia.com/blog/openai-gpt-oss/)
- [NVIDIA & gpt-oss developer launch blog post](https://developer.nvidia.com/blog/delivering-1-5-m-tps-inference-on-nvidia-gb200-nvl72-nvidia-accelerates-openai-gpt-oss-models-from-cloud-to-edge/)
- Use [gpt-oss-120b](https://build.nvidia.com/openai/gpt-oss-120b) and [gpt-oss-20b](https://build.nvidia.com/openai/gpt-oss-20b) on NVIDIA's Cloud
- Cloudflare
- - [Cloudflare & gpt-oss launch blog post](http://blog.cloudflare.com/openai-gpt-oss-on-workers-ai)
+ - [Cloudflare & gpt-oss launch blog post](https://blog.cloudflare.com/openai-gpt-oss-on-workers-ai)
- [gpt-oss-120b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-120b)
- [gpt-oss-20b on Cloudflare Workers AI](https://developers.cloudflare.com/workers-ai/models/gpt-oss-20b)
+- AMD
+ - [gpt-oss-120B on AMD MI300X](https://huggingface.co/spaces/amd/gpt-oss-120b-chatbot)
+- AWS (Deploy via Tensorfuse)
+ - [Deploy gpt-oss for both 20b and 120b models on AWS EKS](https://tensorfuse.io/docs/guides/modality/text/openai_oss)
## Examples & Tutorials
@@ -65,6 +78,12 @@ This is a list of guides and resources to help you get started with the gpt-oss
- [Example `python` tool for gpt-oss](./gpt_oss/tools/python_docker/)
- [Example `browser` tool for gpt-oss](./gpt_oss/tools/simple_browser/)
+## Training
+
+- [Hugging Face TRL examples](https://github.com/huggingface/gpt-oss-recipes)
+- [LlamaFactory examples](https://llamafactory.readthedocs.io/en/latest/advanced/best_practice/gpt-oss.html)
+- [Unsloth examples](https://docs.unsloth.ai/basics/gpt-oss-how-to-run-and-fine-tune)
+
## Contributing
Feel free to open a PR to add your own guides and resources on how to run gpt-oss. We will try to review it and add it here.
diff --git a/compatibility-test/.gitignore b/compatibility-test/.gitignore
new file mode 100644
index 00000000..2ba323b0
--- /dev/null
+++ b/compatibility-test/.gitignore
@@ -0,0 +1,142 @@
+# Logs
+logs
+*.log
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+lerna-debug.log*
+
+# Diagnostic reports (https://nodejs.org/api/report.html)
+report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
+
+# Runtime data
+pids
+*.pid
+*.seed
+*.pid.lock
+
+# Directory for instrumented libs generated by jscoverage/JSCover
+lib-cov
+
+# Coverage directory used by tools like istanbul
+coverage
+*.lcov
+
+# nyc test coverage
+.nyc_output
+
+# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
+.grunt
+
+# Bower dependency directory (https://bower.io/)
+bower_components
+
+# node-waf configuration
+.lock-wscript
+
+# Compiled binary addons (https://nodejs.org/api/addons.html)
+build/Release
+
+# Dependency directories
+node_modules/
+jspm_packages/
+
+# Snowpack dependency directory (https://snowpack.dev/)
+web_modules/
+
+# TypeScript cache
+*.tsbuildinfo
+
+# Optional npm cache directory
+.npm
+
+# Optional eslint cache
+.eslintcache
+
+# Optional stylelint cache
+.stylelintcache
+
+# Optional REPL history
+.node_repl_history
+
+# Output of 'npm pack'
+*.tgz
+
+# Yarn Integrity file
+.yarn-integrity
+
+# dotenv environment variable files
+.env
+.env.*
+!.env.example
+
+# parcel-bundler cache (https://parceljs.org/)
+.cache
+.parcel-cache
+
+# Next.js build output
+.next
+out
+
+# Nuxt.js build / generate output
+.nuxt
+dist
+
+# Gatsby files
+.cache/
+# Comment in the public line in if your project uses Gatsby and not Next.js
+# https://nextjs.org/blog/next-9-1#public-directory-support
+# public
+
+# vuepress build output
+.vuepress/dist
+
+# vuepress v2.x temp and cache directory
+.temp
+.cache
+
+# Sveltekit cache directory
+.svelte-kit/
+
+# vitepress build output
+**/.vitepress/dist
+
+# vitepress cache directory
+**/.vitepress/cache
+
+# Docusaurus cache and generated files
+.docusaurus
+
+# Serverless directories
+.serverless/
+
+# FuseBox cache
+.fusebox/
+
+# DynamoDB Local files
+.dynamodb/
+
+# Firebase cache directory
+.firebase/
+
+# TernJS port file
+.tern-port
+
+# Stores VSCode versions used for testing VSCode extensions
+.vscode-test
+
+# yarn v3
+.pnp.*
+.yarn/*
+!.yarn/patches
+!.yarn/plugins
+!.yarn/releases
+!.yarn/sdks
+!.yarn/versions
+
+# Vite logs files
+vite.config.js.timestamp-*
+vite.config.ts.timestamp-*
+
+rollout_*.jsonl
+analysis_*.json
\ No newline at end of file
diff --git a/compatibility-test/README.md b/compatibility-test/README.md
new file mode 100644
index 00000000..22e0007f
--- /dev/null
+++ b/compatibility-test/README.md
@@ -0,0 +1,29 @@
+# API Compatibility Test
+
+This script uses the Agents SDK in TypeScript and the underlying OpenAI client to verify the shape of the API calls but also whether the API performs tool calling.
+
+## What it tests
+
+1.
+
+## How to run
+
+0. Run `npm install` in this directory.
+1. Update `providers.ts` to create an entry for the API to test. Change `vllm` to the provider name of your choice. Use `chat` for Chat Completions tests and `responses` for Responses API tests.
+2. Run an initial quick test to make sure things work. This will only run one test
+
+```
+npm start -- --provider -n 1 -k 1
+```
+
+3. Run the full test (runs each test 5 times to test consistency)
+
+```
+npm start -- --provider -k 5
+```
+
+## Considerations
+
+1. The tests will fail if the API shape does not match the expected behavior
+2. Events in the chat API are currently not tested
+3. If the schema validation succeeds but the input is wrong the test will still pass for this test. That's because it's likely more of a prompt engineering issue or a validator issue than an API issue as it still nailed the input
diff --git a/compatibility-test/analysis.ts b/compatibility-test/analysis.ts
new file mode 100644
index 00000000..9c5cf97d
--- /dev/null
+++ b/compatibility-test/analysis.ts
@@ -0,0 +1,142 @@
+export function analyze(caseResults: any[], tries: number) {
+ // Group results by unique task: test_case + apiType
+ type TaskKey = string;
+ const taskKeyFor = (r: any): TaskKey =>
+ `${r.test_case}::${r.result?.apiType}`;
+
+ const successesByTask: Map> = new Map();
+
+ // Count wrong-input tool calls (schema correct but incorrect arguments)
+ let wrongInputToolCalls = 0;
+
+ // Count invalid response shapes per API type
+ const totalByApiType: Record = {};
+ const invalidByApiType: Record = {};
+
+ for (const r of caseResults) {
+ if (!r?.result || typeof r.result.apiType !== "string") continue;
+
+ // Parse attempt index from run_id `${i}_${k}` safely
+ let attemptIndex: number | undefined;
+ if (typeof r.run_id === "string") {
+ const parts = r.run_id.split("_");
+ const k = Number(parts[1]);
+ if (Number.isFinite(k)) attemptIndex = k;
+ }
+
+ const key = taskKeyFor(r);
+ if (!successesByTask.has(key)) successesByTask.set(key, new Map());
+ if (attemptIndex != null) {
+ successesByTask.get(key)!.set(attemptIndex, Boolean(r.success));
+ }
+
+ const d = r.result.toolCallingDetails ?? {};
+ const calledToolAtLeastOnce = Boolean(d.calledToolAtLeastOnce);
+ const calledToolWithRightSchema = Boolean(d.calledToolWithRightSchema);
+ const calledToolWithRightArguments = Boolean(
+ d.calledToolWithRightArguments
+ );
+ if (
+ calledToolAtLeastOnce &&
+ calledToolWithRightSchema &&
+ !calledToolWithRightArguments
+ ) {
+ wrongInputToolCalls++;
+ }
+
+ // Track invalid/total per apiType for response shape
+ const apiType = r.result.apiType as string;
+ totalByApiType[apiType] = (totalByApiType[apiType] ?? 0) + 1;
+ const isValidResponse = r.result.validResponse === true;
+ if (!isValidResponse) {
+ invalidByApiType[apiType] = (invalidByApiType[apiType] ?? 0) + 1;
+ }
+ }
+
+ const totalTasks = successesByTask.size;
+
+ // Compute pass@k and pass^k for k = 1..tries
+ const passAtKByK: number[] = [];
+ const passHatKByK: number[] = [];
+
+ for (let k = 1; k <= tries; k++) {
+ let tasksSuccessfulK = 0; // any success in first k attempts
+ let tasksAllSuccessfulK = 0; // all success in first k attempts
+
+ for (const [, attemptsMap] of successesByTask) {
+ let anySuccess = false;
+ let allSuccess = true;
+ for (let i = 0; i < k; i++) {
+ const v = attemptsMap.get(i) === true;
+ anySuccess = anySuccess || v;
+ if (!v) allSuccess = false;
+ }
+ if (anySuccess) tasksSuccessfulK++;
+ if (allSuccess) tasksAllSuccessfulK++;
+ }
+
+ const passAtK = totalTasks > 0 ? tasksSuccessfulK / totalTasks : 0;
+ const passHatK = totalTasks > 0 ? tasksAllSuccessfulK / totalTasks : 0;
+ passAtKByK.push(passAtK);
+ passHatKByK.push(passHatK);
+ }
+
+ // Convenience: final k=tries values
+ const passAtK = passAtKByK[tries - 1] ?? 0;
+ const passHatK = passHatKByK[tries - 1] ?? 0;
+
+ return {
+ totalTasks,
+ passAtKByK,
+ passHatKByK,
+ passAtK,
+ passHatK,
+ wrongInputToolCalls,
+ // New stats for invalid response shapes per API
+ invalidByApiType,
+ totalByApiType,
+ };
+}
+
+export function printAnalysis(
+ stats: ReturnType,
+ caseResults: any[],
+ provider: string,
+ selectedLines: string[],
+ tries: number,
+ skipped: number,
+ analysisFile: string
+) {
+ const formatPerK = (arr: number[]) =>
+ Array.from({ length: tries }, (_, i) => {
+ const v = arr[i] ?? 0;
+ return `${i + 1}=${v.toFixed(3)}`;
+ }).join(", ");
+
+ console.log("Summary:");
+ console.log(` Provider: ${provider}`);
+ console.log(` Total input cases: ${selectedLines.length}`);
+ console.log(` Tries: ${tries}`);
+ console.log(` Total tasks: ${stats.totalTasks}`);
+ console.log(` Total runs: ${caseResults.length}`);
+ // Conditionally print invalid response shape stats per API type
+ if ((stats.totalByApiType["responses"] ?? 0) > 0) {
+ const bad = stats.invalidByApiType["responses"] ?? 0;
+ const tot = stats.totalByApiType["responses"] ?? 0;
+ console.log(` Invalid Responses API responses: ${bad} (out of ${tot})`);
+ }
+ if ((stats.totalByApiType["chat"] ?? 0) > 0) {
+ const bad = stats.invalidByApiType["chat"] ?? 0;
+ const tot = stats.totalByApiType["chat"] ?? 0;
+ console.log(
+ ` Invalid Chat Completions API responses: ${bad} (out of ${tot})`
+ );
+ }
+ console.log(` pass@k (k=1..${tries}): ${formatPerK(stats.passAtKByK)}`);
+ console.log(` pass^k (k=1..${tries}): ${formatPerK(stats.passHatKByK)}`);
+ console.log(` pass@k (k=${tries}): ${stats.passAtK.toFixed(3)}`);
+ console.log(` pass^k (k=${tries}): ${stats.passHatK.toFixed(3)}`);
+ console.log(` Wrong-input tool calls: ${stats.wrongInputToolCalls}`);
+ console.log(` Invalid cases.jsonl lines: ${skipped}`);
+ console.log(` Analysis written to ${analysisFile}`);
+}
diff --git a/compatibility-test/cases.jsonl b/compatibility-test/cases.jsonl
new file mode 100644
index 00000000..29e7d4e8
--- /dev/null
+++ b/compatibility-test/cases.jsonl
@@ -0,0 +1,30 @@
+{"tool_name":"get_system_health","input":"Hey, quick check: is everything up and running?","expected_arguments":"{}"}
+{"tool_name":"get_system_health","input":"Status report please.","expected_arguments":"{}"}
+{"tool_name":"get_system_health","input":"Can you confirm the LLM health before we start?","expected_arguments":"{}"}
+{"tool_name":"get_system_health","input":"Need a health snapshot.","expected_arguments":"{}"}
+{"tool_name":"get_system_health","input":"Hi, what's the current system health?","expected_arguments":"{}"}
+{"tool_name":"markdown_to_html","input":"Convert this markdown to HTML:\n\n# Title\n\nSome *italic* text.","expected_arguments":"{\"markdown\":\"# Title\\n\\nSome *italic* text.\"}"}
+{"tool_name":"markdown_to_html","input":"Hey, could you turn `## Docs` into HTML?","expected_arguments":"{\"markdown\":\"## Docs\"}"}
+{"tool_name":"markdown_to_html","input":"Please render the following markdown:\n\n- item 1\n- item 2","expected_arguments":"{\"markdown\":\"- item 1\\n- item 2\"}"}
+{"tool_name":"markdown_to_html","input":"I have `**bold**` markdown; give me HTML.","expected_arguments":"{\"markdown\":\"**bold**\"}"}
+{"tool_name":"markdown_to_html","input":"Markdown to HTML: > quote","expected_arguments":"{\"markdown\":\"> quote\"}"}
+{"tool_name":"detect_language","input":"Hey, what language is this: 'Buenos días, ¿cómo estás?'","expected_arguments":"{\"text\":\"Buenos días, ¿cómo estás?\"}"}
+{"tool_name":"detect_language","input":"Identify the language: \"Guten Morgen\"","expected_arguments":"{\"text\":\"Guten Morgen\"}"}
+{"tool_name":"detect_language","input":"Language detection needed: こんにちは、お元気ですか?","expected_arguments":"{\"text\":\"こんにちは、お元気ですか?\"}"}
+{"tool_name":"detect_language","input":"Detect language for: 'Привет, как дела?'","expected_arguments":"{\"text\":\"Привет, как дела?\"}"}
+{"tool_name":"detect_language","input":"What language is 'Bonjour tout le monde'?","expected_arguments":"{\"text\":\"Bonjour tout le monde\"}"}
+{"tool_name":"generate_chart","input":"Plot a simple line chart for these points: (1,2),(2,4),(3,9).","expected_arguments":"{\"data\":[[1,2],[2,4],[3,9]],\"chart_type\":\"line\"}"}
+{"tool_name":"generate_chart","input":"Hey, can I get a bar chart of my sales: 10, 20, 30 across Q1–Q3?","expected_arguments":"{\"data\":[[1,10],[2,20],[3,30]],\"chart_type\":\"bar\",\"title\":\"Quarterly Sales\"}"}
+{"tool_name":"generate_chart","input":"Make a scatter chart titled 'Experiment' with x label Time and y label Value for data [ [0,1], [1,1.5], [2,2.2] ].","expected_arguments":"{\"data\":[[0,1],[1,1.5],[2,2.2]],\"chart_type\":\"scatter\",\"title\":\"Experiment\",\"x_label\":\"Time\",\"y_label\":\"Value\"}"}
+{"tool_name":"generate_chart","input":"Create a line chart of temperatures 70,72,68,65 over 4 days, label x as 'Day'.","expected_arguments":"{\"data\":[[1,70],[2,72],[3,68],[4,65]],\"chart_type\":\"line\",\"x_label\":\"Day\"}"}
+{"tool_name":"generate_chart","input":"Visualize visits per day with a bar chart; numbers: 100,150,120.","expected_arguments":"{\"data\":[[1,100],[2,150],[3,120]],\"chart_type\":\"bar\",\"title\":\"Daily Visits\",\"y_label\":\"Visitors\"}"}
+{"tool_name":"query_database","input":"Give me the ids and emails from users table, limit 5.","expected_arguments":"{\"table\":\"users\",\"columns\":[\"id\",\"email\"],\"limit\":5}"}
+{"tool_name":"query_database","input":"Hey, fetch order_id and amount from orders where status is 'shipped'.","expected_arguments":"{\"table\":\"orders\",\"columns\":[\"order_id\",\"amount\"],\"filters\":\"status = 'shipped'\"}"}
+{"tool_name":"query_database","input":"Retrieve name and price from products ordered by price descending, top 10 please.","expected_arguments":"{\"table\":\"products\",\"columns\":[\"name\",\"price\"],\"limit\":10,\"order_by\":\"price DESC\"}"}
+{"tool_name":"query_database","input":"I need the first 3 log entries from audit_log table.","expected_arguments":"{\"table\":\"audit_log\",\"columns\":[\"id\",\"timestamp\",\"action\"],\"limit\":3}"}
+{"tool_name":"query_database","input":"Query the customers table for name, city where city = 'Berlin'.","expected_arguments":"{\"table\":\"customers\",\"columns\":[\"name\",\"city\"],\"filters\":\"city = 'Berlin'\"}"}
+{"tool_name":"get_weather","input":"What's the weather in San Francisco right now?","expected_arguments":"{\"location\":\"San Francisco\"}"}
+{"tool_name":"get_weather","input":"Weather for Tokyo, please.","expected_arguments":"{\"location\":\"Tokyo\"}"}
+{"tool_name":"get_weather","input":"Get me the current weather for 10001.","expected_arguments":"{\"location\":\"10001\"}"}
+{"tool_name":"get_weather","input":"How's the weather in Paris today?","expected_arguments":"{\"location\":\"Paris\"}"}
+{"tool_name":"get_weather","input":"Check the weather for Sydney.","expected_arguments":"{\"location\":\"Sydney\"}"}
diff --git a/compatibility-test/index.ts b/compatibility-test/index.ts
new file mode 100644
index 00000000..ca6b03dc
--- /dev/null
+++ b/compatibility-test/index.ts
@@ -0,0 +1,196 @@
+import { parseArgs } from "node:util";
+import { createWriteStream } from "node:fs";
+import { readFile, writeFile } from "node:fs/promises";
+import path from "node:path";
+import process from "node:process";
+import { runCase, RunCaseSummary } from "./runCase";
+import { Listr, ListrTaskWrapper } from "listr2";
+import { analyze, printAnalysis } from "./analysis";
+
+function formatTimestamp(d: Date): string {
+ const pad = (n: number) => String(n).padStart(2, "0");
+ const yyyy = d.getFullYear();
+ const mm = pad(d.getMonth() + 1);
+ const dd = pad(d.getDate());
+ const hh = pad(d.getHours());
+ const mi = pad(d.getMinutes());
+ const ss = pad(d.getSeconds());
+ return `${yyyy}${mm}${dd}_${hh}${mi}${ss}`;
+}
+
+async function main() {
+ const args = parseArgs({
+ options: {
+ cases: { type: "string", short: "c", default: "cases.jsonl" },
+ provider: { type: "string", short: "p", default: "openai" },
+ streaming: { type: "boolean", short: "s", default: false },
+ maxTurns: { type: "string", short: "t", default: "10" },
+ n: { type: "string", short: "n" },
+ strict: { type: "boolean", short: "s", default: false },
+ tries: { type: "string", short: "k", default: "1" },
+ },
+ });
+ const casesPathArg = args.values.cases;
+ const provider = args.values.provider as string;
+ const streaming = Boolean(args.values.streaming);
+ const maxTurns = Number(args.values.maxTurns ?? 10);
+ const nRaw = args.values.n as string | undefined;
+ const triesRaw = args.values.tries as string | undefined;
+ const tries = triesRaw != null ? Number(triesRaw) : 1;
+ const limit = nRaw != null ? Number(nRaw) : undefined;
+ if (limit != null && (!Number.isFinite(limit) || limit <= 0)) {
+ console.error("--n must be a positive integer");
+ process.exitCode = 1;
+ return;
+ }
+
+ if (!casesPathArg) {
+ console.error("--cases is required (path to JSONL file)");
+ process.exitCode = 1;
+ return;
+ }
+
+ const casesPath = path.isAbsolute(casesPathArg)
+ ? casesPathArg
+ : path.join(process.cwd(), casesPathArg);
+
+ const timestamp = formatTimestamp(new Date());
+ const defaultFilename = `rollout_${provider}_${timestamp}.jsonl`;
+ const outputFile = path.join(process.cwd(), defaultFilename);
+ const analysisFile = path.join(
+ process.cwd(),
+ `analysis_${provider}_${timestamp}.json`
+ );
+
+ let fileContent: string;
+ try {
+ fileContent = await readFile(casesPath, "utf8");
+ } catch (err: any) {
+ console.error(
+ `Failed to read cases file at ${casesPath}: ${err?.message ?? err}`
+ );
+ process.exitCode = 1;
+ return;
+ }
+
+ const lines = fileContent
+ .split(/\r?\n/)
+ .map((l) => l.trim())
+ .filter((l) => l.length > 0);
+
+ const selectedLines =
+ typeof limit === "number" ? lines.slice(0, limit) : lines;
+
+ const out = createWriteStream(outputFile, { flags: "w", encoding: "utf8" });
+
+ const writeLine = (obj: any) =>
+ new Promise((resolve, reject) => {
+ const str = JSON.stringify(obj) + "\n";
+ out.write(str, (err) => (err ? reject(err) : resolve()));
+ });
+
+ // Accumulators for post-run analysis
+ let skipped = 0; // invalid JSON lines
+ const caseResults: Array<{
+ run_id: string;
+ success: boolean;
+ provider: string;
+ test_case: number;
+ tool_name: string;
+ input: string;
+ result: RunCaseSummary;
+ }> = [];
+
+ async function processIndex(
+ i: number,
+ k: number,
+ task: ListrTaskWrapper
+ ) {
+ const line = selectedLines[i];
+ let caseObj: any;
+ try {
+ caseObj = JSON.parse(line);
+ } catch (err: any) {
+ console.error(
+ `Skipping invalid JSON on line ${i + 1}: ${err?.message ?? err}`
+ );
+ skipped++;
+ return;
+ }
+
+ try {
+ const summaries = await runCase(provider, caseObj, {
+ maxTurns,
+ streaming,
+ strict: args.values.strict,
+ });
+
+ for (const summary of summaries) {
+ const record = {
+ run_id: `${i}_${k}`,
+ success: summary.success,
+ provider,
+ test_case: i,
+ tool_name: caseObj.tool_name,
+ input: caseObj.input,
+ result: summary,
+ };
+ task.output = `Case ${i} (attempt ${k + 1}): ${
+ summary.success ? "Success" : "Failed"
+ } ${summary.toolCallingDetails.warning || ""}`;
+ caseResults.push(record);
+ await writeLine(record);
+ }
+ } catch (err: any) {
+ const record = {
+ provider,
+ test_case: i,
+ tool_name: caseObj?.tool_name,
+ input: caseObj?.input,
+ expected_output: caseObj?.expected_output,
+ instructions: caseObj?.instructions,
+ error: String(err?.message ?? err),
+ };
+ await writeLine(record);
+ task.output = `Case ${i} failed: ${err?.message ?? err}`;
+ }
+ }
+
+ const listr = new Listr<{
+ output: string;
+ }>(
+ selectedLines.flatMap((line, index) => {
+ return Array.from({ length: tries }, (_, attempt) => ({
+ title: `Processing case ${index} (attempt ${attempt + 1})`,
+ task: async (_, task) => {
+ await processIndex(index, attempt, task);
+ },
+ rendererOptions: { persistentOutput: true },
+ }));
+ }),
+ {
+ concurrent: 5,
+ }
+ );
+
+ await listr.run();
+
+ await new Promise((resolve) => out.end(resolve));
+ console.log(`Results written to ${outputFile}`);
+ const stats = analyze(caseResults, tries);
+ await writeFile(analysisFile, JSON.stringify(stats, null, 2), "utf8");
+ printAnalysis(
+ stats,
+ caseResults,
+ provider,
+ selectedLines,
+ tries,
+ skipped,
+ analysisFile
+ );
+}
+
+main().catch((err) => {
+ console.error(err);
+ process.exitCode = 1;
+});
diff --git a/compatibility-test/package-lock.json b/compatibility-test/package-lock.json
new file mode 100644
index 00000000..89b6a5e8
--- /dev/null
+++ b/compatibility-test/package-lock.json
@@ -0,0 +1,1633 @@
+{
+ "name": "compatibility-test",
+ "lockfileVersion": 3,
+ "requires": true,
+ "packages": {
+ "": {
+ "dependencies": {
+ "@openai/agents": "^0.0.15",
+ "ajv": "^8.17.1",
+ "listr2": "^9.0.1"
+ }
+ },
+ "node_modules/@modelcontextprotocol/sdk": {
+ "version": "1.17.1",
+ "resolved": "https://registry.npmjs.org/@modelcontextprotocol/sdk/-/sdk-1.17.1.tgz",
+ "integrity": "sha512-CPle1OQehbWqd25La9Ack5B07StKIxh4+Bf19qnpZKJC1oI22Y0czZHbifjw1UoczIfKBwBDAp/dFxvHG13B5A==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "ajv": "^6.12.6",
+ "content-type": "^1.0.5",
+ "cors": "^2.8.5",
+ "cross-spawn": "^7.0.5",
+ "eventsource": "^3.0.2",
+ "eventsource-parser": "^3.0.0",
+ "express": "^5.0.1",
+ "express-rate-limit": "^7.5.0",
+ "pkce-challenge": "^5.0.0",
+ "raw-body": "^3.0.0",
+ "zod": "^3.23.8",
+ "zod-to-json-schema": "^3.24.1"
+ },
+ "engines": {
+ "node": ">=18"
+ }
+ },
+ "node_modules/@modelcontextprotocol/sdk/node_modules/ajv": {
+ "version": "6.12.6",
+ "resolved": "https://registry.npmjs.org/ajv/-/ajv-6.12.6.tgz",
+ "integrity": "sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "fast-deep-equal": "^3.1.1",
+ "fast-json-stable-stringify": "^2.0.0",
+ "json-schema-traverse": "^0.4.1",
+ "uri-js": "^4.2.2"
+ },
+ "funding": {
+ "type": "github",
+ "url": "https://github.com/sponsors/epoberezkin"
+ }
+ },
+ "node_modules/@modelcontextprotocol/sdk/node_modules/json-schema-traverse": {
+ "version": "0.4.1",
+ "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-0.4.1.tgz",
+ "integrity": "sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/@openai/agents": {
+ "version": "0.0.15",
+ "resolved": "https://registry.npmjs.org/@openai/agents/-/agents-0.0.15.tgz",
+ "integrity": "sha512-B8y+WyWOeHowflPx09pyCfcqikC4OYWK27HTyNGt1oraXv93CzuamSr76iAaU1nWQ1MPbUwl6LHPX4BPUikVkQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@openai/agents-core": "0.0.15",
+ "@openai/agents-openai": "0.0.15",
+ "@openai/agents-realtime": "0.0.15",
+ "debug": "^4.4.0",
+ "openai": "^5.10.1"
+ }
+ },
+ "node_modules/@openai/agents-core": {
+ "version": "0.0.15",
+ "resolved": "https://registry.npmjs.org/@openai/agents-core/-/agents-core-0.0.15.tgz",
+ "integrity": "sha512-ODTqttjW0s0ejBe5PKnYRlFbJSZH2IO6OtUlRhIKmWiWrX6pGRxvpKjTSOXy8DEtpRHBj6Nhky0UoSlO6eOkDQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@openai/zod": "npm:zod@3.25.40 - 3.25.67",
+ "debug": "^4.4.0",
+ "openai": "^5.10.1"
+ },
+ "optionalDependencies": {
+ "@modelcontextprotocol/sdk": "^1.12.0"
+ },
+ "peerDependencies": {
+ "zod": "3.25.40 - 3.25.67"
+ },
+ "peerDependenciesMeta": {
+ "zod": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/@openai/agents-openai": {
+ "version": "0.0.15",
+ "resolved": "https://registry.npmjs.org/@openai/agents-openai/-/agents-openai-0.0.15.tgz",
+ "integrity": "sha512-YIX3n98HdmmWKkb/71OB+DCQUYyGEpqfzPjejzdtNLUvAEs3jvXf7nkC8oTISsuCwrirgBz0rQEefeo0oUlyFQ==",
+ "license": "MIT",
+ "dependencies": {
+ "@openai/agents-core": "0.0.15",
+ "@openai/zod": "npm:zod@3.25.40 - 3.25.67",
+ "debug": "^4.4.0",
+ "openai": "^5.10.1"
+ }
+ },
+ "node_modules/@openai/agents-realtime": {
+ "version": "0.0.15",
+ "resolved": "https://registry.npmjs.org/@openai/agents-realtime/-/agents-realtime-0.0.15.tgz",
+ "integrity": "sha512-kSZzMyij9Xt3BpMb/9snuVnu7a5qKZLyhtN/kWMA+wmfETvWz23BBz6tbO5xOmurAt9//OktkB+94e0T0RBtlA==",
+ "license": "MIT",
+ "dependencies": {
+ "@openai/agents-core": "0.0.15",
+ "@openai/zod": "npm:zod@3.25.40 - 3.25.67",
+ "@types/ws": "^8.18.1",
+ "debug": "^4.4.0",
+ "ws": "^8.18.1"
+ }
+ },
+ "node_modules/@openai/zod": {
+ "name": "zod",
+ "version": "3.25.67",
+ "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz",
+ "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==",
+ "license": "MIT",
+ "funding": {
+ "url": "https://github.com/sponsors/colinhacks"
+ }
+ },
+ "node_modules/@types/node": {
+ "version": "24.2.0",
+ "resolved": "https://registry.npmjs.org/@types/node/-/node-24.2.0.tgz",
+ "integrity": "sha512-3xyG3pMCq3oYCNg7/ZP+E1ooTaGB4cG8JWRsqqOYQdbWNY4zbaV0Ennrd7stjiJEFZCaybcIgpTjJWHRfBSIDw==",
+ "license": "MIT",
+ "dependencies": {
+ "undici-types": "~7.10.0"
+ }
+ },
+ "node_modules/@types/ws": {
+ "version": "8.18.1",
+ "resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
+ "integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==",
+ "license": "MIT",
+ "dependencies": {
+ "@types/node": "*"
+ }
+ },
+ "node_modules/accepts": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/accepts/-/accepts-2.0.0.tgz",
+ "integrity": "sha512-5cvg6CtKwfgdmVqY1WIiXKc3Q1bkRqGLi+2W/6ao+6Y7gu/RCwRuAhGEzh5B4KlszSuTLgZYuqFqo5bImjNKng==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "mime-types": "^3.0.0",
+ "negotiator": "^1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/ajv": {
+ "version": "8.17.1",
+ "resolved": "https://registry.npmjs.org/ajv/-/ajv-8.17.1.tgz",
+ "integrity": "sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==",
+ "license": "MIT",
+ "dependencies": {
+ "fast-deep-equal": "^3.1.3",
+ "fast-uri": "^3.0.1",
+ "json-schema-traverse": "^1.0.0",
+ "require-from-string": "^2.0.2"
+ },
+ "funding": {
+ "type": "github",
+ "url": "https://github.com/sponsors/epoberezkin"
+ }
+ },
+ "node_modules/ansi-escapes": {
+ "version": "7.0.0",
+ "resolved": "https://registry.npmjs.org/ansi-escapes/-/ansi-escapes-7.0.0.tgz",
+ "integrity": "sha512-GdYO7a61mR0fOlAsvC9/rIHf7L96sBc6dEWzeOu+KAea5bZyQRPIpojrVoI4AXGJS/ycu/fBTdLrUkA4ODrvjw==",
+ "license": "MIT",
+ "dependencies": {
+ "environment": "^1.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/ansi-regex": {
+ "version": "6.1.0",
+ "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz",
+ "integrity": "sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-regex?sponsor=1"
+ }
+ },
+ "node_modules/ansi-styles": {
+ "version": "6.2.1",
+ "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.1.tgz",
+ "integrity": "sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/ansi-styles?sponsor=1"
+ }
+ },
+ "node_modules/body-parser": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/body-parser/-/body-parser-2.2.0.tgz",
+ "integrity": "sha512-02qvAaxv8tp7fBa/mw1ga98OGm+eCbqzJOKoRt70sLmfEEi+jyBYVTDGfCL/k06/4EMk/z01gCe7HoCH/f2LTg==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "bytes": "^3.1.2",
+ "content-type": "^1.0.5",
+ "debug": "^4.4.0",
+ "http-errors": "^2.0.0",
+ "iconv-lite": "^0.6.3",
+ "on-finished": "^2.4.1",
+ "qs": "^6.14.0",
+ "raw-body": "^3.0.0",
+ "type-is": "^2.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ }
+ },
+ "node_modules/bytes": {
+ "version": "3.1.2",
+ "resolved": "https://registry.npmjs.org/bytes/-/bytes-3.1.2.tgz",
+ "integrity": "sha512-/Nf7TyzTx6S3yRJObOAV7956r8cr2+Oj8AC5dt8wSP3BQAoeX58NoHyCU8P8zGkNXStjTSi6fzO6F0pBdcYbEg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/call-bind-apply-helpers": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
+ "integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "es-errors": "^1.3.0",
+ "function-bind": "^1.1.2"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/call-bound": {
+ "version": "1.0.4",
+ "resolved": "https://registry.npmjs.org/call-bound/-/call-bound-1.0.4.tgz",
+ "integrity": "sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "call-bind-apply-helpers": "^1.0.2",
+ "get-intrinsic": "^1.3.0"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/cli-cursor": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/cli-cursor/-/cli-cursor-5.0.0.tgz",
+ "integrity": "sha512-aCj4O5wKyszjMmDT4tZj93kxyydN/K5zPWSCe6/0AV/AA1pqe5ZBIw0a2ZfPQV7lL5/yb5HsUreJ6UFAF1tEQw==",
+ "license": "MIT",
+ "dependencies": {
+ "restore-cursor": "^5.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/cli-truncate": {
+ "version": "4.0.0",
+ "resolved": "https://registry.npmjs.org/cli-truncate/-/cli-truncate-4.0.0.tgz",
+ "integrity": "sha512-nPdaFdQ0h/GEigbPClz11D0v/ZJEwxmeVZGeMo3Z5StPtUTkA9o1lD6QwoirYiSDzbcwn2XcjwmCp68W1IS4TA==",
+ "license": "MIT",
+ "dependencies": {
+ "slice-ansi": "^5.0.0",
+ "string-width": "^7.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/colorette": {
+ "version": "2.0.20",
+ "resolved": "https://registry.npmjs.org/colorette/-/colorette-2.0.20.tgz",
+ "integrity": "sha512-IfEDxwoWIjkeXL1eXcDiow4UbKjhLdq6/EuSVR9GMN7KVH3r9gQ83e73hsz1Nd1T3ijd5xv1wcWRYO+D6kCI2w==",
+ "license": "MIT"
+ },
+ "node_modules/content-disposition": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/content-disposition/-/content-disposition-1.0.0.tgz",
+ "integrity": "sha512-Au9nRL8VNUut/XSzbQA38+M78dzP4D+eqg3gfJHMIHHYa3bg067xj1KxMUWj+VULbiZMowKngFFbKczUrNJ1mg==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "safe-buffer": "5.2.1"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/content-type": {
+ "version": "1.0.5",
+ "resolved": "https://registry.npmjs.org/content-type/-/content-type-1.0.5.tgz",
+ "integrity": "sha512-nTjqfcBFEipKdXCv4YDQWCfmcLZKm81ldF0pAopTvyrFGVbcR6P/VAAd5G7N+0tTr8QqiU0tFadD6FK4NtJwOA==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/cookie": {
+ "version": "0.7.2",
+ "resolved": "https://registry.npmjs.org/cookie/-/cookie-0.7.2.tgz",
+ "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/cookie-signature": {
+ "version": "1.2.2",
+ "resolved": "https://registry.npmjs.org/cookie-signature/-/cookie-signature-1.2.2.tgz",
+ "integrity": "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=6.6.0"
+ }
+ },
+ "node_modules/cors": {
+ "version": "2.8.5",
+ "resolved": "https://registry.npmjs.org/cors/-/cors-2.8.5.tgz",
+ "integrity": "sha512-KIHbLJqu73RGr/hnbrO9uBeixNGuvSQjul/jdFvS/KFSIH1hWVd1ng7zOHx+YrEfInLG7q4n6GHQ9cDtxv/P6g==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "object-assign": "^4",
+ "vary": "^1"
+ },
+ "engines": {
+ "node": ">= 0.10"
+ }
+ },
+ "node_modules/cross-spawn": {
+ "version": "7.0.6",
+ "resolved": "https://registry.npmjs.org/cross-spawn/-/cross-spawn-7.0.6.tgz",
+ "integrity": "sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "path-key": "^3.1.0",
+ "shebang-command": "^2.0.0",
+ "which": "^2.0.1"
+ },
+ "engines": {
+ "node": ">= 8"
+ }
+ },
+ "node_modules/debug": {
+ "version": "4.4.1",
+ "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz",
+ "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==",
+ "license": "MIT",
+ "dependencies": {
+ "ms": "^2.1.3"
+ },
+ "engines": {
+ "node": ">=6.0"
+ },
+ "peerDependenciesMeta": {
+ "supports-color": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/depd": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/depd/-/depd-2.0.0.tgz",
+ "integrity": "sha512-g7nH6P6dyDioJogAAGprGpCtVImJhpPk/roCzdb3fIh61/s/nPsfR6onyMwkCAR/OlC3yBC0lESvUoQEAssIrw==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/dunder-proto": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
+ "integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "call-bind-apply-helpers": "^1.0.1",
+ "es-errors": "^1.3.0",
+ "gopd": "^1.2.0"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/ee-first": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/ee-first/-/ee-first-1.1.1.tgz",
+ "integrity": "sha512-WMwm9LhRUo+WUaRN+vRuETqG89IgZphVSNkdFgeb6sS/E4OrDIN7t48CAewSHXc6C8lefD8KKfr5vY61brQlow==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/emoji-regex": {
+ "version": "10.4.0",
+ "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-10.4.0.tgz",
+ "integrity": "sha512-EC+0oUMY1Rqm4O6LLrgjtYDvcVYTy7chDnM4Q7030tP4Kwj3u/pR6gP9ygnp2CJMK5Gq+9Q2oqmrFJAz01DXjw==",
+ "license": "MIT"
+ },
+ "node_modules/encodeurl": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/encodeurl/-/encodeurl-2.0.0.tgz",
+ "integrity": "sha512-Q0n9HRi4m6JuGIV1eFlmvJB7ZEVxu93IrMyiMsGC0lrMJMWzRgx6WGquyfQgZVb31vhGgXnfmPNNXmxnOkRBrg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/environment": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/environment/-/environment-1.1.0.tgz",
+ "integrity": "sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/es-define-property": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
+ "integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/es-errors": {
+ "version": "1.3.0",
+ "resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
+ "integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/es-object-atoms": {
+ "version": "1.1.1",
+ "resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
+ "integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "es-errors": "^1.3.0"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/escape-html": {
+ "version": "1.0.3",
+ "resolved": "https://registry.npmjs.org/escape-html/-/escape-html-1.0.3.tgz",
+ "integrity": "sha512-NiSupZ4OeuGwr68lGIeym/ksIZMJodUGOSCZ/FSnTxcrekbvqrgdUxlJOMpijaKZVjAJrWrGs/6Jy8OMuyj9ow==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/etag": {
+ "version": "1.8.1",
+ "resolved": "https://registry.npmjs.org/etag/-/etag-1.8.1.tgz",
+ "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/eventemitter3": {
+ "version": "5.0.1",
+ "resolved": "https://registry.npmjs.org/eventemitter3/-/eventemitter3-5.0.1.tgz",
+ "integrity": "sha512-GWkBvjiSZK87ELrYOSESUYeVIc9mvLLf/nXalMOS5dYrgZq9o5OVkbZAVM06CVxYsCwH9BDZFPlQTlPA1j4ahA==",
+ "license": "MIT"
+ },
+ "node_modules/eventsource": {
+ "version": "3.0.7",
+ "resolved": "https://registry.npmjs.org/eventsource/-/eventsource-3.0.7.tgz",
+ "integrity": "sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "eventsource-parser": "^3.0.1"
+ },
+ "engines": {
+ "node": ">=18.0.0"
+ }
+ },
+ "node_modules/eventsource-parser": {
+ "version": "3.0.3",
+ "resolved": "https://registry.npmjs.org/eventsource-parser/-/eventsource-parser-3.0.3.tgz",
+ "integrity": "sha512-nVpZkTMM9rF6AQ9gPJpFsNAMt48wIzB5TQgiTLdHiuO8XEDhUgZEhqKlZWXbIzo9VmJ/HvysHqEaVeD5v9TPvA==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=20.0.0"
+ }
+ },
+ "node_modules/express": {
+ "version": "5.1.0",
+ "resolved": "https://registry.npmjs.org/express/-/express-5.1.0.tgz",
+ "integrity": "sha512-DT9ck5YIRU+8GYzzU5kT3eHGA5iL+1Zd0EutOmTE9Dtk+Tvuzd23VBU+ec7HPNSTxXYO55gPV/hq4pSBJDjFpA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "accepts": "^2.0.0",
+ "body-parser": "^2.2.0",
+ "content-disposition": "^1.0.0",
+ "content-type": "^1.0.5",
+ "cookie": "^0.7.1",
+ "cookie-signature": "^1.2.1",
+ "debug": "^4.4.0",
+ "encodeurl": "^2.0.0",
+ "escape-html": "^1.0.3",
+ "etag": "^1.8.1",
+ "finalhandler": "^2.1.0",
+ "fresh": "^2.0.0",
+ "http-errors": "^2.0.0",
+ "merge-descriptors": "^2.0.0",
+ "mime-types": "^3.0.0",
+ "on-finished": "^2.4.1",
+ "once": "^1.4.0",
+ "parseurl": "^1.3.3",
+ "proxy-addr": "^2.0.7",
+ "qs": "^6.14.0",
+ "range-parser": "^1.2.1",
+ "router": "^2.2.0",
+ "send": "^1.1.0",
+ "serve-static": "^2.2.0",
+ "statuses": "^2.0.1",
+ "type-is": "^2.0.1",
+ "vary": "^1.1.2"
+ },
+ "engines": {
+ "node": ">= 18"
+ },
+ "funding": {
+ "type": "opencollective",
+ "url": "https://opencollective.com/express"
+ }
+ },
+ "node_modules/express-rate-limit": {
+ "version": "7.5.1",
+ "resolved": "https://registry.npmjs.org/express-rate-limit/-/express-rate-limit-7.5.1.tgz",
+ "integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 16"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/express-rate-limit"
+ },
+ "peerDependencies": {
+ "express": ">= 4.11"
+ }
+ },
+ "node_modules/fast-deep-equal": {
+ "version": "3.1.3",
+ "resolved": "https://registry.npmjs.org/fast-deep-equal/-/fast-deep-equal-3.1.3.tgz",
+ "integrity": "sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==",
+ "license": "MIT"
+ },
+ "node_modules/fast-json-stable-stringify": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz",
+ "integrity": "sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/fast-uri": {
+ "version": "3.0.6",
+ "resolved": "https://registry.npmjs.org/fast-uri/-/fast-uri-3.0.6.tgz",
+ "integrity": "sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==",
+ "funding": [
+ {
+ "type": "github",
+ "url": "https://github.com/sponsors/fastify"
+ },
+ {
+ "type": "opencollective",
+ "url": "https://opencollective.com/fastify"
+ }
+ ],
+ "license": "BSD-3-Clause"
+ },
+ "node_modules/finalhandler": {
+ "version": "2.1.0",
+ "resolved": "https://registry.npmjs.org/finalhandler/-/finalhandler-2.1.0.tgz",
+ "integrity": "sha512-/t88Ty3d5JWQbWYgaOGCCYfXRwV1+be02WqYYlL6h0lEiUAMPM8o8qKGO01YIkOHzka2up08wvgYD0mDiI+q3Q==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "debug": "^4.4.0",
+ "encodeurl": "^2.0.0",
+ "escape-html": "^1.0.3",
+ "on-finished": "^2.4.1",
+ "parseurl": "^1.3.3",
+ "statuses": "^2.0.1"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/forwarded": {
+ "version": "0.2.0",
+ "resolved": "https://registry.npmjs.org/forwarded/-/forwarded-0.2.0.tgz",
+ "integrity": "sha512-buRG0fpBtRHSTCOASe6hD258tEubFoRLb4ZNA6NxMVHNw2gOcwHo9wyablzMzOA5z9xA9L1KNjk/Nt6MT9aYow==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/fresh": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/fresh/-/fresh-2.0.0.tgz",
+ "integrity": "sha512-Rx/WycZ60HOaqLKAi6cHRKKI7zxWbJ31MhntmtwMoaTeF7XFH9hhBp8vITaMidfljRQ6eYWCKkaTK+ykVJHP2A==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/function-bind": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
+ "integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
+ "license": "MIT",
+ "optional": true,
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/get-east-asian-width": {
+ "version": "1.3.0",
+ "resolved": "https://registry.npmjs.org/get-east-asian-width/-/get-east-asian-width-1.3.0.tgz",
+ "integrity": "sha512-vpeMIQKxczTD/0s2CdEWHcb0eeJe6TFjxb+J5xgX7hScxqrGuyjmv4c1D4A/gelKfyox0gJJwIHF+fLjeaM8kQ==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/get-intrinsic": {
+ "version": "1.3.0",
+ "resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
+ "integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "call-bind-apply-helpers": "^1.0.2",
+ "es-define-property": "^1.0.1",
+ "es-errors": "^1.3.0",
+ "es-object-atoms": "^1.1.1",
+ "function-bind": "^1.1.2",
+ "get-proto": "^1.0.1",
+ "gopd": "^1.2.0",
+ "has-symbols": "^1.1.0",
+ "hasown": "^2.0.2",
+ "math-intrinsics": "^1.1.0"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/get-proto": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
+ "integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "dunder-proto": "^1.0.1",
+ "es-object-atoms": "^1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/gopd": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
+ "integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/has-symbols": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
+ "integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/hasown": {
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
+ "integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "function-bind": "^1.1.2"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/http-errors": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/http-errors/-/http-errors-2.0.0.tgz",
+ "integrity": "sha512-FtwrG/euBzaEjYeRqOgly7G0qviiXoJWnvEH2Z1plBdXgbyjv34pHTSb9zoeHMyDy33+DWy5Wt9Wo+TURtOYSQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "depd": "2.0.0",
+ "inherits": "2.0.4",
+ "setprototypeof": "1.2.0",
+ "statuses": "2.0.1",
+ "toidentifier": "1.0.1"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/http-errors/node_modules/statuses": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.1.tgz",
+ "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/iconv-lite": {
+ "version": "0.6.3",
+ "resolved": "https://registry.npmjs.org/iconv-lite/-/iconv-lite-0.6.3.tgz",
+ "integrity": "sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "safer-buffer": ">= 2.1.2 < 3.0.0"
+ },
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/inherits": {
+ "version": "2.0.4",
+ "resolved": "https://registry.npmjs.org/inherits/-/inherits-2.0.4.tgz",
+ "integrity": "sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==",
+ "license": "ISC",
+ "optional": true
+ },
+ "node_modules/ipaddr.js": {
+ "version": "1.9.1",
+ "resolved": "https://registry.npmjs.org/ipaddr.js/-/ipaddr.js-1.9.1.tgz",
+ "integrity": "sha512-0KI/607xoxSToH7GjN1FfSbLoU0+btTicjsQSWQlh/hZykN8KpmMf7uYwPW3R+akZ6R/w18ZlXSHBYXiYUPO3g==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.10"
+ }
+ },
+ "node_modules/is-fullwidth-code-point": {
+ "version": "4.0.0",
+ "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-4.0.0.tgz",
+ "integrity": "sha512-O4L094N2/dZ7xqVdrXhh9r1KODPJpFms8B5sGdJLPy664AgvXsreZUyCQQNItZRDlYug4xStLjNp/sz3HvBowQ==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/is-promise": {
+ "version": "4.0.0",
+ "resolved": "https://registry.npmjs.org/is-promise/-/is-promise-4.0.0.tgz",
+ "integrity": "sha512-hvpoI6korhJMnej285dSg6nu1+e6uxs7zG3BYAm5byqDsgJNWwxzM6z6iZiAgQR4TJ30JmBTOwqZUw3WlyH3AQ==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/isexe": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz",
+ "integrity": "sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==",
+ "license": "ISC",
+ "optional": true
+ },
+ "node_modules/json-schema-traverse": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/json-schema-traverse/-/json-schema-traverse-1.0.0.tgz",
+ "integrity": "sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==",
+ "license": "MIT"
+ },
+ "node_modules/listr2": {
+ "version": "9.0.1",
+ "resolved": "https://registry.npmjs.org/listr2/-/listr2-9.0.1.tgz",
+ "integrity": "sha512-SL0JY3DaxylDuo/MecFeiC+7pedM0zia33zl0vcjgwcq1q1FWWF1To9EIauPbl8GbMCU0R2e0uJ8bZunhYKD2g==",
+ "license": "MIT",
+ "dependencies": {
+ "cli-truncate": "^4.0.0",
+ "colorette": "^2.0.20",
+ "eventemitter3": "^5.0.1",
+ "log-update": "^6.1.0",
+ "rfdc": "^1.4.1",
+ "wrap-ansi": "^9.0.0"
+ },
+ "engines": {
+ "node": ">=20.0.0"
+ }
+ },
+ "node_modules/log-update": {
+ "version": "6.1.0",
+ "resolved": "https://registry.npmjs.org/log-update/-/log-update-6.1.0.tgz",
+ "integrity": "sha512-9ie8ItPR6tjY5uYJh8K/Zrv/RMZ5VOlOWvtZdEHYSTFKZfIBPQa9tOAEeAWhd+AnIneLJ22w5fjOYtoutpWq5w==",
+ "license": "MIT",
+ "dependencies": {
+ "ansi-escapes": "^7.0.0",
+ "cli-cursor": "^5.0.0",
+ "slice-ansi": "^7.1.0",
+ "strip-ansi": "^7.1.0",
+ "wrap-ansi": "^9.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/log-update/node_modules/is-fullwidth-code-point": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-5.0.0.tgz",
+ "integrity": "sha512-OVa3u9kkBbw7b8Xw5F9P+D/T9X+Z4+JruYVNapTjPYZYUznQ5YfWeFkOj606XYYW8yugTfC8Pj0hYqvi4ryAhA==",
+ "license": "MIT",
+ "dependencies": {
+ "get-east-asian-width": "^1.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/log-update/node_modules/slice-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-7.1.0.tgz",
+ "integrity": "sha512-bSiSngZ/jWeX93BqeIAbImyTbEihizcwNjFoRUIY/T1wWQsfsm2Vw1agPKylXvQTU7iASGdHhyqRlqQzfz+Htg==",
+ "license": "MIT",
+ "dependencies": {
+ "ansi-styles": "^6.2.1",
+ "is-fullwidth-code-point": "^5.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/slice-ansi?sponsor=1"
+ }
+ },
+ "node_modules/math-intrinsics": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
+ "integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ }
+ },
+ "node_modules/media-typer": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/media-typer/-/media-typer-1.1.0.tgz",
+ "integrity": "sha512-aisnrDP4GNe06UcKFnV5bfMNPBUw4jsLGaWwWfnH3v02GnBuXX2MCVn5RbrWo0j3pczUilYblq7fQ7Nw2t5XKw==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/merge-descriptors": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/merge-descriptors/-/merge-descriptors-2.0.0.tgz",
+ "integrity": "sha512-Snk314V5ayFLhp3fkUREub6WtjBfPdCPY1Ln8/8munuLuiYhsABgBVWsozAG+MWMbVEvcdcpbi9R7ww22l9Q3g==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/mime-db": {
+ "version": "1.54.0",
+ "resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.54.0.tgz",
+ "integrity": "sha512-aU5EJuIN2WDemCcAp2vFBfp/m4EAhWJnUNSSw0ixs7/kXbd6Pg64EmwJkNdFhB8aWt1sH2CTXrLxo/iAGV3oPQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/mime-types": {
+ "version": "3.0.1",
+ "resolved": "https://registry.npmjs.org/mime-types/-/mime-types-3.0.1.tgz",
+ "integrity": "sha512-xRc4oEhT6eaBpU1XF7AjpOFD+xQmXNB5OVKwp4tqCuBpHLS/ZbBDrc07mYTDqVMg6PfxUjjNp85O6Cd2Z/5HWA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "mime-db": "^1.54.0"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/mimic-function": {
+ "version": "5.0.1",
+ "resolved": "https://registry.npmjs.org/mimic-function/-/mimic-function-5.0.1.tgz",
+ "integrity": "sha512-VP79XUPxV2CigYP3jWwAUFSku2aKqBH7uTAapFWCBqutsbmDo96KY5o8uh6U+/YSIn5OxJnXp73beVkpqMIGhA==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/ms": {
+ "version": "2.1.3",
+ "resolved": "https://registry.npmjs.org/ms/-/ms-2.1.3.tgz",
+ "integrity": "sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==",
+ "license": "MIT"
+ },
+ "node_modules/negotiator": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/negotiator/-/negotiator-1.0.0.tgz",
+ "integrity": "sha512-8Ofs/AUQh8MaEcrlq5xOX0CQ9ypTF5dl78mjlMNfOK08fzpgTHQRQPBxcPlEtIw0yRpws+Zo/3r+5WRby7u3Gg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/object-assign": {
+ "version": "4.1.1",
+ "resolved": "https://registry.npmjs.org/object-assign/-/object-assign-4.1.1.tgz",
+ "integrity": "sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/object-inspect": {
+ "version": "1.13.4",
+ "resolved": "https://registry.npmjs.org/object-inspect/-/object-inspect-1.13.4.tgz",
+ "integrity": "sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/on-finished": {
+ "version": "2.4.1",
+ "resolved": "https://registry.npmjs.org/on-finished/-/on-finished-2.4.1.tgz",
+ "integrity": "sha512-oVlzkg3ENAhCk2zdv7IJwd/QUD4z2RxRwpkcGY8psCVcCYZNq4wYnVWALHM+brtuJjePWiYF/ClmuDr8Ch5+kg==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "ee-first": "1.1.1"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/once": {
+ "version": "1.4.0",
+ "resolved": "https://registry.npmjs.org/once/-/once-1.4.0.tgz",
+ "integrity": "sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==",
+ "license": "ISC",
+ "optional": true,
+ "dependencies": {
+ "wrappy": "1"
+ }
+ },
+ "node_modules/onetime": {
+ "version": "7.0.0",
+ "resolved": "https://registry.npmjs.org/onetime/-/onetime-7.0.0.tgz",
+ "integrity": "sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ==",
+ "license": "MIT",
+ "dependencies": {
+ "mimic-function": "^5.0.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/openai": {
+ "version": "5.12.0",
+ "resolved": "https://registry.npmjs.org/openai/-/openai-5.12.0.tgz",
+ "integrity": "sha512-vUdt02xiWgOHiYUmW0Hj1Qu9OKAiVQu5Bd547ktVCiMKC1BkB5L3ImeEnCyq3WpRKR6ZTaPgekzqdozwdPs7Lg==",
+ "license": "Apache-2.0",
+ "bin": {
+ "openai": "bin/cli"
+ },
+ "peerDependencies": {
+ "ws": "^8.18.0",
+ "zod": "^3.23.8"
+ },
+ "peerDependenciesMeta": {
+ "ws": {
+ "optional": true
+ },
+ "zod": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/parseurl": {
+ "version": "1.3.3",
+ "resolved": "https://registry.npmjs.org/parseurl/-/parseurl-1.3.3.tgz",
+ "integrity": "sha512-CiyeOxFT/JZyN5m0z9PfXw4SCBJ6Sygz1Dpl0wqjlhDEGGBP1GnsUVEL0p63hoG1fcj3fHynXi9NYO4nWOL+qQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/path-key": {
+ "version": "3.1.1",
+ "resolved": "https://registry.npmjs.org/path-key/-/path-key-3.1.1.tgz",
+ "integrity": "sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/path-to-regexp": {
+ "version": "8.2.0",
+ "resolved": "https://registry.npmjs.org/path-to-regexp/-/path-to-regexp-8.2.0.tgz",
+ "integrity": "sha512-TdrF7fW9Rphjq4RjrW0Kp2AW0Ahwu9sRGTkS6bvDi0SCwZlEZYmcfDbEsTz8RVk0EHIS/Vd1bv3JhG+1xZuAyQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=16"
+ }
+ },
+ "node_modules/pkce-challenge": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/pkce-challenge/-/pkce-challenge-5.0.0.tgz",
+ "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=16.20.0"
+ }
+ },
+ "node_modules/proxy-addr": {
+ "version": "2.0.7",
+ "resolved": "https://registry.npmjs.org/proxy-addr/-/proxy-addr-2.0.7.tgz",
+ "integrity": "sha512-llQsMLSUDUPT44jdrU/O37qlnifitDP+ZwrmmZcoSKyLKvtZxpyV0n2/bD/N4tBAAZ/gJEdZU7KMraoK1+XYAg==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "forwarded": "0.2.0",
+ "ipaddr.js": "1.9.1"
+ },
+ "engines": {
+ "node": ">= 0.10"
+ }
+ },
+ "node_modules/punycode": {
+ "version": "2.3.1",
+ "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz",
+ "integrity": "sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=6"
+ }
+ },
+ "node_modules/qs": {
+ "version": "6.14.0",
+ "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz",
+ "integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==",
+ "license": "BSD-3-Clause",
+ "optional": true,
+ "dependencies": {
+ "side-channel": "^1.1.0"
+ },
+ "engines": {
+ "node": ">=0.6"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/range-parser": {
+ "version": "1.2.1",
+ "resolved": "https://registry.npmjs.org/range-parser/-/range-parser-1.2.1.tgz",
+ "integrity": "sha512-Hrgsx+orqoygnmhFbKaHE6c296J+HTAQXoxEF6gNupROmmGJRoyzfG3ccAveqCBrwr/2yxQ5BVd/GTl5agOwSg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/raw-body": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/raw-body/-/raw-body-3.0.0.tgz",
+ "integrity": "sha512-RmkhL8CAyCRPXCE28MMH0z2PNWQBNk2Q09ZdxM9IOOXwxwZbN+qbWaatPkdkWIKL2ZVDImrN/pK5HTRz2PcS4g==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "bytes": "3.1.2",
+ "http-errors": "2.0.0",
+ "iconv-lite": "0.6.3",
+ "unpipe": "1.0.0"
+ },
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/require-from-string": {
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/require-from-string/-/require-from-string-2.0.2.tgz",
+ "integrity": "sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=0.10.0"
+ }
+ },
+ "node_modules/restore-cursor": {
+ "version": "5.1.0",
+ "resolved": "https://registry.npmjs.org/restore-cursor/-/restore-cursor-5.1.0.tgz",
+ "integrity": "sha512-oMA2dcrw6u0YfxJQXm342bFKX/E4sG9rbTzO9ptUcR/e8A33cHuvStiYOwH7fszkZlZ1z/ta9AAoPk2F4qIOHA==",
+ "license": "MIT",
+ "dependencies": {
+ "onetime": "^7.0.0",
+ "signal-exit": "^4.1.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/rfdc": {
+ "version": "1.4.1",
+ "resolved": "https://registry.npmjs.org/rfdc/-/rfdc-1.4.1.tgz",
+ "integrity": "sha512-q1b3N5QkRUWUl7iyylaaj3kOpIT0N2i9MqIEQXP73GVsN9cw3fdx8X63cEmWhJGi2PPCF23Ijp7ktmd39rawIA==",
+ "license": "MIT"
+ },
+ "node_modules/router": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/router/-/router-2.2.0.tgz",
+ "integrity": "sha512-nLTrUKm2UyiL7rlhapu/Zl45FwNgkZGaCpZbIHajDYgwlJCOzLSk+cIPAnsEqV955GjILJnKbdQC1nVPz+gAYQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "debug": "^4.4.0",
+ "depd": "^2.0.0",
+ "is-promise": "^4.0.0",
+ "parseurl": "^1.3.3",
+ "path-to-regexp": "^8.0.0"
+ },
+ "engines": {
+ "node": ">= 18"
+ }
+ },
+ "node_modules/safe-buffer": {
+ "version": "5.2.1",
+ "resolved": "https://registry.npmjs.org/safe-buffer/-/safe-buffer-5.2.1.tgz",
+ "integrity": "sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==",
+ "funding": [
+ {
+ "type": "github",
+ "url": "https://github.com/sponsors/feross"
+ },
+ {
+ "type": "patreon",
+ "url": "https://www.patreon.com/feross"
+ },
+ {
+ "type": "consulting",
+ "url": "https://feross.org/support"
+ }
+ ],
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/safer-buffer": {
+ "version": "2.1.2",
+ "resolved": "https://registry.npmjs.org/safer-buffer/-/safer-buffer-2.1.2.tgz",
+ "integrity": "sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==",
+ "license": "MIT",
+ "optional": true
+ },
+ "node_modules/send": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/send/-/send-1.2.0.tgz",
+ "integrity": "sha512-uaW0WwXKpL9blXE2o0bRhoL2EGXIrZxQ2ZQ4mgcfoBxdFmQold+qWsD2jLrfZ0trjKL6vOw0j//eAwcALFjKSw==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "debug": "^4.3.5",
+ "encodeurl": "^2.0.0",
+ "escape-html": "^1.0.3",
+ "etag": "^1.8.1",
+ "fresh": "^2.0.0",
+ "http-errors": "^2.0.0",
+ "mime-types": "^3.0.1",
+ "ms": "^2.1.3",
+ "on-finished": "^2.4.1",
+ "range-parser": "^1.2.1",
+ "statuses": "^2.0.1"
+ },
+ "engines": {
+ "node": ">= 18"
+ }
+ },
+ "node_modules/serve-static": {
+ "version": "2.2.0",
+ "resolved": "https://registry.npmjs.org/serve-static/-/serve-static-2.2.0.tgz",
+ "integrity": "sha512-61g9pCh0Vnh7IutZjtLGGpTA355+OPn2TyDv/6ivP2h/AdAVX9azsoxmg2/M6nZeQZNYBEwIcsne1mJd9oQItQ==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "encodeurl": "^2.0.0",
+ "escape-html": "^1.0.3",
+ "parseurl": "^1.3.3",
+ "send": "^1.2.0"
+ },
+ "engines": {
+ "node": ">= 18"
+ }
+ },
+ "node_modules/setprototypeof": {
+ "version": "1.2.0",
+ "resolved": "https://registry.npmjs.org/setprototypeof/-/setprototypeof-1.2.0.tgz",
+ "integrity": "sha512-E5LDX7Wrp85Kil5bhZv46j8jOeboKq5JMmYM3gVGdGH8xFpPWXUMsNrlODCrkoxMEeNi/XZIwuRvY4XNwYMJpw==",
+ "license": "ISC",
+ "optional": true
+ },
+ "node_modules/shebang-command": {
+ "version": "2.0.0",
+ "resolved": "https://registry.npmjs.org/shebang-command/-/shebang-command-2.0.0.tgz",
+ "integrity": "sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "shebang-regex": "^3.0.0"
+ },
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/shebang-regex": {
+ "version": "3.0.0",
+ "resolved": "https://registry.npmjs.org/shebang-regex/-/shebang-regex-3.0.0.tgz",
+ "integrity": "sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=8"
+ }
+ },
+ "node_modules/side-channel": {
+ "version": "1.1.0",
+ "resolved": "https://registry.npmjs.org/side-channel/-/side-channel-1.1.0.tgz",
+ "integrity": "sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "es-errors": "^1.3.0",
+ "object-inspect": "^1.13.3",
+ "side-channel-list": "^1.0.0",
+ "side-channel-map": "^1.0.1",
+ "side-channel-weakmap": "^1.0.2"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/side-channel-list": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/side-channel-list/-/side-channel-list-1.0.0.tgz",
+ "integrity": "sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "es-errors": "^1.3.0",
+ "object-inspect": "^1.13.3"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/side-channel-map": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/side-channel-map/-/side-channel-map-1.0.1.tgz",
+ "integrity": "sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "call-bound": "^1.0.2",
+ "es-errors": "^1.3.0",
+ "get-intrinsic": "^1.2.5",
+ "object-inspect": "^1.13.3"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/side-channel-weakmap": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/side-channel-weakmap/-/side-channel-weakmap-1.0.2.tgz",
+ "integrity": "sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "call-bound": "^1.0.2",
+ "es-errors": "^1.3.0",
+ "get-intrinsic": "^1.2.5",
+ "object-inspect": "^1.13.3",
+ "side-channel-map": "^1.0.1"
+ },
+ "engines": {
+ "node": ">= 0.4"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/ljharb"
+ }
+ },
+ "node_modules/signal-exit": {
+ "version": "4.1.0",
+ "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz",
+ "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==",
+ "license": "ISC",
+ "engines": {
+ "node": ">=14"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/isaacs"
+ }
+ },
+ "node_modules/slice-ansi": {
+ "version": "5.0.0",
+ "resolved": "https://registry.npmjs.org/slice-ansi/-/slice-ansi-5.0.0.tgz",
+ "integrity": "sha512-FC+lgizVPfie0kkhqUScwRu1O/lF6NOgJmlCgK+/LYxDCTk8sGelYaHDhFcDN+Sn3Cv+3VSa4Byeo+IMCzpMgQ==",
+ "license": "MIT",
+ "dependencies": {
+ "ansi-styles": "^6.0.0",
+ "is-fullwidth-code-point": "^4.0.0"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/slice-ansi?sponsor=1"
+ }
+ },
+ "node_modules/statuses": {
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/statuses/-/statuses-2.0.2.tgz",
+ "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/string-width": {
+ "version": "7.2.0",
+ "resolved": "https://registry.npmjs.org/string-width/-/string-width-7.2.0.tgz",
+ "integrity": "sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==",
+ "license": "MIT",
+ "dependencies": {
+ "emoji-regex": "^10.3.0",
+ "get-east-asian-width": "^1.0.0",
+ "strip-ansi": "^7.1.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/sponsors/sindresorhus"
+ }
+ },
+ "node_modules/strip-ansi": {
+ "version": "7.1.0",
+ "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-7.1.0.tgz",
+ "integrity": "sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==",
+ "license": "MIT",
+ "dependencies": {
+ "ansi-regex": "^6.0.1"
+ },
+ "engines": {
+ "node": ">=12"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/strip-ansi?sponsor=1"
+ }
+ },
+ "node_modules/toidentifier": {
+ "version": "1.0.1",
+ "resolved": "https://registry.npmjs.org/toidentifier/-/toidentifier-1.0.1.tgz",
+ "integrity": "sha512-o5sSPKEkg/DIQNmH43V0/uerLrpzVedkUh8tGNvaeXpfpuwjKenlSox/2O/BTlZUtEe+JG7s5YhEz608PlAHRA==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">=0.6"
+ }
+ },
+ "node_modules/type-is": {
+ "version": "2.0.1",
+ "resolved": "https://registry.npmjs.org/type-is/-/type-is-2.0.1.tgz",
+ "integrity": "sha512-OZs6gsjF4vMp32qrCbiVSkrFmXtG/AZhY3t0iAMrMBiAZyV9oALtXO8hsrHbMXF9x6L3grlFuwW2oAz7cav+Gw==",
+ "license": "MIT",
+ "optional": true,
+ "dependencies": {
+ "content-type": "^1.0.5",
+ "media-typer": "^1.1.0",
+ "mime-types": "^3.0.0"
+ },
+ "engines": {
+ "node": ">= 0.6"
+ }
+ },
+ "node_modules/undici-types": {
+ "version": "7.10.0",
+ "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.10.0.tgz",
+ "integrity": "sha512-t5Fy/nfn+14LuOc2KNYg75vZqClpAiqscVvMygNnlsHBFpSXdJaYtXMcdNLpl/Qvc3P2cB3s6lOV51nqsFq4ag==",
+ "license": "MIT"
+ },
+ "node_modules/unpipe": {
+ "version": "1.0.0",
+ "resolved": "https://registry.npmjs.org/unpipe/-/unpipe-1.0.0.tgz",
+ "integrity": "sha512-pjy2bYhSsufwWlKwPc+l3cN7+wuJlK6uz0YdJEOlQDbl6jo/YlPi4mb8agUkVC8BF7V8NuzeyPNqRksA3hztKQ==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/uri-js": {
+ "version": "4.4.1",
+ "resolved": "https://registry.npmjs.org/uri-js/-/uri-js-4.4.1.tgz",
+ "integrity": "sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==",
+ "license": "BSD-2-Clause",
+ "optional": true,
+ "dependencies": {
+ "punycode": "^2.1.0"
+ }
+ },
+ "node_modules/vary": {
+ "version": "1.1.2",
+ "resolved": "https://registry.npmjs.org/vary/-/vary-1.1.2.tgz",
+ "integrity": "sha512-BNGbWLfd0eUPabhkXUVm0j8uuvREyTh5ovRa/dyow/BqAbZJyC+5fU+IzQOzmAKzYqYRAISoRhdQr3eIZ/PXqg==",
+ "license": "MIT",
+ "optional": true,
+ "engines": {
+ "node": ">= 0.8"
+ }
+ },
+ "node_modules/which": {
+ "version": "2.0.2",
+ "resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
+ "integrity": "sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==",
+ "license": "ISC",
+ "optional": true,
+ "dependencies": {
+ "isexe": "^2.0.0"
+ },
+ "bin": {
+ "node-which": "bin/node-which"
+ },
+ "engines": {
+ "node": ">= 8"
+ }
+ },
+ "node_modules/wrap-ansi": {
+ "version": "9.0.0",
+ "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-9.0.0.tgz",
+ "integrity": "sha512-G8ura3S+3Z2G+mkgNRq8dqaFZAuxfsxpBB8OCTGRTCtp+l/v9nbFNmCUP1BZMts3G1142MsZfn6eeUKrr4PD1Q==",
+ "license": "MIT",
+ "dependencies": {
+ "ansi-styles": "^6.2.1",
+ "string-width": "^7.0.0",
+ "strip-ansi": "^7.1.0"
+ },
+ "engines": {
+ "node": ">=18"
+ },
+ "funding": {
+ "url": "https://github.com/chalk/wrap-ansi?sponsor=1"
+ }
+ },
+ "node_modules/wrappy": {
+ "version": "1.0.2",
+ "resolved": "https://registry.npmjs.org/wrappy/-/wrappy-1.0.2.tgz",
+ "integrity": "sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==",
+ "license": "ISC",
+ "optional": true
+ },
+ "node_modules/ws": {
+ "version": "8.18.3",
+ "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz",
+ "integrity": "sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==",
+ "license": "MIT",
+ "engines": {
+ "node": ">=10.0.0"
+ },
+ "peerDependencies": {
+ "bufferutil": "^4.0.1",
+ "utf-8-validate": ">=5.0.2"
+ },
+ "peerDependenciesMeta": {
+ "bufferutil": {
+ "optional": true
+ },
+ "utf-8-validate": {
+ "optional": true
+ }
+ }
+ },
+ "node_modules/zod": {
+ "version": "3.25.67",
+ "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.67.tgz",
+ "integrity": "sha512-idA2YXwpCdqUSKRCACDE6ItZD9TZzy3OZMtpfLoh6oPR47lipysRrJfjzMqFxQ3uJuUPyUeWe1r9vLH33xO/Qw==",
+ "license": "MIT",
+ "optional": true,
+ "funding": {
+ "url": "https://github.com/sponsors/colinhacks"
+ }
+ },
+ "node_modules/zod-to-json-schema": {
+ "version": "3.24.6",
+ "resolved": "https://registry.npmjs.org/zod-to-json-schema/-/zod-to-json-schema-3.24.6.tgz",
+ "integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==",
+ "license": "ISC",
+ "optional": true,
+ "peerDependencies": {
+ "zod": "^3.24.1"
+ }
+ }
+ }
+}
diff --git a/compatibility-test/package.json b/compatibility-test/package.json
new file mode 100644
index 00000000..66d51439
--- /dev/null
+++ b/compatibility-test/package.json
@@ -0,0 +1,11 @@
+{
+ "type": "module",
+ "dependencies": {
+ "@openai/agents": "^0.0.15",
+ "ajv": "^8.17.1",
+ "listr2": "^9.0.1"
+ },
+ "scripts": {
+ "start": "tsx index.ts"
+ }
+}
diff --git a/compatibility-test/providers.ts b/compatibility-test/providers.ts
new file mode 100644
index 00000000..91f58e0f
--- /dev/null
+++ b/compatibility-test/providers.ts
@@ -0,0 +1,15 @@
+export const PROVIDERS = {
+ vllm: {
+ apiBaseUrl: "http://localhost:8000/v1",
+ apiKey: "vllm",
+ apiType: ["responses", "chat"], // choose from responses, chat, or both
+ modelName: "openai/gpt-oss-120b",
+ providerDetails: {
+ // add any provider-specific details here. These will be passed as part of every request
+ // for example to fix the provider for openrouter, you can do:
+ // provider: {
+ // only: ["example"],
+ // },
+ },
+ },
+};
diff --git a/compatibility-test/runCase.ts b/compatibility-test/runCase.ts
new file mode 100644
index 00000000..fd066c0c
--- /dev/null
+++ b/compatibility-test/runCase.ts
@@ -0,0 +1,331 @@
+import {
+ Agent,
+ Runner,
+ OpenAIResponsesModel,
+ OpenAIChatCompletionsModel,
+ RunResult,
+ StreamedRunResult,
+ FunctionTool,
+ setTracingDisabled,
+} from "@openai/agents";
+import { Ajv } from "ajv";
+import { OpenAI } from "openai";
+import { PROVIDERS } from "./providers";
+import { TOOLS_MAP } from "./tools";
+
+setTracingDisabled(true);
+
+const ajv = new Ajv();
+
+export type Case = {
+ tool_name: string;
+ input: string;
+ expected_arguments: string;
+ instructions?: string;
+};
+
+// Summary shape for each apiType
+export type RunCaseSummary = {
+ apiType: string;
+ success: boolean;
+ validResponse: boolean;
+ validEvents?: boolean;
+ details: Record;
+ history: any[];
+ successToolCall: boolean;
+ toolCallingDetails: Record;
+};
+
+export async function runCase(
+ provider: string,
+ caseData: Case,
+ {
+ maxTurns,
+ streaming,
+ strict,
+ }: { maxTurns: number; streaming: boolean; strict: boolean }
+): Promise {
+ const config = PROVIDERS[provider];
+ if (!config) {
+ throw new Error(
+ `Provider ${provider} not found. Valid providers are: ${Object.keys(
+ PROVIDERS
+ ).join(", ")}`
+ );
+ }
+
+ const agent = new Agent({
+ name: caseData.tool_name,
+ instructions: caseData.instructions,
+ tools: [TOOLS_MAP[caseData.tool_name]],
+ });
+
+ const client = new OpenAI({
+ apiKey: config.apiKey,
+ baseURL: config.apiBaseUrl,
+ });
+
+ const summaries: RunCaseSummary[] = [];
+
+ for (const apiType of config.apiType) {
+ const runner = new Runner({
+ model:
+ apiType === "responses"
+ ? new OpenAIResponsesModel(client, config.modelName)
+ : new OpenAIChatCompletionsModel(client, config.modelName),
+ modelSettings: {
+ providerData: config.providerDetails ?? {},
+ },
+ });
+
+ let result: RunResult | StreamedRunResult;
+ let streamedEvents: any[] | undefined = undefined;
+ if (streaming) {
+ result = await runner.run(agent, caseData.input, {
+ stream: streaming,
+ maxTurns: maxTurns,
+ });
+ if (result instanceof StreamedRunResult) {
+ // Collect streaming events if applicable
+ streamedEvents = [];
+ for await (const event of result) {
+ if (event.type === "raw_model_stream_event") {
+ if (event.data.type === "model") {
+ streamedEvents.push(event.data.event);
+ }
+ }
+ }
+ await result.completed;
+ }
+ } else {
+ result = await runner.run(agent, caseData.input, {
+ maxTurns: maxTurns,
+ });
+ }
+
+ const { success: successToolCall, details: toolCallingDetails } =
+ testToolCall(apiType, caseData, result, strict);
+
+ const { validResponse, details } = testOutputData(
+ apiType,
+ result.rawResponses,
+ streaming
+ );
+
+ const { validEvents, details: eventsDetails } = streaming
+ ? testEvents(apiType, streamedEvents)
+ : { validEvents: true, details: {} };
+
+ let success = successToolCall && validResponse;
+ if (streaming) {
+ success = success && validEvents;
+ }
+ const summary: RunCaseSummary = {
+ apiType,
+ success,
+ validResponse,
+ validEvents,
+ details: {
+ ...details,
+ ...eventsDetails,
+ },
+ history: result?.rawResponses.map((entry) => entry.providerData) ?? [],
+ successToolCall,
+ toolCallingDetails,
+ };
+
+ summaries.push(summary);
+ }
+
+ return summaries;
+}
+
+function testToolCall(apiType, caseData, result, strict) {
+ let details: Record = {};
+ result.newItems.forEach((item) => {
+ // for this test for now we only care if the tool is called at least once
+ if (details.calledToolAtLeastOnce) {
+ return;
+ }
+
+ const isToolCall = item.type === "tool_call_item";
+ if (isToolCall) {
+ if (item.rawItem.type === "function_call") {
+ if (item.rawItem.name === caseData.tool_name) {
+ const validate = ajv.compile(
+ (TOOLS_MAP[caseData.tool_name] as FunctionTool).parameters
+ );
+ const valid = validate(JSON.parse(item.rawItem.arguments));
+ details.calledToolWithRightSchema = valid;
+ details.calledToolAtLeastOnce = true;
+
+ if (details.calledToolWithRightSchema) {
+ const parsedArguments = JSON.parse(item.rawItem.arguments);
+ const expectedArguments = JSON.parse(caseData.expected_arguments);
+ details.calledToolWithRightArguments = deepEqual(
+ parsedArguments,
+ expectedArguments
+ );
+ if (!details.calledToolWithRightArguments) {
+ if (details.calledToolWithRightSchema) {
+ details.warning = `Tool call with wrong arguments but correct schema. Check logs for full details. Not failing this test. Parsed: ${JSON.stringify(
+ parsedArguments
+ )} Expected: ${JSON.stringify(expectedArguments)}`;
+ }
+ details.actualArguments = parsedArguments;
+ details.expectedArguments = expectedArguments;
+ }
+ }
+ }
+ }
+ }
+ });
+
+ return {
+ success:
+ !!details.calledToolAtLeastOnce &&
+ !!details.calledToolWithRightSchema &&
+ (!strict || !!details.calledToolWithRightArguments),
+ details,
+ };
+}
+
+function testEvents(apiType, events) {
+ // In an ideal world we would check all the events to follow and reconstruct the final response
+ // and then compare it against the final response in the response.completed event
+ // for now we just check that certain events are present
+
+ let details: Record = {};
+ let validEvents: boolean = false;
+
+ if (apiType === "chat") {
+ let hasReasoningDeltas = false;
+ for (const event of events) {
+ hasReasoningDeltas =
+ hasReasoningDeltas ||
+ (typeof event.choices[0].delta.reasoning === "string" &&
+ event.choices[0].delta.reasoning.length > 0);
+ }
+ details.hasReasoningDeltas = hasReasoningDeltas;
+ validEvents = hasReasoningDeltas;
+ }
+
+ if (apiType === "responses") {
+ let hasReasoningDeltaEvents = false;
+ let hasReasoningDoneEvents = false;
+ for (const event of events) {
+ if (event.type === "raw_model_stream_event") {
+ if (event.data.type === "model") {
+ if (event.data.event.type === "response.reasoning_text.delta") {
+ hasReasoningDeltaEvents = true;
+ }
+ if (event.data.event.type === "response.reasoning_text.done") {
+ hasReasoningDoneEvents = true;
+ }
+ }
+ }
+ }
+
+ details.hasReasoningDeltaEvents = hasReasoningDeltaEvents;
+ details.hasReasoningDoneEvents = hasReasoningDoneEvents;
+ validEvents =
+ details.hasReasoningDeltaEvents && details.hasReasoningDoneEvents;
+ }
+
+ return {
+ validEvents,
+ details,
+ };
+}
+
+function testOutputData(apiType, rawResponses, streaming) {
+ let details: Record = {};
+ let validResponse: boolean = false;
+
+ if (apiType === "chat") {
+ for (const response of rawResponses) {
+ if (streaming && !response.providerData) {
+ // with Chat Completions we don't have a final response object that's native so we skip this test
+ return {
+ validResponse: true,
+ details: {
+ skippedBecauseStreaming: true,
+ },
+ };
+ }
+
+ // this is the actual HTTP response from the provider
+ // Since it's not guaranteed that every response has a reasoning field, we check if it's present
+ // at least once across all responses
+ const data = response.providerData;
+ const message = data.choices[0].message;
+ if (message.role === "assistant" && !message.refusal) {
+ details.hasReasoningField =
+ details.hasReasoningField ||
+ ("reasoning" in message && typeof message.reasoning === "string");
+ details.hasReasoningContentField =
+ details.hasReasoningContentField ||
+ ("reasoning_content" in message &&
+ typeof message.reasoning_content === "string");
+
+ validResponse =
+ validResponse ||
+ (details.hasReasoningField && message.reasoning.length > 0);
+ }
+ }
+ } else if (apiType === "responses") {
+ // this is the actual HTTP response from the provider
+ const data = rawResponses[0].providerData;
+ for (const item of data.output) {
+ // Since it's not guaranteed that every response has a reasoning field, we check if it's present
+ // at least once across all responses
+
+ if (item.type === "reasoning") {
+ details.hasReasoningContentArray = Array.isArray(item.content);
+ details.hasReasoningContentArrayLength = item.content.length > 0;
+ details.hasReasoningContentArrayItemType = item.content.every(
+ (item) => item.type === "reasoning_text"
+ );
+ details.hasReasoningContentArrayItemText = item.content.every(
+ (item) => item.text.length > 0
+ );
+
+ validResponse =
+ details.hasReasoningContentArray &&
+ details.hasReasoningContentArrayLength &&
+ details.hasReasoningContentArrayItemType &&
+ details.hasReasoningContentArrayItemText;
+ }
+ }
+ }
+
+ return {
+ validResponse,
+ details,
+ };
+}
+
+function deepEqual(a: any, b: any): boolean {
+ if (a === b) return true;
+ if (typeof a !== typeof b) return false;
+ if (a && b && typeof a === "object") {
+ if (Array.isArray(a) !== Array.isArray(b)) return false;
+ if (Array.isArray(a)) {
+ if (a.length !== b.length) return false;
+ for (let i = 0; i < a.length; i++) {
+ if (!deepEqual(a[i], b[i])) return false;
+ }
+ return true;
+ } else {
+ const aKeys = Object.keys(a);
+ const bKeys = Object.keys(b);
+ if (aKeys.length !== bKeys.length) return false;
+ for (const key of aKeys) {
+ if (!b.hasOwnProperty(key)) return false;
+ if (!deepEqual(a[key], b[key])) return false;
+ }
+ return true;
+ }
+ }
+ return false;
+}
diff --git a/compatibility-test/tools.ts b/compatibility-test/tools.ts
new file mode 100644
index 00000000..d2d4db6e
--- /dev/null
+++ b/compatibility-test/tools.ts
@@ -0,0 +1,156 @@
+import { Tool, tool } from "@openai/agents";
+
+function convertToTool(toolData: any) {
+ return tool({
+ name: toolData.name,
+ description: toolData.description,
+ parameters: toolData.parameters,
+ execute: async (parameters) => {
+ return toolData.output;
+ },
+ strict: false,
+ });
+}
+
+export const TOOLS = [
+ {
+ type: "function",
+ name: "get_weather",
+ description: "Get the weather for a given location",
+ parameters: {
+ type: "object",
+ properties: {
+ location: {
+ type: "string",
+ description: "The location to get the weather for",
+ },
+ },
+ required: ["location"],
+ additionalProperties: false,
+ },
+ output: '{"weather":"sunny"}',
+ },
+ {
+ type: "function",
+ name: "get_system_health",
+ description:
+ "Returns the current health status of the LLM runtime—use before critical operations to verify the service is live.",
+ parameters: { type: "object", properties: {} },
+ output: '{"status":"ok","uptime_seconds":372045}',
+ },
+ {
+ type: "function",
+ name: "markdown_to_html",
+ description:
+ "Converts a Markdown string to sanitized HTML—use when you need browser-renderable output.",
+ parameters: {
+ type: "object",
+ properties: {
+ markdown: { type: "string", description: "Raw Markdown content" },
+ },
+ required: ["markdown"],
+ additionalProperties: false,
+ },
+ output: '{"html":"Hello World
This is great.
"}',
+ },
+ {
+ type: "function",
+ name: "detect_language",
+ description:
+ "Identifies the ISO language code of the supplied text—use for routing text to language-specific models.",
+ parameters: {
+ type: "object",
+ properties: {
+ text: {
+ type: "string",
+ description: "Text whose language should be detected",
+ },
+ },
+ required: ["text"],
+ additionalProperties: false,
+ },
+ output: '{"language":"de","confidence":0.98}',
+ },
+ {
+ type: "function",
+ name: "generate_chart",
+ description:
+ "Creates a base64-encoded PNG chart from tabular data—use for quick visualizations inside chat.",
+ parameters: {
+ type: "object",
+ properties: {
+ data: {
+ type: "array",
+ items: { type: "array", items: { type: "number" } },
+ description: "2-D numeric data matrix",
+ },
+ chart_type: {
+ type: "string",
+ enum: ["line", "bar", "scatter"],
+ description: "Type of chart to generate",
+ },
+ title: {
+ type: "string",
+ description: "Chart title",
+ default: "",
+ },
+ x_label: {
+ type: "string",
+ description: "Label for the x-axis",
+ default: "",
+ },
+ y_label: {
+ type: "string",
+ description: "Label for the y-axis",
+ default: "",
+ },
+ },
+ required: ["data", "chart_type"],
+ additionalProperties: false,
+ },
+ output: '{"image_png_base64":"iVBORw0KGgoAAAANSUhEUgAA..."}',
+ },
+ {
+ type: "function",
+ name: "query_database",
+ description:
+ "Runs a parameterized SQL SELECT on the internal analytics DB—use for lightweight data look-ups.",
+ parameters: {
+ type: "object",
+ properties: {
+ table: { type: "string", description: "Table name to query" },
+ columns: {
+ type: "array",
+ items: { type: "string" },
+ description: "Columns to return",
+ },
+ filters: {
+ type: "string",
+ description: "SQL WHERE clause without the word WHERE",
+ default: "",
+ },
+ limit: {
+ type: "integer",
+ minimum: 1,
+ maximum: 10000,
+ description: "Max rows to return",
+ default: 100,
+ },
+ order_by: {
+ type: "string",
+ description: "Column to order by (optional)",
+ default: "",
+ },
+ },
+ required: ["table", "columns"],
+ additionalProperties: false,
+ },
+ output:
+ '{"rows":[{"id":1,"email":"user@example.com"},{"id":2,"email":"foo@bar.com"}],"row_count":2}',
+ },
+];
+
+export const TOOLS_MAP = TOOLS.reduce((acc, tool) => {
+ acc[tool.name] = convertToTool(tool);
+ return acc;
+}, {} as Record);
diff --git a/examples/agents-sdk-python/example.py b/examples/agents-sdk-python/example.py
new file mode 100644
index 00000000..af0be603
--- /dev/null
+++ b/examples/agents-sdk-python/example.py
@@ -0,0 +1,102 @@
+import asyncio
+from pathlib import Path
+import shutil
+
+from openai import AsyncOpenAI
+from agents import (
+ Agent,
+ ItemHelpers,
+ Runner,
+ set_default_openai_api,
+ set_default_openai_client,
+ set_tracing_disabled,
+ function_tool,
+)
+from agents.mcp import MCPServerStdio
+
+
+async def prompt_user(question: str) -> str:
+ """Async input prompt function"""
+ loop = asyncio.get_event_loop()
+ return await loop.run_in_executor(None, input, question)
+
+
+async def main():
+ # Set up OpenAI client for local server (e.g., Ollama)
+ openai_client = AsyncOpenAI(
+ api_key="local",
+ base_url="http://localhost:11434/v1",
+ )
+
+ # Get current working directory
+ samples_dir = str(Path.cwd())
+
+ # Create MCP server for filesystem operations
+ mcp_server = MCPServerStdio(
+ name="Filesystem MCP Server, via npx",
+ params={
+ "command": "npx",
+ "args": [
+ "-y",
+ "@modelcontextprotocol/server-filesystem",
+ samples_dir,
+ ],
+ },
+ )
+
+ # Connect to MCP server
+ await mcp_server.connect()
+
+ # Configure agents SDK
+ set_tracing_disabled(True)
+ set_default_openai_client(openai_client)
+ set_default_openai_api("chat_completions")
+
+ # Define weather tool
+ @function_tool
+ async def get_weather(location: str) -> str:
+ return f"The weather in {location} is sunny."
+
+ # Create agent
+ agent = Agent(
+ name="My Agent",
+ instructions="You are a helpful assistant.",
+ tools=[get_weather],
+ model="gpt-oss:20b-test",
+ mcp_servers=[mcp_server],
+ )
+
+ # Get user input
+ user_input = await prompt_user("> ")
+
+ # Run agent with streaming
+ result = Runner.run_streamed(agent, user_input)
+
+ # Process streaming results
+ async for event in result.stream_events():
+ if event.type == "raw_response_event":
+ continue
+ elif event.type == "agent_updated_stream_event":
+ print(f"Agent updated: {event.new_agent.name}")
+ elif event.type == "run_item_stream_event":
+ if event.item.type == "tool_call_item":
+ print("-- Tool was called")
+ elif event.item.type == "tool_call_output_item":
+ print(f"-- Tool output: {event.item.output}")
+ elif event.item.type == "message_output_item":
+ print(
+ f"-- Message output:\n {ItemHelpers.text_message_output(event.item)}"
+ )
+ else:
+ pass
+
+ print("=== Run complete ===")
+
+
+if __name__ == "__main__":
+
+ if not shutil.which("npx"):
+ raise RuntimeError(
+ "npx is not installed. Please install it with `npm install -g npx`."
+ )
+ asyncio.run(main())
diff --git a/examples/agents-sdk-python/pyproject.toml b/examples/agents-sdk-python/pyproject.toml
new file mode 100644
index 00000000..e8d24a81
--- /dev/null
+++ b/examples/agents-sdk-python/pyproject.toml
@@ -0,0 +1,9 @@
+[project]
+name = "agents-sdk-python"
+version = "0.1.0"
+description = "Add your description here"
+readme = "README.md"
+requires-python = ">=3.12"
+dependencies = [
+ "openai-agents>=0.2.4",
+]
diff --git a/examples/gradio/gradio_chat.py b/examples/gradio/gradio_chat.py
new file mode 100644
index 00000000..da742bd3
--- /dev/null
+++ b/examples/gradio/gradio_chat.py
@@ -0,0 +1,247 @@
+import json
+import requests
+import gradio as gr
+
+DEFAULT_FUNCTION_PROPERTIES = """
+{
+ "type": "object",
+ "properties": {
+ "location": {
+ "type": "string",
+ "description": "The city and state, e.g. San Francisco, CA"
+ }
+ },
+ "required": ["location"]
+}
+""".strip()
+
+def chat_with_model(message, history, model_choice, instructions, effort, use_functions,
+ function_name, function_description, function_parameters,
+ use_browser_search, temperature, max_output_tokens, debug_mode):
+
+ if not message.strip():
+ return history, ""
+
+ # Append user message and empty assistant placeholder (idiomatic Gradio pattern)
+ history = history + [[message, ""]]
+
+ # Build messages list from history (excluding the empty assistant placeholder)
+ messages = []
+
+ # Convert history to messages format (excluding the last empty assistant message)
+ for user_msg, assistant_msg in history[:-1]:
+ if user_msg:
+ messages.append({
+ "type": "message",
+ "role": "user",
+ "content": [{"type": "input_text", "text": user_msg}]
+ })
+ if assistant_msg:
+ messages.append({
+ "type": "message",
+ "role": "assistant",
+ "content": [{"type": "output_text", "text": assistant_msg}]
+ })
+
+ # Add current user message
+ messages.append({
+ "type": "message",
+ "role": "user",
+ "content": [{"type": "input_text", "text": message}]
+ })
+
+ # Prepare tools
+ tools = []
+ if use_functions:
+ try:
+ tools.append({
+ "type": "function",
+ "name": function_name,
+ "description": function_description,
+ "parameters": json.loads(function_parameters),
+ })
+ except json.JSONDecodeError:
+ pass
+
+ if use_browser_search:
+ tools.append({"type": "browser_search"})
+
+ # Get URL based on model (matching streamlit logic)
+ options = ["large", "small"]
+ URL = ("http://localhost:8081/v1/responses" if model_choice == options[1]
+ else "http://localhost:8000/v1/responses")
+
+ try:
+ response = requests.post(
+ URL,
+ json={
+ "input": messages,
+ "stream": True,
+ "instructions": instructions,
+ "reasoning": {"effort": effort},
+ "metadata": {"__debug": debug_mode},
+ "tools": tools,
+ "temperature": temperature,
+ "max_output_tokens": max_output_tokens,
+ },
+ stream=True,
+ )
+
+ full_content = ""
+ text_delta = ""
+ current_output_index = 0
+ in_reasoning = False
+
+ for line in response.iter_lines(decode_unicode=True):
+ if not line or not line.startswith("data:"):
+ continue
+ data_str = line[len("data:"):].strip()
+ if not data_str:
+ continue
+
+ try:
+ data = json.loads(data_str)
+ except Exception:
+ continue
+
+ event_type = data.get("type", "")
+ output_index = data.get("output_index", 0)
+
+ if event_type == "response.output_item.added":
+ current_output_index = output_index
+ output_type = data.get("item", {}).get("type", "message")
+ text_delta = ""
+
+ if output_type == "reasoning":
+ if not in_reasoning:
+ full_content += "🤔 **Thinking...**\n"
+ in_reasoning = True
+ elif output_type == "message":
+ if in_reasoning:
+ full_content += "\n\n"
+ in_reasoning = False
+
+ elif event_type == "response.reasoning_text.delta":
+ delta = data.get("delta", "")
+ full_content += delta
+
+ # Update last assistant message (idiomatic Gradio pattern)
+ history[-1][1] = full_content
+ yield history, ""
+
+ elif event_type == "response.output_text.delta":
+ delta = data.get("delta", "")
+ full_content += delta
+
+ # Update last assistant message (idiomatic Gradio pattern)
+ history[-1][1] = full_content
+ yield history, ""
+
+ elif event_type == "response.output_item.done":
+ item = data.get("item", {})
+ if item.get("type") == "function_call":
+ function_call_text = f"\n\n🔨 Called `{item.get('name')}`\n**Arguments**\n```json\n{item.get('arguments', '')}\n```"
+ full_content += function_call_text
+
+ # Update last assistant message (idiomatic Gradio pattern)
+ history[-1][1] = full_content
+ yield history, ""
+
+ elif item.get("type") == "web_search_call":
+ web_search_text = f"\n\n🌐 **Web Search**\n```json\n{json.dumps(item.get('action', {}), indent=2)}\n```\n✅ Done"
+ full_content += web_search_text
+
+ # Update last assistant message (idiomatic Gradio pattern)
+ history[-1][1] = full_content
+ yield history, ""
+
+ elif event_type == "response.completed":
+ response_data = data.get("response", {})
+ if debug_mode:
+ debug_info = response_data.get("metadata", {}).get("__debug", "")
+ if debug_info:
+ full_content += f"\n\n**Debug**\n```\n{debug_info}\n```"
+
+ # Update last assistant message (idiomatic Gradio pattern)
+ history[-1][1] = full_content
+ yield history, ""
+ break
+
+ # Return final history and empty string to clear textbox
+ return history, ""
+
+ except Exception as e:
+ error_message = f"❌ Error: {str(e)}"
+ history[-1][1] = error_message
+ return history, ""
+
+
+# Create the Gradio interface
+with gr.Blocks(title="💬 Chatbot") as demo:
+ gr.Markdown("# 💬 Chatbot")
+
+ with gr.Row():
+ with gr.Column(scale=3):
+ chatbot = gr.Chatbot(height=500)
+
+ with gr.Row():
+ msg = gr.Textbox(placeholder="Type a message...", scale=4, show_label=False)
+ send_btn = gr.Button("Send", scale=1)
+
+ clear_btn = gr.Button("Clear Chat")
+
+ with gr.Column(scale=1):
+ model_choice = gr.Radio(["large", "small"], value="small", label="Model")
+
+ instructions = gr.Textbox(
+ label="Instructions",
+ value="You are a helpful assistant that can answer questions and help with tasks.",
+ lines=3
+ )
+
+ effort = gr.Radio(["low", "medium", "high"], value="medium", label="Reasoning effort")
+
+ gr.Markdown("#### Functions")
+ use_functions = gr.Checkbox(label="Use functions", value=False)
+
+ with gr.Column(visible=False) as function_group:
+ function_name = gr.Textbox(label="Function name", value="get_weather")
+ function_description = gr.Textbox(
+ label="Function description",
+ value="Get the weather for a given city"
+ )
+ function_parameters = gr.Textbox(
+ label="Function parameters",
+ value=DEFAULT_FUNCTION_PROPERTIES,
+ lines=6
+ )
+
+ # Conditional browser search (matching Streamlit logic)
+ # In Streamlit: if "show_browser" in st.query_params:
+ # For Gradio, we'll always show it (simplified)
+ gr.Markdown("#### Built-in Tools")
+ use_browser_search = gr.Checkbox(label="Use browser search", value=False)
+
+ temperature = gr.Slider(0.0, 1.0, value=1.0, step=0.01, label="Temperature")
+ max_output_tokens = gr.Slider(1000, 20000, value=1024, step=100, label="Max output tokens")
+
+ debug_mode = gr.Checkbox(label="Debug mode", value=False)
+
+ # Event handlers
+ def toggle_function_group(use_funcs):
+ return gr.update(visible=use_funcs)
+
+ use_functions.change(toggle_function_group, use_functions, function_group)
+
+ # Chat functionality
+ inputs = [msg, chatbot, model_choice, instructions, effort, use_functions,
+ function_name, function_description, function_parameters,
+ use_browser_search, temperature, max_output_tokens, debug_mode]
+
+ msg.submit(chat_with_model, inputs, [chatbot, msg])
+ send_btn.click(chat_with_model, inputs, [chatbot, msg])
+ clear_btn.click(lambda: [], outputs=chatbot)
+
+
+if __name__ == "__main__":
+ demo.launch()
\ No newline at end of file
diff --git a/examples/streamlit/streamlit_chat.py b/examples/streamlit/streamlit_chat.py
index d03fe9c0..9185ff67 100644
--- a/examples/streamlit/streamlit_chat.py
+++ b/examples/streamlit/streamlit_chat.py
@@ -48,12 +48,10 @@
st.sidebar.subheader("Functions")
use_functions = st.sidebar.toggle("Use functions", value=False)
-if "show_browser" in st.query_params:
- st.sidebar.subheader("Built-in Tools")
+st.sidebar.subheader("Built-in Tools")
# Built-in Tools section
- use_browser_search = st.sidebar.toggle("Use browser search", value=False)
-else:
- use_browser_search = False
+use_browser_search = st.sidebar.toggle("Use browser search", value=False)
+use_code_interpreter = st.sidebar.toggle("Use code interpreter", value=False)
if use_functions:
function_name = st.sidebar.text_input("Function name", value="get_weather")
@@ -72,7 +70,7 @@
"Temperature", min_value=0.0, max_value=1.0, value=1.0, step=0.01
)
max_output_tokens = st.sidebar.slider(
- "Max output tokens", min_value=1000, max_value=20000, value=1024, step=100
+ "Max output tokens", min_value=1, max_value=131072, value=30000, step=1000
)
st.sidebar.divider()
debug_mode = st.sidebar.toggle("Debug mode", value=False)
@@ -89,6 +87,7 @@
else "http://localhost:8000/v1/responses"
)
+
def trigger_fake_tool(container):
function_output = st.session_state.get("function_output", "It's sunny!")
last_call = st.session_state.messages[-1]
@@ -117,6 +116,8 @@ def run(container):
# Add browser_search tool if checkbox is checked
if use_browser_search:
tools.append({"type": "browser_search"})
+ if use_code_interpreter:
+ tools.append({"type": "code_interpreter"})
response = requests.post(
URL,
json={
@@ -134,7 +135,7 @@ def run(container):
text_delta = ""
- current_output_index = 0
+ _current_output_index = 0
for line in response.iter_lines(decode_unicode=True):
if not line or not line.startswith("data:"):
continue
@@ -149,7 +150,7 @@ def run(container):
event_type = data.get("type", "")
output_index = data.get("output_index", 0)
if event_type == "response.output_item.added":
- current_output_index = output_index
+ _current_output_index = output_index
output_type = data.get("item", {}).get("type", "message")
if output_type == "message":
output = container.chat_message("assistant")
@@ -159,7 +160,13 @@ def run(container):
placeholder = output.empty()
elif output_type == "web_search_call":
output = container.chat_message("web_search_call", avatar="🌐")
- output.code(json.dumps(data.get("item", {}).get("action", {}), indent=4), language="json")
+ output.code(
+ json.dumps(data.get("item", {}).get("action", {}), indent=4),
+ language="json",
+ )
+ placeholder = output.empty()
+ elif output_type == "code_interpreter_call":
+ output = container.chat_message("code_interpreter_call", avatar="🧪")
placeholder = output.empty()
text_delta = ""
elif event_type == "response.reasoning_text.delta":
@@ -173,11 +180,23 @@ def run(container):
item = data.get("item", {})
if item.get("type") == "function_call":
with container.chat_message("function_call", avatar="🔨"):
- st.markdown(f"Called `{item.get("name")}`")
+ st.markdown(f"Called `{item.get('name')}`")
st.caption("Arguments")
st.code(item.get("arguments", ""), language="json")
if item.get("type") == "web_search_call":
placeholder.markdown("✅ Done")
+ if item.get("type") == "code_interpreter_call":
+ placeholder.markdown("✅ Done")
+ elif event_type == "response.code_interpreter_call.in_progress":
+ try:
+ placeholder.markdown("⏳ Running")
+ except Exception:
+ pass
+ elif event_type == "response.code_interpreter_call.completed":
+ try:
+ placeholder.markdown("✅ Done")
+ except Exception:
+ pass
elif event_type == "response.completed":
response = data.get("response", {})
if debug_mode:
@@ -187,7 +206,7 @@ def run(container):
st.session_state.messages.extend(response.get("output", []))
if st.session_state.messages[-1].get("type") == "function_call":
with container.form("function_output_form"):
- function_output = st.text_input(
+ _function_output = st.text_input(
"Enter function output",
value=st.session_state.get("function_output", "It's sunny!"),
key="function_output",
@@ -213,7 +232,9 @@ def run(container):
st.markdown(item["text"])
if item.get("annotations"):
annotation_lines = "\n".join(
- f"- {annotation.get('url')}" for annotation in item["annotations"] if annotation.get("url")
+ f"- {annotation.get('url')}"
+ for annotation in item["annotations"]
+ if annotation.get("url")
)
st.caption(f"**Annotations:**\n{annotation_lines}")
elif msg.get("type") == "reasoning":
@@ -223,7 +244,7 @@ def run(container):
st.markdown(item["text"])
elif msg.get("type") == "function_call":
with st.chat_message("function_call", avatar="🔨"):
- st.markdown(f"Called `{msg.get("name")}`")
+ st.markdown(f"Called `{msg.get('name')}`")
st.caption("Arguments")
st.code(msg.get("arguments", ""), language="json")
elif msg.get("type") == "function_call_output":
@@ -234,6 +255,9 @@ def run(container):
with st.chat_message("web_search_call", avatar="🌐"):
st.code(json.dumps(msg.get("action", {}), indent=4), language="json")
st.markdown("✅ Done")
+ elif msg.get("type") == "code_interpreter_call":
+ with st.chat_message("code_interpreter_call", avatar="🧪"):
+ st.markdown("✅ Done")
if render_input:
# Input field
diff --git a/gpt-oss-mcp-server/README.md b/gpt-oss-mcp-server/README.md
index 6326b2e7..10aedd5f 100644
--- a/gpt-oss-mcp-server/README.md
+++ b/gpt-oss-mcp-server/README.md
@@ -1,8 +1,8 @@
# MCP Servers for gpt-oss reference tools
This directory contains MCP servers for the reference tools in the [gpt-oss](https://github.com/openai/gpt-oss) repository.
-You can set up these tools behind MCP servers and use them in your applications.
-For inference service that integrates with MCP, you can also use these as reference tools.
+You can set up these tools behind MCP servers and use them in your applications.
+For inference service that integrates with MCP, you can also use these as reference tools.
In particular, this directory contains a `build-system-prompt.py` script that will generate exactly the same system prompt as `reference-system-prompt.py`.
The build system prompt script show case all the care needed to automatically discover the tools and construct the system prompt before feeding it into Harmony.
@@ -22,8 +22,8 @@ mcp run -t sse browser_server.py:mcp
mcp run -t sse python_server.py:mcp
```
-You can now use MCP inspector to play with the tools.
+You can now use MCP inspector to play with the tools.
Once opened, set SSE to `http://localhost:8001/sse` and `http://localhost:8000/sse` respectively.
-To compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`.
+To compare the system prompt and see how to construct it via MCP service discovery, see `build-system-prompt.py`.
This script will generate exactly the same system prompt as `reference-system-prompt.py`.
diff --git a/gpt-oss-mcp-server/browser_server.py b/gpt-oss-mcp-server/browser_server.py
index 5d5ad4ad..b37a63a6 100644
--- a/gpt-oss-mcp-server/browser_server.py
+++ b/gpt-oss-mcp-server/browser_server.py
@@ -1,3 +1,4 @@
+import os
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
@@ -5,8 +6,7 @@
from mcp.server.fastmcp import Context, FastMCP
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
-
+from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend
@dataclass
class AppContext:
@@ -14,7 +14,13 @@ class AppContext:
def create_or_get_browser(self, session_id: str) -> SimpleBrowserTool:
if session_id not in self.browsers:
- backend = ExaBackend(source="web")
+ tool_backend = os.getenv("BROWSER_BACKEND", "exa")
+ if tool_backend == "youcom":
+ backend = YouComBackend(source="web")
+ elif tool_backend == "exa":
+ backend = ExaBackend(source="web")
+ else:
+ raise ValueError(f"Invalid tool backend: {tool_backend}")
self.browsers[session_id] = SimpleBrowserTool(backend=backend)
return self.browsers[session_id]
diff --git a/gpt-oss-mcp-server/build-system-prompt.py b/gpt-oss-mcp-server/build-system-prompt.py
index 58e953ad..1aca256a 100644
--- a/gpt-oss-mcp-server/build-system-prompt.py
+++ b/gpt-oss-mcp-server/build-system-prompt.py
@@ -1,7 +1,7 @@
import datetime
import asyncio
-from gpt_oss.tokenizer import tokenizer
+from gpt_oss.tokenizer import get_tokenizer
from openai_harmony import (
Conversation,
@@ -66,6 +66,7 @@ def post_process_tools_description(
return list_tools_result
+tokenizer = get_tokenizer()
tools_urls = [
"http://localhost:8001/sse", # browser
diff --git a/gpt-oss-mcp-server/python_server.py b/gpt-oss-mcp-server/python_server.py
index bea86587..7ec35308 100644
--- a/gpt-oss-mcp-server/python_server.py
+++ b/gpt-oss-mcp-server/python_server.py
@@ -20,7 +20,7 @@
When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.
""",
annotations={
- # Harmony format don't wnat this schema to be part of it because it's simple text in text out
+ # Harmony format don't want this schema to be part of it because it's simple text in text out
"include_in_prompt": False,
})
async def python(code: str) -> str:
diff --git a/gpt-oss-mcp-server/reference-system-prompt.py b/gpt-oss-mcp-server/reference-system-prompt.py
index 98f171dd..6ddbf7c9 100644
--- a/gpt-oss-mcp-server/reference-system-prompt.py
+++ b/gpt-oss-mcp-server/reference-system-prompt.py
@@ -1,7 +1,7 @@
import datetime
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from gpt_oss.tokenizer import tokenizer
@@ -22,7 +22,7 @@
ReasoningEffort.LOW).with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")))
-backend = ExaBackend(source="web", )
+backend = YouComBackend(source="web")
browser_tool = SimpleBrowserTool(backend=backend)
system_message_content = system_message_content.with_tools(
browser_tool.tool_config)
diff --git a/gpt_oss/chat.py b/gpt_oss/chat.py
index ed2bda21..4856a397 100644
--- a/gpt_oss/chat.py
+++ b/gpt_oss/chat.py
@@ -19,7 +19,7 @@
from gpt_oss.tools import apply_patch
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend
from gpt_oss.tools.python_docker.docker_tool import PythonTool
from openai_harmony import (
@@ -85,7 +85,7 @@ def main(args):
)
if args.browser:
- backend = ExaBackend(
+ backend = YouComBackend(
source="web",
)
browser_tool = SimpleBrowserTool(backend=backend)
@@ -123,6 +123,8 @@ def main(args):
elif args.developer_message:
developer_message_content = DeveloperContent.new().with_instructions(args.developer_message)
messages.append(Message.from_role_and_content(Role.DEVELOPER, developer_message_content))
+ else:
+ developer_message_content = None
if args.raw:
conversation = Conversation.from_messages(messages)
@@ -142,9 +144,9 @@ def main(args):
print(termcolor.colored("Browser Tool:", "cyan"), "Enabled" if args.browser else "Disabled", flush=True)
print(termcolor.colored("Python Tool:", "cyan"), "Enabled" if args.python else "Disabled", flush=True)
print(termcolor.colored("Apply Patch Function:", "cyan"), "Enabled" if args.apply_patch else "Disabled", flush=True)
- # Developer message
- print(termcolor.colored("Developer Message:", "yellow"), flush=True)
- print(developer_message_content.instructions, flush=True)
+ if developer_message_content:
+ print(termcolor.colored("Developer Message:", "yellow"), flush=True)
+ print(developer_message_content.instructions, flush=True)
# Print the system message and the user message start
MESSAGE_PADDING = 12
diff --git a/gpt_oss/evals/__main__.py b/gpt_oss/evals/__main__.py
index 7e95ab7e..40d56c12 100644
--- a/gpt_oss/evals/__main__.py
+++ b/gpt_oss/evals/__main__.py
@@ -3,12 +3,13 @@
from datetime import datetime
from . import report
+from .basic_eval import BasicEval
from .gpqa_eval import GPQAEval
from .aime_eval import AIME25Eval
from .healthbench_eval import HealthBenchEval
-from .chat_completion_sampler import (
+from .chat_completions_sampler import (
OPENAI_SYSTEM_MESSAGE_API,
- ChatCompletionSampler,
+ ChatCompletionsSampler,
)
from .responses_sampler import ResponsesSampler
@@ -19,12 +20,23 @@ def main():
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
- "--list-models", action="store_true", help="List available models"
+ "--model",
+ type=str,
+ default="gpt-oss-120b,gpt-oss-20b",
+ help="Select a model by name. Accepts a comma-separated list.",
)
parser.add_argument(
- "--model",
+ "--reasoning-effort",
+ type=str,
+ default="low,medium,high",
+ help="Reasoning effort (low, medium, high). Accepts a comma-separated list.",
+ )
+ parser.add_argument(
+ "--sampler",
type=str,
- help="Select a model by name. Also accepts a comma-separated list of models.",
+ choices=["responses", "chat_completions"],
+ default="responses",
+ help="Sampler backend to use for models.",
)
parser.add_argument(
"--base-url",
@@ -36,7 +48,7 @@ def main():
"--eval",
type=str,
default="gpqa,healthbench,healthbench_hard,healthbench_consensus,aime25",
- help="Select an eval by name. Also accepts a comma-separated list of evals.",
+ help="Select an eval by name. Accepts a comma-separated list.",
)
parser.add_argument(
"--temperature",
@@ -59,71 +71,27 @@ def main():
args = parser.parse_args()
- models = {
- "120b-low": ResponsesSampler(
- model="gpt-oss-120b",
- reasoning_model=True,
- reasoning_effort="low",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- "120b": ResponsesSampler(
- model="gpt-oss-120b",
- reasoning_model=True,
- reasoning_effort="medium",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- "120b-high": ResponsesSampler(
- model="gpt-oss-120b",
- reasoning_model=True,
- reasoning_effort="high",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- "20b-low": ResponsesSampler(
- model="gpt-oss-20b",
- reasoning_model=True,
- reasoning_effort="low",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- "20b": ResponsesSampler(
- model="gpt-oss-20b",
- reasoning_model=True,
- reasoning_effort="medium",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- "20b-high": ResponsesSampler(
- model="gpt-oss-20b",
- reasoning_model=True,
- reasoning_effort="high",
- temperature=args.temperature,
- base_url=args.base_url,
- ),
- }
-
- if args.list_models:
- print("Available models:")
- for model_name in models.keys():
- print(f" - {model_name}")
- return
-
- if args.model:
- models_chosen = args.model.split(",")
- for model_name in models_chosen:
- if model_name not in models:
- print(f"Error: Model '{model_name}' not found.")
- return
- models = {model_name: models[model_name] for model_name in models_chosen}
+ sampler_cls = ResponsesSampler if args.sampler == "responses" else ChatCompletionsSampler
+
+ models = {}
+ for model_name in args.model.split(","):
+ for reasoning_effort in args.reasoning_effort.split(","):
+ models[f"{model_name}-{reasoning_effort}"] = sampler_cls(
+ model=model_name,
+ reasoning_model=True,
+ reasoning_effort=reasoning_effort,
+ temperature=args.temperature,
+ base_url=args.base_url,
+ max_tokens=131_072,
+ )
print(f"Running with args {args}")
- grading_sampler = ChatCompletionSampler(
+ grading_sampler = ChatCompletionsSampler(
model="gpt-4.1-2025-04-14",
system_message=OPENAI_SYSTEM_MESSAGE_API,
max_tokens=2048,
+ base_url="https://api.openai.com/v1",
)
def get_evals(eval_name, debug_mode):
@@ -132,9 +100,11 @@ def get_evals(eval_name, debug_mode):
)
# Set num_examples = None to reproduce full evals
match eval_name:
+ case "basic":
+ return BasicEval()
case "gpqa":
return GPQAEval(
- n_repeats=8,
+ n_repeats=1 if args.debug else 8,
num_examples=num_examples,
debug=debug_mode,
n_threads=args.n_threads or 1,
@@ -165,28 +135,27 @@ def get_evals(eval_name, debug_mode):
)
case "aime25":
return AIME25Eval(
- n_repeats=8,
+ n_repeats=1 if args.debug else 8,
num_examples=num_examples,
n_threads=args.n_threads or 1,
)
case _:
raise Exception(f"Unrecognized eval type: {eval_name}")
- evals_list = args.eval.split(",")
evals = {}
- for eval_name in evals_list:
+ for eval_name in args.eval.split(","):
evals[eval_name] = get_evals(eval_name, args.debug)
- print(evals)
debug_suffix = "_DEBUG" if args.debug else ""
print(debug_suffix)
mergekey2resultpath = {}
- print(f"Running the following evals: {list(evals.keys())}")
- print(f"Running evals for the following models: {list(models.keys())}")
+ print(f"Running the following evals: {evals}")
+ print(f"Running evals for the following models: {models}")
now = datetime.now()
date_str = now.strftime("%Y%m%d_%H%M%S")
for model_name, sampler in models.items():
+ model_name = model_name.replace("/", "__")
for eval_name, eval_obj in evals.items():
result = eval_obj(sampler)
# ^^^ how to use a sampler
@@ -220,6 +189,7 @@ def get_evals(eval_name, debug_mode):
print(f"Writing all results to {full_result_filename}")
mergekey2resultpath[f"{file_stem}"] = result_filename
+
merge_metrics = []
for eval_model_name, result_filename in mergekey2resultpath.items():
try:
diff --git a/gpt_oss/evals/basic_eval.py b/gpt_oss/evals/basic_eval.py
new file mode 100644
index 00000000..77995307
--- /dev/null
+++ b/gpt_oss/evals/basic_eval.py
@@ -0,0 +1,38 @@
+"""
+Basic eval
+"""
+from . import report
+
+from .types import Eval, EvalResult, SamplerBase, SingleEvalResult
+
+class BasicEval(Eval):
+ def __init__(self,):
+ self.examples = [{
+ "question": "hi",
+ "answer": "hi, how can i help?",
+ }]
+
+ def __call__(self, sampler: SamplerBase) -> EvalResult:
+ def fn(row: dict):
+ sampler_response = sampler([
+ sampler._pack_message(content=row["question"], role="user")
+ ])
+ response_text = sampler_response.response_text
+ extracted_answer = response_text
+ actual_queried_prompt_messages = sampler_response.actual_queried_message_list
+ score = 1.0 if len(extracted_answer) > 0 else 0.0
+ html = report.jinja_env.from_string(report.HTML_JINJA).render(
+ prompt_messages=actual_queried_prompt_messages,
+ next_message=dict(content=response_text, role="assistant"),
+ score=score,
+ correct_answer=row["answer"],
+ extracted_answer=extracted_answer,
+ )
+ convo = actual_queried_prompt_messages + [dict(content=response_text, role="assistant")]
+ return SingleEvalResult(
+ html=html, score=score, convo=convo, metrics={"chars": len(response_text)}
+ )
+
+ results = report.map_with_progress(fn, self.examples, num_threads=1)
+ return report.aggregate_results(results)
+
diff --git a/gpt_oss/evals/chat_completion_sampler.py b/gpt_oss/evals/chat_completions_sampler.py
similarity index 57%
rename from gpt_oss/evals/chat_completion_sampler.py
rename to gpt_oss/evals/chat_completions_sampler.py
index 4a1f9618..29c1a0a8 100644
--- a/gpt_oss/evals/chat_completion_sampler.py
+++ b/gpt_oss/evals/chat_completions_sampler.py
@@ -6,6 +6,7 @@
from .types import MessageList, SamplerBase, SamplerResponse
+
OPENAI_SYSTEM_MESSAGE_API = "You are a helpful assistant."
OPENAI_SYSTEM_MESSAGE_CHATGPT = (
"You are ChatGPT, a large language model trained by OpenAI, based on the GPT-4 architecture."
@@ -13,10 +14,8 @@
)
-class ChatCompletionSampler(SamplerBase):
- """
- Sample from OpenAI's chat completion API
- """
+class ChatCompletionsSampler(SamplerBase):
+ """Sample from a Chat Completions compatible API."""
def __init__(
self,
@@ -24,17 +23,20 @@ def __init__(
system_message: str | None = None,
temperature: float = 0.5,
max_tokens: int = 1024,
+ reasoning_model: bool = False,
+ reasoning_effort: str | None = None,
+ base_url: str = "http://localhost:8000/v1",
):
- self.api_key_name = "OPENAI_API_KEY"
- self.client = OpenAI()
- # using api_key=os.environ.get("OPENAI_API_KEY") # please set your API_KEY
+ self.client = OpenAI(base_url=base_url, timeout=24 * 60 * 60)
self.model = model
self.system_message = system_message
self.temperature = temperature
self.max_tokens = max_tokens
+ self.reasoning_model = reasoning_model
+ self.reasoning_effort = reasoning_effort
self.image_format = "url"
- def _pack_message(self, role: str, content: Any):
+ def _pack_message(self, role: str, content: Any) -> dict[str, Any]:
return {"role": str(role), "content": content}
def __call__(self, message_list: MessageList) -> SamplerResponse:
@@ -45,21 +47,34 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
trial = 0
while True:
try:
- response = self.client.chat.completions.create(
- model=self.model,
- messages=message_list,
- temperature=self.temperature,
- max_tokens=self.max_tokens,
- )
- content = response.choices[0].message.content
- if content is None:
+ if self.reasoning_model:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=message_list,
+ reasoning_effort=self.reasoning_effort,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+ else:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=message_list,
+ temperature=self.temperature,
+ max_tokens=self.max_tokens,
+ )
+
+ choice = response.choices[0]
+ content = choice.message.content
+ if getattr(choice.message, "reasoning", None):
+ message_list.append(self._pack_message("assistant", choice.message.reasoning))
+
+ if not content:
raise ValueError("OpenAI API returned empty response; retrying")
return SamplerResponse(
response_text=content,
response_metadata={"usage": response.usage},
actual_queried_message_list=message_list,
)
- # NOTE: BadRequestError is triggered once for MMMU, please uncomment if you are reruning MMMU
except openai.BadRequestError as e:
print("Bad Request Error", e)
return SamplerResponse(
@@ -68,7 +83,7 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
actual_queried_message_list=message_list,
)
except Exception as e:
- exception_backoff = 2**trial # expontial back off
+ exception_backoff = 2 ** trial # exponential back off
print(
f"Rate limit exception so wait and retry {trial} after {exception_backoff} sec",
e,
diff --git a/gpt_oss/evals/healthbench_eval.py b/gpt_oss/evals/healthbench_eval.py
index 77a6b3a5..09d184c1 100644
--- a/gpt_oss/evals/healthbench_eval.py
+++ b/gpt_oss/evals/healthbench_eval.py
@@ -8,7 +8,7 @@
You can also evaluate physician ideal completions or reference completions against the HealthBench rubrics. To do so, run the following command:
- To evaluate physician ideal completions: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completions`
-- To evaluate reference model completions used by physicians: `python -m gpt_oss.evalshealthbench_eval --run_mode=physician_completion_references`
+- To evaluate reference model completions used by physicians: `python -m gpt_oss.evals.healthbench_eval --run_mode=physician_completion_references`
"""
import argparse
@@ -26,9 +26,9 @@
import numpy as np
from . import report
-from .chat_completion_sampler import (
+from .chat_completions_sampler import (
OPENAI_SYSTEM_MESSAGE_API,
- ChatCompletionSampler,
+ ChatCompletionsSampler,
)
from .types import Eval, EvalResult, MessageList, SamplerBase, SingleEvalResult
@@ -48,7 +48,7 @@
# Instructions
Return a json object with the following fields: "explanation" and "criteria_met".
- The "explanation" field should be a string explaining why the response does or does not meet the criteria of the rubric item.
-- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true is all of the criteria are met.
+- The "criteria_met" field should be a boolean indicating whether the response meets the criteria of the rubric item. If a rubric item has multiple sentences or criteria, you should consider all of them. If any of the criteria is not met, the answer should be false. Only return true if all of the criteria are met.
- One important exception to the above bullet point is that if a criteria says "such as", "for example", or "including", the response does not have to include all of the examples listed to meet the criteria. For example, if the criteria says "States that oral iron supplements can lead to unpleasant gastrointestinal side effects such as nausea, vomiting, and constipation", and the response just says that oral iron supplements can lead to unpleasant gastrointestinal side effects such as cramps, that would still meet the criteria even though it didn't mention any of the specific examples listed in the criteria.
# Example 1
@@ -72,7 +72,7 @@
}
```
-As another example, if the critera says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
+As another example, if the criteria says "Is overly verbose" and the response is concise, then the criteria is not met, and you should return a json like this:
```json
{
@@ -540,10 +540,11 @@ def physician_completions_main(
now = datetime.now()
date_str = now.strftime("%Y%m%d_%H%M")
- grading_sampler = ChatCompletionSampler(
+ grading_sampler = ChatCompletionsSampler(
model="gpt-4.1-2025-04-14",
system_message=OPENAI_SYSTEM_MESSAGE_API,
max_tokens=2048,
+ base_url="https://api.openai.com/v1",
)
dummy_sampler = SamplerBase()
diff --git a/gpt_oss/evals/responses_sampler.py b/gpt_oss/evals/responses_sampler.py
index ec4e0485..134303f5 100644
--- a/gpt_oss/evals/responses_sampler.py
+++ b/gpt_oss/evals/responses_sampler.py
@@ -17,12 +17,11 @@ def __init__(
model: str,
developer_message: str | None = None,
temperature: float = 1.0,
- max_tokens: int = 1024,
+ max_tokens: int = 131_072,
reasoning_model: bool = False,
reasoning_effort: str | None = None,
base_url: str = "http://localhost:8000/v1",
):
- self.api_key_name = "OPENAI_API_KEY"
self.client = OpenAI(base_url=base_url, timeout=24*60*60)
self.model = model
self.developer_message = developer_message
@@ -43,24 +42,17 @@ def __call__(self, message_list: MessageList) -> SamplerResponse:
trial = 0
while True:
try:
+ request_kwargs = {
+ "model": self.model,
+ "input": message_list,
+ "temperature": self.temperature,
+ "max_output_tokens": self.max_tokens,
+ }
if self.reasoning_model:
- reasoning = (
- {"effort": self.reasoning_effort}
- if self.reasoning_effort
- else None
- )
- response = self.client.responses.create(
- model=self.model,
- input=message_list,
- reasoning=reasoning,
- )
- else:
- response = self.client.responses.create(
- model=self.model,
- input=message_list,
- temperature=self.temperature,
- max_output_tokens=self.max_tokens,
+ request_kwargs["reasoning"] = (
+ {"effort": self.reasoning_effort} if self.reasoning_effort else None
)
+ response = self.client.responses.create(**request_kwargs)
for output in response.output:
if hasattr(output, "text"):
diff --git a/gpt_oss/generate.py b/gpt_oss/generate.py
index cd60ca4c..c0755805 100644
--- a/gpt_oss/generate.py
+++ b/gpt_oss/generate.py
@@ -19,20 +19,21 @@ def main(args):
from gpt_oss.torch.utils import init_distributed
from gpt_oss.triton.model import TokenGenerator as TritonGenerator
device = init_distributed()
- generator = TritonGenerator(args.checkpoint, context=4096, device=device)
+ generator = TritonGenerator(args.checkpoint, context=args.context_length, device=device)
case "vllm":
from gpt_oss.vllm.token_generator import TokenGenerator as VLLMGenerator
- generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=2)
+ generator = VLLMGenerator(args.checkpoint, tensor_parallel_size=args.tensor_parallel_size)
case _:
raise ValueError(f"Invalid backend: {args.backend}")
tokenizer = get_tokenizer()
tokens = tokenizer.encode(args.prompt)
- for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=args.limit, return_logprobs=True):
+ max_tokens = None if args.limit == 0 else args.limit
+ for token, logprob in generator.generate(tokens, stop_tokens=[tokenizer.eot_token], temperature=args.temperature, max_tokens=max_tokens, return_logprobs=True):
tokens.append(token)
- decoded_token = tokenizer.decode([token])
+ token_text = tokenizer.decode([token])
print(
- f"Generated token: {repr(decoded_token)}, logprob: {logprob}"
+ f"Generated token: {repr(token_text)}, logprob: {logprob}"
)
@@ -77,6 +78,18 @@ def main(args):
choices=["triton", "torch", "vllm"],
help="Inference backend",
)
+ parser.add_argument(
+ "--tensor-parallel-size",
+ type=int,
+ default=2,
+ help="Tensor parallel size for vLLM backend",
+ )
+ parser.add_argument(
+ "--context-length",
+ type=int,
+ default=4096,
+ help="Context length for Triton backend",
+ )
args = parser.parse_args()
main(args)
diff --git a/gpt_oss/metal/CMakeLists.txt b/gpt_oss/metal/CMakeLists.txt
index c6a8e32b..52f83b0f 100644
--- a/gpt_oss/metal/CMakeLists.txt
+++ b/gpt_oss/metal/CMakeLists.txt
@@ -147,6 +147,14 @@ add_executable(f32-bf16w-rmsnorm-bench benchmark/f32-bf16w-rmsnorm.cc)
target_link_libraries(f32-bf16w-rmsnorm-bench PRIVATE benchmark::benchmark metal-kernels)
target_include_directories(f32-bf16w-rmsnorm-bench PRIVATE source/include)
+add_executable(end-to-end-bench benchmark/end-to-end.cc)
+target_link_libraries(end-to-end-bench PRIVATE benchmark::benchmark gptoss)
+target_include_directories(end-to-end-bench PRIVATE source/include)
+
+add_executable(end-to-end-threadgroup-bench benchmark/end-to-end-threadgroup.cc)
+target_link_libraries(end-to-end-threadgroup-bench PRIVATE benchmark::benchmark gptoss)
+target_include_directories(end-to-end-threadgroup-bench PRIVATE source/include)
+
# --- [ Python extension ] -----------------------------------------------
find_package(pybind11 CONFIG REQUIRED) # provides pybind11_add_module
diff --git a/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc b/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc
new file mode 100644
index 00000000..93fb1647
--- /dev/null
+++ b/gpt_oss/metal/benchmark/end-to-end-threadgroup.cc
@@ -0,0 +1,590 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+
+constexpr std::uint32_t kNumGeneratedTokens = 100;
+
+
+static void attn_qkv_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->attn_qkv_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void AttnQKVThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto attn_qkv_threadgroup_size = 32; attn_qkv_threadgroup_size <= 1024; attn_qkv_threadgroup_size += 32) {
+ const auto num_simdgroups = attn_qkv_threadgroup_size / 32;
+ if (5120 % num_simdgroups != 0) {
+ // Skip incompatible threadgroup sizes
+ continue;
+ }
+ b->Args({attn_qkv_threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(attn_qkv_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnQKVThreadgroupSizeArguments);
+
+static void attn_out_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->attn_out_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void AttnOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto attn_out_threadgroup_size = 32; attn_out_threadgroup_size <= 1024; attn_out_threadgroup_size += 32) {
+ const auto num_simdgroups = attn_out_threadgroup_size / 32;
+ if (2880 % num_simdgroups != 0) {
+ // Skip incompatible threadgroup sizes
+ continue;
+ }
+ b->Args({attn_out_threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(attn_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(AttnOutThreadgroupSizeArguments);
+
+static void mlp_gate_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->mlp_gate_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void MlpGateThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto mlp_gate_threadgroup_size = 32; mlp_gate_threadgroup_size <= 1024; mlp_gate_threadgroup_size += 32) {
+ const auto num_simdgroups = mlp_gate_threadgroup_size / 32;
+ if (128 % num_simdgroups != 0) {
+ // Skip incompatible threadgroup sizes
+ continue;
+ }
+ b->Args({mlp_gate_threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(mlp_gate_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpGateThreadgroupSizeArguments);
+
+static void mlp_swiglu_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->mlp_swiglu_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void MlpSwigluThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {
+ const auto num_simdgroups = threadgroup_size / 32;
+ if (5760 % num_simdgroups != 0) {
+ // Skip incompatible threadgroup sizes
+ continue;
+ }
+ b->Args({threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(mlp_swiglu_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpSwigluThreadgroupSizeArguments);
+
+static void mlp_out_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->mlp_out_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void MlpOutThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto threadgroup_size = 64; threadgroup_size <= 1024; threadgroup_size += 64) {
+ const auto num_simdgroups = threadgroup_size / 32;
+ if (5760 % num_simdgroups != 0) {
+ // Skip incompatible threadgroup sizes
+ continue;
+ }
+ b->Args({threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(mlp_out_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpOutThreadgroupSizeArguments);
+
+static void mlp_acc_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->mlp_acc_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void MlpAccThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {
+ b->Args({threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(mlp_acc_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(MlpAccThreadgroupSizeArguments);
+
+static void unembedding_tgsize(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, /*max_batch_tokens=*/0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+ model->unembedding_threadgroup_size = static_cast(state.range(0));
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ const std::size_t num_kvcache_tokens = context->num_kv_tokens;
+
+ std::uint64_t rng_seed = 0;
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void UnembeddingThreadgroupSizeArguments(benchmark::internal::Benchmark* b) {
+ b->ArgNames({"tgsize"});
+ for (auto threadgroup_size = 32; threadgroup_size <= 1024; threadgroup_size += 32) {
+ b->Args({threadgroup_size});
+ }
+}
+
+BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_20b, "GPT_OSS_20B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);
+BENCHMARK_CAPTURE(unembedding_tgsize, gpt_oss_120b, "GPT_OSS_120B_PATH")
+ ->UseRealTime()->Unit(benchmark::kMillisecond)->Apply(UnembeddingThreadgroupSizeArguments);
+
+BENCHMARK_MAIN();
diff --git a/gpt_oss/metal/benchmark/end-to-end.cc b/gpt_oss/metal/benchmark/end-to-end.cc
new file mode 100644
index 00000000..b0f4367c
--- /dev/null
+++ b/gpt_oss/metal/benchmark/end-to-end.cc
@@ -0,0 +1,221 @@
+#include
+#include
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include
+
+constexpr std::uint32_t kNumGeneratedTokens = 100;
+
+static void end2end_decode(benchmark::State& state, const char* env_var_name) {
+ const char* model_path = getenv(env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set", env_var_name));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status = gptoss_model_create_from_file(model_path, &model_ptr, 0);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_model_release)> model(model_ptr, gptoss_model_release);
+
+ gptoss_context_t context_ptr = nullptr;
+ status = gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr, decltype(&gptoss_context_release)> context(context_ptr, gptoss_context_release);
+
+ const char* prompt = "why did the chicken cross the road?";
+ std::size_t num_prompt_tokens = 0;
+ status = gptoss_context_append_chars(context.get(), prompt, strlen(prompt), &num_prompt_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format("failed to tokenize prompt \"{}\"", prompt));
+ return;
+ }
+
+ // Prefill
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ std::uint64_t rng_seed = 0;
+
+ for (auto _ : state) {
+ const std::uint64_t current_rng_seed = rng_seed++;
+ context->num_kv_tokens = num_prompt_tokens;
+ context->num_tokens = num_prompt_tokens;
+
+ std::array tokens;
+ std::size_t num_generated_tokens = 0;
+ do {
+ std::size_t num_current_generated_tokens = 0;
+ status = gptoss_context_sample(context.get(), /*temperature=*/1.0f, /*rng_state=*/current_rng_seed,
+ /*max_tokens=*/kNumGeneratedTokens - num_generated_tokens, tokens.data(), &num_current_generated_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to sample from the Context object");
+ return;
+ }
+ num_generated_tokens += num_current_generated_tokens;
+ } while (num_generated_tokens < kNumGeneratedTokens);
+ }
+
+ state.counters["generations"] =
+ benchmark::Counter(state.iterations(), benchmark::Counter::kIsRate);
+ state.counters["tokens"] =
+ benchmark::Counter(state.iterations() * kNumGeneratedTokens, benchmark::Counter::kIsRate);
+}
+
+static void end2end_prefill(benchmark::State& state,
+ const char* model_path_env_var_name,
+ const char* prompt_env_var_name,
+ size_t context_length = 0) {
+ const char* model_path = getenv(model_path_env_var_name);
+ if (model_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set",
+ model_path_env_var_name));
+ return;
+ }
+
+ const char* prompt_file_path = getenv(prompt_env_var_name);
+ if (prompt_file_path == NULL) {
+ state.SkipWithError(std::format("environment variable {} is not set",
+ prompt_env_var_name));
+ return;
+ }
+
+ // Read prompt contents from file into a std::string
+ std::ifstream prompt_file(prompt_file_path,
+ std::ios::in | std::ios::binary);
+ if (!prompt_file) {
+ state.SkipWithError(
+ std::format("failed to open prompt file {}", prompt_file_path));
+ return;
+ }
+ std::string prompt_str;
+ prompt_file.seekg(0, std::ios::end);
+ std::streampos file_size = prompt_file.tellg();
+ if (file_size < 0) {
+ state.SkipWithError(std::format("failed to read prompt file size {}",
+ prompt_file_path));
+ return;
+ }
+ prompt_str.resize(static_cast(file_size));
+ prompt_file.seekg(0, std::ios::beg);
+ if (file_size > 0) {
+ prompt_file.read(prompt_str.data(), file_size);
+ }
+ if (!prompt_file) {
+ state.SkipWithError(
+ std::format("failed to read prompt file {}", prompt_file_path));
+ return;
+ }
+
+ gptoss_model_t model_ptr = nullptr;
+ gptoss_status status =
+ gptoss_model_create_from_file(model_path, &model_ptr, 1024);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(
+ std::format("failed to load model from file {}", model_path));
+ return;
+ }
+ std::unique_ptr,
+ decltype(&gptoss_model_release)>
+ model(model_ptr, gptoss_model_release);
+
+ gptoss_tokenizer_t tokenizer_ptr = nullptr;
+ status = gptoss_model_get_tokenizer(model.get(), &tokenizer_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to retrieve Tokenizer");
+ return;
+ }
+ std::unique_ptr,
+ decltype(&gptoss_tokenizer_release)>
+ tokenizer(tokenizer_ptr, gptoss_tokenizer_release);
+
+ gptoss_context_t context_ptr = nullptr;
+ status =
+ gptoss_context_create(model.get(), /*context_lenght=*/0, &context_ptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to create Context object");
+ return;
+ }
+ std::unique_ptr,
+ decltype(&gptoss_context_release)>
+ context(context_ptr, gptoss_context_release);
+
+ const char* prompt = prompt_str.c_str();
+ status = gptoss_context_append_chars(context.get(), prompt,
+ prompt_str.size(), nullptr);
+ if (status != gptoss_status_success) {
+ state.SkipWithError(std::format(
+ "failed to tokenize prompt from file {}", prompt_file_path));
+ return;
+ }
+
+ size_t num_tokens;
+ status = gptoss_context_get_num_tokens(context.get(), &num_tokens);
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to get number of tokens");
+ return;
+ }
+ if (context_length != 0) {
+ assert(context_length <= num_tokens);
+ context->num_tokens = context_length;
+ }
+ // Prefill
+ for (auto _ : state) {
+ status = gptoss_context_process(context.get());
+ if (status != gptoss_status_success) {
+ state.SkipWithError("failed to prefill Context object");
+ return;
+ }
+ context->num_kv_tokens = 0;
+ }
+
+ state.counters["tokens"] = num_tokens;
+ state.counters["tokens/s"] = benchmark::Counter(
+ state.iterations() * num_tokens, benchmark::Counter::kIsRate);
+}
+
+// Decode end-to-end benchmark
+BENCHMARK_CAPTURE(end2end_decode, gpt_oss_20b_decode, "GPT_OSS_20B_PATH")
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(end2end_decode, gpt_oss_120b_decode, "GPT_OSS_120B_PATH")
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+
+// Prefill end-to-end benchmark
+BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_1024,
+ "GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 1024)
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_1024, "GPT_OSS_20B_PATH",
+ "GPT_OSS_PROMPT_FILE_PATH", 1024)
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+
+BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_120b_prefill_3072,
+ "GPT_OSS_120B_PATH", "GPT_OSS_PROMPT_FILE_PATH", 3072)
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+BENCHMARK_CAPTURE(end2end_prefill, gpt_oss_20b_prefill_3072, "GPT_OSS_20B_PATH",
+ "GPT_OSS_PROMPT_FILE_PATH", 3072)
+ ->UseRealTime()
+ ->Unit(benchmark::kMillisecond);
+
+BENCHMARK_MAIN();
diff --git a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
index 17515942..ee7551c2 100644
--- a/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
+++ b/gpt_oss/metal/benchmark/f32-bf16w-rmsnorm.cc
@@ -26,6 +26,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
Buffer input_buffer{device, num_tokens * num_channels * sizeof(float)};
Buffer weight_buffer{device, num_channels * sizeof(gptoss_bfloat16)};
Buffer output_buffer{device, num_tokens * num_channels * sizeof(float)};
+ Buffer control_buffer{device, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
{
CommandBuffer command_buffer{command_queue};
@@ -69,6 +71,8 @@ static void f32_bf16w_rnsnorm(benchmark::State& state) {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens,
num_channels,
kEpsilon),
diff --git a/gpt_oss/metal/examples/generate.py b/gpt_oss/metal/examples/generate.py
index b9c0beac..3b781999 100644
--- a/gpt_oss/metal/examples/generate.py
+++ b/gpt_oss/metal/examples/generate.py
@@ -3,8 +3,7 @@
import argparse
import sys
-from datetime import date
-from gpt_oss import Context, Model
+from gpt_oss.metal import Context, Model
parser = argparse.ArgumentParser(description='Chat with gpt-oss', formatter_class=argparse.ArgumentDefaultsHelpFormatter)
diff --git a/gpt_oss/metal/include/gpt-oss/functions.h b/gpt_oss/metal/include/gpt-oss/functions.h
index a81bf50a..5b0d83ea 100644
--- a/gpt_oss/metal/include/gpt-oss/functions.h
+++ b/gpt_oss/metal/include/gpt-oss/functions.h
@@ -15,13 +15,17 @@ extern "C" {
*
* @param path Path to the file containing the model in GPT-OSS format.
* @param model_out Pointer to the Model object that will be created. Must be released with gptoss_release_model.
+ * @param max_batch_tokens Maximum number of tokens that can be processed in a single batch.
+ * Larger values may improve prefill performance, but require more memory.
+ * Specify 0 to use the default value.
*
* On success, returns gptoss_status_success and saves a pointer to the created Model in the model_out argument.
* On failure, returns an error code and stores null pointer in the model_out argument.
*/
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
const char* path,
- gptoss_model_t* model_out);
+ gptoss_model_t* model_out,
+ size_t max_batch_tokens);
/*
* Query the Tokenizer object associated with the Model.
@@ -218,7 +222,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_max_tokens(
*
* On success, returns gptoss_status_success and stores cached token IDs in the tokens_out argument and the number of
* cached tokens in the num_tokens_out argument.
- * On failure, returns an error code and leaves the values specified by tokend_out and num_tokens_out unchanged.
+ * On failure, returns an error code and leaves the values specified by tokens_out and num_tokens_out unchanged.
*/
enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
gptoss_context_t context,
@@ -267,7 +271,7 @@ enum gptoss_status GPTOSS_ABI gptoss_context_reset(
gptoss_context_t context);
/*
- * Pre-process the tokens in the Context and generate probability distrubution over the next token.
+ * Pre-process the tokens in the Context and generate probability distribution over the next token.
*
* @param context Context object created by gptoss_context_create.
*
@@ -290,7 +294,9 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
- uint32_t* token_out);
+ size_t max_tokens,
+ uint32_t* tokens_out,
+ size_t* num_tokens_out);
/*
* Increments a Context object's reference count.
diff --git a/gpt_oss/metal/python/context.c b/gpt_oss/metal/python/context.c
index d71cc396..abc031af 100644
--- a/gpt_oss/metal/python/context.c
+++ b/gpt_oss/metal/python/context.c
@@ -120,25 +120,54 @@ static PyObject* PyGPTOSSContext_process(PyGPTOSSContext* self) {
}
static PyObject* PyGPTOSSContext_sample(PyGPTOSSContext* self, PyObject* args, PyObject* kwargs) {
- static char *kwlist[] = {"temperature", "seed", NULL};
+ static char *kwlist[] = {"max_output_tokens", "temperature", "seed", NULL};
+ PyObject* token_list_obj = NULL;
+ uint32_t* token_ptr = NULL;
+ unsigned int max_output_tokens = 0;
unsigned long long seed = 0;
float temperature = 1.0f;
- if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$fK", kwlist,
- &temperature, &seed))
+ if (!PyArg_ParseTupleAndKeywords(args, kwargs, "I|$fK", kwlist,
+ &max_output_tokens, &temperature, &seed))
{
return NULL;
}
- uint32_t token_out = UINT32_MAX;
- enum gptoss_status status = gptoss_context_sample(
- self->handle, temperature, (uint64_t) seed, &token_out);
+ token_ptr = (uint32_t*) PyMem_Malloc(max_output_tokens * sizeof(uint32_t));
+ if (token_ptr == NULL) {
+ goto error;
+ }
+
+ size_t num_tokens = 0;
+ const enum gptoss_status status = gptoss_context_sample(
+ self->handle, temperature, (uint64_t) seed,
+ (size_t) max_output_tokens, token_ptr, &num_tokens);
if (status != gptoss_status_success) {
// TODO: set exception
- return NULL;
+ goto error;
}
- return PyLong_FromUnsignedLong((unsigned long) token_out);
+ token_list_obj = PyList_New((Py_ssize_t) num_tokens);
+ if (token_list_obj == NULL) {
+ goto error;
+ }
+
+ for (size_t t = 0; t < num_tokens; t++) {
+ PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
+ if (token_obj == NULL) {
+ goto error;
+ }
+
+ PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
+ }
+
+ PyMem_Free(token_ptr);
+ return token_list_obj;
+
+error:
+ PyMem_Free(token_ptr);
+ Py_XDECREF(token_list_obj);
+ return NULL;
}
static PyObject* PyGPTOSSContext_reset(PyGPTOSSContext* self) {
@@ -155,7 +184,7 @@ static PyMethodDef PyGPTOSSContext_methods[] = {
{"__copy__", (PyCFunction) PyGPTOSSContext_copy, METH_NOARGS, "Create a copy of the Context"},
{"append", (PyCFunction) PyGPTOSSContext_append, METH_O, "Append bytes to the Context"},
{"process", (PyCFunction) PyGPTOSSContext_process, METH_NOARGS, "Process tokens in the Context"},
- {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token prediction from the Context"},
+ {"sample", (PyCFunction) PyGPTOSSContext_sample, METH_VARARGS | METH_KEYWORDS, "Sample token predictions from the Context"},
{"reset", (PyCFunction) PyGPTOSSContext_reset, METH_NOARGS, "Discard the content of the Context"},
{NULL},
};
@@ -184,7 +213,6 @@ static PyObject* PyGPTOSSContext_get_max_tokens(PyGPTOSSContext* self, void* clo
static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure) {
PyObject* token_list_obj = NULL;
- PyObject* token_obj = NULL;
uint32_t* token_ptr = NULL;
size_t num_tokens = 0;
@@ -210,14 +238,12 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure
}
for (size_t t = 0; t < num_tokens; t++) {
- token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
+ PyObject* token_obj = PyLong_FromUnsignedLong((unsigned long) token_ptr[t]);
if (token_obj == NULL) {
goto error;
}
- if (PyList_SetItem(token_list_obj, (Py_ssize_t) t, token_obj) < 0) {
- goto error;
- }
- token_obj = NULL; // PyList_SetItem stole the reference
+
+ PyList_SET_ITEM(token_list_obj, (Py_ssize_t) t, token_obj);
}
PyMem_Free(token_ptr);
@@ -225,7 +251,6 @@ static PyObject* PyGPTOSSContext_get_tokens(PyGPTOSSContext* self, void* closure
error:
PyMem_Free(token_ptr);
- Py_XDECREF(token_obj);
Py_XDECREF(token_list_obj);
return NULL;
}
diff --git a/gpt_oss/metal/python/model.c b/gpt_oss/metal/python/model.c
index 49202a2c..a1713be7 100644
--- a/gpt_oss/metal/python/model.c
+++ b/gpt_oss/metal/python/model.c
@@ -12,7 +12,7 @@ static int PyGPTOSSModel_init(PyGPTOSSModel* self, PyObject* args, PyObject* kwa
if (!PyArg_ParseTuple(args, "s", &filepath)) {
return -1;
}
- status = gptoss_model_create_from_file(filepath, &self->handle);
+ status = gptoss_model_create_from_file(filepath, &self->handle, 0);
if (status != gptoss_status_success) {
// TODO: set exception
return -1;
diff --git a/gpt_oss/metal/source/accumulate.metal b/gpt_oss/metal/source/accumulate.metal
index f7ebc506..70dc4c2b 100644
--- a/gpt_oss/metal/source/accumulate.metal
+++ b/gpt_oss/metal/source/accumulate.metal
@@ -12,11 +12,15 @@ kernel void gptoss_f32_accumulate_e4(
const device float4* input [[ buffer(1) ]],
const device gptoss_expert_prediction* expert [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint2 threadgroup_size [[ threads_per_threadgroup ]])
{
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_vecs_per_threadgroup = args.num_vecs_per_threadgroup;
const uint threadgroup_start = gid.x * num_vecs_per_threadgroup;
diff --git a/gpt_oss/metal/source/context.c b/gpt_oss/metal/source/context.c
index af6a5d65..5cdaee7f 100644
--- a/gpt_oss/metal/source/context.c
+++ b/gpt_oss/metal/source/context.c
@@ -47,6 +47,45 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
atomic_store_explicit(&context->ref_count, 1, memory_order_relaxed);
context->max_tokens = context_length;
+ // Activation buffers
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->residual_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &context->rmsnorm_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &context->qkv_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &context->sdpa_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &context->gate_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &context->expert_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &context->swiglu_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &context->moe_activation_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ // Input/output buffers
+ status = gptoss_metal_buffer_create(&model->device, sizeof(struct gptoss_control), NULL, &context->control_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
status = gptoss_metal_buffer_create(&model->device, context_length * sizeof(uint32_t), NULL, &context->token_buffer);
if (status != gptoss_status_success) {
goto cleanup;
@@ -73,7 +112,11 @@ enum gptoss_status GPTOSS_ABI gptoss_context_create(
}
context->kvcache_size = context->kvcache_buffer.size;
- context->allocation_size = context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
+ context->allocation_size =
+ context->residual_activation_buffer.size + context->rmsnorm_activation_buffer.size +
+ context->qkv_activation_buffer.size + context->sdpa_activation_buffer.size +
+ context->gate_activation_buffer.size + context->expert_activation_buffer.size + context->swiglu_activation_buffer.size + context->moe_activation_buffer.size +
+ context->token_buffer.size + context->kvcache_buffer.size + context->score_buffer.size + context->argmax_buffer.size;
context->model = model;
gptoss_model_retain(model);
@@ -118,338 +161,477 @@ enum gptoss_status GPTOSS_ABI gptoss_context_get_tokens(
return gptoss_status_success;
}
-static enum gptoss_status process_batch(
- gptoss_context_t context)
+// Prefill: input_tokens_offset = number of tokens in KV cache, num_input_tokens > 0, num_output_tokens = 0.
+// Sampling: input_tokens_offset = number of tokens in the context - 1, num_input_tokens = 1, num_output_tokens = 1.
+// Perplexity: input_tokens_offset = 0, num_input_tokens > 1, num_output_tokens = num_input_tokens.
+static enum gptoss_status process_tokens(
+ gptoss_context_t context,
+ struct gptoss_metal_command_buffer* command_buffer,
+ size_t input_tokens_offset,
+ size_t num_input_tokens,
+ size_t num_output_tokens)
{
+ assert(num_input_tokens != 0);
+ assert(num_input_tokens <= context->max_batch_tokens);
+ assert(num_output_tokens <= context->max_batch_tokens);
+ assert(num_input_tokens >= num_output_tokens);
+ const size_t dense_matmul_kernel_token_multiple_constraint = 64;
+
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
- struct gptoss_metal_command_buffer command_buffer = {0};
const size_t attn_qkv_dim = model->head_dim * (model->num_heads + 2 * model->num_kv_heads);
- status = gptoss_metal_command_buffer_create(&model->command_queue, &command_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
- &command_buffer,
- &model->bf16_f32_embeddings_fn,
- /*threadgroup_size=*/512,
- &context->token_buffer,
- (context->num_tokens - context->num_batch_tokens) * sizeof(uint32_t),
- &model->shared_weight_buffer,
- /*weight_offset=*/0,
- &model->residual_activation_buffer,
- /*output_offset=*/0,
- /*num_tokens=*/context->num_batch_tokens,
- /*num_channels=*/model->embedding_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
- goto cleanup;
- }
- for (uint32_t n = 0; n < model->num_blocks; n++) {
- const bool last_block = n + 1 == model->num_blocks;
- const size_t num_output_tokens = last_block ? 1 : context->num_batch_tokens;
-
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
- &model->f32_bf16w_rmsnorm_fn,
- &model->residual_activation_buffer,
- /*input_offset=*/0,
- &model->shared_weight_buffer,
- /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
- &model->rmsnorm_activation_buffer,
- /*output_offset=*/0,
- /*num_tokens=*/context->num_batch_tokens,
- /*num_channels=*/model->embedding_dim,
- model->rmsnorm_epsilon);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
- }
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
- &command_buffer,
- &model->f32_bf16w_matmul_fn,
- /*threadgroup_size=*/256,
- &model->rmsnorm_activation_buffer,
- /*input_offset=*/0,
- &model->shared_weight_buffer,
- /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
+ const size_t input_tokens_end = input_tokens_offset + num_input_tokens;
+ for (size_t input_batch_start = input_tokens_offset;
+ input_batch_start < input_tokens_end;
+ input_batch_start += model->max_batch_tokens)
+ {
+ const size_t input_batch_size = math_min(model->max_batch_tokens, input_tokens_end - input_batch_start);
+ const size_t input_batch_end = input_batch_start + input_batch_size;
+ const size_t output_batch_size = math_sub_sat(num_output_tokens, input_tokens_end - input_batch_end);
+
+ status = gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
+ command_buffer,
+ &model->bf16_f32_embeddings_fn,
+ model->embeddings_threadgroup_size,
+ &context->token_buffer,
+ input_batch_start * sizeof(uint32_t),
&model->shared_weight_buffer,
- /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
- &model->qkv_activation_buffer,
+ /*weight_offset=*/0,
+ &context->residual_activation_buffer,
/*output_offset=*/0,
- /*num_tokens=*/context->num_batch_tokens,
- /*num_cols=*/model->embedding_dim,
- /*num_rows=*/attn_qkv_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
- goto cleanup;
- }
-
- status = gptoss_metal_command_buffer_encode_launch_f32_rope(
- &command_buffer,
- &model->f32_rope_fn,
- /*threadgroup_size=*/32,
- &model->qkv_activation_buffer,
- model->rope_theta,
- model->interpolation_scale,
- model->yarn_offset,
- model->yarn_scale,
- model->yarn_multiplier,
- context->num_batch_tokens,
- model->num_heads,
- model->num_kv_heads,
- model->head_dim,
- /*token_offset=*/context->num_kv_tokens);
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/input_batch_size,
+ /*num_channels=*/model->embedding_dim);
if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
- goto cleanup;
+ GPTOSS_LOG_ERROR("failed to encode bf16_f32_embeddings kernel launch");
+ return status;
}
- for (uint32_t t = 0; t < context->num_batch_tokens; t++) {
- status = gptoss_metal_command_buffer_encode_copy_buffer(
- &command_buffer,
- &model->qkv_activation_buffer,
- /*input_offset=*/(t * attn_qkv_dim + model->num_heads * model->head_dim) * sizeof(float),
- &context->kvcache_buffer,
- /*output_offset=*/(n * context->max_tokens + context->num_kv_tokens + t) * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
- /*size=*/2 * model->num_kv_heads * model->head_dim * sizeof(float));
+ for (uint32_t n = 0; n < model->num_blocks; n++) {
+ const bool last_block = n + 1 == model->num_blocks;
+ const size_t num_block_output_tokens = last_block ? output_batch_size : input_batch_size;
+
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
+ command_buffer,
+ &model->f32_bf16w_rmsnorm_fn,
+ &context->residual_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->attn_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
+ &context->rmsnorm_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/input_batch_size,
+ /*num_channels=*/model->embedding_dim,
+ model->rmsnorm_epsilon);
if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
- goto cleanup;
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
+ return status;
}
- }
- status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
- &command_buffer,
- &model->f32_sdpa_q8_d64_fn,
- &model->qkv_activation_buffer,
- /*q_offset=*/attn_qkv_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
- &context->kvcache_buffer,
- /*k_offset=*/n * context->max_tokens * 2 * model->num_kv_heads * model->head_dim * sizeof(float),
- &context->kvcache_buffer,
- /*v_offset=*/(n * context->max_tokens * 2 + 1) * model->num_kv_heads * model->head_dim * sizeof(float),
- &model->shared_weight_buffer,
- /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
- &model->sdpa_activation_buffer, /*output_offset=*/0,
- /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
- num_output_tokens, context->num_kv_tokens + (context->num_batch_tokens - num_output_tokens),
- model->num_heads, model->num_kv_heads, model->head_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
- goto cleanup;
- }
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
- &command_buffer,
- &model->f32_bf16w_matmul_fn,
- /*threadgroup_size=*/256,
- &model->sdpa_activation_buffer,
- /*input_offset=*/0,
- &model->shared_weight_buffer,
- /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
- &model->shared_weight_buffer,
- /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
- &model->residual_activation_buffer,
- /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
- /*num_tokens=*/num_output_tokens,
- /*num_cols=*/model->num_heads * model->head_dim,
- /*num_rows=*/model->embedding_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
- goto cleanup;
- }
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
+ command_buffer,
+ &model->f32_bf16w_dense_matmul_qkv_fn,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
+ &context->qkv_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/input_batch_size,
+ /*num_cols=*/model->embedding_dim,
+ /*num_rows=*/attn_qkv_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_qkv kernel launch");
+ return status;
+ }
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
- &model->f32_bf16w_rmsnorm_fn,
- &model->residual_activation_buffer,
- /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
- &model->shared_weight_buffer,
- /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
- &model->rmsnorm_activation_buffer,
- /*output_offset=*/0,
- num_output_tokens,
- model->embedding_dim,
- model->rmsnorm_epsilon);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
- }
+ status = gptoss_metal_command_buffer_encode_launch_f32_rope(
+ command_buffer,
+ &model->f32_rope_fn,
+ /*threadgroup_size=*/32,
+ &context->qkv_activation_buffer,
+ /*input_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->rope_theta,
+ model->interpolation_scale,
+ model->yarn_offset,
+ model->yarn_scale,
+ model->yarn_multiplier,
+ input_batch_size,
+ model->num_heads,
+ model->num_kv_heads,
+ model->head_dim,
+ /*token_offset=*/input_batch_start);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_rope kernel launch");
+ return status;
+ }
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
- &command_buffer,
- &model->f32_bf16w_matmul_fn,
- /*threadgroup_size=*/256,
- &model->rmsnorm_activation_buffer,
- /*input_offset=*/0,
- &model->shared_weight_buffer,
- /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
- &model->shared_weight_buffer,
- /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
- &model->gate_activation_buffer,
- /*output_offset=*/0,
- /*num_tokens=*/num_output_tokens,
- /*num_cols=*/model->embedding_dim,
- /*num_rows=*/model->num_experts);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
- goto cleanup;
- }
+ for (uint32_t t = 0; t < input_batch_size; t++) {
+ for (uint32_t kv = 0; kv < 2; kv++) {
+ for (uint32_t h = 0; h < model->num_kv_heads; h++) {
+ status = gptoss_metal_command_buffer_encode_copy_buffer(
+ command_buffer,
+ &context->qkv_activation_buffer,
+ /*input_offset=*/(t * attn_qkv_dim + (model->num_heads + kv * model->num_kv_heads + h) * model->head_dim) * sizeof(float),
+ &context->kvcache_buffer,
+ /*output_offset=*/(((n * model->num_kv_heads + h) * context->max_tokens + input_batch_start + t) * 2 + kv) * model->head_dim * sizeof(float),
+ /*size=*/model->head_dim * sizeof(float));
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode copy of token %" PRIu32 " to KV cache", t);
+ return status;
+ }
+ }
+ }
+ }
+ } else {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
+ command_buffer,
+ &model->f32_bf16w_matmul_qkv_fn,
+ model->attn_qkv_threadgroup_size,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->attn_qkv_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->attn_qkv_bias_offset + model->per_block_shared_weights_size * n,
+ &context->qkv_activation_buffer,
+ /*output_offset=*/0,
+ &context->kvcache_buffer,
+ /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/input_batch_size,
+ /*num_cols=*/model->embedding_dim,
+ /*num_q_heads=*/model->num_heads,
+ /*num_kv_heads=*/model->num_kv_heads,
+ /*attn_head_dim=*/model->head_dim,
+ /*token_offset=*/input_batch_start,
+ /*max_tokens=*/context->max_tokens,
+ /*rope_base=*/model->rope_theta,
+ /*interpolation_scale=*/model->interpolation_scale,
+ /*yarn_offset=*/model->yarn_offset,
+ /*yarn_scale=*/model->yarn_scale,
+ /*yarn_multiplier=*/model->yarn_multiplier);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch");
+ return status;
+ }
+ }
- const char* kernel_name = NULL;
- switch (model->num_experts) {
- case 32:
- kernel_name = "f32_topk_softmax_e32_k4_fn";
- status = gptoss_metal_command_buffer_encode_launch_f32_topk(
- &command_buffer,
- &model->f32_topk_softmax_e32_k4_fn,
- &model->gate_activation_buffer, /*input_offset=*/0,
- &model->expert_activation_buffer, /*output_offset=*/0,
- num_output_tokens,
- model->num_experts,
- model->num_active_experts);
- break;
- case 128:
- kernel_name = "f32_topk_softmax_e128_k4_fn";
- status = gptoss_metal_command_buffer_encode_launch_f32_topk(
- &command_buffer,
- &model->f32_topk_softmax_e128_k4_fn,
- &model->gate_activation_buffer, /*input_offset=*/0,
- &model->expert_activation_buffer, /*output_offset=*/0,
- num_output_tokens,
- model->num_experts,
- model->num_active_experts);
- break;
- default:
- status = gptoss_status_unsupported_argument;
- GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
- goto cleanup;
- }
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
- goto cleanup;
- }
+ if (num_block_output_tokens != 0) {
+ status = gptoss_metal_command_buffer_encode_launch_f32_sdpa(
+ command_buffer,
+ &model->f32_sdpa_q8_d64_fn,
+ &context->qkv_activation_buffer,
+ /*q_offset=*/attn_qkv_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->kvcache_buffer,
+ /*kv_offset=*/n * model->num_kv_heads * context->max_tokens * 2 * model->head_dim * sizeof(float),
+ &model->shared_weight_buffer,
+ /*s_offset=*/model->attn_sdpa_sink_offset + model->per_block_shared_weights_size * n,
+ &context->sdpa_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*window=*/n % 2 == 0 ? model->attention_window : UINT32_MAX,
+ /*kv_stride=*/2 * context->max_tokens * model->head_dim,
+ num_block_output_tokens,
+ input_batch_start + input_batch_size - num_block_output_tokens,
+ model->num_heads, model->num_kv_heads, model->head_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_sdpa kernel launch");
+ return status;
+ }
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
- &command_buffer,
- &model->f32_mf4w_moe_matmul_swiglu_fn,
- /*threadgroup_size=*/512,
- &model->rmsnorm_activation_buffer, /*input_offset=*/0,
- &model->expert_activation_buffer, /*expert_offset=*/0,
- &model->block_weight_buffers[n], /*weight_block_offset=*/0,
- &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
- &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_swiglu_bias_offset,
- &model->swiglu_activation_buffer, /*output_offset=*/0,
- model->swiglu_limit,
- model->per_expert_block_weight_size,
- num_output_tokens,
- model->num_active_experts,
- model->embedding_dim,
- model->mlp_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
- goto cleanup;
- }
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
+ command_buffer,
+ &model->f32_bf16w_dense_matmul_attn_output_fn,
+ &context->sdpa_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
+ &context->residual_activation_buffer,
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/num_block_output_tokens,
+ /*num_cols=*/model->num_heads * model->head_dim,
+ /*num_rows=*/model->embedding_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_attn_output kernel launch");
+ return status;
+ }
+ } else {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
+ command_buffer,
+ &model->f32_bf16w_matmul_fn,
+ model->attn_out_threadgroup_size,
+ &context->sdpa_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->attn_out_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->attn_out_bias_offset + model->per_block_shared_weights_size * n,
+ &context->residual_activation_buffer,
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/num_block_output_tokens,
+ /*num_cols=*/model->num_heads * model->head_dim,
+ /*num_rows=*/model->embedding_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_add kernel launch");
+ return status;
+ }
+ }
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
+ command_buffer,
+ &model->f32_bf16w_rmsnorm_fn,
+ &context->residual_activation_buffer,
+ /*input_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->mlp_rmsnorm_gain_offset + model->per_block_shared_weights_size * n,
+ &context->rmsnorm_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ num_block_output_tokens,
+ model->embedding_dim,
+ model->rmsnorm_epsilon);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
+ return status;
+ }
+ if (input_batch_size % dense_matmul_kernel_token_multiple_constraint == 0) {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
+ command_buffer,
+ &model->f32_bf16w_dense_matmul_mlp_gate_fn,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
+ &context->gate_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ num_block_output_tokens,
+ model->embedding_dim,
+ model->num_experts);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul_mlp_gate kernel launch");
+ return status;
+ }
+ } else {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
+ command_buffer,
+ &model->f32_bf16w_matmul_fn,
+ model->mlp_gate_threadgroup_size,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->mlp_gate_weight_offset + model->per_block_shared_weights_size * n,
+ &model->shared_weight_buffer,
+ /*bias_offset=*/model->mlp_gate_bias_offset + model->per_block_shared_weights_size * n,
+ &context->gate_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/num_block_output_tokens,
+ /*num_cols=*/model->embedding_dim,
+ /*num_rows=*/model->num_experts);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul kernel launch");
+ return status;
+ }
+ }
- status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
- &command_buffer,
- &model->f32_mf4w_moe_matmul_fn,
- /*threadgroup_size=*/512,
- &model->swiglu_activation_buffer, /*input_offset=*/0,
- &model->expert_activation_buffer, /*expert_offset=*/0,
- &model->block_weight_buffers[n], /*weight_block_offset=*/model->mlp_out_block_offset,
- &model->block_weight_buffers[n], /*weight_scale_offset=*/model->mlp_out_scale_offset,
- &model->block_weight_buffers[n], /*bias_offset=*/model->mlp_out_bias_offset,
- &model->moe_activation_buffer, /*output_offset=*/0,
- model->per_expert_block_weight_size,
- num_output_tokens,
- model->num_active_experts,
- model->mlp_dim,
- model->embedding_dim);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
- goto cleanup;
- }
+ const char* kernel_name = NULL;
+ switch (model->num_experts) {
+ case 32:
+ kernel_name = "f32_topk_softmax_e32_k4_fn";
+ status = gptoss_metal_command_buffer_encode_launch_f32_topk(
+ command_buffer,
+ &model->f32_topk_softmax_e32_k4_fn,
+ &context->gate_activation_buffer, /*input_offset=*/0,
+ &context->expert_activation_buffer, /*output_offset=*/0,
+ &context->control_buffer, /*control_offset=*/0,
+ num_block_output_tokens,
+ model->num_experts,
+ model->num_active_experts);
+ break;
+ case 128:
+ kernel_name = "f32_topk_softmax_e128_k4_fn";
+ status = gptoss_metal_command_buffer_encode_launch_f32_topk(
+ command_buffer,
+ &model->f32_topk_softmax_e128_k4_fn,
+ &context->gate_activation_buffer, /*input_offset=*/0,
+ &context->expert_activation_buffer, /*output_offset=*/0,
+ &context->control_buffer, /*control_offset=*/0,
+ num_block_output_tokens,
+ model->num_experts,
+ model->num_active_experts);
+ break;
+ default:
+ status = gptoss_status_unsupported_argument;
+ GPTOSS_LOG_ERROR("missing Top-K kernel for %" PRIu32 " experts", model->num_experts);
+ return status;
+ }
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode %s kernel launch", kernel_name);
+ return status;
+ }
- status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
- &command_buffer,
- &model->f32_accumulate_e4_fn,
- /*threadgroup_size=*/256,
- model->max_threadgroups,
- &model->moe_activation_buffer,
- /*input_offset=*/0,
- &model->expert_activation_buffer,
- /*expert_offset=*/0,
- &model->residual_activation_buffer,
- /*output_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
- model->embedding_dim,
- num_output_tokens,
- model->num_active_experts);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
- goto cleanup;
- }
- }
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
+ command_buffer,
+ &model->f32_mf4w_moe_matmul_swiglu_fn,
+ model->mlp_swiglu_threadgroup_size,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &context->expert_activation_buffer,
+ /*expert_offset=*/0,
+ &model->block_weight_buffers[n],
+ /*weight_block_offset=*/0,
+ &model->block_weight_buffers[n],
+ /*weight_scale_offset=*/model->mlp_swiglu_scale_offset,
+ &model->block_weight_buffers[n],
+ /*bias_offset=*/model->mlp_swiglu_bias_offset,
+ &context->swiglu_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->swiglu_limit,
+ model->per_expert_block_weight_size,
+ num_block_output_tokens,
+ model->num_active_experts,
+ model->embedding_dim,
+ model->mlp_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul_swiglu kernel launch");
+ return status;
+ }
- const size_t num_output_tokens = 1;
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
- &command_buffer,
- &model->f32_bf16w_rmsnorm_fn,
- &model->residual_activation_buffer,
- /*input_offset=*/model->embedding_dim * (context->num_batch_tokens - num_output_tokens) * sizeof(float),
- &model->shared_weight_buffer,
- /*weight_offset=*/model->rmsnorm_weight_offset,
- &model->rmsnorm_activation_buffer,
- /*output_offset=*/0,
- /*num_tokens=*/num_output_tokens,
- /*num_channels=*/model->embedding_dim,
- model->rmsnorm_epsilon);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
- goto cleanup;
- }
+ status = gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
+ command_buffer,
+ &model->f32_mf4w_moe_matmul_fn,
+ model->mlp_out_threadgroup_size,
+ &context->swiglu_activation_buffer,
+ /*input_offset=*/0,
+ &context->expert_activation_buffer,
+ /*expert_offset=*/0,
+ &model->block_weight_buffers[n],
+ /*weight_block_offset=*/model->mlp_out_block_offset,
+ &model->block_weight_buffers[n],
+ /*weight_scale_offset=*/model->mlp_out_scale_offset,
+ &model->block_weight_buffers[n],
+ /*bias_offset=*/model->mlp_out_bias_offset,
+ &context->moe_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->per_expert_block_weight_size,
+ num_block_output_tokens,
+ model->num_active_experts,
+ model->mlp_dim,
+ model->embedding_dim);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_mf4w_moe_matmul kernel launch");
+ return status;
+ }
- status = gptoss_metal_command_buffer_encode_fill_buffer(
- &command_buffer,
- &context->argmax_buffer,
- /*offset=*/0,
- /*size=*/sizeof(uint64_t) * num_output_tokens,
- /*fill_value=*/0xFF);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode fill buffer command");
- goto cleanup;
- }
- status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
- &command_buffer,
- &model->f32_bf16w_unembedding_fn,
- /*threadgroup_size=*/256,
- model->max_threadgroups,
- &model->rmsnorm_activation_buffer,
- /*input_offset=*/0,
- &model->shared_weight_buffer,
- /*weight_offset=*/model->unembedding_weight_offset,
- &context->score_buffer,
- /*output_offset=*/0,
- &context->argmax_buffer,
- /*argmax_offset=*/0,
- /*num_tokens=*/num_output_tokens,
- /*num_cols=*/model->embedding_dim,
- /*num_rows=*/model->vocabulary_size);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
- goto cleanup;
- }
+ status = gptoss_metal_command_buffer_encode_launch_f32_accumulate(
+ command_buffer,
+ &model->f32_accumulate_e4_fn,
+ model->mlp_acc_threadgroup_size,
+ model->max_threadgroups,
+ &context->moe_activation_buffer,
+ /*input_offset=*/0,
+ &context->expert_activation_buffer,
+ /*expert_offset=*/0,
+ &context->residual_activation_buffer,
+ /*output_offset=*/model->embedding_dim * (input_batch_size - num_block_output_tokens) * sizeof(float),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->embedding_dim,
+ num_block_output_tokens,
+ model->num_active_experts);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_accumulate kernel launch");
+ return status;
+ }
+ }
+ }
- gptoss_metal_command_buffer_commit(&command_buffer);
- gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
+ if (output_batch_size != 0) {
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
+ command_buffer,
+ &model->f32_bf16w_rmsnorm_fn,
+ &context->residual_activation_buffer,
+ /*input_offset=*/model->embedding_dim * (input_batch_size - output_batch_size) * sizeof(float),
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->rmsnorm_weight_offset,
+ &context->rmsnorm_activation_buffer,
+ /*output_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/output_batch_size,
+ /*num_channels=*/model->embedding_dim,
+ model->rmsnorm_epsilon);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_rmsnorm kernel launch");
+ return status;
+ }
- context->num_kv_tokens = context->num_tokens;
- context->num_processed_tokens = num_output_tokens;
- context->num_batch_tokens = 0;
+ status = gptoss_metal_command_buffer_encode_fill_buffer(
+ command_buffer,
+ &context->argmax_buffer,
+ /*offset=*/0,
+ /*size=*/sizeof(uint64_t) * output_batch_size,
+ /*fill_value=*/0xFF);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode fill buffer command");
+ return status;
+ }
-cleanup:
- gptoss_metal_command_buffer_release(&command_buffer);
- return status;
+ status = gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
+ command_buffer,
+ &model->f32_bf16w_unembedding_fn,
+ model->unembedding_threadgroup_size,
+ model->max_threadgroups,
+ &context->rmsnorm_activation_buffer,
+ /*input_offset=*/0,
+ &model->shared_weight_buffer,
+ /*weight_offset=*/model->unembedding_weight_offset,
+ &context->score_buffer,
+ /*output_offset=*/0,
+ &context->argmax_buffer,
+ /*argmax_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*num_tokens=*/output_batch_size,
+ /*num_cols=*/model->embedding_dim,
+ /*num_rows=*/model->vocabulary_size);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_unembedding kernel launch");
+ return status;
+ }
+ }
+ }
+ return gptoss_status_success;
}
enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
@@ -491,17 +673,18 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_chars(
}
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
- input_tokens[context->num_tokens] = best_token;
- context->num_tokens++;
- num_appended_tokens++;
- if (++context->num_batch_tokens == model->max_batch_tokens) {
- status = process_batch(context);
- if (status != gptoss_status_success) {
- break;
+ if (context->num_kv_tokens > context->num_tokens) {
+ if (input_tokens[context->num_tokens] != best_token) {
+ input_tokens[context->num_tokens] = best_token;
+
+ // Invalidate the KV cache starting with the newly added token.
+ context->num_kv_tokens = context->num_tokens;
}
- assert(context->num_batch_tokens == 0);
+ context->num_tokens++;
+ } else {
+ input_tokens[context->num_tokens++] = best_token;
}
- assert(context->num_batch_tokens < model->max_batch_tokens);
+ num_appended_tokens++;
text += best_token_length;
text_length -= best_token_length;
}
@@ -531,27 +714,32 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
enum gptoss_status status = gptoss_status_success;
uint32_t* input_tokens = (uint32_t*) context->token_buffer.ptr;
while (num_tokens != 0) {
- assert(context->num_batch_tokens < model->max_batch_tokens);
if (context->num_tokens == context->max_tokens) {
status = gptoss_status_context_overflow;
break;
}
- const size_t num_tokens_to_copy =
- math_min(context->max_tokens - context->num_tokens,
- math_min(num_tokens, model->max_batch_tokens - context->num_batch_tokens));
- memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
- context->num_tokens += num_tokens_to_copy;
- context->num_batch_tokens += num_tokens_to_copy;
- if (context->num_batch_tokens == model->max_batch_tokens) {
- status = process_batch(context);
- if (status != gptoss_status_success) {
- break;
+ if (context->num_kv_tokens > context->num_tokens) {
+ const size_t num_tokens_to_verify = math_min(context->num_kv_tokens - context->num_tokens, num_tokens);
+ size_t num_verified_tokens = 0;
+ for (; num_verified_tokens < num_tokens_to_verify; num_verified_tokens++) {
+ if (input_tokens[context->num_tokens + num_verified_tokens] != tokens[num_verified_tokens]) {
+ // Invalidate the KV cache starting with the newly added tokens.
+ context->num_kv_tokens = context->num_tokens + num_verified_tokens;
+ break;
+ }
}
- assert(context->num_batch_tokens == 0);
+
+ context->num_tokens += num_verified_tokens;
+ tokens += num_verified_tokens;
+ num_tokens -= num_verified_tokens;
+ } else {
+ const size_t num_tokens_to_copy = math_min(context->max_tokens - context->num_tokens, num_tokens);
+ memcpy(input_tokens + context->num_tokens, tokens, num_tokens_to_copy * sizeof(uint32_t));
+ context->num_tokens += num_tokens_to_copy;
+ tokens += num_tokens_to_copy;
+ num_tokens -= num_tokens_to_copy;
}
- tokens += num_tokens_to_copy;
- num_tokens -= num_tokens_to_copy;
}
return status;
@@ -560,10 +748,44 @@ enum gptoss_status GPTOSS_ABI gptoss_context_append_tokens(
enum gptoss_status GPTOSS_ABI gptoss_context_process(
gptoss_context_t context)
{
- if (context->num_batch_tokens != 0) {
- process_batch(context);
- }
+ if (context->num_tokens > context->num_kv_tokens) {
+ struct gptoss_metal_command_buffer command_buffer = {0};
+
+ enum gptoss_status status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
+ control->abort = 0;
+
+ status = process_tokens(
+ context,
+ &command_buffer,
+ /*input_tokens_offset=*/context->num_kv_tokens,
+ /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
+ /*num_output_tokens=*/0);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_command_buffer_commit(&command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ status = gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+
+ context->num_kv_tokens = context->num_tokens;
+
+cleanup:
+ gptoss_metal_command_buffer_release(&command_buffer);
+ return status;
+ }
+
return gptoss_status_success;
}
@@ -571,120 +793,135 @@ enum gptoss_status GPTOSS_ABI gptoss_context_sample(
gptoss_context_t context,
float temperature,
uint64_t seed,
- uint32_t* token_out)
+ size_t max_tokens,
+ uint32_t* tokens_out,
+ size_t* num_tokens_out)
{
enum gptoss_status status = gptoss_status_success;
const struct gptoss_model* model = context->model;
struct gptoss_metal_command_buffer command_buffer = {0};
- *token_out = UINT32_MAX;
- if (context->num_batch_tokens != 0) {
- status = process_batch(context);
- if (status != gptoss_status_success) {
- return status;
- }
- }
-
- if (temperature == 0.0f) {
- const uint64_t argmax_bits = ((const uint64_t*) context->argmax_buffer.ptr)[0];
- *token_out = (uint32_t) argmax_bits;
- } else {
- assert(context->num_processed_tokens != 0);
- status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
+ *num_tokens_out = 0;
- uint32_t num_threadgroups = 0;
- uint32_t num_dims_per_threadgroup = 0;
- status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
- &command_buffer,
- &model->f32_softmax_fn,
- /*threadgroup_size=*/256,
- model->max_threadgroups,
- &context->score_buffer,
- /*score_offset=*/0,
- &context->argmax_buffer,
- /*argmax_offset=*/0,
- &context->prob_buffer,
- /*prob_offset=*/0,
- &context->sum_buffer,
- /*sum_offset=*/0,
- model->vocabulary_size,
- /*num_tokens=*/1,
- temperature,
- &num_threadgroups,
- &num_dims_per_threadgroup);
- if (status != gptoss_status_success) {
- GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
- }
+ const uint32_t num_original_tokens = context->num_tokens;
- gptoss_metal_command_buffer_commit(&command_buffer);
- gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
+ status = gptoss_metal_command_buffer_create(&context->model->command_queue, &command_buffer);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
- const uint32_t sample_word = rng_squares32(context->num_tokens, seed + UINT64_C(0x123456789ABCDEF));
- float sample_cdf = (float) ((int32_t) sample_word & INT32_C(0x00FFFFFF)) * 0x1.0p-24f;
+ struct gptoss_control* control = (struct gptoss_control*) context->control_buffer.ptr;
+ control->abort = 0;
- const float* sum_ptr = (const float*) context->sum_buffer.ptr;
- float sum = 0.0f;
- for (uint32_t i = 0; i < num_threadgroups; i++) {
- sum += sum_ptr[i];
+ for (size_t t = 0; t < max_tokens; t++) {
+ if (context->num_kv_tokens < context->num_tokens) {
+ status = process_tokens(
+ context,
+ &command_buffer,
+ /*input_tokens_offset=*/context->num_kv_tokens,
+ /*num_input_tokens=*/context->num_tokens - context->num_kv_tokens,
+ /*num_output_tokens=*/1);
+ context->num_kv_tokens = context->num_tokens;
+ } else {
+ status = process_tokens(
+ context,
+ &command_buffer,
+ /*input_tokens_offset=*/context->num_tokens - 1,
+ /*num_input_tokens=*/1,
+ /*num_output_tokens=*/1);
}
- sample_cdf *= sum;
-
- uint32_t block_idx = 0, token_idx = 0;
- if (sample_cdf == 0.0f) {
- // Make sure we choose the first token with non-zero probability rather than just the first token
- sample_cdf = FLT_TRUE_MIN;
+ if (status != gptoss_status_success) {
+ goto cleanup;
}
- // Step 1: find block
- float cumsum = 0.0f;
- for (; block_idx < num_threadgroups; block_idx++) {
- const float new_cumsum = cumsum + sum_ptr[block_idx];
- if (new_cumsum >= sample_cdf) {
- break;
+ if (temperature != 0.0f) {
+ assert(context->num_processed_tokens != 0);
+ uint32_t num_threadgroups = 0;
+ uint32_t num_dims_per_threadgroup = 0;
+ status = gptoss_metal_command_buffer_encode_launch_f32_softmax(
+ &command_buffer,
+ &model->f32_softmax_fn,
+ /*threadgroup_size=*/512,
+ model->max_threadgroups,
+ &context->score_buffer,
+ /*score_offset=*/0,
+ &context->argmax_buffer,
+ /*argmax_offset=*/0,
+ &context->prob_buffer,
+ /*prob_offset=*/0,
+ &context->sum_buffer,
+ /*sum_offset=*/0,
+ &context->control_buffer,
+ /*control_offset=*/0,
+ model->vocabulary_size,
+ /*num_tokens=*/1,
+ temperature,
+ &num_threadgroups,
+ &num_dims_per_threadgroup);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_softmax kernel launch");
+ goto cleanup;
}
- cumsum = new_cumsum;
- }
- if (block_idx == num_threadgroups) {
- block_idx -= 1;
- }
- // Step 2: find token
- const float* prob_ptr = (const float*) context->prob_buffer.ptr + block_idx * num_dims_per_threadgroup;
- assert(model->vocabulary_size > num_dims_per_threadgroup * block_idx);
- uint32_t num_dims_per_block = math_min(num_dims_per_threadgroup, model->vocabulary_size - num_dims_per_threadgroup * block_idx);
- for (; token_idx < num_dims_per_block; token_idx++) {
- const float new_cumsum = cumsum + prob_ptr[token_idx];
- if (new_cumsum >= sample_cdf) {
- break;
+ status = gptoss_metal_command_buffer_encode_launch_f32_sample(
+ &command_buffer,
+ &model->f32_sample_fn,
+ /*min_threadgroup_size=*/512,
+ &context->prob_buffer,
+ /*prob_offset=*/0,
+ &context->sum_buffer,
+ /*sum_offset=*/0,
+ &context->token_buffer,
+ /*token_offset=*/context->num_tokens * sizeof(uint32_t),
+ &context->control_buffer,
+ /*control_offset=*/0,
+ /*rng_seed=*/seed + UINT64_C(0x123456789ABCDEF),
+ /*rng_offset=*/context->num_tokens,
+ /*num_blocks=*/num_threadgroups,
+ /*num_channels=*/model->vocabulary_size,
+ /*num_channels_per_block=*/num_dims_per_threadgroup);
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode f32_sample kernel launch");
+ goto cleanup;
+ }
+ } else {
+ status = gptoss_metal_command_buffer_encode_copy_buffer(
+ &command_buffer,
+ &context->argmax_buffer,
+ /*input_offset=*/0,
+ &context->token_buffer,
+ /*output_offset=*/context->num_tokens * sizeof(uint32_t),
+ /*size=*/sizeof(uint32_t));
+ if (status != gptoss_status_success) {
+ GPTOSS_LOG_ERROR("failed to encode copy buffer");
+ goto cleanup;
}
- cumsum = new_cumsum;
- }
- if (token_idx == num_dims_per_block) {
- token_idx -= 1;
}
+ context->num_tokens += 1;
+ context->num_kv_tokens = context->num_tokens;
+ }
- token_idx += block_idx * num_dims_per_threadgroup;
+ gptoss_metal_command_buffer_commit(&command_buffer);
+ gptoss_metal_command_buffer_wait_completion(&command_buffer, NULL);
- *token_out = token_idx;
+ const uint32_t* token_ptr = (const uint32_t*) context->token_buffer.ptr;
+ const uint32_t num_generated_tokens = context->num_tokens - num_original_tokens;
+ memcpy(tokens_out, token_ptr + num_original_tokens, num_generated_tokens * sizeof(uint32_t));
+ *num_tokens_out = num_generated_tokens;
cleanup:
- gptoss_metal_command_buffer_release(&command_buffer);
- return status;
- }
-
- return gptoss_status_success;
+ gptoss_metal_command_buffer_release(&command_buffer);
+ return status;
}
enum gptoss_status GPTOSS_ABI gptoss_context_reset(
gptoss_context_t context)
{
context->num_tokens = 0;
- context->num_kv_tokens = 0;
- context->num_batch_tokens = 0;
- context->num_processed_tokens = 0;
+
+ // Note: context->num_kv_tokens is not reset and context->input_tokens_buffer is not cleared.
+ // If the subsequently added tokens match the tokens already in the KV cache, we reuse the KV cache.
+
return gptoss_status_success;
}
@@ -700,6 +937,18 @@ enum gptoss_status GPTOSS_ABI gptoss_context_release(
{
if (context != NULL) {
if (atomic_fetch_sub_explicit(&context->ref_count, 1, memory_order_acq_rel) == 1) {
+ // Activation buffers
+ gptoss_metal_buffer_release(&context->residual_activation_buffer);
+ gptoss_metal_buffer_release(&context->rmsnorm_activation_buffer);
+ gptoss_metal_buffer_release(&context->qkv_activation_buffer);
+ gptoss_metal_buffer_release(&context->sdpa_activation_buffer);
+ gptoss_metal_buffer_release(&context->gate_activation_buffer);
+ gptoss_metal_buffer_release(&context->expert_activation_buffer);
+ gptoss_metal_buffer_release(&context->swiglu_activation_buffer);
+ gptoss_metal_buffer_release(&context->moe_activation_buffer);
+
+ // Input/output buffers
+ gptoss_metal_buffer_release(&context->control_buffer);
gptoss_metal_buffer_release(&context->token_buffer);
gptoss_metal_buffer_release(&context->score_buffer);
gptoss_metal_buffer_release(&context->prob_buffer);
diff --git a/gpt_oss/metal/source/embeddings.metal b/gpt_oss/metal/source/embeddings.metal
index b4541d21..9cc7d121 100644
--- a/gpt_oss/metal/source/embeddings.metal
+++ b/gpt_oss/metal/source/embeddings.metal
@@ -9,10 +9,15 @@ kernel void gptoss_bf16_f32_embeddings(
const device uint* tokens [[ buffer(1) ]],
const device bfloat4* weights [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
+ if (control->abort != 0) {
+ return;
+ }
+
const uint t = tokens[gid];
weights += t * args.num_vecs;
diff --git a/gpt_oss/metal/source/generate.c b/gpt_oss/metal/source/generate.c
index 976046f6..63a8569c 100644
--- a/gpt_oss/metal/source/generate.c
+++ b/gpt_oss/metal/source/generate.c
@@ -162,7 +162,7 @@ struct options parse_options(int argc, char** argv) {
static void print_profile() {
const size_t num_prefill_tokens = atomic_load(&globals.num_prefill_tokens);
const uint64_t prefill_microseconds = atomic_load(&globals.prefill_microseconds);
- const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens) - 1;
+ const size_t num_generated_tokens = atomic_load(&globals.num_generated_tokens);
const uint64_t generation_microseconds = atomic_load(&globals.generation_microseconds);
const uint64_t inference_bytes = atomic_load(&globals.inference_bytes);
if (num_prefill_tokens != 0 || num_generated_tokens != 0) {
@@ -173,10 +173,10 @@ static void print_profile() {
num_prefill_tokens,
(double) num_prefill_tokens / (double) prefill_microseconds * 1.0e+6);
}
- if (num_generated_tokens > 5) {
- printf("Generation speed (%zu tokens, excluding the first 5): %.1f tokens/second\n",
- (num_generated_tokens - 5),
- (double) (num_generated_tokens - 5) / (double) generation_microseconds * 1.0e+6);
+ if (num_generated_tokens != 0) {
+ printf("Generation speed (%zu tokens): %.1f tokens/second\n",
+ num_generated_tokens,
+ (double) num_generated_tokens / (double) generation_microseconds * 1.0e+6);
}
}
@@ -200,7 +200,7 @@ int main(int argc, char *argv[]) {
struct options options = parse_options(argc, argv);
const uint64_t load_start_time = mach_continuous_time();
- status = gptoss_model_create_from_file(options.model, &model);
+ status = gptoss_model_create_from_file(options.model, &model, 0);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to load model from file %s\n", options.model);
goto error;
@@ -268,8 +268,9 @@ int main(int argc, char *argv[]) {
while (options.max_tokens == 0 || atomic_load(&globals.num_generated_tokens) < options.max_tokens) {
uint32_t predicted_token = UINT32_MAX;
+ size_t num_predicted_tokens = 0;
const uint64_t inference_start_timestamp = mach_continuous_time();
- status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, &predicted_token);
+ status = gptoss_context_sample(context, options.temperature, /*rng_state=*/0, /*num_tokens=*/1, &predicted_token, &num_predicted_tokens);
if (status != gptoss_status_success) {
fprintf(stderr, "Error: failed to sample from the Context object\n");
goto error;
@@ -292,7 +293,7 @@ int main(int argc, char *argv[]) {
const size_t previous_num_generated_tokens = atomic_fetch_add(&globals.num_generated_tokens, 1);
if (previous_num_generated_tokens == 0) {
atomic_fetch_add(&globals.prefill_microseconds, mach_timestamp_diff_to_microseconds(prefill_start_time, prefill_end_time));
- } else if (previous_num_generated_tokens > 5) {
+ } else {
atomic_fetch_add(&globals.generation_microseconds, mach_timestamp_diff_to_microseconds(inference_start_timestamp, inference_end_timestamp));
}
printf("%.*s", (int) token_size, (const char*) token_ptr);
diff --git a/gpt_oss/metal/source/include/internal/kernel-args.h b/gpt_oss/metal/source/include/internal/kernel-args.h
index 677ce488..90dbdcf7 100644
--- a/gpt_oss/metal/source/include/internal/kernel-args.h
+++ b/gpt_oss/metal/source/include/internal/kernel-args.h
@@ -4,11 +4,34 @@
#include
#endif
+// TODO(ibahmed): specalize using metal function constants.
+#define QKV_Bm 64
+#define QKV_Bn 64
+#define QKV_Bk 32
+#define QKV_Sg_Bm 32
+#define QKV_Sg_Bn 32
+
+#define ATTN_OUTPUT_Bm 32
+#define ATTN_OUTPUT_Bn 64
+#define ATTN_OUTPUT_Bk 64
+#define ATTN_OUTPUT_Sg_Bm 32
+#define ATTN_OUTPUT_Sg_Bn 16
+
+#define MLP_GATE_Bm 64
+#define MLP_GATE_Bn 16
+#define MLP_GATE_Bk 64
+#define MLP_GATE_Sg_Bm 16
+#define MLP_GATE_Sg_Bn 16
+
struct gptoss_expert_prediction {
uint32_t expert_id;
float score;
};
+struct gptoss_control {
+ uint32_t abort;
+};
+
struct gptoss_topk_args {
uint32_t num_vecs_per_token;
};
@@ -16,6 +39,7 @@ struct gptoss_topk_args {
struct gptoss_sdpa_args {
uint32_t qkv_dim;
uint32_t num_kv_tokens;
+ uint32_t kv_stride;
uint32_t window;
};
@@ -62,6 +86,12 @@ struct gptoss_matmul_args {
uint32_t add;
};
+struct gptoss_dense_matmul_args {
+ uint32_t m;
+ uint32_t n;
+ uint32_t k;
+};
+
struct gptoss_unembedding_args {
uint32_t num_column_vecs;
uint32_t num_rows_per_threadgroup;
@@ -97,9 +127,29 @@ struct gptoss_rope_args {
float yarn_multiplier;
};
+struct gptoss_qkv_args {
+ uint32_t num_column_vecs;
+ uint32_t num_rows;
+ uint32_t token_offset;
+ float freq_scale;
+ float interpolation_scale;
+ float yarn_offset;
+ float yarn_scale;
+ float yarn_multiplier;
+ uint32_t max_tokens;
+};
+
struct gptoss_softmax_args {
uint32_t num_vecs;
uint32_t num_vecs_per_threadgroup;
uint32_t max_threadgroups;
float temperature;
};
+
+struct gptoss_sample_args {
+ uint64_t rng_seed;
+ uint32_t rng_offset;
+ uint32_t num_blocks;
+ uint32_t num_dims;
+ uint32_t num_dims_per_block;
+};
diff --git a/gpt_oss/metal/source/include/internal/math.h b/gpt_oss/metal/source/include/internal/math.h
index 8d6a9040..06f2b1f1 100644
--- a/gpt_oss/metal/source/include/internal/math.h
+++ b/gpt_oss/metal/source/include/internal/math.h
@@ -1,5 +1,6 @@
#pragma once
+#include
#include
#include
@@ -15,11 +16,25 @@ inline static size_t math_min(size_t a, size_t b) {
return a < b ? a : b;
}
-static size_t math_round_up_po2(size_t bytes, size_t multiple) {
+inline static size_t math_sub_sat(size_t a, size_t b) {
+ return a > b ? a - b : 0;
+}
+
+static size_t math_round_down_po2(size_t number, size_t multiple) {
+ assert(multiple != 0);
+ assert((multiple & (multiple - 1)) == 0);
+
+ return number & -multiple;
+}
+
+static size_t math_round_up_po2(size_t number, size_t multiple) {
+ assert(multiple != 0);
+ assert((multiple & (multiple - 1)) == 0);
+
const size_t multiple_mask = multiple - 1;
- if ((bytes & multiple_mask) != 0) {
- bytes |= multiple_mask;
- bytes += 1;
+ if ((number & multiple_mask) != 0) {
+ number |= multiple_mask;
+ number += 1;
}
- return bytes;
+ return number;
}
diff --git a/gpt_oss/metal/source/include/internal/metal-kernels.h b/gpt_oss/metal/source/include/internal/metal-kernels.h
index aa5a3ef7..c12a834d 100644
--- a/gpt_oss/metal/source/include/internal/metal-kernels.h
+++ b/gpt_oss/metal/source/include/internal/metal-kernels.h
@@ -74,6 +74,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels);
@@ -86,6 +88,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels,
float epsilon);
@@ -102,10 +106,41 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
+ size_t threadgroup_size,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* kv_buffer,
+ size_t kv_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_q_heads,
+ uint32_t num_kv_heads,
+ uint32_t attn_head_dim,
+ uint32_t token_offset,
+ uint32_t max_tokens,
+ float rope_base,
+ float interpolation_scale,
+ float yarn_offset,
+ float yarn_scale,
+ float yarn_multiplier);
+
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
const struct gptoss_metal_command_buffer* command_buffer,
const struct gptoss_metal_function* f32_bf16w_matmul_fn,
@@ -118,6 +153,62 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows);
+
+enum gptoss_status
+gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows);
+
+enum gptoss_status
+gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows);
+
+enum gptoss_status
+gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
@@ -135,6 +226,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
size_t output_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows);
@@ -155,6 +248,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float swiglu_limit,
uint32_t expert_stride,
uint32_t num_tokens,
@@ -178,6 +273,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
@@ -189,6 +286,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
const struct gptoss_metal_function* f32_rope_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* activations_buffer,
+ size_t activations_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float rope_base,
float interpolation_scale,
float yarn_offset,
@@ -211,6 +311,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
size_t expert_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_experts);
@@ -222,6 +324,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_active_experts);
@@ -231,15 +335,16 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
const struct gptoss_metal_function* f32_sdpa_fn,
const struct gptoss_metal_buffer* q_buffer,
size_t q_offset,
- const struct gptoss_metal_buffer* k_buffer,
- size_t k_offset,
- const struct gptoss_metal_buffer* v_buffer,
- size_t v_offset,
+ const struct gptoss_metal_buffer* kv_buffer,
+ size_t kv_offset,
const struct gptoss_metal_buffer* s_buffer,
size_t s_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t window,
+ uint32_t kv_stride,
uint32_t num_q_tokens,
uint32_t num_kv_tokens,
uint32_t num_q_heads,
@@ -259,12 +364,32 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
float temperature,
uint32_t* num_threadgroups_out,
uint32_t* num_channels_per_threadgroup_out);
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_sample_fn,
+ size_t min_threadgroup_size,
+ const struct gptoss_metal_buffer* prob_buffer,
+ size_t prob_offset,
+ const struct gptoss_metal_buffer* sum_buffer,
+ size_t sum_offset,
+ const struct gptoss_metal_buffer* token_buffer,
+ size_t token_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint64_t rng_seed,
+ uint32_t rng_offset,
+ uint32_t num_blocks,
+ uint32_t num_channels,
+ uint32_t num_channels_per_block);
+
#ifdef __cplusplus
} // extern "C"
#endif
diff --git a/gpt_oss/metal/source/include/internal/metal.h b/gpt_oss/metal/source/include/internal/metal.h
index 41194bda..f38190f0 100644
--- a/gpt_oss/metal/source/include/internal/metal.h
+++ b/gpt_oss/metal/source/include/internal/metal.h
@@ -118,9 +118,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
size_t num_threadgroups_z,
size_t params_size,
const void* params,
- size_t num_buffers,
- const struct gptoss_metal_buffer** buffers,
- const size_t* buffer_offsets);
+ size_t num_device_buffers,
+ const struct gptoss_metal_buffer** device_buffers,
+ const size_t* device_buffer_offsets,
+ size_t threadgroup_buffer_size);
enum gptoss_status gptoss_metal_command_buffer_commit(
const struct gptoss_metal_command_buffer* command_buffer);
diff --git a/gpt_oss/metal/source/include/internal/metal.hpp b/gpt_oss/metal/source/include/internal/metal.hpp
index 9df7aed7..a143a11a 100644
--- a/gpt_oss/metal/source/include/internal/metal.hpp
+++ b/gpt_oss/metal/source/include/internal/metal.hpp
@@ -246,10 +246,11 @@ class CommandBuffer {
const std::array& threadgroup_size,
const std::array& num_threadgroups,
size_t params_size, const void* params,
- std::initializer_list buffers = {})
+ std::initializer_list device_buffers = {},
+ size_t threadgroup_buffer_size = 0)
{
- std::vector buffer_handles(buffers.size());
- std::transform(buffers.begin(), buffers.end(), buffer_handles.begin(),
+ std::vector buffer_handles(device_buffers.size());
+ std::transform(device_buffers.begin(), device_buffers.end(), buffer_handles.begin(),
[](const Buffer* buffer) -> const gptoss_metal_buffer* { return buffer->handle(); });
Check(gptoss_metal_command_buffer_encode_launch_kernel(
&command_buffer_, function.handle(),
@@ -258,7 +259,8 @@ class CommandBuffer {
params_size, params,
buffer_handles.size(),
buffer_handles.data(),
- /*buffer_offsets=*/nullptr),
+ /*buffer_offsets=*/nullptr,
+ threadgroup_buffer_size),
"gptoss_metal_command_buffer_encode_launch_kernel");
}
diff --git a/gpt_oss/metal/source/include/internal/model.h b/gpt_oss/metal/source/include/internal/model.h
index e2a45647..c63578a7 100644
--- a/gpt_oss/metal/source/include/internal/model.h
+++ b/gpt_oss/metal/source/include/internal/model.h
@@ -1,6 +1,9 @@
#pragma once
-#include
+#ifndef __cplusplus
+ #include
+#endif
+#include
#include
#include
@@ -8,7 +11,11 @@
struct gptoss_tokenizer {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
void* mapping_ptr;
size_t mapping_size;
@@ -23,7 +30,11 @@ struct gptoss_tokenizer {
};
struct gptoss_model {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
struct gptoss_tokenizer* tokenizer;
@@ -54,6 +65,8 @@ struct gptoss_model {
// Once the batch size is reached, we process it to fill the KV cache.
size_t max_batch_tokens;
+ bool lock_memory;
+
size_t weights_size;
size_t allocation_size;
@@ -65,6 +78,10 @@ struct gptoss_model {
struct gptoss_metal_function bf16_f32_embeddings_fn;
struct gptoss_metal_function f32_bf16w_rmsnorm_fn;
struct gptoss_metal_function f32_bf16w_matmul_fn;
+ struct gptoss_metal_function f32_bf16w_matmul_qkv_fn;
+ struct gptoss_metal_function f32_bf16w_dense_matmul_qkv_fn;
+ struct gptoss_metal_function f32_bf16w_dense_matmul_attn_output_fn;
+ struct gptoss_metal_function f32_bf16w_dense_matmul_mlp_gate_fn;
struct gptoss_metal_function f32_bf16w_unembedding_fn;
struct gptoss_metal_function f32_rope_fn;
struct gptoss_metal_function f32_mf4w_moe_matmul_swiglu_fn;
@@ -74,21 +91,20 @@ struct gptoss_model {
struct gptoss_metal_function f32_topk_softmax_e128_k4_fn;
struct gptoss_metal_function f32_sdpa_q8_d64_fn;
struct gptoss_metal_function f32_softmax_fn;
-
- // Activation buffers.
- // TODO: merge into a single buffer.
- struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
- struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
- struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
- struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
- struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
- struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
- struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
- struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
+ struct gptoss_metal_function f32_sample_fn;
size_t per_block_shared_weights_size;
size_t per_expert_block_weight_size;
+ size_t embeddings_threadgroup_size;
+ size_t attn_qkv_threadgroup_size;
+ size_t attn_out_threadgroup_size;
+ size_t mlp_gate_threadgroup_size;
+ size_t mlp_swiglu_threadgroup_size;
+ size_t mlp_out_threadgroup_size;
+ size_t mlp_acc_threadgroup_size;
+ size_t unembedding_threadgroup_size;
+
size_t attn_rmsnorm_gain_offset;
size_t attn_qkv_weight_offset;
size_t attn_qkv_bias_offset;
@@ -115,7 +131,11 @@ struct gptoss_model {
#define GPTOSS_DEFAULT_BATCH_SIZE 128
struct gptoss_context {
+#ifndef __cplusplus
atomic_uint_least64_t ref_count;
+#else
+ uint_least64_t ref_count;
+#endif
struct gptoss_model* model;
// Number of tokens processed in the context.
@@ -125,16 +145,22 @@ struct gptoss_context {
// Length of the context.
size_t max_tokens;
- // Current number of tokens in the batch.
- // Always in the [0, max_batch_tokens) range.
- size_t num_batch_tokens;
- // Number of tokens processed in the last batch.
- // Activations for [num_batch_tokens, num_processed_tokens) tokens can be accessed from internal structures.
- size_t num_processed_tokens;
-
size_t kvcache_size;
size_t allocation_size;
+ // Activation buffers.
+ // TODO: merge into a single buffer.
+ struct gptoss_metal_buffer residual_activation_buffer; // Residual stream
+ struct gptoss_metal_buffer rmsnorm_activation_buffer; // Both attention & MLP RMSNorm output
+ struct gptoss_metal_buffer qkv_activation_buffer; // QKV projection output
+ struct gptoss_metal_buffer sdpa_activation_buffer; // SDPA output
+ struct gptoss_metal_buffer gate_activation_buffer; // MoE gating output
+ struct gptoss_metal_buffer expert_activation_buffer; // MoE expert predictions
+ struct gptoss_metal_buffer swiglu_activation_buffer; // MLP+SwiGLU output
+ struct gptoss_metal_buffer moe_activation_buffer; // MoE MLP output (per-active expert)
+
+ // Input/output buffers.
+ struct gptoss_metal_buffer control_buffer;
struct gptoss_metal_buffer token_buffer; // uint32 token IDs
struct gptoss_metal_buffer score_buffer; // unembedding outputs
struct gptoss_metal_buffer prob_buffer;
@@ -142,12 +168,3 @@ struct gptoss_context {
struct gptoss_metal_buffer argmax_buffer;
struct gptoss_metal_buffer kvcache_buffer;
};
-
-struct gptoss_sampler {
- atomic_uint_least64_t ref_count;
-
- float temperature;
- float top_p;
- float presence_penalty;
- float frequency_penalty;
-};
diff --git a/gpt_oss/metal/source/matmul.metal b/gpt_oss/metal/source/matmul.metal
index 6396f6cc..8831563b 100644
--- a/gpt_oss/metal/source/matmul.metal
+++ b/gpt_oss/metal/source/matmul.metal
@@ -3,6 +3,7 @@
#include
#include
#include
+#include
#include
@@ -23,12 +24,16 @@ kernel void gptoss_f32_bf16w_matmul(
const device bfloat4* weight [[ buffer(2) ]],
const device bfloat* bias [[ buffer(3) ]],
device float* output [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
@@ -62,12 +67,101 @@ kernel void gptoss_f32_bf16w_matmul(
}
}
+kernel void gptoss_f32_bf16w_matmul_qkv(
+ constant gptoss_qkv_args& args [[ buffer(0) ]],
+ const device float4* input [[ buffer(1) ]],
+ const device bfloat4* weight [[ buffer(2) ]],
+ const device bfloat* bias [[ buffer(3) ]],
+ device float* q [[ buffer(4) ]],
+ device float* kv [[ buffer(5) ]],
+ const device gptoss_control* control [[ buffer(6) ]],
+ threadgroup void* scratch [[ threadgroup(0) ]],
+ uint2 gid [[threadgroup_position_in_grid]],
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
+{
+ const uint simdgroup_size = 32;
+ const uint head_dim = 64;
+ const uint num_q_heads = 64;
+ const uint num_kv_heads = 8;
+ if (control->abort != 0) {
+ return;
+ }
+
+ const uint num_column_vecs = args.num_column_vecs;
+ const uint row = gid.x * num_simdgroups + simdgroup_idx;
+
+ input += gid.y * num_column_vecs + simdgroup_tid;
+ weight += num_column_vecs * row + simdgroup_tid;
+ bias += row;
+ q += gid.y * args.num_rows;
+
+ uint num_iter = (num_column_vecs - simdgroup_tid + (simdgroup_size - 1)) / simdgroup_size;
+
+ float4 sum4 = 0.0f;
+ do {
+ const bfloat4 w = *weight;
+ const float4 i = *input;
+ sum4 = metal::fma(static_cast(w), i, sum4);
+
+ weight += simdgroup_size;
+ input += simdgroup_size;
+ } while (--num_iter != 0);
+ const float2 sum2 = sum4.xy + sum4.zw;
+ float sum = sum2.x + sum2.y;
+ sum = metal::simd_sum(sum);
+ if (metal::simd_is_first()) {
+ sum += static_cast(*bias);
+ static_cast(scratch)[simdgroup_idx] = sum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (simdgroup_idx == 0) {
+ const uint num_half_simdgroups = num_simdgroups / 2;
+ if (simdgroup_tid < num_half_simdgroups) {
+ float2 vals = static_cast(scratch)[simdgroup_tid];
+ const uint idx = gid.x * num_half_simdgroups + simdgroup_tid;
+ const uint head_idx = idx / (head_dim / 2);
+ const uint token_idx = args.token_offset + gid.y;
+ const uint dim_idx = idx % (head_dim / 2);
+ if (head_idx < num_q_heads + num_kv_heads) {
+ const float dim_idx_val = static_cast(dim_idx);
+ const float inv_extrapolation_freq = metal::precise::exp(dim_idx_val * args.freq_scale);
+ const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
+ const float alpha = metal::saturate(metal::fma(dim_idx_val, args.yarn_scale, args.yarn_offset));
+ const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
+
+ const float phi = static_cast(token_idx) * inv_freq;
+ const float yarn_multiplier = args.yarn_multiplier;
+ float cosphi;
+ const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
+ cosphi *= yarn_multiplier;
+
+ vals = (float2) {
+ vals.x * cosphi - vals.y * sinphi,
+ vals.x * sinphi + vals.y * cosphi,
+ };
+ }
+ if (head_idx < num_q_heads) {
+ reinterpret_cast(q)[idx] = vals;
+ } else if (head_idx < num_q_heads + num_kv_heads) {
+ const uint h = head_idx - num_q_heads;
+ reinterpret_cast(kv + (h * args.max_tokens + token_idx) * 2 * head_dim)[dim_idx] = vals;
+ } else {
+ const uint h = head_idx - num_q_heads - num_kv_heads;
+ reinterpret_cast(kv + (h * args.max_tokens + token_idx) * 2 * head_dim + head_dim)[dim_idx] = vals;
+ }
+ }
+ }
+}
+
kernel void gptoss_f32_bf16w_unembedding(
constant gptoss_unembedding_args& args [[ buffer(0) ]],
const device float4* input [[ buffer(1) ]],
const device bfloat4* weight [[ buffer(2) ]],
device float* output [[ buffer(3) ]],
device metal::atomic_ulong* argmax [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint2 gid [[threadgroup_position_in_grid]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
@@ -75,6 +169,9 @@ kernel void gptoss_f32_bf16w_unembedding(
{
const uint simdgroup_size = 32;
threadgroup uint2 threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row_start = gid.x * args.num_rows_per_threadgroup + simdgroup_idx;
@@ -135,3 +232,191 @@ kernel void gptoss_f32_bf16w_unembedding(
}
}
}
+
+// Current constraints for the dense matmul kernel:
+// 1- All B* and Sg_* are a multiple of 8.
+// 2- Bm is divisible by Sg_n and Bn is divisible by Sg_n.
+// 3- M, N and K are all divisible by 8..
+template
+inline void _gptoss_f32_bf16w_dense_matmul_impl(
+ constant gptoss_dense_matmul_args& args, const device float* lhs,
+ const device bfloat* rhs, const device bfloat* __restrict__ bias,
+ device float* out, const device gptoss_control* control, threadgroup float* scratch, threadgroup float* bias_tile,
+ uint sg_id, uint sg_count_per_tg, uint3 gid, uint3 tg_id, uint3 local_tid,
+ uint3 threadgroup_size) {
+
+ if (control->abort != 0) {
+ return;
+ }
+
+ // The kernel assumes that M, K, and N are divisible by 8.
+ const uint M = args.m;
+ const uint K = args.k;
+ const uint N = args.n;
+ static_assert((Bm % 8u) == 0u, "Bm must be a multiple of 8");
+ static_assert((Bn % 8u) == 0u, "Bn must be a multiple of 8");
+ static_assert((Bk % 8u) == 0u, "Bk must be a multiple of 8");
+ static_assert((Sg_Bm % 8u) == 0u, "Bk must be a multiple of 8");
+ static_assert((Sg_Bn % 8u) == 0u, "Bk must be a multiple of 8");
+ static_assert((Bn % Sg_Bn) == 0u, "Bn must be a multiple of Sg_Bn");
+ static_assert((Bm % Sg_Bm) == 0u, "Bm must be a multiple of Sg_Bm");
+
+ // Get row and col tg.
+ const uint row_tg = tg_id.y;
+ const uint col_tg = tg_id.x;
+ // Get row and col local tid.
+ const uint row_tg_offset = row_tg * Bm;
+ const uint col_tg_offset = col_tg * Bn;
+
+ const uint sg_col_count = Bn / Sg_Bn;
+ const uint row_sg = sg_id / sg_col_count;
+ const uint col_sg = sg_id % sg_col_count;
+
+ const uint row_sg_offset = row_sg * Sg_Bm;
+ const uint col_sg_offset = col_sg * Sg_Bn;
+ constexpr uint temp_result_size = (Sg_Bm / 8) * (Sg_Bn / 8);
+ // Create an array of simdgroup_float8x8 to hold temp results.
+ metal::simdgroup_float8x8 OutTiles[temp_result_size];
+#pragma clang loop unroll(full)
+ for (uint i = 0; i < temp_result_size; i++) {
+ OutTiles[i] = metal::make_filled_simdgroup_matrix(
+ static_cast(0.0));
+ }
+
+ for (uint k_offset = 0; k_offset < K; k_offset += Bk) {
+#pragma clang loop unroll(full)
+ for (uint k = 0; k < Bk; k += 8) {
+#pragma clang loop unroll(full)
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
+ // const uint m_subtile = row_sg_offset + m_subtile_;
+ // const uint row_index_in_out_tile = (m_subtile - row_sg_offset) / 8;
+ const uint row_index_in_out_tile = m_subtile_ / 8;
+ metal::simdgroup_float8x8 LHStile;
+ const uint k_id = k + k_offset;
+ const uint row_offset = row_tg_offset + row_sg_offset + m_subtile_;
+ metal::simdgroup_load(LHStile, lhs, K, ulong2(k_id, row_offset));
+ metal::simdgroup_bfloat8x8 RHStile;
+#pragma clang loop unroll(full)
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
+ const uint col_index_in_out_tile = n_subtile_ / 8;
+ const uint current_index_out_tile =
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
+ const uint col_offset = col_tg_offset + col_sg_offset + n_subtile_;
+ simdgroup_load(RHStile, rhs, K, ulong2(k_id, col_offset), /*transpose=*/true);
+ // If rhs was not transposed, use the following instead:
+ // simdgroup_load(RHStile, rhs, N, ulong2(col_offset, k_id));
+ simdgroup_multiply_accumulate(OutTiles[current_index_out_tile],
+ LHStile, RHStile,
+ OutTiles[current_index_out_tile]);
+ }
+ }
+ }
+ }
+ // Epilogue.
+#pragma clang loop unroll(full)
+ for (uint n_subtile_ = 0; n_subtile_ < Sg_Bn; n_subtile_ += 8) {
+ const uint col_index_in_out_tile = n_subtile_ / 8;
+ const uint local_col_offset = col_sg_offset + n_subtile_;
+#pragma clang loop unroll(full)
+ for (uint m_subtile_ = 0; m_subtile_ < Sg_Bm; m_subtile_ += 8) {
+ const uint row_index_in_out_tile = m_subtile_ / 8;
+ const uint local_row_offset = row_sg_offset + m_subtile_;
+ const uint current_index_out_tile =
+ row_index_in_out_tile * (Sg_Bn / 8) + col_index_in_out_tile;
+ simdgroup_store(OutTiles[current_index_out_tile], scratch, Bn,
+ ulong2(local_col_offset, local_row_offset));
+ }
+ }
+ // TODO(ibahmed): vectorize these loads an maybe unroll the loop.
+ const uint thread_count_per_tg =
+ threadgroup_size.x * threadgroup_size.y * threadgroup_size.z;
+ for (uint c_local = local_tid.x; c_local < Bn;
+ c_local += thread_count_per_tg) {
+ const uint c_global = col_tg_offset + c_local;
+ bias_tile[c_local] =
+ (c_global < N) ? static_cast(bias[c_global]) : 0.0f;
+ }
+
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+
+ // TODO(ibahmed): vectorize these stores and maybe unroll the loop.
+ for (uint idx = local_tid.x; idx < Bm * Bn; idx += thread_count_per_tg) {
+ const uint r = idx / Bn;
+ const uint c = idx % Bn;
+
+ const uint out_row = row_tg_offset + r;
+ const uint out_col = col_tg_offset + c;
+
+ if (out_row < M && out_col < N) {
+ float acc = scratch[idx] + bias_tile[c];
+ if (add) {
+ acc += out[out_row * N + out_col];
+ }
+ out[out_row * N + out_col] = acc;
+ }
+ }
+}
+
+kernel void gptoss_f32_bf16w_dense_matmul_qkv(
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
+ const device float* lhs [[buffer(1)]],
+ const device bfloat* rhs [[buffer(2)]],
+ const device bfloat* __restrict__ bias [[buffer(3)]],
+ device float* out [[buffer(4)]],
+ const device gptoss_control* control [[buffer(5)]],
+ uint sg_id [[simdgroup_index_in_threadgroup]],
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
+ uint3 gid [[thread_position_in_grid]],
+ uint3 tg_id [[threadgroup_position_in_grid]],
+ uint3 local_tid [[thread_position_in_threadgroup]],
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
+ threadgroup float scratch[QKV_Bm * QKV_Bn];
+ threadgroup float bias_tile[QKV_Bn];
+ _gptoss_f32_bf16w_dense_matmul_impl(
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
+ gid, tg_id, local_tid, threadgroup_size);
+}
+
+kernel void gptoss_f32_bf16w_dense_matmul_attn_output(
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
+ const device float* lhs [[buffer(1)]],
+ const device bfloat* rhs [[buffer(2)]],
+ const device bfloat* __restrict__ bias [[buffer(3)]],
+ device float* out [[buffer(4)]],
+ const device gptoss_control* control [[buffer(5)]],
+ uint sg_id [[simdgroup_index_in_threadgroup]],
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
+ uint3 gid [[thread_position_in_grid]],
+ uint3 tg_id [[threadgroup_position_in_grid]],
+ uint3 local_tid [[thread_position_in_threadgroup]],
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
+ threadgroup float scratch[ATTN_OUTPUT_Bm * ATTN_OUTPUT_Bn];
+ threadgroup float bias_tile[ATTN_OUTPUT_Bn];
+ _gptoss_f32_bf16w_dense_matmul_impl(
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
+ gid, tg_id, local_tid, threadgroup_size);
+}
+
+kernel void gptoss_f32_bf16w_dense_matmul_mlp_gate(
+ constant gptoss_dense_matmul_args& args [[buffer(0)]],
+ const device float* lhs [[buffer(1)]],
+ const device bfloat* rhs [[buffer(2)]],
+ const device bfloat* __restrict__ bias [[buffer(3)]],
+ device float* out [[buffer(4)]],
+ const device gptoss_control* control [[buffer(5)]],
+ uint sg_id [[simdgroup_index_in_threadgroup]],
+ uint sg_count_per_tg [[dispatch_simdgroups_per_threadgroup]],
+ uint3 gid [[thread_position_in_grid]],
+ uint3 tg_id [[threadgroup_position_in_grid]],
+ uint3 local_tid [[thread_position_in_threadgroup]],
+ uint3 threadgroup_size [[threads_per_threadgroup]]) {
+ threadgroup float scratch[MLP_GATE_Bm * MLP_GATE_Bn];
+ threadgroup float bias_tile[MLP_GATE_Bn];
+ _gptoss_f32_bf16w_dense_matmul_impl(
+ args, lhs, rhs, bias, out, control, scratch, bias_tile, sg_id, sg_count_per_tg,
+ gid, tg_id, local_tid, threadgroup_size);
+}
diff --git a/gpt_oss/metal/source/metal-kernels.c b/gpt_oss/metal/source/metal-kernels.c
index 61b9c973..3aaeb32f 100644
--- a/gpt_oss/metal/source/metal-kernels.c
+++ b/gpt_oss/metal/source/metal-kernels.c
@@ -46,7 +46,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_u32_fill_random(
threadgroup_size, 1, 1,
num_threadgroups, 1, 1,
sizeof(args), &args,
- 1, &output_buffer, &output_offset);
+ 1, &output_buffer, &output_offset,
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
@@ -93,7 +94,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_fill_random(
threadgroup_size, 1, 1,
num_threadgroups, 1, 1,
sizeof(args), &args,
- 1, &output_buffer, &output_offset);
+ 1, &output_buffer, &output_offset,
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
@@ -140,7 +142,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_fill_random(
threadgroup_size, 1, 1,
num_threadgroups, 1, 1,
sizeof(args), &args,
- 1, &output_buffer, &output_offset);
+ 1, &output_buffer, &output_offset,
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
@@ -180,7 +183,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_mf4_f32_convert(
threadgroup_size, 1, 1,
num_threadgroups, 1, 1,
sizeof(args), &args,
- 3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL);
+ 3, (const struct gptoss_metal_buffer *[]) {block_buffer, scale_buffer, output_buffer}, NULL,
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings(
@@ -193,6 +197,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels)
{
@@ -220,9 +226,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings
threadgroup_size, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer},
- (const size_t[]) {token_offset, weight_offset, output_offset});
+ 4,
+ (const struct gptoss_metal_buffer *[]) {token_buffer, weight_buffer, output_buffer, control_buffer},
+ (const size_t[]) {token_offset, weight_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
@@ -234,6 +241,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
size_t weight_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_channels,
float epsilon)
@@ -266,9 +275,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_rmsnorm(
/*threadgroup_size=*/1024, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, output_offset});
+ 4,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
@@ -283,6 +293,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -323,9 +335,105 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
+}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_qkv(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_matmul_qkv_fn,
+ size_t threadgroup_size,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* kv_buffer,
+ size_t kv_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_q_heads,
+ uint32_t num_kv_heads,
+ uint32_t attn_head_dim,
+ uint32_t token_offset,
+ uint32_t max_tokens,
+ float rope_base,
+ float interpolation_scale,
+ float yarn_offset,
+ float yarn_scale,
+ float yarn_multiplier)
+{
+ if (command_buffer->object == NULL || f32_bf16w_matmul_qkv_fn->pipeline_state_object == NULL) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: invalid command buffer or pipeline state object");
+ return gptoss_status_invalid_state;
+ }
+
+ if (threadgroup_size == 0) {
+ threadgroup_size = f32_bf16w_matmul_qkv_fn->simdgroup_threads;
+ } else if (threadgroup_size > f32_bf16w_matmul_qkv_fn->max_threadgroup_threads) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: threadgroup size (%zu) exceeds supported maximum (%zu)",
+ threadgroup_size, f32_bf16w_matmul_qkv_fn->max_threadgroup_threads);
+ return gptoss_status_invalid_argument;
+ }
+
+ if (num_cols % 4 != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of columns (%" PRIu32 ") is not divisible by 4",
+ num_cols);
+ return gptoss_status_invalid_argument;
+ }
+
+ if (num_q_heads != 64) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of Q heads (%" PRIu32 ") must be 64",
+ num_q_heads);
+ return gptoss_status_invalid_argument;
+ }
+ if (num_kv_heads != 8) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of KV heads (%" PRIu32 ") must be 8",
+ num_kv_heads);
+ return gptoss_status_invalid_argument;
+ }
+ if (attn_head_dim != 64) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: attention head dimension (%" PRIu32 ") must be 64",
+ attn_head_dim);
+ return gptoss_status_invalid_argument;
+ }
+
+ const size_t num_simdgroups = threadgroup_size / f32_bf16w_matmul_qkv_fn->simdgroup_threads;
+ const uint32_t num_rows = (num_q_heads + 2 * num_kv_heads) * attn_head_dim;
+ if (num_rows % num_simdgroups != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_matmul_qkv kernel launch: number of rows (%" PRIu32 ") is not divisible by the number of simdgroups (%zu)",
+ num_rows, num_simdgroups);
+ return gptoss_status_invalid_argument;
+ }
+
+ const struct gptoss_qkv_args args = {
+ .num_column_vecs = num_cols / 4,
+ .num_rows = num_rows,
+ .token_offset = token_offset,
+ .freq_scale = -logf(rope_base) / (float) (int32_t) (attn_head_dim / 2),
+ .interpolation_scale = interpolation_scale,
+ .yarn_offset = yarn_offset,
+ .yarn_scale = yarn_scale,
+ .yarn_multiplier = yarn_multiplier,
+ .max_tokens = max_tokens,
+ };
+
+ return gptoss_metal_command_buffer_encode_launch_kernel(
+ command_buffer, f32_bf16w_matmul_qkv_fn,
+ threadgroup_size, 1, 1,
+ num_rows / num_simdgroups, num_tokens, 1,
+ sizeof(args), &args,
+ 6,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, kv_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, kv_offset, control_offset},
+ /*threadgroup_buffer_size=*/num_simdgroups * sizeof(float));
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_add(
@@ -340,6 +448,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -380,9 +490,186 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul_ad
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset});
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, bias_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
+}
+
+enum gptoss_status _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows,
+ uint32_t Bm,
+ uint32_t Bn,
+ uint32_t Bk,
+ uint32_t Sg_Bm,
+ uint32_t Sg_Bn)
+{
+
+ if (command_buffer->object == NULL || f32_bf16w_dense_matmul_fn->pipeline_state_object == NULL) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: invalid command buffer or pipeline state object");
+ return gptoss_status_invalid_state;
+ }
+
+ if (num_cols % 8 != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of columns (%" PRIu32 ") is not divisible by 8",
+ num_cols);
+ return gptoss_status_invalid_argument;
+ }
+ if (num_rows % 8 != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of rows (%" PRIu32 ") is not divisible by 8",
+ num_rows);
+ return gptoss_status_invalid_argument;
+ }
+ if (num_tokens % 8 != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: number of tokens (%" PRIu32 ") is not divisible by 8",
+ num_tokens);
+ return gptoss_status_invalid_argument;
+ }
+
+ const struct gptoss_dense_matmul_args args = {
+ .m = num_tokens,
+ .n = num_rows,
+ .k = num_cols,
+ };
+ const size_t threads_per_simdgroup = f32_bf16w_dense_matmul_fn->simdgroup_threads;
+ const uint32_t m = args.m;
+ const uint32_t n = args.n;
+ const uint32_t k = args.k;
+ if (Bm % Sg_Bm != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bm (%" PRIu32 ") is not divisible by Sg_Bm (%" PRIu32 ")",
+ Bm, Sg_Bm);
+ return gptoss_status_invalid_argument;
+ }
+ if (Bn % Sg_Bn != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: Bn (%" PRIu32 ") is not divisible by Sg_Bn (%" PRIu32 ")",
+ Bn, Sg_Bn);
+ return gptoss_status_invalid_argument;
+ }
+ const size_t threadgroup_size_x = (Bm / Sg_Bm) * (Bn / Sg_Bn) * threads_per_simdgroup;
+ const size_t threadgroup_size_y = 1;
+ const size_t threadgroup_size_z = 1;
+ const size_t total_threadgroup_size = threadgroup_size_x * threadgroup_size_y * threadgroup_size_z;
+ if (total_threadgroup_size > f32_bf16w_dense_matmul_fn->max_threadgroup_threads) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: total threadgroup size (%zu) exceeds supported maximum (%zu)",
+ total_threadgroup_size, f32_bf16w_dense_matmul_fn->max_threadgroup_threads);
+ return gptoss_status_invalid_argument;
+ }
+ if (m % Bm != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: m (%" PRIu32 ") is not divisible by Bm (%" PRIu32 ")",
+ m, Bm);
+ return gptoss_status_invalid_argument;
+ }
+ if (n % Bn != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: n (%" PRIu32 ") is not divisible by Bn (%" PRIu32 ")",
+ n, Bn);
+ return gptoss_status_invalid_argument;
+ }
+ if (k % Bk != 0) {
+ GPTOSS_LOG_ERROR("failed to encode f32_bf16w_dense_matmul kernel launch: k (%" PRIu32 ") is not divisible by Bk (%" PRIu32 ")",
+ k, Bk);
+ return gptoss_status_invalid_argument;
+ }
+ const size_t grid_x = n / Bn;
+ const size_t grid_y = m / Bm;
+ const size_t grid_z = 1;
+
+ return gptoss_metal_command_buffer_encode_launch_kernel(
+ command_buffer, f32_bf16w_dense_matmul_fn,
+ threadgroup_size_x, threadgroup_size_y, threadgroup_size_z,
+ grid_x, grid_y, grid_z,
+ sizeof(args), &args,
+ 5,
+ (const struct gptoss_metal_buffer *[]){input_buffer, weight_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]){input_offset, weight_offset, bias_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
+ return gptoss_status_success;
+}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows)
+{
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
+ output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, QKV_Bm, QKV_Bn, QKV_Bk,
+ QKV_Sg_Bm, QKV_Sg_Bn);
+}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows)
+{
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
+ output_offset, control_buffer, control_offset, num_tokens, num_cols, num_rows, ATTN_OUTPUT_Bm,
+ ATTN_OUTPUT_Bn, ATTN_OUTPUT_Bk, ATTN_OUTPUT_Sg_Bm, ATTN_OUTPUT_Sg_Bn);
+}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_bf16w_dense_matmul_fn,
+ const struct gptoss_metal_buffer* input_buffer,
+ size_t input_offset,
+ const struct gptoss_metal_buffer* weight_buffer,
+ size_t weight_offset,
+ const struct gptoss_metal_buffer* bias_buffer,
+ size_t bias_offset,
+ const struct gptoss_metal_buffer* output_buffer,
+ size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint32_t num_tokens,
+ uint32_t num_cols,
+ uint32_t num_rows)
+{
+ return _gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_impl(
+ command_buffer, f32_bf16w_dense_matmul_fn, input_buffer, input_offset,
+ weight_buffer, weight_offset, bias_buffer, bias_offset, output_buffer,
+ output_offset, control_buffer, control_offset, num_tokens, num_cols,
+ num_rows, MLP_GATE_Bm, MLP_GATE_Bn, MLP_GATE_Bk, MLP_GATE_Sg_Bm,
+ MLP_GATE_Sg_Bn);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembedding(
@@ -398,6 +685,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
size_t output_offset,
const struct gptoss_metal_buffer* argmax_buffer,
size_t argmax_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_cols,
uint32_t num_rows)
@@ -435,9 +724,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_bf16w_unembeddi
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
- 4,
- (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer},
- (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset});
+ 5,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, weight_buffer, output_buffer, argmax_buffer, control_buffer},
+ (const size_t[]) {input_offset, weight_offset, output_offset, argmax_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul_swiglu(
@@ -456,6 +746,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float swiglu_limit,
uint32_t expert_stride,
uint32_t num_tokens,
@@ -508,9 +800,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
threadgroup_size, 1, 1,
(2 * num_rows) / num_simdgroups, num_tokens, num_active_experts,
sizeof(args), &args,
- 6,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
+ 7,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul(
@@ -529,6 +822,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
size_t bias_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t expert_stride,
uint32_t num_tokens,
uint32_t num_active_experts,
@@ -579,9 +874,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_mf4w_moe_matmul
threadgroup_size, 1, 1,
num_rows / num_simdgroups, num_tokens, num_active_experts,
sizeof(args), &args,
- 6,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset});
+ 7,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, weight_block_buffer, weight_scale_buffer, bias_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, weight_block_offset, weight_scale_offset, bias_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
@@ -589,6 +885,9 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
const struct gptoss_metal_function* f32_rope_fn,
size_t threadgroup_size,
const struct gptoss_metal_buffer* activations_buffer,
+ size_t activations_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
float rope_base,
float interpolation_scale,
float yarn_offset,
@@ -631,7 +930,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_rope(
threadgroup_size, 1, 1,
num_qk_heads / num_simdgroups, num_tokens, 1,
sizeof(args), &args,
- 1, (const struct gptoss_metal_buffer *[]) {activations_buffer}, NULL);
+ 2,
+ (const struct gptoss_metal_buffer *[]) {activations_buffer, control_buffer},
+ (const size_t[]) {activations_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
@@ -645,6 +947,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
size_t expert_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
uint32_t num_experts)
@@ -678,9 +982,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_accumulate(
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
- 3,
- (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer},
- (const size_t[]) {input_offset, expert_offset, output_offset});
+ 4,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, expert_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, expert_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
@@ -690,6 +995,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
size_t input_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_tokens,
uint32_t num_experts,
uint32_t num_active_experts)
@@ -713,9 +1020,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_topk(
/*threadgroup_size=*/32, 1, 1,
num_tokens, 1, 1,
sizeof(args), &args,
- 2,
- (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer},
- (const size_t[]) {input_offset, output_offset});
+ 3,
+ (const struct gptoss_metal_buffer *[]) {input_buffer, output_buffer, control_buffer},
+ (const size_t[]) {input_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
@@ -723,15 +1031,16 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
const struct gptoss_metal_function* f32_sdpa_fn,
const struct gptoss_metal_buffer* q_buffer,
size_t q_offset,
- const struct gptoss_metal_buffer* k_buffer,
- size_t k_offset,
- const struct gptoss_metal_buffer* v_buffer,
- size_t v_offset,
+ const struct gptoss_metal_buffer* kv_buffer,
+ size_t kv_offset,
const struct gptoss_metal_buffer* s_buffer,
size_t s_offset,
const struct gptoss_metal_buffer* output_buffer,
size_t output_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t window,
+ uint32_t kv_stride,
uint32_t num_q_tokens,
uint32_t num_kv_tokens,
uint32_t num_q_heads,
@@ -753,20 +1062,27 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sdpa(
return gptoss_status_invalid_argument;
}
+ const size_t max_context_tokens = math_min(num_q_tokens + num_kv_tokens + 1, window);
+ const size_t threadgroup_size = math_min(f32_sdpa_fn->max_threadgroup_threads,
+ max_context_tokens * f32_sdpa_fn->simdgroup_threads);
+ const size_t half_threadgroup_size = math_round_down_po2(threadgroup_size / 2, f32_sdpa_fn->simdgroup_threads);
+
const struct gptoss_sdpa_args args = {
.qkv_dim = head_dim * (num_q_heads + 2 * num_kv_heads),
.num_kv_tokens = num_kv_tokens,
+ .kv_stride = kv_stride,
.window = window,
};
return gptoss_metal_command_buffer_encode_launch_kernel(
command_buffer, f32_sdpa_fn,
- /*threadgroup_size=*/32, 1, 1,
+ threadgroup_size, 1, 1,
num_q_tokens, num_kv_heads, 1,
sizeof(args), &args,
5,
- (const struct gptoss_metal_buffer *[]) {q_buffer, k_buffer, v_buffer, s_buffer, output_buffer},
- (const size_t[]) {q_offset, k_offset, v_offset, s_offset, output_offset});
+ (const struct gptoss_metal_buffer *[]) {q_buffer, kv_buffer, s_buffer, output_buffer, control_buffer},
+ (const size_t[]) {q_offset, kv_offset, s_offset, output_offset, control_offset},
+ /*threadgroup_buffer_size=*/half_threadgroup_size * 8 * 4 * sizeof(float));
}
enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
@@ -782,6 +1098,8 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
size_t prob_offset,
const struct gptoss_metal_buffer* sum_buffer,
size_t sum_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
uint32_t num_channels,
uint32_t num_tokens,
float temperature,
@@ -811,7 +1129,63 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_softmax(
threadgroup_size, 1, 1,
num_threadgroups, num_tokens, 1,
sizeof(args), &args,
+ 5,
+ (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer, control_buffer},
+ (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
+}
+
+enum gptoss_status gptoss_metal_command_buffer_encode_launch_f32_sample(
+ const struct gptoss_metal_command_buffer* command_buffer,
+ const struct gptoss_metal_function* f32_sample_fn,
+ size_t min_threadgroup_size,
+ const struct gptoss_metal_buffer* prob_buffer,
+ size_t prob_offset,
+ const struct gptoss_metal_buffer* sum_buffer,
+ size_t sum_offset,
+ const struct gptoss_metal_buffer* token_buffer,
+ size_t token_offset,
+ const struct gptoss_metal_buffer* control_buffer,
+ size_t control_offset,
+ uint64_t rng_seed,
+ uint32_t rng_offset,
+ uint32_t num_blocks,
+ uint32_t num_channels,
+ uint32_t num_channels_per_block)
+{
+ if (command_buffer->object == NULL || f32_sample_fn->pipeline_state_object == NULL) {
+ return gptoss_status_invalid_state;
+ }
+
+ if (min_threadgroup_size > f32_sample_fn->max_threadgroup_threads) {
+ return gptoss_status_invalid_argument;
+ }
+
+ if (min_threadgroup_size % f32_sample_fn->simdgroup_threads != 0) {
+ return gptoss_status_invalid_argument;
+ }
+
+ if (num_blocks > f32_sample_fn->max_threadgroup_threads) {
+ return gptoss_status_invalid_argument;
+ }
+
+ const struct gptoss_sample_args args = {
+ .rng_seed = rng_seed,
+ .rng_offset = rng_offset,
+ .num_blocks = num_blocks,
+ .num_dims = num_channels,
+ .num_dims_per_block = num_channels_per_block,
+ };
+
+ const size_t threadgroup_size = math_max(min_threadgroup_size,
+ math_round_up_po2(num_blocks, f32_sample_fn->simdgroup_threads));
+ return gptoss_metal_command_buffer_encode_launch_kernel(
+ command_buffer, f32_sample_fn,
+ threadgroup_size, 1, 1,
+ 1, 1, 1,
+ sizeof(args), &args,
4,
- (const struct gptoss_metal_buffer *[]) {score_buffer, argmax_buffer, prob_buffer, sum_buffer},
- (const size_t[]) {score_offset, argmax_offset, prob_offset, sum_offset});
+ (const struct gptoss_metal_buffer *[]) {prob_buffer, sum_buffer, token_buffer, control_buffer},
+ (const size_t[]) {prob_offset, sum_offset, token_offset, control_offset},
+ /*threadgroup_buffer_size=*/0);
}
diff --git a/gpt_oss/metal/source/metal.m b/gpt_oss/metal/source/metal.m
index 4f6cb35f..03d69962 100644
--- a/gpt_oss/metal/source/metal.m
+++ b/gpt_oss/metal/source/metal.m
@@ -96,18 +96,19 @@ enum gptoss_status gptoss_metal_library_create_default(
enum gptoss_status status = gptoss_status_success;
id device_obj = (id) device->object;
id library_obj = nil;
- NSError* error_obj = nil;
- NSString* error_string_obj = nil;
+ NSAutoreleasePool* autorelease_pool = nil;
dispatch_data_t library_blob = NULL;
unsigned long library_size = 0;
uint8_t* library_data = getsectiondata(&__dso_handle, "__METAL", "__shaders", &library_size);
if (library_data != NULL) {
library_blob = dispatch_data_create(library_data, library_size, NULL, DISPATCH_DATA_DESTRUCTOR_DEFAULT);
+
+ autorelease_pool = [[NSAutoreleasePool alloc] init];
+ NSError* error_obj = nil;
library_obj = [device_obj newLibraryWithData:library_blob error:&error_obj];
if (library_obj == nil) {
- error_string_obj = [error_obj localizedDescription];
- GPTOSS_LOG_ERROR("failed to create Metal library: %s", [error_string_obj UTF8String]);
+ GPTOSS_LOG_ERROR("failed to create Metal library: %s", [[error_obj localizedDescription] UTF8String]);
status = gptoss_status_unsupported_system;
goto cleanup;
}
@@ -129,11 +130,8 @@ enum gptoss_status gptoss_metal_library_create_default(
if (library_blob != NULL) {
dispatch_release(library_blob);
}
- if (error_string_obj != nil) {
- [error_string_obj release];
- }
- if (error_obj != nil) {
- [error_obj release];
+ if (autorelease_pool != nil) {
+ [autorelease_pool drain];
}
return status;
}
@@ -154,14 +152,16 @@ enum gptoss_status gptoss_metal_function_create(
const char* name,
struct gptoss_metal_function* function_out)
{
- NSString* name_obj = nil;
- NSError* error_obj = nil;
- NSString* error_string_obj = nil;
+ __block NSString* error_string_obj = nil;
id function_obj = nil;
+ MTLComputePipelineDescriptor* pipeline_descriptor_obj = nil;
+ __block id pipeline_state_obj = nil;
+ dispatch_semaphore_t pipeline_build_semaphore = NULL;
enum gptoss_status status = gptoss_status_success;
+ NSAutoreleasePool* autorelease_pool = [[NSAutoreleasePool alloc] init];
id library_obj = (id) library->object;
- name_obj = [NSString stringWithUTF8String:name];
+ NSString* name_obj = [NSString stringWithUTF8String:name];
function_obj = [library_obj newFunctionWithName:name_obj];
if (function_obj == nil) {
GPTOSS_LOG_ERROR("failed to create Metal function %s", name);
@@ -169,11 +169,33 @@ enum gptoss_status gptoss_metal_function_create(
goto cleanup;
}
id device_obj = [library_obj device];
- id pipeline_state_obj = [device_obj newComputePipelineStateWithFunction:function_obj error:&error_obj];
+ pipeline_descriptor_obj = [[MTLComputePipelineDescriptor alloc] init];
+ [pipeline_descriptor_obj setComputeFunction:function_obj];
+ [pipeline_descriptor_obj setThreadGroupSizeIsMultipleOfThreadExecutionWidth:YES];
+
+ pipeline_build_semaphore = dispatch_semaphore_create(/*value=*/0);
+ [device_obj newComputePipelineStateWithDescriptor:pipeline_descriptor_obj
+ options:MTLPipelineOptionNone
+ completionHandler:^(id _Nullable new_state,
+ MTLComputePipelineReflection* _Nullable reflection,
+ NSError* _Nullable error_obj) {
+ if (new_state != nil) {
+ pipeline_state_obj = [new_state retain];
+ }
+ if (error_obj != nil) {
+ error_string_obj = [[error_obj localizedDescription] copy];
+ }
+ dispatch_semaphore_signal(pipeline_build_semaphore);
+ }];
+ dispatch_semaphore_wait(pipeline_build_semaphore, DISPATCH_TIME_FOREVER);
+
if (pipeline_state_obj == nil) {
- error_string_obj = [error_obj localizedDescription];
+ const char* error_string = "unknown error";
+ if (error_string_obj != nil) {
+ error_string = [error_string_obj UTF8String];
+ }
GPTOSS_LOG_ERROR("failed to create Metal compute pipeline state for function %s: %s",
- name, [error_string_obj UTF8String]);
+ name, error_string);
status = gptoss_status_unsupported_system;
goto cleanup;
}
@@ -189,17 +211,20 @@ enum gptoss_status gptoss_metal_function_create(
pipeline_state_obj = nil;
cleanup:
- if (name_obj != nil) {
- [name_obj release];
- }
if (function_obj != nil) {
[function_obj release];
}
+ if (pipeline_descriptor_obj != nil) {
+ [pipeline_descriptor_obj release];
+ }
if (error_string_obj != nil) {
[error_string_obj release];
}
- if (error_obj != nil) {
- [error_obj release];
+ if (pipeline_build_semaphore != NULL) {
+ dispatch_release(pipeline_build_semaphore);
+ }
+ if (autorelease_pool != nil) {
+ [autorelease_pool drain];
}
return status;
}
@@ -380,9 +405,10 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
size_t num_threadgroups_z,
size_t params_size,
const void* params,
- size_t num_buffers,
- const struct gptoss_metal_buffer** buffers,
- const size_t* buffer_offsets)
+ size_t num_device_buffers,
+ const struct gptoss_metal_buffer** device_buffers,
+ const size_t* device_buffer_offsets,
+ size_t threadgroup_buffer_size)
{
if (command_buffer->object == NULL || function->pipeline_state_object == NULL) {
return gptoss_status_invalid_state;
@@ -396,11 +422,14 @@ enum gptoss_status gptoss_metal_command_buffer_encode_launch_kernel(
// Set kernel arguments
[command_encoder_obj setComputePipelineState:pipeline_state_obj];
[command_encoder_obj setBytes:params length:params_size atIndex:0];
- for (size_t i = 0; i < num_buffers; ++i) {
- id buffer_obj = (id) buffers[i]->object;
- const NSUInteger offset = buffer_offsets == NULL ? 0 : (NSUInteger) buffer_offsets[i];
+ for (size_t i = 0; i < num_device_buffers; ++i) {
+ id buffer_obj = (id) device_buffers[i]->object;
+ const NSUInteger offset = device_buffer_offsets == NULL ? 0 : (NSUInteger) device_buffer_offsets[i];
[command_encoder_obj setBuffer:buffer_obj offset:offset atIndex:i + 1];
}
+ if (threadgroup_buffer_size != 0) {
+ [command_encoder_obj setThreadgroupMemoryLength:threadgroup_buffer_size atIndex:0];
+ }
// Dispatch kernel
const MTLSize threadgroup_size = MTLSizeMake(threadgroup_size_x, threadgroup_size_y, threadgroup_size_z);
diff --git a/gpt_oss/metal/source/model.c b/gpt_oss/metal/source/model.c
index aba8a27e..469ef232 100644
--- a/gpt_oss/metal/source/model.c
+++ b/gpt_oss/metal/source/model.c
@@ -79,7 +79,8 @@ static void prefetch_fd(int fd, size_t offset, size_t size, const char* path) {
enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
const char* path,
- gptoss_model_t* model_out)
+ gptoss_model_t* model_out,
+ size_t max_batch_tokens)
{
*model_out = NULL;
@@ -192,7 +193,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
model->yarn_multiplier = model_header.yarn_multiplier;
model->rmsnorm_epsilon = model_header.rmsnorm_epsilon;
- model->max_batch_tokens = GPTOSS_DEFAULT_BATCH_SIZE;
+ model->max_batch_tokens = max_batch_tokens == 0 ? GPTOSS_DEFAULT_BATCH_SIZE : max_batch_tokens;
struct gptoss_uuid tokenizer_uuid;
status = read_fd(fd, &tokenizer_uuid, sizeof(tokenizer_uuid), path);
@@ -290,6 +291,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
prefetch_fd(fd, model_mapping_start, model_mapping_size, path);
+ if (mlock(model_mapping_ptr, model_mapping_size) != 0) {
+ GPTOSS_LOG_WARNING("mlock(%s, size=%zu) failed with error %d", path, model_mapping_size, errno);
+ } else {
+ model->lock_memory = true;
+ }
+
// Initialize Metal
status = gptoss_metal_device_create_system_default(&model->device);
if (status != gptoss_status_success) {
@@ -318,6 +325,22 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
if (status != gptoss_status_success) {
goto cleanup;
}
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_matmul_qkv", &model->f32_bf16w_matmul_qkv_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_qkv", &model->f32_bf16w_dense_matmul_qkv_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_attn_output", &model->f32_bf16w_dense_matmul_attn_output_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_dense_matmul_mlp_gate", &model->f32_bf16w_dense_matmul_mlp_gate_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
status = gptoss_metal_function_create(&model->library, "gptoss_f32_bf16w_unembedding", &model->f32_bf16w_unembedding_fn);
if (status != gptoss_status_success) {
goto cleanup;
@@ -350,11 +373,25 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
if (status != gptoss_status_success) {
goto cleanup;
}
+ status = gptoss_metal_function_create(&model->library, "gptoss_f32_sample", &model->f32_sample_fn);
+ if (status != gptoss_status_success) {
+ goto cleanup;
+ }
status = gptoss_metal_function_create(&model->library, "gptoss_f32_sdpa_q8_d64", &model->f32_sdpa_q8_d64_fn);
if (status != gptoss_status_success) {
goto cleanup;
}
+ // Kernel launch parameters
+ model->embeddings_threadgroup_size = 512;
+ model->attn_qkv_threadgroup_size = 1024;
+ model->attn_out_threadgroup_size = 768;
+ model->mlp_gate_threadgroup_size = 256;
+ model->mlp_swiglu_threadgroup_size = 192;
+ model->mlp_out_threadgroup_size = 192;
+ model->mlp_acc_threadgroup_size = 768;
+ model->unembedding_threadgroup_size = 416;
+
// Weight buffers
const char* current_ptr = (const char*) model->mapping_ptr;
@@ -421,45 +458,6 @@ enum gptoss_status GPTOSS_ABI gptoss_model_create_from_file(
model->weights_size += moe_block_weight_size;
}
- // Activation buffers
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->residual_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->embedding_dim * sizeof(float), NULL, &model->rmsnorm_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * (model->num_heads + 2 * model->num_kv_heads) * sizeof(float), NULL, &model->qkv_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->head_dim * model->num_heads * sizeof(float), NULL, &model->sdpa_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(float), NULL, &model->gate_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_experts * sizeof(struct gptoss_expert_prediction), NULL, &model->expert_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->mlp_dim * sizeof(float), NULL, &model->swiglu_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
- status = gptoss_metal_buffer_create(&model->device, model->max_batch_tokens * model->num_active_experts * model->embedding_dim * sizeof(float), NULL, &model->moe_activation_buffer);
- if (status != gptoss_status_success) {
- goto cleanup;
- }
-
- model->allocation_size =
- model->residual_activation_buffer.size + model->rmsnorm_activation_buffer.size +
- model->qkv_activation_buffer.size + model->sdpa_activation_buffer.size +
- model->gate_activation_buffer.size + model->expert_activation_buffer.size + model->swiglu_activation_buffer.size + model->moe_activation_buffer.size;
-
// Commit tokenizer
model->tokenizer = tokenizer;
tokenizer = NULL;
@@ -510,16 +508,6 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
if (atomic_fetch_sub_explicit(&model->ref_count, 1, memory_order_acq_rel) == 1) {
gptoss_tokenizer_release(model->tokenizer);
- // Activation buffers
- gptoss_metal_buffer_release(&model->residual_activation_buffer);
- gptoss_metal_buffer_release(&model->rmsnorm_activation_buffer);
- gptoss_metal_buffer_release(&model->qkv_activation_buffer);
- gptoss_metal_buffer_release(&model->sdpa_activation_buffer);
- gptoss_metal_buffer_release(&model->gate_activation_buffer);
- gptoss_metal_buffer_release(&model->expert_activation_buffer);
- gptoss_metal_buffer_release(&model->swiglu_activation_buffer);
- gptoss_metal_buffer_release(&model->moe_activation_buffer);
-
// Weight buffers
gptoss_metal_buffer_release(&model->shared_weight_buffer);
for (uint32_t n = 0; n < model->num_blocks; n++) {
@@ -530,6 +518,10 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
gptoss_metal_function_release(&model->bf16_f32_embeddings_fn);
gptoss_metal_function_release(&model->f32_bf16w_rmsnorm_fn);
gptoss_metal_function_release(&model->f32_bf16w_matmul_fn);
+ gptoss_metal_function_release(&model->f32_bf16w_matmul_qkv_fn);
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_qkv_fn);
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_attn_output_fn);
+ gptoss_metal_function_release(&model->f32_bf16w_dense_matmul_mlp_gate_fn);
gptoss_metal_function_release(&model->f32_bf16w_unembedding_fn);
gptoss_metal_function_release(&model->f32_rope_fn);
gptoss_metal_function_release(&model->f32_mf4w_moe_matmul_swiglu_fn);
@@ -538,6 +530,7 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
gptoss_metal_function_release(&model->f32_topk_softmax_e32_k4_fn);
gptoss_metal_function_release(&model->f32_topk_softmax_e128_k4_fn);
gptoss_metal_function_release(&model->f32_softmax_fn);
+ gptoss_metal_function_release(&model->f32_sample_fn);
gptoss_metal_function_release(&model->f32_sdpa_q8_d64_fn);
gptoss_metal_library_release(&model->library);
@@ -546,6 +539,12 @@ enum gptoss_status GPTOSS_ABI gptoss_model_release(
// Weight buffers
if (model->mapping_ptr != NULL && model->mapping_size != 0) {
+ if (model->lock_memory) {
+ if (munlock(model->mapping_ptr, model->mapping_size) != 0) {
+ GPTOSS_LOG_WARNING("munlock for model weight mapping failed with error %d", errno);
+ }
+ }
+
if (munmap(model->mapping_ptr, model->mapping_size) != 0) {
GPTOSS_LOG_WARNING("munmap for model weight mapping failed with error %d", errno);
}
diff --git a/gpt_oss/metal/source/moematmul.metal b/gpt_oss/metal/source/moematmul.metal
index 6e2f6950..58247484 100644
--- a/gpt_oss/metal/source/moematmul.metal
+++ b/gpt_oss/metal/source/moematmul.metal
@@ -24,6 +24,7 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
const device uchar* weight_scales [[ buffer(4) ]],
const device bfloat* bias [[ buffer(5) ]],
device float* output [[ buffer(6) ]],
+ const device gptoss_control* control [[ buffer(7) ]],
uint3 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
@@ -32,6 +33,9 @@ kernel void gptoss_f32_mf4w_moe_matmul_swiglu(
{
const uint simdgroup_size = 32;
threadgroup float threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
@@ -130,6 +134,7 @@ kernel void gptoss_f32_mf4w_moe_matmul(
const device uchar* weight_scales [[ buffer(4) ]],
const device bfloat* bias [[ buffer(5) ]],
device float* output [[ buffer(6) ]],
+ const device gptoss_control* control [[ buffer(7) ]],
uint3 gid [[threadgroup_position_in_grid]],
uint tid [[thread_index_in_threadgroup]],
uint simdgroup_tid [[thread_index_in_simdgroup]],
@@ -137,6 +142,9 @@ kernel void gptoss_f32_mf4w_moe_matmul(
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
const uint num_column_vecs = args.num_column_vecs;
const uint row = gid.x * num_simdgroups + simdgroup_idx;
diff --git a/gpt_oss/metal/source/rmsnorm.metal b/gpt_oss/metal/source/rmsnorm.metal
index ceb690f0..fc4bcaa2 100644
--- a/gpt_oss/metal/source/rmsnorm.metal
+++ b/gpt_oss/metal/source/rmsnorm.metal
@@ -14,12 +14,16 @@ kernel void gptoss_f32_bf16w_rmsnorm(
const device float4* input [[ buffer(1) ]],
const device bfloat4* weights [[ buffer(2) ]],
device float4* output [[ buffer(3) ]],
+ const device gptoss_control* control [[ buffer(4) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]],
uint threadgroup_size [[ threads_per_threadgroup ]])
{
const uint simdgroup_size = 32;
threadgroup float threadgroup_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
input += gid * args.num_vecs;
output += gid * args.num_vecs;
diff --git a/gpt_oss/metal/source/rope.metal b/gpt_oss/metal/source/rope.metal
index 2739b5fa..8bd2f568 100644
--- a/gpt_oss/metal/source/rope.metal
+++ b/gpt_oss/metal/source/rope.metal
@@ -13,17 +13,22 @@
kernel void gptoss_f32_rope(
constant gptoss_rope_args& args [[ buffer(0) ]],
device float2* activations [[ buffer(1) ]],
+ const device gptoss_control* control [[ buffer(2) ]],
uint2 gid [[thread_position_in_grid]])
{
const uint num_head_dims = 64;
- const float head_idx = static_cast(gid.x % (num_head_dims / 2));
+ if (control->abort != 0) {
+ return;
+ }
+
+ const float dim_idx = static_cast(gid.x % (num_head_dims / 2));
const uint token_idx = args.token_offset + gid.y;
activations += gid.y * args.token_stride + gid.x;
const float2 input_vals = *activations;
- const float inv_extrapolation_freq = metal::precise::exp(head_idx * args.freq_scale);
+ const float inv_extrapolation_freq = metal::precise::exp(dim_idx * args.freq_scale);
const float inv_interpolation_freq = inv_extrapolation_freq * args.interpolation_scale;
- const float alpha = metal::saturate(metal::fma(head_idx, args.yarn_scale, args.yarn_offset));
+ const float alpha = metal::saturate(metal::fma(dim_idx, args.yarn_scale, args.yarn_offset));
const float inv_freq = metal::mix(inv_extrapolation_freq, inv_interpolation_freq, alpha);
const float phi = static_cast(token_idx) * inv_freq;
@@ -32,7 +37,7 @@ kernel void gptoss_f32_rope(
const float sinphi = metal::precise::sincos(phi, cosphi) * yarn_multiplier;
cosphi *= yarn_multiplier;
- const float output_re = metal::fma(-input_vals.y, sinphi, input_vals.x * cosphi);
- const float output_im = metal::fma(input_vals.y, cosphi, input_vals.x * sinphi);
+ const float output_re = input_vals.x * cosphi - input_vals.y * sinphi;
+ const float output_im = input_vals.x * sinphi + input_vals.y * cosphi;
*activations = (float2) { output_re, output_im };
}
diff --git a/gpt_oss/metal/source/sample.metal b/gpt_oss/metal/source/sample.metal
index b739f72c..4a0efe3b 100644
--- a/gpt_oss/metal/source/sample.metal
+++ b/gpt_oss/metal/source/sample.metal
@@ -9,12 +9,34 @@
#pragma METAL fp contract(off)
+inline static uint rng_squares32(ulong offset, ulong seed) {
+ const ulong y = offset * seed;
+ const ulong z = y + seed;
+
+ /* Round 1 */
+ ulong x = y * y + y;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 2 */
+ x = x * x + z;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 3 */
+ x = x * x + y;
+ x = metal::rotate(x, 32ul);
+
+ /* Round 4 */
+ x = x * x + z;
+ return as_type(x).y;
+}
+
kernel void gptoss_f32_softmax(
constant gptoss_softmax_args& args [[ buffer(0) ]],
const device float* score [[ buffer(1) ]],
const device uint2* argmax [[ buffer(2) ]],
device float* prob [[ buffer(3) ]],
device float* sum [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(5) ]],
uint tidx [[thread_index_in_threadgroup]],
uint2 gid [[threadgroup_position_in_grid]],
uint2 threadgroup_size [[threads_per_threadgroup]],
@@ -23,6 +45,9 @@ kernel void gptoss_f32_softmax(
uint num_simdgroups [[simdgroups_per_threadgroup]])
{
threadgroup float threadgroup_sumexp[32];
+ if (control->abort != 0) {
+ return;
+ }
score += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
prob += gid.y * args.num_vecs + gid.x * args.num_vecs_per_threadgroup;
@@ -58,3 +83,127 @@ kernel void gptoss_f32_softmax(
}
}
}
+
+[[max_total_threads_per_threadgroup(1024)]]
+kernel void gptoss_f32_sample(
+ constant gptoss_sample_args& args [[ buffer(0) ]],
+ device const float* prob [[ buffer(1) ]],
+ device const float* sum [[ buffer(2) ]],
+ device uint* prediction [[ buffer(3) ]],
+ device gptoss_control* control [[ buffer(4) ]],
+ uint tid [[thread_position_in_threadgroup]],
+ uint threadgroup_size [[threads_per_threadgroup]],
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
+{
+ threadgroup float threadgroup_sum_buffer[32];
+ threadgroup uint threadgroup_idx_buffer[32];
+ threadgroup float threadgroup_cumsum_buffer[32];
+ if (control->abort != 0) {
+ return;
+ }
+
+ const uint sample_word = rng_squares32(args.rng_offset, args.rng_seed);
+ float sample_cdf = static_cast(sample_word & 0x00FFFFFFu) * 0x1.0p-24f;
+
+ float cumsum = 0.0f;
+ if (tid < args.num_blocks) {
+ cumsum = sum[tid];
+ }
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
+ if (simdgroup_tid == 31) {
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
+ if (simdgroup_tid < num_simdgroups) {
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
+ if (simdgroup_tid < simdgroup_idx) {
+ threadgroup_cumsum = threadgroup_sum;
+ }
+ }
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
+ cumsum += metal::simd_sum(threadgroup_cumsum);
+
+ sample_cdf *= threadgroup_sum;
+ sample_cdf = metal::max(sample_cdf, 0x1.0p-149f);
+
+ // Find the block: the smallest tid where sample_cdf >= s
+ uint block_idx = args.num_blocks;
+ float block_sum = cumsum;
+ if (tid >= args.num_blocks - 1) {
+ block_idx = args.num_blocks - 1;
+ block_sum = 0.0f;
+ } else if (cumsum >= sample_cdf) {
+ block_idx = tid;
+ block_sum = 0.0f;
+ }
+ block_idx = metal::simd_min(block_idx);
+ block_sum = metal::simd_max(block_sum);
+ if (simdgroup_tid == 0) {
+ threadgroup_idx_buffer[simdgroup_idx] = block_idx;
+ threadgroup_cumsum_buffer[simdgroup_idx] = block_sum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (simdgroup_tid < num_simdgroups) {
+ block_idx = threadgroup_idx_buffer[simdgroup_tid];
+ block_sum = threadgroup_cumsum_buffer[simdgroup_tid];
+ }
+ block_idx = metal::simd_min(block_idx);
+ block_sum = metal::simd_max(block_sum);
+
+ const uint block_start = args.num_dims_per_block * block_idx;
+ const uint block_end = metal::min(block_start + args.num_dims_per_block, args.num_dims);
+ uint offset = block_start + tid;
+ float accumulated_sum = block_sum;
+ uint sample_idx;
+
+ // This loop must be threadgroup-uniform.
+ do {
+ // Find the token: the smallest tid where sample_cdf >= s
+ float cumsum = 0.0f;
+ if (offset < block_end) {
+ cumsum = prob[offset];
+ }
+ cumsum = metal::simd_prefix_inclusive_sum(cumsum);
+ if (simdgroup_tid == 31) {
+ threadgroup_sum_buffer[simdgroup_idx] = cumsum;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ float threadgroup_cumsum = 0.0f, threadgroup_sum = 0.0f;
+ if (simdgroup_tid < num_simdgroups) {
+ threadgroup_sum = threadgroup_sum_buffer[simdgroup_tid];
+ if (simdgroup_tid < simdgroup_idx) {
+ threadgroup_cumsum = threadgroup_sum;
+ }
+ }
+ threadgroup_sum = metal::simd_sum(threadgroup_sum);
+ cumsum += metal::simd_sum(threadgroup_cumsum);
+ cumsum += accumulated_sum;
+
+ sample_idx = block_end;
+ if (offset >= block_end) {
+ // Trigger loop exit, with the last token in the block being sampled if no other candidate was found.
+ sample_idx = block_end - 1;
+ } else if (cumsum >= sample_cdf) {
+ sample_idx = offset;
+ }
+ sample_idx = metal::simd_min(sample_idx);
+ if (simdgroup_tid == 0) {
+ threadgroup_idx_buffer[simdgroup_idx] = sample_idx;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (simdgroup_tid < num_simdgroups) {
+ sample_idx = threadgroup_idx_buffer[simdgroup_tid];
+ }
+ sample_idx = metal::simd_min(sample_idx);
+
+ offset += threadgroup_size;
+ accumulated_sum += threadgroup_sum;
+ } while (sample_idx == block_end);
+
+ if (tid == 0) {
+ *prediction = sample_idx;
+ }
+}
diff --git a/gpt_oss/metal/source/sdpa.metal b/gpt_oss/metal/source/sdpa.metal
index 5564be6c..d112569f 100644
--- a/gpt_oss/metal/source/sdpa.metal
+++ b/gpt_oss/metal/source/sdpa.metal
@@ -11,28 +11,36 @@
// Each threadgroup handles 8 Q heads / 1 KV head for 1 token
-[[max_total_threads_per_threadgroup(32)]]
kernel void gptoss_f32_sdpa_q8_d64(
constant gptoss_sdpa_args& args [[ buffer(0) ]],
const device float* q [[ buffer(1) ]],
- const device float* k [[ buffer(2) ]],
- const device float* v [[ buffer(3) ]],
- const device bfloat* s [[ buffer(4) ]],
- device float* output [[ buffer(5) ]],
+ const device float* kv [[ buffer(2) ]],
+ const device bfloat* s [[ buffer(3) ]],
+ device float* output [[ buffer(4) ]],
+ const device gptoss_control* control [[ buffer(6) ]],
+ threadgroup void* threadgroup_buffer [[ threadgroup(0) ]],
uint2 gid [[threadgroup_position_in_grid]],
- uint tid [[thread_index_in_threadgroup]])
+ uint2 tid [[thread_position_in_threadgroup]],
+ uint simdgroup_tid [[thread_index_in_simdgroup]],
+ uint simdgroup_idx [[simdgroup_index_in_threadgroup]],
+ uint num_simdgroups [[simdgroups_per_threadgroup]])
{
+ const uint simdgroup_size = 32;
+ if (control->abort != 0) {
+ return;
+ }
+
const uint num_q_heads = 64;
- const uint num_kv_heads = 8;
const uint head_dim = 64;
const uint qmul = 8;
+ const uint token_stride = 2 * head_dim;
+
const uint qt = gid.x; // Q token index
const uint h = gid.y; // KV head index
q += qt * args.qkv_dim + h * (qmul * head_dim);
- k += h * head_dim;
- v += h * head_dim;
+ kv += h * args.kv_stride;
output += qt * (num_q_heads * head_dim) + h * (qmul * head_dim);
float m0 = static_cast(s[h * qmul + 0]);
@@ -44,14 +52,14 @@ kernel void gptoss_f32_sdpa_q8_d64(
float m6 = static_cast(s[h * qmul + 6]);
float m7 = static_cast(s[h * qmul + 7]);
- float l0 = 1.0f;
- float l1 = 1.0f;
- float l2 = 1.0f;
- float l3 = 1.0f;
- float l4 = 1.0f;
- float l5 = 1.0f;
- float l6 = 1.0f;
- float l7 = 1.0f;
+ float l0 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l1 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l2 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l3 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l4 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l5 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l6 = simdgroup_idx == 0 ? 1.0f : 0.0f;
+ float l7 = simdgroup_idx == 0 ? 1.0f : 0.0f;
float2 out0 = 0.0f;
float2 out1 = 0.0f;
@@ -62,22 +70,20 @@ kernel void gptoss_f32_sdpa_q8_d64(
float2 out6 = 0.0f;
float2 out7 = 0.0f;
- float2 q0 = reinterpret_cast(q + 0 * head_dim)[tid];
- float2 q1 = reinterpret_cast(q + 1 * head_dim)[tid];
- float2 q2 = reinterpret_cast(q + 2 * head_dim)[tid];
- float2 q3 = reinterpret_cast(q + 3 * head_dim)[tid];
- float2 q4 = reinterpret_cast(q + 4 * head_dim)[tid];
- float2 q5 = reinterpret_cast(q + 5 * head_dim)[tid];
- float2 q6 = reinterpret_cast(q + 6 * head_dim)[tid];
- float2 q7 = reinterpret_cast(q + 7 * head_dim)[tid];
+ float2 q0 = reinterpret_cast(q + 0 * head_dim)[simdgroup_tid];
+ float2 q1 = reinterpret_cast(q + 1 * head_dim)[simdgroup_tid];
+ float2 q2 = reinterpret_cast(q + 2 * head_dim)[simdgroup_tid];
+ float2 q3 = reinterpret_cast(q + 3 * head_dim)[simdgroup_tid];
+ float2 q4 = reinterpret_cast(q + 4 * head_dim)[simdgroup_tid];
+ float2 q5 = reinterpret_cast(q + 5 * head_dim)[simdgroup_tid];
+ float2 q6 = reinterpret_cast(q + 6 * head_dim)[simdgroup_tid];
+ float2 q7 = reinterpret_cast(q + 7 * head_dim)[simdgroup_tid];
const uint kt_end = qt + args.num_kv_tokens + 1;
- const uint kt_start = metal::subsat(kt_end, args.window);
- k += 2 * num_kv_heads * head_dim * kt_start;
- v += 2 * num_kv_heads * head_dim * kt_start;
- for (uint kt = kt_start; kt < kt_end; kt++) {
- const float2 kval = reinterpret_cast(k)[tid];
- k += 2 * num_kv_heads * head_dim;
+ const uint kt_start = metal::subsat(kt_end, args.window) + simdgroup_idx;
+ kv += token_stride * kt_start;
+ for (uint kt = kt_start; kt < kt_end; kt += num_simdgroups) {
+ const float2 kval = reinterpret_cast(kv)[simdgroup_tid];
float qk0 = metal::dot(q0, kval);
float qk1 = metal::dot(q1, kval);
@@ -142,8 +148,8 @@ kernel void gptoss_f32_sdpa_q8_d64(
m6 = new_m6;
m7 = new_m7;
- const float2 vval = reinterpret_cast(v)[tid];
- v += 2 * num_kv_heads * head_dim;
+ const float2 vval = reinterpret_cast(kv + head_dim)[simdgroup_tid];
+ kv += token_stride * num_simdgroups;
out0 = metal::fma(vval, qk0, out0 * alpha0);
out1 = metal::fma(vval, qk1, out1 * alpha1);
out2 = metal::fma(vval, qk2, out2 * alpha2);
@@ -153,12 +159,135 @@ kernel void gptoss_f32_sdpa_q8_d64(
out6 = metal::fma(vval, qk6, out6 * alpha6);
out7 = metal::fma(vval, qk7, out7 * alpha7);
}
- reinterpret_cast(output + 0 * head_dim)[tid] = out0 / l0;
- reinterpret_cast(output + 1 * head_dim)[tid] = out1 / l1;
- reinterpret_cast(output + 2 * head_dim)[tid] = out2 / l2;
- reinterpret_cast(output + 3 * head_dim)[tid] = out3 / l3;
- reinterpret_cast(output + 4 * head_dim)[tid] = out4 / l4;
- reinterpret_cast(output + 5 * head_dim)[tid] = out5 / l5;
- reinterpret_cast(output + 6 * head_dim)[tid] = out6 / l6;
- reinterpret_cast(output + 7 * head_dim)[tid] = out7 / l7;
+ if (num_simdgroups > 1) {
+ if (metal::simd_is_first()) {
+ static_cast(threadgroup_buffer)[0 * num_simdgroups + simdgroup_idx] = m0;
+ static_cast(threadgroup_buffer)[1 * num_simdgroups + simdgroup_idx] = m1;
+ static_cast(threadgroup_buffer)[2 * num_simdgroups + simdgroup_idx] = m2;
+ static_cast(threadgroup_buffer)[3 * num_simdgroups + simdgroup_idx] = m3;
+ static_cast(threadgroup_buffer)[4 * num_simdgroups + simdgroup_idx] = m4;
+ static_cast(threadgroup_buffer)[5 * num_simdgroups + simdgroup_idx] = m5;
+ static_cast(threadgroup_buffer)[6 * num_simdgroups + simdgroup_idx] = m6;
+ static_cast(threadgroup_buffer)[7 * num_simdgroups + simdgroup_idx] = m7;
+
+ static_cast(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_idx] = l0;
+ static_cast(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_idx] = l1;
+ static_cast(threadgroup_buffer)[10 * num_simdgroups + simdgroup_idx] = l2;
+ static_cast(threadgroup_buffer)[11 * num_simdgroups + simdgroup_idx] = l3;
+ static_cast(threadgroup_buffer)[12 * num_simdgroups + simdgroup_idx] = l4;
+ static_cast(threadgroup_buffer)[13 * num_simdgroups + simdgroup_idx] = l5;
+ static_cast(threadgroup_buffer)[14 * num_simdgroups + simdgroup_idx] = l6;
+ static_cast(threadgroup_buffer)[15 * num_simdgroups + simdgroup_idx] = l7;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ // Note: simdgroup refers not to the thread's current simdgroup, but to one with simdgroup_idx == thread's simdgroup_tid.
+ float simdgroup_m0 = m0;
+ float simdgroup_m1 = m1;
+ float simdgroup_m2 = m2;
+ float simdgroup_m3 = m3;
+ float simdgroup_m4 = m4;
+ float simdgroup_m5 = m5;
+ float simdgroup_m6 = m6;
+ float simdgroup_m7 = m7;
+ if (simdgroup_tid < num_simdgroups) {
+ simdgroup_m0 = static_cast(threadgroup_buffer)[0 * num_simdgroups + simdgroup_tid];
+ simdgroup_m1 = static_cast(threadgroup_buffer)[1 * num_simdgroups + simdgroup_tid];
+ simdgroup_m2 = static_cast(threadgroup_buffer)[2 * num_simdgroups + simdgroup_tid];
+ simdgroup_m3 = static_cast(threadgroup_buffer)[3 * num_simdgroups + simdgroup_tid];
+ simdgroup_m4 = static_cast(threadgroup_buffer)[4 * num_simdgroups + simdgroup_tid];
+ simdgroup_m5 = static_cast(threadgroup_buffer)[5 * num_simdgroups + simdgroup_tid];
+ simdgroup_m6 = static_cast(threadgroup_buffer)[6 * num_simdgroups + simdgroup_tid];
+ simdgroup_m7 = static_cast(threadgroup_buffer)[7 * num_simdgroups + simdgroup_tid];
+ }
+
+ const float threadgroup_m0 = metal::simd_max(simdgroup_m0);
+ const float threadgroup_m1 = metal::simd_max(simdgroup_m1);
+ const float threadgroup_m2 = metal::simd_max(simdgroup_m2);
+ const float threadgroup_m3 = metal::simd_max(simdgroup_m3);
+ const float threadgroup_m4 = metal::simd_max(simdgroup_m4);
+ const float threadgroup_m5 = metal::simd_max(simdgroup_m5);
+ const float threadgroup_m6 = metal::simd_max(simdgroup_m6);
+ const float threadgroup_m7 = metal::simd_max(simdgroup_m7);
+
+ out0 *= metal::fast::exp(m0 - threadgroup_m0);
+ out1 *= metal::fast::exp(m1 - threadgroup_m1);
+ out2 *= metal::fast::exp(m2 - threadgroup_m2);
+ out3 *= metal::fast::exp(m3 - threadgroup_m3);
+ out4 *= metal::fast::exp(m4 - threadgroup_m4);
+ out5 *= metal::fast::exp(m5 - threadgroup_m5);
+ out6 *= metal::fast::exp(m6 - threadgroup_m6);
+ out7 *= metal::fast::exp(m7 - threadgroup_m7);
+
+ if (simdgroup_idx == 0) {
+ l0 = 0.0f;
+ l1 = 0.0f;
+ l2 = 0.0f;
+ l3 = 0.0f;
+ l4 = 0.0f;
+ l5 = 0.0f;
+ l6 = 0.0f;
+ l7 = 0.0f;
+ if (simdgroup_tid < num_simdgroups) {
+ l0 = static_cast(threadgroup_buffer)[ 8 * num_simdgroups + simdgroup_tid];
+ l1 = static_cast(threadgroup_buffer)[ 9 * num_simdgroups + simdgroup_tid];
+ l2 = static_cast(threadgroup_buffer)[10 * num_simdgroups + simdgroup_tid];
+ l3 = static_cast(threadgroup_buffer)[11 * num_simdgroups + simdgroup_tid];
+ l4 = static_cast(threadgroup_buffer)[12 * num_simdgroups + simdgroup_tid];
+ l5 = static_cast(threadgroup_buffer)[13 * num_simdgroups + simdgroup_tid];
+ l6 = static_cast(threadgroup_buffer)[14 * num_simdgroups + simdgroup_tid];
+ l7 = static_cast(threadgroup_buffer)[15 * num_simdgroups + simdgroup_tid];
+ }
+
+ l0 = metal::simd_sum(l0 * metal::fast::exp(simdgroup_m0 - threadgroup_m0));
+ l1 = metal::simd_sum(l1 * metal::fast::exp(simdgroup_m1 - threadgroup_m1));
+ l2 = metal::simd_sum(l2 * metal::fast::exp(simdgroup_m2 - threadgroup_m2));
+ l3 = metal::simd_sum(l3 * metal::fast::exp(simdgroup_m3 - threadgroup_m3));
+ l4 = metal::simd_sum(l4 * metal::fast::exp(simdgroup_m4 - threadgroup_m4));
+ l5 = metal::simd_sum(l5 * metal::fast::exp(simdgroup_m5 - threadgroup_m5));
+ l6 = metal::simd_sum(l6 * metal::fast::exp(simdgroup_m6 - threadgroup_m6));
+ l7 = metal::simd_sum(l7 * metal::fast::exp(simdgroup_m7 - threadgroup_m7));
+ }
+
+ uint num_threads = num_simdgroups * simdgroup_size;
+ do {
+ const uint num_smem_threads = (num_threads / 2) & -simdgroup_size;
+ const uint num_half_threads = num_threads - num_smem_threads;
+
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ const uint smem_tid = tid.x - num_half_threads;
+ if (smem_tid < num_smem_threads) {
+ static_cast(threadgroup_buffer)[num_smem_threads * 0 + smem_tid] = out0;
+ static_cast(threadgroup_buffer)[num_smem_threads * 1 + smem_tid] = out1;
+ static_cast(threadgroup_buffer)[num_smem_threads * 2 + smem_tid] = out2;
+ static_cast(threadgroup_buffer)[num_smem_threads * 3 + smem_tid] = out3;
+ static_cast(threadgroup_buffer)[num_smem_threads * 4 + smem_tid] = out4;
+ static_cast(threadgroup_buffer)[num_smem_threads * 5 + smem_tid] = out5;
+ static_cast(threadgroup_buffer)[num_smem_threads * 6 + smem_tid] = out6;
+ static_cast(threadgroup_buffer)[num_smem_threads * 7 + smem_tid] = out7;
+ }
+ metal::threadgroup_barrier(metal::mem_flags::mem_threadgroup);
+ if (tid.x < num_smem_threads) {
+ out0 += static_cast(threadgroup_buffer)[num_smem_threads * 0 + tid.x];
+ out1 += static_cast(threadgroup_buffer)[num_smem_threads * 1 + tid.x];
+ out2 += static_cast(threadgroup_buffer)[num_smem_threads * 2 + tid.x];
+ out3 += static_cast(threadgroup_buffer)[num_smem_threads * 3 + tid.x];
+ out4 += static_cast(threadgroup_buffer)[num_smem_threads * 4 + tid.x];
+ out5 += static_cast(threadgroup_buffer)[num_smem_threads * 5 + tid.x];
+ out6 += static_cast(threadgroup_buffer)[num_smem_threads * 6 + tid.x];
+ out7 += static_cast(threadgroup_buffer)[num_smem_threads * 7 + tid.x];
+ }
+
+ num_threads = num_half_threads;
+ } while (num_threads > simdgroup_size);
+ }
+ if (simdgroup_idx == 0) {
+ reinterpret_cast(output + 0 * head_dim)[simdgroup_tid] = out0 / l0;
+ reinterpret_cast(output + 1 * head_dim)[simdgroup_tid] = out1 / l1;
+ reinterpret_cast(output + 2 * head_dim)[simdgroup_tid] = out2 / l2;
+ reinterpret_cast(output + 3 * head_dim)[simdgroup_tid] = out3 / l3;
+ reinterpret_cast(output + 4 * head_dim)[simdgroup_tid] = out4 / l4;
+ reinterpret_cast(output + 5 * head_dim)[simdgroup_tid] = out5 / l5;
+ reinterpret_cast(output + 6 * head_dim)[simdgroup_tid] = out6 / l6;
+ reinterpret_cast(output + 7 * head_dim)[simdgroup_tid] = out7 / l7;
+ }
}
diff --git a/gpt_oss/metal/source/topk.metal b/gpt_oss/metal/source/topk.metal
index d3532ac6..90f4e51c 100644
--- a/gpt_oss/metal/source/topk.metal
+++ b/gpt_oss/metal/source/topk.metal
@@ -14,11 +14,15 @@ kernel void gptoss_f32_topk_softmax_e128_k4(
constant gptoss_topk_args& args [[ buffer(0) ]],
const device float4* input [[ buffer(1) ]],
device gptoss_expert_prediction* output [[ buffer(2) ]],
+ const device gptoss_control* control [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
const uint num_experts = 128;
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
input += gid * (num_experts / 4);
output += gid * num_active_experts;
@@ -132,11 +136,15 @@ kernel void gptoss_f32_topk_softmax_e32_k4(
constant gptoss_topk_args& args [[ buffer(0) ]],
const device float* input [[ buffer(1) ]],
device gptoss_expert_prediction* output [[ buffer(2) ]],
+ const device gptoss_control* control [[ buffer(3) ]],
uint gid [[threadgroup_position_in_grid]],
uint tid [[thread_position_in_threadgroup]])
{
const uint num_experts = 32;
const uint num_active_experts = 4;
+ if (control->abort != 0) {
+ return;
+ }
input += gid * num_experts;
output += gid * num_active_experts;
diff --git a/gpt_oss/metal/test/embeddings-kernel-tester.hpp b/gpt_oss/metal/test/embeddings-kernel-tester.hpp
index fd810c6d..83092a8c 100644
--- a/gpt_oss/metal/test/embeddings-kernel-tester.hpp
+++ b/gpt_oss/metal/test/embeddings-kernel-tester.hpp
@@ -69,6 +69,8 @@ class EmbeddingsKernelTester {
metal::Buffer token_buffer{device_, sizeof(std::uint32_t)};
metal::Buffer weight_buffer{device_, vocabulary_size() * num_channels() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_channels() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
std::uint32_t* token_ptr = static_cast(token_buffer.ptr());
for (std::uint32_t t = 0; t < num_tokens(); t++) {
@@ -85,6 +87,8 @@ class EmbeddingsKernelTester {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens(),
num_channels()),
"gptoss_metal_command_buffer_encode_launch_bf16_f32_embeddings");
diff --git a/gpt_oss/metal/test/f32-bf16w-matmul.cc b/gpt_oss/metal/test/f32-bf16w-matmul.cc
index 9692e6a7..745bff2e 100644
--- a/gpt_oss/metal/test/f32-bf16w-matmul.cc
+++ b/gpt_oss/metal/test/f32-bf16w-matmul.cc
@@ -58,3 +58,30 @@ TEST(F32_BF16W_MATMUL, multiple_tokens) {
.threadgroup_size(threadgroup_size)
.TestF32_BF16W();
}
+
+TEST(F32_BF16W_DENSE_MATMUL_QKV, seq_len_1024) {
+ MatMulKernelTester()
+ .num_tokens(1024)
+ .num_rows(5120)
+ .num_cols(2880)
+ .TestF32_BF16W(
+ MatMulKernelTester::MatMulKernelType::PREFILL_QKV_OPTIMIZED);
+}
+
+TEST(F32_BF16W_DENSE_MATMUL_ATTN_OUTPUT, seq_len_1024) {
+ MatMulKernelTester()
+ .num_tokens(1024)
+ .num_rows(2880)
+ .num_cols(4096)
+ .TestF32_BF16W(
+ MatMulKernelTester::MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED);
+}
+
+TEST(F32_BF16W_DENSE_MATMUL_MLP_GATE, seq_len_1024) {
+ MatMulKernelTester()
+ .num_tokens(1024)
+ .num_rows(128)
+ .num_cols(2880)
+ .TestF32_BF16W(
+ MatMulKernelTester::MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED);
+}
\ No newline at end of file
diff --git a/gpt_oss/metal/test/matmul-kernel-tester.hpp b/gpt_oss/metal/test/matmul-kernel-tester.hpp
index ec13af6b..f5958c7b 100644
--- a/gpt_oss/metal/test/matmul-kernel-tester.hpp
+++ b/gpt_oss/metal/test/matmul-kernel-tester.hpp
@@ -10,9 +10,39 @@
#include
#include
-
namespace gptoss {
+template
+::testing::AssertionResult
+IsNearAbsRel(const char* a_expr, const char* b_expr, const char* abs_expr,
+ const char* rel_expr, T a, T b, T abs_tol, T rel_tol = 1.0) {
+
+ using std::abs;
+ if (!std::isfinite(a) || !std::isfinite(b)) {
+ return ::testing::AssertionFailure()
+ << "Non-finite value(s): " << a_expr << "=" << a << ", " << b_expr
+ << "=" << b;
+ // At least one of abs_tol and rel_tol must be provided
+ }
+ const T diff = abs(a - b);
+ const T rel = rel_tol * std::max(abs(a), abs(b));
+ const T thr = std::max(abs_tol, rel);
+
+ if (diff <= thr)
+ return ::testing::AssertionSuccess();
+
+ return ::testing::AssertionFailure()
+ << a_expr << " vs " << b_expr << " differ by " << diff
+ << " > max(abs_tol=" << abs_tol << ", rel_tol*max(|a|,|b|)=" << rel
+ << ") with " << abs_expr << "=" << abs_tol << ", " << rel_expr << "="
+ << rel_tol << ". \n"
+ << a_expr << "=" << a << ". \n"
+ << b_expr << "=" << b;
+}
+
+#define ASSERT_NEAR_ABS_REL(a, b, abs_tol, rel_tol) \
+ ASSERT_PRED_FORMAT4(IsNearAbsRel, a, b, abs_tol, rel_tol)
+
class MatMulKernelTester {
public:
MatMulKernelTester() { }
@@ -70,16 +100,26 @@ class MatMulKernelTester {
ASSERT_NE(threadgroup_size(), 0);
}
- void TestF32_BF16W() const {
+ enum class MatMulKernelType {
+ DECODE_OPTIMIZED,
+ PREFILL_QKV_OPTIMIZED,
+ PREFILL_ATTN_OUTPUT_OPTIMIZED,
+ PREFILL_MLP_GATE_OPTIMIZED,
+ };
+
+ void TestF32_BF16W(MatMulKernelType kernel_type = MatMulKernelType::DECODE_OPTIMIZED) const {
Validate(/*vec_size=*/4);
- metal::CommandBuffer command_buffer{command_queue_};
+ metal::CommandBuffer command_buffer_initialize{command_queue_};
metal::Buffer input_buffer{device_, num_tokens() * num_cols() * sizeof(float)};
metal::Buffer weight_buffer{device_, num_rows() * num_cols() * sizeof(gptoss_bfloat16)};
metal::Buffer bias_buffer{device_, num_rows() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_tokens() * num_rows() * sizeof(float)};
+ metal::Buffer output_buffer_copy{device_, num_tokens() * num_rows() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
- command_buffer.encode_launch_f32_fill_random(
+ command_buffer_initialize.encode_launch_f32_fill_random(
f32_fill_random_fn_,
/*threadgroup_size=*/0,
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
@@ -87,7 +127,7 @@ class MatMulKernelTester {
/*output_offset=*/0,
num_tokens() * num_cols(), kSeed, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
- command_buffer.encode_launch_bf16_fill_random(
+ command_buffer_initialize.encode_launch_bf16_fill_random(
bf16_fill_random_fn_,
/*threadgroup_size=*/0,
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
@@ -95,7 +135,7 @@ class MatMulKernelTester {
/*output_offset=*/0,
num_rows() * num_cols(), kSeed + 1, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
- command_buffer.encode_launch_bf16_fill_random(
+ command_buffer_initialize.encode_launch_bf16_fill_random(
bf16_fill_random_fn_,
/*threadgroup_size=*/0,
/*max_threadgroups=*/kFillRandomMaxThreadgroups,
@@ -103,30 +143,90 @@ class MatMulKernelTester {
/*output_offset=*/0,
num_rows(), kSeed + 2, /*offset=*/0, /*min=*/-1.0f, /*max=*/1.0);
- Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
- command_buffer.handle(),
- f32_bf16w_matmul_fn_.handle(),
- /*threadgroup_size=*/threadgroup_size(),
- input_buffer.handle(),
- /*input_offset=*/0,
- weight_buffer.handle(),
- /*weight_offset=*/0,
- bias_buffer.handle(),
- /*bias_offset=*/0,
- output_buffer.handle(),
- /*output_offset=*/0,
- num_tokens(),
- num_cols(),
- num_rows()),
- "gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul");
-
- command_buffer.commit();
- command_buffer.wait_completion();
+ // Fill output buffer with random values to test matmul with add.
+ command_buffer_initialize.encode_launch_f32_fill_random(
+ f32_fill_random_fn_,
+ /*threadgroup_size=*/0,
+ /*max_threadgroups=*/kFillRandomMaxThreadgroups,
+ /*output_buffer=*/output_buffer,
+ /*output_offset=*/0, num_tokens() * num_rows(), kSeed + 3,
+ /*offset=*/0,
+ /*min=*/-1.0f, /*max=*/1.0);
+ command_buffer_initialize.commit();
+ command_buffer_initialize.wait_completion();
+ if (kernel_type ==
+ MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) {
+ // Copy output buffer to output buffer copy to use when calculating reference.
+ const uint64_t bytes =
+ uint64_t(num_tokens()) * uint64_t(num_rows()) * sizeof(float);
+ void* src = output_buffer.ptr();
+ void* dst = output_buffer_copy.ptr();
+ assert(src && dst && "Buffers must be CPU-mappable for memcpy");
+
+ std::memcpy(reinterpret_cast(dst),
+ reinterpret_cast(src), bytes);
+ }
+
+ metal::CommandBuffer command_buffer_compute{command_queue_};
+ switch (kernel_type) {
+ case MatMulKernelType::DECODE_OPTIMIZED:
+ Check(gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul(
+ command_buffer_compute.handle(), f32_bf16w_matmul_fn_.handle(),
+ /*threadgroup_size=*/threadgroup_size(), input_buffer.handle(),
+ /*input_offset=*/0, weight_buffer.handle(),
+ /*weight_offset=*/0, bias_buffer.handle(),
+ /*bias_offset=*/0, output_buffer.handle(),
+ /*output_offset=*/0, control_buffer.handle(),
+ /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),
+ "gptoss_metal_command_buffer_encode_launch_f32_bf16w_matmul");
+ break;
+ case MatMulKernelType::PREFILL_QKV_OPTIMIZED:
+ Check(
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv(
+ command_buffer_compute.handle(),
+ f32_bf16w_dense_matmul_qkv_fn_.handle(), input_buffer.handle(),
+ /*input_offset=*/0, weight_buffer.handle(),
+ /*weight_offset=*/0, bias_buffer.handle(),
+ /*bias_offset=*/0, output_buffer.handle(),
+ /*output_offset=*/0, control_buffer.handle(),
+ /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),
+ "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_qkv");
+ break;
+ case MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED:
+ Check(
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output(
+ command_buffer_compute.handle(),
+ f32_bf16w_dense_matmul_attn_output_fn_.handle(),
+ input_buffer.handle(),
+ /*input_offset=*/0, weight_buffer.handle(),
+ /*weight_offset=*/0, bias_buffer.handle(),
+ /*bias_offset=*/0, output_buffer.handle(),
+ /*output_offset=*/0, control_buffer.handle(),
+ /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),
+ "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_attn_output");
+ break;
+ case MatMulKernelType::PREFILL_MLP_GATE_OPTIMIZED:
+ Check(
+ gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate(
+ command_buffer_compute.handle(),
+ f32_bf16w_dense_matmul_mlp_gate_fn_.handle(),
+ input_buffer.handle(),
+ /*input_offset=*/0, weight_buffer.handle(),
+ /*weight_offset=*/0, bias_buffer.handle(),
+ /*bias_offset=*/0, output_buffer.handle(),
+ /*output_offset=*/0, control_buffer.handle(),
+ /*control_offset=*/0, num_tokens(), num_cols(), num_rows()),
+ "gptoss_metal_command_buffer_encode_launch_f32_bf16w_dense_matmul_mlp_gate");
+ break;
+ }
+ command_buffer_compute.commit();
+ command_buffer_compute.wait_completion();
const float* input_ptr = static_cast(input_buffer.ptr());
const gptoss_bfloat16* weight_ptr = static_cast(weight_buffer.ptr());
const gptoss_bfloat16* bias_ptr = static_cast(bias_buffer.ptr());
const float* output_ptr = static_cast(output_buffer.ptr());
+ const float* output_ptr_copy = static_cast(output_buffer_copy.ptr());
for (size_t t = 0; t < num_tokens(); t++) {
for (size_t r = 0; r < num_rows(); r++) {
double ref_sum = upcast(bias_ptr[r]);
@@ -135,7 +235,13 @@ class MatMulKernelTester {
const double input_value = upcast(input_ptr[t * num_cols() + c]);
ref_sum = std::fma(input_value, ref_weight, ref_sum);
}
- ASSERT_NEAR(upcast(output_ptr[t * num_rows() + r]), ref_sum, std::abs(ref_sum) * 1.0e-5)
+
+ if (kernel_type ==
+ MatMulKernelType::PREFILL_ATTN_OUTPUT_OPTIMIZED) {
+ ref_sum += upcast(output_ptr_copy[t * num_rows() + r]);
+ }
+ ASSERT_NEAR_ABS_REL(upcast(output_ptr[t * num_rows() + r]),
+ ref_sum, 2.0e-4, 1.0e-4)
<< "token " << t;
}
}
@@ -155,6 +261,9 @@ class MatMulKernelTester {
metal::Function f32_fill_random_fn_{library_, "gptoss_f32_fill_random"};
metal::Function bf16_fill_random_fn_{library_, "gptoss_bf16_fill_random"};
metal::Function f32_bf16w_matmul_fn_{library_, "gptoss_f32_bf16w_matmul"};
+ metal::Function f32_bf16w_dense_matmul_qkv_fn_{library_, "gptoss_f32_bf16w_dense_matmul_qkv"};
+ metal::Function f32_bf16w_dense_matmul_attn_output_fn_{library_, "gptoss_f32_bf16w_dense_matmul_attn_output"};
+ metal::Function f32_bf16w_dense_matmul_mlp_gate_fn_{library_, "gptoss_f32_bf16w_dense_matmul_mlp_gate"};
std::uint32_t num_tokens_{1};
std::uint32_t num_rows_{1};
std::uint32_t num_cols_{32};
diff --git a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
index 16a6da64..3111eecb 100644
--- a/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
+++ b/gpt_oss/metal/test/rmsnorm-kernel-tester.hpp
@@ -64,6 +64,8 @@ class RMSNormKernelTester {
metal::Buffer input_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
metal::Buffer weight_buffer{device_, num_channels() * sizeof(gptoss_bfloat16)};
metal::Buffer output_buffer{device_, num_tokens() * num_channels() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
metal::CommandBuffer command_buffer{command_queue_};
@@ -90,6 +92,8 @@ class RMSNormKernelTester {
/*weight_offset=*/0,
output_buffer.handle(),
/*output_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
num_tokens(),
num_channels(),
epsilon()),
diff --git a/gpt_oss/metal/test/rope-kernel-tester.hpp b/gpt_oss/metal/test/rope-kernel-tester.hpp
index 602912a1..cb930621 100644
--- a/gpt_oss/metal/test/rope-kernel-tester.hpp
+++ b/gpt_oss/metal/test/rope-kernel-tester.hpp
@@ -112,6 +112,8 @@ class RoPEKernelTester {
metal::Buffer activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
metal::Buffer ref_activations_buffer{device_, (num_tokens() * num_qkv_heads() + num_qk_heads()) * head_dim() * sizeof(float)};
+ metal::Buffer control_buffer{device_, sizeof(gptoss_control)};
+ std::memset(control_buffer.ptr(), 0, sizeof(gptoss_control));
metal::CommandBuffer command_buffer{command_queue_};
@@ -138,6 +140,9 @@ class RoPEKernelTester {
f32_rope_fn_.handle(),
threadgroup_size(),
activations_buffer.handle(),
+ /*activations_offset=*/0,
+ control_buffer.handle(),
+ /*control_offset=*/0,
frequency_base(),
/*interpolation_scale=*/1.0f,
/*yarn_offset=*/0.0f,
diff --git a/gpt_oss/responses_api/api_server.py b/gpt_oss/responses_api/api_server.py
index 908d86c0..8eb053f1 100644
--- a/gpt_oss/responses_api/api_server.py
+++ b/gpt_oss/responses_api/api_server.py
@@ -1,8 +1,7 @@
-import asyncio
+import os
import datetime
import uuid
from typing import Callable, Literal, Optional
-import json
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
@@ -20,29 +19,32 @@
ToolDescription,
)
+from gpt_oss.tools.python_docker.docker_tool import PythonTool
from gpt_oss.tools.simple_browser import SimpleBrowserTool
-from gpt_oss.tools.simple_browser.backend import ExaBackend
+from gpt_oss.tools.simple_browser.backend import YouComBackend, ExaBackend
from .events import (
+ ResponseCodeInterpreterCallCompleted,
+ ResponseCodeInterpreterCallInProgress,
ResponseCompletedEvent,
+ ResponseContentPartAdded,
+ ResponseContentPartDone,
ResponseCreatedEvent,
- ResponseInProgressEvent,
ResponseEvent,
+ ResponseInProgressEvent,
ResponseOutputItemAdded,
ResponseOutputItemDone,
- ResponseContentPartAdded,
- ResponseContentPartDone,
- ResponseOutputTextDone,
+ ResponseOutputTextAnnotationAdded,
ResponseOutputTextDelta,
- ResponseReasoningTextDone,
+ ResponseOutputTextDone,
ResponseReasoningTextDelta,
+ ResponseReasoningTextDone,
+ ResponseWebSearchCallCompleted,
ResponseWebSearchCallInProgress,
ResponseWebSearchCallSearching,
- ResponseWebSearchCallCompleted,
- ResponseOutputTextAnnotationAdded
)
from .types import (
- UrlCitation,
+ CodeInterpreterCallItem,
Error,
FunctionCallItem,
Item,
@@ -51,11 +53,12 @@
ResponseObject,
ResponsesRequest,
TextContentItem,
+ UrlCitation,
Usage,
- WebSearchCallItem,
- WebSearchActionSearch,
- WebSearchActionOpenPage,
WebSearchActionFind,
+ WebSearchActionOpenPage,
+ WebSearchActionSearch,
+ WebSearchCallItem,
)
DEFAULT_TEMPERATURE = 0.0
@@ -64,14 +67,20 @@
def get_reasoning_effort(effort: Literal["low", "medium", "high"]) -> ReasoningEffort:
if effort == "low":
return ReasoningEffort.LOW
- elif effort == "medium":
+ if effort == "medium":
return ReasoningEffort.MEDIUM
- elif effort == "high":
+ if effort == "high":
return ReasoningEffort.HIGH
+ raise ValueError(f"Invalid reasoning effort: {effort}")
def is_not_builtin_tool(recipient: str) -> bool:
- return not recipient.startswith("browser.") and not recipient == "python" and not recipient == "assistant"
+ return (
+ not recipient.startswith("browser.")
+ and not recipient == "python"
+ and not recipient == "assistant"
+ )
+
def create_api_server(
infer_next_token: Callable[[list[int], float], int], encoding: HarmonyEncoding
@@ -89,6 +98,8 @@ def generate_response(
previous_response_id: Optional[str] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
browser_call_ids: Optional[list[str]] = None,
+ python_tool: Optional[PythonTool] = None,
+ python_call_ids: Optional[list[str]] = None,
) -> ResponseObject:
output = []
error = None
@@ -112,9 +123,12 @@ def generate_response(
fc_index = 0
browser_tool_index = 0
+ python_tool_index = 0
for entry in entries:
entry_dict = entry.to_dict()
- if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(entry_dict["recipient"]):
+ if len(entry_dict.get("recipient", "")) > 0 and is_not_builtin_tool(
+ entry_dict["recipient"]
+ ):
call = entry_dict["content"][0]
arguments = call["text"]
name = entry_dict["recipient"]
@@ -138,12 +152,16 @@ def generate_response(
call_id=call_id,
)
)
- elif len(entry_dict.get("recipient", "")) > 0 and entry_dict["recipient"].startswith("browser.") and browser_tool is not None:
+ elif (
+ len(entry_dict.get("recipient", "")) > 0
+ and entry_dict["recipient"].startswith("browser.")
+ and browser_tool is not None
+ ):
# Mirror event-based creation of WebSearchCallItems when the browser tool is invoked
name = entry_dict["recipient"]
call = entry_dict["content"][0]
arguments = call["text"]
- function_name = name[len("browser."):]
+ function_name = name[len("browser.") :]
# Reconstruct a Message for argument parsing
tool_msg = (
@@ -176,7 +194,9 @@ def generate_response(
action = None
if action is not None:
- if browser_call_ids and browser_tool_index < len(browser_call_ids):
+ if browser_call_ids and browser_tool_index < len(
+ browser_call_ids
+ ):
web_search_call_id = browser_call_ids[browser_tool_index]
else:
web_search_call_id = f"ws_{uuid.uuid4().hex}"
@@ -188,11 +208,29 @@ def generate_response(
action=action,
)
)
+ elif (
+ len(entry_dict.get("recipient", "")) > 0
+ and entry_dict["recipient"].startswith("python")
+ and python_tool is not None
+ ):
+ if python_call_ids and python_tool_index < len(python_call_ids):
+ code_call_id = python_call_ids[python_tool_index]
+ else:
+ code_call_id = f"ci_{uuid.uuid4().hex}"
+ python_tool_index += 1
+ output.append(
+ CodeInterpreterCallItem(
+ type="code_interpreter_call",
+ id=code_call_id,
+ )
+ )
elif entry_dict["channel"] == "final":
content = []
- for content_entry in entry_dict["content"]:
+ for content_entry in entry_dict["content"]:
if browser_tool:
- text_content, annotation_entries, _has_partial_citations = browser_tool.normalize_citations(content_entry["text"])
+ text_content, annotation_entries, _has_partial_citations = (
+ browser_tool.normalize_citations(content_entry["text"])
+ )
annotations = [UrlCitation(**a) for a in annotation_entries]
else:
text_content = content_entry["text"]
@@ -287,7 +325,6 @@ class StreamResponsesEvents:
request_body: ResponsesRequest
request: Request
sequence_number: int
-
def __init__(
self,
@@ -300,6 +337,7 @@ def __init__(
Callable[[str, ResponsesRequest, ResponseObject], None]
] = None,
browser_tool: Optional[SimpleBrowserTool] = None,
+ python_tool: Optional[PythonTool] = None,
):
self.initial_tokens = initial_tokens
self.tokens = initial_tokens.copy()
@@ -326,6 +364,9 @@ def __init__(
self.browser_tool = browser_tool
self.use_browser_tool = browser_tool is not None
self.browser_call_ids: list[str] = []
+ self.python_tool = python_tool
+ self.use_code_interpreter = python_tool is not None
+ self.python_call_ids: list[str] = []
def _send_event(self, event: ResponseEvent):
event.sequence_number = self.sequence_number
@@ -345,6 +386,10 @@ async def run(self):
function_call_ids=self.function_call_ids,
response_id=self.response_id,
previous_response_id=self.request_body.previous_response_id,
+ browser_tool=self.browser_tool,
+ browser_call_ids=self.browser_call_ids,
+ python_tool=self.python_tool,
+ python_call_ids=self.python_call_ids,
)
initial_response.status = "in_progress"
yield self._send_event(
@@ -367,9 +412,9 @@ async def run(self):
sent_output_item_added = False
# we use this if the model outputs a citation to buffer until completed
- output_delta_buffer = ""
+ output_delta_buffer = ""
# we use this to track the current output text content for things like providing the right indices in citations
- current_output_text_content = ""
+ current_output_text_content = ""
current_annotations = []
while True:
@@ -386,7 +431,7 @@ async def run(self):
self.tokens.append(next_tok)
try:
self.parser.process(next_tok)
- except Exception as e:
+ except Exception:
pass
if self.parser.state == StreamState.EXPECT_START:
@@ -462,9 +507,17 @@ async def run(self):
)
)
if previous_item.channel == "final":
- annotations = [UrlCitation(**a) for a in current_annotations]
+ annotations = [
+ UrlCitation(**a) for a in current_annotations
+ ]
if browser_tool:
- normalized_text, _annotations, _has_partial_citations = browser_tool.normalize_citations(previous_item.content[0].text)
+ (
+ normalized_text,
+ _annotations,
+ _has_partial_citations,
+ ) = browser_tool.normalize_citations(
+ previous_item.content[0].text
+ )
else:
normalized_text = previous_item.content[0].text
annotations = []
@@ -530,14 +583,26 @@ async def run(self):
should_send_output_text_delta = True
if browser_tool:
# we normalize on the full current text to get the right indices in citations
- updated_output_text, annotations, has_partial_citations = browser_tool.normalize_citations(current_output_text_content + output_delta_buffer)
+ updated_output_text, annotations, has_partial_citations = (
+ browser_tool.normalize_citations(
+ current_output_text_content + output_delta_buffer
+ )
+ )
# remove the current text to get back the delta but now normalized
- output_delta_buffer = updated_output_text[len(current_output_text_content):]
-
+ output_delta_buffer = updated_output_text[
+ len(current_output_text_content) :
+ ]
+
# Filter annotations to only include those whose start_index is not already present in current_annotations
# this is to avoid sending duplicate annotations as multiple annotations can't be in the same place
- existing_start_indices = {a["start_index"] for a in current_annotations}
- new_annotations = [a for a in annotations if a["start_index"] not in existing_start_indices]
+ existing_start_indices = {
+ a["start_index"] for a in current_annotations
+ }
+ new_annotations = [
+ a
+ for a in annotations
+ if a["start_index"] not in existing_start_indices
+ ]
for a in new_annotations:
current_annotations.append(a)
citation = UrlCitation(**a)
@@ -554,7 +619,6 @@ async def run(self):
if has_partial_citations:
should_send_output_text_delta = False
-
if should_send_output_text_delta:
yield self._send_event(
ResponseOutputTextDelta(
@@ -588,7 +652,9 @@ async def run(self):
type="response.content_part.added",
output_index=current_output_index,
content_index=current_content_index,
- part=ReasoningTextContentItem(type="reasoning_text", text=""),
+ part=ReasoningTextContentItem(
+ type="reasoning_text", text=""
+ ),
)
)
yield self._send_event(
@@ -617,7 +683,7 @@ async def run(self):
and last_message.recipient is not None
and last_message.recipient.startswith("browser.")
):
- function_name = last_message.recipient[len("browser."):]
+ function_name = last_message.recipient[len("browser.") :]
action = None
parsed_args = browser_tool.process_arguments(last_message)
if function_name == "search":
@@ -628,32 +694,42 @@ async def run(self):
elif function_name == "open":
action = WebSearchActionOpenPage(
type="open_page",
- url=parsed_args["url"] if "url" in parsed_args else None,
+ url=(
+ parsed_args["url"]
+ if "url" in parsed_args
+ else None
+ ),
)
elif function_name == "find":
action = WebSearchActionFind(
type="find",
pattern=parsed_args["pattern"],
- url=parsed_args["url"] if "url" in parsed_args else None,
+ url=(
+ parsed_args["url"]
+ if "url" in parsed_args
+ else None
+ ),
)
if action is not None:
web_search_call_id = f"ws_{uuid.uuid4().hex}"
self.browser_call_ids.append(web_search_call_id)
- yield self._send_event(ResponseOutputItemAdded(
- type="response.output_item.added",
- output_index=current_output_index,
- item=WebSearchCallItem(
- type="web_search_call",
- id=web_search_call_id,
- action=action,
- ),
- ))
+ yield self._send_event(
+ ResponseOutputItemAdded(
+ type="response.output_item.added",
+ output_index=current_output_index,
+ item=WebSearchCallItem(
+ type="web_search_call",
+ id=web_search_call_id,
+ action=action,
+ ),
+ )
+ )
yield self._send_event(
ResponseWebSearchCallInProgress(
type="response.web_search_call.in_progress",
output_index=current_output_index,
- id=web_search_call_id
+ id=web_search_call_id,
)
)
@@ -675,10 +751,12 @@ async def run_tool():
new_tokens = encoding.render_conversation_for_completion(
Conversation.from_messages(result), Role.ASSISTANT
)
-
+
print(encoding.decode_utf8(new_tokens))
self.output_tokens.append(next_tok)
- self.tokens.append(encoding.encode('<|end|>', allowed_special="all")[0])
+ self.tokens.append(
+ encoding.encode("<|end|>", allowed_special="all")[0]
+ )
for token in new_tokens:
self.parser.process(token)
@@ -692,19 +770,94 @@ async def run_tool():
id=web_search_call_id,
)
)
- yield self._send_event(ResponseOutputItemDone(
- type="response.output_item.done",
- output_index=current_output_index,
- item=WebSearchCallItem(
- type="web_search_call",
- id=web_search_call_id,
- action=action,
- ),
- ))
+ yield self._send_event(
+ ResponseOutputItemDone(
+ type="response.output_item.done",
+ output_index=current_output_index,
+ item=WebSearchCallItem(
+ type="web_search_call",
+ id=web_search_call_id,
+ action=action,
+ ),
+ )
+ )
+
+ current_output_index += 1
+ self.new_request = True
+
+ continue
+
+ elif (
+ self.use_code_interpreter
+ and last_message.recipient is not None
+ and last_message.recipient.startswith("python")
+ ):
+ code_call_id = f"ci_{uuid.uuid4().hex}"
+ self.python_call_ids.append(code_call_id)
+ yield self._send_event(
+ ResponseOutputItemAdded(
+ type="response.output_item.added",
+ output_index=current_output_index,
+ item=CodeInterpreterCallItem(
+ type="code_interpreter_call",
+ id=code_call_id,
+ ),
+ )
+ )
+ yield self._send_event(
+ ResponseCodeInterpreterCallInProgress(
+ type="response.code_interpreter_call.in_progress",
+ output_index=current_output_index,
+ id=code_call_id,
+ )
+ )
+
+ async def run_python_tool():
+ results = []
+ async for msg in self.python_tool.process(last_message):
+ results.append(msg)
+ return results
+
+ result = await run_python_tool()
+
+ print(result)
+
+ new_tokens = encoding.render_conversation_for_completion(
+ Conversation.from_messages(result), Role.ASSISTANT
+ )
+
+ print(encoding.decode_utf8(new_tokens))
+ self.output_tokens.append(next_tok)
+ self.tokens.append(
+ encoding.encode("<|end|>", allowed_special="all")[0]
+ )
+
+ for token in new_tokens:
+ self.parser.process(token)
+ self.output_tokens.append(token)
+ self.tokens.append(token)
+
+ yield self._send_event(
+ ResponseCodeInterpreterCallCompleted(
+ type="response.code_interpreter_call.completed",
+ output_index=current_output_index,
+ id=code_call_id,
+ )
+ )
+ yield self._send_event(
+ ResponseOutputItemDone(
+ type="response.output_item.done",
+ output_index=current_output_index,
+ item=CodeInterpreterCallItem(
+ type="code_interpreter_call",
+ id=code_call_id,
+ ),
+ )
+ )
current_output_index += 1
self.new_request = True
-
+
continue
else:
@@ -746,15 +899,28 @@ async def generate(body: ResponsesRequest, request: Request):
getattr(tool, "type", None) == "browser_search"
for tool in (body.tools or [])
)
+ use_code_interpreter = any(
+ getattr(tool, "type", None) == "code_interpreter"
+ for tool in (body.tools or [])
+ )
if use_browser_tool:
- backend = ExaBackend(
- source="web",
- )
+ tool_backend = os.getenv("BROWSER_BACKEND", "exa")
+ if tool_backend == "youcom":
+ backend = YouComBackend(source="web")
+ elif tool_backend == "exa":
+ backend = ExaBackend(source="web")
+ else:
+ raise ValueError(f"Invalid tool backend: {tool_backend}")
browser_tool = SimpleBrowserTool(backend=backend)
else:
browser_tool = None
+ if use_code_interpreter:
+ python_tool = PythonTool()
+ else:
+ python_tool = None
+
if body.previous_response_id:
prev = responses_store.get(body.previous_response_id)
if prev:
@@ -778,31 +944,44 @@ def _ensure_list(inp):
body.instructions = prev_req.instructions
body.input = merged_input
-
system_message_content = SystemContent.new().with_conversation_start_date(
datetime.datetime.now().strftime("%Y-%m-%d")
)
-
+
if body.reasoning is not None:
- reasoning_effort = get_reasoning_effort(body.reasoning.effort)
- system_message_content = system_message_content.with_reasoning_effort(reasoning_effort)
+ try:
+
+ reasoning_effort = get_reasoning_effort(body.reasoning.effort)
+ except ValueError as e:
+ from fastapi import HTTPException
+
+ raise HTTPException(status_code=422, detail=str(e))
+ system_message_content = system_message_content.with_reasoning_effort(
+ reasoning_effort
+ )
if use_browser_tool:
- system_message_content = system_message_content.with_tools(browser_tool.tool_config)
+ system_message_content = system_message_content.with_tools(
+ browser_tool.tool_config
+ )
+ if use_code_interpreter:
+ system_message_content = system_message_content.with_tools(
+ python_tool.tool_config
+ )
system_message = Message.from_role_and_content(
Role.SYSTEM, system_message_content
)
+ messages = [system_message]
- developer_message_content = DeveloperContent.new().with_instructions(
- body.instructions
- )
+ if body.instructions or body.tools:
+ developer_message_content = DeveloperContent.new().with_instructions(
+ body.instructions
+ )
- tools = []
- if body.tools:
+ tools = []
for tool in body.tools:
if tool.type == "function":
- has_functions = True
tools.append(
ToolDescription.new(
tool.name,
@@ -810,17 +989,17 @@ def _ensure_list(inp):
tool.parameters,
)
)
-
- if len(tools) > 0:
- developer_message_content = developer_message_content.with_function_tools(
- tools
- )
- developer_message = Message.from_role_and_content(
- Role.DEVELOPER, developer_message_content
- )
+ if tools:
+ developer_message_content = (
+ developer_message_content.with_function_tools(tools)
+ )
+
+ developer_message = Message.from_role_and_content(
+ Role.DEVELOPER, developer_message_content
+ )
- messages = [system_message, developer_message]
+ messages.append(developer_message)
if isinstance(body.input, str):
user_message = Message.from_role_and_content(Role.USER, body.input)
@@ -846,7 +1025,9 @@ def _ensure_list(inp):
else:
for content_item in item.content:
messages.append(
- Message.from_role_and_content(item.role, content_item.text)
+ Message.from_role_and_content(
+ item.role, content_item.text
+ )
)
# add final channel to the last assistant message if it's from the assistant
if item.role == Role.ASSISTANT:
@@ -879,7 +1060,9 @@ def _ensure_list(inp):
Message.from_author_and_content(
Author.new(Role.TOOL, f"functions.{function_call.name}"),
item.output,
- ).with_recipient("assistant").with_channel("commentary")
+ )
+ .with_recipient("assistant")
+ .with_channel("commentary")
)
conversation = Conversation.from_messages(messages)
@@ -901,6 +1084,7 @@ def store_callback(rid: str, req: ResponsesRequest, resp: ResponseObject):
response_id=response_id,
store_callback=store_callback,
browser_tool=browser_tool,
+ python_tool=python_tool,
)
if body.stream:
diff --git a/gpt_oss/responses_api/events.py b/gpt_oss/responses_api/events.py
index 7adecc64..fed4c6e6 100644
--- a/gpt_oss/responses_api/events.py
+++ b/gpt_oss/responses_api/events.py
@@ -4,14 +4,15 @@
from pydantic import BaseModel
from .types import (
+ CodeInterpreterCallItem,
FunctionCallItem,
Item,
ReasoningItem,
+ ReasoningTextContentItem,
ResponseObject,
TextContentItem,
- ReasoningTextContentItem,
- WebSearchCallItem,
UrlCitation,
+ WebSearchCallItem,
)
@@ -67,13 +68,25 @@ class ResponseReasoningTextDone(ResponseEvent):
class ResponseOutputItemAdded(ResponseEvent):
type: Literal["response.output_item.added"] = "response.output_item.added"
output_index: int = 0
- item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem]
+ item: Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ WebSearchCallItem,
+ CodeInterpreterCallItem,
+ ]
class ResponseOutputItemDone(ResponseEvent):
type: Literal["response.output_item.done"] = "response.output_item.done"
output_index: int = 0
- item: Union[Item, ReasoningItem, FunctionCallItem, WebSearchCallItem]
+ item: Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ WebSearchCallItem,
+ CodeInterpreterCallItem,
+ ]
class ResponseInProgressEvent(ResponseEvent):
@@ -105,25 +118,53 @@ class ResponseContentPartDone(ResponseEvent):
content_index: int = 0
part: Union[TextContentItem, ReasoningTextContentItem]
+
class ResponseOutputTextAnnotationAdded(ResponseEvent):
- type: Literal["response.output_text.annotation.added"] = "response.output_text.annotation.added"
+ type: Literal["response.output_text.annotation.added"] = (
+ "response.output_text.annotation.added"
+ )
item_id: str = "item_1234"
output_index: int = 0
content_index: int = 0
annotation_index: int = 0
annotation: UrlCitation
+
class ResponseWebSearchCallInProgress(ResponseEvent):
- type: Literal["response.web_search_call.in_progress"] = "response.web_search_call.in_progress"
+ type: Literal["response.web_search_call.in_progress"] = (
+ "response.web_search_call.in_progress"
+ )
output_index: int = 0
item_id: str = "item_1234"
+
class ResponseWebSearchCallSearching(ResponseEvent):
- type: Literal["response.web_search_call.searching"] = "response.web_search_call.searching"
+ type: Literal["response.web_search_call.searching"] = (
+ "response.web_search_call.searching"
+ )
output_index: int = 0
item_id: str = "item_1234"
+
class ResponseWebSearchCallCompleted(ResponseEvent):
- type: Literal["response.web_search_call.completed"] = "response.web_search_call.completed"
+ type: Literal["response.web_search_call.completed"] = (
+ "response.web_search_call.completed"
+ )
output_index: int = 0
- item_id: str = "item_1234"
\ No newline at end of file
+ item_id: str = "item_1234"
+
+
+class ResponseCodeInterpreterCallInProgress(ResponseEvent):
+ type: Literal["response.code_interpreter_call.in_progress"] = (
+ "response.code_interpreter_call.in_progress"
+ )
+ output_index: int = 0
+ item_id: str = "item_1234"
+
+
+class ResponseCodeInterpreterCallCompleted(ResponseEvent):
+ type: Literal["response.code_interpreter_call.completed"] = (
+ "response.code_interpreter_call.completed"
+ )
+ output_index: int = 0
+ item_id: str = "item_1234"
diff --git a/gpt_oss/responses_api/inference/metal.py b/gpt_oss/responses_api/inference/metal.py
index 9abe50db..9b62b660 100644
--- a/gpt_oss/responses_api/inference/metal.py
+++ b/gpt_oss/responses_api/inference/metal.py
@@ -5,74 +5,39 @@
from gpt_oss.metal import Context, Model
+# Tunables
+MAX_OUTPUT_TOKENS = 100
+
+
def setup_model(checkpoint: str) -> Callable[[list[int], float], int]:
"""Load the Metal model and return an inference function."""
model = Model(checkpoint)
context = Context(model)
- def lcp(cache: list[int], inp: list[int]) -> list[int]:
- i = 0
- max_len = min(len(cache), len(inp))
- while i < max_len and cache[i] == inp[i]:
- i += 1
- return cache[:i]
-
- tokens_so_far = []
+ seed = 0
+ output_tokens = []
def infer_next_token(
tokens: list[int], temperature: float = 0.0, new_request: bool = False
) -> int:
"""Infer next token using incremental LCP caching when possible."""
- nonlocal tokens_so_far
-
- # Fast path: first call or explicitly new request.
- if new_request or not tokens_so_far:
- context.reset()
- for t in tokens:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
+ nonlocal output_tokens
- # Longest common prefix length
- overlap = lcp(tokens_so_far, tokens)
- ol = len(overlap)
- prev_len = len(tokens_so_far)
- cur_len = len(tokens)
+ if new_request:
+ output_tokens = []
- diverged_midstream = (ol < prev_len) and (
- ol < cur_len
- ) # mismatch not at the end
-
- if diverged_midstream:
- # safest: rebuild
+ if len(output_tokens) == 0:
+ # Context handles LCP caching internally; if `tokens` matches the
+ # tokens in the KV cache, the KV cache is reused after reset+append.
context.reset()
for t in tokens:
context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
-
- if cur_len > prev_len:
- # pure extension (good for KV reuse)
- extension = tokens[prev_len:]
- for t in extension:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
- if cur_len < prev_len:
- # truncation/backspace; easiest correct behavior is rebuild
- context.reset()
- for t in tokens:
- context.append(t)
- tokens_so_far = tokens.copy()
- context.process()
- return int(context.sample(temperature=temperature))
+ output_tokens = context.sample(max_output_tokens=MAX_OUTPUT_TOKENS,
+ temperature=temperature,
+ seed=seed)
- # cur_len == prev_len and everything matches => no new tokens; just sample.
- return int(context.sample(temperature=temperature))
+ return int(output_tokens.pop(0))
return infer_next_token
diff --git a/gpt_oss/responses_api/inference/ollama.py b/gpt_oss/responses_api/inference/ollama.py
index cab54adf..e0196c6d 100644
--- a/gpt_oss/responses_api/inference/ollama.py
+++ b/gpt_oss/responses_api/inference/ollama.py
@@ -1,6 +1,6 @@
"""
-NOTE: this is a stiched together implementation that uses Ollama for inference. It's primarily used
-for testing and development. It does not leverage any prompt caching or other optimizations and
+NOTE: this is a stitched together implementation that uses Ollama for inference. It's primarily used
+for testing and development. It does not leverage any prompt caching or other optimizations and
can therefore be slow between turns.
"""
@@ -8,17 +8,17 @@
import threading
import time
from typing import Callable, Optional
-import requests
-from openai_harmony import load_harmony_encoding, HarmonyEncodingName
+import requests
+from openai_harmony import HarmonyEncodingName, load_harmony_encoding
EOS_TOKEN = 200002 # only used on hard timeout
# Tunables
-POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
-CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
-NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
-FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
+POLL_INTERVAL_S = 0.01 # 10ms between buffer checks
+CALL_MAX_WAIT_S = 0.250 # max time to block inside a single infer call
+NO_TOKEN_TIMEOUT_S = 15.0 # overall inactivity timeout before emitting EOS
+FIRST_BYTE_TIMEOUT_S = 30.0 # time to wait for first token before EOS
# Shared state
_token_buffer: list[int] = []
@@ -26,9 +26,10 @@
_stream_thread: Optional[threading.Thread] = None
_stream_done = threading.Event()
_stream_error: Optional[Exception] = None
-_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
+_last_progress_ts: float = 0.0 # updated whenever we enqueue or dequeue tokens
_previous_request_tokens: list[int] = []
+
def lcp(cache: list[int], inp: list[int]) -> list[int]:
i = 0
max_len = min(len(cache), len(inp))
@@ -36,13 +37,16 @@ def lcp(cache: list[int], inp: list[int]) -> list[int]:
i += 1
return cache[:i]
+
def _now():
return time.monotonic()
+
def _touch_progress():
global _last_progress_ts
_last_progress_ts = _now()
+
def _reset_stream_state():
global _token_buffer, _stream_thread, _stream_error
with _buffer_lock:
@@ -52,12 +56,14 @@ def _reset_stream_state():
_stream_error = None
_touch_progress()
+
def setup_model(checkpoint: str) -> Callable[[list[int], float, bool], int]:
encoding = load_harmony_encoding(HarmonyEncodingName.HARMONY_GPT_OSS)
model_name = checkpoint
def _start_stream(token_ids: list[int], temperature: float):
prompt_text = encoding.decode(token_ids)
+
def run():
nonlocal prompt_text, temperature
global _stream_error
@@ -68,21 +74,13 @@ def run():
try:
url = "http://localhost:11434/api/generate"
- context = None
- if len(_previous_request_tokens) > 0:
- context = _previous_request_tokens
- # cache_hit = lcp(_previous_request_tokens, token_ids)
- # if len(cache_hit) > 0:
- # context = cache_hit
- # print(f"Cache hit: {encoding.decode(context)}")
- # prompt_text = encoding.decode(token_ids[len(context):])
payload = {
"model": model_name,
"prompt": prompt_text,
"stream": True,
- "context": context,
"options": {"temperature": temperature},
+ "raw": True,
}
with requests.post(url, json=payload, stream=True, timeout=60) as resp:
@@ -106,9 +104,6 @@ def run():
_token_buffer.append(EOS_TOKEN)
last_len = len(toks)
_touch_progress()
- context = obj.get("context")
- if context and len(context) > 0:
- _previous_request_tokens = context
break
_stream_done.set()
@@ -187,6 +182,8 @@ def infer_next_token(
# If we reach here, we still haven't got a token—ask the caller to call again soon.
# Return a harmless token that the server will replace/ignore if your interface supports it.
# If your interface does NOT allow a sentinel, keep the short-blocking behavior above.
- return EOS_TOKEN if False else 0 # replace `0` with a PAD/NOOP token your server ignores
+ return (
+ EOS_TOKEN if False else 0
+ ) # replace `0` with a PAD/NOOP token your server ignores
return infer_next_token
diff --git a/gpt_oss/responses_api/types.py b/gpt_oss/responses_api/types.py
index 1d908e34..454d8e07 100644
--- a/gpt_oss/responses_api/types.py
+++ b/gpt_oss/responses_api/types.py
@@ -6,7 +6,8 @@
MODEL_IDENTIFIER = "gpt-oss-120b"
DEFAULT_TEMPERATURE = 0.0
REASONING_EFFORT = ReasoningEffort.LOW
-DEFAULT_MAX_OUTPUT_TOKENS = 10_000
+DEFAULT_MAX_OUTPUT_TOKENS = 131072
+
class UrlCitation(BaseModel):
type: Literal["url_citation"]
@@ -15,6 +16,7 @@ class UrlCitation(BaseModel):
url: str
title: str
+
class TextContentItem(BaseModel):
type: Union[Literal["text"], Literal["input_text"], Literal["output_text"]]
text: str
@@ -61,25 +63,37 @@ class FunctionCallOutputItem(BaseModel):
call_id: str = "call_1234"
output: str
+
class WebSearchActionSearch(BaseModel):
type: Literal["search"]
query: Optional[str] = None
+
class WebSearchActionOpenPage(BaseModel):
type: Literal["open_page"]
url: Optional[str] = None
+
class WebSearchActionFind(BaseModel):
type: Literal["find"]
pattern: Optional[str] = None
url: Optional[str] = None
+
class WebSearchCallItem(BaseModel):
type: Literal["web_search_call"]
id: str = "ws_1234"
status: Literal["in_progress", "completed", "incomplete"] = "completed"
action: Union[WebSearchActionSearch, WebSearchActionOpenPage, WebSearchActionFind]
+
+class CodeInterpreterCallItem(BaseModel):
+ type: Literal["code_interpreter_call"]
+ id: str = "ci_1234"
+ status: Literal["in_progress", "completed", "incomplete"] = "completed"
+ input: Optional[str] = None
+
+
class Error(BaseModel):
code: str
message: str
@@ -107,6 +121,10 @@ class BrowserToolConfig(BaseModel):
type: Literal["browser_search"]
+class CodeInterpreterToolConfig(BaseModel):
+ type: Literal["code_interpreter"]
+
+
class ReasoningConfig(BaseModel):
effort: Literal["low", "medium", "high"] = REASONING_EFFORT
@@ -115,11 +133,24 @@ class ResponsesRequest(BaseModel):
instructions: Optional[str] = None
max_output_tokens: Optional[int] = DEFAULT_MAX_OUTPUT_TOKENS
input: Union[
- str, list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
+ str,
+ list[
+ Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ FunctionCallOutputItem,
+ WebSearchCallItem,
+ ]
+ ],
]
model: Optional[str] = MODEL_IDENTIFIER
stream: Optional[bool] = False
- tools: Optional[list[Union[FunctionToolDefinition, BrowserToolConfig]]] = []
+ tools: Optional[
+ list[
+ Union[FunctionToolDefinition, BrowserToolConfig, CodeInterpreterToolConfig]
+ ]
+ ] = []
reasoning: Optional[ReasoningConfig] = ReasoningConfig()
metadata: Optional[Dict[str, Any]] = {}
tool_choice: Optional[Literal["auto", "none"]] = "auto"
@@ -131,7 +162,16 @@ class ResponsesRequest(BaseModel):
class ResponseObject(BaseModel):
- output: list[Union[Item, ReasoningItem, FunctionCallItem, FunctionCallOutputItem, WebSearchCallItem]]
+ output: list[
+ Union[
+ Item,
+ ReasoningItem,
+ FunctionCallItem,
+ FunctionCallOutputItem,
+ WebSearchCallItem,
+ CodeInterpreterCallItem,
+ ]
+ ]
created_at: int
usage: Optional[Usage] = None
status: Literal["completed", "failed", "incomplete", "in_progress"] = "in_progress"
diff --git a/gpt_oss/tools/python_docker/docker_tool.py b/gpt_oss/tools/python_docker/docker_tool.py
index f2d9183b..3d630cc1 100644
--- a/gpt_oss/tools/python_docker/docker_tool.py
+++ b/gpt_oss/tools/python_docker/docker_tool.py
@@ -1,6 +1,11 @@
# Run this before running the tool:
# $ docker image pull python:3.11
+import io
+import tarfile
from typing import Any, AsyncIterator
+import tempfile
+import os
+import subprocess
import docker
from openai_harmony import (
@@ -11,14 +16,16 @@
TextContent,
ToolNamespaceConfig,
)
-import io
-import tarfile
from ..tool import Tool
-
_docker_client = None
+PYTHON_EXECUTION_BACKEND = "docker"
+
+if os.environ.get("PYTHON_EXECUTION_BACKEND") == "dangerously_use_uv":
+ PYTHON_EXECUTION_BACKEND = "dangerously_use_uv"
+
def call_python_script(script: str) -> str:
"""
@@ -59,6 +66,21 @@ def call_python_script(script: str) -> str:
return output
+def call_python_script_with_uv(script: str) -> str:
+ """
+ Call a python script by writing it to a file to a temporary directory
+ and executing it with uv.
+ """
+ with tempfile.TemporaryDirectory() as temp_dir:
+ script_path = os.path.join(temp_dir, "script.py")
+ with open(script_path, "w") as f:
+ f.write(script)
+ exec_result = subprocess.run(
+ ["uv", "run", "--no-project", "python", script_path],
+ capture_output=True)
+ return exec_result.stdout.decode("utf-8")
+
+
class PythonTool(Tool):
def __init__(
self,
@@ -78,23 +100,22 @@ def name(self) -> str:
def instruction(self) -> str:
return """
Use this tool to execute Python code in your chain of thought. The code will not be shown to the user. This tool should be used for internal reasoning, but not for code that is intended to be visible to the user (e.g. when creating plots, tables, or files).
-When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you.
+When you send a message containing python code to python, it will be executed in a stateless docker container, and the stdout of that process will be returned to you. You have to use print statements to access the output.
""".strip()
@property
def tool_config(self) -> ToolNamespaceConfig:
return ToolNamespaceConfig(
- name=self.get_tool_name(),
- description=self.instruction,
- tools=[]
+ name=self.get_tool_name(), description=self.instruction, tools=[]
)
def _make_response(
self,
output: str,
+ channel: str | None = None,
) -> Message:
content = TextContent(text=output)
- return self.make_response(content=content)
+ return self.make_response(content=content, channel=channel)
def make_response(
self,
@@ -110,7 +131,7 @@ def make_response(
message = Message(
author=author,
content=[content],
- ).with_recipient('assistant')
+ ).with_recipient("assistant")
if channel:
message = message.with_channel(channel)
@@ -120,5 +141,12 @@ def make_response(
async def _process(self, message: Message) -> AsyncIterator[Message]:
script = message.content[0].text
channel = message.channel
- output = call_python_script(script)
+ if PYTHON_EXECUTION_BACKEND == "docker":
+ output = call_python_script(script)
+ elif PYTHON_EXECUTION_BACKEND == "dangerously_use_uv":
+ output = call_python_script_with_uv(script)
+ else:
+ raise ValueError(
+ f"Invalid PYTHON_EXECUTION_BACKEND: {PYTHON_EXECUTION_BACKEND}"
+ )
yield self._make_response(output, channel=channel)
diff --git a/gpt_oss/tools/simple_browser/__init__.py b/gpt_oss/tools/simple_browser/__init__.py
index 9043cb18..da3ff280 100644
--- a/gpt_oss/tools/simple_browser/__init__.py
+++ b/gpt_oss/tools/simple_browser/__init__.py
@@ -1,7 +1,8 @@
from .simple_browser_tool import SimpleBrowserTool
-from .backend import ExaBackend
+from .backend import ExaBackend, YouComBackend
__all__ = [
"SimpleBrowserTool",
"ExaBackend",
+ "YouComBackend",
]
diff --git a/gpt_oss/tools/simple_browser/backend.py b/gpt_oss/tools/simple_browser/backend.py
index 03bdf566..33daf8d6 100644
--- a/gpt_oss/tools/simple_browser/backend.py
+++ b/gpt_oss/tools/simple_browser/backend.py
@@ -3,6 +3,7 @@
"""
import functools
+import asyncio
import logging
import os
from abc import abstractmethod
@@ -87,6 +88,24 @@ async def search(
async def fetch(self, url: str, session: ClientSession) -> PageContents:
pass
+ async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:
+ headers = {"x-api-key": self._get_api_key()}
+ async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp:
+ if resp.status != 200:
+ raise BackendError(
+ f"{self.__class__.__name__} error {resp.status}: {await resp.text()}"
+ )
+ return await resp.json()
+
+ async def _get(self, session: ClientSession, endpoint: str, params: dict) -> dict:
+ headers = {"x-api-key": self._get_api_key()}
+ async with session.get(f"{self.BASE_URL}{endpoint}", params=params, headers=headers) as resp:
+ if resp.status != 200:
+ raise BackendError(
+ f"{self.__class__.__name__} error {resp.status}: {await resp.text()}"
+ )
+ return await resp.json()
+
@chz.chz(typecheck=True)
class ExaBackend(Backend):
@@ -106,14 +125,6 @@ def _get_api_key(self) -> str:
raise BackendError("Exa API key not provided")
return key
- async def _post(self, session: ClientSession, endpoint: str, payload: dict) -> dict:
- headers = {"x-api-key": self._get_api_key()}
- async with session.post(f"{self.BASE_URL}{endpoint}", json=payload, headers=headers) as resp:
- if resp.status != 200:
- raise BackendError(
- f"Exa API error {resp.status}: {await resp.text()}"
- )
- return await resp.json()
async def search(
self, query: str, topn: int, session: ClientSession
@@ -164,3 +175,78 @@ async def fetch(self, url: str, session: ClientSession) -> PageContents:
display_urls=True,
session=session,
)
+
+@chz.chz(typecheck=True)
+class YouComBackend(Backend):
+ """Backend that uses the You.com Search API."""
+
+ source: str = chz.field(doc="Description of the backend source")
+
+ BASE_URL: str = "https://api.ydc-index.io"
+
+ def _get_api_key(self) -> str:
+ key = os.environ.get("YDC_API_KEY")
+ if not key:
+ raise BackendError("You.com API key not provided")
+ return key
+
+
+ async def search(
+ self, query: str, topn: int, session: ClientSession
+ ) -> PageContents:
+ data = await self._get(
+ session,
+ "/v1/search",
+ {"query": query, "count": topn},
+ )
+ # make a simple HTML page to work with browser format
+ web_titles_and_urls, news_titles_and_urls = [], []
+ if "web" in data["results"]:
+ web_titles_and_urls = [
+ (result["title"], result["url"], result["snippets"])
+ for result in data["results"]["web"]
+ ]
+ if "news" in data["results"]:
+ news_titles_and_urls = [
+ (result["title"], result["url"], result["description"])
+ for result in data["results"]["news"]
+ ]
+ titles_and_urls = web_titles_and_urls + news_titles_and_urls
+ html_page = f"""
+
+Search Results
+
+{"".join([f"- {title} {summary}
" for title, url, summary in titles_and_urls])}
+
+
+"""
+
+ return process_html(
+ html=html_page,
+ url="",
+ title=query,
+ display_urls=True,
+ session=session,
+ )
+
+ async def fetch(self, url: str, session: ClientSession) -> PageContents:
+ is_view_source = url.startswith(VIEW_SOURCE_PREFIX)
+ if is_view_source:
+ url = url[len(VIEW_SOURCE_PREFIX) :]
+ data = await self._post(
+ session,
+ "/v1/contents",
+ {"urls": [url], "livecrawl_formats": "html"},
+ )
+ if not data:
+ raise BackendError(f"No contents returned for {url}")
+ if "html" not in data[0]:
+ raise BackendError(f"No HTML returned for {url}")
+ return process_html(
+ html=data[0].get("html", ""),
+ url=url,
+ title=data[0].get("title", ""),
+ display_urls=True,
+ session=session,
+ )
+
diff --git a/gpt_oss/tools/simple_browser/page_contents.py b/gpt_oss/tools/simple_browser/page_contents.py
index 4a18fc97..6fffd3f1 100644
--- a/gpt_oss/tools/simple_browser/page_contents.py
+++ b/gpt_oss/tools/simple_browser/page_contents.py
@@ -64,6 +64,7 @@ class Tokens:
def get_domain(url: str) -> str:
+ """Extracts the domain from a URL."""
if "http" not in url:
# If `get_domain` is called on a domain, add a scheme so that the
# original domain is returned instead of the empty string.
@@ -72,12 +73,14 @@ def get_domain(url: str) -> str:
def multiple_replace(text: str, replacements: dict[str, str]) -> str:
+ """Performs multiple string replacements using regex pass."""
regex = re.compile("(%s)" % "|".join(map(re.escape, replacements.keys())))
return regex.sub(lambda mo: replacements[mo.group(1)], text)
@functools.lru_cache(maxsize=1024)
def mark_lines(text: str) -> str:
+ """Adds line numbers (ex: 'L0:') to the beginning of each line in a string."""
# Split the string by newline characters
lines = text.split("\n")
@@ -87,17 +90,20 @@ def mark_lines(text: str) -> str:
@functools.cache
-def _tiktoken_vocabulary_lenghts(enc_name: str) -> list[int]:
+def _tiktoken_vocabulary_lengths(enc_name: str) -> list[int]:
+ """Gets the character lengths of all tokens in the specified TikToken vocabulary."""
encoding = tiktoken.get_encoding(enc_name)
return [len(encoding.decode([i])) for i in range(encoding.n_vocab)]
def warmup_caches(enc_names: list[str]) -> None:
- for _ in map(_tiktoken_vocabulary_lenghts, enc_names):
+ """Warm up the cache by computing token length lists for the given TikToken encodings."""
+ for _ in map(_tiktoken_vocabulary_lengths, enc_names):
pass
def _replace_special_chars(text: str) -> str:
+ """Replaces specific special characters with visually similar alternatives."""
replacements = {
"【": "〖",
"】": "〗",
@@ -110,16 +116,19 @@ def _replace_special_chars(text: str) -> str:
def merge_whitespace(text: str) -> str:
+ """Replace newlines with spaces and merge consecutive whitespace into a single space."""
text = text.replace("\n", " ")
text = re.sub(r"\s+", " ", text)
return text
def arxiv_to_ar5iv(url: str) -> str:
+ """Converts an arxiv.org URL to its ar5iv.org equivalent."""
return re.sub(r"arxiv.org", r"ar5iv.org", url)
def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]:
+ """Processes all anchor tags in the HTML, replaces them with a custom format and returns an ID-to-URL mapping."""
cur_domain = get_domain(cur_url)
urls: dict[str, str] = {}
urls_rev: dict[str, str] = {}
@@ -156,10 +165,12 @@ def _clean_links(root: lxml.html.HtmlElement, cur_url: str) -> dict[str, str]:
def _get_text(node: lxml.html.HtmlElement) -> str:
+ """Extracts all text from an HTML element and merges it into a whitespace-normalized string."""
return merge_whitespace(" ".join(node.itertext()))
def _remove_node(node: lxml.html.HtmlElement) -> None:
+ """Removes a node from its parent in the lxml tree."""
node.getparent().remove(node)
@@ -172,6 +183,7 @@ def _escape_md_section(text: str, snob: bool = False) -> str:
def html_to_text(html: str) -> str:
+ """Converts an HTML string to clean plaintext."""
html = re.sub(HTML_SUP_RE, r"^{\2}", html)
html = re.sub(HTML_SUB_RE, r"_{\2}", html)
# add spaces between tags such as table cells
@@ -195,6 +207,7 @@ def html_to_text(html: str) -> str:
def _remove_math(root: lxml.html.HtmlElement) -> None:
+ """Removes all