diff --git a/.claude/skills/impl-jit-kernel/SKILL.md b/.claude/skills/impl-jit-kernel/SKILL.md new file mode 100644 index 000000000..39cc02b6f --- /dev/null +++ b/.claude/skills/impl-jit-kernel/SKILL.md @@ -0,0 +1,486 @@ +--- +name: impl-jit-kernel +description: Guide for implementing CUDA or CPU JIT kernels in mllm-kernel. Use when the user asks to create, add, or implement a new kernel in mllm-kernel. +--- + +# Implementing a JIT Kernel in mllm-kernel + +## Overview + +mllm-kernel uses a JIT (Just-In-Time) compilation system built on `tvm_ffi`. Kernels are written in C++20 (`.cuh` for CUDA, `.cpp` for CPU), validated at runtime via `TensorMatcher`, and exposed to Python through a `@jit` decorator. No pre-compilation is needed -- kernels compile on first call and are cached at `~/.cache/mllm_kernel/`. + +## File Layout + +For a kernel named `my_kernel`: + +``` +mllm-kernel/ + mllm_kernel/ + cuda/ + csrc/my_kernel.cuh # CUDA kernel implementation + jit/my_kernel.py # Python JIT wrapper + jit/__init__.py # Add export here + cpu/ + csrc/my_kernel.cpp # CPU kernel implementation (Highway SIMD) + include/mllm_kernel/cpu/ + my_kernel.hpp # CPU SIMD body (NO #pragma once) + jit/my_kernel.py # Python JIT wrapper + jit/__init__.py # Add export here + tests/test_my_kernel.py # Pytest correctness tests + benchmarks/bench_my_kernel.py # Profiler benchmark vs PyTorch reference +``` + +--- + +## CUDA Kernel Walkthrough + +### Step 1: Write the `.cuh` kernel + +Create `mllm_kernel/cuda/csrc/my_kernel.cuh`: + +```cpp +#pragma once + +#include // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType +#include // RuntimeCheck, Panic, div_ceil +#include // LaunchKernel, fp16_t, bf16_t, PDL helpers + +#include +#include + +#include + +namespace { + +// --------------------------------------------------------------------------- +// 1. Parameter struct (trivially copyable, passed to kernel by value) +// --------------------------------------------------------------------------- +struct MyKernelParams { + const float* __restrict__ input; + float* __restrict__ output; + int32_t num_elements; +}; + +// --------------------------------------------------------------------------- +// 2. CUDA kernel +// --------------------------------------------------------------------------- +__global__ void my_kernel(const MyKernelParams params) { + const int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx >= params.num_elements) return; + params.output[idx] = params.input[idx] * 2.0f; +} + +// --------------------------------------------------------------------------- +// 3. Host-side launcher (entry point for TVM FFI binding) +// --------------------------------------------------------------------------- +struct MyKernel { + static void run(tvm::ffi::TensorView input, tvm::ffi::TensorView output) { + using namespace mllm_kernel::host; + + // --- Validate tensors --- + SymbolicSize N{"num_elements"}; + SymbolicDevice device; + + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device) + .verify(input); + + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device) + .verify(output); + + const int64_t n = N.unwrap(); + RuntimeCheck(n > 0, "num_elements must be positive, got ", n); + + // --- Build params --- + MyKernelParams params{ + .input = static_cast(input.data_ptr()), + .output = static_cast(output.data_ptr()), + .num_elements = static_cast(n), + }; + + // --- Launch --- + constexpr int kBlock = 256; + const int grid = static_cast(div_ceil(n, kBlock)); + LaunchKernel(grid, kBlock, device.unwrap())(my_kernel, params); + } +}; + +} // namespace +``` + +**Key rules:** + +- **Always wrap in `namespace {}`** (anonymous namespace). +- **Entry point** is a `static void run(tvm::ffi::TensorView ...)` method. +- **Validate every tensor** with `TensorMatcher` before reading `.data_ptr()`. +- **Never dereference device pointers on host** -- `data_ptr()` returns a GPU pointer. +- **Use `LaunchKernel`** to launch -- it handles stream resolution and error checking. + +### Step 2: Write the Python JIT wrapper + +Create `mllm_kernel/cuda/jit/my_kernel.py`: + +```python +"""JIT wrapper for my_kernel CUDA kernel.""" + +import torch +from mllm_kernel.jit_utils import jit + + +@jit( + args=[], + device="cuda", + cuda_files=["my_kernel.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("my_kernel", "MyKernel::run")], + func_name="my_kernel", +) +def _kernel(compiled_module, input: torch.Tensor, output: torch.Tensor) -> None: + compiled_module.my_kernel(input, output) + + +def my_kernel(input: torch.Tensor) -> torch.Tensor: + """Double every element in *input*. + + Parameters + ---------- + input : torch.Tensor + 1-D float32 tensor on CUDA. + + Returns + ------- + torch.Tensor + Same shape and dtype as *input*. + """ + output = torch.empty_like(input) + _kernel(input, output) + return output +``` + +### Step 3: Export in `__init__.py` + +Edit `mllm_kernel/cuda/jit/__init__.py` and add: + +```python +from mllm_kernel.cuda.jit.my_kernel import my_kernel +``` + +### Step 4: Clear JIT cache after editing `.cuh` + +Any time you modify the `.cuh` file, delete the cached `.so`: + +```bash +rm -rf ~/.cache/mllm_kernel/cuda_my_kernel* +``` + +The next Python call will trigger recompilation automatically. + +--- + +## Template-Parameterized CUDA Kernels + +When the kernel takes compile-time constants (e.g. block size, dtype), use `make_cpp_args`: + +```python +from mllm_kernel.jit_utils import jit, make_cpp_args + +def _make_kernel(block_size: int, use_pdl: bool): + cpp_args = make_cpp_args(block_size, use_pdl) # -> "256, true" + + @jit( + args=[block_size, use_pdl], + device="cuda", + cuda_files=["my_kernel.cuh"], + cpp_wrappers=[], + cuda_wrappers=[("my_kernel", f"MyKernel<{cpp_args}>::run")], + func_name="my_kernel", + ) + def _kernel(compiled_module, input, output): + compiled_module.my_kernel(input, output) + return _kernel +``` + +`make_cpp_args` converts Python types to C++ literals: +- `int/float` -> string literal +- `bool` -> `"true"` / `"false"` +- `torch.dtype` -> C++ type (`torch.float32` -> `"fp32_t"`, `torch.float16` -> `"fp16_t"`, `torch.bfloat16` -> `"bf16_t"`, `torch.int32` -> `"int32_t"`, etc.) + +--- + +## CPU Kernel Walkthrough + +CPU kernels use **Google Highway** for portable SIMD. The key difference: the `.hpp` body is included **multiple times** by Highway's `foreach_target` dispatch, so it must NOT have `#pragma once`. + +### Step 1: Write the SIMD body (`.hpp`) + +Create `mllm_kernel/cpu/include/mllm_kernel/cpu/my_kernel.hpp`: + +```cpp +// NOTE: NO #pragma once -- this file is included multiple times by Highway. + +#include + +HWY_BEFORE_NAMESPACE(); +namespace mllm_kernel::cpu { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +template +inline void my_kernel_impl(float* HWY_RESTRICT dst, + const float* HWY_RESTRICT src, + size_t count) { + const hn::ScalableTag d; + const size_t lanes = hn::Lanes(d); + const auto vc = hn::Set(d, static_cast(Constant)); + size_t i = 0; + for (; i + lanes <= count; i += lanes) { + const auto v = hn::Load(d, src + i); + hn::Store(hn::Add(v, vc), d, dst + i); + } + for (; i < count; ++i) { + dst[i] = src[i] + static_cast(Constant); + } +} + +// Named entry points for HWY_EXPORT +static HWY_NOINLINE HWY_MAYBE_UNUSED void my_kernel_1(float* d, const float* s, size_t n) { + my_kernel_impl<1>(d, s, n); +} + +} // namespace HWY_NAMESPACE +} // namespace mllm_kernel::cpu +HWY_AFTER_NAMESPACE(); +``` + +### Step 2: Write the `.cpp` source + +Create `mllm_kernel/cpu/csrc/my_kernel.cpp`: + +```cpp +#include +#include +#include + +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "../csrc/my_kernel.cpp" +#include + +#include + +#if HWY_ONCE +#include +#endif + +namespace mllm_kernel::cpu { +#if HWY_ONCE + +HWY_EXPORT(my_kernel_1); + +template +void my_kernel(tvm::ffi::TensorView dst, tvm::ffi::TensorView src) { + using namespace mllm_kernel::host; + SymbolicSize N{"num_elements"}; + SymbolicDevice device_; + (void)TensorMatcher({N}) + .with_dtype() + .with_device(device_) + .verify(dst) + .verify(src); + const size_t n = N.unwrap(); + auto* dst_ptr = static_cast(dst.data_ptr()); + const auto* src_ptr = static_cast(src.data_ptr()); + HWY_DYNAMIC_DISPATCH(my_kernel_1)(dst_ptr, src_ptr, n); +} + +// Explicit instantiation +template void my_kernel<1>(tvm::ffi::TensorView, tvm::ffi::TensorView); + +#endif +} // namespace mllm_kernel::cpu +``` + +### Step 3: Write the Python JIT wrapper + +Create `mllm_kernel/cpu/jit/my_kernel.py`: + +```python +import torch +from mllm_kernel.jit_utils import jit + +@jit( + args=1, + device="cpu", + cpp_files=["my_kernel.cpp"], + cpp_wrappers=[("my_kernel", "mllm_kernel::cpu::my_kernel<1>")], + func_name="my_kernel", +) +def _kernel_1(compiled_module, dst, src): + compiled_module.my_kernel(dst, src) + +def my_kernel(src: torch.Tensor) -> torch.Tensor: + dst = torch.empty_like(src) + _kernel_1(dst, src) + return dst +``` + +**Key CPU differences from CUDA:** + +| Aspect | CUDA | CPU | +|--------|------|-----| +| Source file | `.cuh` in `cuda/csrc/` | `.cpp` + `.hpp` in `cpu/csrc/` and `cpu/include/` | +| Namespace | Anonymous `namespace {}` | `mllm_kernel::cpu` | +| Device check | `with_device` | `with_device` | +| Launch | `LaunchKernel(grid, block, device)(...)` | Direct function call via `HWY_DYNAMIC_DISPATCH` | +| SIMD | CUDA warps | Highway `ScalableTag` | +| Wrapper fields | `cuda_files`, `cuda_wrappers` | `cpp_files`, `cpp_wrappers` | +| Wrapper name | `"MyKernel::run"` | `"mllm_kernel::cpu::my_kernel<1>"` (fully qualified) | + +--- + +## TensorMatcher Reference + +`TensorMatcher` validates shape, dtype, device, and strides of `tvm::ffi::TensorView` arguments. + +```cpp +using namespace mllm_kernel::host; + +// Symbolic dimensions -- bind on first .verify(), check consistency on subsequent calls +SymbolicSize B{"batch"}, N{"seq_len"}, D{"dim"}; +SymbolicSize Stride0{"stride0"}; +SymbolicDType dtype; +SymbolicDevice device; + +// Shape [B, N, D], contiguous, float32, on CUDA +(void)TensorMatcher({B, N, D}) + .with_dtype(dtype) + .with_device(device) + .verify(tensor_a); + +// Shape [B, N, D], same dtype and device (already bound) +(void)TensorMatcher({B, N, D}) + .with_dtype(dtype) + .with_device(device) + .verify(tensor_b); + +// Shape [B, D] with explicit strides (non-contiguous OK) +(void)TensorMatcher({B, D}) + .with_strides({Stride0, 1}) + .with_dtype() + .with_device(device) + .verify(indices); + +// Multiple acceptable dtypes +SymbolicDType flex_dtype; +(void)TensorMatcher({N}) + .with_dtype(flex_dtype) + .with_device(device) + .verify(mixed_tensor); + +// Extract bound values +int64_t batch = B.unwrap(); +int64_t dim = D.unwrap(); +DLDevice dev = device.unwrap(); +``` + +--- + +## LaunchKernel Reference + +```cpp +using namespace mllm_kernel::host; + +// Basic launch (resolves CUDA stream from DLDevice) +DLDevice dev = device.unwrap(); +LaunchKernel(grid_dim, block_dim, dev)(kernel_func, param_struct); + +// With shared memory +LaunchKernel(grid, block, dev, shared_mem_bytes)(kernel, params); + +// With PDL (Programmatic Dependent Launch, sm_90+) +LaunchKernel(grid, block, dev).enable_pdl(true)(kernel, params); +``` + +--- + +## Utility Reference (`mllm_kernel::host`) + +| Function | Description | +|----------|-------------| +| `RuntimeCheck(cond, msg...)` | Throws `PanicError` if `cond` is false | +| `Panic(msg...)` | Always throws (unreachable code) | +| `div_ceil(a, b)` | Integer ceiling division | +| `dtype_bytes(DLDataType)` | Byte size of a DLPack dtype | + +CUDA-only (`mllm_kernel::device`): + +| Symbol | Value | +|--------|-------| +| `kWarpThreads` | 32 | +| `kFullMask` | 0xffffffff | +| `fp16_t` | `__half` | +| `bf16_t` | `__nv_bfloat16` | + +--- + +## Testing Pattern + +Create `tests/test_my_kernel.py`: + +```python +import pytest +import torch +from mllm_kernel.cuda.jit.my_kernel import my_kernel + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") +@pytest.mark.parametrize("n", [1, 128, 1024, 65536]) +def test_my_kernel(n): + x = torch.randn(n, dtype=torch.float32, device="cuda") + result = my_kernel(x) + torch.cuda.synchronize() + expected = x * 2.0 + assert torch.allclose(result, expected) +``` + +Run: +```bash +pytest tests/test_my_kernel.py -v +``` + +--- + +## Benchmark Pattern + +Create `benchmarks/bench_my_kernel.py`. Use `torch.profiler.profile` with `ProfilerActivity.CPU` and `ProfilerActivity.CUDA`. Compare the JIT kernel against a naive PyTorch implementation and report speedup. + +Run: +```bash +python benchmarks/bench_my_kernel.py --num-elements 1000000 +``` + +--- + +## Checklist for a New Kernel + +- [ ] `.cuh` / `.cpp` + `.hpp` kernel source created +- [ ] `TensorMatcher` validates all tensor arguments (shape, dtype, device) +- [ ] No host-side dereference of device pointers +- [ ] Python `@jit` wrapper created with correct `cuda_wrappers` or `cpp_wrappers` +- [ ] Public API function added (allocates output, calls internal `_kernel`) +- [ ] Exported in `jit/__init__.py` +- [ ] JIT cache cleared after `.cuh` edits (`rm -rf ~/.cache/mllm_kernel/cuda_*`) +- [ ] Pytest test with `@pytest.mark.parametrize` and PyTorch reference +- [ ] Benchmark with `torch.profiler` (optional but recommended) + +--- + +## Common Pitfalls + +1. **Segfault from dereferencing device pointer on host** -- `tensor.data_ptr()` returns a GPU pointer for CUDA tensors. Never read its contents in host code. Use `TensorMatcher` for validation instead. +2. **Stale JIT cache** -- After editing `.cuh`, delete `~/.cache/mllm_kernel/cuda_*/`. The old `.so` will be reused otherwise. +3. **Missing `#include `** -- CPU kernels must include this inside `#if HWY_ONCE` to provide `GetChosenTarget` for the JIT-built module. +4. **`#pragma once` in Highway `.hpp`** -- Highway's `foreach_target` includes the file multiple times for different SIMD targets. `#pragma once` breaks this. +5. **Wrong wrapper name** -- CUDA uses short names (`"MyKernel::run"`); CPU uses fully qualified names (`"mllm_kernel::cpu::my_kernel<1>"`). +6. **Generator device mismatch in tests** -- `torch.randperm` needs a CUDA generator on CUDA; `torch.randint` only accepts CPU generators. Use separate generators. diff --git a/.claude/skills/install-pymllm/SKILL.md b/.claude/skills/install-pymllm/SKILL.md new file mode 100644 index 000000000..d9d637989 --- /dev/null +++ b/.claude/skills/install-pymllm/SKILL.md @@ -0,0 +1,73 @@ +--- +name: install-pymllm +description: Install the pymllm Python package. Asks the user whether to do a full build (with CMake C++ compilation) or a fast install (Python-only, skip CMake). Use when the user asks to install, set up, or reinstall pymllm. +--- + +# Install pymllm + +## Goal + +Help the user install the `pymllm` package with the right configuration for their use case. + +## Workflow + +### Step 1: Ask the user which install mode they want + +Use `AskUserQuestion` to present two options: + +**Full Install (with C++ build)** +- Compiles the C++ mllm runtime and FFI extension via CMake +- Required if the user needs mobile inference, model conversion with FFI, or CPU/QNN backends +- Slower (several minutes depending on the machine) +- Command: `pip wheel -v -w dist . && pip install dist/*.whl --force-reinstall` + +**Fast Install (Python-only, skip CMake)** +- Skips the entire CMake build step +- Only installs the pure Python package +- Recommended for users who only use CUDA backends (FlashInfer, TileLang) and do not need the C++ mllm runtime +- Much faster (seconds) +- Command: `SKBUILD_WHEEL_CMAKE=false pip install -e .` + +### Step 2: Ask editable or non-editable + +Use `AskUserQuestion` to ask: + +- **Editable (`pip install -e .`)**: For active development. Python imports point to the source tree. Changes to `.py` files take effect immediately without reinstalling. +- **Non-editable (wheel)**: For stable usage. Installs a wheel into site-packages. + +### Step 3: Ask whether the user needs CUDA optional dependencies + +Use `AskUserQuestion` to ask whether the user needs CUDA support (FlashInfer, TileLang, pyzmq, etc.). + +This determines whether to append `[cuda]` to the install specifier (e.g. `pip install -e ".[cuda]"` instead of `pip install -e .`). + +**This applies to ALL install modes.** For fast-install users this is especially important since the CUDA packages are the primary compute backend. + +### Step 4: Execute the install + +Based on user choices, compose and run the appropriate command. The install specifier is either `.` or `".[cuda]"` depending on Step 3. + +| Mode | Editable | CUDA | Command | +|------|----------|------|---------| +| Full | Yes | No | `pip install -e -v .` | +| Full | Yes | Yes | `pip install -e -v ".[cuda]"` | +| Full | No | No | `pip wheel -v -w dist . && pip install dist/*.whl --force-reinstall` | +| Full | No | Yes | `pip wheel -v -w dist . && pip install dist/*.whl --force-reinstall && pip install "pymllm[cuda]"` | +| Fast | Yes | No | `SKBUILD_WHEEL_CMAKE=false pip install -e .` | +| Fast | Yes | Yes | `SKBUILD_WHEEL_CMAKE=false pip install -e ".[cuda]"` | +| Fast | No | No | `SKBUILD_WHEEL_CMAKE=false pip wheel -v -w dist . && pip install dist/*.whl --force-reinstall` | +| Fast | No | Yes | `SKBUILD_WHEEL_CMAKE=false pip wheel -v -w dist . && pip install dist/*.whl --force-reinstall && pip install "pymllm[cuda]"` | + +### Step 5: Post-install for editable + full build + +If the user chose **editable + full build**, the compiled `.so` files live in a build directory (e.g. `build/bin/`), not in the source tree. The Python code at `pymllm/__init__.py` looks for libraries at `pymllm/lib/MllmFFIExtension.so`. A symlink is needed to bridge this gap. + +**Invoke the `/link-pymllm-lib` skill** to help the user set up the symlink. + +## Important Notes + +- The project root must contain `pyproject.toml` with `scikit-build-core` as the build backend. +- The `wheel.cmake = true` flag in `pyproject.toml` controls whether CMake runs. The env var `SKBUILD_WHEEL_CMAKE=false` overrides it at install time without modifying the file. +- For non-editable full builds, the `.so` files are bundled inside the wheel automatically — no symlink needed. +- For fast installs, `pymllm.is_mobile_available()` will return `False` since no C++ libraries are present. This is expected. +- The `[cuda]` optional dependencies are defined in `pyproject.toml` under `[project.optional-dependencies]`. diff --git a/.claude/skills/link-pymllm-lib/SKILL.md b/.claude/skills/link-pymllm-lib/SKILL.md new file mode 100644 index 000000000..b8d9760f2 --- /dev/null +++ b/.claude/skills/link-pymllm-lib/SKILL.md @@ -0,0 +1,83 @@ +--- +name: link-pymllm-lib +description: Create or update the pymllm/lib symlink to point to a C++ build directory's bin/ folder. Required after editable installs with C++ builds so that Python can find the compiled .so libraries. Use when the user asks to link, fix, or set up pymllm native libraries. +--- + +# Link pymllm lib + +## Goal + +Create a symlink at `pymllm/lib` pointing to the correct build output directory so that an editable-installed pymllm can load the compiled C++ shared libraries (`MllmFFIExtension.so`, `libMllmRT.so`, etc.). + +## Background + +When pymllm is installed in editable mode (`pip install -e .`), Python imports from the source tree directly. The C++ libraries are compiled into `/bin/` by CMake, but pymllm looks for them at `pymllm/lib/`. A symlink bridges this gap: + +``` +pymllm/lib -> //bin +``` + +## Workflow + +### Step 1: Detect available build directories + +Scan the project root for directories matching the pattern `build*/bin/` that contain `MllmFFIExtension.so` (or `.dylib` on macOS). List all valid candidates. + +Common build directories and their corresponding platforms: + +| Build directory | Platform / Config | Typical build command | +|----------------|-------------------|----------------------| +| `build/bin` | X86 CPU only | `python task.py tasks/build_x86.yaml` | +| `build-x86-cuda/bin` | X86 + CUDA | `python task.py tasks/build_x86_cuda.yaml` | +| `build-qnn-aot/bin` | X86 + QNN AOT | `python task.py tasks/build_x86_qnn_aot.yaml` | +| `build-android-arm64-v8a-qnn/bin` | Android ARM + QNN | `python task.py tasks/build_android_qnn.yaml` | + +### Step 2: Ask the user which build to link + +Use `AskUserQuestion` to let the user pick from the detected build directories. Show each option with its path and the platform it corresponds to. + +If no build directories with `.so` files are found, inform the user they need to build first: + +```bash +pip install -r requirements.txt +python task.py tasks/build_x86.yaml # or another build task +``` + +### Step 3: Check existing symlink + +Before creating a new symlink, check if `pymllm/lib` already exists: + +- If it's a symlink, show where it currently points and confirm replacement. +- If it's a real directory, warn the user and ask before removing it. +- If it doesn't exist, proceed directly. + +### Step 4: Create the symlink + +```bash +ln -sfn //bin /pymllm/lib +``` + +Use `ln -sfn` to atomically replace any existing symlink. + +### Step 5: Verify + +After creating the symlink, verify by checking that the target `.so` file is accessible: + +```bash +ls -la pymllm/lib/MllmFFIExtension.so +``` + +Then run a quick Python check: + +```bash +python -c "import pymllm; print('mobile available:', pymllm.is_mobile_available())" +``` + +If `is_mobile_available()` returns `True`, the link is correct. + +## Important Notes + +- The symlink target must be an **absolute path** for reliability. +- On macOS, the library extension is `.dylib` instead of `.so`. +- Android build directories (e.g., `build-android-arm64-v8a-qnn/bin`) contain ARM binaries that cannot run on x86 hosts. Warn the user if they select one of these on a non-ARM machine. +- If the user has multiple build directories, they can re-run this skill anytime to switch which build pymllm uses. diff --git a/.claude/skills/update-codeowners/SKILL.md b/.claude/skills/update-codeowners/SKILL.md new file mode 100644 index 000000000..286667045 --- /dev/null +++ b/.claude/skills/update-codeowners/SKILL.md @@ -0,0 +1,44 @@ +--- +name: update-codeowners +description: Updates CODEOWNERS entries safely with consistent path and owner formatting. Use when the user asks to add, remove, or modify CODEOWNERS rules, ownership mappings, reviewers, or module maintainers. +--- + +# Update CODEOWNERS + +## Goal +Maintain `CODEOWNERS` accurately while preserving the repository's existing section/comment style. + +## Workflow +1. Read the current `CODEOWNERS` file before editing. +2. Identify requested changes as one of: + - Add new path rule + - Modify owners for existing path rule + - Remove obsolete path rule + - Reorganize section comments (only if requested) +3. Update rules in place instead of creating duplicates for the same path. +4. Keep existing section headers and comment style unless the user asks to refactor structure. +5. Return a concise changelog describing which paths were added, changed, or removed. + +## Rule Format +- Use one rule per line: ` ...` +- Owners must be GitHub handles prefixed with `@`. +- Keep path style consistent with the file (in this repo, path patterns typically start with `/`). +- Do not leave rules with empty owner lists. + +## Editing Guidelines +- Prefer minimal edits near related sections. +- If a path already exists, update that line instead of adding a second conflicting line. +- If a new rule logically belongs to an existing section, place it in that section. +- Preserve human-readable grouping and blank lines. +- Keep comments intact unless they are clearly outdated and the user asked for cleanup. + +## Validation Checklist +- [ ] Every non-comment, non-empty line has at least one owner. +- [ ] Every owner token starts with `@`. +- [ ] No accidental duplicate rule for the exact same path pattern. +- [ ] Existing comments/sections were preserved unless explicitly changed. + +## Example Requests +- "Add `/mllm/models/new_model/ @alice @bob` under models." +- "Change `/core/Storage` owner to `@team-core`." +- "Remove ownership rule for deprecated path `/legacy/`." diff --git a/.codespellrc b/.codespellrc index 9ddb9d851..bbf02bd17 100644 --- a/.codespellrc +++ b/.codespellrc @@ -1,3 +1,3 @@ [codespell] -ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao +ignore-words-list = ans, als, hel, boostrap, childs, te, vas, hsa, ment, cann, thi, makro, wil, rouge, PRIS, bfloat, constexpr, cuda, dlpack, expt, forceinline, ifndef, linalg, LPBQ, mllm, pymllm, Quantizaton, Qwen, ROCM, silu, torchao, flashinfer skip = *.json,*.jsonl,*.patch,*.txt diff --git a/.gitignore b/.gitignore index 7397d6ecc..b441a62eb 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ .cache/ .tmp/ compile_commands.json -.claude/ +settings.local.json # MLLM Team Specific tasks/mllmteam* @@ -13,7 +13,7 @@ tasks/mllmteam* # Building files and binary build*/ -install*/ +/install*/ mllm-sdk-*/ mllm-install-*/ diff --git a/README-ZH.md b/README-ZH.md index e33b718d2..b5592d1e9 100644 --- a/README-ZH.md +++ b/README-ZH.md @@ -17,6 +17,7 @@ mllm ## 最新动态 +- [2026 年 3 月 18 日] 🔥🔥🔥 `pymllm` 已支持在 Jetson Orin 和 Jetson Thor 设备上使用 CUDA(实验特性,仍在持续开发中)。 - [2026 年 2 月 3 日] 🔥🔥🔥 MLLM Qnn AOT 已支持在 NPU 上全图执行![快速开始](https://ubiquitouslearning.github.io/mllm/qnn_backend/aot_execute.html), [技术报告](https://chenghuawang.github.io/News/2026-01-29-mllm-qnn-aot-support/) - [2025 年 11 月 27 日] Android Demo 更新:通过一种全新的 In-App Go 服务架构,在 Android 上实现了 Qwen3 和 DeepSeek-OCR 的稳定流式推理。 - [2025 年 11 月 23 日] MLLM v2 发布! @@ -78,6 +79,7 @@ mllm 框架可以与主流社区框架的模型检查点无缝集成。通过 ml |-----------------------------------------------------------------------------|------|-----------------------| | [Qwen3-0.6B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-0.6B-w4a32kai) | | | [Qwen3-1.7B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-1.7B-w4a8-i8mm-kai) | [W4A16-SM8650](https://modelscope.cn/models/mllmTeam/Qwen3-1.7B-Qnn-AOT-SM8650/summary) | +| [Qwen3-4B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-4B-w4a8-i8mm-kai) | | | [DeepSeek-OCR](https://github.com/deepseek-ai/DeepSeek-OCR) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/DeepSeek-OCR-w4a8-i8mm-kai) | | | [SmolLM3](https://huggingface.co/blog/smollm3)| [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/SmolLM3-3B-w4a8-i8mm-kai) | | | [Qwen2-VL-2B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-2B-Instruct-w4a32kai) || diff --git a/README.md b/README.md index 92dc29a6b..decfbf68a 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ mllm ## Latest News +- [2026 Mar 18] 🔥🔥🔥 `pymllm` now supports CUDA on Jetson Orin and Jetson Thor devices (experimental; still under active development). - [2026 Feb 03] 🔥🔥🔥 MLLM Qnn AOT Support for Full Graph Execution on NPU! [Quick Start](https://ubiquitouslearning.github.io/mllm/qnn_backend/aot_execute.html), [Technical Report](https://chenghuawang.github.io/News/2026-01-29-mllm-qnn-aot-support-en/) - [2025 Nov 27] Android Demo Update: Enabled stable Qwen3 and DeepSeek-OCR streaming on Android via a novel In-App Go Server Architecture. - [2025 Nov 23] MLLM v2 released! @@ -76,6 +77,7 @@ The mllm framework integrates seamlessly with popular community frameworks' chec |-----------------------------------------------------------------------------|------|-----------------------| | [Qwen3-0.6B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-0.6B-w4a32kai) | | | [Qwen3-1.7B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-1.7B-w4a8-i8mm-kai) | [W4A16-SM8650](https://modelscope.cn/models/mllmTeam/Qwen3-1.7B-Qnn-AOT-SM8650/) | +| [Qwen3-4B](https://github.com/QwenLM/Qwen3) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen3-4B-w4a8-i8mm-kai) | | | [DeepSeek-OCR](https://github.com/deepseek-ai/DeepSeek-OCR) | [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/DeepSeek-OCR-w4a8-i8mm-kai) | | | [SmolLM3](https://huggingface.co/blog/smollm3)| [✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/SmolLM3-3B-w4a8-i8mm-kai) | | | [Qwen2-VL-2B-Instruct](https://qwenlm.github.io/zh/blog/qwen2-vl/)|[✔️ w4a8](https://www.modelscope.cn/models/mllmTeam/Qwen2-VL-2B-Instruct-w4a32kai) || @@ -308,6 +310,15 @@ mllm provides a set of model converters to convert models from other popular mod bash ./scripts/install_pymllm.sh ``` +> **Tip for CUDA-only users:** If you only use CUDA backends (e.g., FlashInfer, TileLang) and do not need the C++ mllm runtime, you can skip the CMake build to speed up installation significantly: +> +> ```shell +> SKBUILD_WHEEL_CMAKE=false pip install -e . +> pip install pymllm[cuda] +> ``` +> +> This installs only the pure Python package without compiling the C++ components. + **future:** Once PyPI approves the creation of the mllm organization, we will publish it there. Afterwards, you can use the command below to install it in the future. diff --git a/assets/pymllm-arch.png b/assets/pymllm-arch.png new file mode 100644 index 000000000..37c48b2a0 Binary files /dev/null and b/assets/pymllm-arch.png differ diff --git a/docs/index.rst b/docs/index.rst index 1f06ef487..3db7d58e2 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -246,6 +246,17 @@ mllm provides a set of model converters to convert models from other popular mod bash ./scripts/install_pymllm.sh +.. tip:: + + **For CUDA-only users:** If you only use CUDA backends (e.g., FlashInfer, TileLang) and do not need the C++ mllm runtime, you can skip the CMake build to speed up installation significantly: + + .. code-block:: shell + + SKBUILD_WHEEL_CMAKE=false pip install -e . + pip install pymllm[cuda] + + This installs only the pure Python package without compiling the C++ components. + **future:** Once PyPI approves the creation of the mllm organization, we will publish it there. Afterwards, you can use the command below to install it in the future. diff --git a/docs/qnn_backend/aot_execute.rst b/docs/qnn_backend/aot_execute.rst index 6b03834c0..7fd2a9a6b 100644 --- a/docs/qnn_backend/aot_execute.rst +++ b/docs/qnn_backend/aot_execute.rst @@ -60,6 +60,10 @@ Taking ``qwen3_qnn_aot`` as an example, the detailed steps are as follows. pip install -e . # link lib to pymllm's dir, so that tvm ffi can find the lib + # + # NOTE:! build x86 qualcomm aot first ! + source /bin/envsetup.sh + python task.py tasks/build_x86_qnn_aot.yaml ln -s /bin/ mllm/pymllm/lib @@ -82,6 +86,7 @@ Taking ``qwen3_qnn_aot`` as an example, the detailed steps are as follows. .. code-block:: shell # In the mllm-v2 project root directory + source /bin/envsetup.sh python task.py tasks/build_x86_qnn_aot.yaml # Run the compiler program diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 3df37bddc..0f025fcf6 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(qwen2vl) add_subdirectory(qwen2vl_tracer) add_subdirectory(qwen2_5vl) add_subdirectory(qwen2_5vl_tracer) +add_subdirectory(minicpm_o45) add_subdirectory(llama) add_subdirectory(minicpm_o) add_subdirectory(minicpm4) diff --git a/examples/minicpm_o45/CMakeLists.txt b/examples/minicpm_o45/CMakeLists.txt new file mode 100644 index 000000000..bf30aa52b --- /dev/null +++ b/examples/minicpm_o45/CMakeLists.txt @@ -0,0 +1,11 @@ +add_executable(mllm-minicpm-o45-runner main.cpp) +target_link_libraries(mllm-minicpm-o45-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-minicpm-o45-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-minicpm-o45-runner-dbg main_dbg.cpp) +target_link_libraries(mllm-minicpm-o45-runner-dbg PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-minicpm-o45-runner-dbg PRIVATE ${MLLM_INCLUDE_DIR}) + +# add_executable(mllm-minicpm-o45-runner-python main_python.cpp) +# target_link_libraries(mllm-minicpm-o45-runner-python PRIVATE MllmRT MllmCPUBackend) +# target_include_directories(mllm-minicpm-o45-runner-python PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/minicpm_o45/config_minicpm_o45.json b/examples/minicpm_o45/config_minicpm_o45.json new file mode 100644 index 000000000..e432e2355 --- /dev/null +++ b/examples/minicpm_o45/config_minicpm_o45.json @@ -0,0 +1,285 @@ +{ + "architectures": [ + "MiniCPMO" + ], + "version": "4.5", + "attention_bias": false, + "attention_dropout": 0.0, + "audio_chunk_length": 1.0, + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "openai/whisper-medium", + "activation_dropout": 0.0, + "activation_function": "gelu", + "apply_spec_augment": false, + "architectures": [ + "MiniCPMWhisperEncoder" + ], + "attention_dropout": 0.0, + "begin_suppress_tokens": [ + 220, + 50257 + ], + "bos_token_id": 50257, + "classifier_proj_size": 256, + "d_model": 1024, + "decoder_attention_heads": 16, + "decoder_ffn_dim": 4096, + "decoder_layerdrop": 0.0, + "decoder_layers": 24, + "decoder_start_token_id": 50258, + "dropout": 0.0, + "encoder_attention_heads": 16, + "encoder_ffn_dim": 4096, + "encoder_layerdrop": 0.0, + "encoder_layers": 24, + "eos_token_id": 50257, + "forced_decoder_ids": [ + [ + 1, + 50259 + ], + [ + 2, + 50359 + ], + [ + 3, + 50363 + ] + ], + "init_std": 0.02, + "mask_feature_length": 10, + "mask_feature_min_masks": 0, + "mask_feature_prob": 0.0, + "mask_time_length": 10, + "mask_time_min_masks": 2, + "mask_time_prob": 0.05, + "max_length": 448, + "max_source_positions": 1500, + "max_target_positions": 448, + "median_filter_width": 7, + "model_type": "whisper", + "num_hidden_layers": 24, + "num_mel_bins": 80, + "pad_token_id": 50257, + "scale_embedding": false, + "suppress_tokens": [ + 1, + 2, + 7, + 8, + 9, + 10, + 14, + 25, + 26, + 27, + 28, + 29, + 31, + 58, + 59, + 60, + 61, + 62, + 63, + 90, + 91, + 92, + 93, + 359, + 503, + 522, + 542, + 873, + 893, + 902, + 918, + 922, + 931, + 1350, + 1853, + 1982, + 2460, + 2627, + 3246, + 3253, + 3268, + 3536, + 3846, + 3961, + 4183, + 4667, + 6585, + 6647, + 7273, + 9061, + 9383, + 10428, + 10929, + 11938, + 12033, + 12331, + 12562, + 13793, + 14157, + 14635, + 15265, + 15618, + 16553, + 16604, + 18362, + 18956, + 20075, + 21675, + 22520, + 26130, + 26161, + 26435, + 28279, + 29464, + 31650, + 32302, + 32470, + 36865, + 42863, + 47425, + 49870, + 50254, + 50258, + 50358, + 50359, + 50360, + 50361, + 50362 + ], + "torch_dtype": "float32", + "use_cache": true, + "use_weighted_layer_sum": false, + "vocab_size": 51865 + }, + "audio_pool_step": 5, + "auto_map": { + "AutoConfig": "configuration_minicpmo.MiniCPMOConfig", + "AutoModel": "modeling_minicpmo.MiniCPMO", + "AutoModelForCausalLM": "modeling_minicpmo.MiniCPMO" + }, + "batch_vision_input": true, + "bos_token_id": 151643, + "drop_vision_last_layer": false, + "eos_token_id": 151645, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 4096, + "image_size": 448, + "init_audio": true, + "init_tts": true, + "init_vision": true, + "initializer_range": 0.02, + "intermediate_size": 12288, + "listen_speak_type": "asr", + "max_position_embeddings": 40960, + "max_window_layers": 36, + "model_type": "minicpmo", + "num_attention_heads": 32, + "num_hidden_layers": 36, + "num_key_value_heads": 8, + "patch_size": 14, + "query_num": 64, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 1000000, + "slice_config": { + "max_slice_nums": 1, + "model_type": "minicpmv", + "patch_size": 14, + "scale_resolution": 448 + }, + "slice_mode": true, + "sliding_window": null, + "stream_input": true, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.51.0", + "tts_config": { + "_attn_implementation_autoset": true, + "attention_type": "full_attention", + "attn_implementation": "sdpa", + "audio_bos_token_id": 151687, + "audio_tokenizer_sample_rate": 16000, + "audio_tokenizer_type": "s3tokenizer", + "aug_layer_loss_weight": false, + "aug_loss_weight": false, + "backbone_model": "llama", + "condition_type": "hidden_text_merge", + "cosyvoice_config_path": null, + "cosyvoice_model_dir": null, + "filter_tts_loss": false, + "hidden_act": "silu", + "hidden_size": 768, + "interleaved": false, + "intermediate_size": 3072, + "llm_dim": 4096, + "llm_dim_model_base": 256, + "llm_down_scale": false, + "llm_hidden_size": 4096, + "llm_intermediate_size": 768, + "long_weight": 0.1, + "max_position_embeddings": 4096, + "model_type": "minicpmtts", + "normalize_projected_hidden": true, + "num_attention_heads": 12, + "num_audio_tokens": 6562, + "num_hidden_layers": 20, + "num_key_value_heads": 12, + "num_mel_bins": 100, + "num_text_tokens": 152064, + "num_vq": 1, + "projector_type": "mlp", + "recomputed_chunks": 1, + "s3_stream_chunk_size": 25, + "s3_stream_generate": false, + "s3_stream_n_timesteps": 10, + "s3_stream_prelook_size": 3, + "short_weight": 0.1, + "streaming": false, + "streaming_audio_chunk_size": 50, + "streaming_sliding_window": false, + "streaming_sliding_window_audio_frame_rate": 50, + "streaming_sliding_window_audio_init_text_length": 10, + "streaming_sliding_window_audio_window_size": 300, + "streaming_sliding_window_average_speed": 5, + "streaming_sliding_window_fast_speed": 7, + "streaming_sliding_window_max_text_len": 500, + "streaming_sliding_window_slow_speed": 3, + "streaming_sliding_window_text_window_size": 50, + "streaming_text_chunk_max": 7, + "streaming_text_chunk_min": 3, + "streaming_text_reserved_len": 300, + "text_eos_token_id": 151692, + "tts_filter_loss_fix": false, + "use_llm_hidden_state": false, + "use_text": true, + "window_size": 2 + }, + "use_cache": true, + "use_image_id": true, + "use_sliding_window": false, + "vision_batch_size": 16, + "vision_config": { + "_attn_implementation_autoset": true, + "attention_dropout": 0.0, + "hidden_act": "gelu_pytorch_tanh", + "hidden_size": 1152, + "image_size": 980, + "intermediate_size": 4304, + "layer_norm_eps": 1e-06, + "model_type": "siglip_vision_model", + "num_attention_heads": 16, + "num_channels": 3, + "num_hidden_layers": 27, + "patch_size": 14 + }, + "vocab_size": 151748 +} diff --git a/examples/minicpm_o45/main.cpp b/examples/minicpm_o45/main.cpp new file mode 100644 index 000000000..3c7cdf6ab --- /dev/null +++ b/examples/minicpm_o45/main.cpp @@ -0,0 +1,316 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o45/configuration_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp" +#include "mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/token2wav_prompt_cache.hpp" + +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").def(""); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version: v1/v2").def("v1"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer path (tokenizer.json)").def(""); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").def(""); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def("Describe the input."); + auto& image_path = Argparse::add("-i|--image").help("Optional image path").def(""); + auto& audio_path = Argparse::add("-a|--audio").help("Optional audio path (wav)").def(""); + auto& ref_audio_path = Argparse::add("--ref_audio") + .help("Optional reference audio path for system voice-cloning prompt (wav).") + .def(""); + auto& ref_audio_prompt_prefix = Argparse::add("--ref_audio_prompt_prefix") + .help("System prompt prefix placed before reference audio.") + .def("Clone the voice in the provided audio prompt."); + auto& ref_audio_prompt_suffix = Argparse::add("--ref_audio_prompt_suffix") + .help("System prompt suffix placed after reference audio.") + .def("As an assistant, you will speak using this voice style."); + auto& generate_tts_tokens = Argparse::add("-gt|--generate_tts_tokens") + .help("Generate TTS tokens (text->tts-token stage, no waveform)") + .def(false); + auto& text_max_new_tokens = Argparse::add("--text_max_new_tokens").help("Max new text tokens").def(512); + auto& tts_max_new_tokens = Argparse::add("--tts_max_new_tokens").help("Max new TTS tokens").def(1024); + auto& tts_min_new_tokens = Argparse::add("--tts_min_new_tokens").help("Min new TTS tokens").def(50); + auto& tts_force_no_stop = Argparse::add("--tts_force_no_stop").help("Disable TTS EOS stopping").def(false); + auto& tts_temperature = Argparse::add("--tts_temperature").help("TTS sampling temperature").def(0.8f); + auto& tts_top_k = Argparse::add("--tts_top_k").help("TTS top-k sampling (<=0 disables)").def(25); + auto& tts_top_p = Argparse::add("--tts_top_p").help("TTS top-p sampling (<=0 or >=1 disables)").def(0.85f); + auto& tts_repetition_penalty = + Argparse::add("--tts_repetition_penalty").help("TTS repetition penalty (1.0 disables)").def(1.05f); + auto& tts_repetition_window = + Argparse::add("--tts_repetition_window").help("TTS repetition window size in generated tokens").def(16); + auto& tts_greedy = Argparse::add("--tts_greedy").help("Use greedy decoding for TTS tokens").def(false); + auto& tts_tokens_out = Argparse::add("--tts_tokens_out").help("Output path for generated TTS token ids").def(""); + auto& tts_tokens_in = + Argparse::add("--tts_tokens_in").help("Input path for pre-generated TTS token ids (one per line or whitespace).").def(""); + auto& tts_wav_out = Argparse::add("--tts_wav_out") + .help("Output wav path. If set, run native C++ token2wav.") + .def(""); + auto& tts_token2wav_model_path = Argparse::add("--tts_token2wav_model_path") + .help("Path to token2wav .mllm (if empty, fallback to --model_path).") + .def(""); + auto& tts_token2wav_model_version = Argparse::add("--tts_token2wav_model_version") + .help("token2wav model version: v1/v2") + .def("v1"); + auto& tts_prompt_cache = Argparse::add("--tts_prompt_cache") + .help("Path to fixed prompt cache generated by export_prompt_cache.py") + .def(""); + auto& tts_token2wav_n_timesteps = Argparse::add("--tts_token2wav_n_timesteps") + .help("Flow diffusion steps for native token2wav") + .def(10); + auto& debug_progress = Argparse::add("--debug_progress").help("Print step-level debug progress.").def(false); + auto& debug_interval = + Argparse::add("--debug_interval").help("Token step interval for debug progress logs.").def(16); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; } + + auto token2wav_model_path = tts_token2wav_model_path.get().empty() ? model_path.get() : tts_token2wav_model_path.get(); + mllm::ModelFileVersion token2wav_file_version = mllm::ModelFileVersion::kV1; + if (tts_token2wav_model_version.get() == "v2") { token2wav_file_version = mllm::ModelFileVersion::kV2; } + + auto run_native_token2wav = !tts_wav_out.get().empty(); + if (run_native_token2wav && tts_prompt_cache.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_prompt_cache is required when --tts_wav_out is set."); + } + + auto debug_t0 = std::chrono::steady_clock::now(); + auto debug_log = [&](const std::string& msg) { + if (!debug_progress.get()) { return; } + auto now = std::chrono::steady_clock::now(); + auto sec = std::chrono::duration_cast(now - debug_t0).count() / 1000.0; + fmt::print("[debug +{:.3f}s] {}\n", sec, msg); + }; + + if (!tts_tokens_in.get().empty()) { + if (!run_native_token2wav) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_wav_out is required when --tts_tokens_in is set."); + } + if (token2wav_model_path.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Missing token2wav model path (--tts_token2wav_model_path or --model_path)."); + } + + std::ifstream ifs(tts_tokens_in.get()); + if (!ifs.is_open()) { MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open token file: {}", tts_tokens_in.get()); } + std::vector token_ids; + for (std::string line; std::getline(ifs, line);) { + if (line.empty()) { continue; } + std::stringstream ss(line); + while (!ss.eof()) { + int64_t token = 0; + ss >> token; + if (!ss.fail()) { token_ids.push_back(token); } + } + } + if (token_ids.empty()) { MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No token id found in {}", tts_tokens_in.get()); } + + fmt::print("Loaded {} TTS token IDs from {}\n", token_ids.size(), tts_tokens_in.get()); + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + mllm::shutdownContext(); + return 0; + } + + if (model_path.get().empty() || tokenizer_path.get().empty() || config_path.get().empty()) { + Argparse::printHelp(); + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "Missing required arguments: --model_path, --tokenizer_path, --config_path"); + } + + auto cfg = mllm::models::minicpm_o45::MiniCPMO45Config(config_path.get()); + + debug_log("Loading tokenizer and model modules..."); + auto tokenizer = mllm::models::minicpm_o45::MiniCPMO45Tokenizer(tokenizer_path.get(), cfg.vision_patch_size, cfg.audio_pool_step); + auto model = mllm::models::minicpm_o45::MiniCPMO45ForCausalLM(cfg); + + debug_log("Loading model parameters..."); + auto param = mllm::load(model_path.get(), file_version); + model.llm_.load(param); + model.vpm_.load(param); + model.resampler_.load(param); + model.apm_.load(param); + model.audio_projection_layer_.load(param); + if (generate_tts_tokens.get()) { model.tts_.loadFromParameter(param); } + debug_log("Model parameters loaded."); + + mllm::models::minicpm_o45::MiniCPMO45Message message; + message.prompt = prompt.get(); + message.img_file_path = image_path.get(); + message.audio_file_path = audio_path.get(); + message.ref_audio_file_path = ref_audio_path.get(); + message.ref_audio_prompt_prefix = ref_audio_prompt_prefix.get(); + message.ref_audio_prompt_suffix = ref_audio_prompt_suffix.get(); + + auto inputs = tokenizer.convertMessage(message, generate_tts_tokens.get()); + debug_log("Tokenizer convertMessage finished."); + + fmt::print("\n{:*^60}\n", " MiniCPM-o-4_5 CLI "); + fmt::print("Prompt: {}\n", message.prompt); + if (!message.img_file_path.empty()) { fmt::print("Image : {}\n", message.img_file_path); } + if (!message.audio_file_path.empty()) { fmt::print("Audio : {}\n", message.audio_file_path); } + if (!message.ref_audio_file_path.empty()) { fmt::print("RefAudio : {}\n", message.ref_audio_file_path); } + + if (!generate_tts_tokens.get()) { + fmt::print("\nResponse: "); + for (auto& step : model.chat(inputs)) { + std::wcout << tokenizer.detokenize(step.cur_token_id) << std::flush; + } + fmt::print("\n"); + } else { + auto tts_eos_id = tokenizer.lookupTokenId(L"<|tts_eos|>"); + auto im_end_id = tokenizer.lookupTokenId(L"<|im_end|>"); + auto eot_id = tokenizer.lookupTokenId(L"<|endoftext|>"); + + std::vector stop_token_ids = { + tts_eos_id, + im_end_id, + eot_id, + cfg.eos_token_id, + }; + + debug_log("Start text generation for TTS conditioning..."); + auto text_out = model.generateTextWithHidden( + inputs, text_max_new_tokens.get(), stop_token_ids, false, 1.0f, 0, 0.0f, + [&](int32_t step, int64_t token_id) { + auto interval = std::max(debug_interval.get(), 1); + if (debug_progress.get() && (step == 1 || (step % interval) == 0)) { + debug_log(fmt::format("Text generation step {} (token_id={})", step, token_id)); + } + }); + debug_log(fmt::format("Text generation done, generated_tokens={}", text_out.generated_tokens.size())); + + fmt::print("\nGenerated text tokens: {}\n", text_out.generated_tokens.size()); + fmt::print("Text (for TTS conditioning): "); + + std::vector tts_text_tokens; + std::vector tts_hidden_states; + for (size_t i = 0; i < text_out.aligned_tokens.size() && i < text_out.aligned_hidden_states.size(); ++i) { + auto token_id = text_out.aligned_tokens[i]; + if (token_id == tts_eos_id || token_id == im_end_id || token_id == eot_id || token_id == cfg.eos_token_id) { break; } + tts_text_tokens.push_back(token_id); + tts_hidden_states.push_back(text_out.aligned_hidden_states[i]); + std::wcout << tokenizer.detokenize(token_id) << std::flush; + } + fmt::print("\n"); + + if (tts_text_tokens.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "No text token available before <|tts_eos|>/<|im_end|>; cannot build TTS condition."); + } + + auto condition_embeds = model.tts_.makeConditionEmbeddings(tts_text_tokens, tts_hidden_states); + if (condition_embeds.isNil()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Failed to build TTS conditioning embeddings."); + } + debug_log(fmt::format("Built TTS condition embeddings from {} text tokens.", tts_text_tokens.size())); + + mllm::models::minicpm_o45::MiniCPMO45TTSGenerationConfig tts_cfg; + tts_cfg.max_new_tokens = tts_max_new_tokens.get(); + tts_cfg.min_new_tokens = tts_min_new_tokens.get(); + tts_cfg.force_no_stop = tts_force_no_stop.get(); + tts_cfg.do_sample = !tts_greedy.get(); + tts_cfg.temperature = {tts_temperature.get()}; + tts_cfg.top_k = tts_top_k.get(); + tts_cfg.top_p = tts_top_p.get(); + tts_cfg.repetition_penalty = tts_repetition_penalty.get(); + tts_cfg.repetition_penalty_window = tts_repetition_window.get(); + tts_cfg.debug_interval = std::max(debug_interval.get(), 1); + if (debug_progress.get()) { + tts_cfg.step_callback = [&](int32_t step, const std::vector& tokens, bool has_eos) { + auto first_token = tokens.empty() ? -1 : tokens[0]; + debug_log(fmt::format("TTS generation step {} (first_vq_token={}, has_eos={})", step, first_token, + has_eos ? "true" : "false")); + }; + } + + debug_log("Start TTS token generation..."); + auto tts_out = model.tts_.generate(condition_embeds, tts_cfg); + debug_log("TTS token generation finished."); + if (tts_out.new_ids.isNil()) { + fmt::print("Generated TTS tokens: 0\n"); + } else { + auto token_count = tts_out.new_ids.shape()[1]; + fmt::print("Generated TTS tokens: {} (finished={})\n", token_count, tts_out.finished ? "true" : "false"); + + std::vector token_ids; + token_ids.reserve(token_count); + auto tts_ids = tts_out.new_ids.contiguous(); + const auto* tts_ids_ptr = tts_ids.ptr(); + auto num_vq = tts_ids.shape()[2]; + for (int32_t i = 0; i < token_count; ++i) { token_ids.push_back(tts_ids_ptr[static_cast(i) * num_vq]); } + + fmt::print("TTS token IDs:\n"); + for (size_t i = 0; i < token_ids.size(); ++i) { + fmt::print("{}{}", token_ids[i], (i + 1 == token_ids.size() ? "\n" : " ")); + } + + if (!tts_tokens_out.get().empty()) { + std::ofstream ofs(tts_tokens_out.get()); + if (!ofs.is_open()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open output file: {}", tts_tokens_out.get()); + } + for (auto id : token_ids) { ofs << std::to_string(id) << '\n'; } + fmt::print("Saved TTS token ids to {}\n", tts_tokens_out.get()); + debug_log(fmt::format("Saved token ids to {}", tts_tokens_out.get())); + } + + if (!tts_wav_out.get().empty()) { + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + } + } + } + + model.perfSummary(); + mllm::memoryReport(); +}) diff --git a/examples/minicpm_o45/main_dbg.cpp b/examples/minicpm_o45/main_dbg.cpp new file mode 100644 index 000000000..a3d4a78af --- /dev/null +++ b/examples/minicpm_o45/main_dbg.cpp @@ -0,0 +1,325 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/minicpm_o45/configuration_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp" +#include "mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp" +#include "mllm/models/minicpm_o45/token2wav_prompt_cache.hpp" + +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +//MLLM_MAIN({ +int main(int argc, char** argv) { + ::mllm::__setup_signal_handler(); + ::mllm::initializeContext(); + + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").def(""); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version: v1/v2").def("v1"); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer path (tokenizer.json)").def(""); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").def(""); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def("Describe the input."); + auto& image_path = Argparse::add("-i|--image").help("Optional image path").def(""); + auto& audio_path = Argparse::add("-a|--audio").help("Optional audio path (wav)").def(""); + auto& ref_audio_path = Argparse::add("--ref_audio") + .help("Optional reference audio path for system voice-cloning prompt (wav).") + .def(""); + auto& ref_audio_prompt_prefix = Argparse::add("--ref_audio_prompt_prefix") + .help("System prompt prefix placed before reference audio.") + .def("Clone the voice in the provided audio prompt."); + auto& ref_audio_prompt_suffix = Argparse::add("--ref_audio_prompt_suffix") + .help("System prompt suffix placed after reference audio.") + .def("As an assistant, you will speak using this voice style."); + auto& generate_tts_tokens = Argparse::add("-gt|--generate_tts_tokens") + .help("Generate TTS tokens (text->tts-token stage, no waveform)") + .def(false); + auto& text_max_new_tokens = Argparse::add("--text_max_new_tokens").help("Max new text tokens").def(512); + auto& tts_max_new_tokens = Argparse::add("--tts_max_new_tokens").help("Max new TTS tokens").def(1024); + auto& tts_min_new_tokens = Argparse::add("--tts_min_new_tokens").help("Min new TTS tokens").def(50); + auto& tts_force_no_stop = Argparse::add("--tts_force_no_stop").help("Disable TTS EOS stopping").def(false); + auto& tts_temperature = Argparse::add("--tts_temperature").help("TTS sampling temperature").def(0.8f); + auto& tts_top_k = Argparse::add("--tts_top_k").help("TTS top-k sampling (<=0 disables)").def(25); + auto& tts_top_p = Argparse::add("--tts_top_p").help("TTS top-p sampling (<=0 or >=1 disables)").def(0.85f); + auto& tts_repetition_penalty = + Argparse::add("--tts_repetition_penalty").help("TTS repetition penalty (1.0 disables)").def(1.05f); + auto& tts_repetition_window = + Argparse::add("--tts_repetition_window").help("TTS repetition window size in generated tokens").def(16); + auto& tts_greedy = Argparse::add("--tts_greedy").help("Use greedy decoding for TTS tokens").def(false); + auto& tts_tokens_out = Argparse::add("--tts_tokens_out").help("Output path for generated TTS token ids").def(""); + auto& tts_tokens_in = + Argparse::add("--tts_tokens_in").help("Input path for pre-generated TTS token ids (one per line or whitespace).").def(""); + auto& tts_wav_out = Argparse::add("--tts_wav_out") + .help("Output wav path. If set, run native C++ token2wav.") + .def(""); + auto& tts_token2wav_model_path = Argparse::add("--tts_token2wav_model_path") + .help("Path to token2wav .mllm (if empty, fallback to --model_path).") + .def(""); + auto& tts_token2wav_model_version = Argparse::add("--tts_token2wav_model_version") + .help("token2wav model version: v1/v2") + .def("v1"); + auto& tts_prompt_cache = Argparse::add("--tts_prompt_cache") + .help("Path to fixed prompt cache generated by export_prompt_cache.py") + .def(""); + auto& tts_token2wav_n_timesteps = Argparse::add("--tts_token2wav_n_timesteps") + .help("Flow diffusion steps for native token2wav") + .def(10); + auto& debug_progress = Argparse::add("--debug_progress").help("Print step-level debug progress.").def(false); + auto& debug_interval = + Argparse::add("--debug_interval").help("Token step interval for debug progress logs.").def(16); + + Argparse::parse(argc, argv); + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v2") { file_version = mllm::ModelFileVersion::kV2; } + + auto token2wav_model_path = tts_token2wav_model_path.get().empty() ? model_path.get() : tts_token2wav_model_path.get(); + mllm::ModelFileVersion token2wav_file_version = mllm::ModelFileVersion::kV1; + if (tts_token2wav_model_version.get() == "v2") { token2wav_file_version = mllm::ModelFileVersion::kV2; } + + auto run_native_token2wav = !tts_wav_out.get().empty(); + if (run_native_token2wav && tts_prompt_cache.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_prompt_cache is required when --tts_wav_out is set."); + } + + auto debug_t0 = std::chrono::steady_clock::now(); + auto debug_log = [&](const std::string& msg) { + if (!debug_progress.get()) { return; } + auto now = std::chrono::steady_clock::now(); + auto sec = std::chrono::duration_cast(now - debug_t0).count() / 1000.0; + fmt::print("[debug +{:.3f}s] {}\n", sec, msg); + }; + + if (!tts_tokens_in.get().empty()) { + if (!run_native_token2wav) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "--tts_wav_out is required when --tts_tokens_in is set."); + } + if (token2wav_model_path.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Missing token2wav model path (--tts_token2wav_model_path or --model_path)."); + } + + std::ifstream ifs(tts_tokens_in.get()); + if (!ifs.is_open()) { MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open token file: {}", tts_tokens_in.get()); } + std::vector token_ids; + for (std::string line; std::getline(ifs, line);) { + if (line.empty()) { continue; } + std::stringstream ss(line); + while (!ss.eof()) { + int64_t token = 0; + ss >> token; + if (!ss.fail()) { token_ids.push_back(token); } + } + } + if (token_ids.empty()) { MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "No token id found in {}", tts_tokens_in.get()); } + + fmt::print("Loaded {} TTS token IDs from {}\n", token_ids.size(), tts_tokens_in.get()); + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + mllm::shutdownContext(); + return 0; + } + + if (model_path.get().empty() || tokenizer_path.get().empty() || config_path.get().empty()) { + Argparse::printHelp(); + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "Missing required arguments: --model_path, --tokenizer_path, --config_path"); + } + + auto cfg = mllm::models::minicpm_o45::MiniCPMO45Config(config_path.get()); + + debug_log("Loading tokenizer and model modules..."); + auto tokenizer = mllm::models::minicpm_o45::MiniCPMO45Tokenizer(tokenizer_path.get(), cfg.vision_patch_size, cfg.audio_pool_step); + auto model = mllm::models::minicpm_o45::MiniCPMO45ForCausalLM(cfg); + + debug_log("Loading model parameters..."); + auto param = mllm::load(model_path.get(), file_version); + model.llm_.load(param); + model.vpm_.load(param); + model.resampler_.load(param); + model.apm_.load(param); + model.audio_projection_layer_.load(param); + if (generate_tts_tokens.get()) { model.tts_.loadFromParameter(param); } + debug_log("Model parameters loaded."); + + mllm::models::minicpm_o45::MiniCPMO45Message message; + message.prompt = prompt.get(); + message.img_file_path = image_path.get(); + message.audio_file_path = audio_path.get(); + message.ref_audio_file_path = ref_audio_path.get(); + message.ref_audio_prompt_prefix = ref_audio_prompt_prefix.get(); + message.ref_audio_prompt_suffix = ref_audio_prompt_suffix.get(); + + auto inputs = tokenizer.convertMessage(message, generate_tts_tokens.get()); + debug_log("Tokenizer convertMessage finished."); + + fmt::print("\n{:*^60}\n", " MiniCPM-o-4_5 CLI "); + fmt::print("Prompt: {}\n", message.prompt); + if (!message.img_file_path.empty()) { fmt::print("Image : {}\n", message.img_file_path); } + if (!message.audio_file_path.empty()) { fmt::print("Audio : {}\n", message.audio_file_path); } + if (!message.ref_audio_file_path.empty()) { fmt::print("RefAudio : {}\n", message.ref_audio_file_path); } + + if (!generate_tts_tokens.get()) { + fmt::print("\nResponse: "); + for (auto& step : model.chat(inputs)) { + std::wcout << tokenizer.detokenize(step.cur_token_id) << std::flush; + } + fmt::print("\n"); + } else { + auto tts_eos_id = tokenizer.lookupTokenId(L"<|tts_eos|>"); + auto im_end_id = tokenizer.lookupTokenId(L"<|im_end|>"); + auto eot_id = tokenizer.lookupTokenId(L"<|endoftext|>"); + + std::vector stop_token_ids = { + tts_eos_id, + im_end_id, + eot_id, + cfg.eos_token_id, + }; + + debug_log("Start text generation for TTS conditioning..."); + auto text_out = model.generateTextWithHidden( + inputs, text_max_new_tokens.get(), stop_token_ids, false, 1.0f, 0, 0.0f, + [&](int32_t step, int64_t token_id) { + auto interval = std::max(debug_interval.get(), 1); + if (debug_progress.get() && (step == 1 || (step % interval) == 0)) { + debug_log(fmt::format("Text generation step {} (token_id={})", step, token_id)); + } + }); + debug_log(fmt::format("Text generation done, generated_tokens={}", text_out.generated_tokens.size())); + + fmt::print("\nGenerated text tokens: {}\n", text_out.generated_tokens.size()); + fmt::print("Text (for TTS conditioning): "); + + std::vector tts_text_tokens; + std::vector tts_hidden_states; + for (size_t i = 0; i < text_out.aligned_tokens.size() && i < text_out.aligned_hidden_states.size(); ++i) { + auto token_id = text_out.aligned_tokens[i]; + if (token_id == tts_eos_id || token_id == im_end_id || token_id == eot_id || token_id == cfg.eos_token_id) { break; } + tts_text_tokens.push_back(token_id); + tts_hidden_states.push_back(text_out.aligned_hidden_states[i]); + std::wcout << tokenizer.detokenize(token_id) << std::flush; + } + fmt::print("\n"); + + if (tts_text_tokens.empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, + "No text token available before <|tts_eos|>/<|im_end|>; cannot build TTS condition."); + } + + auto condition_embeds = model.tts_.makeConditionEmbeddings(tts_text_tokens, tts_hidden_states); + if (condition_embeds.isNil()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Failed to build TTS conditioning embeddings."); + } + debug_log(fmt::format("Built TTS condition embeddings from {} text tokens.", tts_text_tokens.size())); + + mllm::models::minicpm_o45::MiniCPMO45TTSGenerationConfig tts_cfg; + tts_cfg.max_new_tokens = tts_max_new_tokens.get(); + tts_cfg.min_new_tokens = tts_min_new_tokens.get(); + tts_cfg.force_no_stop = tts_force_no_stop.get(); + tts_cfg.do_sample = !tts_greedy.get(); + tts_cfg.temperature = {tts_temperature.get()}; + tts_cfg.top_k = tts_top_k.get(); + tts_cfg.top_p = tts_top_p.get(); + tts_cfg.repetition_penalty = tts_repetition_penalty.get(); + tts_cfg.repetition_penalty_window = tts_repetition_window.get(); + tts_cfg.debug_interval = std::max(debug_interval.get(), 1); + if (debug_progress.get()) { + tts_cfg.step_callback = [&](int32_t step, const std::vector& tokens, bool has_eos) { + auto first_token = tokens.empty() ? -1 : tokens[0]; + debug_log(fmt::format("TTS generation step {} (first_vq_token={}, has_eos={})", step, first_token, + has_eos ? "true" : "false")); + }; + } + + debug_log("Start TTS token generation..."); + auto tts_out = model.tts_.generate(condition_embeds, tts_cfg); + debug_log("TTS token generation finished."); + if (tts_out.new_ids.isNil()) { + fmt::print("Generated TTS tokens: 0\n"); + } else { + auto token_count = tts_out.new_ids.shape()[1]; + fmt::print("Generated TTS tokens: {} (finished={})\n", token_count, tts_out.finished ? "true" : "false"); + + std::vector token_ids; + token_ids.reserve(token_count); + auto tts_ids = tts_out.new_ids.contiguous(); + const auto* tts_ids_ptr = tts_ids.ptr(); + auto num_vq = tts_ids.shape()[2]; + for (int32_t i = 0; i < token_count; ++i) { token_ids.push_back(tts_ids_ptr[static_cast(i) * num_vq]); } + + fmt::print("TTS token IDs:\n"); + for (size_t i = 0; i < token_ids.size(); ++i) { + fmt::print("{}{}", token_ids[i], (i + 1 == token_ids.size() ? "\n" : " ")); + } + + if (!tts_tokens_out.get().empty()) { + std::ofstream ofs(tts_tokens_out.get()); + if (!ofs.is_open()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kIOError, "Failed to open output file: {}", tts_tokens_out.get()); + } + for (auto id : token_ids) { ofs << std::to_string(id) << '\n'; } + fmt::print("Saved TTS token ids to {}\n", tts_tokens_out.get()); + debug_log(fmt::format("Saved token ids to {}", tts_tokens_out.get())); + } + + if (!tts_wav_out.get().empty()) { + debug_log("Loading token2wav model and prompt cache..."); + auto token2wav_param = mllm::load(token2wav_model_path, token2wav_file_version); + auto prompt_cache = mllm::models::minicpm_o45::loadMiniCPMO45Token2WavPromptCache(tts_prompt_cache.get()); + + mllm::models::minicpm_o45::MiniCPMO45Token2WavModel token2wav("token2wav", {}); + token2wav.loadFromParameter(token2wav_param); + debug_log("Native token2wav model loaded."); + + debug_log("Running native flow + HiFT..."); + auto wav = token2wav.infer(token_ids, prompt_cache, std::max(1, tts_token2wav_n_timesteps.get())); + auto wav_i16 = wav * 32767.0f; + wenet::WavWriter wav_writer(wav_i16.ptr(), wav_i16.shape().back(), 1, 24000, 16); + wav_writer.Write(tts_wav_out.get()); + fmt::print("Saved TTS waveform to {}\n", tts_wav_out.get()); + debug_log("Native token2wav finished."); + } + } + } + + model.perfSummary(); + mllm::memoryReport(); + + ::mllm::shutdownContext(); + return 0; +} + +//}) diff --git a/examples/minicpm_o45/quant_cfg_gguf_q4_0_aggressive.json b/examples/minicpm_o45/quant_cfg_gguf_q4_0_aggressive.json new file mode 100644 index 000000000..35c587d48 --- /dev/null +++ b/examples/minicpm_o45/quant_cfg_gguf_q4_0_aggressive.json @@ -0,0 +1,359 @@ +{ + "^llm\\.model\\.embed_tokens\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 151748, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.q_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.k_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1024, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.v_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1024, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.o_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.gate_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 12288, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.up_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 12288, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.down_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 12288 + ], + "replace": true + } + }, + "^llm\\.lm_head\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 151748, + 4096 + ], + "replace": true + } + }, + + "^vpm\\.embeddings\\.position_embedding\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4900, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.self_attn\\.q_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1152, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.self_attn\\.k_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1152, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.self_attn\\.v_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1152, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.self_attn\\.out_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1152, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.mlp\\.fc1\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4304, + 1152 + ], + "replace": true + } + }, + "^vpm\\.encoder\\.layers\\.\\d+\\.mlp\\.fc2\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 1152, + 4304 + ], + "replace": true + } + }, + + "^resampler\\.kv_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 1152 + ], + "replace": true + } + }, + "^resampler\\.attn\\.out_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + + "^audio_projection_layer\\.linear1\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 1024 + ], + "replace": true + } + }, + "^audio_projection_layer\\.linear2\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + + "^tts\\.projector_spk\\.linear1\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 4096 + ], + "replace": true + } + }, + "^tts\\.projector_spk\\.linear2\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^tts\\.projector_semantic\\.linear1\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 4096 + ], + "replace": true + } + }, + "^tts\\.projector_semantic\\.linear2\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + + "^tts\\.model\\.layers\\.\\d+\\.self_attn\\.q_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.self_attn\\.k_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.self_attn\\.v_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.self_attn\\.o_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.mlp\\.gate_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 3072, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.mlp\\.up_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 3072, + 768 + ], + "replace": true + } + }, + "^tts\\.model\\.layers\\.\\d+\\.mlp\\.down_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 768, + 3072 + ], + "replace": true + } + }, + "^tts\\.emb_text\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_0", + "shape": [ + 152064, + 768 + ], + "replace": true + } + } +} diff --git a/examples/minicpm_o45/quant_cfg_gguf_q4_k.json b/examples/minicpm_o45/quant_cfg_gguf_q4_k.json new file mode 100644 index 000000000..59aecafb0 --- /dev/null +++ b/examples/minicpm_o45/quant_cfg_gguf_q4_k.json @@ -0,0 +1,101 @@ +{ + "^llm\\.model\\.embed_tokens\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 151748, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.q_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.k_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 1024, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.v_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q5_K", + "shape": [ + 1024, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.self_attn\\.o_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 4096, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.gate_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 12288, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.up_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q4_K", + "shape": [ + 12288, + 4096 + ], + "replace": true + } + }, + "^llm\\.model\\.layers\\.\\d+\\.mlp\\.down_proj\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q5_K", + "shape": [ + 4096, + 12288 + ], + "replace": true + } + }, + "^llm\\.lm_head\\.weight$": { + "hints": { + "quant_method": "gguf", + "gguf_type": "Q6_K", + "shape": [ + 151748, + 4096 + ], + "replace": true + } + } +} diff --git a/examples/qwen2_5omni/CMakeLists.txt b/examples/qwen2_5omni/CMakeLists.txt new file mode 100644 index 000000000..2fdd3690f --- /dev/null +++ b/examples/qwen2_5omni/CMakeLists.txt @@ -0,0 +1,19 @@ +add_executable(mllm-qwen2_5-omni-text-runner text_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-text-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-text-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner image_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-runner audio_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-audio-out-runner audio_out_infer.cpp) +target_link_libraries(mllm-qwen2_5-omni-audio-out-runner PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-audio-out-runner PRIVATE ${MLLM_INCLUDE_DIR}) + +add_executable(mllm-qwen2_5-omni-image-runner-dbg image_infer_dbg.cpp) +target_link_libraries(mllm-qwen2_5-omni-image-runner-dbg PRIVATE MllmRT MllmCPUBackend) +target_include_directories(mllm-qwen2_5-omni-image-runner-dbg PRIVATE ${MLLM_INCLUDE_DIR}) diff --git a/examples/qwen2_5omni/audio_infer.cpp b/examples/qwen2_5omni/audio_infer.cpp new file mode 100644 index 000000000..d159c2b3e --- /dev/null +++ b/examples/qwen2_5omni/audio_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Audio CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string audio_path; + std::string prompt_text; + + fmt::print("Audio path (or 'exit/quit'): "); + //std::getline(std::cin, audio_path); + //if (audio_path == "exit" || audio_path == "quit") { return 0; } + audio_path = ""; + + fmt::print("Prompt text: "); + //std::getline(std::cin, prompt_text); + //if (prompt_text.empty()) { prompt_text = "Please describe the audio."; } + prompt_text = ""; + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/audio_out_infer.cpp b/examples/qwen2_5omni/audio_out_infer.cpp new file mode 100644 index 000000000..9e46fcd0e --- /dev/null +++ b/examples/qwen2_5omni/audio_out_infer.cpp @@ -0,0 +1,93 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include +#include "wenet_audio/wav.h" + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + auto& spk_dict_path = Argparse::add("-s|--spk_dict_path").help("Speaker json path").required(true); + auto& prompt = Argparse::add("-p|--prompt").help("Prompt text").def(""); + auto& image_path = Argparse::add("-i|--image_path").help("Image path").def(""); + auto& audio_path = Argparse::add("-a|--audio_path").help("Audio path").def(""); + auto& speaker = Argparse::add("-sp|--speaker").help("Speaker name (default: first entry)").def(""); + auto& output_path = Argparse::add("-o|--output_path").help("Output wav path").def("./qwen2_5omni.wav"); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + if (!image_path.get().empty() && !audio_path.get().empty()) { + MLLM_ERROR_EXIT(mllm::ExitCode::kCoreError, "Only one of --image_path or --audio_path can be set."); + } + + auto qwen_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen_omni = mllm::models::qwen2_5omni::Qwen2_5OmniForConditionalGeneration(qwen_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen_omni.load(param); + qwen_omni.loadSpeakers(spk_dict_path.get()); + + std::string prompt_text = prompt.get(); + if (prompt_text.empty()) { + fmt::print("Prompt text: "); + std::getline(std::cin, prompt_text); + if (prompt_text.empty()) { prompt_text = "Please respond."; } + } + + mllm::models::ARGenerationOutputPast inputs; + if (!image_path.get().empty()) { + inputs = qwen_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path.get()}); + } else if (!audio_path.get().empty()) { + inputs = qwen_tokenizer.convertAudioMessage({.prompt = prompt_text, .audio_file_path = audio_path.get()}); + } else { + inputs = qwen_tokenizer.convertMessage({.prompt = prompt_text}); + } + + mllm::models::qwen2_5omni::Qwen2_5OmniAudioGenerationConfig gen_cfg; + auto output = qwen_omni.generateAudio(inputs, gen_cfg, speaker.get()); + + auto input_len = inputs["sequence"].shape()[1]; + auto total_len = output.sequences.shape()[1]; + fmt::print("\nResponse: "); + for (int i = input_len; i < total_len; ++i) { + std::wcout << qwen_tokenizer.detokenize(output.sequences.at({0, i})) << std::flush; + } + fmt::print("\n"); + + auto wav = output.wav * 32767.0f; + wenet::WavWriter wav_writer(wav.ptr(), wav.shape().back(), 1, 24000, 16); + wav_writer.Write(output_path.get()); + + fmt::print("Saved audio to {}\n", output_path.get()); + + qwen_omni.thinker_.perfSummary(); + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/config_qwen2_5omni_7B.json b/examples/qwen2_5omni/config_qwen2_5omni_7B.json new file mode 100644 index 000000000..8f27b94b9 --- /dev/null +++ b/examples/qwen2_5omni/config_qwen2_5omni_7B.json @@ -0,0 +1,495 @@ +{ + "architectures": [ + "Qwen2_5OmniModel" + ], + "enable_audio_output": true, + "enable_talker": true, + "model_type": "qwen2_5_omni", + "talker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/talker", + "architectures": [ + "Qwen2OmniTalkerForConditionalGeneration" + ], + "attention_dropout": 0.0, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "embedding_size": 3584, + "head_dim": 128, + "hidden_act": "silu", + "hidden_size": 896, + "image_token_index": 151655, + "init_std": 0.02, + "initializer_range": 0.02, + "intermediate_size": 18944, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "model_type": "qwen2_5_omni_talker", + "num_attention_heads": 12, + "num_hidden_layers": 24, + "num_key_value_heads": 4, + "position_id_per_seconds": 25, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "rope_theta": 1000000.0, + "seconds_per_chunk": 2, + "sliding_window": 32768, + "spatial_merge_size": 2, + "torch_dtype": "bfloat16", + "tts_codec_end_token_id": 8294, + "tts_codec_mask_token_id": 8296, + "tts_codec_pad_token_id": 8292, + "tts_codec_start_token_id": 8293, + "tts_text_end_token_id": 151861, + "tts_text_pad_token_id": 151859, + "tts_text_start_token_id": 151860, + "use_cache": true, + "use_sliding_window": false, + "video_token_index": 151656, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vocab_size": 8448 + }, + "thinker_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "Qwen2.5-Omni-7B/thinker", + "architectures": [ + "Qwen2OmniNaViTThinkerForConditionalGeneration" + ], + "audio_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "activation_dropout": 0.0, + "activation_function": "gelu", + "add_cross_attention": false, + "architectures": null, + "attention_dropout": 0.0, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "d_model": 1280, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.0, + "early_stopping": false, + "encoder_attention_heads": 20, + "encoder_ffn_dim": 5120, + "encoder_layerdrop": 0.0, + "encoder_layers": 32, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "init_std": 0.02, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "max_source_positions": 1500, + "min_length": 0, + "model_type": "qwen2_5_omni_audio_encoder", + "n_window": 100, + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_hidden_layers": 32, + "num_mel_bins": 128, + "num_return_sequences": 1, + "output_attentions": false, + "output_dim": 3584, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "scale_embedding": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "text_config": { + "model_type": "qwen2_5_omni_text", + "hidden_act": "silu", + "hidden_size": 3584, + "init_std": 0.02, + "intermediate_size": 18944, + "vocab_size": 152064, + "num_attention_heads": 28, + "num_hidden_layers": 28, + "num_key_value_heads": 4, + "max_position_embeddings": 32768, + "max_window_layers": 28, + "rms_norm_eps": 1e-06, + "rope_scaling": { + "mrope_section": [ + 16, + 24, + 24 + ], + "rope_type": "default", + "type": "default" + }, + "use_cache": true, + "rope_theta": 1000000.0, + "use_sliding_window": false, + "sliding_window": 32768, + "attention_dropout": 0.0, + "tie_word_embeddings": false + }, + "audio_end_token_id": 151648, + "audio_start_token_id": 151647, + "audio_token_index": 151646, + "bos_token_id": 151644, + "eos_token_id": 151645, + "ignore_index": -100, + "image_token_index": 151655, + "init_std": 0.02, + "model_type": "qwen2_5_omni_thinker", + "pad_token_id": 151643, + "position_id_per_seconds": 25, + "seconds_per_chunk": 2, + "torch_dtype": "bfloat16", + "user_token_id": 872, + "video_token_index": 151656, + "vision_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 32, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "embed_dim": 1280, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "fullatt_block_indexes": [ + 7, + 15, + 23, + 31 + ], + "hidden_act": "silu", + "hidden_size": 1280, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "in_channels": 3, + "in_chans": 3, + "init_std": 0.02, + "intermediate_size": 3420, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "min_length": 0, + "model_type": "qwen2_5_omni_vision_encoder", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_heads": 16, + "num_return_sequences": 1, + "out_hidden_size": 3584, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "patch_size": 14, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "spatial_merge_size": 2, + "spatial_patch_size": 14, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "temporal_patch_size": 2, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "tokens_per_second": 25, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false, + "window_size": 112 + }, + "vision_end_token_id": 151653, + "vision_start_token_id": 151652, + "vision_token_id": 151654 + }, + "token2wav_config": { + "_attn_implementation_autoset": true, + "bigvgan_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "diversity_penalty": 0.0, + "do_sample": false, + "early_stopping": false, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_bigvgan", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repetition_penalty": 1.0, + "resblock_dilation_sizes": [ + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ], + [ + 1, + 3, + 5 + ] + ], + "resblock_kernel_sizes": [ + 3, + 7, + 11 + ], + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": null, + "torchscript": false, + "typical_p": 1.0, + "upsample_initial_channel": 1536, + "upsample_kernel_sizes": [ + 11, + 7, + 4, + 4, + 4, + 4 + ], + "upsample_rates": [ + 5, + 3, + 2, + 2, + 2, + 2 + ], + "use_bfloat16": false, + "use_bias_at_final": false + }, + "dit_config": { + "_attn_implementation_autoset": true, + "_name_or_path": "", + "add_cross_attention": false, + "architectures": null, + "bad_words_ids": null, + "begin_suppress_tokens": null, + "bos_token_id": null, + "chunk_size_feed_forward": 0, + "cross_attention_hidden_size": null, + "decoder_start_token_id": null, + "depth": 22, + "dim": 1024, + "diversity_penalty": 0.0, + "do_sample": false, + "dropout": 0.1, + "early_stopping": false, + "emb_dim": 512, + "enc_attention_channels": 64, + "enc_channels": [ + 256, + 256, + 256, + 256, + 768 + ], + "enc_dilations": [ + 1, + 2, + 3, + 4, + 1 + ], + "enc_dim": 128, + "enc_emb_dim": 192, + "enc_global_context": true, + "enc_kernel_sizes": [ + 5, + 3, + 3, + 3, + 1 + ], + "enc_lin_neurons": 192, + "enc_res2net_scale": 2, + "enc_se_channels": 64, + "encoder_no_repeat_ngram_size": 0, + "eos_token_id": null, + "exponential_decay_length_penalty": null, + "ff_mult": 2, + "finetuning_task": null, + "forced_bos_token_id": null, + "forced_eos_token_id": null, + "head_dim": 64, + "heads": 16, + "id2label": { + "0": "LABEL_0", + "1": "LABEL_1" + }, + "is_decoder": false, + "is_encoder_decoder": false, + "label2id": { + "LABEL_0": 0, + "LABEL_1": 1 + }, + "length_penalty": 1.0, + "max_length": 20, + "mel_dim": 80, + "min_length": 0, + "model_type": "qwen2_5_omni_dit", + "no_repeat_ngram_size": 0, + "num_beam_groups": 1, + "num_beams": 1, + "num_embeds": 8193, + "num_return_sequences": 1, + "output_attentions": false, + "output_hidden_states": false, + "output_scores": false, + "pad_token_id": null, + "prefix": null, + "problem_type": null, + "pruned_heads": {}, + "remove_invalid_values": false, + "repeats": 2, + "repetition_penalty": 1.0, + "return_dict": true, + "return_dict_in_generate": false, + "sep_token_id": null, + "suppress_tokens": null, + "task_specific_params": null, + "temperature": 1.0, + "tf_legacy_loss": false, + "tie_encoder_decoder": false, + "tie_word_embeddings": true, + "tokenizer_class": null, + "top_k": 50, + "top_p": 1.0, + "torch_dtype": "float32", + "torchscript": false, + "typical_p": 1.0, + "use_bfloat16": false + }, + "model_type": "qwen2_5_omni_token2wav" + }, + "torch_dtype": "bfloat16", + "transformers_version": "4.50.0.dev0" +} diff --git a/examples/qwen2_5omni/image_infer.cpp b/examples/qwen2_5omni/image_infer.cpp new file mode 100644 index 000000000..473f7de60 --- /dev/null +++ b/examples/qwen2_5omni/image_infer.cpp @@ -0,0 +1,84 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/examples/qwen2_5omni/image_infer_dbg.cpp b/examples/qwen2_5omni/image_infer_dbg.cpp new file mode 100644 index 000000000..de21c8ec7 --- /dev/null +++ b/examples/qwen2_5omni/image_infer_dbg.cpp @@ -0,0 +1,91 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +//MLLM_MAIN({ +int main(int argc, char** argv) { + ::mllm::__setup_signal_handler(); + ::mllm::initializeContext(); + + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = + mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get(), qwen2_5omni_cfg.visual_spatial_merge_size); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Image CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string image_path; + std::string prompt_text; + + fmt::print("Image path (or 'exit/quit'): "); + image_path = "../../rsc/pics.jpg"; + //std::getline(std::cin, image_path); + if (image_path == "exit" || image_path == "quit") { return 0; } + + fmt::print("Prompt text: "); + prompt_text = "描述图片中物体"; + //std::getline(std::cin, prompt_text); + + try { + fmt::print("Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertVisionMessage({.prompt = prompt_text, .img_file_path = image_path}); + + fmt::print("\nResponse: "); + qwen2_5omni.streamGenerate(inputs, + { + {"do_sample", mllm::AnyValue(false)}, + {"max_length", mllm::AnyValue(qwen2_5omni_cfg.max_cache_length)}, + }, + [&](int64_t token_id) { + auto str = qwen2_5omni_tokenizer.detokenize(token_id); + std::wcout << str << std::flush; + }); + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\nError: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); + + ::mllm::shutdownContext(); + return 0; +} diff --git a/examples/qwen2_5omni/text_infer.cpp b/examples/qwen2_5omni/text_infer.cpp new file mode 100644 index 000000000..299a0e07d --- /dev/null +++ b/examples/qwen2_5omni/text_infer.cpp @@ -0,0 +1,72 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include +#include + +using mllm::Argparse; + +MLLM_MAIN({ + mllm::Logger::level() = mllm::LogLevel::kError; + + auto& help = Argparse::add("-h|--help").help("Show help message"); + auto& model_path = Argparse::add("-m|--model_path").help("Model path").required(true); + auto& model_version = Argparse::add("-mv|--model_version").help("Model version").required(true); + auto& tokenizer_path = Argparse::add("-t|--tokenizer_path").help("Tokenizer directory").required(true); + auto& config_path = Argparse::add("-c|--config_path").help("Config path").required(true); + + Argparse::parse(argc, argv); + + mllm::ModelFileVersion file_version = mllm::ModelFileVersion::kV1; + if (model_version.get() == "v1") { + file_version = mllm::ModelFileVersion::kV1; + } else if (model_version.get() == "v2") { + file_version = mllm::ModelFileVersion::kV2; + } + + if (help.isSet()) { + Argparse::printHelp(); + mllm::shutdownContext(); + return 0; + } + + { + auto qwen2_5omni_cfg = mllm::models::qwen2_5omni::Qwen2_5OmniConfig(config_path.get()); + auto qwen2_5omni_tokenizer = mllm::models::qwen2_5omni::Qwen2_5OmniTokenizer(tokenizer_path.get()); + auto qwen2_5omni = mllm::models::qwen2_5omni::Qwen2_5OmniForCausalLM(qwen2_5omni_cfg); + + auto param = mllm::load(model_path.get(), file_version); + qwen2_5omni.thinker_.load(param); + + fmt::print("\n{:*^60}\n", " Qwen2.5-Omni Text CLI "); + fmt::print("Enter 'exit' or 'quit' to end the session\n\n"); + + std::string prompt_text; + + fmt::print("💬 Prompt text (or 'exit/quit'): "); + std::getline(std::cin, prompt_text); + + if (prompt_text == "exit" || prompt_text == "quit") { return 0; } + + try { + fmt::print("🔄 Processing...\n"); + auto inputs = qwen2_5omni_tokenizer.convertMessage({.prompt = prompt_text}); + + fmt::print("\n🤖 Response: "); + for (auto& step : qwen2_5omni.chat(inputs)) { + std::wcout << qwen2_5omni_tokenizer.detokenize(step.cur_token_id) << std::flush; + } + + fmt::print("\n{}\n", std::string(60, '-')); + } catch (const std::exception& e) { fmt::print("\n❌ Error: {}\n{}\n", e.what(), std::string(60, '-')); } + + qwen2_5omni.perfSummary(); + } + + mllm::print("\n"); + mllm::memoryReport(); +}) diff --git a/mllm-kernel/.gitignore b/mllm-kernel/.gitignore index df61d0fae..3eefc8fba 100644 --- a/mllm-kernel/.gitignore +++ b/mllm-kernel/.gitignore @@ -3,3 +3,4 @@ build-py/ .vscode/settings.json compile_commands.json .clangd +.pytest_cache/ diff --git a/mllm-kernel/README.md b/mllm-kernel/README.md index 14c8118f0..0a4580495 100644 --- a/mllm-kernel/README.md +++ b/mllm-kernel/README.md @@ -80,31 +80,30 @@ y = add_constant(x, 8) Use the helpers in `mllm_kernel.jit_utils`: -- `load_cpu_jit` -- `load_cuda_jit` +- `jit` - `make_cpp_args` -- `cache_once` -Example pattern: +Recommended pattern (CPU example): ```python import torch -from mllm_kernel.jit_utils import cache_once, load_cpu_jit, make_cpp_args - -@cache_once -def _jit_my_kernel_module(param: int): - args = make_cpp_args(param) - return load_cpu_jit( - "my_kernel", - *args, - cpp_files=["my_kernel.cpp"], - cpp_wrappers=[("my_kernel", f"my_namespace::my_kernel<{args}>")], - ) +import mllm_kernel + +@mllm_kernel.jit( + args=16, + device="cpu", + cpp_files=["my_kernel.cpp"], + cpp_wrappers=[("my_kernel", "my_namespace::my_kernel<16>")], + func_name="my_kernel", +) +def _my_kernel_16(compiled_module, dst: torch.Tensor, src: torch.Tensor) -> None: + compiled_module.my_kernel(dst, src) def my_kernel(src: torch.Tensor, param: int) -> torch.Tensor: + if param != 16: + raise ValueError("This demo only supports param=16.") dst = torch.empty_like(src) - module = _jit_my_kernel_module(param) - module.my_kernel(dst, src) + _my_kernel_16(dst, src) return dst ``` diff --git a/mllm-kernel/benchmarks/bench_create_kv_indices.py b/mllm-kernel/benchmarks/bench_create_kv_indices.py new file mode 100644 index 000000000..f570e66de --- /dev/null +++ b/mllm-kernel/benchmarks/bench_create_kv_indices.py @@ -0,0 +1,218 @@ +"""Benchmark create_kv_indices vs naive torch gather using torch.profiler. + +Example: + python benchmarks/bench_create_kv_indices.py --batch-size 512 --max-reqs 2048 --max-ctx 4096 +""" + +from __future__ import annotations + +import argparse + +import torch +from torch.profiler import ProfilerActivity, profile + +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + + +def _make_batch( + *, + max_reqs: int, + max_ctx: int, + batch_size: int, + use_start_offsets: bool, + device: torch.device, + seed: int, +): + g_cuda = torch.Generator(device=device).manual_seed(seed) + g_cpu = torch.Generator(device="cpu").manual_seed(seed) + + req_to_token = torch.arange( + max_reqs * max_ctx, dtype=torch.int32, device=device + ).reshape(max_reqs, max_ctx) + + assert batch_size <= max_reqs + req_pool_indices = torch.randperm(max_reqs, generator=g_cuda, device=device)[ + :batch_size + ].to(torch.int32) + + page_kernel_lens_list = [] + kv_start_idx_list = [] + for _ in range(batch_size): + L = int(torch.randint(1, max_ctx, (1,), generator=g_cpu).item()) + if use_start_offsets: + start_max = max_ctx - L + start = int(torch.randint(0, max(start_max, 1), (1,), generator=g_cpu).item()) + else: + start = 0 + page_kernel_lens_list.append(L) + kv_start_idx_list.append(start) + + page_kernel_lens = torch.tensor( + page_kernel_lens_list, dtype=torch.int32, device=device + ) + kv_start_idx = torch.tensor(kv_start_idx_list, dtype=torch.int32, device=device) + + kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device=device) + kv_indptr[0] = 0 + kv_indptr[1:] = torch.cumsum(page_kernel_lens, dim=0) + + kv_indices = torch.empty( + int(kv_indptr[-1].item()), dtype=torch.int32, device=device + ) + + return ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + +def _profile( + name: str, fn, *, warmup: int, iters: int, row_limit: int, trace_path: str | None +): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + + events = prof.key_averages() + time_attr = ( + "self_cuda_time_total" + if events and hasattr(events[0], "self_cuda_time_total") + else "self_device_time_total" + ) + sort_key = ( + "self_cuda_time_total" + if time_attr == "self_cuda_time_total" + else "self_device_time_total" + ) + total_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events) + avg_us = total_us / max(iters, 1) + + print(f"\n=== {name} ===") + print( + prof.key_averages().table( + sort_by=sort_key, + row_limit=row_limit, + ) + ) + print(f"{name} total self device time: {total_us:.2f} us") + print(f"{name} avg self device time/iter: {avg_us:.2f} us") + + if trace_path: + prof.export_chrome_trace(trace_path) + print(f"{name} trace exported: {trace_path}") + + return avg_us + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark create_kv_indices vs naive torch gather", + ) + parser.add_argument("--batch-size", type=int, default=512) + parser.add_argument("--max-reqs", type=int, default=2048) + parser.add_argument("--max-ctx", type=int, default=4096) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--row-limit", type=int, default=20) + parser.add_argument("--export-trace-dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + parser.add_argument( + "--use-start-offsets", + action="store_true", + help="Enable non-zero kv_start_idx to emulate sliding-window decode", + ) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + torch.manual_seed(args.seed) + device = torch.device("cuda") + + ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) = _make_batch( + max_reqs=args.max_reqs, + max_ctx=args.max_ctx, + batch_size=args.batch_size, + use_start_offsets=args.use_start_offsets, + device=device, + seed=args.seed, + ) + + print("=== create_kv_indices profiler benchmark ===") + print( + f"batch_size={args.batch_size}, max_reqs={args.max_reqs}, max_ctx={args.max_ctx}, " + f"use_start_offsets={args.use_start_offsets}" + ) + print(f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}") + + trace_dir = args.export_trace_dir.strip() + kernel_trace = f"{trace_dir}/create_kv_indices_trace.json" if trace_dir else None + torch_trace = f"{trace_dir}/torch_gather_trace.json" if trace_dir else None + + def _run_kernel_once(): + create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + def _run_torch_once(): + # Torch reference implementation on device: gather per-sequence ranges + # from req_to_token into a flat buffer. + out = [] + for i in range(args.batch_size): + req = req_pool_indices[i].item() + start = kv_start_idx[i].item() if args.use_start_offsets else 0 + L = page_kernel_lens[i].item() + row = req_to_token[req, start : start + L] + out.append(row) + torch.cat(out, out=kv_indices) + + kernel_avg_us = _profile( + "create_kv_indices", + _run_kernel_once, + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=kernel_trace, + ) + + torch_avg_us = _profile( + "torch_reference", + _run_torch_once, + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + + speedup = torch_avg_us / max(kernel_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/mllm-kernel/benchmarks/bench_store_cache.py b/mllm-kernel/benchmarks/bench_store_cache.py new file mode 100644 index 000000000..b96fa608b --- /dev/null +++ b/mllm-kernel/benchmarks/bench_store_cache.py @@ -0,0 +1,164 @@ +"""Benchmark store_cache vs torch index with torch.profiler. + +Example: +python benchmarks/bench_store_cache.py --warmup 20 --iters 200 --batch-size 512 --num-slots 8192 +""" + +import argparse + +import torch +from torch.profiler import ProfilerActivity, profile + +from mllm_kernel.cuda.jit import can_use_store_cache, store_cache + + +def _run_store_cache_once( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +): + store_cache(k, v, k_cache, v_cache, indices) + + +def _run_torch_index_once( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, +): + k_cache[indices] = k + v_cache[indices] = v + + +def _profile_path( + name: str, + fn, + *, + warmup: int, + iters: int, + row_limit: int, + trace_path: str | None, +): + for _ in range(warmup): + fn() + torch.cuda.synchronize() + + with profile( + activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], + record_shapes=False, + profile_memory=False, + with_stack=False, + ) as prof: + for _ in range(iters): + fn() + torch.cuda.synchronize() + + events = prof.key_averages() + # torch profiler times are in microseconds. + # PyTorch versions vary between *cuda* and *device* naming. + time_attr = ( + "self_cuda_time_total" + if events and hasattr(events[0], "self_cuda_time_total") + else "self_device_time_total" + ) + sort_key = ( + "self_cuda_time_total" + if time_attr == "self_cuda_time_total" + else "self_device_time_total" + ) + total_self_device_us = sum(float(getattr(evt, time_attr, 0.0)) for evt in events) + avg_self_device_us = total_self_device_us / max(iters, 1) + + print(f"\n=== {name} ===") + print( + prof.key_averages().table( + sort_by=sort_key, + row_limit=row_limit, + ) + ) + print(f"{name} total self device time: {total_self_device_us:.2f} us") + print(f"{name} avg self device time/iter: {avg_self_device_us:.2f} us") + + if trace_path: + prof.export_chrome_trace(trace_path) + print(f"{name} trace exported: {trace_path}") + + return avg_self_device_us + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Benchmark store_cache vs torch index using torch.profiler" + ) + parser.add_argument("--batch-size", type=int, default=1024) + parser.add_argument("--num-slots", type=int, default=16384) + parser.add_argument("--head-num", type=int, default=8) + parser.add_argument("--head-dim", type=int, default=128) + parser.add_argument( + "--dtype", + type=str, + default="float16", + choices=["float16", "bfloat16", "float32"], + ) + parser.add_argument("--warmup", type=int, default=50) + parser.add_argument("--iters", type=int, default=200) + parser.add_argument("--row-limit", type=int, default=20) + parser.add_argument("--export-trace-dir", type=str, default="") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required for this benchmark") + + torch.manual_seed(args.seed) + device = torch.device("cuda") + dtype = getattr(torch, args.dtype) + + row_dim = args.head_num * args.head_dim + row_bytes = row_dim * torch.tensor([], dtype=dtype).element_size() + if not can_use_store_cache(row_bytes): + raise RuntimeError(f"store_cache is unavailable for row_bytes={row_bytes}") + + k = torch.randn(args.batch_size, row_dim, device=device, dtype=dtype) + v = torch.randn(args.batch_size, row_dim, device=device, dtype=dtype) + # Use unique indices to avoid write conflicts. + indices = torch.randperm(args.num_slots, device=device)[: args.batch_size].to( + torch.int64 + ) + k_cache = torch.zeros(args.num_slots, row_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + print("=== store_cache profiler benchmark ===") + print( + f"shape: batch={args.batch_size}, row_dim={row_dim}, slots={args.num_slots}, dtype={dtype}" + ) + print(f"warmup={args.warmup}, iters={args.iters}, row_limit={args.row_limit}") + + trace_dir = args.export_trace_dir.strip() + store_trace = f"{trace_dir}/store_cache_trace.json" if trace_dir else None + torch_trace = f"{trace_dir}/torch_index_trace.json" if trace_dir else None + + store_avg_us = _profile_path( + "store_cache", + lambda: _run_store_cache_once(k, v, k_cache, v_cache, indices), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=store_trace, + ) + torch_avg_us = _profile_path( + "torch_index", + lambda: _run_torch_index_once(k, v, k_cache, v_cache, indices), + warmup=args.warmup, + iters=args.iters, + row_limit=args.row_limit, + trace_path=torch_trace, + ) + speedup = torch_avg_us / max(store_avg_us, 1e-12) + print(f"\nSpeedup: {speedup:.3f}x") + + +if __name__ == "__main__": + main() diff --git a/mllm-kernel/include/mllm_kernel/scalar_type.hpp b/mllm-kernel/include/mllm_kernel/scalar_type.hpp new file mode 100644 index 000000000..def41a12b --- /dev/null +++ b/mllm-kernel/include/mllm_kernel/scalar_type.hpp @@ -0,0 +1,260 @@ +#pragma once + +#include +#include +#ifndef __CUDACC__ +#include +#endif + +namespace host { + +// +// ScalarType can represent a wide range of floating point and integer types, +// in particular it can be used to represent sub-byte data types (something +// that torch.dtype currently does not support). +// +class ScalarType { + public: + enum NanRepr : uint8_t { + NAN_NONE = 0, // nans are not supported + NAN_IEEE_754 = 1, // nans are: exp all 1s, mantissa not all 0s + NAN_EXTD_RANGE_MAX_MIN = 2, // nans are: exp all 1s, mantissa all 1s + + NAN_REPR_ID_MAX + }; + + constexpr ScalarType(uint8_t exponent, uint8_t mantissa, bool signed_, int32_t bias, bool finite_values_only = false, + NanRepr nan_repr = NAN_IEEE_754) + : exponent(exponent), + mantissa(mantissa), + signed_(signed_), + bias(bias), + finite_values_only(finite_values_only), + nan_repr(nan_repr) {}; + + static constexpr ScalarType int_(uint8_t size_bits, int32_t bias = 0) { return ScalarType(0, size_bits - 1, true, bias); } + + static constexpr ScalarType uint(uint8_t size_bits, int32_t bias = 0) { return ScalarType(0, size_bits, false, bias); } + + // IEEE 754 compliant floating point type + static constexpr ScalarType float_IEEE754(uint8_t exponent, uint8_t mantissa) { + assert(mantissa > 0 && exponent > 0); + return ScalarType(exponent, mantissa, true, 0, false, NAN_IEEE_754); + } + + // IEEE 754 non-compliant floating point type + static constexpr ScalarType float_(uint8_t exponent, uint8_t mantissa, bool finite_values_only, NanRepr nan_repr) { + assert(nan_repr < NAN_REPR_ID_MAX); + assert(mantissa > 0 && exponent > 0); + assert(nan_repr != NAN_IEEE_754); + return ScalarType(exponent, mantissa, true, 0, finite_values_only, nan_repr); + } + + uint8_t const exponent; // size of the exponent field (0 for integer types) + uint8_t const mantissa; // size of the mantissa field (size of the integer + // excluding the sign bit for integer types) + bool const signed_; // flag if the type supports negative numbers (i.e. has a + // sign bit) + int32_t const bias; // stored values equal value + bias, + // used for quantized type + + // Extra Floating point info + bool const finite_values_only; // i.e. no +/-inf if true + NanRepr const nan_repr; // how NaNs are represented + // (not applicable for integer types) + + using Id = int64_t; + + private: + // Field size in id + template + static constexpr size_t member_id_field_width() { + using T = std::decay_t; + return std::is_same_v ? 1 : sizeof(T) * 8; + } + + template + static constexpr auto reduce_members_helper(Fn f, Init val, Member member, Rest... rest) { + auto new_val = f(val, member); + if constexpr (sizeof...(rest) > 0) { + return reduce_members_helper(f, new_val, rest...); + } else { + return new_val; + }; + } + + template + constexpr auto reduce_members(Fn f, Init init) const { + // Should be in constructor order for `from_id` + return reduce_members_helper(f, init, exponent, mantissa, signed_, bias, finite_values_only, nan_repr); + }; + + template + static constexpr auto reduce_member_types(Fn f, Init init) { + constexpr auto dummy_type = ScalarType(0, 0, false, 0, false, NAN_NONE); + return dummy_type.reduce_members(f, init); + }; + + static constexpr auto id_size_bits() { + return reduce_member_types([](int acc, auto member) -> int { return acc + member_id_field_width(); }, 0); + } + + public: + constexpr Id id() const { + static_assert(id_size_bits() <= sizeof(Id) * 8, "ScalarType id is too large to be stored"); + + auto or_and_advance = [](std::pair result, auto member) -> std::pair { + auto [id, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + return {id | (int64_t(member) & ((uint64_t(1) << bits) - 1)) << bit_offset, bit_offset + bits}; + }; + return reduce_members(or_and_advance, std::pair{}).first; + } + + static constexpr ScalarType from_id(Id id) { + auto extract_and_advance = [id](auto result, auto member) { + using T = decltype(member); + auto [tuple, bit_offset] = result; + auto constexpr bits = member_id_field_width(); + auto extracted_val = static_cast((int64_t(id) >> bit_offset) & ((uint64_t(1) << bits) - 1)); + auto new_tuple = std::tuple_cat(tuple, std::make_tuple(extracted_val)); + return std::pair{new_tuple, bit_offset + bits}; + }; + + auto [tuple_args, _] = reduce_member_types(extract_and_advance, std::pair, int>{}); + return std::apply([](auto... args) { return ScalarType(args...); }, tuple_args); + } + + constexpr int64_t size_bits() const { return mantissa + exponent + is_signed(); } + constexpr bool is_signed() const { return signed_; } + constexpr bool is_integer() const { return exponent == 0; } + constexpr bool is_floating_point() const { return exponent > 0; } + constexpr bool is_ieee_754() const { return is_floating_point() && finite_values_only == false && nan_repr == NAN_IEEE_754; } + constexpr bool has_nans() const { return is_floating_point() && nan_repr != NAN_NONE; } + constexpr bool has_infs() const { return is_floating_point() && finite_values_only == false; } + constexpr bool has_bias() const { return bias != 0; } + +#ifndef __CUDACC__ + private: + double _floating_point_max() const { + assert(mantissa <= 52 && exponent <= 11); + + uint64_t max_mantissa = (uint64_t(1) << mantissa) - 1; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN) { max_mantissa -= 1; } + + uint64_t max_exponent = (uint64_t(1) << exponent) - 2; + if (nan_repr == NAN_EXTD_RANGE_MAX_MIN || nan_repr == NAN_NONE) { + assert(exponent < 11); + max_exponent += 1; + } + + uint64_t exponent_bias = (uint64_t(1) << (exponent - 1)) - 1; + uint64_t exponent_bias_double = (uint64_t(1) << 10) - 1; // double e = 11 + + uint64_t max_exponent_double = max_exponent - exponent_bias + exponent_bias_double; + + uint64_t double_raw = (max_mantissa << (52 - mantissa)) | (max_exponent_double << 52); + + return *reinterpret_cast(&double_raw); + } + + constexpr std::variant _raw_max() const { + if (is_floating_point()) { + return {_floating_point_max()}; + } else { + assert(size_bits() < 64 || (size_bits() == 64 && is_signed())); + return {(int64_t(1) << mantissa) - 1}; + } + } + + constexpr std::variant _raw_min() const { + if (is_floating_point()) { + assert(is_signed()); + constexpr uint64_t sign_bit_double = (uint64_t(1) << 63); + + double max = _floating_point_max(); + uint64_t max_raw = *reinterpret_cast(&max); + uint64_t min_raw = max_raw | sign_bit_double; + return {*reinterpret_cast(&min_raw)}; + } else { + assert(!is_signed() || size_bits() <= 64); + if (is_signed()) { + return {INT64_MIN >> (64 - size_bits())}; + } else { + return {int64_t(0)}; + } + } + } + + public: + constexpr std::variant max() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_max()); + } + + constexpr std::variant min() const { + return std::visit([this](auto x) -> std::variant { return {x - bias}; }, _raw_min()); + } +#endif // __CUDACC__ + + public: + std::string str() const { + if (is_floating_point()) { + auto ret = "float" + std::to_string(size_bits()) + "_e" + std::to_string(exponent) + "m" + std::to_string(mantissa); + if (!is_ieee_754()) { + if (finite_values_only) { ret += "f"; } + if (nan_repr != NAN_NONE) { ret += "n"; } + } + return ret; + } else { + auto ret = ((is_signed()) ? "int" : "uint") + std::to_string(size_bits()); + if (has_bias()) { ret += "b" + std::to_string(bias); } + return ret; + } + } + + constexpr bool operator==(ScalarType const& other) const { + return mantissa == other.mantissa && exponent == other.exponent && bias == other.bias && signed_ == other.signed_ + && finite_values_only == other.finite_values_only && nan_repr == other.nan_repr; + } +}; + +using ScalarTypeId = ScalarType::Id; + +// "rust style" names +static inline constexpr auto kS4 = ScalarType::int_(4); +static inline constexpr auto kU4 = ScalarType::uint(4); +static inline constexpr auto kU4B8 = ScalarType::uint(4, 8); +static inline constexpr auto kS8 = ScalarType::int_(8); +static inline constexpr auto kU8 = ScalarType::uint(8); +static inline constexpr auto kU8B128 = ScalarType::uint(8, 128); + +static inline constexpr auto kFE2M1f = ScalarType::float_(2, 1, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE3M2f = ScalarType::float_(3, 2, true, ScalarType::NAN_NONE); +static inline constexpr auto kFE4M3fn = ScalarType::float_(4, 3, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE8M0fnu = ScalarType(8, 0, false, 0, true, ScalarType::NAN_EXTD_RANGE_MAX_MIN); +static inline constexpr auto kFE5M2 = ScalarType::float_IEEE754(5, 2); +static inline constexpr auto kFE8M7 = ScalarType::float_IEEE754(8, 7); +static inline constexpr auto kFE5M10 = ScalarType::float_IEEE754(5, 10); + +// Fixed width style names +static inline constexpr auto kInt4 = kS4; +static inline constexpr auto kUint4 = kU4; +static inline constexpr auto kUint4b8 = kU4B8; +static inline constexpr auto kInt8 = kS8; +static inline constexpr auto kUint8 = kU8; +static inline constexpr auto kUint8b128 = kU8B128; + +static inline constexpr auto kFloat4_e2m1f = kFE2M1f; +static inline constexpr auto kFloat6_e3m2f = kFE3M2f; +static inline constexpr auto kFloat8_e4m3fn = kFE4M3fn; +static inline constexpr auto kFloat8_e5m2 = kFE5M2; +static inline constexpr auto kFloat16_e8m7 = kFE8M7; +static inline constexpr auto kFloat16_e5m10 = kFE5M10; + +// colloquial names +static inline constexpr auto kHalf = kFE5M10; +static inline constexpr auto kFloat16 = kHalf; +static inline constexpr auto kBFloat16 = kFE8M7; + +static inline constexpr auto kFloat16Id = kFloat16.id(); +} // namespace host diff --git a/mllm-kernel/mllm_kernel/__main__.py b/mllm-kernel/mllm_kernel/__main__.py index d4888b86c..e5f0779d6 100644 --- a/mllm-kernel/mllm_kernel/__main__.py +++ b/mllm-kernel/mllm_kernel/__main__.py @@ -388,7 +388,7 @@ def main() -> None: logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") parser = argparse.ArgumentParser( - prog="python -m mllm_kernel", + prog="mllm_kernel", description="mllm-kernel helper commands.", ) parser.add_argument( diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh new file mode 100644 index 000000000..0b9e4c888 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/create_kv_indices.cuh @@ -0,0 +1,282 @@ +// High-performance CUDA kernel to build FlashInfer KV index arrays from +// pymllm's ReqToTokenPool mapping table. +// +// This is the CUDA-C equivalent of the Triton kernel +// `_create_kv_indices_triton` previously defined in +// `pymllm/layers/attention/flashinfer_backend.py`. +// +// Motivation +// ---------- +// FlashInfer's paged KV attention API expects a *flat* buffer of KV indices +// (`kv_indices`) together with a prefix-sum pointer array (`kv_indptr`). +// +// * `kv_indices` is a 1-D int32 array that stores, for every token of every +// sequence in a batch, the corresponding *slot index* in the KV cache. +// * `kv_indptr` (length = batch_size + 1) stores prefix sums over the +// per-sequence token counts. For sequence `i` we have tokens in: +// +// kv_indices[kv_indptr[i] : kv_indptr[i + 1]] +// +// In pymllm, the mapping from (request_slot, position_in_sequence) to KV slot +// index is stored in a 2-D tensor `req_to_token` owned by `ReqToTokenPool`: +// +// req_to_token[req_slot, position] -> kv_index (int32) +// +// For each batch we also know: +// * which request slots we are serving: `req_pool_indices[bs]` +// * how many tokens to use from each sequence: `page_kernel_lens[bs]` +// * the starting position inside each sequence: `kv_start_idx[bs]` (optional, +// used for sliding-window / partial-context attention) +// +// This kernel converts that 2-D layout into the flat `(kv_indptr, kv_indices)` +// layout in a single, highly parallel CUDA pass: +// +// For each sequence i in the batch: +// - let req = req_pool_indices[i] +// - let len = page_kernel_lens[i] +// - let start = kv_start_idx[i] (or 0 if not provided) +// - let offset = kv_indptr[i] +// - for j in [0, len): +// kv_indices[offset + j] = req_to_token[req, start + j] +// +// Requirements / invariants +// ------------------------- +// * `req_to_token` is int32 (aligned with sglang). +// * All tensors must reside on the same CUDA device. +// * The kernel is designed for extremely high throughput: +// - a block is assigned per sequence (batch element), +// - threads cooperate within the block to copy the token range with +// coalesced loads/stores. +// * Shape and dtype checks are performed at runtime via mllm_kernel's +// TensorMatcher utilities, so misuse is caught with clear error messages. +// +// Integration +// ----------- +// The exported entry point is `CreateKvIndicesKernel::run(...)`. The Python +// wrapper in `mllm_kernel/cuda/jit/create_kv_indices.py` JIT-compiles this +// kernel and exposes a `create_kv_indices(...)` function which is then called +// by `pymllm.layers.attention.flashinfer_backend`. + +#pragma once + +#include // TensorMatcher, SymbolicSize, SymbolicDevice, SymbolicDType +#include // div_ceil, RuntimeCheck, Panic +#include // LaunchKernel + +#include +#include + +#include + +namespace { + +// --------------------------------------------------------------------------- +// Parameter block passed to the CUDA kernel +// --------------------------------------------------------------------------- +// +// We keep this struct trivially-copyable so it can be passed via +// `__grid_constant__` if desired. Each field is carefully documented to make +// the data flow explicit. + +struct CreateKvIndicesParams { + // Pointer to ReqToTokenPool mapping table: + // req_to_token[req_slot, position] -> kv_index (int32) + // shape: [max_reqs, max_context_len] + const int32_t* __restrict__ req_to_token; + + // Request slots participating in this batch. + // shape: [batch_size] + const int32_t* __restrict__ req_pool_indices; + + // Number of tokens to copy for each sequence in the batch. + // shape: [batch_size] + const int32_t* __restrict__ page_kernel_lens; + + // Prefix sums over per-sequence token counts. + // kv_indptr[i] is the starting offset in kv_indices for sequence i. + // shape: [batch_size + 1] + const int32_t* __restrict__ kv_indptr; + + // Optional starting position inside each request's sequence. When nullptr, + // we assume start = 0 for all sequences. When non-null, shape is + // [batch_size]. + const int32_t* __restrict__ kv_start_idx; + + // Output flat KV index buffer (int32). Length must be at least + // kv_indptr[batch_size]. + int32_t* __restrict__ kv_indices; + + // Stride of the first dimension of req_to_token, i.e. the number of + // positions per request (max_context_len). + int32_t req_to_token_stride; + + // Number of sequences in the batch. + uint32_t batch_size; + + // Whether kv_start_idx is valid (1) or should be ignored (0). + uint32_t has_kv_start; +}; + +// We use a fixed block size chosen to balance occupancy and per-sequence +// parallelism. Each block is mapped to a single sequence and threads within +// the block cooperate to copy its token range. +constexpr int kBlockSize = 256; + +// --------------------------------------------------------------------------- +// Core CUDA kernel +// --------------------------------------------------------------------------- +// +// Grid mapping: +// * blockIdx.x -> sequence index `i` in [0, batch_size) +// * threadIdx.x -> intra-sequence worker; threads stride over the token +// range [0, len) with step `blockDim.x`. +// +// This design has several advantages: +// * No inter-block synchronisation is required. +// * Memory accesses are fully coalesced because each thread block walks a +// contiguous segment of the `req_to_token` and `kv_indices` arrays. +// * It handles variable-length sequences naturally; sequences with more +// tokens simply iterate more in the inner loop. + +__global__ void create_kv_indices_kernel(const CreateKvIndicesParams params) { + const uint32_t seq_id = blockIdx.x; // which sequence in the batch + if (seq_id >= params.batch_size) { return; } + + // Resolve the request slot for this sequence. + const int32_t req_slot = params.req_pool_indices[seq_id]; + + // Compute the output range [out_offset, out_offset + len) in kv_indices. + const int32_t out_offset = params.kv_indptr[seq_id]; + const int32_t len = params.page_kernel_lens[seq_id]; + + // Compute the starting position inside the original sequence. + int32_t start = 0; + if (params.has_kv_start && params.kv_start_idx != nullptr) { start = params.kv_start_idx[seq_id]; } + + // Base pointers for this sequence. + const int32_t* __restrict__ row = params.req_to_token + static_cast(req_slot) * params.req_to_token_stride; + int32_t* __restrict__ out = params.kv_indices + out_offset; + + // Each thread in the block handles a strided subset of [0, len). + for (int32_t t = threadIdx.x; t < len; t += blockDim.x) { + // Guard against out-of-bounds reads if (start + t) exceeds the + // configured context length. Under normal conditions upstream + // invariants guarantee `start + len <= req_to_token_stride`, but + // this check makes the kernel robust against misconfigured inputs + // and prevents rare segmentation faults observed during testing. + const int32_t pos = start + t; + if (pos < 0 || pos >= params.req_to_token_stride) { continue; } + + out[t] = row[pos]; + } +} + +// --------------------------------------------------------------------------- +// Host-side launcher used by the JIT wrapper +// --------------------------------------------------------------------------- +// +// `CreateKvIndicesKernel::run(...)` is the C++ entry point that will be bound +// to a TVM FFI function and called from Python via the JIT utility. It is +// responsible for: +// 1. Validating tensor shapes / dtypes / devices. +// 2. Extracting symbolic sizes and strides. +// 3. Building the parameter block. +// 4. Launching the CUDA kernel using mllm_kernel::host::LaunchKernel. + +struct CreateKvIndicesKernel { + static void run(tvm::ffi::TensorView req_to_token, tvm::ffi::TensorView req_pool_indices, + tvm::ffi::TensorView page_kernel_lens, tvm::ffi::TensorView kv_indptr, tvm::ffi::TensorView kv_start_idx, + tvm::ffi::TensorView kv_indices) { + using namespace mllm_kernel::host; + + // --------------------------------------------------------------------- + // 1. Validate input tensors + // --------------------------------------------------------------------- + // req_to_token: [max_reqs, max_context_len], int32, CUDA + SymbolicSize MaxReqs{"max_reqs"}; + SymbolicSize MaxCtx{"max_context_len"}; + SymbolicSize ReqStride{"req_stride"}; + SymbolicDType req_dtype; + SymbolicDevice device; + + (void)TensorMatcher({MaxReqs, MaxCtx}) + .with_strides({ReqStride, 1}) + .with_dtype(req_dtype) + .with_device(device) + .verify(req_to_token); + + // req_pool_indices: [B], int32, CUDA + SymbolicSize B{"batch_size"}; + SymbolicSize ReqPoolStride{"req_pool_stride"}; + (void)TensorMatcher({B}).with_strides({ReqPoolStride}).with_dtype().with_device(device).verify(req_pool_indices); + + // page_kernel_lens: [B], int32, same device + SymbolicSize PageStride{"page_stride"}; + (void)TensorMatcher({B}).with_strides({PageStride}).with_dtype().with_device(device).verify(page_kernel_lens); + + // kv_indptr: [Nind], int32, same device (we later require Nind >= B + 1) + SymbolicSize Nind{"indptr_len"}; + (void)TensorMatcher({Nind}).with_dtype().with_device(device).verify(kv_indptr); + + // kv_start_idx: either [B] or [0]; int32, same device + SymbolicSize StartLen{"start_len"}; + SymbolicSize StartStride{"start_stride"}; + (void)TensorMatcher({StartLen}).with_strides({StartStride}).with_dtype().with_device(device).verify(kv_start_idx); + + // kv_indices: [Nidx], int32, same device + SymbolicSize Nidx{"num_indices"}; + (void)TensorMatcher({Nidx}).with_dtype().with_device(device).verify(kv_indices); + + // Extract concrete sizes. + const int64_t batch_size = B.unwrap(); + const int64_t indptr_len = Nind.unwrap(); + const int64_t req_stride = ReqStride.unwrap(); + + // Basic consistency checks. + RuntimeCheck(batch_size > 0, "batch_size must be positive, got ", batch_size); + RuntimeCheck(indptr_len >= batch_size + 1, "kv_indptr length (", indptr_len, ") must be at least batch_size+1 (", + batch_size + 1, ")"); + + // NOTE: We intentionally do NOT read kv_indptr[batch_size] on the host to + // validate that kv_indices is large enough. kv_indptr resides in device + // memory and dereferencing it from host code would be an illegal memory + // access (segfault). Callers are responsible for ensuring that + // kv_indices.numel() >= kv_indptr[batch_size]. + + // kv_start_idx is optional; when StartLen == 0 we treat it as absent. + RuntimeCheck(StartLen.unwrap() == 0 || StartLen.unwrap() == batch_size, + "kv_start_idx must have length 0 or batch_size; got ", StartLen.unwrap(), " vs batch_size=", batch_size); + + const bool has_kv_start = (StartLen.unwrap() == batch_size); + + // --------------------------------------------------------------------- + // 2. Build parameter block + // --------------------------------------------------------------------- + CreateKvIndicesParams params{ + .req_to_token = static_cast(req_to_token.data_ptr()), + .req_pool_indices = static_cast(req_pool_indices.data_ptr()), + .page_kernel_lens = static_cast(page_kernel_lens.data_ptr()), + .kv_indptr = static_cast(kv_indptr.data_ptr()), + .kv_start_idx = has_kv_start ? static_cast(kv_start_idx.data_ptr()) : nullptr, + .kv_indices = static_cast(kv_indices.data_ptr()), + .req_to_token_stride = static_cast(req_stride), + .batch_size = static_cast(batch_size), + .has_kv_start = has_kv_start ? 1u : 0u, + }; + + const DLDevice dl_device = device.unwrap(); + + // --------------------------------------------------------------------- + // 3. Launch the CUDA kernel + // --------------------------------------------------------------------- + // We launch one block per sequence so that each sequence can be processed + // independently with fully coalesced memory accesses. The per-thread + // inner loop runs over the token range [0, len) with stride = blockDim.x. + + const int grid_size = static_cast(batch_size); + + LaunchKernel(grid_size, kBlockSize, dl_device)(create_kv_indices_kernel, params); + } +}; + +} // namespace diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh new file mode 100644 index 000000000..4c2833c06 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gdn_decode.cuh @@ -0,0 +1,432 @@ +// Fused GDN (Gated Delta Net) decode kernel for linear attention. +// +// Performs a single-token recurrent update per request: +// g = -exp(A_log) * softplus(a + dt_bias) +// beta = sigmoid(b) +// q = L2norm(q) * scale +// k = L2norm(k) +// state *= exp(g) (decay) +// v_delta = v - state @ k (delta rule) +// v_delta *= beta (gated update) +// state += v_delta outer k (state update) +// output = state @ q (readout) +// +// Works on SM80+ (Ampere, Jetson Orin, Hopper, ...). +// Matches the algorithm of sglang's fused_sigmoid_gating_delta_rule_update. +// +// Grid : (NV, bs * HV) where NV = ceil(V / BV) +// Block: BLOCK_K threads (one thread per K-dimension element) +// +// Each thread owns BV state elements at its K position. +// Two cross-thread reductions (over K) compute delta and output dot products. + +#pragma once + +#include +#include + +#include +#include + +#include +#include +#include + +#include + +namespace GDNDecodeKernel { + +// --------------------------------------------------------------------------- +// Constants +// --------------------------------------------------------------------------- + +inline constexpr int BV = 32; // V-dimension tile size + +// --------------------------------------------------------------------------- +// Warp-level reduction +// --------------------------------------------------------------------------- + +__device__ __forceinline__ float warp_reduce_sum(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, offset); + } + return val; +} + +// --------------------------------------------------------------------------- +// Type conversion helpers +// --------------------------------------------------------------------------- + +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float<__half>(__half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ __forceinline__ float to_float(float val) { + return val; +} + +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ __half from_float<__half>(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ __forceinline__ float from_float(float val) { + return val; +} + +// --------------------------------------------------------------------------- +// Block-level scalar reduction (sum across all threads → broadcast result) +// --------------------------------------------------------------------------- + +// Reduces a scalar across all threads in the block. +// Returns the sum in ALL threads (via shared memory broadcast). +// smem must have at least (blockDim.x / 32) floats. +__device__ __forceinline__ float block_reduce_sum(float val, float* smem) { + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int num_warps = blockDim.x / 32; + + val = warp_reduce_sum(val); + if (lane_id == 0) smem[warp_id] = val; + __syncthreads(); + + // First warp reduces across warps + if (warp_id == 0) { + float v = (lane_id < num_warps) ? smem[lane_id] : 0.0f; + v = warp_reduce_sum(v); + if (lane_id == 0) smem[0] = v; + } + __syncthreads(); + return smem[0]; +} + +// --------------------------------------------------------------------------- +// Block-level vector reduction: BV independent sums across all K threads +// --------------------------------------------------------------------------- + +// Each thread contributes partial[0..BV-1]. After this call, the results +// are written to out[0..BV-1] and are valid in all threads. +// reduce_buf must have at least BV * num_warps floats. +// broadcast_buf must have at least BV floats. +__device__ __forceinline__ void block_reduce_bv( + float partial[BV], + float* reduce_buf, // [num_warps * BV] + float* broadcast_buf, // [BV] + float out[BV] +) { + const int warp_id = threadIdx.x / 32; + const int lane_id = threadIdx.x % 32; + const int num_warps = blockDim.x / 32; + + // Intra-warp reduction for each bv + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + float val = warp_reduce_sum(partial[bv]); + if (lane_id == 0) { + reduce_buf[warp_id * BV + bv] = val; + } + } + __syncthreads(); + + // Inter-warp reduction: threads 0..BV-1 each reduce one bv + if (threadIdx.x < BV) { + float sum = 0.0f; + #pragma unroll 8 + for (int w = 0; w < num_warps; w++) { + sum += reduce_buf[w * BV + threadIdx.x]; + } + broadcast_buf[threadIdx.x] = sum; + } + __syncthreads(); + + // Broadcast to all threads + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + out[bv] = broadcast_buf[bv]; + } +} + +// --------------------------------------------------------------------------- +// Main GDN decode kernel +// --------------------------------------------------------------------------- + +template +__global__ void gdn_decode_kernel( + const T* __restrict__ q_ptr, // [bs, H, K] + const T* __restrict__ k_ptr, // [bs, H, K] + const T* __restrict__ v_ptr, // [bs, HV, V] + const T* __restrict__ a_ptr, // [bs, HV] + const T* __restrict__ b_ptr, // [bs, HV] + const float* __restrict__ A_log_ptr, // [HV] + const float* __restrict__ dt_bias_ptr, // [HV] + float* __restrict__ state_pool, // [pool_size, HV, V, K] + const int64_t* __restrict__ cache_indices, // [bs] + T* __restrict__ output_ptr, // [bs, HV, V] + const int bs, + const int H, // num_k_heads + const int HV, // num_v_heads + const int K, // head_k_dim + const int V, // head_v_dim + const float scale // K^-0.5 +) { + // Block indices + const int bv_block = blockIdx.x; // V-tile index + const int batch_head = blockIdx.y; // batch * HV + const int i_n = batch_head / HV; // batch index + const int i_hv = batch_head % HV; // value head index + const int i_h = i_hv * H / HV; // key head index (GQA mapping) + const int k_idx = threadIdx.x; // K-dimension index + const int v_start = bv_block * BV; // V-dimension start + + if (i_n >= bs) return; + + // Shared memory layout (declared dynamically) + extern __shared__ float smem[]; + const int num_warps = BLOCK_K / 32; + float* sq = smem; // [BLOCK_K] + float* sk = smem + BLOCK_K; // [BLOCK_K] + float* sv_broadcast = smem + 2 * BLOCK_K; // [BV] + float* warp_buf = smem + 2 * BLOCK_K + BV; // [num_warps] + float* reduce_buf = smem + 2 * BLOCK_K + BV + num_warps; // [BV * num_warps] + + // ===== 1. Load gating parameters and compute decay + beta ===== + // All threads load the same scalars (cheap, avoids shared memory) + const float A_log_val = A_log_ptr[i_hv]; + const float dt_bias_val = dt_bias_ptr[i_hv]; + const float a_val = to_float(a_ptr[i_n * HV + i_hv]); + const float b_val = to_float(b_ptr[i_n * HV + i_hv]); + + const float x = a_val + dt_bias_val; + // softplus with numerical stability: softplus(x) = log(1+exp(x)), or x for x>20 + const float softplus_x = (x <= 20.0f) ? logf(1.0f + expf(x)) : x; + const float g = -expf(A_log_val) * softplus_x; + const float decay = expf(g); + const float beta = 1.0f / (1.0f + expf(-b_val)); + + // ===== 2. Load q, k and compute L2 norms ===== + float q_val = 0.0f, k_val = 0.0f; + if (k_idx < K) { + q_val = to_float(q_ptr[i_n * H * K + i_h * K + k_idx]); + k_val = to_float(k_ptr[i_n * H * K + i_h * K + k_idx]); + } + + // L2 norm: reduce q*q and k*k across block + float q_sq_sum = block_reduce_sum(q_val * q_val, warp_buf); + float k_sq_sum = block_reduce_sum(k_val * k_val, warp_buf); + + float q_norm = rsqrtf(q_sq_sum + 1e-6f); + float k_norm = rsqrtf(k_sq_sum + 1e-6f); + + // Store normalized q (scaled) and k in shared memory + if (k_idx < K) { + sq[k_idx] = q_val * q_norm * scale; + sk[k_idx] = k_val * k_norm; + } else { + sq[k_idx] = 0.0f; + sk[k_idx] = 0.0f; + } + __syncthreads(); + + // ===== 3. Load state elements for this thread ===== + const int64_t pool_idx = cache_indices[i_n]; + // state_pool layout: [pool_size, HV, V, K] + const int64_t state_base = pool_idx * HV * V * K + i_hv * V * K; + + float state[BV]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + const int v_idx = v_start + bv; + if (v_idx < V && k_idx < K) { + state[bv] = state_pool[state_base + (int64_t)v_idx * K + k_idx]; + } else { + state[bv] = 0.0f; + } + } + + // ===== 4. Decay: state *= exp(g) ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + state[bv] *= decay; + } + + // ===== 5. Delta: v_delta[bv] = v[bv] - sum_k(state[bv,k] * k_norm[k]) ===== + float partial_delta[BV]; + const float my_k = sk[k_idx]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + partial_delta[bv] = state[bv] * my_k; + } + + float delta[BV]; + block_reduce_bv(partial_delta, reduce_buf, sv_broadcast, delta); + + // Compute v_delta = (v - delta) * beta and broadcast to all threads. + // Threads 0..BV-1 each load one v element, compute v_delta, write to smem. + if (k_idx < BV) { + const int my_v_idx = v_start + k_idx; + float my_v = (my_v_idx < V) + ? to_float(v_ptr[i_n * HV * V + i_hv * V + my_v_idx]) + : 0.0f; + sv_broadcast[k_idx] = (my_v - delta[k_idx]) * beta; + } + __syncthreads(); + + float v_delta[BV]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + v_delta[bv] = sv_broadcast[bv]; + } + + // ===== 6. State update: state[bv,k] += v_delta[bv] * k_norm[k] ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + state[bv] += v_delta[bv] * my_k; + } + + // ===== 7. Output: o[bv] = sum_k(state[bv,k] * q_norm_scaled[k]) ===== + float partial_out[BV]; + const float my_q = sq[k_idx]; + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + partial_out[bv] = state[bv] * my_q; + } + + float out_vals[BV]; + block_reduce_bv(partial_out, reduce_buf, sv_broadcast, out_vals); + + // ===== 8. Store output ===== + // output layout: [bs, HV, V] + if (k_idx < BV) { + const int v_idx = v_start + k_idx; + if (v_idx < V) { + output_ptr[i_n * HV * V + i_hv * V + v_idx] = from_float(out_vals[k_idx]); + } + } + + // ===== 9. Store state back to pool ===== + #pragma unroll + for (int bv = 0; bv < BV; bv++) { + const int v_idx = v_start + bv; + if (v_idx < V && k_idx < K) { + state_pool[state_base + (int64_t)v_idx * K + k_idx] = state[bv]; + } + } +} + +// --------------------------------------------------------------------------- +// Launch wrapper (called via TVM FFI) +// --------------------------------------------------------------------------- + +void run( + tvm::ffi::TensorView q, // [bs, H, K] + tvm::ffi::TensorView k, // [bs, H, K] + tvm::ffi::TensorView v, // [bs, HV, V] + tvm::ffi::TensorView a, // [bs, HV] + tvm::ffi::TensorView b, // [bs, HV] + tvm::ffi::TensorView A_log, // [HV] + tvm::ffi::TensorView dt_bias, // [HV] + tvm::ffi::TensorView state_pool, // [pool_size, HV, V, K] + tvm::ffi::TensorView cache_indices, // [bs] + tvm::ffi::TensorView output // [bs, HV, V] +) { + using namespace mllm_kernel::host; + + // --- Extract dimensions --- + auto BS = SymbolicSize{"bs"}; + auto H_ = SymbolicSize{"H"}; + auto HV_ = SymbolicSize{"HV"}; + auto K_ = SymbolicSize{"K"}; + auto V_ = SymbolicSize{"V"}; + auto PS = SymbolicSize{"pool_size"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({BS, H_, K_}).with_dtype(dtype).with_device(device).verify(q); + (void)TensorMatcher({BS, H_, K_}).with_dtype(dtype).with_device(device).verify(k); + (void)TensorMatcher({BS, HV_, V_}).with_dtype(dtype).with_device(device).verify(v); + (void)TensorMatcher({BS, HV_}).with_dtype(dtype).with_device(device).verify(a); + (void)TensorMatcher({BS, HV_}).with_dtype(dtype).with_device(device).verify(b); + (void)TensorMatcher({HV_}).with_dtype().with_device(device).verify(A_log); + (void)TensorMatcher({HV_}).with_dtype().with_device(device).verify(dt_bias); + (void)TensorMatcher({PS, HV_, V_, K_}).with_dtype().with_device(device).verify(state_pool); + (void)TensorMatcher({BS}).with_device(device).verify(cache_indices); + (void)TensorMatcher({BS, HV_, V_}).with_dtype(dtype).with_device(device).verify(output); + + const int bs = static_cast(BS.unwrap()); + const int H = static_cast(H_.unwrap()); + const int HV = static_cast(HV_.unwrap()); + const int K = static_cast(K_.unwrap()); + const int V = static_cast(V_.unwrap()); + const float scale = 1.0f / sqrtf(static_cast(K)); + + // Block size = K (rounded up to warp multiple, max 1024) + int block_k = ((K + 31) / 32) * 32; + if (block_k > 1024) block_k = 1024; + const int num_warps = block_k / 32; + + // Grid + const int NV = (V + BV - 1) / BV; + dim3 grid(NV, bs * HV); + dim3 block(block_k); + + // Dynamic shared memory: sq[block_k] + sk[block_k] + sv[BV] + warp_buf[nw] + reduce[BV*nw] + const size_t smem_bytes = (2 * block_k + BV + num_warps + BV * num_warps) * sizeof(float); + + const DLDevice dl_device = device.unwrap(); + + // Typed launch helper + #define LAUNCH_GDN_DECODE(CType, BKVAL) \ + LaunchKernel(grid, block, dl_device, smem_bytes)( \ + gdn_decode_kernel, \ + static_cast(q.data_ptr()), \ + static_cast(k.data_ptr()), \ + static_cast(v.data_ptr()), \ + static_cast(a.data_ptr()), \ + static_cast(b.data_ptr()), \ + static_cast(A_log.data_ptr()), \ + static_cast(dt_bias.data_ptr()), \ + static_cast(state_pool.data_ptr()), \ + static_cast(cache_indices.data_ptr()), \ + static_cast(output.data_ptr()), \ + bs, H, HV, K, V, scale \ + ) + + // Dispatch based on dtype and block size + if (dtype.is_type()) { + if (block_k == 64) { LAUNCH_GDN_DECODE(__nv_bfloat16, 64); } + else if (block_k == 128) { LAUNCH_GDN_DECODE(__nv_bfloat16, 128); } + else if (block_k == 256) { LAUNCH_GDN_DECODE(__nv_bfloat16, 256); } + else { LAUNCH_GDN_DECODE(__nv_bfloat16, 256); } + } else { + if (block_k == 64) { LAUNCH_GDN_DECODE(__half, 64); } + else if (block_k == 128) { LAUNCH_GDN_DECODE(__half, 128); } + else if (block_k == 256) { LAUNCH_GDN_DECODE(__half, 256); } + else { LAUNCH_GDN_DECODE(__half, 256); } + } + + #undef LAUNCH_GDN_DECODE +} + +} // namespace GDNDecodeKernel diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/awq_marlin_repack.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/awq_marlin_repack.cuh new file mode 100644 index 000000000..71ace4470 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/awq_marlin_repack.cuh @@ -0,0 +1,251 @@ +#pragma once + +#include + +#include + +#include "marlin.cuh" + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + return; +} +#else + +template +__global__ void awq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, int size_k, int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, (int)gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int tile_n_ints = tile_n_size / pack_factor; + + constexpr int stage_n_threads = tile_n_ints / 4; + constexpr int stage_k_threads = tile_k_size; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + int first_n_packed = first_n / pack_factor; + + int4* sh_ptr = sh + stage_size * pipe; + + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast( + &(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) + first_n_packed + (n_id * 4)]))); + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + int cur_n_packed = cur_n / pack_factor; + int cur_n_pos = cur_n % pack_factor; + + constexpr int sh_stride = tile_n_ints; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + // Undo interleaving + int cur_n_pos_unpacked; + if constexpr (num_bits == 4) { + constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } else { + constexpr int undo_pack[4] = {0, 2, 1, 3}; + cur_n_pos_unpacked = undo_pack[cur_n_pos]; + } + + uint32_t vals[8]; +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + + int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem]; + int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) + sh_stride * cur_elem]; + + vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask; + vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask; + } + + constexpr int tile_size_val = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size_val; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} +#endif + +} // namespace device::marlin + +// Host wrapper +void awq_marlin_repack( + tvm::ffi::TensorView out, tvm::ffi::TensorView b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { + using namespace host; + using namespace device::marlin; + + // Validate alignment + RuntimeCheck(size_k % tile_k_size == 0, "size_k = ", size_k, " is not divisible by tile_k_size = ", tile_k_size); + RuntimeCheck(size_n % tile_n_size == 0, "size_n = ", size_n, " is not divisible by tile_n_size = ", tile_n_size); + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + + int const pack_factor = 32 / num_bits; + + // Validate tensors + SymbolicDevice cuda_device; + cuda_device.set_options(); + + TensorMatcher({size_k, size_n / pack_factor}).with_dtype().with_device(cuda_device).verify(b_q_weight); + + TensorMatcher({size_k / tile_size, size_n * tile_size / pack_factor}) + .with_dtype() + .with_device(cuda_device) + .verify(out); + + // Get device and stream + auto device = cuda_device.unwrap(); + auto stream = LaunchKernel::resolve_device(device); + + // Get pointers + auto* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + auto* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get device attributes + int blocks = 0; + cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, device.device_id); + + int max_shared_mem = 0; + cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, device.device_id); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + // Dispatch based on num_bits + if (num_bits == 4) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else if (num_bits == 8) { + cudaFuncSetAttribute( + awq_marlin_repack_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); + LaunchKernel(blocks, repack_threads, stream, max_shared_mem)( + awq_marlin_repack_kernel, + b_q_weight_ptr, + out_ptr, + static_cast(size_k), + static_cast(size_n)); + } else { + RuntimeCheck(false, "Unsupported repack config: num_bits = ", num_bits); + } +} diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/dequant.h b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/dequant.h new file mode 100644 index 000000000..d194cf3ec --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/dequant.h @@ -0,0 +1,504 @@ +/* +Fast Dequantization (Converting INT4/INT8/FP4/FP8 to FP16/BF16) + +The process of fast dequantization can be summarized as a combination +of bitwise operations and floating-point computations: + +weight =>(bit_op / bitwise operations)=> +f16_value =>(flop / floating-point computation)=> +dequantized_weight + +Since the dequantized weights typically require subtracting the zero point and +applying a scale factor, the floating-point computation step can be fused with +the zero-point subtraction and scaling operations. + +The following are the parts that need to be modified for the fused operation +of zero-point subtraction and scaling. + +## INT4 => FP16/BF16 or INT8 => FP16 + +The floating-point computation is `__hsub2` + +If has zero points: + + flop(bit_op(weight)) - flop(bit_op(zp)) + = sub(bit_op(weight), bias) - sub(bit_op(zp), bias) + = bit_op(weight) - bit_op(zp) + +so we don't need additional modification. + +If has float zero points: + + flop(bit_op(weight)) - fzp + = sub(bit_op(weight), bias) - fzp + = bit_op(weight) - (fzp + bias) + +where the `fzp + bias` can be computed at weight loading. But this +may have accuracy issue, so we should not use this in most cases. + +If has not zero points: + + scale(flop(bit_op(weight))) + = scale(sub(bit_op(weight), bias)) + = scale(bit_op(weight)) - scale(bias) + = fma(bit_op(weight), scale_factor, scale(bias)) + +where the `scale(bias)` can be cached. But this may have accuracy issue, +so we should not use this in most cases. + + +## INT8 => BF16 + +INT8 => BF16 is a special case, it use byte_perm instead of flop. +We cannot fused byte_perm with scaling. + + +## FP4/FP8 => FP16/BF16 + + scale(flop(bit_op(weight))) + = scale(mul(bit_op(weight), multiplier)) + = mul(bit_op(weight), scale_factor * multiplier) + +where `scale_factor * multiplier` can be computed at weight loading. + +*/ + +#include "marlin_dtypes.cuh" + +namespace device::marlin { + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 +// Lookup-table based 3-input logical operation; explicitly used for +// dequantization as the compiler does not seem to automatically recognize it in +// all cases. +template +__device__ inline int lop3(int a, int b, int c) { + int res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" : "=r"(res) : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} + +// Constructs destination register by taking bytes from 2 sources (based on +// mask) +template +__device__ inline uint32_t prmt(uint32_t a) { + uint32_t res; + asm volatile("prmt.b32 %0, %1, %2, %3;\n" : "=r"(res) : "r"(a), "n"(start_byte), "n"(mask)); + return res; +} + +template +__device__ inline void dequant(int q, scalar_t2* frag_b); + +// +// Efficiently dequantize 4bit values packed in an int32 value into a full +// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below, +// with some small changes: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385 +// +template<> +__device__ inline void dequant(int q, half2* frag_b) { + const int MASK = 0x000f000f; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64086408; + const int MUL = 0x2c002c00; + const int ADD = 0xd480d480; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = + __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + const int LO = 0x000f000f; + const int HI = 0x00f000f0; + const int EX = 0x64006400; + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX); + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX); + // clang-format on + // We want signed int4 outputs, hence we fuse the `-8` symmetric zero point + // directly into `SUB` and `ADD`. + const int SUB = 0x64006400; + const int MUL = 0x2c002c00; + const int ADD = 0xd400d400; + frag_b[0] = __hsub2(*reinterpret_cast(&lo), *reinterpret_cast(&SUB)); + frag_b[1] = + __hfma2(*reinterpret_cast(&hi), *reinterpret_cast(&MUL), *reinterpret_cast(&ADD)); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + static constexpr uint32_t MASK = 0x000f000f; + static constexpr uint32_t EX = 0x43004300; + + // Guarantee that the `(a & b) | c` operations are LOP3s. + // clang-format off + int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + q >>= 4; + int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX); + // clang-format on + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43084308; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t SUB = 0x43004300; + + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&SUB)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&SUB)); +} + +// +// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or +// bf16 Reference: +// - FP16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85 +// - BF16: +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175 +// +template<> +__device__ inline void dequant(int q, half2* frag_b) { + static constexpr uint32_t mask_for_elt_01 = 0x5250; + static constexpr uint32_t mask_for_elt_23 = 0x5351; + static constexpr uint32_t start_byte_for_fp16 = 0x64646464; + + uint32_t lo = prmt(q); + uint32_t hi = prmt(q); + + frag_b[0] = *reinterpret_cast(&lo); + frag_b[1] = *reinterpret_cast(&hi); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64806480; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400; + frag_b[0] = __hsub2(frag_b[0], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); + frag_b[1] = __hsub2(frag_b[1], *reinterpret_cast(&I8s_TO_F16s_MAGIC_NUM)); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388736.f; + fp32_intermediates[1] -= 8388736.f; + fp32_intermediates[2] -= 8388736.f; + fp32_intermediates[3] -= 8388736.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + float fp32_intermediates[4]; + uint32_t* fp32_intermediates_casted = reinterpret_cast(fp32_intermediates); + + static constexpr uint32_t fp32_base = 0x4B000000; + fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650); + fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652); + fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651); + fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653); + + fp32_intermediates[0] -= 8388608.f; + fp32_intermediates[1] -= 8388608.f; + fp32_intermediates[2] -= 8388608.f; + fp32_intermediates[3] -= 8388608.f; + + uint32_t* bf16_result_ptr = reinterpret_cast(frag_b); + bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0], fp32_intermediates_casted[1], 0x7632); + bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2], fp32_intermediates_casted[3], 0x7632); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and FP16 formats + constexpr int FP8_EXPONENT = 4, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP8 (E4M3) and BF16 formats + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to bfloat162 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template<> +__device__ inline void dequant(int q, half2* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, FP16_EXPONENT = 5; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (FP16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + // Constants for FP4 (E2M1) and FP16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP4_EXPONENT; + constexpr int MASK = 0x70007000; + + // Extract and shift FP4 values to FP16 format + int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 4; + int Out2 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template<> +__device__ inline void dequant(int q, nv_bfloat162* frag_b) { + dequant(q, frag_b); + + // Constants for FP4 (E2M1) and BF16 formats + constexpr int FP4_EXPONENT = 2, BF16_EXPONENT = 8; + + // Construct and apply exponent bias + constexpr int BIAS_OFFSET = (1 << (BF16_EXPONENT - 1)) - (1 << (FP4_EXPONENT - 1)); + // Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent + // position + constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + + // Convert to half2 and apply bias + frag_b[1] = __hmul2(frag_b[1], bias_reg); + frag_b[0] = __hmul2(frag_b[0], bias_reg); +} + +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template<> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template<> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +// New version with s_type_id parameter for marlin_moe_wna16_v2 +template +__device__ inline void dequant_fp8_scales(int q, scalar_t2* frag_b); + +template<> +__device__ inline void dequant_fp8_scales(int q, half2* frag_b) { + int Out1 = (q & 0xFF00FF00) >> 1; + ; + q <<= 8; + int Out2 = (q & 0xFF00FF00) >> 1; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +}; + +template<> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + constexpr int FP8_EXPONENT = 4, BF16_EXPONENT = 8; + constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT; + constexpr int MASK = 0x7F007F00; + + // Extract and shift FP8 values to BF16 format + int Out1 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + q <<= 8; + int Out2 = ((q & 0x80008000) >> 1) | ((q & MASK) >> RIGHT_SHIFT); + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +template<> +__device__ inline void dequant_fp8_scales(int q, nv_bfloat162* frag_b) { + // In this conversion, 2 ** -127 in FP8E8M0 would become 0 in BF16, + // but we assume that such a extreme value would not occur in real models. + int Out1 = (q & 0xFF00FF00) >> 1; + q <<= 7; + int Out2 = q & 0x7F807F80; + + // Note: reverse indexing is intentional because weights are permuted + frag_b[1] = *reinterpret_cast(&Out1); + frag_b[0] = *reinterpret_cast(&Out2); +} + +#endif + +} // namespace device::marlin diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin.cuh new file mode 100644 index 000000000..02b3f5222 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin.cuh @@ -0,0 +1,1001 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include + +#include + +#include "kernel.h" +#include "marlin_template.h" + +namespace device::marlin { + +__global__ void MarlinDefault(MARLIN_KERNEL_PARAMS){}; + +using MarlinFuncPtr = void (*)(MARLIN_KERNEL_PARAMS); + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) {} + +#else + +// For a given "a" of size [M,K] performs a permutation of the K columns based +// on the given "perm" indices. +__global__ void permute_cols_kernel( + int4 const* __restrict__ a_int4_ptr, + int const* __restrict__ perm_int_ptr, + int4* __restrict__ out_int4_ptr, + int size_m, + int size_k, + int lda, + int block_rows) { + auto start_row = block_rows * blockIdx.x; + int finish_row = start_row + block_rows; + if (finish_row > size_m) { + finish_row = size_m; + } + int cur_block_rows = finish_row - start_row; + + int input_row_stride = lda * sizeof(half) / 16; + int output_row_stride = size_k * sizeof(half) / 16; + + auto permute_row = [&](int row) { + int iters = size_k / default_threads; + int rest = size_k % default_threads; + + int input_offset = row * input_row_stride; + int output_offset = row * output_row_stride; + + half const* a_row_half = reinterpret_cast(a_int4_ptr + input_offset); + half* out_half = reinterpret_cast(out_int4_ptr + output_offset); + + int base_k = 0; + + for (int i = 0; i < iters; i++) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + + base_k += default_threads; + } + + if (rest) { + if (threadIdx.x < rest) { + auto cur_k = base_k + threadIdx.x; + int src_pos = perm_int_ptr[cur_k]; + + out_half[cur_k] = a_row_half[src_pos]; + } + } + }; + + for (int i = 0; i < cur_block_rows; i++) { + int cur_row = start_row + i; + if (cur_row < size_m) { + permute_row(cur_row); + } + } +} + +typedef struct { + int thread_k; + int thread_n; + int num_threads; +} thread_config_t; + +thread_config_t small_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {128, 128, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +thread_config_t large_batch_thread_configs[] = { + // Ordered by priority + + // thread_k, thread_n, num_threads + {64, 256, 256}, + {64, 128, 128}, + {128, 64, 128}}; + +typedef struct { + int blocks_per_sm; + thread_config_t tb_cfg; +} exec_config_t; + +int get_scales_cache_size( + thread_config_t const& th_config, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full) { + bool cache_scales_chunk = has_act_order && !is_k_full; + + int tb_n = th_config.thread_n; + int tb_k = th_config.thread_k; + + // Get max scale groups per thread-block + int tb_groups; + if (group_size == -1) { + tb_groups = 1; + } else if (group_size == 0) { + tb_groups = div_ceil(tb_k, 32); // Worst case is 32 group size + } else { + tb_groups = div_ceil(tb_k, group_size); + } + + if (cache_scales_chunk) { + int load_groups = tb_groups * pipe_stages * 2; // Chunk size is 2x pipeline over dim K + load_groups = max(load_groups, 32); // We load at least 32 scale groups + return load_groups * tb_n * 2; + } else { + int tb_scales = tb_groups * tb_n * 2; + + return tb_scales * pipe_stages; + } +} + +int get_kernel_cache_size( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float) { + int pack_factor = 32 / num_bits; + + // Get B size + int tb_k = th_config.thread_k; + int tb_n = th_config.thread_n; + int tb_m = thread_m_blocks * 16; + int sh_a_size = pipe_stages * (tb_m * tb_k) * 2; + int sh_b_size = pipe_stages * (tb_k * tb_n / pack_factor) * 4; + int sh_red_size = tb_m * (tb_n + 8); + int sh_s_size = + get_scales_cache_size(th_config, prob_m, prob_n, prob_k, num_bits, group_size, has_act_order, is_k_full); + int sh_g_idx_size = has_act_order && !is_k_full ? pipe_stages * tb_k / 4 : 0; + int sh_zp_size = 0; + if (has_zp) { + if (is_zp_float) + sh_zp_size = sh_s_size; + else if (num_bits == 4) + sh_zp_size = sh_s_size / 4; + else if (num_bits == 8) + sh_zp_size = sh_s_size / 2; + } + + int total_size = max(sh_b_size, sh_red_size) + sh_a_size + sh_s_size + sh_zp_size + sh_g_idx_size; + + return total_size; +} + +bool is_valid_config( + thread_config_t const& th_config, + int thread_m_blocks, + int prob_m, + int prob_n, + int prob_k, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + int has_zp, + int is_zp_float, + int max_shared_mem) { + // Sanity + if (th_config.thread_k == -1 || th_config.thread_n == -1 || th_config.num_threads == -1) { + return false; + } + + // Verify K/N are divisible by thread K/N + if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) { + return false; + } + + // Verify min for thread K/N + if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) { + return false; + } + + // num_threads must be at least 128 (= 4 warps) + if (th_config.num_threads < 128) { + return false; + } + + // Check that pipeline fits into cache + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + return cache_size <= max_shared_mem; +} + +#define _GET_IF( \ + W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, M_BLOCK_SIZE_8, GROUP_BLOCKS, NUM_THREADS, IS_ZP_FLOAT) \ + else if ( \ + q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && thread_n_blocks == THREAD_N_BLOCKS && \ + thread_k_blocks == THREAD_K_BLOCKS && m_block_size_8 == M_BLOCK_SIZE_8 && group_blocks == GROUP_BLOCKS && \ + num_threads == NUM_THREADS && is_zp_float == IS_ZP_FLOAT) { \ + kernel = Marlin< \ + scalar_t, \ + W_TYPE.id(), \ + NUM_THREADS, \ + THREAD_M_BLOCKS, \ + THREAD_N_BLOCKS, \ + THREAD_K_BLOCKS, \ + M_BLOCK_SIZE_8, \ + pipe_stages, \ + GROUP_BLOCKS, \ + IS_ZP_FLOAT>; \ + } + +// COMMON: cases for (group_blocks in [-1, 2, 4, 8] and is_zp_float == false) +// this is the most common cases +// BIGGROUP: cases for big group size (group_blocks in [-1, 8]) +// FZP: cases for float-zero-point (is_zp_float = true) +// ACT: cases for act order case (group_blocks == 0) +// FP4: cases for nvfp4(e2m1) (group_blocks == 1) +#define COMMON_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define COMMON_GET_IF(W_TYPE) \ + COMMON_GET_IF_M1(W_TYPE, 8, 8, 256) \ + COMMON_GET_IF_M1(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M1(W_TYPE, 4, 8, 128) \ + COMMON_GET_IF_M234(W_TYPE, 16, 4, 256) \ + COMMON_GET_IF_M234(W_TYPE, 8, 4, 128) \ + COMMON_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define BIGGROUP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS, false) + +#define BIGGROUP_GET_IF(W_TYPE) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + BIGGROUP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + BIGGROUP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + BIGGROUP_GET_IF_M234(W_TYPE, 4, 8, 128) + +#define FP4_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 1, NUM_THREADS, false) + +#define FP4_GET_IF(W_TYPE) \ + FP4_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FP4_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FP4_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FP4_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FP4_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define FZP_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS, true) + +#define FZP_GET_IF(W_TYPE) \ + FZP_GET_IF_M1(W_TYPE, 8, 8, 256) \ + FZP_GET_IF_M1(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M1(W_TYPE, 4, 8, 128) \ + FZP_GET_IF_M234(W_TYPE, 16, 4, 256) \ + FZP_GET_IF_M234(W_TYPE, 8, 4, 128) \ + FZP_GET_IF_M234(W_TYPE, 4, 8, 128) + +// We currently have 4-bit models only with group_blocks == 4 +#define ACT_GET_IF_M1(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF_M234(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \ + _GET_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) \ + _GET_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, 0, NUM_THREADS, false) + +#define ACT_GET_IF(W_TYPE) \ + ACT_GET_IF_M1(W_TYPE, 8, 8, 256) \ + ACT_GET_IF_M1(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M1(W_TYPE, 4, 8, 128) \ + ACT_GET_IF_M234(W_TYPE, 16, 4, 256) \ + ACT_GET_IF_M234(W_TYPE, 8, 4, 128) \ + ACT_GET_IF_M234(W_TYPE, 4, 8, 128) + +template +MarlinFuncPtr get_marlin_kernel( + const host::ScalarType q_type, + int thread_m_blocks, + int thread_n_blocks, + int thread_k_blocks, + bool m_block_size_8, + bool has_act_order, + bool has_zp, + int group_blocks, + int num_threads, + bool is_zp_float) { + int num_bits = q_type.size_bits(); + auto kernel = MarlinDefault; + if (false) { + } + + COMMON_GET_IF(host::kU4) + COMMON_GET_IF(host::kU4B8) + COMMON_GET_IF(host::kU8B128) + + FP4_GET_IF(host::kFE2M1f) + + BIGGROUP_GET_IF(host::kFE4M3fn) + + ACT_GET_IF(host::kU4B8) + ACT_GET_IF(host::kU8B128) + + if (std::is_same::value) { + if (false) { + } + FZP_GET_IF(host::kU4) + } + + return kernel; +} + +template +exec_config_t determine_exec_config( + const host::ScalarType& q_type, + int prob_m, + int prob_n, + int prob_k, + int thread_m_blocks, + bool m_block_size_8, + int num_bits, + int group_size, + bool has_act_order, + bool is_k_full, + bool has_zp, + bool is_zp_float, + int max_shared_mem, + int sms) { + exec_config_t exec_cfg = exec_config_t{1, thread_config_t{-1, -1, -1}}; + thread_config_t* thread_configs = thread_m_blocks > 1 ? large_batch_thread_configs : small_batch_thread_configs; + int thread_configs_size = thread_m_blocks > 1 ? sizeof(large_batch_thread_configs) / sizeof(thread_config_t) + : sizeof(small_batch_thread_configs) / sizeof(thread_config_t); + + for (int i = 0; i < thread_configs_size; i++) { + thread_config_t th_config = thread_configs[i]; + + if (!is_valid_config( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem)) { + continue; + } + + int cache_size = get_kernel_cache_size( + th_config, + thread_m_blocks, + prob_m, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float); + + int group_blocks = 0; + if (!has_act_order) { + group_blocks = group_size == -1 ? -1 : group_size / 16; + } + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + th_config.thread_n / 16, + th_config.thread_k / 16, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + th_config.num_threads, + is_zp_float); + + if (kernel == MarlinDefault) continue; + + // int m_tiles = div_ceil(prob_m, thread_m_blocks * 16); + // int n_tiles = prob_n / th_config.thread_n; + // int k_tiles = prob_k / th_config.thread_k; + + return {1, th_config}; + } + + return exec_cfg; +} + +template +void marlin_mm( + const void* A, + const void* B, + void* C, + void* C_tmp, + void* s, + void* s2, + void* zp, + void* g_idx, + void* perm, + void* a_tmp, + int prob_m, + int prob_n, + int prob_k, + int lda, + void* workspace, + host::ScalarType const& q_type, + bool has_act_order, + bool is_k_full, + bool has_zp, + int num_groups, + int group_size, + int dev, + cudaStream_t stream, + int thread_k_init, + int thread_n_init, + int sms, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + if (has_zp) { + host::RuntimeCheck( + q_type == host::kU4 || q_type == host::kU8, "q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str()); + } else { + host::RuntimeCheck( + q_type == host::kU4B8 || q_type == host::kU8B128 || q_type == host::kFE4M3fn || q_type == host::kFE2M1f, + "q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + q_type.str()); + } + + host::RuntimeCheck( + prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m, ", ", prob_n, ", ", prob_k, "]"); + + int group_blocks = 0; + if (has_act_order) { + if (is_k_full) { + host::RuntimeCheck(group_size != -1); + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } else { + host::RuntimeCheck(group_size == 0); + group_blocks = 0; + } + } else { + if (group_size == -1) { + group_blocks = -1; + } else { + group_blocks = group_size / 16; + host::RuntimeCheck( + prob_k % group_blocks == 0, "prob_k = ", prob_k, " is not divisible by group_blocks = ", group_blocks); + } + } + + int num_bits = q_type.size_bits(); + const int4* A_ptr = (const int4*)A; + const int4* B_ptr = (const int4*)B; + int4* C_ptr = (int4*)C; + int4* C_tmp_ptr = (int4*)C_tmp; + const int4* s_ptr = (const int4*)s; + const uint16_t* s2_ptr = (const uint16_t*)s2; + const int4* zp_ptr = (const int4*)zp; + const int* g_idx_ptr = (const int*)g_idx; + const int* perm_ptr = (const int*)perm; + int4* a_tmp_ptr = (int4*)a_tmp; + + int* locks = (int*)workspace; + + if (has_act_order) { + // Permute A columns + int block_rows = div_ceil(prob_m, sms); + host::LaunchKernel(sms, default_threads, stream)( + permute_cols_kernel, A_ptr, perm_ptr, a_tmp_ptr, prob_m, prob_k, lda, block_rows); + A_ptr = a_tmp_ptr; + lda = prob_k; + + // If we have a full K, then we can run the non-act-order version of Marlin + // (since the weight rows are reordered by increasing group ids, and by + // having a full K, we have full original groups) + if (is_k_full) has_act_order = false; + } + + int max_shared_mem = 0; + host::RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + host::RuntimeCheck(max_shared_mem > 0); + + int max_par = 16; + if (prob_n <= 4096) max_par = 16 * 8; + int max_shared_mem_new = max_shared_mem; + int rest_m = prob_m; + int max_thread_m_blocks = 4; + while (rest_m) { + int par_count = rest_m / (max_thread_m_blocks * 16); + if (par_count > max_par) par_count = max_par; + int prob_m_split = par_count > 0 ? (par_count * (max_thread_m_blocks * 16)) : rest_m; + + int thread_k = thread_k_init; + int thread_n = thread_n_init; + + int thread_m_blocks = min(div_ceil(prob_m_split, 16), max_thread_m_blocks); + int m_block_size_8 = prob_m_split <= 8; + + // Set thread config + exec_config_t exec_cfg; + thread_config_t thread_tfg; + if (thread_k != -1 && thread_n != -1) { + thread_tfg = thread_config_t{thread_k, thread_n, default_threads}; + exec_cfg = exec_config_t{1, thread_tfg}; + host::RuntimeCheck(prob_n % thread_n == 0, "prob_n = ", prob_n, " is not divisible by thread_n = ", thread_n); + host::RuntimeCheck(prob_k % thread_k == 0, "prob_k = ", prob_k, " is not divisible by thread_k = ", thread_k); + } else { + // Auto config + exec_cfg = determine_exec_config( + q_type, + prob_m_split, + prob_n, + prob_k, + thread_m_blocks, + m_block_size_8, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem, + sms); + thread_tfg = exec_cfg.tb_cfg; + if (thread_tfg.thread_k == -1 && max_thread_m_blocks > 1) { + max_thread_m_blocks--; + continue; + } + } + + int num_threads = thread_tfg.num_threads; + thread_k = thread_tfg.thread_k; + thread_n = thread_tfg.thread_n; + int blocks = sms * exec_cfg.blocks_per_sm; + if (exec_cfg.blocks_per_sm > 1) max_shared_mem_new = max_shared_mem / exec_cfg.blocks_per_sm - 1024; + + int thread_k_blocks = thread_k / 16; + int thread_n_blocks = thread_n / 16; + + host::RuntimeCheck( + is_valid_config( + thread_tfg, + thread_m_blocks, + prob_m_split, + prob_n, + prob_k, + num_bits, + group_size, + has_act_order, + is_k_full, + has_zp, + is_zp_float, + max_shared_mem_new), + "Invalid thread config: thread_m_blocks = ", + thread_m_blocks, + ", thread_k = ", + thread_tfg.thread_k, + ", thread_n = ", + thread_tfg.thread_n, + ", num_threads = ", + thread_tfg.num_threads, + " for MKN = [", + prob_m, + ", ", + prob_k, + ", ", + prob_n, + "] and num_bits = ", + num_bits, + ", prob_m_split = ", + prob_m_split, + ", group_size = ", + group_size, + ", has_act_order = ", + has_act_order, + ", is_k_full = ", + is_k_full, + ", has_zp = ", + has_zp, + ", is_zp_float = ", + is_zp_float, + ", max_shared_mem_new = ", + max_shared_mem_new); + + auto kernel = get_marlin_kernel( + q_type, + thread_m_blocks, + thread_n_blocks, + thread_k_blocks, + m_block_size_8, + has_act_order, + has_zp, + group_blocks, + num_threads, + is_zp_float); + + if (kernel == MarlinDefault) { + host::Panic( + "Unsupported shapes: MNK = [", + prob_m, + ", ", + prob_n, + ", ", + prob_k, + "]", + ", has_act_order = ", + has_act_order, + ", num_groups = ", + num_groups, + ", group_size = ", + group_size, + ", prob_m_split = ", + prob_m_split, + ", thread_m_blocks = ", + thread_m_blocks, + ", thread_n_blocks = ", + thread_n_blocks, + ", thread_k_blocks = ", + thread_k_blocks, + ", num_threads = ", + num_threads, + ", num_bits = ", + num_bits); + } + + host::RuntimeDeviceCheck( + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem_new)); + + bool part_use_atomic_add = use_atomic_add && div_ceil(prob_m_split, 64) * prob_n <= 2048; + + host::LaunchKernel(blocks, num_threads, stream, max_shared_mem_new)( + kernel, + A_ptr, + B_ptr, + C_ptr, + C_tmp_ptr, + s_ptr, + s2_ptr, + zp_ptr, + g_idx_ptr, + num_groups, + prob_m_split, + prob_n, + prob_k, + lda, + locks, + part_use_atomic_add, + use_fp32_reduce, + max_shared_mem_new); + + A_ptr += prob_m_split * (lda / 8); + C_ptr += prob_m_split * (prob_n / 8); + rest_m -= prob_m_split; + } +} + +#endif + +} // namespace device::marlin + +template +void gptq_marlin_gemm( + tvm::ffi::TensorView a, + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView b_scales, + tvm::ffi::TensorView global_scale, + tvm::ffi::TensorView b_zeros, + tvm::ffi::TensorView g_idx, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView c, + tvm::ffi::TensorView c_tmp, + tvm::ffi::TensorView a_tmp, + tvm::ffi::TensorView workspace, + int64_t b_q_type_id, + bool is_k_full, + bool use_atomic_add, + bool use_fp32_reduce, + bool is_zp_float) { + using namespace host; + + ScalarType const b_q_type = ScalarType::from_id(b_q_type_id); + int pack_factor = 32 / b_q_type.size_bits(); + + // Bind symbolic sizes + auto M = SymbolicSize{"M"}; + auto K = SymbolicSize{"K"}; + auto N = SymbolicSize{"N"}; + auto device = SymbolicDevice{}; + device.set_options(); + + // Verify a: [M, K] + auto lda = SymbolicSize{"lda"}; + TensorMatcher({M, K}).with_strides({lda, 1}).with_dtype().with_device(device).verify(a); + + int64_t size_m = M.unwrap(); + int64_t size_k = K.unwrap(); + + // Verify b_q_weight: [K/tile_size, packed_N] + RuntimeCheck( + size_k % device::marlin::tile_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t expected_bqw_dim0 = size_k / device::marlin::tile_size; + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(expected_bqw_dim0); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device).verify(b_q_weight); + + RuntimeCheck( + b_q_weight.size(1) % device::marlin::tile_size == 0, + "b_q_weight.size(1) = ", + b_q_weight.size(1), + " is not divisible by tile_size = ", + device::marlin::tile_size); + int64_t actual_size_n = (b_q_weight.size(1) / device::marlin::tile_size) * pack_factor; + N.set_value(actual_size_n); + int64_t size_n = N.unwrap(); + + // Verify stride alignment + int64_t a_stride0 = a.stride(0); + RuntimeCheck(a_stride0 % 8 == 0, "a.stride(0) must be divisible by 8"); + + // Verify b_scales: [num_groups, N] + auto num_groups_sym = SymbolicSize{"num_groups"}; + TensorMatcher({num_groups_sym, N}).with_device(device).verify(b_scales); + int num_groups = static_cast(num_groups_sym.unwrap()); + + // Verify c: [M, N] + TensorMatcher({M, N}).with_dtype().with_device(device).verify(c); + + // Early return for zero-size M + if (size_m == 0) return; + + // Determine has_act_order from g_idx/perm sizes + int64_t g_idx_size = g_idx.size(0); + int64_t perm_size = perm.size(0); + bool has_act_order = g_idx_size > 0 && perm_size > 0; + + if (has_act_order) { + RuntimeCheck( + (g_idx_size == size_k && perm_size == size_k), + "Unexpected g_idx.size(0) = ", + g_idx_size, + " and perm.size(0) = ", + perm_size, + ", where size_k = ", + size_k); + } + + // Determine has_zp from b_zeros size + int64_t b_zeros_size = b_zeros.size(0); + bool has_zp = b_zeros_size > 0; + + if (has_zp) { + RuntimeCheck( + b_q_type == kU4 || b_q_type == kU8, "b_q_type must be u4 or u8 when has_zp = True. Got = ", b_q_type.str()); + } else { + RuntimeCheck( + b_q_type == kU4B8 || b_q_type == kU8B128 || b_q_type == kFE4M3fn || b_q_type == kFE2M1f, + "b_q_type must be uint4b8, uint8b128, float8_e4m3fn or float4_e2m1f when " + "has_zp = False. Got = ", + b_q_type.str()); + } + + if (has_zp && is_zp_float) { + RuntimeCheck( + std::is_same::value, "Computation type must be float16 (half) when using float zero points."); + } + + // Verify b_zeros shape + if (has_zp) { + RuntimeCheck(b_zeros.dim() == 2, "b_zeros rank = ", b_zeros.dim(), " is not 2"); + if (is_zp_float) { + RuntimeCheck(b_zeros.size(1) == size_n, "b_zeros dim 1 = ", b_zeros.size(1), " is not size_n = ", size_n); + RuntimeCheck( + num_groups == b_zeros.size(0), "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck(num_groups != -1, "num_groups must be != -1"); + } else { + RuntimeCheck( + b_zeros.size(0) == num_groups, "b_zeros dim 0 = ", b_zeros.size(0), " is not num_groups = ", num_groups); + RuntimeCheck( + b_zeros.size(1) == size_n / pack_factor, + "b_zeros dim 1 = ", + b_zeros.size(1), + " is not size_n / pack_factor = ", + size_n / pack_factor); + } + } + + // Verify global_scale + int64_t global_scale_size = global_scale.size(0); + if (global_scale_size > 0) { + RuntimeCheck(b_q_type == kFE2M1f, "global_scale can only be used for float4_e2m1f."); + } else { + RuntimeCheck(!(b_q_type == kFE2M1f), "the global_scale parameter must be passed for float4_e2m1f."); + } + + // Derive group_size + int group_size = -1; + if (has_act_order) { + if (is_k_full) { + RuntimeCheck(num_groups > 1, "For act_order, num_groups must be > 1"); + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = 0; + } + } else { + if (num_groups > 1) { + RuntimeCheck(size_k % num_groups == 0, "size_k = ", size_k, ", is not divisible by num_groups = ", num_groups); + group_size = static_cast(size_k / num_groups); + } else { + group_size = -1; + } + } + + // Verify workspace and get device info + RuntimeCheck( + size_n % device::marlin::min_thread_n == 0, + "size_n = ", + size_n, + ", is not divisible by min_thread_n = ", + device::marlin::min_thread_n); + + DLDevice dl_device = device.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + + int sms = -1; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev)); + + RuntimeCheck( + workspace.size(0) >= sms, "workspace.size(0) = ", workspace.size(0), " is below min_workspace_size = ", sms); + + // Hardcoded defaults (auto config) + int thread_k_init = -1; + int thread_n_init = -1; + + // Compute c_tmp and a_tmp pointers + // c_tmp and a_tmp are pre-allocated by caller + + device::marlin::marlin_mm( + a.data_ptr(), + b_q_weight.data_ptr(), + c.data_ptr(), + c_tmp.data_ptr(), + b_scales.data_ptr(), + global_scale.data_ptr(), + b_zeros.data_ptr(), + g_idx.data_ptr(), + perm.data_ptr(), + a_tmp.data_ptr(), + static_cast(size_m), + static_cast(size_n), + static_cast(size_k), + static_cast(a_stride0), + workspace.data_ptr(), + b_q_type, + has_act_order, + is_k_full, + has_zp, + num_groups, + group_size, + dev, + stream, + thread_k_init, + thread_n_init, + sms, + use_atomic_add, + use_fp32_reduce, + is_zp_float); +} diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin_repack.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin_repack.cuh new file mode 100644 index 000000000..b869260c1 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/gptq_marlin_repack.cuh @@ -0,0 +1,362 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ + +#pragma once + +#include + +#include + +#include "marlin.cuh" + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + return; +} +#else +template +__global__ void gptq_marlin_repack_kernel( + uint32_t const* __restrict__ b_q_weight_ptr, + uint32_t const* __restrict__ perm_ptr, + uint32_t* __restrict__ out_ptr, + int size_k, + int size_n) { + constexpr int pack_factor = 32 / num_bits; + + int k_tiles = size_k / tile_k_size; + int n_tiles = size_n / tile_n_size; + int block_k_tiles = div_ceil(k_tiles, gridDim.x); + + auto start_k_tile = blockIdx.x * block_k_tiles; + if (start_k_tile >= k_tiles) { + return; + } + + int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles); + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + extern __shared__ int4 sh[]; + + constexpr int perm_size = tile_k_size / 4; + + int4* sh_perm_ptr = sh; + int4* sh_pipe_ptr = sh_perm_ptr; + if constexpr (has_perm) { + sh_pipe_ptr += perm_size; + } + + constexpr int tile_ints = tile_k_size / pack_factor; + + constexpr int stage_n_threads = tile_n_size / 4; + constexpr int stage_k_threads = has_perm ? tile_k_size : tile_ints; + constexpr int stage_size = stage_k_threads * stage_n_threads; + + auto load_perm_to_shared = [&](int k_tile_id) { + int first_k_int4 = (k_tile_id * tile_k_size) / 4; + + int4 const* perm_int4_ptr = reinterpret_cast(perm_ptr); + + if (threadIdx.x < perm_size) { + sh_perm_ptr[threadIdx.x] = perm_int4_ptr[first_k_int4 + threadIdx.x]; + } + __syncthreads(); + }; + + auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + cp_async_fence(); + return; + } + + int first_n = n_tile_id * tile_n_size; + + int4* sh_ptr = sh_pipe_ptr + stage_size * pipe; + + if constexpr (has_perm) { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + uint32_t const* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + int src_k = sh_perm_int_ptr[k_id]; + int src_k_packed = src_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[src_k_packed * size_n + first_n + (n_id * 4)]))); + } + + } else { + if (threadIdx.x < stage_size) { + auto k_id = threadIdx.x / stage_n_threads; + auto n_id = threadIdx.x % stage_n_threads; + + int first_k = k_tile_id * tile_k_size; + int first_k_packed = first_k / pack_factor; + + cp_async4( + &sh_ptr[k_id * stage_n_threads + n_id], + reinterpret_cast(&(b_q_weight_ptr[(first_k_packed + k_id) * size_n + first_n + (n_id * 4)]))); + } + } + + cp_async_fence(); + }; + + auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) { + if (n_tile_id >= n_tiles) { + return; + } + + auto warp_id = threadIdx.x / 32; + auto th_id = threadIdx.x % 32; + + if (warp_id >= 4) { + return; + } + + int tc_col = th_id / 4; + int tc_row = (th_id % 4) * 2; + + constexpr int tc_offsets[4] = {0, 1, 8, 9}; + + int cur_n = warp_id * 16 + tc_col; + + constexpr int sh_stride = 64; + constexpr uint32_t mask = (1 << num_bits) - 1; + + int4* sh_stage_ptr = sh_pipe_ptr + stage_size * pipe; + uint32_t* sh_stage_int_ptr = reinterpret_cast(sh_stage_ptr); + + uint32_t* sh_perm_int_ptr = reinterpret_cast(sh_perm_ptr); + + uint32_t vals[8]; + + if constexpr (has_perm) { + for (int i = 0; i < 4; i++) { + int k_idx = tc_row + tc_offsets[i]; + + uint32_t src_k = sh_perm_int_ptr[k_idx]; + uint32_t src_k_pos = src_k % pack_factor; + + uint32_t b1_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n]; + uint32_t b1_cur_val = (b1_val >> (src_k_pos * num_bits)) & mask; + + uint32_t b2_val = sh_stage_int_ptr[k_idx * sh_stride + cur_n + 8]; + uint32_t b2_cur_val = (b2_val >> (src_k_pos * num_bits)) & mask; + + vals[i] = b1_cur_val; + vals[4 + i] = b2_cur_val; + } + + } else { + uint32_t b1_vals[tile_ints]; + uint32_t b2_vals[tile_ints]; + +#pragma unroll + for (int i = 0; i < tile_ints; i++) { + b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; + b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; + } + +#pragma unroll + for (int i = 0; i < 4; i++) { + int cur_elem = tc_row + tc_offsets[i]; + int cur_int = cur_elem / pack_factor; + int cur_pos = cur_elem % pack_factor; + + vals[i] = (b1_vals[cur_int] >> (cur_pos * num_bits)) & mask; + vals[4 + i] = (b2_vals[cur_int] >> (cur_pos * num_bits)) & mask; + } + } + + constexpr int tile_size = tile_k_size * tile_n_size / pack_factor; + int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size; + + // Result of: + // https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h + if constexpr (num_bits == 4) { + constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; + + uint32_t res = 0; +#pragma unroll + for (int i = 0; i < 8; i++) { + res |= vals[pack_idx[i]] << (i * 4); + } + + out_ptr[out_offset + th_id * 4 + warp_id] = res; + + } else { + constexpr int pack_idx[4] = {0, 2, 1, 3}; + + uint32_t res1 = 0; + uint32_t res2 = 0; +#pragma unroll + for (int i = 0; i < 4; i++) { + res1 |= vals[pack_idx[i]] << (i * 8); + res2 |= vals[4 + pack_idx[i]] << (i * 8); + } + + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1; + out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2; + } + }; + + auto start_pipes = [&](int k_tile_id, int n_tile_id) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages - 1; pipe++) { + fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); + } + + wait_for_stage(); + }; +#pragma unroll + for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { + int n_tile_id = 0; + + if constexpr (has_perm) { + load_perm_to_shared(k_tile_id); + } + + start_pipes(k_tile_id, n_tile_id); + + while (n_tile_id < n_tiles) { +#pragma unroll + for (int pipe = 0; pipe < repack_stages; pipe++) { + fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); + repack_tile(pipe, k_tile_id, n_tile_id + pipe); + wait_for_stage(); + } + n_tile_id += repack_stages; + } + } +} +#endif + +} // namespace device::marlin + +#define CALL_IF_REPACK(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + host::RuntimeDeviceCheck(cudaFuncSetAttribute( \ + device::marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, \ + max_shared_mem)); \ + host::LaunchKernel(blocks, device::marlin::repack_threads, stream, static_cast(max_shared_mem))( \ + device::marlin::gptq_marlin_repack_kernel, \ + b_q_weight_ptr, \ + perm_ptr, \ + out_ptr, \ + size_k, \ + size_n); \ + } + +void gptq_marlin_repack( + tvm::ffi::TensorView b_q_weight, + tvm::ffi::TensorView perm, + tvm::ffi::TensorView out, + int64_t size_k, + int64_t size_n, + int64_t num_bits) { + using namespace host; + + // Validate num_bits + RuntimeCheck(num_bits == 4 || num_bits == 8, "num_bits must be 4 or 8. Got = ", num_bits); + int const pack_factor = 32 / static_cast(num_bits); + + // Validate size alignment + RuntimeCheck( + size_k % device::marlin::tile_k_size == 0, + "size_k = ", + size_k, + " is not divisible by tile_k_size = ", + device::marlin::tile_k_size); + RuntimeCheck( + size_n % device::marlin::tile_n_size == 0, + "size_n = ", + size_n, + " is not divisible by tile_n_size = ", + device::marlin::tile_n_size); + + // Validate b_q_weight + auto bqw_dim0 = SymbolicSize{"bqw_dim0"}; + auto bqw_dim1 = SymbolicSize{"bqw_dim1"}; + bqw_dim0.set_value(size_k / pack_factor); + bqw_dim1.set_value(size_n); + auto device_ = SymbolicDevice{}; + device_.set_options(); + TensorMatcher({bqw_dim0, bqw_dim1}).with_dtype().with_device(device_).verify(b_q_weight); + + // Validate out + auto out_dim0 = SymbolicSize{"out_dim0"}; + auto out_dim1 = SymbolicSize{"out_dim1"}; + out_dim0.set_value(size_k / device::marlin::tile_size); + out_dim1.set_value(size_n * device::marlin::tile_size / pack_factor); + TensorMatcher({out_dim0, out_dim1}).with_dtype().with_device(device_).verify(out); + + // Detect if there is act_order + bool has_perm = perm.size(0) != 0; + + // Get ptrs + uint32_t const* b_q_weight_ptr = reinterpret_cast(b_q_weight.data_ptr()); + uint32_t const* perm_ptr = reinterpret_cast(perm.data_ptr()); + uint32_t* out_ptr = reinterpret_cast(out.data_ptr()); + + // Get dev info + DLDevice dl_device = device_.unwrap(); + int dev = dl_device.device_id; + cudaStream_t stream = LaunchKernel::resolve_device(dl_device); + int blocks; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev)); + + int max_shared_mem = 0; + RuntimeDeviceCheck(cudaDeviceGetAttribute(&max_shared_mem, cudaDevAttrMaxSharedMemoryPerBlockOptin, dev)); + RuntimeCheck(max_shared_mem > 0, "max_shared_mem must be > 0"); + + if (false) { + } + CALL_IF_REPACK(4, false) + CALL_IF_REPACK(4, true) + CALL_IF_REPACK(8, false) + CALL_IF_REPACK(8, true) + else { + Panic("Unsupported repack config: num_bits = ", num_bits, ", has_perm = ", has_perm); + } +} + +#undef CALL_IF_REPACK diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/kernel.h b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/kernel.h new file mode 100644 index 000000000..e54dd426f --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/kernel.h @@ -0,0 +1,32 @@ + +#include + +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define MARLIN_KERNEL_PARAMS \ + const int4 *__restrict__ A, const int4 *__restrict__ B, int4 *__restrict__ C, int4 *__restrict__ C_tmp, \ + const int4 *__restrict__ scales_ptr, const uint16_t *__restrict__ scale2_ptr, const int4 *__restrict__ zp_ptr, \ + const int *__restrict__ g_idx, int num_groups, int prob_m, int prob_n, int prob_k, int lda, int *locks, \ + bool use_atomic_add, bool use_fp32_reduce, int max_shared_mem + +namespace device::marlin { +templateshared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(MARLIN_KERNEL_PARAMS); + +} // namespace device::marlin diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh new file mode 100644 index 000000000..483ff5fc5 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin.cuh @@ -0,0 +1,89 @@ +#pragma once + +#include + +#include + +// Bridge the mllm_kernel::host namespace to the `host` namespace expected by +// Marlin code (originally from sglang). +namespace host = ::mllm_kernel::host; + +namespace device::marlin { +// Marlin params + +// 8 warps are a good choice since every SM has 4 schedulers and having more +// than 1 warp per schedule allows some more latency hiding. At the same time, +// we want relatively few warps to have many registers per warp and small tiles. +static constexpr int default_threads = 256; + +static constexpr int pipe_stages = 4; // 4 pipeline stages fit into shared memory + +static constexpr int min_thread_n = 64; +static constexpr int min_thread_k = 64; +static constexpr int max_thread_n = 256; + +static constexpr int tile_size = 16; +static constexpr int max_par = 16; + +// Repack params +static constexpr int repack_stages = 8; + +static constexpr int repack_threads = 256; + +static constexpr int tile_k_size = tile_size; +static constexpr int tile_n_size = tile_k_size * 4; + +// Helpers +template +struct Vec { + T elems[n]; + __device__ T& operator[](int i) { + return elems[i]; + } +}; + +using I4 = Vec; + +using host::div_ceil; + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 +// No support for async +#else + +__device__ inline void cp_async4_pred(void* smem_ptr, const void* glob_ptr, bool pred = true) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" + "}\n" ::"r"((int)pred), + "r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async4(void* smem_ptr, const void* glob_ptr) { + const int BYTES = 16; + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + asm volatile( + "{\n" + " cp.async.cg.shared.global [%0], [%1], %2;\n" + "}\n" ::"r"(smem), + "l"(glob_ptr), + "n"(BYTES)); +} + +__device__ inline void cp_async_fence() { + asm volatile("cp.async.commit_group;\n" ::); +} + +template +__device__ inline void cp_async_wait() { + asm volatile("cp.async.wait_group %0;\n" ::"n"(n)); +} + +#endif + +} // namespace device::marlin diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_dtypes.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_dtypes.cuh new file mode 100644 index 000000000..40b538688 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_dtypes.cuh @@ -0,0 +1,77 @@ +#ifndef _data_types_cuh +#define _data_types_cuh +#include + +#include "marlin.cuh" + +namespace device::marlin { + +template +class ScalarType {}; + +template <> +class ScalarType { + public: + using scalar_t = fp16_t; + using scalar_t2 = fp16x2_t; + + // Matrix fragments for tensor core instructions; their precise layout is + // documented here: + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-fragments-for-mma-m16n8k16-with-floating-point-type + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + + static __device__ float inline num2float(const fp16_t x) { + return __half2float(x); + } + + static __device__ fp16x2_t inline num2num2(const fp16_t x) { + return __half2half2(x); + } + + static __device__ fp16x2_t inline nums2num2(const fp16_t x1, const fp16_t x2) { + return __halves2half2(x1, x2); + } + + static __host__ __device__ fp16_t inline float2num(const float x) { + return __float2half(x); + } +}; + +template <> +class ScalarType { + public: + using scalar_t = bf16_t; + using scalar_t2 = bf16x2_t; + + using FragA = Vec; + using FragB = Vec; + using FragC = Vec; + using FragS = Vec; + using FragZP = Vec; + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 800 + static __device__ float inline num2float(const bf16_t x) { + return __bfloat162float(x); + } + + static __device__ bf16x2_t inline num2num2(const bf16_t x) { + return __bfloat162bfloat162(x); + } + + static __device__ bf16x2_t inline nums2num2(const bf16_t x1, const bf16_t x2) { + return __halves2bfloat162(x1, x2); + } + + static __host__ __device__ bf16_t inline float2num(const float x) { + return __float2bfloat16(x); + } +#endif +}; + +} // namespace device::marlin + +#endif diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_template.h b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_template.h new file mode 100644 index 000000000..04052838c --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/gemm/marlin/marlin_template.h @@ -0,0 +1,1514 @@ +/* + * Modified by Neural Magic + * Copyright (C) Marlin.2024 Elias Frantar + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Adapted from https://github.com/IST-DASLab/marlin + */ +#include + +#include "dequant.h" +#include "marlin.cuh" +#include "marlin_dtypes.cuh" + +#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ + static_assert(std::is_same::value || std::is_same::value, \ + "only float16 and bfloat16 is supported"); + +namespace device::marlin { + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + +templateshared + // fetch pipeline + const bool has_act_order, // whether act_order is enabled + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int* locks, // extra global storage for barrier synchronization + bool use_fp32_reduce // whether to use fp32 global reduce +) {} + +} // namespace device::marlin + +#else + +// m16n8k16 tensor core mma instruction with fp16 inputs and fp32 +// output/accumulation. +template +__device__ inline void mma(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +template +__device__ inline void mma_trans(const typename ScalarType::FragA& a_frag, + const typename ScalarType::FragB& frag_b, + const typename ScalarType::FragB& frag_b2, + typename ScalarType::FragC& frag_c) { + const uint32_t* a = reinterpret_cast(&a_frag); + const uint32_t* b = reinterpret_cast(&frag_b); + const uint32_t* b2 = reinterpret_cast(&frag_b2); + float* c = reinterpret_cast(&frag_c); + if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else if constexpr (std::is_same::value) { + asm volatile("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3]) + : "r"(b[0]), "r"(b2[0]), "r"(b[1]), "r"(b2[1]), "r"(a[0]), "r"(a[1]), "f"(c[0]), "f"(c[1]), "f"(c[2]), + "f"(c[3])); + } else { + STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t); + } +} + +// Instruction for loading a full 16x16 matrix fragment of operand A from shared +// memory, directly in tensor core layout. +template +__device__ inline void ldsm(typename ScalarType::FragA& frag_a, const void* smem_ptr) { + uint32_t* a = reinterpret_cast(&frag_a); + uint32_t smem = static_cast(__cvta_generic_to_shared(smem_ptr)); + if constexpr (count == 4) { + asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n" + : "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3]) + : "r"(smem)); + } else if constexpr (count == 2) { + asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0,%1}, [%2];\n" : "=r"(a[0]), "=r"(a[1]) : "r"(smem)); + } else if constexpr (count == 1) { + asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n" : "=r"(a[0]) : "r"(smem)); + } else { + static_assert(count == 1 || count == 2 || count == 4, "invalid count"); + } +} + +// Multiply dequantized values by the corresponding quantization scale; used +// only for grouped quantization. +template +__device__ inline void scale(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s = ScalarType::num2num2(reinterpret_cast(&frag_s)[i]); + frag_b[0] = __hmul2(frag_b[0], s); + frag_b[1] = __hmul2(frag_b[1], s); +} + +template +__device__ inline void scale_and_sub(typename ScalarType::FragB& frag_b, scalar_t s, scalar_t zp) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s2 = ScalarType::num2num2(s); + scalar_t2 zp2 = ScalarType::num2num2(zp); + frag_b[0] = __hfma2(frag_b[0], s2, __hneg2(zp2)); + frag_b[1] = __hfma2(frag_b[1], s2, __hneg2(zp2)); +} + +template +__device__ inline void sub_zp(typename ScalarType::FragB& frag_b, typename ScalarType::scalar_t2& frag_zp, + int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 zp = ScalarType::num2num2(reinterpret_cast(&frag_zp)[i]); + frag_b[0] = __hsub2(frag_b[0], zp); + frag_b[1] = __hsub2(frag_b[1], zp); +} + +// Same as above, but for act_order (each K is multiplied individually) +template +__device__ inline void scale4(typename ScalarType::FragB& frag_b, typename ScalarType::FragS& frag_s_1, + typename ScalarType::FragS& frag_s_2, typename ScalarType::FragS& frag_s_3, + typename ScalarType::FragS& frag_s_4, int i) { + using scalar_t2 = typename ScalarType::scalar_t2; + scalar_t2 s_val_1_2; + s_val_1_2.x = reinterpret_cast(&frag_s_1)[i]; + s_val_1_2.y = reinterpret_cast(&frag_s_2)[i]; + + scalar_t2 s_val_3_4; + s_val_3_4.x = reinterpret_cast(&frag_s_3)[i]; + s_val_3_4.y = reinterpret_cast(&frag_s_4)[i]; + + frag_b[0] = __hmul2(frag_b[0], s_val_1_2); + frag_b[1] = __hmul2(frag_b[1], s_val_3_4); +} + +// Given 2 floats multiply by 2 scales (halves) +template +__device__ inline void scale_float(float* c, typename ScalarType::FragS& s) { + scalar_t* s_ptr = reinterpret_cast(&s); + c[0] = __fmul_rn(c[0], ScalarType::num2float(s_ptr[0])); + c[1] = __fmul_rn(c[1], ScalarType::num2float(s_ptr[1])); +} + +// Wait until barrier reaches `count`, then lock for current threadblock. +__device__ inline void barrier_acquire(int* lock, int count) { + if (threadIdx.x == 0) { + int state = -1; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state != count); + } + __syncthreads(); +} + +// Release barrier and increment visitation count. +__device__ inline void barrier_release(int* lock, bool reset = false) { + __syncthreads(); + if (threadIdx.x == 0) { + if (reset) { + lock[0] = 0; + return; + } + int val = 1; + // Make sure that all writes since acquiring this barrier are visible + // globally, while releasing the barrier. + asm volatile("fence.acq_rel.gpu;\n"); + asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(lock), "r"(val)); + } +} + +// Wait until value of lock to be negative, and then add 1 +__device__ inline void wait_negative_and_add(int* lock) { + if (threadIdx.x == 0) { + int state = 0; + do + // Guarantee that subsequent writes by this threadblock will be visible + // globally. + asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(lock)); + while (state >= 0); + atomicAdd(lock, 1); + } + __syncthreads(); +} + +templateshared + // fetch pipeline + const int group_blocks, // number of consecutive 16x16 blocks + // with a separate quantization scale + const bool is_zp_float // is zero point of float16 type? + > +__global__ void Marlin(const int4* __restrict__ A, // fp16 input matrix of shape mxk + const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn + int4* __restrict__ C, // fp16 output buffer of shape mxn + int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce) + const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape + // (k/groupsize)xn + const uint16_t* __restrict__ scale2_ptr, // fp16 global scale (for nvfp4 + // only) + const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape + // (k/groupsize)x(n/pack_factor) + const int* __restrict__ g_idx, // int32 group indices of shape k + int num_groups, // number of scale groups per output channel + int prob_m, // batch dimension m + int prob_n, // output dimension n + int prob_k, // reduction dimension k + int lda, // A.stride(0), equal to prob_k is A is contiguous + int* locks, // extra global storage for barrier synchronization + bool use_atomic_add, // whether to use atomic add to reduce + bool use_fp32_reduce, // whether to use fp32 global reduce + int max_shared_mem) { + // Each threadblock processes one "stripe" of the B matrix with (roughly) the + // same size, which might involve multiple column "slices" (of width 16 * + // `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM + // example: + // 0 1 3 + // 0 2 3 + // 1 2 4 + // While this kind of partitioning makes things somewhat more complicated, it + // ensures good utilization of all SMs for many kinds of shape and GPU + // configurations, while requiring as few slow global cross-threadblock + // reductions as possible. + using Dtype = ScalarType; + using scalar_t2 = typename ScalarType::scalar_t2; + using FragA = typename ScalarType::FragA; + using FragB = typename ScalarType::FragB; + using FragC = typename ScalarType::FragC; + using FragS = typename ScalarType::FragS; + using FragZP = typename ScalarType::FragZP; + + static constexpr auto w_type = host::ScalarType::from_id(w_type_id); + constexpr bool has_zp = w_type == host::kU4 || w_type == host::kU8; + constexpr bool is_int_type = w_type == host::kU4 || w_type == host::kU8 || w_type == host::kU4B8 || w_type == host::kU8B128; + // see comments of dequant.h for more details + constexpr bool dequant_skip_flop = !is_int_type || has_zp && !is_zp_float && !std::is_same::value + || has_zp && !is_zp_float && !(w_type == host::kU8); + + scalar_t2 global_scale; + + if constexpr (w_type == host::kFE2M1f) { + uint16_t val = scale2_ptr[0]; + global_scale = Dtype::num2num2(*reinterpret_cast(&val)); + } + + constexpr bool has_act_order = group_blocks == 0; + constexpr int m_block_size = m_block_size_8 ? 8 : (16 * thread_m_blocks); + + constexpr int pack_factor = 32 / w_type.size_bits(); + static_assert(thread_m_blocks == 1 || !m_block_size_8); + + // For larger GEMMs we run multiple batchsize 64 versions in parallel for a + // better partitioning with less reductions + int parallel = 1; + if (prob_m > m_block_size) { + parallel = prob_m / m_block_size; + prob_m = m_block_size; + } + + int k_tiles = prob_k / 16 / thread_k_blocks; + int n_tiles = prob_n / 16 / thread_n_blocks; + int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x); + + if constexpr (!has_act_order && group_blocks != -1) { + if (group_blocks >= thread_k_blocks) { + // Ensure that the number of tiles in each stripe is a multiple of the + // groupsize; this avoids an annoying special case where a stripe starts + // in the middle of group. + iters = (group_blocks / thread_k_blocks) * div_ceil(iters, (group_blocks / thread_k_blocks)); + } + } + + int slice_row = (iters * blockIdx.x) % k_tiles; + int slice_col_par = (iters * blockIdx.x) / k_tiles; + int slice_col = slice_col_par; + int slice_iters; // number of threadblock tiles in the current slice + int slice_count = 0; // total number of active threadblocks in the current slice + int slice_idx; // index of threadblock in current slice; numbered bottom to + // top + + int par_id = 0; + int locks_off = 0; + + // We can easily implement parallel problem execution by just remapping + // indices and advancing global pointers + if (slice_col_par >= n_tiles) { + A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * lda / 8; + C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8; + slice_col = slice_col_par % n_tiles; + par_id = slice_col_par / n_tiles; + } + if (parallel * n_tiles >= gridDim.x) { + // when parallel * n_tiles >= sms + // then there are at most $sms$ conflict tile blocks + locks_off = blockIdx.x; + } else { + locks_off = (iters * blockIdx.x) / k_tiles - 1; + } + + // Compute all information about the current slice which is required for + // synchronization. + auto init_slice = [&](bool first_init = false) { + slice_iters = iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row); + if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0; + if (slice_iters == 0) return; + if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row; + slice_count = 1; + slice_idx = 0; + int col_first = iters * div_ceil(k_tiles * slice_col_par, iters); + if (col_first <= k_tiles * (slice_col_par + 1)) { + int col_off = col_first - k_tiles * slice_col_par; + slice_count = div_ceil(k_tiles - col_off, iters); + if (col_off > 0) slice_count++; + int delta_first = iters * blockIdx.x - col_first; + if (delta_first < 0 || (col_off == 0 && delta_first == 0)) + slice_idx = slice_count - 1; + else { + slice_idx = slice_count - 1 - delta_first / iters; + if (col_off > 0) slice_idx--; + } + } + if (parallel * n_tiles >= gridDim.x) { + if (slice_count > 1 && slice_idx == slice_count - 1) { locks_off++; } + } else { + locks_off++; + } + + if (first_init && use_atomic_add && slice_count > 1 && slice_idx == 0) { + constexpr int threads_per_m = 16 * thread_n_blocks / 8; + int m_per_thread = div_ceil(thread_m_blocks * 16, threads / threads_per_m); + if (m_block_size_8) m_per_thread = div_ceil(8, threads / threads_per_m); + for (int i = 0; i < m_per_thread; i++) { + int row = threads / threads_per_m * i + threadIdx.x / threads_per_m; + if (row < prob_m) { + int col = slice_col * 16 * thread_n_blocks / 8 + threadIdx.x % threads_per_m; + C[row * prob_n / 8 + col] = {0, 0, 0, 0}; + } + } + // After write zero to output, write a negative value to lock. + // Every SM that processes the same slice would wait for + // the negative value, and then atomicAdd 1 to it. + // After all SMs are processed, the lock value would back to 0 again. + __syncthreads(); + if (threadIdx.x == 0) locks[locks_off] = 1 - slice_count; + } + + if (slice_col == n_tiles) { + A += 16 * thread_m_blocks * lda / 8; + C += 16 * thread_m_blocks * prob_n / 8; + slice_col = 0; + par_id++; + } + }; + init_slice(true); + + // A sizes/strides + + // stride of the A matrix in global memory + int a_gl_stride = lda / 8; + // stride of an A matrix tile in shared memory + constexpr int a_sh_stride = 16 * thread_k_blocks / 8; + // delta between subsequent A tiles in global memory + constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8; + // between subsequent accesses within a tile + int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o); + // between shared memory writes + constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o); + // between shared memory tile reads + constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4)); + // within a shared memory tile + constexpr int a_sh_rd_delta_i = a_sh_stride * 16; + // overall size of a tile + constexpr int a_sh_stage = a_sh_stride * m_block_size; + // number of shared write iterations for a tile + constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta); + + // B sizes/strides + int b_gl_stride = 16 * prob_n / (pack_factor * 4); + constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4; + constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2; + constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs; + + int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks; + int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads); + constexpr int b_sh_wr_delta = threads * b_thread_vecs; + constexpr int b_sh_rd_delta = threads * b_thread_vecs; + constexpr int b_sh_stage = b_sh_stride * thread_k_blocks; + constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta; + + // Scale sizes/strides without act_order + int s_gl_stride = prob_n / 8; + constexpr int s_sh_stride = 16 * thread_n_blocks / 8; + constexpr int s_tb_groups = !has_act_order && group_blocks != -1 && group_blocks < thread_k_blocks + ? thread_k_blocks / group_blocks / (w_type == host::kFE2M1f ? 2 : 1) + : 1; + constexpr int s_sh_stage = s_tb_groups * s_sh_stride; + int s_gl_rd_delta = s_gl_stride; + + // Scale size/strides with act_order + constexpr int tb_k = 16 * thread_k_blocks; + constexpr int g_idx_stage = has_act_order ? (tb_k * sizeof(int)) / 16 : 0; + // constexpr int act_s_row_stride = 1; + // int act_s_col_stride = act_s_row_stride * num_groups; + constexpr int act_s_max_num_groups = 32; + int act_s_col_stride = 1; + int act_s_col_warp_stride = act_s_col_stride * 8; + + int tb_n_warps = thread_n_blocks / 4; + int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps; + + // Zero-points sizes/strides + int zp_gl_stride = is_zp_float ? prob_n / 8 : (prob_n / pack_factor) / 4; + constexpr int zp_sh_stride = is_zp_float ? 16 * thread_n_blocks / 8 : ((16 * thread_n_blocks) / pack_factor) / 4; + constexpr int zp_tb_groups = s_tb_groups; + constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0; + int zp_gl_rd_delta = zp_gl_stride; + + // Global A read index of current thread. + int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + a_gl_rd += a_gl_rd_delta_o * slice_row; + // Shared write index of current thread. + int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); + // Shared read index. + int a_sh_rd = a_sh_stride * ((threadIdx.x % 32) % (16 / (m_block_size_8 ? 2 : 1))) + + (threadIdx.x % 32) / (16 / (m_block_size_8 ? 2 : 1)); + a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4)); + + int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads) * b_thread_vecs; + b_gl_rd += b_sh_stride * slice_col; + b_gl_rd += b_gl_rd_delta_o * slice_row; + auto b_sh_wr = threadIdx.x * b_thread_vecs; + auto b_sh_rd = threadIdx.x * b_thread_vecs; + + // For act_order + constexpr int k_iter_size = tb_k / b_sh_wr_iters; + int slice_k_start = tb_k * slice_row; + int slice_k_finish = slice_k_start + tb_k * slice_iters; + int slice_k_start_shared_fetch = slice_k_start; + int slice_n_offset = act_s_col_tb_stride * slice_col; + + // No act_order + int s_gl_rd; + if constexpr (!has_act_order) { + if constexpr (group_blocks == -1) { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + } else { + s_gl_rd = s_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) / (w_type == host::kFE2M1f ? 2 : 1) + + s_sh_stride * slice_col + threadIdx.x; + } + } + auto s_sh_wr = threadIdx.x; + bool s_sh_wr_pred = threadIdx.x < s_sh_stride; + + // Zero-points + int zp_gl_rd; + if constexpr (has_zp) { + if constexpr (group_blocks == -1) { + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } else { + zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) + zp_sh_stride * slice_col + threadIdx.x; + } + } + auto zp_sh_wr = threadIdx.x; + bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride; + + // We use a different scale layout for grouped and column-wise quantization as + // we scale a `half2` tile in column-major layout in the former and in + // row-major in the latter case. + int s_sh_rd; + if constexpr (group_blocks != -1 && w_type == host::kFE2M1f) { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + int warp_row = warp_id / n_warps; + + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + s_sh_rd = s_sh_rd * 2 + warp_row % 2; + + } else if constexpr (group_blocks != -1) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + else if constexpr (group_blocks == -1 && (m_block_size_8 || (has_zp && !dequant_skip_flop))) + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 8; + else + s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4; + + // Zero-points have the same read layout as the scales + // (without column-wise case) + constexpr int num_col_threads = 8; + constexpr int num_row_threads = 4; + constexpr int num_ints_per_thread = 8 / pack_factor; + int zp_sh_rd; + if constexpr (has_zp) { + if constexpr (is_zp_float) { + if constexpr (group_blocks != -1) { + zp_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) / 4; + } + } else { + zp_sh_rd = num_ints_per_thread * num_col_threads * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + + num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads); + } + } + + // Precompute which thread should not read memory in which iterations; this is + // needed if there are more threads than required for a certain tilesize or + // when the batchsize is not a multiple of 16. + bool a_sh_wr_pred[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m; + + // To ensure that writing and reading A tiles to/from shared memory, the + // latter in fragment format, is fully bank conflict free, we need to use a + // rather fancy XOR-based layout. The key here is that neither reads nor + // writes of the 16-byte `int4` blocks of 8 consecutive threads involve the + // same shared memory banks. Further, it seems (based on NSight-Compute) that + // each warp must also write a consecutive memory segment? + auto transform_a = [&](int i) { + int row = i / a_gl_rd_delta_o; + return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ (row % 8); + }; + // Since the computation of this remapping is non-trivial and, due to our main + // loop unrolls, all shared memory accesses are static, we simply precompute + // both transformed reads and writes. + int a_sh_wr_trans[a_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr); + int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < thread_m_blocks; j++) + a_sh_rd_trans[i][j] = transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd); + } + + // Since B-accesses have non-constant stride they have to be computed at + // runtime; we break dependencies between subsequent accesses with a tile by + // maintining multiple pointers (we have enough registers), a tiny + // optimization. + const int4* B_ptr[b_sh_wr_iters]; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd; + + extern __shared__ int4 sh[]; + // Shared memory storage for global fetch pipelines. + constexpr int sh_red_size = (2 * thread_n_blocks + 1) * 16 * thread_m_blocks; + constexpr int sh_b_size = stages * b_sh_stage; + int4* sh_b = sh; + int4* sh_red = sh; + int4* sh_g_idx = sh_b + (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + int4* sh_zp = sh_g_idx + (stages * g_idx_stage); + constexpr int sh_s_size = has_act_order ? (act_s_max_num_groups * s_sh_stride) : (stages * s_sh_stage); + int4* sh_s = sh_zp + (stages * zp_sh_stage); + // shared memory reused by reduction should be smaller than + // shared memory used by weight. + static_assert(thread_m_blocks * 16 * thread_n_blocks * 16 / 8 <= stages * b_sh_stage); + int4* sh_a = sh_s + sh_s_size; + // constexpr int shm_size_used = + // stages * (g_idx_stage + zp_sh_stage) + sh_s_size + + // (sh_red_size > sh_b_size ? sh_red_size : sh_b_size); + + // Register storage for double buffer of shared memory reads. + FragA frag_a[2][thread_m_blocks]; + I4 frag_b_quant[2][b_thread_vecs]; + FragC frag_c[thread_m_blocks][4][2]; + FragS frag_s[2][4]; // No act-order + FragS act_frag_s[2][4][4]; // For act-order + int frag_qzp[2][num_ints_per_thread]; // Zero-points + FragZP frag_zp; // Zero-points in fp16 + FragZP frag_zpf[2]; // Zero-points in fp16 in HQQ + + // Zero accumulators. + auto zero_accums = [&]() { +#pragma unroll + for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++) reinterpret_cast(frag_c)[i] = 0; + }; + + int sh_first_group_id = -1; + int sh_num_groups = -1; + + auto fetch_act_order_scales_to_shared = [&](bool is_async, int first_group_id, int last_group_id) { + sh_first_group_id = first_group_id; + sh_num_groups = last_group_id - first_group_id + 1; + + if (sh_num_groups > act_s_max_num_groups) { sh_num_groups = act_s_max_num_groups; } + + if (sh_first_group_id + sh_num_groups > num_groups) { sh_num_groups = num_groups - sh_first_group_id; } + + int row_offset = first_group_id * s_gl_stride; + + if (is_async) { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x], + &scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]); + } + } + } else { + for (int i = 0; i < sh_num_groups; i++) { + if (threadIdx.x < s_sh_stride) { + sh_s[(i * s_sh_stride) + threadIdx.x] = scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset + threadIdx.x]; + } + } + } + }; + // Asynchronously fetch the next A, B and s tile from global to the next + // shared memory pipeline location. + auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) { + if (pred) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < a_sh_wr_iters; i++) { + cp_async4_pred(&sh_a_stage[a_sh_wr_trans[i]], &A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off], + a_sh_wr_pred[i]); + } + int4* sh_b_stage = sh_b + b_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) { +#pragma unroll + for (int j = 0; j < b_thread_vecs; j++) { cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j); } + + B_ptr[i] += b_gl_rd_delta_o; + } + + if constexpr (has_act_order) { + // Fetch g_idx thread-block portion + int full_pipe = a_off; + int cur_k = slice_k_start_shared_fetch + tb_k * full_pipe; + if (cur_k < prob_k && cur_k < slice_k_finish) { + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + + int4 const* cur_g_idx_stage_ptr = reinterpret_cast(&g_idx[cur_k]); + + if (threadIdx.x < g_idx_stage) { cp_async4_pred(&sh_g_idx_stage[threadIdx.x], &cur_g_idx_stage_ptr[threadIdx.x]); } + } + } else { + if constexpr (group_blocks != -1) { + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch scales if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (s_sh_wr_pred) { cp_async4(&sh_s_stage[s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta; + } + } else { + for (int i = 0; i < s_tb_groups; i++) { + if (s_sh_wr_pred) { cp_async4(&sh_s_stage[i * s_sh_stride + s_sh_wr], &scales_ptr[s_gl_rd]); } + s_gl_rd += s_gl_rd_delta; + } + } + } + + if constexpr (has_zp && group_blocks != -1) { + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + if constexpr (group_blocks >= thread_k_blocks) { + // Only fetch zero-points if this tile starts a new group + if (pipe % (group_blocks / thread_k_blocks) == 0) { + if (zp_sh_wr_pred) { cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta; + } + } else { + for (int i = 0; i < zp_tb_groups; i++) { + if (zp_sh_wr_pred) { cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr], &zp_ptr[zp_gl_rd]); } + zp_gl_rd += zp_gl_rd_delta; + } + } + } + } + } + // Insert a fence even when we are winding down the pipeline to ensure that + // waiting is also correct at this point. + cp_async_fence(); + }; + + auto fetch_col_zp_to_shared = [&]() { + if (zp_sh_wr_pred) { cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]); } + }; + + auto fetch_col_scale_to_shared = [&]() { + if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } + }; + + // Wait until the next thread tile has been loaded to shared memory. + auto wait_for_stage = [&]() { + // We only have `stages - 2` active fetches since we are double buffering + // and can only issue the next fetch when it is guaranteed that the previous + // shared memory load is fully complete (as it may otherwise be + // overwritten). + cp_async_wait(); + __syncthreads(); + }; + + // Load the next sub-tile from the current location in the shared memory pipe + // into the current register buffer. + auto fetch_to_registers = [&](int k, int pipe) { + int4* sh_a_stage = sh_a + a_sh_stage * pipe; +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) + ldsm(frag_a[k % 2][i], &sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]); + int4* sh_b_stage = sh_b + b_sh_stage * pipe; + +#pragma unroll + for (int i = 0; i < b_thread_vecs; i++) { + frag_b_quant[k % 2][i] = *reinterpret_cast(&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]); + } + }; + + bool is_same_group[stages]; + int same_group_id[stages]; + + auto init_same_group = [&](int pipe) { + if constexpr (!has_act_order) { return; } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + int group_id_1 = sh_g_idx_int_ptr[0]; + int group_id_2 = sh_g_idx_int_ptr[tb_k - 1]; + + is_same_group[pipe] = group_id_1 == group_id_2; + same_group_id[pipe] = group_id_1; + }; + + auto fetch_scales_to_registers = [&](int k, int full_pipe) { + int pipe = full_pipe % stages; + + if constexpr (!has_act_order) { + // No act-order case + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + } + } else if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_s_stage = + sh_s + s_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd]; + } else { + reinterpret_cast(&frag_s[1])[0] = reinterpret_cast(&frag_s[0])[0]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = k_blocks / (group_blocks * (w_type == host::kFE2M1f ? 2 : 1)); + + int4* sh_s_stage = sh_s + s_sh_stage * pipe; + + if constexpr (w_type_id != host::kFE2M1f.id()) { + reinterpret_cast(&frag_s[k % 2])[0] = sh_s_stage[s_sh_rd + cur_group_id * s_sh_stride]; + } else { + reinterpret_cast(&frag_s[k % 2])[0] = + reinterpret_cast(sh_s_stage)[s_sh_rd + cur_group_id * (2 * s_sh_stride)]; + } + } + } + + return; + } + + // Act-order case + + // Determine K of the "current" thread-block + int cur_k = slice_k_start + tb_k * full_pipe; + if (cur_k >= prob_k || cur_k >= slice_k_finish) { return; } + + // Reset (to current thread-block) since we read g_idx portion from the + // shared memory + cur_k = 0; + + // Progress to current iteration + cur_k += k_iter_size * (k % b_sh_wr_iters); + + // Determine "position" inside the thread-block (based on warp and + // thread-id) + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; // Each warp processes 4 16-size tiles over N + + int warp_row = warp_id / n_warps; + int warp_col = warp_id % n_warps; + + cur_k += warp_row * 16; + + auto th_id = threadIdx.x % 32; + cur_k += (th_id % 4) * 2; // Due to tensor-core layout for fp16 B matrix + + int s_col_shift = + /*slice_n_offset +*/ (act_s_col_warp_stride * warp_col) + (th_id / 4) * act_s_col_stride; + + if (is_same_group[pipe]) { + if (k % 2 == 0) { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = + sh_s[(same_group_id[pipe] - sh_first_group_id) * s_sh_stride + s_col_shift]; + } else { + *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))) = *(reinterpret_cast(&(act_frag_s[(k - 1) % 2][0][0]))); + } + + for (int i = 1; i < 4; i++) { + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = *(reinterpret_cast(&(act_frag_s[k % 2][0][0]))); + } + return; + } + + int4* sh_g_idx_stage = sh_g_idx + g_idx_stage * pipe; + int* sh_g_idx_int_ptr = reinterpret_cast(sh_g_idx_stage); + + constexpr int k_frag_offsets[4] = {0, 1, 8, 9}; // Tensor core offsets per thread + +#pragma unroll + for (int i = 0; i < 4; i++) { + int actual_k = cur_k + k_frag_offsets[i]; + + int group_id = sh_g_idx_int_ptr[actual_k]; + int rel_group_id = group_id - sh_first_group_id; + + *(reinterpret_cast(&(act_frag_s[k % 2][i][0]))) = sh_s[rel_group_id * s_sh_stride + s_col_shift]; + } + }; + + auto fetch_zp_to_registers = [&](int k, int full_pipe) { + // This code does not handle group_blocks == 0, + // which signifies act_order. + // has_zp implies AWQ, which doesn't have act_order, + static_assert(!has_zp || group_blocks != 0); + + if constexpr (has_zp && !is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks == -1) { + // load only when starting a new slice + if (k == 0 && full_pipe == 0) { +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp))[zp_sh_rd + i]; } + } + + } else if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + int cur_group_id = 0; + + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + sh_zp_stage += cur_group_id * zp_sh_stride; + +#pragma unroll + for (int i = 0; i < num_ints_per_thread; i++) { + frag_qzp[k % 2][i] = (reinterpret_cast(sh_zp_stage))[zp_sh_rd + i]; + } + } + } + + else if constexpr (has_zp && is_zp_float) { + int pipe = full_pipe % stages; + + if constexpr (group_blocks != -1) { + if constexpr (group_blocks >= thread_k_blocks) { + if (k % b_sh_wr_iters == 0) { + int4* sh_zp_stage = + sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) * (pipe / (group_blocks / thread_k_blocks))); + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd]; + } + } else { + auto warp_id = threadIdx.x / 32; + int n_warps = thread_n_blocks / 4; + + int warp_row = warp_id / n_warps; + + int cur_k = warp_row * 16; + cur_k += k_iter_size * (k % b_sh_wr_iters); + + int k_blocks = cur_k / 16; + // Suppress bogus and persistent divide-by-zero warning +#pragma nv_diagnostic push +#pragma nv_diag_suppress divide_by_zero + int cur_group_id = k_blocks / group_blocks; +#pragma nv_diagnostic pop + + int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe; + + reinterpret_cast(&frag_zpf[k % 2])[0] = sh_zp_stage[zp_sh_rd + cur_group_id * zp_sh_stride]; + } + } + } + }; + + auto dequant_data = [&](int q, scalar_t2* frag_b_ptr) { dequant(q, frag_b_ptr); }; + + // Execute the actual tensor core matmul of a sub-tile. + bool is_first_matmul_in_slice = true; + auto matmul = [&](int k) { + int k2 = k % 2; + const bool is_new_zp = ((group_blocks != -1) && (group_blocks < thread_k_blocks || k == 0)) + || (group_blocks == -1 && is_first_matmul_in_slice); + if constexpr (has_zp && !is_zp_float) { + if (is_new_zp) { + if constexpr (group_blocks == -1) is_first_matmul_in_slice = false; + FragB frag_zp_0; + FragB frag_zp_1; + int zp_quant_0, zp_quant_1; + + if constexpr (w_type.size_bits() == 4) { + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = zp_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + zp_quant_0 = frag_qzp[k2][0]; + zp_quant_1 = frag_qzp[k2][1]; + } + + dequant_data(zp_quant_0, reinterpret_cast(&frag_zp)); + dequant_data(zp_quant_1, reinterpret_cast(&frag_zp) + 2); + } + } + if constexpr (!dequant_skip_flop && has_zp && is_zp_float) { + if (is_new_zp) { reinterpret_cast(&frag_zp)[0] = reinterpret_cast(&frag_zpf[k2])[0]; } + } + + if constexpr (w_type == host::kFE2M1f) { + int s_quant_0 = reinterpret_cast(frag_s[k2])[0]; + int s_quant_1 = reinterpret_cast(frag_s[k2])[1]; + + dequant_fp8_scales(s_quant_0, reinterpret_cast(&frag_s[k2])); + dequant_fp8_scales(s_quant_1, reinterpret_cast(&frag_s[k2]) + 2); + } + +// We have the m dimension as the inner loop in order to encourage overlapping +// dequantization and matmul operations. +#pragma unroll + for (int j = 0; j < 4; j++) { + FragB frag_b0; + FragB frag_b1; + int b_quant_0, b_quant_1; + + if constexpr (w_type_id == host::kFE2M1f.id()) { + b_quant_1 = frag_b_quant[k2][0][j]; + b_quant_0 = b_quant_1 << 8; + } else if constexpr (w_type.size_bits() == 4) { + b_quant_0 = frag_b_quant[k2][0][j]; + b_quant_1 = b_quant_0 >> 8; + } else { + static_assert(w_type.size_bits() == 8); + int* frag_b_quant_ptr = reinterpret_cast(frag_b_quant[k2]); + b_quant_0 = frag_b_quant_ptr[j * 2 + 0]; + b_quant_1 = frag_b_quant_ptr[j * 2 + 1]; + } + + dequant_data(b_quant_0, reinterpret_cast(&frag_b0)); + dequant_data(b_quant_1, reinterpret_cast(&frag_b1)); + + if constexpr (dequant_skip_flop && has_zp && !is_zp_float) { + sub_zp(frag_b0, frag_zp[j], 0); + sub_zp(frag_b1, frag_zp[j], 1); + } + + // Apply scale to frag_b0 + if constexpr (has_act_order) { + static_assert(group_blocks != -1); + scale4(frag_b0, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 0); + scale4(frag_b1, act_frag_s[k2][0][j], act_frag_s[k2][1][j], act_frag_s[k2][2][j], act_frag_s[k2][3][j], 1); + } else if constexpr (!dequant_skip_flop && has_zp && !is_zp_float && group_blocks == -1) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2 s2 = Dtype::nums2num2(reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 0])[idx], + reinterpret_cast(&frag_s[j / 2][j % 2 * 2 + 1])[idx]); + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], s2); + scale_and_sub(frag_b0, s2.x, frag_zp[j].x); + scale_and_sub(frag_b1, s2.y, frag_zp[j].y); + } else if constexpr (!dequant_skip_flop && has_zp && group_blocks != -1) { + if (is_new_zp) frag_zp[j] = __hmul2(frag_zp[j], *reinterpret_cast(&frag_s[k2][j])); + scale_and_sub(frag_b0, frag_s[k2][j][0].x, frag_zp[j].x); + scale_and_sub(frag_b1, frag_s[k2][j][0].y, frag_zp[j].y); + } else if constexpr (group_blocks != -1) { + scale(frag_b0, frag_s[k2][j], 0); + scale(frag_b1, frag_s[k2][j], 1); + } + +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { + if constexpr (m_block_size_8) { + mma_trans(frag_a[k2][i], frag_b0, frag_b1, frag_c[i][j][0]); + } else { + mma(frag_a[k2][i], frag_b0, frag_c[i][j][0]); + mma(frag_a[k2][i], frag_b1, frag_c[i][j][1]); + } + } + } + }; + + // Since we slice across the k dimension of a tile in order to increase the + // number of warps while keeping the n dimension of a tile reasonable, we have + // multiple warps that accumulate their partial sums of the same output + // location; which we have to reduce over in the end. We do in shared memory. + auto thread_block_reduce = [&]() { + constexpr int red_off = threads / b_sh_stride_threads / 2; + if (red_off >= 1) { + auto red_idx = threadIdx.x / b_sh_stride_threads; + constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2; + constexpr int red_sh_delta = b_sh_stride_threads; + int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) + (threadIdx.x % b_sh_stride_threads); + + // Parallel logarithmic shared memory reduction. We make sure to avoid any + // unnecessary read or write iterations, e.g., for two warps we write only + // once by warp 1 and read only once by warp 0. + +#pragma unroll + for (int m_block = 0; m_block < thread_m_blocks; m_block++) { +#pragma unroll + for (int i = red_off; i > 0; i /= 2) { + if (i <= red_idx && red_idx < 2 * i) { +#pragma unroll + for (int j = 0; j < 4 * 2; j += (m_block_size_8 ? 2 : 1)) { + int red_sh_wr = red_sh_delta * j + (red_sh_rd - red_sh_stride * i); + if (i < red_off) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * j + red_sh_rd]); + float* c_wr = reinterpret_cast(&sh_red[red_sh_wr]); +#pragma unroll + for (int k = 0; k < 4; k++) reinterpret_cast(frag_c)[4 * 2 * m_block + j][k] += c_rd[k] + c_wr[k]; + } + sh_red[red_sh_wr] = reinterpret_cast(&frag_c)[4 * 2 * m_block + j]; + } + } + __syncthreads(); + } + if (red_idx == 0) { +#pragma unroll + for (int i = 0; i < 4 * 2; i += (m_block_size_8 ? 2 : 1)) { + float* c_rd = reinterpret_cast(&sh_red[red_sh_delta * i + red_sh_rd]); +#pragma unroll + for (int j = 0; j < 4; j++) reinterpret_cast(frag_c)[4 * 2 * m_block + i][j] += c_rd[j]; + } + } + __syncthreads(); + } + } + }; + + // Since multiple threadblocks may process parts of the same column slice, we + // finally have to globally reduce over the results. As the striped + // partitioning minimizes the number of such reductions and our outputs are + // usually rather small, we perform this reduction serially in L2 cache. + auto global_reduce_fp16 = [&](bool first = false, bool last = false) { + // We are very careful here to reduce directly in the output buffer to + // maximize L2 cache utilization in this step. To do this, we write out + // results in FP16 (but still reduce with FP32 compute). + constexpr int active_threads = 32 * thread_n_blocks / 4; + if (threadIdx.x < active_threads) { + int c_gl_stride = prob_n / 8; + int c_gl_wr_delta_o = 8 * c_gl_stride; + int c_gl_wr_delta_i = 4 * (active_threads / 32); + int c_gl_wr; + if constexpr (m_block_size_8) { + c_gl_wr = c_gl_stride * ((threadIdx.x % 4) * 2) + 4 * (threadIdx.x / 32) + (threadIdx.x % 32) / 8; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } else { + c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) + 4 * (threadIdx.x / 32) + threadIdx.x % 4; + c_gl_wr += (2 * thread_n_blocks) * slice_col; + } + constexpr int c_sh_wr_delta = active_threads; + auto c_sh_wr = threadIdx.x; + + int row = (threadIdx.x % 32) / 4; + + if (!first) { +// Interestingly, doing direct global accesses here really seems to mess up +// the compiler and lead to slowdowns, hence we also use async-copies even +// though these fetches are not actually asynchronous. +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + if constexpr (m_block_size_8) { + cp_async4_pred(&sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i], + (threadIdx.x % 4) * 2 + i < prob_m); + } else { + cp_async4_pred(&sh_red[c_sh_wr + c_sh_wr_delta * i], + &C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)], + i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m); + } + } + cp_async_fence(); + cp_async_wait<0>(); + } + +#pragma unroll + for (int i = 0; i < (m_block_size_8 ? 2 : thread_m_blocks * 4); i++) { + bool mask = (!m_block_size_8) && (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) + || (m_block_size_8) && ((threadIdx.x % 4) * 2 + i < prob_m); + if (mask) { + if (!first) { + int4 c_red = sh_red[c_sh_wr + i * c_sh_wr_delta]; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } + reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta] += + Dtype::num2float(reinterpret_cast(&c_red)[j]); + } + } + if (!last) { + int4 c; +#pragma unroll + for (int j = 0; j < 2 * 4; j++) { + int delta = 0; + if constexpr (m_block_size_8) { delta = j % 2 == 1 ? -2 : 0; } + reinterpret_cast(&c)[j] = + Dtype::float2num(reinterpret_cast(&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4) + delta]); + } + if constexpr (m_block_size_8) + C[c_gl_wr + i * c_gl_stride + (threadIdx.x % 8) / 4 * c_gl_wr_delta_i] = c; + else + C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] = c; + } + } + } + } + }; + + // Globally reduce over threadblocks that compute the same column block. + // We use a tmp C buffer to reduce in full fp32 precision. + auto global_reduce_fp32 = [&](bool first = false, bool last = false) { + constexpr int tb_m = thread_m_blocks * 16; + constexpr int tb_n = thread_n_blocks * 16; + + constexpr int c_size = tb_m * tb_n * sizeof(float) / 16; + + constexpr int active_threads = 32 * thread_n_blocks / 4; + bool is_th_active = threadIdx.x < active_threads; + + constexpr int num_floats = thread_m_blocks * 4 * 2 * 4; + constexpr int th_size = num_floats * sizeof(float) / 16; + + int c_cur_offset = locks_off * c_size; + + if (!is_th_active) { return; } + + if (!first) { + float* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + sh_red[threadIdx.x] = C_tmp[c_cur_offset + active_threads * k + threadIdx.x]; + + float* sh_c_ptr = reinterpret_cast(&sh_red[threadIdx.x]); +#pragma unroll + for (int f = 0; f < 4; f++) { frag_c_ptr[k * 4 + f] += sh_c_ptr[f]; } + } + } + + if (!last) { + int4* frag_c_ptr = reinterpret_cast(&frag_c); +#pragma unroll + for (int k = 0; k < th_size; k += (m_block_size_8 ? 2 : 1)) { + C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k]; + } + } + }; + + // Write out the reduce final result in the correct layout. We only actually + // reshuffle matrix fragments in this step, the reduction above is performed + // in fragment layout. + auto write_result = [&]() { + int c_gl_stride = prob_n / 8; + constexpr int c_sh_stride = 2 * thread_n_blocks + 1; + int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks)); + constexpr int c_sh_rd_delta = c_sh_stride * (threads / (2 * thread_n_blocks)); + + int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + c_gl_wr += (2 * thread_n_blocks) * slice_col; + int c_sh_wr; + if constexpr (m_block_size_8) { + c_sh_wr = (8 * c_sh_stride) * ((threadIdx.x % 32) % 4 * 2) + (threadIdx.x % 32) / 4; + c_sh_wr += 64 * (threadIdx.x / 32); + } else { + c_sh_wr = (4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4; + c_sh_wr += 32 * (threadIdx.x / 32); + } + + int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) + (threadIdx.x % (2 * thread_n_blocks)); + + int c_gl_wr_end = c_gl_stride * prob_m; + // We first reorder in shared memory to guarantee the most efficient final + // global write patterns + auto write = [&](int idx, float c0, float c1, FragS& s) { + scalar_t2 res = Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1)); + + // For per-column quantization we finally apply the scale here (only for + // 4-bit) + if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 4 + && (has_zp && dequant_skip_flop || !has_zp)) { + res = __hmul2(res, s[0]); + } + + if constexpr (w_type == host::kFE2M1f) { res = __hmul2(res, global_scale); } + + if constexpr (m_block_size_8) { + ((scalar_t*)sh_red)[idx] = res.x; + ((scalar_t*)sh_red)[idx + 8 * c_sh_stride] = res.y; + } else { + ((scalar_t2*)sh_red)[idx] = res; + } + }; + + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + if constexpr (m_block_size_8) { + int wr = c_sh_wr + 16 * j; + write(wr, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + 8, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 1]); + } else { + int wr = c_sh_wr + 8 * j; + write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0], frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2], frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]); + write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0], frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]); + write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2], frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]); + } + } + c_sh_wr += 16 * (4 * c_sh_stride); + } + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks)); i++) { + if (c_gl_wr < c_gl_wr_end) { + if (use_atomic_add && slice_count > 1) { + scalar_t2* C_half2 = reinterpret_cast(&C[c_gl_wr]); + scalar_t2* sh_red_half2 = reinterpret_cast(&sh_red[c_sh_rd]); +#pragma unroll + for (int a = 0; a < 4; a++) { atomicAdd(&C_half2[a], sh_red_half2[a]); } + } else { + C[c_gl_wr] = sh_red[c_sh_rd]; + } + c_gl_wr += c_gl_wr_delta; + c_sh_rd += c_sh_rd_delta; + } + } + __syncthreads(); + }; + + // Start global fetch and register load pipelines. + auto start_pipes = [&]() { + +#pragma unroll + for (int i = 0; i < stages - 1; i++) { + if (has_act_order && i == 0) { + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } + fetch_act_order_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]); + } + + if constexpr (has_zp && !is_zp_float && group_blocks == -1) { + if (i == 0) { + fetch_col_zp_to_shared(); + if constexpr (!dequant_skip_flop) { fetch_col_scale_to_shared(); } + } + } + fetch_to_shared(i, i, i < slice_iters); + } + + zero_accums(); + wait_for_stage(); + init_same_group(0); + fetch_to_registers(0, 0); + fetch_scales_to_registers(0, 0); + fetch_zp_to_registers(0, 0); + a_gl_rd += a_gl_rd_delta_o * (stages - 1); + if constexpr (has_act_order) { slice_k_start_shared_fetch += tb_k * (stages - 1); } + }; + if (slice_iters) { start_pipes(); } + + // Main loop. + while (slice_iters) { + // We unroll over both the global fetch and the register load pipeline to + // ensure all shared memory accesses are static. Note that both pipelines + // have even length meaning that the next iteration will always start at + // index 0. + +#pragma unroll + for (int pipe = 0; pipe < stages;) { +#pragma unroll + for (int k = 0; k < b_sh_wr_iters; k++) { + fetch_to_registers(k + 1, pipe % stages); + fetch_scales_to_registers(k + 1, pipe); + fetch_zp_to_registers(k + 1, pipe); + if (k == b_sh_wr_iters - 2) { + fetch_to_shared((pipe + stages - 1) % stages, pipe, slice_iters >= stages); + pipe++; + wait_for_stage(); + init_same_group(pipe % stages); + } + matmul(k); + } + slice_iters--; + if (slice_iters == 0) { break; } + } + + a_gl_rd += a_gl_rd_delta_o * stages; + + if constexpr (has_act_order) { + slice_k_start += tb_k * stages; + + if (slice_k_start < prob_k) { + slice_k_start_shared_fetch += tb_k * stages; + int first_group_id = g_idx[slice_k_start]; + int last_g_idx = slice_k_start + stages * tb_k * 2; + if (last_g_idx >= prob_k) { last_g_idx = prob_k - 1; } + int last_group_id = g_idx[last_g_idx]; + if (last_group_id >= sh_first_group_id + sh_num_groups) { + fetch_act_order_scales_to_shared(false, first_group_id, last_group_id); + __syncthreads(); + } + } + } + + // Process results and, if necessary, proceed to the next column slice. + // While this pattern may not be the most readable, other ways of writing + // the loop seemed to noticeably worse performance after compilation. + if (slice_iters == 0) { + cp_async_wait<0>(); + bool last = slice_idx == slice_count - 1; + // For per-column scales, we only fetch them here in the final step before + // write-out + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + if (s_sh_wr_pred) { cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]); } + cp_async_fence(); + } + } + + thread_block_reduce(); + if constexpr (!has_act_order && group_blocks == -1 && (has_zp && dequant_skip_flop || !has_zp)) { + if (w_type.size_bits() == 8 || (last || use_atomic_add)) { + cp_async_wait<0>(); + __syncthreads(); + if (threadIdx.x / 32 < thread_n_blocks / 4) { + reinterpret_cast(&frag_s)[0] = sh_s[s_sh_rd + 0]; + reinterpret_cast(&frag_s)[1] = sh_s[s_sh_rd + 4]; + if constexpr (m_block_size_8) { + int idx = (threadIdx.x / 4) % 2; + scalar_t2* frag_s_half2 = reinterpret_cast(frag_s); +#pragma unroll + for (int i = 0; i < 8; i++) { + frag_s_half2[i] = Dtype::num2num2(reinterpret_cast(&frag_s_half2[i])[idx]); + } + } + } + } + } + + // For 8-bit channelwise, we apply the scale before the global reduction + // that converts the fp32 results to fp16 (so that we avoid possible + // overflow in fp16) + if constexpr (!has_act_order && group_blocks == -1 && w_type.size_bits() == 8 + && (has_zp && dequant_skip_flop || !has_zp)) { + if (threadIdx.x / 32 < thread_n_blocks / 4) { +#pragma unroll + for (int i = 0; i < thread_m_blocks; i++) { +#pragma unroll + for (int j = 0; j < 4; j++) { + scale_float(reinterpret_cast(&frag_c[i][j][0][0]), frag_s[j / 2][2 * (j % 2) + 0]); + scale_float(reinterpret_cast(&frag_c[i][j][0][2]), + frag_s[j / 2][2 * (j % 2) + (m_block_size_8 ? 1 : 0)]); + + if constexpr (!m_block_size_8) { + scale_float(reinterpret_cast(&frag_c[i][j][1][0]), frag_s[j / 2][2 * (j % 2) + 1]); + scale_float(reinterpret_cast(&frag_c[i][j][1][2]), frag_s[j / 2][2 * (j % 2) + 1]); + } + } + } + } + } + + if (slice_count > 1 && !use_atomic_add) { + // only globally reduce if there is more than one block in a slice + barrier_acquire(&locks[locks_off], slice_idx); + if (use_fp32_reduce) { + global_reduce_fp32(slice_idx == 0, last); + } else { + global_reduce_fp16(slice_idx == 0, last); + } + barrier_release(&locks[locks_off], last); + } + if (use_atomic_add && slice_count > 1 && slice_idx != 0) wait_negative_and_add(&locks[locks_off]); + if (last || use_atomic_add) + // only the last block in a slice actually writes the result + write_result(); + slice_row = 0; + slice_col_par++; + slice_col++; + is_first_matmul_in_slice = true; + init_slice(); + + if (slice_iters) { + a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) + (threadIdx.x % a_gl_rd_delta_o); +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles; + if (slice_col == 0) { +#pragma unroll + for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride; + } + + // Update slice k/n for scales loading + if constexpr (has_act_order) { + slice_k_start = tb_k * slice_row; + slice_k_finish = slice_k_start + tb_k * slice_iters; + slice_k_start_shared_fetch = slice_k_start; + slice_n_offset = act_s_col_tb_stride * slice_col; + + } else { + s_gl_rd = s_sh_stride * slice_col + threadIdx.x; + zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x; + } + + start_pipes(); + } + } + } +} + +} // namespace device::marlin + +#endif diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh new file mode 100644 index 000000000..b61246029 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/rms_norm_gated.cuh @@ -0,0 +1,212 @@ +// Fused RMSNorm with optional SiLU gating for Qwen3.5 GDN attention. +// +// Computes: output = rmsnorm(x, weight, eps) * silu(z) (if z provided) +// output = rmsnorm(x, weight, eps) (if z is null) +// +// Where: rmsnorm(x) = x / sqrt(mean(x^2) + eps) * weight +// silu(z) = z * sigmoid(z) +// +// This kernel fuses both operations into a single pass over the data, +// maximizing memory bandwidth utilization. Each block processes one row +// (one token position). +// +// Supported dtypes: float16, bfloat16 (accumulation in float32). + +#pragma once + +#include +#include +#include + +#include +#include + +#include +#include +#include + +namespace RMSNormGatedKernel { + +// --------------------------------------------------------------------------- +// Warp-level reduction +// --------------------------------------------------------------------------- + +__device__ __forceinline__ float warp_reduce_sum(float val) { + #pragma unroll + for (int offset = 16; offset > 0; offset >>= 1) { + val += __shfl_xor_sync(0xffffffff, val, offset); + } + return val; +} + +// --------------------------------------------------------------------------- +// Type conversion helpers +// --------------------------------------------------------------------------- + +template +__device__ __forceinline__ float to_float(T val); + +template <> +__device__ __forceinline__ float to_float(half val) { + return __half2float(val); +} + +template <> +__device__ __forceinline__ float to_float<__nv_bfloat16>(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template <> +__device__ __forceinline__ float to_float(float val) { + return val; +} + +template +__device__ __forceinline__ T from_float(float val); + +template <> +__device__ __forceinline__ half from_float(float val) { + return __float2half(val); +} + +template <> +__device__ __forceinline__ __nv_bfloat16 from_float<__nv_bfloat16>(float val) { + return __float2bfloat16(val); +} + +template <> +__device__ __forceinline__ float from_float(float val) { + return val; +} + +// --------------------------------------------------------------------------- +// Main kernel +// --------------------------------------------------------------------------- + +template +__global__ void rms_norm_gated_kernel( + T* __restrict__ output, // [M, N] + const T* __restrict__ input, // [M, N] + const T* __restrict__ weight, // [N] + const T* __restrict__ gate, // [M, N] or nullptr + const int M, // number of rows + const int N, // number of columns (hidden_size) + const float eps +) { + const int row = blockIdx.x; + if (row >= M) return; + + const int tid = threadIdx.x; + const T* x_row = input + row * N; + T* out_row = output + row * N; + const T* z_row = (gate != nullptr) ? gate + row * N : nullptr; + + // --- Pass 1: compute sum of squares --- + float sum_sq = 0.0f; + for (int col = tid; col < N; col += BLOCK_SIZE) { + float val = to_float(x_row[col]); + sum_sq += val * val; + } + + // Block-level reduction + __shared__ float shared_sum[32]; // one per warp + int warp_id = tid / 32; + int lane_id = tid % 32; + + sum_sq = warp_reduce_sum(sum_sq); + if (lane_id == 0) { + shared_sum[warp_id] = sum_sq; + } + __syncthreads(); + + // Final reduction in first warp + if (warp_id == 0) { + float val = (lane_id < (BLOCK_SIZE / 32)) ? shared_sum[lane_id] : 0.0f; + val = warp_reduce_sum(val); + if (lane_id == 0) { + shared_sum[0] = val; + } + } + __syncthreads(); + + float rms = rsqrtf(shared_sum[0] / (float)N + eps); + + // --- Pass 2: normalize, scale by weight, optionally gate with silu(z) --- + for (int col = tid; col < N; col += BLOCK_SIZE) { + float val = to_float(x_row[col]); + float w = to_float(weight[col]); + + float normed = val * rms * w; + + if (z_row != nullptr) { + float z = to_float(z_row[col]); + // silu(z) = z * sigmoid(z) + float silu_z = z / (1.0f + expf(-z)); + normed *= silu_z; + } + + out_row[col] = from_float(normed); + } +} + +// --------------------------------------------------------------------------- +// Launch wrapper (called via TVM FFI) +// --------------------------------------------------------------------------- + +void run( + tvm::ffi::TensorView output, + tvm::ffi::TensorView input, + tvm::ffi::TensorView weight, + tvm::ffi::TensorView gate, // empty tensor (numel==0) means no gate + double eps +) { + using namespace mllm_kernel::host; + + auto M = SymbolicSize{"M"}; + auto N = SymbolicSize{"N"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + device.set_options(); + dtype.set_options(); + + (void)TensorMatcher({M, N}).with_dtype(dtype).with_device(device).verify(input); + (void)TensorMatcher({M, N}).with_dtype(dtype).with_device(device).verify(output); + (void)TensorMatcher({N}).with_dtype(dtype).with_device(device).verify(weight); + + const int rows = static_cast(M.unwrap()); + const int cols = static_cast(N.unwrap()); + const bool has_gate = (gate.numel() > 0); + + constexpr int BLOCK_SIZE = 256; + + if (dtype.is_type()) { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel, + static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } else if (dtype.is_type()) { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel<__nv_bfloat16, BLOCK_SIZE>, + static_cast<__nv_bfloat16*>(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } else { + LaunchKernel(rows, BLOCK_SIZE, device.unwrap())( + rms_norm_gated_kernel, + static_cast(output.data_ptr()), + static_cast(input.data_ptr()), + static_cast(weight.data_ptr()), + has_gate ? static_cast(gate.data_ptr()) : nullptr, + rows, cols, static_cast(eps) + ); + } +} + +} // namespace RMSNormGatedKernel diff --git a/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh b/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh new file mode 100644 index 000000000..05daabee0 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/csrc/store_cache.cuh @@ -0,0 +1,202 @@ +// Copyright SGLang Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// Store KV cache kernel: efficiently scatter key/value tensors into a +// pre-allocated KV cache pool using warp-level vectorized copies. +// +// Reference: sglang jit_kernel/csrc/elementwise/kvcache.cuh + +#pragma once + +#include +#include +#include + +#include +#include + +#include + +namespace { + +// ─────────────────────────────────────────────────────────────── +// Parameter block passed to the kernel via __grid_constant__ +// ─────────────────────────────────────────────────────────────── + +struct StoreKVCacheParams { + const void* __restrict__ k; + const void* __restrict__ v; + void* __restrict__ k_cache; + void* __restrict__ v_cache; + const void* __restrict__ indices; + int64_t stride_k_bytes; + int64_t stride_v_bytes; + int64_t stride_cache_bytes; + int64_t stride_indices; + uint32_t batch_size; +}; + +constexpr uint32_t kNumWarps = 4; +constexpr uint32_t kThreadsPerBlock = kNumWarps * device::kWarpThreads; + +// ─────────────────────────────────────────────────────────────── +// Vectorized warp-level KV copy +// ─────────────────────────────────────────────────────────────── +// +// Each warp copies kElementBytes of K data and kElementBytes of V +// data using the widest possible aligned vector type (uint4 = 16B, +// uint2 = 8B, or uint32_t = 4B). + +namespace detail { + +template +__device__ __forceinline__ void warp_copy_bytes(const void* __restrict__ src, void* __restrict__ dst, int64_t num_vecs) { + const int lane = threadIdx.x % device::kWarpThreads; + const auto* s = static_cast(src); + auto* d = static_cast(dst); + for (int64_t i = lane; i < num_vecs; i += device::kWarpThreads) { d[i] = s[i]; } +} + +} // namespace detail + +template +__device__ __forceinline__ void copy_kv_warp(const void* __restrict__ k_src, const void* __restrict__ v_src, + void* __restrict__ k_dst, void* __restrict__ v_dst) { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0, "Element size must be a positive multiple of 4 bytes"); + + // Pick the widest aligned vector type the element size supports. + if constexpr (kElementBytes % 16 == 0) { + constexpr int64_t N = kElementBytes / 16; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } else if constexpr (kElementBytes % 8 == 0) { + constexpr int64_t N = kElementBytes / 8; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } else { + constexpr int64_t N = kElementBytes / 4; + detail::warp_copy_bytes(k_src, k_dst, N); + detail::warp_copy_bytes(v_src, v_dst, N); + } +} + +// ─────────────────────────────────────────────────────────────── +// Main kernel +// ─────────────────────────────────────────────────────────────── +// +// Template parameters: +// kElementBytes total bytes per token row (head_num * head_dim * dtype_size) +// kSplit how many warps collaborate on one element (1, 2, or 4) +// kUsePDL whether to emit PDL synchronisation instructions +// T index dtype (int32_t or int64_t) + +template +__global__ void store_kvcache(const __grid_constant__ StoreKVCacheParams params) { + using namespace device; + constexpr auto kSplitSize = kElementBytes / kSplit; + + const uint32_t warp_id = blockIdx.x * kNumWarps + threadIdx.x / kWarpThreads; + const uint32_t item_id = warp_id / kSplit; + const uint32_t split_id = warp_id % kSplit; + + const auto& [k_input, v_input, k_cache, v_cache, indices, stride_k, stride_v, stride_cache, stride_indices, batch_size] = + params; + + if (item_id >= batch_size) return; + + const auto index_ptr = static_cast(indices) + item_id * stride_indices; + PDLWaitPrimary(); + + const auto index = *index_ptr; + const auto k_src = pointer::offset(k_input, item_id * stride_k, split_id * kSplitSize); + const auto v_src = pointer::offset(v_input, item_id * stride_v, split_id * kSplitSize); + const auto k_dst = pointer::offset(k_cache, index * stride_cache, split_id * kSplitSize); + const auto v_dst = pointer::offset(v_cache, index * stride_cache, split_id * kSplitSize); + + copy_kv_warp(k_src, v_src, k_dst, v_dst); + PDLTriggerSecondary(); +} + +template +struct StoreKVCacheKernel { + static_assert(kElementBytes > 0 && kElementBytes % 4 == 0); + + template + static constexpr auto store_kernel = store_kvcache; + + template + static auto get_kernel(int num_split) { + using namespace mllm_kernel::host; + if constexpr (kElementBytes % (4 * 128) == 0) { + if (num_split == 4) return store_kernel<4, T>; + } + if constexpr (kElementBytes % (2 * 128) == 0) { + if (num_split == 2) return store_kernel<2, T>; + } + if (num_split == 1) return store_kernel<1, T>; + Panic("Unsupported num_split ", num_split, " for element size ", kElementBytes); + } + + static void run(tvm::ffi::TensorView k, tvm::ffi::TensorView v, tvm::ffi::TensorView k_cache, tvm::ffi::TensorView v_cache, + tvm::ffi::TensorView indices, int num_split) { + using namespace mllm_kernel::host; + + auto B = SymbolicSize{"batch_size"}; + auto D = SymbolicSize{"element_size"}; + auto KS = SymbolicSize{"k_stride"}; + auto VS = SymbolicSize{"v_stride"}; + auto S = SymbolicSize{"cache_stride"}; + auto I = SymbolicSize{"indices_stride"}; + auto dtype = SymbolicDType{}; + auto device = SymbolicDevice{}; + auto indice_dtype = SymbolicDType{}; + device.set_options(); + + // k, v: [B, D] with strides [KS, 1] + (void)TensorMatcher({B, D}).with_strides({KS, 1}).with_dtype(dtype).with_device(device).verify(k); + (void)TensorMatcher({B, D}).with_strides({VS, 1}).with_dtype(dtype).with_device(device).verify(v); + + // k_cache, v_cache: [*, D] with strides [S, 1] + (void)TensorMatcher({-1, D}).with_strides({S, 1}).with_dtype(dtype).with_device(device).verify(k_cache).verify(v_cache); + + // indices: [B] with strides [I] + (void)TensorMatcher({B}).with_strides({I}).with_dtype(indice_dtype).with_device(device).verify(indices); + + const int64_t dtype_size = dtype_bytes(dtype.unwrap()); + const uint32_t num_elements = static_cast(B.unwrap()); + RuntimeCheck(kElementBytes == dtype_size * D.unwrap(), "Element size mismatch: expected ", kElementBytes, " but got ", + dtype_size * D.unwrap()); + + const auto params = StoreKVCacheParams{ + .k = k.data_ptr(), + .v = v.data_ptr(), + .k_cache = k_cache.data_ptr(), + .v_cache = v_cache.data_ptr(), + .indices = indices.data_ptr(), + .stride_k_bytes = KS.unwrap() * dtype_size, + .stride_v_bytes = VS.unwrap() * dtype_size, + .stride_cache_bytes = S.unwrap() * dtype_size, + .stride_indices = I.unwrap(), + .batch_size = num_elements, + }; + + const auto use_int32 = indice_dtype.is_type(); + const auto kernel = use_int32 ? get_kernel(num_split) : get_kernel(num_split); + const auto num_blocks = div_ceil(num_elements * num_split, kNumWarps); + + LaunchKernel(num_blocks, kThreadsPerBlock, device.unwrap()).enable_pdl(kUsePDL)(kernel, params); + } +}; + +} // namespace diff --git a/mllm-kernel/requirements.txt b/mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh similarity index 100% rename from mllm-kernel/requirements.txt rename to mllm-kernel/mllm_kernel/cuda/csrc/vocab_embedding.cuh diff --git a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py index 696e73ea0..1fe41f560 100644 --- a/mllm-kernel/mllm_kernel/cuda/jit/__init__.py +++ b/mllm-kernel/mllm_kernel/cuda/jit/__init__.py @@ -1,3 +1,14 @@ from .add_constant import add_constant +from .awq_marlin_repack import awq_marlin_repack +from .gdn_decode import gdn_decode +from .gptq_marlin import gptq_marlin_gemm +from .store_cache import can_use_store_cache, store_cache -__all__ = ["add_constant"] +__all__ = [ + "add_constant", + "awq_marlin_repack", + "can_use_store_cache", + "gdn_decode", + "gptq_marlin_gemm", + "store_cache", +] diff --git a/mllm-kernel/mllm_kernel/cuda/jit/awq_marlin_repack.py b/mllm-kernel/mllm_kernel/cuda/jit/awq_marlin_repack.py new file mode 100644 index 000000000..f13f50475 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/awq_marlin_repack.py @@ -0,0 +1,78 @@ +"""AWQ Marlin weight repack CUDA JIT kernel. + +Repacks AWQ-format quantized weights into Marlin kernel layout. + +Usage:: + + from mllm_kernel.cuda.jit.awq_marlin_repack import awq_marlin_repack + + out = awq_marlin_repack(b_q_weight, size_k, size_n, num_bits) +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_awq_marlin_repack_kernel(): + """JIT-compile the AWQ Marlin repack CUDA kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["gemm/marlin/awq_marlin_repack.cuh"], + cuda_wrappers=[("awq_marlin_repack", "awq_marlin_repack")], + func_name="awq_marlin_repack", + ) + def _kernel( + compiled_module, + out: torch.Tensor, + b_q_weight: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, + ) -> None: + compiled_module.awq_marlin_repack(out, b_q_weight, size_k, size_n, num_bits) + + return _kernel + + +def awq_marlin_repack( + b_q_weight: torch.Tensor, + size_k: int, + size_n: int, + num_bits: int, +) -> torch.Tensor: + """Repack AWQ-format quantized weights into Marlin kernel layout. + + Parameters + ---------- + b_q_weight : torch.Tensor + AWQ packed weight tensor, shape ``(size_k, size_n // pack_factor)``, + dtype ``int32``. + size_k : int + Number of input features (must be divisible by 16). + size_n : int + Number of output features (must be divisible by 64). + num_bits : int + Weight quantization bit-width (4 or 8). + + Returns + ------- + torch.Tensor + Repacked weight tensor in Marlin layout, shape + ``(size_k // 16, size_n * 16 // pack_factor)``, dtype ``int32``. + """ + tile_size = 16 + pack_factor = 32 // num_bits + out = torch.empty( + (size_k // tile_size, size_n * tile_size // pack_factor), + dtype=b_q_weight.dtype, + device=b_q_weight.device, + ) + kernel = _make_awq_marlin_repack_kernel() + kernel(out, b_q_weight, size_k, size_n, num_bits) + return out diff --git a/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py b/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py new file mode 100644 index 000000000..565686a40 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/create_kv_indices.py @@ -0,0 +1,118 @@ +"""High-performance CUDA JIT wrapper for create_kv_indices. + +This module exposes a single function: + + create_kv_indices(req_to_token, req_pool_indices, + page_kernel_lens, kv_indptr, + kv_start_idx, kv_indices) + +which is a Python binding around the C++/CUDA kernel defined in +`mllm_kernel/cuda/csrc/create_kv_indices.cuh`. + +The kernel transforms pymllm's 2-D ReqToTokenPool mapping table into the flat +`(kv_indptr, kv_indices)` layout expected by FlashInfer's paged KV attention +wrappers. It is carefully written for maximum throughput and is intended to +replace the Triton implementation `_create_kv_indices_triton` in +`pymllm.layers.attention.flashinfer_backend`. +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_create_kv_indices_kernel(): + """JIT-compile the CUDA kernel and return a callable wrapper. + + The JIT system will: + * locate `create_kv_indices.cuh` under the mllm-kernel CUDA csrc tree, + * compile it into a TVM FFI module, + * expose `CreateKvIndicesKernel::run` as `compiled_module.create_kv_indices`. + """ + + @jit( + args=[], + device="cuda", + cuda_files=["create_kv_indices.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("create_kv_indices", "CreateKvIndicesKernel::run"), + ], + func_name="create_kv_indices", + ) + def _kernel( + compiled_module, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + page_kernel_lens: torch.Tensor, + kv_indptr: torch.Tensor, + kv_start_idx: torch.Tensor, + kv_indices: torch.Tensor, + ) -> None: + compiled_module.create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + return _kernel + + +def create_kv_indices( + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + page_kernel_lens: torch.Tensor, + kv_indptr: torch.Tensor, + kv_start_idx: torch.Tensor | None, + kv_indices: torch.Tensor, +) -> None: + """Fill a flat KV-index buffer from the ReqToTokenPool mapping. + + This is a thin Python wrapper that forwards to the JIT-compiled CUDA + kernel. All tensors must be placed on the same CUDA device. + + Args + ---- + req_to_token: + Mapping tensor from ReqToTokenPool, shape + ``[max_reqs, max_context_len]``, dtype ``torch.int32``. + req_pool_indices: + Request slots participating in this batch, shape ``[batch_size]``, + dtype ``torch.int32``. + page_kernel_lens: + Per-sequence token counts (how many tokens to attend), shape + ``[batch_size]``, dtype ``torch.int32``. + kv_indptr: + Prefix sums over per-sequence token counts, shape ``[batch_size + 1]``, + dtype ``torch.int32``. ``kv_indptr[i]`` is the starting offset in + ``kv_indices`` for sequence ``i``. + kv_start_idx: + Optional starting positions inside each sequence, shape + ``[batch_size]`` or ``[0]``, dtype ``torch.int32``. When + ``None``, the kernel assumes 0 for all sequences. + kv_indices: + Output flat KV-index buffer, shape ``[N]``, dtype ``torch.int32``. + ``N`` must be at least ``kv_indptr[batch_size]``. + """ + if kv_start_idx is None: + # Use an empty tensor to signal "no start offsets". The C++ launcher + # treats length==0 as "no kv_start" and will pass a nullptr into the + # parameter block, which is slightly cheaper than materialising a + # full zero tensor on every call. + kv_start_idx = req_pool_indices.new_empty(0, dtype=torch.int32) + + kernel = _make_create_kv_indices_kernel() + kernel( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) diff --git a/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py b/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py new file mode 100644 index 000000000..53aaeaab3 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/gdn_decode.py @@ -0,0 +1,114 @@ +"""Fused GDN decode CUDA JIT kernel. + +Performs a single-token GDN (Gated Delta Net) recurrent update per request, +fusing gating + L2 normalization + delta rule + output computation into +one kernel. Works on SM80+ (Ampere, Jetson Orin, Hopper, ...). + +Usage:: + + from mllm_kernel.cuda.jit.gdn_decode import gdn_decode + + output = gdn_decode(q, k, v, a, b, A_log, dt_bias, state_pool, cache_indices) +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_gdn_decode_kernel(): + """JIT-compile the fused GDN decode CUDA kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["gdn_decode.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("gdn_decode", "GDNDecodeKernel::run"), + ], + func_name="gdn_decode", + ) + def _kernel( + compiled_module, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + state_pool: torch.Tensor, + cache_indices: torch.Tensor, + output: torch.Tensor, + ) -> None: + compiled_module.gdn_decode( + q, k, v, a, b, A_log, dt_bias, state_pool, cache_indices, output + ) + + return _kernel + + +def gdn_decode( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, + state_pool: torch.Tensor, + cache_indices: torch.Tensor, +) -> torch.Tensor: + """Fused GDN decode: gating + L2 norm + delta rule + output. + + Parameters + ---------- + q : torch.Tensor + Query tensor, shape ``(bs, num_k_heads, head_k_dim)``, bf16/fp16. + k : torch.Tensor + Key tensor, shape ``(bs, num_k_heads, head_k_dim)``, bf16/fp16. + v : torch.Tensor + Value tensor, shape ``(bs, num_v_heads, head_v_dim)``, bf16/fp16. + a : torch.Tensor + Decay gate input, shape ``(bs, num_v_heads)``, bf16/fp16. + b : torch.Tensor + Update gate input, shape ``(bs, num_v_heads)``, bf16/fp16. + A_log : torch.Tensor + Log-space decay parameter, shape ``(num_v_heads,)``, float32. + dt_bias : torch.Tensor + Bias for decay gate, shape ``(num_v_heads,)``, float32. + state_pool : torch.Tensor + Pooled recurrent state, shape ``(pool_size, num_v_heads, head_v_dim, head_k_dim)``, + float32. Modified in-place. + cache_indices : torch.Tensor + Pool indices per request, shape ``(bs,)``, int64. + + Returns + ------- + torch.Tensor + Output tensor, shape ``(bs, num_v_heads, head_v_dim)``, same dtype as v. + """ + bs = q.shape[0] + num_v_heads = v.shape[1] + head_v_dim = v.shape[2] + + output = torch.empty(bs, num_v_heads, head_v_dim, dtype=v.dtype, device=v.device) + + kernel = _make_gdn_decode_kernel() + kernel( + q.contiguous(), + k.contiguous(), + v.contiguous(), + a.contiguous(), + b.contiguous(), + A_log.contiguous(), + dt_bias.contiguous(), + state_pool, + cache_indices.to(torch.int64).contiguous(), + output, + ) + return output diff --git a/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py new file mode 100644 index 000000000..9eeefa765 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/gptq_marlin.py @@ -0,0 +1,213 @@ +"""GPTQ Marlin GEMM CUDA JIT kernel. + +Performs quantized matrix multiplication using the Marlin kernel for +GPTQ/AWQ-style W4A16 or W8A16 quantized weights. + +Usage:: + + from mllm_kernel.cuda.jit.gptq_marlin import gptq_marlin_gemm + + output = gptq_marlin_gemm( + a, c, b_q_weight, b_scales, global_scale, b_zeros, + g_idx, perm, workspace, b_q_type_id, + size_m, size_n, size_k, + ) +""" + +from __future__ import annotations + +from typing import Optional + +import torch + +from mllm_kernel.jit_utils import cache_once, jit, make_cpp_args + +# Constants matching device::marlin:: in marlin.cuh +_MAX_THREAD_N = 256 + + +@cache_once +def _make_gptq_marlin_gemm_kernel(dtype: torch.dtype): + """JIT-compile the GPTQ Marlin GEMM kernel for a specific dtype.""" + args = make_cpp_args(dtype) + + @jit( + args=args, + device="cuda", + cuda_files=["gemm/marlin/gptq_marlin.cuh"], + cuda_wrappers=[("gptq_marlin_gemm", f"gptq_marlin_gemm<{args}>")], + func_name="gptq_marlin_gemm", + ) + def _kernel( + compiled_module, + a: torch.Tensor, + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: torch.Tensor, + b_zeros: torch.Tensor, + g_idx: torch.Tensor, + perm: torch.Tensor, + c: torch.Tensor, + c_tmp: torch.Tensor, + a_tmp: torch.Tensor, + workspace: torch.Tensor, + b_q_type_id: int, + is_k_full: bool, + use_atomic_add: bool, + use_fp32_reduce: bool, + is_zp_float: bool, + ) -> None: + compiled_module.gptq_marlin_gemm( + a, + b_q_weight, + b_scales, + global_scale, + b_zeros, + g_idx, + perm, + c, + c_tmp, + a_tmp, + workspace, + b_q_type_id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) + + return _kernel + + +def _or_empty( + t: Optional[torch.Tensor], device: torch.device, dtype: torch.dtype +) -> torch.Tensor: + return t if t is not None else torch.empty(0, device=device, dtype=dtype) + + +def gptq_marlin_gemm( + a: torch.Tensor, + c: Optional[torch.Tensor], + b_q_weight: torch.Tensor, + b_scales: torch.Tensor, + global_scale: Optional[torch.Tensor], + b_zeros: Optional[torch.Tensor], + g_idx: Optional[torch.Tensor], + perm: Optional[torch.Tensor], + workspace: torch.Tensor, + b_q_type_id: int, + size_m: int, + size_n: int, + size_k: int, + is_k_full: bool = True, + use_atomic_add: bool = False, + use_fp32_reduce: bool = False, + is_zp_float: bool = False, +) -> torch.Tensor: + """Perform quantized GEMM using the Marlin kernel. + + Parameters + ---------- + a : torch.Tensor + Input activation tensor, shape ``(size_m, size_k)``, fp16 or bf16. + c : torch.Tensor or None + Output buffer, shape ``(size_m, size_n)``. Allocated if ``None``. + b_q_weight : torch.Tensor + Quantized weight in Marlin layout, int32. + b_scales : torch.Tensor + Per-group quantization scales. + global_scale : torch.Tensor or None + Global scale for FP8 quantization. + b_zeros : torch.Tensor or None + Per-group zero points (for AWQ-style asymmetric quantization). + g_idx : torch.Tensor or None + Group indices for activation reordering. + perm : torch.Tensor or None + Permutation indices for activation reordering. + workspace : torch.Tensor + Workspace buffer for synchronization. + b_q_type_id : int + ScalarType id for the quantized weight type. + size_m : int + Batch dimension. + size_n : int + Output dimension. + size_k : int + Reduction dimension. + is_k_full : bool + Whether the full K dimension is present (no TP split on K). + use_atomic_add : bool + Use atomic add for output reduction. + use_fp32_reduce : bool + Use fp32 for global reduction. + is_zp_float : bool + Whether zero points are float16 type. + + Returns + ------- + torch.Tensor + Output tensor, shape ``(size_m, size_n)``. + """ + device = a.device + + # Allocate output if not provided + if c is None: + c = torch.empty((size_m, size_n), dtype=a.dtype, device=device) + + # Early return for zero-size M + if size_m == 0: + return c + + # Determine activation ordering + has_act_order = ( + g_idx is not None + and perm is not None + and g_idx.numel() > 0 + and perm.numel() > 0 + ) + + # Allocate c_tmp for fp32 reduce + if use_fp32_reduce: + sms = torch.cuda.get_device_properties(device).multi_processor_count + max_m_block = min(((size_m + 15) // 16) * 16, 64) + c_tmp = torch.empty( + sms * max_m_block * _MAX_THREAD_N, + dtype=torch.float32, + device=device, + ) + else: + c_tmp = torch.empty(0, dtype=torch.float32, device=device) + + # Allocate a_tmp for act_order column permutation + if has_act_order: + a_tmp = torch.empty((size_m, size_k), dtype=a.dtype, device=device) + else: + a_tmp = torch.empty(0, dtype=a.dtype, device=device) + + # Convert Optional tensors to empty tensors + global_scale_t = _or_empty(global_scale, device, a.dtype) + b_zeros_t = _or_empty(b_zeros, device, torch.int32) + g_idx_t = _or_empty(g_idx, device, torch.int32) + perm_t = _or_empty(perm, device, torch.int32) + + kernel = _make_gptq_marlin_gemm_kernel(a.dtype) + kernel( + a, + b_q_weight, + b_scales, + global_scale_t, + b_zeros_t, + g_idx_t, + perm_t, + c, + c_tmp, + a_tmp, + workspace, + b_q_type_id, + is_k_full, + use_atomic_add, + use_fp32_reduce, + is_zp_float, + ) + + return c diff --git a/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py b/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py new file mode 100644 index 000000000..d7906a383 --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/rms_norm_gated.py @@ -0,0 +1,87 @@ +"""Fused RMSNorm + SiLU gating CUDA JIT kernel for Qwen3.5 GDN attention. + +Computes ``rmsnorm(x, weight, eps) * silu(z)`` in a single fused pass. + +Usage:: + + from mllm_kernel.cuda.jit.rms_norm_gated import rms_norm_gated + + output = rms_norm_gated(x, weight, z=gate, eps=1e-6) +""" + +from __future__ import annotations + +import torch + +from mllm_kernel.jit_utils import cache_once, jit + + +@cache_once +def _make_rms_norm_gated_kernel(): + """JIT-compile the fused RMSNorm+gating CUDA kernel.""" + + @jit( + args=[], + device="cuda", + cuda_files=["rms_norm_gated.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("rms_norm_gated", "RMSNormGatedKernel::run"), + ], + func_name="rms_norm_gated", + ) + def _kernel( + compiled_module, + output: torch.Tensor, + input: torch.Tensor, + weight: torch.Tensor, + gate: torch.Tensor, + eps: float, + ) -> None: + compiled_module.rms_norm_gated(output, input, weight, gate, eps) + + return _kernel + + +def rms_norm_gated( + x: torch.Tensor, + weight: torch.Tensor, + z: torch.Tensor | None = None, + eps: float = 1e-6, +) -> torch.Tensor: + """Fused RMSNorm with optional SiLU gating. + + Parameters + ---------- + x : torch.Tensor + Input tensor, shape ``(M, N)`` or ``(..., N)``. + weight : torch.Tensor + Normalization weight, shape ``(N,)``. + z : torch.Tensor or None + Optional gating tensor, same shape as ``x``. + If provided: ``output = rmsnorm(x) * silu(z)`` + eps : float + Epsilon for numerical stability. + + Returns + ------- + torch.Tensor + Output with same shape and dtype as ``x``. + """ + x_shape = x.shape + x_2d = x.reshape(-1, x.shape[-1]) + + if z is not None: + z_2d = z.reshape(-1, z.shape[-1]) + if z_2d.stride(-1) != 1: + z_2d = z_2d.contiguous() + else: + z_2d = x.new_empty(0) # empty tensor signals "no gate" to the kernel + + if x_2d.stride(-1) != 1: + x_2d = x_2d.contiguous() + + output = torch.empty_like(x_2d) + kernel = _make_rms_norm_gated_kernel() + kernel(output, x_2d, weight.contiguous(), z_2d, eps) + return output.reshape(x_shape) diff --git a/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py b/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py new file mode 100644 index 000000000..96a73f5ef --- /dev/null +++ b/mllm-kernel/mllm_kernel/cuda/jit/store_cache.py @@ -0,0 +1,127 @@ +# Copyright (c) MLLM Team. +# Licensed under the MIT License. +# +# Python interface for the store_cache CUDA kernel. +# Efficiently scatters key/value tensors into a pre-allocated KV cache pool. + +from __future__ import annotations + +import logging +import torch +from mllm_kernel.jit_utils import jit +from mllm_kernel.jit_utils.compile import cache_once, make_cpp_args + + +logger = logging.getLogger(__name__) + + +@cache_once +def _is_arch_support_pdl() -> bool: + if not torch.cuda.is_available(): + return False + major, minor = torch.cuda.get_device_capability() + # PDL requires sm_90a (Hopper) or later + return major > 9 or (major == 9 and minor >= 0) + + +def _make_store_cache_kernel(row_bytes: int): + """Create a JIT-compiled store_cache kernel for the given row_bytes.""" + pdl = _is_arch_support_pdl() + cpp_args = make_cpp_args(row_bytes, pdl) + + @jit( + args=[row_bytes, pdl], + device="cuda", + cuda_files=["store_cache.cuh"], + cpp_wrappers=[], + cuda_wrappers=[ + ("store_cache", f"StoreKVCacheKernel<{cpp_args}>::run"), + ], + func_name="store_cache", + ) + def _kernel( + compiled_module, + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + num_split: int, + ) -> None: + compiled_module.store_cache(k, v, k_cache, v_cache, indices, num_split) + + return _kernel + + +_KERNEL_CACHE: dict[int, object] = {} + + +def _get_kernel(row_bytes: int): + if row_bytes not in _KERNEL_CACHE: + _KERNEL_CACHE[row_bytes] = _make_store_cache_kernel(row_bytes) + return _KERNEL_CACHE[row_bytes] + + +@cache_once +def can_use_store_cache(row_bytes: int) -> bool: + """Check whether the JIT store_cache kernel supports the given row size. + + Returns ``False`` if *row_bytes* is not a multiple of 4 or if the JIT + compilation fails for any reason. + """ + if row_bytes % 4 != 0: + logger.warning( + "Unsupported row_bytes=%d for JIT store_cache kernel: " + "must be multiple of 4", + row_bytes, + ) + return False + try: + _get_kernel(row_bytes) + return True + except Exception as e: + logger.warning( + "Failed to load JIT store_cache kernel with row_bytes=%d: %s", + row_bytes, + e, + ) + return False + + +def store_cache( + k: torch.Tensor, + v: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + indices: torch.Tensor, + *, + row_bytes: int = 0, + num_split: int = 0, +) -> None: + """Store key and value tensors into a KV cache at specified indices. + + Each row of *k* (and *v*) is scattered into *k_cache* (and *v_cache*) + at the location given by the corresponding entry in *indices*. + + Args: + k: Key tensor, shape ``(batch_size, head_num * head_dim)``. + v: Value tensor, shape ``(batch_size, head_num * head_dim)``. + k_cache: Key cache, shape ``(num_slots, head_num * head_dim)``. + v_cache: Value cache, shape ``(num_slots, head_num * head_dim)``. + indices: Index tensor, shape ``(batch_size,)``, dtype int32 or int64. + row_bytes: Bytes per row. Auto-detected from *k* when 0. + num_split: Number of warps that cooperate on each element (1, 2, or 4). + When 0 the best value is chosen automatically based on alignment. + """ + row_bytes = row_bytes or k.shape[-1] * k.element_size() + kernel = _get_kernel(row_bytes) + + if num_split <= 0: + if row_bytes % 2048 == 0: + num_split = 4 + elif row_bytes % 1024 == 0: + num_split = 2 + else: + num_split = 1 + + kernel(k, v, k_cache, v_cache, indices, num_split) diff --git a/mllm-kernel/pyproject.toml b/mllm-kernel/pyproject.toml index f64e1306e..13147f068 100644 --- a/mllm-kernel/pyproject.toml +++ b/mllm-kernel/pyproject.toml @@ -18,7 +18,7 @@ dependencies = [ "packaging", "torch", "torch-c-dlpack-ext", - "apache-tvm-ffi", + "apache-tvm-ffi == 0.1.8.post2", ] [project.optional-dependencies] @@ -27,6 +27,9 @@ dev = [ "pytest-html", ] +[project.scripts] +mllm-kernel = "mllm_kernel.__main__:main" + [tool.scikit-build] # Build configuration wheel.py-api = "py3" @@ -52,7 +55,7 @@ logging.level = "INFO" # Wheel configuration - include the Python package wheel.packages = ["mllm_kernel"] -wheel.install-dir = "mllm_kernel" +wheel.install-dir = "" # Install directories for cmake targets wheel.cmake = true diff --git a/mllm-kernel/tests/test_create_kv_indices.py b/mllm-kernel/tests/test_create_kv_indices.py new file mode 100644 index 000000000..e8bf770a3 --- /dev/null +++ b/mllm-kernel/tests/test_create_kv_indices.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +import pytest +import torch + +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + + +def _make_batch( + *, + max_reqs: int, + max_ctx: int, + batch_size: int, + use_start_offsets: bool, + seed: int = 0, +): + """Construct a random-but-bounded test batch for create_kv_indices. + + The constraints ensure that for every sequence i: + 0 <= kv_start_idx[i] + 0 < page_kernel_lens[i] + kv_start_idx[i] + page_kernel_lens[i] <= max_ctx + so the kernel never reads beyond the ReqToTokenPool row. + """ + # Use a CUDA generator for randperm (which requires matching device) + # and a separate CPU generator for randint (which only accepts CPU). + g_cuda = torch.Generator(device="cuda").manual_seed(seed) + g_cpu = torch.Generator(device="cpu").manual_seed(seed) + + device = "cuda" + # req_to_token[req_slot, position] -> kv_index (here we simply use a + # monotonically increasing pattern so correctness is easy to check). + req_to_token = torch.arange( + max_reqs * max_ctx, dtype=torch.int32, device=device + ).reshape(max_reqs, max_ctx) + + # Sample distinct request slots for the batch. + assert batch_size <= max_reqs + req_pool_indices = torch.randperm(max_reqs, generator=g_cuda, device=device)[ + :batch_size + ].to(torch.int32) + + # For each sequence choose a valid (start, length) pair. + page_kernel_lens_list = [] + kv_start_idx_list = [] + for _ in range(batch_size): + # ensure at least 1 token per sequence + L = int(torch.randint(1, max_ctx, (1,), generator=g_cpu).item()) + if use_start_offsets: + start_max = max_ctx - L + start = int(torch.randint(0, max(start_max, 1), (1,), generator=g_cpu).item()) + else: + start = 0 + page_kernel_lens_list.append(L) + kv_start_idx_list.append(start) + + page_kernel_lens = torch.tensor( + page_kernel_lens_list, dtype=torch.int32, device=device + ) + kv_start_idx = torch.tensor(kv_start_idx_list, dtype=torch.int32, device=device) + + # Build kv_indptr prefix sums. + kv_indptr = torch.empty(batch_size + 1, dtype=torch.int32, device=device) + kv_indptr[0] = 0 + kv_indptr[1:] = torch.cumsum(page_kernel_lens, dim=0) + + kv_indices = torch.empty( + int(kv_indptr[-1].item()), dtype=torch.int32, device=device + ) + + return ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("use_start_offsets", [False, True]) +@pytest.mark.parametrize( + "batch_size,max_reqs,max_ctx", + [ + (1, 4, 16), # minimal batch + (4, 8, 64), # small batch + (32, 64, 512), # medium batch, longer context + (128, 256, 2048), # larger batch, stress inner loop + ], +) +def test_create_kv_indices_matches_reference( + use_start_offsets: bool, + batch_size: int, + max_reqs: int, + max_ctx: int, +): + """create_kv_indices must match a naive PyTorch reference implementation. + + The reference is computed on CPU using explicit loops over + (request_slot, start, length); the CUDA kernel must produce identical + flat kv_indices for the same inputs. + """ + ( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx, + kv_indices, + ) = _make_batch( + max_reqs=max_reqs, + max_ctx=max_ctx, + batch_size=batch_size, + use_start_offsets=use_start_offsets, + seed=2026, + ) + + # Call CUDA kernel (kv_start_idx can be None to exercise that path). + create_kv_indices( + req_to_token, + req_pool_indices, + page_kernel_lens, + kv_indptr, + kv_start_idx if use_start_offsets else None, + kv_indices, + ) + torch.cuda.synchronize() + + # Naive reference on CPU. + req_to_token_cpu = req_to_token.cpu() + req_pool_indices_cpu = req_pool_indices.cpu().to(torch.long) + page_kernel_lens_cpu = page_kernel_lens.cpu() + kv_start_idx_cpu = kv_start_idx.cpu() + + ref_segments = [] + for i in range(batch_size): + req = req_pool_indices_cpu[i].item() + start = kv_start_idx_cpu[i].item() if use_start_offsets else 0 + L = page_kernel_lens_cpu[i].item() + row = req_to_token_cpu[req, start : start + L] + ref_segments.append(row) + ref = torch.cat(ref_segments, dim=0) + + assert kv_indices.shape == ref.shape + assert torch.equal(kv_indices.cpu(), ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_single_token_per_sequence(): + """Each sequence has exactly 1 token — exercises the minimal-work path.""" + device = "cuda" + bs = 8 + max_ctx = 32 + req_to_token = torch.arange(bs * max_ctx, dtype=torch.int32, device=device).reshape(bs, max_ctx) + req_pool_indices = torch.arange(bs, dtype=torch.int32, device=device) + page_kernel_lens = torch.ones(bs, dtype=torch.int32, device=device) + kv_indptr = torch.arange(bs + 1, dtype=torch.int32, device=device) + kv_indices = torch.empty(bs, dtype=torch.int32, device=device) + + create_kv_indices(req_to_token, req_pool_indices, page_kernel_lens, kv_indptr, None, kv_indices) + torch.cuda.synchronize() + + # Each sequence contributes req_to_token[i, 0]. + expected = req_to_token[:, 0] + assert torch.equal(kv_indices, expected) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_oversized_output_buffer(): + """kv_indices buffer is larger than needed (prefill path uses +256 padding).""" + device = "cuda" + bs = 4 + max_ctx = 64 + req_to_token = torch.arange(bs * max_ctx, dtype=torch.int32, device=device).reshape(bs, max_ctx) + req_pool_indices = torch.arange(bs, dtype=torch.int32, device=device) + page_kernel_lens = torch.full((bs,), 10, dtype=torch.int32, device=device) + kv_indptr = torch.arange(0, bs * 10 + 1, 10, dtype=torch.int32, device=device) + # Allocate with extra padding, like the prefill path does. + kv_indices = torch.full((bs * 10 + 256,), -1, dtype=torch.int32, device=device) + + create_kv_indices(req_to_token, req_pool_indices, page_kernel_lens, kv_indptr, None, kv_indices) + torch.cuda.synchronize() + + # First bs*10 entries should match; padding should remain -1. + ref_segments = [] + for i in range(bs): + ref_segments.append(req_to_token[i, :10]) + ref = torch.cat(ref_segments, dim=0) + assert torch.equal(kv_indices[:bs * 10], ref) + assert torch.all(kv_indices[bs * 10:] == -1) diff --git a/mllm-kernel/tests/test_store_cache.py b/mllm-kernel/tests/test_store_cache.py new file mode 100644 index 000000000..5e4f1bcc3 --- /dev/null +++ b/mllm-kernel/tests/test_store_cache.py @@ -0,0 +1,66 @@ +from __future__ import annotations + +import pytest +import torch + +from mllm_kernel.cuda.jit import can_use_store_cache, store_cache + + +def _make_inputs( + *, + batch_size: int, + num_slots: int, + row_dim: int, + dtype: torch.dtype, + index_dtype: torch.dtype, + seed: int = 0, +): + torch.manual_seed(seed) + device = "cuda" + k = torch.randn(batch_size, row_dim, device=device, dtype=dtype) + v = torch.randn(batch_size, row_dim, device=device, dtype=dtype) + # Use unique indices to avoid write conflicts on the same cache slot. + indices = torch.randperm(num_slots, device=device)[:batch_size].to(index_dtype) + k_cache = torch.zeros(num_slots, row_dim, device=device, dtype=dtype) + v_cache = torch.zeros_like(k_cache) + return k, v, k_cache, v_cache, indices + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +@pytest.mark.parametrize("dtype", [torch.float16, torch.float32]) +@pytest.mark.parametrize("index_dtype", [torch.int32, torch.int64]) +def test_store_cache_matches_torch_index(dtype: torch.dtype, index_dtype: torch.dtype): + batch_size = 257 + num_slots = 4096 + row_dim = 8 * 128 # 1024 -> fp16 row_bytes=2048 + row_bytes = row_dim * torch.tensor([], dtype=dtype).element_size() + + assert can_use_store_cache(row_bytes), f"store_cache unavailable for row_bytes={row_bytes}" + + k, v, k_cache, v_cache, indices = _make_inputs( + batch_size=batch_size, + num_slots=num_slots, + row_dim=row_dim, + dtype=dtype, + index_dtype=index_dtype, + seed=2026, + ) + + k_ref = k_cache.clone() + v_ref = v_cache.clone() + k_ref[indices] = k + v_ref[indices] = v + + store_cache(k, v, k_cache, v_cache, indices) + torch.cuda.synchronize() + + assert torch.equal(k_cache, k_ref) + assert torch.equal(v_cache, v_ref) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA is required") +def test_can_use_store_cache_rejects_invalid_row_bytes(): + assert not can_use_store_cache(2) + assert not can_use_store_cache(6) + assert can_use_store_cache(4) + diff --git a/mllm/backends/cpu/CPUBackend.cpp b/mllm/backends/cpu/CPUBackend.cpp index 0964cba0d..f4b909913 100644 --- a/mllm/backends/cpu/CPUBackend.cpp +++ b/mllm/backends/cpu/CPUBackend.cpp @@ -14,6 +14,7 @@ #include "mllm/backends/cpu/ops/ConcatOp.hpp" #include "mllm/backends/cpu/ops/ContiguousOp.hpp" #include "mllm/backends/cpu/ops/Conv1DOp.hpp" +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" #include "mllm/backends/cpu/ops/Conv2DOp.hpp" #include "mllm/backends/cpu/ops/Conv3DOp.hpp" #include "mllm/backends/cpu/ops/CopyOp.hpp" @@ -52,6 +53,7 @@ #include "mllm/backends/cpu/ops/Scatter2ShardsOp.hpp" #include "mllm/backends/cpu/ops/SiLUOp.hpp" #include "mllm/backends/cpu/ops/SigmoidOp.hpp" +#include "mllm/backends/cpu/ops/TanhOp.hpp" #include "mllm/backends/cpu/ops/SliceOp.hpp" #include "mllm/backends/cpu/ops/SoftmaxOp.hpp" #include "mllm/backends/cpu/ops/SplitOp.hpp" @@ -78,12 +80,12 @@ CPUBackend::CPUBackend() : Backend(kCPU, createCPUAllocator()) { CPUSiLUOpFactory, CPUSigmoidOpFactory, CPURMSNormOpFactory, CPUGELUOpFactory, CPUQuickGELUOpFactory, CPUReLUOpFactory, CPUMatMulOpFactory, CPUFlashAttention2OpFactory, CPUSliceOpFactory, CPUVisionRoPEOpFactory, CPUParamOpFactory, CPUMultimodalRoPEOpFactory, CPURoPEOpFactory, CPUCausalMaskOpFactory, CPUConv1DOpFactory, - CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, CPUTopKOpFactory, CPUClipOpFactory, - CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, - CPUConv2DOpFactory, CPULayerNorm2DOpFactory, CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, - CPUArgsortOpFactory, CPUCloneOpFactory, CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, - CPURadixAttnRelaxOpFactory, CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, - CPUGatherOpFactory>(); + CPUConvTranspose1DOpFactory, CPUConv3DOpFactory, CPUSTFTOpFactory, CPUISTFTOpFactory, CPUIndexOpFactory, + CPUTopKOpFactory, CPUClipOpFactory, CPUMeanOpFactory, CPUKVCacheOpFactory, CPUPagedAttnOpFactory, + CPUScatter2ShardsOpFactory, CPURadixAttnOpFactory, CPUConv2DOpFactory, CPULayerNorm2DOpFactory, + CPUInterpolateOpFactory, CPUPadOpFactory, CPUMaskedScatterOpFactory, CPUArgsortOpFactory, CPUCloneOpFactory, + CPUAvgPool1dOpFactory, CPUFlashAttention2SwaSinkOpFactory, CPURadixAttnRelaxOpFactory, + CPURadixAttnSwaSinkOpFactory, CPUEqualOpFactory, CPUWhereOpFactory, CPUGatherOpFactory, CPUTanhOpFactory>(); } CPUBackend::~CPUBackend() { diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..15a8097d1 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.cpp @@ -0,0 +1,93 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/backends/cpu/ops/ConvTranspose1DOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUConvTranspose1DOp::CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options) + : aops::ConvTranspose1DOp(options) {} + +void CPUConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + auto& input = inputs[0]; + auto& output = outputs[0]; + + auto i_shape = input.shape(); + auto o_shape = output.shape(); + + // input shape: [batch, in_channels, sequence] + // output shape: [batch, out_channels, out_sequence] + const int batch = i_shape[0]; + const int in_channels = i_shape[1]; + const int sequence = i_shape[2]; + + const int out_channels = o_shape[1]; + const int out_sequence = o_shape[2]; + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int padding = options_.padding; + const int dilation = options_.dilation; + const int groups = options_.groups; + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(output.dtype(), kFloat32); + MLLM_RT_ASSERT(weight_.dtype() == kFloat32); + const auto* weight_ptr = weight_.ptr(); + const auto* input_ptr = input.ptr(); + auto* output_ptr = output.ptr(); + + float* bias_ptr = nullptr; + if (options_.bias && !bias_.isNil()) { bias_ptr = bias_.ptr(); } + + std::fill_n(output_ptr, output.numel(), 0.0f); + + const int total_iterations = batch * out_channels * out_sequence; + + switch (output.dtype()) { + case kFloat32: + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, total_iterations, 1, { + int b = idx / (out_channels * out_sequence); + int oc = (idx % (out_channels * out_sequence)) / out_sequence; + int out_pos = idx % out_sequence; + + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + + float sum = 0.0f; + + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + + sum += input_ptr[input_idx] * weight_ptr[weight_idx]; + } + } + + if (bias_ptr) { sum += bias_ptr[oc]; } + + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_ptr[output_idx] = sum; + }); + break; + default: NYI("ConvTranspose1D: unsupported data type"); + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..fd1163ed3 --- /dev/null +++ b/mllm/backends/cpu/ops/ConvTranspose1DOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::cpu { + +class CPUConvTranspose1DOp final : public aops::ConvTranspose1DOp { + public: + explicit CPUConvTranspose1DOp(const aops::ConvTranspose1DOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUConvTranspose1DOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::ConvTranspose1DOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/EmbeddingOp.cpp b/mllm/backends/cpu/ops/EmbeddingOp.cpp index 71af75f68..f25849f7a 100644 --- a/mllm/backends/cpu/ops/EmbeddingOp.cpp +++ b/mllm/backends/cpu/ops/EmbeddingOp.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include +#include #include "mllm/backends/cpu/ops/EmbeddingOp.hpp" #include "mllm/core/DataTypes.hpp" @@ -22,30 +23,40 @@ void CPUEmbeddingOp::forward(const std::vector& inputs, std::vector 0); + MLLM_RT_ASSERT(options_.hidden_size > 0); + + static std::atomic warned_token_oob{false}; const bool use_parallel = options_.getThreads() > 1; const int thread_count = options_.getThreads(); for (int b = 0; b < B; ++b) { MLLM_CONDITIONAL_PARALLEL_FOR(use_parallel, thread_count, s, 0, S, 1, { - switch (weight_dtype) { - case kFloat32: - std::memcpy(ous.coffsettedPtr({b, (int)s, 0}), - weight_.ptr() + options_.hidden_size * (*ins.coffsettedPtr({b, (int)s})), - options_.hidden_size * sizeof(float)); - break; - case kFloat16: - std::memcpy(ous.coffsettedPtr({b, (int)s, 0}), - weight_.ptr() + options_.hidden_size * (*ins.coffsettedPtr({b, (int)s})), - options_.hidden_size * sizeof(mllm_fp16_t)); - break; - case kGGUF_Q4_K: { - auto token_idx = *ins.coffsettedPtr({b, (int)s}); - if (token_idx >= 0) { + const auto token_idx = *ins.coffsettedPtr({b, (int)s}); + auto* out_ptr = ous.coffsettedPtr({b, (int)s, 0}); + if (token_idx < 0 || token_idx >= options_.vocab_size) { + std::memset(out_ptr, 0, options_.hidden_size * bytesOfType(ous.dtype())); + bool expected = false; + if (warned_token_oob.compare_exchange_strong(expected, true)) { + MLLM_WARN("Embedding token index out of range (idx={}, vocab={}), output row is zero-filled.", + token_idx, options_.vocab_size); + } + } else { + switch (weight_dtype) { + case kFloat32: + std::memcpy(out_ptr, weight_.ptr() + options_.hidden_size * token_idx, + options_.hidden_size * sizeof(float)); + break; + case kFloat16: + std::memcpy(out_ptr, weight_.ptr() + options_.hidden_size * token_idx, + options_.hidden_size * sizeof(mllm_fp16_t)); + break; + case kGGUF_Q4_K: { dequantize_row_q4_K(weight_.ptr() + token_idx * options_.hidden_size / QK_K, ous.coffsettedPtr({b, (int)s, 0}), options_.hidden_size); + break; } - break; - } + case kGGUF_Q4_0: { auto token_idx = *ins.coffsettedPtr({b, (int)s}); if (token_idx >= 0) { @@ -56,6 +67,7 @@ void CPUEmbeddingOp::forward(const std::vector& inputs, std::vector + +#include "mllm/backends/cpu/ops/TanhOp.hpp" +#include "mllm/core/Parallel.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::cpu { + +CPUTanhOp::CPUTanhOp(const aops::TanhOpOptions& options) : aops::TanhOp(options) {} + +void CPUTanhOp::forward(const std::vector& inputs, std::vector& outputs) { + const auto& X = inputs[0]; + auto& Y = outputs[0]; + + const auto numel = X.numel(); + + switch (X.dtype()) { + case kFloat32: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + y_ptr[idx] = std::tanh(x_ptr[idx]); + }); + break; + } + case kFloat16: { + const auto* x_ptr = X.ptr(); + auto* y_ptr = Y.ptr(); + MLLM_CONDITIONAL_PARALLEL_FOR(options_.getThreads() > 1, 4, idx, 0, numel, 1, { + float v = static_cast(x_ptr[idx]); + y_ptr[idx] = static_cast(std::tanh(v)); + }); + break; + } + default: NYI("CPUTanhOp::forward not support dtype {}", nameOfType(X.dtype())); break; + } +} + +} // namespace mllm::cpu diff --git a/mllm/backends/cpu/ops/TanhOp.hpp b/mllm/backends/cpu/ops/TanhOp.hpp new file mode 100644 index 000000000..c88fae9ce --- /dev/null +++ b/mllm/backends/cpu/ops/TanhOp.hpp @@ -0,0 +1,25 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::cpu { + +class CPUTanhOp final : public aops::TanhOp { + public: + explicit CPUTanhOp(const aops::TanhOpOptions& options); + + void forward(const std::vector& inputs, std::vector& outputs) override; +}; + +class CPUTanhOpFactory : public TypedOpFactory { + public: + std::shared_ptr createOpImpl(const aops::TanhOpOptions& options) override { + return std::make_shared(options); + } +}; + +} // namespace mllm::cpu diff --git a/mllm/compile/ir/GeneratedRTTIKind.hpp b/mllm/compile/ir/GeneratedRTTIKind.hpp index d100dc621..2d83493d1 100644 --- a/mllm/compile/ir/GeneratedRTTIKind.hpp +++ b/mllm/compile/ir/GeneratedRTTIKind.hpp @@ -44,6 +44,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_RepeatOp, RK_Op_LinalgIROp_PermuteOp, RK_Op_LinalgIROp_Conv1DOp, + RK_Op_LinalgIROp_ConvTranspose1DOp, RK_Op_LinalgIROp_Conv2DOp, RK_Op_LinalgIROp_Conv3DOp, RK_Op_LinalgIROp_GELUOp, @@ -86,6 +87,7 @@ enum NodeKind : uint32_t { RK_Op_LinalgIROp_EqualOp, RK_Op_LinalgIROp_WhereOp, RK_Op_LinalgIROp_SigmoidOp, + RK_Op_LinalgIROp_TanhOp, RK_Op_LinalgIROp_CustomizedOp, RK_Op_LinalgIROp_Last, RK_Op_GraphIROp, diff --git a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp index 6a98797a9..4c3313cf9 100644 --- a/mllm/compile/ir/NodeRTTIClassOfImpl.hpp +++ b/mllm/compile/ir/NodeRTTIClassOfImpl.hpp @@ -102,6 +102,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_CONV1DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv1DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv1DOp +#define RTTI_RK_OP_LINALGIROP_CONVTRANSPOSE1DOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_ConvTranspose1DOp && (v)->getKind() <= RK_Op_LinalgIROp_ConvTranspose1DOp + #define RTTI_RK_OP_LINALGIROP_CONV2DOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_Conv2DOp && (v)->getKind() <= RK_Op_LinalgIROp_Conv2DOp @@ -229,6 +232,9 @@ struct NodeRTTIClassOfImpl { #define RTTI_RK_OP_LINALGIROP_SIGMOIDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_SigmoidOp && (v)->getKind() <= RK_Op_LinalgIROp_SigmoidOp +#define RTTI_RK_OP_LINALGIROP_TANHOP_IMPL(v) \ + return (v)->getKind() >= RK_Op_LinalgIROp_TanhOp && (v)->getKind() <= RK_Op_LinalgIROp_TanhOp + #define RTTI_RK_OP_LINALGIROP_CUSTOMIZEDOP_IMPL(v) \ return (v)->getKind() >= RK_Op_LinalgIROp_CustomizedOp && (v)->getKind() <= RK_Op_LinalgIROp_CustomizedOp diff --git a/mllm/compile/ir/linalg/Op.cpp b/mllm/compile/ir/linalg/Op.cpp index bb4e2fb9d..ad05e9437 100644 --- a/mllm/compile/ir/linalg/Op.cpp +++ b/mllm/compile/ir/linalg/Op.cpp @@ -55,6 +55,7 @@ LINALG_AOPS_DECL(OpTypes::kTranspose, TransposeOp); LINALG_AOPS_DECL(OpTypes::kRMSNorm, RMSNormOp); LINALG_AOPS_DECL(OpTypes::kSiLU, SiLUOp); LINALG_AOPS_DECL(OpTypes::kSigmoid, SigmoidOp); +LINALG_AOPS_DECL(OpTypes::kTanh, TanhOp); LINALG_AOPS_DECL(OpTypes::kCastType, CastTypeOp); @@ -70,6 +71,7 @@ LINALG_AOPS_DECL(OpTypes::kRepeat, RepeatOp); LINALG_AOPS_DECL(OpTypes::kPermute, PermuteOp); LINALG_AOPS_DECL(OpTypes::kConv1D, Conv1DOp); +LINALG_AOPS_DECL(OpTypes::kConvTranspose1D, ConvTranspose1DOp); LINALG_AOPS_DECL(OpTypes::kConv2D, Conv2DOp); LINALG_AOPS_DECL(OpTypes::kConv3D, Conv3DOp); diff --git a/mllm/compile/ir/linalg/Op.hpp b/mllm/compile/ir/linalg/Op.hpp index 02d04400b..6e6de4785 100644 --- a/mllm/compile/ir/linalg/Op.hpp +++ b/mllm/compile/ir/linalg/Op.hpp @@ -29,6 +29,7 @@ class TransposeOp; class RMSNormOp; class SiLUOp; class SigmoidOp; +class TanhOp; class CausalMaskOp; class CastTypeOp; class X2XOp; @@ -38,6 +39,7 @@ class FlashAttention2Op; class RepeatOp; class PermuteOp; class Conv1DOp; +class ConvTranspose1DOp; class Conv2DOp; class Conv3DOp; class GELUOp; @@ -188,6 +190,7 @@ LINALG_AOPS_DEFINE(TransposeOp, TRANSPOSEOP); LINALG_AOPS_DEFINE(RMSNormOp, RMSNORMOP); LINALG_AOPS_DEFINE(SiLUOp, SILUOP); LINALG_AOPS_DEFINE(SigmoidOp, SIGMOIDOP); +LINALG_AOPS_DEFINE(TanhOp, TANHOP); LINALG_AOPS_DEFINE(CastTypeOp, CASTTYPEOP); @@ -201,6 +204,7 @@ LINALG_AOPS_DEFINE(RepeatOp, REPEATOP); LINALG_AOPS_DEFINE(PermuteOp, PERMUTEOP); LINALG_AOPS_DEFINE(Conv1DOp, CONV1DOP); +LINALG_AOPS_DEFINE(ConvTranspose1DOp, CONVTRANSPOSE1DOP); LINALG_AOPS_DEFINE(Conv2DOp, CONV2DOP); LINALG_AOPS_DEFINE(Conv3DOp, CONV3DOP); diff --git a/mllm/core/OpTypes.hpp b/mllm/core/OpTypes.hpp index 310b39cd0..d64d484fe 100644 --- a/mllm/core/OpTypes.hpp +++ b/mllm/core/OpTypes.hpp @@ -96,6 +96,8 @@ enum class OpTypes : int32_t { kWhere = 74, kSigmoid = 75, + kTanh = 76, + kConvTranspose1D = 77, // Dynamic Op Start for user to register there own ops. kDynamicOp_Start = 4096, @@ -181,6 +183,8 @@ inline std::string optype2Str(OpTypes type) { case OpTypes::kEqual: return "Equal"; case OpTypes::kWhere: return "Where"; case OpTypes::kSigmoid: return "Sigmoid"; + case OpTypes::kTanh: return "Tanh"; + case OpTypes::kConvTranspose1D: return "ConvTranspose1D"; case OpTypes::kDynamicOp_Start: return "DynamicOp_Start"; case OpTypes::kOpType_End: return "OpType_End"; default: return "Unknown"; diff --git a/mllm/core/aops/ConvTranspose1DOp.cpp b/mllm/core/aops/ConvTranspose1DOp.cpp new file mode 100644 index 000000000..25d1b5935 --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.cpp @@ -0,0 +1,95 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/ConvTranspose1DOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" +#include "mllm/compile/ir/graph/Op.hpp" +#include "mllm/compile/ir/tensor/Op.hpp" + +namespace mllm::aops { + +ConvTranspose1DOp::ConvTranspose1DOp(const ConvTranspose1DOpOptions& options) + : BaseOp(OpTypes::kConvTranspose1D), options_(options) {} + +void ConvTranspose1DOp::load(const ParameterFile::ptr_t& ploader) { + switch (ploader->version()) { + case ModelFileVersion::kV1: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + weight_ = weight_.view({options_.in_channels, options_.out_channels / options_.groups, options_.kernel_size}); + if (options_.bias) { bias_ = bias_.view({options_.out_channels}); } + break; + } + case ModelFileVersion::kUserTemporary: + case ModelFileVersion::kV2: { + weight_ = ploader->pull(getName() + ".weight"); + if (options_.bias) { bias_ = ploader->pull(getName() + ".bias"); } + break; + } + default: NYI("Unsupported model file version") + } +} + +void ConvTranspose1DOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + + if (weight_ && !ir_ctx->lookupSymbolTable(getName() + ".weight")) { + ir::IRWriterGuard guard(ir_ctx, ir_ctx->lookupSymbolTable("init")->cast_()->getTopRegion()); + ir_ctx->create(ir_ctx->create(weight_)); + if (options_.bias) { ir_ctx->create(ir_ctx->create(bias_)); } + } + + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void ConvTranspose1DOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("ConvTranspose1DOp::forward not implemented in aops base."); +} + +void ConvTranspose1DOp::reshape(const std::vector& inputs, std::vector& outputs) { + const auto& i = inputs[0]; + const auto& ishape = i.shape(); + + if (ishape.size() != 3) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ConvTranspose1DOp expects 3D input, got {} D", ishape.size()); + outputs.emplace_back(Tensor::empty(i.shape(), i.dtype(), i.device())); + return; + } + + const int batch = ishape[0]; + const int in_channels = ishape[1]; + const int sequence = ishape[2]; + + MLLM_RT_ASSERT_EQ(in_channels, options_.in_channels); + MLLM_RT_ASSERT_EQ(in_channels % options_.groups, 0); + MLLM_RT_ASSERT_EQ(options_.out_channels % options_.groups, 0); + + const int kernel_size = options_.kernel_size; + const int stride = options_.stride; + const int dilation = options_.dilation; + const int padding = options_.padding; + const int output_padding = options_.output_padding; + + const int seq_out = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + + auto new_shape = std::vector{batch, options_.out_channels, seq_out}; + outputs.emplace_back(Tensor::empty(new_shape, i.dtype(), i.device())); +} + +void ConvTranspose1DOp::setup(const std::vector& inputs, std::vector& outputs) { + BaseOp::setup(inputs, outputs); +} + +ParameterFile::ptr_t ConvTranspose1DOp::getParams() { + auto p = ParameterFile::create(); + p->push(getName() + ".weight", weight_); + if (options_.bias) { p->push(getName() + ".bias", bias_); } + return p; +} + +} // namespace mllm::aops diff --git a/mllm/core/aops/ConvTranspose1DOp.hpp b/mllm/core/aops/ConvTranspose1DOp.hpp new file mode 100644 index 000000000..daeda0b8e --- /dev/null +++ b/mllm/core/aops/ConvTranspose1DOp.hpp @@ -0,0 +1,52 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct ConvTranspose1DOpOptions : public BaseOpOptions { + int32_t in_channels; + int32_t out_channels; + int32_t kernel_size; + int32_t stride = 1; + int32_t padding = 0; + int32_t output_padding = 0; + int32_t dilation = 1; + int32_t groups = 1; + bool bias = true; +}; + +class ConvTranspose1DOp : public BaseOp { + public: + explicit ConvTranspose1DOp(const ConvTranspose1DOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + ParameterFile::ptr_t getParams() override; + + inline Tensor& weight() { return weight_; } + + inline Tensor& bias() { return bias_; } + + inline ConvTranspose1DOpOptions& options() { return options_; } + + protected: + Tensor weight_; + Tensor bias_; + ConvTranspose1DOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.cpp b/mllm/core/aops/TanhOp.cpp new file mode 100644 index 000000000..c0938d82f --- /dev/null +++ b/mllm/core/aops/TanhOp.cpp @@ -0,0 +1,37 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/core/aops/TanhOp.hpp" +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/compile/ir/linalg/Op.hpp" + +namespace mllm::aops { + +TanhOp::TanhOp(const TanhOpOptions& options) : BaseOp(OpTypes::kTanh), options_(options) {} + +void TanhOp::load(const ParameterFile::ptr_t& ploader) { MLLM_EMPTY_SCOPE; } + +void TanhOp::trace(void* trace_context, const std::vector& inputs, std::vector& outputs) { + auto ir_ctx = (ir::IRContext*)trace_context; + auto i_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, inputs); + auto o_irs = ir::tensor::wrapTensors2TensorIR(ir_ctx, outputs); + ir_ctx->create(shared_from_this(), i_irs, o_irs); +} + +void TanhOp::forward(const std::vector& inputs, std::vector& outputs) { + NYI("TanhOp::forward not implemented in aops base."); +} + +void TanhOp::reshape(const std::vector& inputs, std::vector& outputs) { + if (options_.isInplace()) { + outputs.emplace_back(inputs[0]); + } else { + outputs.emplace_back(Tensor::empty(inputs[0].shape(), inputs[0].dtype(), inputs[0].device())); + } +} + +void TanhOp::setup(const std::vector& inputs, std::vector& outputs) { BaseOp::setup(inputs, outputs); } + +} // namespace mllm::aops diff --git a/mllm/core/aops/TanhOp.hpp b/mllm/core/aops/TanhOp.hpp new file mode 100644 index 000000000..8b2ce4f43 --- /dev/null +++ b/mllm/core/aops/TanhOp.hpp @@ -0,0 +1,33 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/core/BaseOp.hpp" +#include "mllm/core/ParameterFile.hpp" + +namespace mllm::aops { + +struct TanhOpOptions : public BaseOpOptions {}; + +class TanhOp : public BaseOp { + public: + explicit TanhOp(const TanhOpOptions& options); + + void load(const ParameterFile::ptr_t& ploader) override; + + void trace(void* trace_context, const std::vector& inputs, std::vector& outputs) override; + + void forward(const std::vector& inputs, std::vector& outputs) override; + + void reshape(const std::vector& inputs, std::vector& outputs) override; + + void setup(const std::vector& inputs, std::vector& outputs) override; + + inline TanhOpOptions& options() { return options_; } + + protected: + TanhOpOptions options_; +}; + +} // namespace mllm::aops diff --git a/mllm/ffi/Extension.cc b/mllm/ffi/Extension.cc index cb999191d..f3f2d2488 100644 --- a/mllm/ffi/Extension.cc +++ b/mllm/ffi/Extension.cc @@ -83,12 +83,12 @@ TVM_FFI_STATIC_INIT_BLOCK() { // Tensor related refl::GlobalDef().def("mllm.empty", mllm::ffi::empty); refl::GlobalDef().def("mllm.from_torch", [](const tvm::ffi::Tensor& t) -> mllm::ffi::Tensor { - auto dl_pack = t.get()->ToDLPack(); + auto dl_pack = t.ToDLPack(); return ::mllm::ffi::Tensor(mllm::ffi::__from_dlpack(dl_pack)); }); refl::GlobalDef().def("mllm.from_numpy", [](const tvm::ffi::Tensor& t) -> mllm::ffi::Tensor { - auto dl_pack = t.get()->ToDLPack(); + auto dl_pack = t.ToDLPack(); return ::mllm::ffi::Tensor(mllm::ffi::__from_dlpack(dl_pack)); }); @@ -345,6 +345,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::ObjectDef<::mllm::ffi::BaseOpObj>(); + refl::ObjectDef<::mllm::ffi::ParameterFileObj>(); refl::GlobalDef().def("mllm.BaseOp.load", [](const mllm::ffi::BaseOp& self, const mllm::ffi::ParameterFile& obj) -> void { self.get()->op_ptr_->load(obj.get()->pf_ptr_); }); diff --git a/mllm/models/minicpm_o2_6/modeling_chattts.hpp b/mllm/models/minicpm_o2_6/modeling_chattts.hpp index 3190210a7..4814d0eb7 100644 --- a/mllm/models/minicpm_o2_6/modeling_chattts.hpp +++ b/mllm/models/minicpm_o2_6/modeling_chattts.hpp @@ -401,7 +401,6 @@ class ConditionalChatTTS : public nn::Module { // Apply softmax to get probabilities: [num_vq, codebook_size] auto scores = nn::functional::softmax(logits.view({1, 1, logits.shape()[0], logits.shape()[1]}), -1).squeeze(); - logits.delete_(); // Free memory // Sample from each VQ codebook independently using multinomial sampling // This matches PyTorch's torch.multinomial(scores, num_samples=1) behavior @@ -418,7 +417,6 @@ class ConditionalChatTTS : public nn::Module { if (sampled_token == eos_token) { finished = true; } } - scores.delete_(); // Free memory progress++; audio_bos = false; diff --git a/mllm/models/minicpm_o2_6/modeling_resampler.hpp b/mllm/models/minicpm_o2_6/modeling_resampler.hpp index f447521bd..adc80a39f 100644 --- a/mllm/models/minicpm_o2_6/modeling_resampler.hpp +++ b/mllm/models/minicpm_o2_6/modeling_resampler.hpp @@ -99,34 +99,43 @@ class ResamplerAttention : public nn::Module { q = q + q_bias; k = k + k_bias; v = v + v_bias; + q = q.contiguous(); + k = k.contiguous(); + v = v.contiguous(); auto q_reshaped = Tensor::empty({num_heads_, num_queries, head_dim_}, kFloat32).alloc(); + const auto* q_ptr = q.ptr(); + auto* q_reshaped_ptr = q_reshaped.ptr(); for (int nq = 0; nq < num_queries; nq++) { + auto q_row_ptr = q_ptr + static_cast(nq) * embed_dim_; for (int h = 0; h < num_heads_; h++) { - for (int d = 0; d < head_dim_; d++) { - float val = q.at({nq, h * head_dim_ + d}); - *q_reshaped.offsettedPtr({h, nq, d}) = val; - } + auto src_ptr = q_row_ptr + h * head_dim_; + auto dst_ptr = q_reshaped_ptr + (static_cast(h) * num_queries + nq) * head_dim_; + std::memcpy(dst_ptr, src_ptr, static_cast(head_dim_) * sizeof(float)); } } q = q_reshaped; // [num_heads, num_queries, head_dim] auto k_reshaped = Tensor::empty({num_heads_, seq_len, head_dim_}, kFloat32).alloc(); + const auto* k_ptr = k.ptr(); + auto* k_reshaped_ptr = k_reshaped.ptr(); for (int s = 0; s < seq_len; s++) { + auto k_row_ptr = k_ptr + static_cast(s) * embed_dim_; for (int h = 0; h < num_heads_; h++) { - for (int d = 0; d < head_dim_; d++) { - float val = k.at({s, h * head_dim_ + d}); - *k_reshaped.offsettedPtr({h, s, d}) = val; - } + auto src_ptr = k_row_ptr + h * head_dim_; + auto dst_ptr = k_reshaped_ptr + (static_cast(h) * seq_len + s) * head_dim_; + std::memcpy(dst_ptr, src_ptr, static_cast(head_dim_) * sizeof(float)); } } k = k_reshaped; auto v_reshaped = Tensor::empty({num_heads_, seq_len, head_dim_}, kFloat32).alloc(); + const auto* v_ptr = v.ptr(); + auto* v_reshaped_ptr = v_reshaped.ptr(); for (int s = 0; s < seq_len; s++) { + auto v_row_ptr = v_ptr + static_cast(s) * embed_dim_; for (int h = 0; h < num_heads_; h++) { - for (int d = 0; d < head_dim_; d++) { - float val = v.at({s, h * head_dim_ + d}); - *v_reshaped.offsettedPtr({h, s, d}) = val; - } + auto src_ptr = v_row_ptr + h * head_dim_; + auto dst_ptr = v_reshaped_ptr + (static_cast(h) * seq_len + s) * head_dim_; + std::memcpy(dst_ptr, src_ptr, static_cast(head_dim_) * sizeof(float)); } } v = v_reshaped; @@ -140,10 +149,12 @@ class ResamplerAttention : public nn::Module { if (has_key_padding_mask && key_padding_mask.numel() > 0) { auto mask_value = -std::numeric_limits::infinity(); + auto key_padding_mask_contiguous = key_padding_mask.isContiguous() ? key_padding_mask : key_padding_mask.contiguous(); + const auto* key_padding_mask_ptr = key_padding_mask_contiguous.ptr(); for (int32_t h = 0; h < num_heads_; ++h) { for (int32_t q_idx = 0; q_idx < num_queries; ++q_idx) { for (int32_t s = 0; s < seq_len; ++s) { - if (key_padding_mask.at({s}) == 1) { *attn_weights.offsettedPtr({h, q_idx, s}) = mask_value; } + if (key_padding_mask_ptr[s] == 1) { *attn_weights.offsettedPtr({h, q_idx, s}) = mask_value; } } } } @@ -152,14 +163,16 @@ class ResamplerAttention : public nn::Module { attn_weights = nn::functional::softmax(attn_weights.unsqueeze(0), -1).squeeze(0); auto attn_output = nn::functional::matmul(attn_weights, v); // [num_heads, num_queries, head_dim] + attn_output = attn_output.contiguous(); auto attn_output_reshaped = Tensor::empty({num_queries, embed_dim_}, kFloat32).alloc(); + const auto* attn_output_ptr = attn_output.ptr(); + auto* attn_output_reshaped_ptr = attn_output_reshaped.ptr(); for (int h = 0; h < num_heads_; h++) { for (int nq = 0; nq < num_queries; nq++) { - for (int d = 0; d < head_dim_; d++) { - float val = attn_output.at({h, nq, d}); - *attn_output_reshaped.offsettedPtr({nq, h * head_dim_ + d}) = val; - } + auto src_ptr = attn_output_ptr + (static_cast(h) * num_queries + nq) * head_dim_; + auto dst_ptr = attn_output_reshaped_ptr + static_cast(nq) * embed_dim_ + h * head_dim_; + std::memcpy(dst_ptr, src_ptr, static_cast(head_dim_) * sizeof(float)); } } attn_output = attn_output_reshaped; @@ -224,11 +237,15 @@ class Resampler : public nn::Module { std::vector patch_len(batch_size); int max_h = 0, max_w = 0, max_patch_len = 0; + auto tgt_sizes_contiguous = tgt_sizes.isContiguous() ? tgt_sizes : tgt_sizes.contiguous(); + const auto* tgt_sizes_ptr = tgt_sizes_contiguous.ptr(); for (int i = 0; i < batch_size; i++) { - patch_len[i] = tgt_sizes.at({i, 0}) * tgt_sizes.at({i, 1}); + auto tgt_h = tgt_sizes_ptr[i * 2]; + auto tgt_w = tgt_sizes_ptr[i * 2 + 1]; + patch_len[i] = tgt_h * tgt_w; if (patch_len[i] > max_patch_len) max_patch_len = patch_len[i]; - if (tgt_sizes.at({i, 0}) > max_h) max_h = tgt_sizes.at({i, 0}); - if (tgt_sizes.at({i, 1}) > max_w) max_w = tgt_sizes.at({i, 1}); + if (tgt_h > max_h) max_h = tgt_h; + if (tgt_w > max_w) max_w = tgt_w; } if (max_h > max_size_[0] || max_w > max_size_[1]) { @@ -238,30 +255,35 @@ class Resampler : public nn::Module { registerBuffer("pos_embed", new_pos_embed); } - auto pos_embed = getBuffer("pos_embed"); // [max_h, max_w, embed_dim] + auto pos_embed = getBuffer("pos_embed").contiguous(); // [max_h, max_w, embed_dim] + const auto* pos_embed_ptr = pos_embed.ptr(); auto key_padding_mask = Tensor::empty({batch_size, max_patch_len}, kUInt8).alloc(); + auto* key_padding_mask_ptr = key_padding_mask.ptr(); for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < max_patch_len; j++) { key_padding_mask.at({i, j}) = 1; } - for (int j = 0; j < patch_len[i] && j < max_patch_len; j++) { key_padding_mask.at({i, j}) = 0; } + auto* key_padding_mask_row_ptr = key_padding_mask_ptr + static_cast(i) * max_patch_len; + std::memset(key_padding_mask_row_ptr, 1, static_cast(max_patch_len)); + if (patch_len[i] > 0) { + std::memset(key_padding_mask_row_ptr, 0, static_cast(std::min(patch_len[i], max_patch_len))); + } } std::vector pos_embed_list; for (int i = 0; i < batch_size; i++) { - int32_t tgt_h = tgt_sizes.at({i, 0}); - int32_t tgt_w = tgt_sizes.at({i, 1}); + int32_t tgt_h = tgt_sizes_ptr[i * 2]; + int32_t tgt_w = tgt_sizes_ptr[i * 2 + 1]; int32_t patch_count = tgt_h * tgt_w; Tensor pos_embed_i = Tensor::empty({patch_count, embed_dim_}, kFloat32).alloc(); + auto* pos_embed_i_ptr = pos_embed_i.ptr(); int patch_idx = 0; for (int h = 0; h < tgt_h; h++) { for (int w = 0; w < tgt_w; w++) { - for (int d = 0; d < embed_dim_; d++) { - float value = pos_embed.at({h, w, d}); - *pos_embed_i.offsettedPtr({patch_idx, d}) = value; - } + auto src_ptr = pos_embed_ptr + (static_cast(h) * max_w + w) * embed_dim_; + auto dst_ptr = pos_embed_i_ptr + static_cast(patch_idx) * embed_dim_; + std::memcpy(dst_ptr, src_ptr, static_cast(embed_dim_) * sizeof(float)); patch_idx++; } } @@ -270,18 +292,22 @@ class Resampler : public nn::Module { } Tensor pos_embed_padded = Tensor::empty({batch_size, max_patch_len, embed_dim_}, kFloat32).alloc(); + auto* pos_embed_padded_ptr = pos_embed_padded.ptr(); for (int i = 0; i < batch_size; i++) { - auto& pos_embed_i = pos_embed_list[i]; + auto pos_embed_i = pos_embed_list[i].contiguous(); int actual_len = pos_embed_i.shape()[0]; + const auto* pos_embed_i_ptr = pos_embed_i.ptr(); + auto* pos_embed_padded_batch_ptr = pos_embed_padded_ptr + static_cast(i) * max_patch_len * embed_dim_; - for (int j = 0; j < actual_len && j < max_patch_len; j++) { - for (int k = 0; k < embed_dim_; k++) { - *pos_embed_padded.offsettedPtr({i, j, k}) = pos_embed_i.at({j, k}); - } + auto rows_to_copy = std::min(actual_len, max_patch_len); + if (rows_to_copy > 0) { + std::memcpy(pos_embed_padded_batch_ptr, pos_embed_i_ptr, + static_cast(rows_to_copy) * embed_dim_ * sizeof(float)); } - for (int j = actual_len; j < max_patch_len; j++) { - for (int k = 0; k < embed_dim_; k++) { *pos_embed_padded.offsettedPtr({i, j, k}) = 0.0f; } + if (rows_to_copy < max_patch_len) { + std::memset(pos_embed_padded_batch_ptr + static_cast(rows_to_copy) * embed_dim_, 0, + static_cast(max_patch_len - rows_to_copy) * embed_dim_ * sizeof(float)); } } @@ -315,13 +341,7 @@ class Resampler : public nn::Module { // key_padding_mask for this batch Tensor key_padding_mask_b = key_padding_mask[{b, kAll}].view({max_patch_len}); - bool has_padding = false; - for (int i = 0; i < seq_len; i++) { - if (key_padding_mask_b.at({i}) == 1) { - has_padding = true; - break; - } - } + bool has_padding = patch_len[b] < seq_len; auto attn_output = has_padding ? attn_(q, kv_input, x_b, key_padding_mask_b)[0] : attn_(q, kv_input, x_b)[0]; diff --git a/mllm/models/minicpm_o2_6/modeling_siglip.hpp b/mllm/models/minicpm_o2_6/modeling_siglip.hpp index 30750201b..deb5deb3b 100644 --- a/mllm/models/minicpm_o2_6/modeling_siglip.hpp +++ b/mllm/models/minicpm_o2_6/modeling_siglip.hpp @@ -50,10 +50,14 @@ class SiglipVisionEmbeddings final : public nn::Module { // Create position embeddings if (!tgt_sizes.isNil() && !patch_attention_mask.isNil()) { + if (!tgt_sizes.isContiguous()) { tgt_sizes = tgt_sizes.contiguous(); } + if (!patch_attention_mask.isContiguous()) { patch_attention_mask = patch_attention_mask.contiguous(); } auto max_im_h = pixel_values.shape()[2]; auto max_im_w = pixel_values.shape()[3]; auto max_nb_patches_h = max_im_h / patch_size_; auto max_nb_patches_w = max_im_w / patch_size_; + const auto* tgt_sizes_ptr = tgt_sizes.ptr(); + const auto* patch_mask_ptr = patch_attention_mask.ptr(); // Create boundaries like torch.arange(1 / self.num_patches_per_side, 1.0, 1 / self.num_patches_per_side) std::vector boundaries; @@ -63,10 +67,8 @@ class SiglipVisionEmbeddings final : public nn::Module { // Create position_ids tensor - using the max_patches from patch_attention_mask shape auto max_patches = patch_attention_mask.shape()[2]; auto position_ids = Tensor::empty({batch_size, max_patches}, kInt64).alloc(); - // Initialize to zeros - for (int b = 0; b < batch_size; b++) { - for (int p = 0; p < max_patches; p++) { position_ids.at({b, p}) = 0; } - } + std::memset(position_ids.ptr(), 0, static_cast(batch_size) * max_patches * sizeof(int64_t)); + auto* position_ids_ptr = position_ids.ptr(); // Fill position ids based on patch grid and attention mask for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { @@ -74,8 +76,8 @@ class SiglipVisionEmbeddings final : public nn::Module { int nb_patches_w = max_nb_patches_w; if (tgt_sizes.shape().size() == 2 && batch_idx < tgt_sizes.shape()[0]) { - nb_patches_h = tgt_sizes.at({batch_idx, 0}); - nb_patches_w = tgt_sizes.at({batch_idx, 1}); + nb_patches_h = tgt_sizes_ptr[batch_idx * 2]; + nb_patches_w = tgt_sizes_ptr[batch_idx * 2 + 1]; } // Create fractional coordinates like torch.arange(0, 1 - 1e-6, 1 / nb_patches_h/w) @@ -132,10 +134,11 @@ class SiglipVisionEmbeddings final : public nn::Module { // Apply pos_ids only where patch_attention_mask is True (now it's 1D) // position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids int pos_ids_idx = 0; + auto patch_mask_batch_ptr = patch_mask_ptr + static_cast(batch_idx) * max_patches; + auto position_ids_batch_ptr = position_ids_ptr + static_cast(batch_idx) * max_patches; for (int flat_idx = 0; flat_idx < max_patches; ++flat_idx) { - uint8_t mask_val = patch_attention_mask.at({batch_idx, 0, flat_idx}); - if (mask_val && pos_ids_idx < pos_ids.size()) { - position_ids.at({batch_idx, flat_idx}) = pos_ids[pos_ids_idx]; + if (patch_mask_batch_ptr[flat_idx] && pos_ids_idx < pos_ids.size()) { + position_ids_batch_ptr[flat_idx] = pos_ids[pos_ids_idx]; pos_ids_idx++; } } @@ -350,20 +353,28 @@ class SiglipVisionModel final : public nn::Module { auto batch_size = pixel_values.shape()[0]; int max_patches = 0; // Calculate max_patches based on tgt_sizes + if (!tgt_sizes.isContiguous()) { tgt_sizes = tgt_sizes.contiguous(); } + const auto* tgt_sizes_ptr = tgt_sizes.ptr(); for (int i = 0; i < tgt_sizes.shape()[0]; i++) { - if (tgt_sizes.at({i, 0}) > 0 && tgt_sizes.at({i, 1}) > 0) { - int patches = (tgt_sizes.at({i, 0})) * (tgt_sizes.at({i, 1})); + auto tgt_h = tgt_sizes_ptr[i * 2]; + auto tgt_w = tgt_sizes_ptr[i * 2 + 1]; + if (tgt_h > 0 && tgt_w > 0) { + int patches = tgt_h * tgt_w; if (patches > max_patches) max_patches = patches; } } auto patch_attention_mask = Tensor::empty({batch_size, 1, max_patches}, kUInt8).alloc(); + auto* patch_attention_mask_ptr = patch_attention_mask.ptr(); for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < max_patches; j++) { patch_attention_mask.at({i, 0, j}) = 0; } + auto* patch_attention_mask_batch_ptr = patch_attention_mask_ptr + static_cast(i) * max_patches; + std::memset(patch_attention_mask_batch_ptr, 0, static_cast(max_patches)); if (!tgt_sizes.isNil() && i < tgt_sizes.shape()[0]) { - int nb_patches_h = tgt_sizes.at({i, 0}); - int nb_patches_w = tgt_sizes.at({i, 1}); + int nb_patches_h = tgt_sizes_ptr[i * 2]; + int nb_patches_w = tgt_sizes_ptr[i * 2 + 1]; int valid_patches = nb_patches_h * nb_patches_w; - for (int j = 0; j < valid_patches && j < max_patches; j++) { patch_attention_mask.at({i, 0, j}) = 1; } + if (valid_patches > 0) { + std::memset(patch_attention_mask_batch_ptr, 1, static_cast(std::min(valid_patches, max_patches))); + } } } std::vector hidden_states_result; @@ -374,7 +385,7 @@ class SiglipVisionModel final : public nn::Module { } auto hidden_states = hidden_states_result[0]; // [B, num_patches, embed_dim] - patch_attention_mask = patch_attention_mask.squeeze(1); // [B, max_patches] + patch_attention_mask = patch_attention_mask.squeeze(1).contiguous(); // [B, max_patches] // Create attention mask for encoder (4D mask for multi-head attention) // TODO: this will take about 100ms, optimize it @@ -382,42 +393,38 @@ class SiglipVisionModel final : public nn::Module { if (!patch_attention_mask.isNil()) { auto batch_size = patch_attention_mask.shape()[0]; auto max_patches = patch_attention_mask.shape()[1]; + const auto* patch_mask_ptr = patch_attention_mask.ptr(); bool all_valid = true; for (int i = 0; i < batch_size && all_valid; i++) { + auto patch_mask_batch_ptr = patch_mask_ptr + static_cast(i) * max_patches; for (int j = 0; j < max_patches && all_valid; j++) { - uint8_t mask_val = patch_attention_mask.at({i, j}); - if (mask_val == 0) { all_valid = false; } + if (patch_mask_batch_ptr[j] == 0) { all_valid = false; } } } if (!all_valid) { - // Convert patch_attention_mask to float and create 4D attention mask - auto patch_mask_float = Tensor::empty({batch_size, max_patches}, kFloat32).alloc(); - for (int i = 0; i < batch_size; i++) { - for (int j = 0; j < max_patches; j++) { - uint8_t mask_val = patch_attention_mask.at({i, j}); - patch_mask_float.at({i, j}) = mask_val ? 1.0f : 0.0f; - } - } - // Create 4D attention mask: [B, 1, max_patches, max_patches] attention_mask = Tensor::empty({batch_size, 1, max_patches, max_patches}, kFloat32).alloc(); + auto* attention_mask_ptr = attention_mask.ptr(); // Optimize with cache-friendly access patterns and reduced redundant accesses for (int b = 0; b < batch_size; b++) { // Pre-fetch mask values for this batch to improve cache locality - std::vector batch_mask(max_patches); - for (int p = 0; p < max_patches; p++) { batch_mask[p] = patch_mask_float.at({b, p}); } + std::vector batch_mask(max_patches); + std::memcpy(batch_mask.data(), patch_mask_ptr + static_cast(b) * max_patches, static_cast(max_patches)); // Compute attention mask for this batch with optimized memory access + auto* attention_mask_batch_ptr = + attention_mask_ptr + static_cast(b) * max_patches * max_patches; for (int i = 0; i < max_patches; i++) { - float mask_i = batch_mask[i]; + uint8_t mask_i = batch_mask[i]; + auto* attention_mask_row_ptr = attention_mask_batch_ptr + static_cast(i) * max_patches; // Process row in chunks for better cache utilization for (int j = 0; j < max_patches; j++) { - float mask_j = batch_mask[j]; + uint8_t mask_j = batch_mask[j]; // Both positions must be valid (branchless computation) float final_mask = (mask_i > 0.0f && mask_j > 0.0f) ? 0.0f : -1e9f; - attention_mask.at({b, 0, i, j}) = final_mask; + attention_mask_row_ptr[j] = final_mask; } } } diff --git a/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp b/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp new file mode 100644 index 000000000..92d9106b8 --- /dev/null +++ b/mllm/models/minicpm_o45/configuration_minicpm_o45.hpp @@ -0,0 +1,177 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45Config : protected ConfigFile { + MiniCPMO45Config() = default; + + explicit MiniCPMO45Config(const std::string& file_path) : ConfigFile(file_path) { + auto& cfg = data(); + auto get_or = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + return cfg.contains(key) ? cfg[key].get() : fallback; + }; + + auto vision_cfg = cfg.contains("vision_config") ? cfg["vision_config"] : nlohmann::json::object(); + auto audio_cfg = cfg.contains("audio_config") ? cfg["audio_config"] : nlohmann::json::object(); + auto tts_cfg = cfg.contains("tts_config") ? cfg["tts_config"] : nlohmann::json::object(); + + auto get_vision = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (vision_cfg.contains(key)) { return vision_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + auto get_audio = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (audio_cfg.contains(key)) { return audio_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + auto get_tts = [&](const std::string& key, auto fallback) { + using T = decltype(fallback); + if (tts_cfg.contains(key)) { return tts_cfg[key].get(); } + if (cfg.contains(key)) { return cfg[key].get(); } + return fallback; + }; + + // Vision config. + vision_hidden_size = get_vision("vision_hidden_size", get_vision("hidden_size", vision_hidden_size)); + vision_intermediate_size = get_vision("vision_intermediate_size", get_vision("intermediate_size", vision_intermediate_size)); + vision_num_hidden_layers = get_vision("vision_num_hidden_layers", get_vision("num_hidden_layers", vision_num_hidden_layers)); + vision_num_attention_heads = get_vision("vision_num_attention_heads", get_vision("num_attention_heads", vision_num_attention_heads)); + vision_num_channels = get_vision("vision_num_channels", get_vision("num_channels", vision_num_channels)); + vision_image_size = get_vision("vision_image_size", get_vision("image_size", vision_image_size)); + vision_patch_size = get_vision("vision_patch_size", get_vision("patch_size", vision_patch_size)); + + // LLM config (Qwen3). + attention_bias = get_or("attention_bias", attention_bias); + hidden_size = get_or("hidden_size", hidden_size); + num_attention_heads = get_or("num_attention_heads", num_attention_heads); + num_key_value_heads = get_or("num_key_value_heads", num_key_value_heads); + head_dim = get_or("head_dim", hidden_size / std::max(num_attention_heads, 1)); + intermediate_size = get_or("intermediate_size", intermediate_size); + num_hidden_layers = get_or("num_hidden_layers", num_hidden_layers); + max_position_embeddings = get_or("max_position_embeddings", max_position_embeddings); + rms_norm_eps = get_or("rms_norm_eps", rms_norm_eps); + vocab_size = get_or("vocab_size", vocab_size); + + // Resampler config. + query_num = get_or("query_num", query_num); + + // Audio config (Whisper encoder). + audio_hidden_size = get_audio("audio_hidden_size", get_audio("d_model", audio_hidden_size)); + audio_num_hidden_layers = get_audio("audio_num_hidden_layers", get_audio("num_hidden_layers", audio_num_hidden_layers)); + audio_num_attention_heads = get_audio("audio_num_attention_heads", get_audio("encoder_attention_heads", audio_num_attention_heads)); + audio_max_position_embeddings = + get_audio("audio_max_position_embeddings", get_audio("max_source_positions", audio_max_position_embeddings)); + audio_chunk_length = get_audio("audio_chunk_length", audio_chunk_length); + audio_pool_step = get_or("audio_pool_step", audio_pool_step); + + // TTS config (token generation stage). + tts_llm_dim = get_tts("tts_llm_dim", get_tts("llm_dim", tts_llm_dim)); + tts_llm_intermediate_size = get_tts("tts_llm_intermediate_size", get_tts("llm_intermediate_size", tts_llm_intermediate_size)); + tts_hidden_size = get_tts("tts_hidden_size", get_tts("hidden_size", tts_hidden_size)); + tts_intermediate_size = get_tts("tts_intermediate_size", get_tts("intermediate_size", tts_intermediate_size)); + tts_num_attention_heads = get_tts("tts_num_attention_heads", get_tts("num_attention_heads", tts_num_attention_heads)); + tts_num_key_value_heads = get_tts("tts_num_key_value_heads", get_tts("num_key_value_heads", tts_num_key_value_heads)); + tts_num_hidden_layers = get_tts("tts_num_hidden_layers", get_tts("num_hidden_layers", tts_num_hidden_layers)); + tts_max_position_embeddings = get_tts("tts_max_position_embeddings", get_tts("max_position_embeddings", tts_max_position_embeddings)); + tts_num_audio_tokens = get_tts("tts_num_audio_tokens", get_tts("num_audio_tokens", tts_num_audio_tokens)); + tts_num_text_tokens = get_tts("tts_num_text_tokens", get_tts("num_text_tokens", tts_num_text_tokens)); + tts_num_vq = get_tts("tts_num_vq", get_tts("num_vq", tts_num_vq)); + tts_audio_bos_token_id = get_tts("tts_audio_bos_token_id", get_tts("audio_bos_token_id", tts_audio_bos_token_id)); + tts_text_eos_token_id = get_tts("tts_text_eos_token_id", get_tts("text_eos_token_id", tts_text_eos_token_id)); + tts_backbone_vocab_size = tts_cfg.contains("vocab_size") ? tts_cfg["vocab_size"].get() : tts_backbone_vocab_size; + tts_rms_norm_eps = get_tts("tts_rms_norm_eps", get_tts("rms_norm_eps", tts_rms_norm_eps)); + tts_rope_theta = get_tts("tts_rope_theta", get_tts("rope_theta", tts_rope_theta)); + tts_hidden_act = get_tts("tts_hidden_act", get_tts("hidden_act", tts_hidden_act)); + tts_projector_type = get_tts("tts_projector_type", get_tts("projector_type", tts_projector_type)); + tts_condition_type = get_tts("tts_condition_type", get_tts("condition_type", tts_condition_type)); + tts_normalize_projected_hidden = get_tts("tts_normalize_projected_hidden", get_tts("normalize_projected_hidden", tts_normalize_projected_hidden)); + + // Common config. + max_cache_length = get_or("max_cache_length", max_cache_length); + eos_token_id = get_or("eos_token_id", eos_token_id); + bos_token_id = get_or("bos_token_id", bos_token_id); + rope_theta = get_or("rope_theta", rope_theta); + tie_word_embeddings = get_or("tie_word_embeddings", tie_word_embeddings); + + linear_impl_type = cfg.contains("linear_impl_type") ? aops::str2LinearImplTypes(cfg["linear_impl_type"]) : linear_impl_type; + } + + // Vision config (SigLIP). + int32_t vision_hidden_size = 1152; + int32_t vision_intermediate_size = 4304; + int32_t vision_num_hidden_layers = 27; + int32_t vision_num_attention_heads = 16; + int32_t vision_num_channels = 3; + int32_t vision_image_size = 980; + int32_t vision_patch_size = 14; + + // LLM config (Qwen3-8B). + bool attention_bias = false; + int32_t hidden_size = 4096; + int32_t head_dim = 128; + int32_t intermediate_size = 12288; + int32_t num_attention_heads = 32; + int32_t num_key_value_heads = 8; + int32_t num_hidden_layers = 36; + int32_t max_position_embeddings = 40960; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 151748; + + // Resampler config. + int32_t query_num = 64; + + // Audio config (Whisper-medium). + int32_t audio_hidden_size = 1024; + int32_t audio_num_hidden_layers = 24; + int32_t audio_num_attention_heads = 16; + int32_t audio_max_position_embeddings = 1500; + float audio_chunk_length = 1.0f; + int32_t audio_pool_step = 5; + + // TTS config (MiniCPMTTS in MiniCPM-o-4_5). + int32_t tts_llm_dim = 4096; + int32_t tts_llm_intermediate_size = 768; + int32_t tts_hidden_size = 768; + int32_t tts_intermediate_size = 3072; + int32_t tts_num_attention_heads = 12; + int32_t tts_num_key_value_heads = 12; + int32_t tts_num_hidden_layers = 20; + int32_t tts_max_position_embeddings = 4096; + int32_t tts_num_audio_tokens = 6562; + int32_t tts_num_text_tokens = 152064; + int32_t tts_num_vq = 1; + int32_t tts_audio_bos_token_id = 151687; + int32_t tts_text_eos_token_id = 151692; + int32_t tts_backbone_vocab_size = 32000; + float tts_rms_norm_eps = 1e-06f; + float tts_rope_theta = 10000.0f; + std::string tts_hidden_act = "silu"; + std::string tts_projector_type = "mlp"; + std::string tts_condition_type = "hidden_text_merge"; + bool tts_normalize_projected_hidden = true; + + // Common config. + int32_t max_cache_length = 4096; + int64_t eos_token_id = 151645; + int64_t bos_token_id = 151643; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py b/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py new file mode 100644 index 000000000..760869611 --- /dev/null +++ b/mllm/models/minicpm_o45/convert_token2wav_pt_to_mllm.py @@ -0,0 +1,414 @@ +#!/usr/bin/env python3 +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +"""Lightweight MiniCPM-o-4_5 token2wav converter. + +This script merges `flow.pt` + `hift.pt` into one `.mllm` file without +depending on `pymllm`/`tvm_ffi`. +""" + +from __future__ import annotations + +import argparse +import gc +import os +import struct +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, Iterable, List, Tuple + +try: + import torch +except ImportError as exc: + raise ImportError("PyTorch is required. Please install torch in your Python env.") from exc + + +# ----------------------------- MLLM constants ----------------------------- # + +MLLM_MODEL_FILE_V1_MAGIC_NUMBER = 20012 +MLLM_MODEL_FILE_V2_MAGIC_NUMBER = 0x519A +MLLM_MODEL_FILE_V2_VERSION = 2 +MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH = 512 +MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH = 256 +MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH = 16 + +MODEL_FILE_V2_DESC_SIZE = 532 +MODEL_FILE_V2_PARAM_DESC_SIZE = 352 + + +def _build_torch_type_mapping() -> Dict[torch.dtype, int]: + mapping = { + torch.float32: 0, # kFloat32 + torch.float16: 1, # kFloat16 + torch.bfloat16: 128, # kBFloat16 + torch.int8: 16, # kInt8 + torch.int16: 17, # kInt16 + torch.int32: 18, # kInt32 + torch.int64: 132, # kInt64 + torch.uint8: 129, # kUInt8 + torch.bool: 129, # kUInt8 + } + if hasattr(torch, "uint16"): + mapping[torch.uint16] = 130 # kUInt16 + return mapping + + +TORCH_TYPE_MAPPING = _build_torch_type_mapping() + + +# ----------------------------- Helpers ----------------------------- # + + +@dataclass +class TensorMeta: + raw_name: str + full_name: str + dtype_id: int + data_len: int + + +def _load_pt(path: Path) -> Dict[str, torch.Tensor]: + if not path.exists(): + raise FileNotFoundError(f"Checkpoint not found: {path}") + + try: + obj = torch.load(path, map_location="cpu", weights_only=True) + except TypeError: + obj = torch.load(path, map_location="cpu") + + if isinstance(obj, dict): + if obj and all(torch.is_tensor(v) for v in obj.values()): + return obj + for candidate in ("state_dict", "model", "module"): + cand = obj.get(candidate) + if isinstance(cand, dict) and cand and any(torch.is_tensor(v) for v in cand.values()): + return {k: v for k, v in cand.items() if torch.is_tensor(v)} + tensor_only = {k: v for k, v in obj.items() if torch.is_tensor(v)} + if tensor_only: + return tensor_only + + raise ValueError(f"Unsupported checkpoint layout: {path}") + + +def _normalized_tensor(t: torch.Tensor) -> torch.Tensor: + x = t.detach().cpu().contiguous() + if x.dim() == 0: + x = x.reshape(1) + return x + + +def _tensor_to_bytes(t: torch.Tensor) -> bytes: + x = _normalized_tensor(t) + return x.view(torch.uint8).numpy().tobytes() + + +def _tensor_dtype_id(dtype: torch.dtype) -> int: + if dtype not in TORCH_TYPE_MAPPING: + raise ValueError(f"Unsupported tensor dtype for .mllm export: {dtype}") + return TORCH_TYPE_MAPPING[dtype] + + +def _collect_source_meta( + ckpt_path: Path, + out_prefix: str, + strip_prefix: str, + preview_limit: int, +) -> List[TensorMeta]: + state = _load_pt(ckpt_path) + keys = list(state.keys()) + print(f"[inspect] {ckpt_path}: {len(keys)} tensors") + + metas: List[TensorMeta] = [] + for i, raw_name in enumerate(keys): + t = state[raw_name] + out_name = raw_name[len(strip_prefix) :] if (strip_prefix and raw_name.startswith(strip_prefix)) else raw_name + full_name = f"{out_prefix}{out_name}" + x = _normalized_tensor(t) + dtype_id = _tensor_dtype_id(x.dtype) + data_len = int(x.numel()) * int(x.element_size()) + metas.append(TensorMeta(raw_name=raw_name, full_name=full_name, dtype_id=dtype_id, data_len=data_len)) + if i < max(preview_limit, 0): + print(f" - {raw_name} shape={tuple(x.shape)} dtype={x.dtype}") + + del state + gc.collect() + return metas + + +def _check_duplicate_names(metas: Iterable[TensorMeta]) -> None: + seen = set() + for m in metas: + if m.full_name in seen: + raise ValueError(f"Duplicated tensor name after rename: {m.full_name}") + seen.add(m.full_name) + + +def _stream_source_tensors( + ckpt_path: Path, + metas: List[TensorMeta], +) -> Iterable[Tuple[TensorMeta, torch.Tensor]]: + state = _load_pt(ckpt_path) + try: + for m in metas: + if m.raw_name not in state: + raise KeyError(f"Tensor missing in checkpoint: {ckpt_path} -> {m.raw_name}") + t = _normalized_tensor(state[m.raw_name]) + yield m, t + finally: + del state + gc.collect() + + +# ----------------------------- V1 writer ----------------------------- # + + +def _write_v1( + output: Path, + model_name: str, + flow_pt: Path, + flow_metas: List[TensorMeta], + hift_pt: Path, + hift_metas: List[TensorMeta], +) -> None: + del model_name # v1 header has no model name + + all_metas = flow_metas + hift_metas + _check_duplicate_names(all_metas) + + desc_size = 0 + for m in all_metas: + name_bytes = m.full_name.encode("utf-8") + desc_size += 4 + len(name_bytes) + 8 + 8 + 4 + + output.parent.mkdir(parents=True, exist_ok=True) + with open(output, "wb") as f: + f.write(struct.pack(" {output}") + + +# ----------------------------- V2 writer ----------------------------- # + + +def _pack_v2_file_desc(model_name: str, num_params: int) -> bytes: + name_bytes = model_name.encode("utf-8") + name_bytes = name_bytes.ljust(MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH, b"\0")[:MLLM_MODEL_FILE_V2_MODEL_NAME_LENGTH] + return struct.pack( + f" bytes: + if len(shape) > MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH: + raise ValueError(f"Tensor rank > {MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH} is not supported: {name}") + + shape_padded = list(shape) + [0] * (MLLM_MODEL_FILE_V2_TENSOR_SHAPE_LENGTH - len(shape)) + name_bytes = name.encode("utf-8") + name_bytes = name_bytes.ljust(MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH, b"\0")[:MLLM_MODEL_FILE_V2_PARAMS_NAME_LENGTH] + return struct.pack( + f" None: + if self.num_params >= self.max_params: + raise ValueError(f"Descriptor buffer exceeded: {self.num_params} >= {self.max_params}") + + dtype_id = _tensor_dtype_id(tensor.dtype) + shape = tuple(int(v) for v in tensor.shape) + data = _tensor_to_bytes(tensor) + data_offset = self.f.tell() + data_len = len(data) + + self.f.write(data) + + desc_off = MODEL_FILE_V2_DESC_SIZE + self.num_params * MODEL_FILE_V2_PARAM_DESC_SIZE + self.f.seek(desc_off, os.SEEK_SET) + self.f.write( + _pack_v2_param_desc( + param_id=self.num_params, + param_type=dtype_id, + param_size=data_len, + param_offset=data_offset, + shape=shape, + name=name, + ) + ) + self.f.seek(0, os.SEEK_END) + self.num_params += 1 + + def finalize(self) -> None: + self.f.seek(0, os.SEEK_SET) + self.f.write(_pack_v2_file_desc(self.model_name, self.num_params)) + self.f.flush() + + def close(self) -> None: + if not self.f.closed: + self.f.close() + + +def _write_v2( + output: Path, + model_name: str, + flow_pt: Path, + flow_metas: List[TensorMeta], + hift_pt: Path, + hift_metas: List[TensorMeta], + max_param_desc: int, +) -> None: + all_metas = flow_metas + hift_metas + _check_duplicate_names(all_metas) + + if max_param_desc <= 0: + max_param_desc = len(all_metas) + if max_param_desc < len(all_metas): + raise ValueError(f"--max-param-desc ({max_param_desc}) < total tensors ({len(all_metas)})") + + output.parent.mkdir(parents=True, exist_ok=True) + writer = _V2StreamingWriter(output=output, model_name=model_name, max_params=max_param_desc) + written = 0 + try: + for m, t in _stream_source_tensors(flow_pt, flow_metas): + writer.write_tensor(m.full_name, t) + written += 1 + for m, t in _stream_source_tensors(hift_pt, hift_metas): + writer.write_tensor(m.full_name, t) + written += 1 + writer.finalize() + finally: + writer.close() + + print(f"[done:v2] wrote {written} tensors -> {output}") + + +# ----------------------------- Main ----------------------------- # + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Convert MiniCPM-o-4_5 token2wav flow.pt + hift.pt into one .mllm file." + ) + parser.add_argument( + "--flow-pt", + default="mllm/models/minicpm_o45/python_src_code/assets/token2wav/flow.pt", + help="Path to flow.pt", + ) + parser.add_argument( + "--hift-pt", + default="mllm/models/minicpm_o45/python_src_code/assets/token2wav/hift.pt", + help="Path to hift.pt", + ) + parser.add_argument("--output", required=True, help="Output .mllm path") + parser.add_argument("--model-name", default="minicpm_o45_token2wav", help="Model name (used in v2 header)") + parser.add_argument("--format", choices=["v1", "v2"], default="v1", help="Output model format") + parser.add_argument("--flow-prefix", default="token2wav.flow_model.", help="Prefix for flow tensor names") + parser.add_argument("--hift-prefix", default="token2wav.hift_model.", help="Prefix for hift tensor names") + parser.add_argument( + "--strip-hift-prefix", + default="generator.", + help="Strip this prefix from hift tensor names before adding --hift-prefix", + ) + parser.add_argument( + "--max-param-desc", + type=int, + default=0, + help="Only for v2: max descriptor buffer size, 0 means auto", + ) + parser.add_argument("--inspect-only", action="store_true", help="Only inspect checkpoints and quit") + parser.add_argument("--preview-limit", type=int, default=8, help="How many tensors to print per checkpoint") + args = parser.parse_args() + + flow_pt = Path(args.flow_pt).expanduser().resolve() + hift_pt = Path(args.hift_pt).expanduser().resolve() + output = Path(args.output).expanduser().resolve() + + flow_metas = _collect_source_meta(flow_pt, args.flow_prefix, "", args.preview_limit) + hift_metas = _collect_source_meta(hift_pt, args.hift_prefix, args.strip_hift_prefix, args.preview_limit) + + total = len(flow_metas) + len(hift_metas) + print(f"[count] flow={len(flow_metas)}, hift={len(hift_metas)}, total={total}") + + if args.inspect_only: + print("[inspect-only] done") + return + + if args.format == "v1": + _write_v1( + output=output, + model_name=args.model_name, + flow_pt=flow_pt, + flow_metas=flow_metas, + hift_pt=hift_pt, + hift_metas=hift_metas, + ) + else: + _write_v2( + output=output, + model_name=args.model_name, + flow_pt=flow_pt, + flow_metas=flow_metas, + hift_pt=hift_pt, + hift_metas=hift_metas, + max_param_desc=args.max_param_desc, + ) + + +if __name__ == "__main__": + main() diff --git a/mllm/models/minicpm_o45/export_prompt_cache.py b/mllm/models/minicpm_o45/export_prompt_cache.py new file mode 100644 index 000000000..041cc0c3f --- /dev/null +++ b/mllm/models/minicpm_o45/export_prompt_cache.py @@ -0,0 +1,241 @@ +#!/usr/bin/env python3 +# Copyright (c) MLLM Team. +# Licensed under the MIT License. + +"""Export fixed MiniCPM-o-4_5 token2wav prompt cache for native C++ runtime. + +This script extracts prompt_speech_tokens / prompt_mels / speaker_embedding from +one reference wav and writes a compact binary cache. +""" + +from __future__ import annotations + +import argparse +import struct +import sys +import time +import types +from pathlib import Path + +import numpy as np + + +def _setup_cosyvoice2_alias() -> None: + if "cosyvoice2.flow.flow" in sys.modules: + return + + import stepaudio2.cosyvoice2.flow.decoder_dit as _step_decoder_dit + import stepaudio2.cosyvoice2.flow.flow as _step_flow + import stepaudio2.cosyvoice2.flow.flow_matching as _step_flow_matching + import stepaudio2.cosyvoice2.transformer.upsample_encoder_v2 as _step_upsample + + cosyvoice2_pkg = types.ModuleType("cosyvoice2") + cosyvoice2_flow_pkg = types.ModuleType("cosyvoice2.flow") + cosyvoice2_transformer_pkg = types.ModuleType("cosyvoice2.transformer") + + cosyvoice2_flow_pkg.flow = _step_flow + cosyvoice2_flow_pkg.flow_matching = _step_flow_matching + cosyvoice2_flow_pkg.decoder_dit = _step_decoder_dit + cosyvoice2_transformer_pkg.upsample_encoder_v2 = _step_upsample + + cosyvoice2_pkg.flow = cosyvoice2_flow_pkg + cosyvoice2_pkg.transformer = cosyvoice2_transformer_pkg + + sys.modules["cosyvoice2"] = cosyvoice2_pkg + sys.modules["cosyvoice2.flow"] = cosyvoice2_flow_pkg + sys.modules["cosyvoice2.flow.flow"] = _step_flow + sys.modules["cosyvoice2.flow.flow_matching"] = _step_flow_matching + sys.modules["cosyvoice2.flow.decoder_dit"] = _step_decoder_dit + sys.modules["cosyvoice2.transformer"] = cosyvoice2_transformer_pkg + sys.modules["cosyvoice2.transformer.upsample_encoder_v2"] = _step_upsample + + +def _resolve_device(torch_mod, req: str): + req = req.lower() + if req == "cuda": + if not torch_mod.cuda.is_available(): + raise RuntimeError("Requested --device=cuda but CUDA is unavailable") + return torch_mod.device("cuda") + if req == "mps": + if not getattr(torch_mod.backends, "mps", None) or not torch_mod.backends.mps.is_available(): + raise RuntimeError("Requested --device=mps but MPS is unavailable") + return torch_mod.device("mps") + if req == "cpu": + return torch_mod.device("cpu") + if req != "auto": + raise ValueError(f"Unsupported --device: {req}") + + if torch_mod.cuda.is_available(): + return torch_mod.device("cuda") + if getattr(torch_mod.backends, "mps", None) and torch_mod.backends.mps.is_available(): + return torch_mod.device("mps") + return torch_mod.device("cpu") + + +def _move_model(model, device): + if device.type == "cuda" and hasattr(model, "cuda"): + return model.cuda() + if device.type == "cpu" and hasattr(model, "cpu"): + return model.cpu() + if hasattr(model, "to"): + return model.to(device) + return model + + +class _StageLogger: + def __init__(self, verbose: bool): + self.verbose = verbose + self.t0 = time.time() + + def log(self, msg: str) -> None: + if not self.verbose: + return + dt = time.time() - self.t0 + print(f"[prompt-cache +{dt:.3f}s] {msg}", flush=True) + + +def _write_cache(path: Path, prompt_tokens: np.ndarray, prompt_mels: np.ndarray, spk_emb: np.ndarray) -> None: + # File layout (little-endian): + # magic[8] = "M45PC1\\0\\0" + # u32 version = 1 + # i32 prompt_token_len + # i32 prompt_mel_frames + # i32 mel_dim + # i32 spk_dim + # i32[prompt_token_len] + # f32[prompt_mel_frames * mel_dim] + # f32[spk_dim] + magic = b"M45PC1\0\0" + version = 1 + token_len = int(prompt_tokens.shape[0]) + mel_frames = int(prompt_mels.shape[0]) + mel_dim = int(prompt_mels.shape[1]) + spk_dim = int(spk_emb.shape[0]) + + path.parent.mkdir(parents=True, exist_ok=True) + with open(path, "wb") as f: + f.write(magic) + f.write(struct.pack(" None: + parser = argparse.ArgumentParser(description="Export MiniCPM-o-4_5 fixed prompt cache for native C++ token2wav") + parser.add_argument("--ref_wav", required=True, help="Reference wav path used for voice style") + parser.add_argument("--token2wav_dir", required=True, help="Path to assets/token2wav directory") + parser.add_argument("--python_src_root", required=True, help="Path to MiniCPM-o-4_5 python_src_code directory") + parser.add_argument("--out_cache", required=True, help="Output cache path (.bin)") + parser.add_argument("--device", default="auto", choices=["auto", "cpu", "mps", "cuda"], help="Runtime device") + parser.add_argument("--verbose", action="store_true", help="Print detailed stage logs") + args = parser.parse_args() + + python_src_root = Path(args.python_src_root).expanduser().resolve() + ref_wav = Path(args.ref_wav).expanduser().resolve() + token2wav_dir = Path(args.token2wav_dir).expanduser().resolve() + out_cache = Path(args.out_cache).expanduser().resolve() + + if str(python_src_root) not in sys.path: + sys.path.insert(0, str(python_src_root)) + + import onnxruntime + import s3tokenizer + import torch + import torchaudio + import torchaudio.compliance.kaldi as kaldi + from hyperpyyaml import load_hyperpyyaml + from stepaudio2.flashcosyvoice.utils.audio import mel_spectrogram + + logger = _StageLogger(args.verbose) + + logger.log("Resolving runtime device...") + device = _resolve_device(torch, args.device) + print(f"[prompt-cache] device={device}", flush=True) + print(f"[prompt-cache] ref_wav={ref_wav}", flush=True) + print(f"[prompt-cache] token2wav_dir={token2wav_dir}", flush=True) + + _setup_cosyvoice2_alias() + + logger.log("Loading speech tokenizer ONNX...") + audio_tokenizer = s3tokenizer.load_model(str(token2wav_dir / "speech_tokenizer_v2_25hz.onnx")) + audio_tokenizer = _move_model(audio_tokenizer, device) + if hasattr(audio_tokenizer, "eval"): + audio_tokenizer = audio_tokenizer.eval() + + logger.log("Loading campplus.onnx...") + option = onnxruntime.SessionOptions() + option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL + option.intra_op_num_threads = 1 + spk_model = onnxruntime.InferenceSession( + str(token2wav_dir / "campplus.onnx"), + sess_options=option, + providers=["CPUExecutionProvider"], + ) + + logger.log("Reading flow.yaml for up_rate...") + with open(token2wav_dir / "flow.yaml", "r", encoding="utf-8") as f: + cfg = load_hyperpyyaml(f) + up_rate = int(cfg["flow"].up_rate) + print(f"[prompt-cache] flow.up_rate={up_rate}", flush=True) + + logger.log("Preparing prompt speech tokens (16k)...") + audio = s3tokenizer.load_audio(str(ref_wav), sr=16000) + mels = s3tokenizer.log_mel_spectrogram(audio) + mels, mels_lens = s3tokenizer.padding([mels]) + + quantize_device = device + try: + prompt_tokens, prompt_tokens_lens = audio_tokenizer.quantize(mels.to(quantize_device), mels_lens.to(quantize_device)) + except Exception: + quantize_device = torch.device("cpu") + audio_tokenizer = _move_model(audio_tokenizer, quantize_device) + if hasattr(audio_tokenizer, "eval"): + audio_tokenizer = audio_tokenizer.eval() + prompt_tokens, prompt_tokens_lens = audio_tokenizer.quantize(mels.to(quantize_device), mels_lens.to(quantize_device)) + + prompt_tokens = prompt_tokens.to(device) + prompt_tokens_lens = prompt_tokens_lens.to(device) + logger.log(f"prompt_tokens shape={tuple(prompt_tokens.shape)}, lens={prompt_tokens_lens.tolist()}") + + logger.log("Preparing speaker embedding...") + spk_feat = kaldi.fbank(audio.unsqueeze(0), num_mel_bins=80, dither=0, sample_frequency=16000) + spk_feat = spk_feat - spk_feat.mean(dim=0, keepdim=True) + spk_emb_np = spk_model.run( + None, + {spk_model.get_inputs()[0].name: spk_feat.unsqueeze(dim=0).cpu().numpy()}, + )[0] + spk_emb = torch.tensor(spk_emb_np, device=device, dtype=torch.float32) + logger.log(f"spk_emb shape={tuple(spk_emb.shape)}") + + logger.log("Preparing prompt mel (24k)...") + audio_24k, sample_rate = torchaudio.load(str(ref_wav), backend="soundfile") + audio_24k = audio_24k.mean(dim=0, keepdim=True) + if sample_rate != 24000: + audio_24k = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=24000)(audio_24k) + prompt_mel = mel_spectrogram(audio_24k).transpose(1, 2).squeeze(0) # [T, 80] + prompt_mels = prompt_mel.unsqueeze(0).to(device) + target_len = int(prompt_tokens.shape[1]) * up_rate + if target_len > prompt_mels.shape[1]: + prompt_mels = torch.nn.functional.pad( + prompt_mels, + (0, 0, 0, target_len - prompt_mels.shape[1]), + mode="replicate", + ) + logger.log(f"prompt_mels shape={tuple(prompt_mels.shape)}") + + logger.log("Writing cache...") + token_np = prompt_tokens[0].detach().cpu().numpy().astype(np.int32) + mel_np = prompt_mels[0].detach().cpu().numpy().astype(np.float32) + spk_np = spk_emb[0].detach().cpu().numpy().astype(np.float32) + _write_cache(out_cache, token_np, mel_np, spk_np) + + print(f"[prompt-cache] wrote: {out_cache}", flush=True) + print(f"[prompt-cache] token_len={token_np.shape[0]}, mel_frames={mel_np.shape[0]}, mel_dim={mel_np.shape[1]}, spk_dim={spk_np.shape[0]}", + flush=True) + + +if __name__ == "__main__": + main() + diff --git a/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp b/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp new file mode 100644 index 000000000..156896847 --- /dev/null +++ b/mllm/models/minicpm_o45/modeling_minicpm_o45.hpp @@ -0,0 +1,916 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/llama/configuration_llama.hpp" +#include "mllm/models/llama/modeling_llama.hpp" +#include "mllm/models/minicpm_o2_6/configuration_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/modeling_resampler.hpp" +#include "mllm/models/minicpm_o2_6/modeling_siglip.hpp" +#include "mllm/models/minicpm_o2_6/modeling_whisper_encoder.hpp" +#include "mllm/models/minicpm_o45/configuration_minicpm_o45.hpp" +#include "mllm/models/qwen3/configuration_qwen3.hpp" +#include "mllm/models/qwen3/modeling_qwen3.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +class AudioProjectionLayer final : public nn::Module { + public: + AudioProjectionLayer() = default; + + AudioProjectionLayer(const std::string& name, int32_t input_dim, int32_t hidden_dim, int32_t output_dim) : Module(name) { + linear1_ = reg("linear1", input_dim, hidden_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", hidden_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + x = linear1_(x); + x = relu_(x); + x = linear2_(x); + return {x}; + } + + private: + nn::Linear linear1_; + nn::ReLU relu_; + nn::Linear linear2_; +}; + +class AudioAvgPooler final : public nn::Module { + public: + AudioAvgPooler() = default; + + AudioAvgPooler(const std::string& name, int32_t kernel_size, int32_t stride) : Module(name) { + avg_pool_ = reg("pool", kernel_size, stride); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {avg_pool_(inputs[0])}; + } + + private: + nn::AvgPool1d avg_pool_; +}; + +class TTSProjector final : public nn::Module { + public: + TTSProjector() = default; + + TTSProjector(const std::string& name, int32_t input_dim, int32_t output_dim) : nn::Module(name) { + linear1_ = reg("linear1", input_dim, output_dim, true); + relu_ = reg("relu"); + linear2_ = reg("linear2", output_dim, output_dim, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = linear1_(inputs[0]); + x = relu_(x); + x = linear2_(x); + return {x}; + } + + private: + nn::Linear linear1_; + nn::ReLU relu_; + nn::Linear linear2_; +}; + +struct MiniCPMO45TTSGenerationConfig { + int32_t max_new_tokens = 1024; + int32_t min_new_tokens = 50; + bool force_no_stop = false; + bool do_sample = true; + int32_t top_k = 25; + float top_p = 0.85f; + float repetition_penalty = 1.05f; + int32_t repetition_penalty_window = 16; + std::vector temperature = {0.8f}; + int32_t debug_interval = 16; + std::function& tokens, bool has_eos)> step_callback = nullptr; +}; + +struct MiniCPMO45TTSGenerationOutput { + Tensor new_ids = Tensor::nil(); + bool finished = false; +}; + +class MiniCPMO45TTS final : public nn::Module { + public: + MiniCPMO45TTS() = default; + + MiniCPMO45TTS(const std::string& name, const MiniCPMO45Config& cfg) : nn::Module(name), cfg_(cfg) { + projector_spk_ = reg("projector_spk", cfg.tts_llm_dim, cfg.tts_hidden_size); + projector_semantic_ = reg("projector_semantic", cfg.tts_llm_dim, cfg.tts_hidden_size); + + emb_text_ = reg("emb_text", cfg.tts_num_text_tokens, cfg.tts_hidden_size); + + emb_code_.reserve(cfg.tts_num_vq); + for (int32_t i = 0; i < cfg.tts_num_vq; ++i) { + emb_code_.emplace_back(reg("emb_code." + std::to_string(i), cfg.tts_num_audio_tokens, cfg.tts_hidden_size)); + } + + auto llama_cfg = llama::LLaMAConfig(); + llama_cfg.vocab_size = cfg.tts_backbone_vocab_size; + llama_cfg.hidden_size = cfg.tts_hidden_size; + llama_cfg.intermediate_size = cfg.tts_intermediate_size; + llama_cfg.num_attention_heads = cfg.tts_num_attention_heads; + llama_cfg.num_key_value_heads = cfg.tts_num_key_value_heads; + llama_cfg.num_hidden_layers = cfg.tts_num_hidden_layers; + llama_cfg.max_position_embeddings = cfg.tts_max_position_embeddings; + llama_cfg.rms_norm_eps = cfg.tts_rms_norm_eps; + llama_cfg.rope_theta = cfg.tts_rope_theta; + llama_cfg.hidden_act = cfg.tts_hidden_act; + llama_cfg.tie_word_embeddings = false; + llama_cfg.attention_bias = false; + llama_cfg.linear_impl_type = cfg.linear_impl_type; + model_ = reg("model", llama_cfg); + } + + void loadFromParameter(const ParameterFile::ptr_t& param_file) { + nn::Module::load(param_file); + + head_code_weight_.clear(); + head_code_weight_.reserve(cfg_.tts_num_vq); + + auto prefix = getModuleName() + ".head_code."; + for (int32_t i = 0; i < cfg_.tts_num_vq; ++i) { + auto g = param_file->pull(prefix + std::to_string(i) + ".parametrizations.weight.original0"); + auto v = param_file->pull(prefix + std::to_string(i) + ".parametrizations.weight.original1"); + if (g.dtype() != kFloat32) { g = g.to(kFloat32); } + if (v.dtype() != kFloat32) { v = v.to(kFloat32); } + g = g.contiguous(); + v = v.contiguous().view({cfg_.tts_num_audio_tokens, cfg_.tts_hidden_size}); + + auto weight = Tensor::empty({cfg_.tts_num_audio_tokens, cfg_.tts_hidden_size}, kFloat32, kCPU).alloc(); + + auto* g_ptr = g.ptr(); + auto* v_ptr = v.ptr(); + auto* w_ptr = weight.ptr(); + + constexpr float kEps = 1e-12f; + for (int32_t out_idx = 0; out_idx < cfg_.tts_num_audio_tokens; ++out_idx) { + float norm = 0.0f; + auto row_offset = out_idx * cfg_.tts_hidden_size; + for (int32_t d = 0; d < cfg_.tts_hidden_size; ++d) { + auto val = v_ptr[row_offset + d]; + norm += val * val; + } + norm = std::sqrt(norm); + if (norm < kEps) { norm = kEps; } + auto scale = g_ptr[out_idx] / norm; + for (int32_t d = 0; d < cfg_.tts_hidden_size; ++d) { w_ptr[row_offset + d] = v_ptr[row_offset + d] * scale; } + } + + head_code_weight_.push_back(weight); + } + } + + Tensor makeConditionEmbeddings(const std::vector& text_token_ids, const std::vector& text_hidden_states) { + if (text_token_ids.empty() || text_hidden_states.empty()) { return Tensor::nil(); } + if (text_token_ids.size() != text_hidden_states.size()) { + MLLM_ERROR("MiniCPM-o-4_5 TTS input mismatch: token count {} != hidden count {}.", + text_token_ids.size(), text_hidden_states.size()); + return Tensor::nil(); + } + + Tensor token_ids = Tensor::empty({1, static_cast(text_token_ids.size())}, kInt64, kCPU).alloc(); + auto* token_ids_ptr = token_ids.ptr(); + for (size_t i = 0; i < text_token_ids.size(); ++i) { + auto token_id = text_token_ids[i]; + if (token_id < 0 || token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS text token id out of range: token_id={} valid=[0, {}).", + token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + token_ids_ptr[i] = token_id; + } + + auto llm_embeds = emb_text_(token_ids); + + Tensor hidden_states = text_hidden_states.size() == 1 ? text_hidden_states[0] : nn::functional::concat(text_hidden_states, 1); + auto projected_hidden = projector_semantic_(hidden_states)[0]; + if (cfg_.tts_normalize_projected_hidden) { projected_hidden = normalizeProjectedHidden(projected_hidden); } + + auto tts_embeds = llm_embeds + projected_hidden; + + Tensor text_eos = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + text_eos.ptr()[0] = cfg_.tts_text_eos_token_id; + Tensor audio_bos = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + audio_bos.ptr()[0] = cfg_.tts_audio_bos_token_id; + if (cfg_.tts_text_eos_token_id < 0 || cfg_.tts_text_eos_token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS text_eos_token_id out of range: {} (vocab={}).", + cfg_.tts_text_eos_token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + if (cfg_.tts_audio_bos_token_id < 0 || cfg_.tts_audio_bos_token_id >= cfg_.tts_num_text_tokens) { + MLLM_ERROR("MiniCPM-o-4_5 TTS audio_bos_token_id out of range: {} (vocab={}).", + cfg_.tts_audio_bos_token_id, cfg_.tts_num_text_tokens); + return Tensor::nil(); + } + + auto text_eos_embed = emb_text_(text_eos); + auto audio_bos_embed = emb_text_(audio_bos); + + return nn::functional::concat({tts_embeds, text_eos_embed, audio_bos_embed}, 1); + } + + MiniCPMO45TTSGenerationOutput generate(const Tensor& condition_embeds, + const MiniCPMO45TTSGenerationConfig& generation_cfg = {}) { + if (condition_embeds.isNil()) { return {}; } + + auto eos_token = cfg_.tts_num_audio_tokens - 1; + + std::vector temperature = generation_cfg.temperature; + if (temperature.empty()) { temperature.assign(cfg_.tts_num_vq, 1.0f); } + if (temperature.size() < static_cast(cfg_.tts_num_vq)) { + temperature.resize(cfg_.tts_num_vq, temperature.back()); + } + + nn::StaticCache kv_cache(cfg_.tts_max_position_embeddings, cfg_.tts_num_hidden_layers, + cfg_.tts_num_attention_heads, // q heads + cfg_.tts_num_key_value_heads, // kv heads + cfg_.tts_hidden_size / cfg_.tts_num_attention_heads, + kFloat32, // k dtype + kFloat32, // v dtype + kCPU, // device + false // use fa2 + ); + + Tensor generated = Tensor::zeros({1, generation_cfg.max_new_tokens, cfg_.tts_num_vq}, kInt64, kCPU); + int32_t generated_len = 0; + bool finished = false; + auto condition_length = condition_embeds.shape()[1]; + std::vector> generated_history(cfg_.tts_num_vq); + + for (int32_t t = 0; t < generation_cfg.max_new_tokens; ++t) { + Tensor inputs_embeds = Tensor::nil(); + Tensor position_ids = Tensor::nil(); + + if (t == 0) { + inputs_embeds = condition_embeds; + position_ids = Tensor::empty({1, condition_length}, kInt64, kCPU).alloc(); + auto* position_ids_ptr = position_ids.ptr(); + for (int32_t i = 0; i < condition_length; ++i) { position_ids_ptr[i] = i; } + } else { + for (int32_t q = 0; q < cfg_.tts_num_vq; ++q) { + auto code_ids = generated[{kAll, {t - 1, t}, {q, q + 1}}].contiguous().view({1, 1}); + auto code_embeds = emb_code_[q](code_ids); + if (q == 0) { + inputs_embeds = code_embeds; + } else { + inputs_embeds = inputs_embeds + code_embeds; + } + } + position_ids = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + position_ids.ptr()[0] = condition_length + t - 1; + } + + auto [llm_embedding_sin, llm_embedding_cos] = llama::makeRotaryPosEmbedding(position_ids, model_.getBuffer("inv_freq"), 1.0f); + Tensor causal_mask = Tensor::nil(); + auto* cache_ptr = static_cast(&kv_cache); + auto hidden_states = model_(inputs_embeds, llm_embedding_sin, llm_embedding_cos, causal_mask, AnyValue(cache_ptr))[0]; + + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1, seq_len}, kAll}].contiguous(); + + bool has_eos = false; + std::vector step_tokens; + step_tokens.reserve(cfg_.tts_num_vq); + for (int32_t q = 0; q < cfg_.tts_num_vq; ++q) { + MLLM_RT_ASSERT(q < static_cast(head_code_weight_.size())); + auto logits = nn::functional::matmul(last_hidden, head_code_weight_[q], false, true)[{0, 0, kAll}].contiguous(); + auto temp = std::max(temperature[q], 1e-5f); + logits = logits / temp; + + if (t > 0) { + applyRepetitionPenalty(logits, generated_history[q], generation_cfg.repetition_penalty, + generation_cfg.repetition_penalty_window); + applyTopPLogits(logits, generation_cfg.top_p, 3); + applyTopKLogits(logits, generation_cfg.top_k, 3); + } + + if (t < generation_cfg.min_new_tokens || generation_cfg.force_no_stop) { + if (logits.dtype() == kFloat32) { + logits.ptr()[eos_token] = -std::numeric_limits::infinity(); + } else if (logits.dtype() == kFloat16) { + logits.ptr()[eos_token] = -65504.0f; + } + } + + bool use_sampling = generation_cfg.do_sample || generation_cfg.top_k > 0 || generation_cfg.top_p > 0.0f + || std::abs(temp - 1.0f) > 1e-6f; + auto token_id = sampleFromLogits(logits, use_sampling); + *generated.offsettedPtr({0, t, q}) = token_id; + generated_history[q].push_back(token_id); + step_tokens.push_back(token_id); + has_eos = has_eos || token_id == eos_token; + } + + if (generation_cfg.step_callback) { + auto interval = std::max(generation_cfg.debug_interval, 1); + if (t == 0 || ((t + 1) % interval) == 0 || has_eos) { + generation_cfg.step_callback(t + 1, step_tokens, has_eos); + } + } + + generated_len = t + 1; + if (has_eos) { + finished = true; + break; + } + } + + auto out_len = generated_len; + if (finished && out_len > 0) { out_len -= 1; } // do not return terminal token + + Tensor out_ids = Tensor::nil(); + if (out_len > 0) { out_ids = generated[{kAll, {0, out_len}, kAll}].contiguous(); } + return {.new_ids = out_ids, .finished = finished}; + } + + private: + static int64_t argmax1d(const Tensor& logits) { + auto probs = logits; + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + auto* data = probs.ptr(); + auto n = probs.shape().back(); + + auto max_idx = 0; + auto max_value = data[0]; + for (int32_t i = 1; i < n; ++i) { + if (data[i] > max_value) { + max_value = data[i]; + max_idx = i; + } + } + return max_idx; + } + + static int64_t sampleFromLogits(Tensor logits, bool do_sample) { + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (!do_sample) { return argmax1d(logits); } + + auto probs = nn::functional::softmax(logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample1d(probs); + } + + static int64_t categoricalSample1d(const Tensor& probs) { + MLLM_RT_ASSERT_EQ(probs.dtype(), kFloat32); + auto* prob_data = probs.ptr(); + auto vocab_size = probs.shape().back(); + + std::vector cumulative_probs(vocab_size); + std::partial_sum(prob_data, prob_data + vocab_size, cumulative_probs.begin()); + + auto total = cumulative_probs.back(); + if (total <= 0.0f) { return argmax1d(probs); } + + static thread_local std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution dist(0.0f, total); + auto target = dist(rng); + + auto it = std::lower_bound(cumulative_probs.begin(), cumulative_probs.end(), target); + if (it == cumulative_probs.end()) { return vocab_size - 1; } + return static_cast(std::distance(cumulative_probs.begin(), it)); + } + + static void applyRepetitionPenalty(Tensor& logits, const std::vector& token_ids, float penalty, + int32_t past_window) { + if (penalty <= 1.0f || token_ids.empty()) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + std::unordered_map frequencies; + + int32_t start = 0; + if (past_window > 0 && static_cast(token_ids.size()) > past_window) { + start = static_cast(token_ids.size()) - past_window; + } + for (int32_t i = start; i < static_cast(token_ids.size()); ++i) { + auto token_id = token_ids[i]; + if (token_id < 0 || token_id >= vocab_size) { continue; } + frequencies[token_id] += 1; + } + + auto* logits_ptr = logits.ptr(); + for (const auto& [token_id, freq] : frequencies) { + auto alpha = std::pow(penalty, static_cast(freq)); + float& value = logits_ptr[token_id]; + value = value < 0.0f ? value * alpha : value / alpha; + } + } + + static void applyTopKLogits(Tensor& logits, int32_t top_k, int32_t min_tokens_to_keep) { + if (top_k <= 0) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + int32_t k = std::min(std::max(top_k, min_tokens_to_keep), vocab_size); + if (k >= vocab_size) { return; } + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&logits_ptr](int32_t lhs, int32_t rhs) { return logits_ptr[lhs] > logits_ptr[rhs]; }); + + auto threshold = logits_ptr[indices[k - 1]]; + auto neg_inf = -std::numeric_limits::infinity(); + for (int32_t i = 0; i < vocab_size; ++i) { + if (logits_ptr[i] < threshold) { logits_ptr[i] = neg_inf; } + } + } + + static void applyTopPLogits(Tensor& logits, float top_p, int32_t min_tokens_to_keep) { + if (top_p <= 0.0f || top_p >= 1.0f) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + auto vocab_size = logits.shape().back(); + if (vocab_size <= min_tokens_to_keep) { return; } + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), + [&logits_ptr](int32_t lhs, int32_t rhs) { return logits_ptr[lhs] > logits_ptr[rhs]; }); + + auto max_logit = logits_ptr[indices[0]]; + std::vector probs(vocab_size); + float sum_exp = 0.0f; + for (int32_t i = 0; i < vocab_size; ++i) { + auto prob = std::exp(logits_ptr[indices[i]] - max_logit); + probs[i] = prob; + sum_exp += prob; + } + if (sum_exp <= 0.0f) { return; } + for (auto& p : probs) { p /= sum_exp; } + + int32_t keep = 0; + float cumulative = 0.0f; + for (int32_t i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + keep += 1; + if (cumulative >= top_p && keep >= min_tokens_to_keep) { break; } + } + keep = std::max(keep, min_tokens_to_keep); + keep = std::min(keep, vocab_size); + + auto neg_inf = -std::numeric_limits::infinity(); + for (int32_t i = keep; i < vocab_size; ++i) { logits_ptr[indices[i]] = neg_inf; } + } + + static Tensor normalizeProjectedHidden(Tensor hidden_states) { + auto original_dtype = hidden_states.dtype(); + auto normalized = original_dtype == kFloat32 ? hidden_states.contiguous() : hidden_states.to(kFloat32).contiguous(); + + auto B = normalized.shape()[0]; + auto S = normalized.shape()[1]; + auto D = normalized.shape()[2]; + auto* ptr = normalized.ptr(); + + constexpr float kEps = 1e-12f; + for (int32_t b = 0; b < B; ++b) { + for (int32_t s = 0; s < S; ++s) { + auto base = b * S * D + s * D; + float norm = 0.0f; + for (int32_t d = 0; d < D; ++d) { norm += ptr[base + d] * ptr[base + d]; } + norm = std::sqrt(norm); + if (norm < kEps) { norm = kEps; } + for (int32_t d = 0; d < D; ++d) { ptr[base + d] /= norm; } + } + } + + if (original_dtype != kFloat32) { return normalized.to(original_dtype); } + return normalized; + } + + private: + MiniCPMO45Config cfg_; + TTSProjector projector_spk_; + TTSProjector projector_semantic_; + std::vector emb_code_; + nn::Embedding emb_text_; + std::vector head_code_weight_; + llama::LlamaText model_; +}; + +class MiniCPMO45TextModel final : public nn::Module { + public: + MiniCPMO45TextModel() = default; + + MiniCPMO45TextModel(const std::string& name, const MiniCPMO45Config& cfg) : Module(name) { + auto llm_cfg = toQwen3Config(cfg); + decode_blocks_ = reg>("layers", llm_cfg.num_hidden_layers, llm_cfg); + for (auto [idx, block] : enumerate(decode_blocks_.list())) { block.self_attn_.layer_idx_ = idx; } + norm_ = reg("norm", llm_cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", llm_cfg.vocab_size, llm_cfg.hidden_size); + registerBuffer("last_hidden_states", Tensor::nil()); + } + + Tensor embed(const Tensor& input_ids) { return embedding_(input_ids); } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto kv_cache = args[0].get(); + + for (auto& block : decode_blocks_.list()) { + hidden_states = block(hidden_states, llm_embedding_sin, llm_embedding_cos, AnyValue(kv_cache))[0]; + } + updateBuffer("last_hidden_states", hidden_states); + hidden_states = norm_(hidden_states); + return {hidden_states}; + } + + private: + static qwen3::Qwen3Config toQwen3Config(const MiniCPMO45Config& cfg) { + qwen3::Qwen3Config llm_cfg; + llm_cfg.attention_bias = cfg.attention_bias; + llm_cfg.hidden_size = cfg.hidden_size; + llm_cfg.head_dim = cfg.head_dim; + llm_cfg.intermediate_size = cfg.intermediate_size; + llm_cfg.num_attention_heads = cfg.num_attention_heads; + llm_cfg.num_key_value_heads = cfg.num_key_value_heads; + llm_cfg.num_hidden_layers = cfg.num_hidden_layers; + llm_cfg.max_position_embeddings = cfg.max_position_embeddings; + llm_cfg.rms_norm_eps = cfg.rms_norm_eps; + llm_cfg.vocab_size = cfg.vocab_size; + llm_cfg.bos_token_id = cfg.bos_token_id; + llm_cfg.eos_token_id = cfg.eos_token_id; + llm_cfg.end_of_text_token_id = static_cast(cfg.eos_token_id); + llm_cfg.rope_theta = cfg.rope_theta; + llm_cfg.tie_word_embeddings = cfg.tie_word_embeddings; + llm_cfg.max_cache_length = cfg.max_cache_length; + llm_cfg.linear_impl_type = cfg.linear_impl_type; + return llm_cfg; + } + + private: + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + nn::Embedding embedding_; +}; + +class MiniCPMO45LLM final : public nn::Module { + public: + MiniCPMO45LLM() = default; + + MiniCPMO45LLM(const std::string& name, const MiniCPMO45Config& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + registerBuffer("inv_freq", qwen3::makeRoPEInvFreq(cfg.head_dim, cfg.rope_theta)); + } + + Tensor embed(const Tensor& input_ids) { return model_.embedding_(input_ids); } + + Tensor logits(const Tensor& hidden_states) { return lm_head_(hidden_states); } + + Tensor hiddenStates(Tensor& input_embeddings, Tensor& llm_embedding_sin, Tensor& llm_embedding_cos, nn::StaticCache* kv_cache) { + return model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(kv_cache))[0]; + } + + public: + MiniCPMO45TextModel model_; + + private: + nn::Linear lm_head_; +}; + +class MiniCPMO45ForCausalLM : public models::ARGeneration { + public: + struct TextGenerationWithHiddenOutput { + std::vector generated_tokens; + std::vector aligned_tokens; + std::vector aligned_hidden_states; + bool finished = false; + }; + + explicit MiniCPMO45ForCausalLM(const MiniCPMO45Config& config) + : config_(config), + legacy_config_(createLegacyConfig(config)), + llm_("llm", config), + vpm_("vpm", legacy_config_), + resampler_("resampler", config.query_num, config.hidden_size, config.num_attention_heads, config.vision_hidden_size), + apm_("apm", legacy_config_), + audio_projection_layer_("audio_projection_layer", config.audio_hidden_size, config.hidden_size, config.hidden_size), + audio_avg_pooler_("audio_avg_pooler", config.audio_pool_step, config.audio_pool_step), + tts_("tts", config), + kv_cache_(config.max_cache_length, config.num_hidden_layers, + config.num_attention_heads, // q heads + config.num_key_value_heads, // kv heads + config.head_dim, // kv dim + kFloat32, // k dtype + kFloat32, // v dtype + kCPU, // device + false // use fa2 + ) { + eos_token_id_ = static_cast(config.eos_token_id); + max_length_ = config.max_cache_length; + } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& inputs, const ARGenerationArgs& args) override { + Tensor input_ids = Tensor::nil(); + if (inputs.count("input_ids")) { + input_ids = inputs.at("input_ids"); + } else if (inputs.count("sequence")) { + input_ids = inputs.at("sequence"); + } else { + MLLM_ERROR("No input_ids or sequence found in MiniCPM-o-4_5 forward input."); + return {}; + } + + auto input_embeddings = llm_.embed(input_ids); + + Tensor prev_position_ids = inputs.count("position_ids") ? inputs.at("position_ids") : Tensor::nil(); + + // Prefill-only multimodal embedding injection. + if (prev_position_ids.isNil()) { + auto pixel_values = inputs.count("pixel_values") ? inputs.at("pixel_values") : Tensor::nil(); + auto tgt_sizes = inputs.count("tgt_sizes") ? inputs.at("tgt_sizes") : Tensor::nil(); + auto image_bounds = inputs.count("image_bounds") ? inputs.at("image_bounds") : Tensor::nil(); + + if (!pixel_values.isNil() && !tgt_sizes.isNil() && !image_bounds.isNil()) { + auto vision_outputs = vpm_(pixel_values, tgt_sizes)[0]; + auto vision_embeddings = resampler_(vision_outputs, tgt_sizes)[0]; + input_embeddings = mergeVisionTextEmbeddings(input_embeddings, vision_embeddings, image_bounds); + } + + auto audio_features = inputs.count("audio_features") ? inputs.at("audio_features") : Tensor::nil(); + auto audio_bounds = inputs.count("audio_bounds") ? inputs.at("audio_bounds") : Tensor::nil(); + + if (!audio_features.isNil() && !audio_bounds.isNil()) { + auto audio_embeddings = encodeAudio(audio_features); + input_embeddings = mergeAudioTextEmbeddings(input_embeddings, audio_embeddings, audio_bounds); + } + } + + Tensor position_ids = makePositionIds(input_embeddings.shape()[1], prev_position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = qwen3::makeRotaryPosEmbedding(position_ids, llm_.getBuffer("inv_freq"), 1.0f); + + auto hidden_states = llm_.hiddenStates(input_embeddings, llm_embedding_sin, llm_embedding_cos, &kv_cache_); + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1, seq_len}, kAll}].contiguous(); + auto logits = llm_.logits(last_hidden); + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + {"last_hidden", last_hidden}, + }; + } + + Tensor encodeAudio(const Tensor& audio_features) { + // 1) Whisper encoder. + auto audio_states = apm_(audio_features)[0]; + + // 2) Project to the LLM hidden space. + auto audio_embeds = audio_projection_layer_(audio_states)[0]; + + // 3) Temporal pooling. + audio_embeds = audio_embeds.transpose(1, 2); + audio_embeds = audio_avg_pooler_(audio_embeds)[0]; + audio_embeds = audio_embeds.transpose(1, 2); + return audio_embeds; + } + + TextGenerationWithHiddenOutput generateTextWithHidden(const ARGenerationOutputPast& initial_inputs, int32_t max_new_tokens, + const std::vector& stop_token_ids, bool do_sample = false, + float temperature = 1.0f, int32_t top_k = 0, float top_p = 0.0f, + const std::function& step_callback = + nullptr) { + TextGenerationWithHiddenOutput result; + + auto current_input = initial_inputs; + bool has_previous_generated = false; + int64_t previous_generated_token = 0; + + for (int32_t i = 0; i < max_new_tokens; ++i) { + auto output = forward(current_input, {}); + + if (has_previous_generated && output.count("last_hidden")) { + result.aligned_tokens.push_back(previous_generated_token); + result.aligned_hidden_states.push_back(output.at("last_hidden").contiguous().clone()); + } + + Tensor logits = output.at("sequence"); + int64_t next_token_id = 0; + if (do_sample || temperature != 1.0f || top_k > 0 || top_p > 0.0f) { + if (top_k > 0) { + next_token_id = sampleTopK(logits, top_k, temperature); + } else if (top_p > 0.0f) { + next_token_id = sampleTopP(logits, top_p, temperature); + } else { + next_token_id = sampleTemperature(logits, temperature); + } + } else { + next_token_id = sampleGreedy(logits); + } + result.generated_tokens.push_back(next_token_id); + if (step_callback) { step_callback(i + 1, next_token_id); } + + if (isStopToken(next_token_id, stop_token_ids)) { + result.finished = true; + break; + } + + current_input = std::move(output); + current_input["sequence"] = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + current_input["sequence"].at({0, 0}) = next_token_id; + + previous_generated_token = next_token_id; + has_previous_generated = true; + } + + if (!result.finished && has_previous_generated + && result.aligned_tokens.size() + 1 == result.generated_tokens.size()) { + auto probe_output = forward(current_input, {}); + if (probe_output.count("last_hidden")) { + result.aligned_tokens.push_back(previous_generated_token); + result.aligned_hidden_states.push_back(probe_output.at("last_hidden").contiguous().clone()); + } + } + + return result; + } + + public: + MiniCPMO45Config config_; + minicpmo::MiniCPMOConfig legacy_config_; + + MiniCPMO45LLM llm_; + minicpmo::SiglipVisionModel vpm_; + minicpmo::Resampler resampler_; + + minicpmo::WhisperEncoder apm_; + AudioProjectionLayer audio_projection_layer_; + AudioAvgPooler audio_avg_pooler_; + MiniCPMO45TTS tts_; + + private: + template + static void copyEmbeddingVector(Tensor& dst, const Tensor& src, int32_t dst_batch, int32_t dst_pos, int32_t src_batch, + int32_t src_pos, int32_t hidden_size) { + auto* dst_ptr = dst.offsettedPtr({dst_batch, dst_pos, 0}); + auto* src_ptr = src.coffsettedPtr({src_batch, src_pos, 0}); + std::memcpy(dst_ptr, src_ptr, hidden_size * sizeof(DType)); + } + + static Tensor mergeVisionTextEmbeddings(Tensor& text_embeddings, Tensor& vision_embeddings, const Tensor& image_bounds) { + auto batch_size = text_embeddings.shape()[0]; + auto hidden_size = text_embeddings.shape()[2]; + auto vision_seq_len = vision_embeddings.shape()[1]; + auto num_bounds = std::min(image_bounds.shape()[0], vision_embeddings.shape()[0]); + + if (vision_embeddings.shape()[0] != image_bounds.shape()[0]) { + MLLM_WARN("MiniCPM-o-4_5 vision bound count ({}) != embedding group count ({}). Using min={}.", + image_bounds.shape()[0], vision_embeddings.shape()[0], num_bounds); + } + + if (vision_embeddings.dtype() != text_embeddings.dtype()) { vision_embeddings = vision_embeddings.to(text_embeddings.dtype()); } + + for (int32_t b = 0; b < batch_size; ++b) { + auto image_bounds_ptr = image_bounds.ptr(); + for (int32_t bound_idx = 0; bound_idx < num_bounds; ++bound_idx) { + int32_t vision_idx = 0; + auto start_pos = image_bounds_ptr[bound_idx * 2] + 1; + auto end_pos = image_bounds_ptr[bound_idx * 2 + 1] - 1; + + for (int32_t pos = start_pos; pos <= end_pos && vision_idx < vision_seq_len; ++pos, ++vision_idx) { + if (text_embeddings.dtype() == kFloat32) { + copyEmbeddingVector(text_embeddings, vision_embeddings, b, pos, bound_idx, vision_idx, hidden_size); + } else if (text_embeddings.dtype() == kFloat16) { + copyEmbeddingVector(text_embeddings, vision_embeddings, b, pos, bound_idx, vision_idx, hidden_size); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported text embedding dtype in MiniCPM-o-4_5 vision merge."); + } + } + } + } + return text_embeddings; + } + + static Tensor mergeAudioTextEmbeddings(Tensor& text_embeddings, Tensor& audio_embeddings, const Tensor& audio_bounds) { + auto batch_size = text_embeddings.shape()[0]; + auto hidden_size = text_embeddings.shape()[2]; + auto audio_seq_len = audio_embeddings.shape()[1]; + auto num_bounds = std::min(audio_bounds.shape()[0], audio_embeddings.shape()[0]); + + if (audio_embeddings.shape()[0] != audio_bounds.shape()[0]) { + MLLM_WARN("MiniCPM-o-4_5 audio bound count ({}) != embedding group count ({}). Using min={}.", + audio_bounds.shape()[0], audio_embeddings.shape()[0], num_bounds); + } + + if (audio_embeddings.dtype() != text_embeddings.dtype()) { audio_embeddings = audio_embeddings.to(text_embeddings.dtype()); } + + for (int32_t b = 0; b < batch_size; ++b) { + auto audio_bounds_ptr = audio_bounds.ptr(); + for (int32_t bound_idx = 0; bound_idx < num_bounds; ++bound_idx) { + int32_t audio_idx = 0; + auto start_pos = audio_bounds_ptr[bound_idx * 2]; + auto end_pos = audio_bounds_ptr[bound_idx * 2 + 1] - 1; + + for (int32_t pos = start_pos; pos <= end_pos && audio_idx < audio_seq_len; ++pos, ++audio_idx) { + if (text_embeddings.dtype() == kFloat32) { + copyEmbeddingVector(text_embeddings, audio_embeddings, b, pos, bound_idx, audio_idx, hidden_size); + } else if (text_embeddings.dtype() == kFloat16) { + copyEmbeddingVector(text_embeddings, audio_embeddings, b, pos, bound_idx, audio_idx, hidden_size); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported text embedding dtype in MiniCPM-o-4_5 audio merge."); + } + } + } + } + return text_embeddings; + } + + Tensor makePositionIds(int32_t seq_len, const Tensor& prev_position_ids) { + Tensor position_ids = Tensor::empty({1, seq_len}, kInt64).alloc(); + if (!prev_position_ids.isNil()) { + auto last_pos = *prev_position_ids.coffsettedPtr({0, prev_position_ids.shape()[1] - 1}); + auto* position_ids_ptr = position_ids.ptr(); + for (int32_t i = 0; i < seq_len; ++i) { position_ids_ptr[i] = last_pos + i + 1; } + return position_ids; + } + + auto last_seen_tokens = kv_cache_.getCurrentSeqCnt(0); + auto* position_ids_ptr = position_ids.ptr(); + for (int32_t i = 0; i < seq_len; ++i) { position_ids_ptr[i] = last_seen_tokens + i; } + return position_ids; + } + + static minicpmo::MiniCPMOConfig createLegacyConfig(const MiniCPMO45Config& config) { + minicpmo::MiniCPMOConfig legacy; + legacy.vision_hidden_size = config.vision_hidden_size; + legacy.vision_intermediate_size = config.vision_intermediate_size; + legacy.vision_num_hidden_layers = config.vision_num_hidden_layers; + legacy.vision_num_attention_heads = config.vision_num_attention_heads; + legacy.vision_num_channels = config.vision_num_channels; + legacy.vision_image_size = config.vision_image_size; + legacy.vision_patch_size = config.vision_patch_size; + + legacy.hidden_size = config.hidden_size; + legacy.intermediate_size = config.intermediate_size; + legacy.num_attention_heads = config.num_attention_heads; + legacy.num_key_value_heads = config.num_key_value_heads; + legacy.num_hidden_layers = config.num_hidden_layers; + legacy.max_position_embeddings = config.max_position_embeddings; + legacy.rms_norm_eps = config.rms_norm_eps; + legacy.vocab_size = config.vocab_size; + + legacy.query_num = config.query_num; + + legacy.audio_hidden_size = config.audio_hidden_size; + legacy.audio_num_hidden_layers = config.audio_num_hidden_layers; + legacy.audio_num_attention_heads = config.audio_num_attention_heads; + legacy.audio_max_position_embeddings = config.audio_max_position_embeddings; + legacy.audio_chunk_length = config.audio_chunk_length; + legacy.audio_pool_step = config.audio_pool_step; + + legacy.max_cache_length = config.max_cache_length; + legacy.eos_token_id = config.eos_token_id; + legacy.bos_token_id = config.bos_token_id; + legacy.rope_theta = config.rope_theta; + legacy.tie_word_embeddings = config.tie_word_embeddings; + + legacy.linear_impl_type = config.linear_impl_type; + return legacy; + } + + static bool isStopToken(int64_t token_id, const std::vector& stop_token_ids) { + for (auto id : stop_token_ids) { + if (token_id == id) { return true; } + } + return false; + } + + private: + nn::StaticCache kv_cache_; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp b/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp new file mode 100644 index 000000000..0e145d22c --- /dev/null +++ b/mllm/models/minicpm_o45/modeling_minicpm_o45_token2wav.hpp @@ -0,0 +1,1522 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/layers/STFT.hpp" +#include "mllm/utils/Enumerate.hpp" +#include "mllm/utils/Log.hpp" + +#include "mllm/models/minicpm_o45/token2wav_prompt_cache.hpp" +#include "mllm/models/minicpm_o45/token2wav_weight_norm.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45FlowConfig { + int32_t input_size = 512; + int32_t output_size = 80; + int32_t spk_embed_dim = 192; + int32_t vocab_size = 6561; + int32_t up_rate = 2; + + int32_t encoder_attention_heads = 8; + int32_t encoder_linear_units = 2048; + int32_t encoder_num_blocks = 6; + int32_t encoder_num_up_blocks = 4; + int32_t pre_lookahead_len = 3; + + int32_t dit_in_channels = 320; + int32_t dit_out_channels = 80; + float dit_mlp_ratio = 4.0f; + int32_t dit_depth = 16; + int32_t dit_num_heads = 8; + int32_t dit_head_dim = 64; + int32_t dit_hidden_size = 512; + float cfm_inference_cfg_rate = 0.7f; +}; + +struct MiniCPMO45HiFTConfig { + int32_t in_channels = 80; + int32_t base_channels = 512; + int32_t nb_harmonics = 8; + int32_t sampling_rate = 24000; + float nsf_alpha = 0.1f; + float nsf_sigma = 0.003f; + float nsf_voiced_threshold = 10.0f; + std::vector upsample_rates = {8, 5, 3}; + std::vector upsample_kernel_sizes = {16, 11, 7}; + int32_t istft_n_fft = 16; + int32_t istft_hop_len = 4; + std::vector resblock_kernel_sizes = {3, 7, 11}; + std::vector> resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + std::vector source_resblock_kernel_sizes = {7, 7, 11}; + std::vector> source_resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + float lrelu_slope = 0.1f; + float audio_limit = 0.99f; +}; + +struct MiniCPMO45Token2WavConfig { + MiniCPMO45FlowConfig flow{}; + MiniCPMO45HiFTConfig hift{}; +}; + +namespace token2wav { + +inline bool isDebugEnabled() { + static bool enabled = []() { + const char* v = std::getenv("MLLM_TOKEN2WAV_DEBUG"); + if (v == nullptr) { return false; } + return std::string(v) != "0"; + }(); + return enabled; +} + +inline void debugLog(const std::string& msg) { + if (!isDebugEnabled()) { return; } + std::cerr << "[token2wav-cpp] " << msg << std::endl; +} + +inline std::string shapeOf(const Tensor& x) { + std::string s = "["; + const auto& sh = x.shape(); + for (int32_t i = 0; i < static_cast(sh.size()); ++i) { + s += std::to_string(sh[i]); + if (i + 1 != static_cast(sh.size())) { s += ","; } + } + s += "]"; + return s; +} + +inline std::string descOf(const Tensor& x) { + return "shape=" + shapeOf(x) + ",dtype=" + std::to_string(static_cast(x.dtype())); +} + +inline Tensor repeatInterleaveSeq(Tensor x, int32_t repeats) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 3); + const int32_t batch = shape[0]; + const int32_t seq_len = shape[1]; + const int32_t channels = shape[2]; + + auto out = Tensor::empty({batch, seq_len * repeats, channels}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t in_stride_b = static_cast(seq_len) * channels; + const int64_t out_stride_b = static_cast(seq_len) * repeats * channels; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + static_cast(b) * in_stride_b; + float* dst_b = dst + static_cast(b) * out_stride_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float* src_s = src_b + static_cast(s) * channels; + for (int32_t r = 0; r < repeats; ++r) { + float* dst_s = dst_b + (static_cast(s) * repeats + r) * channels; + std::memcpy(dst_s, src_s, sizeof(float) * channels); + } + } + } + return out; +} + +inline Tensor concatInt64Seq(Tensor a, Tensor b) { + MLLM_RT_ASSERT_EQ(a.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(b.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(static_cast(a.shape().size()), 2); + MLLM_RT_ASSERT_EQ(static_cast(b.shape().size()), 2); + MLLM_RT_ASSERT_EQ(a.shape()[0], b.shape()[0]); + + auto av = a.contiguous(); + auto bv = b.contiguous(); + const int32_t B = av.shape()[0]; + const int32_t Ta = av.shape()[1]; + const int32_t Tb = bv.shape()[1]; + auto out = Tensor::empty({B, Ta + Tb}, kInt64, kCPU).alloc(); + + const auto* ap = av.ptr(); + const auto* bp = bv.ptr(); + auto* op = out.ptr(); + for (int32_t bidx = 0; bidx < B; ++bidx) { + std::memcpy(op + static_cast(bidx) * (Ta + Tb), ap + static_cast(bidx) * Ta, sizeof(int64_t) * Ta); + std::memcpy(op + static_cast(bidx) * (Ta + Tb) + Ta, bp + static_cast(bidx) * Tb, sizeof(int64_t) * Tb); + } + return out; +} + +inline Tensor repeatInterleave1d(Tensor x, int32_t repeats) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 3); + const int32_t batch = shape[0]; + const int32_t channels = shape[1]; + const int32_t seq_len = shape[2]; + + auto out = Tensor::empty({batch, channels, seq_len * repeats}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + + const int64_t in_stride_b = static_cast(channels) * seq_len; + const int64_t out_stride_b = static_cast(channels) * seq_len * repeats; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + static_cast(b) * in_stride_b; + float* dst_b = dst + static_cast(b) * out_stride_b; + for (int32_t c = 0; c < channels; ++c) { + const float* src_c = src_b + static_cast(c) * seq_len; + float* dst_c = dst_b + static_cast(c) * seq_len * repeats; + for (int32_t t = 0; t < seq_len; ++t) { + const float v = src_c[t]; + for (int32_t r = 0; r < repeats; ++r) { dst_c[t * repeats + r] = v; } + } + } + } + return out; +} + +inline Tensor l2NormalizeRow(Tensor x, float eps = 1e-12f) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const auto& shape = in.shape(); + MLLM_RT_ASSERT_EQ(static_cast(shape.size()), 2); + const int32_t batch = shape[0]; + const int32_t dim = shape[1]; + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + + const auto* src = in.ptr(); + auto* dst = out.ptr(); + for (int32_t b = 0; b < batch; ++b) { + const int64_t base = static_cast(b) * dim; + float norm = 0.0f; + for (int32_t i = 0; i < dim; ++i) { + const float v = src[base + i]; + norm += v * v; + } + norm = std::sqrt(std::max(norm, eps)); + for (int32_t i = 0; i < dim; ++i) { dst[base + i] = src[base + i] / norm; } + } + return out; +} + +inline Tensor tensorMish(Tensor x) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + const float sp = std::log1p(std::exp(v)); + dst[i] = v * std::tanh(sp); + }); + return out; +} + +inline Tensor tensorLeakyRelu(Tensor x, float slope) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + dst[i] = (v >= 0.0f) ? v : (v * slope); + }); + return out; +} + +inline Tensor makeHannWindow(int32_t win_length) { + auto w = Tensor::empty({1, win_length}, kFloat32, kCPU).alloc(); + auto* ptr = w.ptr(); + constexpr float kPi = 3.14159265358979323846f; + for (int32_t i = 0; i < win_length; ++i) { + ptr[i] = 0.5f - 0.5f * std::cos(2.0f * kPi * static_cast(i) / static_cast(win_length)); + } + return w; +} + +inline Tensor relShift(Tensor x) { + // x: [B, H, T, 2T-1], output [B, H, T, T] + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto in = x.contiguous(); + const int32_t B = in.shape()[0]; + const int32_t H = in.shape()[1]; + const int32_t T = in.shape()[2]; + const int32_t R = in.shape()[3]; + MLLM_RT_ASSERT_EQ(R, 2 * T - 1); + + auto out = Tensor::empty({B, H, T, T}, kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t in_stride_b = static_cast(H) * T * R; + const int64_t in_stride_h = static_cast(T) * R; + const int64_t in_stride_t = R; + const int64_t out_stride_b = static_cast(H) * T * T; + const int64_t out_stride_h = static_cast(T) * T; + const int64_t out_stride_t = T; + + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + const float* src_h = src + static_cast(b) * in_stride_b + static_cast(h) * in_stride_h; + float* dst_h = dst + static_cast(b) * out_stride_b + static_cast(h) * out_stride_h; + for (int32_t i = 0; i < T; ++i) { + const float* src_i = src_h + static_cast(i) * in_stride_t; + float* dst_i = dst_h + static_cast(i) * out_stride_t; + for (int32_t j = 0; j < T; ++j) { + const int32_t src_idx = j - i + T - 1; + dst_i[j] = src_i[src_idx]; + } + } + } + } + return out; +} + +inline void addHeadBiasInplace(Tensor& q, Tensor bias) { + // q: [B, H, T, D], bias: [H, D] + MLLM_RT_ASSERT_EQ(q.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(bias.dtype(), kFloat32); + auto qv = q.contiguous(); + auto bv = bias.contiguous(); + const int32_t B = qv.shape()[0]; + const int32_t H = qv.shape()[1]; + const int32_t T = qv.shape()[2]; + const int32_t D = qv.shape()[3]; + MLLM_RT_ASSERT_EQ(bv.shape()[0], H); + MLLM_RT_ASSERT_EQ(bv.shape()[1], D); + auto* q_ptr = qv.ptr(); + const auto* b_ptr = bv.ptr(); + + const int64_t q_stride_b = static_cast(H) * T * D; + const int64_t q_stride_h = static_cast(T) * D; + const int64_t q_stride_t = D; + + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + const float* bh = b_ptr + static_cast(h) * D; + for (int32_t t = 0; t < T; ++t) { + float* row = q_ptr + static_cast(b) * q_stride_b + static_cast(h) * q_stride_h + + static_cast(t) * q_stride_t; + for (int32_t d = 0; d < D; ++d) { row[d] += bh[d]; } + } + } + } + q = qv; +} + +inline Tensor concatChannel(const std::vector& xs) { + return nn::functional::concat(xs, 1); +} + +inline Tensor makeTimeStepsTensor(const std::vector& values) { + auto t = Tensor::empty({static_cast(values.size())}, kFloat32, kCPU).alloc(); + auto* ptr = t.ptr(); + for (size_t i = 0; i < values.size(); ++i) { ptr[i] = values[i]; } + return t; +} + +inline Tensor randomNormalLike(const std::vector& shape, float scale = 1.0f) { + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { ptr[i] = dist(rng) * scale; } + return out; +} + +class EspnetRelPositionalEncoding final : public nn::Module { + public: + EspnetRelPositionalEncoding() = default; + EspnetRelPositionalEncoding(const std::string& name, int32_t dim) : nn::Module(name), dim_(dim) { xscale_ = std::sqrt(static_cast(dim)); } + + std::pair forwardWithPos(Tensor x) { + const int32_t T = x.shape()[1]; + auto pos = positionEncoding(T); + return {x * xscale_, pos}; + } + + private: + Tensor positionEncoding(int32_t size) const { + const int32_t dim = dim_; + auto pe_pos = Tensor::empty({size, dim}, kFloat32, kCPU).alloc(); + auto pe_neg = Tensor::empty({size, dim}, kFloat32, kCPU).alloc(); + auto* pos_ptr = pe_pos.ptr(); + auto* neg_ptr = pe_neg.ptr(); + + for (int32_t p = 0; p < size; ++p) { + for (int32_t i = 0; i < dim; i += 2) { + const float div = std::exp(-std::log(10000.0f) * static_cast(i) / static_cast(dim)); + const float v1 = std::sin(static_cast(p) * div); + const float v2 = std::cos(static_cast(p) * div); + pos_ptr[p * dim + i] = v1; + pos_ptr[p * dim + i + 1] = v2; + neg_ptr[p * dim + i] = -v1; + neg_ptr[p * dim + i + 1] = v2; + } + } + + auto pe_positive = Tensor::empty({1, size, dim}, kFloat32, kCPU).alloc(); + auto pe_negative = Tensor::empty({1, std::max(size - 1, 0), dim}, kFloat32, kCPU).alloc(); + auto* pp = pe_positive.ptr(); + auto* pn = pe_negative.ptr(); + for (int32_t i = 0; i < size; ++i) { + std::memcpy(pp + static_cast(i) * dim, pos_ptr + static_cast(size - 1 - i) * dim, + sizeof(float) * dim); + } + for (int32_t i = 1; i < size; ++i) { + std::memcpy(pn + static_cast(i - 1) * dim, neg_ptr + static_cast(i) * dim, sizeof(float) * dim); + } + return nn::functional::concat({pe_positive, pe_negative}, 1); + } + + private: + int32_t dim_ = 0; + float xscale_ = 1.0f; +}; + +class LinearNoSubsampling final : public nn::Module { + public: + LinearNoSubsampling() = default; + LinearNoSubsampling(const std::string& name, int32_t idim, int32_t odim) : nn::Module(name) { + out_linear_ = reg("out.0", idim, odim, true); + out_norm_ = reg("out.1", std::vector{odim}, true, true, 1e-5f); + pos_enc_ = reg("pos_enc", odim); + } + + std::pair forwardWithPos(Tensor x) { + auto y = out_linear_(x); + y = out_norm_(y); + return pos_enc_.forwardWithPos(y); + } + + private: + nn::Linear out_linear_; + nn::LayerNorm out_norm_; + EspnetRelPositionalEncoding pos_enc_; +}; + +class PositionwiseFeedForward final : public nn::Module { + public: + PositionwiseFeedForward() = default; + PositionwiseFeedForward(const std::string& name, int32_t idim, int32_t hidden_units) : nn::Module(name) { + w1_ = reg("w_1", idim, hidden_units, true); + w2_ = reg("w_2", hidden_units, idim, true); + } + + Tensor forwardOne(Tensor x) { + auto y = w1_(x); + y = nn::functional::silu(y); + y = w2_(y); + return y; + } + + private: + nn::Linear w1_; + nn::Linear w2_; +}; + +class RelPositionMultiHeadedAttention final : public nn::Module { + public: + RelPositionMultiHeadedAttention() = default; + RelPositionMultiHeadedAttention(const std::string& name, int32_t n_head, int32_t n_feat, bool key_bias) + : nn::Module(name), n_head_(n_head), n_feat_(n_feat) { + d_k_ = n_feat_ / n_head_; + linear_q_ = reg("linear_q", n_feat_, n_feat_, true); + linear_k_ = reg("linear_k", n_feat_, n_feat_, key_bias); + linear_v_ = reg("linear_v", n_feat_, n_feat_, true); + linear_out_ = reg("linear_out", n_feat_, n_feat_, true); + linear_pos_ = reg("linear_pos", n_feat_, n_feat_, false); + pos_bias_u_ = reg("pos_bias_u", getModuleName() + ".pos_bias_u", Tensor::shape_t{n_head_, d_k_}); + pos_bias_v_ = reg("pos_bias_v", getModuleName() + ".pos_bias_v", Tensor::shape_t{n_head_, d_k_}); + } + + Tensor forwardOne(Tensor x, Tensor pos_emb) { + auto q = linear_q_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); // [B,H,T,D] + auto k = linear_k_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); + auto v = linear_v_(x).view({x.shape()[0], x.shape()[1], n_head_, d_k_}).transpose(1, 2); + + auto p = linear_pos_(pos_emb).view({pos_emb.shape()[0], pos_emb.shape()[1], n_head_, d_k_}).transpose(1, 2); // [1,H,2T-1,D] + + auto q_with_bias_u = q.contiguous(); + auto q_with_bias_v = q.contiguous(); + addHeadBiasInplace(q_with_bias_u, pos_bias_u_.weight()); + addHeadBiasInplace(q_with_bias_v, pos_bias_v_.weight()); + + auto matrix_ac = nn::functional::matmul(q_with_bias_u, k.transpose(2, 3), false, false); // [B,H,T,T] + auto matrix_bd = nn::functional::matmul(q_with_bias_v, p.transpose(2, 3), false, false); // [B,H,T,2T-1] + if (matrix_ac.shape()[3] != matrix_bd.shape()[3]) { matrix_bd = relShift(matrix_bd); } + auto scores = (matrix_ac + matrix_bd) / std::sqrt(static_cast(d_k_)); + auto attn = nn::functional::softmax(scores, -1); + auto y = nn::functional::matmul(attn, v, false, false); // [B,H,T,D] + y = y.transpose(1, 2).view({x.shape()[0], x.shape()[1], n_feat_}); + return linear_out_(y); + } + + private: + int32_t n_head_ = 0; + int32_t n_feat_ = 0; + int32_t d_k_ = 0; + nn::Linear linear_q_; + nn::Linear linear_k_; + nn::Linear linear_v_; + nn::Linear linear_out_; + nn::Linear linear_pos_; + nn::Param pos_bias_u_; + nn::Param pos_bias_v_; +}; + +class ConformerEncoderLayer final : public nn::Module { + public: + ConformerEncoderLayer() = default; + ConformerEncoderLayer(const std::string& name, int32_t size, int32_t n_head, int32_t linear_units, bool key_bias) + : nn::Module(name) { + self_attn_ = reg("self_attn", n_head, size, key_bias); + feed_forward_ = reg("feed_forward", size, linear_units); + norm_ff_ = reg("norm_ff", std::vector{size}, true, true, 1e-12f); + norm_mha_ = reg("norm_mha", std::vector{size}, true, true, 1e-12f); + } + + Tensor forwardOne(Tensor x, Tensor pos_emb) { + auto h = norm_mha_(x); + auto y = self_attn_.forwardOne(h, pos_emb); + y = x + y; + auto z = norm_ff_(y); + z = feed_forward_.forwardOne(z); + return y + z; + } + + private: + RelPositionMultiHeadedAttention self_attn_; + PositionwiseFeedForward feed_forward_; + nn::LayerNorm norm_ff_; + nn::LayerNorm norm_mha_; +}; + +class PreLookaheadLayer final : public nn::Module { + public: + PreLookaheadLayer() = default; + PreLookaheadLayer(const std::string& name, int32_t channels, int32_t pre_lookahead_len) : nn::Module(name), pre_(pre_lookahead_len) { + conv1_ = reg("conv1", channels, channels, pre_ + 1, 1, 0, 1, 1, true); + conv2_ = reg("conv2", channels, channels, 3, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor inputs) { + auto x = inputs.transpose(1, 2).contiguous(); // [B,C,T] + x = nn::functional::pad(x, {0, pre_}, aops::PadMode::kConstant, 0.0f); // right pad + x = conv1_(x); + x = tensorLeakyRelu(x, 0.01f); + x = nn::functional::pad(x, {2, 0}, aops::PadMode::kConstant, 0.0f); // left pad + x = conv2_(x); + x = x.transpose(1, 2).contiguous(); // [B,T,C] + return x + inputs; + } + + private: + int32_t pre_ = 3; + nn::Conv1D conv1_; + nn::Conv1D conv2_; +}; + +class Upsample1D final : public nn::Module { + public: + Upsample1D() = default; + Upsample1D(const std::string& name, int32_t channels, int32_t out_channels, int32_t stride) : nn::Module(name), stride_(stride) { + conv_ = reg("conv", channels, out_channels, stride_ * 2 + 1, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor inputs) { + auto x = repeatInterleave1d(inputs, stride_); + x = nn::functional::pad(x, {stride_ * 2, 0}, aops::PadMode::kConstant, 0.0f); + return conv_(x); + } + + int32_t stride() const { return stride_; } + + private: + int32_t stride_ = 2; + nn::Conv1D conv_; +}; + +class UpsampleConformerEncoderV2 final : public nn::Module { + public: + UpsampleConformerEncoderV2() = default; + UpsampleConformerEncoderV2(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + embed_ = reg("embed", cfg.input_size, cfg.input_size); + pre_lookahead_ = reg("pre_lookahead_layer", cfg.input_size, cfg.pre_lookahead_len); + encoders_ = reg>("encoders", cfg.encoder_num_blocks, cfg.input_size, + cfg.encoder_attention_heads, cfg.encoder_linear_units, true); + up_layer_ = reg("up_layer", cfg.input_size, cfg.input_size, cfg.up_rate); + up_embed_ = reg("up_embed", cfg.input_size, cfg.input_size); + up_encoders_ = reg>("up_encoders", cfg.encoder_num_up_blocks, cfg.input_size, + cfg.encoder_attention_heads, cfg.encoder_linear_units, true); + after_norm_ = reg("after_norm", std::vector{cfg.input_size}, true, true, 1e-5f); + } + + Tensor forwardOne(Tensor xs) { + auto [x0, pos0] = embed_.forwardWithPos(xs); + x0 = pre_lookahead_.forwardOne(x0); + for (auto& layer : encoders_.list()) { x0 = layer.forwardOne(x0, pos0); } + + x0 = x0.transpose(1, 2).contiguous(); + x0 = up_layer_.forwardOne(x0); + x0 = x0.transpose(1, 2).contiguous(); + + auto [x1, pos1] = up_embed_.forwardWithPos(x0); + for (auto& layer : up_encoders_.list()) { x1 = layer.forwardOne(x1, pos1); } + x1 = after_norm_(x1); + return x1; + } + + private: + MiniCPMO45FlowConfig cfg_; + LinearNoSubsampling embed_; + PreLookaheadLayer pre_lookahead_; + nn::ModuleList encoders_; + Upsample1D up_layer_; + LinearNoSubsampling up_embed_; + nn::ModuleList up_encoders_; + nn::LayerNorm after_norm_; +}; + +class DiTAttention final : public nn::Module { + public: + DiTAttention() = default; + DiTAttention(const std::string& name, int32_t dim, int32_t num_heads, int32_t head_dim) : nn::Module(name), + dim_(dim), heads_(num_heads), head_dim_(head_dim), inner_dim_(num_heads * head_dim) { + to_q_ = reg("to_q", dim_, inner_dim_, true); + to_k_ = reg("to_k", dim_, inner_dim_, true); + to_v_ = reg("to_v", dim_, inner_dim_, true); + q_norm_ = reg("q_norm", std::vector{head_dim_}, true, true, 1e-5f); + k_norm_ = reg("k_norm", std::vector{head_dim_}, true, true, 1e-5f); + proj_ = reg("proj", inner_dim_, dim_, true); + } + + Tensor forwardOne(Tensor x) { + debugLog("dit.attn: enter x(" + descOf(x) + ")"); + auto q = to_q_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); // [B,H,T,D] + debugLog("dit.attn: to_q done"); + auto k = to_k_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); + debugLog("dit.attn: to_k done"); + auto v = to_v_(x).view({x.shape()[0], x.shape()[1], heads_, head_dim_}).transpose(1, 2); + debugLog("dit.attn: to_v done"); + + q = q_norm_(q); + k = k_norm_(k); + + auto out = nn::functional::scaledDotProductAttention(q, k, v); // [B,H,T,D] + out = out.transpose(1, 2).contiguous().view({x.shape()[0], x.shape()[1], inner_dim_}); + out = proj_(out); + debugLog("dit.attn: exit"); + return out; + } + + private: + int32_t dim_ = 0; + int32_t heads_ = 0; + int32_t head_dim_ = 0; + int32_t inner_dim_ = 0; + nn::Linear to_q_; + nn::Linear to_k_; + nn::Linear to_v_; + nn::LayerNorm q_norm_; + nn::LayerNorm k_norm_; + nn::Linear proj_; +}; + +class CausalConv1dBlock final : public nn::Module { + public: + CausalConv1dBlock() = default; + CausalConv1dBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t kernel_size) : nn::Module(name), + in_channels_(in_channels), out_channels_(out_channels), kernel_size_(kernel_size) { + conv1_ = reg("block.1", in_channels_, out_channels_, kernel_size_, 1, 0, 1, 1, true); + norm_ = reg("block.3", std::vector{out_channels_}, true, true, 1e-5f); + conv2_ = reg("block.6", out_channels_, out_channels_, kernel_size_, 1, 0, 1, 1, true); + } + + Tensor forwardOne(Tensor x) { + auto y = x.transpose(1, 2).contiguous(); + y = nn::functional::pad(y, {kernel_size_ - 1, 0}, aops::PadMode::kConstant, 0.0f); + y = conv1_(y); + y = y.transpose(1, 2).contiguous(); + y = norm_(y); + y = tensorMish(y); + y = y.transpose(1, 2).contiguous(); + y = nn::functional::pad(y, {kernel_size_ - 1, 0}, aops::PadMode::kConstant, 0.0f); + y = conv2_(y); + y = y.transpose(1, 2).contiguous(); + return y; + } + + private: + int32_t in_channels_ = 0; + int32_t out_channels_ = 0; + int32_t kernel_size_ = 3; + nn::Conv1D conv1_; + nn::LayerNorm norm_; + nn::Conv1D conv2_; +}; + +class DiTMLP final : public nn::Module { + public: + DiTMLP() = default; + DiTMLP(const std::string& name, int32_t in_features, int32_t hidden_features) : nn::Module(name) { + fc1_ = reg("fc1", in_features, hidden_features, true); + gelu_ = reg("act"); + fc2_ = reg("fc2", hidden_features, in_features, true); + } + + Tensor forwardOne(Tensor x) { + auto y = fc1_(x); + y = gelu_(y); + y = fc2_(y); + return y; + } + + private: + nn::Linear fc1_; + nn::GELU gelu_; + nn::Linear fc2_; +}; + +class TimestepEmbedder final : public nn::Module { + public: + TimestepEmbedder() = default; + TimestepEmbedder(const std::string& name, int32_t hidden_size, int32_t frequency_embedding_size = 256) + : nn::Module(name), hidden_size_(hidden_size), freq_size_(frequency_embedding_size) { + fc1_ = reg("mlp.0", freq_size_, hidden_size_, true); + act_ = reg("mlp.1"); + fc2_ = reg("mlp.2", hidden_size_, hidden_size_, true); + } + + Tensor forwardOne(Tensor t) { + auto emb = timestepEmbedding(t, freq_size_); + emb = fc1_(emb); + emb = act_(emb); + emb = fc2_(emb); + return emb; + } + + private: + Tensor timestepEmbedding(Tensor t, int32_t dim) const { + MLLM_RT_ASSERT_EQ(t.dtype(), kFloat32); + auto tt = t.contiguous(); + const int32_t N = tt.shape()[0]; + const int32_t half = dim / 2; + auto out = Tensor::empty({N, dim}, kFloat32, kCPU).alloc(); + const auto* tp = tt.ptr(); + auto* op = out.ptr(); + for (int32_t i = 0; i < N; ++i) { + const float tv = tp[i] * 1000.0f; + for (int32_t j = 0; j < half; ++j) { + const float freq = std::exp(-std::log(10000.0f) * static_cast(j) / static_cast(half)); + const float a = tv * freq; + op[i * dim + j] = std::cos(a); + op[i * dim + half + j] = std::sin(a); + } + if (dim % 2 == 1) { op[i * dim + dim - 1] = 0.0f; } + } + return out; + } + + private: + int32_t hidden_size_ = 0; + int32_t freq_size_ = 0; + nn::Linear fc1_; + nn::SiLU act_; + nn::Linear fc2_; +}; + +class FinalLayer final : public nn::Module { + public: + FinalLayer() = default; + FinalLayer(const std::string& name, int32_t hidden_size, int32_t out_channels) : nn::Module(name) { + adaln_act_ = reg("adaLN_modulation.0"); + adaln_linear_ = reg("adaLN_modulation.1", hidden_size, 2 * hidden_size, true); + norm_ = reg("norm_final", std::vector{hidden_size}, false, false, 1e-6f); + linear_ = reg("linear", hidden_size, out_channels, true); + } + + Tensor forwardOne(Tensor x, Tensor c) { + auto m = adaln_linear_(adaln_act_(c)); + auto chunks = nn::functional::chunk<2>(m, 2); + auto shift = chunks[0]; + auto scale = chunks[1]; + auto y = norm_(x); + if (scale.rank() == 2) { scale = scale.view({scale.shape()[0], 1, scale.shape()[1]}); } + if (shift.rank() == 2) { shift = shift.view({shift.shape()[0], 1, shift.shape()[1]}); } + y = y * (scale + 1.0f) + shift; + y = linear_(y); + return y; + } + + private: + nn::SiLU adaln_act_; + nn::Linear adaln_linear_; + nn::LayerNorm norm_; + nn::Linear linear_; +}; + +class DiTBlock final : public nn::Module { + public: + DiTBlock() = default; + DiTBlock(const std::string& name, int32_t hidden_size, int32_t num_heads, int32_t head_dim, float mlp_ratio) + : nn::Module(name), hidden_size_(hidden_size) { + norm1_ = reg("norm1", std::vector{hidden_size_}, false, false, 1e-6f); + attn_ = reg("attn", hidden_size_, num_heads, head_dim); + norm2_ = reg("norm2", std::vector{hidden_size_}, false, false, 1e-6f); + norm3_ = reg("norm3", std::vector{hidden_size_}, false, false, 1e-6f); + const int32_t mlp_hidden = static_cast(hidden_size_ * mlp_ratio); + mlp_ = reg("mlp", hidden_size_, mlp_hidden); + conv_ = reg("conv", hidden_size_, hidden_size_, 3); + adaln_act_ = reg("adaLN_modulation.0"); + adaln_linear_ = reg("adaLN_modulation.1", hidden_size_, hidden_size_ * 9, true); + } + + Tensor forwardOne(Tensor x, Tensor c) { + debugLog("dit.block: enter x(" + descOf(x) + ") c(" + descOf(c) + ")"); + auto mods = adaln_linear_(adaln_act_(c)); // [B,1,9C] + debugLog("dit.block: adaln_linear done mods(" + descOf(mods) + ")"); + const int32_t C = hidden_size_; + auto shift_msa = mods[{kAll, kAll, {0 * C, 1 * C}}].contiguous(); + auto scale_msa = mods[{kAll, kAll, {1 * C, 2 * C}}].contiguous(); + auto gate_msa = mods[{kAll, kAll, {2 * C, 3 * C}}].contiguous(); + auto shift_mlp = mods[{kAll, kAll, {3 * C, 4 * C}}].contiguous(); + auto scale_mlp = mods[{kAll, kAll, {4 * C, 5 * C}}].contiguous(); + auto gate_mlp = mods[{kAll, kAll, {5 * C, 6 * C}}].contiguous(); + auto shift_conv = mods[{kAll, kAll, {6 * C, 7 * C}}].contiguous(); + auto scale_conv = mods[{kAll, kAll, {7 * C, 8 * C}}].contiguous(); + auto gate_conv = mods[{kAll, kAll, {8 * C, 9 * C}}].contiguous(); + debugLog("dit.block: chunk9 done"); + + auto y = norm1_(x); + y = y * (scale_msa + 1.0f) + shift_msa; + debugLog("dit.block: before attn y(" + descOf(y) + ")"); + auto attn_out = attn_.forwardOne(y); + debugLog("dit.block: attn done"); + auto h = x + attn_out * gate_msa; + + auto c_in = norm3_(h); + c_in = c_in * (scale_conv + 1.0f) + shift_conv; + auto conv_out = conv_.forwardOne(c_in); + debugLog("dit.block: conv done"); + h = h + conv_out * gate_conv; + + auto m_in = norm2_(h); + m_in = m_in * (scale_mlp + 1.0f) + shift_mlp; + auto mlp_out = mlp_.forwardOne(m_in); + debugLog("dit.block: mlp done"); + h = h + mlp_out * gate_mlp; + debugLog("dit.block: exit"); + return h; + } + + private: + int32_t hidden_size_ = 0; + nn::LayerNorm norm1_; + DiTAttention attn_; + nn::LayerNorm norm2_; + nn::LayerNorm norm3_; + DiTMLP mlp_; + CausalConv1dBlock conv_; + nn::SiLU adaln_act_; + nn::Linear adaln_linear_; +}; + +class DiTEstimator final : public nn::Module { + public: + DiTEstimator() = default; + DiTEstimator(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + t_embedder_ = reg("t_embedder", cfg.dit_hidden_size, 256); + in_proj_ = reg("in_proj", cfg.dit_in_channels, cfg.dit_hidden_size, true); + blocks_ = reg>("blocks", cfg.dit_depth, cfg.dit_hidden_size, cfg.dit_num_heads, cfg.dit_head_dim, + cfg.dit_mlp_ratio); + final_layer_ = reg("final_layer", cfg.dit_hidden_size, cfg.dit_out_channels); + } + + Tensor forwardOne(Tensor x, Tensor mu, Tensor t, Tensor spks, Tensor cond) { + // x,mu,cond: [B,C,T], spks: [B,C], t:[B] + debugLog("dit.forward: begin"); + auto time_emb = t_embedder_.forwardOne(t).view({t.shape()[0], 1, cfg_.dit_hidden_size}); + debugLog("dit.forward: t_embedder done"); + auto spk_seq = spks.view({spks.shape()[0], spks.shape()[1], 1}).repeat(x.shape()[2], 2); + auto packed = concatChannel({x, mu, spk_seq, cond}); // [B,320,T] + debugLog("dit.forward: concat packed done"); + auto h = packed.transpose(1, 2).contiguous(); // [B,T,320] + h = in_proj_(h); // [B,T,512] + debugLog("dit.forward: in_proj done"); + int32_t block_idx = 0; + for (auto& block : blocks_.list()) { + h = block.forwardOne(h, time_emb); + if (block_idx == 0) { debugLog("dit.forward: block0 done"); } + ++block_idx; + } + h = final_layer_.forwardOne(h, time_emb); // [B,T,80] + debugLog("dit.forward: final_layer done"); + h = h.transpose(1, 2).contiguous(); // [B,80,T] + debugLog("dit.forward: end"); + return h; + } + + private: + MiniCPMO45FlowConfig cfg_; + TimestepEmbedder t_embedder_; + nn::Linear in_proj_; + nn::ModuleList blocks_; + FinalLayer final_layer_; +}; + +class CausalConditionalCFM final : public nn::Module { + public: + CausalConditionalCFM() = default; + CausalConditionalCFM(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + estimator_ = reg("estimator", cfg_); + } + + Tensor forwardOne(Tensor mu, Tensor spks, Tensor cond, int32_t n_timesteps, float temperature = 1.0f) { + // all in float32 cpu. + debugLog("cfm.forward: start"); + const int32_t B = mu.shape()[0]; + const int32_t C = mu.shape()[1]; + const int32_t T = mu.shape()[2]; + MLLM_RT_ASSERT_EQ(B, 1); + + auto z = randomNormalLike({B, C, T}, temperature); + + std::vector t_span(static_cast(n_timesteps + 1), 0.0f); + constexpr float kPi = 3.14159265358979323846f; + for (int32_t i = 0; i <= n_timesteps; ++i) { + float t = static_cast(i) / static_cast(n_timesteps); + t_span[static_cast(i)] = 1.0f - std::cos(t * 0.5f * kPi); + } + + auto x = z; + auto mu_in = nn::functional::concat({mu, Tensor::zeros(mu.shape(), kFloat32, kCPU)}, 0); + auto spk_in = nn::functional::concat({spks, Tensor::zeros(spks.shape(), kFloat32, kCPU)}, 0); + auto cond_in = nn::functional::concat({cond, Tensor::zeros(cond.shape(), kFloat32, kCPU)}, 0); + + float t = t_span[0]; + float dt = t_span[1] - t_span[0]; + for (int32_t step = 1; step <= n_timesteps; ++step) { + if (step == 1) { debugLog("cfm.forward: first estimator step"); } + auto x_in = nn::functional::concat({x, x}, 0); // [2,C,T] + auto t_in = makeTimeStepsTensor({t, t}); + auto dphi = estimator_.forwardOne(x_in, mu_in, t_in, spk_in, cond_in); // [2,C,T] + auto dphi_split = nn::functional::chunk<2>(dphi, 0); + auto dphi_main = dphi_split[0]; + auto dphi_cfg = dphi_split[1]; + auto dphi_out = dphi_main * (1.0f + cfg_.cfm_inference_cfg_rate) - dphi_cfg * cfg_.cfm_inference_cfg_rate; + x = x + dphi_out * dt; + t += dt; + if (step < n_timesteps) { dt = t_span[static_cast(step + 1)] - t; } + } + debugLog("cfm.forward: finish"); + return x; + } + + private: + MiniCPMO45FlowConfig cfg_; + DiTEstimator estimator_; +}; + +class CausalMaskedDiffWithXvec final : public nn::Module { + public: + CausalMaskedDiffWithXvec() = default; + CausalMaskedDiffWithXvec(const std::string& name, const MiniCPMO45FlowConfig& cfg) : nn::Module(name), cfg_(cfg) { + input_embedding_ = reg("input_embedding", cfg.vocab_size, cfg.input_size); + spk_embed_affine_layer_ = reg("spk_embed_affine_layer", cfg.spk_embed_dim, cfg.output_size, true); + encoder_ = reg("encoder", cfg); + encoder_proj_ = reg("encoder_proj", cfg.input_size, cfg.output_size, true); + decoder_ = reg("decoder", cfg); + } + + Tensor inference(Tensor token, Tensor prompt_token, Tensor prompt_feat, Tensor embedding, + int32_t n_timesteps) { + // token/prompt_token: [1,T], int64 + debugLog("flow.inference: start"); + auto spk = l2NormalizeRow(embedding); + spk = spk_embed_affine_layer_(spk); // [1,80] + debugLog("flow.inference: spk_embed_affine_layer done"); + + auto all_token = concatInt64Seq(prompt_token, token); + auto token_embed = input_embedding_(all_token); + debugLog("flow.inference: input_embedding done"); + + auto h = encoder_.forwardOne(token_embed); + debugLog("flow.inference: encoder done"); + h = encoder_proj_(h); // [1, Tm, 80] + debugLog("flow.inference: encoder_proj done"); + + const int32_t mel_len1 = prompt_feat.shape()[1]; + const int32_t mel_len_total = h.shape()[1]; + const int32_t mel_len2 = mel_len_total - mel_len1; + MLLM_RT_ASSERT(mel_len2 > 0); + + auto conds = Tensor::zeros(h.shape(), kFloat32, kCPU); + // copy prompt mel to prefix. + auto* cond_ptr = conds.ptr(); + const auto* prm_ptr = prompt_feat.ptr(); + const int32_t C = h.shape()[2]; + for (int32_t t = 0; t < mel_len1; ++t) { + std::memcpy(cond_ptr + static_cast(t) * C, prm_ptr + static_cast(t) * C, sizeof(float) * C); + } + + auto feat = decoder_.forwardOne(h.transpose(1, 2).contiguous(), spk, conds.transpose(1, 2).contiguous(), n_timesteps); + debugLog("flow.inference: decoder done"); + // remove prompt part. + auto out = feat[{kAll, kAll, {mel_len1, mel_len1 + mel_len2}}].contiguous(); + debugLog("flow.inference: finish"); + return out; + } + + private: + MiniCPMO45FlowConfig cfg_; + nn::Embedding input_embedding_; + nn::Linear spk_embed_affine_layer_; + UpsampleConformerEncoderV2 encoder_; + nn::Linear encoder_proj_; + CausalConditionalCFM decoder_; +}; + +class SnakeActivation final : public nn::Module { + public: + SnakeActivation() = default; + SnakeActivation(const std::string& name, int32_t channels) : nn::Module(name) { + alpha_ = reg("alpha", getModuleName() + ".alpha", Tensor::shape_t{channels}); + } + + Tensor forwardOne(Tensor x) { + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto in = x.contiguous(); + auto* dst = out.ptr(); + const auto* src = in.ptr(); + const auto* alpha = alpha_.weight().contiguous().ptr(); + const int32_t B = in.shape()[0]; + const int32_t C = in.shape()[1]; + const int32_t T = in.shape()[2]; + const int64_t stride_b = static_cast(C) * T; + const int64_t stride_c = T; + constexpr float eps = 1e-9f; + for (int32_t b = 0; b < B; ++b) { + for (int32_t c = 0; c < C; ++c) { + const float a = alpha[c]; + for (int32_t t = 0; t < T; ++t) { + const int64_t idx = static_cast(b) * stride_b + static_cast(c) * stride_c + t; + const float v = src[idx]; + const float s = std::sin(v * a); + dst[idx] = v + (s * s) / (a + eps); + } + } + } + return out; + } + + private: + nn::Param alpha_; +}; + +class ResBlock final : public nn::Module { + public: + ResBlock() = default; + ResBlock(const std::string& name, int32_t channels, int32_t kernel_size, const std::vector& dilations) + : nn::Module(name) { + MLLM_RT_ASSERT_EQ(static_cast(dilations.size()), 3); + for (int32_t i = 0; i < 3; ++i) { + convs1_.emplace_back(reg("convs1." + std::to_string(i), channels, channels, kernel_size, 1, + getPadding(kernel_size, dilations[i]), dilations[i], 1, true)); + convs2_.emplace_back(reg("convs2." + std::to_string(i), channels, channels, kernel_size, 1, + getPadding(kernel_size, 1), 1, 1, true)); + activations1_.emplace_back(reg("activations1." + std::to_string(i), channels)); + activations2_.emplace_back(reg("activations2." + std::to_string(i), channels)); + } + } + + Tensor forwardOne(Tensor x) { + auto out = x; + for (int32_t i = 0; i < 3; ++i) { + auto y = activations1_[i].forwardOne(out); + y = convs1_[i](y); + y = activations2_[i].forwardOne(y); + y = convs2_[i](y); + out = out + y; + } + return out; + } + + private: + static int32_t getPadding(int32_t kernel_size, int32_t dilation) { return (kernel_size * dilation - dilation) / 2; } + + private: + std::vector convs1_; + std::vector convs2_; + std::vector activations1_; + std::vector activations2_; +}; + +class ConvRNNF0Predictor final : public nn::Module { + public: + ConvRNNF0Predictor() = default; + ConvRNNF0Predictor(const std::string& name, int32_t in_channels = 80, int32_t cond_channels = 512) : nn::Module(name) { + condnet_0_ = reg("condnet.0", in_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_2_ = reg("condnet.2", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_4_ = reg("condnet.4", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_6_ = reg("condnet.6", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + condnet_8_ = reg("condnet.8", cond_channels, cond_channels, 3, 1, 1, 1, 1, true); + classifier_ = reg("classifier", cond_channels, 1, true); + } + + Tensor forwardOne(Tensor x) { + auto y = condnet_0_(x); + y = tensorElu(y); + y = condnet_2_(y); + y = tensorElu(y); + y = condnet_4_(y); + y = tensorElu(y); + y = condnet_6_(y); + y = tensorElu(y); + y = condnet_8_(y); + y = tensorElu(y); + y = y.transpose(1, 2).contiguous(); + y = classifier_(y).squeeze(-1); + y = tensorAbs(y); + return y; + } + + private: + static Tensor tensorAbs(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { dst[i] = std::abs(src[i]); }); + return out; + } + + static Tensor tensorElu(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + MLLM_CONDITIONAL_PARALLEL_FOR(n > 4096, 4, i, 0, n, 1, { + const float v = src[i]; + dst[i] = (v >= 0.0f) ? v : std::expm1(v); + }); + return out; + } + + private: + nn::Conv1D condnet_0_; + nn::Conv1D condnet_2_; + nn::Conv1D condnet_4_; + nn::Conv1D condnet_6_; + nn::Conv1D condnet_8_; + nn::Linear classifier_; +}; + +class SineGen2 { + public: + SineGen2() = default; + SineGen2(int32_t sampling_rate, int32_t upsample_scale, int32_t harmonic_num, float sine_amp, float noise_std, float voiced_threshold) + : sampling_rate_(sampling_rate), + upsample_scale_(upsample_scale), + harmonic_num_(harmonic_num), + sine_amp_(sine_amp), + noise_std_(noise_std), + voiced_threshold_(voiced_threshold) {} + + std::tuple forward(Tensor f0) { + // f0: [B, T, 1] + auto fn = makeHarmonics(f0); + auto sine = f02sine(fn) * sine_amp_; + auto uv = f02uv(f0); + auto inv_uv = uv * -1.0f + 1.0f; + auto noise_amp = uv * noise_std_ + inv_uv * (sine_amp_ / 3.0f); + auto noise = randomLike(noise_amp); + auto out = sine * uv + noise_amp * noise; + return {out, uv, noise_amp * noise}; + } + + private: + Tensor makeHarmonics(Tensor f0) const { + const int32_t B = f0.shape()[0]; + const int32_t T = f0.shape()[1]; + const int32_t H = harmonic_num_ + 1; + auto out = Tensor::empty({B, T, H}, kFloat32, kCPU).alloc(); + const auto* fp = f0.contiguous().ptr(); + auto* op = out.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + const float v = fp[(static_cast(b) * T + t)]; + for (int32_t h = 0; h < H; ++h) { op[(static_cast(b) * T + t) * H + h] = v * static_cast(h + 1); } + } + } + return out; + } + + Tensor f02uv(Tensor f0) const { + auto out = Tensor::empty(f0.shape(), kFloat32, kCPU).alloc(); + const auto* src = f0.contiguous().ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + for (int64_t i = 0; i < n; ++i) { dst[i] = src[i] > voiced_threshold_ ? 1.0f : 0.0f; } + return out; + } + + Tensor f02sine(Tensor f0_values) const { + // f0_values: [B, T, H] + auto fv = f0_values.contiguous(); + const int32_t B = fv.shape()[0]; + const int32_t T = fv.shape()[1]; + const int32_t H = fv.shape()[2]; + auto rad = Tensor::empty(fv.shape(), kFloat32, kCPU).alloc(); + const auto* fp = fv.ptr(); + auto* rp = rad.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + for (int32_t h = 0; h < H; ++h) { + const int64_t idx = (static_cast(b) * T + t) * H + h; + float v = fp[idx] / static_cast(sampling_rate_); + v = v - std::floor(v); + rp[idx] = v; + } + } + } + + std::mt19937 rng(std::random_device{}()); + std::uniform_real_distribution uni(0.0f, 1.0f); + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 1; h < H; ++h) { + const float phase0 = uni(rng); + rp[(static_cast(b) * T + 0) * H + h] += phase0; + } + } + + // linear interpolate in time by 1 / upsample_scale, then cumulative phase, then upsample back. + auto rad_t = rad.transpose(1, 2).contiguous(); // [B,H,T] + auto down_t = nn::functional::interpolateByScale(rad_t, {1.0f / static_cast(upsample_scale_)}, + aops::InterpolateOpMode::kLinear, false, false); + down_t = down_t.transpose(1, 2).contiguous(); // [B,T',H] + + auto phase = Tensor::empty(down_t.shape(), kFloat32, kCPU).alloc(); + auto* pp = phase.ptr(); + const auto* dp = down_t.ptr(); + const int32_t Td = down_t.shape()[1]; + for (int32_t b = 0; b < B; ++b) { + for (int32_t h = 0; h < H; ++h) { + float acc = 0.0f; + for (int32_t t = 0; t < Td; ++t) { + const int64_t idx = (static_cast(b) * Td + t) * H + h; + acc += dp[idx]; + constexpr float kPi = 3.14159265358979323846f; + pp[idx] = acc * 2.0f * kPi; + } + } + } + + auto phase_t = phase.transpose(1, 2).contiguous(); // [B,H,T'] + phase_t = phase_t * static_cast(upsample_scale_); + auto up_t = nn::functional::interpolateByScale(phase_t, {static_cast(upsample_scale_)}, + aops::InterpolateOpMode::kLinear, false, false); + up_t = up_t.transpose(1, 2).contiguous(); // [B,T,H] + auto out = nn::functional::sin(up_t); + return out; + } + + static Tensor randomLike(Tensor x) { + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { dst[i] = dist(rng); } + return out; + } + + private: + int32_t sampling_rate_ = 24000; + int32_t upsample_scale_ = 480; + int32_t harmonic_num_ = 8; + float sine_amp_ = 0.1f; + float noise_std_ = 0.003f; + float voiced_threshold_ = 10.0f; +}; + +class SourceModuleHnNSF2 { + public: + SourceModuleHnNSF2() = default; + SourceModuleHnNSF2(int32_t sampling_rate, int32_t upsample_scale, int32_t harmonic_num, float sine_amp, float noise_std, + float voiced_threshold) + : l_sin_gen_(sampling_rate, upsample_scale, harmonic_num, sine_amp, noise_std, voiced_threshold), + sine_amp_(sine_amp) {} + + // This wrapper only supports loading external weight via setLinearWeights(). + void setLinearWeights(Tensor w, Tensor b) { + linear_w_ = w.contiguous(); + linear_b_ = b.contiguous(); + } + + std::tuple forward(Tensor x) { + auto [sine_wavs, uv, _] = l_sin_gen_.forward(x); + auto sine_merge = linearForward(sine_wavs); + sine_merge = tensorTanh(sine_merge); + auto noise = randomLike(uv) * (sine_amp_ / 3.0f); + return {sine_merge, noise, uv}; + } + + private: + Tensor linearForward(Tensor x) { + // x: [B,T,H], weight [1,H] + MLLM_RT_ASSERT(!linear_w_.isNil()); + MLLM_RT_ASSERT(!linear_b_.isNil()); + const int32_t B = x.shape()[0]; + const int32_t T = x.shape()[1]; + const int32_t H = x.shape()[2]; + auto out = Tensor::empty({B, T, 1}, kFloat32, kCPU).alloc(); + const auto* xp = x.contiguous().ptr(); + const auto* wp = linear_w_.contiguous().ptr(); + const float bias = linear_b_.constAt({0}); + auto* op = out.ptr(); + for (int32_t b = 0; b < B; ++b) { + for (int32_t t = 0; t < T; ++t) { + float acc = bias; + for (int32_t h = 0; h < H; ++h) { acc += xp[(static_cast(b) * T + t) * H + h] * wp[h]; } + op[static_cast(b) * T + t] = acc; + } + } + return out; + } + + static Tensor randomLike(Tensor x) { + auto out = Tensor::empty(x.shape(), kFloat32, kCPU).alloc(); + auto* dst = out.ptr(); + const int64_t n = static_cast(out.numel()); + static thread_local std::mt19937 rng(std::random_device{}()); + std::normal_distribution dist(0.0f, 1.0f); + for (int64_t i = 0; i < n; ++i) { dst[i] = dist(rng); } + return out; + } + + static Tensor tensorTanh(Tensor x) { + auto in = x.contiguous(); + auto out = Tensor::empty(in.shape(), kFloat32, kCPU).alloc(); + const auto* src = in.ptr(); + auto* dst = out.ptr(); + const int64_t n = static_cast(in.numel()); + for (int64_t i = 0; i < n; ++i) { dst[i] = std::tanh(src[i]); } + return out; + } + + private: + SineGen2 l_sin_gen_; + float sine_amp_ = 0.1f; + Tensor linear_w_ = Tensor::nil(); + Tensor linear_b_ = Tensor::nil(); +}; + +class HiFTGenerator final : public nn::Module { + public: + HiFTGenerator() = default; + HiFTGenerator(const std::string& name, const MiniCPMO45HiFTConfig& cfg) + : nn::Module(name), cfg_(cfg), + upsample_total_scale_(cfg.upsample_rates[0] * cfg.upsample_rates[1] * cfg.upsample_rates[2] * cfg.istft_hop_len), + m_source_(cfg.sampling_rate, upsample_total_scale_, cfg.nb_harmonics, cfg.nsf_alpha, cfg.nsf_sigma, cfg.nsf_voiced_threshold) { + conv_pre_ = reg("conv_pre", cfg.in_channels, cfg.base_channels, 7, 1, 3, 1, 1, true); + for (int32_t i = 0; i < static_cast(cfg.upsample_rates.size()); ++i) { + const int32_t in_ch = cfg.base_channels / static_cast(std::pow(2, i)); + const int32_t out_ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + ups_.emplace_back( + reg("ups." + std::to_string(i), in_ch, out_ch, cfg.upsample_kernel_sizes[i], cfg.upsample_rates[i], + (cfg.upsample_kernel_sizes[i] - cfg.upsample_rates[i]) / 2, 0, 1, 1, true)); + } + + // source downs + std::vector downsample_rates = {1, cfg.upsample_rates[2], cfg.upsample_rates[2] * cfg.upsample_rates[1]}; + std::reverse(downsample_rates.begin(), downsample_rates.end()); // [15,3,1] + for (int32_t i = 0; i < static_cast(downsample_rates.size()); ++i) { + const int32_t u = downsample_rates[i]; + const int32_t out_ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + if (u == 1) { + source_downs_.emplace_back(reg("source_downs." + std::to_string(i), cfg.istft_n_fft + 2, out_ch, 1, 1, 0, 1, 1, true)); + } else { + source_downs_.emplace_back(reg("source_downs." + std::to_string(i), cfg.istft_n_fft + 2, out_ch, u * 2, u, (u / 2), 1, 1, true)); + } + source_resblocks_.emplace_back( + reg("source_resblocks." + std::to_string(i), out_ch, cfg.source_resblock_kernel_sizes[i], cfg.source_resblock_dilation_sizes[i])); + } + + const int32_t num_ups = static_cast(cfg.upsample_rates.size()); + const int32_t num_kernels = static_cast(cfg.resblock_kernel_sizes.size()); + for (int32_t i = 0; i < num_ups; ++i) { + const int32_t ch = cfg.base_channels / static_cast(std::pow(2, i + 1)); + for (int32_t j = 0; j < num_kernels; ++j) { + resblocks_.emplace_back(reg("resblocks." + std::to_string(static_cast(resblocks_.size())), ch, + cfg.resblock_kernel_sizes[j], cfg.resblock_dilation_sizes[j])); + } + } + + conv_post_ = reg("conv_post", cfg.base_channels / static_cast(std::pow(2, cfg.upsample_rates.size())), + cfg.istft_n_fft + 2, 7, 1, 3, 1, 1, true); + f0_predictor_ = reg("f0_predictor"); + stft_ = reg("internal_stft", cfg.istft_n_fft, cfg.istft_hop_len, cfg.istft_n_fft, true, true, "reflect", false); + istft_ = reg("internal_istft", cfg.istft_n_fft, cfg.istft_hop_len, cfg.istft_n_fft, true, true, "reflect"); + hann_window_ = makeHannWindow(cfg.istft_n_fft); + } + + void loadFromParameter(const ParameterFile::ptr_t& param) { + nn::Module::load(param); + // SourceModuleHnNSF2 linear is not a nn::Module member, load manually. + auto w = param->pull(getModuleName() + ".m_source.l_linear.weight"); + auto b = param->pull(getModuleName() + ".m_source.l_linear.bias"); + if (w.dtype() != kFloat32) { w = w.to(kFloat32); } + if (b.dtype() != kFloat32) { b = b.to(kFloat32); } + w = w.contiguous().view({1, cfg_.nb_harmonics + 1}); + b = b.contiguous().view({1}); + m_source_.setLinearWeights(w, b); + } + + Tensor forwardOne(Tensor speech_feat) { + auto f0 = f0_predictor_.forwardOne(speech_feat); // [B,T] + auto f0_ex = f0.view({f0.shape()[0], 1, f0.shape()[1]}); // [B,1,T] + auto s = repeatInterleave1d(f0_ex, upsample_total_scale_).transpose(1, 2); // [B,S,1] + auto [s_merge, _, _uv] = m_source_.forward(s); + auto src = s_merge.transpose(1, 2).contiguous(); // [B,1,S] + auto wav = decode(speech_feat, src); + return wav; + } + + private: + Tensor decode(Tensor x_in, Tensor s) { + auto stft = stft_(s.squeeze(1), hann_window_); // [B,F,T,2] + auto stft_chunks = nn::functional::chunk<2>(stft, 3); + auto s_real = stft_chunks[0].squeeze(-1); + auto s_imag = stft_chunks[1].squeeze(-1); + auto s_stft = nn::functional::concat({s_real, s_imag}, 1); // [B,F*2,T] + + auto x = conv_pre_(x_in); + const int32_t num_ups = static_cast(ups_.size()); + const int32_t num_kernels = static_cast(cfg_.resblock_kernel_sizes.size()); + for (int32_t i = 0; i < num_ups; ++i) { + x = tensorLeakyRelu(x, cfg_.lrelu_slope); + x = ups_[i](x); + if (i == num_ups - 1) { x = nn::functional::pad(x, {1, 0}, aops::PadMode::kReflect); } + + auto si = source_downs_[i](s_stft); + si = source_resblocks_[i].forwardOne(si); + x = x + si; + + Tensor xs = Tensor::nil(); + for (int32_t j = 0; j < num_kernels; ++j) { + auto y = resblocks_[i * num_kernels + j].forwardOne(x); + if (j == 0) { + xs = y; + } else { + xs = xs + y; + } + } + x = xs / static_cast(num_kernels); + } + + x = tensorLeakyRelu(x, 0.01f); + x = conv_post_(x); // [B,18,T] + auto mag = x[{kAll, {0, cfg_.istft_n_fft / 2 + 1}, kAll}].contiguous(); + auto phase = x[{kAll, {cfg_.istft_n_fft / 2 + 1, cfg_.istft_n_fft + 2}, kAll}].contiguous(); + mag = nn::functional::exp(mag); + mag = nn::functional::clip(mag, 0.0f, 1e2f); + // Keep parity with python HiFT: phase is first squashed by sin() before ISTFT synthesis. + phase = nn::functional::sin(phase); + auto real = mag * nn::functional::cos(phase); + auto imag = mag * nn::functional::sin(phase); + auto S = real + std::complex{0, 1} * imag; + auto wav = istft_(S, hann_window_); + wav = nn::functional::clip(wav, -cfg_.audio_limit, cfg_.audio_limit); + return wav; + } + + private: + MiniCPMO45HiFTConfig cfg_; + int32_t upsample_total_scale_ = 480; + nn::Conv1D conv_pre_; + std::vector ups_; + std::vector source_downs_; + std::vector source_resblocks_; + std::vector resblocks_; + nn::Conv1D conv_post_; + ConvRNNF0Predictor f0_predictor_; + nn::STFT stft_; + nn::ISTFT istft_; + Tensor hann_window_ = Tensor::nil(); + SourceModuleHnNSF2 m_source_; +}; + +class MiniCPMO45Token2WavModel final : public nn::Module { + public: + MiniCPMO45Token2WavModel() = default; + MiniCPMO45Token2WavModel(const std::string& name, const MiniCPMO45Token2WavConfig& cfg) : nn::Module(name), cfg_(cfg) { + flow_model_ = reg("flow_model", cfg_.flow); + hift_model_ = reg("hift_model", cfg_.hift); + } + + void loadFromParameter(const ParameterFile::ptr_t& param_file) { + // Materialize weight_norm reparameterized conv weights in-place. + (void)materializeWeightNormParameters(param_file, getModuleName() + ".hift_model."); + flow_model_.load(param_file); + hift_model_.loadFromParameter(param_file); + } + + Tensor infer(const std::vector& token_ids, const MiniCPMO45Token2WavPromptCache& prompt_cache, int32_t n_timesteps) { + if (token_ids.empty()) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "MiniCPM-o-4_5 token2wav got empty token ids."); } + if (prompt_cache.prompt_tokens.empty() || prompt_cache.prompt_mels.isNil() || prompt_cache.spk_emb.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "MiniCPM-o-4_5 token2wav prompt cache is incomplete."); + } + + auto token = Tensor::empty({1, static_cast(token_ids.size())}, kInt64, kCPU).alloc(); + for (int32_t i = 0; i < static_cast(token_ids.size()); ++i) { token.at({0, i}) = token_ids[static_cast(i)]; } + + auto prompt_token = Tensor::empty({1, static_cast(prompt_cache.prompt_tokens.size())}, kInt64, kCPU).alloc(); + for (int32_t i = 0; i < static_cast(prompt_cache.prompt_tokens.size()); ++i) { + prompt_token.at({0, i}) = static_cast(prompt_cache.prompt_tokens[static_cast(i)]); + } + + auto prompt_mels = Tensor(prompt_cache.prompt_mels); + auto spk = Tensor(prompt_cache.spk_emb); + if (prompt_mels.dtype() != kFloat32) { prompt_mels = prompt_mels.to(kFloat32); } + if (spk.dtype() != kFloat32) { spk = spk.to(kFloat32); } + + auto mel = flow_model_.inference(token, prompt_token, prompt_mels, spk, n_timesteps); + auto wav = hift_model_.forwardOne(mel); + return wav; + } + + private: + MiniCPMO45Token2WavConfig cfg_; + CausalMaskedDiffWithXvec flow_model_; + HiFTGenerator hift_model_; +}; + +} // namespace token2wav + +using token2wav::MiniCPMO45Token2WavModel; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp b/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp new file mode 100644 index 000000000..fecbe1f93 --- /dev/null +++ b/mllm/models/minicpm_o45/token2wav_prompt_cache.hpp @@ -0,0 +1,70 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +struct MiniCPMO45Token2WavPromptCache { + std::vector prompt_tokens; + Tensor prompt_mels = Tensor::nil(); // [1, Tm, 80], float32 + Tensor spk_emb = Tensor::nil(); // [1, 192], float32 +}; + +inline MiniCPMO45Token2WavPromptCache loadMiniCPMO45Token2WavPromptCache(const std::string& file_path) { + MiniCPMO45Token2WavPromptCache out; + + std::ifstream in(file_path, std::ios::binary); + if (!in.is_open()) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to open MiniCPM-o-4_5 prompt cache: {}", file_path); + } + + std::array magic{}; + in.read(magic.data(), static_cast(magic.size())); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache header read failed: {}", file_path); } + const std::array expected = {'M', '4', '5', 'P', 'C', '1', '\0', '\0'}; + if (magic != expected) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Invalid prompt cache magic: {}", file_path); } + + uint32_t version = 0; + in.read(reinterpret_cast(&version), sizeof(version)); + if (version != 1) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Unsupported prompt cache version {}: {}", version, file_path); } + + int32_t token_len = 0; + int32_t mel_frames = 0; + int32_t mel_dim = 0; + int32_t spk_dim = 0; + in.read(reinterpret_cast(&token_len), sizeof(token_len)); + in.read(reinterpret_cast(&mel_frames), sizeof(mel_frames)); + in.read(reinterpret_cast(&mel_dim), sizeof(mel_dim)); + in.read(reinterpret_cast(&spk_dim), sizeof(spk_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache meta read failed: {}", file_path); } + if (token_len <= 0 || mel_frames <= 0 || mel_dim <= 0 || spk_dim <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt cache has invalid shape metadata: {}", file_path); + } + + out.prompt_tokens.resize(static_cast(token_len)); + in.read(reinterpret_cast(out.prompt_tokens.data()), sizeof(int32_t) * static_cast(token_len)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt token section read failed: {}", file_path); } + + out.prompt_mels = Tensor::empty({1, mel_frames, mel_dim}, kFloat32, kCPU).alloc(); + in.read(reinterpret_cast(out.prompt_mels.ptr()), + sizeof(float) * static_cast(mel_frames) * static_cast(mel_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt mel section read failed: {}", file_path); } + + out.spk_emb = Tensor::empty({1, spk_dim}, kFloat32, kCPU).alloc(); + in.read(reinterpret_cast(out.spk_emb.ptr()), sizeof(float) * static_cast(spk_dim)); + if (!in.good()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Prompt speaker embedding section read failed: {}", file_path); } + + return out; +} + +} // namespace mllm::models::minicpm_o45 + diff --git a/mllm/models/minicpm_o45/token2wav_weight_norm.hpp b/mllm/models/minicpm_o45/token2wav_weight_norm.hpp new file mode 100644 index 000000000..cfc7fdd9a --- /dev/null +++ b/mllm/models/minicpm_o45/token2wav_weight_norm.hpp @@ -0,0 +1,80 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include + +#include "mllm/core/ParameterFile.hpp" +#include "mllm/core/Tensor.hpp" +#include "mllm/utils/Log.hpp" + +namespace mllm::models::minicpm_o45 { + +inline bool _endsWith(const std::string& s, const std::string& suffix) { + return s.size() >= suffix.size() && s.compare(s.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +inline Tensor _materializeWeightNorm(Tensor g_in, Tensor v_in) { + auto g = g_in.dtype() == kFloat32 ? g_in.contiguous() : g_in.to(kFloat32).contiguous(); + auto v = v_in.dtype() == kFloat32 ? v_in.contiguous() : v_in.to(kFloat32).contiguous(); + + const int64_t out_dim = static_cast(g.numel()); + const int64_t total = static_cast(v.numel()); + if (out_dim <= 0 || total <= 0 || (total % out_dim) != 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, + "Invalid weight-norm tensor shape: g.numel()={}, v.numel()={}", + out_dim, total); + } + + const int64_t row = total / out_dim; + auto w = Tensor::empty({static_cast(total)}, kFloat32, kCPU).alloc(); + auto* g_ptr = g.ptr(); + auto* v_ptr = v.ptr(); + auto* w_ptr = w.ptr(); + + constexpr float kEps = 1e-12f; + for (int64_t i = 0; i < out_dim; ++i) { + const int64_t base = i * row; + float norm = 0.0f; + for (int64_t j = 0; j < row; ++j) { + const float val = v_ptr[base + j]; + norm += val * val; + } + norm = std::sqrt(std::max(norm, kEps)); + const float scale = g_ptr[i] / norm; + for (int64_t j = 0; j < row; ++j) { w_ptr[base + j] = v_ptr[base + j] * scale; } + } + return w; +} + +inline int32_t materializeWeightNormParameters(const ParameterFile::ptr_t& param_file, const std::string& scope_prefix) { + std::vector keys; + keys.reserve(param_file->dict().size()); + for (const auto& kv : param_file->dict()) { keys.push_back(kv.first); } + + const std::string marker = ".parametrizations.weight.original0"; + int32_t count = 0; + for (const auto& key : keys) { + if (!_endsWith(key, marker)) { continue; } + if (!scope_prefix.empty() && key.rfind(scope_prefix, 0) != 0) { continue; } + + const auto prefix = key.substr(0, key.size() - marker.size()); + const auto key_g = prefix + ".parametrizations.weight.original0"; + const auto key_v = prefix + ".parametrizations.weight.original1"; + const auto key_w = prefix + ".weight"; + + if (param_file->has(key_w)) { continue; } + if (!param_file->has(key_g) || !param_file->has(key_v)) { continue; } + + auto weight = _materializeWeightNorm(param_file->pull(key_g), param_file->pull(key_v)); + param_file->push(key_w, weight); + ++count; + } + return count; +} + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp b/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp new file mode 100644 index 000000000..6e68c4bbc --- /dev/null +++ b/mllm/models/minicpm_o45/tokenization_minicpm_o45.hpp @@ -0,0 +1,594 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/core/DataTypes.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/minicpm_o2_6/audio_preprocessor_minicpmo.hpp" +#include "mllm/models/minicpm_o2_6/image_preprocessor_minicpmo.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" + +namespace mllm::models::minicpm_o45 { + +// Same tokenizer splitting rules as Qwen2/Qwen3 family. +inline bool miniCPMO45TokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } + + if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } + + pos = original_pos; + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } + pos = start; + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool miniCPMO45Regex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (miniCPMO45TokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct MiniCPMO45Message { + std::string prompt; + std::string img_file_path; + std::string audio_file_path; + std::string ref_audio_file_path; + std::string system_prompt = + "You are a helpful assistant. You can accept video, audio and text input and output voice and text."; + std::string ref_audio_prompt_prefix = "Clone the voice in the provided audio prompt."; + std::string ref_audio_prompt_suffix = "As an assistant, you will speak using this voice style."; + + [[nodiscard]] std::string buildChatMessage(bool generate_audio = false) const { + std::string result; + if (!ref_audio_file_path.empty()) { + result += "<|im_start|>system\n"; + if (!ref_audio_prompt_prefix.empty()) { result += ref_audio_prompt_prefix + "\n"; } + result += ""; + if (!ref_audio_prompt_suffix.empty()) { result += "\n" + ref_audio_prompt_suffix; } + result += "<|im_end|>\n"; + } else if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + + result += "<|im_start|>user\n"; + if (!img_file_path.empty()) { result += "./"; } + if (!audio_file_path.empty()) { result += ""; } + + if (!prompt.empty()) { + if (!img_file_path.empty() || !audio_file_path.empty()) { result += "\n"; } + result += prompt; + } + + result += "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + + if (generate_audio) { result += "\n\n\n\n<|tts_bos|>"; } + return result; + } +}; + +class MiniCPMO45Tokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit MiniCPMO45Tokenizer(const std::string& tokenizer_path, int32_t patch_size = 14, int32_t audio_pool_step = 5) + : image_preprocessor_(patch_size), + audio_preprocessor_(16000, 80, 160), + audio_pool_step_(audio_pool_step) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + + bpe_.initFromSentencePieceJson(tokenizer_path); + + const std::vector special_tokens = { + L"", + L"<|endoftext|>", + L"<|im_start|>", + L"<|im_end|>", + L"<|object_ref_start|>", + L"<|object_ref_end|>", + L"<|box_start|>", + L"<|box_end|>", + L"<|quad_start|>", + L"<|quad_end|>", + L"<|vision_start|>", + L"<|vision_end|>", + L"<|vision_pad|>", + L"<|image_pad|>", + L"<|video_pad|>", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"", + L"<|audio_start|>", + L"<|audio|>", + L"<|audio_end|>", + L"<|spk_bos|>", + L"<|spk|>", + L"<|spk_eos|>", + L"<|tts_bos|>", + L"<|tts_eos|>", + L"<|listen|>", + L"<|speak|>", + L"<|interrupt|>", + L"<|vad_start|>", + L"<|vad_end|>", + L"<|chunk_eos|>", + L"<|chunk_bos|>", + L"<|chunk_tts_bos|>", + L"<|chunk_tts_eos|>", + }; + + for (const auto& token : special_tokens) { addSpecialToken(token); } + loadSpecialTokensFromTokenizerJson(tokenizer_path); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::minicpm_o45::miniCPMO45Regex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_tokens = bpe_._bpe(mapped_str); + for (const auto& bpe_token : bpe_tokens) { ret.push_back(bpe_token); } + } + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back(static_cast(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("minicpmo45-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + return ret; + } + + int64_t lookupTokenId(const std::wstring& token) { return bpe_._lookup_vocab(token); } + + ARGenerationOutputPast convertMessage(const MiniCPMO45Message& message, bool generate_audio_prompt = false) { + bool has_image = !message.img_file_path.empty(); + bool has_ref_audio = !message.ref_audio_file_path.empty(); + bool has_user_audio = !message.audio_file_path.empty(); + bool has_audio = has_ref_audio || has_user_audio; + + auto applied_string = message.buildChatMessage(generate_audio_prompt); + + std::vector img_tensors; + std::vector> original_sizes; + std::vector> tgt_sizes; + std::vector grid; + + Tensor audio_features = Tensor::nil(); + std::vector audio_lengths; + std::vector audio_feature_list; + + if (has_image) { + auto [tensors, orig_size, target_sizes, img_grid] = image_preprocessor_.process(message.img_file_path); + img_tensors = std::move(tensors); + original_sizes = std::move(orig_size); + tgt_sizes = std::move(target_sizes); + grid = std::move(img_grid); + } + + if (has_ref_audio) { + auto audio_data = mllm::audio::readWAV(message.ref_audio_file_path, 16000); + auto audio_length = static_cast(audio_data.size()); + if (audio_length > 0) { + audio_lengths.push_back(audio_length); + auto ref_audio_features = audio_preprocessor_.processAudioData(audio_data.data(), audio_length); + if (!ref_audio_features.isNil()) { audio_feature_list.push_back(ref_audio_features); } + } + } + + if (has_user_audio) { + auto audio_data = mllm::audio::readWAV(message.audio_file_path, 16000); + auto audio_length = static_cast(audio_data.size()); + if (audio_length > 0) { + audio_lengths.push_back(audio_length); + auto user_audio_features = audio_preprocessor_.processAudioData(audio_data.data(), audio_length); + if (!user_audio_features.isNil()) { audio_feature_list.push_back(user_audio_features); } + } + } + + if (!audio_feature_list.empty()) { + int32_t batch = static_cast(audio_feature_list.size()); + auto channels = audio_feature_list[0].shape()[1]; + auto frames = audio_feature_list[0].shape()[2]; + audio_features = Tensor::empty({batch, channels, frames}, kFloat32, kCPU) + .setMemType(kExtraInput) + .setName("audio_features") + .alloc(); + auto* dst = audio_features.ptr(); + auto single_size = static_cast(channels) * static_cast(frames); + for (int32_t i = 0; i < batch; ++i) { + std::memcpy(dst + static_cast(i) * single_size, audio_feature_list[i].ptr(), + single_size * sizeof(float)); + } + } + + if (has_image) { + std::regex img_pattern(R"(\./)"); + std::vector image_tags; + std::sregex_iterator iter(applied_string.begin(), applied_string.end(), img_pattern); + std::sregex_iterator end; + + for (; iter != end; ++iter) { image_tags.push_back(iter->str()); } + + std::vector text_chunks; + int32_t pos = 0; + for (const auto& tag : image_tags) { + auto found = applied_string.find(tag, pos); + if (found != std::string::npos) { + text_chunks.push_back(applied_string.substr(pos, found - pos)); + pos = static_cast(found + tag.size()); + } + } + text_chunks.push_back(applied_string.substr(pos)); + + std::string final_text; + for (size_t i = 0; i < image_tags.size(); ++i) { + final_text += text_chunks[i]; + final_text += image_preprocessor_.get_slice_image_placeholder(original_sizes[i], grid, static_cast(i)); + } + final_text += text_chunks.back(); + applied_string = final_text; + } + + if (has_audio) { + size_t search_pos = 0; + for (auto audio_length : audio_lengths) { + auto audio_placeholder = getAudioPlaceholder(audio_length, false); + auto audio_placeholder_pos = applied_string.find("", search_pos); + if (audio_placeholder_pos == std::string::npos) { break; } + applied_string.replace(audio_placeholder_pos, std::string("").size(), audio_placeholder); + search_pos = audio_placeholder_pos + audio_placeholder.size(); + } + } + + auto sequence_str = tokenize(applied_string); + std::vector input_ids_vec; + input_ids_vec.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { input_ids_vec.emplace_back(bpe_._lookup_vocab(str)); } + + std::vector>image_bounds; + std::vector> audio_bounds; + + if (has_image) { + auto [_, bounds] = image_preprocessor_.calc_bounds(input_ids_vec, bpe_); + image_bounds = std::move(bounds); + } + + if (has_audio) { + int64_t audio_start_id = bpe_._lookup_vocab(L"<|audio_start|>"); + int64_t audio_end_id = bpe_._lookup_vocab(L"<|audio_end|>"); + audio_bounds = audio_preprocessor_.calcAudioBounds(input_ids_vec, audio_start_id, audio_end_id); + } + + return convertToTensors(input_ids_vec, img_tensors, tgt_sizes, image_bounds, audio_features, audio_bounds); + } + + private: + void addSpecialToken(const std::wstring& token) { + if (!token.empty()) { special_tokens_trie_.add(token); } + } + + void loadSpecialTokensFromTokenizerJson(const std::string& tokenizer_path) { + std::ifstream in(tokenizer_path); + if (!in.is_open()) { return; } + + nlohmann::json json_data; + try { + json_data = nlohmann::json::parse(in); + } catch (...) { + return; + } + + if (!json_data.contains("added_tokens") || !json_data["added_tokens"].is_array()) { return; } + for (const auto& token_info : json_data["added_tokens"]) { + if (!token_info.contains("content")) { continue; } + addSpecialToken(preprocessor::utf8string2WideString(token_info["content"].get())); + } + } + + [[nodiscard]] std::string getAudioPlaceholder(int32_t audio_length, bool chunk_input, float chunk_length = 1.0f) const { + int32_t capped_audio_length = std::min(audio_length, max_audio_samples_); + int32_t feature_lens = static_cast(std::ceil(static_cast(capped_audio_length) / hop_length_)); + feature_lens = (feature_lens - 1) / 2 + 1; + + auto output_lens = (feature_lens - audio_pool_step_) / audio_pool_step_ + 1; + output_lens = std::max(output_lens, 0); + + if (!chunk_input) { + std::string audio_placeholder = "<|audio_start|>"; + for (int32_t i = 0; i < output_lens; ++i) { audio_placeholder += ""; } + audio_placeholder += "<|audio_end|>"; + return audio_placeholder; + } + + auto fbank_feat_in_chunk = static_cast(chunk_length * 100); + auto cnn_feat_in_chunk = (fbank_feat_in_chunk - 1) / 2 + 1; + auto audio_embeds_in_chunk = (cnn_feat_in_chunk - audio_pool_step_) / audio_pool_step_ + 1; + audio_embeds_in_chunk = std::max(audio_embeds_in_chunk, 1); + + auto num_audio_chunks = (output_lens + audio_embeds_in_chunk - 1) / audio_embeds_in_chunk; + + std::string placeholders; + int32_t total_unk_len = 0; + for (int32_t i = 0; i < num_audio_chunks; ++i) { + auto unk_len = std::min(audio_embeds_in_chunk, output_lens - total_unk_len); + placeholders += "<|audio_start|>"; + for (int32_t j = 0; j < unk_len; ++j) { placeholders += ""; } + placeholders += "<|audio_end|>"; + total_unk_len += unk_len; + } + return placeholders; + } + + ARGenerationOutputPast convertToTensors(const std::vector& input_ids_vec, std::vector& img_tensors, + const std::vector>& tgt_sizes, + const std::vector>& image_bounds, const Tensor& audio_features, + const std::vector>& audio_bounds) { + ARGenerationOutputPast result; + + if (!input_ids_vec.empty()) { + auto input_ids_tensor = Tensor::empty({1, static_cast(input_ids_vec.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("input_ids") + .alloc(); + auto* input_ids_ptr = input_ids_tensor.ptr(); + for (size_t i = 0; i < input_ids_vec.size(); ++i) { input_ids_ptr[i] = input_ids_vec[i]; } + result["input_ids"] = input_ids_tensor; + } + + if (!img_tensors.empty()) { + int32_t channels = img_tensors[0].shape()[0]; + int32_t patch_size = img_tensors[0].shape()[1]; + int32_t hw_patch_size = img_tensors[0].shape()[2]; + for (const auto& img_tensor : img_tensors) { + if (img_tensor.shape()[2] > hw_patch_size) { hw_patch_size = img_tensor.shape()[2]; } + } + + auto pixel_values = Tensor::empty({static_cast(img_tensors.size()), channels, patch_size, hw_patch_size}, kFloat32, + kCPU) + .setMemType(kExtraInput) + .setName("pixel_values") + .alloc(); + auto* pixel_values_ptr = pixel_values.ptr(); + std::memset(pixel_values_ptr, 0, static_cast(img_tensors.size()) * channels * patch_size * hw_patch_size * sizeof(float_t)); + + for (int32_t b = 0; b < static_cast(img_tensors.size()); ++b) { + int32_t src_hw = img_tensors[b].shape()[2]; + const auto* src_ptr = img_tensors[b].ptr(); + + for (int32_t c = 0; c < channels; ++c) { + for (int32_t p = 0; p < patch_size; ++p) { + int32_t src_offset = c * patch_size * src_hw + p * src_hw; + int32_t dst_offset = b * channels * patch_size * hw_patch_size + c * patch_size * hw_patch_size + p * hw_patch_size; + std::memcpy(pixel_values_ptr + dst_offset, src_ptr + src_offset, src_hw * sizeof(float_t)); + } + } + } + + result["pixel_values"] = pixel_values; + } + + if (!tgt_sizes.empty()) { + auto tgt_sizes_tensor = Tensor::empty({static_cast(tgt_sizes.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("tgt_sizes") + .alloc(); + auto* tgt_sizes_ptr = tgt_sizes_tensor.ptr(); + for (size_t i = 0; i < tgt_sizes.size(); ++i) { + tgt_sizes_ptr[i * 2] = tgt_sizes[i].first; + tgt_sizes_ptr[i * 2 + 1] = tgt_sizes[i].second; + } + result["tgt_sizes"] = tgt_sizes_tensor; + } + + if (!image_bounds.empty()) { + auto image_bounds_tensor = Tensor::empty({static_cast(image_bounds.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("image_bounds") + .alloc(); + auto* image_bounds_ptr = image_bounds_tensor.ptr(); + for (size_t i = 0; i < image_bounds.size(); ++i) { + image_bounds_ptr[i * 2] = image_bounds[i].first; + image_bounds_ptr[i * 2 + 1] = image_bounds[i].second; + } + result["image_bounds"] = image_bounds_tensor; + } + + if (!audio_features.isNil()) { result["audio_features"] = audio_features; } + + if (!audio_bounds.empty()) { + auto audio_bounds_tensor = Tensor::empty({static_cast(audio_bounds.size()), 2}, kInt32, kCPU) + .setMemType(kExtraInput) + .setName("audio_bounds") + .alloc(); + auto* audio_bounds_ptr = audio_bounds_tensor.ptr(); + for (size_t i = 0; i < audio_bounds.size(); ++i) { + audio_bounds_ptr[i * 2] = audio_bounds[i].first; + audio_bounds_ptr[i * 2 + 1] = audio_bounds[i].second; + } + result["audio_bounds"] = audio_bounds_tensor; + } + + return result; + } + + private: + minicpmo::MiniCPMOImageProcessor image_preprocessor_; + minicpmo::MiniCPMOAudioProcessor audio_preprocessor_; + int32_t audio_pool_step_ = 5; + int32_t hop_length_ = 160; + int32_t max_audio_samples_ = 30 * 16000; + + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; +}; + +} // namespace mllm::models::minicpm_o45 diff --git a/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp new file mode 100644 index 000000000..392bfc17b --- /dev/null +++ b/mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp @@ -0,0 +1,240 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Tensor.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/preprocessor/audio/Audio.hpp" + +namespace mllm::models::qwen2_5omni { + +inline float hertz_to_mel_slaney(float freq) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = 27.0f / std::log(6.4f); + + if (freq < kMinLogHertz) { + return 3.0f * freq / 200.0f; + } + return kMinLogMel + std::log(freq / kMinLogHertz) * logstep; +} + +inline float mel_to_hertz_slaney(float mel) { + constexpr float kMinLogHertz = 1000.0f; + constexpr float kMinLogMel = 15.0f; + const float logstep = std::log(6.4f) / 27.0f; + + if (mel < kMinLogMel) { + return 200.0f * mel / 3.0f; + } + return kMinLogHertz * std::exp(logstep * (mel - kMinLogMel)); +} + +inline Tensor create_hann_window(int32_t window_length, bool periodic = true) { + int32_t length = periodic ? window_length + 1 : window_length; + auto window = Tensor::empty({1, window_length}, kFloat32, kCPU).alloc(); + float* window_ptr = window.ptr(); + + for (int32_t i = 0; i < window_length; ++i) { + float n = static_cast(i); + float denominator = periodic ? static_cast(length) : static_cast(length - 1); + window_ptr[i] = 0.5f - 0.5f * std::cos(2.0f * M_PI * n / denominator); + } + + return window; +} + +inline Tensor create_mel_filterbank(int32_t num_frequency_bins, int32_t num_mel_filters, float min_frequency, + float max_frequency, int32_t sampling_rate) { + std::vector fft_freqs(num_frequency_bins); + for (int32_t i = 0; i < num_frequency_bins; ++i) { + fft_freqs[i] = static_cast(i) * (sampling_rate / 2.0f) / (num_frequency_bins - 1); + } + + float mel_min = hertz_to_mel_slaney(min_frequency); + float mel_max = hertz_to_mel_slaney(max_frequency); + + std::vector mel_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { + mel_freqs[i] = mel_min + static_cast(i) * (mel_max - mel_min) / (num_mel_filters + 1); + } + + std::vector filter_freqs(num_mel_filters + 2); + for (int32_t i = 0; i < num_mel_filters + 2; ++i) { filter_freqs[i] = mel_to_hertz_slaney(mel_freqs[i]); } + + auto mel_filters = Tensor::empty({num_frequency_bins, num_mel_filters}, kFloat32, kCPU).alloc(); + float* filters_ptr = mel_filters.ptr(); + std::fill_n(filters_ptr, num_frequency_bins * num_mel_filters, 0.0f); + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float left_freq = filter_freqs[mel_idx]; + float center_freq = filter_freqs[mel_idx + 1]; + float right_freq = filter_freqs[mel_idx + 2]; + + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + float freq = fft_freqs[freq_idx]; + float value = 0.0f; + + if (freq >= left_freq && freq <= center_freq && center_freq != left_freq) { + value = (freq - left_freq) / (center_freq - left_freq); + } else if (freq >= center_freq && freq <= right_freq && right_freq != center_freq) { + value = (right_freq - freq) / (right_freq - center_freq); + } + + filters_ptr[freq_idx * num_mel_filters + mel_idx] = value; + } + } + + for (int32_t mel_idx = 0; mel_idx < num_mel_filters; ++mel_idx) { + float enorm = 2.0f / (filter_freqs[mel_idx + 2] - filter_freqs[mel_idx]); + for (int32_t freq_idx = 0; freq_idx < num_frequency_bins; ++freq_idx) { + filters_ptr[freq_idx * num_mel_filters + mel_idx] *= enorm; + } + } + + return mel_filters; +} + +class MelSpectrogramFeatures final : public nn::Module { + int32_t n_fft_; + int32_t hop_length_; + int32_t win_length_; + int32_t n_mels_; + std::string padding_; + int power_; + nn::STFT stft_; + Tensor window_; + Tensor melscale_fbanks_; + + public: + MelSpectrogramFeatures() = default; + + explicit inline MelSpectrogramFeatures(const std::string& name, int32_t sample_rate = 16000, int32_t n_fft = 400, + int32_t hop_length = 160, int32_t n_mels = 128, + const std::string& padding = "center", int power = 2) + : nn::Module(name), n_fft_(n_fft), hop_length_(hop_length), n_mels_(n_mels), padding_(padding), power_(power) { + if (padding != "center" && padding != "same") { throw std::invalid_argument("Padding must be 'center' or 'same'."); } + + win_length_ = n_fft_; + stft_ = reg("stft", n_fft_, hop_length_, win_length_, true, true, "reflect", true); + window_ = create_hann_window(win_length_, true); + melscale_fbanks_ = create_mel_filterbank(n_fft_ / 2 + 1, n_mels_, 0.0f, 8000.0f, sample_rate); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto audio = inputs[0]; // [B, T] + + if (padding_ == "same") { + NYI("apply same padding in MelSpectrogramFeatures not implemented"); + } + + auto stft_result = stft_(audio, window_); + auto specgram = stft_result.abs(); + if (power_ == 2) { + specgram = specgram * specgram; + } else if (power_ != 1) { + NYI("power != 1 and power != 2 not implemented"); + } + + auto mel_specgram = nn::functional::matmul(specgram.T(), melscale_fbanks_).T(); + mel_specgram = nn::functional::clip(mel_specgram, 1e-10f, std::numeric_limits::max()); + mel_specgram = nn::functional::log(mel_specgram) / std::log(10.0f); + auto max_val = mel_specgram.max(); + float threshold = max_val.item() - 8.0f; + mel_specgram = nn::functional::clip(mel_specgram, threshold, std::numeric_limits::max()); + mel_specgram = (mel_specgram + 4.0f) / 4.0f; + + return {mel_specgram}; + } +}; + +struct Qwen2_5OmniAudioFeatures { + Tensor input_features = Tensor::nil(); + int32_t feature_length = 0; +}; + +class Qwen2_5OmniAudioPreprocessor { + MelSpectrogramFeatures mel_extractor_; + int32_t sample_rate_; + int32_t n_mels_; + int32_t hop_length_; + int32_t chunk_length_; + int32_t n_samples_; + + public: + explicit Qwen2_5OmniAudioPreprocessor(int32_t sample_rate = 16000, int32_t n_mels = 128, int32_t hop_length = 160, + int32_t chunk_length = 300) + : mel_extractor_("feature_extractor.mel_spec", sample_rate, 400, hop_length, n_mels, "center", 2), + sample_rate_(sample_rate), + n_mels_(n_mels), + hop_length_(hop_length), + chunk_length_(chunk_length), + n_samples_(chunk_length * sample_rate) {} + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioFile(const std::string& audio_file_path) { + auto audio_data = mllm::audio::readWAV(audio_file_path, sample_rate_); + if (audio_data.empty()) { return {}; } + return processAudioData(audio_data.data(), static_cast(audio_data.size())); + } + + [[nodiscard]] Qwen2_5OmniAudioFeatures processAudioData(const float* audio_data, int32_t audio_length) { + Qwen2_5OmniAudioFeatures result; + if (audio_data == nullptr || audio_length <= 0) { return result; } + + int32_t padded_length = n_samples_; + int32_t effective_length = std::min(audio_length, padded_length); + + auto audio_tensor = Tensor::empty({1, padded_length}, kFloat32, kCPU).alloc(); + float* audio_ptr = audio_tensor.ptr(); + + if (audio_length <= padded_length) { + std::memcpy(audio_ptr, audio_data, audio_length * sizeof(float)); + std::fill(audio_ptr + audio_length, audio_ptr + padded_length, 0.0f); + } else { + std::memcpy(audio_ptr, audio_data, padded_length * sizeof(float)); + } + + auto mel_spec = mel_extractor_.forward({audio_tensor}, {})[0]; // [1, n_mels, n_frames] + + int32_t valid_frames = calcFeatureLength(effective_length); + int32_t max_frames = mel_spec.shape()[2]; + if (valid_frames > max_frames) { valid_frames = max_frames; } + if (valid_frames <= 0) { return result; } + + auto trimmed = Tensor::empty({1, n_mels_, valid_frames}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < n_mels_; ++m) { + auto src_ptr = mel_spec.offsettedPtr({0, m, 0}); + auto dst_ptr = trimmed.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, valid_frames * sizeof(float)); + } + + result.input_features = trimmed; + result.feature_length = valid_frames; + return result; + } + + [[nodiscard]] int32_t calcFeatureLength(int32_t audio_length) const { + if (audio_length <= 0) { return 0; } + return (audio_length + hop_length_ - 1) / hop_length_; + } + + [[nodiscard]] int32_t calcAudioTokenLength(int32_t feature_length) const { + if (feature_length <= 0) { return 0; } + int32_t after_conv = (feature_length - 1) / 2 + 1; + if (after_conv < 2) { return 0; } + int32_t after_pool = (after_conv - 2) / 2 + 1; + return std::max(0, after_pool); + } +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp new file mode 100644 index 000000000..496ff6996 --- /dev/null +++ b/mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp @@ -0,0 +1,382 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/core/aops/LinearOp.hpp" +#include "mllm/engine/ConfigFile.hpp" + +namespace mllm::models::qwen2_5omni { + +struct Qwen2_5OmniTalkerConfig { + Qwen2_5OmniTalkerConfig() = default; + + explicit Qwen2_5OmniTalkerConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + audio_token_id = root.value("audio_token_index", audio_token_id); + image_token_id = root.value("image_token_index", image_token_id); + video_token_id = root.value("video_token_index", video_token_id); + + vocab_size = root.value("vocab_size", vocab_size); + tts_text_start_token_id = root.value("tts_text_start_token_id", tts_text_start_token_id); + tts_text_end_token_id = root.value("tts_text_end_token_id", tts_text_end_token_id); + tts_text_pad_token_id = root.value("tts_text_pad_token_id", tts_text_pad_token_id); + tts_codec_start_token_id = root.value("tts_codec_start_token_id", tts_codec_start_token_id); + tts_codec_end_token_id = root.value("tts_codec_end_token_id", tts_codec_end_token_id); + tts_codec_pad_token_id = root.value("tts_codec_pad_token_id", tts_codec_pad_token_id); + tts_codec_mask_token_id = root.value("tts_codec_mask_token_id", tts_codec_mask_token_id); + + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + + embedding_size = root.value("embedding_size", embedding_size); + hidden_size = root.value("hidden_size", hidden_size); + intermediate_size = root.value("intermediate_size", intermediate_size); + num_hidden_layers = root.value("num_hidden_layers", num_hidden_layers); + num_attention_heads = root.value("num_attention_heads", num_attention_heads); + num_key_value_heads = root.value("num_key_value_heads", num_key_value_heads); + head_dim = root.value("head_dim", head_dim); + max_position_embeddings = root.value("max_position_embeddings", max_position_embeddings); + rms_norm_eps = root.value("rms_norm_eps", rms_norm_eps); + rope_theta = root.value("rope_theta", rope_theta); + use_sliding_window = root.value("use_sliding_window", use_sliding_window); + sliding_window = root.value("sliding_window", sliding_window); + max_window_layers = root.value("max_window_layers", max_window_layers); + attention_dropout = root.value("attention_dropout", attention_dropout); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); + spatial_merge_size = root.value("spatial_merge_size", spatial_merge_size); + + if (root.contains("rope_scaling") && root["rope_scaling"].contains("mrope_section")) { + mrope_section = root["rope_scaling"]["mrope_section"].get>(); + } + } + + int64_t audio_token_id = 151646; + int64_t image_token_id = 151655; + int64_t video_token_id = 151656; + + int32_t vocab_size = 8448; + int64_t tts_text_start_token_id = 151860; + int64_t tts_text_end_token_id = 151861; + int64_t tts_text_pad_token_id = 151859; + int64_t tts_codec_start_token_id = 8293; + int64_t tts_codec_end_token_id = 8294; + int64_t tts_codec_pad_token_id = 8292; + int64_t tts_codec_mask_token_id = 8296; + + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + + int32_t embedding_size = 3584; + int32_t hidden_size = 896; + int32_t intermediate_size = 18944; + int32_t num_hidden_layers = 24; + int32_t num_attention_heads = 12; + int32_t num_key_value_heads = 4; + int32_t head_dim = 128; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + float rope_theta = 1000000.0f; + bool use_sliding_window = false; + int32_t sliding_window = 32768; + int32_t max_window_layers = 28; + float attention_dropout = 0.0f; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; + int32_t spatial_merge_size = 2; + std::vector mrope_section = {16, 24, 24}; +}; + +struct Qwen2_5OmniDiTConfig { + Qwen2_5OmniDiTConfig() = default; + + explicit Qwen2_5OmniDiTConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + hidden_size = root.value("dim", hidden_size); + num_hidden_layers = root.value("depth", num_hidden_layers); + num_attention_heads = root.value("heads", num_attention_heads); + ff_mult = root.value("ff_mult", ff_mult); + emb_dim = root.value("emb_dim", emb_dim); + head_dim = root.value("head_dim", head_dim); + repeats = root.value("repeats", repeats); + num_embeds = root.value("num_embeds", num_embeds); + mel_dim = root.value("mel_dim", mel_dim); + dropout = root.value("dropout", dropout); + + max_position_embeddings = root.value("max_position_embeddings", max_position_embeddings); + block_size = root.value("block_size", block_size); + if (root.contains("look_ahead_layers")) { look_ahead_layers = root["look_ahead_layers"].get>(); } + if (root.contains("look_backward_layers")) { look_backward_layers = root["look_backward_layers"].get>(); } + rope_theta = root.value("rope_theta", rope_theta); + rope_type = root.value("rope_type", rope_type); + if (root.contains("rope_parameters")) { + const auto& rope_params = root["rope_parameters"]; + rope_theta = rope_params.value("rope_theta", rope_theta); + rope_type = rope_params.value("rope_type", rope_type); + } + + enc_emb_dim = root.value("enc_emb_dim", enc_emb_dim); + enc_dim = root.value("enc_dim", enc_dim); + if (root.contains("enc_channels")) { enc_channels = root["enc_channels"].get>(); } + if (root.contains("enc_kernel_sizes")) { enc_kernel_sizes = root["enc_kernel_sizes"].get>(); } + if (root.contains("enc_dilations")) { enc_dilations = root["enc_dilations"].get>(); } + enc_attention_channels = root.value("enc_attention_channels", enc_attention_channels); + enc_res2net_scale = root.value("enc_res2net_scale", enc_res2net_scale); + enc_se_channels = root.value("enc_se_channels", enc_se_channels); + } + + int32_t hidden_size = 1024; + int32_t num_hidden_layers = 22; + int32_t num_attention_heads = 16; + int32_t ff_mult = 2; + int32_t emb_dim = 512; + int32_t head_dim = 64; + int32_t max_position_embeddings = 32768; + int32_t block_size = 24; + std::vector look_ahead_layers = {10}; + std::vector look_backward_layers = {0, 20}; + int32_t repeats = 2; + int32_t num_embeds = 8193; + int32_t mel_dim = 80; + float dropout = 0.1f; + + int32_t enc_emb_dim = 192; + int32_t enc_dim = 128; + std::vector enc_channels = {256, 256, 256, 256, 768}; + std::vector enc_kernel_sizes = {5, 3, 3, 3, 1}; + std::vector enc_dilations = {1, 2, 3, 4, 1}; + int32_t enc_attention_channels = 64; + int32_t enc_res2net_scale = 2; + int32_t enc_se_channels = 64; + + float rope_theta = 10000.0f; + std::string rope_type = "default"; +}; + +struct Qwen2_5OmniBigVGANConfig { + Qwen2_5OmniBigVGANConfig() = default; + + explicit Qwen2_5OmniBigVGANConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + mel_dim = root.value("mel_dim", mel_dim); + upsample_initial_channel = root.value("upsample_initial_channel", upsample_initial_channel); + if (root.contains("resblock_kernel_sizes")) { + resblock_kernel_sizes = root["resblock_kernel_sizes"].get>(); + } + if (root.contains("resblock_dilation_sizes")) { + resblock_dilation_sizes = root["resblock_dilation_sizes"].get>>(); + } + if (root.contains("upsample_rates")) { upsample_rates = root["upsample_rates"].get>(); } + if (root.contains("upsample_kernel_sizes")) { + upsample_kernel_sizes = root["upsample_kernel_sizes"].get>(); + } + } + + int32_t mel_dim = 80; + int32_t upsample_initial_channel = 1536; + std::vector resblock_kernel_sizes = {3, 7, 11}; + std::vector> resblock_dilation_sizes = {{1, 3, 5}, {1, 3, 5}, {1, 3, 5}}; + std::vector upsample_rates = {5, 3, 2, 2, 2, 2}; + std::vector upsample_kernel_sizes = {11, 7, 4, 4, 4, 4}; +}; + +struct Qwen2_5OmniToken2WavConfig { + Qwen2_5OmniToken2WavConfig() = default; + + explicit Qwen2_5OmniToken2WavConfig(const nlohmann::json& root) { parse(root); } + + void parse(const nlohmann::json& root) { + if (root.contains("dit_config")) { dit_config.parse(root["dit_config"]); } + if (root.contains("bigvgan_config")) { bigvgan_config.parse(root["bigvgan_config"]); } + } + + Qwen2_5OmniDiTConfig dit_config{}; + Qwen2_5OmniBigVGANConfig bigvgan_config{}; +}; + +struct Qwen2_5OmniConfig : protected ConfigFile { + Qwen2_5OmniConfig() = default; + + explicit Qwen2_5OmniConfig(const std::string& file_path) : ConfigFile(file_path) { + auto& root = data(); + enable_audio_output = root.value("enable_audio_output", root.value("enable_talker", enable_audio_output)); + + if (root.contains("talker_config")) { talker_cfg.parse(root["talker_config"]); } + if (root.contains("token2wav_config")) { token2wav_cfg.parse(root["token2wav_config"]); } + + if (root.contains("thinker_config")) { + auto& thinker_cfg = root["thinker_config"]; + auto& text_cfg = thinker_cfg["text_config"]; + + hidden_size = text_cfg["hidden_size"]; + intermediate_size = text_cfg["intermediate_size"]; + num_attention_heads = text_cfg["num_attention_heads"]; + num_key_value_heads = text_cfg["num_key_value_heads"]; + num_hidden_layers = text_cfg["num_hidden_layers"]; + max_position_embeddings = text_cfg["max_position_embeddings"]; + rms_norm_eps = text_cfg["rms_norm_eps"]; + vocab_size = text_cfg["vocab_size"]; + rope_theta = text_cfg["rope_theta"]; + tie_word_embeddings = text_cfg.value("tie_word_embeddings", false); + + if (text_cfg.contains("rope_scaling") && text_cfg["rope_scaling"].contains("mrope_section")) { + mrope_section = text_cfg["rope_scaling"]["mrope_section"].get>(); + } + + if (thinker_cfg.contains("vision_config")) { + auto& vision_cfg = thinker_cfg["vision_config"]; + visual_in_chans = vision_cfg.value("in_channels", vision_cfg.value("in_chans", visual_in_chans)); + visual_hidden_size = vision_cfg.value("hidden_size", vision_cfg.value("embed_dim", visual_hidden_size)); + visual_patch_size = vision_cfg.value("patch_size", vision_cfg.value("spatial_patch_size", visual_patch_size)); + visual_temporal_patch_size = vision_cfg.value("temporal_patch_size", visual_temporal_patch_size); + visual_spatial_merge_size = vision_cfg.value("spatial_merge_size", visual_spatial_merge_size); + visual_out_hidden_size = vision_cfg.value("out_hidden_size", visual_out_hidden_size); + visual_num_heads = vision_cfg.value("num_heads", visual_num_heads); + visual_depth = vision_cfg.value("depth", visual_depth); + visual_intermediate_size = vision_cfg.value("intermediate_size", visual_intermediate_size); + if (vision_cfg.contains("fullatt_block_indexes")) { + visual_fullatt_block_indexes = vision_cfg["fullatt_block_indexes"].get>(); + } + visual_window_size = vision_cfg.value("window_size", visual_window_size); + } + + if (thinker_cfg.contains("audio_config")) { + auto& audio_cfg = thinker_cfg["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + + bos_token_id = thinker_cfg.value("bos_token_id", bos_token_id); + eos_token_id = thinker_cfg.value("eos_token_id", eos_token_id); + pad_token_id = thinker_cfg.value("pad_token_id", pad_token_id); + image_token_id = thinker_cfg.value("image_token_index", image_token_id); + audio_token_id = thinker_cfg.value("audio_token_index", audio_token_id); + video_token_id = thinker_cfg.value("video_token_index", video_token_id); + audio_start_token_id = thinker_cfg.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = thinker_cfg.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = thinker_cfg.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = thinker_cfg.value("vision_end_token_id", vision_end_token_id); + vision_token_id = thinker_cfg.value("vision_token_id", vision_token_id); + position_id_per_seconds = thinker_cfg.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = thinker_cfg.value("seconds_per_chunk", seconds_per_chunk); + } else { + hidden_size = root["hidden_size"]; + intermediate_size = root["intermediate_size"]; + num_attention_heads = root["num_attention_heads"]; + num_key_value_heads = root["num_key_value_heads"]; + num_hidden_layers = root["num_hidden_layers"]; + max_position_embeddings = root["max_position_embeddings"]; + rms_norm_eps = root["rms_norm_eps"]; + vocab_size = root["vocab_size"]; + rope_theta = root["rope_theta"]; + tie_word_embeddings = root.value("tie_word_embeddings", tie_word_embeddings); + if (root.contains("mrope_section")) { + mrope_section = root["mrope_section"].get>(); + } + if (root.contains("audio_config")) { + auto& audio_cfg = root["audio_config"]; + audio_d_model = audio_cfg.value("d_model", audio_d_model); + audio_num_mel_bins = audio_cfg.value("num_mel_bins", audio_num_mel_bins); + audio_encoder_layers = audio_cfg.value("encoder_layers", audio_encoder_layers); + audio_encoder_attention_heads = audio_cfg.value("encoder_attention_heads", audio_encoder_attention_heads); + audio_encoder_ffn_dim = audio_cfg.value("encoder_ffn_dim", audio_encoder_ffn_dim); + audio_max_source_positions = audio_cfg.value("max_source_positions", audio_max_source_positions); + audio_n_window = audio_cfg.value("n_window", audio_n_window); + audio_output_dim = audio_cfg.value("output_dim", audio_output_dim); + } + bos_token_id = root.value("bos_token_id", bos_token_id); + eos_token_id = root.value("eos_token_id", eos_token_id); + pad_token_id = root.value("pad_token_id", pad_token_id); + image_token_id = root.value("image_token_id", image_token_id); + audio_token_id = root.value("audio_token_id", audio_token_id); + video_token_id = root.value("video_token_id", video_token_id); + audio_start_token_id = root.value("audio_start_token_id", audio_start_token_id); + audio_end_token_id = root.value("audio_end_token_id", audio_end_token_id); + vision_start_token_id = root.value("vision_start_token_id", vision_start_token_id); + vision_end_token_id = root.value("vision_end_token_id", vision_end_token_id); + vision_token_id = root.value("vision_token_id", vision_token_id); + position_id_per_seconds = root.value("position_id_per_seconds", position_id_per_seconds); + seconds_per_chunk = root.value("seconds_per_chunk", seconds_per_chunk); + } + + max_cache_length = root.value("max_cache_length", max_position_embeddings); + + if (root.contains("linear_impl_type")) { + linear_impl_type = aops::str2LinearImplTypes(root["linear_impl_type"]); + } + } + + int32_t hidden_size = 3584; + int32_t intermediate_size = 18944; + int32_t num_attention_heads = 28; + int32_t num_key_value_heads = 4; + int32_t num_hidden_layers = 28; + int32_t max_position_embeddings = 32768; + float rms_norm_eps = 1e-06f; + int32_t vocab_size = 152064; + std::vector mrope_section = {16, 24, 24}; + float rope_theta = 1000000.0f; + bool tie_word_embeddings = false; + + int32_t visual_in_chans = 3; + int32_t visual_hidden_size = 1280; + int32_t visual_patch_size = 14; + int32_t visual_temporal_patch_size = 2; + int32_t visual_spatial_merge_size = 2; + int32_t visual_out_hidden_size = 3584; + int32_t visual_num_heads = 16; + int32_t visual_depth = 32; + int32_t visual_intermediate_size = 3420; + std::vector visual_fullatt_block_indexes = {7, 15, 23, 31}; + int32_t visual_window_size = 112; + + int32_t audio_d_model = 1280; + int32_t audio_num_mel_bins = 128; + int32_t audio_encoder_layers = 32; + int32_t audio_encoder_attention_heads = 20; + int32_t audio_encoder_ffn_dim = 5120; + int32_t audio_max_source_positions = 1500; + int32_t audio_n_window = 100; + int32_t audio_output_dim = 3584; + + int32_t max_cache_length = 32768; + + int64_t bos_token_id = 151644; + int64_t eos_token_id = 151645; + int64_t pad_token_id = 151643; + int64_t image_token_id = 151655; + int64_t audio_token_id = 151646; + int64_t video_token_id = 151656; + int64_t audio_start_token_id = 151647; + int64_t audio_end_token_id = 151648; + int64_t vision_start_token_id = 151652; + int64_t vision_end_token_id = 151653; + int64_t vision_token_id = 151654; + int32_t position_id_per_seconds = 25; + int32_t seconds_per_chunk = 2; + + bool enable_audio_output = true; + Qwen2_5OmniTalkerConfig talker_cfg{}; + Qwen2_5OmniToken2WavConfig token2wav_cfg{}; + + aops::LinearImplTypes linear_impl_type = aops::LinearImplTypes::kDefault; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp new file mode 100644 index 000000000..42bae162f --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni.hpp @@ -0,0 +1,2036 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/mllm.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" +#include "mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp" +#include "mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp" + +namespace mllm::models::qwen2_5omni { + +inline auto makeMultimodalRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeMultimodalPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, int seq_len, int output_dim, + const std::vector& mrope_section) -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + + +inline auto makeWindowIndex(const Tensor& grid_thw, int window_size, int spatial_merge_size, + int patch_size) -> std::pair, std::vector> { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + const int grid_num = grid_thw.shape()[0]; + + const int vit_merger_window_size = window_size / spatial_merge_size / patch_size; + const int spatial_merge_unit = spatial_merge_size * spatial_merge_size; + + std::vector window_index; + std::vector cu_window_seqlens = {0}; + int window_index_id = 0; + + for (int grid_idx = 0; grid_idx < grid_num; ++grid_idx) { + const int grid_t = grid_thw.constAt({grid_idx, 0}); + const int grid_h = grid_thw.constAt({grid_idx, 1}); + const int grid_w = grid_thw.constAt({grid_idx, 2}); + + const int llm_grid_h = grid_h / spatial_merge_size; + const int llm_grid_w = grid_w / spatial_merge_size; + const int pad_h = (vit_merger_window_size - llm_grid_h % vit_merger_window_size) % vit_merger_window_size; + const int pad_w = (vit_merger_window_size - llm_grid_w % vit_merger_window_size) % vit_merger_window_size; + + const int num_windows_h = (llm_grid_h + pad_h) / vit_merger_window_size; + const int num_windows_w = (llm_grid_w + pad_w) / vit_merger_window_size; + const int total_windows = grid_t * num_windows_h * num_windows_w; + + std::vector>> index( + grid_t, std::vector>(llm_grid_h, std::vector(llm_grid_w))); + + int counter = 0; + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index[t][h][w] = counter++; } + } + } + + std::vector>> index_padded( + grid_t, std::vector>(llm_grid_h + pad_h, std::vector(llm_grid_w + pad_w, -100))); + + for (int t = 0; t < grid_t; t++) { + for (int h = 0; h < llm_grid_h; h++) { + for (int w = 0; w < llm_grid_w; w++) { index_padded[t][h][w] = index[t][h][w]; } + } + } + + std::vector seqlens(total_windows, 0); + for (int t = 0; t < grid_t; t++) { + for (int wh = 0; wh < num_windows_h; wh++) { + for (int ww = 0; ww < num_windows_w; ww++) { + const int window_idx = t * num_windows_h * num_windows_w + wh * num_windows_w + ww; + for (int h = 0; h < vit_merger_window_size; h++) { + for (int w = 0; w < vit_merger_window_size; w++) { + const int orig_h = wh * vit_merger_window_size + h; + const int orig_w = ww * vit_merger_window_size + w; + if (index_padded[t][orig_h][orig_w] != -100) { + window_index.push_back(index_padded[t][orig_h][orig_w] + window_index_id); + seqlens[window_idx]++; + } + } + } + } + } + } + + int cumulative = cu_window_seqlens.back(); + for (int i = 0; i < total_windows; i++) { + cumulative += seqlens[i] * spatial_merge_unit; + cu_window_seqlens.push_back(cumulative); + } + + window_index_id += grid_t * llm_grid_h * llm_grid_w; + } + + return {window_index, cu_window_seqlens}; +} + +inline auto makeVisualRoPEInvFreq(int32_t dims, float theta) -> Tensor { + const int half_dim = dims / (2 * 2); + Tensor inv_freq = Tensor::empty({half_dim}, kFloat32).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + const float dims_inv = 1.0f / static_cast(dims / 2); + for (int i = 0; i < half_dim; ++i) { + const float exponent = (2.0f * i) * dims_inv; + inv_freq_ptr[i] = 1.0f / std::pow(theta, exponent); + } + return inv_freq; +} + +inline auto makeVisualRotaryPosEmbIds(Tensor& grid_thw, int32_t spatial_merge_size) -> Tensor { + MLLM_RT_ASSERT_EQ(grid_thw.shape().size(), 2); + + const auto img_nums = grid_thw.shape()[0]; + int total_positions = 0; + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({total_positions, 2}, kInt32).alloc(); + int* out_ptr = out.ptr(); + int out_offset = 0; + + for (int row = 0; row < img_nums; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + const int t = dims[0]; + const int h = dims[1]; + const int w = dims[2]; + + const int num_h_blocks = h / spatial_merge_size; + const int num_w_blocks = w / spatial_merge_size; + const int total_blocks = num_h_blocks * num_w_blocks; + const int block_area = spatial_merge_size * spatial_merge_size; + const int grid_size = h * w; + + std::vector flatten_hpos(grid_size); + std::vector flatten_wpos(grid_size); + + for (int block_idx = 0; block_idx < total_blocks; ++block_idx) { + const int i_h = block_idx / num_w_blocks; + const int i_w = block_idx % num_w_blocks; + const int start_idx = block_idx * block_area; + + const int base_h = i_h * spatial_merge_size; + const int base_w = i_w * spatial_merge_size; + + for (int j_h = 0; j_h < spatial_merge_size; ++j_h) { + const int global_h = base_h + j_h; + for (int j_w = 0; j_w < spatial_merge_size; ++j_w) { + const int global_w = base_w + j_w; + const int pos = start_idx + j_h * spatial_merge_size + j_w; + flatten_hpos[pos] = global_h; + flatten_wpos[pos] = global_w; + } + } + } + + for (int frame = 0; frame < t; ++frame) { + for (int pos = 0; pos < grid_size; ++pos) { + const int out_idx = out_offset + (frame * grid_size + pos) * 2; + out_ptr[out_idx] = flatten_hpos[pos]; + out_ptr[out_idx + 1] = flatten_wpos[pos]; + } + } + out_offset += t * grid_size * 2; + } + + return out; +} + +inline float kaiserBesselI0(float x) { + const float ax = std::fabs(x); + if (ax < 3.75f) { + const float y = (x / 3.75f) * (x / 3.75f); + return 1.0f + y * (3.5156229f + y * (3.0899424f + y * (1.2067492f + y * (0.2659732f + y * (0.0360768f + y * 0.0045813f))))); + } + const float y = 3.75f / ax; + return (std::exp(ax) / std::sqrt(ax)) * + (0.39894228f + y * (0.01328592f + y * (0.00225319f + y * (-0.00157565f + y * (0.00916281f + + y * (-0.02057706f + y * (0.02635537f + y * (-0.01647633f + y * 0.00392377f)))))))); +} + +inline Tensor kaiserSincFilter1d(float cutoff, float half_width, int32_t kernel_size) { + const bool is_even = (kernel_size % 2 == 0); + const int32_t half_size = kernel_size / 2; + + const float delta_f = 4.0f * half_width; + const float attenuation = 2.285f * (half_size - 1) * static_cast(M_PI) * delta_f + 7.95f; + + float beta = 0.0f; + if (attenuation > 50.0f) { + beta = 0.1102f * (attenuation - 8.7f); + } else if (attenuation >= 21.0f) { + beta = 0.5842f * std::pow(attenuation - 21.0f, 0.4f) + 0.07886f * (attenuation - 21.0f); + } + + std::vector window(kernel_size); + const float denom = kaiserBesselI0(beta); + for (int32_t n = 0; n < kernel_size; ++n) { + const float ratio = (kernel_size == 1) ? 0.0f : (2.0f * n) / (kernel_size - 1) - 1.0f; + const float val = beta * std::sqrt(std::max(0.0f, 1.0f - ratio * ratio)); + window[n] = (denom == 0.0f) ? 0.0f : kaiserBesselI0(val) / denom; + } + + std::vector time_indices(kernel_size); + if (is_even) { + for (int32_t i = 0; i < kernel_size; ++i) { time_indices[i] = static_cast(i - half_size) + 0.5f; } + } else { + for (int32_t i = 0; i < kernel_size; ++i) { time_indices[i] = static_cast(i - half_size); } + } + + Tensor filter = Tensor::empty({1, 1, kernel_size}, kFloat32, kCPU).alloc(); + auto* filter_ptr = filter.ptr(); + + if (cutoff == 0.0f) { + std::fill(filter_ptr, filter_ptr + kernel_size, 0.0f); + return filter; + } + + float sum = 0.0f; + for (int32_t i = 0; i < kernel_size; ++i) { + const float x = 2.0f * cutoff * time_indices[i]; + const float sinc = (x == 0.0f) ? 1.0f : std::sin(static_cast(M_PI) * x) / (static_cast(M_PI) * x); + const float value = 2.0f * cutoff * window[i] * sinc; + filter_ptr[i] = value; + sum += value; + } + if (sum != 0.0f) { + for (int32_t i = 0; i < kernel_size; ++i) { filter_ptr[i] /= sum; } + } + + return filter; +} + +inline auto makeVisualRotaryPosEmbFull(Tensor& inv_freq, int seq_len) -> Tensor { + MLLM_RT_ASSERT(seq_len > 0); + const int32_t dim = inv_freq.shape()[0]; + Tensor freqs = Tensor::empty({seq_len, dim}, kFloat32, kCPU).alloc(); + float* inv_freq_ptr = inv_freq.ptr(); + float* freqs_ptr = freqs.ptr(); + for (int i = 0; i < seq_len; ++i) { + const float i_val = static_cast(i); + float* row_ptr = freqs_ptr + i * dim; + for (int j = 0; j < dim; ++j) { row_ptr[j] = i_val * inv_freq_ptr[j]; } + } + return freqs; +} + +inline auto makeVisualRotaryPosEmb(Tensor& rotary_pos_emb_full, Tensor& pos_ids, Tensor& grid_thw) -> Tensor { + const int32_t dim = rotary_pos_emb_full.shape()[1]; + const int32_t batch_size = pos_ids.shape()[0]; + const int32_t seq_len = pos_ids.shape()[1]; + + int total_positions = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + total_positions += dims[0] * dims[1] * dims[2]; + } + + Tensor out = Tensor::empty({batch_size, seq_len * dim}, kFloat32, kCPU).alloc(); + + auto rotary_pos_emb_full_ptr = rotary_pos_emb_full.ptr(); + auto pos_ids_ptr = pos_ids.ptr(); + + if (rotary_pos_emb_full.shape()[0] <= 0 || dim <= 0 || batch_size <= 0) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Invalid tensor dimensions"); + } + + if (total_positions != batch_size) { MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Grid dimensions mismatch with batch size"); } + + for (int i = 0; i < batch_size; ++i) { + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + if (idx < 0 || idx >= rotary_pos_emb_full.shape()[0]) { + MLLM_ERROR_EXIT(ExitCode::kSliceOB, "Position index out of bounds"); + } + } + } + + for (int i = 0; i < batch_size; ++i) { + auto batch_ptr = out.offsettedPtr({i, 0}); + size_t offset = 0; + for (int j = 0; j < seq_len; ++j) { + const int idx = pos_ids_ptr[i * seq_len + j]; + auto emb_ptr = rotary_pos_emb_full_ptr + idx * dim; + std::copy(emb_ptr, emb_ptr + dim, batch_ptr + offset); + offset += dim; + } + } + + return out; +} + +inline auto makeVisualRotarySinCos(Tensor& rotary_pos_emb) -> std::pair { + const auto seq = rotary_pos_emb.shape()[0]; + const auto dim = rotary_pos_emb.shape()[1]; + + auto rotary_pos_emb_ptr = rotary_pos_emb.ptr(); + + Tensor sin_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + Tensor cos_pos_emb = Tensor::empty({seq, dim}, kFloat32, kCPU).alloc(); + + auto sin_pos_emb_ptr = sin_pos_emb.ptr(); + auto cos_pos_emb_ptr = cos_pos_emb.ptr(); + + for (int i = 0; i < seq; i++) { + for (int j = 0; j < dim; j++) { + sin_pos_emb_ptr[i * dim + j] = std::sin(rotary_pos_emb_ptr[i * dim + j]); + cos_pos_emb_ptr[i * dim + j] = std::cos(rotary_pos_emb_ptr[i * dim + j]); + } + } + + return {sin_pos_emb, cos_pos_emb}; +} + +inline auto makeAudioSinusoidalPosEmb(int32_t length, int32_t channels, float max_timescale = 10000.0f) -> Tensor { + MLLM_RT_ASSERT(channels % 2 == 0); + auto pos_emb = Tensor::empty({length, channels}, kFloat32, kCPU).alloc(); + auto pos_ptr = pos_emb.ptr(); + + const int half = channels / 2; + const float log_timescale_increment = std::log(max_timescale) / static_cast(half - 1); + + std::vector inv_timescales(half); + for (int i = 0; i < half; ++i) { + inv_timescales[i] = std::exp(-log_timescale_increment * static_cast(i)); + } + + for (int t = 0; t < length; ++t) { + for (int i = 0; i < half; ++i) { + const float scaled_time = static_cast(t) * inv_timescales[i]; + pos_ptr[t * channels + i] = std::sin(scaled_time); + pos_ptr[t * channels + half + i] = std::cos(scaled_time); + } + } + + return pos_emb; +} + +class Qwen2_5OmniPatchEmbed final : public nn::Module { + int32_t in_chans_; + int32_t embed_dim_; + int32_t patch_size_; + int32_t temporal_patch_size_; + + nn::Conv3D proj_; + + public: + Qwen2_5OmniPatchEmbed() = default; + + explicit Qwen2_5OmniPatchEmbed(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + in_chans_ = cfg.visual_in_chans; + embed_dim_ = cfg.visual_hidden_size; + patch_size_ = cfg.visual_patch_size; + temporal_patch_size_ = cfg.visual_temporal_patch_size; + + proj_ = reg("proj", cfg.visual_in_chans, cfg.visual_hidden_size, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + std::vector{cfg.visual_temporal_patch_size, cfg.visual_patch_size, cfg.visual_patch_size}, + false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + hidden_states = hidden_states.view({-1, in_chans_, temporal_patch_size_, patch_size_, patch_size_}); + hidden_states = proj_(hidden_states).view({-1, embed_dim_}); + return {hidden_states}; + } +}; + +class Qwen2_5OmniPatchMerger final : public nn::Module { + int32_t hidden_size_; + int32_t spatial_merge_size_; + int32_t context_dim_; + + nn::RMSNorm ln_q_; + nn::Linear mlp_0_; + nn::Linear mlp_2_; + nn::GELU mlp_gelu_; + + public: + Qwen2_5OmniPatchMerger() = default; + + explicit Qwen2_5OmniPatchMerger(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + context_dim_ = cfg.visual_hidden_size; + spatial_merge_size_ = cfg.visual_spatial_merge_size; + hidden_size_ = context_dim_ * spatial_merge_size_ * spatial_merge_size_; + + ln_q_ = reg("ln_q", 1e-6); + mlp_0_ = reg("mlp.0", hidden_size_, hidden_size_, true, cfg.linear_impl_type); + mlp_gelu_ = reg("mlp.gelu"); + mlp_2_ = reg("mlp.2", hidden_size_, cfg.visual_out_hidden_size, true, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto o = ln_q_(inputs[0]).view({-1, hidden_size_}); + o = mlp_0_(o); + o = mlp_gelu_(o); + o = mlp_2_(o); + return {o}; + } +}; + +class Qwen2_5OmniVisionMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniVisionMLP() = default; + explicit Qwen2_5OmniVisionMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.visual_hidden_size, cfg.visual_intermediate_size, true); + down_proj_ = reg("down_proj", cfg.visual_intermediate_size, cfg.visual_hidden_size, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniVisionAttention final : public nn::Module { + int32_t dim_; + int32_t num_heads_; + int32_t head_dim_; + + nn::Linear q_; + nn::Linear k_; + nn::Linear v_; + nn::Linear proj_; + nn::Softmax softmax_; + nn::VisionRoPE vision_rope_q_; + nn::VisionRoPE vision_rope_k_; + + public: + Qwen2_5OmniVisionAttention() = default; + + explicit Qwen2_5OmniVisionAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + dim_ = cfg.visual_hidden_size; + num_heads_ = cfg.visual_num_heads; + head_dim_ = dim_ / num_heads_; + + q_ = reg("q", dim_, dim_, true, cfg.linear_impl_type); + k_ = reg("k", dim_, dim_, true, cfg.linear_impl_type); + v_ = reg("v", dim_, dim_, true, cfg.linear_impl_type); + proj_ = reg("proj", dim_, dim_, true, cfg.linear_impl_type); + softmax_ = reg("softmax", -1); + + vision_rope_q_ = reg("vision_rope_q", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + vision_rope_k_ = reg("vision_rope_k", aops::VisionRoPEOpOptionsType::kQwen2VL, + aops::Qwen2VLRoPEOpOptions{ + .dims = head_dim_, + .spatial_merge_size = cfg.visual_spatial_merge_size, + .theta = 10000.0, + }); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto& mask = inputs[3]; + + auto seq_length = hidden_states.shape()[0]; + + auto query_states = q_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto key_states = k_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + auto value_states = v_(hidden_states).view({seq_length, num_heads_, head_dim_}).unsqueeze(0); + + query_states = vision_rope_q_(query_states, visual_embedding_sin, visual_embedding_cos); + key_states = vision_rope_k_(key_states, visual_embedding_sin, visual_embedding_cos); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + auto attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + if (mask) { attn = attn + mask; } + attn = softmax_(attn); + + auto attn_output = nn::functional::matmul(attn, value_states); + attn_output = attn_output.transpose(1, 2).view({seq_length, -1}); + attn_output = proj_(attn_output); + return {attn_output}; + } +}; + +class Qwen2_5OmniVisionBlock final : public nn::Module { + nn::RMSNorm norm1_; + nn::RMSNorm norm2_; + + Qwen2_5OmniVisionAttention attn_; + Qwen2_5OmniVisionMLP mlp_; + + public: + Qwen2_5OmniVisionBlock() = default; + + explicit Qwen2_5OmniVisionBlock(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + norm1_ = reg("norm1", 1e-6); + norm2_ = reg("norm2", 1e-6); + attn_ = reg("attn", cfg); + mlp_ = reg("mlp", cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto visual_embedding_sin = inputs[1]; + auto visual_embedding_cos = inputs[2]; + auto mask = inputs[3]; + + hidden_states = hidden_states + attn_(norm1_(hidden_states), visual_embedding_sin, visual_embedding_cos, mask)[0]; + hidden_states = hidden_states + mlp_(norm2_(hidden_states))[0]; + return {hidden_states}; + } +}; + +class Qwen2_5OmniVisionEncoder final : public nn::Module { + Qwen2_5OmniPatchEmbed patch_embed_; + Qwen2_5OmniPatchMerger patch_merger_; + nn::ModuleList blocks_; + std::vector visual_fullatt_block_indexes_; + int32_t visual_window_size_ = 0; + int32_t visual_spatial_merge_size_ = 1; + int32_t visual_patch_size_ = 1; + int32_t spatial_merge_unit_ = 1; + + public: + Qwen2_5OmniVisionEncoder() = default; + + explicit Qwen2_5OmniVisionEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + visual_window_size_ = cfg.visual_window_size; + visual_spatial_merge_size_ = cfg.visual_spatial_merge_size; + visual_patch_size_ = cfg.visual_patch_size; + spatial_merge_unit_ = visual_spatial_merge_size_ * visual_spatial_merge_size_; + visual_fullatt_block_indexes_ = cfg.visual_fullatt_block_indexes; + patch_embed_ = reg("patch_embed", cfg); + patch_merger_ = reg("merger", cfg); + blocks_ = reg>("blocks", cfg.visual_depth, cfg); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto embedding_sin = inputs[1]; + auto embedding_cos = inputs[2]; + auto& grid_thw = inputs[3]; + + hidden_states = patch_embed_(hidden_states)[0]; + auto [window_index, cu_window_seqlens] = + makeWindowIndex(grid_thw, visual_window_size_, visual_spatial_merge_size_, visual_patch_size_); + + auto seq_len = hidden_states.shape()[0]; + hidden_states = hidden_states.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + hidden_states = hidden_states[{window_index, {kAll}, {kAll}}]; + hidden_states = hidden_states.view({seq_len, -1}); + + embedding_sin = embedding_sin.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_sin = embedding_sin[{window_index, {kAll}, {kAll}}]; + embedding_sin = embedding_sin.view({seq_len, -1}); + embedding_cos = embedding_cos.view({seq_len / spatial_merge_unit_, spatial_merge_unit_, -1}); + embedding_cos = embedding_cos[{window_index, {kAll}, {kAll}}]; + embedding_cos = embedding_cos.view({seq_len, -1}); + + auto mask = Tensor::empty({1, 1, seq_len, seq_len}, DataTypes::kFloat32, DeviceTypes::kCPU).alloc(); + { + auto mask_ptr = mask.ptr(); + const mllm_fp32_t neg_inf = -1e12f; + for (int i = 0; i < seq_len * seq_len; ++i) { mask_ptr[i] = neg_inf; } + for (int i = 1; i < cu_window_seqlens.size(); ++i) { + const int start = cu_window_seqlens[i - 1]; + const int end = cu_window_seqlens[i]; + for (int r = start; r < end; ++r) { + for (int c = start; c < end; ++c) { mask_ptr[r * seq_len + c] = 0.0f; } + } + } + } + + for (auto [layer_idx, b] : enumerate(blocks_.list())) { + if (std::find(visual_fullatt_block_indexes_.begin(), visual_fullatt_block_indexes_.end(), layer_idx) + != visual_fullatt_block_indexes_.end()) { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, Tensor::nil())[0]; + } else { + hidden_states = b(hidden_states, embedding_sin, embedding_cos, mask)[0]; + } + } + + hidden_states = patch_merger_(hidden_states)[0]; + + std::vector reverse_indices(window_index.size()); + std::iota(reverse_indices.begin(), reverse_indices.end(), 0); + std::sort(reverse_indices.begin(), reverse_indices.end(), + [&window_index](int i, int j) { return window_index[i] < window_index[j]; }); + hidden_states = hidden_states[{reverse_indices, {kAll}}]; + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioAttention final : public nn::Module { + int32_t embed_dim_ = 0; + int32_t num_heads_ = 0; + int32_t head_dim_ = 0; + + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear q_proj_; + nn::Linear out_proj_; + + public: + Qwen2_5OmniAudioAttention() = default; + + explicit Qwen2_5OmniAudioAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + embed_dim_ = cfg.audio_d_model; + num_heads_ = cfg.audio_encoder_attention_heads; + head_dim_ = embed_dim_ / num_heads_; + + k_proj_ = reg("k_proj", embed_dim_, embed_dim_, false); + v_proj_ = reg("v_proj", embed_dim_, embed_dim_, true); + q_proj_ = reg("q_proj", embed_dim_, embed_dim_, true); + out_proj_ = reg("out_proj", embed_dim_, embed_dim_, true); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; // [seq_len, embed_dim] + auto seq_len = hidden_states.shape()[0]; + + auto hidden = hidden_states.unsqueeze(0); // [1, seq_len, embed_dim] + auto query_states = q_proj_(hidden); + auto key_states = k_proj_(hidden); + auto value_states = v_proj_(hidden); + + query_states = query_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + key_states = key_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + value_states = value_states.view({1, seq_len, num_heads_, head_dim_}).transpose(1, 2); + + float scale = 1.0f / std::sqrt(static_cast(head_dim_)); + auto attn_weights = nn::functional::matmul(query_states, key_states.transpose(-2, -1)) * scale; + attn_weights = nn::functional::softmax(attn_weights, -1); + auto attn_output = nn::functional::matmul(attn_weights, value_states); + + attn_output = attn_output.transpose(1, 2).contiguous().view({1, seq_len, embed_dim_}); + attn_output = out_proj_(attn_output); + + return {attn_output.squeeze(0)}; + } +}; + +class Qwen2_5OmniAudioEncoderLayer final : public nn::Module { + Qwen2_5OmniAudioAttention self_attn_; + nn::LayerNorm self_attn_layer_norm_; + nn::Linear fc1_; + nn::Linear fc2_; + nn::LayerNorm final_layer_norm_; + nn::GELU activation_fn_; + + public: + Qwen2_5OmniAudioEncoderLayer() = default; + + explicit Qwen2_5OmniAudioEncoderLayer(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + const int32_t embed_dim = cfg.audio_d_model; + self_attn_ = reg("self_attn", cfg); + self_attn_layer_norm_ = + reg("self_attn_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + fc1_ = reg("fc1", embed_dim, cfg.audio_encoder_ffn_dim, true); + fc2_ = reg("fc2", cfg.audio_encoder_ffn_dim, embed_dim, true); + final_layer_norm_ = reg("final_layer_norm", std::vector{embed_dim}, true, true, 1e-5); + activation_fn_ = reg("activation_fn"); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto hidden_states = inputs[0]; + auto residual = hidden_states; + + hidden_states = self_attn_layer_norm_(hidden_states); + hidden_states = self_attn_(hidden_states)[0]; + hidden_states = residual + hidden_states; + + residual = hidden_states; + hidden_states = final_layer_norm_(hidden_states); + hidden_states = fc1_(hidden_states); + hidden_states = activation_fn_(hidden_states); + hidden_states = fc2_(hidden_states); + hidden_states = residual + hidden_states; + + if (hidden_states.dtype() == kFloat16) { + const float clamp_value = 65504.0f - 1000.0f; + hidden_states = nn::functional::clip(hidden_states, -clamp_value, clamp_value); + } + + return {hidden_states}; + } +}; + +class Qwen2_5OmniAudioEncoder final : public nn::Module { + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::GELU gelu_; + nn::ModuleList layers_; + nn::LayerNorm ln_post_; + nn::AvgPool1d avg_pooler_; + nn::Linear proj_; + nn::Embedding audio_bos_eos_token_; + + int32_t num_mel_bins_ = 0; + int32_t embed_dim_ = 0; + int32_t n_window_ = 0; + int32_t output_dim_ = 0; + + public: + Qwen2_5OmniAudioEncoder() = default; + + explicit Qwen2_5OmniAudioEncoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + num_mel_bins_ = cfg.audio_num_mel_bins; + embed_dim_ = cfg.audio_d_model; + n_window_ = cfg.audio_n_window; + output_dim_ = cfg.audio_output_dim; + + conv1_ = reg("conv1", num_mel_bins_, embed_dim_, 3, 1, 1); + conv2_ = reg("conv2", embed_dim_, embed_dim_, 3, 2, 1); + gelu_ = reg("gelu"); + audio_bos_eos_token_ = reg("audio_bos_eos_token", 2, cfg.audio_output_dim); + layers_ = reg>("layers", cfg.audio_encoder_layers, cfg); + ln_post_ = reg("ln_post", std::vector{embed_dim_}, true, true, 1e-5); + avg_pooler_ = reg("avg_pooler", 2, 2); + proj_ = reg("proj", embed_dim_, cfg.audio_output_dim, true); + + auto pos_emb = makeAudioSinusoidalPosEmb(cfg.audio_max_source_positions, embed_dim_); + registerBuffer("positional_embedding", pos_emb); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto input_features = inputs[0]; // [B, n_mels, T] + MLLM_RT_ASSERT_EQ(input_features.shape().size(), 3); + + const int32_t batch_size = input_features.shape()[0]; + MLLM_RT_ASSERT_EQ(input_features.shape()[1], num_mel_bins_); + const int32_t feature_len = input_features.shape()[2]; + MLLM_RT_ASSERT(feature_len > 0); + + auto pos_emb = getBuffer("positional_embedding"); + + std::vector audio_outputs; + audio_outputs.reserve(batch_size); + + for (int32_t b = 0; b < batch_size; ++b) { + Tensor audio_b = input_features[make_slice(b), kAll, kAll].view({1, num_mel_bins_, feature_len}).contiguous(); + + const int32_t chunk_size = n_window_ * 2; + const int32_t num_chunks = (feature_len + chunk_size - 1) / chunk_size; + + std::vector chunk_outputs; + chunk_outputs.reserve(num_chunks); + + for (int32_t chunk_idx = 0; chunk_idx < num_chunks; ++chunk_idx) { + const int32_t start = chunk_idx * chunk_size; + const int32_t chunk_len = std::min(chunk_size, feature_len - start); + auto chunk = Tensor::empty({1, num_mel_bins_, chunk_len}, kFloat32, kCPU).alloc(); + for (int32_t m = 0; m < num_mel_bins_; ++m) { + auto src_ptr = audio_b.offsettedPtr({0, m, start}); + auto dst_ptr = chunk.offsettedPtr({0, m, 0}); + std::memcpy(dst_ptr, src_ptr, chunk_len * sizeof(float)); + } + + auto x = conv1_(chunk); + x = gelu_(x); + x = conv2_(x); + x = gelu_(x); + x = x.transpose(1, 2).contiguous(); // [1, T2, D] + + const int32_t t2 = x.shape()[1]; + MLLM_RT_ASSERT(t2 <= pos_emb.shape()[0]); + auto pos_ptr = pos_emb.ptr(); + auto x_ptr = x.ptr(); + for (int32_t t = 0; t < t2; ++t) { + const float* pos_row = pos_ptr + t * embed_dim_; + float* x_row = x_ptr + t * embed_dim_; + for (int32_t d = 0; d < embed_dim_; ++d) { x_row[d] += pos_row[d]; } + } + + auto hidden_states = x.squeeze(0); // [T2, D] + for (auto& layer : layers_.list()) { hidden_states = layer(hidden_states)[0]; } + if (hidden_states.shape()[0] < 2) { continue; } + + auto pooled = hidden_states.unsqueeze(0).transpose(1, 2); // [1, D, T] + pooled = avg_pooler_(pooled); + pooled = pooled.transpose(1, 2).squeeze(0); // [T', D] + pooled = ln_post_(pooled); + pooled = proj_(pooled); + chunk_outputs.push_back(pooled); + } + + int32_t total_len = 0; + for (const auto& chunk : chunk_outputs) { total_len += chunk.shape()[0]; } + + auto merged = Tensor::empty({total_len, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& chunk : chunk_outputs) { + const int32_t len = chunk.shape()[0]; + const float* src_ptr = chunk.ptr(); + float* dst_ptr = merged.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + audio_outputs.push_back(merged); + } + + int32_t total_audio_tokens = 0; + for (const auto& out : audio_outputs) { total_audio_tokens += out.shape()[0]; } + + auto output = Tensor::empty({total_audio_tokens, output_dim_}, kFloat32, kCPU).alloc(); + int32_t offset = 0; + for (const auto& out : audio_outputs) { + const int32_t len = out.shape()[0]; + const float* src_ptr = out.ptr(); + float* dst_ptr = output.offsettedPtr({offset, 0}); + std::memcpy(dst_ptr, src_ptr, len * output_dim_ * sizeof(float)); + offset += len; + } + + return {output}; + } +}; + +class Qwen2_5OmniMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniMLP() = default; + Qwen2_5OmniMLP(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false, cfg.linear_impl_type); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false, cfg.linear_impl_type); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniAttention() = default; + + Qwen2_5OmniAttention(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + head_dim_ = hidden_size_ / num_attention_heads_; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true, cfg.linear_impl_type); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true, cfg.linear_impl_type); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false, cfg.linear_impl_type); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_; +}; + +class Qwen2_5OmniDecoder final : public nn::Module { + public: + Qwen2_5OmniAttention self_attn_; + Qwen2_5OmniMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniDecoder() = default; + + Qwen2_5OmniDecoder(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniText final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniText() = default; + + Qwen2_5OmniText(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.hidden_size); + + auto inv = makeMultimodalRoPEInvFreq(cfg.hidden_size / cfg.num_attention_heads, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +class Qwen2_5OmniThinker final : public nn::Module { + public: + Qwen2_5OmniThinker() = default; + Qwen2_5OmniThinker(const std::string& name, const Qwen2_5OmniConfig& cfg) : nn::Module(name) { + model_ = reg("model", cfg); + audio_tower_ = reg("audio_tower", cfg); + visual_ = reg("visual", cfg); + lm_head_ = reg("lm_head", cfg.hidden_size, cfg.vocab_size, false, cfg.linear_impl_type); + } + + Qwen2_5OmniText model_; + Qwen2_5OmniAudioEncoder audio_tower_; + Qwen2_5OmniVisionEncoder visual_; + nn::Linear lm_head_; +}; + +class Qwen2_5OmniForCausalLM : public ARGeneration { + public: + explicit Qwen2_5OmniForCausalLM(const Qwen2_5OmniConfig& cfg) : cfg_(cfg), thinker_("thinker", cfg) { + kv_cache_ = nn::StaticCache(cfg.max_cache_length, cfg.num_hidden_layers, + cfg.num_attention_heads, + cfg.num_key_value_heads, + cfg.hidden_size / cfg.num_attention_heads, + kFloat32, + kFloat32, + kCPU, + false); + eos_token_id_ = cfg.eos_token_id; + max_length_ = cfg.max_cache_length; + } + + void clearCache() { kv_cache_.clearCache(); } + + ARGenerationOutputPast forward(const ARGenerationOutputPast& input, const ARGenerationArgs& args) override { + auto sequence = input.at("sequence"); + + auto input_embeddings = thinker_.model_.embedding_(sequence); + + if (input.count("input_features")) { + auto input_features = input.at("input_features"); + auto audio_embeddings = thinker_.audio_tower_(input_features)[0]; + MLLM_RT_ASSERT_EQ(audio_embeddings.shape()[1], input_embeddings.shape()[2]); + if (audio_embeddings.dtype() != input_embeddings.dtype()) { + audio_embeddings = audio_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector audio_positions; + audio_positions.reserve(audio_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.audio_token_id) { audio_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(audio_positions.size()), audio_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < audio_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, audio_positions[i], 0}); + auto in_ptr = audio_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni audio input."); + } + } + + if (input.count("img")) { + auto img = input.at("img"); + auto grid_thw = input.at("grid_thw"); + + auto inv_freq = makeVisualRoPEInvFreq(cfg_.visual_hidden_size / cfg_.visual_num_heads, 10000.0f); + auto pos_ids = makeVisualRotaryPosEmbIds(grid_thw, cfg_.visual_spatial_merge_size); + + int max_grid = 0; + for (int row = 0; row < grid_thw.shape()[0]; ++row) { + const int* dims = grid_thw.offsettedPtr({row, 0}); + max_grid = std::max({max_grid, dims[1], dims[2]}); + } + MLLM_RT_ASSERT(max_grid > 0); + auto rotary_pos_emb_full = makeVisualRotaryPosEmbFull(inv_freq, max_grid); + auto pos_emb = makeVisualRotaryPosEmb(rotary_pos_emb_full, pos_ids, grid_thw); + auto [visual_embedding_sin, visual_embedding_cos] = makeVisualRotarySinCos(pos_emb); + + auto visual_embeddings = thinker_.visual_(img, visual_embedding_sin, visual_embedding_cos, grid_thw)[0]; + MLLM_RT_ASSERT_EQ(visual_embeddings.shape()[1], input_embeddings.shape()[2]); + if (visual_embeddings.dtype() != input_embeddings.dtype()) { + visual_embeddings = visual_embeddings.to(input_embeddings.dtype()); + } + + MLLM_RT_ASSERT_EQ(sequence.shape()[0], 1); + auto S = sequence.shape()[1]; + std::vector image_positions; + image_positions.reserve(visual_embeddings.shape()[0]); + auto input_ids_ptr = sequence.ptr(); + for (int s = 0; s < S; ++s) { + if (input_ids_ptr[s] == cfg_.image_token_id) { image_positions.push_back(s); } + } + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), visual_embeddings.shape()[0]); + + auto D = input_embeddings.shape()[2]; + if (input_embeddings.dtype() == kFloat32) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else if (input_embeddings.dtype() == kFloat16) { + for (size_t i = 0; i < image_positions.size(); ++i) { + auto out_ptr = input_embeddings.offsettedPtr({0, image_positions[i], 0}); + auto in_ptr = visual_embeddings.offsettedPtr({static_cast(i), 0}); + std::copy(in_ptr, in_ptr + D, out_ptr); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni image input."); + } + } + + Tensor position_ids = input.count("position_ids") ? input.at("position_ids") : Tensor::nil(); + Tensor img = input.count("img") ? input.at("img") : Tensor::nil(); + Tensor grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + position_ids = getPositionIds(img, grid_thw, sequence, position_ids); + + auto [llm_embedding_sin, llm_embedding_cos] = + makeMultimodalPositionEmbedding(position_ids, thinker_.model_.getBuffer("inv_freq"), cfg_.max_position_embeddings, + cfg_.hidden_size / cfg_.num_attention_heads, cfg_.mrope_section); + + auto hidden_states = thinker_.model_(input_embeddings, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto seq_len = hidden_states.shape()[1]; + auto last_hidden = hidden_states[{kAll, {seq_len - 1}, kAll}]; + auto logits = thinker_.lm_head_(last_hidden); + + const bool output_hidden_states = + args.count("output_hidden_states") ? args.at("output_hidden_states").get() : false; + + if (output_hidden_states) { + return { + {"sequence", logits}, + {"position_ids", position_ids}, + {"hidden_states", hidden_states}, + {"input_embeddings", input_embeddings}, + }; + } + + return { + {"sequence", logits}, + {"position_ids", position_ids}, + }; + } + + Qwen2_5OmniThinker thinker_; + + private: + Tensor getPositionIds(Tensor& img, Tensor& grid_thw, Tensor& input_ids, Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = *position_ids.offsettedPtr({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(Tensor& input_ids, Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + MLLM_RT_ASSERT(image_idx < total_images); + if (image_grid_thw.isNil()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing grid_thw for Qwen2.5-Omni vision input."); + } + MLLM_RT_ASSERT_EQ(image_grid_thw.shape().size(), 2); + + std::vector image_positions; + for (int i = segment_start + 1; i < vision_end; ++i) { + if (input_ids_ptr[i] == cfg_.image_token_id) { + image_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside vision segment."); + } + } + + const int* grid_dims = image_grid_thw.offsettedPtr({image_idx, 0}); + const int grid_t = grid_dims[0]; + const int grid_h = grid_dims[1]; + const int grid_w = grid_dims[2]; + + const int image_token_len = (grid_t * grid_h * grid_w) + / (cfg_.visual_spatial_merge_size * cfg_.visual_spatial_merge_size); + MLLM_RT_ASSERT_EQ(static_cast(image_positions.size()), image_token_len); + + const int inputs_t = grid_t; + const int inputs_h = grid_h / cfg_.visual_spatial_merge_size; + const int inputs_w = grid_w / cfg_.visual_spatial_merge_size; + + const int64_t vision_start_id = current_max_position_id + 1; + int pos_counter = 0; + for (int ti = 0; ti < inputs_t; ++ti) { + const int64_t t_id = vision_start_id + static_cast(ti) * cfg_.position_id_per_seconds; + for (int hi = 0; hi < inputs_h; ++hi) { + for (int wi = 0; wi < inputs_w; ++wi) { + const auto seq_pos = image_positions[pos_counter++]; + *position_ids.offsettedPtr({0, 0, seq_pos}) = t_id; + *position_ids.offsettedPtr({1, 0, seq_pos}) = vision_start_id + hi; + *position_ids.offsettedPtr({2, 0, seq_pos}) = vision_start_id + wi; + } + } + } + + const int64_t dim_0_tail = vision_start_id + static_cast(inputs_t - 1) * cfg_.position_id_per_seconds; + const int64_t dim_1_tail = vision_start_id + inputs_h - 1; + const int64_t dim_2_tail = vision_start_id + inputs_w - 1; + current_max_position_id = std::max({dim_0_tail, dim_1_tail, dim_2_tail}); + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); + } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = audio_end + 1; + } + } + + MLLM_RT_ASSERT_EQ(image_idx, total_images); + return position_ids; + } + + const Qwen2_5OmniConfig& cfg_; + nn::StaticCache kv_cache_; +}; + +struct Qwen2_5OmniAudioGenerationConfig { + int32_t thinker_max_new_tokens = 1024; + bool thinker_do_sample = false; + int32_t thinker_top_k = 0; + float thinker_top_p = 0.0f; + float thinker_temperature = 1.0f; + + int32_t talker_max_new_tokens = 1024; + int32_t talker_min_new_tokens = 128; + bool talker_do_sample = true; + int32_t talker_top_k = 40; + float talker_top_p = 0.8f; + float talker_temperature = 0.9f; + float talker_repetition_penalty = 1.05f; + std::vector talker_eos_token_ids = {}; + bool suppress_codec_bos = true; + + int32_t token2wav_num_steps = 10; + float token2wav_guidance_scale = 0.5f; + float token2wav_sway_coefficient = -1.0f; +}; + +struct Qwen2_5OmniAudioGenerationResult { + Tensor sequences = Tensor::nil(); + Tensor wav = Tensor::nil(); +}; + +class Qwen2_5OmniForConditionalGeneration { + public: + explicit Qwen2_5OmniForConditionalGeneration(const Qwen2_5OmniConfig& cfg) + : cfg_(cfg), + thinker_(cfg_), + talker_("talker", cfg_.talker_cfg), + token2wav_("token2wav", cfg_.token2wav_cfg) {} + + void load(const ParameterFile::ptr_t& param) { + thinker_.thinker_.load(param); + if (cfg_.enable_audio_output) { + talker_.load(param); + token2wav_.load(param); + } + } + + void loadSpeakers(const std::string& path) { speaker_map_ = loadSpeakerMap(path); } + + void clearCache() { + thinker_.clearCache(); + talker_.clearCache(); + } + + Qwen2_5OmniAudioGenerationResult generateAudio(const ARGenerationOutputPast& input, const Qwen2_5OmniAudioGenerationConfig& gen_cfg, + const std::string& speaker = "") { + if (!cfg_.enable_audio_output) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Audio output is disabled in Qwen2.5-Omni config."); + } + if (speaker_map_.speakers.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Speaker map is empty. Call loadSpeakers() first."); + } + + const std::string speaker_name = speaker.empty() ? speaker_map_.default_speaker : speaker; + auto spk_it = speaker_map_.speakers.find(speaker_name); + if (spk_it == speaker_map_.speakers.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unknown speaker '{}'.", speaker_name); + } + + auto thinker_output = runThinkerGeneration(input, gen_cfg); + if (thinker_output.generated_ids.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Thinker produced no tokens; cannot run talker."); + } + + auto talker_output = runTalkerGeneration(input, thinker_output, spk_it->second, gen_cfg); + auto wav = token2wav_.forward(talker_output, spk_it->second.cond.to(kFloat32), spk_it->second.ref_mel.to(kFloat32), + gen_cfg.token2wav_num_steps, gen_cfg.token2wav_guidance_scale, gen_cfg.token2wav_sway_coefficient); + + return { + .sequences = thinker_output.sequences, + .wav = wav, + }; + } + + Tensor generateReferenceWav(const std::string& speaker = "") { + if (speaker_map_.speakers.empty()) { return Tensor::nil(); } + const std::string speaker_name = speaker.empty() ? speaker_map_.default_speaker : speaker; + auto spk_it = speaker_map_.speakers.find(speaker_name); + if (spk_it == speaker_map_.speakers.end()) { return Tensor::nil(); } + auto ref_mel = spk_it->second.ref_mel.to(kFloat32); + ref_mel = ref_mel.permute({0, 2, 1}); + if (!ref_mel.isContiguous()) { ref_mel = ref_mel.contiguous(); } + return token2wav_.vocodeMel(ref_mel); + } + + private: + Qwen2_5OmniConfig cfg_; + Qwen2_5OmniSpeakerMap speaker_map_{}; + + public: + Qwen2_5OmniForCausalLM thinker_; + Qwen2_5OmniTalker talker_; + Qwen2_5OmniToken2WavModel token2wav_; + + private: + struct ThinkerGenerationOutput { + Tensor sequences = Tensor::nil(); + std::vector generated_ids; + std::vector token_embeddings; + std::vector token_hidden_states; + int32_t prompt_len = 0; + }; + + static Tensor makeTokenTensor(int64_t token_id) { + Tensor out = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + out.at({0, 0}) = token_id; + return out; + } + + static Tensor makeTokenTensor(const std::vector& ids) { + Tensor out = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU).alloc(); + auto* ptr = out.ptr(); + std::copy(ids.begin(), ids.end(), ptr); + return out; + } + + static Tensor concatTokenTensors(const std::vector& parts) { + MLLM_RT_ASSERT(!parts.empty()); + int32_t total_len = 0; + for (const auto& part : parts) { + MLLM_RT_ASSERT_EQ(part.shape().size(), 2); + MLLM_RT_ASSERT_EQ(part.shape()[0], 1); + MLLM_RT_ASSERT_EQ(part.dtype(), kInt64); + MLLM_RT_ASSERT_EQ(part.device(), kCPU); + total_len += part.shape()[1]; + } + + Tensor out = Tensor::empty({1, total_len}, kInt64, kCPU).alloc(); + auto* out_ptr = out.ptr(); + int32_t offset = 0; + for (const auto& part : parts) { + auto* in_ptr = part.ptr(); + int32_t len = part.shape()[1]; + std::copy(in_ptr, in_ptr + len, out_ptr + offset); + offset += len; + } + return out; + } + + static void zeroEmbeddingsByTokenId(Tensor& embeds, const Tensor& input_ids, int64_t token_id) { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + MLLM_RT_ASSERT_EQ(embeds.shape().size(), 3); + MLLM_RT_ASSERT_EQ(input_ids.shape()[1], embeds.shape()[1]); + + auto seq_len = input_ids.shape()[1]; + auto dim = embeds.shape()[2]; + auto* ids = input_ids.ptr(); + + if (embeds.dtype() == kFloat32) { + for (int s = 0; s < seq_len; ++s) { + if (ids[s] != token_id) continue; + auto* out_ptr = embeds.offsettedPtr({0, s, 0}); + std::fill(out_ptr, out_ptr + dim, 0.0f); + } + } else if (embeds.dtype() == kFloat16) { + for (int s = 0; s < seq_len; ++s) { + if (ids[s] != token_id) continue; + auto* out_ptr = embeds.offsettedPtr({0, s, 0}); + std::fill(out_ptr, out_ptr + dim, static_cast(0)); + } + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported embedding dtype for Qwen2.5-Omni talker preparation."); + } + } + + static Tensor getLastLogits(const Tensor& logits) { + MLLM_RT_ASSERT_EQ(logits.shape().size(), 3); + if (logits.shape()[1] == 1) { return logits; } + return logits[{kAll, logits.shape()[1] - 1, kAll}]; + } + + static int64_t sampleFromDistribution(const std::vector& probs) { + std::random_device rd; + std::mt19937 gen(rd()); + std::discrete_distribution<> dist(probs.begin(), probs.end()); + return dist(gen); + } + + static int64_t categoricalSample(const Tensor& probs) { + MLLM_RT_ASSERT_EQ(probs.dtype(), kFloat32); + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + + std::vector cumulative_probs(vocab_size); + std::partial_sum(prob_data, prob_data + vocab_size, cumulative_probs.begin()); + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution<> dis(0.0, 1.0); + float r = dis(gen); + + auto it = std::lower_bound(cumulative_probs.begin(), cumulative_probs.end(), r); + if (it == cumulative_probs.end()) { return static_cast(vocab_size - 1); } + return static_cast(std::distance(cumulative_probs.begin(), it)); + } + + static void applyRepetitionPenalty(Tensor& logits, const std::vector& token_ids, float penalty) { + if (penalty <= 1.0f || token_ids.empty()) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + + int vocab_size = logits.shape().back(); + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + std::unordered_set unique_ids; + unique_ids.reserve(token_ids.size()); + for (auto id : token_ids) { unique_ids.insert(id); } + + auto* logits_ptr = logits.ptr(); + for (auto id : unique_ids) { + if (id < 0 || id >= vocab_size) { continue; } + float& v = logits_ptr[id]; + v = (v < 0.0f) ? v * penalty : v / penalty; + } + } + + static void applyTopKLogits(Tensor& logits, int32_t top_k) { + if (top_k <= 0) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + int vocab_size = logits.shape().back(); + int k = std::min(std::max(top_k, 1), vocab_size); + + auto* logits_ptr = logits.ptr(); + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&logits_ptr](int i1, int i2) { return logits_ptr[i1] > logits_ptr[i2]; }); + + float threshold = logits_ptr[indices[k - 1]]; + float neg_inf = -std::numeric_limits::infinity(); + for (int i = 0; i < vocab_size; ++i) { + if (logits_ptr[i] < threshold) { logits_ptr[i] = neg_inf; } + } + } + + static void applyTopPLogits(Tensor& logits, float top_p) { + if (top_p <= 0.0f || top_p >= 1.0f) { return; } + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (logits.shape().size() == 2) { MLLM_RT_ASSERT_EQ(logits.shape()[0], 1); } + + int vocab_size = logits.shape().back(); + auto* logits_ptr = logits.ptr(); + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&logits_ptr](int i1, int i2) { return logits_ptr[i1] > logits_ptr[i2]; }); + + float max_logit = logits_ptr[indices[0]]; + std::vector probs(vocab_size); + float sum_exp = 0.0f; + for (int i = 0; i < vocab_size; ++i) { + float exp_val = std::exp(logits_ptr[indices[i]] - max_logit); + probs[i] = exp_val; + sum_exp += exp_val; + } + if (sum_exp <= 0.0f) { return; } + for (auto& p : probs) { p /= sum_exp; } + + float cumulative = 0.0f; + int keep = 0; + for (int i = 0; i < vocab_size; ++i) { + cumulative += probs[i]; + keep++; + if (cumulative > top_p) { break; } + } + keep = std::max(keep, 1); + + float neg_inf = -std::numeric_limits::infinity(); + for (int i = keep; i < vocab_size; ++i) { + logits_ptr[indices[i]] = neg_inf; + } + } + + static int64_t sampleFromLogits(Tensor logits, bool do_sample) { + if (logits.dtype() != kFloat32) { logits = logits.to(kFloat32); } + if (!do_sample) { + auto* logits_ptr = logits.ptr(); + int vocab_size = logits.shape().back(); + auto max_it = std::max_element(logits_ptr, logits_ptr + vocab_size); + return static_cast(std::distance(logits_ptr, max_it)); + } + Tensor probs = nn::functional::softmax(logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample(probs); + } + + static int64_t sampleGreedyLocal(const Tensor& logits) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + auto* logits_data = last_logits.ptr(); + int vocab_size = last_logits.shape().back(); + auto max_it = std::max_element(logits_data, logits_data + vocab_size); + return static_cast(std::distance(logits_data, max_it)); + } + + static int64_t sampleTemperatureLocal(const Tensor& logits, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + return categoricalSample(probs); + } + + static int64_t sampleTopKLocal(const Tensor& logits, int k, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + if (k <= 0 || k > vocab_size) { k = vocab_size; } + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::partial_sort(indices.begin(), indices.begin() + k, indices.end(), + [&prob_data](int i1, int i2) { return prob_data[i1] > prob_data[i2]; }); + + std::vector top_k_probs(k); + float sum = 0.0f; + for (int i = 0; i < k; ++i) { + top_k_probs[i] = prob_data[indices[i]]; + sum += top_k_probs[i]; + } + if (sum <= 0.0f) { return static_cast(indices[0]); } + for (int i = 0; i < k; ++i) { top_k_probs[i] *= (1.f / sum); } + + return static_cast(indices[sampleFromDistribution(top_k_probs)]); + } + + static int64_t sampleTopPLocal(const Tensor& logits, float p, float temperature) { + Tensor last_logits = getLastLogits(logits); + if (last_logits.dtype() != kFloat32) { last_logits = last_logits.to(kFloat32); } + if (temperature != 1.0f && temperature > 0.0f) { last_logits = last_logits * (1.f / temperature); } + Tensor probs = nn::functional::softmax(last_logits, -1); + if (probs.dtype() != kFloat32) { probs = probs.to(kFloat32); } + + auto* prob_data = probs.ptr(); + int vocab_size = probs.shape().back(); + + std::vector indices(vocab_size); + std::iota(indices.begin(), indices.end(), 0); + std::sort(indices.begin(), indices.end(), [&prob_data](int i1, int i2) { return prob_data[i1] > prob_data[i2]; }); + + std::vector top_probs; + float cumulative_prob = 0.0f; + int i = 0; + for (; i < vocab_size && cumulative_prob < p; ++i) { + top_probs.push_back(prob_data[indices[i]]); + cumulative_prob += prob_data[indices[i]]; + } + + float sum = std::accumulate(top_probs.begin(), top_probs.end(), 0.0f); + if (sum <= 0.0f) { return static_cast(indices[0]); } + for (float& prob : top_probs) { prob *= (1.f / sum); } + + return static_cast(indices[sampleFromDistribution(top_probs)]); + } + + int64_t sampleToken(const Tensor& logits, bool do_sample, int32_t top_k, float top_p, float temperature) { + bool use_sampling = do_sample || (temperature != 1.0f) || (top_k > 0) || (top_p > 0.0f); + if (use_sampling) { + if (top_k > 0) { return sampleTopKLocal(logits, top_k, temperature); } + if (top_p > 0.0f) { return sampleTopPLocal(logits, top_p, temperature); } + return sampleTemperatureLocal(logits, temperature); + } + return sampleGreedyLocal(logits); + } + + ThinkerGenerationOutput runThinkerGeneration(const ARGenerationOutputPast& input, const Qwen2_5OmniAudioGenerationConfig& gen_cfg) { + thinker_.clearCache(); + + ARGenerationOutputPast past = input; + ARGenerationArgs args; + args.emplace("output_hidden_states", AnyValue(true)); + + const auto& input_ids = input.at("sequence"); + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + std::vector generated_ids; + std::vector token_embeddings; + std::vector token_hidden_states; + + for (int32_t step = 0; step < gen_cfg.thinker_max_new_tokens; ++step) { + auto output = thinker_.forward(past, args); + auto logits = output.at("sequence"); + + auto input_embeddings = output.at("input_embeddings"); + auto hidden_states = output.at("hidden_states"); + + if (step == 0) { + auto embeds_to_talker = input_embeddings.clone(); + if (input.count("input_features")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.audio_token_id); } + if (input.count("img")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.image_token_id); } + if (input.count("video")) { zeroEmbeddingsByTokenId(embeds_to_talker, input_ids, cfg_.video_token_id); } + token_embeddings.emplace_back(std::move(embeds_to_talker)); + } else { + token_embeddings.emplace_back(std::move(input_embeddings)); + } + token_hidden_states.emplace_back(std::move(hidden_states)); + + int64_t next_token_id = sampleToken(logits, gen_cfg.thinker_do_sample, gen_cfg.thinker_top_k, gen_cfg.thinker_top_p, + gen_cfg.thinker_temperature); + generated_ids.push_back(next_token_id); + + if (next_token_id == cfg_.eos_token_id) { break; } + + past = std::move(output); + past["sequence"] = makeTokenTensor(next_token_id); + } + + std::vector sequence_ids; + sequence_ids.reserve(input_ids.shape()[1] + generated_ids.size()); + auto* input_ptr = input_ids.ptr(); + for (int i = 0; i < input_ids.shape()[1]; ++i) { sequence_ids.push_back(input_ptr[i]); } + sequence_ids.insert(sequence_ids.end(), generated_ids.begin(), generated_ids.end()); + + return { + .sequences = makeTokenTensor(sequence_ids), + .generated_ids = std::move(generated_ids), + .token_embeddings = std::move(token_embeddings), + .token_hidden_states = std::move(token_hidden_states), + .prompt_len = input_ids.shape()[1], + }; + } + + Tensor runTalkerGeneration(const ARGenerationOutputPast& input, const ThinkerGenerationOutput& thinker_output, + const Qwen2_5OmniSpeakerParams& speaker_params, const Qwen2_5OmniAudioGenerationConfig& gen_cfg) { + if (thinker_output.generated_ids.empty()) { return Tensor::nil(); } + + talker_.clearCache(); + + const auto& input_ids = input.at("sequence"); + const auto& token_embeddings = thinker_output.token_embeddings; + const auto& token_hidden_states = thinker_output.token_hidden_states; + + std::vector reply_hidden_states(token_hidden_states.begin() + 1, token_hidden_states.end()); + std::vector reply_token_embeds(token_embeddings.begin() + 1, token_embeddings.end()); + + auto hidden_dtype = token_hidden_states[0].dtype(); + auto hidden_device = token_hidden_states[0].device(); + auto embed_dtype = token_embeddings[0].dtype(); + auto embed_device = token_embeddings[0].device(); + Tensor reply_hidden = reply_hidden_states.empty() + ? Tensor::empty({1, 0, token_hidden_states[0].shape()[2]}, hidden_dtype, hidden_device).alloc() + : nn::functional::concat(reply_hidden_states, 1); + Tensor reply_embeds = reply_token_embeds.empty() + ? Tensor::empty({1, 0, token_embeddings[0].shape()[2]}, embed_dtype, embed_device).alloc() + : nn::functional::concat(reply_token_embeds, 1); + auto thinker_reply_part = reply_hidden + reply_embeds; + if (thinker_reply_part.shape()[1] == 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Thinker response is too short for talker conditioning."); + } + + std::vector talker_text_ids; + talker_text_ids.reserve(input_ids.shape()[1] + 2); + auto* input_ptr = input_ids.ptr(); + for (int i = 0; i < input_ids.shape()[1]; ++i) { talker_text_ids.push_back(input_ptr[i]); } + talker_text_ids.push_back(speaker_params.bos_token); + talker_text_ids.push_back(thinker_output.generated_ids.front()); + auto talker_input_text_ids = makeTokenTensor(talker_text_ids); + + std::vector talker_codec_ids(input_ids.shape()[1] + 2, talker_.codec_mask_token()); + talker_codec_ids[input_ids.shape()[1]] = talker_.codec_pad_token(); + talker_codec_ids[input_ids.shape()[1] + 1] = talker_.codec_bos_token(); + auto talker_input_ids = makeTokenTensor(talker_codec_ids); + + auto talker_inputs_embeds = Tensor(token_hidden_states[0]); + talker_inputs_embeds = talker_inputs_embeds + token_embeddings[0]; + auto talker_text_bos_embed = thinker_.thinker_.model_.embedding_(makeTokenTensor(speaker_params.bos_token)); + auto first_reply = thinker_reply_part.shape()[1] > 0 + ? thinker_reply_part[{kAll, {0, 1}, kAll}] + : Tensor::empty({1, 0, talker_inputs_embeds.shape()[2]}, talker_inputs_embeds.dtype(), talker_inputs_embeds.device()) + .alloc(); + talker_inputs_embeds = nn::functional::concat({talker_inputs_embeds, talker_text_bos_embed, first_reply}, 1); + + auto eos_embedding = thinker_.thinker_.model_.embedding_(makeTokenTensor(talker_.text_eos_token())); + auto pad_embedding = thinker_.thinker_.model_.embedding_(makeTokenTensor(talker_.text_pad_token())); + Tensor reply_tail = + thinker_reply_part.shape()[1] > 1 + ? thinker_reply_part[{kAll, {1, thinker_reply_part.shape()[1]}, kAll}] + : Tensor::empty({1, 0, talker_inputs_embeds.shape()[2]}, talker_inputs_embeds.dtype(), talker_inputs_embeds.device()).alloc(); + thinker_reply_part = nn::functional::concat({reply_tail, eos_embedding, pad_embedding}, 1); + + Tensor talker_attention_mask = Tensor::nil(); + if (input.count("attention_mask")) { + auto mask = input.at("attention_mask"); + if (mask.dtype() != kFloat16 && mask.dtype() != kFloat32) { mask = mask.to(kFloat32); } + auto ones = Tensor::ones({1, 2}, mask.dtype(), mask.device()); + talker_attention_mask = nn::functional::concat({mask, ones}, 1); + } + + Tensor image_grid_thw = input.count("grid_thw") ? input.at("grid_thw") : Tensor::nil(); + + std::vector generated_codes; + Tensor position_ids = Tensor::nil(); + Tensor cur_input_ids = talker_input_ids; + Tensor cur_input_text_ids = talker_input_text_ids; + Tensor cur_inputs_embeds = talker_inputs_embeds; + Tensor cur_reply_part = thinker_reply_part; + + std::vector repetition_tokens = talker_codec_ids; + repetition_tokens.reserve(talker_codec_ids.size() + gen_cfg.talker_max_new_tokens); + + std::vector eos_ids = gen_cfg.talker_eos_token_ids; + if (eos_ids.empty()) { + eos_ids.push_back(talker_.codec_pad_token()); + eos_ids.push_back(talker_.codec_eos_token()); + } + + for (int32_t step = 0; step < gen_cfg.talker_max_new_tokens; ++step) { + auto output = talker_.forward(cur_input_ids, cur_input_text_ids, cur_reply_part, cur_inputs_embeds, talker_attention_mask, + image_grid_thw, position_ids); + + auto logits = output.logits; + auto last_logits = getLastLogits(logits); + + const int32_t vocab_size = last_logits.shape().back(); + + if (gen_cfg.suppress_codec_bos) { + auto* logits_ptr = last_logits.ptr(); + logits_ptr[talker_.codec_bos_token()] = -1e9f; + } + if (gen_cfg.talker_min_new_tokens > 0 && step < gen_cfg.talker_min_new_tokens) { + auto* logits_ptr = last_logits.ptr(); + for (int64_t eos_id : eos_ids) { + if (eos_id >= 0 && eos_id < vocab_size) { logits_ptr[eos_id] = -1e9f; } + } + } + applyRepetitionPenalty(last_logits, repetition_tokens, gen_cfg.talker_repetition_penalty); + + Tensor sample_logits = last_logits; + if (gen_cfg.talker_temperature != 1.0f && gen_cfg.talker_temperature > 0.0f) { + sample_logits = sample_logits * (1.f / gen_cfg.talker_temperature); + } + if (gen_cfg.talker_do_sample) { + if (gen_cfg.talker_top_k > 0) { applyTopKLogits(sample_logits, gen_cfg.talker_top_k); } + if (gen_cfg.talker_top_p > 0.0f) { applyTopPLogits(sample_logits, gen_cfg.talker_top_p); } + } + + int64_t next_token_id = sampleFromLogits(sample_logits, gen_cfg.talker_do_sample); + generated_codes.push_back(next_token_id); + repetition_tokens.push_back(next_token_id); + + if (std::find(eos_ids.begin(), eos_ids.end(), next_token_id) != eos_ids.end()) { break; } + + position_ids = output.position_ids; + cur_reply_part = output.thinker_reply_part; + cur_input_ids = makeTokenTensor(next_token_id); + cur_input_text_ids = Tensor::nil(); + cur_inputs_embeds = Tensor::nil(); + } + + if (!generated_codes.empty()) { generated_codes.pop_back(); } + if (generated_codes.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Talker produced no codec tokens."); + } + return makeTokenTensor(generated_codes); + } + + +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp new file mode 100644 index 000000000..df8019a84 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_talker.hpp @@ -0,0 +1,626 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/nn/lmcache/StaticCache.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +constexpr float kPi = 3.14159265358979323846f; + +inline auto makeTalkerRoPEInvFreq(int output_dim, float rope_theta) -> Tensor { + auto inv_freq = Tensor::empty({output_dim / 2}, kFloat32, kCPU).alloc(); + auto inv_freq_ptr = inv_freq.ptr(); + for (int i = 0; i < output_dim / 2; i++) { inv_freq_ptr[i] = 1.0f / std::pow(rope_theta, 2.0f * i / output_dim); } + return inv_freq; +} + +inline auto makeTalkerPositionEmbedding(Tensor& position_ids, const Tensor& inv_freq, const std::vector& mrope_section) + -> std::pair { + MLLM_RT_ASSERT_EQ(position_ids.shape().size(), 3); + MLLM_RT_ASSERT_EQ(position_ids.shape()[1], 1); + + Tensor tmp_sin = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + Tensor tmp_cos = Tensor::empty({3, position_ids.shape()[2], inv_freq.shape()[0] * 2}).alloc(); + + for (int b = 0; b < 3; ++b) { + for (int d = 0; d < inv_freq.shape()[0]; ++d) { + for (int s = 0; s < position_ids.shape()[2]; ++s) { + auto value = inv_freq.ptr()[d] * (*position_ids.offsettedPtr({b, 0, s})); + *tmp_cos.offsettedPtr({b, s, d}) = cosf(value); + *tmp_cos.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = cosf(value); + *tmp_sin.offsettedPtr({b, s, d}) = sinf(value); + *tmp_sin.offsettedPtr({b, s, d + inv_freq.shape()[0]}) = sinf(value); + } + } + } + + Tensor sin = Tensor::nil(); + Tensor cos = Tensor::nil(); + + if (!mrope_section.empty()) { + auto double_rope_section = mrope_section; + for (int i : mrope_section) { double_rope_section.push_back(i); } + + int num_rows = tmp_sin.shape()[1]; + int num_cols = tmp_sin.shape()[2]; + + sin = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + cos = Tensor::empty({num_rows, num_cols}, kFloat32, kCPU).alloc(); + + std::vector start_cols; + int current_start = 0; + start_cols.push_back(current_start); + for (int s : double_rope_section) { + current_start += s; + start_cols.push_back(current_start); + } + + for (int j = 0; j < static_cast(double_rope_section.size()); ++j) { + int layer = j % 3; + int s_j = double_rope_section[j]; + int start_col_in = start_cols[j]; + int start_col_out = start_cols[j]; + for (int row = 0; row < num_rows; ++row) { + auto in_cos_row_ptr = tmp_cos.offsettedPtr({layer, row, 0}); + auto out_cos_row_ptr = cos.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_cos_row_ptr[start_col_out + c] = in_cos_row_ptr[start_col_in + c]; } + + auto in_sin_row_ptr = tmp_sin.offsettedPtr({layer, row, 0}); + auto out_sin_row_ptr = sin.offsettedPtr({row, 0}); + for (int c = 0; c < s_j; ++c) { out_sin_row_ptr[start_col_out + c] = in_sin_row_ptr[start_col_in + c]; } + } + } + } else { + sin = tmp_sin; + cos = tmp_cos; + } + + return {sin, cos}; +} + +struct Qwen2_5OmniSpeakerParams { + int64_t bos_token = 0; + Tensor cond = Tensor::nil(); + Tensor ref_mel = Tensor::nil(); +}; + +struct Qwen2_5OmniSpeakerMap { + std::unordered_map speakers; + std::string default_speaker; +}; + +inline Tensor tensorFromJson(const nlohmann::ordered_json& obj) { + if (!obj.contains("shape") || !obj.contains("data")) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Invalid speaker json entry: missing shape/data."); + } + auto shape = obj["shape"].get>(); + auto data = obj["data"].get>(); + + int64_t expected = 1; + for (auto dim : shape) { expected *= dim; } + MLLM_RT_ASSERT_EQ(expected, static_cast(data.size())); + + Tensor out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + std::copy(data.begin(), data.end(), out.ptr()); + return out; +} + +inline Qwen2_5OmniSpeakerMap loadSpeakerMap(const std::string& path) { + std::ifstream in(path); + if (!in.is_open()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to open spk_dict.json at {}", path); } + + nlohmann::ordered_json root; + in >> root; + + Qwen2_5OmniSpeakerMap map; + bool first = true; + for (auto it = root.begin(); it != root.end(); ++it) { + const auto& name = it.key(); + const auto& entry = it.value(); + Qwen2_5OmniSpeakerParams params; + params.bos_token = entry.value("bos_token", 0); + params.cond = tensorFromJson(entry["cond"]); + params.ref_mel = tensorFromJson(entry["ref_mel"]); + map.speakers.emplace(name, std::move(params)); + if (first) { + map.default_speaker = name; + first = false; + } + } + + if (map.speakers.empty()) { MLLM_ERROR_EXIT(ExitCode::kIOError, "Empty speaker map in {}", path); } + return map; +} + +class Qwen2_5OmniTalkerMLP final : public nn::Module { + nn::Linear gate_proj_; + nn::Linear up_proj_; + nn::Linear down_proj_; + nn::SiLU silu_; + + public: + Qwen2_5OmniTalkerMLP() = default; + Qwen2_5OmniTalkerMLP(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + gate_proj_ = reg("gate_proj", cfg.hidden_size, cfg.intermediate_size, false); + silu_ = reg("act"); + up_proj_ = reg("up_proj", cfg.hidden_size, cfg.intermediate_size, false); + down_proj_ = reg("down_proj", cfg.intermediate_size, cfg.hidden_size, false); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = gate_proj_(inputs[0]); + x = silu_(x); + auto y = up_proj_(inputs[0]); + x = x * y; + x = down_proj_(x); + return {x}; + } +}; + +class Qwen2_5OmniTalkerAttention final : public nn::Module { + nn::Linear q_proj_; + nn::Linear k_proj_; + nn::Linear v_proj_; + nn::Linear o_proj_; + nn::MultimodalRoPE q_rope_; + nn::MultimodalRoPE k_rope_; + nn::CausalMask mask_; + nn::Softmax softmax_; + + int hidden_size_; + int head_dim_; + int num_attention_heads_; + int num_key_value_heads_; + int num_key_value_groups_; + + public: + Qwen2_5OmniTalkerAttention() = default; + + Qwen2_5OmniTalkerAttention(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + hidden_size_ = cfg.hidden_size; + head_dim_ = cfg.head_dim; + num_attention_heads_ = cfg.num_attention_heads; + num_key_value_heads_ = cfg.num_key_value_heads; + num_key_value_groups_ = num_attention_heads_ / num_key_value_heads_; + + q_proj_ = reg("q_proj", hidden_size_, head_dim_ * num_attention_heads_, true); + k_proj_ = reg("k_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + v_proj_ = reg("v_proj", hidden_size_, head_dim_ * num_key_value_heads_, true); + o_proj_ = reg("o_proj", head_dim_ * num_attention_heads_, hidden_size_, false); + + q_rope_ = reg( + "q_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + k_rope_ = reg( + "k_rope", aops::Qwen2VLMultimodalRoPEOpOptions{.rope_theta = cfg.rope_theta, + .max_position_embeddings = cfg.max_position_embeddings, + .mrope_section = cfg.mrope_section}); + + mask_ = reg("mask"); + softmax_ = reg("softmax", -1); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto past_kv_cache = args[0].get(); + + auto query_states = q_proj_(x); + auto key_states = k_proj_(x); + auto value_states = v_proj_(x); + + int B = inputs[0].shape()[0]; + int S = inputs[0].shape()[1]; + + query_states = query_states.view({B, S, num_attention_heads_, head_dim_}); + key_states = key_states.view({B, S, num_key_value_heads_, head_dim_}); + value_states = value_states.view({B, S, num_key_value_heads_, head_dim_}); + + query_states = query_states.transpose(1, 2); + key_states = key_states.transpose(1, 2); + value_states = value_states.transpose(1, 2); + + query_states = q_rope_(query_states, llm_embedding_sin, llm_embedding_cos); + key_states = k_rope_(key_states, llm_embedding_sin, llm_embedding_cos); + + auto [k, v] = past_kv_cache->updateKVCache(layer_idx_, key_states, value_states); + key_states = k; + value_states = v; + + Tensor attn; + if (key_states.dtype() == kFloat32) { + attn = nn::functional::matmul(query_states, key_states, false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + } else if (key_states.dtype() == kFloat16) { + attn = nn::functional::matmul(query_states.to(kFloat32), key_states.to(kFloat32), false, true) * (1.f / sqrtf(head_dim_)); + attn = mask_(attn); + attn = softmax_(attn); + attn = attn.to(kFloat16); + } + + auto output = nn::functional::matmul(attn, value_states); + output = output.transpose(1, 2).view({B, S, num_attention_heads_ * head_dim_}); + output = o_proj_(output); + return {output}; + } + + int layer_idx_ = 0; +}; + +class Qwen2_5OmniTalkerDecoder final : public nn::Module { + public: + Qwen2_5OmniTalkerAttention self_attn_; + Qwen2_5OmniTalkerMLP mlp_; + nn::RMSNorm input_layer_norm_; + nn::RMSNorm post_attention_layer_norm_; + + Qwen2_5OmniTalkerDecoder() = default; + + Qwen2_5OmniTalkerDecoder(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + self_attn_ = reg("self_attn", cfg); + mlp_ = reg("mlp", cfg); + input_layer_norm_ = reg("input_layernorm", cfg.rms_norm_eps); + post_attention_layer_norm_ = reg("post_attention_layernorm", cfg.rms_norm_eps); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + auto x = input_layer_norm_(inputs[0]); + x = self_attn_(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; + auto tmp = x + inputs[0]; + x = post_attention_layer_norm_(tmp); + x = mlp_(x)[0]; + x = x + tmp; + return {x}; + } +}; + +class Qwen2_5OmniTalkerModel final : public nn::Module { + nn::ModuleList decode_blocks_; + nn::RMSNorm norm_; + + public: + Qwen2_5OmniTalkerModel() = default; + + Qwen2_5OmniTalkerModel(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name) { + decode_blocks_ = reg>("layers", cfg.num_hidden_layers, cfg); + for (auto [idx, b] : enumerate(decode_blocks_.list())) { b.self_attn_.layer_idx_ = idx; } + + norm_ = reg("norm", cfg.rms_norm_eps); + embedding_ = reg("embed_tokens", cfg.vocab_size, cfg.embedding_size); + + auto inv = makeTalkerRoPEInvFreq(cfg.head_dim, cfg.rope_theta); + registerBuffer("inv_freq", inv); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + auto& blocks = decode_blocks_.list(); + auto x = inputs[0]; + auto llm_embedding_sin = inputs[1]; + auto llm_embedding_cos = inputs[2]; + auto& kv_cache = args[0]; + + for (auto& block : blocks) { x = block(x, llm_embedding_sin, llm_embedding_cos, kv_cache)[0]; } + x = norm_(x); + + return {x}; + } + + nn::Embedding embedding_; +}; + +struct Qwen2_5OmniTalkerOutput { + Tensor logits = Tensor::nil(); + Tensor thinker_reply_part = Tensor::nil(); + Tensor position_ids = Tensor::nil(); +}; + +class Qwen2_5OmniTalker final : public nn::Module { + public: + Qwen2_5OmniTalker() = delete; + Qwen2_5OmniTalker(const std::string& name, const Qwen2_5OmniTalkerConfig& cfg) : nn::Module(name), cfg_(cfg) { + thinker_to_talker_proj_ = reg("thinker_to_talker_proj", cfg.embedding_size, cfg.hidden_size, true); + model_ = reg("model", cfg); + codec_head_ = reg("codec_head", cfg.hidden_size, cfg.vocab_size, false); + + kv_cache_ = nn::StaticCache(cfg.max_position_embeddings, cfg.num_hidden_layers, cfg.num_attention_heads, cfg.num_key_value_heads, + cfg.head_dim, kFloat32, kFloat32, kCPU, false); + + codec_bos_token_ = cfg.tts_codec_start_token_id; + codec_eos_token_ = cfg.tts_codec_end_token_id; + codec_pad_token_ = cfg.tts_codec_pad_token_id; + codec_mask_token_ = cfg.tts_codec_mask_token_id; + text_bos_token_ = cfg.tts_text_start_token_id; + text_eos_token_ = cfg.tts_text_end_token_id; + text_pad_token_ = cfg.tts_text_pad_token_id; + } + + void clearCache() { + kv_cache_.clearCache(); + rope_deltas_ = Tensor::nil(); + } + + Qwen2_5OmniTalkerOutput forward(const Tensor& input_ids, const Tensor& input_text_ids, Tensor thinker_reply_part, + Tensor inputs_embeds, const Tensor& attention_mask, const Tensor& image_grid_thw, + Tensor position_ids) { + Tensor ids_for_pos = input_text_ids.isNil() ? input_ids : input_text_ids; + position_ids = getPositionIds(ids_for_pos, image_grid_thw, position_ids); + + const bool prefill = kv_cache_.getCurrentSeqCnt(0) == 0; + if (!inputs_embeds.isNil() && prefill) { + const auto S = inputs_embeds.shape()[1]; + MLLM_RT_ASSERT(S >= 2); + + auto bos_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + bos_token.at({0, 0}) = codec_bos_token_; + auto bos_embed = model_.embedding_(bos_token); + + auto pad_token = Tensor::empty({1, 1}, kInt64, kCPU).alloc(); + pad_token.at({0, 0}) = codec_pad_token_; + auto pad_embed = model_.embedding_(pad_token); + + auto embed_dim = inputs_embeds.shape()[2]; + if (inputs_embeds.dtype() == kFloat32) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] += bos_ptr[d]; + pad_ptr[d] += pad_src_ptr[d]; + } + } else if (inputs_embeds.dtype() == kFloat16) { + auto* out_ptr = inputs_embeds.offsettedPtr({0, S - 1, 0}); + auto* pad_ptr = inputs_embeds.offsettedPtr({0, S - 2, 0}); + auto* bos_ptr = bos_embed.ptr(); + auto* pad_src_ptr = pad_embed.ptr(); + for (int d = 0; d < embed_dim; ++d) { + out_ptr[d] = static_cast(static_cast(out_ptr[d]) + static_cast(bos_ptr[d])); + pad_ptr[d] = static_cast(static_cast(pad_ptr[d]) + static_cast(pad_src_ptr[d])); + } + } + } + + if (inputs_embeds.isNil()) { + auto codec_embeds = model_.embedding_(input_ids); + inputs_embeds = codec_embeds + thinker_reply_part[{kAll, {0, 1}, kAll}]; + if (thinker_reply_part.shape()[1] > 1) { + thinker_reply_part = thinker_reply_part[{kAll, {1, thinker_reply_part.shape()[1]}, kAll}]; + } + } + + auto [llm_embedding_sin, llm_embedding_cos] = + makeTalkerPositionEmbedding(position_ids, model_.getBuffer("inv_freq"), cfg_.mrope_section); + + auto talker_lm_input = thinker_to_talker_proj_(inputs_embeds); + auto hidden_states = model_(talker_lm_input, llm_embedding_sin, llm_embedding_cos, AnyValue(&kv_cache_))[0]; + auto logits = codec_head_(hidden_states).to(kFloat32); + + return { + .logits = logits, + .thinker_reply_part = thinker_reply_part, + .position_ids = position_ids, + }; + } + + int64_t codec_bos_token() const { return codec_bos_token_; } + int64_t codec_eos_token() const { return codec_eos_token_; } + int64_t codec_pad_token() const { return codec_pad_token_; } + int64_t codec_mask_token() const { return codec_mask_token_; } + int64_t text_eos_token() const { return text_eos_token_; } + int64_t text_pad_token() const { return text_pad_token_; } + int64_t text_bos_token() const { return text_bos_token_; } + + Qwen2_5OmniTalkerModel model_; + + private: + Tensor getPositionIds(const Tensor& input_ids, const Tensor& image_grid_thw, const Tensor& position_ids) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + bool has_multimodal = false; + auto input_ids_ptr = input_ids.ptr(); + auto seq_len = input_ids.shape()[1]; + for (int s = 0; s < seq_len; ++s) { + if (input_ids_ptr[s] == cfg_.vision_start_token_id || input_ids_ptr[s] == cfg_.audio_start_token_id) { + has_multimodal = true; + break; + } + } + + if (has_multimodal) { return getPositionIdsPrefill(input_ids, image_grid_thw); } + + if (!position_ids.isNil()) { + auto last_pos = position_ids.constAt({0, 0, position_ids.shape()[2] - 1}); + auto ret_position_ids = Tensor::empty({3, 1, 1}, kInt64, kCPU).alloc(); + *ret_position_ids.offsettedPtr({0, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({1, 0, 0}) = last_pos + 1; + *ret_position_ids.offsettedPtr({2, 0, 0}) = last_pos + 1; + return ret_position_ids; + } + + auto B = input_ids.shape()[0]; + auto S = seq_len; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor out = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + for (int d = 0; d < 3; ++d) { + auto out_ptr = out.offsettedPtr({d, 0, 0}); + for (int64_t s = 0; s < S; ++s) { out_ptr[s] = s; } + } + return out; + } + + Tensor getPositionIdsPrefill(const Tensor& input_ids, const Tensor& image_grid_thw) const { + MLLM_RT_ASSERT_EQ(input_ids.shape().size(), 2); + + auto B = input_ids.shape()[0]; + auto S = input_ids.shape()[1]; + MLLM_RT_ASSERT_EQ(B, 1); + + Tensor position_ids = Tensor::empty({3, B, S}, kInt64, kCPU).alloc(); + auto input_ids_ptr = input_ids.ptr(); + + auto fill_text_positions = [&](int start_seq, int len, int64_t start_id) { + for (int d = 0; d < 3; ++d) { + auto out_ptr = position_ids.offsettedPtr({d, 0, 0}); + for (int i = 0; i < len; ++i) { out_ptr[start_seq + i] = start_id + i; } + } + }; + + int seq_idx = 0; + int image_idx = 0; + int64_t current_max_position_id = -1; + const int total_images = image_grid_thw.isNil() ? 0 : image_grid_thw.shape()[0]; + + while (seq_idx < S) { + int next_vision = -1; + int next_audio = -1; + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_start_token_id) { + next_vision = i; + break; + } + } + for (int i = seq_idx; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_start_token_id) { + next_audio = i; + break; + } + } + + if (next_vision == -1 && next_audio == -1) { + const int text_len = S - seq_idx; + if (text_len > 0) { fill_text_positions(seq_idx, text_len, current_max_position_id + 1); } + break; + } + + const bool is_vision = (next_vision != -1) && (next_audio == -1 || next_vision < next_audio); + const int segment_start = is_vision ? next_vision : next_audio; + + const int text_len = segment_start - seq_idx; + if (text_len > 0) { + fill_text_positions(seq_idx, text_len, current_max_position_id + 1); + current_max_position_id += text_len; + } + + if (is_vision) { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int vision_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.vision_end_token_id) { + vision_end = i; + break; + } + } + MLLM_RT_ASSERT(vision_end != -1); + + if (image_idx >= total_images) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Image index out of range."); } + + auto grid_t = image_grid_thw.ptr()[image_idx * 3]; + auto grid_h = image_grid_thw.ptr()[image_idx * 3 + 1]; + auto grid_w = image_grid_thw.ptr()[image_idx * 3 + 2]; + int vision_len = grid_t * grid_h * grid_w; + vision_len /= (cfg_.spatial_merge_size * cfg_.spatial_merge_size); + + for (int i = 0; i < vision_len; ++i) { + const int pos = segment_start + 1 + i; + if (pos >= S) { break; } + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, pos}) = current_max_position_id + 1 + i; + } + } + current_max_position_id += vision_len; + + fill_text_positions(vision_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + + seq_idx = vision_end + 1; + image_idx += 1; + } else { + fill_text_positions(segment_start, 1, current_max_position_id + 1); + current_max_position_id += 1; + + int audio_end = -1; + for (int i = segment_start + 1; i < S; ++i) { + if (input_ids_ptr[i] == cfg_.audio_end_token_id) { + audio_end = i; + break; + } + } + MLLM_RT_ASSERT(audio_end != -1); + + std::vector audio_positions; + for (int i = segment_start + 1; i < audio_end; ++i) { + if (input_ids_ptr[i] == cfg_.audio_token_id) { + audio_positions.push_back(i); + } else { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Unsupported token inside audio segment."); + } + } + const int audio_len = static_cast(audio_positions.size()); + if (audio_len == 0) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Empty audio tokens inside audio segment."); } + const int64_t audio_start_id = current_max_position_id + 1; + for (int i = 0; i < audio_len; ++i) { + const int64_t pos_id = audio_start_id + i; + for (int d = 0; d < 3; ++d) { + *position_ids.offsettedPtr({d, 0, audio_positions[i]}) = pos_id; + } + } + current_max_position_id += audio_len; + fill_text_positions(audio_end, 1, current_max_position_id + 1); + current_max_position_id += 1; + seq_idx = audio_end + 1; + } + } + + return position_ids; + } + + const Qwen2_5OmniTalkerConfig& cfg_; + nn::Linear thinker_to_talker_proj_; + nn::Linear codec_head_; + nn::StaticCache kv_cache_; + Tensor rope_deltas_ = Tensor::nil(); + + int64_t codec_bos_token_ = 0; + int64_t codec_eos_token_ = 0; + int64_t codec_pad_token_ = 0; + int64_t codec_mask_token_ = 0; + int64_t text_bos_token_ = 0; + int64_t text_eos_token_ = 0; + int64_t text_pad_token_ = 0; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp new file mode 100644 index 000000000..6e5939a44 --- /dev/null +++ b/mllm/models/qwen2_5omni/modeling_qwen2_5omni_token2wav.hpp @@ -0,0 +1,1508 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "mllm/core/Parallel.hpp" +#include "mllm/core/SlicePrimitives.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Functional.hpp" +#include "mllm/nn/Module.hpp" +#include "mllm/nn/Nn.hpp" +#include "mllm/utils/Common.hpp" +#include "mllm/utils/Enumerate.hpp" + +#include "mllm/models/qwen2_5omni/configuration_qwen2_5omni.hpp" + +namespace mllm::models::qwen2_5omni { + +namespace token2wav { + +constexpr float kPi = 3.14159265358979323846f; + +inline Tensor pad1dReflect(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReflect); +} + +inline Tensor pad1dReplicate(const Tensor& x, int32_t pad_left, int32_t pad_right) { + if (pad_left == 0 && pad_right == 0) { return x; } + return nn::functional::pad(x, {pad_left, pad_right}, aops::PadMode::kReplicate); +} + +inline Tensor clampTensor(const Tensor& x, float min_val, float max_val) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + + auto out = Tensor::empty(x.shape(), x.dtype(), x.device()).alloc(); + const auto* src = x.ptr(); + auto* dst = out.ptr(); + const auto numel = x.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = src[idx]; + v = std::min(std::max(v, min_val), max_val); + dst[idx] = v; + }); + return out; +} + +inline Tensor amplitudeToDb(const Tensor& amplitude, float min_db_level) { + MLLM_RT_ASSERT_EQ(amplitude.device(), kCPU); + MLLM_RT_ASSERT_EQ(amplitude.dtype(), kFloat32); + + const float min_level = std::exp(min_db_level / 20.0f * std::log(10.0f)); + const float log10_scale = 1.0f / std::log(10.0f); + + auto out = Tensor::empty(amplitude.shape(), amplitude.dtype(), amplitude.device()).alloc(); + const auto* src = amplitude.ptr(); + auto* dst = out.ptr(); + const auto numel = amplitude.numel(); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = std::max(src[idx], min_level); + dst[idx] = 20.0f * std::log(v) * log10_scale; + }); + + return out; +} + +inline Tensor normalizeSpectrogram(const Tensor& spectrogram, float max_value, float min_db) { + MLLM_RT_ASSERT_EQ(spectrogram.device(), kCPU); + MLLM_RT_ASSERT_EQ(spectrogram.dtype(), kFloat32); + + auto out = Tensor::empty(spectrogram.shape(), spectrogram.dtype(), spectrogram.device()).alloc(); + const auto* src = spectrogram.ptr(); + auto* dst = out.ptr(); + const auto numel = spectrogram.numel(); + + const float scale = (2.0f * max_value) / (-min_db); + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + float v = scale * (src[idx] - min_db) - max_value; + v = std::min(std::max(v, -max_value), max_value); + dst[idx] = v; + }); + return out; +} + +inline float besselI0(float x) { + const float ax = std::abs(x); + if (ax < 3.75f) { + const float y = (ax / 3.75f); + const float y2 = y * y; + return 1.0f + y2 * (3.5156229f + + y2 * (3.0899424f + + y2 * (1.2067492f + + y2 * (0.2659732f + + y2 * (0.0360768f + + y2 * 0.0045813f))))); + } + + const float y = 3.75f / ax; + const float exp_ax = std::exp(ax); + return (exp_ax / std::sqrt(ax)) * + (0.39894228f + + y * (0.01328592f + + y * (0.00225319f + + y * (-0.00157565f + + y * (0.00916281f + + y * (-0.02057706f + + y * (0.02635537f + + y * (-0.01647633f + + y * 0.00392377f)))))))); +} + +inline Tensor kaiserSincFilter1d(float cutoff, float half_width, int32_t kernel_size) { + const bool is_even = (kernel_size % 2) == 0; + const int32_t half_size = kernel_size / 2; + + if (cutoff == 0.0f) { return Tensor::zeros({1, 1, kernel_size}, kFloat32, kCPU); } + + const float delta_f = 4.0f * half_width; + const float attenuation = 2.285f * static_cast(half_size - 1) * kPi * delta_f + 7.95f; + + float beta = 0.0f; + if (attenuation > 50.0f) { + beta = 0.1102f * (attenuation - 8.7f); + } else if (attenuation >= 21.0f) { + beta = 0.5842f * std::pow(attenuation - 21.0f, 0.4f) + 0.07886f * (attenuation - 21.0f); + } + + const float denom = besselI0(beta); + std::vector window(kernel_size, 1.0f); + for (int32_t n = 0; n < kernel_size; ++n) { + const float ratio = (2.0f * static_cast(n) / static_cast(kernel_size - 1)) - 1.0f; + const float val = std::sqrt(std::max(0.0f, 1.0f - ratio * ratio)); + window[n] = besselI0(beta * val) / denom; + } + + std::vector filter(kernel_size, 0.0f); + float sum = 0.0f; + for (int32_t n = 0; n < kernel_size; ++n) { + float t = static_cast(n) - static_cast(half_size); + if (is_even) { t += 0.5f; } + const float arg = 2.0f * cutoff * t; + const float sinc = (arg == 0.0f) ? 1.0f : std::sin(kPi * arg) / (kPi * arg); + const float v = 2.0f * cutoff * window[n] * sinc; + filter[n] = v; + sum += v; + } + + if (sum != 0.0f) { + for (auto& v : filter) { v /= sum; } + } + + auto out = Tensor::empty({1, 1, kernel_size}, kFloat32, kCPU).alloc(); + std::copy(filter.begin(), filter.end(), out.ptr()); + return out; +} + +inline Tensor convTranspose1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - 1) * stride + kernel; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t i = 0; i < in_len; ++i) { + const float v = in_c[i]; + const int32_t base = i * stride; + for (int32_t k = 0; k < kernel; ++k) { out_c[base + k] += v * f[k]; } + } + } + } + + return out; +} + +inline Tensor conv1dDepthwise(const Tensor& input, const Tensor& filter, int32_t stride) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(filter.device(), kCPU); + MLLM_RT_ASSERT_EQ(filter.dtype(), kFloat32); + + const auto& in_shape = input.shape(); + const int32_t batch = in_shape[0]; + const int32_t channels = in_shape[1]; + const int32_t in_len = in_shape[2]; + const int32_t kernel = filter.shape()[2]; + + const int32_t out_len = (in_len - kernel) / stride + 1; + auto out = Tensor::zeros({batch, channels, out_len}, kFloat32, kCPU); + + const auto* in_ptr = input.ptr(); + const auto* filt_ptr = filter.ptr(); + auto* out_ptr = out.ptr(); + + const int32_t in_step = channels * in_len; + const int32_t out_step = channels * out_len; + + for (int32_t b = 0; b < batch; ++b) { + const float* in_b = in_ptr + b * in_step; + float* out_b = out_ptr + b * out_step; + for (int32_t c = 0; c < channels; ++c) { + const float* in_c = in_b + c * in_len; + float* out_c = out_b + c * out_len; + const float* f = filt_ptr; + for (int32_t o = 0; o < out_len; ++o) { + float sum = 0.0f; + const int32_t base = o * stride; + for (int32_t k = 0; k < kernel; ++k) { sum += in_c[base + k] * f[k]; } + out_c[o] = sum; + } + } + } + + return out; +} + +inline Tensor randomNormal(const std::vector& shape, float mean = 0.0f, float std = 1.0f) { + auto out = Tensor::empty(shape, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + const int64_t numel = out.numel(); + std::mt19937 gen(static_cast(mllm::Context::instance().getRandomState())); + std::normal_distribution dist(mean, std); + for (int64_t i = 0; i < numel; ++i) { ptr[i] = dist(gen); } + return out; +} + +inline Tensor linspace(float start, float end, int32_t steps) { + auto out = Tensor::empty({steps}, kFloat32, kCPU).alloc(); + auto* ptr = out.ptr(); + if (steps <= 1) { + if (steps == 1) { ptr[0] = start; } + return out; + } + const float step = (end - start) / static_cast(steps - 1); + for (int32_t i = 0; i < steps; ++i) { ptr[i] = start + step * static_cast(i); } + return out; +} + +inline Tensor repeatInterleave(const Tensor& input, int32_t repeats, int32_t dim) { + MLLM_RT_ASSERT_EQ(input.device(), kCPU); + MLLM_RT_ASSERT_EQ(input.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(dim, 1); + + if (repeats == 1) { return input; } + + const auto& shape = input.shape(); + const int32_t batch = shape[0]; + const int32_t seq_len = shape[1]; + const int32_t channels = shape[2]; + + auto out = Tensor::empty({batch, seq_len * repeats, channels}, kFloat32, kCPU).alloc(); + const auto* src = input.ptr(); + auto* dst = out.ptr(); + + const int64_t in_stride_b = static_cast(seq_len) * channels; + const int64_t out_stride_b = static_cast(seq_len) * repeats * channels; + + for (int32_t b = 0; b < batch; ++b) { + const float* src_b = src + b * in_stride_b; + float* dst_b = dst + b * out_stride_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float* src_s = src_b + static_cast(s) * channels; + for (int32_t r = 0; r < repeats; ++r) { + float* dst_s = dst_b + (static_cast(s) * repeats + r) * channels; + std::memcpy(dst_s, src_s, sizeof(float) * channels); + } + } + } + + return out; +} + +class SnakeBeta final : public nn::Module { + nn::Param alpha_; + nn::Param beta_; + float no_div_by_zero_ = 1e-9f; + + public: + SnakeBeta() = default; + SnakeBeta(const std::string& name, int32_t in_features) : nn::Module(name) { + alpha_ = reg("alpha", getModuleName() + ".alpha", std::vector{in_features}); + beta_ = reg("beta", getModuleName() + ".beta", std::vector{in_features}); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(x.dtype(), kFloat32); + if (!x.isContiguous()) { x = x.contiguous(); } + + const auto& shape = x.shape(); + const int32_t batch = shape[0]; + const int32_t channels = shape[1]; + const int32_t seq_len = shape[2]; + + auto y = Tensor::empty(shape, kFloat32, kCPU).alloc(); + const auto* x_ptr = x.ptr(); + auto* y_ptr = y.ptr(); + + auto alpha = alpha_.weight(); + auto beta = beta_.weight(); + const auto* alpha_ptr = alpha.ptr(); + const auto* beta_ptr = beta.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const float a = std::exp(alpha_ptr[c]); + const float bb = std::exp(beta_ptr[c]); + const float inv_b = 1.0f / (bb + no_div_by_zero_); + const int32_t base = b * stride_b + c * stride_c; + for (int32_t t = 0; t < seq_len; ++t) { + float v = x_ptr[base + t]; + const float s = std::sin(v * a); + v = v + inv_b * (s * s); + y_ptr[base + t] = v; + } + } + } + + return {y}; + } + +}; + +class TorchActivation1d final : public nn::Module { + public: + TorchActivation1d() = default; + TorchActivation1d(const std::string& name, int32_t channels, int32_t up_ratio = 2, int32_t down_ratio = 2, + int32_t up_kernel_size = 12, int32_t down_kernel_size = 12) + : nn::Module(name), + up_ratio_(up_ratio), + down_ratio_(down_ratio), + up_kernel_size_(up_kernel_size), + down_kernel_size_(down_kernel_size) { + act_ = reg("act", channels); + + up_kernel_size_ = (up_kernel_size_ <= 0) ? static_cast(int(6 * up_ratio_ / 2) * 2) : up_kernel_size_; + up_stride_ = up_ratio_; + up_pad_ = up_kernel_size_ / up_ratio_ - 1; + up_pad_left_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_) / 2; + up_pad_right_ = up_pad_ * up_stride_ + (up_kernel_size_ - up_stride_ + 1) / 2; + + down_kernel_size_ = (down_kernel_size_ <= 0) ? static_cast(int(6 * down_ratio_ / 2) * 2) : down_kernel_size_; + down_stride_ = down_ratio_; + down_even_ = (down_kernel_size_ % 2) == 0; + down_pad_left_ = down_kernel_size_ / 2 - (down_even_ ? 1 : 0); + down_pad_right_ = down_kernel_size_ / 2; + + up_filter_ = kaiserSincFilter1d(0.5f / static_cast(up_ratio_), 0.6f / static_cast(up_ratio_), up_kernel_size_); + down_filter_ = + kaiserSincFilter1d(0.5f / static_cast(down_ratio_), 0.6f / static_cast(down_ratio_), down_kernel_size_); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + x = upsample(x); + x = act_(x)[0]; + x = downsample(x); + return {x}; + } + + private: + Tensor upsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, up_pad_, up_pad_); + auto out = convTranspose1dDepthwise(padded, up_filter_, up_stride_); + out = out * static_cast(up_ratio_); + if (up_pad_left_ > 0 || up_pad_right_ > 0) { + auto length = out.shape()[2]; + auto start = up_pad_left_; + auto end = length - up_pad_right_; + out = out[{kAll, kAll, {start, end}}]; + } + return out; + } + + Tensor downsample(const Tensor& input) const { + auto padded = pad1dReplicate(input, down_pad_left_, down_pad_right_); + auto out = conv1dDepthwise(padded, down_filter_, down_stride_); + return out; + } + + SnakeBeta act_; + int32_t up_ratio_ = 2; + int32_t down_ratio_ = 2; + int32_t up_kernel_size_ = 12; + int32_t down_kernel_size_ = 12; + int32_t up_stride_ = 2; + int32_t down_stride_ = 2; + int32_t up_pad_ = 0; + int32_t up_pad_left_ = 0; + int32_t up_pad_right_ = 0; + int32_t down_pad_left_ = 0; + int32_t down_pad_right_ = 0; + bool down_even_ = false; + Tensor up_filter_ = Tensor::nil(); + Tensor down_filter_ = Tensor::nil(); +}; + +class TimeDelayNetBlock final : public nn::Module { + public: + TimeDelayNetBlock() = default; + TimeDelayNetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), kernel_size_(kernel_size), dilation_(dilation) { + conv_ = reg("conv", in_channels, out_channels, kernel_size_, 1, 0, dilation_, 1, true); + relu_ = reg("relu"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t pad_total = dilation_ * (kernel_size_ - 1); + const int32_t pad_left = pad_total / 2; + const int32_t pad_right = pad_total - pad_left; + if (pad_total > 0) { x = pad1dReflect(x, pad_left, pad_right); } + x = conv_(x); + x = relu_(x); + return {x}; + } + + private: + nn::Conv1D conv_; + nn::ReLU relu_; + int32_t kernel_size_ = 1; + int32_t dilation_ = 1; +}; + +class Res2NetBlock final : public nn::Module { + public: + Res2NetBlock() = default; + Res2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t scale, int32_t kernel_size, int32_t dilation) + : nn::Module(name), scale_(scale) { + const int32_t in_channel = in_channels / scale; + const int32_t hidden_channel = out_channels / scale; + blocks_ = reg>("blocks", scale_ - 1, in_channel, hidden_channel, kernel_size, dilation); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto x = inputs[0]; + const int32_t channels = x.shape()[1]; + const int32_t split = channels / scale_; + + std::vector outputs; + outputs.reserve(scale_); + Tensor output_part = Tensor::nil(); + + for (int32_t i = 0; i < scale_; ++i) { + auto hidden_part = x[{kAll, {i * split, (i + 1) * split}, kAll}]; + if (i == 0) { + output_part = hidden_part; + } else if (i == 1) { + output_part = blocks_.list()[i - 1](hidden_part)[0]; + } else { + output_part = blocks_.list()[i - 1](hidden_part + output_part)[0]; + } + outputs.push_back(output_part); + } + + auto out = nn::functional::concat(outputs, 1); + return {out}; + } + + private: + int32_t scale_ = 1; + nn::ModuleList blocks_; +}; + +class SqueezeExcitationBlock final : public nn::Module { + public: + SqueezeExcitationBlock() = default; + SqueezeExcitationBlock(const std::string& name, int32_t in_channels, int32_t se_channels, int32_t out_channels) + : nn::Module(name) { + conv1_ = reg("conv1", in_channels, se_channels, 1, 1, 0, 1, 1, true); + conv2_ = reg("conv2", se_channels, out_channels, 1, 1, 0, 1, 1, true); + relu_ = reg("relu"); + sigmoid_ = reg("sigmoid"); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto hidden_mean = nn::functional::mean(hidden_states, 2, true); + hidden_mean = relu_(conv1_(hidden_mean)); + hidden_mean = sigmoid_(conv2_(hidden_mean)); + hidden_states = hidden_states * hidden_mean; + return {hidden_states}; + } + + private: + nn::Conv1D conv1_; + nn::Conv1D conv2_; + nn::ReLU relu_; + nn::Sigmoid sigmoid_; +}; + +class AttentiveStatisticsPooling final : public nn::Module { + public: + AttentiveStatisticsPooling() = default; + AttentiveStatisticsPooling(const std::string& name, int32_t channels, int32_t attention_channels) + : nn::Module(name), channels_(channels) { + tdnn_ = reg("tdnn", channels * 3, attention_channels, 1, 1); + tanh_ = reg("tanh"); + conv_ = reg("conv", attention_channels, channels, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t channels = hidden_states.shape()[1]; + const int32_t seq_len = hidden_states.shape()[2]; + + auto mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + + const auto* x_ptr = hidden_states.ptr(); + auto* mean_ptr = mean.ptr(); + auto* std_ptr = std.ptr(); + + const int32_t stride_c = seq_len; + const int32_t stride_b = channels * seq_len; + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float sum = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { sum += x_ptr[base + t]; } + float m = sum / static_cast(seq_len); + mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += diff * diff; + } + var /= static_cast(seq_len); + std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto mean_rep = mean.view({batch, channels, 1}).repeat(seq_len, 2); + auto std_rep = std.view({batch, channels, 1}).repeat(seq_len, 2); + + auto attention = nn::functional::concat({hidden_states, mean_rep, std_rep}, 1); + attention = tdnn_(attention)[0]; + attention = tanh_(attention); + attention = conv_(attention); + attention = nn::functional::softmax(attention, 2); + + auto out_mean = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto out_std = Tensor::empty({batch, channels}, kFloat32, kCPU).alloc(); + auto* out_mean_ptr = out_mean.ptr(); + auto* out_std_ptr = out_std.ptr(); + const auto* attn_ptr = attention.ptr(); + + for (int32_t b = 0; b < batch; ++b) { + for (int32_t c = 0; c < channels; ++c) { + const int32_t base = b * stride_b + c * stride_c; + float m = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { m += attn_ptr[base + t] * x_ptr[base + t]; } + out_mean_ptr[b * channels + c] = m; + + float var = 0.0f; + for (int32_t t = 0; t < seq_len; ++t) { + float diff = x_ptr[base + t] - m; + var += attn_ptr[base + t] * diff * diff; + } + out_std_ptr[b * channels + c] = std::sqrt(std::max(var, 1e-12f)); + } + } + + auto pooled = nn::functional::concat({out_mean, out_std}, 1).view({batch, channels * 2, 1}); + return {pooled}; + } + + private: + int32_t channels_ = 0; + TimeDelayNetBlock tdnn_; + nn::Tanh tanh_; + nn::Conv1D conv_; +}; + +class SqueezeExcitationRes2NetBlock final : public nn::Module { + public: + SqueezeExcitationRes2NetBlock() = default; + SqueezeExcitationRes2NetBlock(const std::string& name, int32_t in_channels, int32_t out_channels, int32_t res2net_scale, + int32_t se_channels, int32_t kernel_size, int32_t dilation) + : nn::Module(name), out_channels_(out_channels) { + tdnn1_ = reg("tdnn1", in_channels, out_channels, 1, 1); + res2net_block_ = reg("res2net_block", out_channels, out_channels, res2net_scale, kernel_size, dilation); + tdnn2_ = reg("tdnn2", out_channels, out_channels, 1, 1); + se_block_ = reg("se_block", out_channels, se_channels, out_channels); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_state = inputs[0]; + auto residual = hidden_state; + + hidden_state = tdnn1_(hidden_state)[0]; + hidden_state = res2net_block_(hidden_state)[0]; + hidden_state = tdnn2_(hidden_state)[0]; + hidden_state = se_block_(hidden_state)[0]; + hidden_state = hidden_state + residual; + return {hidden_state}; + } + + private: + int32_t out_channels_ = 0; + TimeDelayNetBlock tdnn1_; + Res2NetBlock res2net_block_; + TimeDelayNetBlock tdnn2_; + SqueezeExcitationBlock se_block_; +}; + +class ECAPA_TimeDelayNet final : public nn::Module { + public: + ECAPA_TimeDelayNet() = default; + explicit ECAPA_TimeDelayNet(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + if (cfg.enc_channels.size() != cfg.enc_kernel_sizes.size() || cfg.enc_channels.size() != cfg.enc_dilations.size()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels, enc_kernel_sizes and enc_dilations should have same length"); + } + + if (cfg.enc_channels.empty()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "enc_channels should not be empty"); + } + + const int32_t num_blocks = static_cast(cfg.enc_channels.size()); + tdnn0_ = reg("blocks.0", cfg.mel_dim, cfg.enc_channels[0], cfg.enc_kernel_sizes[0], cfg.enc_dilations[0]); + + for (int32_t i = 1; i < num_blocks - 1; ++i) { + se_blocks_.emplace_back(reg( + "blocks." + std::to_string(i), + cfg.enc_channels[i - 1], + cfg.enc_channels[i], + cfg.enc_res2net_scale, + cfg.enc_se_channels, + cfg.enc_kernel_sizes[i], + cfg.enc_dilations[i])); + } + + mfa_ = reg("mfa", cfg.enc_channels.back(), cfg.enc_channels.back(), cfg.enc_kernel_sizes.back(), + cfg.enc_dilations.back()); + asp_ = reg("asp", cfg.enc_channels.back(), cfg.enc_attention_channels); + fc_ = reg("fc", cfg.enc_channels.back() * 2, cfg.enc_dim, 1, 1, 0, 1, 1, true); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + + hidden_states = hidden_states.transpose(1, 2); + + std::vector hidden_states_list; + hidden_states = tdnn0_(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + + for (auto& block : se_blocks_) { + hidden_states = block(hidden_states)[0]; + hidden_states_list.push_back(hidden_states); + } + + if (hidden_states_list.size() <= 1) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "ECAPA_TimeDelayNet expects at least 2 blocks."); + } + + std::vector mfa_inputs; + for (size_t i = 1; i < hidden_states_list.size(); ++i) { mfa_inputs.push_back(hidden_states_list[i]); } + hidden_states = nn::functional::concat(mfa_inputs, 1); + hidden_states = mfa_(hidden_states)[0]; + hidden_states = asp_(hidden_states)[0]; + hidden_states = fc_(hidden_states); + hidden_states = hidden_states.squeeze(-1); + + return {hidden_states}; + } + + private: + TimeDelayNetBlock tdnn0_; + std::vector se_blocks_; + TimeDelayNetBlock mfa_; + AttentiveStatisticsPooling asp_; + nn::Conv1D fc_; +}; + +class DiTInputEmbedding final : public nn::Module { + public: + DiTInputEmbedding() = default; + explicit DiTInputEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name) { + const int32_t in_dim = cfg.mel_dim + cfg.enc_dim + cfg.enc_emb_dim + cfg.emb_dim; + proj_ = reg("proj", in_dim, cfg.hidden_size, true); + spk_encoder_ = reg("spk_encoder", cfg); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& speaker_embedding, const Tensor& condition_vector, const Tensor& code_embed, + bool drop_audio_cond, const Tensor& code_embed_uncond, bool apply_cfg) { + auto x = hidden_states; + auto spk = speaker_embedding; + auto cond = condition_vector; + auto code = code_embed; + + if (apply_cfg) { + x = nn::functional::concat({x, x}, 0); + spk = nn::functional::concat({spk, Tensor::zeros(spk.shape(), spk.dtype(), spk.device())}, 0); + cond = nn::functional::concat({cond, Tensor::zeros(cond.shape(), cond.dtype(), cond.device())}, 0); + code = nn::functional::concat({code, code_embed_uncond}, 0); + } else if (drop_audio_cond) { + cond = Tensor::zeros(cond.shape(), cond.dtype(), cond.device()); + spk = Tensor::zeros(spk.shape(), spk.dtype(), spk.device()); + } + + auto cond_embed = spk_encoder_(cond)[0]; + const int32_t seq_len = x.shape()[1]; + cond_embed = cond_embed.view({cond_embed.shape()[0], 1, cond_embed.shape()[1]}).repeat(seq_len, 1); + + auto merged = nn::functional::concat({x, cond_embed, code, spk}, -1); + auto out = proj_(merged); + return out; + } + + private: + nn::Linear proj_; + ECAPA_TimeDelayNet spk_encoder_; +}; + +class DiTCodecEmbedding final : public nn::Module { + public: + DiTCodecEmbedding() = default; + DiTCodecEmbedding(const std::string& name, int32_t codec_num_embeds, int32_t codec_dim, int32_t repeats) + : nn::Module(name), repeats_(repeats) { + codec_embed_ = reg("codec_embed", codec_num_embeds + 1, codec_dim); + } + + Tensor forward(const Tensor& code, bool drop_code) { + Tensor code_ids = code; + if (drop_code) { code_ids = Tensor::zeros(code.shape(), code.dtype(), code.device()); } + auto code_embed = codec_embed_(code_ids); + return repeatInterleave(code_embed, repeats_, 1); + } + + private: + int32_t repeats_ = 1; + nn::Embedding codec_embed_; +}; + +class Qwen2_5_OmniAdaLayerNormZero final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero() = default; + Qwen2_5_OmniAdaLayerNormZero(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 6, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + std::vector forward(const std::vector& inputs, const std::vector&) override { + auto hidden_states = inputs[0]; + auto emb = inputs[1]; + emb = linear_(silu_(emb)); + + auto chunks = nn::functional::chunk<6>(emb, 1); + auto shift_msa = chunks[0]; + auto scale_msa = chunks[1]; + auto gate_msa = chunks[2]; + auto shift_mlp = chunks[3]; + auto scale_mlp = chunks[4]; + auto gate_mlp = chunks[5]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + auto scale = scale_msa.view({scale_msa.shape()[0], 1, scale_msa.shape()[1]}).repeat(seq_len, 1); + auto shift = shift_msa.view({shift_msa.shape()[0], 1, shift_msa.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + + return {normed, gate_msa, shift_mlp, scale_mlp, gate_mlp}; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class Qwen2_5_OmniAdaLayerNormZero_Final final : public nn::Module { + public: + Qwen2_5_OmniAdaLayerNormZero_Final() = default; + Qwen2_5_OmniAdaLayerNormZero_Final(const std::string& name, int32_t dim) : nn::Module(name) { + silu_ = reg("silu"); + linear_ = reg("linear", dim, dim * 2, true); + norm_ = reg("norm", std::vector{dim}, false, false, 1e-6f); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& emb) { + auto emb_out = linear_(silu_(emb)); + auto chunks = nn::functional::chunk<2>(emb_out, 1); + auto scale = chunks[0]; + auto shift = chunks[1]; + + auto normed = norm_(hidden_states); + const int32_t seq_len = hidden_states.shape()[1]; + scale = scale.view({scale.shape()[0], 1, scale.shape()[1]}).repeat(seq_len, 1); + shift = shift.view({shift.shape()[0], 1, shift.shape()[1]}).repeat(seq_len, 1); + normed = normed * (scale + 1.0f) + shift; + return normed; + } + + private: + nn::SiLU silu_; + nn::Linear linear_; + nn::LayerNorm norm_; +}; + +class DiTMLP final : public nn::Module { + public: + DiTMLP() = default; + DiTMLP(const std::string& name, int32_t dim, int32_t mult) : nn::Module(name) { + const int32_t inner_dim = dim * mult; + fc1_ = reg("ff.0", dim, inner_dim, true); + act_ = reg("ff.1"); + fc2_ = reg("ff.3", inner_dim, dim, true); + } + + Tensor forward(const Tensor& hidden_states) { + auto x = fc1_(hidden_states); + x = act_(x); + x = fc2_(x); + return x; + } + + private: + nn::Linear fc1_; + nn::GELU act_; + nn::Linear fc2_; +}; + +inline void applyRotaryPosEmbFirstHead(Tensor& q, Tensor& k, const Tensor& cos, const Tensor& sin) { + MLLM_RT_ASSERT_EQ(q.device(), kCPU); + MLLM_RT_ASSERT_EQ(k.device(), kCPU); + MLLM_RT_ASSERT_EQ(cos.device(), kCPU); + MLLM_RT_ASSERT_EQ(sin.device(), kCPU); + MLLM_RT_ASSERT_EQ(q.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(k.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(cos.dtype(), kFloat32); + MLLM_RT_ASSERT_EQ(sin.dtype(), kFloat32); + + const int32_t batch = q.shape()[0]; + const int32_t heads = q.shape()[1]; + const int32_t seq_len = q.shape()[2]; + const int32_t head_dim = q.shape()[3]; + MLLM_RT_ASSERT_EQ(head_dim % 2, 0); + MLLM_RT_ASSERT_EQ(cos.shape()[0], batch); + MLLM_RT_ASSERT_EQ(cos.shape()[1], seq_len); + MLLM_RT_ASSERT_EQ(cos.shape()[2], head_dim); + + const auto* cos_ptr = cos.ptr(); + const auto* sin_ptr = sin.ptr(); + auto* q_ptr = q.ptr(); + auto* k_ptr = k.ptr(); + + const int64_t stride_q_b = static_cast(heads) * seq_len * head_dim; + const int64_t stride_q_h = static_cast(seq_len) * head_dim; + const int64_t stride_q_s = head_dim; + + const int64_t stride_cos_b = static_cast(seq_len) * head_dim; + const int64_t stride_cos_s = head_dim; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t q_base_b = static_cast(b) * stride_q_b; + const int64_t cos_base_b = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + float* q_row = q_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + float* k_row = k_ptr + q_base_b + 0 * stride_q_h + static_cast(s) * stride_q_s; + const float* cos_row = cos_ptr + cos_base_b + static_cast(s) * stride_cos_s; + const float* sin_row = sin_ptr + cos_base_b + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < head_dim; d += 2) { + const float c = cos_row[d]; + const float ss = sin_row[d]; + const float q1 = q_row[d]; + const float q2 = q_row[d + 1]; + const float k1 = k_row[d]; + const float k2 = k_row[d + 1]; + q_row[d] = q1 * c - q2 * ss; + q_row[d + 1] = q1 * ss + q2 * c; + k_row[d] = k1 * c - k2 * ss; + k_row[d + 1] = k1 * ss + k2 * c; + } + } + } +} + +inline Tensor makeBlockDiff(int32_t batch, int32_t heads, int32_t seq_len, int32_t block_size) { + (void)heads; + MLLM_RT_ASSERT(block_size > 0); + std::vector block_indices(seq_len, 0); + for (int32_t i = 0; i < seq_len; ++i) { block_indices[i] = i / block_size; } + + std::vector base(static_cast(seq_len) * seq_len, 0.0f); + for (int32_t i = 0; i < seq_len; ++i) { + for (int32_t j = 0; j < seq_len; ++j) { + base[static_cast(i) * seq_len + j] = static_cast(block_indices[j] - block_indices[i]); + } + } + + // Use a broadcast-friendly shape to avoid materializing head copies while keeping naive broadcast support. + auto out = Tensor::empty({batch, 1, seq_len, seq_len}, kFloat32, kCPU).alloc(); + const int64_t block_stride = static_cast(seq_len) * seq_len; + auto* out_ptr = out.ptr(); + for (int32_t b = 0; b < batch; ++b) { + float* dst = out_ptr + static_cast(b) * block_stride; + std::memcpy(dst, base.data(), sizeof(float) * base.size()); + } + return out; +} + +inline Tensor makeBlockMask(const Tensor& block_diff, int32_t look_backward_block, int32_t look_ahead_block) { + MLLM_RT_ASSERT_EQ(block_diff.device(), kCPU); + MLLM_RT_ASSERT_EQ(block_diff.dtype(), kFloat32); + + auto mask = Tensor::empty(block_diff.shape(), kFloat32, kCPU).alloc(); + const auto* src = block_diff.ptr(); + auto* dst = mask.ptr(); + const int64_t numel = block_diff.numel(); + const float lower = -static_cast(look_backward_block); + const float upper = static_cast(look_ahead_block); + + MLLM_CONDITIONAL_PARALLEL_FOR(numel > 1024, 4, idx, 0, numel, 1, { + const float v = src[idx]; + dst[idx] = (v >= lower && v <= upper) ? 0.0f : -1e4f; + }); + return mask; +} + +class DiTAttention final : public nn::Module { + public: + DiTAttention() = default; + explicit DiTAttention(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + dim_ = cfg.hidden_size; + heads_ = cfg.num_attention_heads; + head_dim_ = cfg.head_dim; + inner_dim_ = head_dim_ * heads_; + + to_q_ = reg("to_q", dim_, inner_dim_, true); + to_k_ = reg("to_k", dim_, inner_dim_, true); + to_v_ = reg("to_v", dim_, inner_dim_, true); + to_out_ = reg("to_out.0", inner_dim_, dim_, true); + } + + Tensor forward(const Tensor& hidden_states, const std::pair& position_embeddings, const Tensor& attention_mask) { + auto query = to_q_(hidden_states); + auto key = to_k_(hidden_states); + auto value = to_v_(hidden_states); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t seq_len = hidden_states.shape()[1]; + + query = query.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + key = key.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + value = value.view({batch, seq_len, heads_, head_dim_}).transpose(1, 2); + + if (!position_embeddings.first.isNil()) { + applyRotaryPosEmbFirstHead(query, key, position_embeddings.first, position_embeddings.second); + } + + auto attn_output = nn::functional::scaledDotProductAttention(query, key, value, attention_mask); + attn_output = attn_output.transpose(1, 2).view({batch, seq_len, inner_dim_}); + attn_output = to_out_(attn_output); + return attn_output; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t dim_ = 0; + int32_t heads_ = 0; + int32_t head_dim_ = 0; + int32_t inner_dim_ = 0; + nn::Linear to_q_; + nn::Linear to_k_; + nn::Linear to_v_; + nn::Linear to_out_; +}; + +class SinusPositionEmbedding final : public nn::Module { + public: + SinusPositionEmbedding() = default; + explicit SinusPositionEmbedding(const std::string& name, int32_t dim) : nn::Module(name), dim_(dim) {} + + Tensor forward(const Tensor& hidden_states, float scale = 1000.0f) { + MLLM_RT_ASSERT_EQ(hidden_states.device(), kCPU); + MLLM_RT_ASSERT_EQ(hidden_states.dtype(), kFloat32); + + const int32_t batch = hidden_states.shape()[0]; + const int32_t half_dim = dim_ / 2; + auto out = Tensor::empty({batch, dim_}, kFloat32, kCPU).alloc(); + auto* out_ptr = out.ptr(); + const auto* hs_ptr = hidden_states.ptr(); + + const float emb = std::log(10000.0f) / static_cast(half_dim - 1); + std::vector freqs(half_dim); + for (int32_t i = 0; i < half_dim; ++i) { freqs[i] = std::exp(-emb * static_cast(i)); } + + for (int32_t b = 0; b < batch; ++b) { + const float t = hs_ptr[b] * scale; + float* row = out_ptr + static_cast(b) * dim_; + for (int32_t i = 0; i < half_dim; ++i) { + const float val = t * freqs[i]; + row[i] = std::sin(val); + row[i + half_dim] = std::cos(val); + } + } + + return out; + } + + private: + int32_t dim_ = 0; +}; + +class DiTTimestepEmbedding final : public nn::Module { + public: + DiTTimestepEmbedding() = default; + explicit DiTTimestepEmbedding(const std::string& name, int32_t dim, int32_t freq_embed_dim = 256) + : nn::Module(name), freq_embed_dim_(freq_embed_dim) { + time_embed_ = reg("time_embed", freq_embed_dim_); + fc1_ = reg("time_mlp.0", freq_embed_dim_, dim, true); + act_ = reg("time_mlp.1"); + fc2_ = reg("time_mlp.2", dim, dim, true); + } + + Tensor forward(const Tensor& timestep) { + auto time_hidden = time_embed_.forward(timestep); + time_hidden = fc1_(time_hidden); + time_hidden = act_(time_hidden); + time_hidden = fc2_(time_hidden); + return time_hidden; + } + + private: + int32_t freq_embed_dim_ = 256; + SinusPositionEmbedding time_embed_; + nn::Linear fc1_; + nn::SiLU act_; + nn::Linear fc2_; +}; + +class DiTDecoderLayer final : public nn::Module { + public: + DiTDecoderLayer() = default; + DiTDecoderLayer(const std::string& name, const Qwen2_5OmniDiTConfig& cfg, int32_t look_ahead_block, int32_t look_backward_block) + : nn::Module(name), look_ahead_block_(look_ahead_block), look_backward_block_(look_backward_block) { + attn_norm_ = reg("attn_norm", cfg.hidden_size); + attn_ = reg("attn", cfg); + ff_norm_ = reg("ff_norm", std::vector{cfg.hidden_size}, false, false, 1e-6f); + ff_ = reg("ff", cfg.hidden_size, cfg.ff_mult); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& timestep, const std::pair& position_embeddings, + const Tensor& block_diff) { + auto attn_norm_out = attn_norm_(hidden_states, timestep); + auto norm = attn_norm_out[0]; + auto gate_msa = attn_norm_out[1]; + auto shift_mlp = attn_norm_out[2]; + auto scale_mlp = attn_norm_out[3]; + auto gate_mlp = attn_norm_out[4]; + + Tensor attn_mask = Tensor::nil(); + if (!block_diff.isNil()) { attn_mask = makeBlockMask(block_diff, look_backward_block_, look_ahead_block_); } + auto attn_output = attn_.forward(norm, position_embeddings, attn_mask); + + auto gate_msa_rep = gate_msa.view({gate_msa.shape()[0], 1, gate_msa.shape()[1]}).repeat(hidden_states.shape()[1], 1); + auto x = Tensor(hidden_states); + x = x + gate_msa_rep * attn_output; + + auto norm_ff = ff_norm_(x); + auto scale_rep = scale_mlp.view({scale_mlp.shape()[0], 1, scale_mlp.shape()[1]}).repeat(x.shape()[1], 1); + auto shift_rep = shift_mlp.view({shift_mlp.shape()[0], 1, shift_mlp.shape()[1]}).repeat(x.shape()[1], 1); + norm_ff = norm_ff * (scale_rep + 1.0f) + shift_rep; + auto ff_output = ff_.forward(norm_ff); + auto gate_mlp_rep = gate_mlp.view({gate_mlp.shape()[0], 1, gate_mlp.shape()[1]}).repeat(x.shape()[1], 1); + x = x + gate_mlp_rep * ff_output; + return x; + } + + private: + Qwen2_5_OmniAdaLayerNormZero attn_norm_; + DiTAttention attn_; + nn::LayerNorm ff_norm_; + DiTMLP ff_; + int32_t look_ahead_block_ = 0; + int32_t look_backward_block_ = 0; +}; + +class Qwen2_5OmniDiTRotaryEmbedding final : public nn::Module { + public: + Qwen2_5OmniDiTRotaryEmbedding() = default; + explicit Qwen2_5OmniDiTRotaryEmbedding(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + const int32_t dim = cfg.head_dim; + inv_freq_ = reg("inv_freq", getModuleName() + ".inv_freq", std::vector{dim / 2}); + attention_scaling_ = 1.0f; + + auto inv = inv_freq_.weight(); + if (!inv.isNil() && inv.numel() == 0) { + inv = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + inv_freq_.weight().copy2(inv); + } + } + + std::pair forward(const Tensor& x, const Tensor& position_ids) { + MLLM_RT_ASSERT_EQ(x.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.device(), kCPU); + MLLM_RT_ASSERT_EQ(position_ids.dtype(), kInt64); + + const int32_t batch = position_ids.shape()[0]; + const int32_t seq_len = position_ids.shape()[1]; + auto inv_freq = inv_freq_.weight(); + if (inv_freq.isNil() || inv_freq.numel() == 0) { + const int32_t dim = cfg_.head_dim; + inv_freq = Tensor::empty({dim / 2}, kFloat32, kCPU).alloc(); + auto* ptr = inv_freq.ptr(); + for (int32_t i = 0; i < dim / 2; ++i) { + ptr[i] = 1.0f / std::pow(cfg_.rope_theta, 2.0f * i / static_cast(dim)); + } + } + + const int32_t half_dim = inv_freq.shape()[0]; + auto cos = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + auto sin = Tensor::empty({batch, seq_len, half_dim * 2}, kFloat32, kCPU).alloc(); + + const auto* inv_ptr = inv_freq.ptr(); + const auto* pos_ptr = position_ids.ptr(); + auto* cos_ptr = cos.ptr(); + auto* sin_ptr = sin.ptr(); + + const int64_t stride_pos_b = seq_len; + const int64_t stride_cos_b = static_cast(seq_len) * half_dim * 2; + const int64_t stride_cos_s = half_dim * 2; + + for (int32_t b = 0; b < batch; ++b) { + const int64_t pos_base = static_cast(b) * stride_pos_b; + const int64_t out_base = static_cast(b) * stride_cos_b; + for (int32_t s = 0; s < seq_len; ++s) { + const float position = static_cast(pos_ptr[pos_base + s]); + float* cos_row = cos_ptr + out_base + static_cast(s) * stride_cos_s; + float* sin_row = sin_ptr + out_base + static_cast(s) * stride_cos_s; + for (int32_t d = 0; d < half_dim; ++d) { + const float freq = inv_ptr[d] * position; + const float c = std::cos(freq) * attention_scaling_; + const float ss = std::sin(freq) * attention_scaling_; + cos_row[d] = c; + cos_row[d + half_dim] = c; + sin_row[d] = ss; + sin_row[d + half_dim] = ss; + } + } + } + + return {cos, sin}; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + nn::Param inv_freq_; + float attention_scaling_ = 1.0f; +}; + +class RungeKutta4ODESolver { + public: + using Function = std::function; + + RungeKutta4ODESolver(Function function, Tensor initial_value) + : function_(std::move(function)), initial_value_(std::move(initial_value)) {} + + Tensor integrate(const std::vector& time_points) { + auto current_value = initial_value_; + if (time_points.size() < 2) { return current_value; } + + for (size_t i = 0; i + 1 < time_points.size(); ++i) { + const float time_start = time_points[i]; + const float time_end = time_points[i + 1]; + const float time_step = time_end - time_start; + + auto k1 = function_(time_start, current_value); + auto k2 = function_(time_start + time_step * one_third_, current_value + k1 * (time_step * one_third_)); + auto k3 = function_(time_start + time_step * two_thirds_, + current_value + (k2 - k1 * one_third_) * time_step); + auto k4 = function_(time_end, current_value + (k1 - k2 + k3) * time_step); + + auto delta = (k1 + (k2 + k3) * 3.0f + k4) * (time_step / 8.0f); + current_value = current_value + delta; + } + + return current_value; + } + + private: + Function function_; + Tensor initial_value_; + float one_third_ = 1.0f / 3.0f; + float two_thirds_ = 2.0f / 3.0f; +}; + +class Qwen2_5OmniToken2WavDiTModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavDiTModel() = default; + explicit Qwen2_5OmniToken2WavDiTModel(const std::string& name, const Qwen2_5OmniDiTConfig& cfg) : nn::Module(name), cfg_(cfg) { + mel_dim_ = cfg.mel_dim; + repeats_ = cfg.repeats; + block_size_ = cfg.block_size; + num_attention_heads_ = cfg.num_attention_heads; + + time_embed_ = reg("time_embed", cfg.hidden_size); + text_embed_ = reg("text_embed", cfg.num_embeds, cfg.emb_dim, cfg.repeats); + input_embed_ = reg("input_embed", cfg); + rotary_embed_ = reg("rotary_embed", cfg); + + for (int32_t i = 0; i < cfg.num_hidden_layers; ++i) { + const bool look_ahead = std::find(cfg.look_ahead_layers.begin(), cfg.look_ahead_layers.end(), i) != cfg.look_ahead_layers.end(); + const bool look_backward = + std::find(cfg.look_backward_layers.begin(), cfg.look_backward_layers.end(), i) != cfg.look_backward_layers.end(); + transformer_blocks_.emplace_back(reg("transformer_blocks." + std::to_string(i), cfg, look_ahead ? 1 : 0, + look_backward ? 1 : 0)); + } + + norm_out_ = reg("norm_out", cfg.hidden_size); + proj_out_ = reg("proj_out", cfg.hidden_size, cfg.mel_dim, true); + } + + Tensor forward(const Tensor& hidden_states, const Tensor& condition_vector, const Tensor& speaker_embedding, const Tensor& quantized_code, + const Tensor& time_step, bool drop_audio_conditioning, bool drop_code, bool apply_cfg) { + Tensor timestep = time_step; + if (timestep.shape().empty()) { timestep = timestep.view({1}); } + if (timestep.shape().size() == 1 && timestep.shape()[0] == 1 && hidden_states.shape()[0] > 1) { + timestep = timestep.repeat(hidden_states.shape()[0], 0); + } + + auto time_embedding = time_embed_.forward(timestep); + auto text_embedding = text_embed_.forward(quantized_code, apply_cfg ? false : drop_code); + Tensor text_embedding_uncond = Tensor::nil(); + if (apply_cfg) { text_embedding_uncond = text_embed_.forward(quantized_code, true); } + + auto x = input_embed_.forward(hidden_states, speaker_embedding, condition_vector, text_embedding, drop_audio_conditioning, + text_embedding_uncond, apply_cfg); + + const int32_t seq_len = x.shape()[1]; + auto position_ids = Tensor::empty({x.shape()[0], seq_len}, kInt64, kCPU).alloc(); + auto* pos_ptr = position_ids.ptr(); + for (int32_t b = 0; b < position_ids.shape()[0]; ++b) { + for (int32_t s = 0; s < seq_len; ++s) { pos_ptr[b * seq_len + s] = s; } + } + + auto position_embeddings = rotary_embed_.forward(x, position_ids); + auto block_diff = makeBlockDiff(x.shape()[0], num_attention_heads_, seq_len, block_size_); + + for (auto& block : transformer_blocks_) { x = block.forward(x, time_embedding, position_embeddings, block_diff); } + + x = norm_out_.forward(x, time_embedding); + x = proj_out_(x); + return x; + } + + Tensor sample(const Tensor& conditioning_vector, const Tensor& reference_mel, const Tensor& quantized_code, int32_t num_steps, + float guidance_scale, float sway_coefficient) { + const int32_t max_duration = quantized_code.shape()[1] * repeats_; + auto initial_state = randomNormal({1, max_duration, mel_dim_}); + + const int32_t batch = reference_mel.shape()[0]; + if (batch != 1) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "Only batch size = 1 is supported for Qwen2.5-Omni token2wav."); } + + auto cond = Tensor(conditioning_vector); + cond = cond.view({batch, 1, conditioning_vector.shape()[1]}).repeat(max_duration, 1); + + auto ode_function = [&](float time_step, const Tensor& hidden) -> Tensor { + auto t = Tensor::empty({1}, kFloat32, kCPU).alloc(); + t.ptr()[0] = time_step; + + if (guidance_scale < 1e-5f) { + return forward(hidden, reference_mel, cond, quantized_code, t, false, false, false); + } + + auto model_output = forward(hidden, reference_mel, cond, quantized_code, t, false, false, true); + auto outputs = nn::functional::chunk<2>(model_output, 0); + return outputs[0] + (outputs[0] - outputs[1]) * guidance_scale; + }; + + auto time_points_tensor = linspace(0.0f, 1.0f, num_steps); + std::vector time_points(static_cast(num_steps)); + const auto* tp_ptr = time_points_tensor.ptr(); + for (int32_t i = 0; i < num_steps; ++i) { time_points[i] = tp_ptr[i]; } + + if (sway_coefficient != 0.0f) { + for (auto& t : time_points) { + t = t + sway_coefficient * (std::cos(kPi / 2.0f * t) - 1.0f + t); + } + } + + RungeKutta4ODESolver solver(ode_function, initial_state); + auto generated = solver.integrate(time_points); + auto mel = generated.permute({0, 2, 1}); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return mel; + } + + private: + Qwen2_5OmniDiTConfig cfg_; + int32_t mel_dim_ = 0; + int32_t repeats_ = 1; + int32_t block_size_ = 1; + int32_t num_attention_heads_ = 1; + + DiTTimestepEmbedding time_embed_; + DiTCodecEmbedding text_embed_; + DiTInputEmbedding input_embed_; + Qwen2_5OmniDiTRotaryEmbedding rotary_embed_; + std::vector transformer_blocks_; + Qwen2_5_OmniAdaLayerNormZero_Final norm_out_; + nn::Linear proj_out_; +}; + +class AMPBlock final : public nn::Module { + public: + AMPBlock() = default; + AMPBlock(const std::string& name, int32_t channels, int32_t kernel_size, const std::vector& dilations) + : nn::Module(name) { + if (dilations.size() != 3) { MLLM_ERROR_EXIT(ExitCode::kCoreError, "AMPBlock expects 3 dilation values."); } + + convs1_.emplace_back(reg("convs1.0", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[0]), + dilations[0], 1, true)); + convs1_.emplace_back(reg("convs1.1", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[1]), + dilations[1], 1, true)); + convs1_.emplace_back(reg("convs1.2", channels, channels, kernel_size, 1, getPadding(kernel_size, dilations[2]), + dilations[2], 1, true)); + + convs2_.emplace_back(reg("convs2.0", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.1", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + convs2_.emplace_back(reg("convs2.2", channels, channels, kernel_size, 1, getPadding(kernel_size, 1), 1, 1, true)); + + const int32_t num_layers = static_cast(convs1_.size() + convs2_.size()); + for (int32_t i = 0; i < num_layers; ++i) { + activations_.emplace_back(reg("activations." + std::to_string(i), channels)); + } + } + + Tensor forward(const Tensor& hidden_states) { + auto out = hidden_states; + const int32_t num_blocks = static_cast(convs1_.size()); + for (int32_t i = 0; i < num_blocks; ++i) { + auto residual = out; + auto x = activations_[i * 2].forward({out}, {})[0]; + x = convs1_[i](x); + x = activations_[i * 2 + 1].forward({x}, {})[0]; + x = convs2_[i](x); + out = residual + x; + } + return out; + } + + private: + static int32_t getPadding(int32_t kernel_size, int32_t dilation) { + return static_cast((kernel_size * dilation - dilation) / 2); + } + + std::vector convs1_; + std::vector convs2_; + std::vector activations_; +}; + +class Qwen2_5OmniToken2WavBigVGANModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavBigVGANModel() = default; + explicit Qwen2_5OmniToken2WavBigVGANModel(const std::string& name, const Qwen2_5OmniBigVGANConfig& cfg) : nn::Module(name), cfg_(cfg) { + num_residual_blocks_ = static_cast(cfg.resblock_kernel_sizes.size()); + num_upsample_layers_ = static_cast(cfg.upsample_rates.size()); + + conv_pre_ = reg("conv_pre", cfg.mel_dim, cfg.upsample_initial_channel, 7, 1, 3, 1, 1, true); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t stride = cfg.upsample_rates[layer_idx]; + const int32_t kernel = cfg.upsample_kernel_sizes[layer_idx]; + const int32_t in_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx)); + const int32_t out_ch = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + const int32_t padding = (kernel - stride) / 2; + ups_.emplace_back(reg("ups." + std::to_string(layer_idx) + ".0", in_ch, out_ch, kernel, stride, + padding, 0, 1, 1, true)); + } + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + const int32_t channels = cfg.upsample_initial_channel / static_cast(std::pow(2, layer_idx + 1)); + for (size_t i = 0; i < cfg.resblock_kernel_sizes.size(); ++i) { + resblocks_.emplace_back(reg("resblocks." + std::to_string(resblocks_.size()), channels, + cfg.resblock_kernel_sizes[i], cfg.resblock_dilation_sizes[i])); + } + } + + activation_post_ = + reg("activation_post", cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_))); + conv_post_ = reg("conv_post", + cfg.upsample_initial_channel / static_cast(std::pow(2, num_upsample_layers_)), 1, 7, 1, 3, 1, 1, + false); + } + + Tensor forward(const Tensor& mel_spectrogram) { + auto mel = mel_spectrogram; + if (!mel.isContiguous()) { mel = mel.contiguous(); } + auto processed = processMelSpectrogram(mel); + return forwardProcessed(processed); + } + + private: + Tensor forwardProcessed(const Tensor& processed) { + auto hidden = conv_pre_(processed); + + for (int32_t layer_idx = 0; layer_idx < num_upsample_layers_; ++layer_idx) { + hidden = ups_[layer_idx](hidden); + Tensor residual_sum = Tensor::zeros(hidden.shape(), hidden.dtype(), hidden.device()); + for (int32_t block_idx = 0; block_idx < num_residual_blocks_; ++block_idx) { + residual_sum = residual_sum + resblocks_[layer_idx * num_residual_blocks_ + block_idx].forward(hidden); + } + hidden = residual_sum * (1.0f / static_cast(num_residual_blocks_)); + } + + hidden = activation_post_.forward({hidden}, {})[0]; + auto output = conv_post_(hidden); + output = clampTensor(output, -1.0f, 1.0f); + return output.squeeze(); + } + Tensor processMelSpectrogram(const Tensor& mel_spectrogram) const { + auto amplitude = nn::functional::exp(mel_spectrogram); + auto decibel = amplitudeToDb(amplitude, -115.0f) + (-20.0f); + return normalizeSpectrogram(decibel, 1.0f, -115.0f); + } + + Qwen2_5OmniBigVGANConfig cfg_; + int32_t num_residual_blocks_ = 0; + int32_t num_upsample_layers_ = 0; + nn::Conv1D conv_pre_; + std::vector ups_; + std::vector resblocks_; + TorchActivation1d activation_post_; + nn::Conv1D conv_post_; +}; + +class Qwen2_5OmniToken2WavModel final : public nn::Module { + public: + Qwen2_5OmniToken2WavModel() = default; + explicit Qwen2_5OmniToken2WavModel(const std::string& name, const Qwen2_5OmniToken2WavConfig& cfg) : nn::Module(name), cfg_(cfg) { + code2wav_dit_model_ = reg("code2wav_dit_model", cfg.dit_config); + code2wav_bigvgan_model_ = reg("code2wav_bigvgan_model", cfg.bigvgan_config); + } + + Tensor forward(const Tensor& code, const Tensor& conditioning, const Tensor& reference_mel, int32_t num_steps = 10, + float guidance_scale = 0.5f, float sway_coefficient = -1.0f) { + auto mel = code2wav_dit_model_.sample(conditioning, reference_mel, code, num_steps, guidance_scale, sway_coefficient); + if (!mel.isContiguous()) { mel = mel.contiguous(); } + return code2wav_bigvgan_model_.forward(mel); + } + + Tensor vocodeMel(const Tensor& mel) { + return code2wav_bigvgan_model_.forward(mel); + } + + private: + Qwen2_5OmniToken2WavConfig cfg_; + Qwen2_5OmniToken2WavDiTModel code2wav_dit_model_; + Qwen2_5OmniToken2WavBigVGANModel code2wav_bigvgan_model_; +}; + +} // namespace token2wav + +using token2wav::Qwen2_5OmniToken2WavBigVGANModel; +using token2wav::Qwen2_5OmniToken2WavDiTModel; +using token2wav::Qwen2_5OmniToken2WavModel; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp new file mode 100644 index 000000000..961b5c8f2 --- /dev/null +++ b/mllm/models/qwen2_5omni/tokenization_qwen2_5omni.hpp @@ -0,0 +1,385 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include +#include + +#include "mllm/preprocessor/tokenizers/BPE.hpp" +#include "mllm/preprocessor/tokenizers/Unicode.hpp" +#include "mllm/preprocessor/tokenizers/AutoTokenizer.hpp" +#include "mllm/models/ARGeneration.hpp" +#include "mllm/models/qwen2vl/image_preprocessor_qwen2vl.hpp" +#include "mllm/models/qwen2_5omni/audio_preprocessor_qwen2_5omni.hpp" +#include "mllm/utils/Common.hpp" + +namespace mllm::models::qwen2_5omni { + +// same regex as Qwen2/Qwen2-VL tokenizers +inline bool qwen2_5OmniTokenizerMatchPattern(const std::wstring& str, size_t& pos, std::wstring& matched) { + if (pos >= str.size()) return false; + + static const std::wstring contractions[] = {L"'s", L"'t", L"'re", L"'ve", L"'m", L"'ll", L"'d"}; + for (const auto& contraction : contractions) { + if (pos + contraction.size() <= str.size() && str.compare(pos, contraction.size(), contraction) == 0) { + matched = contraction; + pos += contraction.size(); + return true; + } + } + + { + size_t original_pos = pos; + bool has_prefix = false; + matched.clear(); + + if (!preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos]) && str[pos] != L'\r' && str[pos] != L'\n') { + matched += str[pos]; + ++pos; + has_prefix = true; + } + + if (pos < str.size() && preprocessor::isLetter(str[pos])) { + do { + matched += str[pos]; + ++pos; + } while (pos < str.size() && preprocessor::isLetter(str[pos])); + return true; + } else if (has_prefix) { + pos = original_pos; + matched.clear(); + } + } + + if (preprocessor::isDigit(str[pos])) { + matched = str.substr(pos, 1); + ++pos; + return true; + } + + { + size_t original_pos = pos; + matched.clear(); + size_t start = pos; + + if (str[pos] == L' ') { ++pos; } + + if (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) && !preprocessor::isDigit(str[pos])) { + do { + ++pos; + } while (pos < str.size() && !std::iswspace(str[pos]) && !preprocessor::isLetter(str[pos]) + && !preprocessor::isDigit(str[pos])); + + matched = str.substr(start, pos - start); + + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + matched += str[pos]; + ++pos; + } + return true; + } else { + pos = original_pos; + } + } + + { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) { + while (pos < str.size() && (str[pos] == L'\r' || str[pos] == L'\n')) ++pos; + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + if (pos >= str.size() || std::iswspace(str[pos])) { + matched = str.substr(start, pos - start); + return true; + } else { + pos = start; + } + } + + if (std::iswspace(str[pos])) { + size_t start = pos; + while (pos < str.size() && std::iswspace(str[pos])) ++pos; + matched = str.substr(start, pos - start); + return true; + } + + return false; +} + +inline bool qwen2_5OmniRegex(const std::string& str, std::vector& splitted) { + auto w_string = preprocessor::utf8string2WideString(str); + size_t pos = 0; + while (pos < w_string.size()) { + std::wstring matched; + if (qwen2_5OmniTokenizerMatchPattern(w_string, pos, matched)) { + splitted.push_back(matched); + } else { + ++pos; + } + } + return true; +} + +struct Qwen2_5OmniMessage { + std::string prompt; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniVisionMessage { + std::string prompt; + std::string img_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|vision_bos|><|IMAGE|><|vision_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +struct Qwen2_5OmniAudioMessage { + std::string prompt; + std::string audio_file_path; + std::string system_prompt = "You are a helpful assistant."; + + [[nodiscard]] std::string buildChatMessage() const { + std::string result; + if (!system_prompt.empty()) { + result += "<|im_start|>system\n" + system_prompt + "<|im_end|>\n"; + } + result += "<|im_start|>user\n<|audio_bos|><|AUDIO|><|audio_eos|>" + prompt + "<|im_end|>\n"; + result += "<|im_start|>assistant\n"; + return result; + } +}; + +class Qwen2_5OmniTokenizer final : public mllm::preprocessor::AutoTokenizer { + public: + explicit Qwen2_5OmniTokenizer(const std::string& file_path, + int32_t spatial_merge_size = 2, + int32_t min_pixels = 56 * 56, + int32_t max_pixels = 1280 * 1280, + int32_t audio_sample_rate = 16000, + int32_t audio_n_mels = 128, + int32_t audio_hop_length = 160, + int32_t audio_chunk_length = 300) + //interestingly, the answer went bad when setting max_pixels higher, eg. 3584*3584) + : image_preprocessor_(min_pixels, max_pixels), + audio_preprocessor_(audio_sample_rate, audio_n_mels, audio_hop_length, audio_chunk_length), + spatial_merge_size_(spatial_merge_size) { + preprocessor::initLocal(); + preprocessor::makeBytes2UnicodeMap(bytes_2_unicode_dict_); + for (auto& kv : bytes_2_unicode_dict_) { bytes_2_unicode_dict_inverse_.insert({kv.second, kv.first}); } + bpe_.initFromSentencePieceJson(file_path); + special_tokens_trie_.add(L"<|endoftext|>"); + special_tokens_trie_.add(L"<|im_start|>"); + special_tokens_trie_.add(L"<|im_end|>"); + special_tokens_trie_.add(L"<|object_ref_start|>"); + special_tokens_trie_.add(L"<|object_ref_end|>"); + special_tokens_trie_.add(L"<|box_start|>"); + special_tokens_trie_.add(L"<|box_end|>"); + special_tokens_trie_.add(L"<|quad_start|>"); + special_tokens_trie_.add(L"<|quad_end|>"); + special_tokens_trie_.add(L"<|vision_bos|>"); + special_tokens_trie_.add(L"<|vision_eos|>"); + special_tokens_trie_.add(L"<|vision_pad|>"); + special_tokens_trie_.add(L"<|image_pad|>"); + special_tokens_trie_.add(L"<|video_pad|>"); + special_tokens_trie_.add(L"<|AUDIO|>"); + special_tokens_trie_.add(L"<|audio_bos|>"); + special_tokens_trie_.add(L"<|audio_eos|>"); + special_tokens_trie_.add(L"<|IMAGE|>"); + special_tokens_trie_.add(L"<|VIDEO|>"); + } + + std::vector _tokenize(const std::string& str) override { + std::vector ret; + std::vector splitted; + ::mllm::models::qwen2_5omni::qwen2_5OmniRegex(str, splitted); + for (const auto& s : splitted) { + auto utf_8_str = preprocessor::wideString2Utf8String(s); + std::wstring mapped_str; + for (unsigned char c : utf_8_str) { mapped_str.push_back(bytes_2_unicode_dict_[c]); } + + auto bpe_ts = bpe_._bpe(mapped_str); + + for (const auto& bpe_t : bpe_ts) { ret.push_back(bpe_t); } + } + + return ret; + } + + std::vector tokenize(const std::string& str) override { + auto tokens = special_tokens_trie_.split(preprocessor::utf8string2WideString(str)); + std::vector all_tokens; + for (const auto& token : tokens) { + if (special_tokens_trie_.isSpecialToken(token)) { + all_tokens.emplace_back(token); + continue; + } + auto tmp_tokens = _tokenize(preprocessor::wideString2Utf8String(token)); + all_tokens.insert(all_tokens.end(), tmp_tokens.begin(), tmp_tokens.end()); + } + return all_tokens; + } + + std::wstring _detokenize(int64_t pos_idx) override { return bpe_._lookup_inverse_vocab(pos_idx); } + + std::wstring detokenize(int64_t pos_idx) override { + auto str = _detokenize(pos_idx); + std::string utf_8_str; + for (wchar_t c : str) { utf_8_str.push_back((unsigned char)(bytes_2_unicode_dict_inverse_[c])); } + return {mllm::preprocessor::utf8string2WideString(utf_8_str)}; + } + + Tensor convert2Ids(const std::vector& strs) override { + std::vector ids; + ids.reserve(strs.size()); + for (const auto& str : strs) { ids.emplace_back(bpe_._lookup_vocab(str)); } + Tensor ret = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kExtraInput) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = ret.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return ret; + } + + ARGenerationOutputPast convertMessage(const Qwen2_5OmniMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return {{"sequence", sequence}}; + } + + ARGenerationOutputPast convertVisionMessage(const Qwen2_5OmniVisionMessage& message) { + auto applied_string = message.buildChatMessage(); + + auto [img, grid_thw] = image_preprocessor_(message.img_file_path); + + auto sequence_str = tokenize(applied_string); + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto grid_t = grid_thw.ptr()[0]; + auto grid_h = grid_thw.ptr()[1]; + auto grid_w = grid_thw.ptr()[2]; + int32_t img_token_nums = grid_t * grid_h * grid_w; + img_token_nums /= (spatial_merge_size_ * spatial_merge_size_); + + auto image_token_id = bpe_._lookup_vocab(L"<|IMAGE|>"); + { + auto it = std::find(ids.begin(), ids.end(), image_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|IMAGE|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, img_token_nums - 1, image_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + return { + {"sequence", sequence}, + {"img", img}, + {"grid_thw", grid_thw}, + }; + } + + ARGenerationOutputPast convertAudioMessage(const Qwen2_5OmniAudioMessage& message) { + auto applied_string = message.buildChatMessage(); + auto sequence_str = tokenize(applied_string); + + std::vector ids; + ids.reserve(sequence_str.size()); + for (const auto& str : sequence_str) { ids.emplace_back(bpe_._lookup_vocab(str)); } + + auto audio_result = audio_preprocessor_.processAudioFile(message.audio_file_path); + if (audio_result.input_features.isNil() || audio_result.feature_length <= 0) { + MLLM_ERROR_EXIT(ExitCode::kIOError, "Failed to extract audio features for Qwen2.5-Omni."); + } + + int32_t audio_token_nums = audio_preprocessor_.calcAudioTokenLength(audio_result.feature_length); + if (audio_token_nums <= 0) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Invalid audio token length for Qwen2.5-Omni."); + } + + auto audio_token_id = bpe_._lookup_vocab(L"<|AUDIO|>"); + { + auto it = std::find(ids.begin(), ids.end(), audio_token_id); + if (it == ids.end()) { + MLLM_ERROR_EXIT(ExitCode::kCoreError, "Missing <|AUDIO|> token in Qwen2.5-Omni prompt template."); + } + ids.insert(it + 1, audio_token_nums - 1, audio_token_id); + } + + Tensor sequence = Tensor::empty({1, static_cast(ids.size())}, kInt64, kCPU) + .setMemType(kNormal) + .setName("qwen2_5omni-tokenizer-i0") + .alloc(); + + auto ptr = sequence.ptr(); + for (size_t i = 0; i < ids.size(); ++i) { ptr[i] = ids[i]; } + + audio_result.input_features.setName("input_features"); + + return { + {"sequence", sequence}, + {"input_features", audio_result.input_features}, + }; + } + + private: + preprocessor::BPE bpe_; + std::unordered_map bytes_2_unicode_dict_; + std::unordered_map bytes_2_unicode_dict_inverse_; + mllm::models::qwen2vl::Qwen2VLImagePreprocessor image_preprocessor_; + Qwen2_5OmniAudioPreprocessor audio_preprocessor_; + int32_t spatial_merge_size_ = 2; +}; + +} // namespace mllm::models::qwen2_5omni diff --git a/mllm/nn/Nn.hpp b/mllm/nn/Nn.hpp index fdb0edc82..160e1cb43 100644 --- a/mllm/nn/Nn.hpp +++ b/mllm/nn/Nn.hpp @@ -11,6 +11,7 @@ #include "mllm/nn/layers/RMSNorm.hpp" // IWYU pragma: export #include "mllm/nn/layers/SiLU.hpp" // IWYU pragma: export #include "mllm/nn/layers/Sigmoid.hpp" // IWYU pragma: export +#include "mllm/nn/layers/Tanh.hpp" // IWYU pragma: export #include "mllm/nn/layers/Embedding.hpp" // IWYU pragma: export #include "mllm/nn/layers/GELU.hpp" // IWYU pragma: export #include "mllm/nn/layers/QuickGELU.hpp" // IWYU pragma: export @@ -26,6 +27,7 @@ #include "mllm/nn/layers/Param.hpp" // IWYU pragma: export #include "mllm/nn/layers/KVCache.hpp" // IWYU pragma: export #include "mllm/nn/layers/Conv1D.hpp" // IWYU pragma: export +#include "mllm/nn/layers/ConvTranspose1D.hpp" // IWYU pragma: export #include "mllm/nn/layers/AvgPool1d.hpp" // IWYU pragma: export #include "mllm/nn/layers/STFT.hpp" // IWYU pragma: export #include "mllm/nn/layers/PagedAttn.hpp" // IWYU pragma: export diff --git a/mllm/nn/layers/ConvTranspose1D.cpp b/mllm/nn/layers/ConvTranspose1D.cpp new file mode 100644 index 000000000..de2a7a5c7 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.cpp @@ -0,0 +1,32 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/ConvTranspose1D.hpp" + +namespace mllm::nn { + +ConvTranspose1D::ConvTranspose1D() : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{}) {} + +ConvTranspose1D::ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size, + int32_t padding, int32_t output_padding, int32_t dilation, int32_t groups, bool bias) + : Layer(OpTypes::kConvTranspose1D, aops::ConvTranspose1DOpOptions{.in_channels = in_channels, + .out_channels = out_channels, + .kernel_size = kernel_size, + .stride = stride_size, + .padding = padding, + .output_padding = output_padding, + .dilation = dilation, + .groups = groups, + .bias = bias}) {} + +ConvTranspose1D::ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options) : Layer(OpTypes::kConvTranspose1D, options) {} + +Tensor ConvTranspose1D::weight() const { + return std::static_pointer_cast(impl()->getInstancedOp())->weight(); +} + +Tensor ConvTranspose1D::bias() const { + return std::static_pointer_cast(impl()->getInstancedOp())->bias(); +} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/ConvTranspose1D.hpp b/mllm/nn/layers/ConvTranspose1D.hpp new file mode 100644 index 000000000..6ddc2fac3 --- /dev/null +++ b/mllm/nn/layers/ConvTranspose1D.hpp @@ -0,0 +1,29 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/ConvTranspose1DOp.hpp" + +namespace mllm::nn { + +class ConvTranspose1D : public Layer { + public: + ConvTranspose1D(); + + ConvTranspose1D(int32_t in_channels, int32_t out_channels, int32_t kernel_size, int32_t stride_size = 1, + int32_t padding = 0, int32_t output_padding = 0, int32_t dilation = 1, int32_t groups = 1, + bool bias = true); + + explicit ConvTranspose1D(const aops::ConvTranspose1DOpOptions& options); + + [[nodiscard]] Tensor weight() const; + + [[nodiscard]] Tensor bias() const; + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD +}; + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.cpp b/mllm/nn/layers/Tanh.cpp new file mode 100644 index 000000000..dda95f7ae --- /dev/null +++ b/mllm/nn/layers/Tanh.cpp @@ -0,0 +1,12 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#include "mllm/nn/layers/Tanh.hpp" + +namespace mllm::nn { + +Tanh::Tanh() : Layer(OpTypes::kTanh, aops::TanhOpOptions{}) {} + +Tanh::Tanh(const aops::TanhOpOptions& options) : Layer(OpTypes::kTanh, options) {} + +} // namespace mllm::nn diff --git a/mllm/nn/layers/Tanh.hpp b/mllm/nn/layers/Tanh.hpp new file mode 100644 index 000000000..ab84e7eeb --- /dev/null +++ b/mllm/nn/layers/Tanh.hpp @@ -0,0 +1,21 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. + +#pragma once + +#include "mllm/nn/Layer.hpp" +#include "mllm/core/aops/TanhOp.hpp" + +namespace mllm::nn { + +class Tanh : public Layer { + public: + Tanh(); + + explicit Tanh(const aops::TanhOpOptions& options); + + MLLM_LAYER_ANY_INPUTS_1_OUTPUTS_FORWARD + MLLM_LAYER_ENABLE_INPLACE_ATTRIBUTE(Tanh) +}; + +} // namespace mllm::nn diff --git a/pymllm/README.md b/pymllm/README.md index e69de29bb..bee5ac41c 100644 --- a/pymllm/README.md +++ b/pymllm/README.md @@ -0,0 +1,3 @@ +# pymllm + +![pymllm-arch](../assets/pymllm-arch.png) diff --git a/pymllm/__init__.py b/pymllm/__init__.py index 1bd31cd6c..3f2488d27 100644 --- a/pymllm/__init__.py +++ b/pymllm/__init__.py @@ -2,48 +2,32 @@ # Licensed under the MIT License. from __future__ import annotations +import os +import sys -from . import ffi -from . import convertor -from . import utils -from . import quantize -from . import nn -from . import compile -from . import service -from . import backends -from .ffi import ( - # Floating point types - float32, - float16, - bfloat16, - # Signed integer types - int8, - int16, - int32, - int64, - # Unsigned integer types - uint8, - uint16, - uint32, - uint64, - # Bool type - boolean, - # Devices - cpu, - cuda, - qnn, - # Tensor and utilities - Tensor, - empty, - echo, - device, - is_torch_available, - is_numpy_available, - from_torch, - from_numpy, - zeros, - ones, - arange, - random, -) -from .nn.functional import matmul +__all__ = [] + + +def _has_mobile_libs() -> bool: + parent_dir = os.path.dirname(os.path.realpath(__file__)) + + # Platform-specific library names + if sys.platform.startswith("win32"): + lib_name = "MllmFFIExtension.dll" + elif sys.platform.startswith("darwin"): + lib_name = "MllmFFIExtension.dylib" + else: + lib_name = "MllmFFIExtension.so" + + lib_path = os.path.join(parent_dir, "lib", lib_name) + return os.path.exists(lib_path) + + +def is_mobile_available() -> bool: + return _has_mobile_libs() + + +if _has_mobile_libs(): + from . import mobile + + __all__.append("mobile") diff --git a/pymllm/__main__.py b/pymllm/__main__.py new file mode 100644 index 000000000..0b427fcee --- /dev/null +++ b/pymllm/__main__.py @@ -0,0 +1,39 @@ +def show_config() -> None: + from . import is_mobile_available + + mobile_enabled = str(is_mobile_available()).lower() + print(f"mllm mobile: {mobile_enabled}") + + # try import mllm_kernel, if true, print mllm_kernel config + try: + import mllm_kernel + + print(f"mllm_kernel: {mllm_kernel.__version__}") + except ImportError: + print("mllm_kernel: not found") + + +def main() -> None: + import argparse + + parser = argparse.ArgumentParser( + prog="pymllm", + description="pymllm helper commands.", + ) + parser.add_argument( + "command", + nargs="?", + choices=["show-config"], + help="Run helper command. Use 'show-config' to print config details.", + ) + args = parser.parse_args() + + if args.command == "show-config": + show_config() + return + + parser.print_help() + + +if __name__ == "__main__": + main() diff --git a/pymllm/backends/__init__.py b/pymllm/backends/__init__.py deleted file mode 100644 index 5e926d580..000000000 --- a/pymllm/backends/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) MLLM Team. -# Licensed under the MIT License. - -from . import cuda, qualcomm diff --git a/pymllm/backends/cuda/tilelang_compile_test.py b/pymllm/backends/cuda/tilelang_compile_test.py deleted file mode 100644 index 65a2e0071..000000000 --- a/pymllm/backends/cuda/tilelang_compile_test.py +++ /dev/null @@ -1,41 +0,0 @@ -import tilelang -import tilelang.language as T - - -@tilelang.jit( - out_idx=[-1], compile_flags=["-O3", "--use_fast_math", "--expt-relaxed-constexpr"] -) -def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): - @T.prim_func - def elem_add( - A: T.Tensor((M, N), in_dtype), - B: T.Tensor((M, N), in_dtype), - C: T.Tensor((M, N), out_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads - ) as (bx, by): - A_shared = T.alloc_shared((block_M, block_N), in_dtype) - B_shared = T.alloc_shared((block_M, block_N), in_dtype) - C_local = T.alloc_fragment((block_M, block_N), out_dtype) - C_shared = T.alloc_shared((block_M, block_N), out_dtype) - - T.copy(A[by * block_M, bx * block_N], A_shared) - T.copy(B[by * block_M, bx * block_N], B_shared) - for local_y, local_x in T.Parallel(block_M, block_N): - C_local[local_y, local_x] = ( - A_shared[local_y, local_x] + B_shared[local_y, local_x] - ) - T.copy(C_local, C_shared) - T.copy(C_shared, C[by * block_M, bx * block_N]) - - return elem_add - - -def compile_test(): - M = 1024 - N = 1024 - config = {"block_M": 128, "block_N": 128, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype="float16", out_dtype="float16") - source = kernel.get_kernel_source() - print(source) diff --git a/pymllm/configs/__init__.py b/pymllm/configs/__init__.py new file mode 100644 index 000000000..a23de035c --- /dev/null +++ b/pymllm/configs/__init__.py @@ -0,0 +1,14 @@ +"""Configuration module for pymllm.""" + +from pymllm.configs.global_config import GlobalConfig, get_global_config +from pymllm.configs.model_config import ModelConfig +from pymllm.configs.quantization_config import QuantizationConfig +from pymllm.configs.server_config import ServerConfig + +__all__ = [ + "GlobalConfig", + "get_global_config", + "ServerConfig", + "ModelConfig", + "QuantizationConfig", +] diff --git a/pymllm/configs/global_config.py b/pymllm/configs/global_config.py new file mode 100644 index 000000000..6ec68dda2 --- /dev/null +++ b/pymllm/configs/global_config.py @@ -0,0 +1,356 @@ +"""Global configuration singleton aggregating all sub-configs.""" + +from __future__ import annotations + +import argparse +import types +from dataclasses import MISSING, dataclass, field, fields +from pathlib import Path +from typing import ( + Any, + Callable, + Literal, + Optional, + Sequence, + Union, + get_args, + get_origin, + get_type_hints, +) + +from pymllm.configs.server_config import ServerConfig +from pymllm.configs.model_config import ModelConfig +from pymllm.configs.quantization_config import QuantizationConfig + + +@dataclass(init=False) +class GlobalConfig: + """Singleton that holds every sub-config pymllm needs. + + Usage:: + + from pymllm.configs import get_global_config + + cfg = get_global_config() + cfg.model.model_path + cfg.model.hidden_size + cfg.quantization.method + cfg.server.host + + .. note:: + + Always use :meth:`get_instance` (or the module-level + :func:`get_global_config` shortcut) to obtain the singleton. + ``GlobalConfig()`` is safe to call multiple times — the second and + subsequent calls return the existing instance without re-initialising + fields. + """ + + server: "ServerConfig" + model: ModelConfig + quantization: QuantizationConfig + _initialized: bool + + def __new__(cls): + if not hasattr(cls, "_instance") or cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + # Guard: skip re-initialisation on repeated GlobalConfig() calls. + # The dataclass auto-generated __init__ is disabled (init=False) so + # this custom __init__ has full control. + if getattr(self, "_initialized", False): + return + self.server = ServerConfig(model_path=None) + self.model = ModelConfig() + self.quantization = QuantizationConfig() + self._initialized = True + + @classmethod + def get_instance(cls) -> "GlobalConfig": + if not hasattr(cls, "_instance") or cls._instance is None: + cls._instance = cls() + return cls._instance + + @classmethod + def reset(cls) -> None: + """Destroy the singleton (useful in tests).""" + cls._instance = None + + +def _parse_bool(value: Any) -> bool: + """Convert common CLI boolean spellings into ``bool``. + + This helper is intentionally permissive because CLI users often provide + booleans in different forms (for example ``true``, ``1``, ``yes``, + ``false``, ``0``, ``no``). The function raises ``argparse.ArgumentTypeError`` + to integrate naturally with ``argparse`` validation and error reporting. + """ + + if isinstance(value, bool): + return value + if value is None: + return True + + lowered = str(value).strip().lower() + if lowered in {"1", "true", "t", "yes", "y", "on"}: + return True + if lowered in {"0", "false", "f", "no", "n", "off"}: + return False + raise argparse.ArgumentTypeError( + f"Invalid boolean value: {value!r}. Expected one of true/false, 1/0, yes/no." + ) + + +def _unwrap_optional(annotation: Any) -> tuple[Any, bool]: + """Return ``(inner_type, is_optional)`` for Optional/Union annotations.""" + + origin = get_origin(annotation) + if origin not in (Union, types.UnionType): + return annotation, False + + args = [arg for arg in get_args(annotation) if arg is not type(None)] + if len(args) == 1 and len(get_args(annotation)) == 2: + return args[0], True + return annotation, False + + +def _converter_for_annotation(annotation: Any) -> Optional[Callable[[str], Any]]: + """Map a type annotation to an ``argparse`` converter. + + Only scalar, CLI-friendly annotations are supported. Complex runtime fields + (for example nested dict/object handles) are intentionally excluded from the + generated CLI surface to keep the interface predictable and safe. + """ + + inner, _ = _unwrap_optional(annotation) + origin = get_origin(inner) + if origin is not None: + if origin is Literal: + literal_values = get_args(inner) + if literal_values: + return type(literal_values[0]) + return str + return None + + if inner in (str, int, float): + return inner + if inner is Path: + return Path + return None + + +def _choices_for_annotation(annotation: Any) -> Optional[list]: + """Extract allowed values from a ``Literal`` annotation, if applicable.""" + + inner, _ = _unwrap_optional(annotation) + origin = get_origin(inner) + if origin is Literal: + return list(get_args(inner)) + return None + + +def _is_bool_annotation(annotation: Any) -> bool: + """Return ``True`` if annotation represents a bool/Optional[bool] field.""" + + inner, _ = _unwrap_optional(annotation) + return inner is bool + + +def _format_default_for_help(value: Any) -> str: + """Create a concise, readable default string for CLI help text.""" + + if value is MISSING: + return "" + if value is None: + return "None" + if isinstance(value, Path): + return str(value) + return repr(value) + + +def make_args( + parser: Optional[argparse.ArgumentParser] = None, +) -> argparse.ArgumentParser: + """Create an ``argparse`` parser with two-level GlobalConfig CLI options. + + The generated options follow the naming pattern ``--
.`` so + each sub-config can be configured independently: + + - ``server`` options map to :class:`ServerConfig` fields. + - ``model`` options map to :class:`ModelConfig` fields. + - ``quantization`` options map to :class:`QuantizationConfig` fields. + + Examples + -------- + - ``--server.host 0.0.0.0`` + - ``--server.port 8080`` + - ``--server.sleep_on_idle`` (implicit true) + - ``--server.sleep_on_idle false`` (explicit false) + - ``--quantization.method awq`` + + Design notes + ------------ + - Options are generated from dataclass metadata, which keeps the CLI surface + synchronized with config definitions and avoids manual drift. + - Parser defaults are suppressed (``argparse.SUPPRESS``), so ``read_args`` + can reliably detect whether a value was explicitly provided by the user. + - Only CLI-friendly scalar fields are exposed; runtime-only fields are + skipped automatically. + """ + + if parser is None: + parser = argparse.ArgumentParser( + prog="pymllm", + description="CLI options for configuring pymllm GlobalConfig.", + ) + + cfg = GlobalConfig.get_instance() + sections: list[tuple[str, Any]] = [ + ("server", cfg.server), + ("model", cfg.model), + ("quantization", cfg.quantization), + ] + + for section_name, section_obj in sections: + section_group = parser.add_argument_group( + f"{section_name} config", + f"Options for the '{section_name}' section of GlobalConfig.", + ) + type_hints = get_type_hints(type(section_obj)) + for dc_field in fields(section_obj): + if dc_field.name.startswith("_"): + continue + + annotation = type_hints.get(dc_field.name, dc_field.type) + option = f"--{section_name}.{dc_field.name}" + dest = f"{section_name}__{dc_field.name}" + default_value = getattr(section_obj, dc_field.name) + + if _is_bool_annotation(annotation): + section_group.add_argument( + option, + dest=dest, + nargs="?", + const=True, + type=_parse_bool, + default=argparse.SUPPRESS, + help=( + f"{section_name}.{dc_field.name} (bool, default: " + f"{_format_default_for_help(default_value)}). " + "Can be provided as a flag for true or with an explicit value." + ), + ) + continue + + converter = _converter_for_annotation(annotation) + if converter is None: + # Skip non-scalar or runtime-only fields (e.g. arbitrary objects). + continue + + choices = _choices_for_annotation(annotation) + kwargs: dict[str, Any] = dict( + dest=dest, + type=converter, + default=argparse.SUPPRESS, + ) + if choices is not None: + kwargs["choices"] = choices + choices_str = ", ".join(str(c) for c in choices) + kwargs["help"] = ( + f"{section_name}.{dc_field.name} " + f"{{choices: {choices_str}}} " + f"(default: {_format_default_for_help(default_value)})." + ) + else: + kwargs["help"] = ( + f"{section_name}.{dc_field.name} (default: " + f"{_format_default_for_help(default_value)})." + ) + + section_group.add_argument(option, **kwargs) + + return parser + + +def read_args( + argv: Optional[Sequence[str]] = None, + parser: Optional[argparse.ArgumentParser] = None, +) -> GlobalConfig: + """Parse CLI args and apply overrides to the singleton ``GlobalConfig``. + + Parameters + ---------- + argv + Optional argument vector. If ``None``, ``argparse`` reads from + ``sys.argv`` (standard CLI behavior). + parser + Optional parser to use. When omitted, this function builds one through + :func:`make_args`. + + Returns + ------- + GlobalConfig + The singleton config instance after CLI overrides have been applied. + + Behavior + -------- + 1. Parse all generated ``--section.field`` options. + 2. Apply only explicitly provided options (no accidental overwrite by parser + defaults). + 3. Rebuild ``ServerConfig`` when server fields change so validation in + ``ServerConfig.__post_init__`` and ``_validate`` remains enforced. + 4. Keep ``server.model_path`` and ``model.model_path`` aligned when only one + side is explicitly overridden (the same precedence used by runtime config + loading conventions). + """ + + if parser is None: + parser = make_args() + + namespace = parser.parse_args(argv) + parsed = vars(namespace) + cfg = GlobalConfig.get_instance() + + # Server: reconstruct to preserve validation behavior. + from pymllm.configs.server_config import ServerConfig + + server_updates: dict[str, Any] = {} + for dc_field in fields(cfg.server): + key = f"server__{dc_field.name}" + if key in parsed: + server_updates[dc_field.name] = parsed[key] + if server_updates: + server_values = { + dc_field.name: getattr(cfg.server, dc_field.name) + for dc_field in fields(cfg.server) + } + server_values.update(server_updates) + cfg.server = ServerConfig(**server_values) + + # Model / Quantization: in-place updates are sufficient. + for section_name, section_obj in ( + ("model", cfg.model), + ("quantization", cfg.quantization), + ): + for dc_field in fields(section_obj): + key = f"{section_name}__{dc_field.name}" + if key in parsed: + setattr(section_obj, dc_field.name, parsed[key]) + + # Keep model path synchronized when only one side is explicitly overridden. + server_model_overridden = "server__model_path" in parsed + model_model_overridden = "model__model_path" in parsed + if server_model_overridden and not model_model_overridden: + cfg.model.model_path = cfg.server.model_path + elif model_model_overridden and not server_model_overridden: + cfg.server.model_path = cfg.model.model_path + + cfg._initialized = True + return cfg + + +def get_global_config() -> GlobalConfig: + """Return the global config singleton.""" + return GlobalConfig.get_instance() diff --git a/pymllm/configs/model_config.py b/pymllm/configs/model_config.py new file mode 100644 index 000000000..c23dff1d9 --- /dev/null +++ b/pymllm/configs/model_config.py @@ -0,0 +1,31 @@ +"""Lightweight model configuration: path + HuggingFace config handle.""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Optional + + +@dataclass +class ModelConfig: + """Minimal model config wrapping a HuggingFace PretrainedConfig. + + Attributes on ``hf_config`` are flattened onto this object:: + + cfg = get_global_config().model + cfg.hidden_size # -> hf_config.hidden_size + cfg.vocab_size # -> hf_config.vocab_size + cfg.text_config # -> hf_config.text_config (multimodal) + """ + + # Populated at runtime via ``transformers.AutoConfig.from_pretrained`` + hf_config: Optional[Any] = field(default=None, repr=False) + + def __getattr__(self, name: str) -> Any: + hf = object.__getattribute__(self, "hf_config") + if hf is not None and hasattr(hf, name): + return getattr(hf, name) + raise AttributeError( + f"'{type(self).__name__}' has no attribute '{name}' " + f"(also not found on hf_config)" + ) diff --git a/pymllm/configs/quantization_config.py b/pymllm/configs/quantization_config.py new file mode 100644 index 000000000..850ea82b8 --- /dev/null +++ b/pymllm/configs/quantization_config.py @@ -0,0 +1,18 @@ +"""Quantization settings for model weights and KV cache.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Literal, Optional + + +@dataclass +class QuantizationConfig: + """Quantization configuration for weights and KV cache.""" + + # Weight quantization method (e.g. "awq", "gptq", "fp8", None for no quant) + method: Optional[str] = None + # KV cache data type override + kv_cache_dtype: Literal[ + "auto", "float16", "bfloat16", "fp8_e4m3", "fp8_e5m2" + ] = "auto" diff --git a/pymllm/configs/server_config.py b/pymllm/configs/server_config.py new file mode 100644 index 000000000..92d02e05e --- /dev/null +++ b/pymllm/configs/server_config.py @@ -0,0 +1,174 @@ +from __future__ import annotations + +from pathlib import Path +from typing import Any, Literal, Optional +from dataclasses import dataclass, field + + +@dataclass +class ServerConfig: + """Centralized runtime configuration for the MLLM server.""" + + # --------------------------------------------------------------------- # + # Model and tokenizer configuration + # --------------------------------------------------------------------- # + model_path: Optional[Path] = None + tokenizer_path: Optional[Path] = None + tokenizer_mode: Literal["auto", "slow", "fast"] = "auto" + load_format: Literal["auto", "safetensors"] = "auto" + trust_remote_code: bool = False + download_dir: Optional[Path] = None + context_length: Optional[int] = None + dtype: Literal["auto", "float16", "bfloat16", "float32"] = "auto" + + # --------------------------------------------------------------------- # + # HTTP / API server + # --------------------------------------------------------------------- # + host: str = "127.0.0.1" + port: int = 30000 + fastapi_root_path: str = "" + api_key: Optional[str] = None + admin_api_key: Optional[str] = None + served_model_name: Optional[str] = None + file_storage_path: Path = Path("mllm_storage") + cors_allow_origins: list[str] = field(default_factory=lambda: ["*"]) + + # --------------------------------------------------------------------- # + # Scheduling and memory + # --------------------------------------------------------------------- # + mem_fraction_static: Optional[float] = None + max_running_requests: Optional[int] = 1 + max_queued_requests: Optional[int] = None + max_total_tokens: Optional[int] = None + chunked_prefill_size: Optional[int] = None + max_prefill_tokens: Optional[int] = None + schedule_policy: Literal["auto", "fcfs"] = "fcfs" + schedule_conservativeness: float = 1.0 + sleep_on_idle: bool = False + stream_interval: int = 1 + stream_output: bool = True + + # --------------------------------------------------------------------- # + # Device + # --------------------------------------------------------------------- # + base_gpu_id: int = 0 + + # --------------------------------------------------------------------- # + # Backend / acceleration + # --------------------------------------------------------------------- # + attention_backend: Literal["auto", "flashinfer"] = "auto" + gdn_decode_backend: Literal["auto", "flashinfer", "mllm_kernel", "pytorch"] = "auto" + sampling_backend: Optional[str] = None + disable_cuda_graph: bool = False + enable_torch_compile: bool = False + torch_compile_max_bs: int = 32 + random_seed: Optional[int] = 42 + + # --------------------------------------------------------------------- # + # Output parsers (reasoning / tool calls) + # --------------------------------------------------------------------- # + reasoning_parser: Optional[str] = None # e.g. "deepseek-r1", "qwen3" + tool_call_parser: Optional[str] = None # e.g. "qwen25", "llama3", "hermes" + + # --------------------------------------------------------------------- # + # Logging and observability + # --------------------------------------------------------------------- # + log_level: Literal["debug", "info", "warning", "error", "critical"] = "info" + enable_metrics: bool = False + show_time_cost: bool = False + # Log prefill/decode throughput stats every N decode batches (0 = disabled) + decode_log_interval: int = 40 + + # --------------------------------------------------------------------- # + # Feature switches + # --------------------------------------------------------------------- # + enable_shared_queue: bool = False # Use shared memory queue for fast IPC + disable_radix_cache: bool = False # Disable radix-tree prefix caching (uses ChunkCache) + radix_cache_page_size: int = 1 # Number of tokens per KV-pool page in RadixCache + enable_mamba_cache: bool = False # Use MambaRadixCache for SSM state caching + + # CUDA IPC transport for multimodal GPU tensors. + # Requires enable_shared_queue=True to take effect. + # + # Three transport modes (mutually exclusive for GPU tensors): + # + # "default" + # GPU tensors are moved to CPU first (GPU→CPU copy), then placed in + # POSIX shared memory via share_memory_(). Safe but adds a device copy. + # + # "cuda_ipc" + # GPU tensors stay on GPU. Each tensor is wrapped in a + # TransportProxyTensor whose __getstate__ calls storage._share_cuda_() + # to obtain an IPC handle; the receiver reconstructs via + # UntypedStorage._new_shared_cuda(*handle). Simple, but the underlying + # GPU allocation is never freed until the sender process exits + # (PyTorch limitation) -- can leak GPU memory in long-running services. + # + # "cuda_ipc_pool" [recommended for production] + # GPU tensors are copied into a pre-allocated fixed-size GPU workspace + # (MmItemMemoryPool). Each outgoing tensor occupies a "chunk" of the + # pool; the chunk's IPC handle is sent via CudaIpcTensorTransportProxy. + # After the receiver finishes copying data it increments a shared-memory + # sync flag; a background recycler thread in the sender watches these + # flags and returns chunks to the available pool. No GPU memory is leaked. + tensor_transport_mode: str = "default" # one of: default, cuda_ipc, cuda_ipc_pool + + # Size of the pre-allocated CUDA IPC memory pool in MB. + # Only used when tensor_transport_mode == "cuda_ipc_pool". + cuda_ipc_pool_size_mb: int = 512 + + # How often (seconds) the pool recycler thread wakes up. + cuda_ipc_recycle_interval: float = 0.1 + # enable_lora: bool = False + # max_loaded_loras: Optional[int] = None + # max_loras_per_batch: int = 8 + # lora_backend: Literal["triton", "csgmv", "torch_native"] = "csgmv" + # enable_multimodal: bool = False + # speculative_algorithm: Optional[str] = None + # speculative_draft_model_path: Optional[Path] = None + # speculative_num_steps: Optional[int] = None + # speculative_num_draft_tokens: Optional[int] = None + + # --------------------------------------------------------------------- # + # Extra + # --------------------------------------------------------------------- # + extra_options: dict[str, Any] = field(default_factory=dict) + + def __post_init__(self) -> None: + if self.tokenizer_path is None: + self.tokenizer_path = self.model_path + if self.served_model_name is None: + self.served_model_name = str(self.model_path) + self._validate() + + def _validate(self) -> None: + valid_modes = {"default", "cuda_ipc", "cuda_ipc_pool"} + if self.tensor_transport_mode not in valid_modes: + raise ValueError( + f"`tensor_transport_mode` must be one of {valid_modes}, " + f"got {self.tensor_transport_mode!r}." + ) + if self.tensor_transport_mode != "default" and not self.enable_shared_queue: + raise ValueError( + "`tensor_transport_mode` != 'default' requires `enable_shared_queue=True`." + ) + if self.cuda_ipc_pool_size_mb <= 0: + raise ValueError("`cuda_ipc_pool_size_mb` must be > 0.") + if self.port <= 0 or self.port > 65535: + raise ValueError("`port` must be in range [1, 65535].") + if self.max_prefill_tokens is not None and self.max_prefill_tokens <= 0: + raise ValueError("`max_prefill_tokens` must be > 0.") + if self.stream_interval <= 0: + raise ValueError("`stream_interval` must be > 0.") + if self.mem_fraction_static is not None and not ( + 0.0 < self.mem_fraction_static < 1.0 + ): + raise ValueError("`mem_fraction_static` must be in (0.0, 1.0).") + if self.max_running_requests is not None and self.max_running_requests <= 0: + raise ValueError("`max_running_requests` must be > 0 when set.") + if self.max_queued_requests is not None and self.max_queued_requests < 0: + raise ValueError("`max_queued_requests` must be >= 0 when set.") + if self.radix_cache_page_size < 1: + raise ValueError("`radix_cache_page_size` must be >= 1.") + if self.schedule_conservativeness <= 0: + raise ValueError("`schedule_conservativeness` must be > 0.") diff --git a/pymllm/engine/__init__.py b/pymllm/engine/__init__.py new file mode 100644 index 000000000..50f2b7249 --- /dev/null +++ b/pymllm/engine/__init__.py @@ -0,0 +1,8 @@ +"""Engine module for pymllm.""" + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + +__all__ = [ + "ForwardBatch", + "ForwardMode", +] diff --git a/pymllm/engine/forward_batch.py b/pymllm/engine/forward_batch.py new file mode 100644 index 000000000..428da7b66 --- /dev/null +++ b/pymllm/engine/forward_batch.py @@ -0,0 +1,191 @@ +"""ForwardMode and ForwardBatch for pymllm. + +Simplified forward-batch abstraction: no speculative decoding, no +encoder-decoder support, and no distributed-attention complexity (DP/TP +head splitting is handled at the layer level by the model code, not here). + +Typical data flow +----------------- + ModelRunner builds a ForwardBatch + ↓ + attn_backend.init_forward_metadata(forward_batch) + ↓ + model.forward(input_ids, positions, forward_batch) + ↓ + RadixAttention.forward(q, k, v, forward_batch) + ↓ + forward_batch.attn_backend.forward(q, k, v, layer, forward_batch) +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from enum import IntEnum, auto +from typing import TYPE_CHECKING, List, Optional + +import torch + +if TYPE_CHECKING: + from pymllm.layers.attention.attention_backend import AttentionBackend + from pymllm.mem_cache.memory_pool import KVPool, ReqToTokenPool + + +# --------------------------------------------------------------------------- +# ForwardMode +# --------------------------------------------------------------------------- + + +class ForwardMode(IntEnum): + """Describes what kind of forward pass is being performed. + + Covers standard prefill / decode inference without speculative decoding. + """ + + # Prefill / extend: process new tokens. The KV cache of the prefix (if + # any) is already populated (e.g. shared system-prompt via radix cache). + EXTEND = auto() + + # Decode: generate exactly one new token per sequence. + DECODE = auto() + + # Mixed: a chunked-prefill batch that contains both extend and decode + # sequences simultaneously. + MIXED = auto() + + # Idle: no sequences to process (used with data-parallel workers when some + # ranks have no allocated sequences). + IDLE = auto() + + # ---- helpers ---- + + def is_extend(self) -> bool: + """True for EXTEND or MIXED (i.e. any prefill-style pass).""" + return self in (ForwardMode.EXTEND, ForwardMode.MIXED) + + def is_prefill(self) -> bool: + """Alias for ``is_extend()``.""" + return self.is_extend() + + def is_decode(self) -> bool: + return self == ForwardMode.DECODE + + def is_mixed(self) -> bool: + return self == ForwardMode.MIXED + + def is_idle(self) -> bool: + return self == ForwardMode.IDLE + + def is_decode_or_idle(self) -> bool: + return self == ForwardMode.DECODE or self == ForwardMode.IDLE + + +# --------------------------------------------------------------------------- +# ForwardBatch +# --------------------------------------------------------------------------- + + +@dataclass +class ForwardBatch: + """All tensors required by a single forward pass through the model. + + Parameters + ---------- + forward_mode + The kind of pass being performed (EXTEND / DECODE / MIXED / IDLE). + batch_size + Number of sequences in the batch. + input_ids + Token ids for every position in the batch, shape ``[num_tokens]``. + For decode, ``num_tokens == batch_size``; for extend, + ``num_tokens == extend_num_tokens``. + req_pool_indices + Index of each sequence in ``ReqToTokenPool``, shape ``[batch_size]`` + (int32 or int64, on the target device). + seq_lens + Total (prefix + new) length of each sequence, shape ``[batch_size]`` + (int32). + out_cache_loc + KV-pool slot that each *output* token is written to, shape + ``[num_tokens]`` (int64). + seq_lens_sum + Python ``int`` equal to ``seq_lens.sum()``. Cached to avoid repeated + device-to-host syncs. + seq_lens_cpu + CPU copy of ``seq_lens`` (optional; used by some attention backends + for plan computation without a device sync). + positions + Token position for each input token, shape ``[num_tokens]`` + (int32 or int64). + extend_num_tokens + Total number of new (non-prefix) tokens across the batch. Only set + during EXTEND / MIXED passes. + extend_seq_lens + Number of *new* tokens for each sequence, shape ``[batch_size]`` + (int32). Only set during EXTEND / MIXED. + extend_prefix_lens + Length of the already-cached prefix for each sequence, + shape ``[batch_size]`` (int32). Only set during EXTEND / MIXED. + extend_start_loc + Cumulative start offset of each sequence in the flattened extend + token stream, shape ``[batch_size]`` (int32). + extend_prefix_lens_cpu + CPU list mirror of ``extend_prefix_lens``. + extend_seq_lens_cpu + CPU list mirror of ``extend_seq_lens``. + return_logprob + Whether to compute per-token log-probabilities. + top_logprobs_nums + Number of top log-probs to return per sequence (None or list of ints). + req_to_token_pool + Reference to the ``ReqToTokenPool`` (set by the model runner). + token_to_kv_pool + Reference to the ``KVPool`` (set by the model runner). + attn_backend + The attention backend to use (set by the model runner before calling + ``model.forward``). + """ + + # ---- required fields (positional) ---- + forward_mode: ForwardMode + batch_size: int + input_ids: torch.Tensor # [num_tokens] + req_pool_indices: torch.Tensor # [batch_size] int32/int64 + seq_lens: torch.Tensor # [batch_size] int32 + out_cache_loc: torch.Tensor # [num_tokens] int64 + seq_lens_sum: int # python int + + # ---- optional metadata ---- + + # CPU mirror of seq_lens + seq_lens_cpu: Optional[torch.Tensor] = None + + # Position encoding – shape [num_tokens], int32 or int64 + positions: Optional[torch.Tensor] = None + + # ---- extend / prefill specific ---- + extend_num_tokens: Optional[int] = None + extend_seq_lens: Optional[torch.Tensor] = None # [batch_size] int32 + extend_prefix_lens: Optional[torch.Tensor] = None # [batch_size] int32 + extend_start_loc: Optional[torch.Tensor] = None # [batch_size] int32 + extend_prefix_lens_cpu: Optional[List[int]] = None + extend_seq_lens_cpu: Optional[List[int]] = None + + # ---- logprob options ---- + return_logprob: bool = False + top_logprobs_nums: Optional[List[int]] = None + + # ---- memory pools (set by model runner) ---- + req_to_token_pool: Optional["ReqToTokenPool"] = None + token_to_kv_pool: Optional["KVPool"] = None + + # ---- attention backend (set by model runner) ---- + attn_backend: Optional["AttentionBackend"] = None + + # ---- multimodal M-RoPE ---- + # Per-request position delta for M-RoPE decode steps. + # Set by the model during prefill; consumed during decode to offset positions. + mrope_position_deltas: Optional[torch.Tensor] = None # [batch_size] int64 + + # ---- multimodal vision inputs (extend / prefill only) ---- + pixel_values: Optional[torch.Tensor] = None + image_grid_thw: Optional[torch.Tensor] = None diff --git a/pymllm/engine/io_struct.py b/pymllm/engine/io_struct.py new file mode 100644 index 000000000..06c8d78d6 --- /dev/null +++ b/pymllm/engine/io_struct.py @@ -0,0 +1,205 @@ +from __future__ import annotations + +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, Iterator, List, Optional, Union + + +@dataclass +class BaseReq: + rid: Optional[Union[str, List[str]]] = field(default=None, kw_only=True) + + def regenerate_rid(self) -> Union[str, List[str]]: + if isinstance(self.rid, list): + self.rid = [uuid.uuid4().hex for _ in range(len(self.rid))] + else: + self.rid = uuid.uuid4().hex + return self.rid + + +@dataclass +class BaseBatchReq: + rids: List[str] + + def regenerate_rids(self) -> List[str]: + self.rids = [uuid.uuid4().hex for _ in range(len(self.rids))] + return self.rids + + +@dataclass +class GenerateReqInput(BaseReq): + text: Optional[Union[List[str], str]] = None + input_ids: Optional[Union[List[List[int]], List[int]]] = None + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None + top_logprobs_num: Optional[Union[List[int], int]] = None + stream: bool = False + + # Multimodal placeholders. + image_data: Optional[Any] = None + video_data: Optional[Any] = None + audio_data: Optional[Any] = None + + # Runtime extension placeholders. + lora_path: Optional[Union[List[Optional[str]], str]] = None + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + extra_options: Dict[str, Any] = field(default_factory=dict) + + # Derived fields populated by normalization. + is_single: bool = field(default=True, init=False) + batch_size: int = field(default=1, init=False) + + def normalize_batch_and_arguments(self) -> None: + self._validate_inputs() + self._determine_batch_size() + + def _validate_inputs(self) -> None: + has_text = self.text is not None + has_input_ids = self.input_ids is not None + if has_text == has_input_ids: + raise ValueError("Exactly one of `text` or `input_ids` must be provided.") + + def _determine_batch_size(self) -> None: + if self.text is not None: + if isinstance(self.text, str): + self.is_single = True + self.batch_size = 1 + else: + if len(self.text) == 0: + raise ValueError("`text` cannot be an empty list.") + self.is_single = False + self.batch_size = len(self.text) + return + + assert self.input_ids is not None + if len(self.input_ids) == 0: + raise ValueError("`input_ids` cannot be empty.") + if isinstance(self.input_ids[0], int): + self.is_single = True + self.batch_size = 1 + else: + self.is_single = False + self.batch_size = len(self.input_ids) + + def __getitem__(self, i: int) -> "GenerateReqInput": + if i < 0 or i >= self.batch_size: + raise IndexError(f"index {i} out of range for batch size {self.batch_size}") + if self.batch_size == 1: + return self + return GenerateReqInput( + rid=self._pick(self.rid, i), + text=self._pick(self.text, i), + input_ids=self._pick(self.input_ids, i), + sampling_params=self._pick(self.sampling_params, i), + return_logprob=self._pick(self.return_logprob, i), + logprob_start_len=self._pick(self.logprob_start_len, i), + top_logprobs_num=self._pick(self.top_logprobs_num, i), + stream=self.stream, + image_data=self._pick(self.image_data, i), + video_data=self._pick(self.video_data, i), + audio_data=self._pick(self.audio_data, i), + lora_path=self._pick(self.lora_path, i), + session_params=self._pick(self.session_params, i), + extra_options=self.extra_options.copy(), + ) + + @staticmethod + def _pick(value: Any, i: int) -> Any: + if isinstance(value, list): + return value[i] + return value + + def to_request_dict(self) -> Dict[str, Any]: + payload: Dict[str, Any] = {} + for key, value in { + "rid": self.rid, + "text": self.text, + "input_ids": self.input_ids, + "sampling_params": self.sampling_params, + "return_logprob": self.return_logprob, + "logprob_start_len": self.logprob_start_len, + "top_logprobs_num": self.top_logprobs_num, + "stream": self.stream, + "image_data": self.image_data, + "video_data": self.video_data, + "audio_data": self.audio_data, + "lora_path": self.lora_path, + "session_params": self.session_params, + }.items(): + if value is not None: + payload[key] = value + payload.update(self.extra_options) + return payload + + +@dataclass +class TokenizedGenerateReqInput(BaseReq): + # The decoded text passed to the tokenizer (empty string if only input_ids + # were provided by the caller). + input_text: str = "" + # Token IDs produced by the tokenizer. + input_ids: List[int] = field(default_factory=list) + # Multimodal inputs (processor output, e.g. pixel_values, or raw image / + # audio / video data when no processor is available). ``None`` means the + # request is text-only. + mm_inputs: Optional[Dict[str, Any]] = None + # Raw sampling parameters dict (parsed into a SamplingParams object by the + # model runner when needed). + sampling_params: Dict[str, Any] = field(default_factory=dict) + stream: bool = False + return_logprob: bool = False + logprob_start_len: int = -1 + top_logprobs_num: int = 0 + lora_path: Optional[str] = None + session_params: Optional[Dict[str, Any]] = None + + +@dataclass +class BatchTokenizedGenerateReqInput(BaseBatchReq): + reqs: List[TokenizedGenerateReqInput] + + def __len__(self) -> int: + return len(self.reqs) + + def __getitem__(self, i: int) -> TokenizedGenerateReqInput: + return self.reqs[i] + + def __iter__(self) -> Iterator[TokenizedGenerateReqInput]: + return iter(self.reqs) + + +@dataclass +class BatchTokenIDOutput(BaseBatchReq): + finished_reasons: List[Optional[str]] + decode_ids: List[int] + read_offsets: List[int] + output_ids: Optional[List[int]] + skip_special_tokens: List[bool] + prompt_tokens: List[int] + completion_tokens: List[int] + input_token_logprobs_val: List[float] = field(default_factory=list) + input_token_logprobs_idx: List[int] = field(default_factory=list) + output_token_logprobs_val: List[float] = field(default_factory=list) + output_token_logprobs_idx: List[int] = field(default_factory=list) + input_top_logprobs_val: List[List[float]] = field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = field(default_factory=list) + output_top_logprobs_val: List[List[float]] = field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = field(default_factory=list) + + +@dataclass +class BatchStrOutput(BaseBatchReq): + finished_reasons: List[Optional[str]] + output_strs: List[str] + output_ids: Optional[List[int]] + prompt_tokens: List[int] + completion_tokens: List[int] + input_token_logprobs_val: List[float] = field(default_factory=list) + input_token_logprobs_idx: List[int] = field(default_factory=list) + output_token_logprobs_val: List[float] = field(default_factory=list) + output_token_logprobs_idx: List[int] = field(default_factory=list) + input_top_logprobs_val: List[List[float]] = field(default_factory=list) + input_top_logprobs_idx: List[List[int]] = field(default_factory=list) + output_top_logprobs_val: List[List[float]] = field(default_factory=list) + output_top_logprobs_idx: List[List[int]] = field(default_factory=list) diff --git a/pymllm/engine/launch.py b/pymllm/engine/launch.py new file mode 100644 index 000000000..8fd39caab --- /dev/null +++ b/pymllm/engine/launch.py @@ -0,0 +1,658 @@ +import asyncio +import atexit +import logging +import os +import threading +import time +import uuid +from pathlib import Path +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import torch +import torch.multiprocessing as mp +from transformers import AutoConfig +from huggingface_hub import snapshot_download + +try: + from pyfiglet import figlet_format + from termcolor import colored + + HAS_BANNER_LIBS = True +except ImportError: + HAS_BANNER_LIBS = False + +from pymllm.configs import get_global_config +from pymllm.engine.io_struct import GenerateReqInput +from pymllm.orchestrator.ipc_utils import cleanup_ipc_files, make_ipc_address +from pymllm.orchestrator.request_response_process import ( + ReqState, + RequestResponseProcess, +) +from pymllm.orchestrator.tokenizer_process import run_tokenizer_process +from pymllm.orchestrator.scheduler_process import run_scheduler_process +from pymllm.orchestrator.detokenizer_process import run_detokenizer_process + +logger = logging.getLogger(__name__) + +# Standard HuggingFace config fields that indicate max output tokens, +# checked in priority order. +_MAX_NEW_TOKENS_FIELDS = ( + "max_new_tokens", + "max_tokens", + "max_completion_tokens", +) + + +def _normalize_eos_raw(raw) -> List[int]: + """Normalize a raw eos_token_id value (int, list, or None) to a list.""" + if raw is None: + return [] + if isinstance(raw, int): + return [raw] + if isinstance(raw, (list, tuple)): + return [x for x in raw if isinstance(x, int)] + return [] + + +def _get_eos_token_ids(hf_config, model_path=None) -> List[int]: + """Extract EOS token ID(s) from a HuggingFace model config. + + Searches in priority order: + 1. ``hf_config.eos_token_id`` (top-level, standard models) + 2. ``hf_config.text_config.eos_token_id`` (VL / multimodal models) + 3. ``generation_config.json`` (many models store EOS here) + 4. ``tokenizer_config.json`` via AutoTokenizer (last resort) + """ + if hf_config is None: + return [] + + # 1. Top-level config + ids = _normalize_eos_raw(getattr(hf_config, "eos_token_id", None)) + if ids: + return ids + + # 2. Nested text_config (VL / multimodal models like Qwen3-VL) + text_config = getattr(hf_config, "text_config", None) + if text_config is not None: + ids = _normalize_eos_raw(getattr(text_config, "eos_token_id", None)) + if ids: + return ids + + # 3. generation_config.json (lightweight, just reads a JSON file) + if model_path is not None: + try: + from transformers import GenerationConfig + + gen_cfg = GenerationConfig.from_pretrained(str(model_path)) + ids = _normalize_eos_raw(getattr(gen_cfg, "eos_token_id", None)) + if ids: + logger.info("EOS token IDs from generation_config.json: %s", ids) + return ids + except Exception: + pass + + # 4. Tokenizer (last resort) + if model_path is not None: + try: + from transformers import AutoTokenizer + + tok = AutoTokenizer.from_pretrained(str(model_path), trust_remote_code=True) + if tok.eos_token_id is not None: + ids = [tok.eos_token_id] + logger.info("EOS token ID from tokenizer: %s", ids) + return ids + except Exception: + pass + + return [] + + +def _get_model_default_max_new_tokens(hf_config) -> Optional[int]: + """Extract max output token limit from a HuggingFace model config. + + Checks standard fields in priority order. Returns ``None`` when the + config does not specify any recognised output-length field. + """ + if hf_config is None: + return None + for field_name in _MAX_NEW_TOKENS_FIELDS: + value = getattr(hf_config, field_name, None) + if value is not None and isinstance(value, int) and value > 0: + logger.info( + "Using model config %s=%d as default max_new_tokens", + field_name, + value, + ) + return value + return None + + +class Engine: + def __init__(self): + self._subprocesses: List[mp.Process] = [] + self._rr_process: Optional[RequestResponseProcess] = None + self._ipc_uid: Optional[str] = None + self._subprocess_healthy: bool = True + self._config_logging() + self._set_default_torch_dtype() + self._check_model_and_tokenizer() + + @property + def is_healthy(self) -> bool: + """True if engine and all subprocesses are alive.""" + return self._subprocess_healthy + + def launch(self) -> None: + self._launch_processes() + self._start_health_monitor() + atexit.register(self.shutdown) + + def _start_health_monitor(self) -> None: + """Start a daemon thread that checks subprocess liveness.""" + + def _monitor(): + while self._subprocess_healthy: + for proc in self._subprocesses: + if not proc.is_alive(): + logger.error( + "Subprocess pid=%s died unexpectedly (exitcode=%s)", + proc.pid, + proc.exitcode, + ) + self._subprocess_healthy = False + return + time.sleep(5) + + t = threading.Thread(target=_monitor, daemon=True, name="engine-health-monitor") + t.start() + + def _launch_processes(self) -> None: + """Spawn all subprocess workers and wire up ZMQ IPC channels.""" + mp.set_start_method("spawn", force=True) + uid = str(os.getpid()) + self._ipc_uid = uid + + # IPC addresses for ZMQ communication between processes + addr_request_response_to_tokenizer: str = make_ipc_address( + "request_response_to_tokenizer", uid + ) + addr_tokenizer_to_scheduler: str = make_ipc_address( + "tokenizer_to_scheduler", uid + ) + addr_scheduler_to_detokenizer: str = make_ipc_address( + "scheduler_to_detokenizer", uid + ) + addr_detokenizer_to_request_response: str = make_ipc_address( + "detokenizer_to_request_response", uid + ) + # Record all subprocesses + procs_and_readers: List[tuple] = [] + + # Config dict for the tokenizer subprocess (must be picklable). + cfg = get_global_config() + enable_shared_queue = cfg.server.enable_shared_queue + transport_mode: str = ( + cfg.server.tensor_transport_mode + ) # "default" | "cuda_ipc" | "cuda_ipc_pool" + + # Create shared queue if enabled. + # Note: the MmItemMemoryPool (for "cuda_ipc_pool") is created *inside* + # the tokenizer subprocess after CUDA is initialised. The queue here + # is constructed without a pool; TokenizerProcess._ensure_pool() will + # swap in a pool-aware queue at runtime. + shared_queue = None + if enable_shared_queue: + from pymllm.orchestrator.shared_memory_queue import TensorQueue as _TQ + + # Construct with the configured transport mode. The pool is not + # supplied here; it will be lazily initialised inside the subprocess. + shared_queue = _TQ( + maxsize=1000, + transport_mode=transport_mode, + pool=None, # pool initialised lazily inside TokenizerProcess + ) + logger.info( + "Shared memory queue enabled for fast IPC (transport_mode=%s)", + transport_mode, + ) + + tokenizer_cfg: Dict[str, Any] = { + "tokenizer_path": str(cfg.server.tokenizer_path), + "tokenizer_mode": cfg.server.tokenizer_mode, + "trust_remote_code": cfg.server.trust_remote_code, + "context_length": cfg.server.context_length, + "hf_config": cfg.model.hf_config, + "enable_shared_queue": enable_shared_queue, + "tensor_transport_mode": transport_mode, + "cuda_ipc_pool_size_mb": cfg.server.cuda_ipc_pool_size_mb, + "cuda_ipc_recycle_interval": cfg.server.cuda_ipc_recycle_interval, + "log_level": cfg.server.log_level, + } + + # Tokenizer + tokenizer_reader, tokenizer_writer = mp.Pipe(duplex=False) + tokenizer_proc = mp.Process( + target=run_tokenizer_process, + args=( + addr_request_response_to_tokenizer, + addr_tokenizer_to_scheduler, + tokenizer_writer, + tokenizer_cfg, + shared_queue, # Pass shared queue + ), + daemon=True, + ) + procs_and_readers.append((tokenizer_proc, tokenizer_reader, "tokenizer")) + + # Determine default max_new_tokens from model config (if available) + model_max_new_tokens = _get_model_default_max_new_tokens(cfg.model.hf_config) + scheduler_kwargs = {} + if model_max_new_tokens is not None: + scheduler_kwargs["default_max_new_tokens"] = model_max_new_tokens + + # Extract EOS token ID(s) from model config + eos_token_ids = _get_eos_token_ids( + cfg.model.hf_config, model_path=cfg.server.model_path + ) + if eos_token_ids: + scheduler_kwargs["eos_token_ids"] = eos_token_ids + logger.info("EOS token IDs for scheduler: %s", eos_token_ids) + + # Model runner config — passed to the scheduler process which now + # owns the model runner in-process. + scheduler_kwargs["server_config"] = cfg.server + scheduler_kwargs["model_config"] = cfg.model + scheduler_kwargs["gpu_id"] = cfg.server.base_gpu_id + + # Scheduler (+ in-process model runner) + scheduler_reader, scheduler_writer = mp.Pipe(duplex=False) + scheduler_proc = mp.Process( + target=run_scheduler_process, + args=( + addr_tokenizer_to_scheduler, + addr_scheduler_to_detokenizer, + scheduler_writer, + shared_queue, # Pass shared queue + enable_shared_queue, # Pass flag + transport_mode, # Pass tensor transport mode + cfg.server.log_level, # Pass log level + ), + kwargs=scheduler_kwargs, + daemon=True, + ) + procs_and_readers.append((scheduler_proc, scheduler_reader, "scheduler")) + + # Detokenizer + detokenizer_reader, detokenizer_writer = mp.Pipe(duplex=False) + detokenizer_proc = mp.Process( + target=run_detokenizer_process, + args=( + addr_scheduler_to_detokenizer, + addr_detokenizer_to_request_response, + detokenizer_writer, + tokenizer_cfg, + ), + daemon=True, + ) + procs_and_readers.append((detokenizer_proc, detokenizer_reader, "detokenizer")) + + # Start all subprocesses + for proc, _, name in procs_and_readers: + proc.start() + self._subprocesses.append(proc) + logger.info("Started %s process (pid=%s)", name, proc.pid) + + # Wait for readiness signals + for _, reader, name in procs_and_readers: + try: + msg = reader.recv() + except EOFError: + raise RuntimeError(f"{name} process died before signalling readiness") + if msg.get("status") != "ready": + raise RuntimeError(f"{name} process failed to initialise: {msg}") + logger.info("%s process ready", name) + + # RR Process is current main process — only bind ZMQ sockets here. + # Background tasks are started lazily by listen() on the first + # add_request(), so they always run on the correct event loop. + self._rr_process = RequestResponseProcess( + send_to_tokenizer_addr=addr_request_response_to_tokenizer, + recv_from_detokenizer_addr=addr_detokenizer_to_request_response, + ) + self._rr_process.start() + logger.info("RequestResponseProcess sockets bound") + + # Print colorful gradient ASCII art banner + if HAS_BANNER_LIBS: + try: + text = figlet_format("pymllm", font="slant") + fired_up = figlet_format("FIRED UP!", font="slant") + + # Apply blue-purple gradient + lines = text.strip().split("\n") + colors_cycle = ["blue", "cyan", "blue", "magenta", "magenta"] + for i, line in enumerate(lines): + color = colors_cycle[i % len(colors_cycle)] + print(colored(line, color, attrs=["bold"])) + + # Print "FIRED UP!" in bright magenta + for line in fired_up.strip().split("\n"): + print(colored(line, "magenta", attrs=["bold"])) + print() + except Exception as e: + logger.debug("Failed to print banner: %s", e) + print("🚀 pymllm FIRED UP! 🚀\n") + else: + print("🚀 pymllm FIRED UP! 🚀\n") + + def generate( + self, + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + image_data: Optional[Any] = None, + audio_data: Optional[Any] = None, + video_data: Optional[Any] = None, + return_logprob: Optional[Union[List[bool], bool]] = None, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[Union[List[Optional[str]], str]] = None, + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + stream: bool = False, + rid: Optional[Union[List[str], str]] = None, + **kwargs, + ) -> Union[Dict[str, Any], List[Dict[str, Any]]]: + """Synchronous, non-streaming generation entry point. + + Accepts a single prompt (``str``) or a batch (``List[str]``). Returns a + single result dict for single inputs and a list of result dicts for batch + inputs, preserving the input order. + """ + rid = self._make_rids(rid, prompt, input_ids) + request = GenerateReqInput( + rid=rid, + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + stream=stream, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + lora_path=lora_path, + session_params=session_params, + extra_options=kwargs, + ) + request.normalize_batch_and_arguments() + + async def _run() -> Union[Dict[str, Any], List[Dict[str, Any]]]: + max_queued = get_global_config().server.max_queued_requests + result = await self._rr_process.add_request(request, max_queued=max_queued) + if request.is_single: + single_rid = rid if isinstance(rid, str) else rid[0] + return await self._wait_for_final_result(single_rid, result) # type: ignore[arg-type] + # Batch: wait for every sub-request concurrently. + rids_list: List[str] = rid if isinstance(rid, list) else [rid] # type: ignore[assignment] + states: List[ReqState] = result # type: ignore[assignment] + outputs = await asyncio.gather( + *(self._wait_for_final_result(r, s) for r, s in zip(rids_list, states)) + ) + return list(outputs) + + loop = asyncio.new_event_loop() + try: + return loop.run_until_complete(_run()) + finally: + loop.close() + + async def generate_async( + self, + prompt: Optional[Union[List[str], str]] = None, + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + input_ids: Optional[Union[List[List[int]], List[int]]] = None, + image_data: Optional[Any] = None, + audio_data: Optional[Any] = None, + video_data: Optional[Any] = None, + return_logprob: Optional[Union[List[bool], bool]] = None, + logprob_start_len: Optional[Union[List[int], int]] = None, + top_logprobs_num: Optional[Union[List[int], int]] = None, + lora_path: Optional[Union[List[Optional[str]], str]] = None, + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None, + stream: bool = False, + rid: Optional[Union[List[str], str]] = None, + **kwargs, + ) -> AsyncIterator[Dict[str, Any]]: + """Asynchronous generation entry point. + + For a **single** request and ``stream=False`` yields one final result + dict; with ``stream=True`` yields incremental chunks. + + For a **batch** request the iterator yields the final result for each + sub-request as it completes (order not guaranteed); streaming mode yields + incremental chunks from all sub-requests interleaved. + """ + rid = self._make_rids(rid, prompt, input_ids) + request = GenerateReqInput( + rid=rid, + text=prompt, + input_ids=input_ids, + sampling_params=sampling_params, + return_logprob=return_logprob, + logprob_start_len=logprob_start_len, + top_logprobs_num=top_logprobs_num, + stream=stream, + image_data=image_data, + audio_data=audio_data, + video_data=video_data, + lora_path=lora_path, + session_params=session_params, + extra_options=kwargs, + ) + request.normalize_batch_and_arguments() + max_queued = get_global_config().server.max_queued_requests + result = await self._rr_process.add_request(request, max_queued=max_queued) + + if request.is_single: + single_rid = rid if isinstance(rid, str) else rid[0] # type: ignore[index] + state: ReqState = result # type: ignore[assignment] + try: + if stream: + async for chunk in self._stream_results(single_rid, state): + yield chunk + else: + yield await self._wait_for_final_result(single_rid, state) + finally: + if not state.finished: + logger.info("Aborting request %s (client disconnected)", single_rid) + await self._rr_process.abort_request(single_rid) + else: + self._rr_process.remove_state(single_rid) + else: + rids_list: List[str] = rid if isinstance(rid, list) else [rid] # type: ignore[assignment] + states: List[ReqState] = result # type: ignore[assignment] + _bg_tasks: List[asyncio.Task] = [] + try: + if stream: + # Merge streams from all sub-requests using an asyncio queue. + queue: asyncio.Queue = asyncio.Queue() + + async def _forward(r: str, s: ReqState) -> None: + async for chunk in self._stream_results(r, s): + await queue.put(chunk) + await queue.put(None) # sentinel + + _bg_tasks = [ + asyncio.create_task(_forward(r, s)) + for r, s in zip(rids_list, states) + ] + done_count = 0 + while done_count < len(_bg_tasks): + item = await queue.get() + if item is None: + done_count += 1 + else: + yield item + await asyncio.gather(*_bg_tasks) + else: + for coro in asyncio.as_completed( + [ + self._wait_for_final_result(r, s) + for r, s in zip(rids_list, states) + ] + ): + yield await coro + finally: + for t in _bg_tasks: + t.cancel() + for r, s in zip(rids_list, states): + if not s.finished: + logger.info("Aborting request %s (client disconnected)", r) + await self._rr_process.abort_request(r) + else: + self._rr_process.remove_state(r) + + @staticmethod + async def _wait_for_final_result(rid: str, state: ReqState) -> Dict[str, Any]: + """Block until the request is finished and return the last output.""" + while True: + await state.event.wait() + if state.finished: + return state.out_list[-1] + state.event.clear() + + @staticmethod + async def _stream_results( + rid: str, state: ReqState + ) -> AsyncIterator[Dict[str, Any]]: + """Yield incremental chunks as they arrive, until finished.""" + while True: + await state.event.wait() + for item in state.out_list: + yield item + state.out_list.clear() + if state.finished: + return + state.event.clear() + + @staticmethod + def _make_rids( + rid: Optional[Union[str, List[str]]], + prompt: Optional[Union[str, List[str]]], + input_ids: Optional[Union[List[int], List[List[int]]]], + ) -> Union[str, List[str]]: + """Return rids, auto-generating UUIDs when *rid* is ``None``. + + The helper infers whether the call is a batch from *prompt* / *input_ids* + so callers don't have to handle this case themselves. + """ + if rid is not None: + return rid + # Determine batch size from the text/input_ids argument. + is_batch = isinstance(prompt, list) or ( + isinstance(input_ids, list) + and len(input_ids) > 0 + and isinstance(input_ids[0], list) + ) + if is_batch: + n = len(prompt) if prompt is not None else len(input_ids) # type: ignore[arg-type] + return [uuid.uuid4().hex for _ in range(n)] + return uuid.uuid4().hex + + def shutdown(self) -> None: + """Terminate all subprocesses.""" + if self._rr_process is not None: + try: + loop = asyncio.get_running_loop() + # Loop is running (e.g. called from uvicorn shutdown) — + # schedule cleanup as a fire-and-forget task. + loop.create_task(self._rr_process.shutdown()) + except RuntimeError: + # No running loop — create a temporary one for cleanup. + loop = asyncio.new_event_loop() + try: + loop.run_until_complete(self._rr_process.shutdown()) + finally: + loop.close() + for proc in self._subprocesses: + if proc.is_alive(): + proc.terminate() + proc.join(timeout=5) + if proc.is_alive(): + proc.kill() + self._subprocesses.clear() + # Clean up IPC socket files + if self._ipc_uid is not None: + cleanup_ipc_files(self._ipc_uid) + logger.info("All subprocesses shut down") + + def _set_default_torch_dtype(self): + """Set the default torch dtype based on the server configuration.""" + dtype = get_global_config().server.dtype + if dtype == "auto": + dtype = "bfloat16" if torch.cuda.is_available() else "float32" + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + torch_dtype = dtype_map.get(dtype) + if torch_dtype is None: + raise ValueError(f"Unsupported dtype for torch default dtype: {dtype!r}") + torch.set_default_dtype(torch_dtype) + + def _config_logging(self): + """Configure logging level from server configuration.""" + level_name = get_global_config().server.log_level.upper() + level = getattr(logging, level_name, logging.INFO) + root_logger = logging.getLogger() + if not root_logger.handlers: + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + else: + root_logger.setLevel(level) + logging.getLogger("pymllm").setLevel(level) + + def _check_model_and_tokenizer(self): + cfg = get_global_config() + if cfg.server.model_path is None or cfg.server.tokenizer_path is None: + logger.error("Model path or tokenizer path is not set") + raise ValueError("Model path or tokenizer path is not set") + model_path = cfg.server.model_path + tokenizer_path = cfg.server.tokenizer_path + download_dir = cfg.server.download_dir + trust_remote_code = cfg.server.trust_remote_code + + shared_path = model_path == tokenizer_path + + model_path = self._maybe_download(model_path, download_dir) + cfg.server.model_path = model_path + + if shared_path: + cfg.server.tokenizer_path = model_path + else: + cfg.server.tokenizer_path = self._maybe_download( + tokenizer_path, download_dir + ) + + cfg.model.hf_config = AutoConfig.from_pretrained( + str(model_path), + trust_remote_code=trust_remote_code, + ) + logger.info("Loaded model config: %s", cfg.model.hf_config.__class__.__name__) + + @staticmethod + def _maybe_download(path: Path, download_dir: Optional[Path] = None) -> Path: + if path.is_dir(): + return path + repo_id = str(path) + logger.info("Downloading '%s' ...", repo_id) + kwargs = {} + if download_dir is not None: + kwargs["local_dir"] = str(download_dir / path.name) + downloaded = snapshot_download(repo_id=repo_id, **kwargs) + logger.info("Downloaded '%s' to '%s'", repo_id, downloaded) + return Path(downloaded) diff --git a/pymllm/executor/__init__.py b/pymllm/executor/__init__.py new file mode 100644 index 000000000..b513b8705 --- /dev/null +++ b/pymllm/executor/__init__.py @@ -0,0 +1,10 @@ +"""Executor module: model loading, forward pass, and sampling.""" + +from pymllm.executor.cuda_graph_runner import CudaGraphRunner +from pymllm.executor.model_runner import LogitsProcessorOutput, ModelRunner + +__all__ = [ + "CudaGraphRunner", + "LogitsProcessorOutput", + "ModelRunner", +] diff --git a/pymllm/executor/cuda_graph_runner.py b/pymllm/executor/cuda_graph_runner.py new file mode 100644 index 000000000..7fa674b7b --- /dev/null +++ b/pymllm/executor/cuda_graph_runner.py @@ -0,0 +1,589 @@ +"""CUDA-graph accelerated forward pass for decode steps. + +Captures CUDA graphs for a set of discrete batch sizes so that the decode +forward pass can be replayed without CPU-side kernel-launch overhead. + +``CudaGraphRunner`` for pymllm's single-GPU architecture. Handles: + +* Pre-allocated input buffers (avoids per-step allocations) +* CUDA-graph capture for each batch size +* Optional ``torch.compile`` integration +* Graph replay with padding to the nearest captured batch size + +Typical lifecycle:: + + runner = CudaGraphRunner(model_runner) # captures all batch sizes + + # --- inside the inference loop --- + if runner.can_run(forward_batch): + logits_output = runner.replay(forward_batch) + else: + logits_output = model_runner.forward(forward_batch) + +Integration with :class:`~pymllm.executor.model_runner.ModelRunner` +------------------------------------------------------------------- +The ``ModelRunner`` owns the ``CudaGraphRunner`` and delegates decode +batches to it when the batch size is within the captured range. The +``CudaGraphRunner`` calls ``attn_backend.init_forward_metadata_*_cuda_graph`` +directly (bypassing the normal ``init_forward_metadata`` path) so that +FlashInfer's per-batch planning is recorded inside the graph. +""" + +from __future__ import annotations + +import bisect +import gc +import logging +import time +from contextlib import contextmanager +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Union + +import torch + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.executor.model_runner import LogitsProcessorOutput + +if TYPE_CHECKING: + from pymllm.executor.model_runner import ModelRunner + from pymllm.layers.attention.attention_backend import AttentionBackend + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Global CUDA-graph memory pool (shared across all CudaGraphRunner instances) +# --------------------------------------------------------------------------- + +_global_graph_memory_pool: Optional[tuple] = None + + +def get_global_graph_memory_pool() -> Optional[tuple]: + """Return the shared CUDA graph memory pool handle.""" + return _global_graph_memory_pool + + +def set_global_graph_memory_pool(pool: tuple) -> None: + """Set the shared CUDA graph memory pool handle.""" + global _global_graph_memory_pool + _global_graph_memory_pool = pool + + +# --------------------------------------------------------------------------- +# Context managers +# --------------------------------------------------------------------------- + +# Flag indicating whether we are currently capturing a CUDA graph. +_is_capture_mode: bool = False + + +def is_capture_mode() -> bool: + """Return ``True`` if a CUDA-graph capture is in progress.""" + return _is_capture_mode + + +@contextmanager +def model_capture_mode(): + """Context manager that sets the global capture-mode flag.""" + global _is_capture_mode + _is_capture_mode = True + try: + yield + finally: + _is_capture_mode = False + + +@contextmanager +def freeze_gc(): + """Freeze the garbage collector during CUDA-graph capture. + + GC activity during capture can interfere with the recorded stream + ordering. This context manager collects garbage before capture, + freezes all surviving objects, and unfreezes + re-collects afterwards. + """ + gc.collect() + gc.freeze() + try: + yield + finally: + gc.unfreeze() + gc.collect() + + +# --------------------------------------------------------------------------- +# Pre-allocated input buffers +# --------------------------------------------------------------------------- + + +@dataclass +class _InputBuffers: + """Pre-allocated GPU tensors used as CUDA-graph inputs. + + During graph capture these buffers are used as-is. During replay the + real batch data is copied into the first ``batch_size`` rows while the + remaining padding rows retain their fill values. + """ + + input_ids: torch.Tensor # [max_bs] int64 + req_pool_indices: torch.Tensor # [max_bs] int32 + seq_lens: torch.Tensor # [max_bs] int32 + seq_lens_cpu: torch.Tensor # [max_bs] int32 (CPU) + out_cache_loc: torch.Tensor # [max_bs] int64 + positions: torch.Tensor # [max_bs] int64 + mrope_position_deltas: torch.Tensor # [max_bs] int64 + + @classmethod + def create( + cls, + *, + device: torch.device, + max_bs: int, + seq_len_fill_value: int, + ) -> "_InputBuffers": + """Allocate all buffers for the given maximum batch size.""" + with torch.device(device): + input_ids = torch.zeros((max_bs,), dtype=torch.int64) + req_pool_indices = torch.zeros((max_bs,), dtype=torch.int32) + seq_lens = torch.full((max_bs,), seq_len_fill_value, dtype=torch.int32) + out_cache_loc = torch.zeros((max_bs,), dtype=torch.int64) + positions = torch.zeros((max_bs,), dtype=torch.int64) + mrope_position_deltas = torch.zeros((max_bs,), dtype=torch.int64) + + # seq_lens_cpu must be a real CPU tensor. + seq_lens_cpu = torch.full( + (max_bs,), + seq_len_fill_value, + dtype=torch.int32, + device="cpu", + ) + + return cls( + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_cpu=seq_lens_cpu, + out_cache_loc=out_cache_loc, + positions=positions, + mrope_position_deltas=mrope_position_deltas, + ) + + def populate( + self, + forward_batch: ForwardBatch, + padded_bs: int, + seq_len_fill_value: int, + ) -> None: + """Copy real batch data into the pre-allocated buffers. + + Any padding slots (``[real_bs : padded_bs]``) are filled with safe + defaults so that the captured graph does not access invalid memory. + """ + real_bs = forward_batch.batch_size + + # Reset padding slots when the padded size exceeds the real size. + if padded_bs != real_bs: + self.seq_lens.fill_(seq_len_fill_value) + self.out_cache_loc.zero_() + self.mrope_position_deltas.zero_() + + self.input_ids[:real_bs].copy_(forward_batch.input_ids) + self.req_pool_indices[:real_bs].copy_(forward_batch.req_pool_indices) + self.seq_lens[:real_bs].copy_(forward_batch.seq_lens) + self.out_cache_loc[:real_bs].copy_(forward_batch.out_cache_loc) + self.positions[:real_bs].copy_(forward_batch.positions) + + # Copy M-RoPE position deltas (used by Qwen3-VL for multimodal). + if forward_batch.mrope_position_deltas is not None: + self.mrope_position_deltas[:real_bs].copy_( + forward_batch.mrope_position_deltas + ) + else: + self.mrope_position_deltas[:real_bs].zero_() + + if forward_batch.seq_lens_cpu is not None: + if padded_bs != real_bs: + self.seq_lens_cpu.fill_(seq_len_fill_value) + self.seq_lens_cpu[:real_bs].copy_(forward_batch.seq_lens_cpu) + + +# --------------------------------------------------------------------------- +# Batch-size schedule +# --------------------------------------------------------------------------- + + +def _default_capture_batch_sizes(max_bs: int) -> List[int]: + """Return a list of batch sizes to capture. + + Uses the same schedule as sglang (non-speculative):: + + [1, 2, 4, 8, 12, 16, 24, 32, 40, …, 256, 272, 288, …, 512, 544, …] + + Capped at *max_bs*. + """ + bs_list = ( + [1, 2, 4, 8, 12] + + list(range(16, 257, 8)) + + list(range(272, 512, 16)) + + list(range(512, max_bs + 1, 32)) + ) + bs_list = sorted(set(bs for bs in bs_list if bs <= max_bs)) + if not bs_list: + bs_list = [1] + return bs_list + + +# --------------------------------------------------------------------------- +# CudaGraphRunner +# --------------------------------------------------------------------------- + + +class CudaGraphRunner: + """Captures and replays CUDA graphs for decode-step forward passes. + + This class is the pymllm equivalent of sglang's ``CudaGraphRunner``, + stripped of distributed, speculative-decoding, LoRA, mamba, TBO, and + piecewise-graph complexities. + + Parameters + ---------- + model_runner + The owning :class:`~pymllm.executor.model_runner.ModelRunner`. + Must have been fully initialised before the ``CudaGraphRunner`` + is constructed. + """ + + def __init__(self, model_runner: "ModelRunner"): + self.model_runner = model_runner + self.device = model_runner.device + + self.graphs: Dict[int, torch.cuda.CUDAGraph] = {} + self.output_buffers: Dict[int, LogitsProcessorOutput] = {} + + self.enable_torch_compile: bool = ( + model_runner.server_config.enable_torch_compile + ) + self.torch_compile_max_bs: int = model_runner.server_config.torch_compile_max_bs + + # ----------------------------------------------------------- + # Batch-size schedule + # ----------------------------------------------------------- + max_bs = model_runner.max_running_requests + self.capture_bs: List[int] = _default_capture_batch_sizes(max_bs) + self.compile_bs: List[int] = ( + [bs for bs in self.capture_bs if bs <= self.torch_compile_max_bs] + if self.enable_torch_compile + else [] + ) + self.max_bs: int = max(self.capture_bs) + + logger.info("CUDA graph capture batch sizes: %s", self.capture_bs) + + # ----------------------------------------------------------- + # Attention-backend CUDA-graph state + # ----------------------------------------------------------- + self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs, self.max_bs) + + # Fill value for padded seq_lens so attention kernels don't div-by-0. + self.seq_len_fill_value: int = ( + self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() + ) + + # ----------------------------------------------------------- + # Pre-allocated input buffers + # ----------------------------------------------------------- + self.buffers: _InputBuffers = _InputBuffers.create( + device=torch.device(self.device), + max_bs=self.max_bs, + seq_len_fill_value=self.seq_len_fill_value, + ) + + # ----------------------------------------------------------- + # Optional torch.compile config + # ----------------------------------------------------------- + if self.enable_torch_compile: + _set_torch_compile_config() + + # ----------------------------------------------------------- + # Capture all batch sizes + # ----------------------------------------------------------- + try: + with model_capture_mode(): + self.capture() + except RuntimeError as exc: + raise RuntimeError( + f"CUDA graph capture failed: {exc}\n" + "Possible fixes:\n" + " 1. Reduce --server.mem_fraction_static (e.g. 0.7)\n" + " 2. Reduce --server.max_running_requests\n" + " 3. Disable CUDA graph with --server.disable_cuda_graph\n" + ) from exc + + # ------------------------------------------------------------------ + # Capability check + # ------------------------------------------------------------------ + + def can_run(self, forward_batch: ForwardBatch) -> bool: + """Return ``True`` if the batch can be run via CUDA graph replay. + + The batch must be a decode (or idle) batch whose size does not + exceed the largest captured batch size. + """ + return ( + forward_batch.forward_mode.is_decode_or_idle() + and forward_batch.batch_size <= self.max_bs + ) + + # ------------------------------------------------------------------ + # Capture + # ------------------------------------------------------------------ + + def capture(self) -> None: + """Capture CUDA graphs for every batch size in ``capture_bs``. + + Iterates in reverse order (largest first) so that the GPU memory + pool allocated for the largest graph is reused by smaller ones. + """ + tic = time.perf_counter() + before_mem = _get_avail_mem(self.device) + logger.info("CUDA graph capture begin. avail mem=%.2f GB", before_mem) + + with freeze_gc(): + stream = torch.cuda.Stream() + with torch.cuda.stream(stream): + for bs in reversed(self.capture_bs): + forward_fn = self._get_forward_fn(bs) + graph, output = self._capture_one_batch_size(bs, forward_fn, stream) + self.graphs[bs] = graph + self.output_buffers[bs] = output + + after_mem = _get_avail_mem(self.device) + logger.info( + "CUDA graph capture end. elapsed=%.2f s, mem usage=%.2f GB, " + "avail mem=%.2f GB", + time.perf_counter() - tic, + before_mem - after_mem, + after_mem, + ) + + def _get_forward_fn(self, bs: int) -> Callable: + """Return the forward callable for the given batch size. + + When ``torch.compile`` is enabled and *bs* is within the compile + threshold, the model's forward method is wrapped with + ``torch.compile``. + """ + model_forward = self.model_runner.model.forward + if self.enable_torch_compile and bs in self.compile_bs: + return torch.compile( + torch.no_grad()(model_forward), + mode="max-autotune-no-cudagraphs", + ) + return model_forward + + def _capture_one_batch_size( + self, + bs: int, + forward: Callable, + stream: torch.cuda.Stream, + ) -> tuple: + """Capture a single CUDA graph for batch size *bs*. + + Steps: + 1. Build a ``ForwardBatch`` from the pre-allocated buffers. + 2. Tell the attention backend to plan for CUDA-graph capture. + 3. Run the forward pass twice for warmup. + 4. Capture the third run into a ``CUDAGraph``. + + Returns ``(graph, output_buffers)``. + """ + buffers = self.buffers + + # Slice pre-allocated buffers to the capture size. + input_ids = buffers.input_ids[:bs] + req_pool_indices = buffers.req_pool_indices[:bs] + seq_lens = buffers.seq_lens[:bs] + seq_lens_cpu = buffers.seq_lens_cpu[:bs] + out_cache_loc = buffers.out_cache_loc[:bs] + positions = buffers.positions[:bs] + mrope_position_deltas = buffers.mrope_position_deltas[:bs] + + # Build ForwardBatch (DECODE mode). + # mrope_position_deltas is set to the static buffer (initially zeros) + # so that the graph captures the ``positions + deltas`` path. During + # replay the buffer is updated with real delta values. + forward_batch = ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=bs, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=int(seq_lens.sum().item()), + seq_lens_cpu=seq_lens_cpu, + positions=positions, + return_logprob=False, + req_to_token_pool=self.model_runner.req_to_token_pool, + token_to_kv_pool=self.model_runner.token_to_kv_pool, + attn_backend=self.model_runner.attn_backend, + mrope_position_deltas=mrope_position_deltas, + ) + + # Tell the attention backend to set up CUDA-graph-aware metadata. + self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=ForwardMode.DECODE, + ) + + # The single forward-pass function to be captured. + def run_once(): + return forward( + input_ids, + forward_batch.positions, + forward_batch, + ) + + # Warmup (2 eager runs to stabilise cudnn / autotuner / etc.). + for _ in range(2): + torch.cuda.synchronize() + run_once() + + # ----- Capture ----- + global _global_graph_memory_pool + if _global_graph_memory_pool is None: + _global_graph_memory_pool = torch.cuda.graph_pool_handle() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph( + graph, + pool=_global_graph_memory_pool, + stream=stream, + ): + output = run_once() + + return graph, output + + # ------------------------------------------------------------------ + # Replay + # ------------------------------------------------------------------ + + def replay( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Replay a captured CUDA graph for the given decode batch. + + The batch is padded to the nearest captured size, inputs are copied + into the pre-allocated buffers, the graph is replayed, and the + output is sliced back to the real batch size. + + Parameters + ---------- + forward_batch + The decode batch from the scheduler. + + Returns + ------- + LogitsProcessorOutput + The logits for the real (un-padded) sequences. + """ + real_bs = forward_batch.batch_size + + # Find the smallest captured bs >= real_bs. + idx = bisect.bisect_left(self.capture_bs, real_bs) + padded_bs = self.capture_bs[idx] + + # Copy real data into the static buffers. + self.buffers.populate( + forward_batch, + padded_bs=padded_bs, + seq_len_fill_value=self.seq_len_fill_value, + ) + + # Update the attention backend for replay. + seq_lens_sum = ( + forward_batch.seq_lens_sum + (padded_bs - real_bs) * self.seq_len_fill_value + ) + self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( + bs=padded_bs, + req_pool_indices=self.buffers.req_pool_indices[:padded_bs], + seq_lens=self.buffers.seq_lens[:padded_bs], + seq_lens_sum=seq_lens_sum, + forward_mode=ForwardMode.DECODE, + seq_lens_cpu=self.buffers.seq_lens_cpu[:padded_bs], + ) + + # Replay the graph. + self.graphs[padded_bs].replay() + + # Retrieve output and slice to real batch size. + output = self.output_buffers[padded_bs] + + if isinstance(output, LogitsProcessorOutput): + return LogitsProcessorOutput( + next_token_logits=output.next_token_logits[:real_bs], + hidden_states=( + output.hidden_states[:real_bs] + if output.hidden_states is not None + else None + ), + ) + elif isinstance(output, torch.Tensor): + # Raw tensor output: assume [padded_bs, vocab_size]. + return LogitsProcessorOutput( + next_token_logits=output[:real_bs], + ) + else: + # HuggingFace-style output with .logits attribute. + if hasattr(output, "logits"): + logits = output.logits + if logits.dim() == 3: + return LogitsProcessorOutput( + next_token_logits=logits[:real_bs, -1, :], + ) + return LogitsProcessorOutput( + next_token_logits=logits[:real_bs], + ) + raise TypeError(f"Unexpected CUDA graph output type: {type(output)}") + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + """Release all captured CUDA graphs and associated buffers.""" + for graph in self.graphs.values(): + del graph + self.graphs.clear() + self.output_buffers.clear() + logger.info("CudaGraphRunner shutdown complete.") + + +# --------------------------------------------------------------------------- +# Utility helpers +# --------------------------------------------------------------------------- + + +def _get_avail_mem(device: str) -> float: + """Return available GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + free, _ = torch.cuda.mem_get_info() + return free / (1 << 30) + + +def _set_torch_compile_config() -> None: + """Set dynamo / inductor configs for optimal CUDA-graph + compile.""" + try: + import torch._dynamo.config + import torch._inductor.config + + torch._inductor.config.coordinate_descent_tuning = True + torch._inductor.config.triton.unique_kernel_names = True + torch._inductor.config.fx_graph_cache = True + torch._dynamo.config.accumulated_cache_size_limit = 1024 + if hasattr(torch._dynamo.config, "cache_size_limit"): + torch._dynamo.config.cache_size_limit = 1024 + except ImportError: + logger.warning("torch._dynamo / torch._inductor not available.") diff --git a/pymllm/executor/model_runner.py b/pymllm/executor/model_runner.py new file mode 100644 index 000000000..2178afa99 --- /dev/null +++ b/pymllm/executor/model_runner.py @@ -0,0 +1,1452 @@ +"""ModelRunner runs the forward passes of the models. + +pymllm's single-GPU inference architecture. Handles: + +* Model loading (HuggingFace checkpoint via ``transformers``) +* KV-cache memory pool initialisation +* Attention backend setup (FlashInfer) +* Forward pass dispatch (extend / decode / idle) +* Token sampling from logits + +Typical lifecycle:: + + runner = ModelRunner(server_config, model_config) + runner.initialize() + + # --- inside the inference loop --- + forward_batch = runner.prepare_forward_batch_decode(...) + logits_output = runner.forward(forward_batch) + next_token_ids = runner.sample(logits_output, forward_batch) + +Typical data flow +----------------- + SchedulerProcess builds a batch dict + ↓ + ModelRunnerProcess calls ModelRunner.forward(forward_batch) + ↓ + attn_backend.init_forward_metadata(forward_batch) + ↓ + model.forward(input_ids, positions, forward_batch) + ↓ + ModelRunner.sample(logits_output, forward_batch) + ↓ + next_token_ids returned to scheduler +""" + +from __future__ import annotations + +import gc +import logging +import time +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn + +from pymllm.configs import get_global_config +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.mem_cache.memory_pool import ( + GDNPool, + KVPool, + ReqToTokenPool, + TokenToKVPoolAllocator, + make_full_attention_net_mem_pool, + make_req_to_token_pool, +) + +if TYPE_CHECKING: + from pymllm.configs.model_config import ModelConfig + from pymllm.configs.server_config import ServerConfig + from pymllm.executor.cuda_graph_runner import CudaGraphRunner + from pymllm.layers.attention.attention_backend import AttentionBackend + +logger = logging.getLogger(__name__) + + +def _suppress_cpu_threads() -> None: + """Limit PyTorch intra-op threads to 1 for GPU inference.""" + torch.set_num_threads(1) + + +# --------------------------------------------------------------------------- +# Utility: GPU memory query +# --------------------------------------------------------------------------- + + +def get_available_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float: + """Return available GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + torch.cuda.set_device(gpu_id) + free, _ = torch.cuda.mem_get_info(gpu_id) + return free / (1 << 30) + + +def get_total_gpu_memory(device: str = "cuda", gpu_id: int = 0) -> float: + """Return total GPU memory in GB.""" + if device != "cuda" or not torch.cuda.is_available(): + return 0.0 + torch.cuda.set_device(gpu_id) + _, total = torch.cuda.mem_get_info(gpu_id) + return total / (1 << 30) + + +# --------------------------------------------------------------------------- +# LogitsProcessorOutput +# --------------------------------------------------------------------------- + + +@dataclass +class LogitsProcessorOutput: + """Container for output logits produced by the model's forward pass. + + Attributes + ---------- + next_token_logits + Raw logits for the last token of each sequence in the batch, + shape ``[batch_size, vocab_size]``. + hidden_states + Optional hidden states from the model (e.g. for speculative decoding + or auxiliary loss computation). + """ + + next_token_logits: torch.Tensor # [batch_size, vocab_size] + hidden_states: Optional[torch.Tensor] = None + + +# --------------------------------------------------------------------------- +# Penalty helpers +# --------------------------------------------------------------------------- + + +def _apply_penalties( + logits: torch.Tensor, + token_histories: List[List[int]], + repetition_penalties: torch.Tensor, + frequency_penalties: torch.Tensor, + presence_penalties: torch.Tensor, +) -> torch.Tensor: + """Apply repetition, frequency, and presence penalties to logits in-place. + + - **repetition_penalty** (multiplicative, default 1.0): + For each token that appeared in the history, if logit > 0 divide + by the penalty, else multiply by it. Values > 1 discourage repetition. + + - **frequency_penalty** (additive, default 0.0): + Subtract ``penalty * count(token)`` from the logit for each token + that appeared in the history. The more a token appears, the + stronger the penalty. + + - **presence_penalty** (additive, default 0.0): + Subtract ``penalty`` from the logit for each token that appeared + at least once in the history (binary, not count-based). + + Parameters + ---------- + logits : [batch_size, vocab_size] + token_histories : list of list of int, length batch_size + repetition_penalties : [batch_size] + frequency_penalties : [batch_size] + presence_penalties : [batch_size] + """ + logits = logits.clone() + batch_size, vocab_size = logits.shape + device = logits.device + + for i in range(batch_size): + history = token_histories[i] + if not history: + continue + + rep_p = repetition_penalties[i].item() + freq_p = frequency_penalties[i].item() + pres_p = presence_penalties[i].item() + + # Skip if all penalties are neutral + if rep_p == 1.0 and freq_p == 0.0 and pres_p == 0.0: + continue + + # Count token occurrences + token_counts: Dict[int, int] = {} + for t in history: + if 0 <= t < vocab_size: + token_counts[t] = token_counts.get(t, 0) + 1 + + if not token_counts: + continue + + token_ids = list(token_counts.keys()) + token_ids_t = torch.tensor(token_ids, dtype=torch.long, device=device) + selected_logits = logits[i, token_ids_t] + + # Repetition penalty (multiplicative) + if rep_p != 1.0: + selected_logits = torch.where( + selected_logits > 0, + selected_logits / rep_p, + selected_logits * rep_p, + ) + + # Frequency penalty (additive, proportional to count) + if freq_p != 0.0: + counts = torch.tensor( + [token_counts[t] for t in token_ids], + dtype=torch.float32, + device=device, + ) + selected_logits = selected_logits - freq_p * counts + + # Presence penalty (additive, binary) + if pres_p != 0.0: + selected_logits = selected_logits - pres_p + + logits[i, token_ids_t] = selected_logits + + return logits + + +# --------------------------------------------------------------------------- +# ModelRunner +# --------------------------------------------------------------------------- + + +class ModelRunner: + """Runs the forward passes of the models. + + This is the core execution component that owns the model, memory pools, + and attention backend. It is used by + :class:`~pymllm.orchestrator.model_runner_process.ModelRunnerProcess` to + execute batches dispatched by the scheduler. + + Parameters + ---------- + server_config + Server runtime configuration. Falls back to the global singleton + when ``None``. + model_config + Model configuration (wraps a HuggingFace ``PretrainedConfig``). + Falls back to the global singleton when ``None``. + gpu_id + GPU device index to use. + """ + + def __init__( + self, + server_config: Optional["ServerConfig"] = None, + model_config: Optional["ModelConfig"] = None, + gpu_id: int = 0, + ): + cfg = get_global_config() + self.server_config = server_config or cfg.server + self.model_config = model_config or cfg.model + + self.gpu_id = gpu_id + self.device: str = "cuda" if torch.cuda.is_available() else "cpu" + self.dtype: torch.dtype = self._resolve_dtype() + + # Set by initialize() + self.model: Optional[nn.Module] = None + self.req_to_token_pool: Optional[ReqToTokenPool] = None + self.token_to_kv_pool: Optional[KVPool] = None + self.token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None + self.gdn_pool: Optional[GDNPool] = None + self.attn_backend: Optional["AttentionBackend"] = None + self.graph_runner: Optional["CudaGraphRunner"] = None + + # Memory configuration + self.max_total_num_tokens: int = 0 + self.max_running_requests: int = 0 + + # Model metadata (populated after loading) + self.num_hidden_layers: int = 0 + self.num_attention_heads: int = 0 + self.num_kv_heads: int = 0 + self.head_dim: int = 0 + self.hidden_size: int = 0 + self.vocab_size: int = 0 + self.context_len: int = 0 + + # KV cache dtype -- same as model dtype by default; may differ for + # quantised KV caches in the future. + self.kv_cache_dtype: torch.dtype = self.dtype + + # Forward pass counter (monotonically increasing). + self.forward_pass_id: int = 0 + + # ------------------------------------------------------------------ + # Initialisation + # ------------------------------------------------------------------ + + def initialize(self) -> None: + """Full initialisation: set device, load model, init memory + backend. + + Call this once before any forward pass. + """ + tic = time.perf_counter() + logger.info("ModelRunner initialisation begin.") + + # Set device + if self.device == "cuda": + torch.cuda.set_device(self.gpu_id) + + # Limit PyTorch CPU threads to 1 for GPU inference. + # PyTorch's default (= CPU core count) causes OpenMP thread pool + # spin-wait that wastes CPU. GPU models don't benefit from CPU + # parallelism. + if self.device != "cpu": + _suppress_cpu_threads() + + # Set default dtype + torch.set_default_dtype(self.dtype) + + # Load the model + self.load_model() + + # Extract model metadata from hf_config + self._extract_model_metadata() + + # Resolve KV-cache dtype + self._configure_kv_cache_dtype() + + # Initialise memory pools + self.init_memory_pool() + + # Initialise attention backend + self.init_attention_backend() + + # Warm up cuBLAS + if self.device == "cuda": + self._init_cublas() + + # Capture CUDA graphs (must be after model + pools + backend) + self.init_cuda_graphs() + + elapsed = time.perf_counter() - tic + logger.info( + "ModelRunner initialisation complete. elapsed=%.2f s, " + "device=%s, dtype=%s, kv_dtype=%s, max_tokens=%d, max_reqs=%d", + elapsed, + self.device, + self.dtype, + self.kv_cache_dtype, + self.max_total_num_tokens, + self.max_running_requests, + ) + + # ------------------------------------------------------------------ + # Dtype resolution + # ------------------------------------------------------------------ + + def _resolve_dtype(self) -> torch.dtype: + """Resolve the model dtype from configuration.""" + dtype_str = self.server_config.dtype + if dtype_str == "auto": + if torch.cuda.is_available(): + if torch.cuda.get_device_capability()[0] >= 8: + return torch.bfloat16 + return torch.float16 + return torch.float32 + dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "float32": torch.float32, + } + result = dtype_map.get(dtype_str) + if result is None: + raise ValueError(f"Unsupported dtype: {dtype_str!r}") + return result + + def _configure_kv_cache_dtype(self) -> None: + """Determine the dtype used for KV-cache storage. + + The global ``QuantizationConfig.kv_cache_dtype`` can override the + model dtype (e.g. ``fp8_e4m3`` for quantised KV caches). When set + to ``"auto"`` the model dtype is used as-is. + """ + cfg = get_global_config() + kv_dtype_str = cfg.quantization.kv_cache_dtype + + if kv_dtype_str == "auto": + self.kv_cache_dtype = self.dtype + return + + kv_dtype_map = { + "float16": torch.float16, + "bfloat16": torch.bfloat16, + "fp8_e4m3": torch.float8_e4m3fn, + "fp8_e5m2": torch.float8_e5m2, + } + resolved = kv_dtype_map.get(kv_dtype_str) + if resolved is None: + logger.warning( + "Unrecognised kv_cache_dtype %r, falling back to model dtype.", + kv_dtype_str, + ) + self.kv_cache_dtype = self.dtype + else: + self.kv_cache_dtype = resolved + + logger.info("KV-cache dtype: %s", self.kv_cache_dtype) + + # ------------------------------------------------------------------ + # Model metadata + # ------------------------------------------------------------------ + + def _extract_model_metadata(self) -> None: + """Extract key model parameters from the HuggingFace config.""" + hf_config = self.model_config.hf_config + if hf_config is None: + raise RuntimeError( + "HuggingFace config not loaded. " + "Make sure model_config.hf_config is set before calling " + "initialize()." + ) + + # Handle text_config for multimodal models + text_config = getattr(hf_config, "text_config", hf_config) + + self.num_hidden_layers = getattr(text_config, "num_hidden_layers", 0) + self.num_attention_heads = getattr(text_config, "num_attention_heads", 0) + self.num_kv_heads = getattr( + text_config, + "num_key_value_heads", + self.num_attention_heads, + ) + self.head_dim = getattr( + text_config, + "head_dim", + getattr(text_config, "hidden_size", 0) // max(self.num_attention_heads, 1), + ) + self.hidden_size = getattr(text_config, "hidden_size", 0) + self.vocab_size = getattr(text_config, "vocab_size", 0) + + # V-head dim may differ from K-head dim (e.g. MLA) + self.v_head_dim: int = getattr(text_config, "v_head_dim", self.head_dim) + + # Context length + self.context_len = self.server_config.context_length or getattr( + text_config, "max_position_embeddings", 4096 + ) + + # Hybrid model metadata (GDN layers) + self.num_gdn_layers: int = getattr(self.model, "num_gdn_layers", 0) + self.full_attn_layer_ids: set = getattr( + self.model, "full_attn_layer_ids", set() + ) + + logger.info( + "Model metadata: layers=%d, q_heads=%d, kv_heads=%d, " + "head_dim=%d, v_head_dim=%d, hidden=%d, vocab=%d, ctx_len=%d" + + (", gdn_layers=%d" if self.num_gdn_layers > 0 else ""), + self.num_hidden_layers, + self.num_attention_heads, + self.num_kv_heads, + self.head_dim, + self.v_head_dim, + self.hidden_size, + self.vocab_size, + self.context_len, + *([self.num_gdn_layers] if self.num_gdn_layers > 0 else []), + ) + + # ------------------------------------------------------------------ + # Quantization config resolution + # ------------------------------------------------------------------ + + @staticmethod + def _load_quant_config_dict(model_path: str) -> dict: + """Probe checkpoint directory for quantization metadata. + + Checks files listed by each registered ``QuantizationConfig`` + (e.g. ``quantize_config.json``), then falls back to the + ``quantization_config`` section of ``config.json``. + + Returns an empty dict when no quantization metadata is found. + """ + import json + from pathlib import Path + + from pymllm.quantization import QuantizationConfig + + model_path = Path(model_path) + + # Collect candidate filenames from all registered config classes + filenames: list[str] = [] + for subcls in QuantizationConfig.__subclasses__(): + filenames.extend(subcls.get_config_filenames()) + # Deduplicate while preserving order + seen: set[str] = set() + unique: list[str] = [] + for f in filenames: + if f not in seen: + seen.add(f) + unique.append(f) + + for fname in unique: + fpath = model_path / fname + if fpath.exists(): + with open(fpath) as fp: + return json.load(fp) + + # Fallback: config.json → quantization_config section + config_path = model_path / "config.json" + if config_path.exists(): + with open(config_path) as fp: + cfg = json.load(fp) + if "quantization_config" in cfg: + return cfg["quantization_config"] + + return {} + + def _resolve_quant_config(self): + """Resolve the quantization configuration for this model. + + Priority: + 1. CLI value from ``GlobalConfig.quantization.method`` + 2. Auto-detect from checkpoint's ``quantize_config.json`` + or ``config.json`` → ``quantization_config.quant_method`` + 3. Auto-upgrade ``"awq"`` → ``"awq_marlin"`` on SM80+ GPUs + + Returns ``None`` when quantization is not requested / detected. + """ + from pymllm.quantization import get_quantization_config + + global_cfg = get_global_config() + method = global_cfg.quantization.method + model_path = self.server_config.model_path + + config_dict = self._load_quant_config_dict(model_path) + + # Auto-detect from checkpoint if CLI didn't specify a method + if method is None and config_dict: + method = config_dict.get("quant_method") + + if method is None: + return None + + # Auto-upgrade awq → awq_marlin on Ampere+ GPUs + if method == "awq": + cap = torch.cuda.get_device_capability(self.gpu_id) + sm = cap[0] * 10 + cap[1] + if sm >= 80: + logger.info( + "Auto-upgrading quantization: awq → awq_marlin (SM%d)", + sm, + ) + method = "awq_marlin" + + config_cls = get_quantization_config(method) + quant_config = config_cls.from_config(config_dict) + logger.info( + "Quantization: %s (bits=%s, group_size=%s)", + quant_config.get_name(), + getattr(quant_config, "weight_bits", "?"), + getattr(quant_config, "group_size", "?"), + ) + return quant_config + + # ------------------------------------------------------------------ + # Model loading + # ------------------------------------------------------------------ + + def load_model(self) -> None: + """Load the model from a HuggingFace checkpoint. + + First checks the pymllm model registry for a custom implementation + that uses ``RadixAttention``. If found, instantiates it with the + HuggingFace config and loads weights via ``load_weights()``. + Otherwise falls back to ``AutoModelForCausalLM.from_pretrained``. + """ + tic = time.perf_counter() + model_path = self.server_config.model_path + + if model_path is None: + raise RuntimeError("server_config.model_path is not set.") + + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + "Load model begin. path=%s, avail mem=%.2f GB", + model_path, + before_mem, + ) + + # Look up the architecture in the pymllm model registry + from pymllm.models import _MODEL_REGISTRY, get_model_class + + hf_config = self.model_config.hf_config + architectures = [] + if hf_config is not None: + architectures = getattr(hf_config, "architectures", None) or [] + + if not architectures: + supported = ", ".join(sorted(_MODEL_REGISTRY.keys())) + raise RuntimeError( + f"Cannot determine model architecture from config. " + f"Supported architectures: {supported}" + ) + + architecture = architectures[0] + model_cls = get_model_class(architecture) + if model_cls is None: + supported = ", ".join(sorted(_MODEL_REGISTRY.keys())) + raise RuntimeError( + f"Architecture {architecture!r} is not supported by pymllm. " + f"Supported architectures: {supported}" + ) + + logger.info("Using pymllm model class: %s", model_cls.__name__) + + quant_config = self._resolve_quant_config() + + device_str = f"cuda:{self.gpu_id}" if self.device == "cuda" else self.device + # Use set_default_dtype so parameters created without explicit dtype + # get the target dtype, while parameters with explicit dtype=torch.float32 + # (e.g. A_log, dt_bias in GDN layers) stay in float32. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(self.dtype) + try: + with torch.device(device_str): + if quant_config is not None: + self.model = model_cls(hf_config, quant_config=quant_config) + else: + self.model = model_cls(hf_config) + finally: + torch.set_default_dtype(old_dtype) + self.model.load_weights(self._iter_weights(model_path)) + + # Post-load processing: let each quantization method repack/transform + # weights from checkpoint format to runtime format (e.g. AWQ → Marlin, + # GPTQ g_idx shuffling, FP8 calibration). + for _name, module in self.model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None and hasattr(quant_method, "process_weights_after_loading"): + quant_method.process_weights_after_loading(module) + + self.model.eval() + + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + weight_mem = before_mem - after_mem + logger.info( + "Load model end. elapsed=%.2f s, type=%s, " + "weight_mem=%.2f GB, avail mem=%.2f GB", + time.perf_counter() - tic, + type(self.model).__name__, + weight_mem, + after_mem, + ) + + @staticmethod + def _iter_weights(model_path) -> "Generator[Tuple[str, torch.Tensor], None, None]": + """Yield ``(name, tensor)`` pairs from safetensors or ``.bin`` files. + + Prefers safetensors when available; falls back to PyTorch ``.bin`` + files otherwise. + """ + import glob as _glob + from pathlib import Path + + model_path = Path(model_path) + + # Prefer safetensors + st_files = sorted(_glob.glob(str(model_path / "*.safetensors"))) + if st_files: + from safetensors.torch import load_file + + for fpath in st_files: + state_dict = load_file(fpath) + yield from state_dict.items() + del state_dict + return + + # Fallback: PyTorch .bin files + bin_files = sorted(_glob.glob(str(model_path / "*.bin"))) + for fpath in bin_files: + state_dict = torch.load(fpath, map_location="cpu", weights_only=True) + yield from state_dict.items() + del state_dict + + # ------------------------------------------------------------------ + # Memory pool initialisation + # ------------------------------------------------------------------ + + def init_memory_pool(self) -> None: + """Initialise KV-cache memory pools and request-to-token mapping. + + 1. Profiles available GPU memory to determine the maximum number of + KV-cache token slots (``max_total_num_tokens``). + 2. Derives ``max_running_requests`` from config or heuristic. + 3. Creates :class:`~pymllm.mem_cache.memory_pool.ReqToTokenPool`, + :class:`~pymllm.mem_cache.memory_pool.KVPool`, and + :class:`~pymllm.mem_cache.memory_pool.TokenToKVPoolAllocator`. + """ + logger.info("Initialising memory pools...") + + # Determine max number of tokens in KV cache + self.max_total_num_tokens = self._profile_max_num_tokens() + + # Determine max running requests + max_reqs = self.server_config.max_running_requests + if max_reqs is None: + max_reqs = min( + max( + int(self.max_total_num_tokens / self.context_len * 512), + 2048, + ), + 4096, + ) + self.max_running_requests = max_reqs + + if self.max_total_num_tokens <= 0: + raise RuntimeError( + "Not enough memory for KV cache. " + "Try reducing context_length or using a smaller model." + ) + + # Create ReqToTokenPool + self.req_to_token_pool = make_req_to_token_pool( + max_reqs=self.max_running_requests, + max_context_len=self.context_len + 4, # small padding + device=self.device, + ) + + # Create KVPool + TokenToKVPoolAllocator + # Note: layer_num uses num_hidden_layers even for hybrid models + # because the KV pool is indexed by global layer_id. GDN layers' + # KV slots are allocated but unused (they use GDNPool instead). + self.token_to_kv_pool, self.token_to_kv_pool_allocator = ( + make_full_attention_net_mem_pool( + size=self.max_total_num_tokens, + layer_num=self.num_hidden_layers, + k_head_num=self.num_kv_heads, + k_head_dim=self.head_dim, + v_head_num=self.num_kv_heads, + v_head_dim=self.v_head_dim, + device=self.device, + dtype=self.kv_cache_dtype, + ) + ) + + # Create GDNPool if hybrid model with GDN layers + if self.num_gdn_layers > 0: + hf_config = self.model_config.hf_config + text_config = getattr(hf_config, "text_config", hf_config) + gdn_num_k_heads = getattr(text_config, "linear_num_key_heads", 16) + gdn_num_v_heads = getattr(text_config, "linear_num_value_heads", 32) + gdn_head_k_dim = getattr(text_config, "linear_key_head_dim", 128) + gdn_head_v_dim = getattr(text_config, "linear_value_head_dim", 128) + gdn_conv_kernel = getattr(text_config, "linear_conv_kernel_dim", 4) + gdn_conv_dim = ( + gdn_num_k_heads * gdn_head_k_dim * 2 + gdn_num_v_heads * gdn_head_v_dim + ) + + self.gdn_pool = GDNPool( + max_reqs=self.max_running_requests, + num_gdn_layers=self.num_gdn_layers, + num_v_heads=gdn_num_v_heads, + head_k_dim=gdn_head_k_dim, + head_v_dim=gdn_head_v_dim, + conv_dim=gdn_conv_dim, + conv_kernel_size=gdn_conv_kernel, + device=self.device, + dtype=self.dtype, + max_track_slots=self.max_running_requests, + ) + + logger.info( + "Memory pool initialised: max_tokens=%d, max_reqs=%d, kv_pool=%.2f GB" + + (", gdn_pool=%.2f GB" if self.gdn_pool is not None else ""), + self.max_total_num_tokens, + self.max_running_requests, + self.token_to_kv_pool._mem_bytes() / (1 << 30), + *( + [self.gdn_pool.mem_bytes() / (1 << 30)] + if self.gdn_pool is not None + else [] + ), + ) + + def _profile_max_num_tokens(self) -> int: + """Profile available memory to determine maximum KV-cache tokens. + + If ``server_config.max_total_tokens`` is explicitly set that value + is used directly. Otherwise a memory-fraction-based heuristic + similar to sglang's ``profile_max_num_token`` is applied. + """ + # If user explicitly set max_total_tokens, use that. + if self.server_config.max_total_tokens is not None: + return self.server_config.max_total_tokens + + if self.device != "cuda": + # For CPU, use a conservative default. + return 4096 + + available_gb = get_available_gpu_memory(self.device, self.gpu_id) + + # Determine memory fraction for static allocation (KV cache). + mem_fraction = self.server_config.mem_fraction_static + if mem_fraction is None: + mem_fraction = 0.85 # default: use 85% of remaining memory + + # Calculate per-token KV cache size in bytes. + kv_element_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size() + cell_size = ( + self.num_kv_heads + * (self.head_dim + self.v_head_dim) # K + V + * self.num_hidden_layers + * kv_element_size + ) + + if cell_size == 0: + logger.warning( + "cell_size is 0 (model metadata may be incomplete); " + "using default max_total_num_tokens=4096" + ) + return 4096 + + rest_memory_bytes = int(available_gb * mem_fraction * (1 << 30)) + + # Reserve memory for GDN pool if hybrid model + if self.num_gdn_layers > 0: + hf_config = self.model_config.hf_config + text_config = getattr(hf_config, "text_config", hf_config) + gdn_num_k_heads = getattr(text_config, "linear_num_key_heads", 16) + gdn_num_v_heads = getattr(text_config, "linear_num_value_heads", 32) + gdn_head_k_dim = getattr(text_config, "linear_key_head_dim", 128) + gdn_head_v_dim = getattr(text_config, "linear_value_head_dim", 128) + gdn_conv_kernel = getattr(text_config, "linear_conv_kernel_dim", 4) + gdn_conv_dim = ( + gdn_num_k_heads * gdn_head_k_dim * 2 + gdn_num_v_heads * gdn_head_v_dim + ) + + # Estimate GDN pool memory for max_running_requests + # Track slots add max_reqs_est extra slots for prefix cache snapshots + max_reqs_est = ( + min( + max( + int(rest_memory_bytes / cell_size / self.context_len * 512), + 2048, + ), + 4096, + ) + if self.server_config.max_running_requests is None + else self.server_config.max_running_requests + ) + pool_size = max_reqs_est + 1 + max_reqs_est # +track_slots + recurrent_bytes = ( + self.num_gdn_layers + * pool_size + * gdn_num_v_heads + * gdn_head_v_dim + * gdn_head_k_dim + * 4 # float32 + ) + dtype_size = torch.tensor([], dtype=self.dtype).element_size() + conv_bytes = ( + self.num_gdn_layers + * pool_size + * gdn_conv_dim + * (gdn_conv_kernel - 1) + * dtype_size + ) + gdn_pool_bytes = recurrent_bytes + conv_bytes + rest_memory_bytes -= gdn_pool_bytes + logger.info( + "GDN pool memory reservation: %.2f GB", + gdn_pool_bytes / (1 << 30), + ) + + max_num_tokens = rest_memory_bytes // cell_size + + logger.info( + "Memory profiling: avail=%.2f GB, fraction=%.2f, " + "cell_size=%d bytes, max_tokens=%d", + available_gb, + mem_fraction, + cell_size, + max_num_tokens, + ) + + return max(max_num_tokens, 1) # at least 1 + + # ------------------------------------------------------------------ + # Attention backend + # ------------------------------------------------------------------ + + def init_attention_backend(self) -> None: + """Initialise the attention backend. + + Creates a :class:`FlashInferAttnBackend` for standard models, or a + :class:`HybridAttnBackend` (FlashInfer + GDN) for hybrid models. + """ + from pymllm.layers.attention.flashinfer_backend import FlashInferAttnBackend + + logger.info("Initialising attention backend...") + + flash_backend = FlashInferAttnBackend( + num_heads=self.num_attention_heads, + num_kv_heads=self.num_kv_heads, + head_dim=self.head_dim, + kv_cache_dtype=self.kv_cache_dtype, + q_dtype=self.dtype, + max_context_len=self.context_len, + req_to_token=self.req_to_token_pool.req_to_token, + device=torch.device(self.device), + max_req_pool_size=self.req_to_token_pool.size, + ) + + if self.gdn_pool is not None: + from pymllm.layers.attention.gdn_backend import GDNAttnBackend + from pymllm.layers.attention.hybrid_backend import HybridAttnBackend + + gdn_backend = GDNAttnBackend( + gdn_pool=self.gdn_pool, + device=torch.device(self.device), + ) + self.attn_backend = HybridAttnBackend( + full_attn_backend=flash_backend, + gdn_backend=gdn_backend, + full_attn_layer_ids=self.full_attn_layer_ids, + ) + else: + self.attn_backend = flash_backend + + logger.info( + "Attention backend: %s", + type(self.attn_backend).__name__, + ) + + # ------------------------------------------------------------------ + # Warmup + # ------------------------------------------------------------------ + + def _init_cublas(self) -> None: + """Run a small matmul to initialise cuBLAS. + + Without this, the first real matmul may incur a significant + initialisation overhead. + """ + dtype = torch.float16 + device = "cuda" + a = torch.ones((16, 16), dtype=dtype, device=device) + b = torch.ones((16, 16), dtype=dtype, device=device) + _ = a @ b + + # ------------------------------------------------------------------ + # CUDA graph capture + # ------------------------------------------------------------------ + + def init_cuda_graphs(self) -> None: + """Capture CUDA graphs for decode-step acceleration. + + Skipped when: + * The device is not CUDA. + * ``server_config.disable_cuda_graph`` is ``True``. + * The model is not a generation model. + """ + self.graph_runner = None + + if self.device != "cuda": + return + if self.server_config.disable_cuda_graph: + logger.info("CUDA graphs disabled by config.") + return + if not self.is_generation: + return + + from pymllm.executor.cuda_graph_runner import CudaGraphRunner + + tic = time.perf_counter() + before_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info("Capturing CUDA graphs... avail mem=%.2f GB", before_mem) + + self.graph_runner = CudaGraphRunner(self) + + after_mem = get_available_gpu_memory(self.device, self.gpu_id) + logger.info( + "CUDA graph capture complete. elapsed=%.2f s, " + "mem usage=%.2f GB, avail mem=%.2f GB", + time.perf_counter() - tic, + before_mem - after_mem, + after_mem, + ) + + # ------------------------------------------------------------------ + # ForwardBatch construction + # ------------------------------------------------------------------ + + def prepare_forward_batch_extend( + self, + input_ids: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + return_logprob: bool = False, + top_logprobs_nums: Optional[List[int]] = None, + ) -> ForwardBatch: + """Build a :class:`ForwardBatch` for an extend (prefill) pass. + + Parameters + ---------- + input_ids + Token IDs for all new tokens, shape ``[total_new_tokens]``. + req_pool_indices + Index of each request in ``ReqToTokenPool``, + shape ``[batch_size]``. + seq_lens + Total (prefix + new) length of each sequence, + shape ``[batch_size]``. + extend_seq_lens + Number of new tokens per sequence, shape ``[batch_size]``. + extend_prefix_lens + Cached prefix length per sequence, shape ``[batch_size]``. + out_cache_loc + KV-pool slot indices for each new token, + shape ``[total_new_tokens]``. + return_logprob + Whether to return per-token log-probabilities. + top_logprobs_nums + Number of top log-probs per sequence. + """ + batch_size = req_pool_indices.shape[0] + seq_lens_sum = int(seq_lens.sum().item()) + extend_num_tokens = int(extend_seq_lens.sum().item()) + + # Compute positions for each token + positions = _compute_positions(extend_seq_lens, extend_prefix_lens) + + # Compute extend_start_loc (exclusive cumsum of extend_seq_lens) + extend_start_loc = torch.zeros( + batch_size, dtype=torch.int32, device=self.device + ) + if batch_size > 1: + extend_start_loc[1:] = torch.cumsum(extend_seq_lens[:-1], dim=0).to( + torch.int32 + ) + + return ForwardBatch( + forward_mode=ForwardMode.EXTEND, + batch_size=batch_size, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens.cpu(), + positions=positions, + extend_num_tokens=extend_num_tokens, + extend_seq_lens=extend_seq_lens, + extend_prefix_lens=extend_prefix_lens, + extend_start_loc=extend_start_loc, + extend_prefix_lens_cpu=extend_prefix_lens.tolist(), + extend_seq_lens_cpu=extend_seq_lens.tolist(), + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + attn_backend=self.attn_backend, + ) + + def prepare_forward_batch_decode( + self, + input_ids: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + return_logprob: bool = False, + top_logprobs_nums: Optional[List[int]] = None, + mrope_position_deltas: Optional[torch.Tensor] = None, + ) -> ForwardBatch: + """Build a :class:`ForwardBatch` for a decode step. + + Parameters + ---------- + input_ids + Token IDs (one per sequence), shape ``[batch_size]``. + req_pool_indices + Index of each request in ``ReqToTokenPool``, + shape ``[batch_size]``. + seq_lens + Total sequence length of each request, shape ``[batch_size]``. + out_cache_loc + KV-pool slot for each sequence's new token, + shape ``[batch_size]``. + return_logprob + Whether to return per-token log-probabilities. + top_logprobs_nums + Number of top log-probs per sequence. + mrope_position_deltas + Per-request M-RoPE position deltas, shape ``[batch_size]`` (int64). + Used by multimodal models (e.g. Qwen3-VL) to offset decode-step + positions by the spatial extent of prefill images. + """ + batch_size = req_pool_indices.shape[0] + seq_lens_sum = int(seq_lens.sum().item()) + + # For decode, positions = seq_lens - 1 (the new token position) + positions = (seq_lens - 1).to(torch.int64) + + return ForwardBatch( + forward_mode=ForwardMode.DECODE, + batch_size=batch_size, + input_ids=input_ids, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + out_cache_loc=out_cache_loc, + seq_lens_sum=seq_lens_sum, + seq_lens_cpu=seq_lens.cpu(), + positions=positions, + return_logprob=return_logprob, + top_logprobs_nums=top_logprobs_nums, + req_to_token_pool=self.req_to_token_pool, + token_to_kv_pool=self.token_to_kv_pool, + attn_backend=self.attn_backend, + mrope_position_deltas=mrope_position_deltas, + ) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def forward( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run a forward pass through the model. + + Dispatches to the appropriate method based on the batch's + :attr:`~pymllm.engine.forward_batch.ForwardMode`. For decode + batches, automatically uses CUDA-graph replay when a captured + graph is available. + + Parameters + ---------- + forward_batch + The prepared batch (from ``prepare_forward_batch_*``). + + Returns + ------- + LogitsProcessorOutput + Contains ``next_token_logits`` of shape + ``[batch_size, vocab_size]``. + """ + self.forward_pass_id += 1 + + if forward_batch.forward_mode.is_idle(): + return self._forward_idle(forward_batch) + + # Try CUDA graph replay for decode batches. + if ( + forward_batch.forward_mode.is_decode() + and self.graph_runner is not None + and self.graph_runner.can_run(forward_batch) + ): + return self.graph_runner.replay(forward_batch) + + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(forward_batch) + elif forward_batch.forward_mode.is_extend(): + return self.forward_extend(forward_batch) + else: + raise ValueError(f"Unsupported forward mode: {forward_batch.forward_mode}") + + def forward_decode( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run a decode forward pass (one new token per sequence). + + Calls ``attn_backend.init_forward_metadata`` followed by + ``model.forward``. + """ + self.attn_backend.init_forward_metadata(forward_batch) + model_output = self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + return self._process_logits(model_output, forward_batch) + + def forward_extend( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Run an extend (prefill) forward pass. + + Calls ``attn_backend.init_forward_metadata`` followed by + ``model.forward``. + """ + self.attn_backend.init_forward_metadata(forward_batch) + model_output = self.model.forward( + forward_batch.input_ids, + forward_batch.positions, + forward_batch, + ) + return self._process_logits(model_output, forward_batch) + + def _forward_idle( + self, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Return empty logits for an idle batch (no sequences to process).""" + return LogitsProcessorOutput( + next_token_logits=torch.empty( + (0, self.vocab_size), + dtype=self.dtype, + device=self.device, + ), + ) + + # ------------------------------------------------------------------ + # Logits post-processing + # ------------------------------------------------------------------ + + def _process_logits( + self, + model_output: Any, + forward_batch: ForwardBatch, + ) -> LogitsProcessorOutput: + """Extract last-token logits from model output. + + Handles: + * A :class:`LogitsProcessorOutput` returned by custom model + implementations. + * A ``CausalLMOutput`` (from HuggingFace ``transformers``) with a + ``.logits`` attribute. + * A raw ``torch.Tensor`` of logits. + """ + if isinstance(model_output, LogitsProcessorOutput): + return model_output + + # Standard HuggingFace output + if hasattr(model_output, "logits"): + logits = model_output.logits + elif isinstance(model_output, torch.Tensor): + logits = model_output + else: + raise TypeError( + f"Unexpected model output type: {type(model_output)}. " + "Expected torch.Tensor or an object with .logits attribute." + ) + + # --- Decode: logits is [bs, 1, vocab] or [bs, vocab] --- + if forward_batch.forward_mode.is_decode(): + if logits.dim() == 3: + next_token_logits = logits[:, -1, :] + else: + next_token_logits = logits + else: + # --- Extend: pick the last token of each sequence --- + next_token_logits = self._gather_last_token_logits(logits, forward_batch) + + return LogitsProcessorOutput(next_token_logits=next_token_logits) + + def _gather_last_token_logits( + self, + logits: torch.Tensor, + forward_batch: ForwardBatch, + ) -> torch.Tensor: + """Gather the logits of the last token in each sequence for extend. + + During extend, the model processes all tokens but we only need the + logits at the last position of each sequence for next-token sampling. + """ + if logits.dim() == 3: + # [batch_size, seq_len, vocab_size] from standard HF model + return logits[:, -1, :] + + # Flat layout [total_tokens, vocab_size] + if ( + forward_batch.extend_start_loc is not None + and forward_batch.extend_seq_lens is not None + ): + last_indices = ( + forward_batch.extend_start_loc + forward_batch.extend_seq_lens - 1 + ).long() + return logits[last_indices] + + # Fallback: last row + return logits[-1:, :] + + # ------------------------------------------------------------------ + # Sampling + # ------------------------------------------------------------------ + + def sample( + self, + logits_output: LogitsProcessorOutput, + forward_batch: ForwardBatch, + temperatures: Optional[torch.Tensor] = None, + top_ps: Optional[torch.Tensor] = None, + top_ks: Optional[torch.Tensor] = None, + penalty_params: Optional[Dict[str, Any]] = None, + ) -> torch.Tensor: + """Sample next-token IDs from logits. + + Supports per-request temperature, top-p, top-k, and penalties + (repetition, frequency, presence). + + Parameters + ---------- + logits_output + The logits from :meth:`forward`. + forward_batch + The current forward batch. + temperatures + Per-request temperature, shape ``[batch_size]``. + top_ps + Per-request top-p, shape ``[batch_size]``. + top_ks + Per-request top-k, shape ``[batch_size]``. + penalty_params + Optional dict with keys ``repetition_penalties``, + ``frequency_penalties``, ``presence_penalties`` (tensors of + shape ``[batch_size]``), and ``token_histories`` (list of + list of int). + + Returns + ------- + torch.Tensor + Next-token IDs, shape ``[batch_size]``, dtype ``int32``. + """ + from pymllm.layers.sampling import ( + sampling_from_probs, + softmax, + top_k_top_p_sampling_from_probs, + ) + + logits = logits_output.next_token_logits + + if logits.numel() == 0: + return torch.empty(0, dtype=torch.int32, device=self.device) + + # Apply penalties to logits before temperature/sampling. + if penalty_params is not None: + logits = _apply_penalties( + logits, + penalty_params["token_histories"], + penalty_params["repetition_penalties"], + penalty_params["frequency_penalties"], + penalty_params["presence_penalties"], + ) + + # Greedy path: temperature=0 (or all zeros) → argmax, no sampling. + if temperatures is not None: + all_greedy = bool((temperatures < 1e-6).all()) + else: + all_greedy = False + + if all_greedy: + return logits.argmax(dim=-1).to(torch.int32) + + # Stochastic path: apply temperature then sample. + if temperatures is not None: + probs = softmax(logits, temperature=temperatures) + else: + probs = torch.softmax(logits.float(), dim=-1) + + # Apply top-k / top-p sampling if specified + has_top_k = top_ks is not None + has_top_p = top_ps is not None + + if has_top_k or has_top_p: + k = top_ks if has_top_k else logits.shape[-1] + p = top_ps if has_top_p else 1.0 + next_token_ids = top_k_top_p_sampling_from_probs(probs, k, p) + else: + next_token_ids = sampling_from_probs(probs) + + return next_token_ids + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + """Release model and memory resources.""" + logger.info("ModelRunner shutting down...") + + if self.graph_runner is not None: + self.graph_runner.shutdown() + self.graph_runner = None + if self.model is not None: + del self.model + self.model = None + if self.token_to_kv_pool is not None: + del self.token_to_kv_pool + self.token_to_kv_pool = None + if self.token_to_kv_pool_allocator is not None: + del self.token_to_kv_pool_allocator + self.token_to_kv_pool_allocator = None + if self.gdn_pool is not None: + del self.gdn_pool + self.gdn_pool = None + if self.req_to_token_pool is not None: + del self.req_to_token_pool + self.req_to_token_pool = None + self.attn_backend = None + + if self.device == "cuda": + torch.cuda.empty_cache() + gc.collect() + + logger.info("ModelRunner shutdown complete.") + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def is_generation(self) -> bool: + """True if the model is a generation (causal-LM) model.""" + return True + + @property + def sliding_window_size(self) -> Optional[int]: + """Sliding-window attention span, or ``None`` for full context.""" + hf_config = self.model_config.hf_config + if hf_config is None: + return None + text_config = getattr(hf_config, "text_config", hf_config) + return getattr(text_config, "sliding_window", None) + + +# --------------------------------------------------------------------------- +# Utility functions +# --------------------------------------------------------------------------- + + +def _compute_positions( + extend_seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, +) -> torch.Tensor: + """Compute per-token positions for an extend batch. + + For each sequence, positions are + ``[prefix_len, prefix_len+1, ..., prefix_len+seq_len-1]``. + The result is a flat 1-D tensor of shape ``[sum(extend_seq_lens)]``. + """ + device = extend_seq_lens.device + batch_size = extend_seq_lens.shape[0] + total_tokens = int(extend_seq_lens.sum().item()) + + if total_tokens == 0: + return torch.empty(0, dtype=torch.int64, device=device) + + positions = torch.empty(total_tokens, dtype=torch.int64, device=device) + offset = 0 + for i in range(batch_size): + seq_len = int(extend_seq_lens[i].item()) + prefix_len = int(extend_prefix_lens[i].item()) + if seq_len > 0: + positions[offset : offset + seq_len] = torch.arange( + prefix_len, + prefix_len + seq_len, + dtype=torch.int64, + device=device, + ) + offset += seq_len + + return positions diff --git a/pymllm/layers/__init__.py b/pymllm/layers/__init__.py new file mode 100644 index 000000000..2ecb13965 --- /dev/null +++ b/pymllm/layers/__init__.py @@ -0,0 +1,65 @@ +"""Layers module for pymllm.""" + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.embedding import VocabParallelEmbedding +from pymllm.layers.layer_norm import LayerNorm +from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear +from pymllm.layers.mlp import MLP, ParallelMLP +from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm +from pymllm.layers.rms_norm_gated import RMSNormGated +from pymllm.layers.gated_delta_net import GatedDeltaNet +from pymllm.layers.rope import ( + apply_llama31_rope, + apply_llama31_rope_pos_ids, + apply_mrope, + apply_rope, + apply_rope_pos_ids, + apply_rope_with_cos_sin_cache, +) +from pymllm.layers.sampling import ( + chain_speculative_sampling, + min_p_sampling_from_probs, + sampling_from_logits, + sampling_from_probs, + softmax, + top_k_mask_logits, + top_k_renorm_probs, + top_k_sampling_from_probs, + top_k_top_p_sampling_from_logits, + top_k_top_p_sampling_from_probs, + top_p_renorm_probs, + top_p_sampling_from_probs, +) +from pymllm.layers.utils import set_weight_attrs + +__all__ = [ + "MllmBaseLayer", + "set_weight_attrs", + "VocabParallelEmbedding", + "ColumnParallelLinear", + "Linear", + "RowParallelLinear", + "MLP", + "ParallelMLP", + "LayerNorm", + "RMSNorm", + "GemmaRMSNorm", + "apply_mrope", + "apply_rope", + "apply_llama31_rope", + "apply_rope_pos_ids", + "apply_llama31_rope_pos_ids", + "apply_rope_with_cos_sin_cache", + "softmax", + "sampling_from_probs", + "sampling_from_logits", + "top_p_sampling_from_probs", + "top_k_sampling_from_probs", + "min_p_sampling_from_probs", + "top_k_top_p_sampling_from_logits", + "top_k_top_p_sampling_from_probs", + "top_p_renorm_probs", + "top_k_renorm_probs", + "top_k_mask_logits", + "chain_speculative_sampling", +] diff --git a/pymllm/layers/attention/__init__.py b/pymllm/layers/attention/__init__.py new file mode 100644 index 000000000..ae187975d --- /dev/null +++ b/pymllm/layers/attention/__init__.py @@ -0,0 +1,33 @@ +"""Attention layers and backends for pymllm.""" + +from pymllm.layers.attention.attention_backend import AttentionBackend +from pymllm.layers.attention.flashinfer_backend import ( + DecodeMetadata, + FlashInferAttnBackend, + PrefillMetadata, + WrapperDispatch, + should_use_tensor_core, +) +from pymllm.layers.attention.gdn_backend import GDNAttnBackend +from pymllm.layers.attention.hybrid_backend import HybridAttnBackend +from pymllm.layers.attention.radix_attention import AttentionType, RadixAttention +from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + +__all__ = [ + # Base + "AttentionBackend", + # RadixAttention + "AttentionType", + "RadixAttention", + # RadixLinearAttention (GDN) + "RadixLinearAttention", + # FlashInfer backend + "FlashInferAttnBackend", + "DecodeMetadata", + "PrefillMetadata", + "WrapperDispatch", + "should_use_tensor_core", + # GDN + Hybrid backends + "GDNAttnBackend", + "HybridAttnBackend", +] diff --git a/pymllm/layers/attention/attention_backend.py b/pymllm/layers/attention/attention_backend.py new file mode 100644 index 000000000..fe168c2d2 --- /dev/null +++ b/pymllm/layers/attention/attention_backend.py @@ -0,0 +1,165 @@ +"""Abstract base class for pymllm attention backends. + +Every concrete backend (FlashInfer, Triton, torch-native, …) must implement +at minimum: + + * ``init_forward_metadata`` – called once per batch before the model forward. + * ``forward_extend`` – prefill / extend attention. + * ``forward_decode`` – single-token decode attention. + +The public ``forward`` method dispatches to the correct variant based on +``forward_batch.forward_mode``. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional + +import torch + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + from pymllm.layers.attention.radix_attention import RadixAttention + + +class AttentionBackend(ABC): + """Abstract base class for attention backends. + + All concrete backends inherit from this class and implement the abstract + methods below. + """ + + # ------------------------------------------------------------------ + # Core interface – must be implemented by every backend + # ------------------------------------------------------------------ + + @abstractmethod + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Prepare per-batch metadata before the model's attention layers run. + + For FlashInfer this plans the KV-index arrays and calls + ``wrapper.begin_forward``; for Triton / torch-native this is a no-op. + Must be called once per batch *before* ``model.forward``. + """ + raise NotImplementedError + + @abstractmethod + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for a decode step (one new token per sequence).""" + raise NotImplementedError + + @abstractmethod + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for a prefill / extend step.""" + raise NotImplementedError + + # ------------------------------------------------------------------ + # Dispatch – shared logic; do not override in normal backends + # ------------------------------------------------------------------ + + def forward( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Dispatch to ``forward_decode`` or ``forward_extend`` based on mode. + + For IDLE batches a zero-filled output tensor is returned without any + compute. + """ + if forward_batch.forward_mode.is_idle(): + # Return empty output without computation. + return q.new_empty(q.shape[0], layer.tp_q_head_num * layer.v_head_dim) + elif forward_batch.forward_mode.is_decode(): + return self.forward_decode( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + else: + return self.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + # ------------------------------------------------------------------ + # GDN linear-attention interface (used by HybridAttnBackend) + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Run GDN linear-attention for one layer. + + Only implemented by backends that support hybrid (full + GDN) + architectures. The default raises ``NotImplementedError``. + """ + raise NotImplementedError( + f"{type(self).__name__} does not support GDN linear attention. " + "Use HybridAttnBackend for hybrid full+GDN models." + ) + + # ------------------------------------------------------------------ + # Optional CUDA-graph interface + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Fill value used to pad ``seq_lens`` tensors for CUDA-graph capture. + + Most backends use ``1`` (not ``0``) to avoid division-by-zero in + attention kernels. + """ + raise NotImplementedError + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate shared CUDA-graph state (buffers reused across captures).""" + raise NotImplementedError + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: "ForwardMode", + ) -> None: + """Set up per-batch metadata for capturing a CUDA graph.""" + raise NotImplementedError + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: "ForwardMode", + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata when replaying a captured CUDA graph.""" + raise NotImplementedError diff --git a/pymllm/layers/attention/flashinfer_backend.py b/pymllm/layers/attention/flashinfer_backend.py new file mode 100644 index 000000000..85b785f6d --- /dev/null +++ b/pymllm/layers/attention/flashinfer_backend.py @@ -0,0 +1,977 @@ +"""FlashInfer attention backend for pymllm. + + * No model-runner object -- constructor takes explicit scalar / tensor params. + * No tensor-parallelism head splitting (handled at the model layer level). + * No speculative decoding support. + * ``KVPool`` API: + - ``get_kv_buffer(layer_id)`` returns ``(k_buf, v_buf)`` each shaped + ``[buf_len, num_heads, head_dim]``. + - ``set_kv_buffer(layer_id, indices, k, v)`` -- no scale arguments. + +Supports: + * Single-wrapper mode (full context, no sliding window) + * Sliding-window mode (two wrappers: window + full) + * CUDA-graph capture / replay for decode and target-verify passes. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass +from enum import Enum, auto +from typing import List, Optional, Union + +import torch + +from pymllm.engine.forward_batch import ForwardBatch, ForwardMode +from pymllm.layers.attention.attention_backend import AttentionBackend +from mllm_kernel.cuda.jit.create_kv_indices import create_kv_indices + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Optional FlashInfer import +# --------------------------------------------------------------------------- + +_flashinfer_available = False +try: + from flashinfer import ( + BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, + BatchPrefillWithRaggedKVCacheWrapper, + ) + + try: + from flashinfer import fast_decode_plan + from functools import partial as _partial + + _has_fast_decode_plan = True + except ImportError: + _has_fast_decode_plan = False + + from flashinfer.cascade import merge_state + + _flashinfer_available = True +except ImportError: + logger.warning( + "flashinfer is not installed; FlashInferAttnBackend will raise " + "NotImplementedError if used." + ) + +# --------------------------------------------------------------------------- +# Global workspace buffer (shared across all FlashInfer wrapper instances) +# --------------------------------------------------------------------------- + +_global_workspace_buffer: Optional[torch.Tensor] = None + +# Default workspace size (128 MB); can be overridden via environment variable. +_DEFAULT_WORKSPACE_BYTES = int( + os.environ.get("PYMLLM_FLASHINFER_WORKSPACE_SIZE", 128 * 1024 * 1024) +) + +# --------------------------------------------------------------------------- +# Enums / dataclasses +# --------------------------------------------------------------------------- + + +class WrapperDispatch(Enum): + """Indicates which wrapper to use for a given attention layer.""" + + SLIDING_WINDOW = auto() + CROSS_ATTENTION = auto() + + +@dataclass +class DecodeMetadata: + """Per-batch metadata for a decode step.""" + + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]" + + +@dataclass +class PrefillMetadata: + """Per-batch metadata for a prefill / extend step.""" + + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]" + use_ragged: bool + extend_no_prefix: bool + + +# --------------------------------------------------------------------------- +# CUDA kernel – build the flat kv_indices array for FlashInfer +# --------------------------------------------------------------------------- + +# --------------------------------------------------------------------------- +# Helper – choose whether to use tensor cores for decode +# --------------------------------------------------------------------------- + + +def should_use_tensor_core( + kv_cache_dtype: torch.dtype, + num_attention_heads: int, + num_kv_heads: int, +) -> bool: + """Return whether FlashInfer decode should use tensor cores. + + For FP8 we always use tensor cores. For fp16 / bf16 we use them when + the GQA group size (num_attention_heads / num_kv_heads) is ≥ 4, which + fuses the head group with the token dimension in the MMA instruction. + """ + env_override = os.environ.get("PYMLLM_FLASHINFER_USE_TENSOR_CORE") + if env_override is not None: + return env_override.lower() == "true" + + try: + from flashinfer.decode import _grouped_size_compiled_for_decode_kernels + + return not _grouped_size_compiled_for_decode_kernels( + num_attention_heads, num_kv_heads + ) + except (ImportError, AttributeError): + pass + + gqa_group_size = num_attention_heads // num_kv_heads + if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + return True + if kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16): + return gqa_group_size >= 4 + return False + + +# --------------------------------------------------------------------------- +# FlashInferAttnBackend +# --------------------------------------------------------------------------- + + +class FlashInferAttnBackend(AttentionBackend): + """FlashInfer-based attention backend for pymllm. + + This class does not depend on a ``ModelRunner`` object. Instead it takes + all required configuration explicitly so that it can be constructed + independently of any particular model runner. + + Parameters + ---------- + num_heads + Number of query heads per device (after any TP sharding). + num_kv_heads + Number of KV heads per device. + head_dim + Per-head dimension for Q and K. + kv_cache_dtype + ``torch.dtype`` of the KV cache (e.g. ``torch.float16``). + q_dtype + ``torch.dtype`` of the query tensor. + max_context_len + Maximum sequence length the model supports. + req_to_token + The ``[max_reqs, max_context_len]`` int32 tensor from + ``ReqToTokenPool.req_to_token``. + device + Target device (e.g. ``torch.device("cuda")``) + max_req_pool_size + Maximum number of concurrent requests (= ``ReqToTokenPool.size``). + Used to pre-allocate ``kv_indptr`` / ``kv_last_page_len`` buffers. + sliding_window_size + When not ``None``, enables sliding-window attention mode which + allocates two wrapper sets (window + full context). + skip_prefill + When ``True``, skip creating prefill wrappers (for backends that only + perform decode, e.g. multi-step draft backends). + kv_indptr_buf + Optional pre-allocated ``kv_indptr`` buffer. Used when sharing + buffers across multiple backend instances (e.g. multi-step draft). + kv_last_page_len_buf + Optional pre-allocated ``kv_last_page_len`` buffer. + init_new_workspace + When ``True`` allocate a fresh workspace buffer instead of reusing the + global one. + """ + + def __init__( + self, + num_heads: int, + num_kv_heads: int, + head_dim: int, + kv_cache_dtype: torch.dtype, + q_dtype: torch.dtype, + max_context_len: int, + req_to_token: torch.Tensor, + device: torch.device, + max_req_pool_size: int, + sliding_window_size: Optional[int] = None, + skip_prefill: bool = False, + kv_indptr_buf: Optional[torch.Tensor] = None, + kv_last_page_len_buf: Optional[torch.Tensor] = None, + init_new_workspace: bool = False, + ): + if not _flashinfer_available: + raise RuntimeError( + "flashinfer is required for FlashInferAttnBackend but is not " + "installed. Run: pip install flashinfer-python" + ) + + super().__init__() + + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.kv_cache_dtype = kv_cache_dtype + self.q_dtype = q_dtype + self.max_context_len = max_context_len + self.req_to_token = req_to_token + self.device = device + self.skip_prefill = skip_prefill + + # Tensor-core preference for decode + self.decode_use_tensor_cores = should_use_tensor_core( + kv_cache_dtype, num_heads, num_kv_heads + ) + + # Sliding-window / cross-attention wrapper dispatch + if sliding_window_size is not None: + self.num_wrappers = 2 + self.dispatch_reason: Optional[WrapperDispatch] = ( + WrapperDispatch.SLIDING_WINDOW + ) + self.sliding_window_size: Optional[int] = sliding_window_size + else: + self.num_wrappers = 1 + self.dispatch_reason = None + self.sliding_window_size = None + + # ------------------------------------------------------------------ + # Workspace buffer + # ------------------------------------------------------------------ + global _global_workspace_buffer + if _global_workspace_buffer is None: + _global_workspace_buffer = torch.empty( + _DEFAULT_WORKSPACE_BYTES, + dtype=torch.uint8, + device=device, + ) + if init_new_workspace: + self.workspace_buffer = torch.empty( + _DEFAULT_WORKSPACE_BYTES, + dtype=torch.uint8, + device=device, + ) + else: + self.workspace_buffer = _global_workspace_buffer + + # ------------------------------------------------------------------ + # kv_indptr [num_wrappers × (max_req_pool_size + 1)] + # kv_last_page_len [max_req_pool_size] + # ------------------------------------------------------------------ + if kv_indptr_buf is None: + self.kv_indptr: List[torch.Tensor] = [ + torch.zeros((max_req_pool_size + 1,), dtype=torch.int32, device=device) + for _ in range(self.num_wrappers) + ] + else: + assert self.num_wrappers == 1 + self.kv_indptr = [kv_indptr_buf] + + if kv_last_page_len_buf is None: + self.kv_last_page_len = torch.ones( + (max_req_pool_size,), dtype=torch.int32, device=device + ) + else: + assert self.num_wrappers == 1 + self.kv_last_page_len = kv_last_page_len_buf + + # qo_indptr – only needed for prefill + if not skip_prefill: + self.qo_indptr: List[torch.Tensor] = [ + torch.zeros((max_req_pool_size + 1,), dtype=torch.int32, device=device) + for _ in range(self.num_wrappers) + ] + + # ------------------------------------------------------------------ + # Create FlashInfer wrappers + # ------------------------------------------------------------------ + self.prefill_wrapper_ragged: Optional[ + "BatchPrefillWithRaggedKVCacheWrapper" + ] = None + self.prefill_wrappers_paged: List["BatchPrefillWithPagedKVCacheWrapper"] = [] + self.decode_wrappers: List["BatchDecodeWithPagedKVCacheWrapper"] = [] + + if not skip_prefill: + self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( + self.workspace_buffer, "NHD" + ) + + for _ in range(self.num_wrappers): + if not skip_prefill: + self.prefill_wrappers_paged.append( + BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD") + ) + self.decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_tensor_cores=self.decode_use_tensor_cores, + ) + ) + + # ------------------------------------------------------------------ + # Indices updaters + # ------------------------------------------------------------------ + if not skip_prefill: + self.indices_updater_prefill = _FlashInferIndicesUpdaterPrefill(self) + self.indices_updater_decode = _FlashInferIndicesUpdaterDecode(self) + + # Per-batch metadata set by init_forward_metadata + self.forward_metadata: Optional[Union[DecodeMetadata, PrefillMetadata]] = None + + # CUDA-graph metadata stores + self.decode_cuda_graph_metadata: dict = {} + self.prefill_cuda_graph_metadata: dict = {} + + # ------------------------------------------------------------------ + # init_forward_metadata + # ------------------------------------------------------------------ + + def init_forward_metadata(self, forward_batch: ForwardBatch) -> None: + """Prepare FlashInfer wrappers for the current batch. + + Must be called once per batch before the model's ``forward`` method. + """ + if forward_batch.forward_mode.is_decode_or_idle(): + self.indices_updater_decode.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_cpu, + forward_batch.seq_lens_sum, + decode_wrappers=self.decode_wrappers, + ) + self.forward_metadata = DecodeMetadata(self.decode_wrappers) + else: + # Extend / prefill + prefix_lens = forward_batch.extend_prefix_lens + extend_no_prefix = forward_batch.extend_prefix_lens_cpu is None or not any( + forward_batch.extend_prefix_lens_cpu + ) + # use_ragged=True + # - extend_no_prefix=True → ragged-only (pure prefill, no cache) + # - extend_no_prefix=False → ragged+paged merge (cache hit) + # The paged wrapper covers only the cached prefix (prefix_lens), + # the ragged wrapper covers the new extend tokens. No overlap. + # NOTE: to avoid a FlashInfer edge-case with 1-token ragged + # extends, _allocate_extend guarantees extend_len >= 2. + use_ragged = True + + self.indices_updater_prefill.update( + forward_batch.req_pool_indices, + forward_batch.seq_lens, + forward_batch.seq_lens_cpu, + forward_batch.seq_lens_sum, + prefix_lens=prefix_lens, + prefill_wrappers=self.prefill_wrappers_paged, + use_ragged=use_ragged, + ) + self.forward_metadata = PrefillMetadata( + self.prefill_wrappers_paged, + use_ragged=use_ragged, + extend_no_prefix=extend_no_prefix, + ) + + # ------------------------------------------------------------------ + # forward_extend + # ------------------------------------------------------------------ + + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", # noqa: F821 + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + from pymllm.layers.attention.radix_attention import RadixAttention + + assert isinstance(layer, RadixAttention) + meta: PrefillMetadata = self.forward_metadata + + prefill_wrapper_paged = meta.prefill_wrappers[self._get_wrapper_idx(layer)] + cache_loc = forward_batch.out_cache_loc + + # Write K/V into the pool + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, cache_loc, k, v + ) + + q_3d = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim) + + if not meta.use_ragged: + # Paged-only path: uses the full KV cache (prefix + extend). + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + # Reshape to [buf_len, page_size=1, num_heads, head_dim] for FlashInfer. + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + + o = prefill_wrapper_paged.forward( + q_3d, + paged_kv, + causal=not layer.is_cross_attention, + sm_scale=layer.scaling, + window_left=layer.sliding_window_size, + logits_soft_cap=layer.logit_cap if layer.logit_cap > 0 else None, + ) + else: + # Ragged path: query attends only to the new (ragged) K/V; + # prefix K/V is in the paged pool. + if k is None: + # Fallback: load K/V from the pool. + k_buf, v_buf = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + k = k_buf + v = v_buf + + k_3d = k.view(-1, layer.tp_k_head_num, layer.head_dim) + v_3d = v.view(-1, layer.tp_v_head_num, layer.v_head_dim) + + if meta.extend_no_prefix: + # Pure prefill – no prefix at all. + o = self.prefill_wrapper_ragged.forward( + q_3d, + k_3d, + v_3d, + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + else: + # Extend with prefix: merge ragged (new) and paged (prefix). + o1, s1 = self.prefill_wrapper_ragged.forward_return_lse( + q_3d, + k_3d, + v_3d, + causal=True, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer( + layer.layer_id + ) + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + o2, s2 = prefill_wrapper_paged.forward_return_lse( + q_3d, + paged_kv, + causal=False, + sm_scale=layer.scaling, + logits_soft_cap=(layer.logit_cap if layer.logit_cap > 0 else None), + ) + + o, _ = merge_state(o1, s1, o2, s2) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + # ------------------------------------------------------------------ + # forward_decode + # ------------------------------------------------------------------ + + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", # noqa: F821 + forward_batch: ForwardBatch, + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + from pymllm.layers.attention.radix_attention import RadixAttention + + assert isinstance(layer, RadixAttention) + meta: DecodeMetadata = self.forward_metadata + + decode_wrapper = meta.decode_wrappers[self._get_wrapper_idx(layer)] + cache_loc = forward_batch.out_cache_loc + + if k is not None: + assert v is not None + if save_kv_cache: + forward_batch.token_to_kv_pool.set_kv_buffer( + layer.layer_id, cache_loc, k, v + ) + + k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) + paged_kv = (k_cache.unsqueeze(1), v_cache.unsqueeze(1)) + + o = decode_wrapper.forward( + q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), + paged_kv, + sm_scale=layer.scaling, + logits_soft_cap=layer.logit_cap if layer.logit_cap > 0 else None, + ) + + return o.view(-1, layer.tp_q_head_num * layer.v_head_dim) + + # ------------------------------------------------------------------ + # CUDA-graph support + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + return 1 + + def init_cuda_graph_state( + self, + max_bs: int, + max_num_tokens: int, + kv_indices_buf: Optional[torch.Tensor] = None, + ) -> None: + """Allocate CUDA-graph shared state buffers.""" + if kv_indices_buf is None: + cuda_graph_kv_indices = torch.zeros( + (max_num_tokens * self.max_context_len,), + dtype=torch.int32, + device=self.device, + ) + else: + cuda_graph_kv_indices = kv_indices_buf + + self.cuda_graph_kv_indices = [cuda_graph_kv_indices] + [ + cuda_graph_kv_indices.clone() for _ in range(self.num_wrappers - 1) + ] + + if not self.skip_prefill: + self.cuda_graph_custom_mask = torch.zeros( + (max_num_tokens * self.max_context_len,), + dtype=torch.uint8, + device=self.device, + ) + self.cuda_graph_qk_indptr = [x.clone() for x in self.kv_indptr] + self.cuda_graph_qo_indptr = [x.clone() for x in self.kv_indptr] + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: ForwardMode, + ) -> None: + """Set up metadata for CUDA-graph capture of a decode step.""" + if not forward_mode.is_decode_or_idle(): + raise ValueError( + "CUDA-graph capture is only supported for decode / idle modes." + ) + + decode_wrappers = [] + for i in range(self.num_wrappers): + decode_wrappers.append( + BatchDecodeWithPagedKVCacheWrapper( + self.workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=self.decode_use_tensor_cores, + paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1], + paged_kv_indices_buffer=self.cuda_graph_kv_indices[i], + paged_kv_last_page_len_buffer=self.kv_last_page_len[:num_tokens], + ) + ) + + seq_lens_sum = seq_lens.sum().item() + self.indices_updater_decode.update( + req_pool_indices, + seq_lens, + seq_lens.cpu(), + seq_lens_sum, + decode_wrappers=decode_wrappers, + ) + self.decode_cuda_graph_metadata[bs] = decode_wrappers + self.forward_metadata = DecodeMetadata(decode_wrappers) + + if _has_fast_decode_plan: + for i in range(self.num_wrappers): + decode_wrappers[i].begin_forward = _partial( + fast_decode_plan, decode_wrappers[i] + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: ForwardMode, + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata when replaying a CUDA graph for decode.""" + if not forward_mode.is_decode_or_idle(): + raise ValueError( + "CUDA-graph replay is only supported for decode / idle modes." + ) + + self.indices_updater_decode.update( + req_pool_indices[:bs], + seq_lens[:bs], + seq_lens_cpu[:bs] if seq_lens_cpu is not None else None, + seq_lens_sum, + decode_wrappers=self.decode_cuda_graph_metadata[bs], + ) + + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + + def _get_wrapper_idx(self, layer) -> int: + """Return the wrapper index for the given attention layer.""" + if self.num_wrappers == 1: + return 0 + if self.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + # Wrapper 0 → sliding window attention. + # Wrapper 1 → full-context attention. + return int(layer.sliding_window_size == -1) + raise ValueError(f"Unknown dispatch reason: {self.dispatch_reason}") + + +# --------------------------------------------------------------------------- +# _FlashInferIndicesUpdaterDecode +# --------------------------------------------------------------------------- + + +class _FlashInferIndicesUpdaterDecode: + """Populates ``kv_indptr`` / ``kv_indices`` and calls + ``wrapper.begin_forward`` before every decode step. + """ + + def __init__(self, backend: FlashInferAttnBackend): + self.num_qo_heads = backend.num_heads + self.num_kv_heads = backend.num_kv_heads + self.head_dim = backend.head_dim + self.data_type = backend.kv_cache_dtype + self.q_data_type = backend.q_dtype + self.sliding_window_size = backend.sliding_window_size + self.backend = backend + + self.kv_indptr = backend.kv_indptr + self.kv_last_page_len = backend.kv_last_page_len + self.req_to_token = backend.req_to_token + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]", + kv_start_idx: Optional[torch.Tensor] = None, + ) -> None: + if self.backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self._update_sliding_window( + req_pool_indices, + seq_lens, + seq_lens_cpu, + seq_lens_sum, + decode_wrappers, + ) + else: + # Single-wrapper: full-context decode. Build kv_indptr/kv_indices + # and call FlashInfer's plan function via the CUDA kernel. + bs = len(req_pool_indices) + kv_indptr = self.kv_indptr[0] + + # Fill kv_indptr: prefix sums of paged_kernel_lens. + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr_sliced = kv_indptr[: bs + 1] + + if seq_lens_cpu is not None: + seq_lens_sum = int(seq_lens_cpu.sum().item()) + else: + seq_lens_sum = int(seq_lens.sum().item()) + + # Allocate KV indices buffer. + if decode_wrappers and decode_wrappers[0].is_cuda_graph_enabled: + kv_indices = decode_wrappers[0]._paged_kv_indices_buf + else: + kv_indices = torch.empty( + seq_lens_sum, dtype=torch.int32, device=self.req_to_token.device + ) + + # Use high-performance CUDA kernel to populate kv_indices. + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + seq_lens.to(torch.int32), + kv_indptr_sliced, + None, + kv_indices, + ) + + decode_wrappers = decode_wrappers or self.decode_wrappers + decode_wrappers[0].begin_forward( + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) + + def _update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + decode_wrappers: "List[BatchDecodeWithPagedKVCacheWrapper]", + ) -> None: + assert self.sliding_window_size is not None + for wrapper_id in range(2): + if wrapper_id == 0: + # Sliding-window attention: clamp to window size + 1 + paged_kernel_lens = torch.clamp( + seq_lens, max=self.sliding_window_size + 1 + ) + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + kv_start_idx = seq_lens - paged_kernel_lens + seq_lens_cpu_tmp = ( + torch.clamp(seq_lens_cpu, max=self.sliding_window_size + 1) + if seq_lens_cpu is not None + else None + ) + else: + # Full-context attention + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + kv_start_idx = None + seq_lens_cpu_tmp = seq_lens_cpu + + bs = len(req_pool_indices) + kv_indptr = self.kv_indptr[wrapper_id] + kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0) + kv_indptr_sliced = kv_indptr[: bs + 1] + + if decode_wrappers and decode_wrappers[wrapper_id].is_cuda_graph_enabled: + kv_indices = decode_wrappers[wrapper_id]._paged_kv_indices_buf + else: + kv_indices = torch.empty( + paged_kernel_lens_sum, + dtype=torch.int32, + device=self.req_to_token.device, + ) + + # High-performance CUDA kernel populates kv_indices from req_to_token. + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + paged_kernel_lens.to(torch.int32), + kv_indptr_sliced, + kv_start_idx.to(torch.int32) if kv_start_idx is not None else None, + kv_indices, + ) + + decode_wrappers[wrapper_id].begin_forward( + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + data_type=self.data_type, + q_data_type=self.q_data_type, + non_blocking=True, + ) + + +# --------------------------------------------------------------------------- +# _FlashInferIndicesUpdaterPrefill +# --------------------------------------------------------------------------- + + +class _FlashInferIndicesUpdaterPrefill: + """Populates indices and calls ``wrapper.begin_forward`` before extend.""" + + def __init__(self, backend: FlashInferAttnBackend): + self.num_qo_heads = backend.num_heads + self.num_kv_heads = backend.num_kv_heads + self.head_dim = backend.head_dim + self.data_type = backend.kv_cache_dtype + self.q_data_type = backend.q_dtype + self.sliding_window_size = backend.sliding_window_size + self.backend = backend + + self.kv_indptr = backend.kv_indptr + self.kv_last_page_len = backend.kv_last_page_len + self.qo_indptr = backend.qo_indptr + self.req_to_token = backend.req_to_token + self.prefill_wrapper_ragged = backend.prefill_wrapper_ragged + + def update( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + prefix_lens: Optional[torch.Tensor], + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]", + use_ragged: bool, + ) -> None: + if self.backend.dispatch_reason == WrapperDispatch.SLIDING_WINDOW: + self._update_sliding_window( + req_pool_indices, + seq_lens, + seq_lens_cpu, + seq_lens_sum, + prefix_lens, + prefill_wrappers, + use_ragged, + ) + else: + if use_ragged: + # Merge path: paged covers ONLY the cached prefix so there + # is no overlap with the ragged (extend) tokens. + paged_kernel_lens = prefix_lens + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + else: + # Paged-only path: covers the full sequence. + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + + self._call_begin_forward( + self.prefill_wrapper_ragged, + prefill_wrappers[0], + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + seq_lens, + prefix_lens, + kv_start_idx=None, + kv_indptr=self.kv_indptr[0], + qo_indptr=self.qo_indptr[0], + use_ragged=use_ragged, + ) + + def _update_sliding_window( + self, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_cpu: Optional[torch.Tensor], + seq_lens_sum: int, + prefix_lens: Optional[torch.Tensor], + prefill_wrappers: "List[BatchPrefillWithPagedKVCacheWrapper]", + use_ragged: bool, + ) -> None: + assert self.sliding_window_size is not None + for wrapper_id in range(2): + if wrapper_id == 0: + # Sliding-window portion uses a limited context window. + extend_lens = seq_lens - prefix_lens + paged_kernel_lens = torch.minimum( + seq_lens, + torch.tensor(self.sliding_window_size, device=seq_lens.device) + + extend_lens, + ) + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + kv_start_idx = seq_lens - paged_kernel_lens + else: + # Full-context SWA wrapper: same split as non-SWA. + if use_ragged: + paged_kernel_lens = prefix_lens + paged_kernel_lens_sum = int(paged_kernel_lens.sum().item()) + else: + paged_kernel_lens = seq_lens + paged_kernel_lens_sum = seq_lens_sum + kv_start_idx = None + + kv_indptr = self.kv_indptr[wrapper_id] + qo_indptr = self.qo_indptr[wrapper_id] + + self._call_begin_forward( + self.prefill_wrapper_ragged, + prefill_wrappers[wrapper_id], + req_pool_indices, + paged_kernel_lens, + paged_kernel_lens_sum, + seq_lens, + prefix_lens, + kv_start_idx=kv_start_idx, + kv_indptr=kv_indptr, + qo_indptr=qo_indptr, + use_ragged=use_ragged, + ) + + def _call_begin_forward( + self, + wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper", + wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper", + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + seq_lens: torch.Tensor, + prefix_lens: Optional[torch.Tensor], + kv_start_idx: Optional[torch.Tensor], + kv_indptr: torch.Tensor, + qo_indptr: torch.Tensor, + use_ragged: bool, + ) -> None: + bs = len(seq_lens) + + # Build kv_indptr and kv_indices using the CUDA kernel. + kv_indptr_sliced = kv_indptr[: bs + 1] + kv_indptr_sliced[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + kv_indices = torch.empty( + paged_kernel_lens_sum + 256, + dtype=torch.int32, + device=req_pool_indices.device, + ) + + create_kv_indices( + self.req_to_token, + req_pool_indices.to(torch.int32), + paged_kernel_lens.to(torch.int32), + kv_indptr_sliced, + kv_start_idx.to(torch.int32) if kv_start_idx is not None else None, + kv_indices, + ) + + # Build qo_indptr (number of new tokens per sequence). + if prefix_lens is not None: + extend_lens = seq_lens - prefix_lens + else: + extend_lens = seq_lens + qo_indptr_sliced = qo_indptr[: bs + 1] + qo_indptr_sliced[1:] = torch.cumsum(extend_lens, dim=0) + + # Plan the ragged wrapper (new tokens only). + if use_ragged: + wrapper_ragged.begin_forward( + qo_indptr_sliced, + qo_indptr_sliced, + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + q_data_type=self.q_data_type, + ) + + # Plan the paged wrapper (cached prefix tokens). + wrapper_paged.begin_forward( + qo_indptr_sliced, + kv_indptr_sliced, + kv_indices, + self.kv_last_page_len[:bs], + self.num_qo_heads, + self.num_kv_heads, + self.head_dim, + 1, + q_data_type=self.q_data_type, + kv_data_type=self.data_type, + non_blocking=True, + ) diff --git a/pymllm/backends/cuda/__init__.py b/pymllm/layers/attention/gdn.py similarity index 100% rename from pymllm/backends/cuda/__init__.py rename to pymllm/layers/attention/gdn.py diff --git a/pymllm/layers/attention/gdn_backend.py b/pymllm/layers/attention/gdn_backend.py new file mode 100644 index 000000000..2b6e27b48 --- /dev/null +++ b/pymllm/layers/attention/gdn_backend.py @@ -0,0 +1,660 @@ +"""GDN attention backend -- pooled-state GDN computation for hybrid models. + +Performs GDN (Gated Delta Net) linear-attention using externalized state +stored in a :class:`~pymllm.mem_cache.memory_pool.GDNPool`. Supports +both extend (prefill) and decode paths with FlashInfer kernels. + +This backend is not used directly; it is wrapped by +:class:`~pymllm.layers.attention.hybrid_backend.HybridAttnBackend`. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn.functional as F + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + from pymllm.mem_cache.memory_pool import GDNPool + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Server config: gdn_decode_backend override +# --------------------------------------------------------------------------- + + +def _get_gdn_decode_backend_override() -> str: + """Read ``server.gdn_decode_backend`` from GlobalConfig. + + Returns one of: ``"auto"``, ``"flashinfer"``, ``"mllm_kernel"``, ``"pytorch"``. + """ + try: + from pymllm.configs import get_global_config + return get_global_config().server.gdn_decode_backend + except Exception: + return "auto" + + +# --------------------------------------------------------------------------- +# mllm-kernel GDN decode (lazy import, SM80+) +# --------------------------------------------------------------------------- + +_mllm_gdn_decode = None + + +def _get_mllm_gdn_decode(): + """Lazy import for mllm-kernel fused GDN decode CUDA kernel.""" + global _mllm_gdn_decode + if _mllm_gdn_decode is None: + try: + from mllm_kernel.cuda.jit.gdn_decode import gdn_decode + + _mllm_gdn_decode = gdn_decode + logger.info("GDNAttnBackend: [probe] mllm-kernel GDN decode available (SM80+)") + except (ImportError, RuntimeError) as e: + logger.info("GDNAttnBackend: [probe] mllm-kernel GDN decode not available: %s", e) + _mllm_gdn_decode = False + return _mllm_gdn_decode if _mllm_gdn_decode is not False else None + + +# --------------------------------------------------------------------------- +# FlashInfer GDN kernel (lazy import) +# --------------------------------------------------------------------------- + +_flashinfer_available: Optional[bool] = None +_fi_chunk_gated_delta_rule = None +_fi_gated_delta_rule_decode = None + + +def _get_flashinfer_gdn(): + """Lazy import for FlashInfer GDN kernels (prefill + decode).""" + global _flashinfer_available, _fi_chunk_gated_delta_rule, _fi_gated_delta_rule_decode + if _flashinfer_available is None: + try: + os.environ.setdefault("FLASHINFER_DISABLE_VERSION_CHECK", "1") + _flashinfer_available = ( + torch.cuda.is_available() + and torch.cuda.get_device_capability()[0] >= 9 + ) + if not _flashinfer_available: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN not available (requires SM90+, " + "current SM%d%d)", *torch.cuda.get_device_capability() + ) + return _flashinfer_available, None, None + + from flashinfer.gdn_prefill import chunk_gated_delta_rule + _fi_chunk_gated_delta_rule = chunk_gated_delta_rule + + try: + from flashinfer.gdn_decode import gated_delta_rule_decode_pretranspose + _fi_gated_delta_rule_decode = gated_delta_rule_decode_pretranspose + logger.info("GDNAttnBackend: [probe] FlashInfer GDN available (prefill + decode)") + except ImportError: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN partially available " + "(prefill only, decode not found)" + ) + except (ImportError, RuntimeError) as e: + logger.info( + "GDNAttnBackend: [probe] FlashInfer GDN not available: %s", e + ) + _flashinfer_available = False + return _flashinfer_available, _fi_chunk_gated_delta_rule, _fi_gated_delta_rule_decode + + +# --------------------------------------------------------------------------- +# GDN gating computation +# --------------------------------------------------------------------------- + + +def _gdn_gating( + a: torch.Tensor, + b: torch.Tensor, + A_log: torch.Tensor, + dt_bias: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute GDN gating factors. + + Returns + ------- + g : log-space decay factor: -exp(A_log) * softplus(a + dt_bias) + beta : update gate: sigmoid(b) + """ + g = -torch.exp(A_log) * F.softplus(a + dt_bias) + beta = torch.sigmoid(b) + return g, beta + + +# --------------------------------------------------------------------------- +# Forward metadata +# --------------------------------------------------------------------------- + + +@dataclass +class GDNForwardMetadata: + """Per-batch metadata for GDN backend.""" + + cache_indices: torch.Tensor # [batch_size] = req_pool_indices + cu_seqlens: Optional[torch.Tensor] = None # extend only + + +# --------------------------------------------------------------------------- +# GDNAttnBackend +# --------------------------------------------------------------------------- + + +class GDNAttnBackend: + """GDN linear-attention backend using pooled states. + + Handles both extend (prefill) and decode paths for GDN layers. + Uses FlashInfer kernels when available (SM90+), with PyTorch fallback. + + Parameters + ---------- + gdn_pool + Pre-allocated :class:`~pymllm.mem_cache.memory_pool.GDNPool`. + device + Target device. + """ + + def __init__(self, gdn_pool: "GDNPool", device: torch.device): + self.gdn_pool = gdn_pool + self.device = device + self.forward_metadata: Optional[GDNForwardMetadata] = None + + # Pre-check FlashInfer availability + self._use_flashinfer, _, _ = _get_flashinfer_gdn() + + # One-shot flags to log the selected backend on first actual forward call + self._decode_backend_logged = False + self._extend_backend_logged = False + + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Prepare GDN metadata from the current forward batch.""" + cache_indices = forward_batch.req_pool_indices.to(torch.int64) + + cu_seqlens = None + if forward_batch.forward_mode.is_extend(): + # Build cu_seqlens from extend_seq_lens + if forward_batch.extend_seq_lens is not None: + seq_lens = forward_batch.extend_seq_lens.to(torch.int64) + cu_seqlens = torch.zeros( + len(seq_lens) + 1, + dtype=torch.int64, + device=self.device, + ) + torch.cumsum(seq_lens, dim=0, out=cu_seqlens[1:]) + + self.forward_metadata = GDNForwardMetadata( + cache_indices=cache_indices, + cu_seqlens=cu_seqlens, + ) + + # ------------------------------------------------------------------ + # CUDA-graph interface + # ------------------------------------------------------------------ + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate CUDA-graph state for GDN backend. + + The GDN pool buffers are already pre-allocated at fixed addresses, + so we only need to allocate the metadata tensor. + """ + self._cuda_graph_cache_indices = torch.zeros( + (max_bs,), dtype=torch.int64, device=self.device + ) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> None: + """Set up GDN metadata for CUDA-graph capture (decode only).""" + self._cuda_graph_cache_indices[:bs].copy_( + req_pool_indices[:bs].to(torch.int64) + ) + self.forward_metadata = GDNForwardMetadata( + cache_indices=self._cuda_graph_cache_indices[:bs], + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + ) -> None: + """Update GDN metadata for CUDA-graph replay (decode only).""" + self._cuda_graph_cache_indices[:bs].copy_( + req_pool_indices[:bs].to(torch.int64) + ) + self.forward_metadata = GDNForwardMetadata( + cache_indices=self._cuda_graph_cache_indices[:bs], + ) + + # ------------------------------------------------------------------ + # Forward: decode + # ------------------------------------------------------------------ + + def forward_decode( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """GDN decode: one new token per request. + + Steps: + 1. Gather conv_state from pool → [bs, conv_dim, K-1] + 2. Conv1d update: shift + weighted sum for 1 new token + 3. Scatter updated conv_state back to pool + 4. SiLU → split q,k,v + 5. FlashInfer gated_delta_rule_decode (or PyTorch fallback) + """ + metadata = self.forward_metadata + cache_indices = metadata.cache_indices + gdn_idx = layer.gdn_layer_idx + bs = mixed_qkv.shape[0] + + recurrent_buf, conv_buf = self.gdn_pool.get_layer_state(gdn_idx) + conv_weight = layer.conv_weight # [conv_dim, kernel_size] + K = conv_weight.shape[1] + + # --- Conv1d decode: single-token update --- + conv_state = conv_buf[cache_indices] # [bs, conv_dim, K-1] + x = mixed_qkv.unsqueeze(-1) # [bs, conv_dim, 1] + + new_conv_state = torch.cat([conv_state[:, :, 1:], x], dim=-1) + full_window = torch.cat([conv_state, x], dim=-1) # [bs, conv_dim, K] + conv_out = (full_window * conv_weight.unsqueeze(0)).sum(dim=-1) + + conv_buf[cache_indices] = new_conv_state + + # --- SiLU activation --- + conv_out = F.silu(conv_out) + + # --- Split q, k, v --- + key_dim = layer.num_k_heads * layer.head_k_dim + value_dim = layer.num_v_heads * layer.head_v_dim + q, k, v = conv_out.split([key_dim, key_dim, value_dim], dim=-1) + q = q.view(bs, layer.num_k_heads, layer.head_k_dim) + k = k.view(bs, layer.num_k_heads, layer.head_k_dim) + v = v.view(bs, layer.num_v_heads, layer.head_v_dim) + + # --- Recurrent update --- + # Priority (when "auto"): FlashInfer SM90+ > mllm-kernel SM80+ > PyTorch + # Can be overridden via --server.gdn_decode_backend + backend = _get_gdn_decode_backend_override() + use_fi, _, fi_decode = _get_flashinfer_gdn() + mllm_gdn = _get_mllm_gdn_decode() + + use_flashinfer = ( + (backend in ("auto", "flashinfer")) + and use_fi and fi_decode is not None + and mixed_qkv.is_cuda + ) + use_mllm = ( + (backend in ("auto", "mllm_kernel")) + and not (backend == "auto" and use_flashinfer) + and mllm_gdn is not None + and mixed_qkv.is_cuda + ) + + if backend == "flashinfer" and not use_flashinfer: + logger.warning("GDNAttnBackend: gdn_decode_backend='flashinfer' requested but unavailable, falling back") + if backend == "mllm_kernel" and mllm_gdn is None: + logger.warning("GDNAttnBackend: gdn_decode_backend='mllm_kernel' requested but unavailable, falling back") + + if not self._decode_backend_logged: + if use_flashinfer: + selected = "flashinfer" + elif use_mllm: + selected = "mllm_kernel" + else: + selected = "pytorch" + logger.info( + "GDNAttnBackend: [decode] using backend=%s (config=%s)", selected, backend + ) + self._decode_backend_logged = True + + if use_flashinfer: + # FlashInfer decode (SM90+) + query_fi = q.unsqueeze(1) + key_fi = k.unsqueeze(1) + value_fi = v.unsqueeze(1) + a_fi = a.unsqueeze(1) + b_fi = b.unsqueeze(1) + + state_batch = recurrent_buf[cache_indices] + + output_fi, new_state = fi_decode( + q=query_fi, k=key_fi, v=value_fi, + state=state_batch, + A_log=layer.A_log.detach(), + a=a_fi, dt_bias=layer.dt_bias.detach(), b=b_fi, + scale=None, output=None, use_qk_l2norm=True, + ) + + recurrent_buf[cache_indices] = new_state + output = output_fi.squeeze(1) + + elif use_mllm: + # mllm-kernel fused CUDA decode (SM80+) + output = mllm_gdn( + q, k, v, a, b, + layer.A_log, layer.dt_bias, + recurrent_buf, cache_indices, + ) + + else: + # PyTorch fallback + g, beta = _gdn_gating(a, b, layer.A_log, layer.dt_bias) + output = self._decode_pytorch_fallback( + q, k, v, g, beta, recurrent_buf, cache_indices, layer + ) + + return output.reshape(bs, value_dim) + + def _decode_pytorch_fallback( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_buf: torch.Tensor, + cache_indices: torch.Tensor, + layer: "RadixLinearAttention", + ) -> torch.Tensor: + """Pure PyTorch decode fallback for GDN with delta rule and L2 norm. + + Matches the sglang Triton kernel (fused_sigmoid_gating_delta_rule_update): + state *= exp(g) # decay + v_delta = v - state @ k # delta rule + v_delta *= beta # gating + state += v_delta outer k # state update + output = state @ q # readout + """ + bs = q.shape[0] + num_v_heads = layer.num_v_heads + num_k_heads = layer.num_k_heads + + # GQA: expand k/q heads to match v heads + if num_k_heads != num_v_heads: + repeats = num_v_heads // num_k_heads + q = q.repeat_interleave(repeats, dim=1) + k = k.repeat_interleave(repeats, dim=1) + + # All computation in float32 (state is float32, avoids dtype mismatch) + orig_dtype = q.dtype + q = q.float() + k = k.float() + v = v.float() + + # L2 normalize q and k per-head (matching use_qk_l2norm_in_kernel=True) + q = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + decay = torch.exp(g.float()) # [bs, num_v_heads] + beta_f = beta.float() # [bs, num_v_heads] + + outputs = [] + for i in range(bs): + idx = cache_indices[i] + state = recurrent_buf[idx] # [H, V, K] float32 + + # Decay + state = state * decay[i].unsqueeze(-1).unsqueeze(-1) + + k_i = k[i] # [H, K] + v_i = v[i] # [H, V] + b_i = beta_f[i] # [H] + q_i = q[i] # [H, K] + + # Delta rule: v_delta = v - state @ k + v_delta = v_i - torch.bmm(state, k_i.unsqueeze(-1)).squeeze(-1) + v_delta = v_delta * b_i.unsqueeze(-1) # gating + + # State update: state += v_delta ⊗ k (outer product in [V, K] layout) + state = state + v_delta.unsqueeze(-1) * k_i.unsqueeze(-2) + recurrent_buf[idx] = state + + # Output: o = state @ q + o_t = torch.bmm(state, q_i.unsqueeze(-1)).squeeze(-1) # [H, V] + outputs.append(o_t) + + return torch.stack(outputs, dim=0).to(orig_dtype) # [bs, H, V] + + # ------------------------------------------------------------------ + # Forward: extend (prefill) + # ------------------------------------------------------------------ + + def forward_extend( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """GDN extend (prefill): multi-token per request. + + Steps: + 1. Gather conv_state from pool for each request + 2. Per-request causal conv1d + 3. Scatter new conv_state back to pool + 4. SiLU → split q,k,v → gating + 5. FlashInfer chunk_gated_delta_rule (or PyTorch fallback) + 6. Scatter final recurrent state back to pool + """ + metadata = self.forward_metadata + cache_indices = metadata.cache_indices + cu_seqlens = metadata.cu_seqlens + gdn_idx = layer.gdn_layer_idx + total_tokens = mixed_qkv.shape[0] + + recurrent_buf, conv_buf = self.gdn_pool.get_layer_state(gdn_idx) + conv_weight = layer.conv_weight # [conv_dim, kernel_size] + K = conv_weight.shape[1] + batch_size = cache_indices.shape[0] + + key_dim = layer.num_k_heads * layer.head_k_dim + value_dim = layer.num_v_heads * layer.head_v_dim + + # --- Per-request causal conv1d --- + conv_out = torch.empty_like(mixed_qkv) # [total_tokens, conv_dim] + + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seq_len = end - start + if seq_len == 0: + continue + + idx = cache_indices[i] + x = mixed_qkv[start:end] # [seq_len, conv_dim] + prev_state = conv_buf[idx] # [conv_dim, K-1] + + # Pad with previous conv state + x_padded = torch.cat([prev_state.T, x], dim=0) # [K-1+seq_len, conv_dim] + + # Save new conv state (last K-1 tokens) + conv_buf[idx] = x_padded[-(K - 1):].T.clone() + + # Causal conv1d + out = torch.zeros(seq_len, x.shape[1], device=x.device, dtype=x.dtype) + for kk in range(K): + out += x_padded[kk: kk + seq_len] * conv_weight[:, kk] + conv_out[start:end] = out + + # --- SiLU activation --- + conv_out = F.silu(conv_out) + + # --- Split q, k, v --- + q, k, v = conv_out.split([key_dim, key_dim, value_dim], dim=-1) + q = q.view(total_tokens, layer.num_k_heads, layer.head_k_dim) + k = k.view(total_tokens, layer.num_k_heads, layer.head_k_dim) + v = v.view(total_tokens, layer.num_v_heads, layer.head_v_dim) + + # --- GDN gating --- + g, beta = _gdn_gating(a, b, layer.A_log, layer.dt_bias) + + # --- Recurrent computation --- + use_fi, fi_prefill, _ = _get_flashinfer_gdn() + use_fi_extend = use_fi and fi_prefill is not None and mixed_qkv.is_cuda + + if not self._extend_backend_logged: + logger.info( + "GDNAttnBackend: [extend] using backend=%s", + "flashinfer" if use_fi_extend else "pytorch", + ) + self._extend_backend_logged = True + + if use_fi_extend: + # Gather initial states for this batch + init_state = recurrent_buf[cache_indices].to(torch.float32) + # [batch_size, num_v_heads, head_v_dim, head_k_dim] + + alpha = torch.exp(g.to(torch.float32)) + beta_f32 = beta.to(torch.float32) + + # FlashInfer's use_qk_l2norm_in_kernel is silently ignored — + # the flag is declared in the Python wrapper but never forwarded + # to the CUDA kernel. Pre-normalize q and k here, matching + # sglang's approach (l2norm_fwd before calling with False). + q_fi = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k_fi = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + output, final_state = fi_prefill( + q=q_fi.contiguous(), + k=k_fi.contiguous(), + v=v.contiguous(), + g=alpha, + beta=beta_f32, + initial_state=init_state, + output_final_state=True, + cu_seqlens=cu_seqlens, + use_qk_l2norm_in_kernel=False, + ) + + # Scatter final states back to pool + recurrent_buf[cache_indices] = final_state.to(recurrent_buf.dtype) + else: + # PyTorch fallback: per-request sequential scan + output = self._extend_pytorch_fallback( + q, k, v, g, beta, recurrent_buf, cache_indices, cu_seqlens, layer + ) + + return output.reshape(total_tokens, value_dim) + + def _extend_pytorch_fallback( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + recurrent_buf: torch.Tensor, + cache_indices: torch.Tensor, + cu_seqlens: torch.Tensor, + layer: "RadixLinearAttention", + ) -> torch.Tensor: + """Pure PyTorch extend fallback for GDN with delta rule and L2 norm.""" + total_tokens = q.shape[0] + num_v_heads = layer.num_v_heads + num_k_heads = layer.num_k_heads + head_v_dim = layer.head_v_dim + batch_size = cache_indices.shape[0] + + # All computation in float32 + orig_dtype = q.dtype + q = q.float() + k = k.float() + v = v.float() + + # L2 normalize q and k per-head + q = q / (q.norm(dim=-1, keepdim=True) + 1e-6) + k = k / (k.norm(dim=-1, keepdim=True) + 1e-6) + + # GQA expansion + if num_k_heads != num_v_heads: + repeats = num_v_heads // num_k_heads + q = q.repeat_interleave(repeats, dim=1) + k = k.repeat_interleave(repeats, dim=1) + + output = torch.zeros( + total_tokens, num_v_heads, head_v_dim, + device=q.device, dtype=torch.float32, + ) + + for i in range(batch_size): + start = int(cu_seqlens[i].item()) + end = int(cu_seqlens[i + 1].item()) + seq_len = end - start + if seq_len == 0: + continue + + idx = cache_indices[i] + q_seq = q[start:end] + k_seq = k[start:end] + v_seq = v[start:end] + g_seq = g[start:end] + beta_seq = beta[start:end] + + decay = torch.exp(g_seq.float()) # [seq_len, H] + beta_f = beta_seq.float() # [seq_len, H] + state = recurrent_buf[idx].clone() # [H, V, K] float32 + + seq_outputs = [] + for t in range(seq_len): + # Decay + state = state * decay[t].unsqueeze(-1).unsqueeze(-1) + + k_t = k_seq[t] # [H, K] + v_t = v_seq[t] # [H, V] + b_t = beta_f[t] # [H] + q_t = q_seq[t] # [H, K] + + # Delta rule: v_delta = v - state @ k + v_delta = v_t - torch.bmm(state, k_t.unsqueeze(-1)).squeeze(-1) + v_delta = v_delta * b_t.unsqueeze(-1) + + # State update + state = state + v_delta.unsqueeze(-1) * k_t.unsqueeze(-2) + + # Output + o_t = torch.bmm(state, q_t.unsqueeze(-1)).squeeze(-1) + seq_outputs.append(o_t) + + recurrent_buf[idx] = state + output[start:end] = torch.stack(seq_outputs, dim=0) + + return output.to(orig_dtype) + + # ------------------------------------------------------------------ + # Dispatch entry point + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Route to decode or extend based on forward mode.""" + if forward_batch.forward_mode.is_decode(): + return self.forward_decode(layer, forward_batch, mixed_qkv, a, b) + else: + return self.forward_extend(layer, forward_batch, mixed_qkv, a, b) diff --git a/pymllm/layers/attention/hybrid_backend.py b/pymllm/layers/attention/hybrid_backend.py new file mode 100644 index 000000000..a5628259e --- /dev/null +++ b/pymllm/layers/attention/hybrid_backend.py @@ -0,0 +1,184 @@ +"""Hybrid attention backend -- FlashInfer + GDN for hybrid architectures. + +Wraps a :class:`FlashInferAttnBackend` (for full-attention layers) and a +:class:`GDNAttnBackend` (for GDN linear-attention layers). Dispatches +based on layer type: + +* ``RadixAttention`` calls → delegated to ``full_attn_backend`` +* ``RadixLinearAttention`` calls (via ``forward_gdn``) → delegated to ``gdn_backend`` + +CUDA-graph compatible: delegates all graph lifecycle methods to both +sub-backends. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Optional, Set + +import torch + +from pymllm.layers.attention.attention_backend import AttentionBackend + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch, ForwardMode + from pymllm.layers.attention.flashinfer_backend import FlashInferAttnBackend + from pymllm.layers.attention.gdn_backend import GDNAttnBackend + from pymllm.layers.attention.radix_attention import RadixAttention + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + +logger = logging.getLogger(__name__) + + +class HybridAttnBackend(AttentionBackend): + """Composite attention backend for hybrid full-attention + GDN models. + + Parameters + ---------- + full_attn_backend + FlashInfer backend for standard transformer attention layers. + gdn_backend + GDN backend for linear-attention layers. + full_attn_layer_ids + Set of global layer IDs that use full attention (for logging). + """ + + def __init__( + self, + full_attn_backend: "FlashInferAttnBackend", + gdn_backend: "GDNAttnBackend", + full_attn_layer_ids: Set[int], + ): + self.full_attn_backend = full_attn_backend + self.gdn_backend = gdn_backend + self.full_attn_layer_ids = full_attn_layer_ids + + logger.info( + "HybridAttnBackend created: %d full-attn layers, " + "%d GDN layers", + len(full_attn_layer_ids), + gdn_backend.gdn_pool.num_gdn_layers, + ) + + # ------------------------------------------------------------------ + # Core interface: init_forward_metadata + # ------------------------------------------------------------------ + + def init_forward_metadata(self, forward_batch: "ForwardBatch") -> None: + """Initialize metadata for both sub-backends.""" + self.full_attn_backend.init_forward_metadata(forward_batch) + self.gdn_backend.init_forward_metadata(forward_batch) + + # ------------------------------------------------------------------ + # Full attention: forward_decode / forward_extend + # ------------------------------------------------------------------ + + def forward_decode( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Delegate full-attention decode to FlashInfer backend.""" + return self.full_attn_backend.forward_decode( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + def forward_extend( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + layer: "RadixAttention", + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Delegate full-attention extend to FlashInfer backend.""" + return self.full_attn_backend.forward_extend( + q, k, v, layer, forward_batch, save_kv_cache=save_kv_cache, **kwargs + ) + + # ------------------------------------------------------------------ + # GDN linear attention: forward_gdn + # ------------------------------------------------------------------ + + def forward_gdn( + self, + layer: "RadixLinearAttention", + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Delegate GDN computation to the GDN backend.""" + return self.gdn_backend.forward_gdn( + layer=layer, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + # ------------------------------------------------------------------ + # CUDA-graph interface: delegate to both sub-backends + # ------------------------------------------------------------------ + + def get_cuda_graph_seq_len_fill_value(self) -> int: + """Delegate to the full-attention backend.""" + return self.full_attn_backend.get_cuda_graph_seq_len_fill_value() + + def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int) -> None: + """Allocate CUDA-graph state for both sub-backends.""" + self.full_attn_backend.init_cuda_graph_state(max_bs, max_num_tokens) + self.gdn_backend.init_cuda_graph_state(max_bs, max_num_tokens) + + def init_forward_metadata_capture_cuda_graph( + self, + bs: int, + num_tokens: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + forward_mode: "ForwardMode", + ) -> None: + """Set up metadata for CUDA-graph capture in both sub-backends.""" + self.full_attn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + num_tokens=num_tokens, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + forward_mode=forward_mode, + ) + self.gdn_backend.init_forward_metadata_capture_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) + + def init_forward_metadata_replay_cuda_graph( + self, + bs: int, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + seq_lens_sum: int, + forward_mode: "ForwardMode", + seq_lens_cpu: Optional[torch.Tensor], + ) -> None: + """Update metadata for CUDA-graph replay in both sub-backends.""" + self.full_attn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + seq_lens_sum=seq_lens_sum, + forward_mode=forward_mode, + seq_lens_cpu=seq_lens_cpu, + ) + self.gdn_backend.init_forward_metadata_replay_cuda_graph( + bs=bs, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens, + ) diff --git a/pymllm/layers/attention/radix_attention.py b/pymllm/layers/attention/radix_attention.py new file mode 100644 index 000000000..114130dbf --- /dev/null +++ b/pymllm/layers/attention/radix_attention.py @@ -0,0 +1,171 @@ +"""RadixAttention -- the attention layer used by pymllm models. + +This module is kept small intentionally: all heavy computation is delegated +to the pluggable ``AttentionBackend`` that is attached to the ``ForwardBatch``. +""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, Optional + +import torch +from torch import nn + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + + +# --------------------------------------------------------------------------- +# AttentionType +# --------------------------------------------------------------------------- + + +class AttentionType(Enum): + """Attention variant used by a :class:`RadixAttention` layer. + + Uses string values so that ``torch.compile`` can treat them as constants. + """ + + # Standard causal self-attention in a decoder layer. + DECODER = "decoder" + + # Bidirectional self-attention for image tokens inside a decoder + # (e.g. VLM visual encoder embedded in the language model). + DECODER_BIDIRECTIONAL = "decoder_bidirectional" + + # Full bidirectional self-attention in an encoder-only model. + ENCODER_ONLY = "encoder_only" + + +# --------------------------------------------------------------------------- +# RadixAttention +# --------------------------------------------------------------------------- + + +class RadixAttention(nn.Module): + """Attention layer that delegates computation to a pluggable backend. + + Each transformer attention layer in a pymllm model creates exactly one + ``RadixAttention`` with a unique ``layer_id``. During the forward pass + the layer looks up the correct KV buffer via ``layer_id`` and calls the + backend attached to the current :class:`~pymllm.engine.forward_batch.ForwardBatch`. + + Parameters + ---------- + num_heads + Number of query attention heads (after any tensor-parallelism + sharding; pass the full count if not using TP). + head_dim + Per-head dimension for query and key projections. + scaling + Softmax pre-scale, typically ``1 / sqrt(head_dim)``. + num_kv_heads + Number of key / value heads (supports GQA / MQA). + layer_id + Zero-based index of this layer within the model. Used to index into + ``KVPool.k_buffer`` / ``v_buffer``. + logit_cap + If > 0, attention logits are soft-capped to this value via a ``tanh`` + gate (used by Gemma2 / Gemma3 style models). Set to ``0.0`` to + disable. + v_head_dim + Per-head dimension of the value projection. Defaults to ``head_dim`` + (i.e. standard square QKV). + sliding_window_size + Sliding-window attention span. ``-1`` means full context (no window). + is_cross_attention + ``True`` for cross-attention layers in encoder-decoder models. + attn_type + One of :class:`AttentionType`. + """ + + def __init__( + self, + num_heads: int, + head_dim: int, + scaling: float, + num_kv_heads: int, + layer_id: int, + logit_cap: float = 0.0, + v_head_dim: int = -1, + sliding_window_size: int = -1, + is_cross_attention: bool = False, + attn_type: AttentionType = AttentionType.DECODER, + ): + super().__init__() + + self.tp_q_head_num: int = num_heads + self.tp_k_head_num: int = num_kv_heads + self.tp_v_head_num: int = num_kv_heads + + self.head_dim: int = head_dim + self.qk_head_dim: int = head_dim + self.v_head_dim: int = v_head_dim if v_head_dim != -1 else head_dim + + self.scaling: float = scaling + self.layer_id: int = layer_id + self.logit_cap: float = logit_cap + self.sliding_window_size: int = ( + sliding_window_size if sliding_window_size is not None else -1 + ) + self.is_cross_attention: bool = is_cross_attention + self.attn_type: AttentionType = attn_type + + # ------------------------------------------------------------------ + # forward + # ------------------------------------------------------------------ + + def forward( + self, + q: torch.Tensor, + k: Optional[torch.Tensor], + v: Optional[torch.Tensor], + forward_batch: "ForwardBatch", + save_kv_cache: bool = True, + **kwargs, + ) -> torch.Tensor: + """Run attention for one batch. + + Parameters + ---------- + q + Query tensor, shape ``[num_tokens, tp_q_head_num * head_dim]`` + (or already reshaped to ``[num_tokens, tp_q_head_num, head_dim]``). + k + Key tensor, same leading dimension as ``q``, shape + ``[num_tokens, tp_k_head_num * qk_head_dim]``. + Pass ``None`` for cross-layer KV sharing (``v`` must also be + ``None`` in this case). + v + Value tensor, shape + ``[num_tokens, tp_v_head_num * v_head_dim]``. + forward_batch + Batch metadata and references to memory pools / backend. + save_kv_cache + When ``False``, skip writing K/V into the pool (useful for draft + models in speculative decoding). + **kwargs + Passed through to the backend (e.g. ``q_rope``, ``k_rope``). + """ + if k is not None: + assert v is not None, "k and v must both be provided or both be None" + k = k.view(-1, self.tp_k_head_num, self.qk_head_dim) + v = v.view(-1, self.tp_v_head_num, self.v_head_dim) + + return forward_batch.attn_backend.forward( + q, k, v, self, forward_batch, save_kv_cache, **kwargs + ) + + def extra_repr(self) -> str: + return ( + f"layer_id={self.layer_id}, " + f"q_heads={self.tp_q_head_num}, " + f"kv_heads={self.tp_k_head_num}, " + f"head_dim={self.head_dim}, " + f"v_head_dim={self.v_head_dim}, " + f"scaling={self.scaling:.4f}, " + f"logit_cap={self.logit_cap}, " + f"sliding_window={self.sliding_window_size}, " + f"attn_type={self.attn_type.value}" + ) diff --git a/pymllm/layers/attention/radix_linear_attention.py b/pymllm/layers/attention/radix_linear_attention.py new file mode 100644 index 000000000..01993163d --- /dev/null +++ b/pymllm/layers/attention/radix_linear_attention.py @@ -0,0 +1,116 @@ +"""RadixLinearAttention -- GDN linear-attention layer for hybrid models. + +Analogous to :class:`RadixAttention` but for GDN (Gated Delta Net) layers. +Stores per-layer GDN parameters and delegates computation to the +:meth:`AttentionBackend.forward_gdn` method on the current +:class:`~pymllm.engine.forward_batch.ForwardBatch`. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import torch +from torch import nn + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + + +class RadixLinearAttention(nn.Module): + """GDN linear-attention layer that delegates to the attention backend. + + Each GDN layer in a pymllm model creates one ``RadixLinearAttention`` + with a unique ``layer_id`` and ``gdn_layer_idx``. During forward, it + calls ``forward_batch.attn_backend.forward_gdn(...)`` which routes to + the appropriate GDN backend implementation. + + Parameters + ---------- + layer_id : int + Global zero-based layer index within the model. + gdn_layer_idx : int + Sequential zero-based index among GDN layers only (not global). + Used to index into :class:`~pymllm.mem_cache.memory_pool.GDNPool`. + num_k_heads : int + Number of key heads. + num_v_heads : int + Number of value heads. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_weight : nn.Parameter + Reference to the GDNConv1d weight parameter. + A_log : nn.Parameter + Log-space decay parameter. + dt_bias : nn.Parameter + Bias for the decay gate. + """ + + def __init__( + self, + layer_id: int, + gdn_layer_idx: int, + num_k_heads: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_weight: nn.Parameter, + A_log: nn.Parameter, + dt_bias: nn.Parameter, + ): + super().__init__() + self.layer_id = layer_id + self.gdn_layer_idx = gdn_layer_idx + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + # Store references to model parameters (not copies) + self.conv_weight = conv_weight + self.A_log = A_log + self.dt_bias = dt_bias + + def forward( + self, + forward_batch: "ForwardBatch", + mixed_qkv: torch.Tensor, + a: torch.Tensor, + b: torch.Tensor, + ) -> torch.Tensor: + """Delegate GDN computation to the attention backend. + + Parameters + ---------- + forward_batch + Batch metadata with ``attn_backend`` attached. + mixed_qkv + Concatenated Q/K/V projection output before conv1d. + a + Decay gate input, shape ``[num_tokens, num_v_heads]``. + b + Update gate input, shape ``[num_tokens, num_v_heads]``. + + Returns + ------- + torch.Tensor + GDN attention output, shape ``[num_tokens, num_v_heads * head_v_dim]``. + """ + return forward_batch.attn_backend.forward_gdn( + layer=self, + forward_batch=forward_batch, + mixed_qkv=mixed_qkv, + a=a, + b=b, + ) + + def extra_repr(self) -> str: + return ( + f"layer_id={self.layer_id}, " + f"gdn_layer_idx={self.gdn_layer_idx}, " + f"k_heads={self.num_k_heads}, " + f"v_heads={self.num_v_heads}, " + f"k_dim={self.head_k_dim}, " + f"v_dim={self.head_v_dim}" + ) diff --git a/pymllm/layers/base.py b/pymllm/layers/base.py new file mode 100644 index 000000000..3e762ae5a --- /dev/null +++ b/pymllm/layers/base.py @@ -0,0 +1,30 @@ +import torch +from torch import nn +from torch.nn import Parameter +from pymllm.layers.utils import set_weight_attrs +from typing import Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from pymllm.layers.quantize_base import QuantizeMethodBase + + +class MllmBaseLayer(nn.Module): + def __init__(self): + super().__init__() + self.quant_method: Optional["QuantizeMethodBase"] = None + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load weights into a parameter. + + This is the default implementation that directly copies the loaded weight + into the parameter. Subclasses should override this method to implement + custom loading logic (e.g., tensor parallelism sharding). + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint. + """ + param.data.copy_(loaded_weight) + + def forward(self, *args, **kwargs): + raise NotImplementedError("Subclasses must implement forward method") diff --git a/pymllm/backends/qualcomm/transformers/core/__init__.py b/pymllm/layers/custom_event.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/__init__.py rename to pymllm/layers/custom_event.py diff --git a/pymllm/layers/embedding.py b/pymllm/layers/embedding.py new file mode 100644 index 000000000..ec99c5b2d --- /dev/null +++ b/pymllm/layers/embedding.py @@ -0,0 +1,160 @@ +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs +from pymllm.orchestrator import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_reduce, +) + + +class VocabParallelEmbedding(MllmBaseLayer): + """Embedding layer with vocabulary parallelism. + + This layer shards the embedding table along the vocabulary dimension + for tensor parallelism. + + Args: + num_embeddings: Size of the vocabulary. + embedding_dim: Size of the embedding vector. + padding_idx: Index for padding token (optional). + """ + + def __init__( + self, + num_embeddings: int, + embedding_dim: int, + padding_idx: int = None, + ): + super().__init__() + + # Get TP info from global state + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.num_embeddings = num_embeddings + self.embedding_dim = embedding_dim + self.padding_idx = padding_idx + + # Calculate sharded size + if self.num_embeddings % self.tp_size != 0: + raise ValueError( + f"num_embeddings ({num_embeddings}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + + self.num_embeddings_per_partition = divide(num_embeddings, self.tp_size) + + # Create sharded weight + self.weight = Parameter( + torch.empty(self.num_embeddings_per_partition, embedding_dim) + ) + + # Calculate shard range + self.vocab_start_index = self.tp_rank * self.num_embeddings_per_partition + self.vocab_end_index = ( + self.vocab_start_index + self.num_embeddings_per_partition + ) + + # Set weight attributes for loading + set_weight_attrs( + self.weight, + { + "output_dim": 0, # Shard along vocab dimension + "input_dim": 1, # Embedding dimension + "weight_loader": self.weight_loader, + }, + ) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + output_dim = getattr(param, "output_dim", None) + + if output_dim is None or self.tp_size == 1: + # No sharding, direct copy + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + # Sharded loading: slice the loaded weight + assert loaded_weight.shape[output_dim] == self.num_embeddings, ( + f"Loaded weight vocab size {loaded_weight.shape[output_dim]} " + f"does not match expected {self.num_embeddings}" + ) + + # Slice along vocab dimension + if output_dim == 0: + shard_weight = loaded_weight[ + self.vocab_start_index : self.vocab_end_index, : + ] + else: + shard_weight = loaded_weight.narrow( + output_dim, + self.vocab_start_index, + self.num_embeddings_per_partition, + ) + + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward pass of the embedding layer with TP support. + + Args: + x: Input tensor of token ids. + + Returns: + Embedded representation (all-reduced across TP group if needed). + """ + local_padding_idx = self.padding_idx + if self.tp_size > 1: + # Create mask for valid vocab range + vocab_mask = (x >= self.vocab_start_index) & (x < self.vocab_end_index) + + # Adjust indices to local vocab space + masked_input = torch.where( + vocab_mask, + x - self.vocab_start_index, + torch.zeros_like(x), # Invalid indices become 0 (will be masked) + ) + # F.embedding expects indices in local weight-table space. + # Only pass padding_idx on the owning rank, remapped to local offset. + if self.padding_idx is not None: + if self.vocab_start_index <= self.padding_idx < self.vocab_end_index: + local_padding_idx = self.padding_idx - self.vocab_start_index + else: + local_padding_idx = None + else: + masked_input = x + vocab_mask = None + + # Lookup embeddings + output = F.embedding( + masked_input.long(), + self.weight, + padding_idx=local_padding_idx, + ) + + # Mask invalid positions (for TP) + if vocab_mask is not None: + output.masked_fill_(~vocab_mask.unsqueeze(-1), 0) + + # All-reduce across TP group + if self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output) + + return output diff --git a/pymllm/layers/gated_delta_net.py b/pymllm/layers/gated_delta_net.py new file mode 100644 index 000000000..5472371da --- /dev/null +++ b/pymllm/layers/gated_delta_net.py @@ -0,0 +1,194 @@ +"""Gated Delta Network (GDN) linear attention for Qwen3.5. + +This implements the linear attention mechanism used in Qwen3.5's hybrid +architecture. GDN alternates with standard full-attention layers. + +Core formulation (decode, per-head): + g_t = -exp(A_log) * softplus(a_t + dt_bias) + beta_t = sigmoid(b_t) + state_t = exp(g_t) * state_{t-1} + beta_t * (k_t outer v_t) + output_t = (q_t @ state_t) + +State is externalized into a :class:`~pymllm.mem_cache.memory_pool.GDNPool` +and computation is delegated to the attention backend via +:class:`~pymllm.layers.attention.radix_linear_attention.RadixLinearAttention`. +""" + +from __future__ import annotations + +import logging +from typing import Any, Optional + +import torch +import torch.nn as nn + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.linear import Linear +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Conv1d weight holder +# --------------------------------------------------------------------------- + + +class GDNConv1d(nn.Module): + """Causal 1D convolution weight holder for GDN sequence mixing. + + The actual convolution computation is performed by the GDN backend + using pooled conv states. This module only holds the learnable weight. + """ + + def __init__(self, channels: int, kernel_size: int): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.weight = nn.Parameter(torch.empty(channels, kernel_size)) + + +# --------------------------------------------------------------------------- +# GatedDeltaNet — main GDN layer +# --------------------------------------------------------------------------- + + +class GatedDeltaNet(MllmBaseLayer): + """Gated Delta Network linear attention layer for Qwen3.5. + + State is externalized into a GDNPool and computation is delegated to + the attention backend via RadixLinearAttention. + + Parameters + ---------- + hidden_size : int + Model hidden dimension. + num_k_heads : int + Number of key heads. + num_v_heads : int + Number of value heads. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_kernel_size : int + Causal conv1d kernel width. + layer_id : int + Global layer index. + gdn_layer_idx : int + Sequential index among GDN layers (0-based). + rms_norm_eps : float + Epsilon for gated RMS normalization. + """ + + def __init__( + self, + hidden_size: int, + num_k_heads: int = 16, + num_v_heads: int = 32, + head_k_dim: int = 128, + head_v_dim: int = 128, + conv_kernel_size: int = 4, + layer_id: int = 0, + gdn_layer_idx: int = 0, + rms_norm_eps: float = 1e-6, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.hidden_size = hidden_size + self.num_k_heads = num_k_heads + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.key_dim = head_k_dim * num_k_heads + self.value_dim = head_v_dim * num_v_heads + self.conv_kernel_size = conv_kernel_size + self.layer_id = layer_id + self.gdn_layer_idx = gdn_layer_idx + + def _get_qm(suffix, out_features): + # Skip quantization for small projections — Marlin kernels + # require minimum thread tile sizes that exceed these dims. + if quant_config is None or out_features < 64: + return None + return quant_config.get_quant_method( + layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + # Input projections + self.in_proj_qkv = Linear( + hidden_size, self.key_dim * 2 + self.value_dim, bias=False, + quant_method=_get_qm("in_proj_qkv", self.key_dim * 2 + self.value_dim), + ) + self.in_proj_z = Linear( + hidden_size, self.value_dim, bias=False, + quant_method=_get_qm("in_proj_z", self.value_dim), + ) + self.in_proj_a = Linear( + hidden_size, num_v_heads, bias=False, + quant_method=_get_qm("in_proj_a", num_v_heads), + ) + self.in_proj_b = Linear( + hidden_size, num_v_heads, bias=False, + quant_method=_get_qm("in_proj_b", num_v_heads), + ) + + # Causal convolution (weight only — computation is in the backend) + self.conv1d = GDNConv1d(self.key_dim * 2 + self.value_dim, conv_kernel_size) + + # State parameters (must stay float32 for numerical stability) + self.A_log = nn.Parameter(torch.empty(num_v_heads, dtype=torch.float32)) + self.dt_bias = nn.Parameter(torch.ones(num_v_heads, dtype=torch.float32)) + set_weight_attrs(self.A_log, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.dt_bias, {"weight_loader": self.weight_loader}) + + # Gated RMSNorm (mllm-kernel accelerated) + from pymllm.layers.rms_norm_gated import RMSNormGated + self.norm = RMSNormGated(head_v_dim, eps=rms_norm_eps, norm_before_gate=True) + + # Output projection + self.out_proj = Linear( + self.value_dim, hidden_size, bias=False, + quant_method=_get_qm("out_proj", hidden_size), + ) + + # RadixLinearAttention — delegates to the attention backend + from pymllm.layers.attention.radix_linear_attention import RadixLinearAttention + self.attn = RadixLinearAttention( + layer_id=layer_id, + gdn_layer_idx=gdn_layer_idx, + num_k_heads=num_k_heads, + num_v_heads=num_v_heads, + head_k_dim=head_k_dim, + head_v_dim=head_v_dim, + conv_weight=self.conv1d.weight, + A_log=self.A_log, + dt_bias=self.dt_bias, + ) + + def forward( + self, hidden_states: torch.Tensor, forward_batch: Any = None, + ) -> torch.Tensor: + seq_len, _ = hidden_states.shape + + # Input projections + mixed_qkv = self.in_proj_qkv(hidden_states) + z = self.in_proj_z(hidden_states) + a = self.in_proj_a(hidden_states) + b = self.in_proj_b(hidden_states) + + # Delegate to backend via RadixLinearAttention + # The backend handles: conv1d, SiLU, split, gating, recurrent update + attn_out = self.attn(forward_batch, mixed_qkv, a, b) + + # Gated norm + output projection + attn_out = attn_out.view(seq_len, self.num_v_heads, self.head_v_dim) + z = z.view(seq_len, self.num_v_heads, self.head_v_dim) + + attn_flat = attn_out.reshape(-1, self.head_v_dim) + z_flat = z.reshape(-1, self.head_v_dim) + normed = self.norm(attn_flat, z_flat) + normed = normed.view(seq_len, self.num_v_heads, self.head_v_dim) + normed = normed.reshape(seq_len, self.value_dim) + return self.out_proj(normed) diff --git a/pymllm/layers/layer_norm.py b/pymllm/layers/layer_norm.py new file mode 100644 index 000000000..54d94c19e --- /dev/null +++ b/pymllm/layers/layer_norm.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +import torch +import flashinfer +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + + +class LayerNorm(MllmBaseLayer): + """LayerNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + # flashinfer.norm.layernorm expects gamma/beta in fp32. + self.weight = Parameter(torch.ones(hidden_size, dtype=torch.float32)) + self.bias = Parameter(torch.zeros(hidden_size, dtype=torch.float32)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + if x.dtype != torch.bfloat16: + raise TypeError( + "flashinfer.norm.layernorm requires bfloat16 input, " + f"but got {x.dtype}" + ) + + if x.dim() == 2: + return flashinfer.norm.layernorm(x, self.weight, self.bias, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.layernorm(x_2d, self.weight, self.bias, self.eps) + return out.reshape(original_shape) diff --git a/pymllm/layers/linear.py b/pymllm/layers/linear.py new file mode 100644 index 000000000..b4058c2da --- /dev/null +++ b/pymllm/layers/linear.py @@ -0,0 +1,316 @@ +"""Linear layers with quantization method dispatch. + +Every linear layer holds a ``quant_method`` attribute (an instance of +:class:`~pymllm.layers.quantize_base.LinearMethodBase`). When no +quantization is configured, :class:`UnquantizedLinearMethod` is used as the +default — it creates a standard FP weight and forwards via ``F.linear``. + +Quantized checkpoints plug in a different ``LinearMethodBase`` (e.g. +``AWQLinearMethod``) which creates packed int4 weights, scales, and +zero-points, and overrides :meth:`apply` with a fused dequant+matmul kernel. + +Usage in model definitions:: + + # Non-quantized (default) + layer = ColumnParallelLinear(4096, 4096) + + # Quantized — pass a quant_method from QuantizationConfig + qm = awq_config.get_quant_method(layer, prefix="model.layers.0.q_proj") + layer = ColumnParallelLinear(4096, 4096, quant_method=qm) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.quantize_base import LinearMethodBase, UnquantizedLinearMethod +from pymllm.layers.utils import set_weight_attrs +from pymllm.orchestrator import ( + divide, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + + +class ColumnParallelLinear(MllmBaseLayer): + """Linear layer with column parallelism (output-dimension sharding). + + The weight matrix is split along the output dimension across TP ranks. + Each rank holds ``out_features / tp_size`` rows of the weight. + + Parameters + ---------- + in_features + Size of each input sample. + out_features + Size of each output sample (before sharding). + bias + If ``True``, adds a learnable bias. + gather_output + If ``True``, all-gather the output across TP ranks so every rank + gets the full ``out_features``. Set to ``False`` when the next + layer is a :class:`RowParallelLinear` that expects a split input. + quant_method + Quantization method instance. ``None`` → :class:`UnquantizedLinearMethod`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + gather_output: bool = True, + quant_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.in_features = in_features + self.out_features = out_features + self.gather_output = gather_output + + if out_features % self.tp_size != 0: + raise ValueError( + f"out_features ({out_features}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + self.out_features_per_partition = divide(out_features, self.tp_size) + + self.output_start_index = self.tp_rank * self.out_features_per_partition + self.output_end_index = self.output_start_index + self.out_features_per_partition + + # --- Quantization method --- + # The quant_method creates the weight parameters on this layer via + # create_weights(). For UnquantizedLinearMethod this creates a + # standard FP Parameter named "weight". For quantized methods it + # may instead create qweight, scales, qzeros, etc. + self.quant_method = quant_method or UnquantizedLinearMethod() + self.quant_method.create_weights( + layer=self, + input_size_per_partition=in_features, + output_partition_sizes=[self.out_features_per_partition], + input_size=in_features, + output_size=out_features, + params_dtype=torch.get_default_dtype(), + weight_loader=self.weight_loader, + ) + + if bias: + self.bias_flag = True + self.bias = Parameter(torch.empty(self.out_features_per_partition)) + set_weight_attrs( + self.bias, + { + "output_dim": 0, + "weight_loader": self.weight_loader, + }, + ) + else: + self.bias_flag = False + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + output_dim = getattr(param, "output_dim", None) + + if output_dim is None or self.tp_size == 1: + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + shard_weight = loaded_weight.narrow( + output_dim, + self.output_start_index, + self.out_features_per_partition, + ) + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Delegate computation to the quant_method. For unquantized layers + # this is F.linear; for quantized layers it's a fused dequant+matmul. + output = self.quant_method.apply(self, x, self.bias) + + if self.gather_output and self.tp_size > 1: + output = tensor_model_parallel_all_gather(output, dim=-1) + + return output + + +class RowParallelLinear(MllmBaseLayer): + """Linear layer with row parallelism (input-dimension sharding). + + The weight matrix is split along the input dimension across TP ranks. + Each rank holds all ``out_features`` rows but only + ``in_features / tp_size`` columns. + + Typically placed after a :class:`ColumnParallelLinear` whose + ``gather_output=False``, so the input is already split. + + Parameters + ---------- + in_features + Size of each input sample (before sharding). + out_features + Size of each output sample. + bias + If ``True``, adds a learnable bias (applied after all-reduce). + reduce_output + If ``True``, all-reduce the output across TP ranks. + quant_method + Quantization method instance. ``None`` → :class:`UnquantizedLinearMethod`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + reduce_output: bool = True, + quant_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + + self.tp_rank = get_tensor_model_parallel_rank() + self.tp_size = get_tensor_model_parallel_world_size() + + self.in_features = in_features + self.out_features = out_features + self.reduce_output = reduce_output + + if in_features % self.tp_size != 0: + raise ValueError( + f"in_features ({in_features}) must be divisible by " + f"tp_size ({self.tp_size})" + ) + self.in_features_per_partition = divide(in_features, self.tp_size) + + self.input_start_index = self.tp_rank * self.in_features_per_partition + self.input_end_index = self.input_start_index + self.in_features_per_partition + + # --- Quantization method --- + self.quant_method = quant_method or UnquantizedLinearMethod() + self.quant_method.create_weights( + layer=self, + input_size_per_partition=self.in_features_per_partition, + output_partition_sizes=[out_features], + input_size=in_features, + output_size=out_features, + params_dtype=torch.get_default_dtype(), + weight_loader=self.weight_loader, + ) + + if bias: + self.bias_flag = True + self.bias = Parameter(torch.empty(out_features)) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + else: + self.bias_flag = False + self.register_parameter("bias", None) + + def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): + """Load sharded weights into the parameter. + + Args: + param: The parameter to load weights into. + loaded_weight: The weight tensor loaded from checkpoint (full size). + """ + input_dim = getattr(param, "input_dim", None) + + if input_dim is None or self.tp_size == 1: + assert param.data.shape == loaded_weight.shape, ( + f"Shape mismatch: param {param.data.shape} vs " + f"loaded {loaded_weight.shape}" + ) + param.data.copy_(loaded_weight) + else: + shard_weight = loaded_weight.narrow( + input_dim, + self.input_start_index, + self.in_features_per_partition, + ) + assert param.data.shape == shard_weight.shape, ( + f"Shard shape mismatch: param {param.data.shape} vs " + f"shard {shard_weight.shape}" + ) + param.data.copy_(shard_weight) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Delegate computation to the quant_method (no bias here; bias is + # added after the all-reduce below). + output = self.quant_method.apply(self, x) + + if self.reduce_output and self.tp_size > 1: + output = tensor_model_parallel_all_reduce(output) + + if self.bias is not None: + output = output + self.bias + + return output + + +class Linear(MllmBaseLayer): + """Non-parallel linear layer with quantization dispatch. + + Parameters + ---------- + in_features + Size of each input sample. + out_features + Size of each output sample. + bias + If ``True``, adds a learnable bias. + quant_method + Quantization method instance. ``None`` → :class:`UnquantizedLinearMethod`. + """ + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + quant_method: Optional[LinearMethodBase] = None, + ): + super().__init__() + self.in_features = in_features + self.out_features = out_features + + # --- Quantization method --- + self.quant_method = quant_method or UnquantizedLinearMethod() + self.quant_method.create_weights( + layer=self, + input_size_per_partition=in_features, + output_partition_sizes=[out_features], + input_size=in_features, + output_size=out_features, + params_dtype=torch.get_default_dtype(), + weight_loader=self.weight_loader, + ) + + if bias: + self.bias = Parameter(torch.empty(out_features)) + set_weight_attrs(self.bias, {"weight_loader": self.weight_loader}) + else: + self.register_parameter("bias", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.quant_method.apply(self, x, self.bias) diff --git a/pymllm/layers/mlp.py b/pymllm/layers/mlp.py new file mode 100644 index 000000000..1894e23ca --- /dev/null +++ b/pymllm/layers/mlp.py @@ -0,0 +1,230 @@ +from __future__ import annotations + +import logging +from typing import Callable, Literal, Optional + +import flashinfer +import torch + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.linear import ColumnParallelLinear, Linear, RowParallelLinear + +logger = logging.getLogger(__name__) + +MLPActivation = Literal["silu", "gelu", "gelu_tanh"] + +_ACTIVATION_MAP: dict[MLPActivation, Callable[..., torch.Tensor]] = { + "silu": flashinfer.activation.silu_and_mul, + "gelu": flashinfer.activation.gelu_and_mul, + "gelu_tanh": flashinfer.activation.gelu_tanh_and_mul, +} + + +def _validate_mlp_args( + hidden_size: int, intermediate_size: int, activation: str +) -> None: + if hidden_size <= 0: + raise ValueError(f"hidden_size must be > 0, but got {hidden_size}") + if intermediate_size <= 0: + raise ValueError( + f"intermediate_size must be > 0, but got {intermediate_size}" + ) + if activation not in _ACTIVATION_MAP: + raise ValueError( + f"Unsupported activation '{activation}'. " + f"Expected one of: {list(_ACTIVATION_MAP)}" + ) + + +def _run_gated_activation( + gate_up: torch.Tensor, + intermediate_size: int, + activation: MLPActivation, + enable_pdl: Optional[bool], +) -> torch.Tensor: + if gate_up.shape[-1] != 2 * intermediate_size: + raise ValueError( + "Expected last dim of gate_up tensor to be " + f"{2 * intermediate_size}, but got {gate_up.shape[-1]}" + ) + return _ACTIVATION_MAP[activation](gate_up, enable_pdl=enable_pdl) + + +class MLP(MllmBaseLayer): + """Feed-forward MLP block with FlashInfer fused gated activations. + + Non-parallel version (TP=1). Uses :class:`Linear` for all projections. + + Supported activations: ``silu``, ``gelu``, ``gelu_tanh``. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: MLPActivation = "silu", + use_fused_gate_up_proj: bool = True, + use_bias_gate_up: bool = False, + use_bias_down: bool = False, + enable_pdl: Optional[bool] = None, + quant_config=None, + prefix: str = "", + ): + super().__init__() + _validate_mlp_args(hidden_size, intermediate_size, activation) + + # Quantized checkpoints store gate_proj / up_proj separately; + # fusing them into a single packed-int32 parameter is impractical, + # so force the unfused path when quantisation is active. + if quant_config is not None: + use_fused_gate_up_proj = False + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation = activation + self.use_fused_gate_up_proj = use_fused_gate_up_proj + self.enable_pdl = enable_pdl + + def _get_qm(suffix): + if quant_config is None: + return None + return quant_config.get_quant_method( + layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + if not use_fused_gate_up_proj and quant_config is None: + logger.warning( + "MLP with use_fused_gate_up_proj=False uses a lower-efficiency path. " + "Use use_fused_gate_up_proj=True for better performance.", + ) + + if use_fused_gate_up_proj: + self.gate_up_proj = Linear( + hidden_size, 2 * intermediate_size, bias=use_bias_gate_up, + quant_method=_get_qm("gate_up_proj"), + ) + self.gate_proj = None + self.up_proj = None + else: + self.gate_up_proj = None + self.gate_proj = Linear( + hidden_size, intermediate_size, bias=use_bias_gate_up, + quant_method=_get_qm("gate_proj"), + ) + self.up_proj = Linear( + hidden_size, intermediate_size, bias=use_bias_gate_up, + quant_method=_get_qm("up_proj"), + ) + + self.down_proj = Linear( + intermediate_size, hidden_size, bias=use_bias_down, + quant_method=_get_qm("down_proj"), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + if self.use_fused_gate_up_proj: + assert self.gate_up_proj is not None + gate_up = self.gate_up_proj(x) + else: + assert self.gate_proj is not None and self.up_proj is not None + gate_up = torch.cat([self.gate_proj(x), self.up_proj(x)], dim=-1) + + hidden = _run_gated_activation( + gate_up, self.intermediate_size, self.activation, self.enable_pdl, + ) + return self.down_proj(hidden) + + +class ParallelMLP(MllmBaseLayer): + """Tensor-parallel MLP with column-sharded intermediate dimension. + + Projection layout (Megatron-style): + + - ``gate_proj``: :class:`ColumnParallelLinear` + ``(hidden_size → intermediate_size, gather_output=False)`` + - ``up_proj``: :class:`ColumnParallelLinear` + ``(hidden_size → intermediate_size, gather_output=False)`` + - ``down_proj``: :class:`RowParallelLinear` + ``(intermediate_size → hidden_size, reduce_output=True)`` + + Gate and up projections are kept separate so that each TP rank holds a + correctly paired ``[gate_shard, up_shard]`` for the gated activation. + + Cost: **1 all-reduce** (inside ``down_proj``). + + Input shape : ``(*, hidden_size)`` — full / replicated. + Output shape: ``(*, hidden_size)`` — full / replicated. + + Args: + hidden_size: Model hidden dimension. + intermediate_size: Intermediate (expanded) dimension **before** TP + sharding. + activation: Gated activation type. + use_bias_gate_up: Add bias to the gate/up projections. + use_bias_down: Add bias to the down projection. + enable_pdl: FlashInfer PDL flag. + """ + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + activation: MLPActivation = "silu", + use_bias_gate_up: bool = False, + use_bias_down: bool = False, + enable_pdl: Optional[bool] = None, + quant_config=None, + prefix: str = "", + ): + super().__init__() + _validate_mlp_args(hidden_size, intermediate_size, activation) + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.activation = activation + self.enable_pdl = enable_pdl + + def _get_qm(suffix): + if quant_config is None: + return None + return quant_config.get_quant_method( + layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + self.gate_proj = ColumnParallelLinear( + hidden_size, intermediate_size, + bias=use_bias_gate_up, gather_output=False, + quant_method=_get_qm("gate_proj"), + ) + self.up_proj = ColumnParallelLinear( + hidden_size, intermediate_size, + bias=use_bias_gate_up, gather_output=False, + quant_method=_get_qm("up_proj"), + ) + + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, + bias=use_bias_down, reduce_output=True, + quant_method=_get_qm("down_proj"), + ) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + gate_up = torch.cat([self.gate_proj(x), self.up_proj(x)], dim=-1) + + shard_inter = self.down_proj.in_features_per_partition + hidden = _run_gated_activation( + gate_up, shard_inter, self.activation, self.enable_pdl, + ) + return self.down_proj(hidden) diff --git a/pymllm/layers/quantize_base.py b/pymllm/layers/quantize_base.py new file mode 100644 index 000000000..951fc6115 --- /dev/null +++ b/pymllm/layers/quantize_base.py @@ -0,0 +1,275 @@ +"""Quantization method base classes for pymllm layers. + +This module defines the plugin interface that all quantization methods must +implement. The pattern follows sglang / vLLM's ``LinearMethodBase`` design: + +1. Each quantization algorithm (AWQ, GPTQ, FP8, ...) provides a concrete + subclass of :class:`LinearMethodBase`. +2. Linear layers hold a ``quant_method`` attribute (an instance of + :class:`LinearMethodBase`). +3. During ``__init__``, the linear layer calls + ``quant_method.create_weights(layer, ...)`` to register the appropriate + parameters (packed int weights, scales, zero-points, etc.) on itself. +4. During ``forward``, the linear layer calls + ``quant_method.apply(layer, x, bias)`` instead of ``F.linear``. +5. After checkpoint loading, :class:`~pymllm.executor.model_runner.ModelRunner` + iterates all modules and calls + ``quant_method.process_weights_after_loading(layer)`` for format conversion, + repacking (e.g. AWQ → Marlin), or calibration. + +Typical lifecycle:: + + # ---- model construction ---- + quant_method = SomeLinearMethod(bits=4, group_size=128) + layer = ColumnParallelLinear(4096, 4096, quant_method=quant_method) + # → calls quant_method.create_weights(layer, ...) + # → layer now has .qweight, .scales, .qzeros, etc. + + # ---- weight loading ---- + model.load_weights(iter_weights(...)) + # → checkpoint tensors are loaded into the parameters created above, + # using each parameter's ``weight_loader`` attribute. + + # ---- post-load processing ---- + for module in model.modules(): + qm = getattr(module, "quant_method", None) + if qm is not None: + qm.process_weights_after_loading(module) + # → AWQ repacks int4 → Marlin layout, GPTQ shuffles by g_idx, etc. + + # ---- inference ---- + output = layer(x) + # → calls quant_method.apply(layer, x, bias) + # → dequant + matmul (or fused kernel) +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +import torch +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.utils import set_weight_attrs + + +# --------------------------------------------------------------------------- +# Base classes +# --------------------------------------------------------------------------- + + +class QuantizeMethodBase(ABC): + """Base class for all quantization methods (linear, embedding, MoE, ...). + + Every concrete quantization algorithm must implement at least + :meth:`create_weights` and :meth:`apply`. + + How to implement a new quantization method + ------------------------------------------- + 1. Subclass :class:`LinearMethodBase` (for linear layers). + 2. Override :meth:`create_weights` to register quantized parameters + (``qweight``, ``scales``, ``qzeros``, etc.) on the layer via + ``layer.register_parameter()``. + 3. Override :meth:`apply` to perform the quantized forward computation. + 4. Optionally override :meth:`process_weights_after_loading` if the + checkpoint format differs from the runtime format (e.g. repacking, + transposing, or calibrating scales). + """ + + def create_weights( + self, + layer: torch.nn.Module, + *args: Any, + **kwargs: Any, + ) -> None: + """Create and register quantized weight parameters on *layer*. + + Called once during layer construction (``__init__``). Implementations + should call ``layer.register_parameter(name, param)`` and attach + metadata via :func:`~pymllm.layers.utils.set_weight_attrs` so that + the weight-loading infrastructure knows how to shard and load them. + + Parameters + ---------- + layer + The ``nn.Module`` (e.g. ``ColumnParallelLinear``) that will own + the parameters. + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + *args: Any, + **kwargs: Any, + ) -> torch.Tensor: + """Execute the quantized forward pass. + + Called by ``layer.forward()`` every inference step. The method should + read the parameters previously created by :meth:`create_weights` from + *layer* (e.g. ``layer.qweight``, ``layer.scales``), dequantize or + invoke a fused kernel, and return the output tensor. + + Parameters + ---------- + layer + The module that owns the quantized parameters. + """ + raise NotImplementedError + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Post-process parameters after checkpoint loading. + + Called once by ``ModelRunner`` after all checkpoint tensors have been + loaded into the layer's parameters. Use this for: + + * **Repacking**: converting checkpoint layout to kernel-native layout + (e.g. AutoAWQ int4 → Marlin packed format). + * **Transposing**: rearranging dimensions for optimised GEMM kernels. + * **Calibration**: computing per-tensor or per-channel scales from + the loaded FP weights (e.g. dynamic FP8 quantisation). + * **Cleanup**: replacing custom parameter wrappers with plain + ``torch.nn.Parameter`` to avoid overhead during inference. + + The default implementation is a no-op. + """ + return + + +class LinearMethodBase(QuantizeMethodBase): + """Base class for quantization methods applied to linear layers. + + Narrows the :class:`QuantizeMethodBase` interface with concrete + signatures tailored to linear (matmul) operations. + + Subclasses must implement :meth:`create_weights` and :meth:`apply`. + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs: Any, + ) -> None: + """Create quantized weight tensors on *layer*. + + Parameters + ---------- + layer + The linear module that will own the parameters. + input_size_per_partition + Number of input features on this TP rank. + output_partition_sizes + Output sizes of each logical weight on this TP rank. For a + standard linear layer this is ``[out_features_per_partition]``. + For a merged QKV layer it might be ``[q_size, k_size, v_size]``. + input_size + Full (un-sharded) input dimension. + output_size + Full (un-sharded) output dimension. + params_dtype + Data type for full-precision parameters (e.g. ``torch.float16``). + **extra_weight_attrs + Additional metadata to attach to created parameters (e.g. + ``weight_loader``, ``packed_dim``, ``packed_factor``). + + Example (AWQ W4A16):: + + # Register packed 4-bit weights, scales, and zero-points + qweight = Parameter(torch.empty(..., dtype=torch.int32)) + layer.register_parameter("qweight", qweight) + + scales = Parameter(torch.empty(..., dtype=params_dtype)) + layer.register_parameter("scales", scales) + + qzeros = Parameter(torch.empty(..., dtype=torch.int32)) + layer.register_parameter("qzeros", qzeros) + """ + raise NotImplementedError + + @abstractmethod + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Compute the quantized linear forward. + + Parameters + ---------- + layer + The module that owns quantized parameters (set by + :meth:`create_weights`). + x + Input activation tensor, shape ``(*, input_size_per_partition)``. + bias + Optional bias vector. + + Returns + ------- + torch.Tensor + Output tensor, shape ``(*, sum(output_partition_sizes))``. + + Example (AWQ W4A16):: + + qweight = layer.qweight # packed int32 + scales = layer.scales # fp16 per-group scales + qzeros = layer.qzeros # packed int32 zero-points + # → invoke dequant + matmul kernel + """ + raise NotImplementedError + + +# --------------------------------------------------------------------------- +# Default unquantized implementation +# --------------------------------------------------------------------------- + + +class UnquantizedLinearMethod(LinearMethodBase): + """Default pass-through for non-quantized linear layers. + + Creates a standard FP weight ``(out_features, in_features)`` and + forwards via ``F.linear``. This is used when no quantization config + is specified so that every linear layer always has a ``quant_method`` + attribute with a uniform interface. + """ + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs: Any, + ) -> None: + """Create a standard full-precision weight parameter.""" + weight = Parameter( + torch.empty( + sum(output_partition_sizes), + input_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0}) + layer.register_parameter("weight", weight) + set_weight_attrs(weight, extra_weight_attrs) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Standard ``F.linear`` forward.""" + return F.linear(x, layer.weight, bias) diff --git a/pymllm/layers/rms_norm.py b/pymllm/layers/rms_norm.py new file mode 100644 index 000000000..b20b36f30 --- /dev/null +++ b/pymllm/layers/rms_norm.py @@ -0,0 +1,84 @@ +from __future__ import annotations + +from typing import Optional, Tuple, Union + +import torch +import flashinfer +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + + +class RMSNorm(MllmBaseLayer): + """RMSNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + self.weight = Parameter(torch.empty(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + flashinfer.norm.fused_add_rmsnorm(x, residual, self.weight.data, self.eps) + return x, residual + + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + # FlashInfer rmsnorm accepts 2D/3D input; flatten higher-rank tensors to 2D. + if x.dim() in (2, 3): + return flashinfer.norm.rmsnorm(x, self.weight, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.rmsnorm(x_2d, self.weight, self.eps) + return out.reshape(original_shape) + + +class GemmaRMSNorm(MllmBaseLayer): + """Gemma-style RMSNorm layer implemented with FlashInfer kernel.""" + + def __init__(self, hidden_size: int, eps: float = 1e-6): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + + self.weight = Parameter(torch.empty(hidden_size)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + residual: Optional[torch.Tensor] = None, + ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if residual is not None: + flashinfer.norm.gemma_fused_add_rmsnorm( + x, residual, self.weight.data, self.eps + ) + return x, residual + + if x.shape[-1] != self.hidden_size: + raise ValueError( + f"Expected last dim == hidden_size ({self.hidden_size}), " + f"but got input shape {tuple(x.shape)}" + ) + + # gemma_rmsnorm is defined on 2D input; flatten other ranks to 2D. + if x.dim() == 2: + return flashinfer.norm.gemma_rmsnorm(x, self.weight, self.eps) + + original_shape = x.shape + x_2d = x.reshape(-1, self.hidden_size) + out = flashinfer.norm.gemma_rmsnorm(x_2d, self.weight, self.eps) + return out.reshape(original_shape) diff --git a/pymllm/layers/rms_norm_gated.py b/pymllm/layers/rms_norm_gated.py new file mode 100644 index 000000000..caec9b88d --- /dev/null +++ b/pymllm/layers/rms_norm_gated.py @@ -0,0 +1,154 @@ +"""Gated RMSNorm layer for Qwen3.5 GDN attention. + +Computes ``rmsnorm(x, weight, eps) * silu(z)`` using a fused CUDA kernel +from mllm-kernel. Falls back to PyTorch when the kernel is unavailable. +""" + +from __future__ import annotations + +import logging +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Parameter + +from pymllm.layers.base import MllmBaseLayer +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Try to load the mllm-kernel fused CUDA implementation +# --------------------------------------------------------------------------- +_HAS_MLLM_KERNEL_CUDA = False +try: + from mllm_kernel.cuda.jit.rms_norm_gated import ( + rms_norm_gated as _mllm_rms_norm_gated, + ) + + _HAS_MLLM_KERNEL_CUDA = True +except Exception: + _mllm_rms_norm_gated = None + + +# --------------------------------------------------------------------------- +# Pure-PyTorch fallback +# --------------------------------------------------------------------------- + + +def _rms_norm_gated_pytorch( + x: torch.Tensor, + weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + eps: float = 1e-6, + norm_before_gate: bool = True, +) -> torch.Tensor: + """Pure-PyTorch reference implementation.""" + dtype = x.dtype + x_fp32 = x.float() + w_fp32 = weight.float() + z_fp32 = z.float() if z is not None else None + + if z_fp32 is not None and not norm_before_gate: + x_fp32 = x_fp32 * F.silu(z_fp32) + + variance = x_fp32.pow(2).mean(dim=-1, keepdim=True) + rstd = torch.rsqrt(variance + eps) + out = x_fp32 * rstd * w_fp32 + + if z_fp32 is not None and norm_before_gate: + out = out * F.silu(z_fp32) + + return out.to(dtype) + + +# --------------------------------------------------------------------------- +# Unified dispatch +# --------------------------------------------------------------------------- + + +def rms_norm_gated( + x: torch.Tensor, + weight: torch.Tensor, + z: Optional[torch.Tensor] = None, + eps: float = 1e-6, + norm_before_gate: bool = True, +) -> torch.Tensor: + """Compute (optionally gated) RMS normalization. + + Uses the fused mllm-kernel CUDA implementation when available, + otherwise falls back to a pure-PyTorch implementation. + """ + if _HAS_MLLM_KERNEL_CUDA and x.is_cuda: + return _mllm_rms_norm_gated(x, weight, z=z, eps=eps) + return _rms_norm_gated_pytorch( + x, weight, z=z, eps=eps, norm_before_gate=norm_before_gate, + ) + + +# --------------------------------------------------------------------------- +# nn.Module wrapper +# --------------------------------------------------------------------------- + + +class RMSNormGated(MllmBaseLayer): + """Gated RMS Normalization layer for Qwen3.5 GDN attention. + + Computes:: + + output = rmsnorm(x, weight) * silu(z) # z is not None + output = rmsnorm(x, weight) # z is None + + Uses a fused CUDA kernel from mllm-kernel for maximum throughput. + + Parameters + ---------- + hidden_size : int + Dimensionality of the input (and weight vector). + eps : float + Small constant for numerical stability. + norm_before_gate : bool + If ``True`` (default): ``rmsnorm(x) * silu(z)``. + If ``False``: ``rmsnorm(x * silu(z))``. + """ + + def __init__( + self, + hidden_size: int, + eps: float = 1e-6, + group_size: Optional[int] = None, + norm_before_gate: bool = True, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ): + super().__init__() + self.hidden_size = hidden_size + self.eps = eps + self.norm_before_gate = norm_before_gate + + factory_kwargs = {} + if device is not None: + factory_kwargs["device"] = device + if dtype is not None: + factory_kwargs["dtype"] = dtype + + self.weight = Parameter(torch.ones(hidden_size, **factory_kwargs)) + set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + + def forward( + self, + x: torch.Tensor, + z: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return rms_norm_gated( + x, self.weight, z=z, eps=self.eps, + norm_before_gate=self.norm_before_gate, + ) + + def extra_repr(self) -> str: + return ( + f"hidden_size={self.hidden_size}, eps={self.eps}, " + f"norm_before_gate={self.norm_before_gate}" + ) diff --git a/pymllm/layers/rope.py b/pymllm/layers/rope.py new file mode 100644 index 000000000..94f89b20d --- /dev/null +++ b/pymllm/layers/rope.py @@ -0,0 +1,401 @@ +from __future__ import annotations + +from typing import List, Optional, Tuple + +import torch +import flashinfer + + +def apply_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 1.0, + rope_theta: float = 1e4, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding to a batch of queries/keys (stored as RaggedTensor). + + cos/sin values are computed on the fly inside the kernel. Position offsets + are provided per-segment via ``indptr`` and ``offsets``. + + Args: + q: Query ragged tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key ragged tensor, shape ``(nnz, num_k_heads, head_dim)``. + indptr: Indptr tensor, shape ``(batch_size + 1,)``. The i-th segment + spans ``q[indptr[i]:indptr[i+1]]``. + offsets: Relative position offsets per segment, shape ``(batch_size,)``. + inplace: If ``True``, apply RoPE in-place and return ``None``. + If ``False``, return new ``(q_rope, k_rope)`` tensors. + rotary_dim: Number of dimensions to apply RoPE to. ``None`` means + the entire ``head_dim``. + interleave: If ``True``, rotate even/odd dims (``[..., ::2]`` / + ``[..., 1::2]``). If ``False``, rotate first/second half dims. + rope_scale: Scaling factor for position indices. + rope_theta: Base frequency theta. + + Returns: + ``None`` when *inplace* is ``True``, otherwise a tuple + ``(q_rope, k_rope)`` of rotated tensors with the same shapes as + the inputs. + """ + if inplace: + flashinfer.rope.apply_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + return None + + return flashinfer.rope.apply_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + + +def apply_llama31_rope( + q: torch.Tensor, + k: torch.Tensor, + indptr: torch.Tensor, + offsets: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 8.0, + rope_theta: float = 5e5, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + old_context_len: int = 8192, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply Llama 3.1 style rotary embedding to a batch of queries/keys. + + This variant adjusts frequencies with ``low_freq_factor``, + ``high_freq_factor``, and ``old_context_len`` following the Llama 3.1 + RoPE recipe. cos/sin values are computed on the fly. + + Args: + q: Query ragged tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key ragged tensor, shape ``(nnz, num_k_heads, head_dim)``. + indptr: Indptr tensor, shape ``(batch_size + 1,)``. + offsets: Relative position offsets per segment, shape ``(batch_size,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. ``None`` means + the entire ``head_dim``. + interleave: If ``True``, rotate even/odd dims; otherwise first/second + half dims. + rope_scale: Scaling factor for position indices (default ``8``). + rope_theta: Base frequency theta (default ``5e5``). + low_freq_factor: Low frequency factor for Llama 3.1 RoPE. + high_freq_factor: High frequency factor for Llama 3.1 RoPE. + old_context_len: Original context length for Llama 3.1 RoPE. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_llama31_rope_inplace( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + return None + + return flashinfer.rope.apply_llama31_rope( + q, + k, + indptr, + offsets, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + + +def apply_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 1.0, + rope_theta: float = 1e4, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding using explicit per-token position IDs. + + Unlike :func:`apply_rope` which derives positions from ``indptr`` / + ``offsets``, this function takes a flat ``pos_ids`` tensor that supplies + an explicit position for every token. + + Args: + q: Query tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key tensor, shape ``(nnz, num_k_heads, head_dim)``. + pos_ids: Position indices, shape ``(nnz,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. + interleave: Interleaved layout flag. + rope_scale: Scaling factor for position indices. + rope_theta: Base frequency theta. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_rope_pos_ids_inplace( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + return None + + return flashinfer.rope.apply_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + ) + + +def apply_llama31_rope_pos_ids( + q: torch.Tensor, + k: torch.Tensor, + pos_ids: torch.Tensor, + inplace: bool = False, + rotary_dim: Optional[int] = None, + interleave: bool = False, + rope_scale: float = 8.0, + rope_theta: float = 5e5, + low_freq_factor: float = 1.0, + high_freq_factor: float = 4.0, + old_context_len: int = 8192, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply Llama 3.1 style RoPE using explicit per-token position IDs. + + Combines Llama 3.1 frequency adjustments with explicit ``pos_ids``. + + Args: + q: Query tensor, shape ``(nnz, num_q_heads, head_dim)``. + k: Key tensor, shape ``(nnz, num_k_heads, head_dim)``. + pos_ids: Position indices, shape ``(nnz,)``. + inplace: If ``True``, apply in-place and return ``None``. + rotary_dim: Number of dimensions to apply RoPE to. + interleave: Interleaved layout flag. + rope_scale: Scaling factor (default ``8``). + rope_theta: Base frequency theta (default ``5e5``). + low_freq_factor: Low frequency factor for Llama 3.1 RoPE. + high_freq_factor: High frequency factor for Llama 3.1 RoPE. + old_context_len: Original context length for Llama 3.1 RoPE. + + Returns: + ``None`` when *inplace* is ``True``, otherwise ``(q_rope, k_rope)``. + """ + if inplace: + flashinfer.rope.apply_llama31_rope_pos_ids_inplace( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + return None + + return flashinfer.rope.apply_llama31_rope_pos_ids( + q, + k, + pos_ids, + rotary_dim=rotary_dim, + interleave=interleave, + rope_scale=rope_scale, + rope_theta=rope_theta, + low_freq_factor=low_freq_factor, + high_freq_factor=high_freq_factor, + old_context_len=old_context_len, + ) + + +def apply_rope_with_cos_sin_cache( + positions: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + head_size: int, + cos_sin_cache: torch.Tensor, + inplace: bool = False, + is_neox: bool = True, +) -> Optional[Tuple[torch.Tensor, torch.Tensor]]: + """Apply rotary embedding with precomputed cos/sin cache. + + Compatible with SGL/vLLM implementations. Note that ``query`` and ``key`` + use a **flattened** head layout ``(nnz, num_heads * head_size)`` instead + of the 3-D layout used by the other ``apply_rope*`` functions. + + Args: + positions: Position indices, shape ``(nnz,)``. + query: Query tensor, shape ``(nnz, num_q_heads * head_size)``. + key: Key tensor, shape ``(nnz, num_k_heads * head_size)``. + head_size: Size of each attention head. + cos_sin_cache: Precomputed cos/sin tensor, shape + ``(max_seq_len, rotary_dim)``. The first half of ``rotary_dim`` + stores cosine values, the second half stores sine values. + inplace: If ``True``, apply in-place and return ``None``. + is_neox: If ``True`` (default), use GPT-NeoX style (rotate + first/second half dims). If ``False``, use interleaved style + (rotate even/odd dims). + + Returns: + ``None`` when *inplace* is ``True``, otherwise + ``(query_out, key_out)`` with the same shapes as the inputs. + """ + if inplace: + flashinfer.rope.apply_rope_with_cos_sin_cache_inplace( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox=is_neox, + ) + return None + + return flashinfer.rope.apply_rope_with_cos_sin_cache( + positions, + query, + key, + head_size, + cos_sin_cache, + is_neox=is_neox, + ) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate the second half of the last dimension into the first half (neox-style).""" + half = x.shape[-1] // 2 + return torch.cat((-x[..., half:], x[..., :half]), dim=-1) + + +def apply_mrope( + q: torch.Tensor, + k: torch.Tensor, + positions: torch.Tensor, + cos_sin_cache: torch.Tensor, + mrope_section: List[int], + mrope_interleaved: bool = True, +) -> Tuple[torch.Tensor, torch.Tensor]: + """Apply multi-dimensional rotary position embedding (M-RoPE). + + Used by Qwen3-VL which assigns independent (t, h, w) position indices to + each token. For text tokens all three indices are the same sequential + value; for image tokens they follow the spatial grid layout. + + Args: + q: Query tensor, shape ``(T, num_q_heads, head_dim)``. + k: Key tensor, shape ``(T, num_kv_heads, head_dim)``. + positions: 3-D position IDs, shape ``(3, T)`` — rows are + ``(temporal, height, width)`` position indices. + cos_sin_cache: Precomputed cache, shape ``(max_pos, head_dim)``. + The first ``head_dim // 2`` columns are cosine values and the + remaining columns are sine values, each for frequencies + ``0, 1, ..., head_dim // 2 - 1``. + mrope_section: Three integers ``[s_t, s_h, s_w]`` that partition + the ``head_dim // 2`` rotary frequency dimensions among the + temporal, height, and width components. + ``sum(mrope_section)`` must equal ``head_dim // 2``. + mrope_interleaved: When ``True`` (Qwen3-VL default), uses the + interleaved layout where frequency dimensions are cycled + ``(t, h, w, t, h, w, ...)`` rather than grouped consecutively. + + Returns: + ``(q_rope, k_rope)`` with the same shapes as the inputs. + """ + rotary_dim = cos_sin_cache.shape[-1] # = head_dim + half_dim = rotary_dim // 2 + + # Look up cos/sin for each of the 3 position dimensions. + # positions: [3, T] => cos_sin: [3, T, rotary_dim] + cos_sin = cos_sin_cache[positions] + cos = cos_sin[..., :half_dim] # [3, T, half_dim] + sin = cos_sin[..., half_dim:] # [3, T, half_dim] + + if mrope_interleaved: + # Interleaved layout (Qwen3-VL): within the first + # mrope_section[1]*3 frequency dims, indices cycle (t, h, w). + # Remaining dims (indices >= span) all use the temporal position. + # Matches SGLang's apply_interleaved_rope. + cos_merged = cos[0].clone() # start with temporal; shape [T, half_dim] + sin_merged = sin[0].clone() + span_h = mrope_section[1] * 3 + span_w = mrope_section[2] * 3 + cos_merged[..., 1:span_h:3] = cos[1, ..., 1:span_h:3] + cos_merged[..., 2:span_w:3] = cos[2, ..., 2:span_w:3] + sin_merged[..., 1:span_h:3] = sin[1, ..., 1:span_h:3] + sin_merged[..., 2:span_w:3] = sin[2, ..., 2:span_w:3] + else: + # Non-interleaved (Qwen2-VL style): consecutive frequency sections. + cos_sects = cos.split(mrope_section, dim=-1) # list of [T, s_i] + sin_sects = sin.split(mrope_section, dim=-1) + # Section i picks its cos/sin from positions[i] + cos_merged = torch.cat( + [cos_sects[i][i] for i in range(3)], dim=-1 + ) # [T, half_dim] + sin_merged = torch.cat( + [sin_sects[i][i] for i in range(3)], dim=-1 + ) # [T, half_dim] + + # Expand to full rotary_dim for the neox-style rotation formula: + # q_rot = q * cos_full + rotate_half(q) * sin_full + cos_full = cos_merged.repeat(1, 2) # [T, rotary_dim] + sin_full = sin_merged.repeat(1, 2) # [T, rotary_dim] + cos_4d = cos_full.unsqueeze(1) # [T, 1, rotary_dim] -- broadcasts over heads + sin_4d = sin_full.unsqueeze(1) + + q_rot = q[..., :rotary_dim] * cos_4d + _rotate_half(q[..., :rotary_dim]) * sin_4d + k_rot = k[..., :rotary_dim] * cos_4d + _rotate_half(k[..., :rotary_dim]) * sin_4d + + q_out = ( + torch.cat([q_rot, q[..., rotary_dim:]], dim=-1) + if rotary_dim < q.shape[-1] + else q_rot + ) + k_out = ( + torch.cat([k_rot, k[..., rotary_dim:]], dim=-1) + if rotary_dim < k.shape[-1] + else k_rot + ) + return q_out, k_out diff --git a/pymllm/layers/sampling.py b/pymllm/layers/sampling.py new file mode 100644 index 000000000..26c769ffd --- /dev/null +++ b/pymllm/layers/sampling.py @@ -0,0 +1,776 @@ +"""Sampling operations with FlashInfer acceleration and PyTorch fallback. + +This module wraps all flashinfer.sampling APIs and provides pure-PyTorch +fallback implementations so that the rest of the codebase can import from +here without worrying about whether FlashInfer is installed. +""" + +from __future__ import annotations + +import logging +from typing import Optional, Tuple, Union + +import torch + +logger = logging.getLogger(__name__) + +try: + import flashinfer.sampling as _fi_sampling + + _HAS_FLASHINFER = True +except ImportError: + _HAS_FLASHINFER = False + logger.warning("flashinfer not found, falling back to PyTorch sampling kernels") + + +# --------------------------------------------------------------------------- +# Helper utilities (torch fallback) +# --------------------------------------------------------------------------- + + +def _resolve_indices( + data: torch.Tensor, indices: Optional[torch.Tensor] +) -> torch.Tensor: + """If *indices* is given, gather rows from *data* accordingly.""" + if indices is None: + return data + return data[indices.long()] + + +def _to_scalar_or_tensor( + value: Union[torch.Tensor, float, int], + batch_size: int, + device: torch.device, +) -> torch.Tensor: + """Broadcast a scalar or per-batch tensor to shape ``(batch_size,)``.""" + if isinstance(value, (int, float)): + return torch.full((batch_size,), value, device=device, dtype=torch.float32) + return value.to(device=device, dtype=torch.float32) + + +# --------------------------------------------------------------------------- +# softmax +# --------------------------------------------------------------------------- + + +def softmax( + logits: torch.Tensor, + temperature: Optional[Union[torch.Tensor, float]] = None, + enable_pdl: Optional[bool] = None, +) -> torch.Tensor: + """Safe softmax with optional temperature scaling. + + Parameters + ---------- + logits : torch.Tensor + Shape ``(batch_size, num_classes)``. + temperature : Optional[Union[torch.Tensor, float]] + Scalar or per-request ``(batch_size,)`` temperature. + enable_pdl : Optional[bool] + FlashInfer PDL flag (ignored in fallback). + + Returns + ------- + torch.Tensor + Probabilities with the same shape as *logits*. + """ + # Clamp temperature to avoid division by zero (temperature=0 → greedy). + # Replace 0 with 1 here; the caller (ModelRunner.sample) handles + # temperature=0 via argmax before reaching this path. + if temperature is not None: + if isinstance(temperature, torch.Tensor): + temperature = temperature.clamp(min=1e-6) + elif temperature < 1e-6: + temperature = 1.0 # effectively no scaling; caller uses argmax + + if _HAS_FLASHINFER: + return _fi_sampling.softmax( + logits, temperature=temperature, enable_pdl=enable_pdl + ) + + if temperature is not None: + if isinstance(temperature, (int, float)): + logits = logits / temperature + else: + logits = logits / temperature.unsqueeze(-1) + return torch.softmax(logits, dim=-1) + + +# --------------------------------------------------------------------------- +# sampling_from_probs +# --------------------------------------------------------------------------- + + +def sampling_from_probs( + probs: torch.Tensor, + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Category sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)`` or ``(unique_batch_size, num_classes)`` + when *indices* is provided. + indices : Optional[torch.Tensor] + Maps each output to a row in *probs*. + deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.sampling_from_probs( + probs, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices) + samples = torch.multinomial(p.float(), num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# sampling_from_logits +# --------------------------------------------------------------------------- + + +def sampling_from_logits( + logits: torch.Tensor, + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Category sampling from logits (applies softmax internally). + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.sampling_from_logits( + logits, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + probs = torch.softmax(logits.float(), dim=-1) + return sampling_from_probs( + probs, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + ) + + +# --------------------------------------------------------------------------- +# top_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_p_sampling_from_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-p (nucleus) sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_p : Union[torch.Tensor, float] + Top-p threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_p_sampling_from_probs( + probs, + top_p, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + renormed = _torch_top_p_renorm_probs(p, top_p) + samples = torch.multinomial(renormed, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_k_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_k_sampling_from_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_sampling_from_probs( + probs, + top_k, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + renormed = _torch_top_k_renorm_probs(p, top_k) + samples = torch.multinomial(renormed, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# min_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def min_p_sampling_from_probs( + probs: torch.Tensor, + min_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Min-p sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + min_p : Union[torch.Tensor, float] + Min-p threshold. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.min_p_sampling_from_probs( + probs, + min_p, + indices=indices, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + batch_size = p.shape[0] + min_p_t = _to_scalar_or_tensor(min_p, batch_size, p.device) + # min-p: keep tokens whose probability >= min_p * max_prob + max_probs = p.max(dim=-1, keepdim=True).values # (B,1) + threshold = min_p_t.unsqueeze(-1) * max_probs # (B,1) + mask = p < threshold + filtered = p.clone() + filtered[mask] = 0.0 + # renormalize + sums = filtered.sum(dim=-1, keepdim=True) + sums = sums.clamp(min=1e-8) + filtered = filtered / sums + samples = torch.multinomial(filtered, num_samples=1, generator=generator).squeeze( + -1 + ) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_k_top_p_sampling_from_logits +# --------------------------------------------------------------------------- + + +def top_k_top_p_sampling_from_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k + top-p sampling from pre-softmax logits. + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + top_p : Union[torch.Tensor, float] + filter_apply_order : str + ``"top_k_first"`` or ``"joint"``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_top_p_sampling_from_logits( + logits, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + probs = torch.softmax(logits.float(), dim=-1) + return top_k_top_p_sampling_from_probs( + probs, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + ) + + +# --------------------------------------------------------------------------- +# top_k_top_p_sampling_from_probs +# --------------------------------------------------------------------------- + + +def top_k_top_p_sampling_from_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], + top_p: Union[torch.Tensor, float], + indices: Optional[torch.Tensor] = None, + filter_apply_order: str = "top_k_first", + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + check_nan: bool = False, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> torch.Tensor: + """Top-k + top-p sampling from probabilities. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + top_p : Union[torch.Tensor, float] + filter_apply_order : str + ``"top_k_first"`` or ``"joint"``. + indices, deterministic, generator, check_nan, seed, offset + See FlashInfer docs. + + Returns + ------- + torch.Tensor + Sampled token ids, shape ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_top_p_sampling_from_probs( + probs, + top_k, + top_p, + indices=indices, + filter_apply_order=filter_apply_order, + deterministic=deterministic, + generator=generator, + check_nan=check_nan, + seed=seed, + offset=offset, + ) + + p = _resolve_indices(probs, indices).float() + if filter_apply_order == "top_k_first": + p = _torch_top_k_renorm_probs(p, top_k) + p = _torch_top_p_renorm_probs(p, top_p) + else: + # joint: apply both filters simultaneously + p = _torch_top_k_renorm_probs(p, top_k) + p = _torch_top_p_renorm_probs(p, top_p) + samples = torch.multinomial(p, num_samples=1, generator=generator).squeeze(-1) + return samples.to(torch.int32) + + +# --------------------------------------------------------------------------- +# top_p_renorm_probs +# --------------------------------------------------------------------------- + + +def top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + """Renormalize probabilities by top-p thresholding. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_p : Union[torch.Tensor, float] + Top-p threshold in ``(0, 1)``. + + Returns + ------- + torch.Tensor + Renormalized probabilities. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_p_renorm_probs(probs, top_p) + + return _torch_top_p_renorm_probs(probs.float(), top_p).to(probs.dtype) + + +def _torch_top_p_renorm_probs( + probs: torch.Tensor, + top_p: Union[torch.Tensor, float], +) -> torch.Tensor: + """Pure-torch top-p renormalization (operates on float32).""" + sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True) + cumsum = torch.cumsum(sorted_probs, dim=-1) + + if isinstance(top_p, (int, float)): + mask = cumsum - sorted_probs > top_p + else: + top_p_t = top_p.unsqueeze(-1) + mask = cumsum - sorted_probs > top_p_t + + sorted_probs[mask] = 0.0 + # scatter back + result = torch.zeros_like(probs) + result.scatter_(1, sorted_indices, sorted_probs) + # renormalize + sums = result.sum(dim=-1, keepdim=True).clamp(min=1e-8) + return result / sums + + +# --------------------------------------------------------------------------- +# top_k_renorm_probs +# --------------------------------------------------------------------------- + + +def top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Renormalize probabilities by top-k thresholding. + + Parameters + ---------- + probs : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + + Returns + ------- + torch.Tensor + Renormalized probabilities. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_renorm_probs(probs, top_k) + + return _torch_top_k_renorm_probs(probs.float(), top_k).to(probs.dtype) + + +def _torch_top_k_renorm_probs( + probs: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Pure-torch top-k renormalization (operates on float32).""" + if isinstance(top_k, int): + # uniform top_k across batch + topk_vals, _ = torch.topk(probs, top_k, dim=-1) + threshold = topk_vals[:, -1:] # (B, 1) + else: + # per-request top_k: use sorting + sorted_probs, _ = torch.sort(probs, dim=-1, descending=True) + # gather the k-th value for each row + k_indices = (top_k.long() - 1).unsqueeze(-1) # (B, 1) + threshold = sorted_probs.gather(1, k_indices) # (B, 1) + + mask = probs < threshold + filtered = probs.clone() + filtered[mask] = 0.0 + sums = filtered.sum(dim=-1, keepdim=True).clamp(min=1e-8) + return filtered / sums + + +# --------------------------------------------------------------------------- +# top_k_mask_logits +# --------------------------------------------------------------------------- + + +def top_k_mask_logits( + logits: torch.Tensor, + top_k: Union[torch.Tensor, int], +) -> torch.Tensor: + """Mask logits by top-k thresholding (set non-top-k to -inf). + + Parameters + ---------- + logits : torch.Tensor + ``(batch_size, num_classes)``. + top_k : Union[torch.Tensor, int] + Top-k threshold. + + Returns + ------- + torch.Tensor + Masked logits with the same shape and dtype. + """ + if _HAS_FLASHINFER: + return _fi_sampling.top_k_mask_logits(logits, top_k) + + if isinstance(top_k, int): + topk_vals, _ = torch.topk(logits, top_k, dim=-1) + threshold = topk_vals[:, -1:] + else: + sorted_logits, _ = torch.sort(logits, dim=-1, descending=True) + k_indices = (top_k.long() - 1).unsqueeze(-1) + threshold = sorted_logits.gather(1, k_indices) + + mask = logits < threshold + result = logits.clone() + result[mask] = float("-inf") + return result + + +# --------------------------------------------------------------------------- +# chain_speculative_sampling +# --------------------------------------------------------------------------- + + +def chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + target_probs: torch.Tensor, + maybe_output_accepted_token_num: Optional[torch.Tensor] = None, + maybe_output_emitted_draft_token_num: Optional[torch.Tensor] = None, + deterministic: bool = True, + generator: Optional[torch.Generator] = None, + seed: Optional[int] = None, + offset: Optional[int] = None, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Speculative sampling for sequence generation. + + Parameters + ---------- + draft_probs : torch.Tensor + ``(batch_size, num_speculate_tokens, vocab_size)``. + draft_token_ids : torch.Tensor + ``(batch_size, num_speculate_tokens)``. + target_probs : torch.Tensor + ``(batch_size, num_speculate_tokens + 1, vocab_size)``. + maybe_output_accepted_token_num : Optional[torch.Tensor] + If provided, accepted counts are added in-place. + maybe_output_emitted_draft_token_num : Optional[torch.Tensor] + If provided, emitted counts are added in-place. + deterministic, generator, seed, offset + See FlashInfer docs. + + Returns + ------- + output_token_ids : torch.Tensor + ``(batch_size, num_speculate_tokens + 1)``, rejected slots padded with -1. + output_accepted_token_num : torch.Tensor + ``(batch_size,)``. + output_emitted_draft_token_num : torch.Tensor + ``(batch_size,)``. + """ + if _HAS_FLASHINFER: + return _fi_sampling.chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_probs, + maybe_output_accepted_token_num=maybe_output_accepted_token_num, + maybe_output_emitted_draft_token_num=maybe_output_emitted_draft_token_num, + deterministic=deterministic, + generator=generator, + seed=seed, + offset=offset, + ) + + return _torch_chain_speculative_sampling( + draft_probs, + draft_token_ids, + target_probs, + maybe_output_accepted_token_num, + maybe_output_emitted_draft_token_num, + generator, + ) + + +def _torch_chain_speculative_sampling( + draft_probs: torch.Tensor, + draft_token_ids: torch.Tensor, + target_probs: torch.Tensor, + maybe_output_accepted_token_num: Optional[torch.Tensor], + maybe_output_emitted_draft_token_num: Optional[torch.Tensor], + generator: Optional[torch.Generator], +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pure-torch chain speculative sampling. + + Implements the rejection-sampling algorithm from + "Accelerating Large Language Model Decoding with Speculative Sampling" + (Leviathan et al., 2023). + """ + batch_size, num_spec, vocab_size = draft_probs.shape + device = draft_probs.device + + output_ids = torch.full( + (batch_size, num_spec + 1), -1, dtype=torch.int32, device=device + ) + accepted_count = torch.zeros(batch_size, dtype=torch.int32, device=device) + emitted_count = torch.zeros(batch_size, dtype=torch.int32, device=device) + + for b in range(batch_size): + all_accepted = True + for t in range(num_spec): + draft_tok = draft_token_ids[b, t].item() + p_draft = draft_probs[b, t, draft_tok].item() + p_target = target_probs[b, t, draft_tok].item() + + # independent acceptance check (for the metric) + if p_target >= p_draft: + accepted_count[b] += 1 + else: + r = torch.rand(1, generator=generator, device=device).item() + if r < p_target / max(p_draft, 1e-10): + accepted_count[b] += 1 + + # sequential chain: accept / reject + if all_accepted: + r = torch.rand(1, generator=generator, device=device).item() + if r < min(1.0, p_target / max(p_draft, 1e-10)): + output_ids[b, t] = draft_tok + emitted_count[b] += 1 + else: + # reject: sample from max(0, p_target - p_draft) + diff = target_probs[b, t].float() - draft_probs[b, t].float() + diff = torch.clamp(diff, min=0.0) + dsum = diff.sum() + if dsum > 1e-8: + diff = diff / dsum + else: + diff = target_probs[b, t].float() + diff = diff / diff.sum().clamp(min=1e-8) + resampled = torch.multinomial( + diff.unsqueeze(0), num_samples=1, generator=generator + ).item() + output_ids[b, t] = resampled + emitted_count[b] += 1 + all_accepted = False + + # bonus token (sampled from target at position after last emitted) + if all_accepted: + pos = num_spec + bonus_probs = target_probs[b, pos].float() + bonus_probs = bonus_probs / bonus_probs.sum().clamp(min=1e-8) + bonus = torch.multinomial( + bonus_probs.unsqueeze(0), num_samples=1, generator=generator + ).item() + output_ids[b, num_spec] = bonus + + if maybe_output_accepted_token_num is not None: + maybe_output_accepted_token_num.add_(accepted_count) + if maybe_output_emitted_draft_token_num is not None: + maybe_output_emitted_draft_token_num.add_(emitted_count) + + return output_ids, accepted_count, emitted_count + + +# --------------------------------------------------------------------------- +# Aliases (FlashInfer also exposes these) +# --------------------------------------------------------------------------- +top_p_renorm_prob = top_p_renorm_probs +top_k_renorm_prob = top_k_renorm_probs diff --git a/pymllm/layers/utils.py b/pymllm/layers/utils.py new file mode 100644 index 000000000..0dcbd1ac0 --- /dev/null +++ b/pymllm/layers/utils.py @@ -0,0 +1,45 @@ +"""Utility functions for layers.""" + +from typing import Any, Dict + +import torch + + +def set_weight_attrs( + weight: torch.Tensor, + weight_attrs: Dict[str, Any] | None, +) -> None: + """Set attributes on a weight tensor. + + This method is used to set attributes on a weight tensor. This method + will not overwrite existing attributes. + + Args: + weight: The weight tensor or parameter. + weight_attrs: A dictionary of attributes to set on the weight tensor. + Common attributes include: + - output_dim: The dimension along which to shard the weight (typically 0 for output dim) + - input_dim: The input dimension (typically 1 for input dim) + - weight_loader: A callable to load weights into this parameter + - packed_dim: The dimension along which the weight is packed (for quantization) + - packed_factor: The packing factor (for quantization) + + Example: + >>> weight = nn.Parameter(torch.empty(100, 64)) + >>> set_weight_attrs(weight, { + ... "output_dim": 0, + ... "input_dim": 1, + ... "weight_loader": my_loader_func, + ... }) + """ + if weight_attrs is None: + return + + for key, value in weight_attrs.items(): + if hasattr(weight, key): + raise AttributeError( + f"Overwriting existing tensor attribute: {key}. " + f"Existing value: {getattr(weight, key)}, " + f"New value: {value}" + ) + setattr(weight, key, value) diff --git a/pymllm/mem_cache/__init__.py b/pymllm/mem_cache/__init__.py new file mode 100644 index 000000000..cc449e426 --- /dev/null +++ b/pymllm/mem_cache/__init__.py @@ -0,0 +1,46 @@ +from pymllm.mem_cache.base_prefix_cache import ( + BasePrefixCache, + EvictResult, + InsertResult, + MatchResult, + RadixKey, + hash_bytes, + hash_to_int64, + hash_token_ids, +) +from pymllm.mem_cache.chunk_cache import ChunkCache +from pymllm.mem_cache.mamba_radix_cache import MambaRadixCache, MambaTreeNode +from pymllm.mem_cache.memory_pool import ( + KVPool, + ReqToTokenPool, + TokenToKVPoolAllocator, + make_full_attention_net_mem_pool, + make_req_to_token_pool, +) +from pymllm.mem_cache.radix_cache import RadixCache, TreeNode + +__all__ = [ + # base_prefix_cache + "BasePrefixCache", + "RadixKey", + "MatchResult", + "InsertResult", + "EvictResult", + "hash_token_ids", + "hash_to_int64", + "hash_bytes", + # radix_cache + "RadixCache", + "TreeNode", + # chunk_cache + "ChunkCache", + # mamba_radix_cache + "MambaRadixCache", + "MambaTreeNode", + # memory_pool + "KVPool", + "TokenToKVPoolAllocator", + "ReqToTokenPool", + "make_full_attention_net_mem_pool", + "make_req_to_token_pool", +] diff --git a/pymllm/mem_cache/base_prefix_cache.py b/pymllm/mem_cache/base_prefix_cache.py new file mode 100644 index 000000000..a49355d6e --- /dev/null +++ b/pymllm/mem_cache/base_prefix_cache.py @@ -0,0 +1,206 @@ +"""Abstract base class and shared data types for prefix cache implementations. + +All concrete caches (:class:`RadixCache`, :class:`ChunkCache`, +:class:`MambaRadixCache`) inherit from :class:`BasePrefixCache` and share +the data classes defined here. +""" + +from __future__ import annotations + +import hashlib +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any, Iterator, List, Optional, Tuple, Union + +import torch + + +# ====================================================================== +# Hashing utilities +# ====================================================================== + + +def hash_token_ids( + token_ids: List[Union[int, Tuple[int, ...]]], + prior_hash: Optional[str] = None, +) -> str: + """SHA-256 hash of a token-id page with optional chain-hash. + + Each token is encoded as a 4-byte little-endian unsigned integer; + tuples (bigram / EAGLE) hash each element in order. When *prior_hash* + is supplied the digest is seeded with the raw bytes of the previous + hash, making the result position-aware. + """ + hasher = hashlib.sha256() + if prior_hash: + hasher.update(bytes.fromhex(prior_hash)) + for t in token_ids: + if isinstance(t, tuple): + for elem in t: + hasher.update(elem.to_bytes(4, byteorder="little", signed=False)) + else: + hasher.update(t.to_bytes(4, byteorder="little", signed=False)) + return hasher.hexdigest() + + +def hash_to_int64(hex_str: str) -> int: + """Convert a hex digest to a signed 64-bit integer (first 16 hex chars).""" + val = int(hex_str[:16], 16) + return val - (1 << 64) if val >= (1 << 63) else val + + +def hash_bytes(data: bytes) -> int: + """SHA-256 -> unsigned 64-bit int. Useful for multimodal embedding keys.""" + return int.from_bytes(hashlib.sha256(data).digest()[:8], "big", signed=False) + + +# ====================================================================== +# Compound lookup key +# ====================================================================== + + +class RadixKey: + """Compound lookup key: token-id sequence + optional namespace tag. + + ``extra_key`` isolates independent namespaces so that sequences with + identical leading tokens but different adapters / LoRA ids / multimodal + context hashes never share prefix nodes. + """ + + __slots__ = ("token_ids", "extra_key") + + def __init__( + self, + token_ids: List[Union[int, Tuple[int, ...]]], + extra_key: Optional[str] = None, + ): + self.token_ids = token_ids + self.extra_key = extra_key + + def __len__(self) -> int: + return len(self.token_ids) + + def __iter__(self) -> Iterator: + return iter(self.token_ids) + + def __getitem__(self, idx: Union[int, slice]) -> RadixKey: + if isinstance(idx, slice): + return RadixKey(self.token_ids[idx], self.extra_key) + return RadixKey([self.token_ids[idx]], self.extra_key) + + def __repr__(self) -> str: + preview = self.token_ids[:10] + tail = "..." if len(self.token_ids) > 10 else "" + return f"RadixKey(extra={self.extra_key!r}, toks={preview}{tail})" + + +# ====================================================================== +# Result data classes +# ====================================================================== + + +@dataclass +class MatchResult: + """Returned by :meth:`BasePrefixCache.match_prefix`.""" + + indices: torch.Tensor + last_node: Any = None + prefix_len: int = 0 + # SSM / Mamba support + mamba_branching_seqlen: Optional[int] = None + + +@dataclass +class InsertResult: + """Returned by :meth:`BasePrefixCache.insert`.""" + + prefix_len: int = 0 + last_node: Any = None + # SSM / Mamba support: True when mamba state already existed in tree + mamba_exist: bool = False + + +@dataclass +class EvictResult: + """Returned by :meth:`BasePrefixCache.evict`.""" + + full_evicted: int = 0 + swa_evicted: int = 0 + mamba_evicted: int = 0 + + +# ====================================================================== +# Abstract base class +# ====================================================================== + + +class BasePrefixCache(ABC): + """Abstract interface for all prefix cache implementations. + + Concrete implementations: + + * :class:`~pymllm.mem_cache.radix_cache.RadixCache` -- radix-tree with + SWA tombstone support + * :class:`~pymllm.mem_cache.chunk_cache.ChunkCache` -- no-op fallback + (``disable_radix_cache=True``) + * :class:`~pymllm.mem_cache.mamba_radix_cache.MambaRadixCache` -- radix-tree + with independent Mamba/SSM state tracking + """ + + @abstractmethod + def reset(self) -> None: + """Clear all cached state and re-initialise.""" + ... + + @abstractmethod + def match_prefix(self, key: RadixKey) -> MatchResult: + """Find the longest cached prefix of *key*.""" + ... + + @abstractmethod + def insert( + self, + key: RadixKey, + value: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> InsertResult: + """Insert *key*/*value* into the cache.""" + ... + + @abstractmethod + def evict(self, num_tokens: int, swa_num_tokens: int = 0) -> EvictResult: + """Evict tokens to free memory.""" + ... + + @abstractmethod + def inc_lock_ref(self, node: Any) -> Optional[Any]: + """Lock *node* (and ancestors) to prevent eviction. + + Returns an opaque token (e.g. ``swa_boundary_id``) that must be + passed back to :meth:`dec_lock_ref`. + """ + ... + + @abstractmethod + def dec_lock_ref(self, node: Any, **kwargs: Any) -> None: + """Unlock *node* (and ancestors).""" + ... + + # ------------------------------------------------------------------ + # Size queries (default implementations return 0) + # ------------------------------------------------------------------ + + def evictable_size(self) -> int: + return 0 + + def swa_evictable_size(self) -> int: + return 0 + + def protected_size(self) -> int: + return 0 + + def swa_protected_size(self) -> int: + return 0 + + def total_size(self) -> int: + return 0 diff --git a/pymllm/mem_cache/chunk_cache.py b/pymllm/mem_cache/chunk_cache.py new file mode 100644 index 000000000..c53b2b69e --- /dev/null +++ b/pymllm/mem_cache/chunk_cache.py @@ -0,0 +1,74 @@ +"""No-op prefix cache used when ``disable_radix_cache=True``. + +Every request is fully computed from scratch -- no prefix sharing, no +tree structure, no eviction logic. This is the simplest possible +:class:`~pymllm.mem_cache.base_prefix_cache.BasePrefixCache` implementation. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import torch + +from pymllm.mem_cache.base_prefix_cache import ( + BasePrefixCache, + EvictResult, + InsertResult, + MatchResult, + RadixKey, +) + + +class ChunkCache(BasePrefixCache): + """No-op prefix cache: no prefix sharing, no eviction. + + When the radix cache is disabled, this class replaces it so that + the rest of the system can call the same interface without branching. + + Parameters + ---------- + token_to_kv_pool_allocator: + Pool allocator used to free KV indices on request completion. + device: + Device for empty tensors returned by :meth:`match_prefix`. + """ + + def __init__( + self, + token_to_kv_pool_allocator: Any = None, + device: torch.device = torch.device("cpu"), + ): + self.pool = token_to_kv_pool_allocator + self.device = device + + def reset(self) -> None: + pass + + def match_prefix(self, key: RadixKey) -> MatchResult: + """Always returns an empty match (no prefix sharing).""" + return MatchResult( + indices=torch.empty(0, dtype=torch.int64, device=self.device), + last_node=None, + ) + + def insert( + self, + key: RadixKey, + value: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> InsertResult: + """No-op: nothing is cached.""" + return InsertResult() + + def evict(self, num_tokens: int, swa_num_tokens: int = 0) -> EvictResult: + """No-op: nothing to evict.""" + return EvictResult() + + def inc_lock_ref(self, node: Any) -> Optional[Any]: + """No-op: nothing to lock.""" + return None + + def dec_lock_ref(self, node: Any, **kwargs: Any) -> None: + """No-op: nothing to unlock.""" + pass diff --git a/pymllm/mem_cache/mamba_radix_cache.py b/pymllm/mem_cache/mamba_radix_cache.py new file mode 100644 index 000000000..bee8027e6 --- /dev/null +++ b/pymllm/mem_cache/mamba_radix_cache.py @@ -0,0 +1,653 @@ +"""Radix-tree KV cache with independent Mamba/SSM state tracking. + +Extends :class:`~pymllm.mem_cache.radix_cache.RadixCache` with dual-tracked +state for hybrid models that combine full attention layers and SSM (Mamba / +GDN) layers. Each tree node stores both: + +- ``value``: KV-pool indices for full-attention layers +- ``mamba_value``: state-pool indices for SSM layers + +The two pools have **independent reference counting and LRU eviction**: +Mamba state can be evicted more aggressively than full KV cache. + +Reference: sglang ``MambaRadixCache``. +""" + +from __future__ import annotations + +import heapq +import logging +import time +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch + +from pymllm.mem_cache.base_prefix_cache import ( + BasePrefixCache, + EvictResult, + InsertResult, + MatchResult, + RadixKey, +) +from pymllm.mem_cache.radix_cache import ( + TreeNode as _BaseTreeNode, + _child_key, + _key_match, + _next_node_id, +) + +logger = logging.getLogger(__name__) + + +# ====================================================================== +# Mamba-aware tree node +# ====================================================================== + + +class MambaTreeNode: + """Tree node with dual KV + Mamba state tracking. + + Invariant: ``full_lock_ref >= mamba_lock_ref``. If Mamba state is + locked, full KV must also be locked; full KV alone can be locked + without locking Mamba state. + """ + + __slots__ = ( + "children", + "parent", + "key", + "value", + "mamba_value", + "full_lock_ref", + "mamba_lock_ref", + "last_access_time", + "hit_count", + "id", + # LRU doubly-linked list pointers (full) + "prev", + "next", + # LRU doubly-linked list pointers (mamba) + "mamba_prev", + "mamba_next", + ) + + def __init__(self) -> None: + self.children: Dict[Any, MambaTreeNode] = defaultdict(MambaTreeNode) + self.parent: Optional[MambaTreeNode] = None + self.key: Optional[RadixKey] = None + self.value: Optional[torch.Tensor] = None + self.mamba_value: Optional[torch.Tensor] = None + + self.full_lock_ref: int = 0 + self.mamba_lock_ref: int = 0 + + self.last_access_time: float = time.monotonic() + self.hit_count: int = 0 + self.id: int = _next_node_id() + + # LRU list pointers + self.prev: Optional[MambaTreeNode] = None + self.next: Optional[MambaTreeNode] = None + self.mamba_prev: Optional[MambaTreeNode] = None + self.mamba_next: Optional[MambaTreeNode] = None + + @property + def evicted(self) -> bool: + return self.value is None + + @property + def mamba_tombstone(self) -> bool: + """Node has full KV but Mamba state was evicted.""" + return self.value is not None and self.mamba_value is None + + def __lt__(self, other: MambaTreeNode) -> bool: + return self.last_access_time < other.last_access_time + + +# ====================================================================== +# Doubly-linked LRU list +# ====================================================================== + + +class LRUList: + """Intrusive doubly-linked list for LRU ordering. + + Supports two modes via *mamba* flag: uses ``prev``/``next`` or + ``mamba_prev``/``mamba_next`` pointers on :class:`MambaTreeNode`. + """ + + def __init__(self, mamba: bool = False): + self.mamba = mamba + if mamba: + self._prv = "mamba_prev" + self._nxt = "mamba_next" + self._lock = "mamba_lock_ref" + else: + self._prv = "prev" + self._nxt = "next" + self._lock = "full_lock_ref" + + # Sentinel head (MRU side) and tail (LRU side) + self.head = MambaTreeNode() + self.tail = MambaTreeNode() + setattr(self.head, self._nxt, self.tail) + setattr(self.tail, self._prv, self.head) + self._cache: Dict[int, MambaTreeNode] = {} + + def __len__(self) -> int: + return len(self._cache) + + def __contains__(self, node: Optional[MambaTreeNode]) -> bool: + return node is not None and node.id in self._cache + + # -- Mutations -------------------------------------------------------- + + def insert_mru(self, node: MambaTreeNode) -> None: + """Insert *node* at the MRU (head) position.""" + self._cache[node.id] = node + self._add_after(self.head, node) + + def remove(self, node: MambaTreeNode) -> None: + """Remove *node* from the list.""" + self._cache.pop(node.id, None) + self._unlink(node) + + def touch_mru(self, node: MambaTreeNode) -> None: + """Move an existing *node* to the MRU position.""" + if node.id not in self._cache: + return + self._unlink(node) + self._add_after(self.head, node) + + def touch_node_and_parents_mru( + self, node: MambaTreeNode, root: MambaTreeNode + ) -> None: + """Move *node* and all ancestors up to *root* to MRU. + + Child is more recently used than parent. + """ + prev = self.head + cur = node + while cur != root: + if cur.id in self._cache: + if self.mamba and cur.mamba_value is None: + cur = cur.parent + continue + self._unlink(cur) + self._add_after(prev, cur) + prev = cur + cur = cur.parent + + # -- Queries ---------------------------------------------------------- + + def get_lru_leaf_unlocked(self) -> Optional[MambaTreeNode]: + """Return the LRU leaf node with lock_ref == 0, or ``None``.""" + x = getattr(self.tail, self._prv) + while x != self.head: + if getattr(x, self._lock) == 0 and len(x.children) == 0: + return x + x = getattr(x, self._prv) + return None + + def get_lru_unlocked(self) -> Optional[MambaTreeNode]: + """Return the LRU node with lock_ref == 0, or ``None``.""" + x = getattr(self.tail, self._prv) + while x != self.head: + if getattr(x, self._lock) == 0: + return x + x = getattr(x, self._prv) + return None + + # -- Internal --------------------------------------------------------- + + def _add_after(self, old: MambaTreeNode, new: MambaTreeNode) -> None: + nxt = getattr(old, self._nxt) + setattr(new, self._prv, old) + setattr(new, self._nxt, nxt) + setattr(nxt, self._prv, new) + setattr(old, self._nxt, new) + + def _unlink(self, node: MambaTreeNode) -> None: + prv = getattr(node, self._prv) + nxt = getattr(node, self._nxt) + if prv is not None: + setattr(prv, self._nxt, nxt) + if nxt is not None: + setattr(nxt, self._prv, prv) + setattr(node, self._prv, None) + setattr(node, self._nxt, None) + + +# ====================================================================== +# MambaRadixCache +# ====================================================================== + + +class MambaRadixCache(BasePrefixCache): + """Radix tree with independent Mamba/SSM state tracking. + + Parameters + ---------- + page_size: + Number of tokens per KV-pool page. + token_to_kv_pool_allocator: + Pool allocator for full-attention KV indices. + mamba_pool: + Pool object for Mamba/SSM state. Must support ``alloc_track_slot()``, + ``free_track_slot(slot)``, ``copy_states(src, dst)``. + on_node_evict: + Optional callback invoked with node id on eviction. + """ + + def __init__( + self, + page_size: int = 1, + token_to_kv_pool_allocator: Any = None, + mamba_pool: Any = None, + on_node_evict: Optional[Callable[[int], None]] = None, + ): + self.page_size = page_size + self.pool = token_to_kv_pool_allocator + self.mamba_pool = mamba_pool + self.on_node_evict = on_node_evict + + if self.pool is not None and hasattr(self.pool, "device"): + self.device = self.pool.device + else: + self.device = torch.device("cpu") + + # Dual LRU lists + self.full_lru = LRUList(mamba=False) + self.mamba_lru = LRUList(mamba=True) + + # Size counters + self._full_evictable: int = 0 + self._full_protected: int = 0 + self._mamba_evictable: int = 0 + self._mamba_protected: int = 0 + + self.reset() + + # ------------------------------------------------------------------ + # Size queries + # ------------------------------------------------------------------ + + def evictable_size(self) -> int: + return self._full_evictable + + def protected_size(self) -> int: + return self._full_protected + + def mamba_evictable_size(self) -> int: + return self._mamba_evictable + + def mamba_protected_size(self) -> int: + return self._mamba_protected + + def total_size(self) -> int: + total = 0 + stack = [self.root_node] + while stack: + n = stack.pop() + if n.value is not None: + total += len(n.value) + stack.extend(c for c in n.children.values() if not c.evicted) + return total + + # ------------------------------------------------------------------ + # BasePrefixCache interface + # ------------------------------------------------------------------ + + def reset(self) -> None: + self.root_node = MambaTreeNode() + self.root_node.key = RadixKey([]) + self.root_node.value = torch.tensor([], dtype=torch.int64) + self.root_node.mamba_value = torch.tensor([], dtype=torch.int64) + self.root_node.full_lock_ref = 1 + self.root_node.mamba_lock_ref = 1 + self._full_evictable = 0 + self._full_protected = 0 + self._mamba_evictable = 0 + self._mamba_protected = 0 + self.full_lru = LRUList(mamba=False) + self.mamba_lru = LRUList(mamba=True) + + def match_prefix(self, key: RadixKey) -> MatchResult: + """Find longest cached prefix. Also returns ``mamba_branching_seqlen``.""" + empty = MatchResult( + indices=torch.empty(0, dtype=torch.int64, device=self.device), + last_node=self.root_node, + ) + if len(key) == 0: + return empty + + key = self._page_align_key(key) + if len(key) == 0: + return empty + + node = self.root_node + values: List[torch.Tensor] = [] + mamba_branching_seqlen: Optional[int] = None + total_matched = 0 + + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + child = node.children[ck] + child.hit_count += 1 + plen = _key_match(child.key, key, self.page_size) + + if plen < len(child.key): + new_node = self._split_node(child.key, child, plen) + values.append(new_node.value) + # Track mamba branching point + if mamba_branching_seqlen is None and new_node.mamba_tombstone: + mamba_branching_seqlen = total_matched + total_matched += len(new_node.value) + node = new_node + break + + values.append(child.value) + if mamba_branching_seqlen is None and child.mamba_tombstone: + mamba_branching_seqlen = total_matched + total_matched += len(child.value) + node = child + key = key[plen:] + + # Update LRU for matched path + self.full_lru.touch_node_and_parents_mru(node, self.root_node) + self.mamba_lru.touch_node_and_parents_mru(node, self.root_node) + + cat = ( + torch.cat(values) + if values + else torch.empty(0, dtype=torch.int64, device=self.device) + ) + return MatchResult( + indices=cat, + last_node=node, + prefix_len=len(cat), + mamba_branching_seqlen=mamba_branching_seqlen, + ) + + def insert( + self, + key: RadixKey, + value: Optional[torch.Tensor] = None, + *, + mamba_value: Optional[torch.Tensor] = None, + **kwargs: Any, + ) -> InsertResult: + """Insert with both full KV and Mamba state values.""" + if value is None: + value = torch.tensor(key.token_ids, dtype=torch.int64) + + if len(key) == 0: + return InsertResult() + + node = self.root_node + total_prefix = 0 + mamba_exist = False + + ck = _child_key(key, self.page_size) + while len(key) > 0 and ck in node.children: + node = node.children[ck] + plen = _key_match(node.key, key, self.page_size) + total_prefix += plen + key = key[plen:] + value = value[plen:] + + if plen < len(node.key): + node = self._split_node(node.key, node, plen) + + # Check if mamba state already exists + if node.mamba_value is not None: + mamba_exist = True + + if len(key) > 0: + ck = _child_key(key, self.page_size) + + if len(key) > 0: + new_leaf = self._add_leaf(node, key, value, mamba_value=mamba_value) + node = new_leaf + elif mamba_value is not None and node.mamba_value is None: + # Existing node gains mamba state (un-tombstone) + node.mamba_value = mamba_value.clone() + self.mamba_lru.insert_mru(node) + self._mamba_evictable += len(node.value) + + return InsertResult( + prefix_len=total_prefix, last_node=node, mamba_exist=mamba_exist + ) + + def evict(self, num_tokens: int, swa_num_tokens: int = 0) -> EvictResult: + """Evict full KV and/or Mamba state tokens. + + Phase 1: Evict full KV leaves (frees both KV and Mamba state). + Phase 2: Evict Mamba state from internal nodes (tombstone mamba). + """ + full_evicted = 0 + mamba_evicted = 0 + + # Phase 1: full leaf eviction + if num_tokens > 0: + while full_evicted < num_tokens: + node = self.full_lru.get_lru_leaf_unlocked() + if node is None: + break + n = len(node.value) + self._free_full_indices(node.value) + if node.mamba_value is not None: + self._free_mamba_value(node.mamba_value) + mamba_evicted += n + full_evicted += n + self._delete_leaf(node) + + # Cascade: parent may become evictable leaf + p = node.parent + if ( + p is not None + and p != self.root_node + and len(p.children) == 0 + and p.full_lock_ref == 0 + ): + # Will be picked up in next iteration via LRU + pass + + # Phase 2: mamba-only tombstone eviction + target_mamba = swa_num_tokens + if target_mamba > 0 and mamba_evicted < target_mamba: + while mamba_evicted < target_mamba: + node = self.mamba_lru.get_lru_unlocked() + if node is None: + break + if node.mamba_value is None: + continue + n = len(node.mamba_value) + self._free_mamba_value(node.mamba_value) + node.mamba_value = None + self.mamba_lru.remove(node) + self._mamba_evictable -= n + mamba_evicted += n + + return EvictResult( + full_evicted=full_evicted, mamba_evicted=mamba_evicted + ) + + def inc_lock_ref(self, node: MambaTreeNode) -> Optional[Any]: + """Lock full KV and Mamba state from *node* to root. + + Full lock propagates up to root. Mamba lock only applies to + the node itself (not ancestors). + """ + if node is None: + return None + + # Lock mamba on the node itself + if node.mamba_value is not None: + if node.mamba_lock_ref == 0 and node in self.mamba_lru: + self._mamba_evictable -= len(node.mamba_value) + self._mamba_protected += len(node.mamba_value) + node.mamba_lock_ref += 1 + + # Lock full KV up to root + cur = node + while cur != self.root_node: + if cur.full_lock_ref == 0: + self._full_evictable -= len(cur.key) + self._full_protected += len(cur.key) + cur.full_lock_ref += 1 + cur = cur.parent + return None + + def dec_lock_ref(self, node: MambaTreeNode, **kwargs: Any) -> None: + """Unlock full KV and Mamba state.""" + if node is None: + return + + # Unlock mamba on the node itself + if node.mamba_lock_ref > 0: + node.mamba_lock_ref -= 1 + if node.mamba_lock_ref == 0 and node.mamba_value is not None: + self._mamba_evictable += len(node.mamba_value) + self._mamba_protected -= len(node.mamba_value) + + # Unlock full KV up to root + cur = node + while cur != self.root_node: + if cur.full_lock_ref == 1: + self._full_evictable += len(cur.key) + self._full_protected -= len(cur.key) + cur.full_lock_ref -= 1 + cur = cur.parent + + # ------------------------------------------------------------------ + # Internal: tree manipulation + # ------------------------------------------------------------------ + + def _add_leaf( + self, + parent: MambaTreeNode, + key: RadixKey, + value: torch.Tensor, + mamba_value: Optional[torch.Tensor] = None, + ) -> MambaTreeNode: + # Parent may lose leaf status + if ( + len(parent.children) == 0 + and parent != self.root_node + and parent.full_lock_ref == 0 + and not parent.evicted + ): + self._full_evictable -= len(parent.key) + + new_node = MambaTreeNode() + new_node.parent = parent + new_node.key = key + new_node.value = value.clone() + parent.children[_child_key(key, self.page_size)] = new_node + + # Track in full LRU + self.full_lru.insert_mru(new_node) + self._full_evictable += len(key) + + # Track mamba state if provided + if mamba_value is not None: + new_node.mamba_value = mamba_value.clone() + self.mamba_lru.insert_mru(new_node) + self._mamba_evictable += len(key) + + return new_node + + def _split_node( + self, key: RadixKey, child: MambaTreeNode, split_len: int + ) -> MambaTreeNode: + """Split *child* at *split_len*, returning the new parent node.""" + new_node = MambaTreeNode() + new_node.children[_child_key(key[split_len:], self.page_size)] = child + new_node.parent = child.parent + new_node.full_lock_ref = child.full_lock_ref + new_node.mamba_lock_ref = child.mamba_lock_ref + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len].clone() + + # Split mamba value + if child.mamba_value is not None: + new_node.mamba_value = child.mamba_value[:split_len].clone() + child.mamba_value = child.mamba_value[split_len:].clone() + + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:].clone() + new_node.parent.children[_child_key(key, self.page_size)] = new_node + + # Update LRU lists: insert new_node, keep child + self.full_lru.insert_mru(new_node) + if new_node.mamba_value is not None: + self.mamba_lru.insert_mru(new_node) + + return new_node + + def _delete_leaf(self, node: MambaTreeNode) -> None: + ck = _child_key(node.key, self.page_size) + node.parent.children.pop(ck, None) + + # Remove from LRU lists + if node in self.full_lru: + self.full_lru.remove(node) + self._full_evictable -= len(node.key) + + if node.mamba_value is not None and node in self.mamba_lru: + self.mamba_lru.remove(node) + self._mamba_evictable -= len(node.key) + + node.value = None + node.mamba_value = None + + if self.on_node_evict is not None: + self.on_node_evict(node.id) + + # ------------------------------------------------------------------ + # Internal: memory management + # ------------------------------------------------------------------ + + def _free_full_indices(self, indices: torch.Tensor) -> None: + if self.pool is not None and len(indices) > 0: + self.pool.free(indices) + + def _free_mamba_value(self, mamba_value: torch.Tensor) -> None: + if self.mamba_pool is not None and len(mamba_value) > 0: + for idx in mamba_value.tolist(): + self.mamba_pool.free_track_slot(int(idx)) + + def _page_align_key(self, key: RadixKey) -> RadixKey: + if self.page_size == 1: + return key + aligned = len(key) // self.page_size * self.page_size + return key[:aligned] + + def pretty_print(self) -> None: + """Print the tree structure to stdout.""" + self._print_helper(self.root_node, 0) + print( + f"total={self.total_size()} " + f"full_evictable={self._full_evictable} " + f"mamba_evictable={self._mamba_evictable}" + ) + + def _print_helper(self, node: MambaTreeNode, indent: int) -> None: + stack = [(node, indent)] + while stack: + n, ind = stack.pop() + toks = n.key.token_ids[:10] if n.key else [] + klen = len(n.key) if n.key else 0 + has_mamba = n.mamba_value is not None + print( + f"{' ' * ind}[{klen}] {toks} " + f"full_lock={n.full_lock_ref} mamba_lock={n.mamba_lock_ref} " + f"mamba={'Y' if has_mamba else 'N'}" + ) + for c in n.children.values(): + stack.append((c, ind + 1)) diff --git a/pymllm/mem_cache/memory_pool.py b/pymllm/mem_cache/memory_pool.py new file mode 100644 index 000000000..9c8ab2a99 --- /dev/null +++ b/pymllm/mem_cache/memory_pool.py @@ -0,0 +1,639 @@ +"""Lightweight KV-cache memory pools + +Three-layer architecture:: + + ReqToTokenPool maps (req_slot, position) → kv_index + TokenToKVPoolAllocator manages a free-list of integer indices + KVPool holds the actual GPU K/V tensors + +All indices are **int32** tensors on the target device. Slot 0 in the KV +buffers is reserved as a padding / dummy-output slot and is never allocated. +""" + +import logging +from typing import List, Optional, Tuple, Union + +import torch + +from mllm_kernel.cuda.jit.store_cache import store_cache, can_use_store_cache + +logger = logging.getLogger(__name__) + + +class KVPool: + """GPU (or CPU) storage for per-layer key and value caches. + + Layout per layer:: + + JIT: + k_buffer[layer][slot, k_head_num * k_head_dim] + v_buffer[layer][slot, v_head_num * v_head_dim] + + PyTorch: + k_buffer[layer][slot, k_head_num, k_head_dim] + v_buffer[layer][slot, v_head_num, v_head_dim] + + K and V may have **independent** head counts and head dimensions, which + covers standard MHA, GQA / MQA, and architectures like MLA where value + projection uses a different dimensionality. + + ``size`` usable slots are numbered ``[1, size]``. Slot 0 is a dummy + padding slot that absorbs writes from padded tokens. + + Parameters + ---------- + size : int + Number of usable token slots (total buffer length = ``size + 1``). + layer_num : int + Number of transformer layers (one K buffer + one V buffer per layer). + k_head_num : int + Number of key heads. + k_head_dim : int + Dimension of each key head. + device : str | torch.device + Target device (``"cuda"``, ``"cpu"``, …). + dtype : torch.dtype + Storage data type. + v_head_num : int, optional + Number of value heads. Defaults to *k_head_num*. + v_head_dim : int, optional + Dimension of each value head. Defaults to *k_head_dim*. + pin_memory : bool, optional + Whether to use pinned memory. Defaults to True. + """ + + def __init__( + self, + size: int, + layer_num: int, + k_head_num: int, + k_head_dim: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.float16, + v_head_num: Optional[int] = None, + v_head_dim: Optional[int] = None, + pin_memory: bool = True, + ): + self.size = size + self.layer_num = layer_num + self.k_head_num = k_head_num + self.k_head_dim = k_head_dim + self.v_head_num = v_head_num if v_head_num is not None else k_head_num + self.v_head_dim = v_head_dim if v_head_dim is not None else k_head_dim + self.device = torch.device(device) + self.dtype = dtype + + # pin_memory only applies to CPU tensors + if self.device.type != "cpu": + pin_memory = False + + buf_len = size + 1 # slot 0 is padding + + if buf_len % 8 != 0: + logger.warning( + "KVPool buffer length is not divisible by 8, padding to the next multiple of 8" + ) + buf_len = (buf_len + 7) & ~7 + + k_row_dim = self.k_head_num * self.k_head_dim + v_row_dim = self.v_head_num * self.v_head_dim + self._same_kv_dim = k_row_dim == v_row_dim + self._row_bytes = k_row_dim * torch.tensor([], dtype=dtype).element_size() + self._use_jit = ( + self.device.type == "cuda" + and self._same_kv_dim + and can_use_store_cache(self._row_bytes) + ) + if not self._use_jit: + logger.warning( + f"Fallback to PyTorch index for KVPool, which is slower than the mllm-kernel's implementation, same_kv_dim={self._same_kv_dim}, row_bytes={self._row_bytes}" + ) + + self.k_buffer: List[torch.Tensor] = [ + torch.zeros( + (buf_len, self.k_head_num, self.k_head_dim), + dtype=dtype, + device=self.device, + pin_memory=pin_memory, + ) + for _ in range(layer_num) + ] + self.v_buffer: List[torch.Tensor] = [ + torch.zeros( + (buf_len, self.v_head_num, self.v_head_dim), + dtype=dtype, + device=self.device, + pin_memory=pin_memory, + ) + for _ in range(layer_num) + ] + + # Pre-computed 2D views for the JIT store_cache kernel. + # Zero-copy: same underlying storage as k_buffer / v_buffer. + if self._use_jit: + self._k_buffer_2d = [b.view(buf_len, -1) for b in self.k_buffer] + self._v_buffer_2d = [b.view(buf_len, -1) for b in self.v_buffer] + + logger.info( + "KVPool allocated: %d layers, %d slots, K=[%d,%d] V=[%d,%d], %.2f GB", + layer_num, + size, + self.k_head_num, + self.k_head_dim, + self.v_head_num, + self.v_head_dim, + self._mem_bytes() / (1 << 30), + ) + + def get_key_buffer(self, layer_id: int) -> torch.Tensor: + return self.k_buffer[layer_id] + + def get_value_buffer(self, layer_id: int) -> torch.Tensor: + return self.v_buffer[layer_id] + + def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]: + return self.k_buffer[layer_id], self.v_buffer[layer_id] + + def set_kv_buffer( + self, + layer_id: int, + indices: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + ) -> None: + """Write K/V vectors into the cache at the given *indices*. + + ``k`` / ``v`` can be any shape as long as the trailing dimensions + multiply to ``head_num * head_dim`` (the row dimension). All leading + dimensions are treated as the batch axis and must match ``indices`` + after flattening. Typical shapes:: + + k: [num_tokens, head_num, head_dim] indices: [num_tokens] + k: [batch, seq_len, head_num, head_dim] indices: [batch, seq_len] + k: [num_tokens, head_num * head_dim] indices: [num_tokens] + """ + if self._use_jit: + row_dim = self.k_head_num * self.k_head_dim + store_cache( + k.reshape(-1, row_dim), + v.reshape(-1, row_dim), + self._k_buffer_2d[layer_id], + self._v_buffer_2d[layer_id], + indices.reshape(-1), + row_bytes=self._row_bytes, + ) + else: + self.k_buffer[layer_id][indices] = k + self.v_buffer[layer_id][indices] = v + + def _mem_bytes(self) -> int: + total = 0 + for buf in self.k_buffer + self.v_buffer: + total += buf.nelement() * buf.element_size() + return total + + +class TokenToKVPoolAllocator: + """Manages allocation / deallocation of integer indices into a :class:`KVPool`. + + Each ``alloc(n)`` returns *n* free indices; each ``free(indices)`` returns + them to the pool. + + Uses a **dual-buffer** strategy (``free_slots`` + ``release_slots``) so + that ``free()`` never cats onto the large main free-list. Freed indices + accumulate in the smaller ``release_slots`` and are merged lazily (with an + optional sort) only when ``alloc()`` cannot be satisfied from + ``free_slots`` alone. + + A **batch-free** API (``free_group_begin`` / ``free_group_end``) further + amortises cost when many ``free()`` calls happen in a tight loop (e.g. + during scheduling or eviction). + + Typical usage:: + + allocator = TokenToKVPoolAllocator(size=4096, device="cuda") + + # --- basic alloc / free --- + indices = allocator.alloc(128) # 128 free slot indices (int32) + allocator.free(indices[:64]) # return 64 slots + + # --- batch free (amortised) --- + allocator.free_group_begin() + for req in finished_requests: + allocator.free(req.kv_indices) # O(1) list append each + allocator.free_group_end() # single torch.cat + release + + Parameters + ---------- + size : int + Total number of allocatable slots (must match ``KVPool.size``). + device : str | torch.device + Device for the free-list tensor. + page_size : int + When > 1 the allocator works in page-aligned mode: ``alloc`` returns + multiples of ``page_size`` contiguous within each page, and ``free`` + deduplicates by page. + need_sort : bool + When ``True`` (default), ``merge_and_sort_free`` sorts after merging + so that lower-index slots are allocated first (better memory locality). + """ + + def __init__( + self, + size: int, + device: Union[str, torch.device] = "cuda", + page_size: int = 1, + need_sort: bool = True, + ): + self.size = size + self.page_size = page_size + self.device = torch.device(device) + self.need_sort = need_sort + self.clear() + + def clear(self) -> None: + """Reset the allocator so that all slots ``[1, size]`` are free. The first slot is reserved for padding.""" + if self.page_size == 1: + self.free_slots = torch.arange( + 1, self.size + 1, dtype=torch.int32, device=self.device + ) + else: + num_pages = self.size // self.page_size + self.free_slots = torch.arange( + 1, num_pages + 1, dtype=torch.int32, device=self.device + ) + self.release_slots = torch.empty((0,), dtype=torch.int32, device=self.device) + self._is_not_in_free_group = True + self._free_group: List[torch.Tensor] = [] + + def available_size(self) -> int: + """Number of tokens that can still be allocated.""" + return (len(self.free_slots) + len(self.release_slots)) * self.page_size + + def merge_and_sort_free(self) -> None: + """Merge ``release_slots`` into ``free_slots`` (and sort if ``need_sort``).""" + if len(self.release_slots) == 0: + return + self.free_slots = torch.cat((self.free_slots, self.release_slots)) + if self.need_sort: + self.free_slots, _ = torch.sort(self.free_slots) + self.release_slots = torch.empty((0,), dtype=torch.int32, device=self.device) + + def free_group_begin(self) -> None: + """Start collecting ``free()`` calls; actual release is deferred to ``free_group_end``.""" + self._is_not_in_free_group = False + self._free_group = [] + + def free_group_end(self) -> None: + """Flush all ``free()`` calls collected since ``free_group_begin``.""" + self._is_not_in_free_group = True + if self._free_group: + self.free(torch.cat(self._free_group)) + self._free_group = [] + + def alloc(self, need_size: int) -> Optional[torch.Tensor]: + """Allocate *need_size* token indices. + + Returns a 1-D ``int32`` tensor on success, or ``None`` if the pool is + exhausted. + """ + if self.page_size == 1: + if need_size > len(self.free_slots): + self.merge_and_sort_free() + if need_size > len(self.free_slots): + return None + out = self.free_slots[:need_size] + self.free_slots = self.free_slots[need_size:] + return out + + num_pages = (need_size + self.page_size - 1) // self.page_size + if num_pages > len(self.free_slots): + self.merge_and_sort_free() + if num_pages > len(self.free_slots): + return None + pages = self.free_slots[:num_pages] + self.free_slots = self.free_slots[num_pages:] + offsets = torch.arange(self.page_size, device=self.device) + out = (pages[:, None] * self.page_size + offsets).reshape(-1) + return out[:need_size] + + def free(self, indices: torch.Tensor) -> None: + """Return *indices* to the free pool.""" + if indices.numel() == 0: + return + + if not self._is_not_in_free_group: + self._free_group.append(indices) + return + + if self.page_size != 1: + indices = torch.unique(indices // self.page_size) + + if self.need_sort: + self.release_slots = torch.cat((self.release_slots, indices)) + else: + self.free_slots = torch.cat((self.free_slots, indices)) + + +class ReqToTokenPool: + """Maps each live request to its per-position KV-pool indices. + + Internally a 2-D tensor ``req_to_token[slot, position]`` stores the + KV-pool index for every token position of every active request. + Slots are recycled via a simple free-list. + + This class is a **pure mapping table** -- it does **not** track per-request + sequence lengths. The caller (typically the ``Req`` / IO-struct object) + must store ``req_pool_idx`` and ``seq_len`` and use them to slice into + ``req_to_token`` when reading back KV indices. + + Typical usage:: + + pool = ReqToTokenPool(max_reqs=256, max_context_len=4096) + + # --- on new request arrival --- + [slot] = pool.alloc(1) # slot = req_pool_idx + kv_indices = kv_allocator.alloc(seq_len) # from TokenToKVPoolAllocator + pool.write((slot, slice(0, seq_len)), kv_indices) + + # --- read back (caller tracks seq_len) --- + kv_indices = pool.req_to_token[slot, :seq_len] + + # --- on request completion --- + kv_allocator.free(pool.req_to_token[slot, :seq_len]) + pool.free(slot) + + Parameters + ---------- + max_reqs : int + Maximum number of concurrent requests (number of rows). + max_context_len : int + Maximum sequence length any single request can reach (number of cols). + device : str | torch.device + Target device for the mapping tensor. + """ + + def __init__( + self, + max_reqs: int, + max_context_len: int, + device: Union[str, torch.device] = "cuda", + ): + self.size = max_reqs + self.max_context_len = max_context_len + self.device = torch.device(device) + + self.req_to_token = torch.zeros( + (max_reqs, max_context_len), dtype=torch.int32, device=self.device + ) + self._free_slots: List[int] = list(range(max_reqs)) + + def available_size(self) -> int: + return len(self._free_slots) + + def alloc(self, n: int = 1) -> Optional[List[int]]: + """Allocate *n* request slots. Returns a list of slot indices.""" + if n > len(self._free_slots): + return None + out = self._free_slots[:n] + self._free_slots = self._free_slots[n:] + return out + + def free(self, slot: int) -> None: + """Return a single request slot to the pool.""" + self._free_slots.append(slot) + + def write(self, index: Tuple, values: torch.Tensor) -> None: + """Write KV indices into the mapping table. + + ``index`` is typically ``(req_pool_idx, slice(start, end))``. + """ + self.req_to_token[index] = values + + def clear(self) -> None: + self._free_slots = list(range(self.size)) + self.req_to_token.zero_() + + +def make_full_attention_net_mem_pool( + size: int, + layer_num: int, + k_head_num: int, + k_head_dim: int, + v_head_num: int, + v_head_dim: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.float16, + page_size: int = 1, + need_sort: bool = True, + pin_memory: bool = True, +) -> Tuple[KVPool, TokenToKVPoolAllocator]: + """Create a :class:`KVPool` and its :class:`TokenToKVPoolAllocator` for a + full-attention (non-SWA) model. + + Parameters + ---------- + size : int + Number of usable token slots in the KV cache. + layer_num : int + Number of transformer layers. + k_head_num / k_head_dim : int + Key head count and dimension. + v_head_num / v_head_dim : int + Value head count and dimension. + device : str | torch.device + Target device. + dtype : torch.dtype + Storage data type for the KV buffers. + page_size : int + Allocator page size (1 = per-token, >1 = page-aligned). + need_sort : bool + Whether the allocator sorts on merge for memory locality. + pin_memory : bool + Whether to use pinned memory for the KV buffers. + + Returns + ------- + (KVPool, TokenToKVPoolAllocator) + """ + pool = KVPool( + size=size, + layer_num=layer_num, + k_head_num=k_head_num, + k_head_dim=k_head_dim, + device=device, + dtype=dtype, + v_head_num=v_head_num, + v_head_dim=v_head_dim, + pin_memory=pin_memory, + ) + allocator = TokenToKVPoolAllocator( + size=size, + device=device, + page_size=page_size, + need_sort=need_sort, + ) + return pool, allocator + + +class GDNPool: + """Pre-allocated memory pool for GDN recurrent and conv states. + + Indexed by ``req_pool_idx`` (same index space as :class:`ReqToTokenPool`). + Slot 0 is reserved as a padding / dummy slot and is never allocated. + + Layout:: + + recurrent_state[gdn_layer_idx, slot, num_v_heads, head_k_dim, head_v_dim] + float32 (FlashInfer requirement) + conv_state[gdn_layer_idx, slot, conv_dim, kernel_size - 1] + model dtype (bfloat16 / float16) + + Parameters + ---------- + max_reqs : int + Maximum number of concurrent requests (matches ``ReqToTokenPool.size``). + num_gdn_layers : int + Number of GDN (linear attention) layers in the model. + num_v_heads : int + Number of value heads per GDN layer. + head_k_dim : int + Per-head key dimension. + head_v_dim : int + Per-head value dimension. + conv_dim : int + Total convolution input dimension (``key_dim * 2 + value_dim``). + conv_kernel_size : int + Causal conv1d kernel width (state stores ``kernel_size - 1`` columns). + device : str | torch.device + Target device. + dtype : torch.dtype + Storage dtype for conv_state (recurrent_state is always float32). + """ + + def __init__( + self, + max_reqs: int, + num_gdn_layers: int, + num_v_heads: int, + head_k_dim: int, + head_v_dim: int, + conv_dim: int, + conv_kernel_size: int, + device: Union[str, torch.device] = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_track_slots: int = 0, + ): + self.max_reqs = max_reqs + self.num_gdn_layers = num_gdn_layers + self.num_v_heads = num_v_heads + self.head_k_dim = head_k_dim + self.head_v_dim = head_v_dim + self.conv_dim = conv_dim + self.conv_kernel_size = conv_kernel_size + self.device = torch.device(device) + self.dtype = dtype + self.max_track_slots = max_track_slots + + # Track slots live after the working slots: indices + # [max_reqs + 1, max_reqs + 1 + max_track_slots) + pool_size = max_reqs + 1 + max_track_slots # slot 0 is padding + + # Recurrent state: always float32 (FlashInfer requirement) + # Shape: [num_gdn_layers, pool_size, num_v_heads, head_v_dim, head_k_dim] + # Note: FlashInfer uses (V, K) layout for the state matrix + self.recurrent_state = torch.zeros( + (num_gdn_layers, pool_size, num_v_heads, head_v_dim, head_k_dim), + dtype=torch.float32, + device=self.device, + ) + + # Conv state: model dtype + # Shape: [num_gdn_layers, pool_size, conv_dim, kernel_size - 1] + self.conv_state = torch.zeros( + (num_gdn_layers, pool_size, conv_dim, conv_kernel_size - 1), + dtype=dtype, + device=self.device, + ) + + # Track-slot free list (indices into the pool starting after working slots) + self._track_slot_base = max_reqs + 1 + self._free_track_slots: List[int] = list( + range(self._track_slot_base, self._track_slot_base + max_track_slots) + ) + + logger.info( + "GDNPool allocated: %d GDN layers, %d working + %d track slots, " + "v_heads=%d, k_dim=%d, v_dim=%d, conv_dim=%d, kernel=%d, %.2f GB", + num_gdn_layers, + max_reqs, + max_track_slots, + num_v_heads, + head_k_dim, + head_v_dim, + conv_dim, + conv_kernel_size, + self.mem_bytes() / (1 << 30), + ) + + def get_layer_state( + self, gdn_layer_idx: int + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Return ``(recurrent_state, conv_state)`` for a specific GDN layer. + + Both are views into the pool tensors with shape: + - recurrent: ``[pool_size, num_v_heads, head_v_dim, head_k_dim]`` + - conv: ``[pool_size, conv_dim, kernel_size - 1]`` + """ + return ( + self.recurrent_state[gdn_layer_idx], + self.conv_state[gdn_layer_idx], + ) + + def reset_states(self, req_pool_indices: torch.Tensor) -> None: + """Zero-init GDN states for the given request pool indices. + + Called when new requests are allocated to ensure clean state. + """ + if req_pool_indices.numel() == 0: + return + # Zero both recurrent and conv states for all GDN layers + self.recurrent_state[:, req_pool_indices] = 0 + self.conv_state[:, req_pool_indices] = 0 + + # ------------------------------------------------------------------ + # Track-slot management (for prefix cache GDN state snapshots) + # ------------------------------------------------------------------ + + def alloc_track_slot(self) -> Optional[int]: + """Allocate a single track slot index. Returns ``None`` if exhausted.""" + if not self._free_track_slots: + return None + return self._free_track_slots.pop() + + def free_track_slot(self, slot: int) -> None: + """Return a track slot to the free list.""" + self._free_track_slots.append(slot) + + def copy_states(self, src_index: int, dst_index: int) -> None: + """Copy recurrent and conv states from *src_index* to *dst_index*. + + Works for any pool indices (working or track slots). + """ + self.recurrent_state[:, dst_index] = self.recurrent_state[:, src_index] + self.conv_state[:, dst_index] = self.conv_state[:, src_index] + + def mem_bytes(self) -> int: + """Total memory consumption in bytes.""" + return ( + self.recurrent_state.nelement() * self.recurrent_state.element_size() + + self.conv_state.nelement() * self.conv_state.element_size() + ) + + +def make_req_to_token_pool( + max_reqs: int, + max_context_len: int, + device: Union[str, torch.device] = "cuda", +) -> ReqToTokenPool: + return ReqToTokenPool(max_reqs, max_context_len, device) diff --git a/pymllm/mem_cache/radix_cache.py b/pymllm/mem_cache/radix_cache.py new file mode 100644 index 000000000..80f3d6f1f --- /dev/null +++ b/pymllm/mem_cache/radix_cache.py @@ -0,0 +1,775 @@ +"""Radix-tree KV cache with SWA and multimodal support. + +Supports: + - Multi-batch serving on a single GPU + - Sliding Window Attention (SWA) via tombstone mechanism + - Multimodal namespace isolation via ``extra_key`` + - SHA256 position-aware hashing + - Page-aligned operations (page_size >= 1) + - LRU leaf eviction +""" + +from __future__ import annotations + +import heapq +import logging +import time +from collections import defaultdict +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch + +from pymllm.mem_cache.base_prefix_cache import ( + BasePrefixCache, + EvictResult, + InsertResult, + MatchResult, + RadixKey, + hash_token_ids, +) + +logger = logging.getLogger(__name__) + + +# ====================================================================== +# Tree node +# ====================================================================== + +_node_counter: int = 0 + + +def _next_node_id() -> int: + global _node_counter + _node_counter += 1 + return _node_counter + + +class TreeNode: + """A single node in the radix tree. + + ``value`` holds a 1-D ``int64`` tensor of KV-pool indices (one per token + in ``key``). When the node has been evicted, ``value`` is ``None``. + """ + + __slots__ = ( + "children", + "parent", + "key", + "value", + "lock_ref", + "swa_lock_ref", + "swa_tombstone", + "swa_boundary_id", + "last_access_time", + "hit_count", + "hash_values", + "id", + ) + + def __init__(self) -> None: + self.children: Dict[Any, TreeNode] = defaultdict(TreeNode) + self.parent: Optional[TreeNode] = None + self.key: Optional[RadixKey] = None + self.value: Optional[torch.Tensor] = None + + self.lock_ref: int = 0 + self.swa_lock_ref: int = 0 + self.swa_tombstone: bool = False + self.swa_boundary_id: Optional[int] = None + + self.last_access_time: float = time.monotonic() + self.hit_count: int = 0 + self.hash_values: Optional[List[str]] = None + self.id: int = _next_node_id() + + @property + def evicted(self) -> bool: + return self.value is None + + def __lt__(self, other: TreeNode) -> bool: + return self.last_access_time < other.last_access_time + + +# ====================================================================== +# Helper functions +# ====================================================================== + + +def _key_match(key0: RadixKey, key1: RadixKey, page_size: int) -> int: + """Return the length of the common prefix (page-aligned when *page_size* > 1).""" + if key0.extra_key != key1.extra_key: + return 0 + if page_size == 1: + i = 0 + for a, b in zip(key0.token_ids, key1.token_ids): + if a != b: + break + i += 1 + return i + min_len = min(len(key0), len(key1)) + i = 0 + while i < min_len: + if key0.token_ids[i : i + page_size] != key1.token_ids[i : i + page_size]: + break + i += page_size + return i + + +def _child_key(key: RadixKey, page_size: int) -> Any: + """Derive the dict key used in ``node.children``.""" + plain = key.token_ids[0] if page_size == 1 else tuple(key.token_ids[:page_size]) + return (key.extra_key, plain) if key.extra_key is not None else plain + + +# ====================================================================== +# RadixCache +# ====================================================================== + + +class RadixCache(BasePrefixCache): + """Radix tree for KV-cache prefix sharing. + + Parameters + ---------- + page_size: + Number of tokens per KV-pool page. Keys and values are aligned to + this granularity. + sliding_window_size: + If set, enables SWA mode. The cache tracks which nodes have had + their SWA KV freed (tombstoned) and constrains prefix matching + so that the sliding-window invariant is maintained. + token_to_kv_pool_allocator: + Optional pool allocator with ``free(indices)`` (and ``free_swa`` for + SWA mode). When *None*, index tensors are simply discarded. + on_node_evict: + Optional callback invoked with the node id when a node is evicted. + """ + + def __init__( + self, + page_size: int = 1, + sliding_window_size: Optional[int] = None, + token_to_kv_pool_allocator: Any = None, + on_node_evict: Optional[Callable[[int], None]] = None, + ): + self.page_size = page_size + self.sliding_window_size = sliding_window_size + self.pool = token_to_kv_pool_allocator + self.on_node_evict = on_node_evict + + if self.pool is not None and hasattr(self.pool, "device"): + self.device = self.pool.device + else: + self.device = torch.device("cpu") + + self._swa_boundary_counter: int = 0 + self.reset() + + @property + def supports_swa(self) -> bool: + return self.sliding_window_size is not None + + # ------------------------------------------------------------------ + # Size queries + # ------------------------------------------------------------------ + + def evictable_size(self) -> int: + return self._evictable_size + + def swa_evictable_size(self) -> int: + return self._swa_evictable_size + + def protected_size(self) -> int: + return self._protected_size + + def swa_protected_size(self) -> int: + return self._swa_protected_size + + def total_size(self) -> int: + """Total number of cached tokens (including tombstoned).""" + total = 0 + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.value is not None: + total += len(n.value) + stack.extend(c for c in n.children.values() if not c.evicted) + return total + + # ------------------------------------------------------------------ + # BasePrefixCache interface + # ------------------------------------------------------------------ + + def reset(self) -> None: + """Clear all cached state and re-initialise the root node.""" + self.root_node = TreeNode() + self.root_node.key = RadixKey([]) + self.root_node.value = torch.tensor([], dtype=torch.int64) + self.root_node.lock_ref = 1 + self.root_node.swa_lock_ref = 1 + self._evictable_size: int = 0 + self._swa_evictable_size: int = 0 + self._protected_size: int = 0 + self._swa_protected_size: int = 0 + + def match_prefix(self, key: RadixKey) -> MatchResult: + """Find the longest cached prefix of *key*. + + For SWA mode the match is further constrained: the path from the + returned ``last_node`` to root must have at least + ``sliding_window_size`` non-tombstone tokens (or be entirely + tombstone-free back to root). + + Accessing a prefix refreshes LRU timestamps along the matched path. + """ + empty = MatchResult( + indices=torch.empty(0, dtype=torch.int64, device=self.device), + last_node=self.root_node, + ) + if len(key) == 0: + return empty + + key = self._page_align_key(key) + if len(key) == 0: + return empty + + if self.supports_swa: + values, last_node, best_count = self._match_swa(key) + values = values[:best_count] + else: + values, last_node = self._match_normal(key) + + cat = ( + torch.cat(values) + if values + else torch.empty(0, dtype=torch.int64, device=self.device) + ) + return MatchResult(indices=cat, last_node=last_node, prefix_len=len(cat)) + + def insert( + self, + key: RadixKey, + value: Optional[torch.Tensor] = None, + *, + prev_prefix_len: int = 0, + swa_evicted_seqlen: int = 0, + **kwargs: Any, + ) -> InsertResult: + """Insert *key*/*value* into the tree. + + Returns how many leading tokens were already present (the prefix + length). The caller is responsible for freeing duplicate KV indices + in the range ``[cache_protected_len, prefix_len)``. + + Parameters + ---------- + prev_prefix_len: + (SWA mode) tokens before this offset are already protected and + should not have their values overwritten. + swa_evicted_seqlen: + (SWA mode) the sequence length up to which SWA KV has been + previously evicted. Used to decide whether a tombstoned node can + be un-tombstoned with the incoming value. + """ + if value is None: + value = torch.tensor(key.token_ids, dtype=torch.int64) + if self.supports_swa: + plen = self._insert_swa( + self.root_node, key, value, prev_prefix_len, swa_evicted_seqlen + ) + return InsertResult(prefix_len=plen) + else: + plen, last_node = self._insert_normal(self.root_node, key, value) + return InsertResult(prefix_len=plen, last_node=last_node) + + def evict(self, num_tokens: int, swa_num_tokens: int = 0) -> EvictResult: + """Evict up to *num_tokens* (full) and *swa_num_tokens* (SWA) tokens. + + Full eviction removes leaf nodes entirely; SWA eviction tombstones + internal nodes (freeing SWA KV but retaining full-attn KV). + """ + full_evicted = 0 + swa_evicted = 0 + + # Phase 1: full leaf eviction + if num_tokens > 0: + leaves = self._collect_evictable_leaves() + heap: List[Tuple[float, TreeNode]] = [ + (n.last_access_time, n) for n in leaves + ] + heapq.heapify(heap) + + while full_evicted < num_tokens and heap: + _, node = heapq.heappop(heap) + if node.evicted or node.lock_ref > 0: + continue + n = len(node.value) + self._free_indices(node.value) + full_evicted += n + swa_evicted += n + self._delete_leaf(node) + + p = node.parent + if ( + p is not None + and p != self.root_node + and len(p.children) == 0 + and p.lock_ref == 0 + ): + if self.supports_swa and p.swa_tombstone: + self._free_indices(p.value) + full_evicted += len(p.value) + self._delete_leaf(p) + else: + heapq.heappush(heap, (p.last_access_time, p)) + + # Phase 2: SWA tombstone eviction (internal nodes) + if self.supports_swa and swa_evicted < swa_num_tokens: + candidates = self._collect_swa_evictable() + heap2: List[Tuple[float, TreeNode]] = [ + (n.last_access_time, n) for n in candidates + ] + heapq.heapify(heap2) + + while swa_evicted < swa_num_tokens and heap2: + _, node = heapq.heappop(heap2) + if node.swa_tombstone or node.swa_lock_ref > 0 or node.evicted: + continue + n = len(node.value) + if len(node.children) == 0 and node.lock_ref == 0: + self._free_indices(node.value) + full_evicted += n + swa_evicted += n + self._delete_leaf(node) + elif len(node.children) > 0: + self._free_swa_indices(node.value) + swa_evicted += n + self._tombstone_node(node) + + return EvictResult(full_evicted=full_evicted, swa_evicted=swa_evicted) + + def inc_lock_ref(self, node: TreeNode) -> Optional[int]: + """Lock nodes from *node* up to root (prevents eviction). + + Returns ``swa_boundary_id`` that must be passed back to + :meth:`dec_lock_ref`. In non-SWA mode, returns ``None``. + """ + if node is None: + return None + + swa_locked = 0 + swa_boundary_id: Optional[int] = None + cur = node + while cur != self.root_node: + if cur.lock_ref == 0: + self._evictable_size -= len(cur.key) + self._protected_size += len(cur.key) + cur.lock_ref += 1 + + if ( + self.supports_swa + and swa_locked < self.sliding_window_size + and not cur.swa_tombstone + ): + if cur.swa_lock_ref == 0: + self._swa_evictable_size -= len(cur.key) + self._swa_protected_size += len(cur.key) + cur.swa_lock_ref += 1 + swa_locked += len(cur.key) + if swa_locked >= self.sliding_window_size: + if cur.swa_boundary_id is None: + self._swa_boundary_counter += 1 + cur.swa_boundary_id = self._swa_boundary_counter + swa_boundary_id = cur.swa_boundary_id + + cur = cur.parent + return swa_boundary_id + + def dec_lock_ref( + self, node: TreeNode, swa_boundary_id: Optional[int] = None, **kwargs: Any + ) -> None: + """Unlock nodes from *node* up to root.""" + if node is None: + return + + dec_swa = True + cur = node + while cur != self.root_node: + if cur.lock_ref == 1: + self._evictable_size += len(cur.key) + self._protected_size -= len(cur.key) + cur.lock_ref -= 1 + + if self.supports_swa and dec_swa and not cur.swa_tombstone: + if cur.swa_lock_ref == 1: + self._swa_evictable_size += len(cur.key) + self._swa_protected_size -= len(cur.key) + cur.swa_lock_ref -= 1 + if swa_boundary_id and cur.swa_boundary_id == swa_boundary_id: + dec_swa = False + + cur = cur.parent + + # ------------------------------------------------------------------ + # Hashing & pretty-print + # ------------------------------------------------------------------ + + def compute_node_hash(self, node: TreeNode) -> List[str]: + """Compute position-aware SHA-256 hashes for *node* (one per page). + + Lazily computed and cached on ``node.hash_values``. + """ + if node.hash_values is not None: + return node.hash_values + + parent_hash: Optional[str] = None + if ( + node.parent is not None + and node.parent.hash_values is not None + and len(node.parent.key) > 0 + and len(node.parent.hash_values) > 0 + ): + parent_hash = node.parent.hash_values[-1] + + hashes: List[str] = [] + for start in range(0, len(node.key), self.page_size): + page = node.key.token_ids[start : start + self.page_size] + if not page: + continue + h = hash_token_ids(page, prior_hash=parent_hash) + hashes.append(h) + parent_hash = h + + node.hash_values = hashes + return hashes + + def pretty_print(self) -> None: + """Print the tree structure to stdout.""" + self._print_helper(self.root_node, 0) + print( + f"total={self.total_size()} evictable={self._evictable_size}" + + ( + f" swa_evictable={self._swa_evictable_size}" + if self.supports_swa + else "" + ) + ) + + # ------------------------------------------------------------------ + # Internal: match + # ------------------------------------------------------------------ + + def _match_normal(self, key: RadixKey) -> Tuple[List[torch.Tensor], TreeNode]: + node = self.root_node + now = time.monotonic() + node.last_access_time = now + values: List[torch.Tensor] = [] + + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + child = node.children[ck] + child.last_access_time = now + child.hit_count += 1 + plen = _key_match(child.key, key, self.page_size) + if plen < len(child.key): + new_node = self._split_node(child.key, child, plen) + values.append(new_node.value) + node = new_node + break + values.append(child.value) + node = child + key = key[plen:] + + return values, node + + def _match_swa(self, key: RadixKey) -> Tuple[List[torch.Tensor], TreeNode, int]: + """SWA-aware match. Returns *(values, last_node, best_value_count)*. + + ``best_value_count`` is the number of value tensors from *values* + that form a valid SWA-safe prefix (enough non-tombstone tokens within + the sliding window, or a tombstone-free path to root). + """ + node = self.root_node + values: List[torch.Tensor] = [] + non_tomb_len: float = float("inf") + best_count = 0 + best_node = node + + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + child = node.children[ck] + + if child.swa_tombstone: + if non_tomb_len >= self.sliding_window_size: + best_count = len(values) + best_node = node + non_tomb_len = 0 + + plen = _key_match(child.key, key, self.page_size) + if plen < len(child.key): + new_node = self._split_node(child.key, child, plen) + values.append(new_node.value) + if not new_node.swa_tombstone: + non_tomb_len += len(new_node.value) + node = new_node + break + values.append(child.value) + if not child.swa_tombstone: + non_tomb_len += len(child.value) + node = child + key = key[plen:] + + if non_tomb_len >= self.sliding_window_size: + best_count = len(values) + best_node = node + + return values, best_node, best_count + + # ------------------------------------------------------------------ + # Internal: insert + # ------------------------------------------------------------------ + + def _insert_normal( + self, node: TreeNode, key: RadixKey, value: torch.Tensor + ) -> Tuple[int, TreeNode]: + """Insert into non-SWA tree. Returns ``(prefix_len, last_node)``.""" + now = time.monotonic() + node.last_access_time = now + if len(key) == 0: + return 0, node + + total_prefix = 0 + ck = _child_key(key, self.page_size) + while len(key) > 0 and ck in node.children: + node = node.children[ck] + node.last_access_time = now + plen = _key_match(node.key, key, self.page_size) + total_prefix += plen + key = key[plen:] + value = value[plen:] + + if plen < len(node.key): + node = self._split_node(node.key, node, plen) + if len(key) > 0: + ck = _child_key(key, self.page_size) + + if len(key) > 0: + new_leaf = self._add_leaf(node, key, value) + node = new_leaf + + return total_prefix, node + + def _insert_swa( + self, + node: TreeNode, + key: RadixKey, + value: torch.Tensor, + prev_prefix_len: int, + swa_evicted_seqlen: int, + ) -> int: + """Insert with SWA tombstone awareness. + + When an existing node is tombstoned and the incoming *value* carries + fresh SWA KV (i.e. beyond *swa_evicted_seqlen*), the node is + un-tombstoned and its value is replaced. + """ + now = time.monotonic() + node.last_access_time = now + if len(key) == 0: + return 0 + + total_prefix = 0 + while len(key) > 0: + ck = _child_key(key, self.page_size) + if ck not in node.children: + break + node = node.children[ck] + node.last_access_time = now + plen = _key_match(node.key, key, self.page_size) + + if plen < len(node.key): + self._split_node(node.key, node, plen) + + beyond_protected = prev_prefix_len < total_prefix + plen + if beyond_protected and node.swa_tombstone: + if swa_evicted_seqlen <= total_prefix: + self._free_indices(node.value[:plen]) + node.value = value[:plen].clone() + node.swa_tombstone = False + self._swa_evictable_size += len(node.value) + else: + self._free_indices(value[:plen]) + elif beyond_protected: + self._free_indices(value[:plen]) + + total_prefix += plen + key = key[plen:] + value = value[plen:] + + if len(key) > 0: + if ( + swa_evicted_seqlen > total_prefix + and swa_evicted_seqlen < total_prefix + len(key) + ): + tomb_len = swa_evicted_seqlen - total_prefix + self._add_leaf( + node, key[:tomb_len], value[:tomb_len], swa_tombstone=True + ) + node = node.children[_child_key(key, self.page_size)] + key = key[tomb_len:] + value = value[tomb_len:] + + if len(key) > 0: + self._add_leaf(node, key, value, swa_tombstone=False) + + return total_prefix + + # ------------------------------------------------------------------ + # Internal: tree manipulation + # ------------------------------------------------------------------ + + def _add_leaf( + self, + parent: TreeNode, + key: RadixKey, + value: torch.Tensor, + swa_tombstone: bool = False, + ) -> TreeNode: + if ( + len(parent.children) == 0 + and parent != self.root_node + and parent.lock_ref == 0 + and not parent.evicted + ): + self._evictable_size -= len(parent.key) + + new_node = TreeNode() + new_node.parent = parent + new_node.key = key + new_node.value = value.clone() + new_node.swa_tombstone = swa_tombstone + parent.children[_child_key(key, self.page_size)] = new_node + self._evictable_size += len(key) + if self.supports_swa and not swa_tombstone: + self._swa_evictable_size += len(key) + return new_node + + def _split_node(self, key: RadixKey, child: TreeNode, split_len: int) -> TreeNode: + """Split *child* at *split_len*, returning the new parent node.""" + logger.debug( + "[SPLIT] node_id=%d key_len=%d split_len=%d " + "parent_val[:4]=%s child_val[:4]=%s", + child.id, + len(key), + split_len, + child.value[:min(split_len, 4)].tolist() + if child.value is not None + else [], + child.value[split_len : split_len + 4].tolist() + if child.value is not None and len(child.value) > split_len + else [], + ) + new_node = TreeNode() + new_node.children[_child_key(key[split_len:], self.page_size)] = child + new_node.parent = child.parent + new_node.lock_ref = child.lock_ref + new_node.swa_lock_ref = child.swa_lock_ref + new_node.swa_tombstone = child.swa_tombstone + new_node.swa_boundary_id = child.swa_boundary_id + child.swa_boundary_id = None + new_node.key = child.key[:split_len] + new_node.value = child.value[:split_len].clone() + + # Split hash values if they exist + if child.hash_values is not None: + pages = split_len // self.page_size if self.page_size > 1 else split_len + new_node.hash_values = child.hash_values[:pages] + child.hash_values = child.hash_values[pages:] + else: + new_node.hash_values = None + + child.parent = new_node + child.key = child.key[split_len:] + child.value = child.value[split_len:].clone() + new_node.parent.children[_child_key(key, self.page_size)] = new_node + return new_node + + def _delete_leaf(self, node: TreeNode) -> None: + ck = _child_key(node.key, self.page_size) + node.parent.children.pop(ck, None) + self._evictable_size -= len(node.key) + if self.supports_swa and not node.swa_tombstone: + self._swa_evictable_size -= len(node.key) + node.value = None + if self.on_node_evict is not None: + self.on_node_evict(node.id) + + def _tombstone_node(self, node: TreeNode) -> None: + node.swa_tombstone = True + self._swa_evictable_size -= len(node.key) + + # ------------------------------------------------------------------ + # Internal: collection helpers + # ------------------------------------------------------------------ + + def _collect_evictable_leaves(self) -> List[TreeNode]: + leaves: List[TreeNode] = [] + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.evicted: + continue + has_live_child = False + for c in n.children.values(): + if not c.evicted: + has_live_child = True + stack.append(c) + if not has_live_child and n.lock_ref == 0 and n != self.root_node: + leaves.append(n) + return leaves + + def _collect_swa_evictable(self) -> List[TreeNode]: + nodes: List[TreeNode] = [] + stack: List[TreeNode] = [self.root_node] + while stack: + n = stack.pop() + if n.evicted: + continue + if n != self.root_node and not n.swa_tombstone and n.swa_lock_ref == 0: + nodes.append(n) + stack.extend(c for c in n.children.values() if not c.evicted) + return nodes + + def _page_align_key(self, key: RadixKey) -> RadixKey: + if self.page_size == 1: + return key + aligned = len(key) // self.page_size * self.page_size + return key[:aligned] + + def _free_indices(self, indices: torch.Tensor) -> None: + if self.pool is not None and len(indices) > 0: + self.pool.free(indices) + + def _free_swa_indices(self, indices: torch.Tensor) -> None: + if self.pool is not None and len(indices) > 0: + if hasattr(self.pool, "free_swa"): + self.pool.free_swa(indices) + else: + self.pool.free(indices) + + def _print_helper(self, node: TreeNode, indent: int) -> None: + stack = [(node, indent)] + while stack: + n, ind = stack.pop() + toks = n.key.token_ids[:10] if n.key else [] + klen = len(n.key) if n.key else 0 + flags = f"lock={n.lock_ref}" + if self.supports_swa: + flags += f" swa={n.swa_lock_ref} tomb={n.swa_tombstone}" + print(f"{' ' * ind}[{klen}] {toks} {flags}") + for c in n.children.values(): + stack.append((c, ind + 1)) diff --git a/pymllm/mobile/README.md b/pymllm/mobile/README.md index 29877ea00..ceb71a5d3 100644 --- a/pymllm/mobile/README.md +++ b/pymllm/mobile/README.md @@ -1 +1,2 @@ -We should refactor current pymllm's src to mobile directory. And provide more functionalities for torch based VLA. +# Pymllm mobile + diff --git a/pymllm/mobile/__init__.py b/pymllm/mobile/__init__.py new file mode 100644 index 000000000..8796bbeaf --- /dev/null +++ b/pymllm/mobile/__init__.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from . import ffi +from . import convertor +from . import utils +from . import quantize +from . import nn +from . import service +from . import backends +from .ffi import ( + # Floating point types + float32, + float16, + bfloat16, + # Signed integer types + int8, + int16, + int32, + int64, + # Unsigned integer types + uint8, + uint16, + uint32, + uint64, + # Bool type + boolean, + # Devices + cpu, + cuda, + qnn, + # Tensor and utilities + Tensor, + empty, + echo, + device, + is_torch_available, + is_numpy_available, + from_torch, + from_numpy, + zeros, + ones, + arange, + random, +) +from .nn.functional import matmul diff --git a/pymllm/quantize/spinquant/__init__.py b/pymllm/mobile/backends/__init__.py similarity index 71% rename from pymllm/quantize/spinquant/__init__.py rename to pymllm/mobile/backends/__init__.py index ea8e2bec7..1578a0d87 100644 --- a/pymllm/quantize/spinquant/__init__.py +++ b/pymllm/mobile/backends/__init__.py @@ -1,2 +1,4 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. + +from . import qualcomm diff --git a/pymllm/backends/qualcomm/README.md b/pymllm/mobile/backends/qualcomm/README.md similarity index 100% rename from pymllm/backends/qualcomm/README.md rename to pymllm/mobile/backends/qualcomm/README.md diff --git a/pymllm/backends/qualcomm/__init__.py b/pymllm/mobile/backends/qualcomm/__init__.py similarity index 100% rename from pymllm/backends/qualcomm/__init__.py rename to pymllm/mobile/backends/qualcomm/__init__.py diff --git a/pymllm/backends/qualcomm/nn.py b/pymllm/mobile/backends/qualcomm/nn.py similarity index 75% rename from pymllm/backends/qualcomm/nn.py rename to pymllm/mobile/backends/qualcomm/nn.py index 0ba9aef55..e4bc91ace 100644 --- a/pymllm/backends/qualcomm/nn.py +++ b/pymllm/mobile/backends/qualcomm/nn.py @@ -1,4 +1,4 @@ -from pymllm.nn._layers import Softmax, RoPE +from pymllm.mobile.nn._layers import Softmax, RoPE class QnnSoftmax(Softmax): diff --git a/pymllm/backends/qualcomm/qnn_aot_env.py b/pymllm/mobile/backends/qualcomm/qnn_aot_env.py similarity index 83% rename from pymllm/backends/qualcomm/qnn_aot_env.py rename to pymllm/mobile/backends/qualcomm/qnn_aot_env.py index 8b0c0d2e1..bc48c7c97 100644 --- a/pymllm/backends/qualcomm/qnn_aot_env.py +++ b/pymllm/mobile/backends/qualcomm/qnn_aot_env.py @@ -1,7 +1,7 @@ -from pymllm.ffi import is_qnn_aot_on_x86_enabled +from pymllm.mobile.ffi import is_qnn_aot_on_x86_enabled if is_qnn_aot_on_x86_enabled(): - from pymllm.ffi import ( + from pymllm.mobile.ffi import ( QnnDeviceAndContext, QnnAOTEnv, QcomChipset, diff --git a/pymllm/backends/qualcomm/transformers/.gitignore b/pymllm/mobile/backends/qualcomm/transformers/.gitignore similarity index 100% rename from pymllm/backends/qualcomm/transformers/.gitignore rename to pymllm/mobile/backends/qualcomm/transformers/.gitignore diff --git a/pymllm/backends/qualcomm/transformers/README.md b/pymllm/mobile/backends/qualcomm/transformers/README.md similarity index 100% rename from pymllm/backends/qualcomm/transformers/README.md rename to pymllm/mobile/backends/qualcomm/transformers/README.md diff --git a/pymllm/backends/qualcomm/transformers/__init__.py b/pymllm/mobile/backends/qualcomm/transformers/__init__.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/__init__.py rename to pymllm/mobile/backends/qualcomm/transformers/__init__.py diff --git a/pymllm/compile/mlir/__init__.py b/pymllm/mobile/backends/qualcomm/transformers/core/__init__.py similarity index 100% rename from pymllm/compile/mlir/__init__.py rename to pymllm/mobile/backends/qualcomm/transformers/core/__init__.py diff --git a/pymllm/backends/qualcomm/transformers/core/embedding.py b/pymllm/mobile/backends/qualcomm/transformers/core/embedding.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/embedding.py rename to pymllm/mobile/backends/qualcomm/transformers/core/embedding.py diff --git a/pymllm/backends/qualcomm/transformers/core/observer.py b/pymllm/mobile/backends/qualcomm/transformers/core/observer.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/observer.py rename to pymllm/mobile/backends/qualcomm/transformers/core/observer.py diff --git a/pymllm/backends/qualcomm/transformers/core/qdq.py b/pymllm/mobile/backends/qualcomm/transformers/core/qdq.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/qdq.py rename to pymllm/mobile/backends/qualcomm/transformers/core/qdq.py diff --git a/pymllm/backends/qualcomm/transformers/core/qlinear.py b/pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py similarity index 99% rename from pymllm/backends/qualcomm/transformers/core/qlinear.py rename to pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py index 9e90ba8a5..35439180c 100644 --- a/pymllm/backends/qualcomm/transformers/core/qlinear.py +++ b/pymllm/mobile/backends/qualcomm/transformers/core/qlinear.py @@ -2,7 +2,7 @@ import torch.nn as nn import torch.nn.functional as F from torch.ao.quantization import FakeQuantize, PerChannelMinMaxObserver -from pymllm.backends.qualcomm.transformers.core.observer import ( +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ( PerBlockParamFakeQuantize, ) from torchao.quantization.quant_primitives import ( diff --git a/pymllm/backends/qualcomm/transformers/core/rms_norm.py b/pymllm/mobile/backends/qualcomm/transformers/core/rms_norm.py similarity index 100% rename from pymllm/backends/qualcomm/transformers/core/rms_norm.py rename to pymllm/mobile/backends/qualcomm/transformers/core/rms_norm.py diff --git a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/llama/modeling_llama.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py index 119ec04bc..6b65f34b9 100644 --- a/pymllm/backends/qualcomm/transformers/llama/modeling_llama.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/modeling_llama.py @@ -52,16 +52,16 @@ from transformers.models.llama.configuration_llama import LlamaConfig # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver logger = logging.get_logger(__name__) diff --git a/pymllm/backends/qualcomm/transformers/llama/runner.py b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/llama/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/runner.py index 8aa4627bf..730147d0f 100644 --- a/pymllm/backends/qualcomm/transformers/llama/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.llama.modeling_llama import LlamaForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/llama/train.py b/pymllm/mobile/backends/qualcomm/transformers/llama/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/llama/train.py rename to pymllm/mobile/backends/qualcomm/transformers/llama/train.py index cd10befba..41ffc0e27 100644 --- a/pymllm/backends/qualcomm/transformers/llama/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/llama/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.llama.runner import LlamaQuantizer +from pymllm.mobile.backends.qualcomm.transformers.llama.runner import LlamaQuantizer def main(): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py index 56b19c421..a43d8b7ea 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/modeling_qwen2.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/modeling_qwen2.py @@ -31,16 +31,16 @@ from transformers.models.qwen2.configuration_qwen2 import Qwen2Config # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen2MLP(nn.Module): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/runner.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/qwen2/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py index d2f5be05b..ce55fd06d 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.qwen2.modeling_qwen2 import Qwen2ForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/qwen2/train.py b/pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/qwen2/train.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py index fec5fdfca..1a8f25ce9 100644 --- a/pymllm/backends/qualcomm/transformers/qwen2/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen2/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer +from pymllm.mobile.backends.qualcomm.transformers.qwen2.runner import Qwen2Quantizer def main(): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py similarity index 98% rename from pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py index 2dabf5c9c..6a8788bad 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/modeling_qwen3.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/modeling_qwen3.py @@ -46,16 +46,16 @@ from transformers.models.qwen3.configuration_qwen3 import Qwen3Config # Replace linear, rms_norm with: -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, ) -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver class Qwen3MLP(nn.Module): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/runner.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py similarity index 96% rename from pymllm/backends/qualcomm/transformers/qwen3/runner.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py index 02ea6a5f0..0d7499c96 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/runner.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/runner.py @@ -2,18 +2,18 @@ from tqdm import tqdm from modelscope.msdatasets import MsDataset from transformers import AutoTokenizer -from pymllm.backends.qualcomm.transformers.core.qdq import ( +from pymllm.mobile.backends.qualcomm.transformers.core.qdq import ( ActivationQDQ, FixedActivationQDQ, ) -from pymllm.backends.qualcomm.transformers.core.rms_norm import QRMSNorm -from pymllm.backends.qualcomm.transformers.core.qlinear import ( +from pymllm.mobile.backends.qualcomm.transformers.core.rms_norm import QRMSNorm +from pymllm.mobile.backends.qualcomm.transformers.core.qlinear import ( QLinearLPBQ, QLinearW8A16_PerChannelSym, ) -from pymllm.backends.qualcomm.transformers.core.embedding import QEmbedding -from pymllm.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM -from pymllm.backends.qualcomm.transformers.core.observer import ConcatObserver +from pymllm.mobile.backends.qualcomm.transformers.core.embedding import QEmbedding +from pymllm.mobile.backends.qualcomm.transformers.qwen3.modeling_qwen3 import Qwen3ForCausalLM +from pymllm.mobile.backends.qualcomm.transformers.core.observer import ConcatObserver def recompute_scale_zp(module): diff --git a/pymllm/backends/qualcomm/transformers/qwen3/train.py b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py similarity index 94% rename from pymllm/backends/qualcomm/transformers/qwen3/train.py rename to pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py index 63c6d0e86..f44fa67b5 100644 --- a/pymllm/backends/qualcomm/transformers/qwen3/train.py +++ b/pymllm/mobile/backends/qualcomm/transformers/qwen3/train.py @@ -2,7 +2,7 @@ import torch import argparse from safetensors.torch import save_model -from pymllm.backends.qualcomm.transformers.qwen3.runner import Qwen3Quantizer +from pymllm.mobile.backends.qualcomm.transformers.qwen3.runner import Qwen3Quantizer def main(): diff --git a/pymllm/convertor/__init__.py b/pymllm/mobile/convertor/__init__.py similarity index 100% rename from pymllm/convertor/__init__.py rename to pymllm/mobile/convertor/__init__.py diff --git a/pymllm/convertor/mllm_type_mapping.py b/pymllm/mobile/convertor/mllm_type_mapping.py similarity index 100% rename from pymllm/convertor/mllm_type_mapping.py rename to pymllm/mobile/convertor/mllm_type_mapping.py diff --git a/pymllm/convertor/model_file_v1.py b/pymllm/mobile/convertor/model_file_v1.py similarity index 100% rename from pymllm/convertor/model_file_v1.py rename to pymllm/mobile/convertor/model_file_v1.py diff --git a/pymllm/convertor/model_file_v2.py b/pymllm/mobile/convertor/model_file_v2.py similarity index 100% rename from pymllm/convertor/model_file_v2.py rename to pymllm/mobile/convertor/model_file_v2.py diff --git a/pymllm/ffi/__init__.py b/pymllm/mobile/ffi/__init__.py similarity index 100% rename from pymllm/ffi/__init__.py rename to pymllm/mobile/ffi/__init__.py diff --git a/pymllm/ffi/_ffi_api.py b/pymllm/mobile/ffi/_ffi_api.py similarity index 100% rename from pymllm/ffi/_ffi_api.py rename to pymllm/mobile/ffi/_ffi_api.py diff --git a/pymllm/ffi/base.py b/pymllm/mobile/ffi/base.py similarity index 90% rename from pymllm/ffi/base.py rename to pymllm/mobile/ffi/base.py index 07a01c49e..96aed2425 100644 --- a/pymllm/ffi/base.py +++ b/pymllm/mobile/ffi/base.py @@ -8,7 +8,7 @@ def _load_lib(): file_dir = os.path.dirname(os.path.realpath(__file__)) - parent_dir = os.path.dirname(file_dir) + parent_dir = os.path.dirname(os.path.dirname(file_dir)) # Platform-specific library names if sys.platform.startswith("win32"): diff --git a/pymllm/nn/__init__.py b/pymllm/mobile/nn/__init__.py similarity index 100% rename from pymllm/nn/__init__.py rename to pymllm/mobile/nn/__init__.py diff --git a/pymllm/nn/_layers.py b/pymllm/mobile/nn/_layers.py similarity index 100% rename from pymllm/nn/_layers.py rename to pymllm/mobile/nn/_layers.py diff --git a/pymllm/nn/_module.py b/pymllm/mobile/nn/_module.py similarity index 100% rename from pymllm/nn/_module.py rename to pymllm/mobile/nn/_module.py diff --git a/pymllm/nn/functional.py b/pymllm/mobile/nn/functional.py similarity index 100% rename from pymllm/nn/functional.py rename to pymllm/mobile/nn/functional.py diff --git a/pymllm/quantize/__init__.py b/pymllm/mobile/quantize/__init__.py similarity index 100% rename from pymllm/quantize/__init__.py rename to pymllm/mobile/quantize/__init__.py diff --git a/pymllm/quantize/cast2fp32_pass.py b/pymllm/mobile/quantize/cast2fp32_pass.py similarity index 100% rename from pymllm/quantize/cast2fp32_pass.py rename to pymllm/mobile/quantize/cast2fp32_pass.py diff --git a/pymllm/compile/__init__.py b/pymllm/mobile/quantize/gguf/__init__.py similarity index 100% rename from pymllm/compile/__init__.py rename to pymllm/mobile/quantize/gguf/__init__.py diff --git a/pymllm/quantize/kai/__init__.py b/pymllm/mobile/quantize/kai/__init__.py similarity index 100% rename from pymllm/quantize/kai/__init__.py rename to pymllm/mobile/quantize/kai/__init__.py diff --git a/pymllm/quantize/kai/w4a32.py b/pymllm/mobile/quantize/kai/w4a32.py similarity index 100% rename from pymllm/quantize/kai/w4a32.py rename to pymllm/mobile/quantize/kai/w4a32.py diff --git a/pymllm/quantize/pipeline.py b/pymllm/mobile/quantize/pipeline.py similarity index 100% rename from pymllm/quantize/pipeline.py rename to pymllm/mobile/quantize/pipeline.py diff --git a/pymllm/quantize/quantize_pass.py b/pymllm/mobile/quantize/quantize_pass.py similarity index 100% rename from pymllm/quantize/quantize_pass.py rename to pymllm/mobile/quantize/quantize_pass.py diff --git a/pymllm/quantize/solver.py b/pymllm/mobile/quantize/solver.py similarity index 100% rename from pymllm/quantize/solver.py rename to pymllm/mobile/quantize/solver.py diff --git a/pymllm/quantize/gguf/__init__.py b/pymllm/mobile/quantize/spinquant/__init__.py similarity index 100% rename from pymllm/quantize/gguf/__init__.py rename to pymllm/mobile/quantize/spinquant/__init__.py diff --git a/pymllm/service/__init__.py b/pymllm/mobile/service/__init__.py similarity index 100% rename from pymllm/service/__init__.py rename to pymllm/mobile/service/__init__.py diff --git a/pymllm/service/models_hub.py b/pymllm/mobile/service/models_hub.py similarity index 100% rename from pymllm/service/models_hub.py rename to pymllm/mobile/service/models_hub.py diff --git a/pymllm/service/network.py b/pymllm/mobile/service/network.py similarity index 100% rename from pymllm/service/network.py rename to pymllm/mobile/service/network.py diff --git a/pymllm/service/rr_process.py b/pymllm/mobile/service/rr_process.py similarity index 100% rename from pymllm/service/rr_process.py rename to pymllm/mobile/service/rr_process.py diff --git a/pymllm/service/tools.py b/pymllm/mobile/service/tools.py similarity index 100% rename from pymllm/service/tools.py rename to pymllm/mobile/service/tools.py diff --git a/pymllm/tests/qualcomm/test_context_create.py b/pymllm/mobile/tests/qualcomm/test_context_create.py similarity index 89% rename from pymllm/tests/qualcomm/test_context_create.py rename to pymllm/mobile/tests/qualcomm/test_context_create.py index 18983daa7..94f42b513 100644 --- a/pymllm/tests/qualcomm/test_context_create.py +++ b/pymllm/mobile/tests/qualcomm/test_context_create.py @@ -1,5 +1,5 @@ -import pymllm as mllm -from pymllm.backends.qualcomm.qnn_aot_env import ( +import pymllm.mobile as mllm +from pymllm.mobile.backends.qualcomm.qnn_aot_env import ( QnnAOTEnv, QnnDeviceAndContext, QcomTryBestPerformance, diff --git a/pymllm/tests/test_nn.py b/pymllm/mobile/tests/test_nn.py similarity index 83% rename from pymllm/tests/test_nn.py rename to pymllm/mobile/tests/test_nn.py index d9a3db2d8..403060e99 100644 --- a/pymllm/tests/test_nn.py +++ b/pymllm/mobile/tests/test_nn.py @@ -1,5 +1,5 @@ -import pymllm as mllm -from pymllm import nn +import pymllm.mobile as mllm +from pymllm.mobile import nn class FooModule(nn.Module): diff --git a/pymllm/tests/test_tensor.py b/pymllm/mobile/tests/test_tensor.py similarity index 89% rename from pymllm/tests/test_tensor.py rename to pymllm/mobile/tests/test_tensor.py index e935f10b4..474e10922 100644 --- a/pymllm/tests/test_tensor.py +++ b/pymllm/mobile/tests/test_tensor.py @@ -1,7 +1,7 @@ # Copyright (c) MLLM Team. # Licensed under the MIT License. -import pymllm as torch +import pymllm.mobile as torch def test_empty_tensor_create() -> bool: diff --git a/pymllm/utils/__init__.py b/pymllm/mobile/utils/__init__.py similarity index 100% rename from pymllm/utils/__init__.py rename to pymllm/mobile/utils/__init__.py diff --git a/pymllm/utils/adb.py b/pymllm/mobile/utils/adb.py similarity index 100% rename from pymllm/utils/adb.py rename to pymllm/mobile/utils/adb.py diff --git a/pymllm/utils/error_handler.py b/pymllm/mobile/utils/error_handler.py similarity index 100% rename from pymllm/utils/error_handler.py rename to pymllm/mobile/utils/error_handler.py diff --git a/pymllm/utils/mllm_convertor.py b/pymllm/mobile/utils/mllm_convertor.py similarity index 100% rename from pymllm/utils/mllm_convertor.py rename to pymllm/mobile/utils/mllm_convertor.py diff --git a/pymllm/models/__init__.py b/pymllm/models/__init__.py new file mode 100644 index 000000000..7751b3091 --- /dev/null +++ b/pymllm/models/__init__.py @@ -0,0 +1,62 @@ +"""Model registry for pymllm. + +Maps HuggingFace ``config.architectures[0]`` strings to pymllm model classes. +Models are imported lazily via ``importlib`` so that heavy dependencies (torch, +numpy, etc.) are only loaded when a model is actually requested. +""" + +from __future__ import annotations + +import importlib +import logging +from typing import Dict, Optional, Tuple, Type + +import torch.nn as nn + +logger = logging.getLogger(__name__) + +# (module_path, class_name) +_MODEL_REGISTRY: Dict[str, Tuple[str, str]] = { + "Qwen3VLForConditionalGeneration": ( + "pymllm.models.qwen3_vl", + "Qwen3VLForConditionalGeneration", + ), + # Qwen3.5 (hybrid attention: full + GDN linear) + "Qwen3_5ForCausalLM": ( + "pymllm.models.qwen3_5", + "Qwen3_5ForCausalLM", + ), + "Qwen3_5ForConditionalGeneration": ( + "pymllm.models.qwen3_5", + "Qwen3_5ForConditionalGeneration", + ), +} + + +def get_model_class(architecture: str) -> Optional[Type[nn.Module]]: + """Look up a pymllm model class by HuggingFace architecture string. + + Returns ``None`` if the architecture is not registered or cannot be + imported. The caller is responsible for raising an appropriate error. + """ + entry = _MODEL_REGISTRY.get(architecture) + if entry is None: + return None + + module_path, class_name = entry + try: + module = importlib.import_module(module_path) + cls = getattr(module, class_name) + logger.info( + "Resolved architecture %r -> %s.%s", architecture, module_path, class_name + ) + return cls + except (ImportError, AttributeError) as exc: + logger.warning( + "Failed to import %s.%s for architecture %r: %s", + module_path, + class_name, + architecture, + exc, + ) + return None diff --git a/pymllm/models/qwen3_5.py b/pymllm/models/qwen3_5.py new file mode 100644 index 000000000..5b6bd558a --- /dev/null +++ b/pymllm/models/qwen3_5.py @@ -0,0 +1,561 @@ +"""Inference-only Qwen3.5 model for pymllm. + +Implements the hybrid attention architecture: +- **Full attention layers** (standard transformer with RoPE + output gate) +- **GDN linear attention layers** (Gated Delta Network, O(n) complexity) + +Layers alternate: linear, attention, linear, attention, ... based on +``full_attention_interval`` in the config. + +Supports: +- Dense (non-MoE) variant +- Vision-Language (multimodal) via inheritance from Qwen3VL + +Adapted from sglang's ``qwen3_5.py``. +""" + +from __future__ import annotations + +import logging +import math +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.embedding import VocabParallelEmbedding +from pymllm.layers.gated_delta_net import GatedDeltaNet +from pymllm.layers.linear import Linear +from pymllm.layers.mlp import MLP +from pymllm.layers.rms_norm import GemmaRMSNorm, RMSNorm +from pymllm.layers.rope import apply_rope_pos_ids +from pymllm.layers.utils import set_weight_attrs + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Config helpers +# --------------------------------------------------------------------------- + + +def _get_text_config(config): + """Extract the text sub-config from a multimodal config, or return as-is.""" + return getattr(config, "text_config", config) + + +def _get_layer_types(config) -> List[str]: + """Return per-layer type list: 'attention' or 'linear_attention'.""" + if hasattr(config, "layers_block_type"): + return config.layers_block_type + # Compute from full_attention_interval + interval = getattr(config, "full_attention_interval", 2) + n_layers = config.num_hidden_layers + types = [] + for i in range(n_layers): + if (i + 1) % interval == 0: + types.append("attention") + else: + types.append("linear_attention") + return types + + +# --------------------------------------------------------------------------- +# Full Attention Layer (with output gate + QK norm) +# --------------------------------------------------------------------------- + + +class Qwen3_5FullAttention(nn.Module): + """Standard multi-head attention with RoPE, QK-norm, and optional output gate.""" + + def __init__(self, config, layer_id: int, quant_config=None, prefix: str = ""): + super().__init__() + tc = _get_text_config(config) + self.hidden_size = tc.hidden_size + self.num_heads = tc.num_attention_heads + self.num_kv_heads = tc.num_key_value_heads + self.head_dim = getattr(tc, "head_dim", self.hidden_size // self.num_heads) + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = self.head_dim ** -0.5 + self.layer_id = layer_id + + # Output gate: Qwen3.5 doubles the Q projection and uses half as a + # sigmoid gate on the attention output. + self.attn_output_gate = getattr(tc, "attn_output_gate", True) + + if self.attn_output_gate: + q_proj_size = self.q_size * 2 # Q + gate + else: + q_proj_size = self.q_size + + def _get_qm(suffix): + if quant_config is None: + return None + return quant_config.get_quant_method( + layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + self.q_proj = Linear(self.hidden_size, q_proj_size, bias=False, quant_method=_get_qm("q_proj")) + self.k_proj = Linear(self.hidden_size, self.kv_size, bias=False, quant_method=_get_qm("k_proj")) + self.v_proj = Linear(self.hidden_size, self.kv_size, bias=False, quant_method=_get_qm("v_proj")) + self.o_proj = Linear(self.q_size, self.hidden_size, bias=False, quant_method=_get_qm("o_proj")) + + # QK normalization + self.q_norm = GemmaRMSNorm(self.head_dim, eps=tc.rms_norm_eps) + self.k_norm = GemmaRMSNorm(self.head_dim, eps=tc.rms_norm_eps) + + # RoPE config + self.partial_rotary_factor = getattr(tc, "partial_rotary_factor", 1.0) + rope_config = getattr(tc, "rope_parameters", None) or getattr(tc, "rope_scaling", None) or {} + self.rope_theta = rope_config.get("rope_theta", getattr(tc, "rope_theta", 10000.0)) + self.rotary_dim = int(self.head_dim * self.partial_rotary_factor) + + # RadixAttention layer — delegates to the pluggable attention backend + self.attn = RadixAttention( + num_heads=self.num_heads, + head_dim=self.head_dim, + scaling=self.scaling, + num_kv_heads=self.num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: Any, + ) -> torch.Tensor: + seq_len = hidden_states.shape[0] + + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + if self.attn_output_gate: + # Split Q into actual Q and gate + q_gate = q.view(seq_len, self.num_heads, self.head_dim * 2) + q, gate = q_gate.chunk(2, dim=-1) + q = q.reshape(seq_len, -1) + gate = gate.reshape(seq_len, -1) + + # QK norm + q = self.q_norm(q.reshape(-1, self.head_dim)).view(seq_len, -1) + k = self.k_norm(k.reshape(-1, self.head_dim)).view(seq_len, -1) + + # RoPE (inplace; rotary_dim handles partial rotation) + q = q.view(seq_len, self.num_heads, self.head_dim) + k = k.view(seq_len, self.num_kv_heads, self.head_dim) + apply_rope_pos_ids( + q, k, positions, inplace=True, + rotary_dim=self.rotary_dim, rope_theta=self.rope_theta, + ) + q = q.reshape(seq_len, -1) + k = k.reshape(seq_len, -1) + + # Standard attention via RadixAttention → attn_backend + attn_output = self.attn(q, k, v, forward_batch) + + # Output gate + if self.attn_output_gate: + attn_output = attn_output * torch.sigmoid(gate) + + return self.o_proj(attn_output) + + +# --------------------------------------------------------------------------- +# Full Attention Decoder Layer +# --------------------------------------------------------------------------- + + +class Qwen3_5AttentionDecoderLayer(nn.Module): + """Decoder layer with full attention + MLP.""" + + def __init__(self, config, layer_id: int, quant_config=None, prefix: str = ""): + super().__init__() + tc = _get_text_config(config) + self.self_attn = Qwen3_5FullAttention( + config, layer_id, + quant_config=quant_config, + prefix=f"{prefix}.self_attn" if prefix else "self_attn", + ) + self.mlp = MLP( + hidden_size=tc.hidden_size, + intermediate_size=tc.intermediate_size, + activation=tc.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp" if prefix else "mlp", + ) + self.input_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: Any, + ): + # Pre-norm + residual + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + + # Post-attention norm + residual + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Linear Attention (GDN) Decoder Layer +# --------------------------------------------------------------------------- + + +class Qwen3_5LinearDecoderLayer(nn.Module): + """Decoder layer with GDN linear attention + MLP.""" + + def __init__(self, config, layer_id: int, gdn_layer_idx: int = 0, + quant_config=None, prefix: str = ""): + super().__init__() + tc = _get_text_config(config) + self.linear_attn = GatedDeltaNet( + hidden_size=tc.hidden_size, + num_k_heads=getattr(tc, "linear_num_key_heads", 16), + num_v_heads=getattr(tc, "linear_num_value_heads", 32), + head_k_dim=getattr(tc, "linear_key_head_dim", 128), + head_v_dim=getattr(tc, "linear_value_head_dim", 128), + conv_kernel_size=getattr(tc, "linear_conv_kernel_dim", 4), + layer_id=layer_id, + gdn_layer_idx=gdn_layer_idx, + rms_norm_eps=tc.rms_norm_eps, + quant_config=quant_config, + prefix=f"{prefix}.linear_attn" if prefix else "linear_attn", + ) + self.mlp = MLP( + hidden_size=tc.hidden_size, + intermediate_size=tc.intermediate_size, + activation=tc.hidden_act, + quant_config=quant_config, + prefix=f"{prefix}.mlp" if prefix else "mlp", + ) + self.input_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + self.post_attention_layernorm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + residual: Optional[torch.Tensor], + forward_batch: Any, + ): + if residual is None: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + else: + hidden_states, residual = self.input_layernorm(hidden_states, residual) + + hidden_states = self.linear_attn(hidden_states, forward_batch) + + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) + hidden_states = self.mlp(hidden_states) + + return hidden_states, residual + + +# --------------------------------------------------------------------------- +# Layer type registry +# --------------------------------------------------------------------------- + +_DECODER_LAYER_TYPES = { + "attention": Qwen3_5AttentionDecoderLayer, + "linear_attention": Qwen3_5LinearDecoderLayer, +} + + +# --------------------------------------------------------------------------- +# Qwen3.5 Language Model (dense variant) +# --------------------------------------------------------------------------- + + +class Qwen3_5ForCausalLM(nn.Module): + """Qwen3.5 causal language model with hybrid attention. + + Alternates between full attention and GDN linear attention layers. + Dense (non-MoE) variant. + """ + + def __init__(self, config, quant_config=None): + super().__init__() + tc = _get_text_config(config) + self.config = tc + self.quant_config = quant_config + self.hidden_size = tc.hidden_size + self.vocab_size = tc.vocab_size + + # Embedding + self.embed_tokens = VocabParallelEmbedding(tc.vocab_size, tc.hidden_size) + + # Build hybrid decoder layers with sequential GDN indexing + layer_types = _get_layer_types(tc) + self.layer_types = layer_types + self.layers = nn.ModuleList() + gdn_count = 0 + self.full_attn_layer_ids = set() + for idx in range(tc.num_hidden_layers): + layer_type = layer_types[idx] + layer_prefix = f"layers.{idx}" + if layer_type == "linear_attention": + self.layers.append( + Qwen3_5LinearDecoderLayer( + config, idx, gdn_layer_idx=gdn_count, + quant_config=quant_config, prefix=layer_prefix, + ) + ) + gdn_count += 1 + else: + self.layers.append( + Qwen3_5AttentionDecoderLayer( + config, idx, + quant_config=quant_config, prefix=layer_prefix, + ) + ) + self.full_attn_layer_ids.add(idx) + self.num_gdn_layers = gdn_count + + # Final norm + self.norm = GemmaRMSNorm(tc.hidden_size, eps=tc.rms_norm_eps) + + logger.info( + "Qwen3_5ForCausalLM: %d layers (%d attention + %d GDN)", + tc.num_hidden_layers, + len(self.full_attn_layer_ids), + self.num_gdn_layers, + ) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any, + input_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + residual = None + for layer in self.layers: + hidden_states, residual = layer( + positions=positions, + hidden_states=hidden_states, + residual=residual, + forward_batch=forward_batch, + ) + + # Final normalization + if residual is not None: + hidden_states, _ = self.norm(hidden_states, residual) + else: + hidden_states = self.norm(hidden_states) + + return hidden_states + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load HuggingFace checkpoint weights with name remapping.""" + # When quantized, gate/up are separate projections — skip stacking. + if self.quant_config is not None: + stacked_params_mapping = [] + else: + stacked_params_mapping = [ + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + params_dict = dict(self.named_parameters()) + loaded: Set[str] = set() + + for name, weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if "mtp" in name: + continue + if "visual" in name: + continue + if "language_model" in name: + name = name.replace("model.language_model.", "") + if name.startswith("model."): + name = name[len("model."):] + # NOTE: do NOT strip .self_attn — pymllm keeps it as a submodule + + # Handle stacked params (gate_up_proj = gate_proj + up_proj) + matched = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + param = params_dict[name] + # gate_up_proj is a plain Linear — manually place each shard + output_dim = param.shape[0] // 2 + param.data[shard_id * output_dim : (shard_id + 1) * output_dim].copy_( + weight + ) + matched = True + break + + if not matched: + if name not in params_dict: + continue + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, weight) + else: + # Squeeze conv1d weight from [C, 1, K] to [C, K] + if weight.dim() != param.dim(): + weight = weight.squeeze() + param.data.copy_(weight) + + loaded.add(name) + + logger.info("Loaded %d parameter tensors for Qwen3_5ForCausalLM", len(loaded)) + return loaded + + +# --------------------------------------------------------------------------- +# Qwen3.5 Vision-Language Model +# --------------------------------------------------------------------------- + + +class Qwen3_5ForConditionalGeneration(nn.Module): + """Qwen3.5 multimodal model (text + vision). + + Inherits vision encoder from Qwen3VL and uses Qwen3.5's hybrid + language model. + """ + + def __init__(self, config, quant_config=None): + super().__init__() + from pymllm.models.qwen3_vl import ( + Qwen3VLVisionModel, + ) + + self.config = config + self.quant_config = quant_config + tc = _get_text_config(config) + + # Vision encoder — NOT quantized + vision_config = getattr(config, "vision_config", None) + if vision_config is not None: + self.visual = Qwen3VLVisionModel( + depth=getattr(vision_config, "depth", 27), + hidden_size=getattr(vision_config, "hidden_size", 1152), + hidden_act=getattr(vision_config, "hidden_act", "gelu_pytorch_tanh"), + intermediate_size=getattr(vision_config, "intermediate_size", 4304), + num_heads=getattr(vision_config, "num_heads", 16), + in_channels=getattr(vision_config, "in_channels", 3), + patch_size=getattr(vision_config, "patch_size", 16), + spatial_merge_size=getattr(vision_config, "spatial_merge_size", 2), + temporal_patch_size=getattr(vision_config, "temporal_patch_size", 2), + out_hidden_size=getattr(vision_config, "out_hidden_size", 3584), + num_position_embeddings=getattr( + vision_config, "num_position_embeddings", 2304 + ), + deepstack_visual_indexes=getattr( + vision_config, "deepstack_visual_indexes", [8, 16, 24] + ), + norm_eps=getattr(tc, "rms_norm_eps", 1e-6), + ) + else: + self.visual = None + + # Language model + self.model = Qwen3_5ForCausalLM(config, quant_config=quant_config) + + # Expose hybrid model metadata for ModelRunner + self.num_gdn_layers = self.model.num_gdn_layers + self.full_attn_layer_ids = self.model.full_attn_layer_ids + + # LM head (tied to embedding when tie_word_embeddings=True) + self.lm_head = Linear(tc.hidden_size, tc.vocab_size, bias=False) + if getattr(tc, "tie_word_embeddings", False): + self.lm_head.weight = self.model.embed_tokens.weight + + # Vision token IDs + self.image_token_id = getattr(config, "image_token_id", 151655) + self.video_token_id = getattr(config, "video_token_id", 151656) + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: Any, + input_embeds: Optional[torch.Tensor] = None, + pixel_values: Optional[torch.Tensor] = None, + image_grid_thw: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Process vision inputs if provided + if input_embeds is None and pixel_values is not None and self.visual is not None: + input_embeds = self.model.embed_tokens(input_ids) + # Run vision encoder + visual_embeds = self.visual(pixel_values, grid_thw=image_grid_thw) + # Replace image/video token positions with visual embeddings + mask = (input_ids == self.image_token_id) | (input_ids == self.video_token_id) + if mask.any(): + input_embeds[mask] = visual_embeds.reshape(-1, visual_embeds.shape[-1]) + + hidden_states = self.model( + input_ids=input_ids, + positions=positions, + forward_batch=forward_batch, + input_embeds=input_embeds, + ) + + # LM head + logits = self.lm_head(hidden_states) + return logits + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + """Load weights, dispatching visual vs language params.""" + visual_weights = [] + language_weights = [] + + for name, weight in weights: + if "visual" in name or "model.visual" in name: + # Normalize visual weight names + name = name.replace("model.visual.", "visual.") + name = name.replace("attn.qkv.", "attn.qkv_proj.") + visual_weights.append((name, weight)) + else: + language_weights.append((name, weight)) + + # Load language model weights + self.model.load_weights(language_weights) + + # Load visual weights + if self.visual is not None and visual_weights: + params_dict = dict(self.named_parameters()) + for name, weight in visual_weights: + if name in params_dict: + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, weight) + else: + param.data.copy_(weight) + + logger.info("Qwen3_5ForConditionalGeneration weights loaded") diff --git a/pymllm/compile/mllm_ir/trace.py b/pymllm/models/qwen3_moe.py similarity index 100% rename from pymllm/compile/mllm_ir/trace.py rename to pymllm/models/qwen3_moe.py diff --git a/pymllm/models/qwen3_vl.py b/pymllm/models/qwen3_vl.py new file mode 100644 index 000000000..b253ad091 --- /dev/null +++ b/pymllm/models/qwen3_vl.py @@ -0,0 +1,1385 @@ +# Copyright 2025 Qwen Team +# Copyright 2025 SGLang Team +# Adapted for pymllm +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Inference-only Qwen3-VL model for pymllm. + +Adapted from sglang's Qwen3-VL implementation for pymllm's single-GPU +inference architecture. Uses pymllm layers (RadixAttention, RMSNorm, MLP) +and conforms to the pymllm forward interface:: + + model.forward(input_ids, positions, forward_batch) + +Designed for a single accelerator card — no tensor / pipeline parallelism. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING, Iterable, List, Optional, Tuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pymllm.layers import RMSNorm, apply_mrope +from pymllm.layers.attention.radix_attention import RadixAttention +from pymllm.layers.linear import Linear +from pymllm.layers.mlp import MLP + +if TYPE_CHECKING: + from pymllm.engine.forward_batch import ForwardBatch + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Vision Encoder +# --------------------------------------------------------------------------- + + +class Qwen3VisionMLP(nn.Module): + """MLP block for the vision encoder.""" + + def __init__( + self, + in_features: int, + hidden_features: int, + hidden_act: str = "silu", + bias: bool = True, + ): + super().__init__() + self.linear_fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.linear_fc2 = nn.Linear(hidden_features, in_features, bias=bias) + if hidden_act == "gelu_pytorch_tanh": + self.act = nn.GELU(approximate="tanh") + elif hidden_act == "gelu": + self.act = nn.GELU() + else: + self.act = nn.SiLU() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.linear_fc2(self.act(self.linear_fc1(x))) + + +class Qwen3VLVisionPatchEmbed(nn.Module): + """3D convolution patch embedding for video/image patchification.""" + + def __init__( + self, + patch_size: int = 16, + temporal_patch_size: int = 2, + in_channels: int = 3, + embed_dim: int = 1152, + ): + super().__init__() + self.patch_size = patch_size + self.temporal_patch_size = temporal_patch_size + self.in_channels = in_channels + self.embed_dim = embed_dim + + kernel_size = [temporal_patch_size, patch_size, patch_size] + self.proj = nn.Conv3d( + in_channels, + embed_dim, + kernel_size=kernel_size, + stride=kernel_size, + bias=True, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + target_dtype = self.proj.weight.dtype + hidden_states = hidden_states.view( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, + ) + hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view( + -1, self.embed_dim + ) + return hidden_states + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate half the hidden dims of the input for RoPE.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +class Qwen3VisionAttention(nn.Module): + """Multi-head self-attention for the vision encoder (no KV cache).""" + + def __init__(self, embed_dim: int, num_heads: int): + super().__init__() + self.embed_dim = embed_dim + self.num_heads = num_heads + self.head_dim = embed_dim // num_heads + + self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=True) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + ) -> torch.Tensor: + """Forward pass with variable-length sequences via cu_seqlens. + + Args: + x: [total_tokens, embed_dim] + cu_seqlens: [num_seqs + 1] cumulative sequence lengths + rotary_pos_emb_cos: [total_tokens, rotary_dim] + rotary_pos_emb_sin: [total_tokens, rotary_dim] + """ + seq_len = x.shape[0] + qkv = self.qkv_proj(x) + q, k, v = qkv.reshape(seq_len, 3, self.num_heads, self.head_dim).unbind(dim=1) + + # Apply rotary position embedding. + # cos/sin are [total_tokens, head_dim // 2]. + # VisionAttention: double them to full head_dim and apply RoPE to + # all head dimensions (the rotation pairs (q[i], q[i + head_dim//2])). + cos = rotary_pos_emb_cos + sin = rotary_pos_emb_sin + if cos.shape[-1] * 2 == self.head_dim: + cos = torch.cat([cos, cos], dim=-1) + sin = torch.cat([sin, sin], dim=-1) + + cos = cos.unsqueeze(1) # [seq, 1, head_dim] + sin = sin.unsqueeze(1) # [seq, 1, head_dim] + + q = q * cos + _rotate_half(q) * sin + k = k * cos + _rotate_half(k) * sin + + # Scaled dot-product attention per variable-length sequence + output = torch.empty_like(q) + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + start = cu_seqlens[i].item() + end = cu_seqlens[i + 1].item() + qi = q[start:end].transpose(0, 1).unsqueeze(0) # [1, heads, seq, dim] + ki = k[start:end].transpose(0, 1).unsqueeze(0) + vi = v[start:end].transpose(0, 1).unsqueeze(0) + oi = F.scaled_dot_product_attention(qi, ki, vi) + output[start:end] = oi.squeeze(0).transpose(0, 1) + + output = output.reshape(seq_len, self.embed_dim) + return self.out_proj(output) + + +class Qwen3VisionBlock(nn.Module): + """Single vision transformer block.""" + + def __init__( + self, + dim: int, + num_heads: int, + intermediate_dim: int, + hidden_act: str = "silu", + norm_eps: float = 1e-6, + ): + super().__init__() + self.norm1 = nn.LayerNorm(dim, eps=norm_eps) + self.norm2 = nn.LayerNorm(dim, eps=norm_eps) + self.attn = Qwen3VisionAttention(embed_dim=dim, num_heads=num_heads) + self.mlp = Qwen3VisionMLP( + dim, intermediate_dim, hidden_act=hidden_act, bias=True + ) + + def forward( + self, + x: torch.Tensor, + cu_seqlens: torch.Tensor, + rotary_pos_emb_cos: torch.Tensor, + rotary_pos_emb_sin: torch.Tensor, + ) -> torch.Tensor: + x = x + self.attn( + self.norm1(x), + cu_seqlens=cu_seqlens, + rotary_pos_emb_cos=rotary_pos_emb_cos, + rotary_pos_emb_sin=rotary_pos_emb_sin, + ) + x = x + self.mlp(self.norm2(x)) + return x + + +class Qwen3VLVisionPatchMerger(nn.Module): + """Merges spatial patches to reduce sequence length. + + Groups ``spatial_merge_size ** 2`` consecutive patch tokens and projects + them to the language model hidden dimension. + """ + + def __init__( + self, + dim: int, + context_dim: int, + spatial_merge_size: int = 2, + use_postshuffle_norm: bool = False, + norm_eps: float = 1e-6, + ): + super().__init__() + self.hidden_size = context_dim * (spatial_merge_size**2) + self.use_postshuffle_norm = use_postshuffle_norm + self.norm = nn.LayerNorm( + self.hidden_size if use_postshuffle_norm else context_dim, eps=norm_eps + ) + self.linear_fc1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True) + self.act_fn = nn.GELU() + self.linear_fc2 = nn.Linear(self.hidden_size, dim, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.use_postshuffle_norm: + x = self.norm(x.view(-1, self.hidden_size)) + else: + x = self.norm(x).view(-1, self.hidden_size) + x = self.act_fn(self.linear_fc1(x)) + return self.linear_fc2(x) + + +class Qwen3VLVisionModel(nn.Module): + """Complete vision encoder for Qwen3-VL. + + Produces patch embeddings from raw pixel values, applies a stack of + vision transformer blocks with 3D rotary embeddings, then merges + spatial patches. Supports "deep stack" where intermediate layer + outputs are captured and concatenated to the final output. + """ + + def __init__( + self, + depth: int = 27, + hidden_size: int = 1152, + hidden_act: str = "gelu_pytorch_tanh", + intermediate_size: int = 4304, + num_heads: int = 16, + in_channels: int = 3, + patch_size: int = 16, + spatial_merge_size: int = 2, + temporal_patch_size: int = 2, + out_hidden_size: int = 3584, + num_position_embeddings: int = 2304, + deepstack_visual_indexes: Optional[List[int]] = None, + norm_eps: float = 1e-6, + ): + super().__init__() + if deepstack_visual_indexes is None: + deepstack_visual_indexes = [8, 16, 24] + + self.hidden_size = hidden_size + self.num_heads = num_heads + self.num_position_embeddings = num_position_embeddings + self.num_grid_per_side = int(num_position_embeddings**0.5) + self.patch_size = patch_size + self.spatial_merge_size = spatial_merge_size + self.temporal_patch_size = temporal_patch_size + self.deepstack_visual_indexes = deepstack_visual_indexes + # Total output dim = out_hidden_size * (1 main + N deepstack mergers) + self.out_hidden_size = out_hidden_size * (1 + len(deepstack_visual_indexes)) + + self.patch_embed = Qwen3VLVisionPatchEmbed( + patch_size=patch_size, + temporal_patch_size=temporal_patch_size, + in_channels=in_channels, + embed_dim=hidden_size, + ) + + self.pos_embed = nn.Embedding(num_position_embeddings, hidden_size) + + head_dim = hidden_size // num_heads + self._init_rope_cache(head_dim) + + self.blocks = nn.ModuleList( + [ + Qwen3VisionBlock( + dim=hidden_size, + num_heads=num_heads, + intermediate_dim=intermediate_size, + hidden_act=hidden_act, + norm_eps=norm_eps, + ) + for _ in range(depth) + ] + ) + + self.merger = Qwen3VLVisionPatchMerger( + dim=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + norm_eps=norm_eps, + ) + + self.deepstack_merger_list = nn.ModuleList( + [ + Qwen3VLVisionPatchMerger( + dim=out_hidden_size, + context_dim=hidden_size, + spatial_merge_size=spatial_merge_size, + use_postshuffle_norm=True, + norm_eps=norm_eps, + ) + for _ in range(len(deepstack_visual_indexes)) + ] + ) + + def _init_rope_cache(self, head_dim: int, max_grid_size: int = 8192): + """Precompute cos/sin cache for 2D rotary embeddings.""" + rotary_dim = head_dim // 2 + inv_freq = 1.0 / ( + 10000.0 + ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim) + ) + t = torch.arange(max_grid_size, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self.register_buffer("cos_cache", torch.cos(freqs), persistent=False) + self.register_buffer("sin_cache", torch.sin(freqs), persistent=False) + + @property + def dtype(self) -> torch.dtype: + return self.patch_embed.proj.weight.dtype + + @property + def device(self) -> torch.device: + return self.patch_embed.proj.weight.device + + # -- Rotary position embedding helpers -- + + @staticmethod + def _rot_pos_ids(h: int, w: int, spatial_merge_size: int) -> torch.Tensor: + """Compute 2D rotary position IDs for a grid of *h* x *w* patches. + + The patches are re-ordered to group ``spatial_merge_size ** 2`` + neighbours together (matching the merger's token order). + + Returns tensor of shape ``[h*w, 2]`` with ``(height_pos, width_pos)``. + """ + merge = spatial_merge_size + h_ids = torch.arange(h).unsqueeze(1).expand(-1, w) + w_ids = torch.arange(w).unsqueeze(0).expand(h, -1) + + h_ids = h_ids.reshape(h // merge, merge, w // merge, merge) + w_ids = w_ids.reshape(h // merge, merge, w // merge, merge) + + h_ids = h_ids.permute(0, 2, 1, 3).flatten() + w_ids = w_ids.permute(0, 2, 1, 3).flatten() + + return torch.stack([h_ids, w_ids], dim=-1) + + def rot_pos_emb( + self, grid_thw: List[List[int]] + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Compute rotary pos-emb cos/sin for all images/videos in the batch.""" + pos_ids = [] + for t, h, w in grid_thw: + base = self._rot_pos_ids(h, w, self.spatial_merge_size) + pos_ids.append(base if t == 1 else base.repeat(t, 1)) + + pos_ids = torch.cat(pos_ids, dim=0).to(self.device, non_blocking=True) + cos_combined = self.cos_cache[pos_ids].flatten(1) + sin_combined = self.sin_cache[pos_ids].flatten(1) + return cos_combined, sin_combined + + # -- Position embedding interpolation -- + + def _get_interpolation_indices(self, dim_size: int) -> np.ndarray: + indices = (np.arange(dim_size, dtype=np.float32) + 0.5) * ( + self.num_grid_per_side / dim_size + ) - 0.5 + return np.clip(indices, 0, self.num_grid_per_side - 1) + + def _calculate_indices_and_weights( + self, h_idxs: np.ndarray, w_idxs: np.ndarray + ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + """Compute bilinear interpolation indices and weights.""" + side = self.num_grid_per_side + h_f = np.floor(h_idxs).astype(np.int64) + h_c = np.clip(h_f + 1, 0, side - 1) + dh = h_idxs - h_f + w_f = np.floor(w_idxs).astype(np.int64) + w_c = np.clip(w_f + 1, 0, side - 1) + dw = w_idxs - w_f + + indices = [ + (h_f[:, None] * side + w_f).flatten(), + (h_f[:, None] * side + w_c).flatten(), + (h_c[:, None] * side + w_f).flatten(), + (h_c[:, None] * side + w_c).flatten(), + ] + weights = [ + ((1 - dh)[:, None] * (1 - dw)).flatten(), + ((1 - dh)[:, None] * dw).flatten(), + (dh[:, None] * (1 - dw)).flatten(), + (dh[:, None] * dw).flatten(), + ] + return indices, weights + + def _get_position_embedding( + self, + patch_pos_embeds: List[torch.Tensor], + grid_ts: List[int], + grid_hs: List[int], + grid_ws: List[int], + ) -> torch.Tensor: + """Tile and reorganize position embeddings to align with the merged token order.""" + result_parts = [] + merge = self.spatial_merge_size + for pos_embed, t, h, w in zip(patch_pos_embeds, grid_ts, grid_hs, grid_ws): + pos_embed = pos_embed.repeat(t, 1) + pos_embed = ( + pos_embed.view(t, h // merge, merge, w // merge, merge, -1) + .permute(0, 1, 3, 2, 4, 5) + .flatten(0, 4) + ) + result_parts.append(pos_embed) + return torch.cat(result_parts, dim=0) + + def fast_pos_embed_interpolate(self, grid_thw: torch.Tensor) -> torch.Tensor: + """Interpolate position embeddings via bilinear interpolation.""" + grid_thw_cpu = grid_thw.cpu().numpy() + temporal_dims = grid_thw_cpu[:, 0].tolist() + height_dims = grid_thw_cpu[:, 1].tolist() + width_dims = grid_thw_cpu[:, 2].tolist() + + device = self.pos_embed.weight.device + dtype = self.pos_embed.weight.dtype + + patches_size = [h * w for h, w in zip(height_dims, width_dims)] + total_patches = sum(patches_size) + all_indices_np = np.zeros((4, total_patches), dtype=np.int64) + all_weights_np = np.zeros((4, total_patches), dtype=np.float32) + + current_idx = 0 + for _t, h, w in zip(temporal_dims, height_dims, width_dims): + h_idxs = self._get_interpolation_indices(h) + w_idxs = self._get_interpolation_indices(w) + indices, weights = self._calculate_indices_and_weights(h_idxs, w_idxs) + end_idx = current_idx + h * w + for i in range(4): + all_indices_np[i, current_idx:end_idx] = indices[i] + all_weights_np[i, current_idx:end_idx] = weights[i] + current_idx = end_idx + + idx_tensor = torch.from_numpy(all_indices_np).to(device) + weight_tensor = torch.from_numpy(all_weights_np).to(dtype=dtype, device=device) + + pos_embeds = self.pos_embed(idx_tensor.view(-1)) + pos_embeds = pos_embeds.view(4, total_patches, -1) + patch_pos_embeds = (pos_embeds * weight_tensor.unsqueeze(-1)).sum(dim=0) + patch_pos_embeds = patch_pos_embeds.split(patches_size) + return self._get_position_embedding( + list(patch_pos_embeds), temporal_dims, height_dims, width_dims + ) + + # -- Forward -- + + def forward( + self, + x: torch.Tensor, + grid_thw: torch.Tensor, + ) -> torch.Tensor: + """Run the vision encoder. + + Args: + x: Pixel values, shape ``[total_patches, patch_dim]``. + grid_thw: Grid dimensions ``[num_images, 3]`` with ``(T, H, W)``. + + Returns: + Vision features of shape + ``[num_merged_tokens, out_hidden_size * (1 + num_deepstack)]``. + """ + x = x.to(device=self.device, dtype=self.dtype) + x = self.patch_embed(x) + + if isinstance(grid_thw, list): + grid_thw_list = grid_thw + grid_thw = torch.tensor(grid_thw, dtype=torch.int32) + else: + grid_thw_list = grid_thw.tolist() + + pos_embeds = self.fast_pos_embed_interpolate(grid_thw) + x += pos_embeds + + rotary_pos_emb_cos, rotary_pos_emb_sin = self.rot_pos_emb(grid_thw_list) + + cu_seqlens = _compute_cu_seqlens_from_grid(grid_thw) + cu_seqlens = cu_seqlens.to(self.device, non_blocking=True) + + deepstack_features = [] + ds_idx = 0 + + for layer_num, blk in enumerate(self.blocks): + x = blk(x, cu_seqlens, rotary_pos_emb_cos, rotary_pos_emb_sin) + + if layer_num in self.deepstack_visual_indexes: + # x is [total_tokens, hidden]. The merger expects the last + # dim to be context_dim so it can group spatial_merge_size^2 + # tokens; reshape to [total_tokens, 1, hidden] so that the + # `.view(-1, hidden_size)` inside the merger collapses the + # spatial merge correctly. + ds_feat = self.deepstack_merger_list[ds_idx](x.unsqueeze(1)) + deepstack_features.append(ds_feat) + ds_idx += 1 + + x = self.merger(x.unsqueeze(1)) + + # Concatenate main + deepstack features along the feature dimension. + # Result: [num_merged_tokens, out_hidden_size * (1 + num_deepstack)] + hidden_states = torch.cat([x] + deepstack_features, dim=-1) + return hidden_states + + +def _compute_cu_seqlens_from_grid(grid_thw: torch.Tensor) -> torch.Tensor: + """Compute cumulative sequence lengths from grid dimensions.""" + grid_np = grid_thw.cpu().numpy() + seq_lens = (grid_np[:, 0] * grid_np[:, 1] * grid_np[:, 2]).astype(np.int32) + cu_seqlens = np.concatenate([[0], np.cumsum(seq_lens)]) + return torch.tensor(cu_seqlens, dtype=torch.int32) + + +def _build_cos_sin_cache( + head_dim: int, + rope_theta: float, + max_pos: int, + dtype: torch.dtype, +) -> torch.Tensor: + """Build a [max_pos, head_dim] cos/sin cache for M-RoPE. + + Layout: first ``head_dim // 2`` columns are cos values, second half are sin. + Each row corresponds to one position index. + """ + inv_freq = 1.0 / ( + rope_theta ** (torch.arange(0, head_dim, 2, dtype=torch.float32) / head_dim) + ) + t = torch.arange(max_pos, dtype=torch.float32) + freqs = torch.outer(t, inv_freq) # [max_pos, head_dim // 2] + return torch.cat([torch.cos(freqs), torch.sin(freqs)], dim=-1).to(dtype) + + +def get_rope_index( + input_ids: torch.Tensor, + image_grid_thw: Optional[torch.Tensor], + image_token_id: int, + vision_start_token_id: int, + spatial_merge_size: int, +) -> Tuple[torch.Tensor, int]: + """Compute M-RoPE 3-D position IDs for one sequence. + + For text tokens all three (temporal, height, width) indices are equal to + the sequential counter. For image tokens the indices follow the spatial + grid ``(t, h, w)``. + + Args: + input_ids: Token IDs for one sequence, shape ``[T]``. + image_grid_thw: Grid dimensions for every image in the sequence, + shape ``[num_images, 3]``. ``None`` when there are no images. + image_token_id: Token ID used as placeholder for image patches. + vision_start_token_id: Token ID that precedes each image block. + spatial_merge_size: Number of patches merged per spatial dimension + (e.g. 2 → 2x2 merge, so llm_grid_h = H // 2). + + Returns: + ``(position_ids, mrope_position_delta)`` where ``position_ids`` has + shape ``[3, T]`` and ``mrope_position_delta`` is a Python ``int`` + equal to ``max_position_used + 1 - T``. + """ + total_tokens = input_ids.shape[0] + device = input_ids.device + position_ids = torch.zeros(3, total_tokens, dtype=torch.long, device=device) + + if image_grid_thw is None or image_grid_thw.shape[0] == 0: + pos = torch.arange(total_tokens, dtype=torch.long, device=device) + position_ids[0] = pos + position_ids[1] = pos + position_ids[2] = pos + return position_ids, 0 + + input_ids_cpu = input_ids.cpu().tolist() + grid_thw_list = image_grid_thw.cpu().tolist() + + llm_pos_ids_start = 0 + image_idx = 0 + i = 0 + + while i < total_tokens: + token = input_ids_cpu[i] + + if token == vision_start_token_id and image_idx < len(grid_thw_list): + # The vision_start token itself gets a regular sequential position. + position_ids[:, i] = llm_pos_ids_start + llm_pos_ids_start += 1 + i += 1 + + # Compute LLM-side grid dimensions (after spatial merging). + t_g = int(grid_thw_list[image_idx][0]) + h_g = int(grid_thw_list[image_idx][1]) + w_g = int(grid_thw_list[image_idx][2]) + llm_grid_t = t_g + llm_grid_h = h_g // spatial_merge_size + llm_grid_w = w_g // spatial_merge_size + num_image_tokens = llm_grid_t * llm_grid_h * llm_grid_w + + # Build per-patch 3-D indices. + t_idx = ( + torch.arange(llm_grid_t, device=device) + .view(-1, 1, 1) + .expand(-1, llm_grid_h, llm_grid_w) + .flatten() + ) + h_idx = ( + torch.arange(llm_grid_h, device=device) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_idx = ( + torch.arange(llm_grid_w, device=device) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + + img_start = i + img_end = i + num_image_tokens + position_ids[0, img_start:img_end] = t_idx + llm_pos_ids_start + position_ids[1, img_start:img_end] = h_idx + llm_pos_ids_start + position_ids[2, img_start:img_end] = w_idx + llm_pos_ids_start + + llm_pos_ids_start += max(llm_grid_t, llm_grid_h, llm_grid_w) + i += num_image_tokens + image_idx += 1 + else: + # Text token (including vision_end and all non-image tokens). + position_ids[:, i] = llm_pos_ids_start + llm_pos_ids_start += 1 + i += 1 + + mrope_position_delta = llm_pos_ids_start - total_tokens + return position_ids, mrope_position_delta + + +# --------------------------------------------------------------------------- +# Text Decoder (Language Model) +# --------------------------------------------------------------------------- + + +class Qwen3VLAttention(nn.Module): + """Attention layer for the Qwen3-VL text decoder. + + Uses QK-norm (per-head RMSNorm on Q and K before RoPE) and + :class:`RadixAttention` for KV-cached inference. Applies + interleaved M-RoPE with a precomputed cos/sin cache. + """ + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + layer_id: int, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.q_size = num_heads * head_dim + self.kv_size = num_kv_heads * head_dim + self.scaling = head_dim**-0.5 + self.mrope_section = list(mrope_section) + self.mrope_interleaved = mrope_interleaved + + def _get_qm(suffix): + if quant_config is None: + return None + return quant_config.get_quant_method( + layer=None, prefix=f"{prefix}.{suffix}" if prefix else suffix, + ) + + # When quantized, AWQ checkpoints store q/k/v separately so we + # cannot fuse them into a single packed-int32 parameter. + self.use_fused_qkv = quant_config is None + + if self.use_fused_qkv: + self.qkv_proj = Linear( + hidden_size, self.q_size + 2 * self.kv_size, bias=False, + ) + self.q_proj = None + self.k_proj = None + self.v_proj = None + else: + self.qkv_proj = None + self.q_proj = Linear( + hidden_size, self.q_size, bias=False, + quant_method=_get_qm("q_proj"), + ) + self.k_proj = Linear( + hidden_size, self.kv_size, bias=False, + quant_method=_get_qm("k_proj"), + ) + self.v_proj = Linear( + hidden_size, self.kv_size, bias=False, + quant_method=_get_qm("v_proj"), + ) + + # Output projection + self.o_proj = Linear( + num_heads * head_dim, hidden_size, bias=False, + quant_method=_get_qm("o_proj"), + ) + + # QK normalization + self.q_norm = RMSNorm(head_dim, eps=rms_norm_eps) + self.k_norm = RMSNorm(head_dim, eps=rms_norm_eps) + + # Precomputed M-RoPE cos/sin cache: [max_pos, head_dim] + cos_sin = _build_cos_sin_cache( + head_dim, rope_theta, max_position_embeddings, torch.float32 + ) + self.register_buffer("cos_sin_cache", cos_sin, persistent=False) + + # Radix attention (single-GPU: heads == tp_heads) + self.attn = RadixAttention( + num_heads=num_heads, + head_dim=head_dim, + scaling=self.scaling, + num_kv_heads=num_kv_heads, + layer_id=layer_id, + ) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: "ForwardBatch", + ) -> torch.Tensor: + if self.use_fused_qkv: + qkv = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + else: + q = self.q_proj(hidden_states) + k = self.k_proj(hidden_states) + v = self.v_proj(hidden_states) + + # Per-head QK normalization + q = self.q_norm(q.view(-1, self.num_heads, self.head_dim)) + k = self.k_norm(k.view(-1, self.num_kv_heads, self.head_dim)) + + # Apply M-RoPE. positions is [3, T] for prefill (3-D) or may arrive + # as [T] for purely text-only batches; expand to [3, T] in that case. + if positions.ndim == 1: + positions = positions.unsqueeze(0).expand(3, -1) + q, k = apply_mrope( + q, + k, + positions, + self.cos_sin_cache.to(q.dtype), + self.mrope_section, + self.mrope_interleaved, + ) + + q = q.reshape(-1, self.q_size) + k = k.reshape(-1, self.kv_size) + + # Attention with KV cache + attn_output = self.attn(q, k, v, forward_batch) + return self.o_proj(attn_output) + + +class Qwen3VLDecoderLayer(nn.Module): + """Single decoder layer for the Qwen3-VL text model.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + head_dim: int, + intermediate_size: int, + layer_id: int, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + quant_config=None, + prefix: str = "", + ): + super().__init__() + self.self_attn = Qwen3VLAttention( + hidden_size=hidden_size, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + head_dim=head_dim, + layer_id=layer_id, + rope_theta=rope_theta, + rms_norm_eps=rms_norm_eps, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"{prefix}.self_attn" if prefix else "self_attn", + ) + self.mlp = MLP( + hidden_size=hidden_size, + intermediate_size=intermediate_size, + activation="silu", + use_fused_gate_up_proj=True, + use_bias_gate_up=False, + use_bias_down=False, + quant_config=quant_config, + prefix=f"{prefix}.mlp" if prefix else "mlp", + ) + self.input_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + self.post_attention_layernorm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + forward_batch: "ForwardBatch", + deepstack_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + # Self-attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn(positions, hidden_states, forward_batch) + hidden_states = residual + hidden_states + + # Add deepstack embeddings after residual (matches HF ordering) + if deepstack_embeds is not None: + hidden_states = hidden_states + deepstack_embeds + + # MLP + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + return hidden_states + + +class Qwen3VLTextModel(nn.Module): + """Qwen3-VL text backbone (embedding + decoder layers + final norm).""" + + def __init__( + self, + vocab_size: int = 151936, + hidden_size: int = 4096, + intermediate_size: int = 22016, + num_hidden_layers: int = 32, + num_attention_heads: int = 32, + num_key_value_heads: int = 32, + head_dim: int = 128, + rope_theta: float = 5_000_000.0, + rms_norm_eps: float = 1e-6, + mrope_section: Tuple[int, int, int] = (24, 20, 20), + mrope_interleaved: bool = True, + max_position_embeddings: int = 32768, + quant_config=None, + ): + super().__init__() + self.hidden_size = hidden_size + self.num_hidden_layers = num_hidden_layers + + self.embed_tokens = nn.Embedding(vocab_size, hidden_size) + + self.layers = nn.ModuleList( + [ + Qwen3VLDecoderLayer( + hidden_size=hidden_size, + num_heads=num_attention_heads, + num_kv_heads=num_key_value_heads, + head_dim=head_dim, + intermediate_size=intermediate_size, + layer_id=layer_id, + rope_theta=rope_theta, + rms_norm_eps=rms_norm_eps, + mrope_section=mrope_section, + mrope_interleaved=mrope_interleaved, + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + prefix=f"model.layers.{layer_id}", + ) + for layer_id in range(num_hidden_layers) + ] + ) + + self.norm = RMSNorm(hidden_size, eps=rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: "ForwardBatch", + input_embeds: Optional[torch.Tensor] = None, + input_deepstack_embeds: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + if input_embeds is None: + hidden_states = self.embed_tokens(input_ids) + else: + hidden_states = input_embeds + + for layer_idx, layer in enumerate(self.layers): + ds_embeds = _get_deepstack_embeds( + layer_idx, input_deepstack_embeds, self.hidden_size + ) + hidden_states = layer( + positions, + hidden_states, + forward_batch, + deepstack_embeds=ds_embeds, + ) + + return self.norm(hidden_states) + + +def _get_deepstack_embeds( + layer_idx: int, + input_deepstack_embeds: Optional[torch.Tensor], + hidden_size: int, +) -> Optional[torch.Tensor]: + """Extract deepstack embeddings for a specific decoder layer.""" + if input_deepstack_embeds is None: + return None + num_deepstack = input_deepstack_embeds.shape[-1] // hidden_size + if layer_idx >= num_deepstack: + return None + start = hidden_size * layer_idx + return input_deepstack_embeds[:, start : start + hidden_size] + + +# --------------------------------------------------------------------------- +# Full Model: Qwen3VLForConditionalGeneration +# --------------------------------------------------------------------------- + + +class Qwen3VLForConditionalGeneration(nn.Module): + """Qwen3-VL multimodal model for conditional generation. + + Combines a vision encoder and text decoder. During prefill, image/video + tokens are replaced with visual features from the vision encoder. + During decode, the model runs only the text decoder. + + Forward interface:: + + logits = model.forward(input_ids, positions, forward_batch) + """ + + def __init__(self, config, quant_config=None) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + + text_config = getattr(config, "text_config", config) + vision_config = getattr(config, "vision_config", None) + + # Vision encoder — NOT quantized + if vision_config is not None: + self.visual = Qwen3VLVisionModel( + depth=getattr(vision_config, "depth", 27), + hidden_size=getattr(vision_config, "hidden_size", 1152), + hidden_act=getattr(vision_config, "hidden_act", "gelu_pytorch_tanh"), + intermediate_size=getattr(vision_config, "intermediate_size", 4304), + num_heads=getattr(vision_config, "num_heads", 16), + in_channels=getattr(vision_config, "in_channels", 3), + patch_size=getattr(vision_config, "patch_size", 16), + spatial_merge_size=getattr(vision_config, "spatial_merge_size", 2), + temporal_patch_size=getattr(vision_config, "temporal_patch_size", 2), + out_hidden_size=getattr(vision_config, "out_hidden_size", 3584), + num_position_embeddings=getattr( + vision_config, "num_position_embeddings", 2304 + ), + deepstack_visual_indexes=getattr( + vision_config, "deepstack_visual_indexes", [8, 16, 24] + ), + norm_eps=getattr(text_config, "rms_norm_eps", 1e-6), + ) + else: + self.visual = None + + # Text decoder + hidden_size = getattr(text_config, "hidden_size", 4096) + vocab_size = getattr(text_config, "vocab_size", 151936) + + # M-RoPE configuration -- mrope_section lives inside rope_scaling, + # NOT as a top-level attribute of text_config. + rope_scaling = getattr(text_config, "rope_scaling", None) or {} + if isinstance(rope_scaling, dict): + mrope_section = rope_scaling.get("mrope_section", [24, 20, 20]) + mrope_interleaved = rope_scaling.get("mrope_interleaved", True) + else: + mrope_section = getattr(rope_scaling, "mrope_section", [24, 20, 20]) + mrope_interleaved = getattr(rope_scaling, "mrope_interleaved", True) + max_position_embeddings = getattr(text_config, "max_position_embeddings", 32768) + + self.model = Qwen3VLTextModel( + vocab_size=vocab_size, + hidden_size=hidden_size, + intermediate_size=getattr(text_config, "intermediate_size", 22016), + num_hidden_layers=getattr(text_config, "num_hidden_layers", 32), + num_attention_heads=getattr(text_config, "num_attention_heads", 32), + num_key_value_heads=getattr(text_config, "num_key_value_heads", 32), + head_dim=getattr(text_config, "head_dim", 128), + rope_theta=getattr(text_config, "rope_theta", 5_000_000.0), + rms_norm_eps=getattr(text_config, "rms_norm_eps", 1e-6), + mrope_section=tuple(mrope_section), + mrope_interleaved=bool(mrope_interleaved), + max_position_embeddings=max_position_embeddings, + quant_config=quant_config, + ) + + # LM head — following sglang's pattern: always use lm_head.weight + # for matmul in forward(), so it works whether lm_head is nn.Embedding + # (tied) or nn.Linear (untied). + tie_word_embeddings = getattr(config, "tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head = self.model.embed_tokens + else: + self.lm_head = nn.Linear(hidden_size, vocab_size, bias=False) + + # Token IDs for multimodal + self.image_token_id = getattr(config, "image_token_id", 151655) + self.video_token_id = getattr(config, "video_token_id", 151656) + self.vision_start_token_id = getattr(config, "vision_start_token_id", 151652) + + # Spatial merge size (needed for get_rope_index) + self.spatial_merge_size = ( + getattr(vision_config, "spatial_merge_size", 2) + if vision_config is not None + else 2 + ) + + # Deepstack config + if vision_config is not None: + ds_indexes = getattr(vision_config, "deepstack_visual_indexes", [8, 16, 24]) + self.num_deepstack_embeddings = len(ds_indexes) + else: + self.num_deepstack_embeddings = 0 + + self._hidden_size = hidden_size + + def get_input_embeddings(self) -> nn.Module: + return self.model.embed_tokens + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: "ForwardBatch", + ) -> torch.Tensor: + """Run forward pass for Qwen3-VL. + + Args: + input_ids: Flattened input token IDs, shape ``[num_tokens]``. + positions: Position IDs, shape ``[num_tokens]`` (1-D, from model + runner). Overridden internally with 3-D M-RoPE positions. + forward_batch: :class:`ForwardBatch` with attention metadata. + + Returns: + Logits tensor of shape ``[num_tokens, vocab_size]``. + """ + pixel_values = getattr(forward_batch, "pixel_values", None) + image_grid_thw = getattr(forward_batch, "image_grid_thw", None) + + # ------------------------------------------------------------------ + # Build 3-D M-RoPE positions + # ------------------------------------------------------------------ + if forward_batch.forward_mode.is_extend(): + # Prefill: compute per-sequence 3-D position IDs from input_ids + # and image grids, then store per-request deltas for future decode. + mrope_positions_list: List[torch.Tensor] = [] + deltas: List[int] = [] + image_idx_offset = 0 + + for i in range(forward_batch.batch_size): + start = int(forward_batch.extend_start_loc[i].item()) + length = int(forward_batch.extend_seq_lens[i].item()) + seq_ids = input_ids[start : start + length] + + # Determine how many images belong to this sequence. + num_img = int((seq_ids == self.vision_start_token_id).sum().item()) + if image_grid_thw is not None and num_img > 0: + thw_seq = image_grid_thw[ + image_idx_offset : image_idx_offset + num_img + ] + image_idx_offset += num_img + else: + thw_seq = None + + pos3d, delta = get_rope_index( + seq_ids, + thw_seq, + self.image_token_id, + self.vision_start_token_id, + self.spatial_merge_size, + ) + mrope_positions_list.append(pos3d) + deltas.append(delta) + + # Concatenate across sequences: [3, total_extend_tokens] + positions = torch.cat(mrope_positions_list, dim=1) + forward_batch.mrope_position_deltas = torch.tensor( + deltas, dtype=torch.int64, device=input_ids.device + ) + else: + # Decode: each sequence emits exactly one token. Apply the stored + # per-request delta so the position matches the image extent. + stored_deltas = getattr(forward_batch, "mrope_position_deltas", None) + if stored_deltas is not None: + pos_1d = forward_batch.positions + stored_deltas + else: + pos_1d = forward_batch.positions + positions = pos_1d.unsqueeze(0).expand(3, -1) # [3, batch_size] + + input_embeds = None + input_deepstack_embeds = None + + if ( + pixel_values is not None + and image_grid_thw is not None + and self.visual is not None + and not forward_batch.forward_mode.is_decode() + ): + # Run vision encoder + vision_features = self.visual(pixel_values, grid_thw=image_grid_thw) + + # Separate main embeddings and deepstack embeddings + if self.num_deepstack_embeddings > 0: + vision_embeds = vision_features[:, : self._hidden_size] + deepstack_embeds = vision_features[:, self._hidden_size :] + else: + vision_embeds = vision_features + deepstack_embeds = None + + # Get text embeddings and replace image tokens with vision features + input_embeds = self.model.embed_tokens(input_ids) + image_mask = input_ids == self.image_token_id + if image_mask.any(): + input_embeds[image_mask] = vision_embeds.to(input_embeds.dtype) + + # Build per-token deepstack embeddings + if deepstack_embeds is not None and image_mask.any(): + input_deepstack_embeds = torch.zeros( + input_embeds.shape[0], + deepstack_embeds.shape[-1], + dtype=input_embeds.dtype, + device=input_embeds.device, + ) + input_deepstack_embeds[image_mask] = deepstack_embeds.to( + input_embeds.dtype + ) + + # Text decoder + hidden_states = self.model( + input_ids, + positions, + forward_batch, + input_embeds=input_embeds, + input_deepstack_embeds=input_deepstack_embeds, + ) + + # Prune hidden_states before lm_head to avoid a wasteful + # [total_tokens, vocab] matmul during prefill. + # LogitsProcessor._get_pruned_states(): in extend mode only keep + # the last token of each sequence; in decode mode all rows are + # already one-per-sequence. + if forward_batch.forward_mode.is_extend(): + if ( + forward_batch.extend_start_loc is not None + and forward_batch.extend_seq_lens is not None + ): + last_index = ( + forward_batch.extend_start_loc + forward_batch.extend_seq_lens - 1 + ).long() + hidden_states = hidden_states[last_index] + else: + hidden_states = hidden_states[-1:] + + # LM head: always use weight matrix directly for the linear + # projection. Works for both nn.Embedding (tied) and nn.Linear + # (untied). + logits = torch.matmul( + hidden_states.to(self.lm_head.weight.dtype), + self.lm_head.weight.T, + ) + + # Return LogitsProcessorOutput so that ModelRunner._process_logits + # skips redundant last-token gathering. + from pymllm.executor.model_runner import LogitsProcessorOutput + + return LogitsProcessorOutput(next_token_logits=logits) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None: + """Load weights from a HuggingFace checkpoint. + + Handles weight name remapping between HuggingFace Qwen3-VL + checkpoints and this model's parameter names. + """ + # When quantized, the model has separate q/k/v and gate/up projections + # (no fused qkv_proj / gate_up_proj), so skip the stacking logic. + if self.quant_config is not None: + stacked_params_mapping = [] + else: + stacked_params_mapping = [ + # (param_name, weight_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".up_proj", 1), + (".gate_up_proj", ".gate_proj", 0), + ] + + params_dict = dict(self.named_parameters()) + + tie_word_embeddings = getattr(self.config, "tie_word_embeddings", False) + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + + # When weights are tied, lm_head.weight is the same tensor as + # embed_tokens.weight — skip the duplicate from the checkpoint. + if tie_word_embeddings and "lm_head.weight" in name: + continue + + name = _remap_weight_name(name) + + # Handle language model stacked parameters (QKV, gate_up) + handled = False + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name or "visual" in name: + continue + name = name.replace(weight_name, param_name) + if name not in params_dict: + continue + _load_stacked_weight(params_dict[name], loaded_weight, shard_id) + handled = True + break + + if handled: + continue + + # Handle vision encoder QKV stacking + if "visual" in name: + for qkv_key in (".attn.q.", ".attn.k.", ".attn.v."): + if qkv_key not in name: + continue + qkv_name = name.replace(qkv_key, ".attn.qkv_proj.") + if qkv_name in params_dict: + shard = {"q": 0, "k": 1, "v": 2}[qkv_key[-2]] + _load_vision_qkv_weight( + params_dict[qkv_name], loaded_weight, shard + ) + handled = True + break + + if handled: + continue + + # Direct parameter loading + if name in params_dict: + param = params_dict[name] + loader = getattr(param, "weight_loader", None) + if loader is not None: + loader(param, loaded_weight) + elif param.data.shape == loaded_weight.shape: + param.data.copy_(loaded_weight) + else: + logger.warning( + "Shape mismatch: param %s (%s) vs loaded (%s), skipping.", + name, + param.data.shape, + loaded_weight.shape, + ) + + +# --------------------------------------------------------------------------- +# Weight loading helpers +# --------------------------------------------------------------------------- + + +def _remap_weight_name(name: str) -> str: + """Remap HuggingFace weight names to pymllm parameter names.""" + # transformers >= v4.52: model.language_model.* -> model.* + if name.startswith("model.language_model."): + name = name.replace("model.language_model.", "model.", 1) + # model.visual.* -> visual.* + elif name.startswith("model.visual."): + name = name.replace("model.visual.", "visual.", 1) + + # Vision attention QKV renaming (fused weights in checkpoint) + if "visual" in name: + name = name.replace("attn.qkv.", "attn.qkv_proj.") + + return name + + +def _load_stacked_weight( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_id, +) -> None: + """Load one shard (q/k/v or gate/up) into a fused parameter. + + For QKV with GQA (grouped-query attention), Q has a different size + from K and V. The fused layout is ``[Q, K, V]`` where + ``Q_size = total - 2 * KV_size``. We must use cumulative offsets + rather than ``idx * shard_size`` to handle the asymmetry correctly. + """ + if isinstance(shard_id, str): + # QKV fused layout: [Q, K, V] + # Q may have a different size from K/V (GQA). + total_size = param.data.shape[0] + shard_size = loaded_weight.shape[0] + if shard_id == "q": + param.data[0:shard_size].copy_(loaded_weight) + elif shard_id == "k": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size : q_size + kv_size].copy_(loaded_weight) + elif shard_id == "v": + kv_size = shard_size + q_size = total_size - 2 * kv_size + param.data[q_size + kv_size : q_size + 2 * kv_size].copy_(loaded_weight) + else: + # gate_up: 0 -> gate, 1 -> up (same size, idx*size is correct) + shard_size = loaded_weight.shape[0] + param.data[shard_id * shard_size : (shard_id + 1) * shard_size].copy_( + loaded_weight + ) + + +def _load_vision_qkv_weight( + param: nn.Parameter, + loaded_weight: torch.Tensor, + shard_idx: int, +) -> None: + """Load a Q, K, or V weight shard into a fused QKV parameter.""" + shard_size = param.data.shape[0] // 3 + start = shard_idx * shard_size + param.data[start : start + shard_size].copy_(loaded_weight) diff --git a/pymllm/orchestrator/__init__.py b/pymllm/orchestrator/__init__.py new file mode 100644 index 000000000..f1716d794 --- /dev/null +++ b/pymllm/orchestrator/__init__.py @@ -0,0 +1,48 @@ +"""Orchestrator module for distributed computation.""" + +from pymllm.orchestrator.group_coordinator import ( + GroupCoordinator, + divide, + split_tensor_along_dim, +) +from pymllm.orchestrator.parallel_state import ( + data_parallel_all_reduce, + get_data_parallel_rank, + get_data_parallel_world_size, + get_dp_group, + get_pipeline_model_parallel_rank, + get_pipeline_model_parallel_world_size, + get_pp_group, + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + get_tp_group, + initialize_model_parallel, + model_parallel_is_initialized, + tensor_model_parallel_all_gather, + tensor_model_parallel_all_reduce, +) + +__all__ = [ + # GroupCoordinator + "GroupCoordinator", + "divide", + "split_tensor_along_dim", + # TP + "get_tp_group", + "get_tensor_model_parallel_rank", + "get_tensor_model_parallel_world_size", + "tensor_model_parallel_all_reduce", + "tensor_model_parallel_all_gather", + # DP + "get_dp_group", + "get_data_parallel_rank", + "get_data_parallel_world_size", + "data_parallel_all_reduce", + # PP + "get_pp_group", + "get_pipeline_model_parallel_rank", + "get_pipeline_model_parallel_world_size", + # State + "initialize_model_parallel", + "model_parallel_is_initialized", +] diff --git a/pymllm/orchestrator/cuda_ipc_transport.py b/pymllm/orchestrator/cuda_ipc_transport.py new file mode 100644 index 000000000..938132c8b --- /dev/null +++ b/pymllm/orchestrator/cuda_ipc_transport.py @@ -0,0 +1,648 @@ +""" +CUDA IPC Transport for zero-copy GPU tensor sharing between processes. + +## Background + +When sharing CUDA tensors between processes, there are two fundamentally different paths: + +1. **CPU shared memory path** (``enable_shared_queue=True, enable_cuda_ipc=False``): + GPU tensors are moved to CPU / POSIX shared memory via ``tensor.share_memory_()``. + This is safe but incurs a GPU→CPU copy which is expensive for large vision features. + +2. **CUDA IPC path** (``enable_cuda_ipc=True``): + GPU tensors stay on GPU. PyTorch's ``storage._share_cuda_()`` yields a serialisable + IPC handle; the receiver calls ``UntypedStorage._new_shared_cuda(*handle)`` to map + the same physical GPU memory without any copy. + +These two paths are **mutually exclusive for GPU tensors**. ``enable_cuda_ipc`` takes +priority; when active the CPU-copy step in ``TensorQueue._make_tensors_shareable`` is +skipped. + +## CUDA IPC memory-leak problem and its fix + +PyTorch never releases the GPU allocation backing an IPC-exported tensor until the +*sending* process exits. If we export raw model tensors we permanently leak GPU memory. + +**Solution** (pool-based recycling via ``MmItemMemoryPool``): + +* Allocate a single, fixed-size GPU workspace (``MmItemMemoryPool``). +* For each outgoing GPU tensor, copy it into a chunk of the workspace and export the + *chunk* via IPC (the workspace is never freed; its chunks are recycled). +* After the receiving process has finished with the data it writes a sync flag + (``ShmSyncBuffer``) to signal that the chunk may be reused. +* A background recycler thread in the sender walks ``occupied_chunks`` and returns + chunks whose sync flag has been incremented back to ``available_chunks``. + +## Transport modes + +``TensorTransportMode``: +* ``"default"`` – CPU/shared-memory path; no CUDA IPC. +* ``"cuda_ipc"`` – Simple CUDA IPC: wraps GPU tensors in ``TransportProxyTensor`` + (a ``torch.Tensor`` subclass whose ``__getstate__``/``__setstate__`` use + ``_share_cuda_``). Suitable for single-process-group scenarios; incurs the + PyTorch memory-leak noted above. +* ``"cuda_ipc_pool"`` – Pool-based CUDA IPC: copies GPU tensors into a pre-allocated + ``MmItemMemoryPool`` and wraps the slice in ``CudaIpcTensorTransportProxy``. + The pool is recycled, so there is no memory leak. +""" + +from __future__ import annotations + +import fcntl +import logging +import threading +import time +from multiprocessing import shared_memory +from typing import Any, Dict, List, Literal, Optional, Tuple + +import numpy as np +import torch + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Type alias for transport mode +# --------------------------------------------------------------------------- + +TensorTransportMode = Literal["default", "cuda_ipc", "cuda_ipc_pool"] + + +# --------------------------------------------------------------------------- +# ShmSyncBuffer – a tiny POSIX shared memory float used as a sync counter +# --------------------------------------------------------------------------- + + +class ShmSyncBuffer: + """A single float32 in POSIX shared memory used as a sync counter. + + The sender resets it to 0 before exporting a chunk. The receiver + increments it (atomically under a file lock) once it has finished copying + data out of the chunk. When the value reaches the number of consumers + (``tp_size``) the sender recycles the chunk. + """ + + def __init__(self, byte_size: int = 4) -> None: + self.buffer = shared_memory.SharedMemory(create=True, size=byte_size) + self._arr = np.ndarray(1, dtype=np.float32, buffer=self.buffer.buf) + self._arr *= 0 # initialise to 0 + self.meta_data: Dict[str, Any] = { + "handle": self.buffer.name, + "shape": self._arr.shape, + "dtype": str(self._arr.dtype), + } + + # ------------------------------------------------------------------ + # Helpers consumed by the *receiver* side + # ------------------------------------------------------------------ + + @staticmethod + def open( + meta_data: Dict[str, Any], + ) -> Tuple[shared_memory.SharedMemory, np.ndarray]: + """Open an existing ShmSyncBuffer from the metadata dict.""" + shm = shared_memory.SharedMemory(name=meta_data["handle"]) + arr = np.ndarray(meta_data["shape"], dtype=meta_data["dtype"], buffer=shm.buf) + return shm, arr + + def __del__(self) -> None: + try: + self.buffer.close() + self.buffer.unlink() + except Exception: + pass + + +# Lock file used to serialise writes to sync flags across processes +_SHM_LOCK_FILE = "/tmp/pymllm_shm_wr_lock.lock" + + +def _increment_sync_flag(meta_data: Dict[str, Any]) -> None: + """Increment the sync flag by 1 under a process-level file lock.""" + shm, arr = ShmSyncBuffer.open(meta_data) + try: + open(_SHM_LOCK_FILE, "a").close() # ensure file exists + with open(_SHM_LOCK_FILE, "w+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + arr += 1.0 + fcntl.flock(f, fcntl.LOCK_UN) + finally: + shm.close() + + +# --------------------------------------------------------------------------- +# MmItemMemoryChunk +# --------------------------------------------------------------------------- + + +class MmItemMemoryChunk: + """A contiguous slice of the ``MmItemMemoryPool`` workspace tensor.""" + + def __init__(self, area: Tuple[int, int], sync_flag: ShmSyncBuffer) -> None: + self.area = area + self.sync_flag = sync_flag + + @property + def mem_size(self) -> int: + return self.area[1] - self.area[0] + + @property + def start(self) -> int: + return self.area[0] + + @property + def end(self) -> int: + return self.area[1] + + def try_to_recycle(self, num_consumers: int = 1) -> bool: + """Return True if all consumers have finished and the chunk can be reused.""" + val = float(self.sync_flag._arr.item()) + logger.debug( + "[try_to_recycle] area=%s flag=%.0f consumers=%d", + self.area, + val, + num_consumers, + ) + if val >= float(num_consumers): + self.sync_flag._arr *= 0.0 # reset for next use + return True + return False + + +# --------------------------------------------------------------------------- +# MmItemMemoryPool – pre-allocated GPU workspace to avoid IPC memory leaks +# --------------------------------------------------------------------------- + + +class MmItemMemoryPool: + """Pre-allocated GPU memory pool for CUDA IPC tensor transport. + + Chunks are allocated from a contiguous ``torch.int8`` tensor on GPU. + A background thread periodically recycles chunks whose sync flags show + that all consumers have finished reading. + + Args: + memory_size: Pool size in **bytes**. + recycle_interval: How often (seconds) the recycler thread runs. + num_consumers: Number of consumer processes (tp_size). Each consumer + must increment the sync flag once before a chunk is recycled. + device: CUDA device index. + """ + + def __init__( + self, + memory_size: int, + recycle_interval: float = 0.1, + num_consumers: int = 1, + device: int = 0, + ) -> None: + self.num_consumers = num_consumers + self._recycle_interval = recycle_interval + self._lock = threading.Lock() + self._stop = False + + with torch.cuda.device(device): + self.memory_pool: torch.Tensor = torch.empty( + memory_size, dtype=torch.int8, device=f"cuda:{device}" + ).contiguous() + + init_chunk = MmItemMemoryChunk((0, memory_size), self._new_sync_buffer()) + self.available_chunks: List[MmItemMemoryChunk] = [init_chunk] + self.occupied_chunks: List[MmItemMemoryChunk] = [] + # Pool of reusable ShmSyncBuffer objects (returned from recycled chunks) + self._sync_pool: List[ShmSyncBuffer] = [] + + self._recycler = threading.Thread( + target=self._recycle_loop, + name="MmItemMemoryPoolRecycler", + daemon=True, + ) + self._recycler.start() + + logger.info( + "MmItemMemoryPool: %d MB on cuda:%d, recycle_interval=%.2fs", + memory_size // (1024 * 1024), + device, + recycle_interval, + ) + + # ------------------------------------------------------------------ + # Sync buffer management + # ------------------------------------------------------------------ + + def _new_sync_buffer(self) -> ShmSyncBuffer: + if self._sync_pool: + return self._sync_pool.pop() + return ShmSyncBuffer() + + def _return_sync_buffer(self, buf: ShmSyncBuffer) -> None: + buf._arr *= 0.0 # reset counter + self._sync_pool.append(buf) + + # ------------------------------------------------------------------ + # Allocation + # ------------------------------------------------------------------ + + def _get_available_chunk(self, src: torch.Tensor) -> Optional[MmItemMemoryChunk]: + """Best-fit allocation: find the smallest available chunk >= src size.""" + needed = src.numel() * src.element_size() + best: Optional[MmItemMemoryChunk] = None + for chunk in self.available_chunks: + if chunk.mem_size >= needed: + if best is None or chunk.mem_size < best.mem_size: + best = chunk + if best is None: + return None + + # Split the selected chunk + occupied_area = (best.start, best.start + needed) + occupied = MmItemMemoryChunk(occupied_area, best.sync_flag) + self.occupied_chunks.append(occupied) + self.available_chunks.remove(best) + + remainder = (occupied.end, best.end) + if remainder[0] < remainder[1]: + split = MmItemMemoryChunk(remainder, self._new_sync_buffer()) + self.available_chunks.append(split) + + return occupied + + def get_slice_with_flag( + self, src: torch.Tensor + ) -> Tuple[Optional[Dict[str, Any]], Optional[torch.Tensor]]: + """Allocate a pool slice for *src* and return ``(sync_flag_meta, slice_tensor)``. + + Thread-safe. Returns ``(None, None)`` if the pool is full. + """ + with self._lock: + chunk = self._get_available_chunk(src) + if chunk is None: + logger.warning( + "MmItemMemoryPool full (%d occupied, %d available); " + "falling back to CPU transport", + len(self.occupied_chunks), + len(self.available_chunks), + ) + return None, None + pool_slice = self.memory_pool[chunk.start : chunk.end] + return chunk.sync_flag.meta_data, pool_slice + + # ------------------------------------------------------------------ + # Recycling + # ------------------------------------------------------------------ + + def _recycle_loop(self) -> None: + while not self._stop: + try: + with self._lock: + self._recycle_chunks() + self._merge_chunks() + except Exception as exc: + logger.warning( + "MmItemMemoryPool recycler error: %s", exc, exc_info=True + ) + time.sleep(self._recycle_interval) + + def _recycle_chunks(self) -> None: + new_occupied: List[MmItemMemoryChunk] = [] + for chunk in self.occupied_chunks: + if chunk.try_to_recycle(self.num_consumers): + self._return_sync_buffer(chunk.sync_flag) + chunk.sync_flag = self._new_sync_buffer() + self.available_chunks.append(chunk) + else: + new_occupied.append(chunk) + self.occupied_chunks = new_occupied + + def _merge_chunks(self) -> None: + """Coalesce adjacent free chunks to reduce fragmentation.""" + merged: List[MmItemMemoryChunk] = [] + for chunk in sorted(self.available_chunks, key=lambda c: c.start): + if merged and merged[-1].end == chunk.start: + prev = merged.pop() + self._return_sync_buffer(chunk.sync_flag) + merged.append( + MmItemMemoryChunk((prev.start, chunk.end), prev.sync_flag) + ) + else: + merged.append(chunk) + self.available_chunks = merged + + def shutdown(self) -> None: + self._stop = True + if self._recycler.is_alive(): + self._recycler.join(timeout=2.0) + + +# --------------------------------------------------------------------------- +# CudaIpcTensorTransportProxy – pool-based CUDA IPC proxy object +# --------------------------------------------------------------------------- + + +class CudaIpcTensorTransportProxy: + """Proxy that carries a CUDA IPC handle for a pool-slice tensor. + + The *sender* process: + 1. Copies the source tensor into a ``MmItemMemoryPool`` slice (int8 view). + 2. Wraps the slice in this proxy, which captures the CUDA IPC handle via + ``storage._share_cuda_()``. + 3. Sends the proxy through ``multiprocessing.Queue`` (pickle). + + The *receiver* process: + 1. Calls :meth:`reconstruct_on_device` to map the IPC memory and copy it + into a fresh local tensor. + 2. The copy increments the sync flag, allowing the sender's recycler to + reclaim the pool slice. + + Fallback: if ``_share_cuda_()`` fails (e.g. TP ranks), ``tensor_data`` holds + the raw tensor (which will be pickled the normal way, incurring serialization cost). + """ + + def __init__( + self, + data: torch.Tensor, + info_data: torch.Tensor, + sync_buffer_meta: Dict[str, Any], + ) -> None: + if not isinstance(data, torch.Tensor) or not isinstance( + info_data, torch.Tensor + ): + raise TypeError( + f"data and info_data must be torch.Tensors, got {type(data)}, {type(info_data)}" + ) + + self.sync_data_meta = sync_buffer_meta + self._state = self._build_state(data, info_data) + self._reconstructed: Optional[torch.Tensor] = None + self._shm: Optional[shared_memory.SharedMemory] = None + + def _build_state( + self, data: torch.Tensor, info_data: torch.Tensor + ) -> Dict[str, Any]: + try: + storage = data.untyped_storage() + handle = storage._share_cuda_() + return { + "ipc_handle": { + "handle": handle, + "shape": data.shape, + "dtype": data.dtype, + "stride": data.stride(), + "device_index": data.device.index, + "storage_offset": data.storage_offset(), + "target_shape": info_data.shape, + "target_dtype": info_data.dtype, + }, + "tensor_data": None, + } + except Exception as exc: + logger.warning( + "CudaIpcTensorTransportProxy: _share_cuda_() failed (%s); " + "falling back to direct tensor.", + exc, + ) + return {"ipc_handle": None, "tensor_data": data} + + def reconstruct_on_device(self, device_index: Optional[int] = None) -> torch.Tensor: + """Map IPC memory and copy into a new local tensor. + + This **must** be called from the *receiver* process. After the copy + the sync flag is incremented so the sender can recycle the pool chunk. + """ + if self._reconstructed is not None: + return self._reconstructed + + state = self._state + if state["ipc_handle"] is not None: + h = state["ipc_handle"] + source_device = torch.device(f"cuda:{h['device_index']}") + target_device = ( + source_device + if device_index is None + else torch.device(f"cuda:{device_index}") + ) + with torch.cuda.device(source_device): + storage = torch.UntypedStorage._new_shared_cuda(*h["handle"]) + slice_tensor = torch.empty( + 0, dtype=h["dtype"], device=source_device + ).set_( + storage, + storage_offset=h["storage_offset"], + size=h["shape"], + stride=h["stride"], + ) + + result = torch.empty( + h["target_shape"], dtype=h["target_dtype"], device=target_device + ).contiguous() + result.view(torch.int8).view(-1).copy_(slice_tensor) + + # Signal sender that the chunk can be recycled + _increment_sync_flag(self.sync_data_meta) + elif state["tensor_data"] is not None: + result = state["tensor_data"] + if device_index is not None: + result = result.to(f"cuda:{device_index}", non_blocking=True) + else: + raise RuntimeError("CudaIpcTensorTransportProxy: invalid state") + + self._reconstructed = result + return result + + +# --------------------------------------------------------------------------- +# TransportProxyTensor – simple CUDA IPC via torch.Tensor subclass + pickle +# --------------------------------------------------------------------------- + + +class TransportProxyTensor(torch.Tensor): + """A ``torch.Tensor`` subclass whose pickle uses CUDA IPC handles. + + When ``transport_mode == "cuda_ipc"`` and the tensor is on CUDA, + ``__getstate__`` exports the tensor via ``storage._share_cuda_()`` instead + of serialising the raw data. ``__setstate__`` reconstructs it in the + receiving process via ``UntypedStorage._new_shared_cuda``. + + Caveat: The underlying GPU allocation is never freed until the *sender* + process exits (PyTorch limitation). Prefer ``"cuda_ipc_pool"`` mode for + long-running services to avoid GPU memory leaks. + + When the tensor is on CPU or ``transport_mode == "default"``, the tensor + is serialised normally (pickle of raw data). + """ + + @staticmethod + def __new__( + cls, + data: torch.Tensor, + transport_mode: TensorTransportMode = "default", + ) -> "TransportProxyTensor": + if not isinstance(data, torch.Tensor): + raise TypeError(f"data must be a torch.Tensor, got {type(data)}") + instance = data.as_subclass(cls) + instance._transport_mode = transport_mode + return instance + + def __getstate__(self) -> Dict[str, Any]: + state: Dict[str, Any] = { + "transport_mode": self._transport_mode, + "tensor_data": None, + "ipc_extra": None, + } + if self._transport_mode == "cuda_ipc" and self.is_cuda: + try: + storage = self.untyped_storage() + handle = storage._share_cuda_() + state["ipc_extra"] = { + "handle": handle, + "shape": self.shape, + "dtype": self.dtype, + "stride": self.stride(), + "device_index": self.device.index, + "storage_offset": self.storage_offset(), + } + except Exception as exc: + logger.warning( + "TransportProxyTensor: _share_cuda_() failed (%s); falling back.", + exc, + ) + state["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + else: + state["transport_mode"] = "default" + state["tensor_data"] = self.as_subclass(torch.Tensor) + return state + + def __setstate__(self, state: Dict[str, Any]) -> None: + self._transport_mode = state["transport_mode"] + if state["transport_mode"] == "cuda_ipc" and state["ipc_extra"] is not None: + h = state["ipc_extra"] + target = torch.device(f"cuda:{h['device_index']}") + try: + with torch.cuda.device(target): + storage = torch.UntypedStorage._new_shared_cuda(*h["handle"]) + reconstructed = torch.empty( + 0, dtype=h["dtype"], device=target + ).set_( + storage, + storage_offset=h["storage_offset"], + size=h["shape"], + stride=h["stride"], + ) + self.set_(reconstructed) + except Exception as exc: + logger.error("TransportProxyTensor: failed to open IPC handle: %s", exc) + raise + elif state["tensor_data"] is not None: + self.set_(state["tensor_data"]) + else: + raise RuntimeError("TransportProxyTensor: invalid state – no tensor data") + + @property + def transport_mode(self) -> TensorTransportMode: + return getattr(self, "_transport_mode", "default") + + +# --------------------------------------------------------------------------- +# Helpers: wrap / unwrap mm_inputs dicts +# --------------------------------------------------------------------------- + + +def wrap_mm_inputs_for_ipc( + mm_inputs: Optional[Dict[str, Any]], + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"] = None, +) -> Optional[Dict[str, Any]]: + """Recursively wrap CUDA tensors in *mm_inputs* for IPC transport. + + Args: + mm_inputs: Nested dict/list of tensors and other data. + transport_mode: One of ``"default"``, ``"cuda_ipc"``, ``"cuda_ipc_pool"``. + pool: Required when ``transport_mode == "cuda_ipc_pool"``. + + Returns: + A new data structure with CUDA tensors replaced by IPC proxies. + CPU tensors are left unchanged (they will be shared via ``share_memory_()`` + or normal pickling downstream). + """ + if mm_inputs is None: + return None + return _wrap_recursive(mm_inputs, transport_mode, pool) + + +def _wrap_recursive( + data: Any, + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"], +) -> Any: + if isinstance(data, torch.Tensor) and data.is_cuda: + return _wrap_cuda_tensor(data, transport_mode, pool) + elif isinstance(data, dict): + return {k: _wrap_recursive(v, transport_mode, pool) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + wrapped = [_wrap_recursive(item, transport_mode, pool) for item in data] + return type(data)(wrapped) + else: + return data + + +def _wrap_cuda_tensor( + tensor: torch.Tensor, + transport_mode: TensorTransportMode, + pool: Optional["MmItemMemoryPool"], +) -> Any: + if transport_mode == "cuda_ipc": + return TransportProxyTensor(tensor, transport_mode="cuda_ipc") + + if transport_mode == "cuda_ipc_pool": + if pool is None: + raise ValueError("pool must be provided for transport_mode='cuda_ipc_pool'") + sync_meta, pool_slice = pool.get_slice_with_flag(tensor) + if pool_slice is not None: + # Copy tensor bytes into the pool slice + pool_slice.copy_(tensor.view(torch.int8).view(-1), non_blocking=True) + return CudaIpcTensorTransportProxy( + data=pool_slice, + info_data=tensor, + sync_buffer_meta=sync_meta, + ) + else: + # Pool full – fall back to simple IPC (with potential memory leak) + logger.warning( + "Pool full; falling back to simple CUDA IPC (potential memory leak)" + ) + return TransportProxyTensor(tensor, transport_mode="cuda_ipc") + + # "default" – move to CPU shared memory (handled by share_memory_() downstream) + return tensor + + +def unwrap_mm_inputs_from_ipc( + mm_inputs: Optional[Dict[str, Any]], + device_index: Optional[int] = None, +) -> Optional[Dict[str, Any]]: + """Recursively reconstruct tensors from IPC proxy objects. + + Call this in the *receiver* process after getting data from the queue. + + Args: + mm_inputs: Data structure possibly containing IPC proxy objects. + device_index: If not None, move reconstructed tensors to this device. + """ + if mm_inputs is None: + return None + return _unwrap_recursive(mm_inputs, device_index) + + +def _unwrap_recursive(data: Any, device_index: Optional[int]) -> Any: + if isinstance(data, CudaIpcTensorTransportProxy): + return data.reconstruct_on_device(device_index) + elif isinstance(data, TransportProxyTensor): + # Already reconstructed during unpickling; just return as plain tensor + return data.as_subclass(torch.Tensor) + elif isinstance(data, dict): + return {k: _unwrap_recursive(v, device_index) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [_unwrap_recursive(item, device_index) for item in data] + return type(data)(result) + else: + return data diff --git a/pymllm/orchestrator/detokenizer_process.py b/pymllm/orchestrator/detokenizer_process.py new file mode 100644 index 000000000..1bbda98d0 --- /dev/null +++ b/pymllm/orchestrator/detokenizer_process.py @@ -0,0 +1,209 @@ +""" +DetokenizerProcess -- subprocess that converts token IDs back to text. + +Receives ``BatchTokenIDOut``-style dicts from the SchedulerProcess, +detokenizes them, and forwards the decoded strings to the +RequestResponseProcess. +""" + +import logging +from multiprocessing.connection import Connection +from typing import Any, Dict, List, Optional + +import zmq + +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging + +logger = logging.getLogger(__name__) + + +class DetokenizerProcess: + """Runs inside a subprocess. Detokenizes finished outputs.""" + + def __init__( + self, + recv_from_scheduler_addr: str, + send_to_rr_addr: str, + tokenizer_cfg: Optional[Dict[str, Any]] = None, + ): + self._recv_from_scheduler_addr = recv_from_scheduler_addr + self._send_to_rr_addr = send_to_rr_addr + self._tokenizer_cfg = tokenizer_cfg or {} + + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_scheduler: Optional[zmq.Socket] = None + self._send_to_rr: Optional[zmq.Socket] = None + + self._tokenizer = None + # Track previous decoded text per rid for incremental (delta) output + self._rid_to_prev_text: Dict[str, str] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + self._recv_from_scheduler = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_scheduler_addr, + bind=False, + ) + self._send_to_rr = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_rr_addr, + bind=False, + ) + + def init_tokenizer(self) -> None: + """Load the tokenizer from the configured path.""" + tokenizer_path = self._tokenizer_cfg.get("tokenizer_path") + if tokenizer_path is None: + logger.warning( + "No tokenizer_path in tokenizer_cfg; detokenization disabled" + ) + return + + from transformers import AutoTokenizer + + trust_remote_code = self._tokenizer_cfg.get("trust_remote_code", False) + self._tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + trust_remote_code=trust_remote_code, + ) + logger.info("Detokenizer loaded tokenizer from %s", tokenizer_path) + + def event_loop(self) -> None: + """Infinite loop: recv token IDs -> detokenize -> send text to RR.""" + logger.info("DetokenizerProcess event loop started") + while True: + token_id_out = self._recv_from_scheduler.recv_pyobj() + results = self._detokenize(token_id_out) + for result in results: + self._send_to_rr.send_pyobj(result) + + # ------------------------------------------------------------------ + # Detokenization + # ------------------------------------------------------------------ + + def _detokenize(self, token_id_out: Dict[str, Any]) -> List[Dict[str, Any]]: + """Convert token IDs to text and fan out one result per rid. + + The scheduler sends a batch dict with parallel lists keyed by + ``"rids"``, ``"output_ids"``, ``"finished_reasons"``, etc. + This method decodes each rid's output_ids and produces one result + dict per rid with keys ``"rid"`` (singular) and ``"finished"`` + (bool) as expected by ``RequestResponseProcess._recv_loop``. + """ + rids: List[str] = token_id_out.get("rids", []) + output_ids: List[int] = token_id_out.get("output_ids", []) + finished_reasons: List[Optional[str]] = token_id_out.get("finished_reasons", []) + + # NOTE: The scheduler currently sends one rid per message. The shared + # output_ids list is the complete output for that single rid. If + # batched sending is ever added, each rid will need its own output_ids. + if len(rids) > 1: + logger.warning( + "Detokenizer received %d rids in one message; " + "output_ids are shared -- results may be incorrect", + len(rids), + ) + decode_ids: List[int] = token_id_out.get("decode_ids", []) + skip_special_tokens_list: List[bool] = token_id_out.get( + "skip_special_tokens", [] + ) + prompt_tokens_list: List[int] = token_id_out.get("prompt_tokens", []) + completion_tokens_list: List[int] = token_id_out.get("completion_tokens", []) + + results: List[Dict[str, Any]] = [] + + for i, rid in enumerate(rids): + finished_reason = finished_reasons[i] if i < len(finished_reasons) else None + is_finished = finished_reason is not None + skip_special = ( + skip_special_tokens_list[i] + if i < len(skip_special_tokens_list) + else True + ) + prompt_tokens = prompt_tokens_list[i] if i < len(prompt_tokens_list) else 0 + completion_tokens = ( + completion_tokens_list[i] if i < len(completion_tokens_list) else 0 + ) + + # Decode text from output_ids + if self._tokenizer is not None: + text = self._tokenizer.decode( + output_ids, + skip_special_tokens=skip_special, + ) + else: + text = "" + + # Compute incremental delta by diffing against previous text + prev_text = self._rid_to_prev_text.get(rid, "") + delta_text = text[len(prev_text):] + self._rid_to_prev_text[rid] = text + + # Clean up tracking when request finishes + if is_finished: + self._rid_to_prev_text.pop(rid, None) + + result: Dict[str, Any] = { + "rid": rid, + "text": text, + "delta": delta_text, + "output_token_ids": list(output_ids), + "finished": is_finished, + "finished_reason": finished_reason, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + results.append(result) + + return results + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._recv_from_scheduler is not None: + self._recv_from_scheduler.close() + if self._send_to_rr is not None: + self._send_to_rr.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_detokenizer_process( + recv_from_scheduler_addr: str, + send_to_rr_addr: str, + pipe_writer: Connection, + tokenizer_cfg: Optional[Dict[str, Any]] = None, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``.""" + setup_subprocess_logging((tokenizer_cfg or {}).get("log_level", "info")) + + # Limit CPU threads — detokenizer doesn't need PyTorch parallelism. + import torch + torch.set_num_threads(1) + + proc = DetokenizerProcess( + recv_from_scheduler_addr, + send_to_rr_addr, + tokenizer_cfg=tokenizer_cfg, + ) + proc.init_sockets() + proc.init_tokenizer() + + pipe_writer.send({"status": "ready", "process": "detokenizer"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/orchestrator/group_coordinator.py b/pymllm/orchestrator/group_coordinator.py new file mode 100644 index 000000000..2fec30784 --- /dev/null +++ b/pymllm/orchestrator/group_coordinator.py @@ -0,0 +1,104 @@ +"""GroupCoordinator for distributed communication.""" + +from typing import List +import torch +import torch.distributed as dist + + +class GroupCoordinator: + """Manages a group of processes for distributed communication. + + Lightweight wrapper around torch.distributed.ProcessGroup. + + Args: + ranks: List of global ranks in this group + local_rank: Local rank for device assignment + backend: Backend to use (nccl, gloo, etc.) + """ + + def __init__( + self, + ranks: List[int], + local_rank: int, + backend: str = "nccl", + ): + self.ranks = ranks + self.local_rank = local_rank + self.backend = backend + self.world_size = len(ranks) + + # Get rank in this specific group + self.rank_in_group = ranks.index(dist.get_rank()) if dist.is_initialized() else 0 + + # Create process group + if dist.is_initialized() and self.world_size > 1: + self.device_group = dist.new_group(ranks, backend=backend) + else: + self.device_group = None + + def all_reduce(self, tensor: torch.Tensor) -> torch.Tensor: + """All-reduce across the group.""" + if self.device_group is not None: + dist.all_reduce(tensor, group=self.device_group) + return tensor + + def all_gather(self, tensor: torch.Tensor, dim: int = 0) -> torch.Tensor: + """All-gather across the group.""" + if self.device_group is None: + return tensor + + world_size = self.world_size + if dim == 0: + shape = list(tensor.shape) + shape[0] = shape[0] * world_size + output = torch.empty(shape, dtype=tensor.dtype, device=tensor.device) + dist.all_gather_into_tensor(output, tensor, group=self.device_group) + return output + else: + # For non-dim-0 gathers, use tensor list + tensor_list = [ + torch.empty_like(tensor) for _ in range(world_size) + ] + dist.all_gather(tensor_list, tensor, group=self.device_group) + return torch.cat(tensor_list, dim=dim) + + def broadcast(self, tensor: torch.Tensor, src: int = 0) -> torch.Tensor: + """Broadcast from source rank to all. + + Args: + tensor: Tensor to broadcast. + src: Source rank relative to this group (0 <= src < world_size). + """ + if self.device_group is not None: + global_src = self.ranks[src] + dist.broadcast(tensor, src=global_src, group=self.device_group) + return tensor + + +def divide(numerator: int, denominator: int) -> int: + """Divide and ensure divisibility.""" + assert numerator % denominator == 0, ( + f"{numerator} is not divisible by {denominator}" + ) + return numerator // denominator + + +def split_tensor_along_dim( + tensor: torch.Tensor, + dim: int, + world_size: int, + rank: int, +) -> torch.Tensor: + """Split tensor along a dimension for tensor parallelism.""" + dim_size = tensor.size(dim) + assert dim_size % world_size == 0, ( + f"Dimension {dim} ({dim_size}) not divisible by world_size {world_size}" + ) + + chunk_size = dim_size // world_size + start = rank * chunk_size + end = start + chunk_size + + slices = [slice(None)] * tensor.ndim + slices[dim] = slice(start, end) + return tensor[tuple(slices)] diff --git a/pymllm/orchestrator/ipc_utils.py b/pymllm/orchestrator/ipc_utils.py new file mode 100644 index 000000000..abb59849a --- /dev/null +++ b/pymllm/orchestrator/ipc_utils.py @@ -0,0 +1,105 @@ +"""ZMQ IPC utilities for inter-process communication. + +Provides helpers to generate unique IPC addresses and create pre-configured +ZMQ sockets so that every process uses the same conventions. +""" + +import logging +import os +import tempfile +from typing import Optional + +import zmq + + +_IPC_DIR = os.path.join(tempfile.gettempdir(), "pymllm_ipc") + + +def _ensure_ipc_dir() -> None: + os.makedirs(_IPC_DIR, exist_ok=True) + + +def make_ipc_address(name: str, unique_id: Optional[str] = None) -> str: + """Return an ``ipc://`` address for *name*, optionally scoped by *unique_id*. + + Parameters + ---------- + name + Logical channel name, e.g. ``"rr_to_tokenizer"``. + unique_id + Per-engine identifier (typically ``str(os.getpid())``) to avoid + collisions when multiple engines run on the same host. + """ + _ensure_ipc_dir() + suffix = f"_{unique_id}" if unique_id else "" + return f"ipc://{_IPC_DIR}/pymllm_{name}{suffix}" + + +def create_zmq_socket( + ctx: zmq.Context, + socket_type: int, + address: str, + bind: bool, +) -> zmq.Socket: + """Create a ZMQ socket, bind or connect it, and return it. + + Parameters + ---------- + ctx + A ``zmq.Context`` shared within the process. + socket_type + One of ``zmq.PUSH``, ``zmq.PULL``, ``zmq.PAIR``, etc. + address + The ``ipc://`` address string. + bind + If ``True`` the socket calls ``bind``; otherwise ``connect``. + """ + sock = ctx.socket(socket_type) + sock.setsockopt(zmq.LINGER, 0) + if bind: + sock.bind(address) + else: + sock.connect(address) + return sock + + +def close_zmq_socket(sock: zmq.Socket) -> None: + """Close a ZMQ socket, ignoring errors.""" + try: + sock.close() + except zmq.ZMQError: + pass + + +def cleanup_ipc_files(unique_id: Optional[str] = None) -> None: + """Remove IPC socket files for the given engine (or all if no id given).""" + import glob as _glob + + suffix = f"_{unique_id}" if unique_id else "" + pattern = os.path.join(_IPC_DIR, f"pymllm_*{suffix}") + for f in _glob.glob(pattern): + try: + os.unlink(f) + except OSError: + pass + + +def setup_subprocess_logging(log_level: str = "info") -> None: + """Configure logging for a spawned subprocess. + + When Python spawns a subprocess (``mp.set_start_method('spawn')``), the + child starts with a blank logging configuration. Call this function at the + very beginning of every subprocess entry point so that log records are + emitted at the correct level. + + Parameters + ---------- + log_level + Case-insensitive level name, e.g. ``"debug"``, ``"info"``, ``"warning"``. + """ + level = getattr(logging, log_level.upper(), logging.INFO) + logging.basicConfig( + level=level, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + ) + logging.getLogger("pymllm").setLevel(level) diff --git a/pymllm/orchestrator/model_runner_process.py b/pymllm/orchestrator/model_runner_process.py new file mode 100644 index 000000000..a514ac2e9 --- /dev/null +++ b/pymllm/orchestrator/model_runner_process.py @@ -0,0 +1,1142 @@ +""" +ModelRunnerProcess -- GPU-owning component that executes model forward passes. + +Instantiated **in-process** by :class:`SchedulerProcess` +The scheduler calls :meth:`_forward_batch` directly — +no inter-process communication is involved. + +This component owns the GPU: it holds a :class:`ModelRunner` with model +weights, KV-cache memory pools, and the attention backend. It also owns +the :class:`RadixCache` for prefix-aware KV reuse. + +RadixCache lifecycle +-------------------- +1. **match_prefix** — called during ``_allocate_extend`` before KV allocation. +2. **inc_lock_ref** — locks matched radix-tree nodes to prevent eviction. +3. **insert (prefill)** — inserts prompt KV indices after prefill. +4. **insert (completion)** — re-inserts the full sequence when a request finishes. +5. **dec_lock_ref** — unlocks radix-tree nodes when a request is freed. +6. **evict** — called when KV allocation fails to free stale cache entries. +""" + +import logging +from typing import Any, Dict, List, Optional, Tuple + +import torch + +from pymllm.mem_cache.base_prefix_cache import BasePrefixCache, RadixKey +from pymllm.mem_cache.chunk_cache import ChunkCache +from pymllm.mem_cache.mamba_radix_cache import MambaRadixCache +from pymllm.mem_cache.radix_cache import RadixCache + +logger = logging.getLogger(__name__) + +# Fraction of KV pool to try evicting when allocation fails. +_EVICT_FRACTION = 0.10 +# Maximum number of eviction retries before giving up. +_MAX_EVICT_RETRIES = 3 + + +class ModelRunnerProcess: + """GPU-owning component created in-process by SchedulerProcess.""" + + def __init__( + self, + gpu_id: int = 0, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + ): + self._gpu_id = gpu_id + self._server_config = server_config + self._model_config = model_config + + # The ModelRunner instance (created in init_model) + self._runner = None + self._is_hybrid: bool = False + + # RadixCache instance (created in init_model, after memory pools) + self._radix_cache: Optional[RadixCache] = None + + # GPU resource tracking: maps rid -> req_pool_idx (slot in ReqToTokenPool) + self._rid_to_req_pool_idx: Dict[str, int] = {} + # Maps rid -> kv_indices tensor (all KV-cache token indices for this request) + self._rid_to_kv_indices: Dict[str, torch.Tensor] = {} + # Maps rid -> input_ids used for prefill (needed for radix cache insert) + self._rid_to_input_ids: Dict[str, List[int]] = {} + # Maps rid -> list of generated (decode) token ids, appended each step. + # Used to build the full sequence for radix cache insert at completion. + self._rid_to_output_ids: Dict[str, List[int]] = {} + # Maps rid -> cache_protected_len: the length of the prefix that has + # already been inserted into the radix cache. When insert() returns + # prefix_len > cache_protected_len, the KV indices in the overlap + # range [cache_protected_len, prefix_len) are duplicates that must + # be freed from the allocator (the tree already holds cloned copies). + self._rid_to_cache_protected_len: Dict[str, int] = {} + # Maps rid -> (last_node, lock_token) for radix cache lock tracking. + # last_node type depends on the cache implementation (TreeNode, MambaTreeNode, etc.) + self._rid_to_radix_lock: Dict[str, Tuple[Any, Optional[Any]]] = {} + # Maps rid -> mrope_position_delta (M-RoPE positional offset per request) + # Populated during prefill; used to offset decode-step positions for + # multimodal models (Qwen3-VL) that consume more position indices than + # tokens due to 3-D image grid positions. + self._rid_to_mrope_delta: Dict[str, int] = {} + + # GDN prefix cache state tracking (hybrid models only): + # Maps rid -> GDN track slot index in GDNPool (for snapshotting state) + self._rid_to_gdn_track_slot: Dict[str, int] = {} + # Maps radix tree node id -> GDN track slot index + self._node_id_to_gdn_track_slot: Dict[int, int] = {} + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_model(self) -> None: + """Create and initialise the ModelRunner and RadixCache. + + Must run inside the subprocess (after spawn) since it does CUDA init. + """ + from pymllm.executor.model_runner import ModelRunner + + logger.info( + "ModelRunnerProcess: initialising ModelRunner on GPU %d", + self._gpu_id, + ) + self._runner = ModelRunner( + server_config=self._server_config, + model_config=self._model_config, + gpu_id=self._gpu_id, + ) + self._runner.initialize() + + # Initialise prefix cache after memory pools are ready. + self._radix_cache = self._create_prefix_cache() + logger.info("ModelRunnerProcess: ModelRunner ready") + + def _create_prefix_cache(self) -> BasePrefixCache: + """Factory: create the appropriate prefix cache based on config.""" + disable_cache = getattr(self._server_config, "disable_radix_cache", False) + self._is_hybrid = self._runner.num_gdn_layers > 0 + enable_mamba_cache = getattr(self._server_config, "enable_mamba_cache", False) + sliding_window = self._runner.sliding_window_size + page_size = getattr(self._server_config, "radix_cache_page_size", 1) + allocator = self._runner.token_to_kv_pool_allocator + + if disable_cache: + device = allocator.device if allocator is not None else torch.device("cpu") + logger.info("ModelRunnerProcess: using ChunkCache (radix cache disabled)") + return ChunkCache( + token_to_kv_pool_allocator=allocator, + device=device, + ) + + if enable_mamba_cache: + mamba_pool = getattr(self._runner, "gdn_pool", None) + logger.info( + "ModelRunnerProcess: using MambaRadixCache " + "(mamba_pool=%s, page_size=%d)", + "available" if mamba_pool is not None else "none", + page_size, + ) + evict_cb = self._on_radix_node_evict if self._is_hybrid else None + return MambaRadixCache( + page_size=page_size, + token_to_kv_pool_allocator=allocator, + mamba_pool=mamba_pool, + on_node_evict=evict_cb, + ) + + # Standard RadixCache (with optional SWA) + if self._is_hybrid: + logger.info( + "ModelRunnerProcess: prefix caching ENABLED with GDN state " + "tracking (%d GDN layers)", + self._runner.num_gdn_layers, + ) + evict_cb = self._on_radix_node_evict if self._is_hybrid else None + logger.info( + "ModelRunnerProcess: using RadixCache " + "(sliding_window=%s, page_size=%d)", + sliding_window, + page_size, + ) + return RadixCache( + page_size=page_size, + sliding_window_size=sliding_window, + token_to_kv_pool_allocator=allocator, + on_node_evict=evict_cb, + ) + + # ------------------------------------------------------------------ + # Forward pass + # ------------------------------------------------------------------ + + def _forward_batch(self, batch: Dict[str, Any]) -> Dict[str, Any]: + """Run the model forward pass and sampling for *batch*. + + *batch* is a dict produced by ``ScheduleBatch.to_batch_dict()`` + containing ``"forward_mode"``, ``"input_ids"``, ``"seq_lens"``, + ``"req_pool_indices"``, ``"requests"`` (metadata list), etc. + + Implements 6 phases: + 1. Cleanup: free GPU resources for rids no longer in the batch + 2. Prefix matching + KV allocation + 3. Build GPU tensors + 4. Forward + sample + 5. Radix cache insert (extend only) + 6. Build result dict + """ + runner = self._runner + forward_mode = batch.get("forward_mode", "decode") + batch_size = batch.get("batch_size", 0) + requests_meta: List[Dict[str, Any]] = batch.get("requests", []) + + if batch_size == 0: + return {"batch_id": batch.get("batch_id"), "outputs": []} + + device = runner.device + + # Collect current batch rids + current_rids: Set[str] = {m["rid"] for m in requests_meta} + + # ============================================================== + # Phase 2: Prefix matching + KV allocation + # ============================================================== + # For extend batches, match_prefix is done inside _allocate_extend + # which may update extend_prefix_lens and extend_seq_lens. + if forward_mode == "extend": + out_cache_loc, actual_prefix_lens, actual_extend_lens = ( + self._allocate_extend(batch, requests_meta) + ) + else: + out_cache_loc = self._allocate_decode(batch, requests_meta) + actual_prefix_lens = None + actual_extend_lens = None + + # ============================================================== + # Phase 3: Build GPU tensors + # ============================================================== + if forward_mode == "extend" and actual_prefix_lens is not None: + # Rebuild input_ids and seq_lens using actual prefix matches. + # The scheduler sent tokens assuming prefix_len=0; we need to + # trim the input_ids to skip the prefix-matched tokens. + ( + input_ids_tensor, + seq_lens_tensor, + extend_seq_lens_t, + extend_prefix_lens_t, + ) = self._rebuild_extend_tensors( + batch, requests_meta, actual_prefix_lens, actual_extend_lens, device + ) + else: + input_ids_list: List[int] = batch["input_ids"] + seq_lens_list: List[int] = batch["seq_lens"] + input_ids_tensor = torch.tensor( + input_ids_list, dtype=torch.int32, device=device + ) + seq_lens_tensor = torch.tensor( + seq_lens_list, dtype=torch.int32, device=device + ) + extend_seq_lens_t = None + extend_prefix_lens_t = None + + # Build req_pool_indices from our own tracking (NOT from scheduler) + req_pool_indices = torch.tensor( + [self._rid_to_req_pool_idx[m["rid"]] for m in requests_meta], + dtype=torch.int64, + device=device, + ) + + out_cache_loc = out_cache_loc.to(torch.int64) + + # ============================================================== + # Phase 4: Forward + sample + # ============================================================== + # Extract per-request sampling params + temperatures = [] + top_ps = [] + top_ks = [] + repetition_penalties = [] + frequency_penalties = [] + presence_penalties = [] + for m in requests_meta: + sp = m.get("sampling_params") or {} + temperatures.append(sp.get("temperature", 1.0)) + top_ps.append(sp.get("top_p", 1.0)) + top_ks.append(sp.get("top_k", -1)) + repetition_penalties.append(sp.get("repetition_penalty", 1.0)) + frequency_penalties.append(sp.get("frequency_penalty", 0.0)) + presence_penalties.append(sp.get("presence_penalty", 0.0)) + + temps_tensor = torch.tensor(temperatures, dtype=torch.float32, device=device) + top_ps_tensor = torch.tensor(top_ps, dtype=torch.float32, device=device) + top_ks_tensor = torch.tensor(top_ks, dtype=torch.int32, device=device) + + # Collect token histories for penalty computation. + # Each entry is (input_ids + output_ids_so_far) for the request. + has_penalties = ( + any(p != 1.0 for p in repetition_penalties) + or any(p != 0.0 for p in frequency_penalties) + or any(p != 0.0 for p in presence_penalties) + ) + penalty_params = None + if has_penalties: + token_histories = [] + for m in requests_meta: + rid = m["rid"] + input_ids = self._rid_to_input_ids.get(rid, []) + output_ids = self._rid_to_output_ids.get(rid, []) + token_histories.append(list(input_ids) + list(output_ids)) + penalty_params = { + "repetition_penalties": torch.tensor( + repetition_penalties, dtype=torch.float32, device=device + ), + "frequency_penalties": torch.tensor( + frequency_penalties, dtype=torch.float32, device=device + ), + "presence_penalties": torch.tensor( + presence_penalties, dtype=torch.float32, device=device + ), + "token_histories": token_histories, + } + + if forward_mode == "extend": + if extend_seq_lens_t is None: + extend_seq_lens_list: List[int] = batch["extend_seq_lens"] + extend_prefix_lens_list: List[int] = batch["extend_prefix_lens"] + extend_seq_lens_t = torch.tensor( + extend_seq_lens_list, dtype=torch.int32, device=device + ) + extend_prefix_lens_t = torch.tensor( + extend_prefix_lens_list, dtype=torch.int32, device=device + ) + + fb = runner.prepare_forward_batch_extend( + input_ids=input_ids_tensor, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens_tensor, + extend_seq_lens=extend_seq_lens_t, + extend_prefix_lens=extend_prefix_lens_t, + out_cache_loc=out_cache_loc, + ) + + # Attach multimodal vision inputs to ForwardBatch so the + # model's vision encoder can process images during prefill. + # The tokenizer wraps processor output under "image_inputs"; + # fall back to top-level keys for direct dicts. + pixel_values_list = [] + image_grid_thw_list = [] + for m in requests_meta: + mm = m.get("mm_inputs") + if mm is None: + continue + # AutoProcessor output is nested under "image_inputs" + src = mm.get("image_inputs") if "image_inputs" in mm else mm + if src is None: + continue + pv = ( + src.get("pixel_values") + if hasattr(src, "get") + else getattr(src, "pixel_values", None) + ) + thw = ( + src.get("image_grid_thw") + if hasattr(src, "get") + else getattr(src, "image_grid_thw", None) + ) + if pv is not None: + if not isinstance(pv, torch.Tensor): + pv = torch.as_tensor(pv) + pixel_values_list.append(pv.to(device=device)) + if thw is not None: + if not isinstance(thw, torch.Tensor): + thw = torch.as_tensor(thw) + image_grid_thw_list.append(thw.to(device=device)) + if pixel_values_list: + fb.pixel_values = torch.cat(pixel_values_list, dim=0) + if image_grid_thw_list: + fb.image_grid_thw = torch.cat(image_grid_thw_list, dim=0) + else: + # Build mrope_position_deltas tensor for decode batches. + mrope_deltas = [ + self._rid_to_mrope_delta.get(m["rid"], 0) for m in requests_meta + ] + mrope_deltas_tensor = torch.tensor( + mrope_deltas, dtype=torch.int64, device=device + ) + + fb = runner.prepare_forward_batch_decode( + input_ids=input_ids_tensor, + req_pool_indices=req_pool_indices, + seq_lens=seq_lens_tensor, + out_cache_loc=out_cache_loc, + mrope_position_deltas=mrope_deltas_tensor, + ) + + logits_output = runner.forward(fb) + + # Persist M-RoPE position deltas for multimodal models (Qwen3-VL). + # The model sets mrope_position_deltas on the ForwardBatch during + # prefill; we store them here so decode steps can retrieve them. + if ( + forward_mode == "extend" + and getattr(fb, "mrope_position_deltas", None) is not None + ): + deltas_cpu = fb.mrope_position_deltas.cpu().tolist() + for idx, m in enumerate(requests_meta): + self._rid_to_mrope_delta[m["rid"]] = int(deltas_cpu[idx]) + + next_token_ids = runner.sample( + logits_output, + fb, + temperatures=temps_tensor, + top_ps=top_ps_tensor, + top_ks=top_ks_tensor, + penalty_params=penalty_params, + ) + + # ============================================================== + # Phase 4.5: Snapshot GDN state after extend (hybrid models) + # ============================================================== + if forward_mode == "extend" and self._is_hybrid: + self._track_gdn_state_after_extend(requests_meta) + + # ============================================================== + # Phase 5: Radix cache insert (extend only) + # ============================================================== + if forward_mode == "extend" and self._radix_cache is not None: + self._insert_into_radix_cache(requests_meta) + + # ============================================================== + # Phase 6: Build result & track output tokens + # ============================================================== + next_ids_cpu = next_token_ids.cpu().tolist() + outputs: List[Dict[str, Any]] = [] + for i, m in enumerate(requests_meta): + rid = m["rid"] + token_id = next_ids_cpu[i] if i < len(next_ids_cpu) else 0 + # Track output tokens for radix cache insert at completion + out_ids = self._rid_to_output_ids.get(rid) + if out_ids is not None: + out_ids.append(token_id) + + out: Dict[str, Any] = { + "rid": rid, + "output_token_ids": [token_id], + } + # Report actual prefix_len back to the scheduler so it can + # update its token budget tracking accurately. + if actual_prefix_lens is not None: + out["prefix_len"] = actual_prefix_lens[i] + outputs.append(out) + + return { + "batch_id": batch.get("batch_id"), + "outputs": outputs, + } + + # ------------------------------------------------------------------ + # Tensor rebuild for prefix-matched extend + # ------------------------------------------------------------------ + + def _rebuild_extend_tensors( + self, + batch: Dict[str, Any], + requests_meta: List[Dict[str, Any]], + actual_prefix_lens: List[int], + actual_extend_lens: List[int], + device: str, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + """Rebuild input_ids and related tensors after prefix matching. + + The scheduler sent input_ids assuming no prefix cache hit. After + radix cache matching, we know the actual prefix lengths and must + trim the input_ids accordingly. + + Returns (input_ids, seq_lens, extend_seq_lens, extend_prefix_lens) + as GPU tensors. + """ + # Reconstruct trimmed input_ids: for each request, take only the + # tokens beyond the matched prefix. + new_input_ids: List[int] = [] + seq_lens_list: List[int] = batch["seq_lens"] + + for i, m in enumerate(requests_meta): + full_input_ids = m.get("input_ids", []) + prefix_len = actual_prefix_lens[i] + # Only send tokens after the prefix + new_input_ids.extend(full_input_ids[prefix_len:]) + + input_ids = torch.tensor(new_input_ids, dtype=torch.int32, device=device) + seq_lens = torch.tensor(seq_lens_list, dtype=torch.int32, device=device) + extend_seq_lens = torch.tensor( + actual_extend_lens, dtype=torch.int32, device=device + ) + extend_prefix_lens = torch.tensor( + actual_prefix_lens, dtype=torch.int32, device=device + ) + return input_ids, seq_lens, extend_seq_lens, extend_prefix_lens + + # ------------------------------------------------------------------ + # Radix cache insert + # ------------------------------------------------------------------ + + def _insert_into_radix_cache(self, requests_meta: List[Dict[str, Any]]) -> None: + """Insert prefill KV indices into the radix cache for future reuse. + + 1. **Insert** the request's token → KV index mapping into the tree. + 2. **Free duplicates** — indices in ``[cache_protected_len, new_prefix_len)`` + are now owned by the tree; the request's copies are redundant. + 3. **Re-match + write-back** — fetch the tree's *own* indices via + ``match_prefix`` and write them into ``req_to_token_pool``, + replacing the just-freed entries. Without this step the pool + still points at freed slots → use-after-free during decode. + 4. **Update** ``cache_protected_len`` and radix lock. + """ + _dbg = logger.isEnabledFor(logging.DEBUG) + cache = self._radix_cache + if cache is None: + return + + runner = self._runner + gdn_pool = getattr(runner, "gdn_pool", None) + + for m in requests_meta: + rid = m["rid"] + input_ids = self._rid_to_input_ids.get(rid) + if input_ids is None: + continue + + slot = self._rid_to_req_pool_idx.get(rid) + if slot is None: + continue + + seq_len = len(input_ids) + kv_indices = runner.req_to_token_pool.req_to_token[slot, :seq_len].to( + torch.int64 + ) + + if _dbg: + logger.debug( + "[CACHE INSERT] rid=%s seq_len=%d pool[slot=%d,0:%d]=%s", + rid, + seq_len, + slot, + min(seq_len, 8), + kv_indices[: min(seq_len, 8)].tolist(), + ) + + key = RadixKey(input_ids) + result = cache.insert(key, kv_indices) + new_prefix_len = result.prefix_len + + # --- Step 2: free duplicates --- + cache_protected_len = self._rid_to_cache_protected_len.get(rid, 0) + if _dbg: + logger.debug( + "[CACHE INSERT] rid=%s insert prefix_len=%d cache_protected=%d", + rid, + new_prefix_len, + cache_protected_len, + ) + if new_prefix_len > cache_protected_len: + dup_indices = kv_indices[cache_protected_len:new_prefix_len] + if _dbg: + logger.debug( + "[CACHE INSERT] rid=%s freeing dup [%d:%d]=%s", + rid, + cache_protected_len, + new_prefix_len, + dup_indices[: min(len(dup_indices), 8)].tolist(), + ) + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + # --- Step 3: re-match + write-back --- + # The tree now owns indices for [0, new_prefix_len). Fetch them + # and patch req_to_token_pool so the request reads the tree's + # (still-live) indices instead of the freed ones. + rematch = cache.match_prefix(key) + new_indices = rematch.indices + if _dbg: + logger.debug( + "[CACHE INSERT] rid=%s rematch len=%d indices[:8]=%s", + rid, + len(new_indices), + new_indices[: min(len(new_indices), 8)].tolist(), + ) + if cache.page_size == 1: + assert len(new_indices) == seq_len, ( + f"Re-match length mismatch after insert: " + f"{len(new_indices)=}, {seq_len=}, rid={rid}" + ) + if len(new_indices) > cache_protected_len: + if _dbg: + logger.debug( + "[CACHE INSERT] rid=%s write-back pool[slot=%d,%d:%d]=%s", + rid, + slot, + cache_protected_len, + len(new_indices), + new_indices[ + cache_protected_len : cache_protected_len + 8 + ].tolist(), + ) + runner.req_to_token_pool.write( + (slot, slice(cache_protected_len, len(new_indices))), + new_indices[cache_protected_len:].to(torch.int32), + ) + + # --- Step 4: update tracking --- + self._rid_to_cache_protected_len[rid] = len(new_indices) + + # Update radix lock to cover the new (potentially deeper) node. + old_lock = self._rid_to_radix_lock.pop(rid, None) + if old_lock is not None: + old_node, old_swa = old_lock + cache.dec_lock_ref(old_node, old_swa) + new_last_node = rematch.last_node + if new_last_node is not None and len(new_indices) > 0: + swa_id = cache.inc_lock_ref(new_last_node) + self._rid_to_radix_lock[rid] = (new_last_node, swa_id) + + # --- GDN track slot association (hybrid models) --- + if gdn_pool is not None and result.last_node is not None: + track_slot = self._rid_to_gdn_track_slot.get(rid) + if track_slot is not None: + node_id = result.last_node.id + old_ts = self._node_id_to_gdn_track_slot.get(node_id) + if old_ts is None: + self._node_id_to_gdn_track_slot[node_id] = track_slot + else: + gdn_pool.free_track_slot(track_slot) + self._rid_to_gdn_track_slot.pop(rid, None) + + # ------------------------------------------------------------------ + # KV allocation helpers + # ------------------------------------------------------------------ + + def _allocate_extend( + self, batch: Dict[str, Any], requests_meta: List[Dict[str, Any]] + ) -> Tuple[torch.Tensor, List[int], List[int]]: + """Allocate req pool slots and KV tokens for an extend (prefill) batch. + + Performs radix cache prefix matching before allocation: + 1. For each request, call ``match_prefix`` to find cached KV indices. + 2. Write cached indices into ``ReqToTokenPool``. + 3. Only allocate new KV tokens for the non-cached suffix. + 4. Lock matched radix nodes to prevent eviction. + + Returns ``(out_cache_loc, actual_prefix_lens, actual_extend_lens)``. + ``out_cache_loc`` has shape ``[total_new_tokens]``. + """ + runner = self._runner + cache = self._radix_cache + batch_size = batch["batch_size"] + seq_lens: List[int] = batch["seq_lens"] + + # --- Step 1: Radix cache prefix matching --- + actual_prefix_lens: List[int] = [] + actual_extend_lens: List[int] = [] + matched_nodes: List[Optional[Any]] = [] + # Cache the match results so we don't call match_prefix twice + cached_indices_list: List[Optional[torch.Tensor]] = [] + gdn_pool = getattr(runner, "gdn_pool", None) + + for i, m in enumerate(requests_meta): + full_input_ids: List[int] = m.get("input_ids", []) + full_seq_len = seq_lens[i] + + # Store input_ids for later radix cache insert + self._rid_to_input_ids[m["rid"]] = full_input_ids + + if cache is not None and len(full_input_ids) > 0: + key = RadixKey(full_input_ids) + match_result = cache.match_prefix(key) + prefix_len = match_result.prefix_len + last_node = match_result.last_node + cached_indices = match_result.indices + else: + prefix_len = 0 + last_node = None + cached_indices = None + + # Hybrid model guard: only use a KV cache hit if the matched + # node has a GDN state snapshot. Without it, the full-attention + # layers would use cached KV while GDN layers start from zero, + # causing an attention/GDN state mismatch. Discard the hit so + # the entire prompt is processed from scratch. + if ( + gdn_pool is not None + and prefix_len > 0 + and last_node is not None + and self._node_id_to_gdn_track_slot.get(last_node.id) is None + ): + logger.debug( + "Discarding radix cache hit for rid=%s: no GDN state " + "for matched node (prefix_len=%d)", + m["rid"], + prefix_len, + ) + prefix_len = 0 + last_node = None + cached_indices = None + + # Ensure at least 2 tokens are extended (not nearly fully cached). + # Reasons: + # 1. A full cache hit (prefix_len == full_seq_len) would produce a + # 0-length input tensor that crashes CUDA kernels. + # 2. A 1-token extend triggers an edge case in FlashInfer's + # ragged forward_return_lse (qo_len=1, kv_len=1, causal=True) + # where s1 (log-partition) is computed incorrectly, causing + # the cascade merge to produce wrong logits → EOS. + # By ensuring extend_len >= 2, we avoid both issues. + if prefix_len >= full_seq_len - 1 and full_seq_len >= 2: + prefix_len = full_seq_len - 2 + if cached_indices is not None: + cached_indices = cached_indices[:prefix_len] + + extend_len = full_seq_len - prefix_len + actual_prefix_lens.append(prefix_len) + actual_extend_lens.append(extend_len) + matched_nodes.append(last_node) + cached_indices_list.append(cached_indices) + + if prefix_len > 0: + logger.info( + "Radix cache hit for rid=%s: %d/%d tokens reused (%.1f%%) " + "node_id=%s cached_kv[:8]=%s", + m["rid"], + prefix_len, + full_seq_len, + 100.0 * prefix_len / full_seq_len, + last_node.id if last_node is not None else None, + cached_indices[: min(prefix_len, 8)].tolist() + if cached_indices is not None + else [], + ) + logger.info( + "Radix cache tree after match: evictable=%d protected=%d", + cache.evictable_size(), + cache.protected_size(), + ) + + total_new_tokens = sum(actual_extend_lens) + + # --- Step 1.5: Lock matched radix nodes BEFORE allocation --- + # This MUST happen before any allocation that could trigger eviction. + # Without locking first, _alloc_kv_with_eviction could evict the + # matched nodes, freeing their KV pool slots and causing + # use-after-free when we later read from cached_indices. + if cache is not None: + for i, m in enumerate(requests_meta): + node = matched_nodes[i] + if node is not None and actual_prefix_lens[i] > 0: + swa_boundary_id = cache.inc_lock_ref(node) + self._rid_to_radix_lock[m["rid"]] = (node, swa_boundary_id) + + # --- Step 2: Allocate req pool slots --- + slots = runner.req_to_token_pool.alloc(batch_size) + if slots is None: + # Rollback locks on failure + self._unlock_matched_nodes(requests_meta) + raise RuntimeError("Failed to allocate req pool slots for extend batch") + + # --- Step 3: Allocate KV tokens (with eviction retry) --- + out_cache_loc = self._alloc_kv_with_eviction(total_new_tokens) + if out_cache_loc is None: + for s in slots: + runner.req_to_token_pool.free(s) + # Rollback locks on failure + self._unlock_matched_nodes(requests_meta) + raise RuntimeError( + f"Failed to allocate {total_new_tokens} KV tokens for extend batch " + f"(even after eviction)" + ) + + # --- Step 4: Write indices into req_to_token_pool --- + offset = 0 + for i, m in enumerate(requests_meta): + rid = m["rid"] + slot = slots[i] + prefix_len = actual_prefix_lens[i] + extend_len = actual_extend_lens[i] + full_seq_len = seq_lens[i] + + # Write cached prefix indices (from the match result we saved) + cached_indices = cached_indices_list[i] + if cached_indices is not None and prefix_len > 0: + logger.debug( + "[ALLOC EXTEND] rid=%s writing prefix[0:%d] to pool[slot=%d]: %s", + rid, + prefix_len, + slot, + cached_indices[: min(prefix_len, 8)].tolist(), + ) + runner.req_to_token_pool.write( + (slot, slice(0, prefix_len)), + cached_indices[:prefix_len].to(torch.int32), + ) + + # Write new KV indices for the suffix + kv_indices = out_cache_loc[offset : offset + extend_len] + runner.req_to_token_pool.write( + (slot, slice(prefix_len, full_seq_len)), kv_indices + ) + + self._rid_to_req_pool_idx[rid] = slot + self._rid_to_kv_indices[rid] = kv_indices.clone() + self._rid_to_output_ids[rid] = [] + # The prefix portion is already protected in the radix cache + # (from a previous request's insert). We start with this as + # cache_protected_len so that subsequent insert() calls know + # which range is already covered. + self._rid_to_cache_protected_len[rid] = actual_prefix_lens[i] + offset += extend_len + + # GDN state management: restore from track slot on cache hit, or reset + if gdn_pool is not None: + for i, m in enumerate(requests_meta): + rid = m["rid"] + working_slot = slots[i] + prefix_len = actual_prefix_lens[i] + node = matched_nodes[i] + + if prefix_len > 0 and node is not None: + # Cache hit — try to restore GDN state from the track slot + # associated with the matched radix node. + track_slot = self._node_id_to_gdn_track_slot.get(node.id) + if track_slot is not None: + gdn_pool.copy_states(track_slot, working_slot) + logger.debug( + "GDN state restored for rid=%s from track_slot=%d " + "(prefix_len=%d)", + rid, + track_slot, + prefix_len, + ) + else: + # Cache hit but no GDN snapshot — reset to zero. + # This can happen if the track slot was evicted. + idx = torch.tensor( + [working_slot], dtype=torch.int64, device=runner.device + ) + gdn_pool.reset_states(idx) + logger.debug( + "GDN state reset for rid=%s (cache hit but no " + "track slot, prefix_len=%d)", + rid, + prefix_len, + ) + else: + # No cache hit — fresh request, zero-init + idx = torch.tensor( + [working_slot], dtype=torch.int64, device=runner.device + ) + gdn_pool.reset_states(idx) + + # Allocate a track slot only when the radix cache is enabled; + # track slots are freed via the eviction callback so they must + # be associated with a node, which only happens when cache is on. + if cache is not None: + ts = gdn_pool.alloc_track_slot() + if ts is not None: + self._rid_to_gdn_track_slot[rid] = ts + + # (Locking already done in Step 1.5 above) + + return out_cache_loc, actual_prefix_lens, actual_extend_lens + + def _unlock_matched_nodes(self, requests_meta: List[Dict[str, Any]]) -> None: + """Rollback radix locks acquired during match_prefix. + + Called when allocation fails after locking matched nodes. + """ + cache = self._radix_cache + if cache is None: + return + for m in requests_meta: + lock = self._rid_to_radix_lock.pop(m["rid"], None) + if lock is not None: + node, swa_id = lock + cache.dec_lock_ref(node, swa_id) + + def _alloc_kv_with_eviction(self, num_tokens: int) -> Optional[torch.Tensor]: + """Try to allocate KV tokens, evicting from radix cache if needed.""" + runner = self._runner + cache = self._radix_cache + + if num_tokens == 0: + return torch.empty(0, dtype=torch.int32, device=runner.device) + + # First attempt: direct allocation + result = runner.token_to_kv_pool_allocator.alloc(num_tokens) + if result is not None: + return result + + # Eviction loop: try evicting from radix cache to free space + if cache is None: + return None + + for attempt in range(_MAX_EVICT_RETRIES): + evictable = cache.evictable_size() + if evictable == 0: + logger.warning( + "KV allocation failed: need %d tokens, no evictable cache entries", + num_tokens, + ) + return None + + # Evict a fraction of the cache (at least what we need) + evict_target = max( + num_tokens, + int(runner.token_to_kv_pool_allocator.size * _EVICT_FRACTION), + ) + evict_result = cache.evict(evict_target) + logger.info( + "Radix cache eviction attempt %d: evicted %d tokens (target=%d)", + attempt + 1, + evict_result.full_evicted, + evict_target, + ) + + # Retry allocation + result = runner.token_to_kv_pool_allocator.alloc(num_tokens) + if result is not None: + return result + + return None + + def _allocate_decode( + self, batch: Dict[str, Any], requests_meta: List[Dict[str, Any]] + ) -> torch.Tensor: + """Allocate 1 KV token per request for a decode step. + + Returns ``out_cache_loc`` tensor of shape ``[batch_size]``. + """ + runner = self._runner + batch_size = batch["batch_size"] + seq_lens: List[int] = batch["seq_lens"] + + # Allocate 1 new KV token per request (with eviction retry) + out_cache_loc = self._alloc_kv_with_eviction(batch_size) + if out_cache_loc is None: + raise RuntimeError( + f"Failed to allocate {batch_size} KV tokens for decode batch" + ) + + # Write the new KV token index into each request's mapping + for i, m in enumerate(requests_meta): + rid = m["rid"] + slot = self._rid_to_req_pool_idx.get(rid) + if slot is None: + logger.warning("Decode step for unknown rid=%s, skipping KV write", rid) + continue + + cur_seq_len = seq_lens[i] + kv_new = out_cache_loc[i : i + 1] + # The scheduler increments req.seq_len by 1 after every step, so + # seq_lens[i] == (number of tokens in the KV cache INCLUDING the + # token being decoded now). The new token's slot must therefore be + # written at index seq_lens[i] - 1, matching the position used by + # prepare_forward_batch_decode (positions = seq_lens - 1) and the + # window FlashInfer reads (req_to_token_pool[slot, 0:seq_lens[i]]). + write_pos = cur_seq_len - 1 + runner.req_to_token_pool.write( + (slot, slice(write_pos, write_pos + 1)), kv_new + ) + + # Append to tracked kv_indices + prev = self._rid_to_kv_indices.get(rid) + if prev is not None: + self._rid_to_kv_indices[rid] = torch.cat([prev, kv_new]) + else: + self._rid_to_kv_indices[rid] = kv_new.clone() + + return out_cache_loc + + # ------------------------------------------------------------------ + # Resource cleanup + # ------------------------------------------------------------------ + + def _free_rid_resources(self, rid: str) -> None: + """Free GPU resources (req pool slot + KV indices) for a finished rid. + + KV index ownership model (when radix cache is enabled): + + ``req_to_token_pool[slot]`` contains three regions after + ``insert()`` returns ``new_prefix_len``:: + + [0, cache_protected_len) + Indices shared with the radix tree from a previous insert. + **Do not free** — the tree already owns them. + + [cache_protected_len, new_prefix_len) + Indices allocated by THIS request that turned out to overlap + with tree nodes inserted concurrently. The tree already + holds cloned copies → these are duplicates → **free them**. + + [new_prefix_len, total_len) + Indices that ``insert()`` just added to the tree (cloned). + The tree now owns the underlying KV pool slots. + **Do not free** — the tree will free during eviction. + + When the radix cache is disabled, all KV indices are freed directly. + """ + runner = self._runner + cache = self._radix_cache + + slot = self._rid_to_req_pool_idx.pop(rid, None) + kv_indices = self._rid_to_kv_indices.pop(rid, None) + input_ids = self._rid_to_input_ids.pop(rid, None) + output_ids = self._rid_to_output_ids.pop(rid, None) + cache_protected_len = self._rid_to_cache_protected_len.pop(rid, 0) + radix_lock = self._rid_to_radix_lock.pop(rid, None) + self._rid_to_mrope_delta.pop(rid, None) + + # Free GDN track slot (if any) — the slot's association with a + # radix node is managed separately via _node_id_to_gdn_track_slot + # and the eviction callback; here we just remove the rid mapping. + self._rid_to_gdn_track_slot.pop(rid, None) + + cache_enabled = cache is not None + + # ---------------------------------------------------------- + # Phase 1: Read all KV indices BEFORE freeing anything. + # ---------------------------------------------------------- + prompt_len = len(input_ids) if input_ids is not None else 0 + decode_len = len(output_ids) if output_ids else 0 + total_len = prompt_len + decode_len + + all_kv_indices: Optional[torch.Tensor] = None + if slot is not None and input_ids is not None: + all_kv_indices = runner.req_to_token_pool.req_to_token[slot, :total_len].to( + torch.int64 + ) + + # ---------------------------------------------------------- + # Phase 2: Insert into radix cache (if enabled). + # ---------------------------------------------------------- + did_insert = False + if cache_enabled and all_kv_indices is not None: + if self._is_hybrid and decode_len > 0: + # Hybrid model: insert only prompt tokens (not decode) + # because GDN state is only tracked at the prompt boundary. + prompt_kv = all_kv_indices[:prompt_len] + decode_kv = all_kv_indices[prompt_len:] + key = RadixKey(list(input_ids)) + result = cache.insert(key, prompt_kv) + new_prefix_len = result.prefix_len + + # Free duplicate KV indices in the overlap region. + if new_prefix_len > cache_protected_len: + dup_indices = prompt_kv[cache_protected_len:new_prefix_len] + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + # Free decode KV indices (tree does not own them) + if decode_kv.numel() > 0: + runner.token_to_kv_pool_allocator.free(decode_kv) + else: + # Non-hybrid or no decode tokens: insert full sequence + full_token_ids = list(input_ids) + if output_ids: + full_token_ids.extend(output_ids) + key = RadixKey(full_token_ids) + result = cache.insert(key, all_kv_indices) + new_prefix_len = result.prefix_len + + # Free duplicate KV indices in the overlap region. + if new_prefix_len > cache_protected_len: + dup_indices = all_kv_indices[cache_protected_len:new_prefix_len] + if dup_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(dup_indices) + + did_insert = True + + # ---------------------------------------------------------- + # Phase 3: Unlock radix cache nodes. + # ---------------------------------------------------------- + if cache_enabled and radix_lock is not None: + node, swa_boundary_id = radix_lock + cache.dec_lock_ref(node, swa_boundary_id) + + # ---------------------------------------------------------- + # Phase 4: Free KV indices not owned by the radix cache. + # ---------------------------------------------------------- + if not did_insert: + if cache_enabled and all_kv_indices is not None: + # Cache enabled but insert skipped (shouldn't happen in + # normal flow). Tree owns [0, cache_protected_len); + # free the rest. + tail = all_kv_indices[cache_protected_len:] + if tail.numel() > 0: + runner.token_to_kv_pool_allocator.free(tail) + elif not cache_enabled: + # Cache disabled — free all newly-allocated KV indices. + if all_kv_indices is not None and all_kv_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(all_kv_indices) + elif kv_indices is not None and kv_indices.numel() > 0: + runner.token_to_kv_pool_allocator.free(kv_indices) + + # ---------------------------------------------------------- + # Phase 5: Free the req pool slot. + # ---------------------------------------------------------- + if slot is not None: + runner.req_to_token_pool.free(slot) + + logger.debug( + "Freed resources for rid=%s (slot=%s, kv_tokens=%d)", + rid, + slot, + kv_indices.numel() if kv_indices is not None else 0, + ) + + # ------------------------------------------------------------------ + # GDN state tracking helpers (hybrid models) + # ------------------------------------------------------------------ + + def _track_gdn_state_after_extend( + self, requests_meta: List[Dict[str, Any]] + ) -> None: + """Snapshot working GDN state into each request's track slot. + + Called immediately after ``runner.forward()`` for extend batches so + that the FINAL recurrent/conv state (after processing the full prompt) + is saved. The track slot is later associated with a radix node in + ``_insert_into_radix_cache``. + """ + gdn_pool = getattr(self._runner, "gdn_pool", None) + if gdn_pool is None: + return + + for m in requests_meta: + rid = m["rid"] + working_slot = self._rid_to_req_pool_idx.get(rid) + track_slot = self._rid_to_gdn_track_slot.get(rid) + if working_slot is not None and track_slot is not None: + gdn_pool.copy_states(working_slot, track_slot) + + def _on_radix_node_evict(self, node_id: int) -> None: + """Callback invoked by RadixCache when a node is evicted. + + Frees the GDN track slot associated with the evicted node. + """ + track_slot = self._node_id_to_gdn_track_slot.pop(node_id, None) + if track_slot is not None: + gdn_pool = getattr(self._runner, "gdn_pool", None) + if gdn_pool is not None: + gdn_pool.free_track_slot(track_slot) + logger.debug( + "Freed GDN track slot %d for evicted node %d", + track_slot, + node_id, + ) + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._runner is not None: + self._runner.shutdown() diff --git a/pymllm/orchestrator/parallel_state.py b/pymllm/orchestrator/parallel_state.py new file mode 100644 index 000000000..9fb208769 --- /dev/null +++ b/pymllm/orchestrator/parallel_state.py @@ -0,0 +1,183 @@ +"""Minimal parallel state for single-GPU serving. + +pymllm targets single-GPU, high-concurrency inference. This module keeps +the TP / DP / PP scaffolding so the rest of the codebase can query ranks +and groups uniformly, but the default (and expected) case is world_size=1. +""" + +import logging +from typing import Optional + +import torch +import torch.distributed as dist + +from pymllm.orchestrator.group_coordinator import GroupCoordinator + +logger = logging.getLogger(__name__) + +_TP_GROUP: Optional[GroupCoordinator] = None +_DP_GROUP: Optional[GroupCoordinator] = None +_PP_GROUP: Optional[GroupCoordinator] = None + +_TP_RANK: int = 0 +_TP_SIZE: int = 1 +_DP_RANK: int = 0 +_DP_SIZE: int = 1 +_PP_RANK: int = 0 +_PP_SIZE: int = 1 + + +def initialize_model_parallel( + tensor_model_parallel_size: int = 1, + data_parallel_size: int = 1, + pipeline_model_parallel_size: int = 1, + backend: str = "nccl", +) -> None: + global _TP_GROUP, _DP_GROUP, _PP_GROUP + global _TP_RANK, _TP_SIZE, _DP_RANK, _DP_SIZE, _PP_RANK, _PP_SIZE + + _TP_SIZE = tensor_model_parallel_size + _DP_SIZE = data_parallel_size + _PP_SIZE = pipeline_model_parallel_size + + if not dist.is_initialized(): + return + + world_size = dist.get_world_size() + world_rank = dist.get_rank() + local_rank = int(torch.cuda.current_device()) if torch.cuda.is_available() else 0 + + assert ( + tensor_model_parallel_size * data_parallel_size * pipeline_model_parallel_size + == world_size + ), ( + f"TP({tensor_model_parallel_size}) * DP({data_parallel_size}) * " + f"PP({pipeline_model_parallel_size}) != World({world_size})" + ) + + logger.info( + "Parallel init: world=%d rank=%d tp=%d dp=%d pp=%d", + world_size, + world_rank, + tensor_model_parallel_size, + data_parallel_size, + pipeline_model_parallel_size, + ) + + if tensor_model_parallel_size > 1: + num_tp_groups = world_size // tensor_model_parallel_size + for i in range(num_tp_groups): + ranks = list( + range( + i * tensor_model_parallel_size, + (i + 1) * tensor_model_parallel_size, + ) + ) + if world_rank in ranks: + _TP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _TP_RANK = _TP_GROUP.rank_in_group + break + + if data_parallel_size > 1: + num_dp_groups = world_size // data_parallel_size + for i in range(num_dp_groups): + ranks = list(range(i, world_size, num_dp_groups)) + if world_rank in ranks: + _DP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _DP_RANK = _DP_GROUP.rank_in_group + break + + if pipeline_model_parallel_size > 1: + num_pp_groups = world_size // pipeline_model_parallel_size + for i in range(num_pp_groups): + start = i * pipeline_model_parallel_size + ranks = list(range(start, start + pipeline_model_parallel_size)) + if world_rank in ranks: + _PP_GROUP = GroupCoordinator( + ranks=ranks, + local_rank=local_rank, + backend=backend, + ) + _PP_RANK = _PP_GROUP.rank_in_group + break + + +# ---- group accessors ------------------------------------------------------ + + +def get_tp_group() -> Optional[GroupCoordinator]: + return _TP_GROUP + + +def get_dp_group() -> Optional[GroupCoordinator]: + return _DP_GROUP + + +def get_pp_group() -> Optional[GroupCoordinator]: + return _PP_GROUP + + +# ---- rank / size helpers -------------------------------------------------- + + +def get_tensor_model_parallel_rank() -> int: + return _TP_RANK + + +def get_tensor_model_parallel_world_size() -> int: + return _TP_SIZE + + +def get_data_parallel_rank() -> int: + return _DP_RANK + + +def get_data_parallel_world_size() -> int: + return _DP_SIZE + + +def get_pipeline_model_parallel_rank() -> int: + return _PP_RANK + + +def get_pipeline_model_parallel_world_size() -> int: + return _PP_SIZE + + +def model_parallel_is_initialized() -> bool: + return _TP_GROUP is not None or _DP_GROUP is not None or _PP_GROUP is not None + + +# ---- communication helpers ------------------------------------------------ + + +def tensor_model_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor: + group = get_tp_group() + if group is None: + return tensor + return group.all_reduce(tensor) + + +def tensor_model_parallel_all_gather( + tensor: torch.Tensor, + dim: int = 0, +) -> torch.Tensor: + group = get_tp_group() + if group is None: + return tensor + return group.all_gather(tensor, dim=dim) + + +def data_parallel_all_reduce(tensor: torch.Tensor) -> torch.Tensor: + group = get_dp_group() + if group is None: + return tensor + return group.all_reduce(tensor) diff --git a/pymllm/orchestrator/request_response_process.py b/pymllm/orchestrator/request_response_process.py new file mode 100644 index 000000000..f59ffa51e --- /dev/null +++ b/pymllm/orchestrator/request_response_process.py @@ -0,0 +1,236 @@ +""" +RequestResponseProcess -- the main-process entry point for user requests. + +This process is **not** a subprocess; it lives in the engine's main process. +Incoming requests are placed into an ``asyncio.Queue`` and forwarded to the +TokenizerProcess via ZMQ. Decoded results arrive back from the +DetokenizerProcess and are dispatched to the waiting callers. + +The request-tracking model uses ``ReqState`` pattern: each request +gets an ``asyncio.Event`` + output list so that streaming (multiple incremental +chunks) and one-shot responses are both supported. +""" + +import asyncio +import dataclasses +import logging +import time +from typing import Any, Dict, List, Optional, Union + +import zmq +import zmq.asyncio + +from pymllm.engine.io_struct import GenerateReqInput +from pymllm.orchestrator.ipc_utils import create_zmq_socket, close_zmq_socket + +logger = logging.getLogger(__name__) + + +@dataclasses.dataclass +class ReqState: + """Per-request state that supports both streaming and one-shot responses. + + ``ReqState`` (Event + out_list). + + The recv loop appends results to *out_list* and signals *event*; + callers ``await event.wait()`` in a loop, consuming results until + *finished* is ``True``. + """ + + out_list: List[Dict[str, Any]] = dataclasses.field(default_factory=list) + finished: bool = False + event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) + created_at: float = dataclasses.field(default_factory=time.time) + + +class RequestResponseProcess: + """Sits in the main process; bridges user-facing API and subprocess pipeline.""" + + def __init__( + self, + send_to_tokenizer_addr: str, + recv_from_detokenizer_addr: str, + ): + self._send_to_tokenizer_addr: str = send_to_tokenizer_addr + self._recv_from_detokenizer_addr: str = recv_from_detokenizer_addr + + # asyncio queue that buffers incoming user requests + self._request_queue: asyncio.Queue[Dict[str, Any]] = asyncio.Queue() + + # rid -> ReqState (replaces the old rid -> Future dict) + self._rid_to_state: Dict[str, ReqState] = {} + + # ZMQ (async context, sockets created lazily in the event loop) + self._zmq_ctx: Optional[zmq.asyncio.Context] = None + self._send_to_tokenizer: Optional[zmq.asyncio.Socket] = None + self._recv_from_detokenizer: Optional[zmq.asyncio.Socket] = None + + self._loop_task: Optional[asyncio.Task] = None + + def start(self) -> None: + """Bind ZMQ sockets. Background tasks are started lazily by + :meth:`listen` on the first :meth:`add_request` call, so they + always run on the correct event loop regardless of whether the + caller is uvicorn, ``loop.run_until_complete``, or anything else. + """ + self._zmq_ctx = zmq.asyncio.Context() + self._send_to_tokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_tokenizer_addr, + bind=True, + ) + self._recv_from_detokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_detokenizer_addr, + bind=True, + ) + + def listen(self) -> None: + """Start the send/recv background tasks on the **current** running + event loop. Idempotent — subsequent calls are no-ops while the + tasks are still alive. + + Called automatically by :meth:`add_request`, so callers never need + to invoke this directly. + """ + if self._loop_task is not None and not self._loop_task.done(): + return + loop = asyncio.get_running_loop() + self._loop_task = loop.create_task(self._run()) + logger.debug("RequestResponseProcess: background tasks started") + + async def add_request( + self, + request: GenerateReqInput, + max_queued: Optional[int] = None, + ) -> Union[ReqState, List[ReqState]]: + """Enqueue request(s) and return the corresponding :class:`ReqState`(s). + + * **Single request** (``request.is_single is True``): behaves exactly as + before – registers one ``ReqState`` and enqueues one message. + * **Batch request** (``request.is_single is False``): splits the batch + into *N* individual sub-requests, registers a ``ReqState`` per rid, and + enqueues each sub-request separately so the downstream pipeline sees + independent messages. Returns a ``List[ReqState]`` in the same order + as the input rids. + + Parameters + ---------- + max_queued + If set, raise ``RuntimeError`` when the queue already has this many + items (back-pressure / overload protection). + """ + self.listen() + if max_queued is not None and self._request_queue.qsize() >= max_queued: + raise RuntimeError("Server overloaded: too many queued requests") + + if request.is_single: + rid = request.rid if isinstance(request.rid, str) else request.rid[0] + state = ReqState() + self._rid_to_state[rid] = state + await self._request_queue.put(request.to_request_dict()) + return state + + # Batch path: fan-out into individual sub-requests. + states: List[ReqState] = [] + for i in range(request.batch_size): + sub = request[i] + rid = sub.rid if isinstance(sub.rid, str) else str(sub.rid) + state = ReqState() + self._rid_to_state[rid] = state + await self._request_queue.put(sub.to_request_dict()) + states.append(state) + return states + + def remove_state(self, rid: str) -> None: + """Remove the ``ReqState`` for *rid* (called by the caller once done).""" + self._rid_to_state.pop(rid, None) + + async def abort_request(self, rid: str) -> None: + """Cancel a pending request and notify downstream processes.""" + state = self._rid_to_state.pop(rid, None) + if state is not None and not state.finished: + state.finished = True + state.out_list.append({"rid": rid, "error": "aborted", "finished": True}) + state.event.set() + if self._send_to_tokenizer is not None: + await self._send_to_tokenizer.send_pyobj({"rid": rid, "abort": True}) + + async def shutdown(self) -> None: + if self._loop_task is not None: + self._loop_task.cancel() + if self._send_to_tokenizer is not None: + close_zmq_socket(self._send_to_tokenizer) + if self._recv_from_detokenizer is not None: + close_zmq_socket(self._recv_from_detokenizer) + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + # ------------------------------------------------------------------ + # Internal loops + # ------------------------------------------------------------------ + + async def _run(self) -> None: + """Main loop: forward requests to tokenizer, receive results from detokenizer.""" + send_task = asyncio.create_task(self._send_loop()) + recv_task = asyncio.create_task(self._recv_loop()) + await asyncio.gather(send_task, recv_task) + + async def _send_loop(self) -> None: + """Drain the asyncio queue and push requests to the TokenizerProcess.""" + while True: + request = await self._request_queue.get() + await self._send_to_tokenizer.send_pyobj(request) + + # Stale state cleanup constants + _STALE_TIMEOUT = 1800 # 30 minutes + _CLEANUP_INTERVAL = 60 # seconds + + async def _recv_loop(self) -> None: + """Receive decoded results from DetokenizerProcess and dispatch to ReqStates.""" + last_cleanup = time.time() + while True: + # Use a timeout so that stale-state cleanup runs even when no + # results are flowing back from the detokenizer. + try: + result = await asyncio.wait_for( + self._recv_from_detokenizer.recv_pyobj(), + timeout=self._CLEANUP_INTERVAL, + ) + except asyncio.TimeoutError: + self._cleanup_stale_states() + last_cleanup = time.time() + continue + + rid = result.get("rid") + state = self._rid_to_state.get(rid) + if state is None: + logger.warning("Received result for unknown rid=%s", rid) + continue + state.out_list.append(result) + if result.get("finished", False): + state.finished = True + state.event.set() + + # Also run cleanup on the normal path when enough time has passed + now = time.time() + if now - last_cleanup > self._CLEANUP_INTERVAL: + last_cleanup = now + self._cleanup_stale_states() + + def _cleanup_stale_states(self) -> None: + """Remove request states that have been pending longer than ``_STALE_TIMEOUT``.""" + now = time.time() + stale = [ + r + for r, s in self._rid_to_state.items() + if not s.finished and (now - s.created_at) > self._STALE_TIMEOUT + ] + for r in stale: + logger.warning("Cleaning stale request state: rid=%s", r) + s = self._rid_to_state.pop(r) + s.finished = True + s.out_list.append({"rid": r, "error": "timeout", "finished": True}) + s.event.set() diff --git a/pymllm/orchestrator/scheduler_process.py b/pymllm/orchestrator/scheduler_process.py new file mode 100644 index 000000000..3bc3466a1 --- /dev/null +++ b/pymllm/orchestrator/scheduler_process.py @@ -0,0 +1,1069 @@ +""" +SchedulerProcess -- the central scheduling and inference hub. + +Receives tokenized requests from the TokenizerProcess, organises them into +batches, runs model forward passes via the **in-process** model runner, +and streams finished token IDs to the DetokenizerProcess. + +Architecture: the scheduler owns the :class:`ModelRunnerProcess` directly +(same process, direct function calls). GPU resources (KV cache, req pool +slots) are freed immediately when requests finish — no cross-process +communication needed. + +Request ingestion supports two modes: + 1. ZMQ path: Receive TokenizedGenerateReqInput via ZMQ recv_pyobj + 2. Shared queue fast path: Read from shared memory + multiprocessing queue + +The main ``event_loop``:: + + while True: + recv_requests() + process_input_requests() + batch = get_next_batch_to_run() # also frees finished GPU resources + if batch: + result = run_batch(batch) # direct call to model runner + process_batch_result(batch, result) + else: + idle_sleeper.sleep() # block until ZMQ data or timeout + stream_output() +""" + +import logging +import queue as stdlib_queue +import time +from collections import deque +from multiprocessing.connection import Connection +from typing import Any, Deque, Dict, List, Optional + +import zmq + +from pymllm.engine.forward_batch import ForwardMode +from pymllm.engine.io_struct import BatchTokenIDOutput, TokenizedGenerateReqInput +from pymllm.orchestrator.cuda_ipc_transport import ( + TensorTransportMode, + unwrap_mm_inputs_from_ipc, +) +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging +from pymllm.orchestrator.shared_memory_queue import SharedMemoryManager, TensorQueue + +logger = logging.getLogger(__name__) + +# Default scheduling limits +_DEFAULT_MAX_RUNNING_REQUESTS = 256 +_DEFAULT_IDLE_POLL_TIMEOUT_MS = 1000 +_DEFAULT_MAX_PREFILL_TOKENS = 8192 +_DEFAULT_MAX_TOTAL_TOKENS = 131072 +_DEFAULT_MAX_NEW_TOKENS = 32768 + + +# ====================================================================== +# IdleSleeper -- avoid busy-looping when no work is available +# ====================================================================== + + +class IdleSleeper: + """Block the scheduler thread when idle using ZMQ Poller. + + Avoids 100% CPU spinning when no requests are pending. The poller + wakes immediately when data arrives on any registered socket, so + request latency is not affected. + """ + + def __init__( + self, sockets: list, poll_timeout_ms: int = _DEFAULT_IDLE_POLL_TIMEOUT_MS + ): + self.poller = zmq.Poller() + for s in sockets: + self.poller.register(s, zmq.POLLIN) + self.poll_timeout_ms = poll_timeout_ms + + def sleep(self) -> None: + """Block until data arrives on any registered socket, or timeout.""" + self.poller.poll(self.poll_timeout_ms) + + +# ====================================================================== +# Req -- per-request state tracker +# ====================================================================== + + +class Req: + """Tracks a single request through its lifecycle (prefill -> decode -> finish). + + Created by :meth:`SchedulerProcess.process_input_requests` from a + :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput`. + """ + + __slots__ = ( + "rid", + "input_ids", + "input_text", + "sampling_params", + "mm_inputs", + "stream", + "return_logprob", + "logprob_start_len", + "top_logprobs_num", + # KV-cache state + "req_pool_idx", + "seq_len", + # Prefix-cache hit (set during scheduling when radix cache is active) + "prefix_len", + # Generation state + "output_ids", + "finished_reason", + "is_prefilled", + # Sampling parameters (parsed) + "max_new_tokens", + "temperature", + "top_p", + "top_k", + "stop_token_ids", + # Streaming + "read_offset", + # Prompt length (for token accounting) + "prompt_len", + ) + + def __init__( + self, + rid: str, + input_ids: List[int], + input_text: str = "", + sampling_params: Optional[Dict[str, Any]] = None, + mm_inputs: Optional[Dict[str, Any]] = None, + stream: bool = False, + return_logprob: bool = False, + logprob_start_len: int = -1, + top_logprobs_num: int = 0, + ): + self.rid = rid + self.input_ids = list(input_ids) + self.input_text = input_text + self.mm_inputs = mm_inputs + self.stream = stream + self.return_logprob = return_logprob + self.logprob_start_len = logprob_start_len + self.top_logprobs_num = top_logprobs_num + + # Parse sampling params + sp = sampling_params or {} + self.sampling_params = sp + self.max_new_tokens: int = sp.get("max_new_tokens", _DEFAULT_MAX_NEW_TOKENS) + self.temperature: float = sp.get("temperature", 1.0) + self.top_p: float = sp.get("top_p", 1.0) + self.top_k: int = sp.get("top_k", -1) + self.stop_token_ids: List[int] = list(sp.get("stop_token_ids", [])) + + # KV-cache state (assigned during scheduling) + self.req_pool_idx: int = -1 + self.seq_len: int = len(input_ids) + # Number of prefix tokens served from the radix/KV cache (0 = no hit). + # Updated by process_batch_result when the model runner reports a + # prefix cache hit. Used in _free_req_resources to correctly + # release the token budget. + self.prefix_len: int = 0 + + # Generation state + self.output_ids: List[int] = [] + self.finished_reason: Optional[str] = None + self.is_prefilled: bool = False + + # Streaming + self.read_offset: int = 0 + + # Prompt length + self.prompt_len: int = len(input_ids) + + def check_finished(self) -> bool: + """Check if this request has reached a finish condition. + + Sets ``finished_reason`` and returns True if finished. + Checks: + 1. Stop token (EOS tokens are merged into stop_token_ids during + :meth:`SchedulerProcess.process_input_requests`) + 2. ``max_new_tokens`` reached + """ + if self.finished_reason is not None: + return True + + if self.output_ids: + last_token = self.output_ids[-1] + if last_token in self.stop_token_ids: + self.finished_reason = "eos" + return True + + # Check max_new_tokens + if len(self.output_ids) >= self.max_new_tokens: + self.finished_reason = "length" + return True + + return False + + @property + def is_finished(self) -> bool: + return self.finished_reason is not None + + def abort(self) -> None: + """Mark this request as aborted.""" + self.finished_reason = "abort" + + def __repr__(self) -> str: + return ( + f"Req(rid={self.rid!r}, seq_len={self.seq_len}, " + f"out={len(self.output_ids)}, finished={self.finished_reason})" + ) + + +# ====================================================================== +# ScheduleBatch -- batch container +# ====================================================================== + + +class ScheduleBatch: + """Wraps a list of :class:`Req` objects for a single forward pass. + + Provides helpers to assemble the batch dict sent to the ModelRunnerProcess + in the format expected by :class:`~pymllm.engine.forward_batch.ForwardBatch`. + """ + + def __init__(self, reqs: List[Req], forward_mode: ForwardMode): + self.reqs = reqs + self.forward_mode = forward_mode + + @property + def batch_size(self) -> int: + return len(self.reqs) + + def prepare_for_extend(self) -> Dict[str, Any]: + """Assemble a batch dict for prefill / extend forward pass. + + Returns a dict with flattened ``input_ids``, per-request ``positions``, + ``req_pool_indices``, ``seq_lens``, ``extend_seq_lens``, + ``extend_prefix_lens``, and request metadata. + + Note: The scheduler sends the **full** input_ids (no prefix trimming). + The ModelRunnerProcess performs radix cache prefix matching and + rebuilds the tensors with actual prefix lengths before the forward + pass. The ``extend_prefix_lens`` here are always 0 from the + scheduler; they serve as placeholders. + """ + all_input_ids: List[int] = [] + all_positions: List[int] = [] + req_pool_indices: List[int] = [] + seq_lens: List[int] = [] + extend_seq_lens: List[int] = [] + extend_prefix_lens: List[int] = [] + requests_meta: List[Dict[str, Any]] = [] + + for req in self.reqs: + input_len = len(req.input_ids) + + # Send full input_ids; model runner will trim based on prefix + all_input_ids.extend(req.input_ids) + all_positions.extend(range(input_len)) + req_pool_indices.append(req.req_pool_idx) + seq_lens.append(req.seq_len) + extend_seq_lens.append(input_len) + extend_prefix_lens.append(0) + requests_meta.append( + { + "rid": req.rid, + "input_ids": req.input_ids, + "mm_inputs": req.mm_inputs, + "sampling_params": req.sampling_params, + "return_logprob": req.return_logprob, + "logprob_start_len": req.logprob_start_len, + "top_logprobs_num": req.top_logprobs_num, + } + ) + + return { + "forward_mode": "extend", + "batch_size": self.batch_size, + "input_ids": all_input_ids, + "positions": all_positions, + "req_pool_indices": req_pool_indices, + "seq_lens": seq_lens, + "extend_seq_lens": extend_seq_lens, + "extend_prefix_lens": extend_prefix_lens, + "requests": requests_meta, + "batch_id": id(self), + "created_at": time.time(), + } + + def prepare_for_decode(self) -> Dict[str, Any]: + """Assemble a batch dict for decode forward pass (one token per request). + + Returns a dict with one input token per request (the last generated + token), positions at ``seq_len``, and request metadata. + """ + all_input_ids: List[int] = [] + all_positions: List[int] = [] + req_pool_indices: List[int] = [] + seq_lens: List[int] = [] + requests_meta: List[Dict[str, Any]] = [] + + for req in self.reqs: + # For decode, the input is the last generated token + if req.output_ids: + all_input_ids.append(req.output_ids[-1]) + else: + # Fallback: last input token (shouldn't happen normally) + all_input_ids.append(req.input_ids[-1]) + all_positions.append(req.seq_len) + req_pool_indices.append(req.req_pool_idx) + seq_lens.append(req.seq_len) + requests_meta.append( + { + "rid": req.rid, + "sampling_params": req.sampling_params, + "return_logprob": req.return_logprob, + "logprob_start_len": req.logprob_start_len, + "top_logprobs_num": req.top_logprobs_num, + } + ) + + return { + "forward_mode": "decode", + "batch_size": self.batch_size, + "input_ids": all_input_ids, + "positions": all_positions, + "req_pool_indices": req_pool_indices, + "seq_lens": seq_lens, + "requests": requests_meta, + "batch_id": id(self), + "created_at": time.time(), + } + + def to_batch_dict(self) -> Dict[str, Any]: + """Build the batch dict appropriate for the current forward mode.""" + if self.forward_mode.is_extend(): + return self.prepare_for_extend() + else: + return self.prepare_for_decode() + + def __repr__(self) -> str: + return f"ScheduleBatch(mode={self.forward_mode.name}, size={self.batch_size})" + + +# ====================================================================== +# SchedulerProcess +# ====================================================================== + + +class SchedulerProcess: + """Runs inside a subprocess. Central hub that drives the inference loop.""" + + def __init__( + self, + recv_from_tokenizer_addr: str, + send_to_detokenizer_addr: str, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + gpu_id: int = 0, + shared_queue: Optional[TensorQueue] = None, + enable_shared_queue: bool = False, + tensor_transport_mode: TensorTransportMode = "default", + # Scheduling limits + max_running_requests: int = _DEFAULT_MAX_RUNNING_REQUESTS, + max_prefill_tokens: int = _DEFAULT_MAX_PREFILL_TOKENS, + max_total_tokens: int = _DEFAULT_MAX_TOTAL_TOKENS, + eos_token_ids: Optional[List[int]] = None, + default_max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS, + ): + # ZMQ addresses (tokenizer + detokenizer only) + self._recv_from_tokenizer_addr = recv_from_tokenizer_addr + self._send_to_detokenizer_addr = send_to_detokenizer_addr + + # Model config (for in-process model runner) + self._server_config = server_config + self._model_config = model_config + self._gpu_id = gpu_id + + # Shared queue configuration + self._shared_queue = shared_queue + self._enable_shared_queue = enable_shared_queue + self._tensor_transport_mode = tensor_transport_mode + + # ZMQ runtime objects (initialised in init_sockets) + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_tokenizer: Optional[zmq.Socket] = None + self._send_to_detokenizer: Optional[zmq.Socket] = None + self._poller: Optional[zmq.Poller] = None + + # In-process model runner (initialised in init_model) + self._model_runner = None + + # Request management -- three-stage pipeline + self._waiting_queue: Deque[TokenizedGenerateReqInput] = deque() + self._pending_queue: List[Req] = [] + self._running_batch: List[Req] = [] + self._finished: Deque[Dict[str, Any]] = deque() + + # Scheduling limits + self._max_running_requests = max_running_requests + self._max_prefill_tokens = max_prefill_tokens + + # KV-cache token budget (simplified single-GPU tracking). + self._max_total_tokens = max_total_tokens + self._used_tokens: int = 0 + + # EOS token(s) for finish detection + self._eos_token_ids: List[int] = list(eos_token_ids) if eos_token_ids else [] + + # Default max_new_tokens (from model config or fallback) + self._default_max_new_tokens = default_max_new_tokens + + # Monotonic request-slot counter (simplified; no GPU pool access) + self._next_req_pool_idx: int = 0 + + # ------ Throughput metrics ------ + # How often (in decode batches) to log throughput stats. + self._decode_log_interval: int = ( + server_config.decode_log_interval + if server_config is not None + and hasattr(server_config, "decode_log_interval") + else 40 + ) + # Accumulators reset at each log interval + self._num_prefill_tokens: int = 0 # new prefill tokens (excluding cache hits) + self._num_prefill_cache_tokens: int = 0 # prefill tokens served from cache + self._num_decode_tokens: int = 0 # generated decode tokens + self._num_prefill_reqs: int = 0 # prefill requests count + # Timestamps for throughput calculation + self._last_prefill_stats_tic: float = time.time() + self._last_decode_stats_tic: float = time.time() + # Forward pass counters + self._forward_ct_decode: int = 0 + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + + self._recv_from_tokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_tokenizer_addr, + bind=False, + ) + self._send_to_detokenizer = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_detokenizer_addr, + bind=True, + ) + + # Poller for non-blocking recv from tokenizer + self._poller = zmq.Poller() + self._poller.register(self._recv_from_tokenizer, zmq.POLLIN) + + # Idle sleeper: blocks the event loop when no batch is ready, + # wakes immediately on incoming ZMQ messages. + self._idle_sleeper = IdleSleeper([self._recv_from_tokenizer]) + + def init_model(self) -> None: + """Create and initialise the in-process model runner. + + Must be called after ``init_sockets`` and inside the subprocess + (after spawn) since it performs CUDA initialisation. + """ + from pymllm.orchestrator.model_runner_process import ModelRunnerProcess + + self._model_runner = ModelRunnerProcess( + gpu_id=self._gpu_id, + server_config=self._server_config, + model_config=self._model_config, + ) + self._model_runner.init_model() + logger.info("In-process model runner initialised on GPU %d", self._gpu_id) + + def event_loop(self) -> None: + """Infinite scheduling loop.""" + logger.info( + "SchedulerProcess event loop started (shared_queue=%s, transport=%s)", + self._enable_shared_queue, + self._tensor_transport_mode, + ) + while True: + self.recv_requests() + self.process_input_requests() + batch = self.get_next_batch_to_run() + if batch is not None: + result = self.run_batch(batch) + self.process_batch_result(batch, result) + else: + # No work available -- sleep until a new request arrives + # on the ZMQ socket (or timeout). Avoids busy-looping. + self._idle_sleeper.sleep() + self.stream_output() + + # ------------------------------------------------------------------ + # Step 1: receive tokenized requests (non-blocking) + # ------------------------------------------------------------------ + + def recv_requests(self) -> None: + """Non-blocking receive of tokenized requests from TokenizerProcess. + + Supports two modes: + 1. Legacy ZMQ: Uses ``zmq.Poller`` with a short timeout + 2. Shared queue: Non-blocking get from multiprocessing.Queue + + Messages are either: + * A :class:`~pymllm.engine.io_struct.TokenizedGenerateReqInput` + dataclass – appended to ``_waiting_queue``. + * A plain abort sentinel dict ``{"rid": ..., "abort": True}`` – handled + inline by removing the matching rid from the waiting queue. + """ + if self._enable_shared_queue and self._shared_queue is not None: + self._recv_from_shared_queue() + else: + self._recv_from_zmq() + + def _recv_from_zmq(self) -> None: + """Receive requests via legacy ZMQ path.""" + while True: + events = dict(self._poller.poll(timeout=0)) # non-blocking + if self._recv_from_tokenizer not in events: + break + msg = self._recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK) + # Abort sentinel: plain dict with "abort" key. + if isinstance(msg, dict) and msg.get("abort"): + rid = msg.get("rid") + logger.debug("Scheduler received abort for rid=%s", rid) + self._waiting_queue = type(self._waiting_queue)( + r for r in self._waiting_queue if r.rid != rid + ) + # Also abort from pending queue + self._abort_request(rid) + else: + self._waiting_queue.append(msg) + + def _recv_from_shared_queue(self) -> None: + """Receive requests via shared memory + shared queue fast path. + + After reading a ``(rid, shm_name, mm_inputs)`` tuple from the queue: + 1. The tokenized metadata is read from the POSIX shared memory segment. + 2. If CUDA IPC is enabled, ``mm_inputs`` may contain + :class:`~pymllm.orchestrator.cuda_ipc_transport.CudaIpcTensorTransportProxy` + or :class:`~pymllm.orchestrator.cuda_ipc_transport.TransportProxyTensor` + objects that are reconstructed by calling + :func:`~pymllm.orchestrator.cuda_ipc_transport.unwrap_mm_inputs_from_ipc`. + This step also increments sync flags so the sender can recycle pool chunks. + 3. A full ``TokenizedGenerateReqInput`` is assembled and appended to + ``_waiting_queue``. + """ + while True: + try: + rid, shm_name, mm_inputs = self._shared_queue.get(timeout=0.002) + + # Read metadata from shared memory (and unlink immediately) + metadata: TokenizedGenerateReqInput = SharedMemoryManager.read_metadata( + shm_name, unlink=True + ) + + # Reconstruct GPU tensors from CUDA IPC handles (if any) + if self._tensor_transport_mode in ("cuda_ipc", "cuda_ipc_pool"): + mm_inputs = unwrap_mm_inputs_from_ipc(mm_inputs) + + # Reassemble the full request + full_request = TokenizedGenerateReqInput( + rid=metadata.rid, + input_text=metadata.input_text, + input_ids=metadata.input_ids, + mm_inputs=mm_inputs, + sampling_params=metadata.sampling_params, + stream=metadata.stream, + return_logprob=metadata.return_logprob, + logprob_start_len=metadata.logprob_start_len, + top_logprobs_num=metadata.top_logprobs_num, + lora_path=metadata.lora_path, + session_params=metadata.session_params, + ) + + self._waiting_queue.append(full_request) + logger.debug("Received request %s from shared queue", rid) + + except stdlib_queue.Empty: + break + except Exception as exc: + logger.error( + "Error receiving from shared queue: %s", exc, exc_info=True + ) + try: + if "shm_name" in locals(): + SharedMemoryManager.cleanup(shm_name) + except Exception: + pass + break + + # ------------------------------------------------------------------ + # Step 2: process input requests + # ------------------------------------------------------------------ + + def process_input_requests(self) -> None: + """Convert raw :class:`TokenizedGenerateReqInput` in ``_waiting_queue`` + into :class:`Req` objects and move them to ``_pending_queue``. + + For each request: + 1. Parse sampling params (max_new_tokens, temperature, top_p, top_k, + stop_token_ids with defaults from EOS token). + 2. Create a ``Req`` object. + 3. Move from ``_waiting_queue`` to ``_pending_queue``. + """ + while self._waiting_queue: + raw = self._waiting_queue.popleft() + + # Merge EOS token into stop_token_ids if not already present + sp = dict(raw.sampling_params) if raw.sampling_params else {} + # Inject model-aware default for max_new_tokens when not provided + if "max_new_tokens" not in sp: + sp["max_new_tokens"] = self._default_max_new_tokens + stop_ids = list(sp.get("stop_token_ids", [])) + for eid in self._eos_token_ids: + if eid not in stop_ids: + stop_ids.append(eid) + sp["stop_token_ids"] = stop_ids + + req = Req( + rid=raw.rid, + input_ids=raw.input_ids, + input_text=raw.input_text, + sampling_params=sp, + mm_inputs=raw.mm_inputs, + stream=raw.stream, + return_logprob=raw.return_logprob, + logprob_start_len=raw.logprob_start_len, + top_logprobs_num=raw.top_logprobs_num, + ) + self._pending_queue.append(req) + logger.debug("Processed input request %s (len=%d)", req.rid, req.seq_len) + + # ------------------------------------------------------------------ + # Step 3: build the next batch + # ------------------------------------------------------------------ + + def get_next_batch_to_run(self) -> Optional[ScheduleBatch]: + """Implements continuous batching with two phases. + + 1. **Filter finished**: Remove finished requests from + ``_running_batch`` and free their token budget. + 2. **Schedule new prefills**: From ``_pending_queue``, admit + requests that fit within the token budget and + ``max_running_requests``. + 3. **Build batch**: + - If new prefill requests exist -> EXTEND batch + - Else if running decode requests exist -> DECODE batch + - Else -> None (idle) + + Note on prefix cache: The actual prefix matching is done by the + ModelRunnerProcess (which owns the RadixCache). The scheduler + uses ``input_len`` as a conservative budget estimate. The model + runner reports back actual ``prefix_len`` in results, and the + scheduler adjusts ``_used_tokens`` accordingly in + ``process_batch_result``. + """ + # Phase 1: filter finished requests from running batch + still_running: List[Req] = [] + for req in self._running_batch: + if req.is_finished: + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + still_running.append(req) + self._running_batch = still_running + + # Phase 2: schedule new prefill requests from pending queue + new_prefill: List[Req] = [] + remaining_pending: List[Req] = [] + prefill_token_budget = self._max_prefill_tokens + + for req in self._pending_queue: + input_len = len(req.input_ids) + total_running = len(self._running_batch) + len(new_prefill) + + # Check capacity constraints. + # We reserve the full input_len as KV budget (conservative). + # If the model runner finds a prefix cache hit, some tokens + # won't need new KV allocation; the budget is corrected in + # process_batch_result. + can_fit_request = total_running < self._max_running_requests + can_fit_tokens = (self._used_tokens + input_len) <= self._max_total_tokens + can_fit_prefill = input_len <= prefill_token_budget + + if can_fit_request and can_fit_tokens and can_fit_prefill: + # Allocate req pool slot + req.req_pool_idx = self._next_req_pool_idx + self._next_req_pool_idx += 1 + # Reserve token budget (full input_len as conservative estimate) + self._used_tokens += input_len + prefill_token_budget -= input_len + new_prefill.append(req) + logger.debug( + "Scheduled prefill for %s (len=%d, used=%d/%d)", + req.rid, + input_len, + self._used_tokens, + self._max_total_tokens, + ) + else: + remaining_pending.append(req) + + self._pending_queue = remaining_pending + + # Phase 3: build batch + if new_prefill: + return ScheduleBatch(new_prefill, ForwardMode.EXTEND) + elif self._running_batch: + return ScheduleBatch(self._running_batch, ForwardMode.DECODE) + else: + return None + + # ------------------------------------------------------------------ + # Step 4: run the batch via ModelRunnerProcess + # ------------------------------------------------------------------ + + def run_batch(self, batch: ScheduleBatch) -> Dict[str, Any]: + """Execute the batch via the in-process model runner. + + Direct function call — no ZMQ serialisation overhead. + """ + batch_dict = batch.to_batch_dict() + return self._model_runner._forward_batch(batch_dict) + + # ------------------------------------------------------------------ + # Step 5: process batch result + # ------------------------------------------------------------------ + + def process_batch_result( + self, batch: ScheduleBatch, result: Dict[str, Any] + ) -> None: + """Handle the result returned by the ModelRunnerProcess. + + For each request in the result: + 1. Update ``prefix_len`` from the model runner's radix cache hit. + 2. Adjust ``_used_tokens`` if a prefix cache hit was found (the + scheduler over-reserved during scheduling). + 3. Append new token(s) to ``req.output_ids``. + 4. Increment ``req.seq_len``. + 5. Call ``req.check_finished()`` (EOS token, max_new_tokens). + 6. If prefill request: mark ``req.is_prefilled = True``, move to + running batch for decode. + 7. If finished: collect for output, free KV-cache budget. + """ + # Build a rid -> Req lookup for the batch + rid_to_req: Dict[str, Req] = {req.rid: req for req in batch.reqs} + + # The result may contain per-request outputs in "finished" and + # "unfinished" lists, or a flat "outputs" list. Handle both. + output_items: List[Dict[str, Any]] = [] + output_items.extend(result.get("finished", [])) + output_items.extend(result.get("unfinished", [])) + if "outputs" in result: + output_items.extend(result["outputs"]) + + for out in output_items: + rid = out.get("rid") + req = rid_to_req.get(rid) + if req is None: + logger.warning("Result for unknown rid=%s, skipping", rid) + continue + + # Update prefix_len from model runner's radix cache matching. + # The model runner reports the actual prefix_len it found. + # The scheduler originally reserved full input_len in + # get_next_batch_to_run; correct the over-reservation now. + if "prefix_len" in out and batch.forward_mode.is_extend(): + actual_prefix_len = out["prefix_len"] + if actual_prefix_len > req.prefix_len: + saved = actual_prefix_len - req.prefix_len + req.prefix_len = actual_prefix_len + # Give back the over-reserved tokens. The model runner + # reused cached KV for `saved` tokens, so those tokens + # do not consume new KV pool slots. + self._used_tokens = max(0, self._used_tokens - saved) + logger.info( + "Prefix cache hit for rid=%s: %d tokens reused, " + "budget adjusted by -%d (used=%d/%d)", + rid, + actual_prefix_len, + saved, + self._used_tokens, + self._max_total_tokens, + ) + + # Append generated token(s) + new_token_ids = out.get("output_token_ids", []) + if isinstance(new_token_ids, int): + new_token_ids = [new_token_ids] + req.output_ids.extend(new_token_ids) + req.seq_len += len(new_token_ids) + + # Update token budget for newly generated tokens + self._used_tokens += len(new_token_ids) + + # Check finish conditions (EOS tokens already in stop_token_ids) + req.check_finished() + + # Process batch requests based on forward mode + if batch.forward_mode.is_extend(): + # Prefill batch: mark as prefilled and route + for req in batch.reqs: + req.is_prefilled = True + if req.is_finished: + self._collect_finished_output(req) + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + self._running_batch.append(req) + + # --- Accumulate prefill metrics --- + total_input = 0 + total_cached = 0 + for req in batch.reqs: + total_input += req.prompt_len + total_cached += req.prefix_len + self._num_prefill_tokens += total_input - total_cached + self._num_prefill_cache_tokens += total_cached + self._num_prefill_reqs += len(batch.reqs) + self._log_prefill_stats() + else: + # Decode batch: check finish and collect + new_running: List[Req] = [] + for req in batch.reqs: + if req.is_finished: + self._collect_finished_output(req) + self._model_runner._free_rid_resources(req.rid) + self._free_req_resources(req) + else: + new_running.append(req) + self._running_batch = new_running + + # --- Accumulate decode metrics --- + self._num_decode_tokens += batch.batch_size # 1 token per request + self._forward_ct_decode += 1 + if ( + self._decode_log_interval > 0 + and self._forward_ct_decode % self._decode_log_interval == 0 + ): + self._log_decode_stats() + + # ------------------------------------------------------------------ + # Step 6: stream output to DetokenizerProcess + # ------------------------------------------------------------------ + + def stream_output(self) -> None: + """Send finished/streaming outputs to the DetokenizerProcess. + + Produces :class:`~pymllm.engine.io_struct.BatchTokenIDOutput`-compatible + dicts. For streaming requests, intermediate tokens are also sent. + """ + # Collect streaming outputs from running requests (skip aborted) + for req in self._running_batch: + if req.finished_reason == "abort": + continue + if req.stream and len(req.output_ids) > req.read_offset: + decode_ids = req.output_ids[req.read_offset :] + output = { + "rids": [req.rid], + "finished_reasons": [None], + "decode_ids": decode_ids, + "read_offsets": [req.read_offset], + "output_ids": list(req.output_ids), + "skip_special_tokens": [True], + "prompt_tokens": [req.prompt_len], + "completion_tokens": [len(req.output_ids)], + } + req.read_offset = len(req.output_ids) + self._send_to_detokenizer.send_pyobj(output) + + # Send finished outputs + while self._finished: + item = self._finished.popleft() + self._send_to_detokenizer.send_pyobj(item) + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _log_prefill_stats(self) -> None: + """Log prefill throughput at INFO level (called after each prefill batch).""" + now = time.time() + elapsed = now - self._last_prefill_stats_tic + self._last_prefill_stats_tic = now + + if elapsed > 0: + input_throughput = self._num_prefill_tokens / elapsed + else: + input_throughput = 0.0 + + logger.info( + "Prefill batch: %d reqs, " + "new tokens: %d, " + "cached tokens: %d, " + "input throughput: %.2f token/s", + self._num_prefill_reqs, + self._num_prefill_tokens, + self._num_prefill_cache_tokens, + input_throughput, + ) + # Reset accumulators + self._num_prefill_tokens = 0 + self._num_prefill_cache_tokens = 0 + self._num_prefill_reqs = 0 + + def _log_decode_stats(self) -> None: + """Log decode throughput at INFO level (called every decode_log_interval batches).""" + now = time.time() + elapsed = now - self._last_decode_stats_tic + self._last_decode_stats_tic = now + + if elapsed > 0: + gen_throughput = self._num_decode_tokens / elapsed + else: + gen_throughput = 0.0 + + logger.info( + "Decode: %d steps, " + "gen tokens: %d, " + "running: %d reqs, " + "gen throughput: %.2f token/s", + self._forward_ct_decode, + self._num_decode_tokens, + len(self._running_batch), + gen_throughput, + ) + # Reset accumulators + self._num_decode_tokens = 0 + self._forward_ct_decode = 0 + + def _collect_finished_output(self, req: Req) -> None: + """Build a finished output dict and add it to ``_finished``.""" + decode_ids = req.output_ids[req.read_offset :] + output: Dict[str, Any] = { + "rids": [req.rid], + "finished_reasons": [req.finished_reason], + "decode_ids": decode_ids, + "read_offsets": [req.read_offset], + "output_ids": list(req.output_ids), + "skip_special_tokens": [True], + "prompt_tokens": [req.prompt_len], + "completion_tokens": [len(req.output_ids)], + } + self._finished.append(output) + logger.debug( + "Request %s finished: reason=%s, tokens=%d", + req.rid, + req.finished_reason, + len(req.output_ids), + ) + + def _free_req_resources(self, req: Req) -> None: + """Release KV-cache token budget for a finished request. + + The budget was charged as follows: + - At scheduling: ``+input_len`` (full prompt as conservative estimate) + - After prefix correction: ``-prefix_len`` (cached prefix doesn't need + new KV allocation; model runner manages those via radix cache) + - At each decode step: ``+1`` per generated token + + So the net charge for this request is: + ``(input_len - prefix_len) + num_decode_tokens`` + = ``seq_len - prefix_len`` + + We release exactly that amount. + """ + tokens_to_free = req.seq_len - req.prefix_len + self._used_tokens = max(0, self._used_tokens - tokens_to_free) + req.req_pool_idx = -1 + + def _abort_request(self, rid: str) -> None: + """Abort a request by rid from pending or running queues.""" + # Remove from pending queue + self._pending_queue = [r for r in self._pending_queue if r.rid != rid] + # Abort in running batch + for req in self._running_batch: + if req.rid == rid: + req.abort() + break + + # ------------------------------------------------------------------ + # Cleanup + # ------------------------------------------------------------------ + + def shutdown(self) -> None: + if self._model_runner is not None: + self._model_runner.shutdown() + for sock in ( + self._recv_from_tokenizer, + self._send_to_detokenizer, + ): + if sock is not None: + sock.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_scheduler_process( + recv_from_tokenizer_addr: str, + send_to_detokenizer_addr: str, + pipe_writer: Connection, + shared_queue: Optional[TensorQueue] = None, + enable_shared_queue: bool = False, + tensor_transport_mode: TensorTransportMode = "default", + log_level: str = "info", + default_max_new_tokens: int = _DEFAULT_MAX_NEW_TOKENS, + eos_token_ids: Optional[List[int]] = None, + server_config: Optional[Any] = None, + model_config: Optional[Any] = None, + gpu_id: int = 0, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``. + + The scheduler process now also owns the model runner, + so model initialisation happens here. + """ + setup_subprocess_logging(log_level) + + # Extract scheduling limits from server_config (fall back to defaults) + max_running = _DEFAULT_MAX_RUNNING_REQUESTS + max_prefill = _DEFAULT_MAX_PREFILL_TOKENS + max_total = _DEFAULT_MAX_TOTAL_TOKENS + if server_config is not None: + if getattr(server_config, "max_running_requests", None) is not None: + max_running = server_config.max_running_requests + if getattr(server_config, "max_prefill_tokens", None) is not None: + max_prefill = server_config.max_prefill_tokens + if getattr(server_config, "max_total_tokens", None) is not None: + max_total = server_config.max_total_tokens + + proc = SchedulerProcess( + recv_from_tokenizer_addr, + send_to_detokenizer_addr, + server_config=server_config, + model_config=model_config, + gpu_id=gpu_id, + shared_queue=shared_queue, + enable_shared_queue=enable_shared_queue, + tensor_transport_mode=tensor_transport_mode, + max_running_requests=max_running, + max_prefill_tokens=max_prefill, + max_total_tokens=max_total, + default_max_new_tokens=default_max_new_tokens, + eos_token_ids=eos_token_ids, + ) + proc.init_sockets() + proc.init_model() + + pipe_writer.send({"status": "ready", "process": "scheduler"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/orchestrator/shared_memory_queue.py b/pymllm/orchestrator/shared_memory_queue.py new file mode 100644 index 000000000..2f006bdc0 --- /dev/null +++ b/pymllm/orchestrator/shared_memory_queue.py @@ -0,0 +1,292 @@ +""" +Shared memory and queue utilities for fast IPC between tokenizer and scheduler. + +This module implements the shared-queue fast path to avoid expensive ZMQ +serialization of large multimodal tensors. + +## Design + +- **Metadata lane**: Small tokenized objects are written to a POSIX shared memory + segment keyed by the request ID (``rid``). The scheduler reads and immediately + unlinks the segment. + +- **Tensor lane**: Large tensors can be transported in one of three modes, + controlled by ``TensorTransportMode`` (passed at queue construction time): + + * ``"default"`` – CPU tensors only. GPU tensors are moved to POSIX shared + memory via ``tensor.share_memory_()`` (or left on CPU if already there). + This is the original behaviour and requires no CUDA support. + + * ``"cuda_ipc"`` – GPU tensors stay on GPU and are wrapped in + :class:`~pymllm.orchestrator.cuda_ipc_transport.TransportProxyTensor`. On the + receiver side the proxy's ``__setstate__`` automatically reconstructs the + tensor from the CUDA IPC handle during unpickling. CPU tensors are handled as + in ``"default"`` mode. **Caveat**: GPU memory is not freed until the sender + process exits (PyTorch limitation). Prefer ``"cuda_ipc_pool"`` for services. + + * ``"cuda_ipc_pool"`` – GPU tensors are copied into a pre-allocated + :class:`~pymllm.orchestrator.cuda_ipc_transport.MmItemMemoryPool` workspace and + wrapped in :class:`~pymllm.orchestrator.cuda_ipc_transport.CudaIpcTensorTransportProxy`. + After the receiver copies the data it increments a sync flag and the sender's + recycler thread returns the chunk to the pool. This avoids GPU memory leaks. + CPU tensors are handled as in ``"default"`` mode. + +## Key relationship with CUDA IPC + +``"default"`` and ``"cuda_ipc*"`` modes are **mutually exclusive for GPU tensors**: + +- In ``"default"`` mode, GPU tensors that need to cross process boundaries must + first be moved to CPU (``share_memory_()``). This incurs a GPU→CPU copy. +- In ``"cuda_ipc*"`` modes, GPU tensors are shared as-is via CUDA IPC handles; + no copy to CPU is needed. + +CPU tensors are always handled via ``share_memory_()`` regardless of the mode. +""" + +from __future__ import annotations + +import logging +import pickle +import uuid +from multiprocessing import Queue +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Dict, Literal, Optional + +import torch + +from pymllm.orchestrator.cuda_ipc_transport import ( + MmItemMemoryPool, + TensorTransportMode, + unwrap_mm_inputs_from_ipc, + wrap_mm_inputs_for_ipc, +) + +logger = logging.getLogger(__name__) + + +class SharedMemoryManager: + """Manages shared memory segments for passing metadata between processes. + + Each tokenized request's metadata is written to a unique shared memory + segment keyed by its request ID (rid). The scheduler reads and immediately + unlinks the segment to prevent memory leaks. + """ + + @staticmethod + def write_metadata(rid: str, metadata: Any) -> str: + """Write metadata to shared memory and return the segment name. + + Args: + rid: Request ID (used as part of the shared memory name) + metadata: Serializable metadata object + + Returns: + str: The shared memory segment name + """ + data = pickle.dumps(metadata) + size = len(data) + shm_name = f"pymllm_meta_{rid}_{uuid.uuid4().hex[:8]}" + try: + shm = SharedMemory(name=shm_name, create=True, size=size) + shm.buf[:size] = data + shm.close() + logger.debug("Wrote %d bytes to shared memory %s", size, shm_name) + return shm_name + except Exception as exc: + logger.error("Failed to write metadata to shared memory: %s", exc) + raise + + @staticmethod + def read_metadata(shm_name: str, unlink: bool = True) -> Any: + """Read metadata from shared memory and optionally unlink it. + + Args: + shm_name: The shared memory segment name + unlink: If True, immediately unlink the segment after reading + + Returns: + The deserialized metadata object + """ + try: + shm = SharedMemory(name=shm_name, create=False) + data = bytes(shm.buf[:]) + metadata = pickle.loads(data) + shm.close() + if unlink: + try: + shm.unlink() + logger.debug("Read and unlinked shared memory %s", shm_name) + except FileNotFoundError: + pass + return metadata + except Exception as exc: + logger.error( + "Failed to read metadata from shared memory %s: %s", shm_name, exc + ) + raise + + @staticmethod + def cleanup(shm_name: str) -> None: + """Manually cleanup a shared memory segment (for error recovery).""" + try: + shm = SharedMemory(name=shm_name, create=False) + shm.close() + shm.unlink() + logger.debug("Cleaned up shared memory %s", shm_name) + except FileNotFoundError: + pass + except Exception as exc: + logger.warning("Failed to cleanup shared memory %s: %s", shm_name, exc) + + +class TensorQueue: + """Queue for passing large tensors between processes. + + Depending on ``transport_mode``, GPU tensors are either moved to CPU shared + memory (``"default"``) or kept on GPU and shared via CUDA IPC handles + (``"cuda_ipc"`` / ``"cuda_ipc_pool"``). + + Args: + maxsize: Maximum queue size (0 for unlimited). + transport_mode: Controls how GPU tensors are transported. + pool: Required when ``transport_mode == "cuda_ipc_pool"``. + """ + + def __init__( + self, + maxsize: int = 0, + transport_mode: TensorTransportMode = "default", + pool: Optional[MmItemMemoryPool] = None, + ) -> None: + # pool is allowed to be None at construction time for "cuda_ipc_pool" mode + # because the pool is initialised lazily inside the sender subprocess. + # The pool reference is injected later via _pool attribute assignment. + self._queue: Queue = Queue(maxsize=maxsize) + self._transport_mode = transport_mode + self._pool = pool + + # ------------------------------------------------------------------ + # Producer side + # ------------------------------------------------------------------ + + def put( + self, + rid: str, + shm_name: str, + mm_inputs: Optional[Dict[str, Any]], + ) -> None: + """Put a request into the queue. + + GPU tensors inside *mm_inputs* are wrapped according to + ``transport_mode`` before being placed into the underlying + ``multiprocessing.Queue``. + + Args: + rid: Request ID. + shm_name: Shared memory segment name for the tokenized metadata. + mm_inputs: Multimodal inputs dict (may contain CUDA tensors). + """ + if mm_inputs is not None: + if self._transport_mode in ("cuda_ipc", "cuda_ipc_pool"): + if self._transport_mode == "cuda_ipc_pool" and self._pool is None: + # Pool not yet initialised (race condition or CUDA unavailable); + # fall back to simple CUDA IPC for this message. + effective_mode = "cuda_ipc" + else: + effective_mode = self._transport_mode + # Wrap CUDA tensors in IPC proxies (stays on GPU, no copy to CPU) + mm_inputs = wrap_mm_inputs_for_ipc( + mm_inputs, + transport_mode=effective_mode, + pool=self._pool, + ) + # CPU tensors within mm_inputs are still shared via share_memory_() + mm_inputs = self._share_cpu_tensors(mm_inputs) + else: + # "default": move all tensors to CPU shared memory + mm_inputs = self._make_tensors_shareable(mm_inputs) + + self._queue.put((rid, shm_name, mm_inputs)) + logger.debug("Put request %s into tensor queue (shm=%s)", rid, shm_name) + + # ------------------------------------------------------------------ + # Consumer side + # ------------------------------------------------------------------ + + def get( + self, timeout: Optional[float] = None + ) -> tuple[str, str, Optional[Dict[str, Any]]]: + """Get a request from the queue. + + GPU tensors wrapped as IPC proxies are **not** automatically + reconstructed here – the caller (scheduler) must call + :func:`~pymllm.orchestrator.cuda_ipc_transport.unwrap_mm_inputs_from_ipc` + after retrieval. + + Args: + timeout: Timeout in seconds (None for blocking). + + Returns: + Tuple of ``(rid, shm_name, mm_inputs)``. + """ + rid, shm_name, mm_inputs = self._queue.get(timeout=timeout) + logger.debug("Got request %s from tensor queue (shm=%s)", rid, shm_name) + return rid, shm_name, mm_inputs + + # ------------------------------------------------------------------ + # Queue introspection + # ------------------------------------------------------------------ + + def empty(self) -> bool: + return self._queue.empty() + + def qsize(self) -> int: + try: + return self._queue.qsize() + except NotImplementedError: + return 0 + + def close(self) -> None: + self._queue.close() + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + @staticmethod + def _make_tensors_shareable(data: Any) -> Any: + """Recursively move all tensors (CPU and CUDA) to POSIX shared memory. + + GPU tensors are first moved to CPU (incurring a device copy), then + placed in shared memory. This is the ``"default"`` path. + """ + if isinstance(data, torch.Tensor): + if data.is_cuda: + data = data.cpu() + if not data.is_shared(): + data = data.share_memory_() + return data + elif isinstance(data, dict): + return {k: TensorQueue._make_tensors_shareable(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [TensorQueue._make_tensors_shareable(item) for item in data] + return type(data)(result) + else: + return data + + @staticmethod + def _share_cpu_tensors(data: Any) -> Any: + """Recursively place CPU tensors in shared memory (GPU tensors are already + wrapped as IPC proxies and must not be touched here). + """ + if isinstance(data, torch.Tensor) and not data.is_cuda: + if not data.is_shared(): + data = data.share_memory_() + return data + elif isinstance(data, dict): + return {k: TensorQueue._share_cpu_tensors(v) for k, v in data.items()} + elif isinstance(data, (list, tuple)): + result = [TensorQueue._share_cpu_tensors(item) for item in data] + return type(data)(result) + else: + return data diff --git a/pymllm/orchestrator/tokenizer_process.py b/pymllm/orchestrator/tokenizer_process.py new file mode 100644 index 000000000..44a4c897c --- /dev/null +++ b/pymllm/orchestrator/tokenizer_process.py @@ -0,0 +1,509 @@ +""" +TokenizerProcess -- subprocess that tokenizes incoming raw requests. + +Receives raw requests from RequestResponseProcess via ZMQ, tokenizes them, +and forwards the tokenized payloads to the SchedulerProcess. + +Supports two transport modes (controlled by ``enable_shared_queue`` and +``tensor_transport_mode`` in the tokenizer config): + +1. **Legacy ZMQ path** (``enable_shared_queue=False``): + Tokenized objects are sent directly via ``ZMQ send_pyobj`` (pickle). This + is simple but slow for large multimodal tensors. + +2. **Shared queue fast path** (``enable_shared_queue=True``): + Metadata is written to POSIX shared memory and the queue carries a + lightweight ``(rid, shm_name, mm_inputs)`` tuple. The GPU tensors inside + ``mm_inputs`` are transported differently depending on ``tensor_transport_mode``: + + * ``"default"`` – GPU tensors are moved to CPU first (GPU→CPU copy), + then placed in POSIX shared memory. + * ``"cuda_ipc"`` – GPU tensors stay on GPU; they are wrapped in a + :class:`TransportProxyTensor` whose pickle uses CUDA IPC handles. + Simple but may leak GPU memory. + * ``"cuda_ipc_pool"`` – GPU tensors are copied into a pre-allocated + :class:`MmItemMemoryPool` workspace and shared via pool-chunk IPC + handles. Chunks are recycled; no GPU memory is leaked. +""" + +import logging +from multiprocessing.connection import Connection +from typing import Any, Dict, List, Optional, Union + +import zmq +from transformers import AutoProcessor, AutoTokenizer + +from pymllm.engine.io_struct import TokenizedGenerateReqInput +from pymllm.orchestrator.cuda_ipc_transport import MmItemMemoryPool, TensorTransportMode +from pymllm.orchestrator.ipc_utils import create_zmq_socket, setup_subprocess_logging +from pymllm.orchestrator.shared_memory_queue import SharedMemoryManager, TensorQueue + +logger = logging.getLogger(__name__) + + +class TokenizerProcess: + """Runs inside a subprocess spawned by ``torch.multiprocessing``.""" + + def __init__( + self, + recv_from_rr_addr: str, + send_to_scheduler_addr: str, + tokenizer_cfg: Dict[str, Any], + shared_queue: Optional[TensorQueue] = None, + ): + """ + Parameters + ---------- + tokenizer_cfg: + Serialisable dict built by the parent process (``Engine``) before + spawning. Required keys: + + * ``tokenizer_path`` – str, path to the tokenizer directory. + * ``tokenizer_mode`` – ``"auto" | "slow" | "fast"``. + * ``trust_remote_code`` – bool. + * ``context_length`` – Optional[int], explicit cap; inferred + from ``hf_config`` when ``None``. + * ``hf_config`` – Optional HuggingFace PretrainedConfig. + * ``enable_shared_queue`` – bool, whether to use shared memory fast path. + * ``tensor_transport_mode`` – ``"default" | "cuda_ipc" | "cuda_ipc_pool"``. + * ``cuda_ipc_pool_size_mb`` – int, pool size in MB (cuda_ipc_pool only). + * ``cuda_ipc_recycle_interval`` – float, recycler sleep interval (s). + + shared_queue: + Optional :class:`TensorQueue` for the shared memory fast path. + When *transport_mode* is ``"cuda_ipc_pool"`` this queue should have + been constructed with a ``MmItemMemoryPool``; the ``TokenizerProcess`` + initialises its own pool in that case. + """ + self._recv_from_rr_addr = recv_from_rr_addr + self._send_to_scheduler_addr = send_to_scheduler_addr + self._tokenizer_cfg = tokenizer_cfg + self._enable_shared_queue = tokenizer_cfg.get("enable_shared_queue", False) + self._shared_queue = shared_queue + + # Tensor transport configuration + self._transport_mode: TensorTransportMode = tokenizer_cfg.get( + "tensor_transport_mode", "default" + ) + # Pool for cuda_ipc_pool mode – will be initialised lazily when the + # process first encounters a CUDA tensor. + self._ipc_pool: Optional[MmItemMemoryPool] = None + if self._transport_mode == "cuda_ipc_pool": + # The pool must be created inside the subprocess (after fork/spawn) + # because it allocates CUDA memory. We defer to _ensure_pool(). + pool_mb: int = int(tokenizer_cfg.get("cuda_ipc_pool_size_mb", 512)) + recycle: float = float(tokenizer_cfg.get("cuda_ipc_recycle_interval", 0.1)) + self._ipc_pool_size_mb = pool_mb + self._ipc_recycle_interval = recycle + + self._zmq_ctx: Optional[zmq.Context] = None + self._recv_from_rr: Optional[zmq.Socket] = None + self._send_to_scheduler: Optional[zmq.Socket] = None + + self._tokenizer = None + self._mm_processor = None + self._context_length: Optional[int] = None + + self._init_tokenizers() + + # ------------------------------------------------------------------ + # Lifecycle + # ------------------------------------------------------------------ + + def init_sockets(self) -> None: + self._zmq_ctx = zmq.Context() + self._recv_from_rr = create_zmq_socket( + self._zmq_ctx, + zmq.PULL, + self._recv_from_rr_addr, + bind=False, + ) + self._send_to_scheduler = create_zmq_socket( + self._zmq_ctx, + zmq.PUSH, + self._send_to_scheduler_addr, + bind=True, + ) + + def event_loop(self) -> None: + """Infinite loop: recv raw request -> tokenize -> send to scheduler.""" + logger.info( + "TokenizerProcess event loop started (shared_queue=%s, transport=%s)", + self._enable_shared_queue, + self._transport_mode, + ) + while True: + raw_request: Dict[str, Any] = self._recv_from_rr.recv_pyobj() + tokenized = self._tokenize(raw_request) + + if self._enable_shared_queue and self._shared_queue is not None: + # Shared queue fast path + self._send_via_shared_queue(tokenized) + else: + # Legacy ZMQ path + self._send_to_scheduler.send_pyobj(tokenized) + + def _send_via_shared_queue( + self, tokenized: Union[TokenizedGenerateReqInput, Dict[str, Any]] + ) -> None: + """Send tokenized request via shared memory + shared queue fast path. + + GPU tensors inside ``mm_inputs`` are handled according to + ``self._transport_mode``: + + * ``"default"`` – moved to CPU via ``share_memory_()`` by ``TensorQueue``. + * ``"cuda_ipc"`` – wrapped in :class:`TransportProxyTensor` (stays on GPU). + * ``"cuda_ipc_pool"`` – copied into the :class:`MmItemMemoryPool` workspace and + wrapped in :class:`CudaIpcTensorTransportProxy`. + + Abort sentinel messages are forwarded via ZMQ (they are lightweight dicts). + """ + # Handle abort sentinel + if isinstance(tokenized, dict) and tokenized.get("abort"): + # Fallback to ZMQ for abort messages (no tensor payload) + self._send_to_scheduler.send_pyobj(tokenized) + return + + assert isinstance(tokenized, TokenizedGenerateReqInput), ( + f"Expected TokenizedGenerateReqInput, got {type(tokenized)}" + ) + + # Lazily initialise the CUDA IPC pool (must happen inside the subprocess) + if self._transport_mode == "cuda_ipc_pool": + self._ensure_pool() + + rid = tokenized.rid + mm_inputs = tokenized.mm_inputs + + # Create lightweight metadata object (mm_inputs sent separately via queue) + metadata = TokenizedGenerateReqInput( + rid=tokenized.rid, + input_text=tokenized.input_text, + input_ids=tokenized.input_ids, + mm_inputs=None, # Will be passed separately via shared queue + sampling_params=tokenized.sampling_params, + stream=tokenized.stream, + return_logprob=tokenized.return_logprob, + logprob_start_len=tokenized.logprob_start_len, + top_logprobs_num=tokenized.top_logprobs_num, + lora_path=tokenized.lora_path, + session_params=tokenized.session_params, + ) + + # Write metadata to shared memory + shm_name = SharedMemoryManager.write_metadata(rid, metadata) + + # Put (rid, shm_name, mm_inputs) into shared queue + # TensorQueue.put() handles wrapping mm_inputs based on transport_mode + self._shared_queue.put(rid, shm_name, mm_inputs) + + logger.debug( + "Sent request %s via shared queue (shm=%s, transport=%s)", + rid, + shm_name, + self._transport_mode, + ) + + # ------------------------------------------------------------------ + # CUDA IPC pool initialisation (deferred to subprocess) + # ------------------------------------------------------------------ + + def _ensure_pool(self) -> None: + """Lazily create the MmItemMemoryPool inside the subprocess. + + This is deferred because CUDA context creation must happen after + ``torch.multiprocessing.Process`` has started (post-fork/spawn). + Once the pool is created we update the shared queue's transport config + in-place so the same underlying ``multiprocessing.Queue`` object is reused + (both processes already hold a reference to it). + """ + if self._ipc_pool is not None: + return + try: + import torch + + if not torch.cuda.is_available(): + logger.warning( + "CUDA not available; falling back to transport_mode='default'" + ) + self._transport_mode = "default" + if self._shared_queue is not None: + self._shared_queue._transport_mode = "default" + return + + pool_bytes = self._ipc_pool_size_mb * 1024 * 1024 + device = torch.cuda.current_device() + self._ipc_pool = MmItemMemoryPool( + memory_size=pool_bytes, + recycle_interval=self._ipc_recycle_interval, + device=device, + ) + # Update the shared queue's config in-place. + # Both processes share the same multiprocessing.Queue object, so we + # just update the wrapper's transport metadata; the underlying queue + # pipe is unchanged. + if self._shared_queue is not None: + self._shared_queue._transport_mode = self._transport_mode + self._shared_queue._pool = self._ipc_pool + + logger.info( + "MmItemMemoryPool initialised: %d MB on cuda:%d", + self._ipc_pool_size_mb, + device, + ) + except Exception as exc: + logger.error( + "Failed to initialise MmItemMemoryPool: %s; " + "falling back to transport_mode='default'", + exc, + exc_info=True, + ) + self._transport_mode = "default" + if self._shared_queue is not None: + self._shared_queue._transport_mode = "default" + + # ------------------------------------------------------------------ + # Tokenization and multimodal preprocessing + # ------------------------------------------------------------------ + + def _init_tokenizers(self) -> None: + """Initialise text tokenizer and (optionally) multimodal processor. + + All configuration is read from ``self._tokenizer_cfg`` which was + serialised by the parent process before ``spawn``. No global config + access happens inside the subprocess. + """ + cfg = self._tokenizer_cfg + tokenizer_path: str = cfg["tokenizer_path"] + tokenizer_mode: str = cfg.get("tokenizer_mode", "auto") + trust_remote_code: bool = bool(cfg.get("trust_remote_code", False)) + + tokenizer_kwargs: Dict[str, Any] = { + "use_fast": tokenizer_mode != "slow", + "trust_remote_code": trust_remote_code, + } + + self._tokenizer = AutoTokenizer.from_pretrained( + tokenizer_path, + **tokenizer_kwargs, + ) + + # Default to left padding for generation. + try: + self._tokenizer.padding_side = "left" + except Exception: + pass + + # Context length: explicit config value takes priority; fall back to + # common HF config field names. + context_len: Optional[int] = cfg.get("context_length") + if context_len is None: + hf_cfg = cfg.get("hf_config") + for name in ("max_position_embeddings", "max_sequence_length", "seq_len"): + if hf_cfg is not None and hasattr(hf_cfg, name): + context_len = int(getattr(hf_cfg, name)) + break + self._context_length = context_len + + # Try to load multimodal processor (optional). + try: + self._mm_processor = AutoProcessor.from_pretrained( + tokenizer_path, + trust_remote_code=trust_remote_code, + ) + except Exception: + # Text-only models don't provide a processor; that's fine. + self._mm_processor = None + + def _tokenize( + self, raw_request: Dict[str, Any] + ) -> Union[TokenizedGenerateReqInput, Dict[str, Any]]: + """Tokenize one raw request dict and return a typed object. + + * **Abort** messages (``{"rid": ..., "abort": True}``) are returned as + plain dicts so the scheduler can intercept them without importing the + io_struct. + * Normal requests are returned as a :class:`TokenizedGenerateReqInput` + dataclass instance that carries ``input_ids``, ``mm_inputs``, and all + sampling meta-data in typed fields. + + Each message arriving here corresponds to exactly one sub-request + because batch splitting happens upstream in ``RequestResponseProcess``. + """ + # Abort: propagate as a plain sentinel dict. + if raw_request.get("abort"): + return {"rid": raw_request.get("rid"), "abort": True} + + # ------------------------------------------------------------------ # + # 1. Text tokenization + # ------------------------------------------------------------------ # + if raw_request.get("input_ids") is not None: + # Caller already tokenized – skip text processing. + input_ids: List[int] = list(raw_request["input_ids"]) + raw_text = raw_request.get("text") + input_text: str = ( + str(raw_text[0]) if isinstance(raw_text, list) else str(raw_text or "") + ) + else: + text = raw_request.get("text") + if text is None: + raise ValueError( + "TokenizerProcess expects either `text` or `input_ids`." + ) + # Accept a list for robustness; take the first element. + input_text = str(text[0]) if isinstance(text, list) else str(text) + logger.debug(f"Tokenizing input text {input_text}") + + encode_kwargs: Dict[str, Any] = { + "add_special_tokens": True, + "return_attention_mask": False, + } + if self._context_length is not None: + encode_kwargs.update( + {"truncation": True, "max_length": self._context_length} + ) + + encoding = self._tokenizer(input_text, **encode_kwargs) + input_ids = encoding["input_ids"] + + # ------------------------------------------------------------------ # + # 2. Multimodal pre-processing + # ------------------------------------------------------------------ # + mm_inputs = self._collect_mm_inputs(raw_request, text=input_text) + + # ------------------------------------------------------------------ # + # 3. Pack into the typed dataclass + # ------------------------------------------------------------------ # + return TokenizedGenerateReqInput( + rid=raw_request.get("rid"), + input_text=input_text, + input_ids=input_ids, + mm_inputs=mm_inputs, + sampling_params=raw_request.get("sampling_params") or {}, + stream=bool(raw_request.get("stream", False)), + return_logprob=bool(raw_request.get("return_logprob", False)), + logprob_start_len=int(raw_request.get("logprob_start_len", -1)), + top_logprobs_num=int(raw_request.get("top_logprobs_num", 0)), + lora_path=raw_request.get("lora_path"), + session_params=raw_request.get("session_params"), + ) + + def _normalize_image_input(self, image_data: Any) -> List[Any]: + """Normalise ``image_data`` into a list of image-like objects. + + Supported input forms: + - single PIL.Image / numpy array / torch.Tensor + - path string or bytes + - list/tuple of the above + """ + + def _to_image(obj: Any) -> Any: + # Lazily import Pillow to avoid hard dependency for text-only models. + try: + from PIL import Image # type: ignore + except Exception as exc: # pragma: no cover - optional dependency + raise RuntimeError( + "Pillow is required for image preprocessing in TokenizerProcess" + ) from exc + + if obj is None: + return None + if isinstance(obj, Image.Image): + return obj + if isinstance(obj, (str, bytes)): + return Image.open(obj) + return obj + + if isinstance(image_data, (list, tuple)): + return [ + img for img in (_to_image(x) for x in image_data) if img is not None + ] + return [img for img in (_to_image(image_data),) if img is not None] + + def _collect_mm_inputs( + self, raw_request: Dict[str, Any], text: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + """Pre-process multimodal data and return a consolidated ``mm_inputs`` dict. + + Returns ``None`` for text-only requests. Otherwise returns a flat dict + whose keys are ready to be unpacked by the model runner: + + * ``image_inputs`` – output of ``AutoProcessor`` (contains + ``pixel_values``, etc.) when a processor is available. + * ``image_data`` – raw image objects when no processor is available. + * ``audio_data`` – forwarded verbatim (no processor yet). + * ``video_data`` – forwarded verbatim (no processor yet). + """ + image_data = raw_request.get("image_data") + video_data = raw_request.get("video_data") + audio_data = raw_request.get("audio_data") + + if not any(x is not None for x in (image_data, video_data, audio_data)): + return None # text-only request + + mm: Dict[str, Any] = {} + + # Image: prefer AutoProcessor output; fall back to raw data. + if image_data is not None: + if self._mm_processor is not None: + images = self._normalize_image_input(image_data) + try: + processor_inputs = self._mm_processor( + images=images, + text=text if text is not None else raw_request.get("text"), + return_tensors="pt", + ) + mm["image_inputs"] = processor_inputs + except Exception: + mm["image_data"] = image_data + else: + mm["image_data"] = image_data + + # Audio / video forwarded verbatim for now. + if audio_data is not None: + mm["audio_data"] = audio_data + if video_data is not None: + mm["video_data"] = video_data + + return mm + + def shutdown(self) -> None: + if self._ipc_pool is not None: + self._ipc_pool.shutdown() + if self._recv_from_rr is not None: + self._recv_from_rr.close() + if self._send_to_scheduler is not None: + self._send_to_scheduler.close() + if self._zmq_ctx is not None: + self._zmq_ctx.term() + + +def run_tokenizer_process( + recv_from_rr_addr: str, + send_to_scheduler_addr: str, + pipe_writer: Connection, + tokenizer_cfg: Dict[str, Any], + shared_queue: Optional[TensorQueue] = None, +) -> None: + """Entry point for ``torch.multiprocessing.Process(target=...)``.""" + setup_subprocess_logging(tokenizer_cfg.get("log_level", "info")) + + # Limit CPU threads — tokenizer doesn't need PyTorch parallelism. + import torch + torch.set_num_threads(1) + + proc = TokenizerProcess( + recv_from_rr_addr, send_to_scheduler_addr, tokenizer_cfg, shared_queue + ) + proc.init_sockets() + + # Signal readiness to the parent process + pipe_writer.send({"status": "ready", "process": "tokenizer"}) + pipe_writer.close() + + try: + proc.event_loop() + except KeyboardInterrupt: + pass + finally: + proc.shutdown() diff --git a/pymllm/parsers/__init__.py b/pymllm/parsers/__init__.py new file mode 100644 index 000000000..5ac5c2922 --- /dev/null +++ b/pymllm/parsers/__init__.py @@ -0,0 +1,10 @@ +"""Output parsers for reasoning (thinking) content and tool calls.""" + +from pymllm.parsers.reasoning_parser import ReasoningParser +from pymllm.parsers.tool_call_parser import ToolCallParser, ToolCallItem + +__all__ = [ + "ReasoningParser", + "ToolCallParser", + "ToolCallItem", +] diff --git a/pymllm/parsers/reasoning_parser.py b/pymllm/parsers/reasoning_parser.py new file mode 100644 index 000000000..1f73c7885 --- /dev/null +++ b/pymllm/parsers/reasoning_parser.py @@ -0,0 +1,212 @@ +"""Reasoning / thinking content parser. + +Separates ``...`` (or model-specific markers) from normal +assistant content. Supports both one-shot and incremental streaming modes. + +Usage:: + + # Non-streaming + parser = ReasoningParser("qwen3") + reasoning, content = parser.parse_non_stream(full_text) + + # Streaming + parser = ReasoningParser("qwen3") + for delta in deltas: + reasoning_delta, content_delta = parser.parse_stream_chunk(delta) +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Type + + +# --------------------------------------------------------------------------- +# Detector registry +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _DetectorConfig: + start: str + end: str + force: bool # True = always assume reasoning at start + + +_DETECTOR_MAP: Dict[str, _DetectorConfig] = { + # DeepSeek-R1: always starts in reasoning mode + "deepseek-r1": _DetectorConfig("", "", force=True), + # Qwen3: optional thinking (controlled by request) + "qwen3": _DetectorConfig("", "", force=False), + # Qwen3 forced thinking + "qwen3-thinking": _DetectorConfig("", "", force=True), + # GLM-4.5 + "glm45": _DetectorConfig("", "", force=False), + # Kimi + "kimi": _DetectorConfig("\u25c1think\u25b7", "\u25c1/think\u25b7", force=False), +} + + +# --------------------------------------------------------------------------- +# ReasoningParser +# --------------------------------------------------------------------------- + + +class ReasoningParser: + """Model-agnostic reasoning content parser. + + Parameters + ---------- + model_type + Key into the detector registry (e.g. ``"qwen3"``, ``"deepseek-r1"``). + stream_reasoning + If ``True``, stream reasoning content incrementally as it arrives. + If ``False``, buffer reasoning until the end tag is found. + """ + + SUPPORTED = set(_DETECTOR_MAP) + + def __init__(self, model_type: str, stream_reasoning: bool = True): + cfg = _DETECTOR_MAP.get(model_type) + if cfg is None: + raise ValueError( + f"Unknown reasoning parser {model_type!r}. " + f"Supported: {sorted(_DETECTOR_MAP)}" + ) + self._start = cfg.start + self._end = cfg.end + self._force = cfg.force + self._stream_reasoning = stream_reasoning + + # -- streaming state -- + self._buffer = "" + self._in_reasoning = cfg.force + self._start_consumed = False # True once start tag has been stripped + self._done = False # True once end tag has been seen + + # ------------------------------------------------------------------ # + # Non-streaming + # ------------------------------------------------------------------ # + + def parse_non_stream(self, text: str) -> Tuple[Optional[str], str]: + """Parse complete text. + + Returns ``(reasoning_content, content)`` where either may be empty. + """ + start_idx = text.find(self._start) + end_idx = text.find(self._end) + + if start_idx == -1 and not self._force: + return None, text + + # Determine boundaries + if self._force and start_idx == -1: + # Model didn't emit explicit start tag; treat prefix as reasoning + reason_start = 0 + else: + reason_start = start_idx + len(self._start) + + before = text[:start_idx] if start_idx != -1 else "" + + if end_idx != -1 and end_idx >= reason_start: + reasoning = text[reason_start:end_idx] + after = text[end_idx + len(self._end) :] + else: + reasoning = text[reason_start:] + after = "" + + content = (before + after).strip() + reasoning = reasoning.strip() + return reasoning or None, content + + # ------------------------------------------------------------------ # + # Streaming + # ------------------------------------------------------------------ # + + def parse_stream_chunk(self, delta: str) -> Tuple[str, str]: + """Parse an incremental streaming delta. + + Returns ``(reasoning_delta, content_delta)``. Either may be ``""``. + """ + if not delta: + return "", "" + + if self._done: + return "", delta + + self._buffer += delta + reasoning_out = "" + content_out = "" + + # In forced reasoning mode, consume the start tag if it appears + # (the model may or may not emit it explicitly). + if self._in_reasoning and not self._start_consumed: + idx = self._buffer.find(self._start) + if idx != -1: + # Start tag found — strip it and any text before it + self._buffer = self._buffer[idx + len(self._start) :] + self._start_consumed = True + elif _could_be_partial(self._buffer, self._start): + # Might be a partial start tag — hold the buffer + return "", "" + else: + # No start tag coming — mark consumed and continue + self._start_consumed = True + + if not self._in_reasoning: + # --- look for start tag --- + idx = self._buffer.find(self._start) + if idx != -1: + content_out += self._buffer[:idx] + self._buffer = self._buffer[idx + len(self._start) :] + self._in_reasoning = True + self._start_consumed = True + elif _could_be_partial(self._buffer, self._start): + # Potential partial match at tail — hold the buffer + safe = len(self._buffer) - len(self._start) + 1 + if safe > 0: + content_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + return "", content_out + else: + content_out += self._buffer + self._buffer = "" + return "", content_out + + if self._in_reasoning: + # --- look for end tag --- + idx = self._buffer.find(self._end) + if idx != -1: + reasoning_out += self._buffer[:idx] + after = self._buffer[idx + len(self._end) :] + self._buffer = "" + self._in_reasoning = False + self._done = True + if after: + content_out += after + elif _could_be_partial(self._buffer, self._end): + safe = len(self._buffer) - len(self._end) + 1 + if safe > 0: + reasoning_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + else: + reasoning_out += self._buffer + self._buffer = "" + + if not self._stream_reasoning: + reasoning_out = "" + + return reasoning_out, content_out + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _could_be_partial(text: str, pattern: str) -> bool: + """Return True if *text* ends with a prefix of *pattern*.""" + for i in range(1, len(pattern)): + if text.endswith(pattern[:i]): + return True + return False diff --git a/pymllm/parsers/tool_call_parser.py b/pymllm/parsers/tool_call_parser.py new file mode 100644 index 000000000..fdfe93914 --- /dev/null +++ b/pymllm/parsers/tool_call_parser.py @@ -0,0 +1,433 @@ +"""Tool-call (function-calling) output parser. + +Extracts structured tool calls from model output text. Supports both +one-shot and incremental streaming modes. + +Formats supported: + +* **qwen25** — ``{"name":...,"arguments":...}`` +* **llama3** — ``<|python_tag|>{"name":...,"parameters":...}`` +* **hermes** — ``{"name":...,"arguments":...}`` (same tags, Hermes schema) + +Usage:: + + # Non-streaming + parser = ToolCallParser("qwen25", tools=tools_list) + content, tool_calls = parser.parse_non_stream(full_text) + + # Streaming + parser = ToolCallParser("qwen25", tools=tools_list) + for delta in deltas: + content_delta, tool_call_deltas = parser.parse_stream_chunk(delta) +""" + +from __future__ import annotations + +import json +import re +import uuid +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Tuple + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class ToolCallItem: + """A single parsed tool call.""" + + name: Optional[str] = None + arguments: str = "" + tool_call_id: str = "" + index: int = 0 + + def to_openai_dict(self, streaming: bool = True) -> Dict[str, Any]: + """Convert to OpenAI ``tool_calls[]`` element format. + + Parameters + ---------- + streaming + If True, include ``index`` (streaming delta format). + If False, omit ``index`` (non-streaming message format). + """ + d: Dict[str, Any] = {"type": "function", "function": {}} + if streaming: + d["index"] = self.index + if self.tool_call_id: + d["id"] = self.tool_call_id + fn: Dict[str, Any] = d["function"] + if self.name is not None: + fn["name"] = self.name + fn["arguments"] = self.arguments or "" + return d + + +# --------------------------------------------------------------------------- +# Detector base +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class _FormatConfig: + bot_token: str + end_token: str + # Regex to extract individual call bodies between bot/end tokens. + # If None, the entire text between bot and end tokens is one call. + call_regex: Optional[str] = None + + +_FORMAT_MAP: Dict[str, _FormatConfig] = { + "qwen25": _FormatConfig( + bot_token="\n", + end_token="\n", + ), + "qwen3_coder": _FormatConfig( + bot_token="", + end_token="", + ), + "hermes": _FormatConfig( + bot_token="\n", + end_token="\n", + ), + "llama3": _FormatConfig( + bot_token="<|python_tag|>", + end_token="", # Llama3 uses EOT, detected via EOS + ), +} + + +# --------------------------------------------------------------------------- +# ToolCallParser +# --------------------------------------------------------------------------- + + +class ToolCallParser: + """Model-agnostic tool-call parser. + + Parameters + ---------- + model_type + Key into the format registry (e.g. ``"qwen25"``, ``"llama3"``). + tools + The ``tools`` list from the OpenAI chat request (used to resolve + function names). + """ + + SUPPORTED = set(_FORMAT_MAP) + + def __init__(self, model_type: str, tools: Optional[List[Any]] = None): + cfg = _FORMAT_MAP.get(model_type) + if cfg is None: + raise ValueError( + f"Unknown tool-call parser {model_type!r}. " + f"Supported: {sorted(_FORMAT_MAP)}" + ) + self._bot = cfg.bot_token + self._end = cfg.end_token + self._model_type = model_type + self._tools = tools or [] + + # -- streaming state -- + self._buffer = "" + self._in_call = False + self._current_tool_idx = 0 + self._current_call_buf = "" + self._prev_args_len = 0 + self._name_sent = False + self._completed_calls: List[ToolCallItem] = [] + + # ------------------------------------------------------------------ # + # Non-streaming + # ------------------------------------------------------------------ # + + def has_tool_call(self, text: str) -> bool: + """Return True if *text* contains a tool-call pattern.""" + return self._bot in text + + def parse_non_stream( + self, text: str + ) -> Tuple[str, List[ToolCallItem]]: + """Parse complete text. + + Returns ``(remaining_content, tool_calls)``. + """ + if not self.has_tool_call(text): + return text, [] + + tool_calls: List[ToolCallItem] = [] + normal_parts: List[str] = [] + + remaining = text + idx = 0 + while True: + bot_pos = remaining.find(self._bot) + if bot_pos == -1: + normal_parts.append(remaining) + break + normal_parts.append(remaining[:bot_pos]) + remaining = remaining[bot_pos + len(self._bot) :] + + if self._end: + end_pos = remaining.find(self._end) + if end_pos == -1: + call_body = remaining + remaining = "" + else: + call_body = remaining[:end_pos] + remaining = remaining[end_pos + len(self._end) :] + else: + call_body = remaining + remaining = "" + + parsed = self._parse_call_body(call_body.strip()) + if parsed is not None: + parsed.index = idx + parsed.tool_call_id = _make_tool_call_id() + tool_calls.append(parsed) + idx += 1 + + content = "".join(normal_parts).strip() + return content, tool_calls + + # ------------------------------------------------------------------ # + # Streaming + # ------------------------------------------------------------------ # + + def parse_stream_chunk( + self, delta: str + ) -> Tuple[str, List[ToolCallItem]]: + """Parse an incremental streaming delta. + + Returns ``(content_delta, tool_call_items)``. + + For tool call items: + - First item for a call: ``name`` is set, ``arguments`` is ``""``. + - Subsequent items: ``name`` is ``None``, ``arguments`` is the new + characters appended (argument delta). + """ + if not delta: + return "", [] + + self._buffer += delta + content_out = "" + items: List[ToolCallItem] = [] + + while True: + if not self._in_call: + # --- look for bot token --- + bot_pos = self._buffer.find(self._bot) + if bot_pos != -1: + content_out += self._buffer[:bot_pos] + self._buffer = self._buffer[bot_pos + len(self._bot) :] + self._in_call = True + self._current_call_buf = "" + self._prev_args_len = 0 + self._name_sent = False + continue # try to process call content + else: + # Check for partial bot token at tail + if self._bot and _could_be_partial(self._buffer, self._bot): + safe = len(self._buffer) - len(self._bot) + 1 + if safe > 0: + content_out += self._buffer[:safe] + self._buffer = self._buffer[safe:] + else: + content_out += self._buffer + self._buffer = "" + break + + if self._in_call: + # --- look for end token --- + if self._end: + end_pos = self._buffer.find(self._end) + if end_pos != -1: + self._current_call_buf += self._buffer[:end_pos] + self._buffer = self._buffer[end_pos + len(self._end) :] + # Emit final tool call + item = self._finalize_call() + if item is not None: + items.append(item) + self._in_call = False + self._current_tool_idx += 1 + continue # there may be more calls + else: + # Accumulate and stream arguments + self._current_call_buf += self._buffer + self._buffer = "" + item = self._stream_partial_call() + if item is not None: + items.append(item) + break + else: + # No end token (e.g. Llama3) — accumulate everything + self._current_call_buf += self._buffer + self._buffer = "" + item = self._stream_partial_call() + if item is not None: + items.append(item) + break + + return content_out, items + + def flush(self) -> List[ToolCallItem]: + """Flush any remaining buffered tool call (call at request end).""" + items: List[ToolCallItem] = [] + if self._in_call and self._current_call_buf.strip(): + item = self._finalize_call() + if item is not None: + items.append(item) + self._in_call = False + return items + + # ------------------------------------------------------------------ # + # Internal helpers + # ------------------------------------------------------------------ # + + def _parse_call_body(self, body: str) -> Optional[ToolCallItem]: + """Parse a single call body (JSON or qwen3_coder XML-style).""" + if self._model_type == "qwen3_coder": + return self._parse_qwen3_coder_body(body) + try: + obj = json.loads(body) + except json.JSONDecodeError: + return None + name = obj.get("name") + args = obj.get("arguments") or obj.get("parameters") or {} + if isinstance(args, dict): + args = json.dumps(args, ensure_ascii=False) + return ToolCallItem(name=name, arguments=args) + + @staticmethod + def _parse_qwen3_coder_body(body: str) -> Optional[ToolCallItem]: + """Parse qwen3_coder XML-style: ``V...``.""" + # Extract function name + func_m = re.search(r"]+)>", body) + if func_m is None: + return None + name = func_m.group(1) + # Extract parameters + params: Dict[str, Any] = {} + for pm in re.finditer( + r"]+)>(.*?)(?:|(?=))", + body, + re.DOTALL, + ): + key = pm.group(1) + val = pm.group(2).strip() + # Try to parse as JSON value, otherwise keep as string + try: + params[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + params[key] = val + return ToolCallItem( + name=name, + arguments=json.dumps(params, ensure_ascii=False), + ) + + def _stream_partial_call(self) -> Optional[ToolCallItem]: + """Try to extract streaming information from the partial call.""" + body = self._current_call_buf.strip() + if not body: + return None + + # Try to extract name first + if not self._name_sent: + name = self._try_extract_name(body) + if name is not None: + self._name_sent = True + return ToolCallItem( + name=name, + arguments="", + tool_call_id=_make_tool_call_id(), + index=self._current_tool_idx, + ) + return None + + # Stream argument characters + args_str = self._try_extract_args_partial(body) + if args_str is not None and len(args_str) > self._prev_args_len: + new_chars = args_str[self._prev_args_len :] + self._prev_args_len = len(args_str) + return ToolCallItem( + name=None, + arguments=new_chars, + index=self._current_tool_idx, + ) + return None + + def _finalize_call(self) -> Optional[ToolCallItem]: + """Finalize a complete call — emit any remaining argument chars.""" + parsed = self._parse_call_body(self._current_call_buf.strip()) + if parsed is None: + return None + + if not self._name_sent: + # Entire call came at once + parsed.index = self._current_tool_idx + parsed.tool_call_id = _make_tool_call_id() + return parsed + + # Name was already sent — emit remaining arguments + full_args = parsed.arguments + new_chars = full_args[self._prev_args_len :] + if new_chars: + return ToolCallItem( + name=None, + arguments=new_chars, + index=self._current_tool_idx, + ) + return None + + def _try_extract_name(self, partial: str) -> Optional[str]: + """Try to extract function name from partial call body.""" + if self._model_type == "qwen3_coder": + m = re.search(r"]+)>", partial) + return m.group(1) if m else None + m = re.search(r'"name"\s*:\s*"([^"]+)"', partial) + return m.group(1) if m else None + + def _try_extract_args_partial(self, partial: str) -> Optional[str]: + """Try to extract partial arguments from call body.""" + if self._model_type == "qwen3_coder": + # Build JSON incrementally from V tags + params: Dict[str, Any] = {} + for pm in re.finditer( + r"]+)>(.*?)(?:)", + partial, + re.DOTALL, + ): + key = pm.group(1) + val = pm.group(2).strip() + try: + params[key] = json.loads(val) + except (json.JSONDecodeError, ValueError): + params[key] = val + if params: + return json.dumps(params, ensure_ascii=False) + return None + m = re.search(r'"arguments"\s*:\s*(\{.*)', partial, re.DOTALL) + if m: + return m.group(1) + m = re.search(r'"parameters"\s*:\s*(\{.*)', partial, re.DOTALL) + if m: + return m.group(1) + return None + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_tool_call_id() -> str: + return f"call_{uuid.uuid4().hex[:24]}" + + +def _could_be_partial(text: str, pattern: str) -> bool: + for i in range(1, len(pattern)): + if text.endswith(pattern[:i]): + return True + return False diff --git a/pymllm/quantization/QUANTIZATION.md b/pymllm/quantization/QUANTIZATION.md new file mode 100644 index 000000000..edb49dcc7 --- /dev/null +++ b/pymllm/quantization/QUANTIZATION.md @@ -0,0 +1,325 @@ +# pymllm Quantization Guide + +## Architecture + +pymllm uses a **plugin-based** quantization system. Each quantization +algorithm (AWQ, GPTQ, FP8, W8A8, ...) is a self-contained plugin that +implements three methods: **create weights**, **apply** (forward), and +**process weights after loading**. + +``` + QuantizationConfig + (parses checkpoint) + │ + │ get_quant_method(layer, prefix) + ▼ +┌─────────────────────────────────────────────────────┐ +│ LinearMethodBase │ +│ │ +│ create_weights() ← called during layer __init__ │ +│ apply() ← called during layer forward │ +│ process_weights_after_loading() ← called once │ +│ after checkpoint is loaded │ +└─────────────────────────────────────────────────────┘ + │ + │ registered on layer as + │ layer.quant_method + ▼ + Linear / ColumnParallelLinear / ... +``` + +### Key modules + +| Module | Purpose | +|--------|---------| +| `pymllm.layers.quantize_base` | `QuantizeMethodBase`, `LinearMethodBase`, `UnquantizedLinearMethod` | +| `pymllm.quantization.quant_config` | `QuantizationConfig` base class, registry, factory | +| `pymllm.quantization.methods/` | Concrete implementations (AWQ, GPTQ, FP8, ...) | + +## Lifecycle + +### 1. Model construction + +Each linear layer accepts an optional `quant_method` argument. If `None`, +`UnquantizedLinearMethod` is used (standard FP weight + `F.linear`). + +```python +from pymllm.layers.linear import ColumnParallelLinear + +# No quantization (default) +layer = ColumnParallelLinear(4096, 4096) + +# With quantization +layer = ColumnParallelLinear(4096, 4096, quant_method=my_quant_method) +``` + +During `__init__`, the layer calls: + +```python +self.quant_method.create_weights( + layer=self, + input_size_per_partition=in_features, + output_partition_sizes=[out_features_per_partition], + input_size=in_features, + output_size=out_features, + params_dtype=torch.get_default_dtype(), + weight_loader=self.weight_loader, +) +``` + +This registers the appropriate parameters on the layer. For unquantized +layers, this is a single `weight` parameter. For AWQ, it might be +`qweight` (packed int32), `scales` (fp16), and `qzeros` (packed int32). + +### 2. Weight loading + +The standard `model.load_weights(iter)` loop loads checkpoint tensors into +the parameters created above, using each parameter's `weight_loader` +attribute for tensor-parallel sharding. + +### 3. Post-load processing + +After all weights are loaded, `ModelRunner` calls: + +```python +for name, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) +``` + +This is where format conversions happen: +- **AWQ**: repack AutoAWQ int4 layout → Marlin kernel layout +- **GPTQ**: shuffle weights according to `g_idx` for exllama kernels +- **FP8**: quantize FP16 weights to FP8 and compute per-tensor scales + +### 4. Inference + +Every forward call goes through `quant_method.apply()`: + +```python +# Inside ColumnParallelLinear.forward(): +output = self.quant_method.apply(self, x, self.bias) +``` + +For unquantized layers this is just `F.linear`. For quantized layers it +invokes a fused dequant+matmul kernel. + +## Server Usage + +### CLI flag + +```bash +python -m pymllm.server --model_path /path/to/model --quantization.method awq_marlin +``` + +### Auto-detection + +If `--quantization.method` is not specified, pymllm probes the checkpoint +directory for `quantize_config.json` (or the `quantization_config` section +of `config.json`). When found, the `quant_method` field is used +automatically. + +### Auto-upgrade: awq → awq_marlin + +On Ampere+ GPUs (SM80+), `"awq"` is automatically upgraded to +`"awq_marlin"` for significantly faster inference via the Marlin GEMM +kernel. No user action required. + +### Supported models + +| Model | Status | +|-------|--------| +| Qwen3VL (`Qwen3VLForConditionalGeneration`) | Supported | +| Qwen3.5 (`Qwen3_5ForConditionalGeneration`, `Qwen3_5ForCausalLM`) | Supported | + +### End-to-end pipeline + +``` +CLI: --quantization.method awq_marlin (or auto-detected) + │ + ▼ +ModelRunner._resolve_quant_config() + reads quantize_config.json / config.json + auto-upgrades "awq" → "awq_marlin" on SM80+ + │ + ▼ +model_cls(hf_config, quant_config=AWQMarlinConfig(...)) + │ + ▼ propagates quant_config through sub-modules +Qwen3VLForConditionalGeneration → Qwen3VLTextModel → decoder layers + │ + ▼ each Linear() call gets quant_method +quant_config.get_quant_method(layer, prefix) → AWQMarlinLinearMethod + │ + ▼ Linear.__init__ calls quant_method.create_weights() +registers qweight, scales, qzeros (instead of weight) + │ + ▼ model.load_weights() loads checkpoint tensors + │ + ▼ process_weights_after_loading() +repacks AWQ int4 → Marlin kernel layout + │ + ▼ inference via quant_method.apply() +calls gptq_marlin_gemm kernel +``` + +### Notes + +- Vision encoder is **never quantized** — only text decoder layers +- Fused QKV and gate_up projections are automatically split into separate + projections when quantized (AWQ checkpoints store them separately) +- Embedding, layer norms, and lm_head remain in full precision + +--- + +## How to add a new quantization method + +### Step 1: Implement `LinearMethodBase` + +Create a file in `pymllm/quantization/methods/`, e.g. `awq.py`: + +```python +from pymllm.layers.quantize_base import LinearMethodBase +from pymllm.layers.utils import set_weight_attrs + +class AWQLinearMethod(LinearMethodBase): + \"\"\"AWQ W4A16 quantized linear method.\"\"\" + + def __init__(self, weight_bits: int, group_size: int, zero_point: bool): + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.pack_factor = 32 // weight_bits # e.g. 8 for 4-bit + + def create_weights( + self, layer, input_size_per_partition, output_partition_sizes, + input_size, output_size, params_dtype, **extra_weight_attrs, + ): + output_size_per_partition = sum(output_partition_sizes) + + # Packed 4-bit weights: each int32 holds 8 x 4-bit values + qweight = Parameter( + torch.empty( + input_size_per_partition, + output_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, {"input_dim": 0, "output_dim": 1}) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + # Per-group scales + scales = Parameter( + torch.empty( + input_size_per_partition // self.group_size, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + # Per-group zero-points (packed) + qzeros = Parameter( + torch.empty( + input_size_per_partition // self.group_size, + output_size_per_partition // self.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + + def apply(self, layer, x, bias=None): + # Dequantize and compute matmul + # In practice, call a fused CUDA kernel here + out = awq_dequantize_and_gemm(x, layer.qweight, layer.scales, layer.qzeros) + if bias is not None: + out = out + bias + return out + + def process_weights_after_loading(self, layer): + # Optional: repack weights for a faster kernel layout + # e.g. convert AutoAWQ format → Marlin format + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) +``` + +### Step 2: Implement `QuantizationConfig` + +```python +from pymllm.quantization.quant_config import QuantizationConfig, register_quantization + +@register_quantization("awq") +class AWQConfig(QuantizationConfig): + def __init__(self, weight_bits, group_size, zero_point): + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + + def get_name(self) -> str: + return "awq" + + @classmethod + def from_config(cls, config: dict) -> "AWQConfig": + return cls( + weight_bits=config["bits"], + group_size=config["group_size"], + zero_point=config["zero_point"], + ) + + def get_quant_method(self, layer, prefix=""): + # Skip quantization for certain layers if needed + # if "lm_head" in prefix: + # return None + return AWQLinearMethod(self.weight_bits, self.group_size, self.zero_point) +``` + +### Step 3: Use it + +```python +from pymllm.quantization import get_quantization_config + +# Parse from checkpoint config +ConfigClass = get_quantization_config("awq") +config = ConfigClass.from_config({"bits": 4, "group_size": 128, "zero_point": True}) + +# Create layer with quantization +quant_method = config.get_quant_method(layer=None, prefix="model.layers.0.q_proj") +layer = ColumnParallelLinear(4096, 4096, quant_method=quant_method) +``` + +## API Reference + +### `QuantizeMethodBase` + +| Method | When called | Purpose | +|--------|-------------|---------| +| `create_weights(layer, ...)` | `layer.__init__` | Register parameters (weight, scales, etc.) on the layer | +| `apply(layer, x, bias)` | `layer.forward` | Quantized matmul computation | +| `process_weights_after_loading(layer)` | After `load_weights` | Repack / transform loaded checkpoint tensors | + +### `QuantizationConfig` + +| Method | Purpose | +|--------|---------| +| `get_name()` | Return method name (e.g. `"awq"`) | +| `from_config(config_dict)` | Class method: parse checkpoint JSON into config instance | +| `get_quant_method(layer, prefix)` | Return `LinearMethodBase` for a specific layer | +| `get_supported_act_dtypes()` | Activation dtypes this method supports | +| `get_min_capability()` | Minimum CUDA compute capability | +| `get_config_filenames()` | Checkpoint files to probe (default: `["quantize_config.json"]`) | + +### Registry functions + +| Function | Purpose | +|----------|---------| +| `@register_quantization("name")` | Decorator to register a config class | +| `get_quantization_config("name")` | Look up registered config class by name | +| `list_quantization_methods()` | List all registered method names | diff --git a/pymllm/quantization/__init__.py b/pymllm/quantization/__init__.py new file mode 100644 index 000000000..e4bf77025 --- /dev/null +++ b/pymllm/quantization/__init__.py @@ -0,0 +1,18 @@ +"""Quantization infrastructure for pymllm.""" + +from pymllm.quantization.quant_config import ( + QuantizationConfig, + get_quantization_config, + list_quantization_methods, + register_quantization, +) + +# Import methods module to trigger @register_quantization decorators +import pymllm.quantization.methods # noqa: F401 + +__all__ = [ + "QuantizationConfig", + "get_quantization_config", + "list_quantization_methods", + "register_quantization", +] diff --git a/pymllm/quantization/methods/__init__.py b/pymllm/quantization/methods/__init__.py new file mode 100644 index 000000000..90367f741 --- /dev/null +++ b/pymllm/quantization/methods/__init__.py @@ -0,0 +1,15 @@ +"""Quantization method implementations. + +Importing this module triggers registration of all built-in quantization +methods via the ``@register_quantization`` decorator. +""" + +from pymllm.quantization.methods.awq_marlin import ( + AWQMarlinConfig, + AWQMarlinLinearMethod, +) + +__all__ = [ + "AWQMarlinConfig", + "AWQMarlinLinearMethod", +] diff --git a/pymllm/quantization/methods/awq_marlin.py b/pymllm/quantization/methods/awq_marlin.py new file mode 100644 index 000000000..e8f929aa3 --- /dev/null +++ b/pymllm/quantization/methods/awq_marlin.py @@ -0,0 +1,524 @@ +"""AWQ quantization with Marlin kernel acceleration. + +This module implements the AWQ Marlin quantization plugin for pymllm, +providing high-performance W4A16 inference via the Marlin GEMM kernel. + +Classes +------- +AWQMarlinConfig + Quantization configuration parsed from ``quantize_config.json``. +AWQMarlinLinearMethod + Linear method that uses AWQ weight format with Marlin kernel dispatch. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Optional + +import numpy +import torch +from torch.nn import Parameter + +from pymllm.layers.quantize_base import LinearMethodBase +from pymllm.layers.utils import set_weight_attrs +from pymllm.quantization.quant_config import QuantizationConfig, register_quantization + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Marlin constants +# --------------------------------------------------------------------------- + +MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128] +GPTQ_MARLIN_MIN_THREAD_N = 64 +GPTQ_MARLIN_MIN_THREAD_K = 128 +GPTQ_MARLIN_TILE = 16 + + +# --------------------------------------------------------------------------- +# ScalarType helpers (matching host::ScalarType in scalar_type.hpp) +# --------------------------------------------------------------------------- + +class _ScalarTypeInfo: + """Lightweight Python mirror of host::ScalarType for type id computation.""" + + def __init__(self, name: str, size_bits: int, type_id: int): + self.name = name + self.size_bits = size_bits + self.id = type_id + + def __repr__(self) -> str: + return f"ScalarType({self.name})" + + def __eq__(self, other: object) -> bool: + if isinstance(other, _ScalarTypeInfo): + return self.id == other.id + return NotImplemented + + def __hash__(self) -> int: + return hash(self.id) + + +def _compute_scalar_type_id( + exponent: int, mantissa: int, signed: bool, bias: int, + finite_values_only: bool = False, nan_repr: int = 1, +) -> int: + """Compute the ScalarType::Id matching the C++ implementation.""" + bit_offset = 0 + result = 0 + + for value, width in [ + (exponent, 8), + (mantissa, 8), + (signed, 1), + (bias, 32), + (finite_values_only, 1), + (nan_repr, 8), + ]: + int_val = int(value) + mask = (1 << width) - 1 + result |= (int_val & mask) << bit_offset + bit_offset += width + + return result + + +# Pre-compute the scalar type ids we need +_uint4_id = _compute_scalar_type_id(0, 4, False, 0) +_uint8_id = _compute_scalar_type_id(0, 8, False, 0) +_uint4b8_id = _compute_scalar_type_id(0, 4, False, 8) +_uint8b128_id = _compute_scalar_type_id(0, 8, False, 128) + +SCALAR_TYPE_UINT4 = _ScalarTypeInfo("uint4", 4, _uint4_id) +SCALAR_TYPE_UINT8 = _ScalarTypeInfo("uint8", 8, _uint8_id) + + +# num_bits -> ScalarType mapping +_TYPE_MAP: Dict[int, _ScalarTypeInfo] = { + 4: SCALAR_TYPE_UINT4, + 8: SCALAR_TYPE_UINT8, +} + + +# --------------------------------------------------------------------------- +# Marlin utility functions +# --------------------------------------------------------------------------- + +def verify_marlin_supported( + quant_type: _ScalarTypeInfo, group_size: int, has_zp: bool +) -> None: + """Verify that the Marlin kernel supports this configuration.""" + major, minor = torch.cuda.get_device_capability() + capability = major * 10 + minor + if capability < 80: + raise ValueError( + f"Marlin requires SM80+ (Ampere or newer). Got SM{capability}." + ) + if group_size not in MARLIN_SUPPORTED_GROUP_SIZES: + raise ValueError( + f"Marlin does not support group_size={group_size}. " + f"Supported: {MARLIN_SUPPORTED_GROUP_SIZES}" + ) + + +def verify_marlin_supports_shape( + output_size_per_partition: int, + input_size_per_partition: int, + input_size: int, + group_size: int, +) -> None: + """Verify that tensor dimensions are compatible with Marlin.""" + if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0: + raise ValueError( + f"output_size_per_partition={output_size_per_partition} is not " + f"divisible by min_thread_n={GPTQ_MARLIN_MIN_THREAD_N}." + ) + if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0: + raise ValueError( + f"input_size_per_partition={input_size_per_partition} is not " + f"divisible by min_thread_k={GPTQ_MARLIN_MIN_THREAD_K}." + ) + if group_size < input_size and input_size_per_partition % group_size != 0: + raise ValueError( + f"input_size_per_partition={input_size_per_partition} is not " + f"divisible by group_size={group_size}." + ) + + +def marlin_make_workspace(device: torch.device) -> torch.Tensor: + """Create Marlin workspace buffer for threadblock synchronization.""" + sms = torch.cuda.get_device_properties(device).multi_processor_count + return torch.zeros(sms, dtype=torch.int, device=device, requires_grad=False) + + +def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor: + """Create empty g_idx tensor (AWQ doesn't use activation reordering).""" + return torch.nn.Parameter( + torch.empty(0, dtype=torch.int, device=device), requires_grad=False + ) + + +def get_scale_perms(): + """Get the scale permutation indices for Marlin format.""" + scale_perm: list[int] = [] + for i in range(8): + scale_perm.extend([i + 8 * j for j in range(8)]) + scale_perm_single: list[int] = [] + for i in range(4): + scale_perm_single.extend( + [2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]] + ) + return scale_perm, scale_perm_single + + +def marlin_permute_scales( + s: torch.Tensor, size_k: int, size_n: int, group_size: int +) -> torch.Tensor: + """Permute quantization scales from standard to Marlin layout.""" + scale_perm, scale_perm_single = get_scale_perms() + if group_size < size_k and group_size != -1: + s = s.reshape((-1, len(scale_perm)))[:, scale_perm] + else: + s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single] + s = s.reshape((-1, size_n)).contiguous() + return s + + +def pack_cols( + q_w: torch.Tensor, num_bits: int, size_k: int, size_n: int +) -> torch.Tensor: + """Pack quantized columns into int32 values.""" + pack_factor = 32 // num_bits + assert size_n % pack_factor == 0 + out = torch.zeros( + (size_k, size_n // pack_factor), dtype=torch.int32, device=q_w.device + ) + for i in range(pack_factor): + out.bitwise_or_(q_w[:, i::pack_factor].int() << (num_bits * i)) + return out + + +def unpack_cols( + packed: torch.Tensor, num_bits: int, size_k: int, size_n: int +) -> torch.Tensor: + """Unpack int32 packed columns into individual quantized values.""" + pack_factor = 32 // num_bits + mask = (1 << num_bits) - 1 + out = torch.zeros( + (size_k, size_n), dtype=torch.int32, device=packed.device + ) + for i in range(pack_factor): + out[:, i::pack_factor] = (packed >> (num_bits * i)) & mask + return out + + +def marlin_zero_points( + zp: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + """Permute and pack zero points into Marlin format.""" + scale_perm, _ = get_scale_perms() + zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm] + + # Interleave column dim (for the dequantize code) and pack to int32 + if num_bits == 4: + interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7]) + elif num_bits == 8: + interleave = numpy.array([0, 2, 1, 3]) + else: + raise ValueError(f"num_bits must be 4 or 8, got {num_bits}") + + zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel() + zp = zp.reshape((-1, size_n)).contiguous() + zp = pack_cols(zp, num_bits, size_k, size_n) + return zp + + +def awq_to_marlin_zero_points( + q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int +) -> torch.Tensor: + """Convert AWQ-format zero points to Marlin format. + + AWQ zero-points are quantized and packed on the column dim with a specific + interleaving. This function undoes the AWQ interleaving, then applies + Marlin permutation and repacks. + """ + q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n) + + # Undo AWQ interleaving + if num_bits == 4: + undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7])) + elif num_bits == 8: + undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3])) + else: + raise ValueError(f"num_bits must be 4 or 8, got {num_bits}") + + q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel() + q_zp = q_zp.reshape((-1, size_n)).contiguous() + + return marlin_zero_points(q_zp, size_k, size_n, num_bits) + + +def replace_parameter( + layer: torch.nn.Module, name: str, new_data: torch.Tensor +) -> None: + """Replace a parameter on a layer with new data.""" + param = torch.nn.Parameter(new_data, requires_grad=False) + layer.register_parameter(name, param) + + +# --------------------------------------------------------------------------- +# AWQMarlinLinearMethod +# --------------------------------------------------------------------------- + +class AWQMarlinLinearMethod(LinearMethodBase): + """Linear method for AWQ with Marlin kernel acceleration. + + Uses the Marlin W4A16 GEMM kernel for high-performance inference. + Weights are repacked from AWQ format to Marlin format after loading. + """ + + def __init__(self, quant_config: AWQMarlinConfig) -> None: + self.quant_config = quant_config + + def create_weights( + self, + layer: torch.nn.Module, + input_size_per_partition: int, + output_partition_sizes: List[int], + input_size: int, + output_size: int, + params_dtype: torch.dtype, + **extra_weight_attrs: Any, + ) -> None: + del output_size + output_size_per_partition = sum(output_partition_sizes) + weight_loader = extra_weight_attrs.get("weight_loader") + + # Normalize group_size + if self.quant_config.group_size != -1: + group_size = self.quant_config.group_size + else: + group_size = input_size + + verify_marlin_supports_shape( + output_size_per_partition=output_size_per_partition, + input_size_per_partition=input_size_per_partition, + input_size=input_size, + group_size=group_size, + ) + + # Packed quantized weights: (input_size, output_size // pack_factor) + qweight = Parameter( + torch.empty( + input_size_per_partition, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(qweight, { + "input_dim": 0, + "output_dim": 1, + }) + layer.register_parameter("qweight", qweight) + set_weight_attrs(qweight, extra_weight_attrs) + + num_groups = input_size_per_partition // group_size + + # Zero points: (num_groups, output_size // pack_factor) + qzeros = Parameter( + torch.empty( + num_groups, + output_size_per_partition // self.quant_config.pack_factor, + dtype=torch.int32, + ), + requires_grad=False, + ) + set_weight_attrs(qzeros, { + "input_dim": 0, + "output_dim": 1, + }) + layer.register_parameter("qzeros", qzeros) + set_weight_attrs(qzeros, extra_weight_attrs) + + # Scales: (num_groups, output_size) + scales = Parameter( + torch.empty( + num_groups, + output_size_per_partition, + dtype=params_dtype, + ), + requires_grad=False, + ) + set_weight_attrs(scales, { + "input_dim": 0, + "output_dim": 1, + }) + layer.register_parameter("scales", scales) + set_weight_attrs(scales, extra_weight_attrs) + + # Store dimensions for post-loading processing + layer.input_size_per_partition = input_size_per_partition + layer.output_size_per_partition = output_size_per_partition + layer.num_groups = num_groups + + def process_weights_after_loading(self, layer: torch.nn.Module) -> None: + """Repack AWQ weights to Marlin format after checkpoint loading.""" + from mllm_kernel.cuda.jit.awq_marlin_repack import awq_marlin_repack + + device = layer.qweight.device + + # Unwrap parameter data for processing + layer.qweight = Parameter(layer.qweight.data, requires_grad=False) + layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False) + layer.scales = Parameter(layer.scales.data, requires_grad=False) + + # Allocate marlin workspace + layer.workspace = marlin_make_workspace(device) + + # Repack weights from AWQ format to Marlin format + marlin_qweight = awq_marlin_repack( + layer.qweight, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qweight", marlin_qweight) + + # Permute scales from AWQ format to Marlin format + marlin_scales = marlin_permute_scales( + layer.scales, + size_k=layer.input_size_per_partition, + size_n=layer.output_size_per_partition, + group_size=self.quant_config.group_size, + ) + replace_parameter(layer, "scales", marlin_scales) + + # Convert zero points from AWQ format to Marlin format + marlin_zp = awq_to_marlin_zero_points( + layer.qzeros, + size_k=layer.num_groups, + size_n=layer.output_size_per_partition, + num_bits=self.quant_config.quant_type.size_bits, + ) + replace_parameter(layer, "qzeros", marlin_zp) + + # AWQ doesn't use activation reordering + layer.g_idx = marlin_make_empty_g_idx(device) + layer.g_idx_sort_indices = marlin_make_empty_g_idx(device) + + def apply( + self, + layer: torch.nn.Module, + x: torch.Tensor, + bias: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Perform quantized matmul using the Marlin GEMM kernel.""" + from mllm_kernel.cuda.jit.gptq_marlin import gptq_marlin_gemm + + reshaped_x = x.reshape(-1, x.shape[-1]) + out_shape = x.shape[:-1] + (layer.output_size_per_partition,) + + size_m = reshaped_x.shape[0] + size_n = layer.output_size_per_partition + size_k = layer.input_size_per_partition + + output = gptq_marlin_gemm( + a=reshaped_x, + c=None, + b_q_weight=layer.qweight, + b_scales=layer.scales, + global_scale=None, + b_zeros=layer.qzeros, + g_idx=layer.g_idx, + perm=layer.g_idx_sort_indices, + workspace=layer.workspace, + b_q_type_id=self.quant_config.quant_type.id, + size_m=size_m, + size_n=size_n, + size_k=size_k, + is_k_full=True, + use_fp32_reduce=True, + is_zp_float=False, + ) + + if bias is not None: + output.add_(bias) + + return output.reshape(out_shape) + + +# --------------------------------------------------------------------------- +# AWQMarlinConfig +# --------------------------------------------------------------------------- + +@register_quantization("awq_marlin") +class AWQMarlinConfig(QuantizationConfig): + """Configuration for AWQ quantization with Marlin kernel acceleration. + + This config is used when loading models quantized with AutoAWQ and + running inference with the high-performance Marlin W4A16 GEMM kernel. + + Registered as ``"awq_marlin"`` in the quantization registry. + """ + + def __init__( + self, + weight_bits: int, + group_size: int, + zero_point: bool, + ) -> None: + super().__init__() + self.weight_bits = weight_bits + self.group_size = group_size + self.zero_point = zero_point + self.pack_factor = 32 // weight_bits + + if weight_bits not in _TYPE_MAP: + raise ValueError( + f"Unsupported weight_bits={weight_bits}. " + f"Supported: {list(_TYPE_MAP.keys())}" + ) + self.quant_type = _TYPE_MAP[weight_bits] + + verify_marlin_supported( + self.quant_type, + group_size=self.group_size, + has_zp=self.zero_point, + ) + + def __repr__(self) -> str: + return ( + f"AWQMarlinConfig(weight_bits={self.weight_bits}, " + f"group_size={self.group_size}, " + f"zero_point={self.zero_point})" + ) + + def get_name(self) -> str: + return "awq_marlin" + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + return [torch.half, torch.bfloat16] + + @classmethod + def get_min_capability(cls) -> int: + return 80 + + @staticmethod + def get_config_filenames() -> List[str]: + return ["quantize_config.json"] + + @classmethod + def from_config(cls, config: Dict[str, Any]) -> AWQMarlinConfig: + weight_bits = config.get("bits", config.get("w_bit", 4)) + group_size = config.get("group_size", 128) + zero_point = config.get("zero_point", True) + return cls(weight_bits, group_size, zero_point) + + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str = "", + ) -> Optional[AWQMarlinLinearMethod]: + return AWQMarlinLinearMethod(self) diff --git a/pymllm/quantization/quant_config.py b/pymllm/quantization/quant_config.py new file mode 100644 index 000000000..8225f6d11 --- /dev/null +++ b/pymllm/quantization/quant_config.py @@ -0,0 +1,203 @@ +"""Quantization configuration base class and registry. + +This module provides the bridge between a model checkpoint's quantization +metadata (e.g. ``quantize_config.json``) and the runtime +:class:`~pymllm.layers.quantize_base.LinearMethodBase` instances used by +each linear layer. + +Architecture overview:: + + quantize_config.json ──parse──► QuantizationConfig subclass + │ + │ get_quant_method(layer, prefix) + ▼ + LinearMethodBase instance + (AWQLinearMethod, FP8LinearMethod, ...) + +How to add a new quantization method +------------------------------------- +1. Create a ``QuantizationConfig`` subclass (e.g. ``AWQConfig``). +2. Implement ``get_name()``, ``from_config()``, ``get_quant_method()``. +3. Register it:: + + from pymllm.quantization.quant_config import register_quantization + + @register_quantization("awq") + class AWQConfig(QuantizationConfig): + ... + +4. When the server starts with ``--quantization.method awq``, the loader + will call ``get_quantization_config("awq")`` to obtain the config class, + then ``from_config(hf_quant_config)`` to instantiate it, and finally + ``config.get_quant_method(layer, prefix)`` for each linear layer. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional, Type + +import torch + +from pymllm.layers.quantize_base import QuantizeMethodBase + + +# --------------------------------------------------------------------------- +# Registry +# --------------------------------------------------------------------------- + +# Maps method name (e.g. "awq", "gptq", "fp8") to config class. +_QUANTIZATION_REGISTRY: Dict[str, Type[QuantizationConfig]] = {} + + +def register_quantization( + name: str, +) -> "type[type[QuantizationConfig]]": + """Class decorator that registers a :class:`QuantizationConfig` subclass. + + Usage:: + + @register_quantization("awq") + class AWQConfig(QuantizationConfig): + ... + """ + + def decorator(cls: Type[QuantizationConfig]) -> Type[QuantizationConfig]: + if name in _QUANTIZATION_REGISTRY: + raise ValueError( + f"Quantization method {name!r} is already registered " + f"by {_QUANTIZATION_REGISTRY[name].__name__}" + ) + _QUANTIZATION_REGISTRY[name] = cls + return cls + + return decorator # type: ignore[return-value] + + +def get_quantization_config(method: str) -> Type[QuantizationConfig]: + """Look up a registered :class:`QuantizationConfig` by name. + + Raises ``KeyError`` if the method is not registered. + """ + if method not in _QUANTIZATION_REGISTRY: + supported = ", ".join(sorted(_QUANTIZATION_REGISTRY)) or "(none)" + raise KeyError( + f"Unknown quantization method {method!r}. " + f"Registered methods: {supported}" + ) + return _QUANTIZATION_REGISTRY[method] + + +def list_quantization_methods() -> List[str]: + """Return sorted list of registered quantization method names.""" + return sorted(_QUANTIZATION_REGISTRY) + + +# --------------------------------------------------------------------------- +# Base config class +# --------------------------------------------------------------------------- + + +class QuantizationConfig(ABC): + """Base class for quantization configurations. + + A ``QuantizationConfig`` is instantiated once per model load. It reads + quantization metadata from the checkpoint (bit-width, group size, etc.) + and provides :class:`~pymllm.layers.quantize_base.QuantizeMethodBase` + instances to each layer. + + Subclass contract + ----------------- + * :meth:`get_name` — return the method name (e.g. ``"awq"``). + * :meth:`from_config` — class method that parses a dict from the + checkpoint's ``quantize_config.json``. + * :meth:`get_quant_method` — return the appropriate + ``LinearMethodBase`` (or ``None`` to skip quantization for a layer). + + Optional overrides + ------------------ + * :meth:`get_supported_act_dtypes` — restrict activation dtypes. + * :meth:`get_min_capability` — minimum GPU compute capability. + * :meth:`get_config_filenames` — files to probe in the checkpoint dir. + """ + + @abstractmethod + def get_name(self) -> str: + """Return the canonical name of this quantization method. + + Examples: ``"awq"``, ``"gptq"``, ``"fp8"``, ``"w8a8"``. + """ + raise NotImplementedError + + @classmethod + @abstractmethod + def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig": + """Create an instance from a checkpoint's quantization config dict. + + Parameters + ---------- + config + Parsed JSON from the checkpoint's ``quantize_config.json`` or + the ``quantization_config`` section of ``config.json``. + + Example config dict (AWQ):: + + { + "quant_method": "awq", + "bits": 4, + "group_size": 128, + "zero_point": true + } + """ + raise NotImplementedError + + @abstractmethod + def get_quant_method( + self, + layer: torch.nn.Module, + prefix: str = "", + ) -> Optional[QuantizeMethodBase]: + """Return the quantization method for *layer*, or ``None`` to skip. + + Parameters + ---------- + layer + The ``nn.Module`` being constructed (e.g. ``ColumnParallelLinear``). + prefix + The layer's full dotted name in the model (e.g. + ``"model.layers.0.self_attn.q_proj"``). Can be used to + selectively skip quantization for certain layers. + + Returns + ------- + QuantizeMethodBase or None + The method instance. ``None`` means this layer should fall back + to the default :class:`~pymllm.layers.quantize_base.UnquantizedLinearMethod`. + """ + raise NotImplementedError + + # -- Optional hooks (with sensible defaults) -- + + def get_supported_act_dtypes(self) -> List[torch.dtype]: + """Activation dtypes supported by this method. + + Override to restrict (e.g. FP8 only supports ``float16``). + Default: no restriction. + """ + return [torch.float16, torch.bfloat16, torch.float32] + + @classmethod + def get_min_capability(cls) -> int: + """Minimum CUDA compute capability (e.g. 75 for Turing). + + Default: 0 (no restriction). + """ + return 0 + + @staticmethod + def get_config_filenames() -> List[str]: + """File names to look for in the checkpoint directory. + + Default: ``["quantize_config.json"]``. + """ + return ["quantize_config.json"] diff --git a/pymllm/server/__init__.py b/pymllm/server/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/server/launch.py b/pymllm/server/launch.py new file mode 100644 index 000000000..7f756d46d --- /dev/null +++ b/pymllm/server/launch.py @@ -0,0 +1,1084 @@ +"""pymllm HTTP server -- RESTful API entry point. + +This module implements a FastAPI-based HTTP server that wraps the pymllm +:class:`Engine` and exposes OpenAI-compatible and native REST endpoints. + +Endpoints +--------- +* ``GET /health`` -- liveness probe +* ``GET /v1/models`` -- list served models (OpenAI-compatible) +* ``POST /generate`` -- native generate (streaming via SSE) +* ``POST /v1/completions`` -- OpenAI-compatible completions +* ``POST /v1/chat/completions`` -- OpenAI-compatible chat completions +* ``GET /model_info`` -- model metadata +* ``GET /server_info`` -- runtime config dump +* ``POST /flush_cache`` -- flush internal caches +* ``POST /abort_request`` -- cancel a running request +""" + +import asyncio +import contextlib +import logging +import os +import time +import uuid +from contextlib import asynccontextmanager +from typing import Any, AsyncIterator, Dict, List, Optional, Union + +import orjson +import uvicorn +import uvloop +from fastapi import FastAPI, HTTPException, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import ORJSONResponse, Response, StreamingResponse +from pydantic import BaseModel, Field + +from pymllm.configs.global_config import get_global_config, make_args, read_args +from pymllm.engine.launch import Engine + +logger = logging.getLogger(__name__) +asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + + +# --------------------------------------------------------------------------- +# Disconnect-aware async generator wrapper +# --------------------------------------------------------------------------- + +_DISCONNECT_CHECK_INTERVAL = 1.0 # seconds + + +async def _iter_with_disconnect_check( + agen: AsyncIterator, + request: Request, + interval: float = _DISCONNECT_CHECK_INTERVAL, +) -> AsyncIterator: + """Wrap an async generator, periodically checking for client disconnect. + + The standard ``async for chunk in agen`` pattern only checks between + items. If the generator blocks waiting for the next item (e.g. waiting + for a decode step), a client disconnect goes unnoticed. + + This wrapper uses ``asyncio.wait`` with a timeout so that + ``request.is_disconnected()`` is polled every *interval* seconds even + while waiting for the next item. + + When a disconnect is detected, the underlying generator is closed via + ``aclose()`` which triggers its ``finally`` cleanup (abort logic). + """ + aiter = agen.__aiter__() + while True: + # Start fetching the next item without blocking indefinitely. + next_task = asyncio.ensure_future(aiter.__anext__()) + try: + while True: + done, _ = await asyncio.wait({next_task}, timeout=interval) + if done: + break + # Timeout: check if client is still connected. + if await request.is_disconnected(): + next_task.cancel() + with contextlib.suppress( + asyncio.CancelledError, StopAsyncIteration + ): + await next_task + # Close the generator to trigger its finally block. + await agen.aclose() + return + except Exception: + next_task.cancel() + with contextlib.suppress(asyncio.CancelledError, StopAsyncIteration): + await next_task + raise + + try: + yield next_task.result() + except StopAsyncIteration: + return + +# --------------------------------------------------------------------------- +# Global handles (populated at startup) +# --------------------------------------------------------------------------- +_engine: Optional[Engine] = None +_tokenizer: Optional[Any] = None + + +def _get_engine() -> Engine: + """Return the running engine or raise.""" + if _engine is None: + raise RuntimeError("Engine not initialised") + return _engine + + +# --------------------------------------------------------------------------- +# Pydantic request / response models +# --------------------------------------------------------------------------- + + +class GenerateRequest(BaseModel): + """Body for ``POST /generate``.""" + + text: Optional[Union[List[str], str]] = None + input_ids: Optional[Union[List[List[int]], List[int]]] = None + sampling_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + image_data: Optional[Any] = None + audio_data: Optional[Any] = None + video_data: Optional[Any] = None + return_logprob: Optional[Union[List[bool], bool]] = None + logprob_start_len: Optional[Union[List[int], int]] = None + top_logprobs_num: Optional[Union[List[int], int]] = None + lora_path: Optional[Union[List[Optional[str]], str]] = None + session_params: Optional[Union[List[Dict[str, Any]], Dict[str, Any]]] = None + stream: bool = False + rid: Optional[Union[List[str], str]] = None + + model_config = {"extra": "allow"} # forward unknown keys as extra_options + + +# -- OpenAI-compatible models ----------------------------------------------- + + +class ImageUrl(BaseModel): + url: str + detail: Optional[str] = "auto" + + +class ContentPart(BaseModel): + type: str + text: Optional[str] = None + image_url: Optional[ImageUrl] = None + + +class ChatMessage(BaseModel): + role: str + content: Optional[Union[str, List[ContentPart]]] = None + name: Optional[str] = None + tool_calls: Optional[List[Any]] = None + tool_call_id: Optional[str] = None + + model_config = {"extra": "allow"} + + +class StreamOptions(BaseModel): + include_usage: Optional[bool] = False + continuous_usage_stats: Optional[bool] = False + + +class ToolFunction(BaseModel): + name: str + description: Optional[str] = None + parameters: Optional[Dict[str, Any]] = None + + +class Tool(BaseModel): + type: str = "function" + function: ToolFunction + + +class ChatCompletionRequest(BaseModel): + """OpenAI ``POST /v1/chat/completions`` body.""" + + model: str = "" + messages: List[ChatMessage] + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + max_tokens: Optional[int] = None + max_completion_tokens: Optional[int] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + stop: Optional[Union[str, List[str]]] = None + n: int = 1 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + seed: Optional[int] = None + logprobs: Optional[bool] = None + top_logprobs: Optional[int] = None + user: Optional[str] = None + # Tool calling + tools: Optional[List[Tool]] = None + tool_choice: Optional[Union[str, Dict[str, Any]]] = None + # Reasoning control + separate_reasoning: bool = True + stream_reasoning: bool = True + # Pass-through to tokenizer.apply_chat_template (e.g. enable_thinking) + chat_template_kwargs: Optional[Dict[str, Any]] = None + + model_config = {"extra": "allow"} + + +class CompletionRequest(BaseModel): + """OpenAI ``POST /v1/completions`` body.""" + + model: str = "" + prompt: Union[str, List[str]] + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + max_tokens: Optional[int] = None + stream: bool = False + stream_options: Optional[StreamOptions] = None + stop: Optional[Union[str, List[str]]] = None + n: int = 1 + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + repetition_penalty: Optional[float] = None + seed: Optional[int] = None + echo: bool = False + logprobs: Optional[int] = None + user: Optional[str] = None + + model_config = {"extra": "allow"} + + +class AbortRequest(BaseModel): + rid: Optional[str] = None + + +# --------------------------------------------------------------------------- +# FastAPI application & lifespan +# --------------------------------------------------------------------------- + + +@asynccontextmanager +async def lifespan(app: FastAPI): + """Startup / shutdown hooks for the FastAPI app.""" + global _engine, _tokenizer + _engine = app.state.engine # type: ignore[attr-defined] + + # Load tokenizer in server process for apply_chat_template + cfg = get_global_config() + try: + from transformers import AutoTokenizer + + _tokenizer = AutoTokenizer.from_pretrained( + str(cfg.server.tokenizer_path), + trust_remote_code=cfg.server.trust_remote_code, + ) + logger.info( + "Loaded tokenizer for chat template: %s", cfg.server.tokenizer_path + ) + except Exception as e: + logger.warning("Failed to load tokenizer for chat template: %s", e) + + logger.info( + "HTTP server ready at http://%s:%s", + cfg.server.host, + cfg.server.port, + ) + yield + # Shutdown + if _engine is not None: + _engine.shutdown() + _engine = None + + +app = FastAPI(lifespan=lifespan) +# NOTE: CORS middleware is added in launch_server() after config is loaded, +# so that cors_allow_origins from ServerConfig can be used. + + +# --------------------------------------------------------------------------- +# Authentication middleware +# --------------------------------------------------------------------------- + +# Paths that are always accessible without an API key (liveness probes). +_AUTH_EXEMPT_PATHS = frozenset({"/health", "/health_generate"}) + + +@app.middleware("http") +async def _auth_middleware(request: Request, call_next): + """Enforce ``Authorization: Bearer `` when ``api_key`` is configured.""" + cfg = get_global_config() + api_key = cfg.server.api_key + if api_key is None: + # No key configured — open access. + return await call_next(request) + if request.url.path in _AUTH_EXEMPT_PATHS: + return await call_next(request) + auth = request.headers.get("Authorization", "") + if auth == f"Bearer {api_key}": + return await call_next(request) + admin_key = cfg.server.admin_api_key + if admin_key and auth == f"Bearer {admin_key}": + return await call_next(request) + return ORJSONResponse( + status_code=401, + content={"error": {"message": "Invalid or missing API key", "code": 401}}, + ) + + +# --------------------------------------------------------------------------- +# Exception handlers +# --------------------------------------------------------------------------- + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request: Request, exc: HTTPException): + return ORJSONResponse( + content={"error": {"message": exc.detail, "code": exc.status_code}}, + status_code=exc.status_code, + ) + + +# --------------------------------------------------------------------------- +# Health / info endpoints +# --------------------------------------------------------------------------- + + +@app.get("/health") +@app.get("/health_generate") +async def health(): + """Liveness / readiness probe. Returns 503 if subprocesses died.""" + engine = _engine + if engine is None or not engine.is_healthy: + return Response(status_code=503) + return Response(status_code=200) + + +@app.get("/model_info") +async def model_info(): + """Return basic model metadata.""" + cfg = get_global_config() + hf_cfg = cfg.model.hf_config + return { + "model_path": str(cfg.server.model_path), + "tokenizer_path": str(cfg.server.tokenizer_path), + "served_model_name": cfg.server.served_model_name, + "model_type": getattr(hf_cfg, "model_type", None) if hf_cfg else None, + "architectures": getattr(hf_cfg, "architectures", None) if hf_cfg else None, + } + + +_SERVER_INFO_REDACT = frozenset({"api_key", "admin_api_key"}) + + +@app.get("/server_info") +async def server_info(): + """Dump runtime server configuration (sensitive fields redacted).""" + import dataclasses as _dc + + cfg = get_global_config() + d = _dc.asdict(cfg.server) + for k in _SERVER_INFO_REDACT: + d.pop(k, None) + return d + + +@app.get("/v1/models") +async def list_models(): + """OpenAI-compatible model listing.""" + cfg = get_global_config() + model_name = cfg.server.served_model_name or str(cfg.server.model_path) + return { + "object": "list", + "data": [_model_card(model_name)], + } + + +@app.get("/v1/models/{model_id:path}") +async def retrieve_model(model_id: str): + """OpenAI-compatible single model retrieval.""" + cfg = get_global_config() + model_name = cfg.server.served_model_name or str(cfg.server.model_path) + if model_id != model_name: + raise HTTPException( + status_code=404, + detail=f"Model '{model_id}' not found. Available: '{model_name}'", + ) + return _model_card(model_name) + + +def _model_card(model_name: str) -> Dict[str, Any]: + """Build an OpenAI-compatible Model object.""" + return { + "id": model_name, + "object": "model", + "created": int(time.time()), + "owned_by": "pymllm", + } + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +# Map internal finish reasons to OpenAI-standard values. +_FINISH_REASON_MAP = { + "eos": "stop", + "stop": "stop", + "length": "length", + "abort": "stop", +} + + +def _normalize_finish_reason(reason: Optional[str]) -> Optional[str]: + """Convert internal finish reason to OpenAI-compatible value.""" + if reason is None: + return None + return _FINISH_REASON_MAP.get(reason, reason) + + +def _build_sampling_params( + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + max_tokens: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + repetition_penalty: Optional[float] = None, + seed: Optional[int] = None, + **extra: Any, +) -> Dict[str, Any]: + """Build a sampling_params dict from OpenAI-style fields.""" + params: Dict[str, Any] = {} + if temperature is not None: + params["temperature"] = temperature + if top_p is not None: + params["top_p"] = top_p + if top_k is not None: + params["top_k"] = top_k + if max_tokens is not None: + params["max_new_tokens"] = max_tokens + if stop is not None: + params["stop"] = stop if isinstance(stop, list) else [stop] + if frequency_penalty is not None: + params["frequency_penalty"] = frequency_penalty + if presence_penalty is not None: + params["presence_penalty"] = presence_penalty + if repetition_penalty is not None: + params["repetition_penalty"] = repetition_penalty + if seed is not None: + params["seed"] = seed + params.update(extra) + return params + + +def _messages_to_prompt( + messages: List[ChatMessage], + chat_template_kwargs: Optional[Dict[str, Any]] = None, +) -> str: + """Render chat messages into a prompt string via the model's chat template. + + Uses ``tokenizer.apply_chat_template()`` when available (handles Llama, + Qwen, Mistral, etc. automatically). Falls back to ChatML format. + + Parameters + ---------- + chat_template_kwargs + Extra keyword arguments forwarded to ``apply_chat_template`` + (e.g. ``enable_thinking=True`` for Qwen3). + """ + # Flatten each message into a plain dict for the tokenizer. + msg_dicts: List[Dict[str, Any]] = [] + for msg in messages: + content = msg.content + if isinstance(content, list): + # Multimodal: extract only text parts for the prompt string. + text_parts = [p.text for p in content if p.type == "text" and p.text] + content = "\n".join(text_parts) if text_parts else "" + elif content is None: + content = "" + d: Dict[str, Any] = {"role": msg.role, "content": content} + if msg.name is not None: + d["name"] = msg.name + msg_dicts.append(d) + + tokenizer = _tokenizer + if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): + try: + extra = dict(chat_template_kwargs) if chat_template_kwargs else {} + return tokenizer.apply_chat_template( + msg_dicts, + tokenize=False, + add_generation_prompt=True, + **extra, + ) + except Exception as e: + logger.warning("apply_chat_template failed, using fallback: %s", e) + + # Fallback: ChatML format (Qwen-style) + parts: List[str] = [] + for m in msg_dicts: + parts.append(f"<|im_start|>{m['role']}\n{m['content']}<|im_end|>") + parts.append("<|im_start|>assistant\n") + return "\n".join(parts) + + +def _extract_image_data(messages: List[ChatMessage]) -> Optional[List[str]]: + """Extract image URLs / base64 strings from multimodal content parts.""" + images: List[str] = [] + for msg in messages: + if not isinstance(msg.content, list): + continue + for part in msg.content: + if part.type == "image_url" and part.image_url is not None: + images.append(part.image_url.url) + return images if images else None + + +def _make_completion_id() -> str: + return f"cmpl-{uuid.uuid4().hex[:24]}" + + +def _make_chat_completion_id() -> str: + return f"chatcmpl-{uuid.uuid4().hex[:24]}" + + +# --------------------------------------------------------------------------- +# Native generate endpoint +# --------------------------------------------------------------------------- + + +@app.api_route("/generate", methods=["POST", "PUT"]) +async def generate(obj: GenerateRequest, request: Request): + """Native generation endpoint. Supports SSE streaming.""" + engine = _get_engine() + + # Collect extra fields as extra_options + known = set(GenerateRequest.model_fields.keys()) + extra_options = {k: v for k, v in obj.model_dump().items() if k not in known} + + kwargs: Dict[str, Any] = { + "prompt": obj.text, + "input_ids": obj.input_ids, + "sampling_params": obj.sampling_params, + "image_data": obj.image_data, + "audio_data": obj.audio_data, + "video_data": obj.video_data, + "return_logprob": obj.return_logprob, + "logprob_start_len": obj.logprob_start_len, + "top_logprobs_num": obj.top_logprobs_num, + "lora_path": obj.lora_path, + "session_params": obj.session_params, + "stream": obj.stream, + "rid": obj.rid, + **extra_options, + } + # Strip None values so Engine defaults are used + kwargs = {k: v for k, v in kwargs.items() if v is not None} + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + gen = engine.generate_async(**kwargs) + try: + async for chunk in _iter_with_disconnect_check(gen, request): + # Skip empty intermediate chunks (e.g. special tokens + # stripped by the detokenizer) + if not chunk.get("delta") and not chunk.get("finished"): + continue + yield b"data: " + orjson.dumps(chunk) + b"\n\n" + except Exception as e: + logger.error("[generate] stream error: %s", e, exc_info=True) + err = {"error": {"message": "Internal server error"}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + finally: + await gen.aclose() + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + gen = engine.generate_async(**kwargs) + try: + results = [] + async for item in _iter_with_disconnect_check(gen, request): + results.append(item) + if not results: + raise HTTPException(status_code=500, detail="No output from engine") + result = results[0] if len(results) == 1 else results + return ORJSONResponse(result) + except HTTPException: + raise + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except RuntimeError as e: + if "too many queued" in str(e): + raise HTTPException(status_code=429, detail=str(e)) + logger.error("[generate] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + except Exception as e: + logger.error("[generate] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + finally: + await gen.aclose() + + +# --------------------------------------------------------------------------- +# OpenAI-compatible /v1/completions +# --------------------------------------------------------------------------- + + +@app.post("/v1/completions") +async def openai_completions(obj: CompletionRequest, request: Request): + """OpenAI-compatible text completion endpoint.""" + if obj.n > 1: + raise HTTPException(status_code=400, detail="n > 1 is not supported") + if obj.echo: + raise HTTPException(status_code=400, detail="echo is not yet supported") + if obj.logprobs is not None and obj.logprobs > 0: + raise HTTPException(status_code=400, detail="logprobs is not yet supported") + engine = _get_engine() + sp = _build_sampling_params( + temperature=obj.temperature, + top_p=obj.top_p, + top_k=obj.top_k, + max_tokens=obj.max_tokens, + stop=obj.stop, + frequency_penalty=obj.frequency_penalty, + presence_penalty=obj.presence_penalty, + repetition_penalty=obj.repetition_penalty, + seed=obj.seed, + ) + cfg = get_global_config() + model_name = obj.model or cfg.server.served_model_name or str(cfg.server.model_path) + include_usage = ( + obj.stream_options is not None and obj.stream_options.include_usage + ) + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + comp_id = _make_completion_id() + prompt_tokens = 0 + completion_tokens = 0 + gen = engine.generate_async( + prompt=obj.prompt, sampling_params=sp, stream=True + ) + try: + async for chunk in _iter_with_disconnect_check(gen, request): + prompt_tokens = chunk.get("prompt_tokens", prompt_tokens) + completion_tokens = chunk.get("completion_tokens", completion_tokens) + delta_text = chunk.get("delta", "") + finish_reason = _normalize_finish_reason( + chunk.get("finished_reason") + ) + # Skip empty intermediate chunks + if not delta_text and finish_reason is None: + continue + sse: Dict[str, Any] = { + "id": comp_id, + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "text": delta_text, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + yield b"data: " + orjson.dumps(sse) + b"\n\n" + except Exception as e: + logger.error("[v1/completions] stream error: %s", e, exc_info=True) + err = {"error": {"message": "Internal server error"}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + finally: + await gen.aclose() + # Final usage-only chunk (OpenAI stream_options.include_usage) + if include_usage: + usage_chunk: Dict[str, Any] = { + "id": comp_id, + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + yield b"data: " + orjson.dumps(usage_chunk) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + gen = engine.generate_async( + prompt=obj.prompt, sampling_params=sp + ) + try: + results = [] + async for item in _iter_with_disconnect_check(gen, request): + results.append(item) + choices = [] + prompt_tokens = 0 + completion_tokens = 0 + for i, r in enumerate(results): + choices.append( + { + "index": i, + "text": r.get("text", ""), + "logprobs": None, + "finish_reason": _normalize_finish_reason( + r.get("finished_reason", "stop") + ), + } + ) + prompt_tokens += r.get("prompt_tokens", 0) + completion_tokens += r.get("completion_tokens", 0) + + return ORJSONResponse( + { + "id": _make_completion_id(), + "object": "text_completion", + "created": int(time.time()), + "model": model_name, + "choices": choices, + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except RuntimeError as e: + if "too many queued" in str(e): + raise HTTPException(status_code=429, detail=str(e)) + logger.error("[v1/completions] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + except Exception as e: + logger.error("[v1/completions] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + finally: + await gen.aclose() + + +# --------------------------------------------------------------------------- +# OpenAI-compatible /v1/chat/completions +# --------------------------------------------------------------------------- + + +@app.post("/v1/chat/completions") +async def openai_chat_completions(obj: ChatCompletionRequest, request: Request): + """OpenAI-compatible chat completion endpoint with reasoning & tool-call parsing.""" + if obj.n > 1: + raise HTTPException(status_code=400, detail="n > 1 is not supported") + if obj.logprobs: + raise HTTPException(status_code=400, detail="logprobs is not yet supported") + engine = _get_engine() + cfg = get_global_config() + # Auto-enable thinking when reasoning_parser is configured and the + # client didn't explicitly set enable_thinking. + chat_kwargs = dict(obj.chat_template_kwargs) if obj.chat_template_kwargs else {} + if cfg.server.reasoning_parser and "enable_thinking" not in chat_kwargs: + chat_kwargs["enable_thinking"] = True + prompt = _messages_to_prompt(obj.messages, chat_template_kwargs=chat_kwargs or None) + image_data = _extract_image_data(obj.messages) + + # max_completion_tokens takes precedence over max_tokens (OpenAI convention) + max_tokens = obj.max_completion_tokens if obj.max_completion_tokens is not None else obj.max_tokens + + sp = _build_sampling_params( + temperature=obj.temperature, + top_p=obj.top_p, + top_k=obj.top_k, + max_tokens=max_tokens, + stop=obj.stop, + frequency_penalty=obj.frequency_penalty, + presence_penalty=obj.presence_penalty, + repetition_penalty=obj.repetition_penalty, + seed=obj.seed, + ) + model_name = obj.model or cfg.server.served_model_name or str(cfg.server.model_path) + include_usage = ( + obj.stream_options is not None and obj.stream_options.include_usage + ) + + # Resolve parsers from server config + reasoning_type = cfg.server.reasoning_parser + tool_call_type = cfg.server.tool_call_parser + + gen_kwargs: Dict[str, Any] = { + "prompt": prompt, + "sampling_params": sp, + } + if image_data is not None: + gen_kwargs["image_data"] = image_data + + if obj.stream: + + async def _stream() -> AsyncIterator[bytes]: + from pymllm.parsers import ReasoningParser, ToolCallParser + + comp_id = _make_chat_completion_id() + created = int(time.time()) + first = True + prompt_tokens = 0 + completion_tokens = 0 + has_tool_calls = False # track across entire stream + + # Instantiate streaming parsers + r_parser = ( + ReasoningParser(reasoning_type, stream_reasoning=obj.stream_reasoning) + if reasoning_type and obj.separate_reasoning + else None + ) + tc_parser = ( + ToolCallParser(tool_call_type, tools=obj.tools) + if tool_call_type and obj.tools + else None + ) + + def _make_sse(delta: Dict[str, Any], finish: Optional[str] = None) -> bytes: + sse: Dict[str, Any] = { + "id": comp_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish, + } + ], + } + return b"data: " + orjson.dumps(sse) + b"\n\n" + + gen = engine.generate_async(**gen_kwargs, stream=True) + try: + async for chunk in _iter_with_disconnect_check(gen, request): + prompt_tokens = chunk.get("prompt_tokens", prompt_tokens) + completion_tokens = chunk.get("completion_tokens", completion_tokens) + + raw_delta = chunk.get("delta", "") + finish_reason = _normalize_finish_reason( + chunk.get("finished_reason") + ) + + # --- Phase 1: reasoning parser --- + reasoning_delta = "" + content_delta = raw_delta + if r_parser and raw_delta: + reasoning_delta, content_delta = r_parser.parse_stream_chunk( + raw_delta + ) + + # --- Phase 2: tool-call parser --- + tool_items: list = [] + if tc_parser and content_delta: + content_delta, tool_items = tc_parser.parse_stream_chunk( + content_delta + ) + + # --- Emit chunks --- + # Role chunk (first) + if first: + yield _make_sse({"role": "assistant"}) + first = False + + # Reasoning content + if reasoning_delta: + yield _make_sse({"reasoning_content": reasoning_delta}) + + # Tool call deltas + if tool_items: + has_tool_calls = True + for tc in tool_items: + yield _make_sse({"tool_calls": [tc.to_openai_dict()]}) + + # Normal content + if content_delta: + yield _make_sse({"content": content_delta}) + + # Finish + if finish_reason is not None: + # Flush remaining tool call data + if tc_parser: + remaining = tc_parser.flush() + for tc in remaining: + has_tool_calls = True + yield _make_sse({"tool_calls": [tc.to_openai_dict()]}) + if has_tool_calls: + finish_reason = "tool_calls" + yield _make_sse({}, finish=finish_reason) + + except Exception as e: + logger.error("[v1/chat/completions] stream error: %s", e, exc_info=True) + err = {"error": {"message": "Internal server error"}} + yield b"data: " + orjson.dumps(err) + b"\n\n" + finally: + await gen.aclose() + # Final usage-only chunk + if include_usage: + usage_chunk: Dict[str, Any] = { + "id": comp_id, + "object": "chat.completion.chunk", + "created": created, + "model": model_name, + "choices": [], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + yield b"data: " + orjson.dumps(usage_chunk) + b"\n\n" + yield b"data: [DONE]\n\n" + + return StreamingResponse(_stream(), media_type="text/event-stream") + + # -- Non-streaming -- + gen = engine.generate_async(**gen_kwargs) + try: + from pymllm.parsers import ReasoningParser, ToolCallParser + + r = {} + async for item in _iter_with_disconnect_check(gen, request): + r = item + prompt_tokens = r.get("prompt_tokens", 0) + completion_tokens = r.get("completion_tokens", 0) + text = r.get("text", "") + finish_reason = _normalize_finish_reason(r.get("finished_reason", "stop")) + + # Parse reasoning + reasoning_content = None + if reasoning_type and obj.separate_reasoning: + rp = ReasoningParser(reasoning_type) + reasoning_content, text = rp.parse_non_stream(text) + + # Parse tool calls + tool_calls_list = None + if tool_call_type and obj.tools: + tp = ToolCallParser(tool_call_type, tools=obj.tools) + if tp.has_tool_call(text): + text, parsed_calls = tp.parse_non_stream(text) + if parsed_calls: + tool_calls_list = [tc.to_openai_dict(streaming=False) for tc in parsed_calls] + finish_reason = "tool_calls" + + message: Dict[str, Any] = {"role": "assistant", "content": text or None} + if reasoning_content: + message["reasoning_content"] = reasoning_content + if tool_calls_list: + message["tool_calls"] = tool_calls_list + + return ORJSONResponse( + { + "id": _make_chat_completion_id(), + "object": "chat.completion", + "created": int(time.time()), + "model": model_name, + "choices": [ + { + "index": 0, + "message": message, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + "usage": { + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + "total_tokens": prompt_tokens + completion_tokens, + }, + } + ) + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + except RuntimeError as e: + if "too many queued" in str(e): + raise HTTPException(status_code=429, detail=str(e)) + logger.error("[v1/chat/completions] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + except Exception as e: + logger.error("[v1/chat/completions] Error: %s", e, exc_info=True) + raise HTTPException(status_code=500, detail="Internal server error") + finally: + await gen.aclose() + + +# --------------------------------------------------------------------------- +# Administrative endpoints +# --------------------------------------------------------------------------- + + +@app.api_route("/flush_cache", methods=["GET", "POST"]) +async def flush_cache(): + """Cache flush (not yet implemented).""" + raise HTTPException(status_code=501, detail="Cache flush not implemented") + + +@app.post("/abort_request") +async def abort_request(obj: AbortRequest): + """Abort a running request by rid.""" + engine = _get_engine() + if obj.rid and engine._rr_process is not None: + await engine._rr_process.abort_request(obj.rid) + return Response(status_code=200) + raise HTTPException(status_code=400, detail="Missing or invalid rid") + + +# --------------------------------------------------------------------------- +# Prepare args helper +# --------------------------------------------------------------------------- + + +def _prepare_args(): + """Parse CLI arguments into the global config singleton.""" + parser = make_args() + read_args(parser=parser) + + +# --------------------------------------------------------------------------- +# Server launcher +# --------------------------------------------------------------------------- + + +def launch_server(): + """Launch the pymllm Engine then start the uvicorn HTTP server. + + It first boots all engine subprocesses (tokenizer, scheduler, model-runner, detokenizer) + and then hands off to uvicorn to serve HTTP traffic. + """ + _prepare_args() + cfg = get_global_config() + + # Add CORS middleware (after config is loaded so origins are configurable). + origins = cfg.server.cors_allow_origins + app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=(origins != ["*"]), + allow_methods=["*"], + allow_headers=["*"], + ) + + engine = Engine() + engine.launch() + + # Attach engine to app.state so the lifespan hook can pick it up. + app.state.engine = engine # type: ignore[attr-defined] + + logger.info( + "Starting HTTP server on %s:%s (root_path=%r)", + cfg.server.host, + cfg.server.port, + cfg.server.fastapi_root_path, + ) + + uvicorn.run( + app, + host=cfg.server.host, + port=cfg.server.port, + root_path=cfg.server.fastapi_root_path, + log_level=cfg.server.log_level, + timeout_keep_alive=5, + loop="uvloop", + ) + + +def main(): + """CLI entry point.""" + launch_server() + + +if __name__ == "__main__": + main() diff --git a/pymllm/tests/README.md b/pymllm/tests/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/pymllm/tests/test_vocab_parallel_embedding.py b/pymllm/tests/test_vocab_parallel_embedding.py new file mode 100644 index 000000000..44148f983 --- /dev/null +++ b/pymllm/tests/test_vocab_parallel_embedding.py @@ -0,0 +1,312 @@ +"""Tests for VocabParallelEmbedding layer. + +This module tests the VocabParallelEmbedding layer with and without +tensor parallelism. +""" + +import os +import logging +import pytest +import torch +import torch.nn as nn +import torch.multiprocessing as mp +from typing import Callable + +from pymllm.layers import VocabParallelEmbedding +from pymllm.orchestrator import initialize_model_parallel +from pymllm.orchestrator.parallel_state import ( + get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, +) + +# Show runtime init logs during test execution. +logging.basicConfig(level=logging.INFO, force=True) +logging.getLogger().setLevel(logging.INFO) + + +# ============================================================================= +# Helper: weight loading +# ============================================================================= +def load_weight(param: nn.Parameter, loaded_weight: torch.Tensor) -> None: + """Load weight using the weight_loader attached to param attribute.""" + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is None: + # Fallback: direct copy + param.data.copy_(loaded_weight) + else: + # Call the loader attached to param + weight_loader(param, loaded_weight) + + +# ============================================================================= +# Real distributed tests with world_size=8 on CUDA +# ============================================================================= +def run_worker_tp8_cuda( + rank: int, + local_rank: int, + world_size: int, + local_world_size: int, + test_func: Callable, + return_dict: dict, +): + """Worker function for multi-process testing with TP=8 on CUDA. + + Args: + rank: Global rank across all nodes + local_rank: Local rank within this node (used for GPU binding) + world_size: Total number of processes across all nodes + local_world_size: Number of processes on this node + test_func: Test function to run + return_dict: Shared dict for returning results + """ + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = "29500" + + # Set device using local_rank (binds to GPU 0,1,2,3 on this node) + torch.cuda.set_device(local_rank) + + torch.distributed.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + ) + + initialize_model_parallel(tensor_model_parallel_size=8) + + try: + result = test_func(rank, local_rank, world_size) + return_dict[rank] = result + except Exception as e: + import traceback + + return_dict[rank] = f"ERROR: {e}\n{traceback.format_exc()}" + finally: + torch.distributed.destroy_process_group() + + +def embedding_forward_tp8_worker_cuda(rank: int, local_rank: int, world_size: int): + """Test forward pass with real TP=8 on CUDA. + + Args: + rank: Global rank + local_rank: Local rank within this node (for logging/debugging) + world_size: Total world size + """ + tp_size = get_tensor_model_parallel_world_size() + tp_rank = get_tensor_model_parallel_rank() + + assert tp_size == 8, f"Rank {rank}: tp_size should be 8" + assert tp_rank == rank, f"Rank {rank}: tp_rank mismatch" + + vocab_size = 1024 + embed_dim = 64 + # .cuda() uses the device set by torch.cuda.set_device(local_rank) + layer = VocabParallelEmbedding(vocab_size, embed_dim).cuda() + + # Verify the layer is on the correct GPU + assert layer.weight.device.index == local_rank, ( + f"Rank {rank}: weight should be on GPU {local_rank}, got {layer.weight.device}" + ) + + expected_shard_size = vocab_size // 8 + assert layer.num_embeddings_per_partition == expected_shard_size + assert layer.weight.shape == (expected_shard_size, embed_dim) + + # Each rank initializes its own shard with known pattern + with torch.no_grad(): + layer.weight.fill_(float(rank + 1)) # Rank 0: 1.0, Rank 1: 2.0, ... + + # Create input on the correct GPU + input_ids = torch.tensor([[0, 128, 256, 384], [512, 640, 768, 896]], device="cuda") + + output = layer(input_ids) + assert output.shape == (2, 4, embed_dim) + + # Verify output is on correct GPU + assert output.device.index == local_rank, ( + f"Rank {rank}: output should be on GPU {local_rank}, got {output.device}" + ) + + if rank == 0: + # Each token is owned by exactly one TP rank. Since each rank fills its + # local shard with (rank + 1), post-all-reduce output must match below. + expected_token_values = torch.tensor( + [[1, 2, 3, 4], [5, 6, 7, 8]], + device=output.device, + dtype=output.dtype, + ) + expected_output = expected_token_values.unsqueeze(-1).expand(-1, -1, embed_dim) + + if torch.equal(output, expected_output): + return "PASSED" + return "FAILED: embedding output does not match expected TP aggregation" + + return "OK" + + +def weight_loading_tp8_worker_cuda(rank: int, local_rank: int, world_size: int): + """Test weight loading with real TP=8 on CUDA. + + Args: + rank: Global rank + local_rank: Local rank within this node (for GPU binding verification) + world_size: Total world size + """ + vocab_size = 1024 + embed_dim = 64 + layer = VocabParallelEmbedding(vocab_size, embed_dim).cuda() + + # Verify the layer is on the correct GPU + assert layer.weight.device.index == local_rank, ( + f"Rank {rank}: weight should be on GPU {local_rank}, got {layer.weight.device}" + ) + + full_weight = torch.randn(vocab_size, embed_dim) + load_weight(layer.weight, full_weight.cuda()) + + shard_size = vocab_size // 8 + start_idx = rank * shard_size + end_idx = start_idx + shard_size + expected_shard = full_weight[start_idx:end_idx] + + if not torch.allclose(layer.weight.cpu(), expected_shard): + return f"FAILED: shard mismatch at rank {rank}" + + if rank == 0: + gathered_shards = [layer.weight.cpu().clone()] + for other_rank in range(1, 8): + other_shard = full_weight[ + other_rank * shard_size : (other_rank + 1) * shard_size + ] + gathered_shards.append(other_shard) + + reconstructed = torch.cat(gathered_shards, dim=0) + if torch.allclose(reconstructed, full_weight): + return "PASSED" + else: + return "FAILED: reconstruction mismatch" + + return "OK" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +@pytest.mark.skipif(torch.cuda.device_count() < 8, reason="Requires at least 8 GPUs") +class TestVocabParallelEmbeddingRealTP8: + """Real distributed tests with world_size=8 and TP=8 on CUDA.""" + + def test_forward_pass_tp8_real(self): + """Test forward pass with real TP=8 using 8 processes on CUDA.""" + world_size = 8 + local_world_size = 8 # Single node with 8 GPUs + + mp.set_start_method("spawn", force=True) + + manager = mp.Manager() + return_dict = manager.dict() + + processes = [] + for rank in range(world_size): + # In single-node setup, local_rank == rank + local_rank = rank + p = mp.Process( + target=run_worker_tp8_cuda, + args=( + rank, + local_rank, + world_size, + local_world_size, + embedding_forward_tp8_worker_cuda, + return_dict, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join(timeout=120) + if p.is_alive(): + p.terminate() + p.join() + + for rank in range(world_size): + result = return_dict.get(rank, "TIMEOUT") + if rank == 0: + assert result == "PASSED", f"Rank {rank} failed: {result}" + else: + assert "ERROR" not in str(result), f"Rank {rank} error: {result}" + + def test_weight_loading_tp8_real(self): + """Test weight loading with real TP=8 using 8 processes on CUDA.""" + world_size = 8 + local_world_size = 8 # Single node with 8 GPUs + + mp.set_start_method("spawn", force=True) + + manager = mp.Manager() + return_dict = manager.dict() + + processes = [] + for rank in range(world_size): + # In single-node setup, local_rank == rank + local_rank = rank + p = mp.Process( + target=run_worker_tp8_cuda, + args=( + rank, + local_rank, + world_size, + local_world_size, + weight_loading_tp8_worker_cuda, + return_dict, + ), + ) + p.start() + processes.append(p) + + for p in processes: + p.join(timeout=120) + if p.is_alive(): + p.terminate() + p.join() + + for rank in range(world_size): + result = return_dict.get(rank, "TIMEOUT") + if rank == 0: + assert result == "PASSED", f"Rank {rank} failed: {result}" + else: + assert "ERROR" not in str(result), f"Rank {rank} error: {result}" + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +class TestVocabParallelEmbeddingCUDA: + """Tests for non-parallel TP=1 mode on CUDA.""" + + @pytest.fixture(autouse=True) + def setup_config(self): + import pymllm.orchestrator.parallel_state as ps + ps._TP_SIZE = 1 + ps._TP_RANK = 0 + yield + ps._TP_SIZE = 1 + ps._TP_RANK = 0 + + def test_cuda_forward(self): + layer = VocabParallelEmbedding(1000, 512).cuda() + input_ids = torch.randint(0, 1000, (4, 32), device="cuda") + + output = layer(input_ids) + + assert output.device.type == "cuda" + assert output.shape == (4, 32, 512) + + def test_cuda_weight_loader(self): + layer = VocabParallelEmbedding(100, 64).cuda() + + cpu_weight = torch.randn(100, 64) + load_weight(layer.weight, cpu_weight.cuda()) + + assert torch.allclose(layer.weight.cpu(), cpu_weight) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/pymllm/utils/mllm_convertor_server/service.py b/pymllm/utils/mllm_convertor_server/service.py deleted file mode 100644 index ea8e2bec7..000000000 --- a/pymllm/utils/mllm_convertor_server/service.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) MLLM Team. -# Licensed under the MIT License. diff --git a/pyproject.toml b/pyproject.toml index 703d4456a..ce64b2ee1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [build-system] requires = [ - "scikit-build-core>=0.11.0", "apache-tvm-ffi" + "scikit-build-core>=0.11.0", "apache-tvm-ffi == 0.1.8" ] build-backend = "scikit_build_core.build" @@ -21,7 +21,7 @@ dependencies=[ "packaging", "pytest", "pytest-html", - "apache-tvm-ffi == 0.1.0b4", + "apache-tvm-ffi == 0.1.8.post2", "pyyaml >= 6.0.2", "openai", "modelscope", @@ -30,14 +30,18 @@ dependencies=[ "typer", "torch", "torchao", + "pyfiglet", + "termcolor", ] [project.optional-dependencies] -cuda = ["tilelang"] +cuda = ["tilelang", "flashinfer-python", "pyzmq"] [project.scripts] -mllm-convertor = "pymllm.utils.mllm_convertor:main" -mllm-service = "pymllm.service.tools:cli_app" +pymllm = "pymllm.__main__:main" +mllm-convertor = "pymllm.mobile.utils.mllm_convertor:main" +mllm-service = "pymllm.mobile.service.tools:cli_app" +pymllm-server = "pymllm.server.launch:main" [tool.setuptools.exclude-package-data] "*" = ["*.pyc"] @@ -50,6 +54,8 @@ first_party_detection = false target-version = ["py310", "py311", "py312"] [tool.scikit-build] +# Set to false or use env var SKBUILD_WHEEL_CMAKE=false to skip CMake build +wheel.cmake = true # ABI-agnostic wheel wheel.py-api = "py3" cmake.args = [ diff --git a/tests/cpu/ConvTranspose1DKernelTest.hpp b/tests/cpu/ConvTranspose1DKernelTest.hpp new file mode 100644 index 000000000..d7657baf1 --- /dev/null +++ b/tests/cpu/ConvTranspose1DKernelTest.hpp @@ -0,0 +1,134 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include +#include + +#include "KernelTestHelper.hpp" +#include "mllm/core/ParameterFile.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +using namespace mllm; // NOLINT + +void naive_conv_transpose1d(const float* input_data, const float* weight_data, const float* bias_data, float* output_data, + int batch, int in_channels, int sequence, int out_channels, int kernel_size, int stride, + int padding, int dilation, int output_padding, int groups) { + const int out_sequence = (sequence - 1) * stride - 2 * padding + dilation * (kernel_size - 1) + output_padding + 1; + std::fill_n(output_data, batch * out_channels * out_sequence, 0.0f); + + const int in_channels_per_group = in_channels / groups; + const int out_channels_per_group = out_channels / groups; + + for (int b = 0; b < batch; ++b) { + for (int oc = 0; oc < out_channels; ++oc) { + const int group_idx = oc / out_channels_per_group; + const int oc_in_group = oc % out_channels_per_group; + for (int out_pos = 0; out_pos < out_sequence; ++out_pos) { + float sum = 0.0f; + for (int ic_in_group = 0; ic_in_group < in_channels_per_group; ++ic_in_group) { + const int ic = group_idx * in_channels_per_group + ic_in_group; + const int base_input_idx = b * (in_channels * sequence) + ic * sequence; + const int base_weight_idx = (ic * out_channels_per_group + oc_in_group) * kernel_size; + + for (int k = 0; k < kernel_size; ++k) { + int input_pos = out_pos + padding - k * dilation; + if (input_pos % stride != 0) { continue; } + input_pos /= stride; + if (input_pos < 0 || input_pos >= sequence) { continue; } + + const int input_idx = base_input_idx + input_pos; + const int weight_idx = base_weight_idx + k; + sum += input_data[input_idx] * weight_data[weight_idx]; + } + } + if (bias_data != nullptr) { sum += bias_data[oc]; } + const int output_idx = b * (out_channels * out_sequence) + oc * out_sequence + out_pos; + output_data[output_idx] = sum; + } + } + } +} + +class ConvTranspose1DModule : public nn::Module { + nn::ConvTranspose1D conv_; + + public: + ConvTranspose1DModule(int in_channel, int out_channel, int kernel_size, int stride, int padding, int output_padding, + int dilation, int groups, bool bias) { + conv_ = reg("conv", in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + } + + std::vector forward(const std::vector& inputs, const std::vector& args) override { + return {conv_(inputs[0])}; + } +}; + +class ConvTranspose1DKernelTest : public KernelTest { + public: + bool testConvTranspose1DOnce(const std::unordered_map& cfg) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + + auto module = ConvTranspose1DModule(in_channel, out_channel, kernel_size, stride, padding, output_padding, dilation, + groups, bias); + + auto weight_param = + Tensor::random({in_channel, out_channel / groups, kernel_size}, -1, 1, kFloat32, kCPU); + auto bias_param = Tensor::random({out_channel}, -1, 1, kFloat32, kCPU); + weight_param.setName("conv.weight"); + bias_param.setName("conv.bias"); + + auto param = ParameterFile::create(); + param->push("conv.weight", weight_param); + if (bias) { param->push("conv.bias", bias_param); } + module.load(param); + + auto input = Tensor::random({batch, in_channel, sequence}, -1, 1, kFloat32, kCPU); + auto predict = module(input)[0]; + + auto expected = Tensor::zeros(predict.shape(), kFloat32, kCPU); + naive_conv_transpose1d(input.ptr(), weight_param.ptr(), bias ? bias_param.ptr() : nullptr, + expected.ptr(), batch, in_channel, sequence, out_channel, kernel_size, stride, padding, + dilation, output_padding, groups); + + auto result = test::allClose(expected, predict, 1e-4f, 1e-4f); + if (!result) { + print(result); + return false; + } + return true; + } + + bool testConvTranspose1D(const std::vector>& cfgs) { + for (auto& cfg : cfgs) { + if (!testConvTranspose1DOnce(cfg)) { + auto batch = cfg.at("batch"); + auto in_channel = cfg.at("in_channel"); + auto out_channel = cfg.at("out_channel"); + auto sequence = cfg.at("sequence"); + auto kernel_size = cfg.at("kernel_size"); + auto stride = cfg.at("stride"); + auto padding = cfg.at("padding"); + auto output_padding = cfg.at("output_padding"); + auto dilation = cfg.at("dilation"); + auto groups = cfg.at("groups"); + auto bias = cfg.at("bias"); + print(batch, in_channel, out_channel, sequence, kernel_size, stride, padding, output_padding, dilation, groups, bias); + return false; + } + } + return true; + } +}; diff --git a/tests/cpu/KernelTest.cpp b/tests/cpu/KernelTest.cpp index 9f8d613ee..575360703 100644 --- a/tests/cpu/KernelTest.cpp +++ b/tests/cpu/KernelTest.cpp @@ -857,6 +857,48 @@ TEST_F(FlashAttn2KernelTest, fwd_bshd) { } #endif +//===----------------------------------------------------------------------===// +// Tanh +//===----------------------------------------------------------------------===// +#include "TanhKernelTest.hpp" +TEST_F(TanhKernelTest, TanhFloat32) { EXPECT_EQ(testTanh({{8}, {2, 3, 4}}), true); } + +//===----------------------------------------------------------------------===// +// ConvTranspose1D +//===----------------------------------------------------------------------===// +#include "ConvTranspose1DKernelTest.hpp" +TEST_F(ConvTranspose1DKernelTest, Basic) { + EXPECT_EQ(testConvTranspose1D({ + { + {"batch", 1}, + {"in_channel", 2}, + {"out_channel", 3}, + {"sequence", 4}, + {"kernel_size", 3}, + {"stride", 2}, + {"padding", 1}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 1}, + }, + { + {"batch", 2}, + {"in_channel", 1}, + {"out_channel", 2}, + {"sequence", 5}, + {"kernel_size", 2}, + {"stride", 1}, + {"padding", 0}, + {"output_padding", 0}, + {"dilation", 1}, + {"groups", 1}, + {"bias", 0}, + }, + }), + true); +} + //===----------------------------------------------------------------------===// // Conv2D Test // diff --git a/tests/cpu/TanhKernelTest.hpp b/tests/cpu/TanhKernelTest.hpp new file mode 100644 index 000000000..ff6762170 --- /dev/null +++ b/tests/cpu/TanhKernelTest.hpp @@ -0,0 +1,49 @@ +// Copyright (c) MLLM Team. +// Licensed under the MIT License. +#pragma once + +#include + +#include "KernelTestHelper.hpp" +#include "mllm/mllm.hpp" +#include "mllm/nn/Nn.hpp" + +class TanhModule : public mllm::nn::Module { + mllm::nn::Tanh tanh_; + + public: + TanhModule() { tanh_ = reg("tanh"); } + + std::vector forward(const std::vector& inputs, + const std::vector& args) override { + return {tanh_(inputs[0])}; + } +}; + +class TanhKernelTest : public KernelTest { + public: + bool testTanh(const std::vector& shapes) { + using mllm::Tensor; + using mllm::kCPU; + using mllm::kFloat32; + TanhModule module; + + for (auto& s : shapes) { + auto input = Tensor::random(s, -3, 3, kFloat32, kCPU); + auto output = module(input)[0]; + auto expected = Tensor::empty(s, kFloat32, kCPU).alloc(); + + const auto* in_ptr = input.ptr(); + auto* out_ptr = expected.ptr(); + const auto numel = input.numel(); + for (size_t i = 0; i < numel; ++i) { out_ptr[i] = std::tanh(in_ptr[i]); } + + auto result = mllm::test::allClose(expected, output, 1e-5f, 1e-5f); + if (!result) { + mllm::print(result); + return false; + } + } + return true; + } +};