diff --git a/.agents/skills/tilelang-build/SKILL.md b/.agents/skills/tilelang-build/SKILL.md new file mode 100644 index 0000000000..f474736fc1 --- /dev/null +++ b/.agents/skills/tilelang-build/SKILL.md @@ -0,0 +1,88 @@ +--- +name: tilelang-build +description: Repository-specific build, rebuild, install, and test instructions for tilelang. Use when working in the tilelang repository and the correct commands are needed for building from source, reinstalling after changes, or running project tests. +--- + +# Build & Install + +## Installing / Rebuilding tilelang + +The standard way to build and install: + +```bash +pip install . +``` + +Or with verbose output for debugging build issues: + +```bash +pip install . -v +``` + +`uv pip install .` also works if `uv` is available but is not required. + +Build dependencies are declared in `pyproject.toml` and resolved automatically during `pip install .`. + +If `ccache` is available, repeated builds only recompile changed C++ files. + +## Alternative: Development Build with `--no-build-isolation` + +If you need faster iteration (e.g. calling `cmake` directly to recompile C++ without re-running the full pip install), install build dependencies first: + +```bash +pip install -r requirements-dev.txt +pip install --no-build-isolation . +``` + +After this, you can invoke `cmake --build build` directly to recompile only changed C++ files. This is useful when iterating on C++ code. + +## Alternative: cmake + PYTHONPATH (recommended for C++ development) + +For the fastest C++ iteration, bypass pip entirely and drive cmake directly: + +```bash +# Configure (auto-detects CUDA; git submodules are initialised automatically) +cmake -S . -B build + +# Build +cmake --build build -j$(nproc) + +# Make the local tilelang package importable +export PYTHONPATH=$(pwd):$PYTHONPATH +``` + +After the initial configure, recompiling is just `cmake --build build -j$(nproc)`. The runtime automatically discovers native libraries from `build/lib/` when it detects a dev checkout (see `tilelang/env.py`). + +Useful cmake options: + +| Flag | Purpose | +|------|---------| +| `-DUSE_CUDA=ON/OFF` | Enable/disable CUDA backend (ON by default) | +| `-DUSE_ROCM=ON` | Enable ROCm/HIP backend | +| `-DUSE_METAL=ON` | Enable Metal backend (default on macOS) | +| `-DCMAKE_BUILD_TYPE=Debug` | Debug build with `TVM_LOG_DEBUG` enabled | + +## Editable Installs + +**Never use `pip install -e .`** (editable install). When running Python from the repo root, the local `./tilelang` directory is imported instead of the installed copy (because `.` is on `sys.path` by default). This makes editable installs unnecessary. Avoid `pip install -e .` as it can cause import confusion with this project's layout. + +## Running Tests + +Most tests require a GPU. + +```bash +python -m pytest testing/python/ -x +``` + +Run a specific test file or test case: + +```bash +python -m pytest testing/python/language/test_tilelang_language_copy.py -x +python -m pytest testing/python/language/test_tilelang_language_copy.py -x -k "test_name" +``` + +For Metal-specific tests (requires macOS with Apple Silicon): + +```bash +python -m pytest testing/python/metal/ -x +``` diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 8d5f3ffb48..1931c353ab 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -201,10 +201,10 @@ jobs: if [[ "${UV_INDEX}" == *"/nightly/"* ]]; then uv pip install --prerelease=allow -v torch fi - uv pip install -v -r requirements-test.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + uv pip install -v -r requirements-test.txt echo "import torch; print(f'torch: {torch.__version__}')" | uv run --no-project --script - if [[ "${{ matrix.runner.toolkit }}" == *"CUDA"* ]]; then - uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt -i https://pypi.tuna.tsinghua.edu.cn/simple + uv pip install --no-build-isolation-package=flash-attn -v -r requirements-test-cuda.txt echo "import flash_attn; print(f'flash_attn: {flash_attn.__version__}')" | uv run --no-project --script - # elif [[ "${{ matrix.runner.toolkit }}" == *"ROCm"* ]]; then # uv pip install -v -r requirements-test-rocm.txt @@ -304,12 +304,12 @@ jobs: # Run distributed tests (marked with @requires_distributed) with TILELANG_USE_DISTRIBUTED=1 # DeepEP tests requires fullmesh nvl or internode environment, we disable for now echo "Running distributed examples with TILELANG_USE_DISTRIBUTED=1:" - TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed --ignore-glob='*deepep*' . || true + TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed --ignore-glob='*deepep*' . # Run remaining example tests (non-distributed) # Temporarily disable problematic tests: sink, vs_sparse echo "Running non-distributed examples:" - "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not sink and not vs_sparse" . || true + "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not sink and not vs_sparse" . # NVIDIA CUDA tests - name: Run CUDA tests with Python ${{ matrix.python-version }} (${{ matrix.runner.toolkit }}) @@ -325,12 +325,12 @@ jobs: # Run distributed tests (marked with @requires_distributed) with TILELANG_USE_DISTRIBUTED=1 echo "Running distributed tests with TILELANG_USE_DISTRIBUTED=1:" - TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed . || true + TILELANG_USE_DISTRIBUTED=1 "${PYTEST[@]}" --maxfail=3 --numprocesses=1 -m distributed . # Run remaining tests (non-distributed) # Temporarily disable problematic tests: tilelibrary_gemm, jit_gemm_ctypes echo "Running non-distributed tests:" - "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not tilelibrary_gemm and not jit_gemm_ctypes" . || true + "${PYTEST[@]}" --maxfail=3 --numprocesses=2 -m "not distributed" -k "not tilelibrary_gemm and not jit_gemm_ctypes" . - name: List generated files if: ${{ !cancelled() }} diff --git a/.github/workflows/dist.yml b/.github/workflows/dist.yml index 74132ffb3f..ade9eb8cb2 100644 --- a/.github/workflows/dist.yml +++ b/.github/workflows/dist.yml @@ -125,7 +125,7 @@ jobs: fi - name: Build wheels - uses: pypa/cibuildwheel@v3.3 + uses: pypa/cibuildwheel@v3.4 with: package-dir: . output-dir: wheelhouse diff --git a/.gitignore b/.gitignore index e85c2c0943..9d994457ee 100644 --- a/.gitignore +++ b/.gitignore @@ -133,3 +133,4 @@ maint/host_checks/logs/* # perf regression test .perf_regression/ +nvshmem_issue.md diff --git a/3rdparty/composable_kernel b/3rdparty/composable_kernel index 1c45ca35dd..b38bb492a1 160000 --- a/3rdparty/composable_kernel +++ b/3rdparty/composable_kernel @@ -1 +1 @@ -Subproject commit 1c45ca35dd5c215e0c1db1f40f01556f467f52a8 +Subproject commit b38bb492a1a55b5abb0c345962143c0f9c482cfb diff --git a/3rdparty/tvm b/3rdparty/tvm index 23bce012ff..0e15b274bc 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 23bce012ffd255a24289eea6ceab74a40b94a096 +Subproject commit 0e15b274bce8b46f971abf5ac390e844aa6acee5 diff --git a/CMakeLists.txt b/CMakeLists.txt index 4fb370d509..a4601f4f3e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -2,6 +2,12 @@ # https://github.com/mlc-ai/mlc-llm/blob/main/CMakeLists.txt cmake_minimum_required(VERSION 3.26) + +# Detect CUDA toolkit: tries host installation first, then falls back to +# pip-installed packages (env WITH_PIP_CUDA_TOOLCHAIN or auto-detect). +# Must be included before project() so CMAKE_CUDA_COMPILER is set. +include(${CMAKE_CURRENT_LIST_DIR}/cmake/FindPipCUDAToolkit.cmake) + project(TILE_LANG C CXX) set(CMAKE_CXX_STANDARD 17) @@ -110,6 +116,37 @@ foreach(BACKEND IN LISTS TILELANG_BACKENDS) endforeach() set(PREBUILD_CYTHON ON) + +# CUDA stub libraries (cuda/cudart/nvrtc) are used to build wheels that can run +# across different CUDA Toolkit major versions and/or on CPU-only machines by +# avoiding hard DT_NEEDED dependencies on versioned CUDA SONAMEs. +# +# These stubs are currently POSIX-only (dlopen/dlsym via ). +if(WIN32 AND NOT CYGWIN) + set(_TILELANG_USE_CUDA_STUBS_DEFAULT OFF) +else() + set(_TILELANG_USE_CUDA_STUBS_DEFAULT ON) +endif() +option(TILELANG_USE_CUDA_STUBS + "Use POSIX dlopen-based CUDA stub libraries (cuda/cudart/nvrtc) for portable wheels" + ${_TILELANG_USE_CUDA_STUBS_DEFAULT}) +unset(_TILELANG_USE_CUDA_STUBS_DEFAULT) + +# HIP stub libraries (hip/hiprtc) are used to build wheels that can be imported +# on machines without ROCm installed by avoiding hard DT_NEEDED dependencies on +# libamdhip64.so / libhiprtc.so. +# +# These stubs are currently POSIX-only (dlopen/dlsym via ). +if(WIN32 AND NOT CYGWIN) + set(_TILELANG_USE_HIP_STUBS_DEFAULT OFF) +else() + # Only meaningful when USE_ROCM is enabled. + set(_TILELANG_USE_HIP_STUBS_DEFAULT ON) +endif() +option(TILELANG_USE_HIP_STUBS + "Use POSIX dlopen-based HIP stub libraries (hip/hiprtc) for portable wheels" + ${_TILELANG_USE_HIP_STUBS_DEFAULT}) +unset(_TILELANG_USE_HIP_STUBS_DEFAULT) # Configs end include(cmake/load_tvm.cmake) @@ -127,6 +164,8 @@ foreach(BACKEND IN LISTS TILELANG_BACKENDS) set(${_backend_var} ${TILELANG_OPTION_${_backend_var}} CACHE STRING "${_doc}" FORCE) set(${_backend_var} ${TILELANG_OPTION_${_backend_var}}) endforeach() +# tvm tries to detect gtest by default, but may fail if its header is not installed. +set(USE_GTEST OFF) # Include directories for TileLang set(TILE_LANG_INCLUDES ${TVM_INCLUDES}) @@ -140,8 +179,8 @@ file(GLOB TILE_LANG_SRCS src/op/*.cc src/target/utils.cc src/target/codegen_c_host.cc - src/target/codegen_cpp.cc - src/target/rt_mod_cpp.cc + src/target/codegen_c.cc + src/target/rt_mod_c.cc # intrin_rule doesn't have system dependency src/target/intrin_rule*.cc ) @@ -151,6 +190,8 @@ list(APPEND TILE_LANG_SRCS src/runtime/error_helpers.cc ) +set(TILELANG_OUTPUT_TARGETS tilelang tvm) + # Track if the user explicitly selected a backend via cache options. set(TILELANG_BACKEND_USER_SELECTED OFF) foreach(BACKEND IN LISTS TILELANG_BACKENDS) @@ -183,49 +224,11 @@ if(NOT TILELANG_BACKEND_USER_SELECTED) endif() endif() -if(USE_METAL) - file(GLOB TILE_LANG_METAL_SRCS - src/target/rt_mod_metal.cc - ) - list(APPEND TILE_LANG_SRCS ${TILE_LANG_METAL_SRCS}) - # FIXME: CIBW failed with backtrace, why??? - set(TVM_FFI_USE_LIBBACKTRACE OFF) -elseif(USE_ROCM) - set(CMAKE_HIP_STANDARD 17) - include(${TVM_SOURCE}/cmake/utils/FindROCM.cmake) - find_rocm(${USE_ROCM}) - add_compile_definitions(__HIP_PLATFORM_AMD__ __HIP_PLATFORM_HCC__=1) - - file(GLOB TILE_LANG_HIP_SRCS - src/target/codegen_hip.cc - src/target/rt_mod_hip.cc - ) - list(APPEND TILE_LANG_SRCS ${TILE_LANG_HIP_SRCS}) - list(APPEND TILE_LANG_INCLUDES ${ROCM_INCLUDE_DIRS}) -elseif(USE_CUDA) - set(CMAKE_CUDA_STANDARD 17) - find_package(CUDAToolkit REQUIRED) - set(CMAKE_CUDA_COMPILER "${CUDAToolkit_BIN_DIR}/nvcc") - add_compile_definitions("CUDA_MAJOR_VERSION=${CUDAToolkit_VERSION_MAJOR}") - - # Set `USE_CUDA=/usr/local/cuda-x.y` - cmake_path(GET CUDAToolkit_BIN_DIR PARENT_PATH USE_CUDA) - - file(GLOB TILE_LANG_CUDA_SRCS - src/runtime/runtime.cc - src/runtime/tilescale_cuda_module.cc - src/target/ptx.cc - src/target/codegen_cuda.cc - src/target/codegen_py.cc - src/target/codegen_utils.cc - src/target/codegen_cutedsl.cc - src/target/rt_mod_cuda.cc - src/target/rt_mod_cutedsl.cc - ) - list(APPEND TILE_LANG_SRCS ${TILE_LANG_CUDA_SRCS}) - - list(APPEND TILE_LANG_INCLUDES ${CUDAToolkit_INCLUDE_DIRS}) -endif() +# Backend-local CMake files own native source lists, stubs, include paths, and +# compile definitions. Top-level CMake only selects and delegates. +include("${CMAKE_CURRENT_SOURCE_DIR}/src/backend/cuda/CMakeLists.txt") +include("${CMAKE_CURRENT_SOURCE_DIR}/src/backend/rocm/CMakeLists.txt") +include("${CMAKE_CURRENT_SOURCE_DIR}/src/backend/metal/CMakeLists.txt") set(USE_Z3 ON CACHE STRING "Use Z3 SMT solver for TileLang optimizations") set(USE_PYPI_Z3 ON CACHE BOOL "Use Z3 provided by PyPI z3-solver package") @@ -235,27 +238,50 @@ if(USE_Z3 AND USE_PYPI_Z3) find_package(Z3 REQUIRED) endif() +# Enable custom logging so we control the output format (e.g. strip build paths +# from __FILE__ so wheel users don't see CI machine paths in warnings). +set(USE_CUSTOM_LOGGING ON CACHE BOOL "Use custom logging implementation" FORCE) + +# Detect release (wheel) builds: in CI (cibuildwheel) or scikit-build-core wheel builds, +# we strip source paths from LOG(WARNING) etc. for a cleaner user experience. +# Local dev builds keep full paths for debugging. +if(DEFINED ENV{CIBUILDWHEEL} OR "$ENV{SKBUILD_STATE}" STREQUAL "wheel") + set(TILELANG_RELEASE_BUILD_DEFAULT ON) +else() + set(TILELANG_RELEASE_BUILD_DEFAULT OFF) +endif() +option(TILELANG_RELEASE_BUILD "Strip source paths from log messages (for wheel releases)" ${TILELANG_RELEASE_BUILD_DEFAULT}) + # Include tvm after configs have been populated add_subdirectory(${TVM_SOURCE} tvm EXCLUDE_FROM_ALL) +# Provide the custom LogMessageImpl / LogFatalImpl implementation to TVM, +# since TVM_LOG_CUSTOMIZE=1 requires them to be supplied by the user. +target_sources(tvm_objs PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/src/runtime/logging.cc") +if(TILELANG_RELEASE_BUILD) + target_compile_definitions(tvm_objs PRIVATE TILELANG_RELEASE_BUILD=1) +endif() + # Resolve compile warnings in tvm add_compile_definitions(DMLC_USE_LOGGING_LIBRARY=) add_library(tilelang_objs OBJECT ${TILE_LANG_SRCS}) # Set debug mode compile definitions -# We open the deubg option of TVM, i.e. TVM_LOG_DEBUG +# Enable the TVM debug option, i.e., TVM_LOG_DEBUG if(CMAKE_BUILD_TYPE STREQUAL "Debug") message(STATUS "Building TileLang with DEBUG mode") target_compile_definitions(tilelang_objs PRIVATE "TVM_LOG_DEBUG") endif() target_include_directories(tilelang_objs PRIVATE ${TILE_LANG_INCLUDES}) +target_compile_definitions(tilelang_objs PRIVATE TVM_LOG_CUSTOMIZE=1) +if(TILELANG_RELEASE_BUILD) + target_compile_definitions(tilelang_objs PRIVATE TILELANG_RELEASE_BUILD=1) +endif() add_library(tilelang SHARED $) -add_library(tilelang_module SHARED $) -target_link_libraries(tilelang PUBLIC tvm_runtime tvm) -target_link_libraries(tilelang_module PUBLIC tvm) +target_link_libraries(tilelang PUBLIC tvm) # Place dev build outputs under build/lib for consistency set_target_properties(tilelang PROPERTIES @@ -263,11 +289,6 @@ set_target_properties(tilelang PROPERTIES RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" ) -set_target_properties(tilelang_module PROPERTIES - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" -) # Build cython extension find_package(Python REQUIRED COMPONENTS Interpreter Development.Module ${SKBUILD_SABI_COMPONENT}) @@ -288,6 +309,13 @@ endif() python_add_library(tilelang_cython_wrapper MODULE "${CMAKE_BINARY_DIR}/tilelang_cython_wrapper.cpp" ${USE_SABI} WITH_SOABI) +# Disable Cython's PEP-489 multi-phase init for the wrapper. The generated +# C++ depends on CPython's private `_xxsubinterpreters` module at import +# time, which is stripped from some distributor-built Python 3.12 builds +# (notably Red Hat's RHEL 9 system Python). Single-phase init avoids that +# dependency and matches Cython's own suggested workaround. See #2125. +target_compile_definitions(tilelang_cython_wrapper PRIVATE CYTHON_PEP489_MULTI_PHASE_INIT=0) + # Ensure dev builds drop the extension into build/lib alongside other shared libs set_target_properties(tilelang_cython_wrapper PROPERTIES LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" @@ -304,7 +332,7 @@ install(TARGETS tilelang_cython_wrapper # Copy libz3.so to build folder to workaround isolated build env issue if(USE_Z3 AND USE_PYPI_Z3) get_target_property(Z3_LIBRARY_PATH z3::libz3 IMPORTED_LOCATION) - install(FILES "${Z3_LIBRARY_PATH}" DESTINATION "${CMAKE_BINARY_DIR}/tvm") + install(FILES "${Z3_LIBRARY_PATH}" DESTINATION "${CMAKE_BINARY_DIR}/lib") if(APPLE) set_target_properties(tvm PROPERTIES BUILD_RPATH "@loader_path") else() @@ -312,87 +340,74 @@ if(USE_Z3 AND USE_PYPI_Z3) endif() endif() +if(DEFINED TILELANG_ACTIVE_BACKEND_STUB_LINK) + foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + target_link_libraries(${target} PUBLIC ${TILELANG_ACTIVE_BACKEND_STUB_LINK}) + endforeach() +endif() + +# Append stub targets after the linking loop so they don't link to themselves +if(DEFINED TILELANG_ACTIVE_BACKEND_STUB_TARGETS) + list(APPEND TILELANG_OUTPUT_TARGETS ${TILELANG_ACTIVE_BACKEND_STUB_TARGETS}) +endif() + +unset(PATCHELF_EXECUTABLE CACHE) + if(APPLE) set(TILELANG_INSTALL_RPATH "@loader_path;@loader_path/../../tvm_ffi/lib") if(USE_Z3 AND USE_PYPI_Z3) - # some z3 is placed in lib/ and some in bin/, we add both in rpath - list(APPEND TILELANG_INSTALL_RPATH "@loader_path/../../z3/lib" "@loader_path/../../z3/bin") + # Some z3 is placed in lib/ and some in bin/, we add both in rpath + string(APPEND TILELANG_INSTALL_RPATH ";@loader_path/../../z3/lib;@loader_path/../../z3/bin") endif() elseif(UNIX) set(TILELANG_INSTALL_RPATH "\$ORIGIN:\$ORIGIN/../../tvm_ffi/lib") if(USE_Z3 AND USE_PYPI_Z3) - # cmake uses ; by default, we explicitly use : for linux string(APPEND TILELANG_INSTALL_RPATH ":\$ORIGIN/../../z3/lib") endif() + if(DEFINED TILELANG_ACTIVE_BACKEND_RPATH_EXTRA) + string(APPEND TILELANG_INSTALL_RPATH "${TILELANG_ACTIVE_BACKEND_RPATH_EXTRA}") + endif() + find_program(PATCHELF_EXECUTABLE patchelf) + if (NOT PATCHELF_EXECUTABLE) + message(STATUS "`patchelf` not found.") + endif() endif() -set_target_properties( - tilelang tilelang_module tvm tvm_runtime - PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") +# Let libtilelang search for tvm in the same directory +foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + set_target_properties(${target} PROPERTIES INSTALL_RPATH "${TILELANG_INSTALL_RPATH}") + set_target_properties(${target} PROPERTIES + LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + RUNTIME_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" + ) +endforeach() + +# Strip backend runtime dependencies for portable wheels +if(DEFINED TILELANG_ACTIVE_BACKEND_PATCHELF_REMOVE AND PATCHELF_EXECUTABLE) + foreach(_needed IN LISTS TILELANG_ACTIVE_BACKEND_PATCHELF_REMOVE) + set(_patchelf_remove_args "${_patchelf_remove_args} --remove-needed ${_needed}") + endforeach() + foreach(target IN LISTS TILELANG_OUTPUT_TARGETS) + install(CODE " + execute_process( + COMMAND ${PATCHELF_EXECUTABLE}${_patchelf_remove_args} + \"$\" + WORKING_DIRECTORY \"${CMAKE_INSTALL_PREFIX}\" + RESULT_VARIABLE patchelf_result + ) + if(patchelf_result EQUAL 0) + message(STATUS \"patchelf: removed dependencies from $\") + else() + message(WARNING \"patchelf failed for $\") + endif() + ") + endforeach() +endif() install( - TARGETS tvm tvm_runtime tilelang_module tilelang + TARGETS ${TILELANG_OUTPUT_TARGETS} LIBRARY DESTINATION tilelang/lib + RUNTIME DESTINATION tilelang/lib + ARCHIVE DESTINATION tilelang/lib ) - -# Build tilescale_ext PyTorch C++ extension -if(USE_CUDA) - # Find Torch - execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import torch; print(torch.utils.cmake_prefix_path)" - OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH - OUTPUT_STRIP_TRAILING_WHITESPACE - RESULT_VARIABLE TORCH_CMAKE_RESULT - ) - if(TORCH_CMAKE_RESULT EQUAL 0 AND EXISTS "${TORCH_CMAKE_PREFIX_PATH}") - list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX_PATH}") - endif() - - find_package(Torch QUIET) - if(Torch_FOUND) - message(STATUS "Building tilescale_ext with Torch ${Torch_VERSION}") - - set(TILESCALE_EXT_SOURCES - ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/ts_ext_bindings.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/tensor.cpp - ${CMAKE_CURRENT_SOURCE_DIR}/tilelang/utils/ts_ext/ipc_ops.cpp - ) - - # Find libtorch_python.so - execute_process( - COMMAND "${Python_EXECUTABLE}" -c "import torch; import os; print(os.path.join(os.path.dirname(torch.__file__), 'lib', 'libtorch_python.so'))" - OUTPUT_VARIABLE TORCH_PYTHON_LIBRARY - OUTPUT_STRIP_TRAILING_WHITESPACE - RESULT_VARIABLE TORCH_PYTHON_RESULT - ) - - python_add_library(tilescale_ext_C MODULE ${TILESCALE_EXT_SOURCES} WITH_SOABI) - target_compile_definitions(tilescale_ext_C PRIVATE TORCH_EXTENSION_NAME=_C) - target_include_directories(tilescale_ext_C PRIVATE - ${TORCH_INCLUDE_DIRS} - ${CUDAToolkit_INCLUDE_DIRS} - ) - - if(TORCH_PYTHON_RESULT EQUAL 0 AND EXISTS "${TORCH_PYTHON_LIBRARY}") - message(STATUS "Found libtorch_python: ${TORCH_PYTHON_LIBRARY}") - target_link_libraries(tilescale_ext_C PRIVATE ${TORCH_LIBRARIES} ${TORCH_PYTHON_LIBRARY} CUDA::cudart) - else() - message(WARNING "libtorch_python.so not found, extension may have undefined symbols") - target_link_libraries(tilescale_ext_C PRIVATE ${TORCH_LIBRARIES} CUDA::cudart) - endif() - - target_compile_options(tilescale_ext_C PRIVATE -fPIC) - set_target_properties(tilescale_ext_C PROPERTIES - OUTPUT_NAME "_C" - CXX_STANDARD 17 - LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib" - ) - - # Install as tilescale_ext/_C.so so it can be imported as tilescale_ext._C - install(TARGETS tilescale_ext_C - LIBRARY DESTINATION tilescale_ext - RUNTIME DESTINATION tilescale_ext) - else() - message(WARNING "Torch not found, tilescale_ext will not be built") - endif() -endif() diff --git a/VERSION b/VERSION index e52aba075b..97c7127d8d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.1.7.post1 +0.1.9.post1 diff --git a/cmake/FindPipCUDAToolkit.cmake b/cmake/FindPipCUDAToolkit.cmake new file mode 100644 index 0000000000..29cb3f3642 --- /dev/null +++ b/cmake/FindPipCUDAToolkit.cmake @@ -0,0 +1,70 @@ +# FindPipCUDAToolkit.cmake +# +# Locate CUDA toolkit — first trying the host system, then falling back +# to pip-installed packages (nvidia-cuda-nvcc, nvidia-cuda-cccl). +# +# This module should be included BEFORE project() to set CMAKE_CUDA_COMPILER +# when pip CUDA is used. +# +# Detection order: +# 1. Try find_package(CUDAToolkit QUIET) — succeeds if a host CUDA +# installation is available; skip pip detection. +# 2. If env var WITH_PIP_CUDA_TOOLCHAIN is set to a path (e.g., .../cu13), +# use that directory directly as the CUDA toolkit root. +# 3. Otherwise, try auto-detecting from the current Python environment's +# site-packages (works with --no-build-isolation). + +# --- Try host CUDA first --- +find_package(CUDAToolkit QUIET) +if(CUDAToolkit_FOUND) + return() +endif() + +find_program(_PIP_CUDA_PYTHON_EXE NAMES python3 python) +if(NOT _PIP_CUDA_PYTHON_EXE) + return() +endif() + +# --- Strategy 1: explicit path via env var --- +if(DEFINED ENV{WITH_PIP_CUDA_TOOLCHAIN}) + set(_PIP_CUDA_ROOT "$ENV{WITH_PIP_CUDA_TOOLCHAIN}") + if(NOT EXISTS "${_PIP_CUDA_ROOT}/bin/nvcc") + message(FATAL_ERROR + "FindPipCUDAToolkit: WITH_PIP_CUDA_TOOLCHAIN is set to '${_PIP_CUDA_ROOT}' " + "but nvcc was not found at '${_PIP_CUDA_ROOT}/bin/nvcc'") + endif() + # Prepare the directory (create lib64 symlink, unversioned .so symlinks, + # libcuda.so stub) that CMake / nvcc expect but pip packages omit. + execute_process( + COMMAND "${_PIP_CUDA_PYTHON_EXE}" "${CMAKE_CURRENT_LIST_DIR}/find_pip_cuda.py" + "${_PIP_CUDA_ROOT}" + OUTPUT_QUIET + ) + message(STATUS "FindPipCUDAToolkit: using env WITH_PIP_CUDA_TOOLCHAIN=${_PIP_CUDA_ROOT}") +else() + # --- Strategy 2: auto-detect from current Python env --- + execute_process( + COMMAND "${_PIP_CUDA_PYTHON_EXE}" "${CMAKE_CURRENT_LIST_DIR}/find_pip_cuda.py" + OUTPUT_VARIABLE _PIP_CUDA_OUTPUT + OUTPUT_STRIP_TRAILING_WHITESPACE + RESULT_VARIABLE _PIP_CUDA_RESULT + ) + + if(NOT _PIP_CUDA_RESULT EQUAL 0) + message(STATUS "FindPipCUDAToolkit: pip-installed CUDA toolkit not found") + return() + endif() + + string(JSON _PIP_CUDA_ROOT GET "${_PIP_CUDA_OUTPUT}" "root") + message(STATUS "FindPipCUDAToolkit: auto-detected from Python environment") +endif() + +# --- Common pip-CUDA setup --- +set(CMAKE_CUDA_COMPILER "${_PIP_CUDA_ROOT}/bin/nvcc" CACHE FILEPATH "CUDA compiler (from pip)" FORCE) +set(CUDAToolkit_ROOT "${_PIP_CUDA_ROOT}" CACHE PATH "CUDA toolkit root (from pip)" FORCE) + +list(APPEND CMAKE_LIBRARY_PATH "${_PIP_CUDA_ROOT}/lib/stubs" "${_PIP_CUDA_ROOT}/lib") + +message(STATUS "FindPipCUDAToolkit: using pip-installed CUDA toolkit") +message(STATUS " nvcc: ${CMAKE_CUDA_COMPILER}") +message(STATUS " root: ${CUDAToolkit_ROOT}") diff --git a/cmake/find_pip_cuda.py b/cmake/find_pip_cuda.py new file mode 100644 index 0000000000..2bbf6b890a --- /dev/null +++ b/cmake/find_pip_cuda.py @@ -0,0 +1,103 @@ +"""Locate pip-installed CUDA toolkit and prepare it for CMake consumption. + +Used by cmake/FindPipCUDAToolkit.cmake via ``execute_process``. +Outputs a JSON object with paths on success, exits with code 1 on failure. + +Usage: + python find_pip_cuda.py # auto-detect from current env + python find_pip_cuda.py /path/to/cu13 # use explicit path, just prepare it +""" + +import contextlib +import json +import pathlib +import subprocess +import sys + + +def _find_cu_dir(): + """Find the nvidia/cu directory from the nvidia pip package.""" + try: + import nvidia + except ImportError: + return None + + nvidia_dir = pathlib.Path(nvidia.__path__[0]) + cu_dirs = sorted( + (d for d in nvidia_dir.iterdir() if d.name[:2] == "cu" and d.name[2:].isdigit()), + key=lambda d: int(d.name[2:]), + ) + if not cu_dirs: + return None + cu_dir = cu_dirs[-1] + if (cu_dir / "bin" / "nvcc").is_file(): + return cu_dir + return None + + +def _ensure_lib_symlinks(cu_dir): + """Create symlinks that CMake / nvcc expect but pip packages omit.""" + lib_dir = cu_dir / "lib" + if not lib_dir.is_dir(): + return + + # nvcc expects lib64/ on 64-bit + lib64 = cu_dir / "lib64" + if not lib64.exists(): + with contextlib.suppress(OSError): + lib64.symlink_to("lib") + + # CMake expects unversioned .so (e.g., libcudart.so) + for so in lib_dir.glob("*.so.*"): + base = lib_dir / (so.name.split(".so.")[0] + ".so") + if not base.exists(): + with contextlib.suppress(OSError): + base.symlink_to(so.name) + + +def _ensure_cuda_stub(cu_dir): + """Create a minimal libcuda.so stub for build-time -lcuda linking.""" + stubs_dir = cu_dir / "lib" / "stubs" + stub = stubs_dir / "libcuda.so" + if stub.exists(): + return + stubs_dir.mkdir(parents=True, exist_ok=True) + src = stubs_dir / "_stub.c" + try: + src.write_text("void cuGetErrorString(void){}\n") + subprocess.check_call( + ["gcc", "-shared", "-o", str(stub), str(src)], + stderr=subprocess.DEVNULL, + ) + except Exception: + pass + finally: + src.unlink(missing_ok=True) + + +def main(): + if len(sys.argv) > 1: + # Explicit path provided — just prepare it + cu_dir = pathlib.Path(sys.argv[1]) + else: + # Auto-detect from current Python environment + cu_dir = _find_cu_dir() + + if cu_dir is None or not (cu_dir / "bin" / "nvcc").is_file(): + sys.exit(1) + + _ensure_lib_symlinks(cu_dir) + _ensure_cuda_stub(cu_dir) + + print( + json.dumps( + { + "nvcc": str(cu_dir / "bin" / "nvcc"), + "root": str(cu_dir), + } + ) + ) + + +if __name__ == "__main__": + main() diff --git a/docs/_static/img/ir_transform_diagram.png b/docs/_static/img/ir_transform_diagram.png index 3bd8689139..f6cbc9da4a 100644 Binary files a/docs/_static/img/ir_transform_diagram.png and b/docs/_static/img/ir_transform_diagram.png differ diff --git a/docs/deeplearning_operators/gemv.md b/docs/deeplearning_operators/gemv.md index 38287f2205..c2dddf47fe 100644 --- a/docs/deeplearning_operators/gemv.md +++ b/docs/deeplearning_operators/gemv.md @@ -292,7 +292,7 @@ def splitk_gemv_vectorized_tvm( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): @@ -377,7 +377,7 @@ def get_best_config(N, K): C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): diff --git a/docs/merge_upstream_tilelang.md b/docs/merge_upstream_tilelang.md index 48a1ec9a65..c702585027 100644 --- a/docs/merge_upstream_tilelang.md +++ b/docs/merge_upstream_tilelang.md @@ -344,6 +344,135 @@ PR [#50](https://github.com/tile-ai/tilescale/pull/50) ("Sync mainstream TileLan --- +## 10. Practical Lessons from PR #58 + +This section captures the hard-won lessons from the `0.1.7.post1 → 0.1.9` sync (~50 upstream commits, ~80K LOC diff). Unlike PR #50 which cherry-picked individual commits, PR #58 merged the entire upstream delta in one operation — a much more aggressive approach that revealed systematic failure modes. + +### 10.1 `src/transform/` and `src/tl_templates/` Are NOT All TileScale-Exclusive + +Section 2.3 classifies these directories as "TileScale-specific pass infrastructure" and says "never overwrite". This is **misleading for bulk operations**. In practice: + +- **Most files in `src/transform/` exist in both repos** (e.g., `layout_inference.cc`, `loop_partition.cc`, `lower_tile_op.cc`, `inject_pipeline.cc`). They originated from upstream and were modified by both sides. +- **Only a small subset are truly TileScale-only**: `lower_cpengine_intrin.cc`, `storage_access.cc/h`, `wgmma_sync_rewriter.cc`, `align_dynamic_shared_memory_allocations.cc`, `inject_ptx_async_copy.cc`, `inject_fence_proxy.cc`. +- **Upstream also adds new transforms** (e.g., `producer_consumer_ws.cc`, `unroll_loop.cc`, `verify_parallel_loop.cc`, `fuse_mbarrier_arrive_expect_tx.cc`) that TileScale should absorb. +- **The rule**: For each file, check whether it exists in the upstream commit (`git cat-file -e :`). If it does, `--theirs` (take upstream) is the safe default. Only keep `--ours` for files that are genuinely TileScale-only. + +### 10.2 `git merge` vs `git cherry-pick` + +- **`git merge` (one-shot)**: Fast but produces ~400 conflicted files. Resolution must be done programmatically (batch `--ours`/`--theirs`). The batch resolution can silently corrupt files that need manual adaptation. +- **`git cherry-pick` (per-commit)**: Safer, more auditable, but slow for 50+ commits. +- **For >20 commits**: Consider `git merge --no-commit`, then resolve conflicts with the file-by-file decision table below, then `git commit`. + +### 10.3 Mandatory Build-Import-Run Loop + +After conflict resolution, the merge is **never** clean on the first try. Follow this loop until both `import tilelang` and a distributed example pass: + +```bash +ninja -C build 2>&1 | grep error: # fix C++ build errors +python -c "import tilelang" # fix Python import errors +python examples/distributed/example_xxx.py # fix runtime errors +``` + +Common failure categories and their symptoms: + +| Symptom | Root Cause | Fix | +|---------|-----------|-----| +| `undefined symbol: _ZN3tvm2tl31ApplyMultiVersionBufferRewriterE...` | Stale TileScale `.cc` kept as `--ours`; upstream added function to this file | `git checkout -- ` | +| `no matching function for call to 'VectorizeLoop(..., LayoutMap&)'` | Upstream removed/renamed an overload | Check upstream `loop_vectorize.h` for new signatures; adjust callers | +| `'create_list_of_mbarrier' was not declared` / `'get_mbarrier' was not declared` | TileScale ops registered in old `builtin.cc`; removed in upstream | Add them back to `builtin.cc` and `builtin.h` | +| `error: 'LoopPragmaUnroll' was not declared` | Upstream renamed to `PragmaUnrollLoop` | Bulk rename | +| `error: 'atomicadd_elem_op' was not declared; did you mean 'atomic_add_elem_op'?` | Upstream added underscore | Bulk rename | +| `Module has no function '__tilescale_init_table'` | Upstream `rt_mod_cuda.cc` uses `CUDAModuleCreate` instead of `TileScaleCUDAModuleCreate` | Restore `TileScaleCUDAModuleCreate` calls + include in `rt_mod_cuda.cc` | +| `'JITKernel' object has no attribute 'initialize'` | Upstream `jit/kernel.py` doesn't have TileScale's `initialize()` | Add back `initialize()` method + `allocator` attribute | +| `TVMFFIKernelAdapter has no attribute 'init_table'` | Upstream adapter doesn't have `init_table()` | Add back `init_table()` to `tilelang/jit/adapter/tvm_ffi.py` | +| `'lazy_jit' not found in tilelang.jit` | Upstream `jit/__init__.py` doesn't export `lazy_jit` | Remove from `__init__.py` import, or add back implementation | + +### 10.4 The Distributed Codegen Must Be Surgically Preserved + +TileScale adds significant infrastructure to CUDA codegen that upstream knows nothing about. After merging upstream codegen files, the following **must** be present: + +**In `src/target/codegen_cuda.h`**: +```cpp +static inline bool use_distributed() { + const char *env = std::getenv("TILELANG_USE_DISTRIBUTED"); + if (env) return std::string(env) == "1"; + return false; +} +// Inside class CodeGenTileLangCUDA: +bool use_distributed_{use_distributed()}; +bool need_multimem_h_{false}; +``` + +**In `src/target/codegen_cuda.cc`**: +```cpp +#include "../op/distributed.h" +#include "../op/sync.h" + +// Inside Finish(): +if (use_distributed_) { + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "#include \n"; + decl_stream << "extern \"C\" __constant__ uint64_t meta_data[1024];\n"; +} +if (need_multimem_h_) { + decl_stream << "#include \n"; +} +``` + +**In `src/target/rt_mod_cuda.cc`**: Replace upstream `CUDAModuleCreate` with `TileScaleCUDAModuleCreate` and add `#include "../runtime/tilescale_cuda_module.h"`. + +### 10.5 TileScale-Specific Python Utilities That Pass Silently + +These files exist only in TileScale and were not overwritten by the merge, but their callers in shared modules may have changed: + +| File | TileScale Purpose | What Can Break | +|------|-------------------|----------------| +| `tilelang/utils/allocator.py` | `BaseAllocator`, `get_allocator()` | `torch.set_default_device` conflicts with `all_gather_object`; `parse_device` must be correct | +| `tilelang/utils/tensor.py` (line `tensor()` function) | `tilelang.tensor(...)` factory | Lost `tensor()` function if upstream file overwrites it | +| `tilelang/utils/target.py` (line `parse_device()`) | device string parsing for allocator | `parse_device("cuda")` returning hardcoded 0 instead of `current_device()` | +| `tilelang/distributed/utils.py` | `init_dist()`, `perf_fn()` | `torch.set_default_device("cuda")` before `init_process_group` causes NCCL device mismatch | + +### 10.6 The Device Mismatch Trap + +When `init_dist()` calls `torch.set_default_device("cuda")` (without device index), all PyTorch tensors default to `cuda:0`. With newer PyTorch (2.2+) passing `device_id` to `init_process_group`, NCCL enforces that collective tensors match the process group's device. This causes: + +``` +Torch.distributed.DistBackendError: Tensor found on device cuda:0 but backend constrained to cuda:1 +``` + +**Fix**: Call `torch.cuda.set_device(local_rank)` BEFORE `init_process_group`, and use explicit device strings. Also ensure `parse_device("cuda")` returns `torch.cuda.current_device()` rather than hardcoded `0`. + +### 10.7 Duplicate Op Registration Detection + +After a large merge, upstream may have added op registrations that TileScale's files also register. Check with: + +```bash +python -c "import tilelang" 2>&1 | grep "already registered" +``` + +If you see `Global Function 'tl.X' is already registered`, search for duplicate `refl::GlobalDef().def("tl.X", ...)` registrations and remove the TileScale copy (keep the upstream one). + +### 10.8 After Merge: Restore Truly TileScale-Only Files from Old Main + +After batch resolution, verify these files match the pre-sync TileScale version: + +| Category | Key Files | +|----------|-----------| +| Distributed C++ ops | `src/op/distributed.cc/h`, `src/op/remote_copy.cc/h`, `src/op/sync.cc/h`, `src/op/multimem.cc/h`, `src/op/multimem_rewriter.h`, `src/op/gemm_py.cc/h` | +| Distributed runtime | `src/runtime/tilescale_cuda_module.cc/h`, `src/shared_memory/shared_memory.cc` | +| Distributed templates | `src/tl_templates/cuda/distributed.h`, `sync.h`, `ldst.h`, `multimem.h` | +| TileScale transforms | `lower_cpengine_intrin.cc`, `storage_access.cc/h`, `wgmma_sync_rewriter.cc`, `align_dynamic_shared_memory_allocations.cc`, `inject_ptx_async_copy.cc`, `inject_fence_proxy.cc` | +| Python distributed | `tilelang/distributed/**`, `tilelang/language/distributed/**`, `tilelang/utils/allocator.py` | +| Build config | `src/backend/cuda/CMakeLists.txt` (must include `tilescale_cuda_module.cc` and `shared_memory/shared_memory.cc`) | + +```bash +# Restore a known-good TileScale file +git show main: > +``` + +--- + ## 9. Checklist for Each Sync PR Before opening the PR: @@ -353,8 +482,13 @@ Before opening the PR: - [ ] `CMakeLists.txt` conflict resolved; `tilescale_ext` target intact - [ ] `tilelang/__init__.py` still exports distributed namespace - [ ] Full build passes +- [ ] `import tilelang` succeeds with no import errors +- [ ] `tilelang.distributed` imports successfully - [ ] Shared `testing/python/` tests pass - [ ] At least one distributed example runs end-to-end +- [ ] `TileScaleCUDAModuleCreate` used in `rt_mod_cuda.cc` (not `CUDAModuleCreate`) +- [ ] Distributed template includes present in `codegen_cuda.cc` (`sync.h`, `ldst.h`, `distributed.h`, `multimem.h`, `meta_data`) +- [ ] No duplicate TVM FFI registrations (`python -c "import tilelang"` clean) - [ ] API-breaking upstream changes reflected in TileScale distributed layer if applicable - [ ] PR title follows: `[Sync] Merge upstream TileLang ` - [ ] PR description lists: last-synced upstream SHA, new upstream SHA, major features included, any skipped items with justification diff --git a/docs/programming_guides/python_compatibility.md b/docs/programming_guides/python_compatibility.md new file mode 100644 index 0000000000..b858e392ab --- /dev/null +++ b/docs/programming_guides/python_compatibility.md @@ -0,0 +1,59 @@ +# Python Compatibility + +TileLang is a Python-embedded DSL, but not all Python syntax is supported inside +TileLang DSL. This guide clarifies what works, what doesn't, and how +to translate common Python patterns into TileLang equivalents. Specially, we focus on +the kernel part (scripts inside `with T.Kernel`) semantics. For host-side semantics when +using eager-style JIT, please stay tuned for our upcoming documentation. + +The following codes use the conventional aliases: + +```python +import tilelang +import tilelang.language as T +from tilelang import jit +``` + +## Control Flow & Loops + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `for i in range(n)` | ✅ | Maps to `T.serial(n)` | +| `for i in range(a,b,s)` | ✅ | Maps to `T.serial(a, b, s)` | +| `for x in list` | ❌ | Use index-based loop | +| `while condition` | ✅ | | +| `if` / `elif` / `else` | ✅ | | +| `x if cond else y` | ✅ | Ternary expression | +| `break` / `continue` | ✅ | | +| `enumerate()` / `zip()` | ❌ | | + +## Data Access + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `a[i]` indexing | ✅ | Multi-dim indexing supported: `a[i, j, k]` | +| `a[i:j]` slicing | ✅ | Creates `BufferRegion` | +| `a[-1]` negative index | ✅ | | + +## Assignment & Arithmetic Operations + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `x = expr` | ✅ | | +| `+`, `-`, `*`, `/`, `%` | ✅ | Maps to device-side arithmetic operations | +| `+=`, `-=`, `*=`, etc. | ✅ | Augmented assignment | +| `a = b = c` | ❌ | Use separate assignments | + +## Functions & Classes + +As a kernel script language, TileLang doesn't support functions or classes. You can use `@T.macro` to define reusable code blocks, which will be inlined at compile time like `__device__` function. + +## Statements & Built-in Functions + +| Python Feature | Supported | Notes / Alternative | +|-------------------------|:---------:|------------------------------------------| +| `with` | ⚠️ | Only `T.Kernel`, `T.ws` | +| `assert` | ⚠️ | Use `T.device_assert` or `T.assert` | +| `print()` | ⚠️ | Use `T.print()`; `print` works for Python expressions | +| `len()` | ❌ | Use `buffer.shape[dim]` | +| `type()`, `isinstance()`| ❌ | | diff --git a/docs/runtime_internals/stubs.md b/docs/runtime_internals/stubs.md new file mode 100644 index 0000000000..ee4c628e79 --- /dev/null +++ b/docs/runtime_internals/stubs.md @@ -0,0 +1,47 @@ +# CUDA and ROCm Stub Libraries + +This document describes TileLang's stub mechanism for GPU driver/runtime +libraries (CUDA and ROCm/HIP). + +## Purpose + +CUDA: + +1. **CUDA Driver (`cuda_stub`)**: Allows TileLang to be imported on systems + without a GPU (e.g., CI/compilation nodes) by lazy-loading `libcuda.so` only + when needed. +2. **CUDA Runtime & Compiler (`cudart_stub`, `nvrtc_stub`)**: Resolves SONAME + versioning mismatches (e.g. `libcudart.so.11` vs `libcudart.so.12`), + enabling a single build to work across different CUDA versions. This is + achieved by reusing CUDA libraries already loaded by frameworks like PyTorch + when possible. + +ROCm: + +1. **HIP Runtime/Module API (`hip_stub`)**: Allows TileLang to be imported on + systems without ROCm installed by lazy-loading `libamdhip64.so` only when + needed. The stub also prefers already-loaded symbols via `RTLD_DEFAULT` / + `RTLD_NEXT` to interoperate with frameworks that have already loaded HIP. +2. **HIP Runtime Compiler (`hiprtc_stub`)**: Lazily loads `libhiprtc.so` and + exposes the minimal HIPRTC API subset used by TileLang/TVM. + +## Implementation + +The stubs in `src/target/stubs/` implement a lazy-loading mechanism: + +- **Lazy Loading**: Libraries are loaded via `dlopen` only upon the first API call. +- **Global Symbol Reuse**: For `cudart` and `nvrtc`, the stubs first check the global namespace (`RTLD_DEFAULT`) to use any already loaded symbols (e.g., from PyTorch). +- **ROCm Notes**: `hip_stub` checks `RTLD_DEFAULT` / `RTLD_NEXT` first and then + falls back to `dlopen("libamdhip64.so")`. It additionally provides wrappers + for `hsa_init` / `hsa_shut_down` so that ROCm-enabled wheels do not record a + hard dependency on `libhsa-runtime64` at import time. +- **Versioning Support**: Handles ABI differences between CUDA versions (e.g., `cudaGraphInstantiate` changes in CUDA 12). + +## Build Option + +- `TILELANG_USE_CUDA_STUBS` (Default: `ON`) controls CUDA stubs. When enabled, + TileLang links against these stubs instead of the system CUDA toolkit + libraries. +- `TILELANG_USE_HIP_STUBS` (Default: `ON`) controls ROCm stubs. When enabled + (and `USE_ROCM=ON`), TileLang/TVM link against `hip_stub` / `hiprtc_stub` + instead of the system ROCm libraries. diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 6fd4334594..2d418f0142 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -1,6 +1,8 @@ cancelled +dout HDA hsa +inouts ist LOD nd diff --git a/examples/amd/example_amd_flash_attn_bwd.py b/examples/amd/example_amd_flash_attn_bwd.py index 788aec367c..0a9b7d26cb 100644 --- a/examples/amd/example_amd_flash_attn_bwd.py +++ b/examples/amd/example_amd_flash_attn_bwd.py @@ -1,3 +1,4 @@ +import sys import torch import torch.nn.functional as F import tilelang @@ -10,6 +11,15 @@ import time +def IsRDNA(): + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name().strip() + return "Radeon" in gpu_name + else: + print("Error: GPU Device is not detected") + sys.exit(1) + + def ref_program(Q, K, V, is_causal, groups=1): assert Q.size(2) == K.size(2) * groups, f"Q heads {Q.size(2)} K heads {K.size(2)} groups {groups}" assert Q.size(2) == V.size(2) * groups, f"Q heads {Q.size(2)} V heads {V.size(2)} groups {groups}" @@ -30,13 +40,24 @@ def ref_program(Q, K, V, is_causal, groups=1): def get_fwd_configs(): - block_M = [32, 64, 128, 256] - block_N = [32, 64, 128, 256] - threads = [128, 256, 512] - num_split_q = [64, 128, 256] - num_stages = [0, 1] + # Match the standalone forward example on RDNA. WMMA configs larger than + # 32x32 can trigger layout issues when bridging the softmax fragment into + # the second GEMM's A-layout. + if IsRDNA(): + block_M = [16, 32, 64] + block_N = [16, 32, 64] + threads = [32, 64] + num_split_q = [16, 32, 64] + num_stages = [0] + k_pack = [1] + else: + block_M = [32, 64, 128, 256] + block_N = [32, 64, 128, 256] + threads = [128, 256, 512] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + k_pack = [2] enable_rasterization = [True] - k_pack = [2] panel_size = [7, 8, 9, 10] qk_coalesced_width = [8] v_coalesced_width = [4] @@ -46,6 +67,8 @@ def get_fwd_configs(): for m, n, s, t, stages, r, k, p, qkw, vw in itertools.product( block_M, block_N, num_split_q, threads, num_stages, enable_rasterization, k_pack, panel_size, qk_coalesced_width, v_coalesced_width ): + if IsRDNA() and m == 16 and n == 16 and t == 64: + continue valid_configs.append( { "block_M": m, @@ -112,7 +135,7 @@ def main( bx_loop_var = T.alloc_var(T.int32) bx_loop_var = b_split - with T.While(bx_loop_var < num_q_blocks): + while bx_loop_var < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) @@ -127,6 +150,10 @@ def main( Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) + # Bridge the WMMA D-layout softmax fragment into the A-layout + # expected by GEMM 2 on RDNA GPUs. + if IsRDNA(): + P_shared = T.alloc_shared([block_M, block_N], dtype) acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -188,7 +215,12 @@ def main( for i in T.Parallel(block_M): l_i[i] += row_sum[i] - T.copy(acc_s, acc_s_cast) + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + P_shared[i, j] = T.cast(acc_s[i, j], dtype) + T.copy(P_shared, acc_s_cast) + else: + T.copy(acc_s, acc_s_cast) T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) @@ -211,15 +243,27 @@ def main( def get_bwd_configs(): - block_M = [16, 32, 64, 128, 256] - block_N = [16, 32, 64, 128, 256] - threads = [64, 128, 256, 512, 1024] - num_stages = [0, 1, 2] + # Keep the RDNA search space aligned with the WMMA-friendly tile sizes + # verified above. Larger tiles and some warp/block combinations are either + # unsupported or known to trigger invalid lowering on RDNA. + if IsRDNA(): + block_M = [16, 32] + block_N = [16, 32] + threads = [32, 64] + num_stages = [0] + panel_size = [7, 8] + else: + block_M = [16, 32, 64, 128, 256] + block_N = [16, 32, 64, 128, 256] + threads = [64, 128, 256, 512, 1024] + num_stages = [0, 1, 2] + panel_size = [7, 8, 9, 10] enable_rasterization = [True] - panel_size = [7, 8, 9, 10] configs = [] for m, n, stages, t, r, p in itertools.product(block_M, block_N, num_stages, threads, enable_rasterization, panel_size): + if IsRDNA() and m == 16 and n == 16 and t == 64: + continue configs.append( { "block_M": m, @@ -305,6 +349,10 @@ def flash_bwd_kernel( lse_shared = T.alloc_shared([block_N], accum_dtype) delta_shared = T.alloc_shared([block_N], accum_dtype) ds_shared = T.alloc_shared([block_M, block_N], dtype) + if IsRDNA(): + # Bridge the WMMA D-layout fragment produced by GEMM/elementwise + # ops into the A-layout expected by the following GEMM. + p_shared = T.alloc_shared([block_M, block_N], dtype) p_cast = T.alloc_fragment([block_M, block_N], dtype) qkT = T.alloc_fragment([block_M, block_N], accum_dtype) @@ -343,17 +391,29 @@ def flash_bwd_kernel( T.gemm(V_shared, do_shared, dP, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(P_acc, p_cast) + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + p_shared[i, j] = T.cast(P_acc[i, j], dtype) + T.copy(p_shared, p_cast) + else: + T.copy(P_acc, p_cast) T.gemm(p_cast, do_shared, dv, policy=T.GemmWarpPolicy.FullRow) T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta_shared) - for i, j in T.Parallel(block_M, block_N): - p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale - - T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) - + if IsRDNA(): + for i, j in T.Parallel(block_M, block_N): + dP[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + for i, j in T.Parallel(block_M, block_N): + p_shared[i, j] = T.cast(dP[i, j], dtype) + T.copy(p_shared, p_cast) + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) + else: + for i, j in T.Parallel(block_M, block_N): + p_cast[i, j] = P_acc[i, j] * (dP[i, j] - delta_shared[j]) * sm_scale + T.gemm(p_cast, q_shared, dk, policy=T.GemmWarpPolicy.FullRow) T.copy(p_cast, ds_shared) + T.clear(dq) T.gemm(ds_shared, K_shared, dq, transpose_A=True) for i, j in T.Parallel(block_N, dim): diff --git a/examples/amd/example_amd_flash_attn_fwd.py b/examples/amd/example_amd_flash_attn_fwd.py index ca9c361ff1..cb7d1225d2 100644 --- a/examples/amd/example_amd_flash_attn_fwd.py +++ b/examples/amd/example_amd_flash_attn_fwd.py @@ -1,3 +1,4 @@ +import sys import torch import torch.nn.functional as F import tilelang @@ -8,7 +9,15 @@ from functools import partial -# Custom supply function to ensure tensors are created on GPU +def IsRDNA(): + if torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name().strip() + return "Radeon" in gpu_name + else: + print("Error: GPU Device is not detected") + sys.exit(1) + + def supply_tensors_gpu(params): """Supply function that creates tensors on GPU for ROCm/HIP.""" tensors = [] @@ -16,7 +25,9 @@ def supply_tensors_gpu(params): if hasattr(param, "shape") and hasattr(param, "dtype"): # Force creation on GPU device shape = [int(s) for s in param.shape] - tensor = torch.randn(shape, dtype=param.dtype, device="cuda") + # Convert TileLang dtype to PyTorch dtype + torch_dtype = param.dtype.as_torch() + tensor = torch.randn(shape, dtype=torch_dtype, device="cuda") tensors.append(tensor) else: tensors.append(param) @@ -42,14 +53,31 @@ def ref_program(Q, K, V, is_causal, groups=1): def get_configs(): - """Generates configurations for the autotuner, tailored for FA-2 style parallelism.""" - block_M = [32, 64, 128, 256] - block_N = [32, 64, 128, 256] - threads = [128, 256, 512] - num_split_q = [64, 128, 256] - num_stages = [0, 1] + """Generates configurations for the autotuner. + + For RDNA (gfx11xx/gfx12xx) GPUs using WMMA instructions, block sizes + are limited to 32x32 due to a layout mismatch between WMMA D output and A input + registers when block_M > 16 * num_warps_per_32_threads. Larger blocks cause + incorrect results in the shared memory transpose used to convert softmax scores + to the GEMM 2 A-matrix layout. + """ + if IsRDNA(): + block_M = [16, 32] + block_N = [16, 32] + threads = [32, 64] + num_split_q = [16, 32, 64] + num_stages = [0] + # k_pack=2 is broken for RDNA WMMA (incorrect K-dimension loading for multi-k_pack). + # Use k_pack=1 only until fixed. + k_pack = [1] + else: + block_M = [64, 128, 256] + block_N = [64, 128, 256] + threads = [128, 256] + num_split_q = [64, 128, 256] + num_stages = [0, 1] + k_pack = [2] enable_rasterization = [True] - k_pack = [2] panel_size = [7, 8] qk_coalesced_width = [8] v_coalesced_width = [4] @@ -124,7 +152,7 @@ def main( bx = T.alloc_var(T.int32) bx = b_split - with T.While(bx < num_q_blocks): + while bx < num_q_blocks: acc_o = T.alloc_fragment([block_M, dim], accum_dtype) m_i = T.alloc_fragment([block_M], accum_dtype) l_i = T.alloc_fragment([block_M], accum_dtype) @@ -138,6 +166,13 @@ def main( Q_shared = T.alloc_shared([block_M, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) + # P_shared is used to bridge the WMMA D-layout (acc_s output) to + # A-layout (acc_s_cast input for GEMM 2). On RDNA GPUs with WMMA, + # D and A have different register layouts, so a direct fragment-to- + # fragment copy would cause a layout conflict. Routing through shared + # memory correctly transposes the softmax values. + if IsRDNA(): + P_shared = T.alloc_shared([block_M, block_N], dtype) # Use register fragment for P instead of shared memory to reduce LDS usage acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) @@ -184,6 +219,7 @@ def main( for i, j in T.Parallel(block_M, dim): acc_o[i, j] *= scale_factor[i] + # Compute softmax values for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.exp(acc_s[i, j] * scale - m_i[i] * scale) @@ -191,8 +227,20 @@ def main( for i in T.Parallel(block_M): l_i[i] += row_sum[i] - # Cast acc_s (accum_dtype) to dtype in registers and directly GEMM with V - T.copy(acc_s, acc_s_cast) + if IsRDNA(): + # Cast softmax values from f32 (acc_s, D-layout) to f16 (acc_s_cast, A-layout). + # On RDNA with WMMA, D and A have different register layouts. + # Route through shared memory (P_shared) to correctly bridge them: + # 1) T.Parallel writes acc_s values to P_shared at D-layout coordinates. + # 2) T.copy reads P_shared into acc_s_cast at A-layout coordinates. + # This shared-memory transpose is only correct when block_M / threads + # gives at most 2 warps (block_M=32 with 64 threads, or block_M=16 with 32 threads). + for i, j in T.Parallel(block_M, block_N): + P_shared[i, j] = T.cast(acc_s[i, j], dtype) + T.copy(P_shared, acc_s_cast) + else: + # This avoids layout conflict between acc_s and acc_s_cast + T.copy(acc_s, acc_s_cast) T.gemm(acc_s_cast, V_shared, acc_o, policy=GemmWarpPolicy.FullRow) diff --git a/examples/attention_sink/benchmark_gqa_sink_fwd.py b/examples/attention_sink/benchmark_gqa_sink_fwd.py index 211ef1d18c..e0cb5480f3 100644 --- a/examples/attention_sink/benchmark_gqa_sink_fwd.py +++ b/examples/attention_sink/benchmark_gqa_sink_fwd.py @@ -5,7 +5,7 @@ import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor -from example_gqa_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from example_gqa_sink_fwd_bhsd import flashattn, ref_program, gen_inputs from typing import Optional diff --git a/examples/attention_sink/benchmark_mha_sink_fwd.py b/examples/attention_sink/benchmark_mha_sink_fwd.py index 50747e6b09..c1e25d9be6 100644 --- a/examples/attention_sink/benchmark_mha_sink_fwd.py +++ b/examples/attention_sink/benchmark_mha_sink_fwd.py @@ -5,7 +5,7 @@ import triton import triton.language as tl from triton.tools.tensor_descriptor import TensorDescriptor -from example_mha_sink_fwd_bhsd_wgmma_pipelined import flashattn, ref_program, gen_inputs +from example_mha_sink_fwd_bhsd import flashattn, ref_program, gen_inputs from typing import Optional diff --git a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py index cfdcd21b58..97a504a0d1 100644 --- a/examples/attention_sink/example_gqa_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_gqa_sink_bwd_bhsd.py @@ -542,7 +542,7 @@ def run_kernel_only(): parser.add_argument("--n_ctx", type=int, default=4096, help="Context size") parser.add_argument("--d_head", type=int, default=128, help="Head dimension") parser.add_argument("--groups", type=int, default=8, help="Groups") - parser.add_argument("--window_size", type=int, default=None, help="window size (default: None, which means full attention)") + parser.add_argument("--window_size", type=int, default=128, help="window size (default: None, which means full attention)") parser.add_argument("--dtype", type=str, default="float16", help="dtype, can be float16 or bfloat16") args = parser.parse_args() main(args.batch, args.h, args.n_ctx, args.d_head, args.groups, args.window_size, args.dtype) diff --git a/examples/attention_sink/example_gqa_sink_bwd_varlen.py b/examples/attention_sink/example_gqa_sink_bwd_varlen.py new file mode 100644 index 0000000000..64a5a39a86 --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_bwd_varlen.py @@ -0,0 +1,798 @@ +import torch +import tilelang +from tilelang.profiler import do_bench +import tilelang.language as T +import argparse +from typing import Optional +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "../flash_attention")) +from varlen_utils import generate_random_padding_mask, generate_qkv + + +def get_bwd_configs(): + sm_major, sm_minor = torch.cuda.get_device_capability() + sm_version = sm_major * 10 + sm_minor + if sm_version == 80: + return 64, 32, 1, 128 + else: + return 128, 32, 2, 256 + + +@tilelang.jit( + out_idx=[6, 7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_fwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + heads, + max_seq_len, + dim, + is_causal, + window_size=None, # None for full causal attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype=T.float16, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + Sinks: T.Tensor([heads], dtype), + Output_unpad: T.Tensor(o_shape, dtype), + lse: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + ): + with T.Kernel(T.ceildiv(max_seq_len, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[head_idx] + + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M + + # Determine loop range based on causal mask and sliding window + if is_causal: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + start = 0 + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.ceildiv(kv_current_seqlen, block_N) + else: + start = 0 + end = T.ceildiv(kv_current_seqlen, block_N) + + loop_range = end - start + + for k in T.Pipelined(loop_range, num_stages=num_stages): + actual_k = k + start + T.copy(K_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], K_shared) + + # Build mask considering causal, sliding window, and padding + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx < k_idx) + or (q_idx >= k_idx + window_size) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < actual_k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx >= k_idx + window_size) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(V_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Handle case where scores_max is -inf (query sees no keys due to causal mask or sliding window) + # This can happen when q_len > k_len (offset < 0) in causal attention, or with sliding window + for i in T.Parallel(block_M): + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + + T.copy(acc_s, acc_s_cast) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + # Attention sink: add sink contribution to logsum + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = acc_o[i, d] + + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + for i in T.Parallel(block_M): + if bx * block_M + i < q_current_seqlen: + lse[bz, head_idx, bx * block_M + i] = logsum[i] + + return main + + +@tilelang.jit( + out_idx=[3], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_preprocess(batch_size, heads, UQ, N_CTX, max_seq_len, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [UQ, heads, dim] + blk = 32 + + @T.prim_func + def flash_bwd_prep( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + Delta: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, blk), batch_size) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + for i, j in T.Parallel(blk, blk): + if by * blk + i < q_current_seqlen and k * blk + j < dim: + o[i, j] = O[q_start_idx + by * blk + i, bx, k * blk + j] + do[i, j] = dO[q_start_idx + by * blk + i, bx, k * blk + j] + else: + o[i, j] = 0.0 + do[i, j] = 0.0 + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + + for i in T.Parallel(blk): + if by * blk + i < q_current_seqlen: + Delta[bz, bx, by * blk + i] = delta[i] + + return flash_bwd_prep + + +def make_dq_layout(dQ): + # Reorder dq for atomic add: [seq, head, dim] -> permuted layout + return T.Layout(dQ.shape, lambda l, h, d: [h, l, d]) + + +@tilelang.jit( + out_idx=[1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_bwd_postprocess(UQ, heads, dim, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [UQ, heads, dim] + blk = 64 + + @T.prim_func + def flash_bwd_post( + dQ: T.Tensor(shape, accum_dtype), + dQ_out: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(UQ, blk), heads, threads=128) as (bx, by): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bx * blk : (bx + 1) * blk, by, :], + dQ_out[bx * blk : (bx + 1) * blk, by, :], + ) + + return flash_bwd_post + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + } +) +def flashattn_bwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + heads, + max_seq_len, + dim, + is_causal, + window_size=None, + sm_scale=None, + dtype=T.float16, +): + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + accum_dtype = T.float32 + + block_M, block_N, num_stages, threads = get_bwd_configs() + + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + @T.prim_func + def flash_bwd( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + Delta: T.Tensor([batch_size, heads, N_CTX], accum_dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block_M), batch_size, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + + q_start_idx = cu_seqlens_q[bz] + kv_start_idx = cu_seqlens_k[bz] + q_end_idx = cu_seqlens_q[bz + 1] + k_end_idx = cu_seqlens_k[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.annotate_layout( + { + dQ: make_dq_layout(dQ), + } + ) + T.copy(K[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + + # For varlen causal attention, we need to account for offset between q and kv lengths + # In forward: Q at pos q can see KV at pos k if q + offset >= k (where offset = kv_len - q_len) + # In backward: KV at pos kv_pos is seen by Q at pos q_pos if kv_pos <= q_pos + offset + offset = kv_current_seqlen - q_current_seqlen + + # loop_st: first Q block that can see this KV block + # kv_pos <= q_pos + offset => by * block_M <= k * block_N + offset + # => k >= (by * block_M - offset) / block_N + loop_st = T.max(0, T.floordiv(by * block_M - offset, block_N)) if is_causal else 0 + loop_ed = ( + T.min(T.ceildiv((by + 1) * block_M - offset + window_size, block_N), T.ceildiv(q_current_seqlen, block_N)) + if window_size is not None + else T.ceildiv(q_current_seqlen, block_N) + ) + + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + # Causal: kv_pos <= q_pos + offset + # Sliding window: kv_pos > q_pos + offset - window_size + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k * block_N + j + offset) + and (by * block_M + i > k * block_N + j + offset - window_size) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + # Causal: kv_pos <= q_pos + offset + qkT[i, j] = T.if_then_else( + (by * block_M + i <= k * block_N + j + offset) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + (by * block_M + i > k * block_N + j + offset - window_size) + and (by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen), + qkT[i, j], + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else( + by * block_M + i < kv_current_seqlen and k * block_N + j < q_current_seqlen, + qkT[i, j], + 0, + ) + + T.copy(dO[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], dst=do) + T.clear(dsT) + T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.gemm(dsT_shared, K_shared, dq, transpose_A=True) + T.atomic_add(dQ[q_start_idx + k * block_N : q_start_idx + (k + 1) * block_N, bx, :], dq) + + T.copy(dv, dv_shared) + T.atomic_add(dV[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dK[kv_start_idx + by * block_M : kv_start_idx + (by + 1) * block_M, bx // groups, :], dk_shared) + + return flash_bwd + + +@tilelang.jit(out_idx=-1) +def flashattn_bwd_dsink(batch_size, heads, N_CTX, max_seq_len, block=256, dtype: T.dtype = T.float16): + accum_dtype = T.float32 + shape = [batch_size, heads, N_CTX] + + @T.prim_func + def flash_bwd_dsink( + Sinks: T.Tensor([heads], dtype), + Delta: T.Tensor(shape, accum_dtype), + lse: T.Tensor(shape, accum_dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + dsinks: T.Tensor(shape, dtype), + ): + with T.Kernel(heads, T.ceildiv(max_seq_len, block), batch_size, threads=256) as (bx, by, bz): + lse_fragment = T.alloc_fragment([block], accum_dtype) + delta_fragment = T.alloc_fragment([block], accum_dtype) + dsink_fragment = T.alloc_fragment([block], dtype) + + # Get actual sequence length for this batch item + q_start_idx = cu_seqlens_q[bz] + q_end_idx = cu_seqlens_q[bz + 1] + q_current_seqlen = q_end_idx - q_start_idx + + sink = Sinks[bx] + T.copy(lse[bz, bx, by * block : (by + 1) * block], lse_fragment) + T.copy(Delta[bz, bx, by * block : (by + 1) * block], delta_fragment) + for i in T.Parallel(block): + # Only compute for valid positions, set 0 for positions beyond sequence length + dsink_fragment[i] = T.if_then_else( + by * block + i < q_current_seqlen, + -T.exp2(sink * 1.44269504 - lse_fragment[i]) * delta_fragment[i], + 0, + ) + T.copy(dsink_fragment, dsinks[bz, bx, by * block : (by + 1) * block]) + + return flash_bwd_dsink + + +class _attention(torch.autograd.Function): + @staticmethod + def forward( + ctx, q_unpad, k_unpad, v_unpad, sinks, cu_seqlens_q, cu_seqlens_k, N_CTX, max_seqlen_q, max_seqlen_k, window_size, groups, is_causal + ): + def maybe_contiguous(x): + if x.stride(-1) != 1: + return x.contiguous() + return x + + q_unpad, k_unpad, v_unpad, sinks = [maybe_contiguous(x) for x in (q_unpad, k_unpad, v_unpad, sinks)] + UQ, H, D_HEAD = q_unpad.shape + UKV = k_unpad.shape[0] + batch_size = cu_seqlens_q.shape[0] - 1 + dtype = T.float16 if q_unpad.dtype == torch.float16 else T.bfloat16 + + kernel = flashattn_fwd( + batch_size, + groups, + UQ, + UKV, + N_CTX, + H, + max_seqlen_q, + D_HEAD, + is_causal, + window_size=window_size, + block_M=64, + block_N=64, + num_stages=1, + threads=128, + dtype=dtype, + ) + o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, sinks) + + ctx.save_for_backward(q_unpad, k_unpad, v_unpad, sinks, o_unpad, lse, cu_seqlens_q, cu_seqlens_k) + ctx.window_size = window_size + ctx.groups = groups + ctx.is_causal = is_causal + ctx.N_CTX = N_CTX + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.batch_size = batch_size + return o_unpad + + @staticmethod + def backward(ctx, do): + q_unpad, k_unpad, v_unpad, sinks, o_unpad, lse, cu_seqlens_q, cu_seqlens_k = ctx.saved_tensors + UQ, H, D_HEAD = q_unpad.shape + UKV = k_unpad.shape[0] + groups = ctx.groups + batch_size = ctx.batch_size + dtype = T.float16 if q_unpad.dtype == torch.float16 else T.bfloat16 + + kernel_prep = flashattn_bwd_preprocess(batch_size, H, UQ, ctx.N_CTX, ctx.max_seqlen_q, D_HEAD, dtype=dtype) + kernel_post = flashattn_bwd_postprocess(UQ, H, D_HEAD, dtype=dtype) + delta = kernel_prep(o_unpad, do, cu_seqlens_q) + + kernel = flashattn_bwd( + batch_size, + groups, + UQ, + UKV, + ctx.N_CTX, + H, + ctx.max_seqlen_q, + D_HEAD, + ctx.is_causal, + window_size=ctx.window_size, + dtype=dtype, + ) + + head_kv = H // groups + dq = torch.zeros_like(q_unpad, dtype=torch.float32) + dk = torch.zeros([UKV, head_kv, D_HEAD], dtype=torch.float32, device=q_unpad.device) + dv = torch.zeros([UKV, head_kv, D_HEAD], dtype=torch.float32, device=q_unpad.device) + + kernel(q_unpad, k_unpad, v_unpad, do, lse, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) + dq = kernel_post(dq) + dk = dk.to(q_unpad.dtype) + dv = dv.to(q_unpad.dtype) + + kernel_dsink = flashattn_bwd_dsink(batch_size, H, ctx.N_CTX, ctx.max_seqlen_q, dtype=dtype) + dsinks = kernel_dsink(sinks, delta, lse, cu_seqlens_q).sum(0).sum(1) + + return dq, dk, dv, dsinks, None, None, None, None, None, None, None, None + + +attention = _attention.apply + + +def ref_program( + q_unpad: torch.Tensor, + k_unpad: torch.Tensor, + v_unpad: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sinks: torch.Tensor, + batch_size: int, + is_causal: bool, + sliding_window: Optional[int] = None, + groups: int = 1, +) -> torch.Tensor: + """Reference implementation for varlen attention with sinks.""" + total_q, num_heads, head_dim = q_unpad.shape + _, num_key_value_heads, _ = k_unpad.shape + + sm_scale = 1.0 / head_dim**0.5 + + output = torch.zeros_like(q_unpad) + + for b in range(batch_size): + q_start = cu_seqlens_q[b].item() + q_end = cu_seqlens_q[b + 1].item() + k_start = cu_seqlens_k[b].item() + k_end = cu_seqlens_k[b + 1].item() + + q_len = q_end - q_start + k_len = k_end - k_start + + if q_len == 0: + continue + + q_seq = q_unpad[q_start:q_end] # [q_len, heads, dim] + k_seq = k_unpad[k_start:k_end] # [k_len, head_kv, dim] + v_seq = v_unpad[k_start:k_end] # [k_len, head_kv, dim] + + # Reshape for GQA + q_seq = q_seq.view(q_len, num_key_value_heads, groups, head_dim) + sinks_expanded = sinks.view(num_key_value_heads, groups, 1, 1).float() + + k_seq = k_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + v_seq = v_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + + logits = torch.einsum("qhgd,khgd->hgqk", q_seq.float(), k_seq.float()) * sm_scale + + start_q = k_len - q_len + pos_keys = torch.arange(k_len, device=q_unpad.device) + pos_queries = torch.arange(q_len, device=q_unpad.device) + start_q + + if is_causal: + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + else: + mask = torch.zeros(q_len, k_len, device=q_unpad.device) + + if sliding_window is not None: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = logits + mask[None, None, :, :] + + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks_expanded, logits_max) + sinks_exp = torch.exp(sinks_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks_exp + scores = unnormalized_scores / normalizer + + out = torch.einsum("hgqk,khgd->qhgd", scores, v_seq.float()) + out = out.reshape(q_len, num_heads, head_dim).to(q_unpad.dtype) + + output[q_start:q_end] = out + + return output + + +def main( + batch: int = 1, + heads: int = 64, + q_seqlen: int = 2048, + k_seqlen: int = 2048, + dim: int = 128, + groups: int = 16, + is_causal: bool = True, + window_size: Optional[int] = None, +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 5 * flops_per_matmul # fwd + bwd + + if is_causal: + total_flops *= 0.5 + + if window_size is not None: + print(f"Using sliding window attention with window_size={window_size}") + flops_per_matmul = 2.0 * batch * heads * min(window_size, k_seqlen // 2) * q_seqlen * dim + total_flops = 5 * flops_per_matmul + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + sinks = torch.randn(heads, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + q_unpad = q_unpad.requires_grad_(True) + k_unpad = k_unpad.requires_grad_(True) + v_unpad = v_unpad.requires_grad_(True) + sinks = sinks.requires_grad_(True) + + dO_unpad = torch.randn_like(q_unpad) + + # TileLang forward + backward + # N_CTX is the padded sequence length used for tensor allocation + N_CTX = q_seqlen + O_unpad = attention( + q_unpad, k_unpad, v_unpad, sinks, cu_seqlens_q, cu_seqlens_k, N_CTX, max_seqlen_q, max_seqlen_k, window_size, groups, is_causal + ) + O_unpad.backward(dO_unpad, retain_graph=True) + dQ, q_unpad.grad = q_unpad.grad.clone(), None + dK, k_unpad.grad = k_unpad.grad.clone(), None + dV, v_unpad.grad = v_unpad.grad.clone(), None + dsinks, sinks.grad = sinks.grad.clone(), None + + # Reference forward + backward + O_ref_unpad = ref_program( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sinks, + batch, + is_causal, + sliding_window=window_size, + groups=groups, + ) + O_ref_unpad.backward(dO_unpad, retain_graph=True) + dQ_ref, q_unpad.grad = q_unpad.grad.clone(), None + dK_ref, k_unpad.grad = k_unpad.grad.clone(), None + dV_ref, v_unpad.grad = v_unpad.grad.clone(), None + dsinks_ref, sinks.grad = sinks.grad.clone(), None + + # Checks + # Sliding window attention has slightly higher numerical error due to more complex masking + rtol, atol = (2e-2, 2e-2) if window_size is not None else (1e-2, 1e-2) + assert torch.allclose(O_unpad, O_ref_unpad, rtol=rtol, atol=atol), f"O max err: {(O_unpad - O_ref_unpad).abs().max()}" + assert torch.allclose(dV, dV_ref, rtol=rtol, atol=atol), f"dV max err: {(dV - dV_ref).abs().max()}" + assert torch.allclose(dK, dK_ref, rtol=rtol, atol=atol), f"dK max err: {(dK - dK_ref).abs().max()}" + assert torch.allclose(dQ, dQ_ref, rtol=rtol, atol=atol), f"dQ max err: {(dQ - dQ_ref).abs().max()}" + assert torch.allclose(dsinks, dsinks_ref, rtol=rtol, atol=atol), f"dsinks max err: {(dsinks - dsinks_ref).abs().max()}" + + print("All checks passed for tilelang kernels.✅") + + # Benchmark backward + def torch_bwd(): + O_ref_unpad.backward(dO_unpad, retain_graph=True) + + def tl_bwd(): + O_unpad.backward(dO_unpad, retain_graph=True) + + latency = do_bench(torch_bwd, warmup=500) + print("torch: {:.2f} ms".format(latency)) + print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + latency = do_bench(tl_bwd, warmup=500) + print("tilelang: {:.2f} ms".format(latency)) + print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + parser.add_argument("--window_size", type=int, default=None, help="sliding window size (default: None for full attention)") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal, args.window_size) diff --git a/examples/attention_sink/example_gqa_sink_fwd_varlen.py b/examples/attention_sink/example_gqa_sink_fwd_varlen.py new file mode 100644 index 0000000000..16838dd860 --- /dev/null +++ b/examples/attention_sink/example_gqa_sink_fwd_varlen.py @@ -0,0 +1,401 @@ +# ruff: noqa +# Using varlen (variable length) format with attention sink + +import argparse +import torch +import tilelang +import tilelang.language as T +import tilelang.testing +from tilelang.profiler import do_bench +from typing import Optional +import sys +import os + +sys.path.append(os.path.join(os.path.dirname(__file__), "../flash_attention")) +from varlen_utils import generate_random_padding_mask, generate_qkv + + +@tilelang.jit( + out_idx=[7], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn_sink( + batch_size, + groups, + UQ, + UKV, + heads, + dim, + is_causal, + window_size=None, # None for full causal attention + sm_scale=None, + block_M=64, + block_N=64, + num_stages=1, + threads=128, +): + if window_size is not None: + assert window_size % block_N == 0, "window_size must be divisible by block_N" + + if sm_scale is None: + sm_scale = (1.0 / dim) ** 0.5 + scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // groups + q_shape = [UQ, heads, dim] + kv_shape = [UKV, head_kv, dim] + o_shape = [UQ, heads, dim] + dtype = T.float16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q_unpad: T.Tensor(q_shape, dtype), + K_unpad: T.Tensor(kv_shape, dtype), + V_unpad: T.Tensor(kv_shape, dtype), + cu_seqlens_q: T.Tensor([batch_size + 1], T.int32), + cu_seqlens_k: T.Tensor([batch_size + 1], T.int32), + max_seqlen_q: T.int32, + Sinks: T.Tensor([heads], dtype), + Output_unpad: T.Tensor(o_shape, dtype), + ): + with T.Kernel(T.ceildiv(max_seqlen_q, block_M), heads, batch_size, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + sinks = T.alloc_fragment([block_M], dtype) + + batch_idx = bz + head_idx = by + kv_head_idx = head_idx // groups + + q_start_idx = cu_seqlens_q[batch_idx] + kv_start_idx = cu_seqlens_k[batch_idx] + q_end_idx = cu_seqlens_q[batch_idx + 1] + k_end_idx = cu_seqlens_k[batch_idx + 1] + + q_current_seqlen = q_end_idx - q_start_idx + kv_current_seqlen = k_end_idx - kv_start_idx + + T.copy(Q_unpad[q_start_idx + bx * block_M : q_start_idx + (bx + 1) * block_M, head_idx, :], Q_shared) + + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + for i in T.Parallel(block_M): + sinks[i] = Sinks[head_idx] + + offset = kv_current_seqlen - q_current_seqlen # always align on the right + max_visible_k_idx = offset + (bx + 1) * block_M + + # Determine loop range based on causal mask and sliding window + if is_causal: + if window_size is not None: + # Sliding window + causal: start from window boundary + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + # Full causal attention + start = 0 + end = T.min(T.ceildiv(max_visible_k_idx, block_N), T.ceildiv(kv_current_seqlen, block_N)) + else: + if window_size is not None: + start = T.max(0, (offset + bx * block_M - window_size + 1) // block_N) + end = T.ceildiv(kv_current_seqlen, block_N) + else: + start = 0 + end = T.ceildiv(kv_current_seqlen, block_N) + + loop_range = end - start + + for k in T.Pipelined(loop_range, num_stages=num_stages): + actual_k = k + start + T.copy(K_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], K_shared) + + # Build mask considering causal, sliding window, and padding + if is_causal: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + # Causal + sliding window mask + acc_s[i, j] = T.if_then_else( + (q_idx < k_idx) # causal: can't see future + or (q_idx >= k_idx + window_size) # sliding window: too old + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i + offset < actual_k * block_N + j) + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + if window_size is not None: + for i, j in T.Parallel(block_M, block_N): + q_idx = bx * block_M + i + offset + k_idx = actual_k * block_N + j + acc_s[i, j] = T.if_then_else( + (q_idx >= k_idx + window_size) # sliding window: too old + or (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else( + (bx * block_M + i >= q_current_seqlen or actual_k * block_N + j >= kv_current_seqlen), + -T.infinity(acc_s.dtype), + 0, + ) + + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + + # Check_inf for sliding window attention + if window_size is not None: + for i in T.Parallel(block_M): + scores_max[i] = T.if_then_else(scores_max[i] == -T.infinity(accum_dtype), 0, scores_max[i]) + + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + + T.copy(V_unpad[kv_start_idx + actual_k * block_N : kv_start_idx + (actual_k + 1) * block_N, kv_head_idx, :], V_shared) + + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Attention sink: add sink contribution to logsum + for i in T.Parallel(block_M): + logsum[i] += T.exp2(sinks[i] * 1.44269504 - scores_max[i] * scale) + + for i, j in T.Parallel(block_M, dim): + # When sq > skv, some tokens can see nothing (for causal) + acc_o[i, j] = 0 if is_causal and bx * block_M + i + offset < 0 else acc_o[i, j] / logsum[i] + + T.copy(acc_o, O_shared) + for i, d in T.Parallel(block_M, dim): + if bx * block_M + i < q_current_seqlen: + Output_unpad[q_start_idx + bx * block_M + i, head_idx, d] = O_shared[i, d] + + return main + + +def ref_program( + q_unpad: torch.Tensor, + k_unpad: torch.Tensor, + v_unpad: torch.Tensor, + cu_seqlens_q: torch.Tensor, + cu_seqlens_k: torch.Tensor, + max_seqlen_q: int, + max_seqlen_k: int, + sinks: torch.Tensor, + batch_size: int, + is_causal: bool, + sliding_window: Optional[int] = None, + groups: int = 1, +) -> torch.Tensor: + """Reference implementation for varlen attention with sinks.""" + # q_unpad: [total_q, heads, dim] + # k_unpad: [total_kv, head_kv, dim] + # v_unpad: [total_kv, head_kv, dim] + total_q, num_heads, head_dim = q_unpad.shape + _, num_key_value_heads, _ = k_unpad.shape + + sm_scale = 1.0 / head_dim**0.5 + + output = torch.zeros_like(q_unpad) + + for b in range(batch_size): + q_start = cu_seqlens_q[b].item() + q_end = cu_seqlens_q[b + 1].item() + k_start = cu_seqlens_k[b].item() + k_end = cu_seqlens_k[b + 1].item() + + q_len = q_end - q_start + k_len = k_end - k_start + + if q_len == 0: + continue + + # Extract sequences for this batch + q_seq = q_unpad[q_start:q_end] # [q_len, heads, dim] + k_seq = k_unpad[k_start:k_end] # [k_len, head_kv, dim] + v_seq = v_unpad[k_start:k_end] # [k_len, head_kv, dim] + + # Reshape for GQA + q_seq = q_seq.view(q_len, num_key_value_heads, groups, head_dim) # [q_len, head_kv, groups, dim] + sinks_expanded = sinks.view(num_key_value_heads, groups, 1, 1).float() # [head_kv, groups, 1, 1] + + k_seq = k_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + v_seq = v_seq.unsqueeze(2) # [k_len, head_kv, 1, dim] + + # Compute attention + # q_seq: [q_len, head_kv, groups, dim], k_seq: [k_len, head_kv, 1, dim] + logits = torch.einsum("qhgd,khgd->hgqk", q_seq.float(), k_seq.float()) * sm_scale + + # Build mask + start_q = k_len - q_len # offset for causal alignment + pos_keys = torch.arange(k_len, device=q_unpad.device) + pos_queries = torch.arange(q_len, device=q_unpad.device) + start_q + + if is_causal: + mask = pos_keys[None, :] > pos_queries[:, None] + mask = mask.float().masked_fill(mask, float("-inf")) + else: + mask = torch.zeros(q_len, k_len, device=q_unpad.device) + + if sliding_window is not None: + too_old = pos_keys[None, :] < (pos_queries[:, None] - sliding_window + 1) + mask.masked_fill_(too_old, float("-inf")) + + logits = logits + mask[None, None, :, :] # [head_kv, groups, q_len, k_len] + + # Apply sink-adjusted softmax + logits_max = torch.max(logits, dim=-1, keepdim=True).values + logits_or_sinks_max = torch.maximum(sinks_expanded, logits_max) + sinks_exp = torch.exp(sinks_expanded - logits_or_sinks_max) + unnormalized_scores = torch.exp(logits - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks_exp + scores = unnormalized_scores / normalizer + + # Compute output + out = torch.einsum("hgqk,khgd->qhgd", scores, v_seq.float()) + out = out.reshape(q_len, num_heads, head_dim).to(q_unpad.dtype) + + output[q_start:q_end] = out + + return output + + +def main( + batch: int = 1, + heads: int = 64, + q_seqlen: int = 2048, + k_seqlen: int = 2048, + dim: int = 128, + groups: int = 16, + is_causal: bool = True, + window_size: Optional[int] = None, +): + assert heads % groups == 0, "heads must be divisible by groups" + + flops_per_matmul = 2.0 * batch * heads * q_seqlen * k_seqlen * dim + total_flops = 2 * flops_per_matmul + + tilelang.testing.set_random_seed(0) + + if is_causal: + total_flops *= 0.5 + + if window_size is not None: + print(f"Using sliding window attention with window_size={window_size}") + flops_per_matmul = 2.0 * batch * heads * min(window_size, k_seqlen // 2) * q_seqlen * dim + total_flops = 2 * flops_per_matmul + + dtype = torch.float16 + device = torch.device("cuda") + + head_kv = heads // groups + q = torch.randn(batch, q_seqlen, heads, dim, dtype=dtype, device=device) + k = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + v = torch.randn(batch, k_seqlen, head_kv, dim, dtype=dtype, device=device) + sinks = torch.randn(heads, dtype=dtype, device=device) + + query_padding_mask = generate_random_padding_mask(q_seqlen, batch, device, mode="random") + key_padding_mask = generate_random_padding_mask(k_seqlen, batch, device, mode="random") + + ( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + q, + k, + v, + output_pad_fn, + _, + _, + ) = generate_qkv(q, k, v, query_padding_mask, key_padding_mask, kvpacked=False) + + UQ = q_unpad.shape[0] + UKV = k_unpad.shape[0] + + kernel = flashattn_sink( + batch, groups, UQ, UKV, heads, dim, is_causal, window_size=window_size, block_M=128, block_N=128, num_stages=2, threads=256 + ) + + out_unpad = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks) + out = output_pad_fn(out_unpad) + + # Reference implementation + ref_out_unpad = ref_program( + q_unpad, + k_unpad, + v_unpad, + cu_seqlens_q, + cu_seqlens_k, + max_seqlen_q, + max_seqlen_k, + sinks, + batch, + is_causal, + sliding_window=window_size, + groups=groups, + ) + ref_out = output_pad_fn(ref_out_unpad) + + torch.testing.assert_close(out, ref_out, rtol=1e-2, atol=1e-2) + + print("All checks passed.✅") + latency = do_bench( + lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, sinks), + warmup=500, + ) + print("Tile-lang: {:.2f} ms".format(latency)) + print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=8, help="batch size") + parser.add_argument("--heads", type=int, default=64, help="query heads") + parser.add_argument("--groups", type=int, default=16, help="groups") + parser.add_argument("--q_seqlen", type=int, default=2048, help="query sequence length") + parser.add_argument("--k_seqlen", type=int, default=2048, help="key/value sequence length") + parser.add_argument("--dim", type=int, default=128, help="head dim") + parser.add_argument("--is_causal", action="store_true", help="causal attention") + parser.add_argument("--window_size", type=int, default=None, help="sliding window size (default: None for full attention)") + args = parser.parse_args() + main(args.batch, args.heads, args.q_seqlen, args.k_seqlen, args.dim, args.groups, args.is_causal, args.window_size) diff --git a/examples/attention_sink/example_mha_sink_bwd_bhsd.py b/examples/attention_sink/example_mha_sink_bwd_bhsd.py index 66905f55d1..fa045a1d78 100644 --- a/examples/attention_sink/example_mha_sink_bwd_bhsd.py +++ b/examples/attention_sink/example_mha_sink_bwd_bhsd.py @@ -13,7 +13,7 @@ def get_bwd_configs(): sm_version = sm_major * 10 + sm_minor if sm_version == 80: return 64, 32, 1, 128 - elif sm_version == 90: + elif sm_version >= 90: return 128, 32, 2, 256 else: raise ValueError(f"Unsupported SM version: {sm_version}") diff --git a/examples/attention_sink/regression_attention_sink.py b/examples/attention_sink/regression_attention_sink.py index e2453173cf..f1116befcb 100644 --- a/examples/attention_sink/regression_attention_sink.py +++ b/examples/attention_sink/regression_attention_sink.py @@ -1,9 +1,7 @@ import tilelang.testing import example_mha_sink_fwd_bhsd -import example_mha_sink_fwd_bhsd_wgmma_pipelined import example_mha_sink_bwd_bhsd import example_gqa_sink_bwd_bhsd -import example_gqa_sink_fwd_bhsd_wgmma_pipelined def regression_example_mha_sink_fwd_bhsd(): @@ -16,30 +14,6 @@ def regression_example_mha_sink_fwd_bhsd_sliding_window(): ) -def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined(): - tilelang.testing.process_func(example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) - - -def regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): - tilelang.testing.process_func( - example_mha_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, - "regression_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window", - window_size=128, - ) - - -def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined(): - tilelang.testing.process_func(example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf) - - -def regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): - tilelang.testing.process_func( - example_gqa_sink_fwd_bhsd_wgmma_pipelined.run_regression_perf, - "regression_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window", - window_size=128, - ) - - def regression_example_mha_sink_bwd_bhsd(): tilelang.testing.process_func(example_mha_sink_bwd_bhsd.run_regression_perf) diff --git a/examples/attention_sink/test_example_attention_sink.py b/examples/attention_sink/test_example_attention_sink.py index 57242c199c..41682c6c7c 100644 --- a/examples/attention_sink/test_example_attention_sink.py +++ b/examples/attention_sink/test_example_attention_sink.py @@ -1,10 +1,10 @@ import tilelang.testing import example_mha_sink_fwd_bhsd -import example_mha_sink_fwd_bhsd_wgmma_pipelined -import example_gqa_sink_fwd_bhsd_wgmma_pipelined import example_mha_sink_bwd_bhsd import example_gqa_sink_bwd_bhsd +import example_gqa_sink_fwd_varlen +import example_gqa_sink_bwd_varlen @tilelang.testing.requires_cuda @@ -17,30 +17,6 @@ def test_example_mha_sink_fwd_bhsd_sliding_window(): example_mha_sink_fwd_bhsd.main(window_size=128) -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_full_attn(): - example_mha_sink_fwd_bhsd_wgmma_pipelined.main() - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_mha_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): - example_mha_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_full_attn(): - example_gqa_sink_fwd_bhsd_wgmma_pipelined.main() - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_gqa_sink_fwd_bhsd_wgmma_pipelined_sliding_window(): - example_gqa_sink_fwd_bhsd_wgmma_pipelined.main(window_size=128) - - @tilelang.testing.requires_cuda def test_example_mha_sink_bwd_bhsd(): example_mha_sink_bwd_bhsd.main() @@ -61,5 +37,12 @@ def test_example_gqa_sink_bwd_bhsd_sliding_window(): example_gqa_sink_bwd_bhsd.main(window_size=128) +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_gqa_sink_varlen(): + example_gqa_sink_fwd_varlen.main() # non-causal + example_gqa_sink_bwd_varlen.main() # causal + + if __name__ == "__main__": tilelang.testing.main() diff --git a/examples/autodd/README.md b/examples/autodd/README.md new file mode 100644 index 0000000000..9ae9f98167 --- /dev/null +++ b/examples/autodd/README.md @@ -0,0 +1,126 @@ +# AutoDD - Automatic Delta Debugging for TileLang + +AutoDD (Automatic Delta Debugging) is a built-in debugging tool for TileLang that automatically simplifies complex Python programs to the minimal code needed to reproduce a specific error. This is extremely useful for debugging large, complex TileLang programs. + +## What is Delta Debugging? + +Delta Debugging is an automated debugging technique with the core idea: +1. Given a program that triggers a bug +2. Systematically remove code fragments from the program +3. Check if the simplified program still triggers the same bug +4. Eventually obtain the minimal code that triggers the bug + +AutoDD uses a Probability Distribution Driven Delta Debugging (PDD) algorithm for efficient search of minimized code. + +## Why AutoDD? + +When developing TileLang programs, bugs are often hidden in complex code: + +- **Lots of irrelevant code**: Real projects may have hundreds of lines of configuration, helper functions, logging, etc. +- **Hard to locate**: Error messages may point to underlying TVM/CUDA rather than TileLang code +- **Tedious debugging**: Manually deleting code to locate bugs is very time-consuming + +AutoDD automates this process, reducing hundreds of lines of code to just a few dozen, directly exposing the root cause of the problem. + +## Usage + +### Basic Usage + +```bash +python -m tilelang.autodd --err-msg "" -o +``` + +### Parameters + +| Parameter | Description | +|-----------|-------------| +| `source` | Path to the input Python source file | +| `--err-msg` | Error message to match (searched in stdout or stderr) | +| `-o, --output` | Path to the minimized output file | +| `--backend` | Execution backend: `runner` (faster) or `subproc` (more stable), default `runner` | +| `--timeout` | Timeout for each task in seconds, default 60 | +| `-j, --jobs` | Number of parallel jobs, default 1 | + +### Example + +Run AutoDD on `tilelang_buggy.py` in this directory: + +```bash +# Use 4 parallel jobs, search for "Dimension mismatch" error +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +# Or use subprocess backend (more stable but slower) +python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py --backend subproc +``` + +## Example Files + +### `tilelang_buggy.py` + +A complex TileLang program with a bug (~200 lines), containing: +- Multiple useless helper functions (`calculate_optimal_block_size`, `get_memory_requirements`, etc.) +- A complex configuration class (`MatmulConfig`) +- Unused benchmark code (`benchmark_pytorch`) +- **A GEMM shape mismatch bug** + +The bug is on line 124: +```python +B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong! Should be (block_K, block_N) +``` + +### `tilelang_minimized_expected.py` + +The expected output after AutoDD simplification (~30 lines). The simplified code clearly shows the root cause of the bug: + +```python +def buggy_matmul(...): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) # Error occurs here +``` + +## How AutoDD Works + +AutoDD uses AST (Abstract Syntax Tree) analysis and multiple rewrite rules to simplify code: + +### 1. Fast Reducers +- **Statement removal**: Directly remove statements that don't affect bug reproduction +- **If statement simplification**: Simplify `if cond: body` to `body` +- **For loop simplification**: Bind loop variables to constants + +### 2. Canonicalizers +- **With statement expansion**: Convert `with expr as var` to explicit assignment +- **Function argument extension**: Add `*args, **kwargs` for compatibility + +### 3. Simplifiers +- **Assignment simplification**: Replace complex expressions with constants +- **Function call simplification**: Simplify `f(x)` to `x` +- **Binary operation simplification**: Simplify `a + b` to `a` or `b` + +### 4. Slow Reducers +- **Expression removal**: Remove arbitrary expressions +- **Argument removal**: Remove function arguments +- **Integer reduction**: Gradually reduce large integers + +## Use Cases + +1. **TileLang kernel debugging**: Simplify complex TileLang programs to locate bugs +2. **Bug report submission**: Generate minimal reproduction code for easier issue tracking +3. **Understanding errors**: Easier to understand the nature of errors after removing irrelevant code +4. **Regression testing**: Simplified code can serve as regression test cases + +## Notes + +1. **Error message matching**: The `--err-msg` parameter needs to exactly match a string in the error output +2. **Timeout setting**: For programs with long compilation times, you may need to increase `--timeout` +3. **Parallel jobs**: Increasing `-j` can speed up the simplification process but consumes more resources +4. **Backend selection**: If the `runner` backend is unstable, try the `subproc` backend + +## References + +- [Delta Debugging Paper](https://www.st.cs.uni-saarland.de/papers/tse2002/) +- [TileLang Documentation](https://github.com/tile-ai/tilelang) diff --git a/examples/autodd/tilelang_buggy.py b/examples/autodd/tilelang_buggy.py new file mode 100644 index 0000000000..d2c5469bbe --- /dev/null +++ b/examples/autodd/tilelang_buggy.py @@ -0,0 +1,229 @@ +""" +A complex TileLang program with lots of redundant code and a bug that triggers an error. +AutoDD will simplify it to the minimal code needed to reproduce the error. + +This example demonstrates how AutoDD can help developers quickly isolate bugs +in complex TileLang programs by automatically removing irrelevant code. + +To run AutoDD on this file: + python -m tilelang.autodd tilelang_buggy.py --err-msg "Dimension mismatch" -o minimized.py -j 4 + +The bug in this file: B_shared has shape (block_M, block_N) instead of (block_K, block_N), +causing a GEMM dimension mismatch error. +""" + +import tilelang +import tilelang.language as T +import torch + + +# Useless helper function - will be removed by AutoDD +def calculate_optimal_block_size(M, N, K): + """Calculate optimal block size - this function is completely useless""" + options = [32, 64, 128, 256] + best = 128 + for opt in options: + if M % opt == 0 and N % opt == 0: + best = opt + break + return best, best, 32 + + +def get_memory_requirements(M, N, K, block_M, block_N, block_K, dtype_size=2): + """Calculate memory requirements - completely useless""" + shared_mem_a = block_M * block_K * dtype_size + shared_mem_b = block_K * block_N * dtype_size + total_shared = shared_mem_a + shared_mem_b + return total_shared + + +def validate_parameters(M, N, K, block_M, block_N, block_K): + """Validate parameters - redundant check""" + if M <= 0 or N <= 0 or K <= 0: + raise ValueError("Matrix dimensions must be positive") + if block_M <= 0 or block_N <= 0 or block_K <= 0: + raise ValueError("Block sizes must be positive") + if M % block_M != 0: + print(f"Warning: M ({M}) not divisible by block_M ({block_M})") + if N % block_N != 0: + print(f"Warning: N ({N}) not divisible by block_N ({block_N})") + if K % block_K != 0: + print(f"Warning: K ({K}) not divisible by block_K ({block_K})") + return True + + +class MatmulConfig: + """Configuration class - increases code complexity but is actually useless""" + + def __init__(self, M, N, K): + self.M = M + self.N = N + self.K = K + self.block_M = 128 + self.block_N = 128 + self.block_K = 32 + self.num_stages = 3 + self.threads = 128 + self.dtype = "float16" + self.accum_dtype = "float32" + + def get_grid_size(self): + grid_x = (self.N + self.block_N - 1) // self.block_N + grid_y = (self.M + self.block_M - 1) // self.block_M + return grid_x, grid_y + + def get_shared_memory_size(self): + return get_memory_requirements(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) + + def validate(self): + return validate_parameters(self.M, self.N, self.K, self.block_M, self.block_N, self.block_K) + + +def create_reference_output(a, b, activation="relu"): + """Create reference output - not actually used in verification""" + result = a @ b + if activation == "relu": + result = torch.relu(result) + elif activation == "gelu": + result = torch.nn.functional.gelu(result) + elif activation == "sigmoid": + result = torch.sigmoid(result) + return result + + +def benchmark_pytorch(M, N, K, num_iters=10, warmup=5): + """PyTorch benchmark - not used""" + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + + # Warmup + for _ in range(warmup): + _ = a @ b + torch.cuda.synchronize() + + # Benchmark + import time + + start = time.time() + for _ in range(num_iters): + _ = a @ b + torch.cuda.synchronize() + end = time.time() + + return (end - start) / num_iters * 1000 # ms + + +# Main TileLang kernel - contains a BUG: GEMM shape mismatch! +@tilelang.jit +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32): + @T.prim_func + def matmul_kernel( + A: T.Tensor((M, K), dtype), + B: T.Tensor((K, N), dtype), + C: T.Tensor((M, N), dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=128) as (bx, by): + # Allocate shared memory + A_shared = T.alloc_shared((block_M, block_K), dtype) + # BUG: the first dimension of B_shared should be block_K, but block_M is used here! + B_shared = T.alloc_shared((block_M, block_N), dtype) # Wrong shape! + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Allocate some useless temp variables + temp_buffer = T.alloc_fragment((block_M, block_N), accum_dtype) + + # Zero out + T.clear(C_local) + T.clear(temp_buffer) + + # Main loop + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + # Copy a tile of A + T.copy(A[by * block_M, ko * block_K], A_shared) + + # Copy a tile of B - shape can mismatch here too + T.copy(B[ko * block_K, bx * block_N], B_shared) + + # GEMM computation - shape mismatch will cause an error + # A_shared: (block_M, block_K) + # B_shared: (block_M, block_N) <- should be (block_K, block_N) + T.gemm(A_shared, B_shared, C_local) + + # ReLU activation + for i, j in T.Parallel(block_M, block_N): + C_local[i, j] = T.max(C_local[i, j], 0) + + # Some useless postprocessing + for i, j in T.Parallel(block_M, block_N): + if temp_buffer[i, j] > 0: + C_local[i, j] = C_local[i, j] + 0.0 + + # Write back result + T.copy(C_local, C[by * block_M, bx * block_N]) + + return matmul_kernel + + +def run_kernel(config): + """Run kernel - includes extra redundant logic""" + # Validate parameters + config.validate() + + # Get config + M, N, K = config.M, config.N, config.K + block_M, block_N, block_K = config.block_M, config.block_N, config.block_K + + # Calculate some useless statistics + grid_size = config.get_grid_size() + shared_mem = config.get_shared_memory_size() + print(f"Grid size: {grid_size}") + print(f"Shared memory: {shared_mem} bytes") + + # Create test data + a = torch.randn(M, K, device="cuda", dtype=torch.float16) + b = torch.randn(K, N, device="cuda", dtype=torch.float16) + c = torch.empty(M, N, device="cuda", dtype=torch.float16) + + # Compile and run kernel - will trigger the BUG here + kernel = buggy_matmul(M, N, K, block_M, block_N, block_K) + kernel(a, b, c) + + # Validate results (if it can get here) + ref_c = torch.relu(a @ b) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("Kernel output matches PyTorch reference.") + + return c + + +def main(): + # Useless printing + print("=" * 60) + print("TileLang Matmul Kernel Test") + print("=" * 60) + + # Create config + M, N, K = 512, 512, 512 + config = MatmulConfig(M, N, K) + + # Calculate some useless values + optimal_block = calculate_optimal_block_size(M, N, K) + print(f"Optimal block size: {optimal_block}") + + # Run PyTorch benchmark - result is not used + # pytorch_time = benchmark_pytorch(M, N, K) + # print(f"PyTorch time: {pytorch_time:.3f} ms") + + # Run our kernel - will trigger the error here + try: + result = run_kernel(config) + print(f"Result shape: {result.shape}") + except Exception as e: + print(f"Error: {e}") + raise + + print("Done!") + + +if __name__ == "__main__": + main() diff --git a/examples/autodd/tilelang_minimized_expected.py b/examples/autodd/tilelang_minimized_expected.py new file mode 100644 index 0000000000..3dc88f9921 --- /dev/null +++ b/examples/autodd/tilelang_minimized_expected.py @@ -0,0 +1,49 @@ +""" +This is the expected output after running AutoDD on tilelang_buggy.py. +AutoDD automatically simplified the 200+ line buggy program to ~30 lines +while preserving the ability to reproduce the error. + +The minimized code clearly shows the root cause of the bug: +- A_shared has shape (block_M, block_K) +- B_shared has shape (block_M, block_N) - should be (block_K, block_N) +- This causes a dimension mismatch in T.gemm() +""" + +import tilelang.language as T + + +class MatmulConfig: + def __init__(self, *args, **kwargs): + self.M = 1 + self.N = 1 + self.K = 1 + self.block_M = 2 + self.block_N = 1 + self.block_K = 1 + + +def buggy_matmul(M, N, K, block_M, block_N, block_K, dtype=T.float16, accum_dtype=T.float32, *args, **kwargs): + @T.prim_func + def matmul_kernel(): + with T.Kernel(): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_M, block_N), dtype) # Bug: should be (block_K, block_N) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + T.gemm(A_shared, B_shared, C_local) + + +def run_kernel(config, *args, **kwargs): + M, N, K = (config.M, config.N, config.K) + block_M, block_N, block_K = (config.block_M, config.block_N, config.block_K) + buggy_matmul(M, N, K, block_M, block_N, block_K) + + +def main(*args, **kwargs): + config = MatmulConfig() + try: + run_kernel(config) + except Exception as e: + print(f"{e}") + + +main() diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py index 7b8b7b95cd..ab8346619d 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_decode.py @@ -139,8 +139,8 @@ def kernel( T.call_extern( "handle", "decode_i2u_to_i8s", - T.address_of(B_quant_local[0]), - T.address_of(B_dequantize_local[0]), + T.access_ptr(B_quant_local, "r"), + T.access_ptr(B_dequantize_local, "w"), ) if use_dp4a: @@ -155,7 +155,7 @@ def kernel( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): diff --git a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py index f4a60098a5..9d7ebcf88c 100644 --- a/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py +++ b/examples/bitnet-1.58b/kernel_benchmark/tilelang_bitnet_158_int8xint2_prefill.py @@ -253,8 +253,8 @@ def main( T.call_extern( "handle", "decode_i2u_to_i8s", - T.address_of(B_local[0]), - T.address_of(B_dequantize_local[0]), + T.access_ptr(B_local, "r"), + T.access_ptr(B_dequantize_local, "w"), ) for v in T.vectorized(0, local_size): diff --git a/examples/bitnet-1.58b/requirements.txt b/examples/bitnet-1.58b/requirements.txt index 67357781e0..7660c28c6d 100644 --- a/examples/bitnet-1.58b/requirements.txt +++ b/examples/bitnet-1.58b/requirements.txt @@ -1,3 +1,3 @@ lm_eval==0.3.0 flash_attn -transformers==4.53.0 +transformers==5.0.0rc3 diff --git a/examples/bitnet-1.58b/tokenization_bitnet.py b/examples/bitnet-1.58b/tokenization_bitnet.py index 2adfd6dee1..8db57a9c09 100644 --- a/examples/bitnet-1.58b/tokenization_bitnet.py +++ b/examples/bitnet-1.58b/tokenization_bitnet.py @@ -38,10 +38,10 @@ PRETRAINED_VOCAB_FILES_MAP = { "vocab_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer.model" }, "tokenizer_file": { - "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json", + "hf-internal-testing/llama-tokenizer": "https://huggingface.co/hf-internal-testing/llama-tokenizer/resolve/main/tokenizer_config.json" }, } PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES = { diff --git a/examples/blockscaled_gemm_sm100/figures/blockscaled_data_path.svg b/examples/blockscaled_gemm_sm100/figures/blockscaled_data_path.svg new file mode 100644 index 0000000000..f08de2b779 --- /dev/null +++ b/examples/blockscaled_gemm_sm100/figures/blockscaled_data_path.svg @@ -0,0 +1,138 @@ + + SM100 block-scaled GEMM data path + Global memory A, B, SFA, and SFB move through TMA, shared memory, scale-factor transpose, UTCCP copy, tensor memory MMA, and epilogue store. + + + + + + + + + + + + + + + + + + + + + + + + Blackwell MXFP8 Block-Scaled GEMM Data Path + A/B are FP8; SFA/SFB are packed UE8M0 scale factors; C accumulates in TMEM. + + GLOBAL MEMORY + + A[M,K] + FP8 tile M x K + + + B[K,N] + FP8 tile K x N + + + SFA / SFB + 4 UE8M0 bytes per u32 + + + C + BF16 output + + SHARED MEMORY PIPELINE STAGE + + A_shared[s] + block_M x block_K + loaded by TMA + + + B_shared[s] + block_K x block_N + half_N per CTA in 2CTA + + + SFA/SFB_shared[s] + u32 vectors from TMA + load every 4 K stages + + + SF transpose + 4 x 32 -> 32 x 4 + then fence.proxy.async + + TENSOR MEMORY AND MMA + + SFA_tmem + 128 lanes x 4 cols + id = k mod 4 + + + SFB_tmem + 128 lanes x 4/8 cols + N chunks of 128 cols + + + tcgen05.mma.block_scale + (A * scale_A) * (B * scale_B) + cta_group::1 or cta_group::2 + + + C_tmem + FP32 accumulator + block_M x block_N + + + Epilogue + TMEM -> regs + regs -> SMEM -> C + + + + + TMA every K stage + SF TMA every 4 K stages + + + + + tcgen05.cp.32x128b.warpx4 + + + + + + + + + + + Barrier ring: + consumed -> loaded -> with_sf_full -> consumed + + + Final handoff: + tmem_full + persistent path also uses tmem_empty + diff --git a/examples/blockscaled_gemm_sm100/figures/blockscaled_sf_layout.svg b/examples/blockscaled_gemm_sm100/figures/blockscaled_sf_layout.svg new file mode 100644 index 0000000000..94d8ef0ca5 --- /dev/null +++ b/examples/blockscaled_gemm_sm100/figures/blockscaled_sf_layout.svg @@ -0,0 +1,163 @@ + + Scale-factor packing, SMEM transpose, and TMEM layout + UE8M0 scale factors are packed four per uint32, loaded to shared memory, transposed in 128-word chunks, copied to tensor memory columns, and selected by sf_id. + + + + + + + + + + + + + + + + + + + + Scale-Factor Layout: Global -> SMEM -> TMEM + Example uses block_K = sf_granularity_k = 128, so one packed word covers four K iterations. + + + 1. Global packed SF + group-major flat layout + + SFA[sf_group * M + row] + same pattern for SFB with N + + One uint32 word: + + + + + + sf0 + sf1 + sf2 + sf3 + bits 0..7 + 8..15 + 16..23 + 24..31 + + + + K stage mapping + k % 4 = 0 -> sf0 + k % 4 = 1 -> sf1 + k % 4 = 2 -> sf2 + k % 4 = 3 -> sf3 + The unpacked scale is 2^(ue8m0 - 127). + + + 2. SMEM transpose for UTCCP + per 128-word chunk + + Before: 4 x 32 view + + + + + + + + + word[0..31] + ... + lane + 31 + word[32..63] + word[64..95] + word[96..127] + + + + tcgen05_sf_warp_transpose + + After: 32 x 4 view + + + + + + + + + lane 0 + sf0 + sf1 + sf2 + lane 1 + ... + lane 31 + + TileLang applies this to each 128-word + chunk automatically. + + + 3. TMEM SF columns + tcgen05.cp.32x128b.warpx4 + + + + SFA_tmem + 128 lanes x 4 columns + + + + + + + lane 0 + col1 + col2 + col3 + ... + lane 127 + + + + + SFB_tmem + 128 lanes x 8 columns for block_N=256 + + + + + id0 + id1 + id2 + id3 + u32 cell byte sub-columns + + + + + 4 TMEM columns per + 128-word chunk + + sf_id selects the byte sub-column for this K iteration. + diff --git a/examples/blockscaled_gemm_sm100/figures/blockscaled_variants.svg b/examples/blockscaled_gemm_sm100/figures/blockscaled_variants.svg new file mode 100644 index 0000000000..f570c22666 --- /dev/null +++ b/examples/blockscaled_gemm_sm100/figures/blockscaled_variants.svg @@ -0,0 +1,97 @@ + + Block-scaled GEMM kernel variants + Comparison of 1CTA, 2CTA, and persistent 2CTA scheduling designs. + + + + + + + + + + + + + + + + + + + + Three Versions in GEMM Example + All versions run the same block-scaled MMA data path; they differ in CTA grouping and scheduling. + + + 1CTA + mxfp8_blockscaled_gemm + + one CTA tile + 128 x 128 + local SMEM, local TMEM, local barriers + + warp 0: TMA producer + + warp 1: UTCCP + MMA issue + + warp 2: SF transpose + Epilogue waits tmem_full. + Then C_tmem is copied through regs/SMEM to C. + + + 2CTA + mxfp8_blockscaled_gemm_2cta + + + logical tile + 128 x 256 + CTA pair computes one tile + + CTA 0 + B columns 0..127 + + CTA 1 + B columns 128..255 + + + leader CTA issues cta_group::2 + cluster barriers cover both CTAs + A and SFA are available in both peers. + SFB covers the full logical block_N. + + + 2CTA Persistent + mxfp8_blockscaled_gemm_2cta_persistent + + resident CTA pair per cluster + + + wave loop over tiles + tile_id = num_clusters * w + cluster_id + + MMA + warps 0..2 + + Store + warps 4..7 + + + tmem_full sends data to epilogue. + tmem_empty releases TMEM for the next wave. + diff --git a/examples/blockscaled_gemm_sm100/figures/blockscaled_warp_specialization.svg b/examples/blockscaled_gemm_sm100/figures/blockscaled_warp_specialization.svg new file mode 100644 index 0000000000..2ae2e1b43e --- /dev/null +++ b/examples/blockscaled_gemm_sm100/figures/blockscaled_warp_specialization.svg @@ -0,0 +1,103 @@ + + Warp-specialized pipeline and barriers + Timeline for TMA, scale-factor transpose, UTCCP, MMA, and epilogue roles. + + + + + + + + + + + + + + + + Warp-Specialized Handshake + The roles are fixed by warp index. Barrier parity advances with pipeline stage and persistent wave. + + Role + 1. wait / produce + 2. transform + 3. issue MMA + 4. epilogue + + + warp 0 + TMA producer + + wait consumed + + TMA A/B + plus SF every 4 K + + arrive loaded + + + warp 2 + SF transposer + + wait loaded + + transpose 128 u32 chunks + 4 x 32 -> 32 x 4 + + proxy fence + SMEM -> async proxy + + arrive with_sf_full + + + warp 1 + leader CTA MMA + + wait with_sf_full + + tcgen05.cp.warpx4 + SMEM SF -> TMEM SF + + tcgen05.mma.block_scale + clear_accum only on k=0 + + arrive consumed + + arrive tmem_full + + + epilogue + all warps or 4..7 + + wait tmem_full + + C_tmem -> regs + then SMEM/global + + persistent only + arrive tmem_empty + + + + + + + + 2CTA uses cluster barriers for with_sf_full, + consumed, and tmem_empty. + diff --git a/examples/blockscaled_gemm_sm100/figures/sfa.png b/examples/blockscaled_gemm_sm100/figures/sfa.png new file mode 100644 index 0000000000..545a648693 Binary files /dev/null and b/examples/blockscaled_gemm_sm100/figures/sfa.png differ diff --git a/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py b/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py new file mode 100644 index 0000000000..a8acb6dc97 --- /dev/null +++ b/examples/blockscaled_gemm_sm100/gemm_mxfp8_blockscaled_1d1d.py @@ -0,0 +1,762 @@ +# MXFP8 Block-Scaled GEMM on SM100 +# Blockscale size: (M, N, K) = (1, 1, 128) + +import argparse +import torch +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.profiler import do_bench + + +@tilelang.jit +def mxfp8_blockscaled_gemm( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, + transpose_B=False, +): + """1D-1D Block-scaled MXFP8 GEMM. + + A: [M, K] in FP8 (E4M3 or E5M2) + B: [K, N] in FP8 (E4M3 or E5M2), or [N, K] when transpose_B=True + SFA: [(K / sf_granularity_k) / 4) * M] in uint32 + Group-major packed E8M0 scale factors for A. + SFB: [(K / sf_granularity_k) / 4) * N] in uint32 + Group-major packed E8M0 scale factors for B. + """ + M, N, K = T.const("M, N, K") + + k_iters = T.ceildiv(K, block_K) + # Load 4 K-blocks of SF at once → load every 4 iterations + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[N, K] if transpose_B else [K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + # Data shared memory (pipelined) + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared( + (num_stages, block_N, block_K) if transpose_B else (num_stages, block_K, block_N), + in_dtype, + ) + + # Scale factor shared memory — one uint32 per row/column, packing 4 K-blocks. + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + # Accumulator in tensor memory + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + + # Scale factors in tensor memory (TMEM has 128 rows / 32-bit cells) + SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], "uint32") + + # Output buffers + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + # Barriers + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + T.use_swizzle(8) + + if tx < 32: + # Warp 0: TMA load + for k in T.serial(k_iters): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + if transpose_B: + T.tma_copy( + B[by * block_N : (by + 1) * block_N, k * block_K : (k + 1) * block_K], + B_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + else: + T.tma_copy( + B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], + B_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + # Load one packed uint32 SF word every sf_load_period iterations. + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[k % num_stages, :], + barrier=loaded[k % num_stages], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[k % num_stages, :], + barrier=loaded[k % num_stages], + ) + T.mbarrier_arrive(loaded[k % num_stages]) + + elif tx < 64: + # Warp 1: MMA issue + UTCCP + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + T.mbarrier_wait_parity(with_sf_full[stage], phase) + + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem) + + # sf_id selects which of the 4 packed E8M0 values to use + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=transpose_B, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + ) + + T.tcgen05_mma_arrive(tmem_full) + + elif tx < 96: + # Warp 2: scale-factor transpose + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage]) + + # Epilogue: all warps + T.mbarrier_wait_parity(tmem_full, 0) + T.sync_threads() + + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return C + + +@tilelang.jit +def mxfp8_blockscaled_gemm_2cta( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, + transpose_B=False, +): + M, N, K = T.const("M, N, K") + + assert block_M == 128 + assert block_N == 256 + assert block_K == 128 + assert sf_granularity_k == 128 + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + assert sf_load_period == 4 + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[N, K] if transpose_B else [K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (bx, by): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared( + (num_stages, half_N, block_K) if transpose_B else (num_stages, block_K, half_N), + in_dtype, + ) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, 8], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + T.use_swizzle(16) + + if warp_idx == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], phase ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + if transpose_B: + T.tma_copy( + B[ + (by * block_N + cta_id * half_N) : (by * block_N + (cta_id + 1) * half_N), + k * block_K : (k + 1) * block_K, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + else: + T.tma_copy( + B[ + k * block_K : (k + 1) * block_K, + (by * block_N + cta_id * half_N) : (by * block_N + (cta_id + 1) * half_N), + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=transpose_B, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + T.mbarrier_wait_parity(tmem_full, 0) + T.copy(C_tmem, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + + return C + + +@tilelang.jit +def mxfp8_blockscaled_gemm_2cta_persistent( + A, + B, + SFA, + SFB, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k=128, + transpose_B=False, + use_tma_store=True, + store_block_N=64, +): + M, N, K = T.const("M, N, K") + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[N, K] if transpose_B else [K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M], T.uint32] + SFB: T.Tensor[[sf_k_groups * N], T.uint32] + C = T.empty((M, N), out_dtype) + + sm_num = driver.get_num_sms() + num_clusters = sm_num // 2 + m_blocks = T.ceildiv(M, block_M) + m_clusters = m_blocks // 2 + n_blocks = T.ceildiv(N, block_N) + assert K % (2 * block_K) == 0 # for simplicity + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 16 # in cluster + assert n_blocks % (2 * group_size) == 0 # Please adjust group_size if not satisfied + + with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared( + (num_stages, half_N, block_K) if transpose_B else (num_stages, block_K, half_N), + in_dtype, + ) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1]) + tmem_empty = T.alloc_cluster_barrier([128 * 2]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + + if warp_idx == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], parity ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + if transpose_B: + T.tma_copy( + B[ + by * block_N + cta_id * half_N : by * block_N + (cta_id + 1) * half_N, + k * block_K : (k + 1) * block_K, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + else: + T.tma_copy( + B[ + k * block_K : (k + 1) * block_K, + by * block_N + cta_id * half_N : by * block_N + (cta_id + 1) * half_N, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M + bx * block_M : sf_group_idx * M + (bx + 1) * block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[sf_group_idx * N + by * block_N : sf_group_idx * N + (by + 1) * block_N], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_empty, (w & 1) ^ 1) + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=transpose_B, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + elif 128 <= tx < 256: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_full, w & 1) + T.copy(C_tmem, C_local) + T.mbarrier_arrive(tmem_empty, 0) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + return C + + +def unpack_sf_u32_1d(packed_sf, mn, sf_k_blocks): + sf_k_groups = (sf_k_blocks + 3) // 4 + packed_2d = packed_sf.view(sf_k_groups, mn).T.contiguous().to(torch.int64) + unpacked = torch.empty((mn, sf_k_groups * 4), device=packed_sf.device, dtype=torch.uint8) + for i in range(4): + unpacked[:, i::4] = ((packed_2d >> (8 * i)) & 0xFF).to(torch.uint8) + return unpacked[:, :sf_k_blocks].contiguous() + + +def pack_sf_u8_to_u32_1d(sf_u8): + assert sf_u8.dtype == torch.uint8 + assert sf_u8.dim() == 2 + mn, sf_k_padded = sf_u8.shape + assert sf_k_padded % 4 == 0 + words = sf_u8.to(torch.int64) + packed = (words[:, 0::4] | (words[:, 1::4] << 8) | (words[:, 2::4] << 16) | (words[:, 3::4] << 24)).to(torch.uint32) + return packed.T.contiguous().reshape(-1) + + +def quantize_fp8_with_packed_ue8m0(x, gran_k=128): + """DeepGEMM-style per-token FP8 quantization with UE8M0 scale factors. + + Returns: + x_fp8: [MN, K] in float8_e4m3fn + sf_packed_u32: flattened group-major packed uint32 scale factors + sf_u8: [MN, ceil(K / gran_k)] unpacked E8M0 exponents + """ + + def ceil_div_int(x, y): + return (x + y - 1) // y + + def align_up(x, y): + return ceil_div_int(x, y) * y + + def ceil_to_ue8m0(x): + bits = x.abs().float().view(torch.int32) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).ne(0).to(torch.int32) + return (exp.clamp(1, 254) << 23).view(torch.float32) + + assert x.dim() == 2 + mn, k = x.shape + padded_k = align_up(k, gran_k) + + x_padded = torch.zeros((mn, padded_k), device=x.device, dtype=x.dtype) + x_padded[:, :k] = x + x_view = x_padded.view(mn, padded_k // gran_k, gran_k) + + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(mn, padded_k)[:, :k].contiguous() + + sf_u8 = (sf.contiguous().view(torch.int32) >> 23).to(torch.uint8) + sf_k_blocks = sf_u8.shape[1] + sf_k_padded = align_up(sf_k_blocks, 4) + if sf_k_padded != sf_k_blocks: + sf_u8_padded = torch.full((mn, sf_k_padded), 127, device=x.device, dtype=torch.uint8) + sf_u8_padded[:, :sf_k_blocks] = sf_u8 + else: + sf_u8_padded = sf_u8 + + sf_packed_u32 = pack_sf_u8_to_u32_1d(sf_u8_padded) + return x_fp8, sf_packed_u32, sf_u8 + + +def blockscaled_gemm_ref(a, b, sfa_packed, sfb_packed, sf_granularity_k=128, transpose_B=False): + """Torch reference for block-scaled MXFP8 GEMM. + + Args: + a: [M, K] FP8 tensor + b: [K, N] FP8 tensor, or [N, K] when transpose_B=True + sfa_packed: [(sf_k_blocks / 4) * M] uint32 packed E8M0 scale factors for A + sfb_packed: [(sf_k_blocks / 4) * N] uint32 packed E8M0 scale factors for B + sf_granularity_k: number of K elements per scale factor block (default 128) + + Returns: + [M, N] float32 result + """ + M, K = a.shape + if transpose_B: + N, K2 = b.shape + else: + K2, N = b.shape + assert K == K2 + sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k + sfa_unpacked = unpack_sf_u32_1d(sfa_packed, M, sf_k_blocks) + sfb_unpacked = unpack_sf_u32_1d(sfb_packed, N, sf_k_blocks) + + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + + # E8M0 exponent to float scale: 2^(exp - 127) + sfa_scales = torch.pow(2.0, sfa_unpacked.to(torch.float32) - 127.0) # [M, sf_k_blocks] + sfb_scales = torch.pow(2.0, sfb_unpacked.to(torch.float32) - 127.0) # [N, sf_k_blocks] + + c = torch.zeros(M, N, device=a.device, dtype=torch.float32) + for bi in range(sf_k_blocks): + k_start = bi * sf_granularity_k + k_end = min(k_start + sf_granularity_k, K) + # Scale A block: [M, block_k] * [M, 1] + a_block = a_f32[:, k_start:k_end] * sfa_scales[:, bi : bi + 1] + if transpose_B: + # Scale B block: [N, block_k] * [N, 1] + b_block = b_f32[:, k_start:k_end] * sfb_scales[:, bi : bi + 1] + c += a_block @ b_block.T + else: + # Scale B block: [block_k, N] * [1, N] + b_block = b_f32[k_start:k_end, :] * sfb_scales[:, bi : bi + 1].T + c += a_block @ b_block + return c + + +def cosine_similarity(a, b): + a_flat = a.flatten().float() + b_flat = b.flatten().float() + return (a_flat @ b_flat) / (a_flat.norm() * b_flat.norm()) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--use-e2e-quant-path", action="store_true", default=True) + parser.add_argument("--persistent", action="store_true", default=True) + parser.add_argument("--enable-2cta", action="store_true", default=True) + parser.add_argument("--transpose-b", action="store_true", help="Use B as [N, K] and compute A @ B.T.") + return parser.parse_args() + + +def main(): + args = parse_args() + + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 128 + in_dtype, out_dtype, accum_dtype = T.float8_e4m3fn, T.bfloat16, T.float + use_e2e_quant_path = args.use_e2e_quant_path + persistent = args.persistent + enable_2cta = args.enable_2cta + transpose_B = args.transpose_b + num_stages = 6 if enable_2cta else 4 + if persistent: + assert enable_2cta + kernel = mxfp8_blockscaled_gemm_2cta_persistent + else: + kernel = mxfp8_blockscaled_gemm_2cta if enable_2cta else mxfp8_blockscaled_gemm + sf_granularity_k = 128 + assert sf_granularity_k == 128 + + if use_e2e_quant_path: + # End-to-end path: + # fp16/bf16 source tensors -> per-token FP8 quantization with UE8M0 SF + # -> pack 4 SF entries into one uint32 -> blockscaled GEMM + x = torch.randn(M, K, device="cuda", dtype=torch.float16) + w_nt = torch.randn(N, K, device="cuda", dtype=torch.float16) + + a, sfa, _ = quantize_fp8_with_packed_ue8m0(x, gran_k=sf_granularity_k) + b_nt, sfb, _ = quantize_fp8_with_packed_ue8m0(w_nt, gran_k=sf_granularity_k) + b = b_nt if transpose_B else b_nt.T.contiguous() + else: + a = torch.randn(M, K, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + if transpose_B: + b = torch.randn(N, K, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + else: + b = torch.randn(K, N, device="cuda", dtype=torch.float16).to(torch.float8_e4m3fn) + + # E8M0 scale factors: one uint32 per row per 4 K-blocks. + sf_k_blocks = (K + sf_granularity_k - 1) // sf_granularity_k + + # Pad to multiple of 4 (UTCCP loads 4 K-blocks at a time) + sf_k_padded = ((sf_k_blocks + 3) // 4) * 4 + sfa_u8 = torch.randint(127 - 5, 127 + 5, (M, sf_k_padded), device="cuda", dtype=torch.uint8) + sfb_u8 = torch.randint(127 - 5, 127 + 5, (N, sf_k_padded), device="cuda", dtype=torch.uint8) + sfa = pack_sf_u8_to_u32_1d(sfa_u8) + sfb = pack_sf_u8_to_u32_1d(sfb_u8) + + c = kernel( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + transpose_B, + ) + print( + kernel.get_kernel_source( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + transpose_B, + ) + ) + + if use_e2e_quant_path: + # For the end-to-end quantization path, compare against the reference with bf16 gemm + ref_c = (x.float() @ w_nt.float().T).to(torch.bfloat16) + else: + ref_c = blockscaled_gemm_ref(a, b, sfa, sfb, sf_granularity_k, transpose_B=transpose_B).to(torch.bfloat16) + sim = cosine_similarity(c, ref_c) + + print(f"Output shape: {c.shape}, dtype: {c.dtype}") + print(f"E2E quant path: {use_e2e_quant_path}") + print(f"transpose_B: {transpose_B}") + print(f"{c=}, {ref_c=}") + # print(f"Max abs error: {(c.float() - ref_c.float()).abs().max().item():.6f}") + print(f"Cosine similarity: {sim.item():.6f}") + if use_e2e_quant_path: + assert 1 - sim < 1e-3 # err tolerance from DeepGEMM + print("e2e check passed ✅") + + tl_latency = do_bench( + lambda: kernel( + a, + b, + sfa, + sfb, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + sf_granularity_k, + transpose_B, + ), + backend="cupti", + ) + print(f"Tilelang MXFP8 latency: {tl_latency} ms") + print(f"TFLOPs: {2 * M * N * K / (tl_latency / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/blockscaled_gemm_sm100/grouped_gemm_mxfp8_blockscaled_1d1d.py b/examples/blockscaled_gemm_sm100/grouped_gemm_mxfp8_blockscaled_1d1d.py new file mode 100644 index 0000000000..eaa87e994e --- /dev/null +++ b/examples/blockscaled_gemm_sm100/grouped_gemm_mxfp8_blockscaled_1d1d.py @@ -0,0 +1,657 @@ +# Grouped MXFP8 block-scaled GEMM on SM100. +# Blockscale size: (M, N, K) = (1, 1, 128) + +import argparse + +import torch +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.profiler import do_bench + + +@tilelang.jit +def grouped_mxfp8_blockscaled_gemm_2cta( + A, + B, + SFA, + SFB, + offsets, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + max_M_per_E, + transpose_B=True, + sf_granularity_k=128, +): + """Grouped 2CTA MXFP8 blockscaled GEMM. + + Logical scale shape follows tilelang_gemm.py: + SFA [M_total, sf_k_packed], SFB [E, N, sf_k_packed] + + Kernel scale operands are group-major flat buffers so the SF loads can use + the same contiguous TMA pattern as mxfp8_blockscaled_gemm_2cta. + """ + M_total, N, K, E, E1 = T.const("M_total, N, K, E, E1") + + assert block_M == 128 + assert block_N == 256 + assert block_K == 128 + assert sf_granularity_k == 128 + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + assert sf_load_period == 4 + + A: T.Tensor[[M_total, K], in_dtype] + B: T.Tensor[[E, N, K] if transpose_B else [E, K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M_total], T.uint32] + SFB: T.Tensor[[sf_k_groups * E * N], T.uint32] + offsets: T.Tensor[[E1], T.int32] + C = T.empty((M_total, N), out_dtype) + + n_blocks = T.ceildiv(N, block_N) + max_M_blocks = T.ceildiv(max_M_per_E, block_M) + max_M_blocks_padded = T.ceildiv(max_M_blocks, 2) * 2 + + with T.Kernel(max_M_blocks_padded, n_blocks, E, threads=128, cluster_dims=2) as (pid_m, pid_n, eid): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared( + (num_stages, half_N, block_K) if transpose_B else (num_stages, block_K, half_N), + in_dtype, + ) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, 8], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + T.use_swizzle(16) + + start_m = offsets[eid] + end_m = offsets[eid + 1] + m_size = end_m - start_m + expert_m_blocks = T.ceildiv(m_size, block_M) + clamped_pid_m = T.min(pid_m, T.max(expert_m_blocks, 1) - 1) + tile_m = start_m + clamped_pid_m * block_M + + if warp_idx == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], phase ^ 1) + T.tma_copy( + A[tile_m : tile_m + block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + if transpose_B: + T.tma_copy( + B[ + eid, + pid_n * block_N + cta_id * half_N : pid_n * block_N + (cta_id + 1) * half_N, + k * block_K : (k + 1) * block_K, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + else: + T.tma_copy( + B[ + eid, + k * block_K : (k + 1) * block_K, + pid_n * block_N + cta_id * half_N : pid_n * block_N + (cta_id + 1) * half_N, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M_total + tile_m : sf_group_idx * M_total + tile_m + block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[sf_group_idx * E * N + eid * N + pid_n * block_N : sf_group_idx * E * N + eid * N + (pid_n + 1) * block_N], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=transpose_B, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for k in T.serial(k_iters): + stage = k % num_stages + phase = (k // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], phase) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + T.mbarrier_wait_parity(tmem_full, 0) + T.copy(C_tmem, C_local) + + if pid_m * block_M < m_size and tile_m + block_M <= end_m: + T.copy(C_local, C_shared) + T.copy(C_shared, C[tile_m, pid_n * block_N]) + elif pid_m * block_M < m_size: + T.copy(C_local, C_local_cast) + actual_rows = end_m - tile_m + for i, j in T.Parallel(block_M, block_N): + if i < actual_rows and pid_n * block_N + j < N: + C[tile_m + i, pid_n * block_N + j] = C_local_cast[i, j] + + return C + + +@tilelang.jit +def grouped_mxfp8_blockscaled_gemm_2cta_persistent( + A, + B, + SFA, + SFB, + offsets, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + max_M_per_E, + transpose_B=True, + sf_granularity_k=128, + store_block_N=64, +): + """Persistent grouped 2CTA MXFP8 blockscaled GEMM with one accumulator TMEM.""" + M_total, N, K, E, E1 = T.const("M_total, N, K, E, E1") + + assert block_M == 128 + assert block_N == 256 + assert block_K == 128 + assert sf_granularity_k == 128 + + half_N = block_N // 2 + k_iters = T.ceildiv(K, block_K) + sf_load_period = sf_granularity_k * 4 // block_K + sf_k_groups = T.ceildiv(T.ceildiv(K, sf_granularity_k), 4) + assert sf_load_period == 4 + + A: T.Tensor[[M_total, K], in_dtype] + B: T.Tensor[[E, N, K] if transpose_B else [E, K, N], in_dtype] + SFA: T.Tensor[[sf_k_groups * M_total], T.uint32] + SFB: T.Tensor[[sf_k_groups * E * N], T.uint32] + offsets: T.Tensor[[E1], T.int32] + C = T.empty((M_total, N), out_dtype) + + sm_num = driver.get_num_sms() + num_clusters = sm_num // 2 + n_blocks = T.ceildiv(N, block_N) + max_M_blocks = T.ceildiv(max_M_per_E, block_M) + max_M_blocks_padded = T.ceildiv(max_M_blocks, 2) * 2 + m_clusters = max_M_blocks_padded // 2 + total_cluster_tiles = E * n_blocks * m_clusters + waves = T.ceildiv(total_cluster_tiles, num_clusters) + group_size = 8 + + with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared( + (num_stages, half_N, block_K) if transpose_B else (num_stages, block_K, half_N), + in_dtype, + ) + SFA_shared = T.alloc_shared((num_stages, block_M), "uint32") + SFB_shared = T.alloc_shared((num_stages, block_N), "uint32") + + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + SFA_tmem = T.alloc_tmem([block_M, block_M // 128 * 4], "uint32") + SFB_tmem = T.alloc_tmem([block_M, block_N // 128 * 4], "uint32") + + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + + loaded = T.alloc_barrier([32] * num_stages) + with_sf_full = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1]) + tmem_empty = T.alloc_cluster_barrier([128 * 2]) + + tx = T.get_thread_binding() + warp_idx = tx // 32 + + if warp_idx == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + eid = tile_id // (n_blocks * m_clusters) + local_tile_id = tile_id - eid * n_blocks * m_clusters + num_pid_in_group = group_size * n_blocks + group_id = local_tile_id // num_pid_in_group + first_pid_m_cluster = group_id * group_size + group_m = T.min(m_clusters - first_pid_m_cluster, group_size) + pid_m_cluster = first_pid_m_cluster + (local_tile_id % num_pid_in_group) % group_m + pid_n = (local_tile_id % num_pid_in_group) // group_m + + if tile_id < total_cluster_tiles: + start_m = offsets[eid] + end_m = offsets[eid + 1] + m_size = end_m - start_m + expert_m_blocks = T.ceildiv(m_size, block_M) + pid_m = pid_m_cluster * 2 + cta_id + safe_pid_m = T.min(pid_m, T.max(expert_m_blocks, 1) - 1) + tile_m = start_m + safe_pid_m * block_M + + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(consumed[stage], parity ^ 1) + T.tma_copy( + A[tile_m : tile_m + block_M, k * block_K : (k + 1) * block_K], + A_shared[stage, :, :], + barrier=loaded[stage], + ) + if transpose_B: + T.tma_copy( + B[ + eid, + pid_n * block_N + cta_id * half_N : pid_n * block_N + (cta_id + 1) * half_N, + k * block_K : (k + 1) * block_K, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + else: + T.tma_copy( + B[ + eid, + k * block_K : (k + 1) * block_K, + pid_n * block_N + cta_id * half_N : pid_n * block_N + (cta_id + 1) * half_N, + ], + B_shared[stage, :, :], + barrier=loaded[stage], + ) + if k % sf_load_period == 0: + sf_group_idx = k // sf_load_period + T.tma_copy( + SFA[sf_group_idx * M_total + tile_m : sf_group_idx * M_total + tile_m + block_M], + SFA_shared[stage, :], + barrier=loaded[stage], + ) + T.tma_copy( + SFB[ + sf_group_idx * E * N + eid * N + pid_n * block_N : sf_group_idx * E * N + + eid * N + + (pid_n + 1) * block_N + ], + SFB_shared[stage, :], + barrier=loaded[stage], + ) + T.mbarrier_arrive(loaded[stage]) + + elif warp_idx == 1 and cta_id == 0: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + + if tile_id < total_cluster_tiles: + T.mbarrier_wait_parity(tmem_empty, (w & 1) ^ 1) + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(with_sf_full[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_cp_warpx4(SFA_shared[stage, :], SFA_tmem, use_2cta=True) + T.tcgen05_cp_warpx4(SFB_shared[stage, :], SFB_tmem, use_2cta=True) + T.tcgen05_gemm_blockscaled( + A_shared[stage, :, :], + B_shared[stage, :, :], + C_tmem, + SFA_tmem, + SFB_tmem, + transpose_B=transpose_B, + mbar=consumed[stage], + clear_accum=k == 0, + sf_a_id=k % sf_load_period, + sf_b_id=k % sf_load_period, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + elif warp_idx == 2: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + + if tile_id < total_cluster_tiles: + for k in T.serial(k_iters): + phase = w * k_iters + k + stage = phase % num_stages + parity = (phase // num_stages) & 1 + T.mbarrier_wait_parity(loaded[stage], parity) + if k % sf_load_period == 0: + T.tcgen05_sf_warp_transpose(SFA_shared[stage, :]) + T.tcgen05_sf_warp_transpose(SFB_shared[stage, :]) + T.fence_proxy_async() + T.mbarrier_arrive(with_sf_full[stage], 0) + + elif 128 <= tx < 256: + for w in T.unroll(waves): + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + eid = tile_id // (n_blocks * m_clusters) + local_tile_id = tile_id - eid * n_blocks * m_clusters + num_pid_in_group = group_size * n_blocks + group_id = local_tile_id // num_pid_in_group + first_pid_m_cluster = group_id * group_size + group_m = T.min(m_clusters - first_pid_m_cluster, group_size) + pid_m_cluster = first_pid_m_cluster + (local_tile_id % num_pid_in_group) % group_m + pid_n = (local_tile_id % num_pid_in_group) // group_m + pid_m = pid_m_cluster * 2 + cta_id + + if tile_id < total_cluster_tiles: + start_m = offsets[eid] + end_m = offsets[eid + 1] + m_size = end_m - start_m + tile_m = start_m + pid_m * block_M + T.mbarrier_wait_parity(tmem_full, w & 1) + T.copy(C_tmem, C_local) + T.mbarrier_arrive(tmem_empty, 0) + + if pid_m * block_M < m_size and tile_m + block_M <= end_m: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[tile_m, pid_n * block_N + i * store_block_N]) + elif pid_m * block_M < m_size: + T.copy(C_local, C_local_cast) + actual_rows = end_m - tile_m + for i, j in T.Parallel(block_M, block_N): + if i < actual_rows and pid_n * block_N + j < N: + C[tile_m + i, pid_n * block_N + j] = C_local_cast[i, j] + + return C + + +def pack_sf_u8_to_u32_rows(sf_u8): + assert sf_u8.dtype == torch.uint8 + assert sf_u8.dim() == 2 + assert sf_u8.shape[1] % 4 == 0 + words = sf_u8.to(torch.int64) + return (words[:, 0::4] | (words[:, 1::4] << 8) | (words[:, 2::4] << 16) | (words[:, 3::4] << 24)).to(torch.uint32).contiguous() + + +def pack_rows_to_group_major_flat(packed_rows): + return packed_rows.contiguous().T.contiguous().reshape(-1) + + +def pack_sfb_to_group_major_flat(packed_sfb): + return packed_sfb.contiguous().permute(2, 0, 1).contiguous().reshape(-1) + + +def unpack_sf_u32_rows(packed_sf, sf_k_blocks): + words = packed_sf.contiguous().view(-1, packed_sf.shape[-1]).to(torch.int64) + unpacked = torch.empty((words.shape[0], words.shape[1] * 4), device=packed_sf.device, dtype=torch.uint8) + for i in range(4): + unpacked[:, i::4] = ((words >> (8 * i)) & 0xFF).to(torch.uint8) + return unpacked[:, :sf_k_blocks].view(*packed_sf.shape[:-1], sf_k_blocks).contiguous() + + +def quantize_fp8_with_packed_ue8m0_rows(x, gran_k=128): + def ceil_div_int(x, y): + return (x + y - 1) // y + + def align_up(x, y): + return ceil_div_int(x, y) * y + + def ceil_to_ue8m0(x): + bits = x.abs().float().view(torch.int32) + exp = ((bits >> 23) & 0xFF) + (bits & 0x7FFFFF).ne(0).to(torch.int32) + return (exp.clamp(1, 254) << 23).view(torch.float32) + + assert x.dim() == 2 + mn, k = x.shape + padded_k = align_up(k, gran_k) + x_padded = torch.zeros((mn, padded_k), device=x.device, dtype=x.dtype) + x_padded[:, :k] = x + x_view = x_padded.view(mn, padded_k // gran_k, gran_k) + + x_amax = x_view.abs().float().amax(dim=2).clamp_min(1e-4) + sf = ceil_to_ue8m0(x_amax / 448.0) + x_fp8 = (x_view * (1.0 / sf.unsqueeze(2))).to(torch.float8_e4m3fn) + x_fp8 = x_fp8.view(mn, padded_k)[:, :k].contiguous() + + sf_u8 = (sf.contiguous().view(torch.int32) >> 23).to(torch.uint8) + sf_k_padded = align_up(sf_u8.shape[1], 4) + if sf_k_padded != sf_u8.shape[1]: + sf_u8_padded = torch.full((mn, sf_k_padded), 127, device=x.device, dtype=torch.uint8) + sf_u8_padded[:, : sf_u8.shape[1]] = sf_u8 + else: + sf_u8_padded = sf_u8 + return x_fp8, pack_sf_u8_to_u32_rows(sf_u8_padded), sf_u8 + + +def grouped_blockscaled_gemm_ref(a, b, sfa_packed, sfb_packed, offsets, sf_granularity_k=128, transpose_B=True): + m_total, k = a.shape + if transpose_B: + e, n, k2 = b.shape + else: + e, k2, n = b.shape + assert k == k2 + sf_k_blocks = (k + sf_granularity_k - 1) // sf_granularity_k + sfa_unpacked = unpack_sf_u32_rows(sfa_packed, sf_k_blocks) + sfb_unpacked = unpack_sf_u32_rows(sfb_packed, sf_k_blocks) + + a_f32 = a.to(torch.float32) + b_f32 = b.to(torch.float32) + sfa_scales = torch.pow(2.0, sfa_unpacked.to(torch.float32) - 127.0) + sfb_scales = torch.pow(2.0, sfb_unpacked.to(torch.float32) - 127.0) + + c = torch.empty((m_total, n), device=a.device, dtype=torch.float32) + for eid in range(e): + start = int(offsets[eid].item()) + end = int(offsets[eid + 1].item()) + if start == end: + continue + out = torch.zeros((end - start, n), device=a.device, dtype=torch.float32) + for bi in range(sf_k_blocks): + k_start = bi * sf_granularity_k + k_end = min(k_start + sf_granularity_k, k) + a_block = a_f32[start:end, k_start:k_end] * sfa_scales[start:end, bi : bi + 1] + if transpose_B: + b_block = b_f32[eid, :, k_start:k_end] * sfb_scales[eid, :, bi : bi + 1] + out += a_block @ b_block.T + else: + b_block = b_f32[eid, k_start:k_end, :] * sfb_scales[eid, :, bi : bi + 1].T + out += a_block @ b_block + c[start:end] = out + return c + + +def cosine_similarity(a, b): + a_flat = a.flatten().float() + b_flat = b.flatten().float() + return (a_flat @ b_flat) / (a_flat.norm() * b_flat.norm()) + + +def make_offsets(batch_sizes, device): + offsets = torch.zeros(len(batch_sizes) + 1, device=device, dtype=torch.int32) + offsets[1:] = torch.tensor(batch_sizes, device=device, dtype=torch.int32).cumsum(0) + return offsets + + +def run_grouped_mxfp8_blockscaled_gemm( + a, + b, + sfa_flat, + sfb_flat, + offsets, + max_M_per_E, + transpose_B=True, + persistent=True, +): + block_M, block_N, block_K = 128, 256, 128 + in_dtype, out_dtype, accum_dtype = T.float8_e4m3fn, T.bfloat16, T.float + num_stages = 6 + sf_granularity_k = 128 + + m_total, k = a.shape + if transpose_B: + _, n, k2 = b.shape + else: + _, k2, n = b.shape + assert k == k2 + assert n % block_N == 0, f"N={n} not divisible by {block_N}" + assert k % block_K == 0, f"K={k} not divisible by {block_K}" + + kernel = grouped_mxfp8_blockscaled_gemm_2cta_persistent if persistent else grouped_mxfp8_blockscaled_gemm_2cta + return kernel( + a, + b, + sfa_flat, + sfb_flat, + offsets, + block_M, + block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + max_M_per_E, + transpose_B, + sf_granularity_k, + ) + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--batch-sizes", type=str, default="512,1024,1536,2048") + parser.add_argument("--N", type=int, default=8192) + parser.add_argument("--K", type=int, default=8192) + parser.add_argument("--transpose-b", action="store_true", help="Use B as [E, N, K] and compute grouped A @ B.T.") + parser.add_argument("--no-persistent", action="store_true", help="Run the non-persistent 2CTA kernel.") + parser.add_argument("--no-bench", action="store_true") + return parser.parse_args() + + +def main(): + args = parse_args() + batch_sizes = [int(x) for x in args.batch_sizes.split(",") if x.strip()] + transpose_B = args.transpose_b + persistent = not args.no_persistent + device = "cuda" + m_total = sum(batch_sizes) + e = len(batch_sizes) + n = args.N + k = args.K + + offsets = make_offsets(batch_sizes, device) + max_M_per_E = max(batch_sizes) + + x = torch.randn(m_total, k, device=device, dtype=torch.float16) + w_nt = torch.randn(e, n, k, device=device, dtype=torch.float16) + + a, sfa, _ = quantize_fp8_with_packed_ue8m0_rows(x) + b_nt, sfb_2d, _ = quantize_fp8_with_packed_ue8m0_rows(w_nt.view(e * n, k)) + b_nt = b_nt.view(e, n, k).contiguous() + sfb = sfb_2d.view(e, n, -1).contiguous() + + sfa_flat = pack_rows_to_group_major_flat(sfa) + sfb_flat = pack_sfb_to_group_major_flat(sfb) + b = b_nt if transpose_B else b_nt.transpose(1, 2).contiguous() + + c = run_grouped_mxfp8_blockscaled_gemm( + a, + b, + sfa_flat, + sfb_flat, + offsets, + max_M_per_E, + transpose_B, + persistent, + ) + ref_c = grouped_blockscaled_gemm_ref(a, b, sfa, sfb, offsets, transpose_B=transpose_B).to(torch.bfloat16) + sim = cosine_similarity(c, ref_c) + max_abs = (c.float() - ref_c.float()).abs().max().item() + + print(f"Output shape: {c.shape}, dtype: {c.dtype}") + print(f"batch_sizes: {batch_sizes}") + print(f"transpose_B: {transpose_B}") + print(f"persistent: {persistent}") + print(f"Cosine similarity: {sim.item():.6f}") + print(f"Max abs error: {max_abs:.6f}") + assert 1 - sim < 1e-5 + print("grouped blockscaled check passed") + + if not args.no_bench: + latency = do_bench( + lambda: run_grouped_mxfp8_blockscaled_gemm( + a, + b, + sfa_flat, + sfb_flat, + offsets, + max_M_per_E, + transpose_B, + persistent, + ), + backend="cupti", + ) + print(f"Tilelang grouped MXFP8 latency: {latency} ms") + print(f"TFLOPs: {2 * m_total * n * k / (latency / 1e3) / 1e12:.2f}") + + +if __name__ == "__main__": + main() diff --git a/examples/blockscaled_gemm_sm100/mxfp8_illustrated.md b/examples/blockscaled_gemm_sm100/mxfp8_illustrated.md new file mode 100644 index 0000000000..48460bed88 --- /dev/null +++ b/examples/blockscaled_gemm_sm100/mxfp8_illustrated.md @@ -0,0 +1,117 @@ +# SM100 MXFP8 Blockscaled Illustration + +This note explains the Blackwell data path used by +[`gemm_mxfp8_blockscaled_1d1d.py`](gemm_mxfp8_blockscaled_1d1d.py) and the +grouped variants in +[`grouped_gemm_mxfp8_blockscaled_1d1d.py`](grouped_gemm_mxfp8_blockscaled_1d1d.py). +The kernels use 1D-1D MXFP8 block scaling: one `ue8m0` scale for each A row +and B column per 128 K elements. Four adjacent K scale blocks are packed into +one `uint32`. + +![SM100 block-scaled GEMM data path](figures/blockscaled_data_path.svg) + +## Kernel Variants + +![1CTA, 2CTA, and persistent 2CTA comparison](figures/blockscaled_variants.svg) + +| Variant | Function | Launch shape | Tile ownership | Main difference | +| --- | --- | --- | --- | --- | +| 1CTA | `mxfp8_blockscaled_gemm` | `threads=128` | One CTA computes one logical `block_M x block_N` tile. | Local barriers. Warp 0 loads A/B/SF, warp 1 issues UTCCP and MMA, warp 2 transposes SF in SMEM. | +| 2CTA | `mxfp8_blockscaled_gemm_2cta` | `threads=128, cluster_dims=2` | A CTA pair computes one logical `128 x 256` tile. | Each peer loads a half-N B panel, A is loaded in both peers, and leader CTA issues `use_2cta=True`. | +| 2CTA persistent | `mxfp8_blockscaled_gemm_2cta_persistent` | `T.Kernel(sm_num, threads=256, cluster_dims=2)` | Resident CTA pairs walk logical tiles over multiple waves. | Adds a persistent scheduler and dedicated epilogue warpgroup, with `tmem_empty` handing TMEM back to the MMA warp. | + +The grouped kernels reuse the same 2CTA and persistent structure. Their extra +work is scheduler-side: map `(pid_m, pid_n, eid)` through `offsets`, clamp tail +M blocks, and use `SFB[sf_group * E * N + eid * N + n]` for expert-local B +scales. + +## Warp Specialization + +![Warp-specialized pipeline](figures/blockscaled_warp_specialization.svg) + +The examples are warp-specialized by role, but they do not use the +`tcgen05.mma.ws` PTX form. The block-scaled call lowers to +`tcgen05.mma.cta_group::{1,2}.kind::mxf8f6f4.block_scale`; TileLang currently +rejects combining the block-scaled `.ws` variant with 2CTA. + +| Threads | Non-persistent role | Persistent role | Main handoff | +| --- | --- | --- | --- | +| warp 0 | TMA producer for A, B, SFA, SFB | Same, inside the wave loop | Wait `consumed`, arrive `loaded`. | +| warp 1 in leader CTA | UTCCP copy plus `tcgen05.mma.block_scale` issue | Same, but waits `tmem_empty` before each wave | Wait `with_sf_full`, arrive `consumed`, finally arrive `tmem_full`. | +| warp 2 | SF transposer | Same | Wait `loaded`, transpose packed SF chunks, `fence_proxy_async`, arrive `with_sf_full`. | +| all warps / warps 4-7 | Copy `C_tmem` through registers/SMEM to global | Dedicated epilogue warpgroup | Wait `tmem_full`; persistent path arrives `tmem_empty` after reading TMEM. | + +Two details are easy to miss: + +- `tcgen05.mma` has single-thread issue semantics, so one elected lane in the + MMA warp initiates the whole operation. +- `tcgen05.mma` and `tcgen05.cp` access SMEM through the async proxy. After the + SF warp rewrites the SF buffer in normal SMEM, it uses `fence_proxy_async()` + before the MMA warp uses `tcgen05.cp`. + +## Scale-Factor Layout + +Blackwell blockscaled tcgen05 MMA instructions require special layout for scale factors in TMEM. Take K=32, SFA as an example: +![Layout requirement for SFA in TMEM](figures/sfa.png) + +We need process like below to pack the scale factors into the required layout: + +![Scale-factor packing and TMEM layout](figures/blockscaled_sf_layout.svg) + +For `sf_granularity_k = block_K = 128`, the examples load one packed SF word +every four K iterations: + +```text +sf_load_period = sf_granularity_k * 4 / block_K = 4 +sf_k_blocks = ceil(K / 128) +sf_k_groups = ceil(sf_k_blocks / 4) +``` + +The global flat layout is group-major: + +```text +SFA[sf_group * M + m] +SFB[sf_group * N + n] + +word = sf0 | (sf1 << 8) | (sf2 << 16) | (sf3 << 24) +``` + +For grouped GEMM, SFA is still group-major over the concatenated M dimension, +while SFB is group-major over `(E, N)`: + +```text +SFA[sf_group * M_total + m] +SFB[sf_group * E * N + eid * N + n] +``` + +Each SF TMA places a 1D `uint32` vector in SMEM: + +- SFA has `block_M` words, one per output row. +- SFB has `block_N` words, one per output column. +- Each word contains four `ue8m0` bytes for four consecutive 128-wide K + groups. + +`T.tcgen05_sf_warp_transpose` works on each 128-word chunk. It rewrites a +`4 x 32` word view into a `32 x 4` word view, matching the +`tcgen05.cp.32x128b.warpx4` source pattern. `T.tcgen05_cp_warpx4` then copies +one 128-word chunk into four TMEM columns and duplicates it across the four +32-lane TMEM partitions required by block-scaled MMA, which is also required by hardware. + +The resulting TMEM shapes in the 128x256 examples are: + +```text +SFA_tmem: [128 lanes, 4 columns] +SFB_tmem: [128 lanes, 8 columns] # two 128-column N chunks +``` + +During MMA issue, `sf_a_id = k % 4` and `sf_b_id = k % 4` select the active +byte sub-column from the packed `uint32` cell. This is why one SF TMA load +serves four adjacent `block_K=128` MMA iterations. + +## Related + +- TileLang kernels: [`gemm_mxfp8_blockscaled_1d1d.py`](gemm_mxfp8_blockscaled_1d1d.py) +- Grouped kernels: [`grouped_gemm_mxfp8_blockscaled_1d1d.py`](grouped_gemm_mxfp8_blockscaled_1d1d.py) +- TileLang helpers: `T.tcgen05_cp_warpx4`, `T.tcgen05_sf_warp_transpose`, and + `T.tcgen05_gemm_blockscaled` +- PTX document: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#tcgen05-block-scaling diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py index 6e73214522..a93e4de135 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_paged.py @@ -2,7 +2,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -13,160 +12,159 @@ from heuristic import num_splits_heuristic -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, page_block_size, num_stages, threads, num_pages): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func( - block_N, block_H, page_block_size, num_split, num_stages, threads, num_pages, max_num_blocks_per_seq, max_selected_blocks + num_split = T.dynamic("num_split") + max_num_blocks_per_seq = T.dynamic("max_num_blocks_per_seq") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [num_pages, page_block_size, heads_kv, dim] + shape_v = [num_pages, page_block_size, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_block_table = [batch, max_num_blocks_per_seq] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + assert block_N <= page_block_size and page_block_size % block_N == 0 + block_ratio = page_block_size // block_N + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + block_table: T.Tensor(shape_block_table, T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), ): - shape_q = [batch, heads, dim] - shape_k = [num_pages, page_block_size, heads_kv, dim] - shape_v = [num_pages, page_block_size, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_block_table = [batch, max_num_blocks_per_seq] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - assert block_N <= page_block_size and page_block_size % block_N == 0 - block_ratio = page_block_size // block_N - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - block_table: T.Tensor(shape_block_table, T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - logical_block_idx = block_indices[bid, cur_kv_head, start + k] - if logical_block_idx >= 0: - has_valid_block = True - block_table_idx = T.floordiv(logical_block_idx, block_ratio) - block_tile_idx = T.floormod(logical_block_idx, block_ratio) - physical_block_idx = block_table[bid, block_table_idx] - T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + # flash_attn_split + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + logical_block_idx = block_indices[bid, cur_kv_head, start + k] + if logical_block_idx >= 0: + has_valid_block = True + block_table_idx = T.floordiv(logical_block_idx, block_ratio) + block_tile_idx = T.floormod(logical_block_idx, block_ratio) + physical_block_idx = block_table[bid, block_table_idx] + T.copy(K[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else( + logical_block_idx * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j] + ) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[physical_block_idx, block_tile_idx * block_N : (block_tile_idx + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] + lse_local_split = glse[bz, by, k] + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + print(main) + return main class SparseFlashAttn(torch.nn.Module): @@ -181,19 +179,6 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, self.page_block_size = page_block_size self.num_pages = num_pages self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_N, - block_H=self.block_H, - page_block_size=page_block_size, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - num_pages=num_pages, - max_num_blocks_per_seq=T.dynamic("max_num_blocks_per_seq"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -221,16 +206,19 @@ def forward(self, query, key, value, block_indices, cache_seqlens, block_table): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel( - query, - key, - value, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + page_block_size=self.page_block_size, + num_stages=2, + threads=128, + num_pages=self.num_pages, + )(query, key, value, block_indices, cache_seqlens, block_table, glse, output_partial) return output @@ -513,6 +501,8 @@ def main(args): def run_regression_perf(args): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = ( args.batch, args.heads, @@ -524,15 +514,15 @@ def run_regression_perf(args): sparse_ratio = args.sparse_ratio block_N = args.block_N page_block_size = args.page_block_size - num_blocks = args.num_pages + num_pages = args.num_pages max_selected_blocks = int(math.ceil(max_cache_seqlen / block_N)) dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") cache_seqlens = torch.randint(max_cache_seqlen // 2, max_cache_seqlen + 1, (batch,), dtype=torch.int32, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") - K_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") - V_cache = torch.zeros((num_blocks, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") + K_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim), dtype=dtype, device="cuda") + V_cache = torch.zeros((num_pages, page_block_size, heads_kv, dim_v), dtype=dtype, device="cuda") max_num_blocks_per_seq = int(math.ceil(max_cache_seqlen / page_block_size)) block_table = torch.zeros((batch, max_num_blocks_per_seq), dtype=torch.int32, device="cuda") block_indices = torch.zeros((batch, heads_kv, max_selected_blocks), dtype=torch.int32, device="cuda") @@ -596,22 +586,20 @@ def run_regression_perf(args): for i in range(len(selected_blocks), max_selected_blocks): block_indices[seq_idx, head_idx, i] = -1 - sparse_attn = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_blocks) - kernel = sparse_attn.kernel - batch = sparse_attn.batch - heads = sparse_attn.heads - heads_kv = sparse_attn.heads_kv - dim_v = sparse_attn.dim_v - dim = sparse_attn.dim - block_size = sparse_attn.block_N + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, page_block_size, block_N, num_pages) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_N max_selected_blocks = block_indices.shape[-1] - num_m_blocks = 1 * (heads // heads_kv + sparse_attn.block_H - 1) // sparse_attn.block_H + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - - num_sm = sparse_attn.num_sm + num_sm = sparse_kernel.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 @@ -619,18 +607,22 @@ def run_regression_perf(args): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + page_block_size=sparse_kernel.page_block_size, + num_stages=2, + threads=128, + num_pages=sparse_kernel.num_pages, + ) def run_kernel_only(): - kernel( - Q, - K_cache, - V_cache, - block_indices, - cache_seqlens, - block_table, - glse, - output_partial, - ) + kernel(Q, K_cache, V_cache, block_indices, cache_seqlens, block_table, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py index d6cf7d9176..b608adfae9 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_indice.py @@ -10,153 +10,150 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, max_selected_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_indices = [batch, heads_kv, max_selected_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_indices: T.Tensor(shape_indices, T.int32), - cache_seqlens: T.Tensor([batch], T.int32), - # actual_num_blocks: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - # flash_attn_split(Q, K, V, block_indices, cache_seqlens, actual_num_blocks, glse, Output_partial) - # flash_attn_split - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - # O_shared = T.alloc_shared([valid_block_H, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - num_blocks = max_selected_blocks - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - - for k in T.Pipelined(loop_range, num_stages=num_stages): - i_s = block_indices[bid, cur_kv_head, start + k] - if i_s >= 0: - has_valid_block = True - T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + max_selected_blocks = T.dynamic("max_selected_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_indices = [batch, heads_kv, max_selected_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_indices: T.Tensor(shape_indices, T.int32), + cache_seqlens: T.Tensor([batch], T.int32), + # actual_num_blocks: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + num_blocks = max_selected_blocks + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + + for k in T.Pipelined(loop_range, num_stages=num_stages): + i_s = block_indices[bid, cur_kv_head, start + k] + if i_s >= 0: + has_valid_block = True + T.copy(K[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + if k == 0: # assume block_indices is sorted in reverse order, otherwise, remove this if condition for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] - + acc_s[i, j] = T.if_then_else(i_s * block_N + j >= cache_seqlens[bid], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, i_s * block_N : (i_s + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - # combine - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - max_split = T.alloc_var(T.int32) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k + lse_max_local = T.max(lse_max_local, glse[bz, by, k]) + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: + lse_local_split = glse[bz, by, k] + lse_logsum_local += T.exp2(lse_local_split - lse_max_local) + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: + for i in T.Parallel(dim_v): + po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] - if lse_local_split != 0: - max_split = k - lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - - for k in T.Pipelined(num_split, num_stages=1): - if k <= max_split: - lse_local_split = glse[bz, by, k] - lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): - if k <= max_split: - for i in T.Parallel(dim_v): - po_local[i] = Output_partial[bz, by, k, i] - lse_local_split = glse[bz, by, k] - scale_local = T.exp2(lse_local_split - lse_logsum_local) - for i in T.Parallel(dim_v): - o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main - - return kernel_func + scale_local = T.exp2(lse_local_split - lse_logsum_local) + for i in T.Parallel(dim_v): + o_accum_local[i] += po_local[i] * scale_local + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] + + return main class SparseFlashAttn(torch.nn.Module): @@ -168,19 +165,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -208,7 +193,18 @@ def forward(self, query, key, value, block_indices, cache_seqlens): glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + ) + output = kernel(query, key, value, block_indices, cache_seqlens, glse, output_partial) return output @@ -252,14 +248,16 @@ def sparse_gqa_decode_varlen_indice(query, key, value, block_indices, cache_seql glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - max_selected_blocks=T.dynamic("max_selected_blocks"), ) output = kernel(query, key, value, block_indices, cache_seqlens, glse, Output_partial) @@ -311,7 +309,7 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: @@ -324,29 +322,17 @@ def debug(name, expect, actual, atol=1e-3, rtol=1e-3): def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v + dtype = torch.float16 sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) - print("max_selected_blocks: ", max_selected_blocks) - dtype = torch.float16 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') - # # Ensure at least one element equals cache_seqlen - # random_index = torch.randint(0, batch, (1,), device='cuda').item() # Select a random index - # # cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) - # Initialize block_indices with -1 (for padding blocks) block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") - # max_num_blocks = int((max_cache_seqlen + block_size - 1)/ block_size) - # block_indices = torch.full((batch, heads_kv, max_num_blocks), -1, dtype=torch.int32, device='cuda') # Assign valid indices while ensuring no duplicates within each batch-group for b in range(batch): @@ -354,27 +340,17 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 if max_valid_block > 0: # Ensure there's at least one valid block for h in range(heads_kv): valid_indices = torch.randperm(max_valid_block, device="cuda", dtype=torch.int32)[:max_selected_blocks] - # valid_indices = torch.randperm(max_valid_block, device='cuda', dtype=torch.int32)[:max_num_blocks] block_indices[b, h, : len(valid_indices)] = valid_indices - # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) # parity reference ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) - - import flash_attn # noqa: F401 + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) ## latency reference for _ in range(10): @@ -387,12 +363,10 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 print("dense time: ", (time.time() - start) / 100 * 1000) for _ in range(10): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() start = time.time() for _ in range(100): - # out = sparse_gqa_decode_varlen_indice(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, block_size) out = sparse_kernel(Q, K, V, block_indices, cache_seqlens) torch.cuda.synchronize() print("sparse time: ", (time.time() - start) / 100 * 1000) @@ -428,24 +402,9 @@ def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, di dim_v = sparse_kernel.dim_v dim = sparse_kernel.dim block_size = sparse_kernel.block_size - max_selected_blocks = block_indices.shape[-1] - - num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H - num_n_blocks = max_selected_blocks - size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 - total_mblocks = batch * heads_kv * num_m_blocks - num_sm = sparse_kernel.num_sm - - num_split = num_splits_heuristic( - total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 - ) - - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = sparse_kernel.kernel def run_kernel_only(): - kernel(Q, K, V, block_indices, cache_seqlens, glse, output_partial) + sparse_kernel(Q, K, V, block_indices, cache_seqlens) return do_bench(run_kernel_only, backend="cupti") diff --git a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py index e48428fb89..e588ec54cc 100644 --- a/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py +++ b/examples/blocksparse_attention/example_tilelang_sparse_gqa_decode_varlen_mask.py @@ -1,7 +1,6 @@ import torch import torch.nn.functional as F import tilelang -from tilelang.autotuner import * import tilelang.language as T from einops import rearrange, einsum import argparse @@ -11,137 +10,144 @@ from tilelang.profiler import do_bench -def flashattn(batch, heads, heads_kv, dim, dim_v): +@tilelang.jit( + out_idx=[-1], + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def flashattn(batch, heads, heads_kv, dim, dim_v, block_N, block_H, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) dtype = T.float16 accum_dtype = T.float32 kv_group_num = heads // heads_kv - @tilelang.jit( - out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, - ) - def kernel_func(block_N, block_H, num_split, num_stages, threads, max_cache_seqlen, num_blocks): - shape_q = [batch, heads, dim] - shape_k = [batch, max_cache_seqlen, heads_kv, dim] - shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] - shape_mask = [batch, heads_kv, num_blocks] - shape_o = [batch, heads, dim_v] - part_shape = [batch, heads, num_split, dim_v] - valid_block_H = min(block_H, kv_group_num) - - @T.prim_func - def main( - Q: T.Tensor(shape_q, dtype), - K: T.Tensor(shape_k, dtype), - V: T.Tensor(shape_v, dtype), - block_mask: T.Tensor(shape_mask, T.bool), - cache_seqlens: T.Tensor([batch], T.int32), - glse: T.Tensor([batch, heads, num_split], accum_dtype), - Output_partial: T.Tensor(part_shape, accum_dtype), - Output: T.Tensor(shape_o, dtype), - ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): - Q_shared = T.alloc_shared([block_H, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim_v], dtype) - acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) - acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) - - scores_max = T.alloc_fragment([block_H], accum_dtype) - scores_max_prev = T.alloc_fragment([block_H], accum_dtype) - scores_scale = T.alloc_fragment([block_H], accum_dtype) - scores_sum = T.alloc_fragment([block_H], accum_dtype) - logsum = T.alloc_fragment([block_H], accum_dtype) - has_valid_block = T.alloc_var("bool") - - bid = bx - hid = by - sid = bz - cur_kv_head = hid // (kv_group_num // valid_block_H) - - T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - blocks_per_split = T.floordiv(num_blocks, num_split) - remaining_blocks = T.floormod(num_blocks, num_split) - loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) - start = blocks_per_split * sid + T.min(sid, remaining_blocks) - has_valid_block = False - for k in T.Pipelined(loop_range, num_stages=num_stages): - if block_mask[bid, hid, start + k]: - has_valid_block = True - T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) - T.clear(acc_s) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.if_then_else( - (start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j] - ) - T.copy(scores_max, scores_max_prev) - T.fill(scores_max, -T.infinity(accum_dtype)) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_H): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_H, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_H): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - T.copy(acc_s, acc_s_cast) - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] *= scores_scale[i] - T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - if has_valid_block: - for i, j in T.Parallel(block_H, dim_v): - acc_o[i, j] /= logsum[i] + num_split = T.dynamic("num_split") + max_cache_seqlen = T.dynamic("max_cache_seqlen") + num_blocks = T.dynamic("num_blocks") + + shape_q = [batch, heads, dim] + shape_k = [batch, max_cache_seqlen, heads_kv, dim] + shape_v = [batch, max_cache_seqlen, heads_kv, dim_v] + shape_mask = [batch, heads_kv, num_blocks] + shape_o = [batch, heads, dim_v] + part_shape = [batch, heads, num_split, dim_v] + valid_block_H = min(block_H, kv_group_num) + + @T.prim_func + def main( + Q: T.Tensor(shape_q, dtype), + K: T.Tensor(shape_k, dtype), + V: T.Tensor(shape_v, dtype), + block_mask: T.Tensor(shape_mask, T.bool), + cache_seqlens: T.Tensor([batch], T.int32), + glse: T.Tensor([batch, heads, num_split], accum_dtype), + Output_partial: T.Tensor(part_shape, accum_dtype), + Output: T.Tensor(shape_o, dtype), + ): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_H, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim_v], dtype) + acc_s = T.alloc_fragment([block_H, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_H, block_N], dtype) + acc_o = T.alloc_fragment([block_H, dim_v], accum_dtype) + + scores_max = T.alloc_fragment([block_H], accum_dtype) + scores_max_prev = T.alloc_fragment([block_H], accum_dtype) + scores_scale = T.alloc_fragment([block_H], accum_dtype) + scores_sum = T.alloc_fragment([block_H], accum_dtype) + logsum = T.alloc_fragment([block_H], accum_dtype) + has_valid_block = T.alloc_var(T.bool) + + bid = bx + hid = by + sid = bz + cur_kv_head = hid // (kv_group_num // valid_block_H) + + T.copy(Q[bid, hid * valid_block_H : hid * valid_block_H + block_H, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + blocks_per_split = T.floordiv(num_blocks, num_split) + remaining_blocks = T.floormod(num_blocks, num_split) + loop_range = blocks_per_split + T.if_then_else(sid < remaining_blocks, 1, 0) + start = blocks_per_split * sid + T.min(sid, remaining_blocks) + has_valid_block = False + for k in T.Pipelined(loop_range, num_stages=num_stages): + if block_mask[bid, hid, start + k]: + has_valid_block = True + T.copy(K[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], K_shared) + T.clear(acc_s) + T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.if_then_else((start + k) * block_N + j >= cache_seqlens[bx], -T.infinity(accum_dtype), acc_s[i, j]) + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) for i in T.Parallel(block_H): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - - for i in T.Parallel(block_H): - if i < valid_block_H: - glse[bid, hid * valid_block_H + i, sid] = logsum[i] - + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_H, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_H): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + T.copy(acc_s, acc_s_cast) + for i, j in T.Parallel(block_H, dim_v): + acc_o[i, j] *= scores_scale[i] + T.copy(V[bid, (start + k) * block_N : (start + k + 1) * block_N, cur_kv_head, :], V_shared) + T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + if has_valid_block: for i, j in T.Parallel(block_H, dim_v): - if i < valid_block_H: - Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] - - with T.Kernel(heads, batch, threads=128) as (by, bz): - po_local = T.alloc_fragment([dim_v], accum_dtype) - o_accum_local = T.alloc_fragment([dim_v], accum_dtype) - lse_local_split = T.alloc_var(accum_dtype) - lse_logsum_local = T.alloc_var(accum_dtype) - lse_max_local = T.alloc_var(accum_dtype) - scale_local = T.alloc_var(accum_dtype) - - T.clear(lse_logsum_local) - T.clear(o_accum_local) - lse_max_local = -T.infinity(accum_dtype) - for k in T.serial(num_split): + acc_o[i, j] /= logsum[i] + for i in T.Parallel(block_H): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + + # TODO(lei): Support T.Parallel(valid_block_H) + for i in T.Parallel(block_H): + if i < valid_block_H: + glse[bid, hid * valid_block_H + i, sid] = logsum[i] + for i, j in T.Parallel(block_H, dim_v): + if i < valid_block_H: + Output_partial[bid, hid * valid_block_H + i, sid, j] = acc_o[i, j] + + # combine + with T.Kernel(heads, batch, threads=128) as (by, bz): + po_local = T.alloc_fragment([dim_v], accum_dtype) + o_accum_local = T.alloc_fragment([dim_v], accum_dtype) + lse_local_split = T.alloc_var(accum_dtype) + lse_logsum_local = T.alloc_var(accum_dtype) + lse_max_local = T.alloc_var(accum_dtype) + scale_local = T.alloc_var(accum_dtype) + max_split = T.alloc_var(T.int32) + + T.clear(lse_logsum_local) + T.clear(o_accum_local) + lse_max_local = -T.infinity(accum_dtype) + for k in T.serial(num_split): + lse_local_split = glse[bz, by, k] + if lse_local_split != 0: + max_split = k lse_max_local = T.max(lse_max_local, glse[bz, by, k]) - for k in T.Pipelined(num_split, num_stages=1): + + for k in T.Pipelined(num_split, num_stages=1): + if k <= max_split: lse_local_split = glse[bz, by, k] lse_logsum_local += T.exp2(lse_local_split - lse_max_local) - lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local - for k in T.serial(num_split): + lse_logsum_local = T.log2(lse_logsum_local) + lse_max_local + for k in T.serial(num_split): + if k <= max_split: for i in T.Parallel(dim_v): po_local[i] = Output_partial[bz, by, k, i] lse_local_split = glse[bz, by, k] scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim_v): o_accum_local[i] += po_local[i] * scale_local - for i in T.Parallel(dim_v): - Output[bz, by, i] = o_accum_local[i] - - return main + for i in T.Parallel(dim_v): + Output[bz, by, i] = o_accum_local[i] - return kernel_func + return main class SparseFlashAttn(torch.nn.Module): @@ -153,19 +159,7 @@ def __init__(self, batch, heads, heads_kv, dim, dim_v, block_size): self.dim = dim self.dim_v = dim_v self.block_size = block_size - self.block_H = 64 - - self.kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( - block_N=block_size, - block_H=self.block_H, - num_split=T.dynamic("num_split"), - num_stages=2, - threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), - ) - props = torch.cuda.get_device_properties(torch.device("cuda:0")) self.num_sm = props.multi_processor_count @@ -176,24 +170,33 @@ def forward(self, query, key, value, block_mask, cache_seqlens): dim_v = self.dim_v dim = self.dim block_size = self.block_size - block_H = self.block_H max_cache_seqlen = key.shape[1] # get num_split max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H + num_m_blocks = 1 * (heads // heads_kv + self.block_H - 1) // self.block_H num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 # kv_seqlen * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - # num_sm = 132 num_sm = self.num_sm num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - # print("num_split: ", num_split) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - output = self.kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + + output = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=self.block_H, + num_stages=2, + threads=128, + )(query, key, value, block_mask, cache_seqlens, glse, output_partial) return output @@ -233,21 +236,21 @@ def sparse_gqa_decode_varlen_mask(query, key, value, block_mask, cache_seqlens, total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) - kernel = flashattn(batch, heads, heads_kv, dim, dim_v)( + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") + Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, block_N=block_size, block_H=block_H, - num_split=T.dynamic("num_split"), num_stages=2, threads=128, - max_cache_seqlen=T.dynamic("max_cache_seqlen"), - num_blocks=T.dynamic("num_blocks"), ) - glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - # print(kernel.get_kernel_source()) output = kernel(query, key, value, block_mask, cache_seqlens, glse, Output_partial) - return output @@ -297,12 +300,10 @@ def ref_program_fa(query, key, value, block_indices, cache_seqlens, max_cache_se return output -def debug(name, expect, actual, atol=1e-3, rtol=1e-3): +def assert_close(name, expect, actual, atol=1e-3, rtol=1e-3): all_close = torch.allclose(expect, actual, atol=atol, rtol=rtol) print(name + " all_close={}".format(all_close)) if not all_close: - # print(expect[3, 28]) - # print(actual[3, 28]) diff = (expect - actual).abs() print("all_close={}, max={}, min={}, mean={}".format(all_close, diff.max().item(), diff.min().item(), diff.mean().item())) max_indices = torch.nonzero(diff == diff.max().item()) @@ -353,7 +354,7 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 # out = sparse_gqa_decode_varlen_mask(Q, K, V, block_mask, cache_seqlens, block_size) model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) out = model(Q, K, V, block_mask, cache_seqlens) - debug("output", ref, out, atol=1e-3, rtol=1e-3) + assert_close("output", ref, out, atol=1e-3, rtol=1e-3) import flash_attn # noqa: F401 @@ -381,12 +382,13 @@ def main(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=12 def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32): + torch.manual_seed(42) + torch.cuda.manual_seed_all(42) batch, heads, heads_kv, max_cache_seqlen, dim, dim_v = batch, heads, heads_kv, max_cache_seqlen, dim, dim_v sparse_ratio = sparse_ratio block_size = block_size max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) dtype = torch.float16 - Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") @@ -408,31 +410,41 @@ def run_regression_perf(batch=8, heads=32, heads_kv=8, max_cache_seqlen=8192, di perm = torch.randperm(max_valid_block, device="cuda")[:valid_num_block] block_mask[b, h, perm] = True - model = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) - batch = model.batch - heads = model.heads - heads_kv = model.heads_kv - dim_v = model.dim_v - dim = model.dim - block_size = model.block_size - block_H = model.block_H - max_cache_seqlen = K.shape[1] + sparse_kernel = SparseFlashAttn(batch, heads, heads_kv, dim, dim_v, block_size) + batch = sparse_kernel.batch + heads = sparse_kernel.heads + heads_kv = sparse_kernel.heads_kv + dim_v = sparse_kernel.dim_v + dim = sparse_kernel.dim + block_size = sparse_kernel.block_size max_selected_blocks = (max_cache_seqlen + block_size - 1) // block_size - num_m_blocks = 1 * (heads // heads_kv + block_H - 1) // block_H - num_n_blocks = max_selected_blocks + num_m_blocks = 1 * (heads // heads_kv + sparse_kernel.block_H - 1) // sparse_kernel.block_H + num_n_blocks = max_selected_blocks size_one_kv_head = max_selected_blocks * block_size * (dim + dim_v) * 2 total_mblocks = batch * heads_kv * num_m_blocks - num_sm = model.num_sm + num_sm = sparse_kernel.num_sm + num_split = num_splits_heuristic( total_mblocks, num_sm, num_n_blocks, num_m_blocks, size_one_kv_head, is_causal_or_local=True, max_splits=128 ) + glse = torch.empty((batch, heads, num_split), dtype=torch.float32, device="cuda") - Output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") - kernel = model.kernel + output_partial = torch.empty((batch, heads, num_split, dim_v), dtype=torch.float32, device="cuda") + kernel = flashattn( + batch, + heads, + heads_kv, + dim, + dim_v, + block_N=block_size, + block_H=sparse_kernel.block_H, + num_stages=2, + threads=128, + ) def run_kernel_only(): - kernel(Q, K, V, block_mask, cache_seqlens, glse, Output_partial) + kernel(Q, K, V, block_mask, cache_seqlens, glse, output_partial) return do_bench(run_kernel_only, backend="cupti") diff --git a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py index 01695742b5..91d85a1a43 100644 --- a/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py +++ b/examples/blocksparse_attention/example_triton_sparse_gqa_decode_varlen_indice.py @@ -329,21 +329,15 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 max_selected_blocks = int(math.ceil(max_cache_seqlen * (1 - sparse_ratio) / block_size)) print("max_selected_blocks: ", max_selected_blocks) dtype = torch.float16 - block_H = 64 Q = torch.randn((batch, heads, dim), dtype=dtype, device="cuda") K = torch.randn((batch, max_cache_seqlen, heads_kv, dim), dtype=dtype, device="cuda") V = torch.randn((batch, max_cache_seqlen, heads_kv, dim_v), dtype=dtype, device="cuda") cache_seqlens = torch.randint(1, max_cache_seqlen, (batch,), dtype=torch.int32, device="cuda") - # cache_seqlens = torch.full((batch,), max_cache_seqlen, dtype=torch.int32, device='cuda') # Ensure at least one element equals cache_seqlen random_index = torch.randint(0, batch, (1,), device="cuda").item() # Select a random index cache_seqlens[random_index] = max_cache_seqlen # Assign cache_seqlen to ensure at least one occurrence - - print("cache_seqlens: ", cache_seqlens) - max_valid_num_blocks = torch.ceil(cache_seqlens / block_size).int() - print("max_valid_num_blocks: ", max_valid_num_blocks) # Initialize block_indices with -1 (for padding blocks) block_indices = torch.full((batch, heads_kv, max_selected_blocks), -1, dtype=torch.int32, device="cuda") @@ -357,13 +351,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 # Sort indices within each batch-group for consistency block_indices, _ = block_indices.sort(dim=-1, descending=True) - # print("block_indices: ", block_indices) - actual_num_blocks = torch.sum(block_indices != -1, dim=-1).to(torch.int32)[:, 0] - print("actual_num_blocks: ", actual_num_blocks) - # print(block_indices.shape, actual_num_blocks.shape) - max_num_blocks = torch.max(max_valid_num_blocks).item() - print("max_num_blocks: ", max_num_blocks) ref = ref_program_torch(Q, K, V, block_indices, cache_seqlens, max_cache_seqlen, max_num_blocks, block_size) @@ -402,6 +390,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 avg_time = elapsed_time / 1000 avg_flops = total_flops / avg_time print(f"Average time: {avg_time:.6f} seconds") + print(f"Average FLOPS: {avg_flops:.2f} GFLOPS") # Measure performance of reference implementation import flash_attn # noqa: F401 @@ -415,7 +404,7 @@ def main(batch=64, heads=32, heads_kv=8, max_cache_seqlen=8192, dim=128, dim_v=1 avg_time_ref = elapsed_time_ref / 1000 avg_flops_ref = total_flops / avg_time_ref print(f"Average time of ref: {avg_time_ref:.6f} seconds") - + print(f"Average FLOPS of ref: {avg_flops_ref:.2f} GFLOPS") print(f"Speedup: {avg_time_ref / avg_time:.2f}x") diff --git a/examples/blocksparse_attention/test_example_blocksparse_attention.py b/examples/blocksparse_attention/test_example_blocksparse_attention.py index dd33f46c4e..0144f2004a 100644 --- a/examples/blocksparse_attention/test_example_blocksparse_attention.py +++ b/examples/blocksparse_attention/test_example_blocksparse_attention.py @@ -3,8 +3,6 @@ import example_tilelang_block_sparse_attn import example_tilelang_sparse_gqa_decode_varlen_indice import example_tilelang_sparse_gqa_decode_varlen_mask -import example_triton_sparse_gqa_decode_varlen_indice -import example_triton_sparse_gqa_decode_varlen_mask def test_block_sparse_attn_triton(): @@ -23,17 +21,5 @@ def test_example_tilelang_sparse_gqa_decode_varlen_mask(): example_tilelang_sparse_gqa_decode_varlen_mask.main(batch=1, max_cache_seqlen=2048) -def test_example_triton_sparse_gqa_decode_varlen_indice(): - example_triton_sparse_gqa_decode_varlen_indice.main( - batch=8, heads=8, heads_kv=4, max_cache_seqlen=2048, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 - ) - - -def test_example_triton_sparse_gqa_decode_varlen_mask(): - example_triton_sparse_gqa_decode_varlen_mask.main( - batch=16, heads=16, heads_kv=8, max_cache_seqlen=1024, dim=128, dim_v=128, sparse_ratio=0.8, block_size=32 - ) - - if __name__ == "__main__": tilelang.testing.main() diff --git a/examples/blocksparse_gemm/example_blocksparse_gemm.py b/examples/blocksparse_gemm/example_blocksparse_gemm.py index 178cc59842..9defb72882 100644 --- a/examples/blocksparse_gemm/example_blocksparse_gemm.py +++ b/examples/blocksparse_gemm/example_blocksparse_gemm.py @@ -2,10 +2,8 @@ import itertools import tilelang import tilelang.language as T -from tilelang.engine.param import KernelParam from tilelang.utils.tensor import get_tensor_supply, TensorSupplyType import torch -from typing import List from tilelang.profiler import do_bench DEFAULT_BLOCK_M = 128 @@ -14,24 +12,8 @@ DEFAULT_NUM_STAGES = 2 DEFAULT_THREAD_NUM = 128 DEFAULT_ENABLE_RASTERIZATION = True - -parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") -parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") -parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") -parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") -parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") -parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") - -args, _ = parser.parse_known_args() -M, N, K = args.m, args.n, args.k -sparsity = args.sparsity -use_autotune = args.use_autotune default_tensor_supply = get_tensor_supply(TensorSupplyType.Auto) -print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") -print(f"Target Block Sparsity: {sparsity}") -print(f"Using Autotuner: {use_autotune}\n") - def get_configs(): block_M = [64, 128, 256] @@ -57,6 +39,8 @@ def get_configs(): def ref_program(A, B, BlockMask, block_M, block_N, block_K): + M, K = A.shape + _, N = B.shape ref_c = torch.zeros((M, N), dtype=torch.float16, device=A.device) for i in range(M // block_M): for j in range(N // block_N): @@ -70,25 +54,6 @@ def ref_program(A, B, BlockMask, block_M, block_N, block_K): return ref_c -def supply_program(params: List[KernelParam]): - input_tensors = [] - - for p in params: - # Check if the kernel parameter is BlockMask tensor. - # Here, BlockMask is uniquely identified by having 3 dimensions. - if len(p.shape) != 3: - # For non-BlockMask tensors, use the default tensor generation logic. - input_tensors.append(default_tensor_supply(p)) - else: - # For BlockMask tensor, randomly set elements to True based on desired - # sparsity level. - block_mask = torch.zeros(p.shape, dtype=torch.bool, device=torch.cuda.current_device()) - block_mask[:, :, :] = torch.rand(p.shape) > sparsity - input_tensors.append(block_mask) - - return input_tensors - - @tilelang.autotune( configs=get_configs(), ) @@ -127,6 +92,20 @@ def block_sparse_matmul( def main(): + parser = argparse.ArgumentParser(description="Autotuned BlockSparse MatMul Benchmark") + parser.add_argument("--m", type=int, default=1024, help="Matrix dimension M") + parser.add_argument("--n", type=int, default=1024, help="Matrix dimension N") + parser.add_argument("--k", type=int, default=1024, help="Matrix dimension K") + parser.add_argument("--sparsity", type=float, default=0.5, help="Sparsity ratio (0-1)") + parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune") + + args, _ = parser.parse_known_args() + M, N, K = args.m, args.n, args.k + sparsity = args.sparsity + use_autotune = args.use_autotune + print(f"Running BlockSparse MatMul Benchmark for M={M}, N={N}, K={K}") + print(f"Target Block Sparsity: {sparsity}") + print(f"Using Autotuner: {use_autotune}\n") # Initialize input matrices A and B on the GPU with half precision a = torch.randn(M, K).cuda().half() b = torch.randn(K, N).cuda().half() @@ -158,6 +137,7 @@ def main(): ) block_M, block_N, block_K = DEFAULT_BLOCK_M, DEFAULT_BLOCK_N, DEFAULT_BLOCK_K print(f"Using default kernel with block size ({block_M}, {block_N}, {block_K})") + # Create block mask with desired sparsity mask_shape = (M // block_M, N // block_N, K // block_K) block_mask = torch.rand(mask_shape).cuda() > sparsity @@ -177,6 +157,8 @@ def main(): def run_regression_perf(): + M = N = K = 1024 + sparsity = 0.5 torch.manual_seed(42) torch.cuda.manual_seed_all(42) a = torch.randn(M, K).cuda().half() diff --git a/examples/cast/example_per_token_cast_to_fp8.py b/examples/cast/example_per_token_cast_to_fp8.py index 4b3730b4b9..693e90d30a 100644 --- a/examples/cast/example_per_token_cast_to_fp8.py +++ b/examples/cast/example_per_token_cast_to_fp8.py @@ -92,21 +92,15 @@ def main(M=8192, N=8192, blk_m=8): print("Tile-lang: {:.2f} ms".format(latency)) from tilelang.profiler import do_bench + from example_triton_cast_to_fp8 import per_token_group_quant_fp8 - # Triton fp8e4nv is only supported on Hopper (SM90) and later - major, _ = torch.cuda.get_device_capability() - if major >= 9: - from example_triton_cast_to_fp8 import per_token_group_quant_fp8 + def run_triton(): + x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) + return x_fp8_triton_, x_amax_triton_ - def run_triton(): - x_fp8_triton_, x_amax_triton_ = per_token_group_quant_fp8(x, 128, 1e-4, dtype=torch.float8_e4m3fn, column_major_scales=False) - return x_fp8_triton_, x_amax_triton_ - - x_fp8_triton, x_amax_triton = run_triton() - latency = do_bench(run_triton) - print("Triton: {:.2f} ms".format(latency)) - else: - print("Triton fp8e4nv benchmark skipped (requires SM90+)") + x_fp8_triton, x_amax_triton = run_triton() + latency = do_bench(run_triton) + print("Triton: {:.2f} ms".format(latency)) def run_regression_perf(M=8192, N=8192, blk_m=8): diff --git a/examples/conftest.py b/examples/conftest.py index 4010e0d83a..afc122b6c2 100644 --- a/examples/conftest.py +++ b/examples/conftest.py @@ -21,6 +21,42 @@ np.random.seed(0) +# --------------------------------------------------------------------------- +# CuTeDSL backend: auto-mark known failures / unsupported tests +# --------------------------------------------------------------------------- + +# Known failures when running with TILELANG_TARGET=cutedsl. +# These are marked as xfail(strict=False) so unexpected passes are reported. +CUTEDSL_KNOWN_FAILURES = { + # Unimplemented sparse ops: tl.tl_gemm_sp + "sparse_tensorcore/test_example_sparse_tensorcore.py::test_tilelang_example_sparse_tensorcore", + "gemm_sp/test_example_gemm_sp.py::test_example_gemm_sp", + # Flaky — passes when run in isolation, fails under parallel execution + "minference/test_vs_sparse_attn.py::test_vs_sparse_attn", +} + + +def _match_any(nodeid, patterns): + """Return True if *nodeid* contains any of the *patterns*.""" + return any(p in nodeid for p in patterns) + + +def pytest_collection_modifyitems(config, items): # noqa: ARG001 + """When TILELANG_TARGET=cutedsl, annotate known-bad tests automatically.""" + if os.environ.get("TILELANG_TARGET") != "cutedsl": + return + + for item in items: + nid = item.nodeid + if _match_any(nid, CUTEDSL_KNOWN_FAILURES): + item.add_marker( + pytest.mark.xfail( + reason="CuTeDSL: known limitation (unimplemented op or flaky)", + strict=False, + ) + ) + + def pytest_terminal_summary(terminalreporter, exitstatus, config): """Ensure that at least one test is collected. Error out if all tests are skipped.""" known_types = { diff --git a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py index 18467a8118..22ce27de18 100644 --- a/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py +++ b/examples/deepseek_deepgemm/example_deepgemm_fp8_2xAcc.py @@ -4,7 +4,6 @@ import tilelang.testing import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type tilelang.testing.set_random_seed(42) @@ -28,7 +27,7 @@ def tl_gemm( ], "Currently only float16 and float32 are supported" group_size = 128 - block_M = 128 + block_M = 64 block_K = 128 A_shape = (M, K) @@ -51,7 +50,6 @@ def main( A_shared = T.alloc_shared(A_shared_shape, in_dtype) B_shared = T.alloc_shared(B_shared_shape, in_dtype) C_shared = T.alloc_shared(C_shared_shape, out_dtype) - Scale_C_shared = T.alloc_shared((block_M), T.float32) C_local = T.alloc_fragment(C_shared_shape, accum_dtype) C_local_accum = T.alloc_fragment(C_shared_shape, accum_dtype) @@ -66,15 +64,12 @@ def main( T.copy(A[by * block_M, k * block_K], A_shared) # Load B into shared memory T.copy(B[bx * block_N, k * block_K], B_shared) - # Load scale into shared memory Scale_B = scales_b[bx * block_N // group_size, k] - for i in T.Parallel(block_M): - Scale_C_shared[i] = scales_a[by * block_M + i, k] * Scale_B T.gemm(A_shared, B_shared, C_local, transpose_B=True) # Promote to enable 2xAcc for i, j in T.Parallel(block_M, block_N): - C_local_accum[i, j] += C_local[i, j] * Scale_C_shared[i] + C_local_accum[i, j] += C_local[i, j] * (scales_a[by * block_M + i, k] * Scale_B) T.clear(C_local) # TMA store T.copy(C_local_accum, C_shared) @@ -148,9 +143,9 @@ def assert_tl_gemm_correctness(M, N, K, block_N, in_dtype, out_dtype, accum_dtyp # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + accum_dtype = accum_dtype.as_torch() A = torch.randn(M, K).to(torch.bfloat16).cuda() B = torch.randn(N, K).to(torch.bfloat16).cuda() diff --git a/examples/deepseek_mhc/example_mhc_bwd.py b/examples/deepseek_mhc/example_mhc_bwd.py new file mode 100644 index 0000000000..2961d87952 --- /dev/null +++ b/examples/deepseek_mhc/example_mhc_bwd.py @@ -0,0 +1,283 @@ +# NOTE: This bwd script is not an official upstream script; it is community-written and provided for reference only. +# checkout pr: https://github.com/tile-ai/tilelang/pull/1758 +import torch + +import tilelang +import tilelang.language as T +from tilelang.autotuner import set_autotune_inputs +from tqdm import trange + + +dtype = torch.float32 + +seqlen = 65536 +n_stream = 16 +iters = 100 +repeat = 512 + +EPS = 1e-10 + + +def sinkhorn_forward(M, iters=20): + P = torch.exp(M) + R = P + + for _ in range(iters): + R = R / R.sum(-2, keepdim=True) + R = R / R.sum(-1, keepdim=True) + + return R, P + + +def sinkhorn_bwd_configs(n_stream, seqlen): + """Generate autotune configurations for different tilesize and threads""" + configs = [] + + # Explore different tile sizes and thread counts + tilesizes = [1, 2, 4, 8, 16, 32, 64] + thread_counts = [32, 64, 128, 256] + + for tilesize in tilesizes: + # Skip if tilesize doesn't divide seqlen evenly (optional constraint) + if seqlen % tilesize != 0: + continue + + for threads in thread_counts: + configs.append({"tilesize": tilesize, "threads": threads}) + + return configs + + +@tilelang.autotune( + configs=sinkhorn_bwd_configs(n_stream, seqlen), + warmup=4, + rep=repeat, +) +@tilelang.jit( + out_idx=[2], + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, +) +def sinkhorn_bwd_implicit_cg(n_stream: int, tilesize: int = 32, threads: int = 128): + seqlen = T.dynamic("seqlen") + tensor_shape = [seqlen, n_stream, n_stream] + dtype = T.float32 + + @T.macro + def matvec_A(R, x1, x2, buf, y1, y2): + for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream): + buf[i_tile, i, j] = R[i_tile, i, j] * x2[i_tile, j] + T.reduce_sum(buf, y1, dim=-1) + + for i_tile, i, j in T.Parallel(tilesize, n_stream, n_stream): + buf[i_tile, i, j] = R[i_tile, i, j] * x1[i_tile, i] + T.reduce_sum(buf, y2, dim=-2) + + for i_tile, i in T.Parallel(tilesize, n_stream): + y1[i_tile, i] += x1[i_tile, i] + y2[i_tile, i] += x2[i_tile, i] + + @T.macro + def dot(x1, x2, y1, y2, buf, out): + for i_tile, i in T.Parallel(tilesize, n_stream): + buf[i_tile, i] = x1[i_tile, i] * y1[i_tile, i] + x2[i_tile, i] * y2[i_tile, i] + + T.reduce_sum(buf, out, dim=-1) + + @T.prim_func + def main( + out: T.Tensor(tensor_shape, dtype), + dout: T.Tensor(tensor_shape, dtype), + res: T.Tensor(tensor_shape, dtype), + ): + with T.Kernel(T.ceildiv(seqlen, tilesize), threads=threads) as i_seq: + R = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) + dR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) + RdR = T.alloc_fragment([tilesize, n_stream, n_stream], dtype=dtype) + res_tile = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) + b1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + b2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + x1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + x2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + r1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + r2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + p1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + p2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + alpha = T.alloc_fragment([tilesize, n_stream], dtype=dtype) + beta = T.alloc_fragment([tilesize, n_stream], dtype=dtype) + r_normsq = T.alloc_fragment([tilesize], dtype=dtype) + r_new_normsq = T.alloc_fragment([tilesize], dtype=dtype) + Ap1 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + Ap2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + pAp = T.alloc_fragment([tilesize], dtype=dtype) + + # Buffers for intermediate results + buf1 = T.alloc_shared([tilesize, n_stream, n_stream], dtype=dtype) + buf2 = T.alloc_shared([tilesize, n_stream], dtype=dtype) + + T.copy(out[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], R) + T.copy(dout[i_seq * tilesize : (i_seq + 1) * tilesize, :, :], dR) + + for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream): + RdR[i_tile, i_nx, i_ny] = R[i_tile, i_nx, i_ny] * dR[i_tile, i_nx, i_ny] + + T.reduce_sum(RdR, b1, dim=-1) + T.reduce_sum(RdR, b2, dim=-2) + + T.fill(x1, 0.0) + T.fill(x2, 0.0) + + matvec_A(R, x1, x2, buf1, r1, r2) + + for i_tile, i_n in T.Parallel(tilesize, n_stream): + r1[i_tile, i_n] = b1[i_tile, i_n] - r1[i_tile, i_n] + + for i_tile, i_n in T.Parallel(tilesize, n_stream): + r2[i_tile, i_n] = b2[i_tile, i_n] - r2[i_tile, i_n] + + T.copy(r1, p1) + T.copy(r2, p2) + + dot(r1, r2, r1, r2, buf2, r_normsq) + + # Conjugate gradient: iteration starts + for _ in T.serial(2 * n_stream): + matvec_A(R, p1, p2, buf1, Ap1, Ap2) + + dot(p1, p2, Ap1, Ap2, buf2, pAp) + + for i_tile, i_n in T.Parallel(tilesize, n_stream): + # VERY important to avoid divide by zero + alpha[i_tile, i_n] = r_normsq[i_tile] / (pAp[i_tile] + EPS) + for i_tile, i_n in T.Parallel(tilesize, n_stream): + x1[i_tile, i_n] += alpha[i_tile, i_n] * p1[i_tile, i_n] + for i_tile, i_n in T.Parallel(tilesize, n_stream): + x2[i_tile, i_n] += alpha[i_tile, i_n] * p2[i_tile, i_n] + for i_tile, i_n in T.Parallel(tilesize, n_stream): + r1[i_tile, i_n] -= alpha[i_tile, i_n] * Ap1[i_tile, i_n] + for i_tile, i_n in T.Parallel(tilesize, n_stream): + r2[i_tile, i_n] -= alpha[i_tile, i_n] * Ap2[i_tile, i_n] + + dot(r1, r2, r1, r2, buf2, r_new_normsq) + + for i_tile, i_n in T.Parallel(tilesize, n_stream): + # not very important to avoid divide by zero, but it's good to have it + beta[i_tile, i_n] = r_new_normsq[i_tile] / (r_normsq[i_tile] + EPS) + for i_tile, i_n in T.Parallel(tilesize, n_stream): + p1[i_tile, i_n] = r1[i_tile, i_n] + beta[i_tile, i_n] * p1[i_tile, i_n] + for i_tile, i_n in T.Parallel(tilesize, n_stream): + p2[i_tile, i_n] = r2[i_tile, i_n] + beta[i_tile, i_n] * p2[i_tile, i_n] + + T.copy(r_new_normsq, r_normsq) + # Conjugate gradient: iteration ends + + for i_tile, i_nx, i_ny in T.Parallel(tilesize, n_stream, n_stream): + res_tile[i_tile, i_nx, i_ny] = (dR[i_tile, i_nx, i_ny] - x1[i_tile, i_nx] - x2[i_tile, i_ny]) * R[i_tile, i_nx, i_ny] + + T.copy(res_tile, res[i_seq * tilesize : (i_seq + 1) * tilesize, :, :]) + + return main + + +def main(): + print("Autotuning TileLang kernel for sinkhorn backward pass") + print(f"{seqlen = }") + print(f"{n_stream = }") + print(f"{iters = }") + print(f"{repeat = }") + + ###################################################################### + # Variable + ###################################################################### + dist = torch.distributions.uniform.Uniform(0.0, 4.0) + device = torch.device("cuda") + M = dist.sample((seqlen, n_stream, n_stream)).to(device) + M.requires_grad_() + + ###################################################################### + # Shared forward + one shared loss weight + ###################################################################### + R, P = sinkhorn_forward(M, iters) + loss_weight = torch.randn_like(R) + + ###################################################################### + # Method A: Autograd (reference) + ###################################################################### + loss_a = (R * loss_weight).sum() + loss_a.backward() + grad_M_autograd = M.grad.detach().clone() + + ###################################################################### + # Method B: Implicit differentiation with autotuning + ###################################################################### + grad_R = loss_weight + + print("\n" + "=" * 60) + print("Starting autotuning...") + print("=" * 60) + + # Set autotune inputs + with set_autotune_inputs(R, grad_R): + kernel = sinkhorn_bwd_implicit_cg(n_stream) + print(kernel.get_kernel_source()) + print("\n" + "=" * 60) + print("Autotuning completed! Running with best configuration...") + print("=" * 60) + + # Warmup and timing with best config + a = torch.randn(8192, 8192, device=device) + for _ in trange(4, desc="Warmup"): + _ = a @ a + grad_M_implicit = kernel(R, grad_R) + torch.cuda.synchronize() + + # Timing + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + + torch.cuda.synchronize() + start_event.record() + + for _ in range(repeat): + grad_M_implicit = kernel(R, grad_R) + + end_event.record() + torch.cuda.synchronize() + + elapsed_time_ms = start_event.elapsed_time(end_event) + + print(f"\nKernel execution time ({repeat = }): {elapsed_time_ms:.3f} ms") + print(f"Average time per iteration: {elapsed_time_ms / repeat:.3f} ms") + + ###################################################################### + # Compare + ###################################################################### + g1 = grad_M_autograd + g2 = grad_M_implicit + + abs_diff = (g1 - g2).abs() + # Use max of absolute values for more stable relative error + rel_diff = abs_diff / (torch.maximum(g1.abs(), g2.abs()) + 1e-8) + + print("\n" + "=" * 60) + print("Comparison of gradients dL/dM") + print("=" * 60) + + def format_list(ls): + return [f"{x:.2e}" for x in ls] + + MAE = abs_diff.mean(dim=(-1, -2)).tolist() + max_abs_diff = abs_diff.reshape(seqlen, -1).max(-1).values.tolist() + mean_rel_diff = rel_diff.mean(dim=(-1, -2)).tolist() + max_rel_diff = rel_diff.reshape(seqlen, -1).max(-1).values.tolist() + + print(f"Max MAE = {max(MAE):.6e}") + print(f"Max max_abs_diff = {max(max_abs_diff):.6e}") + print(f"Max mean_rel_diff = {max(mean_rel_diff):.6e}") + print(f"Max max_rel_diff = {max(max_rel_diff):.6e}") + + print("\nGrad (autograd) sample:\n", g1[0, :3, :3]) + print("\nGrad (implicit) sample:\n", g2[0, :3, :3]) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_mhc/example_mhc_post.py b/examples/deepseek_mhc/example_mhc_post.py new file mode 100644 index 0000000000..9c9dc2f720 --- /dev/null +++ b/examples/deepseek_mhc/example_mhc_post.py @@ -0,0 +1,140 @@ +import math + +import torch + +import tilelang +import tilelang.language as T + + +@tilelang.jit( + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10}, +) +def mhc_post_tilelang(a, b, c, d, x, hc: int, hidden: int, n_thr: int = 128, h_blk: int = 1024) -> tilelang.JITKernel: + # rename for shorter code + n = T.dynamic("num_tokens") + h = hidden + + h_blk = math.gcd(hidden, h_blk) + a: T.Tensor((n, hc, hc), T.float32) + b: T.Tensor((n, hc, h), T.bfloat16) + c: T.Tensor((n, hc), T.float32) + d: T.Tensor((n, h), T.bfloat16) + x: T.Tensor((n, hc, h), T.bfloat16) + with T.Kernel(n, threads=n_thr) as i_n: + x_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + b_shared = T.alloc_shared((hc, h_blk), T.bfloat16) + d_shared = T.alloc_shared(h_blk, T.bfloat16) + + x_local = T.alloc_fragment((hc, h_blk), T.float32) + b_local = T.alloc_fragment((hc, h_blk), T.float32) + d_local = T.alloc_fragment(h_blk, T.float32) + + a_local = T.alloc_fragment((hc, hc), T.float32) + c_local = T.alloc_fragment(hc, T.float32) + T.copy(a[i_n, 0, 0], a_local) + T.copy(c[i_n, 0], c_local) + + for i0_h in T.Pipelined(T.ceildiv(h, h_blk), num_stages=2): + T.copy(b[i_n, 0, i0_h * h_blk], b_shared) + T.copy(d[i_n, i0_h * h_blk], d_shared) + + T.copy(b_shared, b_local) + T.copy(d_shared, d_local) + for i_hco, i1_h in T.Parallel(hc, h_blk): + x_local[i_hco, i1_h] = c_local[i_hco] * d_local[i1_h] + for i_hci in T.serial(hc): + x_local[i_hco, i1_h] += a_local[i_hci, i_hco] * b_local[i_hci, i1_h] + T.copy(x_local, x_shared) + + T.copy(x_shared, x[i_n, 0, i0_h * h_blk]) + + +def mhc_post( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + out = torch.empty_like(residual) + print( + mhc_post_tilelang.get_kernel_source( + comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1] + ) + ) + mhc_post_tilelang(comb_res_mix, residual, post_layer_mix.squeeze(-1), x, out, residual.shape[-2], residual.shape[-1]) + return out + + +def mhc_post_ref( + x: torch.Tensor, + residual: torch.Tensor, + post_layer_mix: torch.Tensor, + comb_res_mix: torch.Tensor, +) -> torch.Tensor: + term2 = torch.bmm(comb_res_mix.mT, residual.float()) + return (x.float().unsqueeze(-2) * post_layer_mix + term2).bfloat16() + + +def generate_test_data( + n: int, + h: int, + hc_mult: int, + device: str = "cuda", +) -> dict[str, torch.Tensor]: + """Generate test data for post operator.""" + torch.random.manual_seed(42) + + x = torch.randn((n, h), dtype=torch.bfloat16, device=device) + residual = torch.randn((n, hc_mult, h), dtype=torch.bfloat16, device=device) + post_layer_mix = torch.randn((n, hc_mult, 1), dtype=torch.float32, device=device) + comb_res_mix = torch.randn((n, hc_mult, hc_mult), dtype=torch.float32, device=device) + + return { + "x": x, + "residual": residual, + "post_layer_mix": post_layer_mix, + "comb_res_mix": comb_res_mix, + } + + +def test(n: int, h: int) -> None: + print(f"Testing mhc_post with {n=} {h=}") + test_data = generate_test_data(n=n, h=h, hc_mult=4) + out_tl = mhc_post(**test_data) + out_ref = mhc_post_ref(**test_data) + torch.testing.assert_close(out_tl, out_ref) + + +def run_regression_perf(n: int = 4096, h: int = 2560, hc_mult: int = 4) -> float: + test_data = generate_test_data(n=n, h=h, hc_mult=hc_mult) + out = torch.empty_like(test_data["residual"]) + post_layer_mix = test_data["post_layer_mix"].squeeze(-1) + + def run_kernel_only(): + mhc_post_tilelang( + test_data["comb_res_mix"], + test_data["residual"], + post_layer_mix, + test_data["x"], + out, + hc_mult, + h, + ) + + run_kernel_only() + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") + + +def main(): + for n in [4096]: + for h in [1280, 2560, 7168]: + test(n=n, h=h) + + +if __name__ == "__main__": + # main() + tilelang.disable_cache() + test(n=4096, h=2560) diff --git a/examples/deepseek_mhc/example_mhc_pre.py b/examples/deepseek_mhc/example_mhc_pre.py new file mode 100644 index 0000000000..28b6c32bf6 --- /dev/null +++ b/examples/deepseek_mhc/example_mhc_pre.py @@ -0,0 +1,490 @@ +import math + +import tilelang +import tilelang.language as T +import torch + + +@tilelang.jit( + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_PTXAS_REGISTER_USAGE_LEVEL: 10}, +) +def mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual, + post_mix, + comb_mix, + layer_input, + hidden_size: int, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 16, + hc_mult: int = 4, +): + """Deeply fused kernels, everything other than gemm & sqrsum in mHC pre block.""" + num_tokens = T.dynamic("num_tokens") + hc_mult3 = hc_mult * (2 + hc_mult) + hidden_block = math.gcd(512, hidden_size) + + gemm_out_mul: T.Tensor[[n_splits, num_tokens, hc_mult3], T.float32] + gemm_out_sqrsum: T.Tensor[[n_splits, num_tokens], T.float32] + hc_scale: T.Tensor[[3], T.float32] + hc_base: T.Tensor[[hc_mult3], T.float32] + residual: T.Tensor[[num_tokens, hc_mult, hidden_size], T.bfloat16] + # outputs + post_mix: T.Tensor[[num_tokens, hc_mult], T.float32] + comb_mix: T.Tensor[[num_tokens, hc_mult * hc_mult], T.float32] + layer_input: T.Tensor[[num_tokens, hidden_size], T.bfloat16] + + with T.Kernel(num_tokens, threads=96) as i: + ################################################################## + # _pre_norm_fn_fwd_norm + rms = T.alloc_fragment(1, T.float32) + mixes = T.alloc_fragment(hc_mult3, T.float32) + T.clear(mixes) + rms[0] = 0 + for i_split in T.serial(n_splits): + rms[0] += gemm_out_sqrsum[i_split, i] + rms[0] = T.rsqrt(rms[0] / (hc_mult * hidden_size) + rms_eps) + for j in T.Parallel(hc_mult3): + mixes[j] = 0 + for i_split in T.serial(n_splits): + mixes[j] += gemm_out_mul[i_split, i, j] + mixes[j] *= rms[0] + mixes_shared = T.alloc_shared(hc_mult3, T.float32) + T.copy(mixes, mixes_shared) + + if T.get_thread_binding() < 32: + ################################################################## + # _pre_split_mixes_fwd (post & comb) + cm = T.alloc_fragment((hc_mult, hc_mult), T.float32) + for j in T.Parallel(hc_mult): + post_mix[i, j] = T.sigmoid(mixes_shared[j + hc_mult] * hc_scale[1] + hc_base[j + hc_mult]) * hc_post_mult_value + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = mixes_shared[j * hc_mult + k + hc_mult * 2] * hc_scale[2] + hc_base[j * hc_mult + k + hc_mult * 2] + + ################################################################## + # _sinkhorn_fwd + row_sum = T.alloc_fragment(hc_mult, T.float32) + col_sum = T.alloc_fragment(hc_mult, T.float32) + + # comb = comb.softmax(-1) + eps + row_max = T.alloc_fragment(hc_mult, T.float32) + T.reduce_max(cm, row_max, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = T.exp(cm[j, k] - row_max[j]) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / row_sum[j] + hc_sinkhorn_eps + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + for _ in T.serial(sinkhorn_repeat - 1): + # comb = comb / (comb.sum(-1) + eps) + T.reduce_sum(cm, row_sum, dim=1) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (row_sum[j] + hc_sinkhorn_eps) + + # comb = comb / (comb.sum(-2) + eps) + T.reduce_sum(cm, col_sum, dim=0) + for j, k in T.Parallel(hc_mult, hc_mult): + cm[j, k] = cm[j, k] / (col_sum[k] + hc_sinkhorn_eps) + + # save comb_mix to global memory + for j, k in T.Parallel(hc_mult, hc_mult): + comb_mix[i, j * hc_mult + k] = cm[j, k] + else: + ################################################################## + # _pre_split_mixes_fwd (pre) + pre_mix_shared = T.alloc_shared(hc_mult, T.float32) + for j in T.Parallel(hc_mult): + pre_mix_shared[j] = ( + T.sigmoid( + mixes_shared[j] * hc_scale[0] + hc_base[j], + ) + + hc_pre_eps + ) + ################################################################### + # _pre_apply_mix_fwd + for i0_h in T.Pipelined(hidden_size // hidden_block, num_stages=2): + xs = T.alloc_shared((hc_mult, hidden_block), T.float32) + xl = T.alloc_fragment((hc_mult, hidden_block), T.float32) + T.copy(residual[i, 0, i0_h * hidden_block], xs) + T.copy(xs, xl) + + ol = T.alloc_fragment(hidden_block, T.float32) + T.clear(ol) + + for i_hc in T.serial(hc_mult): + pre = pre_mix_shared[i_hc] + for i1_h in T.Parallel(hidden_block): + ol[i1_h] += pre * xl[i_hc, i1_h] + + T.copy(ol, layer_input[i, i0_h * hidden_block]) + + +@tilelang.jit +def mhc_pre_gemm_sqrsum_tilelang( + x, + fn, + out, + sqrsum, + hc_mult3: int, + hc_hidden_size: int, + token_block: int = 32, + hidden_block: int = 256, +) -> tilelang.JITKernel: + """Not highly optimized TileLang implementation of fused gemm and sqrsum in mHC pre block.""" + assert hc_mult3 <= 32 # should be 24 usually + num_tokens = T.dynamic("num_tokens") + assert hc_hidden_size % hidden_block == 0 + + x: T.Tensor((num_tokens, hc_hidden_size), T.bfloat16) + fn: T.Tensor((hc_mult3, hc_hidden_size), T.float32) + out: T.Tensor((num_tokens, hc_mult3), T.float32) + sqrsum: T.Tensor((num_tokens), T.float32) + + with T.Kernel(T.ceildiv(num_tokens, token_block)) as px: + out_frag = T.alloc_fragment((token_block, 32), T.float32) + sqrsum_part = T.alloc_fragment((token_block, 4), T.float32) + T.clear(out_frag) + T.clear(sqrsum_part) + for pz in T.Pipelined(hc_hidden_size // hidden_block, num_stages=2): + x_smem_16 = T.alloc_shared((token_block, hidden_block), T.bfloat16) + fn_smem = T.alloc_shared((32, hidden_block), T.float32) + + T.annotate_layout({x_smem_16: tilelang.layout.make_swizzled_layout(x_smem_16)}) + + T.copy(x[px * token_block, pz * hidden_block], x_smem_16) + T.copy(fn[0, pz * hidden_block], fn_smem) + + x_frag_16 = T.alloc_fragment((token_block, hidden_block), T.bfloat16) + T.copy(x_smem_16, x_frag_16) + x_frag = T.alloc_fragment((token_block, hidden_block), T.float32) + T.copy(x_frag_16, x_frag) + + for jj in T.serial(hidden_block // 4): + for i, j in T.Parallel(token_block, 4): + sqrsum_part[i, j] += x_frag[i, jj * 4 + j] * x_frag[i, jj * 4 + j] + + # should be TF32 gemm + T.gemm( + x_frag, + fn_smem, + out_frag, + transpose_A=False, + transpose_B=True, + clear_accum=False, + ) + sqrsum_l = T.alloc_fragment(token_block, T.float32) + T.reduce_sum(sqrsum_part, sqrsum_l) + for i in T.Parallel(token_block): + sqrsum[px * token_block + i] = sqrsum_l[i] + for i, j in T.Parallel(token_block, 32): + if j < hc_mult3: + out[px * token_block + i, j] = out_frag[i, j] + + +def mhc_pre( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, + n_splits: int = 1, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """ + Forward pass for mHC pre block. + + Args: + residual: shape (..., hc_mult, hidden_size), dtype torch.bfloat16 + fn: shape (hc_mult3, hc_mult * hidden_size), dtype torch.float32 + hc_scale: shape (3,), dtype torch.float32 + hc_base: shape (hc_mult3,), dtype torch.float32 + rms_eps: RMS normalization epsilon + hc_pre_eps: pre-mix epsilon + hc_sinkhorn_eps: sinkhorn epsilon + hc_post_mult_value: post-mix multiplier value + sinkhorn_repeat: number of sinkhorn iterations + n_splits: split-k factor; TileLang version of mhc_pre_gemm_sqrsum doesn't support this + + Returns: + post_mix: shape (..., hc_mult), dtype torch.float32 + comb_mix: shape (..., hc_mult, hc_mult), dtype torch.float32 + layer_input: shape (..., hidden_size), dtype torch.bfloat16 + """ + + # Validate shapes + assert residual.dtype == torch.bfloat16 + assert fn.dtype == torch.float32 + assert hc_scale.dtype == torch.float32 + assert hc_base.dtype == torch.float32 + + hc_mult = residual.shape[-2] + hidden_size = residual.shape[-1] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + hc_hidden_size = hc_mult * hidden_size + assert fn.shape[0] == hc_mult3 + assert fn.shape[1] == hc_hidden_size + assert hc_scale.shape == (3,) + assert hc_base.shape == (hc_mult3,) + + outer_shape = residual.shape[:-2] + + residual_flat = residual.view(-1, hc_mult, hidden_size) + num_tokens = residual_flat.shape[0] + fn_flat = fn + + post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=residual.device) + comb_mix = torch.empty(num_tokens, hc_mult2, dtype=torch.float32, device=residual.device) + layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) + + gemm_out_mul = torch.empty(n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device) + gemm_out_sqrsum = torch.empty(n_splits, num_tokens, dtype=torch.float32, device=residual.device) + assert n_splits == 1, "The simple TileLang version gemm_sqrsum doesn't support split-k" + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn_flat, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + post_mix = post_mix.view(*outer_shape, hc_mult, 1) + comb_mix = comb_mix.view(*outer_shape, hc_mult, hc_mult) + layer_input = layer_input.view(*outer_shape, hidden_size) + + return post_mix, comb_mix, layer_input + + +def sinkhorn_normalize_ref(x: torch.Tensor, repeat: int, eps: float) -> torch.Tensor: + x = x.softmax(-1) + eps + x = x / (x.sum(-2, keepdim=True) + eps) + for _ in range(repeat - 1): + x = x / (x.sum(-1, keepdim=True) + eps) + x = x / (x.sum(-2, keepdim=True) + eps) + return x + + +def mhc_pre_ref( + residual: torch.Tensor, + fn: torch.Tensor, + hc_scale: torch.Tensor, + hc_base: torch.Tensor, + rms_eps: float, + hc_pre_eps: float, + hc_sinkhorn_eps: float, + hc_post_mult_value: float, + sinkhorn_repeat: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + hc_mult = residual.shape[-2] + + residual_flat = residual.flatten(-2, -1).float() + sqrsum = residual_flat.square().sum(-1) + mixes = residual_flat @ fn.T * (sqrsum.unsqueeze(-1) / fn.shape[-1] + rms_eps).rsqrt() + + hc_scale = torch.cat( + [ + hc_scale[0].expand(hc_mult), + hc_scale[1].expand(hc_mult), + hc_scale[2].expand(hc_mult * hc_mult), + ], + ) + mixes = mixes * hc_scale + hc_base + + pre_mix = mixes[:, :hc_mult].sigmoid().unsqueeze(-1) + hc_pre_eps + post_mix = (mixes[:, hc_mult : 2 * hc_mult].sigmoid() * hc_post_mult_value).unsqueeze(-1) + res_mix = mixes[:, 2 * hc_mult :].view(-1, hc_mult, hc_mult) + + res_mix = sinkhorn_normalize_ref(res_mix, repeat=sinkhorn_repeat, eps=hc_sinkhorn_eps) + + layer_input = (residual * pre_mix).sum(-2).bfloat16() + + return post_mix, res_mix, layer_input + + +def generate_test_data( + n: int, + hc_mult: int, + hidden_size: int, + rms_eps: float = 1e-6, + hc_pre_eps: float = 1e-6, + hc_sinkhorn_eps: float = 1e-6, + hc_post_mult_value: float = 1.0, + sinkhorn_repeat: int = 10, +) -> dict[str, torch.Tensor | float]: + """Generate test data for big fuse operator.""" + torch.random.manual_seed(42) + + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + device = "cuda" + + residual = ( + torch.randn((n, hc_mult, hidden_size), dtype=torch.float, device=device) + .mul(1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + .bfloat16() + ) + + fn = ( + torch.randn((hc_mult3, hc_mult, hidden_size), dtype=torch.float, device=device) + * 1e-4 + * (1 + torch.arange(hc_mult, device=device).mul(0.01).view(1, -1, 1)) + ).flatten(1, 2) + + hc_scale = torch.randn((3,), dtype=torch.float, device=device) * 0.1 + + hc_base = torch.randn((hc_mult3,), dtype=torch.float, device=device) * 0.1 + + return { + "residual": residual, + "fn": fn, + "hc_scale": hc_scale, + "hc_base": hc_base, + "rms_eps": rms_eps, + "hc_pre_eps": hc_pre_eps, + "hc_sinkhorn_eps": hc_sinkhorn_eps, + "hc_post_mult_value": hc_post_mult_value, + "sinkhorn_repeat": sinkhorn_repeat, + } + + +def test(n: int, hidden_size: int, hc_mult: int) -> None: + print(f"Testing mhc_pre with {n=} {hidden_size=} {hc_mult=}") + test_data = generate_test_data( + n=n, + hc_mult=hc_mult, + hidden_size=hidden_size, + ) + + # Forward pass with big fuse + post_mix_fused, comb_mix_fused, layer_input_fused = mhc_pre(**test_data) + + # Forward pass with reference + post_mix_ref, comb_mix_ref, layer_input_ref = mhc_pre_ref(**test_data) + + # Compare outputs + torch.testing.assert_close(post_mix_fused, post_mix_ref) + torch.testing.assert_close(comb_mix_fused, comb_mix_ref) + torch.testing.assert_close(layer_input_fused, layer_input_ref) + + +def run_regression_perf( + n: int = 2048, + hidden_size: int = 4096, + hc_mult: int = 4, + rms_eps: float = 1e-6, + hc_pre_eps: float = 1e-6, + hc_sinkhorn_eps: float = 1e-6, + hc_post_mult_value: float = 1.0, + sinkhorn_repeat: int = 10, + n_splits: int = 1, +) -> float: + assert n_splits == 1, "The simple TileLang version gemm_sqrsum doesn't support split-k" + + test_data = generate_test_data( + n=n, + hc_mult=hc_mult, + hidden_size=hidden_size, + rms_eps=rms_eps, + hc_pre_eps=hc_pre_eps, + hc_sinkhorn_eps=hc_sinkhorn_eps, + hc_post_mult_value=hc_post_mult_value, + sinkhorn_repeat=sinkhorn_repeat, + ) + + residual = test_data["residual"] + fn = test_data["fn"] + hc_scale = test_data["hc_scale"] + hc_base = test_data["hc_base"] + + num_tokens = residual.shape[0] + hc_mult2 = hc_mult * hc_mult + hc_mult3 = hc_mult * 2 + hc_mult2 + + residual_flat = residual.view(num_tokens, hc_mult, hidden_size) + post_mix = torch.empty(num_tokens, hc_mult, dtype=torch.float32, device=residual.device) + comb_mix = torch.empty(num_tokens, hc_mult2, dtype=torch.float32, device=residual.device) + layer_input = torch.empty(num_tokens, hidden_size, dtype=torch.bfloat16, device=residual.device) + gemm_out_mul = torch.empty(n_splits, num_tokens, hc_mult3, dtype=torch.float32, device=residual.device) + gemm_out_sqrsum = torch.empty(n_splits, num_tokens, dtype=torch.float32, device=residual.device) + + def run_kernel_only(): + mhc_pre_gemm_sqrsum_tilelang( + residual_flat.view(num_tokens, hc_mult * hidden_size), + fn, + gemm_out_mul.squeeze(0), + gemm_out_sqrsum.squeeze(0), + hc_mult3, + hc_mult * hidden_size, + ) + + mhc_pre_big_fuse_tilelang( + gemm_out_mul, + gemm_out_sqrsum, + hc_scale, + hc_base, + residual_flat, + post_mix, + comb_mix, + layer_input, + hidden_size, + rms_eps, + hc_pre_eps, + hc_sinkhorn_eps, + hc_post_mult_value, + sinkhorn_repeat, + n_splits, + hc_mult, + ) + + run_kernel_only() + + from tilelang.profiler import do_bench + + return do_bench(run_kernel_only, backend="cupti") + + +def main(): + for n1 in [512, 1024, 2048, 8192]: + for hidden_size in [1280, 2560, 4096]: + for hc_mult in [4]: + test(n=n1, hidden_size=hidden_size, hc_mult=hc_mult) + + +if __name__ == "__main__": + main() diff --git a/examples/deepseek_mhc/regression_example_mhc.py b/examples/deepseek_mhc/regression_example_mhc.py new file mode 100644 index 0000000000..880c3e04a6 --- /dev/null +++ b/examples/deepseek_mhc/regression_example_mhc.py @@ -0,0 +1,15 @@ +import tilelang.testing +import example_mhc_post +import example_mhc_pre + + +def regression_example_mhc_post(): + tilelang.testing.process_func(example_mhc_post.run_regression_perf) + + +def regression_example_mhc_pre(): + tilelang.testing.process_func(example_mhc_pre.run_regression_perf) + + +if __name__ == "__main__": + tilelang.testing.regression() diff --git a/examples/deepseek_mhc/test_example_mhc.py b/examples/deepseek_mhc/test_example_mhc.py new file mode 100644 index 0000000000..3d9ecad4da --- /dev/null +++ b/examples/deepseek_mhc/test_example_mhc.py @@ -0,0 +1,18 @@ +import tilelang.testing + +from example_mhc_post import main as main_post +from example_mhc_pre import main as main_pre + + +@tilelang.testing.requires_cuda +def test_mhc_post(): + main_post() + + +@tilelang.testing.requires_cuda +def test_mhc_pre(): + main_pre() + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/deepseek_mla/README.md b/examples/deepseek_mla/README.md index bd3539d269..f75e606bd5 100644 --- a/examples/deepseek_mla/README.md +++ b/examples/deepseek_mla/README.md @@ -44,7 +44,7 @@ for i in range(loop_range): scores_scale = exp(scores_max_prev - scores_max) acc_o *= scores_scale acc_s = exp(acc_s - scores_max) - acc_o = acc_s @ V[i] + acc_o += acc_s @ V[i] ... ``` diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py new file mode 100644 index 0000000000..9eae480822 --- /dev/null +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_aiter.py @@ -0,0 +1,290 @@ +# This benchmark script is modified based on: https://github.com/deepseek-ai/FlashMLA/blob/main/benchmark/bench_flash_mla.py +# ruff: noqa +import argparse +import math +import random +import torch + +import triton +import triton.language as tl + +import tilelang +from tilelang.profiler import do_bench + +try: + from aiter.mla import mla_decode_fwd +except ImportError: + print("aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device.") + + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32) + for i in range(b): + begin = i * max_seqlen_pad + end = begin + cache_seqlens[i] + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[begin:end].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[begin:end].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out, lse + + out_torch, lse_torch = ref_mla() + t = triton.testing.do_bench(ref_mla) + return out_torch, lse_torch, t + + +@torch.inference_mode() +def run_mla_aiter(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + assert d > dv, "mla with rope dim should be larger than no rope dim" + + qo_indptr = torch.zeros(b + 1, dtype=torch.int) + kv_indptr = torch.zeros(b + 1, dtype=torch.int) + seq_lens_qo = torch.empty(b, dtype=torch.int) + seq_lens_qo.fill_(1) + max_seqlen_qo = seq_lens_qo.max().item() + + kv_indptr[1 : b + 1] = torch.cumsum(cache_seqlens, dim=0) + qo_indptr[1 : b + 1] = torch.cumsum(seq_lens_qo, dim=0) + total_q = qo_indptr[-1].item() + + # set block_size to 1 + page_size = 1 + kv_buffer = blocked_k.view(-1, page_size, h_kv, d) + + flat_indices = [] + for i in range(b): + start = i * max_seqlen_pad + end = start + cache_seqlens[i] + flat_indices.append(torch.arange(start, end, dtype=torch.int)) + + kv_indices = torch.cat(flat_indices) + + kv_last_page_lens = torch.ones(b, dtype=torch.int) + + sm_scale = 1.0 / (d**0.5) + + def mla_aiter(): + out_aiter = torch.empty((total_q, h_q, dv), dtype=dtype).fill_(-1) + attn_logits_aiter, attn_lse_aiter = mla_decode_fwd( + q.view((total_q, h_q, d)), + kv_buffer, + out_aiter, + qo_indptr, + kv_indptr, + kv_indices, + kv_last_page_lens, + max_seqlen_qo, + sm_scale, + ) + return out_aiter.view([b, s_q, h_q, dv]) + + out_aiter = mla_aiter() + t = triton.testing.do_bench(mla_aiter) + return out_aiter, None, t + + +FUNC_TABLE = { + "torch": run_torch_mla, + "mla_aiter": run_mla_aiter, +} + + +def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print( + f"comparing {baseline} vs {target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}" + ) + device = torch.device("cuda:0") + torch.set_default_dtype(dtype) + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert baseline in FUNC_TABLE + assert target in FUNC_TABLE + baseline_func = FUNC_TABLE[baseline] + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_a, lse_a, perf_a = baseline_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + torch.testing.assert_close(out_b.float(), out_a.float(), atol=1e-2, rtol=1e-2), "out" + if target not in ["mla_aiter"]: + # flash_mla_triton doesn't return lse + torch.testing.assert_close(lse_b.float(), lse_a.float(), atol=1e-2, rtol=1e-2), "lse" + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.3f} TFLOPS, {bytes / 10**6 / perf_a:.3f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") + return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b + + +def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): + print(f"{target}: {b=}, {s_q=}, mean_seqlens={cache_seqlens.float().mean()}, {h_q=}, {h_kv=}, {d=}, {dv=}, {causal=}, {dtype=}") + torch.set_default_dtype(dtype) + device = torch.device("cuda:0") + torch.set_default_device(device) + torch.cuda.set_device(device) + torch.manual_seed(0) + random.seed(0) + assert target in FUNC_TABLE, f"target {target} not in {FUNC_TABLE}" + target_func = FUNC_TABLE[target] + + total_seqlens = cache_seqlens.sum().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = triton.cdiv(max_seqlen, 256) * 256 + # print(f"{total_seqlens=}, {mean_seqlens=}, {max_seqlen=}") + + q = torch.randn(b, s_q, h_q, d) + block_size = 64 + block_table = torch.arange(b * max_seqlen_pad // block_size, dtype=torch.int32).view(b, max_seqlen_pad // block_size) + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d) + + out_b, lse_b, perf_b = target_func( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 + bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") + return bytes / 10**6 / perf_b + + +available_targets = [ + "torch", + "mla_aiter", +] + +shape_configs = [ + { + "b": batch, + "s_q": 1, + "cache_seqlens": torch.tensor([seqlen + 2 * i for i in range(batch)], dtype=torch.int32, device="cuda"), + "h_q": head, + "h_kv": 1, + "d": 512 + 64, + "dv": 512, + "causal": True, + "dtype": torch.bfloat16, + } + for batch in [64, 128] + for seqlen in [1024, 2048, 4096, 8192, 16384] + for head in [128] +] + + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("--baseline", type=str, default="torch") + parser.add_argument("--target", type=str, default="mla_aiter") + parser.add_argument("--all", action="store_true") + parser.add_argument("--one", action="store_true") + parser.add_argument("--compare", action="store_true") + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = get_args() + benchmark_type = "all" if args.all else f"{args.baseline}_vs_{args.target}" if args.compare else args.target + with open(f"{benchmark_type}_perf.csv", "w") as fout: + fout.write("name,batch,seqlen,head,bw\n") + for shape in shape_configs: + if args.all: + for target in available_targets: + perf = compare_a( + target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) + elif args.compare: + perfa, prefb = compare_ab( + args.baseline, + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.baseline},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perfa:.0f}\n" + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{prefb:.0f}\n" + ) + elif args.one: + perf = compare_a( + args.target, + shape["b"], + shape["s_q"], + shape["cache_seqlens"], + shape["h_q"], + shape["h_kv"], + shape["d"], + shape["dv"], + shape["causal"], + shape["dtype"], + ) + fout.write( + f"{args.target},{shape['b']},{shape['cache_seqlens'].float().mean().cpu().item():.0f},{shape['h_q']},{perf:.0f}\n" + ) diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py index dccf333ad3..399bb8e6e5 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_tilelang.py @@ -131,7 +131,7 @@ def main_split( lse_local_split = glse[bz, by, k] scale_local = T.exp2(lse_local_split - lse_logsum_local) for i in T.Parallel(dim): - o_accum_local[i] += po_local[i] * scale_local[0] + o_accum_local[i] += po_local[i] * scale_local for i in T.Parallel(dim): Output[bz, by, i] = o_accum_local[i] @@ -259,6 +259,8 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): num_split = 4 threads = 128 + print(f"Using {batch=}, {heads=}, {kv_heads=}, {kv_ctx=}, {dim=}, {pe_dim=}") + if enable_autotune: kernel = flashmla_decode(batch, heads, kv_heads, kv_ctx, dim, pe_dim) else: @@ -267,8 +269,6 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): input_tensors = profiler._get_inputs() tilelang_output = kernel(*input_tensors) ref_output = ref_program(*input_tensors) - print(f"Tilelang output: {tilelang_output}") - print(f"Ref output: {ref_output}") torch.testing.assert_close(tilelang_output, ref_output, rtol=0.01, atol=0.01) latency = profiler.do_bench(warmup=500) print(f"Latency: {latency} ms") diff --git a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py index 861e841c4e..e8c1006a01 100644 --- a/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py +++ b/examples/deepseek_mla/amd/benchmark_mla_decode_amd_triton.py @@ -378,8 +378,8 @@ def compare_ab(baseline, target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.0f} TFLOPS, {bytes / 10**6 / perf_a:.0f} GB/s") - print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + print(f"perf {baseline}: {perf_a:.3f} ms, {FLOPS / 10**9 / perf_a:.3f} TFLOPS, {bytes / 10**6 / perf_a:.3f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") return bytes / 10**6 / perf_a, bytes / 10**6 / perf_b @@ -410,7 +410,7 @@ def compare_a(target, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype): FLOPS = s_q * total_seqlens * h_q * (d + dv) * 2 bytes = (total_seqlens * h_kv * d + b * s_q * h_q * d + b * s_q * h_q * dv) * (torch.finfo(dtype).bits // 8) - print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.0f} TFLOPS, {bytes / 10**6 / perf_b:.0f} GB/s") + print(f"perf {target}: {perf_b:.3f} ms, {FLOPS / 10**9 / perf_b:.3f} TFLOPS, {bytes / 10**6 / perf_b:.3f} GB/s") return bytes / 10**6 / perf_b diff --git a/examples/deepseek_mla/example_mla_decode.py b/examples/deepseek_mla/example_mla_decode.py index 7de4faf089..4daa39f494 100644 --- a/examples/deepseek_mla/example_mla_decode.py +++ b/examples/deepseek_mla/example_mla_decode.py @@ -6,14 +6,10 @@ from einops import rearrange, einsum import argparse -tilelang.disable_cache() - @tilelang.jit( - out_idx=[6], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, + out_idx=[4], + pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}, ) def flashattn(batch, heads, kv_head_num, seqlen_kv, dim, pe_dim, block_N, block_H, num_split, softmax_scale): scale = float(softmax_scale * 1.44269504) # log2(e) @@ -29,10 +25,10 @@ def main_split( Q_pe: T.Tensor([batch, heads, pe_dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype), ): + glse = T.alloc_global([batch, heads, num_split], dtype) + Output_partial = T.alloc_global([batch, heads, num_split, dim], dtype) # flash_attn_split with T.Kernel(batch, heads // min(block_H, kv_group_num), num_split, threads=256) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -127,8 +123,6 @@ def main_no_split( Q_pe: T.Tensor([batch, heads, pe_dim], dtype), KV: T.Tensor([batch, seqlen_kv, kv_head_num, dim], dtype), K_pe: T.Tensor([batch, seqlen_kv, kv_head_num, pe_dim], dtype), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor([batch, heads, num_split, dim], dtype), Output: T.Tensor([batch, heads, dim], dtype), ): with T.Kernel(heads // min(block_H, kv_group_num), batch, threads=256) as (hid, bid): @@ -187,15 +181,13 @@ def main_no_split( return main_no_split -def ref_program(q, q_pe, kv, k_pe, glse, Output_partial): +def ref_program(q, q_pe, kv, k_pe): # """ # Inputs: # - q (Tensor): [batch, heads, dim] # - q_pe (Tensor): [batch, heads, pe_dim] # - kv (Tensor): [batch, seqlen_kv, kv_head_num, dim] # - k_pe (Tensor): [batch, seqlen_kv, kv_head_num, pe_dim] - # - glse (Tensor): [batch, heads, num_split] - # - Output_partial (Tensor): [batch, heads, num_split, dim] # Outputs: # - output (Tensor): [batch, heads, dim] # """ diff --git a/examples/deepseek_mla/example_mla_decode_ws.py b/examples/deepseek_mla/example_mla_decode_ws.py index 32eb0d4754..98657e381a 100644 --- a/examples/deepseek_mla/example_mla_decode_ws.py +++ b/examples/deepseek_mla/example_mla_decode_ws.py @@ -101,9 +101,9 @@ def main_split( T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) T.clear(acc_s) - T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -136,9 +136,9 @@ def main_split( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) T.clear(acc_s) - T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -215,40 +215,44 @@ def main_split( T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ - bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v - ] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): kv_indices = (seqlen_kv // num_split) * bz + (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ - bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v - ] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) # combine @@ -346,9 +350,9 @@ def main_no_split( T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) T.clear(acc_s) - T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -381,9 +385,9 @@ def main_no_split( T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) T.clear(acc_s) - T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -459,40 +463,44 @@ def main_no_split( T.barrier_wait(bar_k_0_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): kv_indices = (i_i * 2) * block_N + r * 16 + (tx - 256) // 8 - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ - bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v - ] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 T.barrier_wait(bar_k_1_free[0], ((i_i & 1) ^ 1)) for r in T.serial(4): kv_indices = (i_i * 2 + 1) * block_N + r * 16 + (tx - 256) // 8 - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = K_pe[ - bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8 + v - ] + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[bid, kv_indices, cur_kv_head, dim // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(K_pe[bid, kv_indices, cur_kv_head, (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) if num_split > 1: diff --git a/examples/deepseek_mla/test_example_mla_decode.py b/examples/deepseek_mla/test_example_mla_decode.py index a269ea57ae..00e30023a4 100644 --- a/examples/deepseek_mla/test_example_mla_decode.py +++ b/examples/deepseek_mla/test_example_mla_decode.py @@ -1,9 +1,14 @@ +import os +import pytest import tilelang.testing import example_mla_decode +_is_cutedsl = os.environ.get("TILELANG_TARGET", "").lower() == "cutedsl" + @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_mla_decode(): example_mla_decode.main() diff --git a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py index ca98d01be9..697f3de38c 100644 --- a/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py +++ b/examples/deepseek_nsa/benchmark/benchmark_nsa_fwd.py @@ -460,13 +460,7 @@ def get_configs(): @tilelang.autotune( configs=get_configs(), ) -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_sparse_attention( batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16, block_T=128, num_stages=2, threads=32 ): diff --git a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py index 3da285a9ba..2aa30a5bc5 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_bwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_bwd.py @@ -18,13 +18,7 @@ import tilelang -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_kernel_fwd( batch, heads, diff --git a/examples/deepseek_nsa/example_tilelang_nsa_decode.py b/examples/deepseek_nsa/example_tilelang_nsa_decode.py index 381d92493e..79414a762b 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_decode.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_decode.py @@ -12,11 +12,7 @@ # auto warp specialization may have some bugs. @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}, ) def native_sparse_attention( batch, diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py index 7b36d6e26f..abed2e41dd 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd.py @@ -10,11 +10,7 @@ @tilelang.jit( out_idx=[-1], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def native_sparse_attention(batch, heads, seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: diff --git a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py index b52ebe42e2..1d5d942b40 100644 --- a/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py +++ b/examples/deepseek_nsa/example_tilelang_nsa_fwd_varlen.py @@ -17,13 +17,7 @@ from einops import rearrange -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - } -) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def native_sparse_attention_varlen(batch, heads, c_seq_len, dim, is_causal, scale=None, block_size=64, groups=1, selected_blocks=16): if scale is None: scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) diff --git a/examples/deepseek_v32/fp8_lighting_indexer.py b/examples/deepseek_v32/fp8_lighting_indexer.py index 03e88dd972..2f88575971 100644 --- a/examples/deepseek_v32/fp8_lighting_indexer.py +++ b/examples/deepseek_v32/fp8_lighting_indexer.py @@ -283,28 +283,16 @@ def run_regression_perf(S=4096, SKV=8192, H=32, HKV=1, D=64, kv_stride=1): q = torch.randn(S, H, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) kv = torch.randn(SKV, D, device="cuda", dtype=torch.bfloat16).to(torch.bfloat16) weights = torch.randn(S, H, device="cuda", dtype=torch.float32) - p = (torch.randn(S, SKV, device="cuda", dtype=torch.float32) * 4).softmax(dim=-1) - ks, ke = generate_random_cu_seqlens(per_cp_seqlen=S, cp_size=4, cp_rank=3, kv_stride=kv_stride, average_q_len=2048) - logits_ref, cost_ref = ref_fp8_mqa_logits(q=q, kv=kv, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - q_fp8 = q.to(torch.float8_e4m3fn) kv_fp8, kv_scales = per_custom_dims_cast_to_fp8(kv, (0,), False) - logits_tl = mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - diff = validate_tensor_match(logits_ref, logits_tl, tolerance=1e-14, tensor_name="logits", should_raise=False) - from tilelang.profiler import do_bench def logits_fn(): return mqa_attn_return_logits_interface(q=q_fp8, kv=kv_fp8, kv_scales=kv_scales, weights=weights, cu_seqlen_ks=ks, cu_seqlen_ke=ke) - with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as prof: - logits_fn() - - print(prof.key_averages().table(sort_by="cuda_time_total", max_name_column_width=50)) - return do_bench(logits_fn, backend="cupti") diff --git a/examples/deepseek_v32/inference/README.md b/examples/deepseek_v32/inference/README.md index 60afe7ceb1..f5cd62491c 100644 --- a/examples/deepseek_v32/inference/README.md +++ b/examples/deepseek_v32/inference/README.md @@ -1,6 +1,6 @@ # DeepSeek V3.2 -First convert huggingface model weights to the the format required by our inference demo. Set `MP` to match your available GPU count: +First convert huggingface model weights to the format required by our inference demo. Set `MP` to match your available GPU count: ```bash cd inference export EXPERTS=256 diff --git a/examples/deepseek_v32/inference/convert.py b/examples/deepseek_v32/inference/convert.py index 090be71455..cb912e1d8b 100644 --- a/examples/deepseek_v32/inference/convert.py +++ b/examples/deepseek_v32/inference/convert.py @@ -29,8 +29,7 @@ "wq_b": ("wq_b", None), "wk": ("wk", None), "k_norm": ("k_norm", None), - "weights_proj": ("weights_proj", None), -} + "weights_proj": ("weights_proj", None)} def main(hf_ckpt_path, save_path, n_experts, mp): diff --git a/examples/deepseek_v32/inference/kernel.py b/examples/deepseek_v32/inference/kernel.py index 25abf15d59..9d9402d1a8 100644 --- a/examples/deepseek_v32/inference/kernel.py +++ b/examples/deepseek_v32/inference/kernel.py @@ -7,9 +7,7 @@ pass_configs = { tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True, -} + tilelang.PassConfigKey.TL_DISABLE_FAST_MATH: True} FP8 = T.float8_e4m3fn BF16 = T.bfloat16 @@ -17,15 +15,15 @@ def fast_log2_ceil(x): - bits_x = T.reinterpret(T.uint32, x) + bits_x = T.reinterpret(x, T.uint32) exp_x = (bits_x >> 23) & 0xFF man_bits = bits_x & ((1 << 23) - 1) - return T.Cast(T.int32, exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0)) + return T.cast(exp_x - 127 + T.if_then_else(man_bits != 0, 1, 0), T.int32) def fast_pow2(x): bits_x = (x + 127) << 23 - return T.reinterpret(T.float32, bits_x) + return T.reinterpret(bits_x, T.float32) def fast_round_scale(amax, fp8_max_inv): diff --git a/examples/deepseek_v32/sparse_mla_bwd.py b/examples/deepseek_v32/sparse_mla_bwd.py index 527de22b39..50192fa2bf 100644 --- a/examples/deepseek_v32/sparse_mla_bwd.py +++ b/examples/deepseek_v32/sparse_mla_bwd.py @@ -76,7 +76,6 @@ def postprocess_kernel( @tilelang.jit( out_idx=[-2], pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, }, diff --git a/examples/deepseek_v32/sparse_mla_fwd.py b/examples/deepseek_v32/sparse_mla_fwd.py index 2c8bf7fc74..5426d9072b 100644 --- a/examples/deepseek_v32/sparse_mla_fwd.py +++ b/examples/deepseek_v32/sparse_mla_fwd.py @@ -7,10 +7,7 @@ @tilelang.jit( out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def sparse_mla_fwd( heads, @@ -161,9 +158,7 @@ def main( for h_i in T.Parallel(H_per_block): sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale - T.copy(acc_o, O_shared) T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) - T.copy(sumexp, Lse_shared) T.copy(sumexp, Lse[b_i, s_i, H0:H1]) return main diff --git a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py index 7e664d11b4..bff9f19b98 100644 --- a/examples/deepseek_v32/sparse_mla_fwd_pipelined.py +++ b/examples/deepseek_v32/sparse_mla_fwd_pipelined.py @@ -133,9 +133,9 @@ def main( tx = T.get_thread_binding() - T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) - T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) - T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + T.tma_copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l, barrier=bar_q) + T.tma_copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r, barrier=bar_q) + T.tma_copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared, barrier=bar_q) T.barrier_arrive(bar_q) if tx < 128: @@ -151,9 +151,9 @@ def main( for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_0_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_0_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_0, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -187,9 +187,9 @@ def main( for h_i, bi_i in T.Parallel(H_per_block, BI): acc_s[h_i, bi_i] = T.if_then_else(is_kv_valid[bi_i], 0, -T.infinity(acc_s.dtype)) - T.gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True, wg_wait=-1) - T.gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True, wg_wait=-1) + T.wgmma_gemm(Q_shared_l, KV_shared_1_l, acc_s, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_1_r, acc_s, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_1, acc_s, transpose_B=True) T.wait_wgmma(0) @@ -266,20 +266,23 @@ def main( indices_local = Indices[b_i, s_i, g_i, (i_i * 2) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v - ] + # Manually issue cp.async copies for KV_left, KV_right, and K_tail. + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, D + (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_0_ready[0]) # Buffer 1 @@ -288,20 +291,23 @@ def main( indices_local = Indices[b_i, s_i, g_i, (i_i * 2 + 1) * BI + r * 16 + (tx - 256) // 8] is_kv_valid[r * 16 + (tx - 256) // 8] = indices_local <= max_kv_i if is_kv_valid[r * 16 + (tx - 256) // 8]: - with T.attr("default", "async_scope", 1): - for u in T.serial(4): - for v in T.vectorized(8): - KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8 + v - ] - KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8 + v - ] - with T.attr("default", "async_scope", 1): - for v in T.vectorized(8): - K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8 + v] = KV[ - b_i, indices_local, g_i, D + (tx - 256) % 8 * 8 + v - ] + # Manually issue cp.async copies for KV_left, KV_right, and K_tail. + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, indices_local, g_i, D + (tx - 256) % 8 * 8], "r", 8), + 8, + ) T.cp_async_barrier_noinc(bar_k_1_ready[0]) return main @@ -410,8 +416,6 @@ def fn(): tl_out, tl_lse = fn() ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) - # print(f"tl_out: {tl_out}") - # print(f"ref_out: {ref_out}") torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) @@ -450,12 +454,12 @@ def run_regression_perf(B=1, S=4096, SKV=8192, H=128, HKV=1, DQK=576, DV=512, to CP0 = q_start_s_index == 0 kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, KV_stride, kv_group, None, True, CP0) - def run_kernel_only(): + def fn(): kernel(q, kv, indices, torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda")) from tilelang.profiler import do_bench - return do_bench(run_kernel_only, backend="cupti") + return do_bench(fn, backend="cupti") if __name__ == "__main__": diff --git a/examples/deepseek_v32/sparse_mla_fwd_seesaw.py b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py new file mode 100644 index 0000000000..cdffac281a --- /dev/null +++ b/examples/deepseek_v32/sparse_mla_fwd_seesaw.py @@ -0,0 +1,643 @@ +# ruff: noqa +import torch +import tilelang +from tilelang import language as T +import argparse + + +@tilelang.jit( + out_idx=[-2, -1], + compile_flags=[ + "-O3", + "--ptxas-options=-v,--register-usage-level=10", + "-DNDEBUG", + "-Wno-deprecated-declarations", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + ], +) +def sparse_mla_fwd( + batch, + seq_len, + seq_len_kv, + heads, + dim, + tail_dim, + topk, + kv_stride, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=0, + threads=384, +): + assert dim == tilelang.math.next_power_of_2(dim), f"haven't check padding correctness yet, dim={dim}" + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), f"haven't check padding correctness yet, dim={tail_dim}" + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, "otherwise will load some index=0 thus causing wrong kv to be loaded" + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = "int32" + dtype = "bfloat16" + accum_dtype = "float" + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you " + "should handle Q copy and Output copy with your mask (when " + "kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would " + "be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + assert NI % 2 == 0, "NI should be a multiple of 2" + D = dim + D_tail = tail_dim + KV_stride = kv_stride + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + # Increasing from 32->64 reduces the time spent reading kvcache. If num_query_head = 128 + # and num_kv_head = 1, the same kvcache originally needed to be read 4 times, but now only 2 times + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + q_start_index_s: T.Tensor(1, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel( + # If CP0 is True (i.e., start of sequence), skip the first (KV_stride - 1) + # queries that cannot see any KV. Also be careful that seq_len < kv_stride could cause negative grid size + (max(0, seq_len - kv_stride + 1) if CP0 else seq_len) * REPLICATE_H, + batch, + kv_group, + threads=threads, + ) as (bx, by, bz): + Q_shared_l = T.alloc_shared([H_per_block, D // 2], dtype) + Q_shared_r = T.alloc_shared([H_per_block, D // 2], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + + KV_shared_0_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_0_r = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_l = T.alloc_shared([BI, D // 2], dtype) + KV_shared_1_r = T.alloc_shared([BI, D // 2], dtype) + K_tail_shared_0 = T.alloc_shared([BI, D_tail], dtype) + K_tail_shared_1 = T.alloc_shared([BI, D_tail], dtype) + + O_shared_l = Q_shared_l + O_shared_r = Q_shared_r + + # Whether the kv in current BI is visible for this query + # Producer alternates writing to buf0 and buf1 masks. To avoid the situation + # where consumer0 is still reading buf0 mask when producer has already started + # writing buf1 mask, we use two buf masks + is_kv_valid = T.alloc_shared([2, BI], "bool", scope="shared") + + acc_o_l = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + acc_o_r = T.alloc_fragment([H_per_block, D // 2], accum_dtype) + + # WG0 computes S0(BI_2*i), WG1 computes S1(BI_2*i+1), shared via shared memory + + # Reuse K_tail_shared for S_shared to save memory when dimensions match + # Must reuse, otherwise H100 SM's shared mem is insufficient (> 228kb), this is shared mem bound + S_shared_0 = K_tail_shared_0 + S_shared_1 = K_tail_shared_1 + + # WG0 and WG1 exchange local max with each other, compare to compute global max, and rescale their O_L or O_R accordingly + row_max_shared_0 = T.alloc_shared([H_per_block], accum_dtype) + row_max_shared_1 = T.alloc_shared([H_per_block], accum_dtype) + + # Used to store sum of exps for even BI and odd BI respectively, which will be summed up for integration later + row_sum_shared_0 = T.alloc_shared([H_per_block], accum_dtype) + row_sum_shared_1 = T.alloc_shared([H_per_block], accum_dtype) + + # acc_s, sumexp, m_i each need to be allocated separately for consumer0 and consumer1 + acc_s_0 = T.alloc_fragment([H_per_block, BI], accum_dtype) + acc_s_1 = T.alloc_fragment([H_per_block, BI], accum_dtype) + + sumexp_0 = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev_0 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_peer_0 = T.alloc_fragment([H_per_block], accum_dtype) + + sumexp_1 = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev_1 = T.alloc_fragment([H_per_block], accum_dtype) + m_i_peer_1 = T.alloc_fragment([H_per_block], accum_dtype) + + bar_q = T.alloc_barrier(arrive_count=384) + + # Producer -> Consumer Barriers + bar_k_0_ready = T.alloc_barrier(arrive_count=128) # Prod arrives + bar_k_1_ready = T.alloc_barrier(arrive_count=128) # Prod arrives + + # Consumer -> Producer Barriers (Both consumers must arrive) + bar_k_0_free = T.alloc_barrier(arrive_count=256) + bar_k_1_free = T.alloc_barrier(arrive_count=256) + + # Inter-Consumer Barriers (Seesaw Sync) + bar_stats_0_ready = T.alloc_barrier(arrive_count=128) # Cons 0 arrives + bar_stats_1_ready = T.alloc_barrier(arrive_count=128) # Cons 1 arrives + + bar_S_0_ready = T.alloc_barrier(arrive_count=128) # Cons 0 arrives + bar_S_1_ready = T.alloc_barrier(arrive_count=128) # Cons 1 arrives + + b_i, g_i = by, bz + # If it's the first chunk, start computing directly from the (kv_stride - 1)-th token + s_i = (bx + (KV_stride - 1 if CP0 else 0)) if REPLICATE_H == 1 else (bx // REPLICATE_H + (KV_stride - 1 if CP0 else 0)) + q_i = q_start_index_s[0] + s_i + # Sometimes to reduce kvcache size, we may not store KV for every token, but store + # KV every KV_stride tokens (usually the last token in the stride window), + # so the kv range visible to the current query should be [0:max_kv_i] + max_kv_i = (q_i + 1 - KV_stride) // KV_stride + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + tx = T.get_thread_binding() + + T.copy(Q[b_i, s_i, H0:H1, 0 : D // 2], Q_shared_l) + T.copy(Q[b_i, s_i, H0:H1, D // 2 : D], Q_shared_r) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + # Non-blockingly increment the barrier's internal counter, producer threads can start loading kv ahead of time + T.barrier_arrive(bar_q) + + if tx >= 256: + # producer: prefetch kvcache to shared mem + T.set_max_nreg(72, 0) + + prefetch_indices_0 = T.alloc_fragment([4], indices_dtype) + prefetch_indices_1 = T.alloc_fragment([4], indices_dtype) + + # Prime the Pump! Prefetch indices for iter_0 + for r in T.serial(4): + # This read will cause a long scoreboard stall, but it only happens once before the loop starts + prefetch_indices_0[r] = Indices[b_i, s_i, g_i, r * 16 + (tx - 256) // 8] + prefetch_indices_1[r] = Indices[b_i, s_i, g_i, BI + r * 16 + (tx - 256) // 8] + + for i_i in T.serial(T.ceildiv(NI, 2)): + # Buffer 0 + # Wait for both KV_shared_0_l and KV_shared_0_r to be done being used + + T.barrier_wait(bar_k_0_free[0], (i_i & 1)) + + # Block size `BI` is 64, loading is divided into 4 iterations, each processing 16 indices + # Producer has 128 threads total, 8 consecutive threads collaborate to load kv for one index + for r in T.serial(4): + # mitigate long scoreboard stall here + index = prefetch_indices_0[r] + is_kv_valid[0, r * 16 + (tx - 256) // 8] = index <= max_kv_i + if is_kv_valid[0, r * 16 + (tx - 256) // 8]: + # 8 threads collaborate to load one row of KV_dim (512) in 4 iters, each loading 8 elems + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_0_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_0_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + # tail_dim (64) needs only one iter of 8 elems per 8 collaborating threads + T.ptx_cp_async( + T.access_ptr(K_tail_shared_0[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, D + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.cp_async_barrier_noinc(bar_k_0_ready[0]) + + if i_i + 1 < T.ceildiv(NI, 2): + # Async prefetch indices needed for the next round of kv data loading, overlaps with current round to hide latency + for r in T.serial(4): + prefetch_indices_0[r] = Indices[b_i, s_i, g_i, ((i_i + 1) * 2) * BI + r * 16 + (tx - 256) // 8] + + # Buffer 1 + T.barrier_wait(bar_k_1_free[0], (i_i & 1)) + + for r in T.serial(4): + index = prefetch_indices_1[r] + is_kv_valid[1, r * 16 + (tx - 256) // 8] = index <= max_kv_i + if is_kv_valid[1, r * 16 + (tx - 256) // 8]: + for u in T.serial(4): + T.ptx_cp_async( + T.access_ptr(KV_shared_1_l[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(KV_shared_1_r[r * 16 + (tx - 256) // 8, 64 * u + (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, D // 2 + 64 * u + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.ptx_cp_async( + T.access_ptr(K_tail_shared_1[r * 16 + (tx - 256) // 8, (tx - 256) % 8 * 8], "w", 8), + T.access_ptr(KV[b_i, index, g_i, D + (tx - 256) % 8 * 8], "r", 8), + 8, + ) + T.cp_async_barrier_noinc(bar_k_1_ready[0]) + + if i_i + 1 < T.ceildiv(NI, 2): + for r in T.serial(4): + prefetch_indices_1[r] = Indices[b_i, s_i, g_i, ((i_i + 1) * 2 + 1) * BI + r * 16 + (tx - 256) // 8] + + elif tx < 128: + # Check if 384 threads have already arrived at bar_q (phase0 completed), + # if not continue waiting, otherwise pass through directly + T.barrier_wait(bar_q, 0) + + # pre-arrive free barriers to indicate buffers are initially free + # At the beginning of phase0, tells producer it can load data into both buffers + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_k_1_free[0]) + + # Consumer 0 (WG0): Responsible for Even Blocks and O_L (Left Half) + T.set_max_nreg(216, 1) + T.fill(sumexp_0, 0) + for h_i in T.Parallel(H_per_block): + m_i_0[h_i] = -5e4 + T.fill(acc_o_l, 0) + + # Each iteration, two consumers cooperate to compute two BIs + for i_i in T.serial(T.ceildiv(NI, 2)): + # --- Step 1: Compute S0 = Q @ K0^T (Even Block) --- + T.barrier_wait(bar_k_0_ready[0], (i_i & 1)) + + T.fill(acc_s_0, 0) + T.wgmma_gemm(Q_shared_l, KV_shared_0_l, acc_s_0, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_0_r, acc_s_0, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_0, acc_s_0, transpose_B=True) + + T.copy(m_i_0, m_i_prev_0) + T.wait_wgmma(0) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + if not is_kv_valid[0, bi_i]: + acc_s_0[h_i, bi_i] = -5e4 + T.reduce_max(acc_s_0, m_i_0, dim=1, clear=False) + + # --- Step 2: Local Softmax Stats & Exchange --- + T.copy(m_i_0, row_max_shared_0) + T.barrier_arrive(bar_stats_0_ready) + # If consumer0 has received the local max from consumer1 at iter_i, this also means + # consumer1 has finished using S_0 passed by consumer0 at iter_i-1, + # so we can write to it directly without blocking below + T.barrier_wait(bar_stats_1_ready, (i_i & 1)) + T.copy(row_max_shared_1, m_i_peer_0) + + # Update global max and scale O + for h_i in T.Parallel(H_per_block): + m_i_0[h_i] = T.max(m_i_0[h_i], m_i_peer_0[h_i]) + + # Scale O_L + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] *= T.exp2((m_i_prev_0[h_i] - m_i_0[h_i]) * sm_scale) + + # Scale SumExp + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] *= T.exp2((m_i_prev_0[h_i] - m_i_0[h_i]) * sm_scale) + + # Compute P0 = exp(S0 - m_new) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s_0[h_i, bi_i] = T.exp2(acc_s_0[h_i, bi_i] * sm_scale - m_i_0[h_i] * sm_scale) + + # Update SumExp with P0 + T.reduce_sum(acc_s_0, sumexp_i_0, dim=1) + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] += sumexp_i_0[h_i] + + # --- Step 3: O_L += P0 @ V0_L (Self-Attention) --- + # Wait for S0 buffer to be free (consumed by peer in prev iter) + # T.barrier_wait(bar_S_0_free, (i_i & 1)) + T.copy(acc_s_0, S_shared_0) + T.barrier_arrive(bar_S_0_ready) + + T.wgmma_gemm(S_shared_0, KV_shared_0_l, acc_o_l, transpose_B=False) + + # --- Step 4: O_L += P1 @ V1_L (Cross-Attention) --- + # Wait for P1 (S1) from peer + T.barrier_wait(bar_S_1_ready, (i_i & 1)) + + T.wgmma_gemm(S_shared_1, KV_shared_1_l, acc_o_l, transpose_B=False) + + # NOTE: However, k_0 and k_1 are used by both consumer0 and consumer1, so this doesn't bring much performance improvement + # Except for the most recent async gemm (i.e., S_shared_1 @ KV_shared_1_k), all others need to wait to finish + T.wait_wgmma(1) + T.barrier_arrive(bar_k_0_free[0]) + # Wait for all async gemms to finish + T.wait_wgmma(0) + T.barrier_arrive(bar_k_1_free[0]) + + T.copy(sumexp_0, row_sum_shared_0) + T.barrier_arrive(bar_stats_0_ready) # Reuse barrier + T.barrier_wait(bar_stats_1_ready, T.ceildiv(NI, 2) & 1) + T.copy(row_sum_shared_1, sumexp_i_0) # Reuse sumexp_i buffer + + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] += sumexp_i_0[h_i] + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_l[h_i, d_i] /= sumexp_0[h_i] + + for h_i in T.Parallel(H_per_block): + sumexp_0[h_i] = T.log2(sumexp_0[h_i]) + m_i_0[h_i] * sm_scale + + T.copy(acc_o_l, O_shared_l) + T.copy(O_shared_l, Output[b_i, s_i, H0:H1, 0 : D // 2]) + T.copy(sumexp_0, Lse[b_i, s_i, H0:H1]) # Write LSE + + elif tx >= 128 and tx < 256: + T.barrier_wait(bar_q, 0) + + # pre-arrive free barriers to indicate buffers are initially free + # At the beginning of phase0, tells producer it can load data into both buffers + T.barrier_arrive(bar_k_0_free[0]) + T.barrier_arrive(bar_k_1_free[0]) + + # Consumer 1 (WG1): Responsible for Odd Blocks and O_R (Right Half) + # NOTE: 256 * 216 + 128 * 72 = 64,512 < 65536 (H100 SM RegFile Limit), + # setting more registers will cause a hang, all values must be multiples of 8 + T.set_max_nreg(216, 1) + T.fill(sumexp_1, 0) + for h_i in T.Parallel(H_per_block): + m_i_1[h_i] = -5e4 + T.fill(acc_o_r, 0) + + for i_i in T.serial(T.ceildiv(NI, 2)): + # --- Step 1: Compute S1 = Q @ K1^T (Odd Block) --- + T.barrier_wait(bar_k_1_ready[0], (i_i & 1)) + + T.fill(acc_s_1, 0) + T.wgmma_gemm(Q_shared_l, KV_shared_1_l, acc_s_1, transpose_B=True) + T.wgmma_gemm(Q_shared_r, KV_shared_1_r, acc_s_1, transpose_B=True) + T.wgmma_gemm(Q_tail_shared, K_tail_shared_1, acc_s_1, transpose_B=True) + + # --- Step 2: Local Softmax Stats & Exchange --- + T.copy(m_i_1, m_i_prev_1) + T.wait_wgmma(0) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + if not is_kv_valid[1, bi_i]: + acc_s_1[h_i, bi_i] = -5e4 + + T.reduce_max(acc_s_1, m_i_1, dim=1, clear=False) + T.copy(m_i_1, row_max_shared_1) + T.barrier_arrive(bar_stats_1_ready) + T.barrier_wait(bar_stats_0_ready, (i_i & 1)) + T.copy(row_max_shared_0, m_i_peer_1) + + for h_i in T.Parallel(H_per_block): + m_i_1[h_i] = T.max(m_i_1[h_i], m_i_peer_1[h_i]) + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] *= T.exp2((m_i_prev_1[h_i] - m_i_1[h_i]) * sm_scale) + + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] *= T.exp2((m_i_prev_1[h_i] - m_i_1[h_i]) * sm_scale) + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s_1[h_i, bi_i] = T.exp2(acc_s_1[h_i, bi_i] * sm_scale - m_i_1[h_i] * sm_scale) + + T.reduce_sum(acc_s_1, sumexp_i_1, dim=1) + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] += sumexp_i_1[h_i] + + # --- Step 3: O_R += P1 @ V1_R (Self-Attention) --- + T.copy(acc_s_1, S_shared_1) + + T.barrier_arrive(bar_S_1_ready) + + T.wgmma_gemm(S_shared_1, KV_shared_1_r, acc_o_r, transpose_B=False) + + # --- Step 4: O_R += P0 @ V0_R (Cross-Attention) --- + T.barrier_wait(bar_S_0_ready, (i_i & 1)) + + T.wgmma_gemm(S_shared_0, KV_shared_0_r, acc_o_r, transpose_B=False) + + T.wait_wgmma(1) + T.barrier_arrive(bar_k_1_free[0]) + T.wait_wgmma(0) + T.barrier_arrive(bar_k_0_free[0]) + + T.copy(sumexp_1, row_sum_shared_1) + T.barrier_arrive(bar_stats_1_ready) + T.barrier_wait(bar_stats_0_ready, T.ceildiv(NI, 2) & 1) + T.copy(row_sum_shared_0, sumexp_i_1) + + for h_i in T.Parallel(H_per_block): + sumexp_1[h_i] += sumexp_i_1[h_i] + + for h_i, d_i in T.Parallel(H_per_block, D // 2): + acc_o_r[h_i, d_i] /= sumexp_1[h_i] + + T.copy(acc_o_r, O_shared_r) + T.copy(O_shared_r, Output[b_i, s_i, H0:H1, D // 2 : D]) + + return main + + +def sparse_mla_fwd_interface( + q, kv, indices, q_start_index_s, kv_stride, sm_scale=None, is_casual=True, return_kernel=False, print_kernel=False +): + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = 512 + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + if q_start_index_s != 0: + assert q_start_index_s > kv_stride, ( + "If it is because each cp has too short length, you should fix the logic involving CP0 (cp_rank == 0), to make sure q with pos < KV_Stride - 1 is masked (or you may just ignore how this is handled if nan in these q's Out would not effect others, which is reported to be likely to happen by wangding)" + ) + CP0 = q_start_index_s == 0 + + # Compile the kernel + kernel = sparse_mla_fwd(batch, seq_len, seq_len_kv, heads, dim, tail_dim, topk, kv_stride, kv_group, sm_scale, is_casual, CP0) + + if print_kernel: + print(kernel.get_kernel_source()) + + if return_kernel: + return kernel + + ( + out, + lse, + ) = kernel(q, kv, indices, torch.tensor([q_start_index_s], dtype=torch.int32, device="cuda")) + if q_start_index_s == 0 and kv_stride > 1: + # Set the output of the first (kv_stride - 1) positions to 0, since they cannot see any kv so no computation was performed + out[:, : kv_stride - 1, :, :] = 0 + return out, lse + + +def ref_sparse_mla_fwd_interface(q, kv, indices, q_start_index_s, kv_stride=1, sm_scale=None, is_casual=True): + q = q.float() + kv = kv.float() + indices = indices.transpose(1, 2) + b, sq, h, dim_q = q.shape + b, sk, g, _ = kv.shape + if q_start_index_s is None: + q_start_index_s = sk * kv_stride - sq + + assert kv.shape[-1] == 576, "you should assign dim otherwise" + dim = 512 + k = kv + v = kv[..., :dim] + + b, _, _, dim_v = v.shape + num_kv_per_index = 1 + g_index = g + h_index = h // g + compressed_casual_mask = torch.arange(q_start_index_s, sq + q_start_index_s, dtype=torch.int32, device="cuda").view( + -1, 1 + ) >= torch.arange(kv_stride - 1, sk * kv_stride, kv_stride, dtype=torch.int32, device="cuda").view(1, -1) + + mask = q.new_zeros(b, g_index, sq, sk + 1, dtype=torch.bool).scatter(3, indices.long(), 1) + mask = mask[..., :-1] + mask = mask & compressed_casual_mask.view(1, 1, sq, sk) + mask[:, :, : kv_stride - 1, 0] = True + mask = mask.view(b, g_index, 1, sq, sk) + + q = q.view(b, sq, g, -1, dim_q) + score = torch.einsum("bmghd,bngd->bghmn", q, k) + sm_scale = dim_q**-0.5 if sm_scale is None else sm_scale + score = score.masked_fill(~mask, float("-inf")).mul(sm_scale) + p = score.softmax(dim=-1) + p = p.view(b, g_index, h_index, -1, sq, sk) + p = p.view(b, g, -1, sq, sk) + o = torch.einsum("bghmn,bngd->bmghd", p.type(v.dtype), v) + o = o.reshape(b, sq, h, dim_v) + return o.to(torch.bfloat16) + + +def test_sparse_mla_fwd_pipelined( + B=1, + S=4096, + SKV=8192, + H=128, + HKV=1, + DQK=576, + DV=512, + topk=2048, + dtype=torch.bfloat16, + # Offset of query in global sequence position (or relative to kv) + q_start_s_index=2048, + check_correctness=True, + profile=False, +): + KV_stride = 1 + + torch.random.manual_seed(0) + q = torch.randn((B, S, H, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + kv = torch.randn((B, SKV, HKV, DQK), dtype=dtype, device="cuda").requires_grad_(True) / 10 + q_start_s_index_t = torch.tensor([q_start_s_index], dtype=torch.int32, device="cuda") + + q.clamp_(-10, 10) + kv.clamp_(-10, 10) + + indices = torch.full((B, S, HKV, topk), SKV, dtype=torch.int32, device="cuda") + for b in range(B): + for t in range(S): + for h in range(HKV): + # Add offset q_start_s_index to convert to global sequence position + i_i = torch.randperm(min(max(1, ((t + q_start_s_index) // KV_stride)), SKV))[:topk] + indices[b, t, h, : len(i_i)] = i_i + + print("index generation finished") + + kernel = sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride, return_kernel=True, print_kernel=True) + + def fn(): + return kernel(q, kv, indices, q_start_s_index_t) + + if check_correctness: + tl_out, tl_lse = fn() + assert KV_stride == 1, "KV_stride > 1 not supported" + # if q_start_s_index == 0 and KV_stride > 1: + # tl_out[:, :KV_stride - 1, :, :] = 0 + ref_out = ref_sparse_mla_fwd_interface(q, kv, indices, q_start_s_index, KV_stride) + print(f"tl_out: {tl_out}") + print(f"ref_out: {ref_out}") + torch.testing.assert_close(tl_out, ref_out, rtol=1e-3, atol=1e-3) + + if profile: + print("Profiling mode: running minimal iterations (1 warmup + 1 run)...") + fn() + torch.cuda.synchronize() + fn() + torch.cuda.synchronize() + return + + from tilelang.profiler import do_bench + + ms = do_bench( + fn, + rep=20, + warmup=10, + ) + print(f"Average time: {ms:.3f} ms") + print(f"fwd io bandwidth = ", (B * S * DQK * topk * 2) / (ms * 1e-3) / 1e12) + tflops = (B * S * (DQK + DV) * topk * 2 * H) / (ms * 1e-3) / 1e12 + print(f"fwd tflops = {tflops:.2f}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--test_correctness", action="store_true") + parser.add_argument("--profile", action="store_true") + args = parser.parse_args() + if args.test_correctness: + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 1, 1024, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined(B, S, SKV, H, HKV, DQK, DV, topk, dtype, check_correctness=True, profile=args.profile) + else: + # Prefill Benchmark: long context + print(" --- Prefill Benchmark --- ") + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 2, 4096, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined( + B, S, SKV, H, HKV, DQK, DV, topk, dtype, q_start_s_index=4096, check_correctness=False, profile=args.profile + ) + + # Decode Benchmark: large batch size, high throughput generation + print("\n --- Decode Benchmark --- ") + # Increase batch size to saturate h100 for decode + B, S, SKV, H, HKV, DQK, DV, topk, dtype = 128 * 16, 2, 8192, 128, 1, 576, 512, 2048, torch.bfloat16 + test_sparse_mla_fwd_pipelined( + B, S, SKV, H, HKV, DQK, DV, topk, dtype, q_start_s_index=2048 + 4096, check_correctness=False, profile=args.profile + ) diff --git a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py index 983798f9f0..9e4c6a63d9 100644 --- a/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py +++ b/examples/deepseek_v32/test_tilelang_example_deepseek_v32.py @@ -9,30 +9,34 @@ import sparse_mla_bwd +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_topk_selector(): topk_selector.test_topk_selector() +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_fp8_lighting_indexer(): fp8_lighting_indexer.test_fp8_lighting_indexer(S=512, SKV=1024, H=32, HKV=1, D=64, kv_stride=1) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_sparse_mla_fwd(): # small shapes for testing sparse_mla_fwd.test_sparse_mla_fwd(S=256, SKV=1024, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_sparse_mla_fwd_pipelined(): # small shapes for testing sparse_mla_fwd_pipelined.test_sparse_mla_fwd_pipelined(S=256, SKV=512, H=64, HKV=1, DQK=576, DV=512, topk=256, check_correctness=False) @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_sparse_mla_bwd(): sparse_mla_bwd.test_sparse_mla_bwd(S=256, SKV=512, H=64, HKV=1, DQKV=576, DV=512, topk=256, check_correctness=False) sparse_mla_bwd.test_sparse_mla_bwd( diff --git a/examples/deepseek_v32/topk_selector.py b/examples/deepseek_v32/topk_selector.py index 078eb26868..8b29c6fd5e 100644 --- a/examples/deepseek_v32/topk_selector.py +++ b/examples/deepseek_v32/topk_selector.py @@ -8,18 +8,18 @@ def convert_to_uint16(x): - hval = T.Cast(T.float16, x) - bits_uint = T.reinterpret(T.uint16, hval) + hval = T.cast(x, T.float16) + bits_uint = T.reinterpret(hval, T.uint16) bits_uint = T.if_then_else(x < 0, ~bits_uint & (0xFFFF), bits_uint | (0x8000)) return bits_uint >> 8 def convert_to_uint32(x): - bits_uint = T.reinterpret(T.uint32, x) + bits_uint = T.reinterpret(x, T.uint32) bits_uint = T.if_then_else( x < 0, - ~bits_uint & T.Cast(T.uint32, (0xFFFFFFFF)), - bits_uint | T.Cast(T.uint32, (0x80000000)), + ~bits_uint & T.cast((0xFFFFFFFF), T.uint32), + bits_uint | T.cast((0x80000000), T.uint32), ) return bits_uint @@ -57,6 +57,8 @@ def tl_topk_kernel( l_end_idx = T.alloc_var(T.int32) l_out_pos = T.alloc_var(T.int32) + pos = T.alloc_var(T.int32) + l_new_topk = topk l_start_idx = starts[bx] l_end_idx = ends[bx] @@ -99,7 +101,7 @@ def tl_topk_kernel( input_idx = s * BLOCK_SIZE + tx if input_idx < l_end_idx and input_idx >= l_start_idx and input_idx < seq_len: bin_id = convert_to_uint16(input[bx, input_idx]) - l_bin_id32 = T.Cast(T.int32, bin_id) + l_bin_id32 = T.cast(bin_id, T.int32) if l_bin_id32 > l_threshold_bin_id: # need a pos = T.atomic_add(s_histogram[bin_id32+1], 1) pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) @@ -113,7 +115,7 @@ def tl_topk_kernel( # stage 2: tail pass for round in T.serial(4): if l_new_topk <= 0: - T.loop_break() + break r_idx = round % 2 l_start_pos = topk - l_new_topk @@ -127,8 +129,8 @@ def tl_topk_kernel( l_num_input = s_num_input[r_idx] for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + l_bin_id32 = T.cast( + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF), T.int32 ) T.atomic_add(s_histogram[l_bin_id32], 1) T.sync_threads() @@ -156,8 +158,8 @@ def tl_topk_kernel( for s in T.serial(T.ceildiv(l_num_input, BLOCK_SIZE)): T.sync_threads() if s * BLOCK_SIZE + tx < l_num_input: - l_bin_id32 = T.Cast( - T.int32, ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF) + l_bin_id32 = T.cast( + ((convert_to_uint32(input[bx, s_input_idx[r_idx, s * BLOCK_SIZE + tx]]) >> (24 - round * 8)) & 0xFF), T.int32 ) if l_bin_id32 > l_threshold_bin_id: pos = T.atomic_add(s_histogram[l_bin_id32 + 1], 1, return_prev=True) + l_start_pos @@ -183,9 +185,6 @@ def tl_topk(input, starts, ends, topk): def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): - batch = 64 - seq_len = 32 * 1024 - topk = 2048 torch.manual_seed(1) input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() starts = torch.zeros(batch, dtype=torch.int32).cuda() @@ -241,27 +240,11 @@ def test_topk_selector(batch=64, seq_len=32 * 1024, topk=2048): def run_regression_perf(batch=64, seq_len=32 * 1024, topk=2048): - batch = 64 - seq_len = 32 * 1024 - topk = 2048 torch.manual_seed(1) input = torch.randn(batch, seq_len, dtype=torch.float32).cuda() starts = torch.zeros(batch, dtype=torch.int32).cuda() ends = torch.ones(batch, dtype=torch.int32).cuda() * seq_len - indexes = tl_topk(input, starts, ends, topk) - - indexes_ref = torch.topk(input, topk, dim=-1)[1] - - for i in range(batch): - ref_np = indexes_ref[i].cpu().to(torch.int32).numpy() - trt_np = indexes[i].cpu().to(torch.int32).numpy() - - set_ref = set(ref_np) - set_trt = set(trt_np) - intersection = set_ref & set_trt - print("selected/all:", len(intersection), "/", len(set_ref), "=", len(intersection) / len(set_ref)) - from tilelang.profiler import do_bench def run_kernel_only(): diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py index 36b32c0a8a..2ae9bdf3eb 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py @@ -41,7 +41,7 @@ def get_configs(): ) @tilelang.jit( out_idx=[-1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def matmul( M, @@ -180,8 +180,8 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): # Then, dequant. T.call_extern( func_name, - T.address_of(B_local_thread[0]), - T.address_of(B_dequantize_local_thread[0]), + T.access_ptr(B_local_thread, "r"), + T.access_ptr(B_dequantize_local_thread, "w"), 1, dtype=out_dtype, ) diff --git a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py index cc37c8bc42..0842e16856 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py +++ b/examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py @@ -246,8 +246,8 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k): # Then, dequant. T.call_extern( func_name, - T.address_of(B_local_thread[0]), - T.address_of(B_dequantize_local_thread[0]), + T.access_ptr(B_local_thread, "r"), + T.access_ptr(B_dequantize_local_thread, "w"), 1, dtype=out_dtype, ) diff --git a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py index 37826874bc..a870208083 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py +++ b/examples/dequantize_gemm/example_dequant_gemm_fine_grained.py @@ -302,8 +302,16 @@ def main( T.call_extern( "handle", "decode_i4u_to_f16", - T.address_of(B_local[j * local_size_b // num_elems_per_byte]), - T.address_of(B_dequantize_local[j * local_size_b]), + T.access_ptr( + B_local[j * local_size_b // num_elems_per_byte], + "r", + local_size_b // num_elems_per_byte, + ), + T.access_ptr( + B_dequantize_local[j * local_size_b], + "w", + local_size_b, + ), 8, ) diff --git a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py index b1f8b11328..2db3cd61a9 100644 --- a/examples/dequantize_gemm/example_dequant_gemm_w4a8.py +++ b/examples/dequantize_gemm/example_dequant_gemm_w4a8.py @@ -202,4 +202,3 @@ def run_regression_perf(m=4096, n=4096, k=4096): M, N, K = args.m, args.n, args.k main(M, N, K, args.tune) - # main(M, N, K, True) diff --git a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py index 43e97f9309..b67d8165b4 100644 --- a/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py +++ b/examples/dequantize_gemm/example_dequant_gemv_fp16xint4.py @@ -113,8 +113,8 @@ def main( if fast_decoding: T.call_extern( func_name, - T.address_of(B_quant_local[0]), - T.address_of(B_dequantize_local[0]), + T.access_ptr(B_quant_local, "r"), + T.access_ptr(B_dequantize_local, "w"), dtype=in_dtype, ) else: @@ -135,7 +135,7 @@ def main( accum_res[0] += A_local[ki] * B_dequantize_local[ki] with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): diff --git a/examples/dequantize_gemm/regression_example_dequantize_gemm.py b/examples/dequantize_gemm/regression_example_dequantize_gemm.py index 4ab03784ff..51b7c53e00 100644 --- a/examples/dequantize_gemm/regression_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/regression_example_dequantize_gemm.py @@ -4,7 +4,6 @@ import example_dequant_gemm_fp4_hopper import example_dequant_gemm_w4a8 import example_dequant_gemv_fp16xint4 -import example_dequant_groupedgemm_bf16_mxfp4_hopper def regression_example_dequant_gemv_fp16xint4(): @@ -23,10 +22,6 @@ def regression_example_dequant_gemm_bf16_mxfp4_hopper(): tilelang.testing.process_func(example_dequant_gemm_bf16_mxfp4_hopper.run_regression_perf) -def regression_example_dequant_groupedgemm_bf16_mxfp4_hopper(): - tilelang.testing.process_func(example_dequant_groupedgemm_bf16_mxfp4_hopper.run_regression_perf) - - def regression_example_dequant_gemm_w4a8(): tilelang.testing.process_func(example_dequant_gemm_w4a8.run_regression_perf) diff --git a/examples/dequantize_gemm/test_example_dequantize_gemm.py b/examples/dequantize_gemm/test_example_dequantize_gemm.py index a2f777222b..021402a363 100644 --- a/examples/dequantize_gemm/test_example_dequantize_gemm.py +++ b/examples/dequantize_gemm/test_example_dequantize_gemm.py @@ -3,7 +3,6 @@ import example_dequant_gemv_fp16xint4 import example_dequant_gemm_fp4_hopper import example_dequant_gemm_bf16_mxfp4_hopper -import example_dequant_groupedgemm_bf16_mxfp4_hopper import example_dequant_gemm_w4a8 @@ -13,25 +12,19 @@ def test_example_dequant_gemv_fp16xint4(): @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_dequant_gemm_fp4_hopper(): example_dequant_gemm_fp4_hopper.main() @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_dequant_gemm_bf16_mxfp4_hopper(): example_dequant_gemm_bf16_mxfp4_hopper.main() @tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): - example_dequant_groupedgemm_bf16_mxfp4_hopper.main() - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_dequant_gemm_w4a8(): example_dequant_gemm_w4a8.main() diff --git a/examples/distributed/example_allgather_gemm_specialized.py b/examples/distributed/example_allgather_gemm_specialized.py new file mode 100644 index 0000000000..17a6a6a6e8 --- /dev/null +++ b/examples/distributed/example_allgather_gemm_specialized.py @@ -0,0 +1,243 @@ +import os +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing + +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.distributed import init_dist +from tilelang.distributed import perf_fn +from tilelang.utils.allocator import get_allocator + +tilelang.disable_cache() +os.environ["NCCL_DEBUG"] = "WARN" + + +@tilelang.jit +def ag_gemm_sm_specialized_kernel( + M, + N, + K, + num_ranks, + num_comm_sms: int, + block_M: int, + block_N: int, + block_K: int, + threads: int, + dtype: str = "float16", +): + sm_num = driver.get_num_sms() + num_comp_sms = sm_num - num_comm_sms + M_per_rank = M // num_ranks + N_per_rank = N // num_ranks + m_blocks = T.ceildiv(M, block_M) + n_blocks = T.ceildiv(N_per_rank, block_N) + local_m_blocks = T.ceildiv(M_per_rank, block_M) + k_blocks = T.ceildiv(K, block_K) + total_tiles = m_blocks * n_blocks + waves = T.ceildiv(total_tiles, num_comp_sms) + GROUP_SIZE_M = 8 + accum_dtype = "float" + + @T.prim_func + def main( + A_local: T.Tensor((M_per_rank, K), dtype), + B: T.Tensor((K, N_per_rank), dtype), + mcast_A: T.Tensor((M, K), dtype), + gathered_A: T.Tensor((M, K), dtype), + mcast_signal: T.Tensor((m_blocks,), "uint32"), + local_signal: T.Tensor((m_blocks,), "uint32"), + grid_barrier: T.Tensor((num_ranks,), "int32"), + C: T.Tensor((M, N_per_rank), dtype), + local_rank: T.int32, + ): + with T.Kernel(sm_num, threads=threads) as bid: + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_shared = T.alloc_shared((block_M, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + tid = T.get_thread_binding(0) + + if bid == 0: + for i in T.serial(T.ceildiv(m_blocks, threads)): + signal_idx = i * threads + tid + if signal_idx < m_blocks: + local_signal[signal_idx] = 0 + T.fence_sys() + T.barrier_blocks(grid_barrier) + + if bid < num_comp_sms: + for w in T.serial(waves): + tile_id = bid + w * num_comp_sms + if tile_id < total_tiles: + num_pid_in_group = GROUP_SIZE_M * n_blocks + group_id = tile_id // num_pid_in_group + first_pid_m = group_id * GROUP_SIZE_M + group_size_m = T.min(m_blocks - first_pid_m, GROUP_SIZE_M) + pid_m = first_pid_m + ((tile_id % num_pid_in_group) % group_size_m) + pid_n = (tile_id % num_pid_in_group) // group_size_m + + if tid == 0: + T.wait_eq(local_signal[pid_m], 1) + + T.clear(C_local) + for k in T.Pipelined(k_blocks, num_stages=3): + T.copy(gathered_A[pid_m * block_M, k * block_K], A_shared) + T.copy(B[k * block_K, pid_n * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + T.copy(C_local, C_shared) + T.copy(C_shared, C[pid_m * block_M, pid_n * block_N]) + else: + loaded = T.alloc_barrier([256]) + parity = 0 + comm_sm_id = bid - num_comp_sms + for local_m in T.serial(T.ceildiv(local_m_blocks, num_comm_sms)): + local_pid_m = comm_sm_id + local_m * num_comm_sms + if local_pid_m < local_m_blocks: + global_pid_m = local_rank * local_m_blocks + local_pid_m + for k in T.serial(k_blocks): + T.tma_load(A_local[local_pid_m * block_M, k * block_K], A_shared) + T.mbarrier_arrive(loaded) + T.mbarrier_wait_parity(loaded, parity) + parity = (parity + 1) % 2 + T.copy( + A_shared, + mcast_A[global_pid_m * block_M, k * block_K], + ) # TODO(wt): Change to canonical mcast tma store later + + T.fence_sys() + if tid == 0: + T.multimem_signal(mcast_signal[global_pid_m], 1) + + return main + + +def ag_gemm_op( + A, + B, + mcast_A, + gathered_A, + mcast_signal, + local_signal, + grid_barrier, + C, + kernel, + local_rank, +): + kernel(A, B, mcast_A, gathered_A, mcast_signal, local_signal, grid_barrier, C, local_rank) + return C + + +def torch_ag_gemm(group: torch.distributed.ProcessGroup, A: torch.Tensor, B: torch.Tensor, ag_out: torch.Tensor): + torch.distributed.all_gather_into_tensor(ag_out, A, group) + return torch.matmul(ag_out, B) + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + dtype = torch.float16 + M, N, K = args.M, args.N, args.K + block_M, block_N, block_K = args.block_m, args.block_n, args.block_k + threads = args.threads + num_comm_sms = args.num_comm_sms + + assert M % num_local_ranks == 0, "M must be divisible by num-processes" + assert N % num_local_ranks == 0, "N must be divisible by num-processes" + assert (M // num_local_ranks) % block_M == 0, "M_per_rank must be divisible by block_m" + assert (N // num_local_ranks) % block_N == 0, "N_per_rank must be divisible by block_n" + assert K % block_K == 0, "K must be divisible by block_k" + assert 0 < num_comm_sms < driver.get_num_sms(), "num_comm_sms must leave at least one compute SM" + + M_per_rank = M // num_local_ranks + N_per_rank = N // num_local_ranks + m_blocks = M // block_M + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + assert rank == local_rank and num_ranks == num_local_ranks, "only support single-node launch for now" + + dtype_bytes = torch.tensor([], dtype=dtype).element_size() + signal_bytes = torch.tensor([], dtype=torch.uint32).element_size() + # _allocate_mcast_tensor uses aligned bump allocation; keep room for padding + # between the gathered A buffer and the signal buffer. + mcast_bytes = M * K * dtype_bytes + m_blocks * signal_bytes + 4096 + allocator = get_allocator( + size=2**30, + device=f"cuda:{local_rank}", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + use_vmm=True, + mcast_size=mcast_bytes, + ) + + kernel = ag_gemm_sm_specialized_kernel( + M, + N, + K, + num_local_ranks, + num_comm_sms, + block_M, + block_N, + block_K, + threads, + ) + kernel.initialize(allocator=allocator) + if local_rank == 0 and args.print_source: + print(kernel.get_kernel_source()) + + torch.manual_seed(42 + local_rank) + A = tilelang.tensor((M_per_rank, K), dtype, allocator=allocator).normal_() + B = tilelang.tensor((K, N_per_rank), dtype, allocator=allocator).normal_() + C = tilelang.tensor((M, N_per_rank), dtype, allocator=allocator) + grid_barrier = tilelang.tensor((num_local_ranks,), torch.int32, allocator=allocator).zero_() + + mcast_A_flat, gathered_A_flat = allocator._allocate_mcast_tensor((M * K,), dtype) + mcast_signal, local_signal = allocator._allocate_mcast_tensor((m_blocks,), torch.uint32) + mcast_A = mcast_A_flat.view(M, K) + gathered_A = gathered_A_flat.view(M, K) + + dist.barrier(group) + tilelang_C = ag_gemm_op(A, B, mcast_A, gathered_A, mcast_signal, local_signal, grid_barrier, C, kernel, local_rank) + torch.cuda.synchronize() + dist.barrier(group) + + torch_ag_buffer = torch.empty((M, K), dtype=dtype, device=f"cuda:{local_rank}") + torch_C = torch_ag_gemm(group, A, B, torch_ag_buffer) + + if torch.allclose(torch_C, tilelang_C, atol=1e-2, rtol=1e-2): + print(f"rank {local_rank} check passed.") + else: + max_diff = (torch_C - tilelang_C).abs().max().item() + print(f"rank {local_rank} check failed. max_diff={max_diff}") + + tl_t = perf_fn( + lambda: ag_gemm_op(A, B, mcast_A, gathered_A, mcast_signal, local_signal, grid_barrier, C, kernel, local_rank), + warmup=args.warmup, + rep=args.rep, + ) + print(f"rank {local_rank} tilelang specialized ag_gemm time: {tl_t:.2f} ms, TFLOPS: {2 * M * N * K / 1e9 / tl_t / num_local_ranks:.2f}") + + allocator.close() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=8) + parser.add_argument("--M", type=int, default=32768) + parser.add_argument("--N", type=int, default=16384) + parser.add_argument("--K", type=int, default=2048) + parser.add_argument("--block-m", type=int, default=128) + parser.add_argument("--block-n", type=int, default=256) + parser.add_argument("--block-k", type=int, default=64) + parser.add_argument("--threads", type=int, default=256) + parser.add_argument("--num-comm-sms", type=int, default=8) + parser.add_argument("--warmup", type=int, default=5) + parser.add_argument("--rep", type=int, default=10) + parser.add_argument("--print-source", action="store_true") + args = parser.parse_args() + + torch.multiprocessing.spawn(main, args=(args.num_processes, args), nprocs=args.num_processes, join=True) diff --git a/examples/distributed/example_multimem_allreduce.py b/examples/distributed/example_multimem_allreduce.py new file mode 100644 index 0000000000..9247d9e040 --- /dev/null +++ b/examples/distributed/example_multimem_allreduce.py @@ -0,0 +1,123 @@ +""" +Multimem allreduce example using NVSwitch multicast instructions. + +Multi-process multi-GPU: each process manages one GPU, multicast handle +shared via fabric handles through torch.distributed. + +Usage: + export TILESCALE_USE_VMM=1 + export NCCL_IB_DISABLE=1 + export TILELANG_USE_DISTRIBUTED=1 + python examples/distributed/example_multimem_allreduce.py [--num-processes 8] + +Requirements: + - NVSwitch with multicast support (H100/B200 DGX) +""" + +import os +import argparse + +import torch +import torch.distributed as dist +import torch.multiprocessing + +import tilelang +import tilelang.language as T +from tilelang.distributed import init_dist +from tilelang.utils.allocator import get_allocator + +tilelang.disable_cache() +os.environ["NCCL_DEBUG"] = "WARN" + + +def multimem_allreduce_kernel(N, block_N, threads): + @T.prim_func + def main( + mcast_buf: T.Tensor((N,), "float32"), + result: T.Tensor((N,), "float32"), + ): + with T.Kernel(T.ceildiv(N, block_N), threads=threads) as (bx,): + result_local = T.alloc_fragment([block_N], "float32") + T.multimem_ld_reduce( + mcast_buf[bx * block_N : (bx + 1) * block_N], + result_local, + reduce_op=T.MultimemReduceOp.ADD, + ) + T.copy(result_local, result[bx * block_N : (bx + 1) * block_N]) + + return main + + +def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace): + N = args.N + BLOCK_N = args.block_n + threads = args.threads + + rank, num_ranks, group = init_dist(local_rank, num_local_ranks) + + # Create allocator with integrated multicast buffer + allocator = get_allocator( + size=N * 4, # float32 = 4 bytes + device=f"cuda:{local_rank}", + is_distributed=True, + local_rank=local_rank, + num_local_ranks=num_local_ranks, + group=group, + mcast_size=N * 4, + ) + + # Compile kernel + kernel = tilelang.compile( + multimem_allreduce_kernel(N, BLOCK_N, threads), + pass_configs={"tl.disable_tma_lower": True}, + ) + if local_rank == 0 and args.print_source: + print(kernel.get_kernel_source()) + + # Random input per rank + torch.manual_seed(42 + local_rank) + local_data = torch.randn(N, dtype=torch.float32, device=f"cuda:{local_rank}") + + # Allocate from multicast buffer + # mcast_tensor: MC VA for multimem instructions (read) + # local_tensor: physical VA for writing data + mcast_tensor, local_tensor = allocator._allocate_mcast_tensor((N,), torch.float32) + + # Write to physical memory (NOT the MC VA) + local_tensor.copy_(local_data) + torch.cuda.synchronize() + dist.barrier(group) + result = torch.empty(N, dtype=torch.float32, device=f"cuda:{local_rank}") + kernel(mcast_tensor, result) + torch.cuda.synchronize() + + # torch.distributed reference + expected = local_data.clone() + dist.all_reduce(expected, op=dist.ReduceOp.SUM, group=group) + + # Compare (fp32 should be exact or near-exact) + atol = 1e-5 + max_diff = (result - expected).abs().max().item() + passed = max_diff < atol + + if local_rank == 0: + print(f"N={N}, num_ranks={num_ranks}, max_diff={max_diff:.4f}, atol={atol}") + if passed: + print(f"[rank {local_rank}] PASSED") + else: + print(f"[rank {local_rank}] FAILED (max_diff={max_diff:.4f})") + + allocator.close() + dist.destroy_process_group() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--num-processes", type=int, default=8) + parser.add_argument("--N", type=int, default=65536) + parser.add_argument("--block_n", type=int, default=4096) + parser.add_argument("--threads", type=int, default=128) + parser.add_argument("--print_source", action="store_true") + args = parser.parse_args() + + torch.multiprocessing.spawn(main, args=(args.num_processes, args), nprocs=args.num_processes, join=True) diff --git a/examples/dsa_hisa/README.md b/examples/dsa_hisa/README.md new file mode 100644 index 0000000000..5eb0f72a8a --- /dev/null +++ b/examples/dsa_hisa/README.md @@ -0,0 +1,200 @@ +# tilelang_kernels — hisa prefill pipeline + +Tilelang prefill implementation of **hisa** (HIerarchical Sparse Attention). +Paper: . + +## What is HISA? + +HISA optimizes DeepSeek sparse attention by a plug-and-play replacement +for the indexer that rewrites the search path from a flat token scan into +a two-stage hierarchical procedure. + +**Stage 1 — coarse block-level selection.** Group K tokens into pool blocks +of `k_block_size` tokens, mean-pool each block, then score each query +against all pool blocks and pick the top `block_topk` blocks per query. + +**Stage 2 — fine-grained token-level scoring.** For each query, run a +full-resolution MQA over the raw tokens inside its selected blocks, then +pick the top `topk_tokens` tokens per query. + +## Files + +| file | step | role | +|---|---|---| +| `fp8_block_mean_pooling.py` | 1.1 | Mean-pool raw K into pool blocks (fp8 + per-block f32 scale) | +| `pool_mqa_fp8.py` | 1.2 | fp8×fp8 score `Q · pooled_K` → one logit per (query, pool block) | +| `clean_and_maintain_logits.py` | 1.3 | In-place mask on stage-1 logits: -inf outside per-query range, +inf at first/last valid block | +| `block_sparse_mqa_fp8.py` | 2.1 | fp8×fp8 fine-grained score over the raw tokens of the `block_topk` selected blocks | +| `hisa.py` | — | End-to-end orchestration: all four kernels + the two `torch.topk` steps + the index-translation post-processing | + +Each per-kernel file has one `test_*` entry that (a) runs the kernel + +torch ref, (b) asserts via `torch.testing.assert_close`, (c) prints the +latency of the kernel. `hisa.py` has `test_hisa` that runs the full +pipeline, checks the output-index mask invariant, and prints end-to-end +latency. + +## Per-kernel reference + +### 1.1 `fp8_block_mean_pooling.py` + +**Function**: `fp8_native_block_mean_pooling` + +**Meaning**: flat per-block mean of the chunk's K tokens, re-quantized to +fp8 with a per-block f32 scale. Groups `N` K tokens into +`ceildiv(N, k_block_size)` pool blocks. + +**Interface**: +```python +blocked_k, blocked_k_scale = fp8_native_block_mean_pooling_interface( + k, # [N, D] fp8 + k_scale, # [N] f32 — per-token scale from indexer_k_quant_and_cache + k_block_size, +) +# blocked_k: [num_blocks, D] fp8 +# blocked_k_scale: [num_blocks] f32 +``` + +**What it does**: per pool block `b` of size `kb = k_block_size`, +1. dequantize each of the `kb` tokens: `k_f[i] = k_fp8[i] * k_scale[i]` +2. average across the block in f32: `mean = sum_i k_f[i] / kb` (or the + actual valid count for the ragged tail block) +3. re-quantize the f32 mean to fp8 with a per-block scale + `block_scale = max(max_abs(mean) / 448, 1e-10)`, writing + `blocked_k[b] = fp8(mean / block_scale)` and `blocked_k_scale[b] = block_scale`. + +### 1.2 `pool_mqa_fp8.py` + +**Function**: `pool_mqa_attn_return_logits_fp8` + +**Meaning**: coarse-grained fp8 multi-query attention over the **pooled** K +(one vector per pool block). Produces one logit per (query, pool-block). + +**Interface**: +```python +block_k_score = pool_mqa_attn_return_logits_fp8_interface( + q_fp8, # [M, H, D] fp8 + blocked_kv_fp8, # [Nb, D] fp8 (from step 1.1) + blocked_kv_scale, # [Nb] f32 (from step 1.1) + weights_f32, # [M, H] f32 + cu_seqlen_blocked_ks, # [M] int32 — per-query start in pool-block coords + cu_seqlen_blocked_ke, # [M] int32 — per-query end in pool-block coords +) +# block_k_score: [M, Nb] f32 +``` + +**What it does**: for each query `m` and each pool block `n` in +`[cu_seqlen_blocked_ks[m], cu_seqlen_blocked_ke[m])`, +``` +block_k_score[m, n] = sum_h ReLU(q[m, h] · blocked_k[n]) * blocked_k_scale[n] * weights[m, h] +``` +Uses tile-level fp8×fp8→f32 Tensor Core GEMM; the per-block scale is +applied post-GEMM. The kernel processes queries in tiles of size +`block_Q × block_N` and **writes the union of the tile's queries' visible +K ranges** — entries outside an individual query's range inside that +union still carry raw dot-product values (they will be masked by +step 1.3 next). Entries outside the tile union are left at their +zero-init value. + +### 1.3 `clean_and_maintain_logits.py` + +**Function**: `clean_and_maintain_logits_` + +**Meaning**: in-place post-kernel mask on the stage-1 logits. + +**Interface**: +```python +clean_and_maintain_logits_interface( + logits, # [M, Nb] f32 — stage-1 output; modified in place + cu_seqlen_ks, # [M] int32 — per-row start (inclusive) + cu_seqlen_ke, # [M] int32 — per-row end (exclusive) +) +``` + +**What it does**: for each row `m`, +- positions outside `[cu_seqlen_ks[m], cu_seqlen_ke[m])` → set to `-inf` + (so `torch.topk` ignores them), +- positions `cu_seqlen_ks[m]` and `cu_seqlen_ke[m] - 1` → set to `+inf` + (force-maintain the boundary blocks: they are always picked by the + subsequent top-block selection — a standard hisa trick to preserve + sink and local blocks). + +### 2.1 `block_sparse_mqa_fp8.py` + +**Function**: `fp8_native_block_sparse_mqa_attn_return_logits` + +**Meaning**: fine-grained fp8 MQA over only the **raw K tokens** inside the +top-`block_topk` pool blocks selected per query. Two kernel variants are +auto-dispatched by the factory: +- general (`kv_block_size > block_N`): pipelined sub-block inner loop +- small-pooling-size (`kv_block_size == block_N`): single pass, no pipeline + +**Interface**: +```python +block_sparse_logits = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, # [M, H, D] fp8 + k, # [N, D] fp8 + k_scale, # [N] f32 + topk_block_index, # [M, block_topk] int64 — from torch.topk over stage-1 scores + kv_block_size, # == k_block_size + weights, # [M, H] f32 + cu_seqlen_ks, # [M] int32 — per-query K start (absolute, in raw tokens) + cu_seqlen_ke, # [M] int32 — per-query K end +) +# block_sparse_logits: [M, block_topk * kv_block_size] f32 +``` + +**What it does**: for each query `m`, for each selected block +`t ∈ [0, block_topk)` with `blk = topk_block_index[m, t]`, for each +in-block offset `i ∈ [0, kv_block_size)`, +``` +k_abs = blk * kv_block_size + i +if k_abs ∉ [cu_seqlen_ks[m], cu_seqlen_ke[m]) or k_abs >= N: + block_sparse_logits[m, t * kv_block_size + i] = -inf +else: + block_sparse_logits[m, t * kv_block_size + i] = + sum_h ReLU(q[m, h] · k[k_abs]) * k_scale[k_abs] * weights[m, h] +``` +The out-of-range mask is written directly by this kernel — no separate +mask pass is needed here (unlike stage 1). + +### End-to-end `hisa.py` + +**Function**: `hisa_indexer` + +**Meaning**: single entry point that runs the full pipeline below. + +**Interface**: +```python +topk_indices = hisa_indexer( + q, # [M, H, D] fp8 + k, # [N, D] fp8 + k_scale, # [N] f32 + weights, # [M, H] f32 + cu_seqlen_ks, # [M] int32 — per-query K start + cu_seqlen_ke, # [M] int32 — per-query K end + *, + k_block_size, # pool block size (=128 in DeepSeek-V3.2) + block_topk, # number of top pool blocks kept per query + topk_tokens, # final top-k size handed to the sparse attention +) +# topk_indices: [M, topk_tokens] int32 — each row is the query's top-k K +# positions expressed as offsets within its own [cu_ks, cu_ke) window. +# Out-of-range slots are -1. +``` + +**Pipeline**: + +``` +(1.1) fp8_native_block_mean_pooling K, k_scale → blocked_k, blocked_k_scale +(1.2) pool_mqa_attn_return_logits_fp8 Q × blocked_k → block_k_score[M, Nb] +(1.3) clean_and_maintain_logits in-place mask (-inf/+inf) on block_k_score +(1.4) torch.topk(block_k_score.bfloat16(), → topk_block_indices[M, block_topk] int64 + k=block_topk, sorted=False) +(2.1) fp8_native_block_sparse_mqa_… Q × K[selected] → block_sparse_logits + [M, block_topk * k_block_size] +(2.2) torch.topk(block_sparse_logits, → relevant_topk_indices[M, topk_tokens] int64 + k=topk_tokens) +(2.3) (Python) gather topk_block_indices + → absolute K positions, then subtract + arith + subtract cu_seqlen_ks + mask cu_seqlen_ks for per-query-relative offsets + → topk_indices[M, topk_tokens] int32 +``` diff --git a/examples/dsa_hisa/block_sparse_mqa_fp8.py b/examples/dsa_hisa/block_sparse_mqa_fp8.py new file mode 100644 index 0000000000..10f95bb204 --- /dev/null +++ b/examples/dsa_hisa/block_sparse_mqa_fp8.py @@ -0,0 +1,269 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def fp8_native_block_sparse_mqa_attn_return_logits( + IndexQ, + IndexK, + IndexKScale, + TopKBlockIndex, + Weights, + CuSeqLenKS, + CuSeqLenKE, + heads: int = 64, + index_dim: int = 128, + kv_block_size: int = 128, + topk: int = 64, + block_N: int = 128, + num_stages: int = 1, + threads: int = 256, +): + fp8_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + topk_index_dtype = T.int64 + + seq_len, seq_len_kv = T.const("seq_len, seq_len_kv") + + H_per_block = heads + block_N = min(block_N, kv_block_size // 2) + assert kv_block_size % block_N == 0, "block_N must divide kv_block_size" + + IndexQ: T.Tensor[[seq_len * heads, index_dim], fp8_dtype] + IndexK: T.Tensor[[seq_len_kv, index_dim], fp8_dtype] + IndexKScale: T.Tensor[[seq_len_kv], accum_dtype] + TopKBlockIndex: T.Tensor[[seq_len, topk], topk_index_dtype] + Weights: T.Tensor[[seq_len, heads], accum_dtype] + CuSeqLenKS: T.Tensor[[seq_len], index_dtype] + CuSeqLenKE: T.Tensor[[seq_len], index_dtype] + + Logits = T.empty((seq_len, topk * kv_block_size), accum_dtype) + + with T.Kernel(seq_len, threads=threads) as bx: + index_q_shared = T.alloc_shared([H_per_block, index_dim], fp8_dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], fp8_dtype) + # Shared (zero-init'd) — see note in the hisa source about serial-topk + # loop making shared slightly faster than fragment here. + scale_shared = T.alloc_shared([block_N], accum_dtype) + + s = T.alloc_fragment([block_N, H_per_block], accum_dtype) + s_reshaped = T.reshape(s, (block_N, H_per_block // heads, heads)) + logits = T.alloc_fragment([block_N, H_per_block // heads], accum_dtype) + weights = T.alloc_fragment([H_per_block // heads, heads], accum_dtype) + + seq_len_i = bx + + cu_k_s_min = CuSeqLenKS[seq_len_i] + cu_k_e_max = CuSeqLenKE[seq_len_i] + + T.copy(IndexQ[seq_len_i * heads : seq_len_i * heads + H_per_block, :], index_q_shared) + T.copy(Weights[seq_len_i, :], weights) + + for n_i in T.serial(topk): + topk_block_id = T.cast(TopKBlockIndex[seq_len_i, n_i], index_dtype) + block_s = topk_block_id * kv_block_size + for b_i in T.Pipelined(kv_block_size // block_N, num_stages=num_stages): + block_s_i = block_s + b_i * block_N + + T.copy(IndexK[block_s_i : block_s_i + block_N, :], index_k_shared) + for bn_i in T.Parallel(block_N): + scale_shared[bn_i] = IndexKScale[block_s_i + bn_i] + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullRow, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, H_per_block // heads, heads): + s_reshaped[bn_i, bq_i, h_i] = T.max(s_reshaped[bn_i, bq_i, h_i] * scale_shared[bn_i], 0) * weights[bq_i, h_i] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for i_i in T.Parallel(block_N): + k_i = block_s_i + i_i + if k_i < cu_k_s_min or k_i >= cu_k_e_max: + logits[i_i, 0] = -T.infinity(accum_dtype) + + for bn_i in T.Parallel(block_N): + Logits[seq_len_i, n_i * kv_block_size + b_i * block_N + bn_i] = logits[bn_i, 0] + + return Logits + + +def fp8_native_block_sparse_mqa_attn_return_logits_interface( + q: torch.Tensor, + k: torch.Tensor, + k_scale: torch.Tensor, + topk_block_index: torch.Tensor, + kv_block_size: int, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + seq_len, heads, index_dim = q.shape + topk = topk_block_index.shape[1] + logits = fp8_native_block_sparse_mqa_attn_return_logits( + q.view(seq_len * heads, index_dim), + k, + k_scale, + topk_block_index, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + heads=heads, + index_dim=index_dim, + kv_block_size=kv_block_size, + topk=topk, + ) + return logits + + +def ref_fp8_block_sparse_mqa( + q_fp8: torch.Tensor, + k_fp8: torch.Tensor, + k_scale: torch.Tensor, + topk_block_index: torch.Tensor, + kv_block_size: int, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + M, H, D = q_fp8.shape + N = k_fp8.shape[0] + topk = topk_block_index.shape[1] + + block_starts = topk_block_index.long() * kv_block_size # [M, topk] + pos_in_block = torch.arange(kv_block_size, device=q_fp8.device) + k_abs = block_starts[..., None] + pos_in_block[None, None, :] # [M, topk, B] + k_safe = k_abs.clamp(0, N - 1) + + q_f = q_fp8.float() + k_f = k_fp8.float() * k_scale[:, None] + gathered_k = k_f[k_safe.flatten()].reshape(M, topk, kv_block_size, D) + + s = torch.einsum("mhd,mtid->mtih", q_f, gathered_k) # [M, topk, B, H] + logits = (s.clamp(min=0) * weights[:, None, None, :]).sum(dim=-1) # [M, topk, B] + + in_range = (k_abs >= cu_seqlen_ks.long()[:, None, None]) & (k_abs < cu_seqlen_ke.long()[:, None, None]) & (k_abs < N) + logits = logits.masked_fill(~in_range, float("-inf")) + return logits.reshape(M, topk * kv_block_size) + + +def test_fp8_block_sparse_mqa( + M: int = 1024, + H: int = 64, + D: int = 128, + kv_block_size: int = 128, + topk: int = 64, + num_seqs: int = 1, +): + """Correctness + speed test packing `num_seqs` equal-length causal + sequences into the [M, H, D] Q and [M, D] K tensors. Each query sees + only the prefix of its own sequence (``cu_ks = start_of_seq``, + ``cu_ke = start_of_seq + position_in_seq + 1``). + + ``topk_block_index`` is drawn at random from [0, num_k_blocks) — some + picks will point to blocks outside the query's own sequence; those + positions get -inf via the kernel's built-in mask, and the torch ref + produces the same -inf. Comparison checks both the +/-inf mask + pattern (exact) and the finite values (fp8 tolerance).""" + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + N = M # causal self-attention prefill, packed + + per_seq = M // num_seqs + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + # Random per-query top-k blocks (distinct indices drawn from [0, num_blocks)). + num_k_blocks = (N + kv_block_size - 1) // kv_block_size + topk = min(topk, num_k_blocks) + g = torch.Generator(device="cuda").manual_seed(42) + topk_block_index = torch.stack([torch.randperm(num_k_blocks, generator=g, device="cuda")[:topk] for _ in range(M)]).to(torch.int64) + + # Correctness. + got = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + ref = ref_fp8_block_sparse_mqa( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + # The kernel marks out-of-range as -inf. Compare finite positions only — + # the -inf mask pattern must agree exactly, so we also check that. + finite = torch.isfinite(got) & torch.isfinite(ref) + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + torch.testing.assert_close(got[finite], ref[finite], rtol=1e-1, atol=2e-1) + print(f" correctness: PASS (M={M}, H={H}, D={D}, kv_block_size={kv_block_size}, topk={topk}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + return fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_index, + kv_block_size, + weights, + cu_ks, + cu_ke, + ) + + ms = do_bench(fn, warmup=50, rep=200) + # FLOPs: M × topk × kv_block_size × H × D (fp8×fp8) × 2 (mul+add). + total_flops = 2 * M * topk * kv_block_size * H * D + tflops = total_flops / (ms * 1e-3) / 1e12 + print(f" latency: {ms:.4f} ms ({tflops:.2f} fp8 TFLOPS)") + + +if __name__ == "__main__": + # Ref path materialises [M, topk, B, D] fp32 gathered_k which is ~M GB at + # topk=64, kv_block_size=128, D=128. Keep M modest to avoid OOM. + # (M, H, D, kv_block_size, topk, num_seqs) + for cfg in [ + (1024, 64, 128, 128, 64, 1), + (4096, 64, 128, 128, 64, 1), + (4096, 64, 128, 128, 64, 4), + (8192, 64, 128, 128, 64, 1), + (8192, 64, 128, 128, 64, 8), + (8192, 64, 128, 64, 128, 8), + (8192, 64, 128, 256, 32, 8), + ]: + test_fp8_block_sparse_mqa(*cfg) + torch.cuda.empty_cache() diff --git a/examples/dsa_hisa/clean_and_maintain_logits.py b/examples/dsa_hisa/clean_and_maintain_logits.py new file mode 100644 index 0000000000..12ff8a4c10 --- /dev/null +++ b/examples/dsa_hisa/clean_and_maintain_logits.py @@ -0,0 +1,121 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +@tilelang.jit +def clean_and_maintain_logits_( + Logits, + CuSeqLenKS, + CuSeqLenKE, + threads: int = 512, + block_K: int = 4096, +): + seq_len, seq_len_kv = T.const("seq_len, seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + Logits: T.Tensor[[seq_len, seq_len_kv], dtype] + CuSeqLenKS: T.Tensor[[seq_len], indices_dtype] + CuSeqLenKE: T.Tensor[[seq_len], indices_dtype] + + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx == cu_k_s or idx == cu_k_e - 1: + Logits[bx, idx] = T.infinity(dtype) + if idx < cu_k_s or idx >= cu_k_e: + Logits[bx, idx] = -T.infinity(dtype) + + +def clean_and_maintain_logits_interface( + logits: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +): + """In-place: applies +inf/-inf mask based on per-row [ks, ke).""" + clean_and_maintain_logits_(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits + + +def ref_clean_and_maintain_logits( + logits: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, +) -> torch.Tensor: + """Pure torch equivalent. Returns a new tensor (doesn't mutate the input).""" + M, N = logits.shape + out = logits.clone() + n = torch.arange(N, device=logits.device)[None, :] + mask_out = (n < cu_seqlen_ks.long()[:, None]) | (n >= cu_seqlen_ke.long()[:, None]) + out = out.masked_fill(mask_out, float("-inf")) + m_idx = torch.arange(M, device=logits.device) + out[m_idx, cu_seqlen_ks.long()] = float("inf") + out[m_idx, (cu_seqlen_ke - 1).clamp(min=0).long()] = float("inf") + return out + + +def test_clean_and_maintain_logits(M: int = 4096, N: int = 4096, num_seqs: int = 1): + """Correctness + speed test where `M` query rows are packed from + `num_seqs` equal-length causal sequences. Per-row ``cu_ks / cu_ke`` + is derived from ``prepare_ks_ke_from_cu_seqlens`` so each row sees + only the prefix of its own sequence (causal self-attention).""" + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + assert (M // num_seqs) <= N, "N must accommodate the longest sequence" + + per_seq = M // num_seqs + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).clamp(max=N).contiguous() + + logits_init = torch.randn(M, N, device="cuda", dtype=torch.float32) + + # Run kernel in place on a copy. + got = logits_init.clone() + clean_and_maintain_logits_interface(got, cu_ks, cu_ke) + + # Ref. + ref = ref_clean_and_maintain_logits(logits_init, cu_ks, cu_ke) + + # Exact equality: this kernel only writes +/-inf, other positions untouched + # (ref clones the input and does the same). Compare directly. + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + finite = torch.isfinite(got) & torch.isfinite(ref) + torch.testing.assert_close(got[finite], ref[finite], rtol=0.0, atol=0.0) + print(f" correctness: PASS (M={M}, N={N}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + logits = torch.randn(M, N, device="cuda", dtype=torch.float32) # fresh copy each iter + clean_and_maintain_logits_interface(logits, cu_ks, cu_ke) + return logits + + ms = do_bench(fn, warmup=50, rep=200) + # ~2 reads + 1 write of [M, N] f32, but mostly no-op except at mask boundaries. + bytes_moved = 2 * M * N * 4 + gbps = bytes_moved / (ms * 1e-3) / 1e9 + print(f" latency: {ms:.4f} ms ({gbps:.1f} GB/s)") + + +if __name__ == "__main__": + # (M, N, num_seqs) + for cfg in [ + (4096, 4096, 1), + (4096, 4096, 4), + (16384, 16384, 1), + (16384, 16384, 8), + (65536, 65536, 16), + ]: + test_clean_and_maintain_logits(*cfg) diff --git a/examples/dsa_hisa/fp8_block_mean_pooling.py b/examples/dsa_hisa/fp8_block_mean_pooling.py new file mode 100644 index 0000000000..1c9f90cc4c --- /dev/null +++ b/examples/dsa_hisa/fp8_block_mean_pooling.py @@ -0,0 +1,146 @@ +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def fp8_native_block_mean_pooling( + K, + KScale, + dim: int = 128, + pooling_block_size: int = 128, + block_N: int = 64, + num_stages: int = 1, + threads: int = 256, +): + dtype = T.float8_e4m3fn + accum_dtype = T.float32 + FP8_MAX_INV = 1.0 / 448.0 + + seq_len_k = T.const("seq_len_k") + + K: T.Tensor[[seq_len_k, dim], dtype] + KScale: T.Tensor[[seq_len_k], accum_dtype] + + num_blocks = T.ceildiv(seq_len_k, pooling_block_size) + BlockedK = T.empty((num_blocks, dim), dtype) + BlockedKScale = T.empty((num_blocks,), accum_dtype) + + with T.Kernel(num_blocks, threads=threads) as bx: + index_k = T.alloc_fragment([block_N, dim], dtype) + scale = T.alloc_fragment([block_N], accum_dtype) + acc = T.alloc_fragment([dim], accum_dtype) + max_abs = T.alloc_fragment([1], accum_dtype) + T.fill(acc, 0.0) + + k_start = bx * pooling_block_size + k_end = T.min(k_start + pooling_block_size, seq_len_k) + cur_pooling_block_size = k_end - k_start + + for b_i in T.serial(T.ceildiv(cur_pooling_block_size, block_N)): + T.fill(index_k, 0.0) + + tl_block_s = k_start + b_i * block_N + tl_block_e = T.min(k_start + (b_i + 1) * block_N, k_end) + T.copy(K[tl_block_s : tl_block_s + block_N, :], index_k) + for bn_i in T.Parallel(block_N): + scale[bn_i] = KScale[tl_block_s + bn_i] + + for bn_i, d_i in T.Parallel(block_N, dim): + index_k[bn_i, d_i] = index_k[bn_i, d_i] * scale[bn_i] + + cur_tl_block_size = tl_block_e - tl_block_s + for n_i in T.parallel(block_N): + for d_i in T.parallel(dim): + if n_i >= cur_tl_block_size: + index_k[n_i, d_i] = T.cast(0, accum_dtype) + + T.reduce_sum(index_k, acc, dim=0, clear=False) + + inv_count = T.cast(1.0, accum_dtype) / T.cast(cur_pooling_block_size, accum_dtype) + for d_i in T.Parallel(dim): + acc[d_i] = acc[d_i] * inv_count + + # Re-quantize f32 mean to fp8 with a per-block scale. + T.reduce_absmax(acc, max_abs, dim=0, clear=True) + block_scale = T.max(max_abs[0] * T.cast(FP8_MAX_INV, accum_dtype), T.cast(1e-10, accum_dtype)) + inv_block_scale = T.cast(1.0, accum_dtype) / block_scale + + for d_i in T.Parallel(dim): + BlockedK[bx, d_i] = T.cast(acc[d_i] * inv_block_scale, dtype) + BlockedKScale[bx] = block_scale + + return BlockedK, BlockedKScale + + +def fp8_native_block_mean_pooling_interface(k: torch.Tensor, k_scale: torch.Tensor, k_block_size: int): + return fp8_native_block_mean_pooling(k, k_scale, dim=k.shape[1], pooling_block_size=k_block_size) + + +def ref_fp8_block_mean_pooling(k_fp8: torch.Tensor, k_scale: torch.Tensor, k_block_size: int) -> torch.Tensor: + """Spec: per-token dequant + per-block mean (dividing by actual valid count). + Returns the f32 mean (caller can compare against fp8*scale re-quant of the kernel).""" + N, D = k_fp8.shape + dequant = k_fp8.float() * k_scale[:, None] + num_blocks = (N + k_block_size - 1) // k_block_size + out = torch.empty(num_blocks, D, device=k_fp8.device, dtype=torch.float32) + for b in range(num_blocks): + s = b * k_block_size + e = min(s + k_block_size, N) + out[b] = dequant[s:e].sum(dim=0) / (e - s) + return out + + +def test_fp8_block_mean_pooling(N: int = 16384, D: int = 128, k_block_size: int = 128, num_seqs: int = 1): + """Correctness + speed test with `num_seqs` sequences of equal length + packed into the flat K buffer. + + NOTE: the flat mean-pool kernel is sequence-agnostic — it pools every + `k_block_size` consecutive tokens regardless of sequence boundaries. + `num_seqs` is accepted here for API consistency with the other kernels' + tests; it affects how `cu_seqlens` is laid out (shown for illustration) + but not the kernel's inputs / outputs. + """ + torch.manual_seed(0) + assert N % num_seqs == 0, f"N ({N}) must be divisible by num_seqs ({num_seqs})" + per_seq = N // num_seqs + + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + + # Correctness. + blocked_k_fp8, blocked_k_scale = fp8_native_block_mean_pooling_interface(k, k_scale, k_block_size) + got = blocked_k_fp8.float() * blocked_k_scale[:, None] + ref = ref_fp8_block_mean_pooling(k, k_scale, k_block_size) + # fp8 re-quant: ~1/256 rel error on top of bf16-level precision. + torch.testing.assert_close(got, ref, rtol=5e-2, atol=5e-3) + print(f" correctness: PASS (N={N}, D={D}, k_block_size={k_block_size}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed. + def fn(): + return fp8_native_block_mean_pooling_interface(k, k_scale, k_block_size) + + ms = do_bench(fn, warmup=50, rep=200) + num_blocks = (N + k_block_size - 1) // k_block_size + # Bytes moved: read N * D fp8 (K) + N * 4 f32 (scale) + write num_blocks * D fp8 + num_blocks * 4 f32. + bytes_moved = N * D + N * 4 + num_blocks * D + num_blocks * 4 + gbps = bytes_moved / (ms * 1e-3) / 1e9 + print(f" latency: {ms:.4f} ms ({gbps:.1f} GB/s)") + + +if __name__ == "__main__": + # (N, D, k_block_size, num_seqs) + for cfg in [ + (16384, 128, 128, 1), + (16384, 128, 128, 4), + (65536, 128, 128, 1), + (65536, 128, 128, 8), + (131072, 128, 128, 16), + ]: + test_fp8_block_mean_pooling(*cfg) diff --git a/examples/dsa_hisa/hisa.py b/examples/dsa_hisa/hisa.py new file mode 100644 index 0000000000..e9863874e1 --- /dev/null +++ b/examples/dsa_hisa/hisa.py @@ -0,0 +1,240 @@ +import torch +from tilelang.profiler import do_bench + +from fp8_block_mean_pooling import fp8_native_block_mean_pooling_interface +from pool_mqa_fp8 import pool_mqa_attn_return_logits_fp8_interface +from block_sparse_mqa_fp8 import fp8_native_block_sparse_mqa_attn_return_logits_interface +from clean_and_maintain_logits import clean_and_maintain_logits_interface +from tilelang_utils import prepare_ks_ke_from_cu_seqlens + + +def hisa_indexer( + q: torch.Tensor, # [M, H, D] fp8_e4m3fn + k: torch.Tensor, # [N, D] fp8_e4m3fn + k_scale: torch.Tensor, # [N] f32 + weights: torch.Tensor, # [M, H] f32 + cu_seqlen_ks: torch.Tensor, # [M] int32 — per-query K start (inclusive) + cu_seqlen_ke: torch.Tensor, # [M] int32 — per-query K end (exclusive) + *, + k_block_size: int, + block_topk: int, + topk_tokens: int, +) -> torch.Tensor: + """Run the full hisa prefill pipeline. + + Returns: ``[M, topk_tokens]`` int32 — each row is this query's top + ``topk_tokens`` K positions, expressed as offsets relative to + ``cu_seqlen_ks[m]`` (so ``0`` means the query's own K start). Slots + that fell outside ``[cu_seqlen_ks[m], cu_seqlen_ke[m])`` get ``-1``. + """ + # ------------------------------------------------------------------ + # Stage 0: fp8 mean-pool over K. Groups K into pool blocks of + # k_block_size tokens each; outputs one fp8 vector + f32 scale per + # pool block. Grid = (ceil(N/k_block_size),). + # ------------------------------------------------------------------ + blocked_k_fp8, blocked_k_scale = fp8_native_block_mean_pooling_interface( + k, + k_scale, + k_block_size, + ) # [Nb, D] fp8, [Nb] f32 + + # Translate the per-query K range from flat-token coords to + # pool-block coords (floor for start, ceil for end). + cu_seqlen_blocked_ks = cu_seqlen_ks // k_block_size + cu_seqlen_blocked_ke = (cu_seqlen_ke + k_block_size - 1) // k_block_size + + # ------------------------------------------------------------------ + # Stage 1: block-level Q·BlockedK score with ReLU + per-head weight + # reduction. Output is dense (kernel doesn't mask out-of-range). + # ------------------------------------------------------------------ + block_k_score = pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k_fp8, + blocked_k_scale, + weights, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + ) # [M, Nb] f32 + + # Mask out-of-range entries to -inf and force +inf on first / last + # valid block so torch.topk picks the boundary blocks. + clean_and_maintain_logits_interface( + block_k_score, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + ) + + # ------------------------------------------------------------------ + # Stage 1.5: top-block_topk selection. bfloat16 + sorted=False is + # ~40% faster than f32 and the downstream sparse_mqa doesn't rely + # on order. + # ------------------------------------------------------------------ + block_topk_eff = min(block_topk, block_k_score.shape[-1]) + topk_block_indices = torch.topk( + block_k_score.bfloat16(), + k=block_topk_eff, + dim=-1, + sorted=False, + ).indices # [M, block_topk_eff] int64 + + # ------------------------------------------------------------------ + # Stage 2: fp8 fine-grained Q·K MQA over only the selected + # blocks' raw tokens (block_topk_eff blocks × k_block_size tokens + # per query). The kernel writes -inf for positions outside + # [cu_seqlen_ks[m], cu_seqlen_ke[m]). + # ------------------------------------------------------------------ + block_sparse_logits = fp8_native_block_sparse_mqa_attn_return_logits_interface( + q, + k, + k_scale, + topk_block_indices, + k_block_size, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) # [M, block_topk_eff * k_block_size] f32 + + # ------------------------------------------------------------------ + # Stage 2.5: top-topk_tokens selection over the block_topk_eff + # × k_block_size candidate tokens. Gives per-query slot ids. + # ------------------------------------------------------------------ + topk_tokens_eff = min(topk_tokens, block_sparse_logits.shape[-1]) + relevant_topk_indices = torch.topk( + block_sparse_logits, + k=topk_tokens_eff, + dim=-1, + ).indices # [M, topk_tokens_eff] int64 + + # ------------------------------------------------------------------ + # Stage 3 (post, Python): translate slot ids → absolute K token + # position → per-query relative offset (matches vLLM indexer + # output buffer). Slots whose relative offset falls outside the + # query's visible range are set to -1. + # ------------------------------------------------------------------ + # slot = block_id_in_topk × k_block_size + offset_in_block + # where block_id_in_topk ∈ [0, block_topk_eff) + # absolute_k = topk_block_indices[m, block_id_in_topk] × k_block_size + offset_in_block + absolute_topk_block_indices = torch.gather( + topk_block_indices, + dim=-1, + index=(relevant_topk_indices // k_block_size), + ) + topk_indices = absolute_topk_block_indices * k_block_size + (relevant_topk_indices % k_block_size) + topk_indices = topk_indices.to(torch.int32) + + # Relative to this query's K start. + topk_indices -= cu_seqlen_ks[:, None] + mask_lo = topk_indices >= 0 + mask_hi = topk_indices - (cu_seqlen_ke - cu_seqlen_ks)[:, None] < 0 + mask = mask_lo & mask_hi + topk_indices = topk_indices.masked_fill(~mask, -1) + + return topk_indices + + +def test_hisa( + M: int = 1024, + H: int = 64, + D: int = 128, + k_block_size: int = 128, + block_topk: int = 8, + topk_tokens: int = 256, + num_seqs: int = 1, +): + """End-to-end smoke + speed test packing `num_seqs` equal-length causal + sequences into the flat [M, H, D] Q and [N=M, D] K tensors. + + Per-token ``cu_ks / cu_ke`` are produced by + ``prepare_ks_ke_from_cu_seqlens`` so each query sees only the prefix + of its own sequence. Validity checks are done per-query (so each + sequence's tail queries have fewer valid candidate slots). + """ + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + per_seq = M // num_seqs + N = M # causal self-attention, packed + + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks = ks_long.to(torch.int32).contiguous() + cu_ke = ke_long.to(torch.int32).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + k_bf16 = torch.randn(N, D, device="cuda", dtype=torch.bfloat16) + k = k_bf16.to(torch.float8_e4m3fn) + k_scale = (0.1 + 0.01 * torch.rand(N, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + topk_indices = hisa_indexer( + q, + k, + k_scale, + weights, + cu_ks, + cu_ke, + k_block_size=k_block_size, + block_topk=block_topk, + topk_tokens=topk_tokens, + ) + + # Sanity checks. + assert topk_indices.shape == (M, topk_tokens), f"unexpected output shape {tuple(topk_indices.shape)}" + assert topk_indices.dtype == torch.int32 + + # Every non-(-1) offset must be within [0, cu_ke[m] - cu_ks[m]). + valid = topk_indices >= 0 + spans = (cu_ke - cu_ks)[:, None].expand_as(topk_indices) + in_range = topk_indices < spans + assert (valid == (valid & in_range)).all(), "some valid offset falls outside its query's K window" + + # Per-query expected number of valid slots = min(cu_ke[m] - cu_ks[m], + # topk_tokens) (clipped by K range and by block_topk × k_block_size). + expected_valid = torch.minimum( + (cu_ke - cu_ks).clamp(min=0), + torch.tensor(min(topk_tokens, block_topk * k_block_size), device=cu_ke.device), + ) + got_valid = valid.sum(dim=-1).to(torch.int32) + frac_match = (got_valid == expected_valid).float().mean().item() + print( + f" shape: {tuple(topk_indices.shape)} " + f"valid_frac: {valid.float().mean().item():.4f} " + f"per-query valid count match: {frac_match:.4f} " + f"(num_seqs={num_seqs}, per_seq={per_seq})" + ) + + # Speed. + def fn(): + return hisa_indexer( + q, + k, + k_scale, + weights, + cu_ks, + cu_ke, + k_block_size=k_block_size, + block_topk=block_topk, + topk_tokens=topk_tokens, + ) + + ms = do_bench(fn, warmup=20, rep=50) + print( + f" latency: {ms:.3f} ms " + f"(M={M}, H={H}, D={D}, k_block_size={k_block_size}, " + f"block_topk={block_topk}, topk_tokens={topk_tokens}, num_seqs={num_seqs})" + ) + + +if __name__ == "__main__": + # Ref path in block_sparse_mqa materialises [M, topk, kvB, D] fp32 so + # stay modest on M (reuse the sparse_mqa module's sizing intuition). + for cfg in [ + dict(M=1024, H=64, D=128, k_block_size=128, block_topk=16, topk_tokens=256, num_seqs=1), + dict(M=1024, H=64, D=128, k_block_size=128, block_topk=16, topk_tokens=256, num_seqs=4), + dict(M=4096, H=64, D=128, k_block_size=128, block_topk=32, topk_tokens=1024, num_seqs=1), + dict(M=4096, H=64, D=128, k_block_size=128, block_topk=32, topk_tokens=1024, num_seqs=4), + dict(M=8192, H=64, D=128, k_block_size=128, block_topk=64, topk_tokens=2048, num_seqs=1), + dict(M=8192, H=64, D=128, k_block_size=128, block_topk=64, topk_tokens=2048, num_seqs=8), + ]: + test_hisa(**cfg) + torch.cuda.empty_cache() diff --git a/examples/dsa_hisa/pool_mqa_fp8.py b/examples/dsa_hisa/pool_mqa_fp8.py new file mode 100644 index 0000000000..515b311ac4 --- /dev/null +++ b/examples/dsa_hisa/pool_mqa_fp8.py @@ -0,0 +1,257 @@ +"""Stage-1 kernel: prefill pool-MQA over pooled (blocked) K. + +Input: fp8 Q ``[M, H, D]`` + fp8 BlockedK ``[Nb, D]`` + per-block f32 scale +``[Nb]`` + f32 Weights ``[M, H]`` + per-query ``cu_seqlen_blocked_ks/ke [M]``. + +For each query ``m`` and pool block ``n`` in ``[cu_seqlen_blocked_ks[m], +cu_seqlen_blocked_ke[m])``: + ``logits[m, n] = sum_h ReLU(Q[m, h] . BlockedK[n]) * BlockedKScale[n] * Weights[m, h]`` + +Out-of-range entries in the raw kernel output are undefined — caller should +zero-init the buffer or apply a separate mask kernel. +""" + +import tilelang +from tilelang import language as T +from tilelang.profiler import do_bench +import torch + +from tilelang_utils import prepare_ks_ke_from_cu_seqlens +from clean_and_maintain_logits import ( + clean_and_maintain_logits_interface, + ref_clean_and_maintain_logits, +) + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def pool_mqa_attn_return_logits_fp8( + IndexQ, + IndexBlockedK, + IndexBlockedKScale, + Logits, + Weights, + CuSeqLenBlockedKS, + CuSeqLenBlockedKE, + heads: int = 64, + index_dim: int = 128, + block_N: int = 256, + num_stages: int = 3, + threads: int = 512, + block_Q: int = 0, +): + # block_Q is the tile size for queries; `0` means "derive from heads". + if block_Q == 0: + block_Q = 128 // heads + fp8_dtype = T.float8_e4m3fn + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len, seq_len_blocked_kv = T.const("seq_len, seq_len_blocked_kv") + + IndexQ: T.Tensor[[seq_len * heads, index_dim], fp8_dtype] + IndexBlockedK: T.Tensor[[seq_len_blocked_kv, index_dim], fp8_dtype] + IndexBlockedKScale: T.Tensor[[seq_len_blocked_kv], accum_dtype] + Logits: T.Tensor[[seq_len, seq_len_blocked_kv], accum_dtype] + Weights: T.Tensor[[seq_len, heads], accum_dtype] + CuSeqLenBlockedKS: T.Tensor[[seq_len], index_dtype] + CuSeqLenBlockedKE: T.Tensor[[seq_len], index_dtype] + + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], fp8_dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], fp8_dtype) + index_k_scale_fragment = T.alloc_fragment([block_N], accum_dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min = T.min(cu_k_s_min, T.min(CuSeqLenBlockedKS[seq_len_i + bq_i], seq_len_blocked_kv)) + for bq_i in T.serial(block_Q): + cu_k_e_max = T.max(cu_k_e_max, T.min(CuSeqLenBlockedKE[seq_len_i + bq_i], seq_len_blocked_kv)) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined(T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages): + T.copy(IndexBlockedK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + T.copy(IndexBlockedKScale[cu_k_s_min + nbn_i * block_N], index_k_scale_fragment) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = T.max(s_reshaped[bn_i, bq_i, h_i] * index_k_scale_fragment[bn_i], 0) * weights[bq_i, h_i] + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = logits[bn_i, bq_i] + + +def pool_mqa_attn_return_logits_fp8_interface( + q_fp8: torch.Tensor, + blocked_kv_fp8: torch.Tensor, + blocked_kv_scale: torch.Tensor, + weights_f32: torch.Tensor, + cu_seqlen_blocked_ks: torch.Tensor, + cu_seqlen_blocked_ke: torch.Tensor, + block_N: int = 256, +): + """Raw kernel invocation; zero-inits logits so positions the kernel + doesn't touch are 0 (matches the ref).""" + seq_len, heads, index_dim = q_fp8.shape + seq_len_blocked_kv = blocked_kv_fp8.shape[0] + + logits = torch.zeros([seq_len, seq_len_blocked_kv], device=q_fp8.device, dtype=torch.float32) + pool_mqa_attn_return_logits_fp8( + q_fp8.view(seq_len * heads, index_dim), + blocked_kv_fp8, + blocked_kv_scale, + logits, + weights_f32, + cu_seqlen_blocked_ks, + cu_seqlen_blocked_ke, + heads=heads, + index_dim=index_dim, + block_N=block_N, + ) + return logits + + +def ref_pool_mqa_fp8( + q_fp8: torch.Tensor, + blocked_kv_fp8: torch.Tensor, + blocked_kv_scale: torch.Tensor, + weights_f32: torch.Tensor, +) -> torch.Tensor: + """Spec: for each (m, n), logits[m, n] = sum_h ReLU(q[m,h] . k[n] * k_scale[n]) * w[m,h]. + Computes the full dense [M, Nb] grid — caller is responsible for any masking.""" + q_f = q_fp8.float() + k_f = blocked_kv_fp8.float() * blocked_kv_scale[:, None] + # score[m, n, h] = q[m, h] . k[n] + s = torch.einsum("mhd,nd->mnh", q_f, k_f) # [M, Nb, H] + logits = (s.clamp(min=0) * weights_f32[:, None, :]).sum(dim=-1) # [M, Nb] + return logits + + +def test_pool_mqa_fp8( + M: int = 32768, + H: int = 64, + D: int = 128, + k_block_size: int = 128, + block_N: int = 256, + num_seqs: int = 1, +): + """Correctness + speed test packing `num_seqs` equal-length causal + sequences into the [M, H, D] Q tensor. + + Per-query ``cu_seqlen_blocked_ks/ke`` is derived from the raw-token + packed ``cu_ks / cu_ke`` produced by ``prepare_ks_ke_from_cu_seqlens`` + (floor-divide / ceil-divide by ``k_block_size`` respectively). + + The kernel writes the per-tile ``[cu_k_s_min, cu_k_e_max)`` union of + visible K ranges — entries inside this union but outside an + individual query's visible range carry raw (unmasked) dot-product + values. To make correctness well-defined, we apply the + ``clean_and_maintain_logits`` mask (-inf for out-of-range, +inf for + the first/last valid block) to both the kernel output and the torch + reference before comparing — this mirrors what the hisa pipeline + does right after this kernel. + """ + torch.manual_seed(0) + assert M % num_seqs == 0, f"M ({M}) must be divisible by num_seqs ({num_seqs})" + per_seq = M // num_seqs + N_blocked = (M + k_block_size - 1) // k_block_size + assert N_blocked % block_N == 0, ( + f"N_blocked ({N_blocked}) must be a multiple of block_N ({block_N}). Pick M such that ceildiv(M, k_block_size) % block_N == 0." + ) + + # Per-token packed ks/ke (causal within each sequence), then translate + # to pool-block coords. + cu_seqlens = torch.arange(num_seqs + 1, device="cuda", dtype=torch.long) * per_seq + ks_long, ke_long = prepare_ks_ke_from_cu_seqlens(cu_seqlens) + cu_ks_token = ks_long.to(torch.int32).contiguous() + cu_ke_token = ke_long.to(torch.int32).contiguous() + cu_blocked_ks = (cu_ks_token // k_block_size).contiguous() + cu_blocked_ke = ((cu_ke_token + k_block_size - 1) // k_block_size).contiguous() + + q_bf16 = torch.randn(M, H, D, device="cuda", dtype=torch.bfloat16) + q = q_bf16.to(torch.float8_e4m3fn) + blocked_k_bf16 = torch.randn(N_blocked, D, device="cuda", dtype=torch.bfloat16) + blocked_k = blocked_k_bf16.to(torch.float8_e4m3fn) + blocked_k_scale = (0.1 + 0.01 * torch.rand(N_blocked, device="cuda", dtype=torch.float32)).contiguous() + weights = torch.randn(M, H, device="cuda", dtype=torch.float32) + + # Correctness — kernel + post-mask. + got = pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k, + blocked_k_scale, + weights, + cu_blocked_ks, + cu_blocked_ke, + block_N=block_N, + ) + clean_and_maintain_logits_interface(got, cu_blocked_ks, cu_blocked_ke) + + ref = ref_pool_mqa_fp8(q, blocked_k, blocked_k_scale, weights) + ref = ref_clean_and_maintain_logits(ref, cu_blocked_ks, cu_blocked_ke) + + # After the mask, +/-inf positions must agree exactly. Compare the + # remaining finite values under an fp8×fp8 GEMM tolerance. + assert torch.equal(torch.isposinf(got), torch.isposinf(ref)), "pos-inf mask differs" + assert torch.equal(torch.isneginf(got), torch.isneginf(ref)), "neg-inf mask differs" + finite = torch.isfinite(got) & torch.isfinite(ref) + torch.testing.assert_close(got[finite], ref[finite], rtol=5e-2, atol=5e-2) + print(f" correctness: PASS (M={M}, H={H}, D={D}, N_blocked={N_blocked}, block_N={block_N}, num_seqs={num_seqs}, per_seq={per_seq})") + + # Speed (kernel only — excludes the post mask). + def fn(): + return pool_mqa_attn_return_logits_fp8_interface( + q, + blocked_k, + blocked_k_scale, + weights, + cu_blocked_ks, + cu_blocked_ke, + block_N=block_N, + ) + + ms = do_bench(fn, warmup=50, rep=200) + # FLOPs: fp8×fp8 GEMM dominates = 2 * M * H * Nb * D (mul+add). + total_flops = 2 * M * H * N_blocked * D + tflops = total_flops / (ms * 1e-3) / 1e12 + print(f" latency: {ms:.4f} ms ({tflops:.2f} fp8 TFLOPS)") + + +if __name__ == "__main__": + # M × k_block_size^-1 must be a multiple of block_N=256. + # With k_block_size=128 → N_blocked = M/128; need N_blocked % 256 == 0 + # → M % 32768 == 0. + # (M, H, D, k_block_size, block_N, num_seqs) + for cfg in [ + (32768, 64, 128, 128, 256, 1), + (32768, 64, 128, 128, 256, 4), + (65536, 64, 128, 128, 256, 1), + (65536, 64, 128, 128, 256, 8), + (131072, 64, 128, 128, 256, 16), + ]: + test_pool_mqa_fp8(*cfg) diff --git a/examples/dsa_hisa/tilelang_utils.py b/examples/dsa_hisa/tilelang_utils.py new file mode 100644 index 0000000000..80a1441c5b --- /dev/null +++ b/examples/dsa_hisa/tilelang_utils.py @@ -0,0 +1,314 @@ +import torch +import torch.nn.functional as F +import functools +from typing import Callable, Any, Tuple + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + (last_args is not None and last_kwargs is not None) + and (len(args) == len(last_args) and len(kwargs) == len(last_kwargs)) + and all(a is b for a, b in zip(args, last_args, strict=False)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_lens( + lens: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + return F.pad(lens.cumsum(dim=0, dtype=dtype), (1, 0)) + + +@tensor_cache +def prepare_lens_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_position_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.cat([torch.arange(n, dtype=cu_seqlens.dtype, device=cu_seqlens.device) for n in prepare_lens(cu_seqlens).unbind()]) + + +@tensor_cache +def prepare_sequence_ids(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return prepare_position_ids(cu_seqlens).eq(0).cumsum(0) - 1 + + +@tensor_cache +def prepare_token_indices(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + position_ids = prepare_position_ids(cu_seqlens) + return torch.stack([prepare_sequence_ids(cu_seqlens), position_ids], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_cu_seqlens_from_position_ids( + position_ids: torch.LongTensor, + dtype: torch.dtype | None = torch.int32, +) -> torch.LongTensor: + starts = (position_ids == 0).nonzero(as_tuple=True)[0] + total_len = position_ids.new_tensor([position_ids.numel()]) + boundaries = torch.cat([starts, total_len]) + lens = torch.diff(boundaries) + cu_seqlens = prepare_cu_seqlens_from_lens(lens, dtype=dtype) + return cu_seqlens + + +@tensor_cache +def prepare_ks_ke_from_cu_seqlens( + cu_seqlens: torch.LongTensor, +) -> tuple[torch.LongTensor, torch.LongTensor]: + position_ids = prepare_position_ids(cu_seqlens) + sequence_ids = position_ids.eq(0).cumsum(0) - 1 + + ks = cu_seqlens[sequence_ids] + ke = ks + position_ids + 1 + + return ks, ke + + +@tensor_cache +def prepare_ks_ke_from_cu_seqlens_qk( + cu_seqlens_q: torch.LongTensor, + cu_seqlens_k: torch.LongTensor, +) -> tuple[torch.LongTensor, torch.LongTensor]: + position_ids_q = prepare_position_ids(cu_seqlens_q) + sequence_ids_q = position_ids_q.eq(0).cumsum(0) - 1 + + seqlens_q = prepare_lens(cu_seqlens_q) + seqlens_k = prepare_lens(cu_seqlens_k) + offset = seqlens_k - seqlens_q + + ks = cu_seqlens_k[sequence_ids_q] + ke = ks + position_ids_q + offset[sequence_ids_q] + 1 + + return ks, ke + + +def ceil_to_ue8m0(x: torch.Tensor): + assert x.view(-1).amax().item() > 0 + return torch.pow(2.0, torch.ceil(torch.log2(x.abs()))) + + +def per_custom_dims_cast_to_fp8(x: torch.Tensor, dims: Tuple[int], use_ue8m0: bool) -> Tuple[torch.Tensor, torch.Tensor]: + excluded_dims = tuple([i for i in range(x.dim()) if i not in set(dims)]) + x_amax = x.abs().float().amax(dim=excluded_dims, keepdim=True).clamp(1e-4) + sf = x_amax / 448.0 + sf = ceil_to_ue8m0(sf) if use_ue8m0 else sf + x_scaled = (x * (1.0 / sf)).to(torch.float8_e4m3fn) + return x_scaled, sf.squeeze() + + +def get_abs_err(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + return (x - y).flatten().abs().max().item() + + +def get_err_ratio(y, x): + x = x.to(torch.float32) + y = y.to(torch.float32) + err = (x - y).flatten().square().mean().sqrt().item() + base = (x).flatten().square().mean().sqrt().item() + return err / base + + +def calculate_tensor_similarity(x, y, name="tensor"): + """ + Calculate similarity between two tensors using a normalized dot product metric. + + Unlike torch.testing.assert_close which uses absolute/relative tolerance based on + element-wise differences, this function computes a global similarity score: + sim = 2 * / (||x||^2 + ||y||^2) + + This metric is scale-invariant and measures the cosine-like similarity normalized + by the magnitude of both tensors. It returns 1 for identical tensors and values + closer to 0 for dissimilar ones. This is particularly useful for comparing tensors + with varying magnitudes where relative errors matter more than absolute differences. + + Args: + x: First tensor to compare + y: Second tensor to compare + name: Name of the tensor for logging purposes + + Returns: + Similarity score in range [0, 1] where 1 means identical + """ + x, y = x.data.double(), y.data.double() + denominator = (x * x + y * y).sum() + if denominator == 0: + print(f"\033[33mWARNING: {name} all zero\033[0m") + return 1 + sim = 2 * (x * y).sum() / denominator + return sim + + +def assert_tensors_similar(x, y, eps=1e-8, name="tensor", raise_assert=True): + """ + Assert that two tensors are similar using a global similarity metric. + + Key differences from torch.testing.assert_close: + - torch.testing.assert_close: Uses element-wise comparison with rtol/atol, checking + that |x - y| <= atol + rtol * |y| for each element. It's sensitive to outliers + and requires all elements to satisfy the tolerance. + - assert_tensors_similar: Uses a single global similarity score (1 - sim) where sim is the + normalized dot product. It's more robust to outliers and focuses on overall + tensor similarity rather than element-wise precision. This is better suited for + comparing large tensors where a few outlier elements shouldn't fail the test. + + Args: + x: First tensor to compare + y: Second tensor to compare + eps: Maximum allowed difference (1 - similarity), default 1e-8 + name: Name of the tensor for error messages + raise_assert: Whether to raise assertion error on failure + """ + sim = calculate_tensor_similarity(x, y, name) + diff = 1.0 - sim + if not (0 <= diff <= eps): + print(f"\033[31mERROR: {name} similarity check failed, diff={diff:.2e} (threshold={eps:.2e})\033[0m") + if raise_assert: + assert False # noqa: B011 + + +@tensor_cache +def cal_seq_idx_for_q(cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, seq_len: int) -> torch.IntTensor: + seq_idx_for_q = torch.full((seq_len,), len(cu_seqlens_qs), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + seq_idx_for_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = i + return seq_idx_for_q + + +@tensor_cache +def cal_cu_seqlen_ks_for_q( + cu_seqlens_qs: torch.LongTensor, cu_seqlens_qe: torch.LongTensor, cu_seqlens_ks: torch.LongTensor, seq_len: int +) -> torch.IntTensor: + cu_seqlen_ks_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ks, torch.full((1,), torch.iinfo(torch.int32).max, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + return cu_seqlen_ks_for_each_q.int() + + +@tensor_cache +def cal_cu_seqlen_ke_for_q( + cu_seqlens_qs: torch.LongTensor, + cu_seqlens_qe: torch.LongTensor, + cu_seqlens_ks: torch.LongTensor, + cu_seqlens_ke: torch.LongTensor, + q_start_idxs: torch.LongTensor, + seq_len: int, + kv_stride: int, +) -> torch.IntTensor: + cu_seqlen_ke_for_each_q = torch.gather( + input=torch.cat([cu_seqlens_ke, torch.zeros(1, dtype=torch.int32, device=cu_seqlens_qs.device)]), + dim=0, + index=cal_seq_idx_for_q(cu_seqlens_qs=cu_seqlens_qs, cu_seqlens_qe=cu_seqlens_qe, seq_len=seq_len).long(), + ) + casual_cu_seqlen_ke_for_each_q = torch.zeros((seq_len,), dtype=torch.int32, device=cu_seqlens_qs.device) + for i in range(len(cu_seqlens_qs)): + casual_cu_seqlen_ke_for_each_q[cu_seqlens_qs[i] : cu_seqlens_qe[i]] = ( + torch.arange( + q_start_idxs[i], q_start_idxs[i] + cu_seqlens_qe[i] - cu_seqlens_qs[i], dtype=torch.int32, device=cu_seqlens_qs.device + ) + + 1 + ) // kv_stride + cu_seqlens_ks[i] + cu_seqlen_ke_for_each_q = torch.minimum(casual_cu_seqlen_ke_for_each_q, cu_seqlen_ke_for_each_q) + return cu_seqlen_ke_for_each_q.int() + + +def generate_random_cu_seqlens(per_cp_seqlen, cp_size=4, cp_rank=3, kv_stride=1, average_q_len=512): + total_seqlen = per_cp_seqlen * cp_size + + cu_seqlens = torch.randint(0, average_q_len * 2, (total_seqlen // average_q_len * 2,)).cuda() + last_seq_id = torch.where(cu_seqlens.cumsum(0) >= total_seqlen)[0][0] + cu_seqlens = cu_seqlens[:last_seq_id] + + if cu_seqlens.sum() < total_seqlen: + cu_seqlens = torch.cat([cu_seqlens, torch.tensor([total_seqlen - cu_seqlens.sum()]).cuda()]) + + cu_seqlens_cumsum = torch.cumsum(cu_seqlens, dim=0) + cu_seqlens_k_cumsum = torch.cumsum(cu_seqlens // kv_stride, dim=0) + cu_seqlens_qs = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_cumsum[:-1]]) + cu_seqlens_ks = torch.cat([torch.tensor([0]).cuda(), cu_seqlens_k_cumsum[:-1]]) + cu_seqlens_qe = cu_seqlens_cumsum.clone() + cu_seqlens_ke = cu_seqlens_k_cumsum.clone() + + cu_seqlens_ks_for_each_q = cal_cu_seqlen_ks_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + seq_len=total_seqlen, + ) + cu_seqlens_ke_for_each_q = cal_cu_seqlen_ke_for_q( + cu_seqlens_qs=cu_seqlens_qs, + cu_seqlens_qe=cu_seqlens_qe, + cu_seqlens_ks=cu_seqlens_ks, + cu_seqlens_ke=cu_seqlens_ke, + q_start_idxs=torch.zeros_like(cu_seqlens_qs), + seq_len=total_seqlen, + kv_stride=kv_stride, + ) + + assert per_cp_seqlen % 2 == 0 + per_chunk_seqlen = per_cp_seqlen // 2 + slice_short = slice(cp_rank * per_chunk_seqlen, (cp_rank + 1) * per_chunk_seqlen) + slice_long = slice( + total_seqlen - (cp_rank + 1) * per_chunk_seqlen, + total_seqlen - cp_rank * per_chunk_seqlen, + ) + ks = torch.cat( + [ + cu_seqlens_ks_for_each_q[slice_short], + cu_seqlens_ks_for_each_q[slice_long], + ] + ) + ke = torch.cat( + [ + cu_seqlens_ke_for_each_q[slice_short], + cu_seqlens_ke_for_each_q[slice_long], + ] + ) + assert len(ks) == len(ke) == per_cp_seqlen + return ks, ke diff --git a/examples/dsa_sparse_finetune/indexer_bwd.py b/examples/dsa_sparse_finetune/indexer_bwd.py index 68508ad4e4..54e02e4f18 100644 --- a/examples/dsa_sparse_finetune/indexer_bwd.py +++ b/examples/dsa_sparse_finetune/indexer_bwd.py @@ -13,10 +13,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tl.jit(pass_configs=pass_configs) diff --git a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py index d76eb02724..1066199cd0 100644 --- a/examples/dsa_sparse_finetune/indexer_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/indexer_topk_reducesum.py @@ -14,11 +14,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tl.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True, tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tl.jit(pass_configs=pass_configs) diff --git a/examples/dsa_sparse_finetune/sparse_mla_bwd.py b/examples/dsa_sparse_finetune/sparse_mla_bwd.py index 53e5f8bfea..ab0b4fc493 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_bwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_bwd.py @@ -78,10 +78,7 @@ def postprocess_kernel( @tilelang.jit( out_idx=[-2], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def bwd( H, @@ -226,17 +223,17 @@ def sparse_mla_bwd_kernel( if bi_i < BS // split_store: acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[bi_i + s * (BS // split_store), d_i] - for bi_i, d_i in T.Parallel(BS // split_store, D // 4): - T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i * 4], - acc_dkv_shared[bi_i, d_i * 4], + for bi_i, d_i in T.Parallel(BS // split_store, D): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, d_i], + acc_dkv_shared[bi_i, d_i], ) # Atomically update dKV, dKV_tail tensors - for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): - T.atomic_addx4( - dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i * 4], - acc_dkv_tail_shared[bi_i, d_i * 4], + for bi_i, d_i in T.Parallel(BS // split_store, D_tail): + T.atomic_add( + dKV[bos + Indices[bos + s_i, bz, i_i * BS + bi_i + s * (BS // split_store)], bz, D + d_i], + acc_dkv_tail_shared[bi_i, d_i], ) # Store the accumulated dQ diff --git a/examples/dsa_sparse_finetune/sparse_mla_fwd.py b/examples/dsa_sparse_finetune/sparse_mla_fwd.py index d875236952..fcde71928b 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_fwd.py +++ b/examples/dsa_sparse_finetune/sparse_mla_fwd.py @@ -9,10 +9,7 @@ @tilelang.jit( out_idx=[-2, -1], - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def sparse_mla_fwd( heads, diff --git a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py index a03bc74f51..2fff8dd20f 100644 --- a/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py +++ b/examples/dsa_sparse_finetune/sparse_mla_topk_reducesum.py @@ -12,10 +12,7 @@ FP32 = T.float32 INT32 = T.int32 -pass_configs = { - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, -} +pass_configs = {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @tilelang.jit(pass_configs=pass_configs) diff --git a/examples/eager_jit/eagerjit.en.ipynb b/examples/eager_jit/eagerjit.en.ipynb new file mode 100644 index 0000000000..6a2bf8453b --- /dev/null +++ b/examples/eager_jit/eagerjit.en.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Eager JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Eager JIT merges JIT kernel generation and invocation into a single workflow.\n", + "\n", + "The function signature looks similar to Triton, but we add many enhancements; the most important one is allowing rich Tensor annotations:\n", + "\n", + "* If a Tensor has complex shape constraints, we can move its annotation into the function body.\n", + "* Use `T.const` or `T.dynamic` to create shape variables, then annotate complex Tensors with `T.Tensor`.\n", + "* Use `T.empty` to declare return tensors." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "Calling the function with Tensors directly triggers the full JIT compile-and-run pipeline:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "Changing the call arguments may trigger a recompilation when compilation parameters change:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "You can also explicitly call the `compile` method to build the kernel.\n", + "\n", + "1. `ker.compile` compiles the kernel\n", + "2. `ker.get_tir` retrieves the TIR\n", + "3. `ker.par_compile` compiles in parallel" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### Use macros to separate implementation" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "Next, we implement a simple GEMM in several different ways. For convenience, we first write a macro that contains the core GEMM logic:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### Use `T.dynamic` to mark dynamic shapes\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### Use `T.StridedTensor` to annotate tensors with strides\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### Use parameters directly as annotations" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "You can directly use function parameters in the annotations." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### Annotations for runtime variables" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "Runtime variables work the same; if the function annotation becomes too long, you can move it into the function body." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### Constraints for constants" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "A constant annotation created by `T.const` must be used directly at least once, otherwise an error is raised." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### Dynamic dimensions" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "If you want certain parameters in a Tensor annotation to change, it is recommended to switch to the `T.ptr` + `T.match_buffer` style." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### Default arguments" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "Scalar annotations like `T.float32` can carry default values." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## Overhead of argument matching" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "EagerJIT has very small overhead; each additional constant annotation costs about 200 ns.\n", + "* 200 ns is roughly the cost of an FFI call that reads parameters from a `torch.Tensor`'s shape/stride." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## Compilation and parallel compilation" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "Both EagerJIT and the original `jit` (i.e. LazyJIT) support parallel compilation.\n", + "\n", + "To avoid wasting memory on temporary `torch.Tensor` objects, you can use `T.Tensor` to create placeholders." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## More convenient macros" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang's macros have been improved:\n", + "\n", + "1. Allow using `T.Ref` as an annotation, similar to C++ references.\n", + "2. Allow returning multiple values.\n", + "3. Allow nesting and recursion." + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### Passing references with `T.Ref`\n", + "\n", + "A `T.Ref` reference can point to a scalar variable or to an element of a buffer." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # Supports constant indices\n", + " macro_with_ref(x[1])\n", + "\n", + " # Also supports variable indices\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### Pass macros as arguments\n", + "\n", + "You can pass a macro as a function argument." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Recursive macros\n", + "\n", + "You may not need this often, but macros can be recursive as long as the termination condition is known at compile time." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macros returning multiple values" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/eager_jit/eagerjit.zh.ipynb b/examples/eager_jit/eagerjit.zh.ipynb new file mode 100644 index 0000000000..0f7c9be99e --- /dev/null +++ b/examples/eager_jit/eagerjit.zh.ipynb @@ -0,0 +1,977 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e0deecc", + "metadata": {}, + "outputs": [], + "source": [ + "import sys\n", + "from pathlib import Path\n", + "\n", + "sys.path.insert(0, str(Path.cwd().parent.parent.absolute()))\n", + "import tilelang\n", + "import torch\n", + "import tilelang.language as T" + ] + }, + { + "cell_type": "markdown", + "id": "1ca2c56d", + "metadata": {}, + "source": [ + "# Tilelang Lazy JIT" + ] + }, + { + "cell_type": "markdown", + "id": "156e7370", + "metadata": {}, + "source": [ + "## Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "b070c109", + "metadata": {}, + "source": [ + "Tilelang Lazy JIT 将 jit 生成和调用的逻辑合并到一起\n", + "\n", + "函数签名的写法与 triton 相似,但做了大量增强,最主要的增强是允许对 Tensor 的标注:\n", + "\n", + "* 如果一个 Tensor 有复杂的 shape 约束,我们可以把它的标注移动到函数内部\n", + "* 通过 `T.const` 或 `T.dynamic` 来建立一些 shape 变量,然后用 `T.Tensor` 标注复杂的 Tensor\n", + "* 用 `T.empty` 来声明返回值" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "60bf8954", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm(\n", + " A,\n", + " B,\n", + " out_dtype: T.dtype = T.float32,\n", + " block_M: int = 128,\n", + " block_N: int = 128,\n", + " block_K: int = 32,\n", + "):\n", + " M, N, K = T.const(\"M, N, K\")\n", + "\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + "\n", + " C = T.empty((M, N), out_dtype)\n", + "\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), out_dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])\n", + " return C" + ] + }, + { + "cell_type": "markdown", + "id": "28f868fe", + "metadata": {}, + "source": [ + "直接将 Tensor 作为参数调用,即可触发完整的 jit 编译运行流程:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "ee13394a", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B)\n", + "\n", + "# check output is correct\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "c6705091", + "metadata": {}, + "source": [ + "更改调用的参数,如果编译器参数发生了变化,会触发重新编译:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "d8aab5b7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 1024, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm(A, B, block_M=64, block_N=64)" + ] + }, + { + "cell_type": "markdown", + "id": "ce6b7391", + "metadata": {}, + "source": [ + "你也可以手动调用 compile 函数编译 kernel\n", + "\n", + "1. `ker.compile` 编译 kernel\n", + "2. `ker.get_tir` 获取 tir\n", + "3. `ker.par_compile` 并行编译" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "f3cf3a2d", + "metadata": {}, + "outputs": [], + "source": [ + "kernel = gemm.compile(A, B, block_M=64, block_N=64)\n", + "C = kernel(A, B)" + ] + }, + { + "cell_type": "markdown", + "id": "921761b5", + "metadata": {}, + "source": [ + "## More Tensor Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "4539e54e", + "metadata": {}, + "source": [ + "### 用 macro 来分离实现" + ] + }, + { + "cell_type": "markdown", + "id": "ad96ba65", + "metadata": {}, + "source": [ + "接下来,我们会用各种方式来实现一个简单的 gemm,为了方便,我们先写一个 macro 把 gemm 的主要逻辑写出来:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "171d4fe6", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def gemm_impl(A, B, C, M, N, K, block_M, block_N, block_K):\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " A_shared = T.alloc_shared((block_M, block_K), A.dtype)\n", + " B_shared = T.alloc_shared((block_K, block_N), B.dtype)\n", + " C_local = T.alloc_fragment((block_M, block_N), C.dtype)\n", + " T.clear(C_local)\n", + " for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=3):\n", + " T.copy(A[bx * block_M, k * block_K], A_shared)\n", + " T.copy(B[k * block_K, by * block_N], B_shared)\n", + " T.gemm(A_shared, B_shared, C_local)\n", + " T.copy(C_local, C[bx * block_M, by * block_N])" + ] + }, + { + "cell_type": "markdown", + "id": "446a1acd", + "metadata": {}, + "source": [ + "### 用 T.dynamic 标记动态 Shape\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6a38aa95", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_dyn_K(A, B):\n", + " M, N, K = T.dynamic(\"M, N, K\")\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, 128, 128, 32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "fe6cfdc8", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_dyn_K(A, B)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "2ee97bf7", + "metadata": {}, + "source": [ + "### 用 T.StridedTensor 标记带 stride 的 Tensor\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "9dde1dae", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def as_contingious(A):\n", + " M, N, dM, dN = T.dynamic(\"M, N, dM, dN\")\n", + " A: T.StridedTensor[[M, N], [dM, dN], T.float32]\n", + " B = T.empty((M, N), A.dtype)\n", + " block_M = 128\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by):\n", + " T.copy(\n", + " A[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " B[bx * block_M : (bx + 1) * block_M, by * block_N : (by + 1) * block_N],\n", + " )\n", + " return B" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "dec2c0a7", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 1024, device=\"cuda\")\n", + "B = as_contingious(A.T)\n", + "B_ref = A.T.contiguous()\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "f5fb20d6", + "metadata": {}, + "source": [ + "## More Annotation" + ] + }, + { + "cell_type": "markdown", + "id": "890df0a2", + "metadata": {}, + "source": [ + "### 直接用参数当 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "e9a47d42", + "metadata": {}, + "source": [ + "可以直接把函数参数写到 annotation 里面" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "0fc17af6", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr(\n", + " A,\n", + " B,\n", + " M,\n", + " N,\n", + " K,\n", + "):\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "8e52a554", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "6b19ef90", + "metadata": {}, + "source": [ + "### 对运行时变量的 annotation" + ] + }, + { + "cell_type": "markdown", + "id": "bba5f27f", + "metadata": {}, + "source": [ + "运行时变量也是一样,如果嫌函数 annotation 太长,可以放到函数体里面" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "c1e7598a", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def gemm_ptr_dyn(A, B, M, N, K):\n", + " M: T.int32\n", + " N: T.int32\n", + " K: T.int32\n", + " A: T.Tensor[[M, K], T.float16]\n", + " B: T.Tensor[[K, N], T.float16]\n", + " C = T.empty((M, N), T.float32)\n", + " gemm_impl(A, B, C, M, N, K, block_M=128, block_N=128, block_K=32)\n", + " return C" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "9e9a4c88", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, 512, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(512, 256, dtype=torch.float16, device=\"cuda\")\n", + "C = gemm_ptr_dyn(A, B, 1024, 256, 512)\n", + "C_ref = (A @ B).float()\n", + "torch.testing.assert_close(C, C_ref, rtol=1e-2, atol=1e-2)" + ] + }, + { + "cell_type": "markdown", + "id": "81427765", + "metadata": {}, + "source": [ + "### 常量的约束" + ] + }, + { + "cell_type": "markdown", + "id": "4d6b084b", + "metadata": {}, + "source": [ + "`T.const` 创建的常量 annotation 只要要被直接使用一次,否则会报错" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "c90dd24f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Constexpr variable `M` is not used in any buffer shape or stride.\n", + "At least one **DIRECT** usage is required. Please check:\n", + "(1) the variable is not used\n", + "(2) all uses are indirect, e.g. M * 2, M * 3. (you can replace them with separate constexpr variables)\n", + "Buffer shapes: {A: [M * 2, M * 3]}\n", + "Buffer strides: {A: [M * 3, 1]}\n" + ] + } + ], + "source": [ + "@tilelang.jit\n", + "def example_wrong_kernel(A):\n", + " M = T.const(\"M\")\n", + " A: T.Tensor[[M * 2, M * 3], T.float32]\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + "\n", + "\n", + "try:\n", + " A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + " example_wrong_kernel(A)\n", + "except Exception as e:\n", + " print(e)" + ] + }, + { + "cell_type": "markdown", + "id": "e07e762b", + "metadata": {}, + "source": [ + "### 动态维度的" + ] + }, + { + "cell_type": "markdown", + "id": "f48e5d7a", + "metadata": {}, + "source": [ + "如果想要 Tensor 的 annotation 类型某个参数变化,建议改成 T.ptr + T.match_buffer 格式。" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "1d050321", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@tilelang.jit\n", + "def dyn_annot(\n", + " A: T.ptr, # 1. T.ptr type annotation\n", + " is_2d=False,\n", + "):\n", + " if is_2d:\n", + " M, N = T.const(\"M, N\")\n", + " # 2. dynamic shape annotation inside function body\n", + " A = T.match_buffer(A, [M, N], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0, 0]\n", + " else:\n", + " L = T.const(\"L\")\n", + " A = T.match_buffer(A, [L], T.float32)\n", + " with T.Kernel(1) as _:\n", + " A[0]\n", + "\n", + "\n", + "A = torch.randn(64, 96, dtype=torch.float32, device=\"cuda\")\n", + "dyn_annot(A, is_2d=True)" + ] + }, + { + "cell_type": "markdown", + "id": "2e9f1bb3", + "metadata": {}, + "source": [ + "### 带默认参数的" + ] + }, + { + "cell_type": "markdown", + "id": "f7fc9917", + "metadata": {}, + "source": [ + "类似 `T.float32` 标注的标量可以带默认参数" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "42ec86a1", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def add_one(X, data: T.float32 = 1):\n", + " M, N = T.const(\"M, N\")\n", + " X: T.Tensor[[M, N], T.float32]\n", + " Y = T.empty((M, N), T.float32)\n", + " with T.Kernel(T.ceildiv(M, 128), threads=128) as bx:\n", + " for i, j in T.Parallel(128, N):\n", + " Y[bx * 128 + i, j] = X[bx * 128 + i, j] + data\n", + " return Y" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "d49e1120", + "metadata": {}, + "outputs": [], + "source": [ + "X = torch.randn(1024, 1024, dtype=torch.float32, device=\"cuda\")\n", + "Y = add_one(X)\n", + "torch.testing.assert_close(Y, X + 1)" + ] + }, + { + "cell_type": "markdown", + "id": "a02baedc", + "metadata": {}, + "source": [ + "## 参数匹配的 Overhead" + ] + }, + { + "cell_type": "markdown", + "id": "860a2972", + "metadata": {}, + "source": [ + "EagerJIT overhead 很小,每个 constant 添加约 200ns 的 overhead\n", + "* 200ns 大约是从 torch.Tensor 的 shape/stride 中拿参数的 ffi call 的代价" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc676e33", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Kernel call : 7.68 us\n", + "Parse cache key: 0.41 us\n" + ] + } + ], + "source": [ + "import time\n", + "\n", + "A = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "B = torch.randn(128, 128, dtype=torch.float16, device=\"cuda\")\n", + "\n", + "\n", + "@tilelang.jit\n", + "def dummy_kernel(A, B):\n", + " M, N = T.const(\"M, N\")\n", + " A: T.Tensor[[M, N], T.float16]\n", + " B: T.Tensor[[M, N], T.float16]\n", + " with T.Kernel(1) as _:\n", + " pass\n", + "\n", + "\n", + "# compile it first\n", + "dummy_kernel(A, B)\n", + "\n", + "\n", + "def eval_overhead(f):\n", + " start = time.perf_counter_ns()\n", + " for _ in range(10000):\n", + " f()\n", + " stop = time.perf_counter_ns()\n", + " return (stop - start) / 10000 / 1000\n", + "\n", + "\n", + "kernel_call_overhead = eval_overhead(lambda: dummy_kernel(A, B))\n", + "parse_cache_key_overhead = eval_overhead(lambda: dummy_kernel.parse_cache_key(A, B))\n", + "\n", + "print(f\"Kernel call : {kernel_call_overhead:.2f} us\")\n", + "print(f\"Parse cache key: {parse_cache_key_overhead:.2f} us\")" + ] + }, + { + "cell_type": "markdown", + "id": "39166cb4", + "metadata": {}, + "source": [ + "## 编译与并行编译" + ] + }, + { + "cell_type": "markdown", + "id": "8c6fbe08", + "metadata": {}, + "source": [ + "Eager JIT 和原来的 jit(即 LazyJIT) 都支持并行编译\n", + "\n", + "为了防止 torch.tensor 白白浪费内存,可以使用 T.Tensor 来创建 placeholder" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "7222e57b", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "8a4e4eb3cd4445bda6e8693da31ef3b8", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Elaborating: 0%| | 0/8 [00:00,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ,\n", + " ]" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from itertools import product\n", + "\n", + "\n", + "def get_configs():\n", + " return [\n", + " {\n", + " \"A\": T.Tensor((1024, 1024), T.float32),\n", + " \"B\": T.Tensor((1024, 1024), T.float32),\n", + " \"block_M\": block_M,\n", + " \"block_N\": block_N,\n", + " \"block_K\": block_K,\n", + " }\n", + " for block_M, block_N, block_K in product([32, 64], repeat=3)\n", + " ]\n", + "\n", + "\n", + "gemm.par_compile(get_configs())" + ] + }, + { + "cell_type": "markdown", + "id": "5160d2cc", + "metadata": {}, + "source": [ + "## 更便利的 Macro" + ] + }, + { + "cell_type": "markdown", + "id": "be44afc4", + "metadata": {}, + "source": [ + "tilelang 的 macro 现在已经升级:\n", + "\n", + "1. 允许用 `T.Ref` 作为 annotation,这类似与 C++ 的引用传递\n", + "2. 允许返回多个值\n", + "3. 允许嵌套,递归" + ] + }, + { + "cell_type": "markdown", + "id": "79575972", + "metadata": {}, + "source": [ + "### T.Ref 传递引用\n", + "\n", + "T.Ref 传递的引用可以 var 也可以是 Buffer 的索引" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "90eaa6e5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo(x_handle: T.handle):\n", + " x = T.match_buffer(x_handle, (2,), strides=(1,))\n", + " # with T.block(\"root\"):\n", + " bx = T.launch_thread(\"blockIdx.x\", 1)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " idx = T.Buffer((1,), \"int32\", scope=\"local.var\")\n", + " T.writes(x[T.min(1, idx[0]):T.min(1, idx[0]) + (T.max(1, idx[0]) + 1 - T.min(1, idx[0]))])\n", + " T.block_attr({\"tl.local_var_init\": {idx.data: 0}})\n", + " idx = T.alloc_buffer((1,), \"int32\", data=idx.data, scope=\"local.var\")\n", + " x[1] = T.float32(1.0)\n", + " _tmp: T.int32 = idx[0]\n", + " x[_tmp] = T.float32(1.0)" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def macro_with_ref(x: T.Ref):\n", + " x = 1 # noqa: F841\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo(x: T.Tensor((2,))):\n", + " with T.Kernel(1) as _:\n", + " # 支持常量 index\n", + " macro_with_ref(x[1])\n", + "\n", + " # 也支持变量 index\n", + " idx = T.alloc_var(T.int32, 0)\n", + " macro_with_ref(x[idx])\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "markdown", + "id": "7bb447a2", + "metadata": {}, + "source": [ + "### 当作参数传递\n", + "\n", + "你可以把 macro 当做参数传递" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dc7bb779", + "metadata": {}, + "outputs": [], + "source": [ + "@tilelang.jit\n", + "def element_wise(A, fn):\n", + " N = T.dynamic(\"N\")\n", + " A: T.Tensor[[N], T.float32]\n", + " B = T.empty((N,), dtype=A.dtype)\n", + " block_N = 128\n", + " with T.Kernel(T.ceildiv(N, block_N), threads=128) as bx:\n", + " for i in T.Parallel(block_N):\n", + " idx = bx * block_N + i\n", + " B[idx] = fn(A[idx])\n", + " return B\n", + "\n", + "\n", + "@T.macro\n", + "def add_one(x):\n", + " return x + 1" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "a89fdb44", + "metadata": {}, + "outputs": [], + "source": [ + "A = torch.randn(1024, device=\"cuda\")\n", + "B = element_wise(A, add_one)\n", + "B_ref = A + 1\n", + "torch.testing.assert_close(B, B_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "ef6e403a", + "metadata": {}, + "source": [ + "### Macro 递归\n", + "\n", + "虽然不知道有没有这种需求,但 macro 是可以递归的,终止条件要求编译期间确定" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "7703cab5", + "metadata": {}, + "outputs": [], + "source": [ + "@T.macro\n", + "def n31(x, var: T.Ref):\n", + " if x == 1:\n", + " pass\n", + " elif x % 2 == 0:\n", + " var = var // 2\n", + " n31(x // 2, var)\n", + " else:\n", + " var = var * 3 + 1\n", + " n31(x * 3 + 1, var)\n", + "\n", + "\n", + "@tilelang.jit\n", + "def foo(A: T.Tensor[[1], T.int32], n: int):\n", + " with T.Kernel(1) as _:\n", + " n31(n, A[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "542ddd4e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([18], device='cuda:0', dtype=torch.int32)" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "A = torch.tensor([100], dtype=torch.int32, device=\"cuda\")\n", + "foo(A, 5)\n", + "A" + ] + }, + { + "cell_type": "markdown", + "id": "dc30c2d2", + "metadata": {}, + "source": [ + "### Macro 返回多个值" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d5a2388f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "# from tvm.script import tir as T\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " # with T.block(\"root\"):\n", + " x = T.launch_thread(\"blockIdx.x\", 32)\n", + " tx = T.launch_thread(\"threadIdx.x\", 128)\n", + " ty = T.launch_thread(\"threadIdx.y\", 1)\n", + " tz = T.launch_thread(\"threadIdx.z\", 1)\n", + " with T.block(\"tilelang_root\"):\n", + " T.reads()\n", + " T.writes()\n", + " s: T.int32 = T.sin(x)\n", + " c: T.int32 = T.cos(x)\n", + " a: T.int32 = s + c\n", + " b: T.int32 = s - c\n", + " T.evaluate(0)" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "@T.macro\n", + "def sincos(x):\n", + " return T.sin(x), T.cos(x)\n", + "\n", + "\n", + "@T.prim_func\n", + "def foo():\n", + " with T.Kernel(32) as x:\n", + " s, c = sincos(x)\n", + " a = s + c # noqa: F841\n", + " b = s - c # noqa: F841\n", + "\n", + "\n", + "foo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dd83fea7", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tilelang-dev_0", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/elementwise/example_elementwise_add.py b/examples/elementwise/example_elementwise_add.py index 32da940155..3d142ed542 100644 --- a/examples/elementwise/example_elementwise_add.py +++ b/examples/elementwise/example_elementwise_add.py @@ -1,5 +1,4 @@ import argparse -import itertools import torch import tilelang import tilelang.language as T @@ -9,15 +8,6 @@ def ref_program(x, y): return x + y -def get_configs(): - block_M = [64, 128, 256] - block_N = [64, 128, 256] - threads = [64, 128, 256] - configs = list(itertools.product(block_M, block_N, threads)) - return [{"block_M": bm, "block_N": bn, "threads": th} for bm, bn, th in configs] - - -@tilelang.autotune(configs=get_configs()) @tilelang.jit(out_idx=[-1]) def elementwise_add(M, N, block_M, block_N, in_dtype, out_dtype, threads): @T.prim_func @@ -42,12 +32,7 @@ def main(M=1024, N=1024, use_autotune=False): a = torch.randn(M, N, dtype=torch.float32, device="cuda") b = torch.randn(M, N, dtype=torch.float32, device="cuda") - if use_autotune: - kernel = elementwise_add(M, N, in_dtype=T.float32, out_dtype=T.float32) - else: - # Default config - config = {"block_M": 32, "block_N": 32, "threads": 128} - kernel = elementwise_add(M, N, **config, in_dtype=T.float32, out_dtype=T.float32) + kernel = elementwise_add(M, N, block_M=32, block_N=32, threads=128, in_dtype=T.float32, out_dtype=T.float32) out = kernel(a, b) torch.testing.assert_close(out, ref_program(a, b), rtol=1e-2, atol=1e-2) @@ -72,6 +57,5 @@ def run_regression_perf(): parser = argparse.ArgumentParser() parser.add_argument("--m", type=int, default=1024) parser.add_argument("--n", type=int, default=1024) - parser.add_argument("--use_autotune", action="store_true", default=False) args, _ = parser.parse_known_args() - main(args.m, args.n, args.use_autotune) + main(args.m, args.n) diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce.py b/examples/flash_attention/example_gqa_bwd_tma_reduce.py index fea547b6e6..4920d8cf06 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce.py @@ -5,8 +5,6 @@ from tilelang.contrib import nvcc import argparse -tilelang.disable_cache() - @tilelang.jit( out_idx=[3, 4], @@ -49,13 +47,13 @@ def flash_fwd( T.fill(logsum, 0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead - T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + T.fill(scores_max, T.cast(-1e30, accum_dtype)) loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) for k in T.Pipelined(loop_range, num_stages=1): T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) if is_causal: for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.Cast(accum_dtype, -1e30)) + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, T.cast(-1e30, accum_dtype)) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) @@ -211,14 +209,6 @@ def flash_bwd( dv_shared = T.alloc_shared([block_M, dim_v], accum_dtype) dq_shared = T.alloc_shared([block_N, dim_qk], accum_dtype) - T.annotate_layout( - { - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - } - ) - T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) T.clear(dv) @@ -389,7 +379,6 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do) if ctx.use_atomic: @@ -403,11 +392,11 @@ def maybe_contiguous(x): dk = torch.zeros(shape_k, dtype=torch.float32, device=q.device) dv = torch.zeros(shape_v, dtype=torch.float32, device=q.device) kernel(q, k, v, do, lse, delta, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split_novarlen( BATCH, H, N_CTX, D_HEAD_QK, D_HEAD_V, ctx.causal, block_M, block_N, threads=256, num_stages=2, groups=groups ) + mod_post = flashattn_bwd_postprocess(BATCH, H, HEAD_KV, N_CTX, D_HEAD_QK, D_HEAD_V) shape_q = [BATCH, N_CTX, H, D_HEAD_QK] shape_k = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_QK] # sum after kernel shape_v = [groups, BATCH, N_CTX, HEAD_KV, D_HEAD_V] # sum after kernel diff --git a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py index a9f45e077d..b09eec00c4 100644 --- a/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py +++ b/examples/flash_attention/example_gqa_bwd_tma_reduce_varlen.py @@ -76,7 +76,7 @@ def flash_fwd( T.fill(logsum, 0.0) # Warning: in causal/varlen/unaligned seqlen scenarios, the -inf will cause undefined behavior in exp ops # We should set it to negative large number instead - T.fill(scores_max, T.Cast(accum_dtype, -1e30)) + T.fill(scores_max, T.cast(-1e30, accum_dtype)) loop_range = T.ceildiv(k_current_seqlen, block_N) for k in T.Pipelined(loop_range, num_stages=1): for i, d in T.Parallel(block_N, dim_qk): @@ -91,12 +91,12 @@ def flash_fwd( (bx * block_M + i >= k * block_N + j) and (bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen), 0, - T.Cast(accum_dtype, -1e30), + T.cast(-1e30, accum_dtype), ) else: for i, j in T.Parallel(block_M, block_N): acc_s[i, j] = T.if_then_else( - bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.Cast(accum_dtype, -1e30) + bx * block_M + i < q_current_seqlen and k * block_N + j < k_current_seqlen, 0, T.cast(-1e30, accum_dtype) ) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, d in T.Parallel(block_N, dim_v): @@ -286,14 +286,6 @@ def flash_bwd( q_current_seqlen = q_end_idx - q_start_idx k_current_seqlen = k_end_idx - k_start_idx - T.annotate_layout( - { - dQ: make_dq_layout(dQ), - dK: make_dq_layout(dK), - dV: make_dq_layout(dV), - } - ) - T.copy(K[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], K_shared) T.copy(V[k_start_idx + by * block_M : k_start_idx + (by + 1) * block_M, bx // groups, :], V_shared) @@ -508,8 +500,8 @@ def forward( total_q = q_unpad.shape[0] total_kv = k_unpad.shape[0] - mod = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) - o_unpad, lse = mod(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) + kernel = flashattn_fwd(BATCH, total_q, total_kv, N_CTX, H, max_seqlen_q, D_HEAD_QK, D_HEAD_V, causal, block_M, block_N, groups) + o_unpad, lse = kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k) o = pad_input(o_unpad, indices_q, BATCH, N_CTX) ctx.save_for_backward(q_unpad, k_unpad, v_unpad, o_unpad, lse, seqlens_q, seqlens_k, cu_seqlens_q, cu_seqlens_k) ctx.batch = BATCH @@ -541,7 +533,6 @@ def maybe_contiguous(x): block_M = 128 block_N = 32 mod_prep = flashattn_bwd_preprocess(BATCH, H, total_q, N_CTX, ctx.max_seqlen_q, D_HEAD_V) - mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) delta = mod_prep(o, do, cu_seqlens_q) if ctx.use_atomic: @@ -565,7 +556,6 @@ def maybe_contiguous(x): dk = torch.zeros_like(k, dtype=torch.float32) dv = torch.zeros_like(v, dtype=torch.float32) kernel(q, k, v, do, lse_clone, delta, cu_seqlens_q, cu_seqlens_k, dq, dk, dv) - dq, dk, dv = mod_post(dq, dk, dv) else: kernel = flashattn_bwd_split( BATCH, @@ -583,6 +573,7 @@ def maybe_contiguous(x): num_stages=2, groups=groups, ) + mod_post = flashattn_bwd_postprocess(total_q, total_kv, H, HEAD_KV, D_HEAD_QK, D_HEAD_V) dq = torch.zeros_like(q, dtype=torch.float32) dk = torch.empty(groups, *k.shape, dtype=torch.float16, device=q.device) dv = torch.empty(groups, *v.shape, dtype=torch.float16, device=q.device) diff --git a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py b/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py deleted file mode 100644 index c0fe4e33d2..0000000000 --- a/examples/flash_attention/example_mha_bwd_bshd_wgmma_pipelined.py +++ /dev/null @@ -1,353 +0,0 @@ -import torch -import torch.nn.functional as F -import tilelang -import tilelang.language as T -from tilelang.profiler import do_bench -import argparse - - -@tilelang.jit( - out_idx=[3, 4], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, -) -def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape = [batch, seq_len, heads, dim] - dtype = T.float16 - accum_dtype = T.float32 - - @T.prim_func - def flash_fwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - Output: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - ): - with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): - Q_shared = T.alloc_shared([block_M, dim], dtype) - K_shared = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_N, dim], dtype) - acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) - acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) - acc_o = T.alloc_fragment([block_M, dim], accum_dtype) - scores_max = T.alloc_fragment([block_M], accum_dtype) - scores_max_prev = T.alloc_fragment([block_M], accum_dtype) - scores_scale = T.alloc_fragment([block_M], accum_dtype) - scores_sum = T.alloc_fragment([block_M], accum_dtype) - logsum = T.alloc_fragment([block_M], accum_dtype) - - T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) - T.fill(acc_o, 0) - T.fill(logsum, 0) - T.fill(scores_max, -T.infinity(accum_dtype)) - - loop_range = T.ceildiv((bx + 1) * block_M, block_N) if is_causal else T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_range, num_stages=1): - T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) - else: - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) - T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) - T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) - T.copy(scores_max, scores_max_prev) - T.reduce_max(acc_s, scores_max, dim=1, clear=False) - for i in T.Parallel(block_M): - scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) - for i in T.Parallel(block_M): - scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] *= scores_scale[i] - for i, j in T.Parallel(block_M, block_N): - acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) - T.copy(acc_s, acc_s_cast) - T.gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) - T.reduce_sum(acc_s, scores_sum, dim=1) - for i in T.Parallel(block_M): - logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] - for i, j in T.Parallel(block_M, dim): - acc_o[i, j] /= logsum[i] - T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) - for i in T.Parallel(block_M): - logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale - T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) - - return flash_fwd - - -@tilelang.jit( - out_idx=[2], - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - }, -) -def flashattn_bwd_preprocess(batch, heads, seq_len, dim): - dtype = T.float16 - accum_dtype = T.float32 - shape = [batch, seq_len, heads, dim] - blk = 32 - - @T.prim_func - def flash_bwd_prep( - O: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - ): - with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): - o = T.alloc_fragment([blk, blk], dtype) - do = T.alloc_fragment([blk, blk], dtype) - acc = T.alloc_fragment([blk, blk], accum_dtype) - delta = T.alloc_fragment([blk], accum_dtype) - T.clear(acc) - for k in range(T.ceildiv(dim, blk)): - T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) - T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) - for i, j in T.Parallel(blk, blk): - acc[i, j] += o[i, j] * do[i, j] - T.reduce_sum(acc, delta, 1) - T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) - - return flash_bwd_prep - - -@tilelang.jit( - pass_configs={ - tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, - } -) -def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): - sm_scale = (1.0 / dim) ** 0.5 - scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) - shape = [batch, seq_len, heads, dim] - dtype = T.float16 - accum_dtype = T.float32 - - @T.prim_func - def flash_bwd( - Q: T.Tensor(shape, dtype), # type: ignore - K: T.Tensor(shape, dtype), # type: ignore - V: T.Tensor(shape, dtype), # type: ignore - dO: T.Tensor(shape, dtype), # type: ignore - lse: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - Delta: T.Tensor([batch, heads, seq_len], accum_dtype), # type: ignore - dQ: T.Tensor(shape, accum_dtype), # type: ignore - dK: T.Tensor(shape, dtype), # type: ignore - dV: T.Tensor(shape, dtype), # type: ignore - ): - with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=256) as (bx, by, bz): - K_shared = T.alloc_shared([block_M, dim], dtype) - dsT_shared = T.alloc_shared([block_M, block_N], dtype) - # should not store K to local if dim is large - # K_local = T.alloc_fragment([block_M, dim], dtype) - # K_local_T = T.alloc_fragment([block_M, dim], dtype) - # V_local = T.alloc_fragment([block_M, dim], dtype) - q = T.alloc_shared([block_N, dim], dtype) - V_shared = T.alloc_shared([block_M, dim], dtype) - qkT = T.alloc_fragment([block_M, block_N], accum_dtype) - dsT = T.alloc_fragment([block_M, block_N], accum_dtype) - qkT_cast = T.alloc_fragment([block_M, block_N], dtype) - dsT_cast = T.alloc_fragment([block_M, block_N], dtype) - lse_shared = T.alloc_shared([block_N], accum_dtype) - delta = T.alloc_shared([block_N], accum_dtype) - do = T.alloc_shared([block_N, dim], dtype) - dv = T.alloc_fragment([block_M, dim], accum_dtype) - dk = T.alloc_fragment([block_M, dim], accum_dtype) - dq = T.alloc_fragment([block_N, dim], accum_dtype) - dv_shared = T.alloc_shared([block_M, dim], dtype) - dk_shared = T.alloc_shared([block_M, dim], dtype) - dq_shared = T.alloc_shared([block_N, dim], accum_dtype) - - T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) - T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) - T.clear(dv) - T.clear(dk) - loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 - loop_ed = T.ceildiv(seq_len, block_N) - for k in T.Pipelined(loop_st, loop_ed, num_stages=2): - T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) - T.clear(qkT) - T.gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) - T.clear(dsT) - T.gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - T.wait_wgmma(1) - - T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) - for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) - if is_causal: - for i, j in T.Parallel(block_M, block_N): - qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) - # We don't need to handle OOB positions for non-causal cases, - # since OOB values won't affect other positions here. - T.wait_wgmma(0) - T.copy(qkT, qkT_cast) - T.gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow, wg_wait=-1) - - T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) - - for i, j in T.Parallel(block_M, block_N): - dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale - T.gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow, wg_wait=1) - - T.copy(dsT_cast, dsT_shared) - T.clear(dq) - T.gemm(dsT_shared, K_shared, dq, transpose_A=True, wg_wait=1) - T.wait_wgmma(0) - T.copy(dq, dq_shared) - T.atomic_add(dQ[bz, k * block_N : (k + 1) * block_N, bx, :], dq_shared) - T.copy(dv, dv_shared) - T.copy(dk, dk_shared) - T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) - T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) - - return flash_bwd - - -class _attention(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, causal): - BATCH, N_CTX, H, D_HEAD = q.shape - block_M = 64 - block_N = 64 if D_HEAD <= 128 else 32 - mod = flashattn_fwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) - o, lse = mod(q, k, v) - ctx.save_for_backward(q, k, v, o, lse) - ctx.causal = causal - return o - - @staticmethod - def backward(ctx, do): - q, k, v, o, lse = ctx.saved_tensors - BATCH, N_CTX, H, D_HEAD = q.shape - - def maybe_contiguous(x): - if x.stride(-1) != 1: - return x.contiguous() - return x - - do, q, k, v, o = [maybe_contiguous(x) for x in (do, q, k, v, o)] - block_M = 128 - block_N = 128 if D_HEAD <= 64 else 32 - mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) - delta = mod_prep(o, do) - mod = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, ctx.causal, block_M, block_N) - shape = [BATCH, N_CTX, H, D_HEAD] - dq = torch.zeros(shape, dtype=torch.float32, device=q.device) - dk = torch.empty(shape, dtype=torch.float16, device=q.device) - dv = torch.empty(shape, dtype=torch.float16, device=q.device) - mod(q, k, v, do, lse, delta, dq, dk, dv) - dq = dq.to(torch.float16) - return dq, dk, dv, None - - -attention = _attention.apply - - -def ref_program(Q, K, V, is_causal): - dim = Q.size(-1) - scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) - scores = scores / torch.sqrt(torch.tensor(dim, dtype=scores.dtype)) - if is_causal: - seq_len = Q.size(1) - mask = torch.tril(torch.ones(seq_len, seq_len, device=scores.device)) - mask = mask.unsqueeze(0).unsqueeze(0) - scores = scores.masked_fill(mask == 0, float("-inf")) - attention_weights = F.softmax(scores, dim=-1) - output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V) - return output - - -def main( - BATCH: int = 8, - H: int = 32, - N_CTX: int = 1024, - D_HEAD: int = 64, - causal: bool = False, -): - flops_per_matmul = 2.0 * BATCH * H * N_CTX * N_CTX * D_HEAD - total_flops = 5 * flops_per_matmul - if causal: - total_flops *= 0.5 - Q = torch.empty(BATCH, N_CTX, H, D_HEAD, dtype=torch.half, device="cuda").normal_().requires_grad_() - K = torch.empty_like(Q).normal_().requires_grad_() - V = torch.empty_like(Q).normal_().requires_grad_() - dO = torch.randn_like(Q) - O = attention(Q, K, V, causal) - O.backward(dO, retain_graph=True) - dQ, Q.grad = Q.grad.clone(), None - dK, K.grad = K.grad.clone(), None - dV, V.grad = V.grad.clone(), None - - O_ref = ref_program(Q, K, V, causal) - O_ref.backward(dO, retain_graph=True) - dQ_ref, Q.grad = Q.grad.clone(), None - dK_ref, K.grad = K.grad.clone(), None - dV_ref, V.grad = V.grad.clone(), None - - assert torch.allclose(O, O_ref, rtol=1e-2, atol=1e-2) - assert torch.allclose(dV, dV_ref, rtol=1e-2, atol=1e-2) - assert torch.allclose(dK, dK_ref, rtol=1e-2, atol=1e-2) - assert torch.allclose(dQ, dQ_ref, rtol=1e-2, atol=1e-2) - print("All checks passed.✅") - - def run(): - O_ref.backward(dO, retain_graph=True) - - def run1(): - O.backward(dO, retain_graph=True) - - latency = do_bench(run, warmup=500) - print("torch: {:.2f} ms".format(latency)) - print("torch: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - latency = do_bench(run1, warmup=500) - print("tilelang: {:.2f} ms".format(latency)) - print("tilelang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) - - -def run_regression_perf(): - BATCH = 1 - H = 32 - N_CTX = 256 - D_HEAD = 64 - causal = False - device = "cuda" - torch.manual_seed(0) - block_M = 128 - block_N = 128 if D_HEAD <= 64 else 32 - Q = torch.randn(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.half) - K = torch.randn_like(Q) - V = torch.randn_like(Q) - O = torch.randn_like(Q) - dO = torch.randn_like(Q) - lse = torch.zeros(BATCH, H, N_CTX, device=device, dtype=torch.float32) - with torch.no_grad(): - mod_prep = flashattn_bwd_preprocess(BATCH, H, N_CTX, D_HEAD) - kernel = flashattn_bwd(BATCH, H, N_CTX, D_HEAD, causal, block_M, block_N) - dQ = torch.zeros(BATCH, N_CTX, H, D_HEAD, device=device, dtype=torch.float32) - dK = torch.zeros_like(Q, dtype=torch.float16) - dV = torch.zeros_like(Q, dtype=torch.float16) - Delta = mod_prep(O, dO) - - from tilelang.profiler import do_bench - - def run_kernel_only(): - kernel(Q, K, V, dO, lse, Delta, dQ, dK, dV) - - return do_bench(run_kernel_only, backend="cupti") - - -if __name__ == "__main__": - parser = argparse.ArgumentParser() - parser.add_argument("--batch", type=int, default=8, help="Batch size") - parser.add_argument("--h", type=int, default=32, help="Number of heads") - parser.add_argument("--n_ctx", type=int, default=1024, help="Context size") - parser.add_argument("--d_head", type=int, default=64, help="Head dimension") - parser.add_argument("--causal", type=bool, default=False, help="Causal flag") - args = parser.parse_args() - main(args.batch, args.h, args.n_ctx, args.d_head, args.causal) diff --git a/examples/flash_attention/regression_example_flash_attention.py b/examples/flash_attention/regression_example_flash_attention.py index 8710bbb6e2..86bea2f86e 100644 --- a/examples/flash_attention/regression_example_flash_attention.py +++ b/examples/flash_attention/regression_example_flash_attention.py @@ -1,17 +1,12 @@ import tilelang.testing import example_gqa_fwd_bshd -import example_gqa_fwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd -import example_mha_fwd_bhsd_wgmma_pipelined import example_mha_fwd_bshd -import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen import example_gqa_bwd_tma_reduce_varlen import example_gqa_bwd -import example_gqa_bwd_wgmma_pipelined import example_mha_bwd_bshd import example_mha_bwd_bhsd -import example_mha_bwd_bshd_wgmma_pipelined def regression_example_gqa_bwd_tma_reduce_varlen(): @@ -22,10 +17,6 @@ def regression_example_gqa_bwd(): tilelang.testing.process_func(example_gqa_bwd.run_regression_perf) -def regression_example_gqa_bwd_wgmma_pipelined(): - tilelang.testing.process_func(example_gqa_bwd_wgmma_pipelined.run_regression_perf) - - def regression_example_mha_bwd_bshd(): tilelang.testing.process_func(example_mha_bwd_bshd.run_regression_perf) @@ -34,34 +25,16 @@ def regression_example_mha_bwd_bhsd(): tilelang.testing.process_func(example_mha_bwd_bhsd.run_regression_perf) -def regression_example_mha_bwd_bshd_wgmma_pipelined(): - tilelang.testing.process_func(example_mha_bwd_bshd_wgmma_pipelined.run_regression_perf) - - -def regression_example_gqa_fwd_bshd_wgmma_pipelined(): - tilelang.testing.process_func( - example_gqa_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 - ) - - def regression_example_gqa_fwd_bshd(): tilelang.testing.process_func( example_gqa_fwd_bshd.run_regression_perf, batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16 ) -def regression_example_mha_fwd_bhsd_wgmma_pipelined(): - tilelang.testing.process_func(example_mha_fwd_bhsd_wgmma_pipelined.run_regression_perf) - - def regression_example_mha_fwd_bhsd(): tilelang.testing.process_func(example_mha_fwd_bhsd.run_regression_perf) -def regression_example_mha_fwd_bshd_wgmma_pipelined(): - tilelang.testing.process_func(example_mha_fwd_bshd_wgmma_pipelined.run_regression_perf, batch=1, heads=32, seq_len=256) - - def regression_example_mha_fwd_bshd(): tilelang.testing.process_func(example_mha_fwd_bshd.run_regression_perf, batch=1, seq_len=256) diff --git a/examples/flash_attention/test_example_flash_attention.py b/examples/flash_attention/test_example_flash_attention.py index a74bf071b9..dc8b9d9266 100644 --- a/examples/flash_attention/test_example_flash_attention.py +++ b/examples/flash_attention/test_example_flash_attention.py @@ -1,22 +1,18 @@ import tilelang.testing import example_gqa_bwd -import example_gqa_bwd_wgmma_pipelined import example_mha_bwd_bshd import example_mha_bwd_bhsd -import example_mha_fwd_bhsd_wgmma_pipelined import example_gqa_fwd_bshd import example_mha_fwd_bshd -import example_gqa_fwd_bshd_wgmma_pipelined -import example_mha_fwd_bshd_wgmma_pipelined import example_mha_fwd_varlen -import example_mha_bwd_bshd_wgmma_pipelined import example_mha_fwd_bhsd import example_gqa_bwd_tma_reduce_varlen import example_gqa_fwd_varlen @tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_gqa_bwd_tma_reduce_varlen(): example_gqa_bwd_tma_reduce_varlen.main() @@ -26,12 +22,6 @@ def test_example_gqa_bwd(): example_gqa_bwd.main() -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_gqa_bwd_wgmma_pipelined(): - example_gqa_bwd_wgmma_pipelined.main() - - @tilelang.testing.requires_cuda def test_example_mha_bwd(): example_mha_bwd_bshd.main( @@ -54,40 +44,16 @@ def test_example_mha_bwd_bhsd(): ) -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_mha_bwd_wgmma_pipelined(): - example_mha_bwd_bshd_wgmma_pipelined.main(BATCH=1, H=32, N_CTX=256, D_HEAD=64, causal=False) - - -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_gqa_fwd_bshd_wgmma_pipelined(): - example_gqa_fwd_bshd_wgmma_pipelined.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) - - @tilelang.testing.requires_cuda def test_example_gqa_fwd_bshd(): example_gqa_fwd_bshd.main(batch=1, heads=16, seq_len=1024, dim=128, is_causal=False, groups=16, tune=False) -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_mha_fwd_bhsd_wgmma_pipelined(): - example_mha_fwd_bhsd_wgmma_pipelined.main() - - @tilelang.testing.requires_cuda def test_example_mha_fwd_bhsd(): example_mha_fwd_bhsd.main() -@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_mha_fwd_bshd_wgmma_pipelined(): - example_mha_fwd_bshd_wgmma_pipelined.main(batch=1, heads=32, seq_len=256) - - @tilelang.testing.requires_cuda def test_example_mha_fwd_bshd(): example_mha_fwd_bshd.main(batch=1, seq_len=256) diff --git a/examples/flash_attention_sm100/gqa_bwd_bshd.py b/examples/flash_attention_sm100/gqa_bwd_bshd.py new file mode 100644 index 0000000000..95e1c35d60 --- /dev/null +++ b/examples/flash_attention_sm100/gqa_bwd_bshd.py @@ -0,0 +1,345 @@ +"""Blackwell (SM100) GQA backward, BSHD layout. + +Q/dQ: [batch, seq_len, heads, dim]; K,V,dK,dV: [batch, seq_len, head_kv, dim]; head_kv = heads // groups. +dK/dV use atomic_add (multiple Q heads -> same KV head). +Pipeline (default): --variant ss. ts (optional): --variant ts. +""" + +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + + +@tilelang.jit(out_idx=[3, 4], pass_configs=PASS_CFG) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=1): + """Forward for LSE; K/V indexed by by // groups.""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + scale = (1.0 / dim) ** 0.5 * 1.44269504 + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.tcgen05_gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.tcgen05_gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return main + + +@tilelang.jit(out_idx=[2], pass_configs=PASS_CFG) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.bfloat16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def main( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return main + + +def make_dq_layout(dQ): + return T.Layout( + dQ.shape, + lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2], + ) + + +@tilelang.jit(out_idx=[1], pass_configs=PASS_CFG) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.bfloat16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def main( + dQ: T.Tensor(shape, accum_dtype), + dQ_out: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return main + + +@tilelang.jit(pass_configs=PASS_CFG) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=1, threads=128, num_stages=2): + """GQA backward: K/V/dK/dV use bx // groups; dK/dV use atomic_add.""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + dO: T.Tensor(q_shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(q_shape, accum_dtype), + dK: T.Tensor(kv_shape, accum_dtype), + dV: T.Tensor(kv_shape, accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + qkT_shared = T.alloc_shared([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], accum_dtype) + dk_shared = T.alloc_shared([block_M, dim], accum_dtype) + + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx // groups, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx // groups, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.tcgen05_gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.tcgen05_gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.copy(qkT_cast, qkT_shared) + T.tcgen05_gemm(qkT_shared, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.copy(dsT_cast, dsT_shared) + T.tcgen05_gemm(dsT_shared, q, dk, policy=T.GemmWarpPolicy.FullRow) + T.clear(dq) + T.tcgen05_gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.atomic_add(dV[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dv_shared) + T.atomic_add(dK[bz, by * block_M : (by + 1) * block_M, bx // groups, :], dk_shared) + + return main + + +def flashattn_bwd_pipeline(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=1): + return flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + groups=groups, + threads=128, + num_stages=2, + ) + + +def flashattn_bwd_warp(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=1): + return flashattn_bwd( + batch, + heads, + seq_len, + dim, + is_causal, + block_M, + block_N, + groups=groups, + threads=256, + num_stages=2, + ) + + +def ref_program(Q, K, V, is_causal, groups=1): + """CPU reference for forward only; backward ref omitted for brevity.""" + dim = Q.size(-1) + K_f = K.cpu().float().repeat_interleave(groups, dim=2) + V_f = V.cpu().float().repeat_interleave(groups, dim=2) + Q_f = Q.cpu().float() + scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f) + scores = scores / (dim**0.5) + if is_causal: + seq_len = Q_f.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + P = F.softmax(scores, dim=-1) + out_ref = torch.einsum("bhqk,bkhd->bqhd", P, V_f) + return out_ref.to(Q.dtype) + + +def main( + batch: int = 2, + heads: int = 4, + seq_len: int = 256, + dim: int = 128, + is_causal: bool = False, + groups: int = 1, + variant: str = "ss", +): + """Run GQA backward kernels (fwd + preprocess + bwd + postprocess).""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + block_M = 64 + block_N = 64 if dim <= 64 else 32 + bwd_fn = flashattn_bwd_warp if variant == "ts" else flashattn_bwd_pipeline + + kernel_fwd = flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=groups) + kernel_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim) + kernel_post = flashattn_bwd_postprocess(batch, heads, seq_len, dim) + kernel_bwd = bwd_fn(batch, heads, seq_len, dim, is_causal, block_M, block_N, groups=groups) + + Q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=torch.bfloat16) + K = torch.randn(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.bfloat16) + V = torch.randn(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.bfloat16) + dO = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=torch.bfloat16) + + O, lse = kernel_fwd(Q, K, V) + Delta = kernel_prep(O, dO) + dQ = torch.zeros(batch, seq_len, heads, dim, device="cuda", dtype=torch.float32) + dK = torch.zeros(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.float32) + dV = torch.zeros(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.float32) + kernel_bwd(Q, K, V, dO, lse, Delta, dQ, dK, dV) + _ = kernel_post(dQ) # dQ_out in output layout; not compared to ref (no backward ref) + print("Blackwell GQA bwd ({}): run OK (backward gradients not verified against ref).".format(variant)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--is_causal", action="store_true") + parser.add_argument("--groups", type=int, default=1, help="head_kv = heads // groups") + parser.add_argument( + "--variant", + choices=["ss", "ts"], + default="ss", + help="ss: pipeline (default); ts: 256 threads", + ) + args = parser.parse_args() + main( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.groups, + args.variant, + ) diff --git a/examples/flash_attention_sm100/gqa_fwd_bshd.py b/examples/flash_attention_sm100/gqa_fwd_bshd.py new file mode 100644 index 0000000000..775cb45dd1 --- /dev/null +++ b/examples/flash_attention_sm100/gqa_fwd_bshd.py @@ -0,0 +1,504 @@ +"""Blackwell (SM100) GQA forward, BSHD layout. + +Q: [batch, seq_len, heads, dim], K/V: [batch, seq_len, head_kv, dim], head_kv = heads // groups. +variant='ss': mma_ss for both GEMMs (128 threads, P via shared memory). +variant='ts': mma_ts for GEMM 2 (256 threads, P via tensor memory). +variant='wasp': warp-specialized pipeline (softmax/DMA/BMM warps); GEMM 2 mma_ts. +""" + +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse + + +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + + +@tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) +def flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=128, + block_N=128, + variant="ss", +): + """GQA forward. variant='ss': mma_ss (128t, P via shared); 'ts': mma_ts (256t, P via TMEM).""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + use_ts = variant == "ts" + threads = 256 if use_ts else 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + + S_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + D_tmem = T.alloc_tmem([block_M, dim], accum_dtype) + mbar_s = T.alloc_barrier(1) + mbar_d = T.alloc_barrier(1) + + if use_ts: + P_tmem = T.alloc_tmem([block_M, block_N], dtype) + else: + P_shared = T.alloc_shared([block_M, block_N], dtype) + + S_reg = T.alloc_fragment([block_M, block_N], accum_dtype) + P_cast = T.alloc_fragment([block_M, block_N], dtype) + O_reg = T.alloc_fragment([block_M, dim], accum_dtype) + D_reg = T.alloc_fragment([block_M, dim], accum_dtype) + + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(O_reg, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min( + T.ceildiv(seq_len, block_N), + T.ceildiv((bx + 1) * block_M, block_N), + ) + if is_causal + else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared) + + T.tcgen05_gemm( + Q_shared, + K_shared, + S_tmem, + transpose_B=True, + mbar=mbar_s, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar_s, k % 2) + + T.copy(S_tmem, S_reg) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + S_reg[i, j], + -T.infinity(accum_dtype), + ) + else: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + k * block_N + j >= seq_len, + -T.infinity(accum_dtype), + S_reg[i, j], + ) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(S_reg, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.exp2(S_reg[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(S_reg, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] *= scores_scale[i] + + T.copy(S_reg, P_cast) + if use_ts: + T.copy(P_cast, P_tmem) + P_operand = P_tmem + else: + T.copy(P_cast, P_shared) + P_operand = P_shared + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared) + + T.tcgen05_gemm( + P_operand, + V_shared, + D_tmem, + mbar=mbar_d, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar_d, k % 2) + + T.copy(D_tmem, D_reg) + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] += D_reg[i, j] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] /= logsum[i] + T.copy(O_reg, O_shared) + T.copy( + O_shared, + Output[bz, bx * block_M : (bx + 1) * block_M, by, :], + ) + + return main + + +flashattn_ss = flashattn +flashattn_ts = flashattn + + +@tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) +def flashattn_wasp( + batch, + heads, + seq_len, + dim, + is_causal, + groups=1, + block_M=128, + block_N=128, + threads=256, + num_stages=2, +): + """GQA warp-specialized pipeline: softmax(0-127)/DMA(128-159)/BMM(160-191); GEMM2 mma_ts.""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + scale = (1.0 / dim) ** 0.5 * 1.44269504 + q_shape = [batch, seq_len, heads, dim] + kv_shape = [batch, seq_len, head_kv, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), + K: T.Tensor(kv_shape, dtype), + V: T.Tensor(kv_shape, dtype), + Output: T.Tensor(q_shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared_0 = T.alloc_shared([block_N, dim], dtype) + K_shared_1 = T.alloc_shared([block_N, dim], dtype) + V_shared_0 = T.alloc_shared([block_N, dim], dtype) + V_shared_1 = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + + S_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + P_tmem = T.alloc_tmem([block_M, block_N], dtype) + O_tmem = T.alloc_tmem([block_M, dim], accum_dtype) + + mbar_dma1_empty = T.alloc_barrier([32] * num_stages) + mbar_dma1_full = T.alloc_barrier([32] * num_stages) + mbar_bmm1_empty = T.alloc_barrier([128] * num_stages) + mbar_bmm1_full = T.alloc_barrier([1] * num_stages) + mbar_dma2_empty = T.alloc_barrier([32] * num_stages) + mbar_dma2_full = T.alloc_barrier([32] * num_stages) + mbar_bmm2_full = T.alloc_barrier([1] * num_stages) + mbar_softmax_empty = T.alloc_barrier([32] * num_stages) + mbar_softmax_full = T.alloc_barrier([128] * num_stages) + mbar_correction_full = T.alloc_barrier([32] * num_stages) + + tid = T.get_thread_binding() + + S_reg = T.alloc_fragment([block_M, block_N], accum_dtype) + P_cast = T.alloc_fragment([block_M, block_N], dtype) + O_reg = T.alloc_fragment([block_M, dim], accum_dtype) + + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_rescale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + if tid < 128: + T.fill(O_reg, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(O_reg, O_tmem) + + loop_range = ( + T.min( + T.ceildiv(seq_len, block_N), + T.ceildiv((bx + 1) * block_M, block_N), + ) + if is_causal + else T.ceildiv(seq_len, block_N) + ) + + for k in T.serial(loop_range): + parity = (k // num_stages) & 1 + parity_inv = parity ^ 1 + stage_id = k % num_stages + is_clear_accum = k == 0 + + if tid >= 128 and tid < 160: + T.mbarrier_wait_parity(mbar_dma1_empty[stage_id], parity_inv) + + if k == 0: + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + + if stage_id == 0: + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared_0) + else: + T.copy(K[bz, k * block_N : (k + 1) * block_N, by // groups, :], K_shared_1) + T.mbarrier_arrive(mbar_dma1_full[stage_id]) + + T.mbarrier_wait_parity(mbar_dma2_empty[stage_id], parity_inv) + + if stage_id == 0: + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared_0) + else: + T.copy(V[bz, k * block_N : (k + 1) * block_N, by // groups, :], V_shared_1) + + T.mbarrier_arrive(mbar_dma2_full[stage_id]) + + elif tid >= 160 and tid < 192: + T.mbarrier_wait_parity(mbar_dma1_full[stage_id], parity) + T.mbarrier_wait_parity(mbar_bmm1_empty[stage_id], parity_inv) + + if stage_id == 0: + T.tcgen05_gemm( + Q_shared, + K_shared_0, + S_tmem, + transpose_B=True, + mbar=mbar_bmm1_full[stage_id], + clear_accum=True, + ) + else: + T.tcgen05_gemm( + Q_shared, + K_shared_1, + S_tmem, + transpose_B=True, + mbar=mbar_bmm1_full[stage_id], + clear_accum=True, + ) + T.mbarrier_arrive(mbar_dma1_empty[stage_id]) + + T.mbarrier_wait_parity(mbar_softmax_full[stage_id], parity) + T.mbarrier_wait_parity(mbar_dma2_full[stage_id], parity) + + if stage_id == 0: + T.tcgen05_gemm( + P_tmem, + V_shared_0, + O_tmem, + mbar=mbar_bmm2_full[stage_id], + clear_accum=is_clear_accum, + ) + else: + T.tcgen05_gemm( + P_tmem, + V_shared_1, + O_tmem, + mbar=mbar_bmm2_full[stage_id], + clear_accum=is_clear_accum, + ) + + T.mbarrier_arrive(mbar_softmax_empty[stage_id]) + T.mbarrier_arrive(mbar_dma2_empty[stage_id]) + + if k == loop_range - 1: + T.mbarrier_arrive(mbar_correction_full[0]) + + elif tid < 128: + T.mbarrier_wait_parity(mbar_softmax_empty[stage_id], parity_inv) + T.mbarrier_wait_parity(mbar_bmm1_full[stage_id], parity) + if k > 0: + prev_stage = (k - 1) % num_stages + prev_parity = ((k - 1) // num_stages) & 1 + T.mbarrier_wait_parity(mbar_bmm2_full[prev_stage], prev_parity) + + T.copy(O_tmem, O_reg) + T.copy(S_tmem, S_reg) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + S_reg[i, j], + -T.infinity(accum_dtype), + ) + else: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + k * block_N + j >= seq_len, + -T.infinity(accum_dtype), + S_reg[i, j], + ) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(S_reg, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_rescale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.exp2(S_reg[i, j] * scale - scores_max[i] * scale) + + T.reduce_sum(S_reg, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_rescale[i] + scores_sum[i] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] *= scores_rescale[i] + + T.copy(S_reg, P_cast) + T.copy(P_cast, P_tmem) + T.copy(O_reg, O_tmem) + + T.mbarrier_arrive(mbar_softmax_full[stage_id]) + T.mbarrier_arrive(mbar_bmm1_empty[stage_id]) + + if k == loop_range - 1: + T.mbarrier_wait_parity(mbar_correction_full[0], 0) + T.mbarrier_wait_parity(mbar_bmm2_full[stage_id], parity) + T.copy(O_tmem, O_reg) + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] /= logsum[i] + T.copy(O_reg, O_shared) + T.copy( + O_shared, + Output[bz, bx * block_M : (bx + 1) * block_M, by, :], + ) + + return main + + +flashattn_warp = flashattn_wasp + + +def ref_program(Q, K, V, is_causal, groups=1): + """CPU reference: K/V [b,s,head_kv,d], expand to heads for einsum.""" + assert Q.size(2) == K.size(2) * groups + dim = Q.size(-1) + K_f = K.cpu().float().repeat_interleave(groups, dim=2) + V_f = V.cpu().float().repeat_interleave(groups, dim=2) + Q_f = Q.cpu().float() + scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f) + scores = scores / (dim**0.5) + if is_causal: + seq_len = Q_f.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + P = F.softmax(scores, dim=-1) + O = torch.einsum("bhqk,bkhd->bqhd", P, V_f) + return O.to(Q.dtype) + + +def main( + batch: int = 2, + heads: int = 4, + seq_len: int = 256, + dim: int = 128, + is_causal: bool = False, + groups: int = 1, + variant: str = "ss", +): + """Run GQA forward kernel (ss or ts variant) and benchmark.""" + if groups <= 0 or heads % groups != 0: + raise ValueError("groups must be a positive divisor of heads") + head_kv = heads // groups + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + print(f"=== Blackwell GQA Forward ({variant.upper()}) ===") + print(f"batch={batch}, heads={heads}, head_kv={head_kv}, groups={groups}, seq_len={seq_len}, dim={dim}, causal={is_causal}") + + if variant in ("ss", "ts"): + kernel = flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=128, + block_N=128, + variant=variant, + ) + else: + kernel = flashattn_wasp( + batch, + heads, + seq_len, + dim, + is_causal, + groups=groups, + block_M=128, + block_N=128, + threads=256, + num_stages=2, + ) + + Q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=torch.bfloat16) + K = torch.randn(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.bfloat16) + V = torch.randn(batch, seq_len, head_kv, dim, device="cuda", dtype=torch.bfloat16) + + out = kernel(Q, K, V) + ref = ref_program(Q, K, V, is_causal, groups).to(out.device) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + print("Correctness check passed.") + + latency = do_bench(lambda: kernel(Q, K, V), warmup=100) + print(f"Blackwell GQA fwd ({variant}): {latency:.2f} ms") + print(f"Blackwell GQA fwd ({variant}): {total_flops / latency * 1e-9:.2f} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--is_causal", action="store_true") + parser.add_argument("--groups", type=int, default=1, help="GQA: head_kv = heads // groups") + parser.add_argument( + "--variant", + choices=["ss", "ts", "wasp"], + default="ss", + help="ss: pipeline (default); ts: 256 threads; wasp: warp-specialized", + ) + args = parser.parse_args() + main( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.groups, + args.variant, + ) diff --git a/examples/flash_attention_sm100/mha_bwd_bshd.py b/examples/flash_attention_sm100/mha_bwd_bshd.py new file mode 100644 index 0000000000..45406a2eda --- /dev/null +++ b/examples/flash_attention_sm100/mha_bwd_bshd.py @@ -0,0 +1,309 @@ +"""Blackwell (SM100) MHA backward, BSHD layout. + +Pipeline (default): --variant ss or default. +ts (optional): --variant ts (256 threads, 2 stages). +""" + +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +import argparse + + +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + + +@tilelang.jit(out_idx=[3, 4], pass_configs=PASS_CFG) +def flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N): + """Forward to get O and LSE (for backward).""" + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=128) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + acc_s = T.alloc_fragment([block_M, block_N], accum_dtype) + acc_s_cast = T.alloc_fragment([block_M, block_N], dtype) + acc_o = T.alloc_fragment([block_M, dim], accum_dtype) + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(acc_o, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + loop_range = ( + T.min(T.ceildiv(seq_len, block_N), T.ceildiv((bx + 1) * block_M, block_N)) if is_causal else T.ceildiv(seq_len, block_N) + ) + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(bx * block_M + i >= k * block_N + j, 0, -T.infinity(acc_s.dtype)) + else: + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.if_then_else(k * block_N + j >= seq_len, -T.infinity(acc_s.dtype), 0) + T.tcgen05_gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + T.copy(scores_max, scores_max_prev) + T.reduce_max(acc_s, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] *= scores_scale[i] + for i, j in T.Parallel(block_M, block_N): + acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) + T.copy(acc_s, acc_s_cast) + T.tcgen05_gemm(acc_s_cast, V_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + T.reduce_sum(acc_s, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + for i, j in T.Parallel(block_M, dim): + acc_o[i, j] /= logsum[i] + T.copy(acc_o, Output[bz, bx * block_M : (bx + 1) * block_M, by, :]) + for i in T.Parallel(block_M): + logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale + T.copy(logsum, lse[bz, by, bx * block_M : (bx + 1) * block_M]) + + return main + + +@tilelang.jit(out_idx=[2], pass_configs=PASS_CFG) +def flashattn_bwd_preprocess(batch, heads, seq_len, dim): + dtype = T.bfloat16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 32 + + @T.prim_func + def main( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, blk), batch) as (bx, by, bz): + o = T.alloc_fragment([blk, blk], dtype) + do = T.alloc_fragment([blk, blk], dtype) + acc = T.alloc_fragment([blk, blk], accum_dtype) + delta = T.alloc_fragment([blk], accum_dtype) + T.clear(acc) + for k in range(T.ceildiv(dim, blk)): + T.copy(O[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], o) + T.copy(dO[bz, by * blk : (by + 1) * blk, bx, k * blk : (k + 1) * blk], do) + for i, j in T.Parallel(blk, blk): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, bx, by * blk : (by + 1) * blk]) + + return main + + +def make_dq_layout(dQ): + return T.Layout( + dQ.shape, + lambda b, l, h, d: [b, l // 8, h, d // 8, (d % 2), 4 * (l % 8) + (d % 8) // 2], + ) + + +@tilelang.jit(out_idx=[1], pass_configs=PASS_CFG) +def flashattn_bwd_postprocess(batch, heads, seq_len, dim): + dtype = T.bfloat16 + accum_dtype = T.float32 + shape = [batch, seq_len, heads, dim] + blk = 64 + + @T.prim_func + def main( + dQ: T.Tensor(shape, accum_dtype), + dQ_out: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, blk), heads, batch, threads=128) as (bx, by, bz): + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy( + dQ[bz, bx * blk : (bx + 1) * blk, by, :], + dQ_out[bz, bx * blk : (bx + 1) * blk, by, :], + ) + + return main + + +@tilelang.jit(pass_configs=PASS_CFG) +def flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2): + """Blackwell MHA backward. Pipeline default (128, 2); ts = (256, 2).""" + sm_scale = (1.0 / dim) ** 0.5 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + lse: T.Tensor([batch, heads, seq_len], accum_dtype), + Delta: T.Tensor([batch, heads, seq_len], accum_dtype), + dQ: T.Tensor(shape, accum_dtype), + dK: T.Tensor(shape, dtype), + dV: T.Tensor(shape, dtype), + ): + with T.Kernel(heads, T.ceildiv(seq_len, block_M), batch, threads=threads) as (bx, by, bz): + K_shared = T.alloc_shared([block_M, dim], dtype) + dsT_shared = T.alloc_shared([block_M, block_N], dtype) + q = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_M, dim], dtype) + qkT = T.alloc_fragment([block_M, block_N], accum_dtype) + dsT = T.alloc_fragment([block_M, block_N], accum_dtype) + qkT_cast = T.alloc_fragment([block_M, block_N], dtype) + dsT_cast = T.alloc_fragment([block_M, block_N], dtype) + lse_shared = T.alloc_shared([block_N], accum_dtype) + delta = T.alloc_shared([block_N], accum_dtype) + do = T.alloc_shared([block_N, dim], dtype) + dv = T.alloc_fragment([block_M, dim], accum_dtype) + dk = T.alloc_fragment([block_M, dim], accum_dtype) + dq = T.alloc_fragment([block_N, dim], accum_dtype) + dv_shared = T.alloc_shared([block_M, dim], dtype) + dk_shared = T.alloc_shared([block_M, dim], dtype) + + T.annotate_layout({dQ: make_dq_layout(dQ)}) + T.copy(K[bz, by * block_M : (by + 1) * block_M, bx, :], K_shared) + T.copy(V[bz, by * block_M : (by + 1) * block_M, bx, :], V_shared) + T.clear(dv) + T.clear(dk) + loop_st = T.floordiv(by * block_M, block_N) if is_causal else 0 + loop_ed = T.ceildiv(seq_len, block_N) + for k in T.Pipelined(loop_st, loop_ed, num_stages=num_stages): + T.copy(Q[bz, k * block_N : (k + 1) * block_N, bx, :], q) + T.clear(qkT) + T.tcgen05_gemm(K_shared, q, qkT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(lse[bz, bx, k * block_N : (k + 1) * block_N], lse_shared) + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.exp2(qkT[i, j] * scale - lse_shared[j]) + if is_causal: + for i, j in T.Parallel(block_M, block_N): + qkT[i, j] = T.if_then_else(by * block_M + i <= k * block_N + j, qkT[i, j], 0) + T.copy(dO[bz, k * block_N : (k + 1) * block_N, bx, :], do) + T.clear(dsT) + T.tcgen05_gemm(V_shared, do, dsT, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) + T.copy(qkT, qkT_cast) + T.tcgen05_gemm(qkT_cast, do, dv, policy=T.GemmWarpPolicy.FullRow) + + T.copy(Delta[bz, bx, k * block_N : (k + 1) * block_N], delta) + for i, j in T.Parallel(block_M, block_N): + dsT_cast[i, j] = qkT[i, j] * (dsT[i, j] - delta[j]) * sm_scale + T.tcgen05_gemm(dsT_cast, q, dk, policy=T.GemmWarpPolicy.FullRow) + + T.copy(dsT_cast, dsT_shared) + T.clear(dq) + T.tcgen05_gemm(dsT_shared, K_shared, dq, transpose_A=True) + for i, j in T.Parallel(block_N, dim): + T.atomic_add(dQ[bz, k * block_N + i, bx, j], dq[i, j]) + T.copy(dv, dv_shared) + T.copy(dk, dk_shared) + T.copy(dv_shared, dV[bz, by * block_M : (by + 1) * block_M, bx, :]) + T.copy(dk_shared, dK[bz, by * block_M : (by + 1) * block_M, bx, :]) + + return main + + +def flashattn_bwd_pipeline(batch, heads, seq_len, dim, is_causal, block_M, block_N): + """Pipeline (default): 128 threads, 2 stages.""" + return flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=128, num_stages=2) + + +def flashattn_bwd_warp(batch, heads, seq_len, dim, is_causal, block_M, block_N): + """ts: 256 threads, 2 stages. Use --variant ts to enable.""" + return flashattn_bwd(batch, heads, seq_len, dim, is_causal, block_M, block_N, threads=256, num_stages=2) + + +def ref_program(Q, K, V, is_causal): + """CPU reference forward (for validation); backward ref not implemented.""" + dim = Q.size(-1) + Q_f = Q.cpu().float() + K_f = K.cpu().float() + V_f = V.cpu().float() + scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f) + scores = scores / (dim**0.5) + if is_causal: + seq_len = Q_f.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + P = F.softmax(scores, dim=-1) + out_ref = torch.einsum("bhqk,bkhd->bqhd", P, V_f) + return out_ref.to(Q.dtype) + + +def main( + batch: int = 2, + heads: int = 4, + seq_len: int = 256, + dim: int = 128, + is_causal: bool = False, + variant: str = "ss", +): + """Run MHA backward kernels (fwd + preprocess + bwd + postprocess).""" + block_M = 64 + block_N = 64 if dim <= 64 else 32 + use_ts = variant == "ts" + bwd_fn = flashattn_bwd_warp if use_ts else flashattn_bwd_pipeline + + kernel_fwd = flashattn_fwd(batch, heads, seq_len, dim, is_causal, block_M, block_N) + kernel_prep = flashattn_bwd_preprocess(batch, heads, seq_len, dim) + kernel_post = flashattn_bwd_postprocess(batch, heads, seq_len, dim) + kernel_bwd = bwd_fn(batch, heads, seq_len, dim, is_causal, block_M, block_N) + + Q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=torch.bfloat16) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + dO = torch.randn_like(Q) + O, lse = kernel_fwd(Q, K, V) + Delta = kernel_prep(O, dO) + dQ = torch.zeros(batch, seq_len, heads, dim, device="cuda", dtype=torch.float32) + dK = torch.empty_like(K, device="cuda") + dV = torch.empty_like(V, device="cuda") + kernel_bwd(Q, K, V, dO, lse, Delta, dQ, dK, dV) + _ = kernel_post(dQ) # dQ_out in output layout; not compared to ref (no backward ref) + print("Blackwell MHA bwd ({}): run OK (backward gradients not verified against ref).".format(variant)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=2) + parser.add_argument("--heads", type=int, default=4) + parser.add_argument("--seq_len", type=int, default=256) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--is_causal", action="store_true") + parser.add_argument( + "--variant", + choices=["ss", "ts"], + default="ss", + help="ss: pipeline (default); ts: 256 threads", + ) + args = parser.parse_args() + main( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.variant, + ) diff --git a/examples/flash_attention_sm100/mha_fwd_bshd.py b/examples/flash_attention_sm100/mha_fwd_bshd.py new file mode 100644 index 0000000000..db9f2472fd --- /dev/null +++ b/examples/flash_attention_sm100/mha_fwd_bshd.py @@ -0,0 +1,491 @@ +"""Blackwell (SM100) Flash Attention Forward using TCGEN05MMA with TMEM accumulators. + +Replaces the Hopper WGMMA-based Flash Attention for Blackwell GPUs. +Three variants: ss, ts, wasp. + - flashattn (variant='ss'): Both GEMMs use mma_ss (shared x shared -> TMEM), 128 threads. + - flashattn (variant='ts'): Single-path; GEMM 2 uses mma_ts (P_tmem x V_shared -> D_tmem), 256 threads. + - flashattn_wasp: Warp-specialized pipeline (softmax/DMA/BMM warps); GEMM 2 mma_ts. + If wasp fails (e.g. layout inference), fallback to ts. +""" + +import torch +import torch.nn.functional as F +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench +import argparse + + +PASS_CFG = {tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: False} + + +@tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) +def flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + variant="ss", +): + """Flash Attention forward. variant='ss': mma_ss (128t, P via shared); 'ts': mma_ts (256t, P via TMEM).""" + use_ts = variant == "ts" + threads = 256 if use_ts else 128 + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared = T.alloc_shared([block_N, dim], dtype) + V_shared = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + + S_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + D_tmem = T.alloc_tmem([block_M, dim], accum_dtype) + mbar_s = T.alloc_barrier(1) + mbar_d = T.alloc_barrier(1) + + if use_ts: + P_tmem = T.alloc_tmem([block_M, block_N], dtype) + else: + P_shared = T.alloc_shared([block_M, block_N], dtype) + + S_reg = T.alloc_fragment([block_M, block_N], accum_dtype) + P_cast = T.alloc_fragment([block_M, block_N], dtype) + O_reg = T.alloc_fragment([block_M, dim], accum_dtype) + D_reg = T.alloc_fragment([block_M, dim], accum_dtype) + + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_scale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + T.fill(O_reg, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + + loop_range = ( + T.min( + T.ceildiv(seq_len, block_N), + T.ceildiv((bx + 1) * block_M, block_N), + ) + if is_causal + else T.ceildiv(seq_len, block_N) + ) + + for k in T.Pipelined(loop_range, num_stages=1): + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared) + + # GEMM 1: S = Q @ K^T -> S_tmem (tcgen05mma_ss) + T.tcgen05_gemm( + Q_shared, + K_shared, + S_tmem, + transpose_B=True, + mbar=mbar_s, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar_s, k % 2) + + T.copy(S_tmem, S_reg) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + S_reg[i, j], + -T.infinity(accum_dtype), + ) + else: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + k * block_N + j >= seq_len, + -T.infinity(accum_dtype), + S_reg[i, j], + ) + + # Online softmax + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(S_reg, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.exp2(S_reg[i, j] * scale - scores_max[i] * scale) + T.reduce_sum(S_reg, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] *= scores_scale[i] + + T.copy(S_reg, P_cast) + if use_ts: + T.copy(P_cast, P_tmem) + P_operand = P_tmem + else: + T.copy(P_cast, P_shared) + P_operand = P_shared + + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared) + + # GEMM 2: D = P @ V -> D_tmem (ss: mma_ss; ts: mma_ts) + T.tcgen05_gemm( + P_operand, + V_shared, + D_tmem, + mbar=mbar_d, + clear_accum=True, + ) + T.mbarrier_wait_parity(mbar_d, k % 2) + + T.copy(D_tmem, D_reg) + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] += D_reg[i, j] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] /= logsum[i] + T.copy(O_reg, O_shared) + T.copy( + O_shared, + Output[bz, bx * block_M : (bx + 1) * block_M, by, :], + ) + + return main + + +flashattn_ss = flashattn +flashattn_ts = flashattn + + +@tilelang.jit(out_idx=[3], pass_configs=PASS_CFG) +def flashattn_wasp( + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + threads=256, + num_stages=2, +): + """Warp-specialized pipeline: softmax(0-127)/DMA(128-159)/BMM(160-191); GEMM2 mma_ts.""" + scale = (1.0 / dim) ** 0.5 * 1.44269504 + shape = [batch, seq_len, heads, dim] + dtype = T.bfloat16 + accum_dtype = T.float32 + + @T.prim_func + def main( + Q: T.Tensor(shape, dtype), + K: T.Tensor(shape, dtype), + V: T.Tensor(shape, dtype), + Output: T.Tensor(shape, dtype), + ): + with T.Kernel(T.ceildiv(seq_len, block_M), heads, batch, threads=threads) as (bx, by, bz): + Q_shared = T.alloc_shared([block_M, dim], dtype) + K_shared_0 = T.alloc_shared([block_N, dim], dtype) + K_shared_1 = T.alloc_shared([block_N, dim], dtype) + V_shared_0 = T.alloc_shared([block_N, dim], dtype) + V_shared_1 = T.alloc_shared([block_N, dim], dtype) + O_shared = T.alloc_shared([block_M, dim], dtype) + + S_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + P_tmem = T.alloc_tmem([block_M, block_N], dtype) + O_tmem = T.alloc_tmem([block_M, dim], accum_dtype) + + mbar_dma1_empty = T.alloc_barrier([32] * num_stages) + mbar_dma1_full = T.alloc_barrier([32] * num_stages) + mbar_bmm1_empty = T.alloc_barrier([128] * num_stages) + mbar_bmm1_full = T.alloc_barrier([1] * num_stages) + mbar_dma2_empty = T.alloc_barrier([32] * num_stages) + mbar_dma2_full = T.alloc_barrier([32] * num_stages) + mbar_bmm2_full = T.alloc_barrier([1] * num_stages) + mbar_softmax_empty = T.alloc_barrier([32] * num_stages) + mbar_softmax_full = T.alloc_barrier([128] * num_stages) + mbar_correction_full = T.alloc_barrier([32] * num_stages) + + tid = T.get_thread_binding() + + S_reg = T.alloc_fragment([block_M, block_N], accum_dtype) + P_cast = T.alloc_fragment([block_M, block_N], dtype) + O_reg = T.alloc_fragment([block_M, dim], accum_dtype) + + scores_max = T.alloc_fragment([block_M], accum_dtype) + scores_max_prev = T.alloc_fragment([block_M], accum_dtype) + scores_rescale = T.alloc_fragment([block_M], accum_dtype) + scores_sum = T.alloc_fragment([block_M], accum_dtype) + logsum = T.alloc_fragment([block_M], accum_dtype) + + if tid < 128: + T.fill(O_reg, 0) + T.fill(logsum, 0) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.copy(O_reg, O_tmem) + + loop_range = ( + T.min( + T.ceildiv(seq_len, block_N), + T.ceildiv((bx + 1) * block_M, block_N), + ) + if is_causal + else T.ceildiv(seq_len, block_N) + ) + + for k in T.serial(loop_range): + parity = (k // num_stages) & 1 + parity_inv = parity ^ 1 + stage_id = k % num_stages + is_clear_accum = k == 0 + + if tid >= 128 and tid < 160: + T.mbarrier_wait_parity(mbar_dma1_empty[stage_id], parity_inv) + + if k == 0: + T.copy(Q[bz, bx * block_M : (bx + 1) * block_M, by, :], Q_shared) + + if stage_id == 0: + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared_0) + else: + T.copy(K[bz, k * block_N : (k + 1) * block_N, by, :], K_shared_1) + T.mbarrier_arrive(mbar_dma1_full[stage_id]) + + T.mbarrier_wait_parity(mbar_dma2_empty[stage_id], parity_inv) + + if stage_id == 0: + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared_0) + else: + T.copy(V[bz, k * block_N : (k + 1) * block_N, by, :], V_shared_1) + + T.mbarrier_arrive(mbar_dma2_full[stage_id]) + + elif tid >= 160 and tid < 192: + T.mbarrier_wait_parity(mbar_dma1_full[stage_id], parity) + T.mbarrier_wait_parity(mbar_bmm1_empty[stage_id], parity_inv) + + if stage_id == 0: + T.tcgen05_gemm( + Q_shared, + K_shared_0, + S_tmem, + transpose_B=True, + mbar=mbar_bmm1_full[stage_id], + clear_accum=True, + ) + else: + T.tcgen05_gemm( + Q_shared, + K_shared_1, + S_tmem, + transpose_B=True, + mbar=mbar_bmm1_full[stage_id], + clear_accum=True, + ) + T.mbarrier_arrive(mbar_dma1_empty[stage_id]) + + T.mbarrier_wait_parity(mbar_softmax_full[stage_id], parity) + T.mbarrier_wait_parity(mbar_dma2_full[stage_id], parity) + + if stage_id == 0: + T.tcgen05_gemm( + P_tmem, + V_shared_0, + O_tmem, + mbar=mbar_bmm2_full[stage_id], + clear_accum=is_clear_accum, + ) + else: + T.tcgen05_gemm( + P_tmem, + V_shared_1, + O_tmem, + mbar=mbar_bmm2_full[stage_id], + clear_accum=is_clear_accum, + ) + + T.mbarrier_arrive(mbar_softmax_empty[stage_id]) + T.mbarrier_arrive(mbar_dma2_empty[stage_id]) + + if k == loop_range - 1: + T.mbarrier_arrive(mbar_correction_full[0]) + + elif tid < 128: + T.mbarrier_wait_parity(mbar_softmax_empty[stage_id], parity_inv) + T.mbarrier_wait_parity(mbar_bmm1_full[stage_id], parity) + if k > 0: + prev_stage = (k - 1) % num_stages + prev_parity = ((k - 1) // num_stages) & 1 + T.mbarrier_wait_parity(mbar_bmm2_full[prev_stage], prev_parity) + + T.copy(O_tmem, O_reg) + T.copy(S_tmem, S_reg) + + if is_causal: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + bx * block_M + i >= k * block_N + j, + S_reg[i, j], + -T.infinity(accum_dtype), + ) + else: + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.if_then_else( + k * block_N + j >= seq_len, + -T.infinity(accum_dtype), + S_reg[i, j], + ) + + T.copy(scores_max, scores_max_prev) + T.fill(scores_max, -T.infinity(accum_dtype)) + T.reduce_max(S_reg, scores_max, dim=1, clear=False) + for i in T.Parallel(block_M): + scores_max[i] = T.max(scores_max[i], scores_max_prev[i]) + for i in T.Parallel(block_M): + scores_rescale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) + for i, j in T.Parallel(block_M, block_N): + S_reg[i, j] = T.exp2(S_reg[i, j] * scale - scores_max[i] * scale) + + T.reduce_sum(S_reg, scores_sum, dim=1) + for i in T.Parallel(block_M): + logsum[i] = logsum[i] * scores_rescale[i] + scores_sum[i] + + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] *= scores_rescale[i] + + T.copy(S_reg, P_cast) + T.copy(P_cast, P_tmem) + T.copy(O_reg, O_tmem) + + T.mbarrier_arrive(mbar_softmax_full[stage_id]) + T.mbarrier_arrive(mbar_bmm1_empty[stage_id]) + + if k == loop_range - 1: + T.mbarrier_wait_parity(mbar_correction_full[0], 0) + T.mbarrier_wait_parity(mbar_bmm2_full[stage_id], parity) + T.copy(O_tmem, O_reg) + for i, j in T.Parallel(block_M, dim): + O_reg[i, j] /= logsum[i] + T.copy(O_reg, O_shared) + T.copy( + O_shared, + Output[bz, bx * block_M : (bx + 1) * block_M, by, :], + ) + + return main + + +flashattn_warp = flashattn_wasp + + +def ref_program(Q, K, V, is_causal): + """CPU reference computation to avoid cuBLAS issues on Blackwell.""" + Q_f = Q.cpu().float() + K_f = K.cpu().float() + V_f = V.cpu().float() + dim = Q_f.size(-1) + scores = torch.einsum("bqhd,bkhd->bhqk", Q_f, K_f) + scores = scores / (dim**0.5) + if is_causal: + seq_len = Q_f.size(1) + mask = torch.tril(torch.ones(seq_len, seq_len)) + mask = mask.unsqueeze(0).unsqueeze(0) + scores = scores.masked_fill(mask == 0, float("-inf")) + attention_weights = F.softmax(scores, dim=-1) + output = torch.einsum("bhqk,bkhd->bqhd", attention_weights, V_f) + return output.to(torch.bfloat16) + + +def main( + batch: int = 2, + heads: int = 4, + seq_len: int = 256, + dim: int = 128, + is_causal: bool = False, + variant: str = "ss", +): + """Run MHA forward kernel (ss / ts / wasp) and benchmark.""" + flops_per_matmul = 2.0 * batch * heads * seq_len * seq_len * dim + total_flops = 2 * flops_per_matmul + if is_causal: + total_flops *= 0.5 + + print(f"=== Blackwell Flash Attention ({variant.upper()}) ===") + print(f"batch={batch}, heads={heads}, seq_len={seq_len}, dim={dim}, causal={is_causal}") + + if variant in ("ss", "ts"): + kernel = flashattn( + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + variant=variant, + ) + else: + kernel = flashattn_wasp( + batch, + heads, + seq_len, + dim, + is_causal, + block_M=128, + block_N=128, + threads=256, + num_stages=2, + ) + + Q = torch.randn(batch, seq_len, heads, dim, device="cuda", dtype=torch.bfloat16) + K = torch.randn_like(Q) + V = torch.randn_like(Q) + + out = kernel(Q, K, V) + ref = ref_program(Q, K, V, is_causal).to(out.device) + torch.testing.assert_close(out, ref, rtol=1e-2, atol=1e-2) + print("Correctness check passed.") + + latency = do_bench(lambda: kernel(Q, K, V), warmup=100) + print(f"Blackwell ({variant}): {latency:.2f} ms") + print(f"Blackwell ({variant}): {total_flops / latency * 1e-9:.2f} TFlops") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch", type=int, default=1) + parser.add_argument("--heads", type=int, default=16) + parser.add_argument("--seq_len", type=int, default=16384) + parser.add_argument("--dim", type=int, default=128) + parser.add_argument("--is_causal", action="store_true") + parser.add_argument( + "--variant", + choices=["ss", "ts", "wasp"], + default="wasp", + help="ss: pipeline 128t; ts: single-path 256t mma_ts; wasp: warp-specialized (fallback to ts if fail)", + ) + args = parser.parse_args() + print(args) + main( + args.batch, + args.heads, + args.seq_len, + args.dim, + args.is_causal, + args.variant, + ) diff --git a/examples/flash_decoding/example_gqa_decode.py b/examples/flash_decoding/example_gqa_decode.py index 9e6f360178..26b4801115 100644 --- a/examples/flash_decoding/example_gqa_decode.py +++ b/examples/flash_decoding/example_gqa_decode.py @@ -40,13 +40,13 @@ def get_heuristic_config() -> Tuple[Dict, int]: return cfg, sm_version -# TODO(lei): fix warp specialized and tma lower pass +# TODO(lei): fix warp specialized pass def get_pass_configs(): - return {tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} + return {tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} @autotune(configs=get_configs(), warmup=10, rep=10) -@tilelang.jit(out_idx=[6], pass_configs=get_pass_configs()) +@tilelang.jit(out_idx=[4], pass_configs=get_pass_configs()) def flashattn(batch, heads, groups, seqlen_kv, dim, block_N, block_H, num_split, num_stages, threads): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, heads, dim] @@ -67,10 +67,10 @@ def flashattn_gqa_decode_split( K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), Output: T.Tensor(shape_o, dtype), ): + glse = T.alloc_global([batch, heads, num_split], dtype) + Output_partial = T.alloc_global(part_shape, dtype) # split with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) @@ -193,8 +193,6 @@ def flashattn_gqa_decode_no_split( K: T.Tensor(shape_k, dtype), V: T.Tensor(shape_v, dtype), mask: T.Tensor([batch, seqlen_kv, groups], "uint8"), - glse: T.Tensor([batch, heads, num_split], dtype), - Output_partial: T.Tensor(part_shape, dtype), Output: T.Tensor(shape_o, dtype), ): with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): @@ -259,7 +257,7 @@ def flashattn_gqa_decode_no_split( return flashattn_gqa_decode_no_split -def ref_program(query, key, value, mask, glse, Output_partial): +def ref_program(query, key, value, mask): # """ # Inputs: # - query (Tensor): [batch, heads, dim] @@ -417,12 +415,9 @@ def main(batch: int = 1, heads: int = 32, groups: int = 8, kv_seqlen: int = 8192 k = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) v = torch.randn(batch, kv_seqlen, groups, dim, device="cuda", dtype=torch.float16) mask = torch.randint(0, 2, (batch, kv_seqlen, groups), device="cuda", dtype=torch.uint8) - split = config["num_split"] - glse = torch.empty(batch, heads, split, device="cuda", dtype=torch.float16) - Output_partial = torch.empty(batch, heads, split, dim, device="cuda", dtype=torch.float16) - o = kernel(q, k, v, mask, glse, Output_partial) - o_ref = ref_program(q, k, v, mask, glse, Output_partial) - o_ref_split = ref_split_program(q, k, v, mask, glse, Output_partial) + o = kernel(q, k, v, mask) + o_ref = ref_program(q, k, v, mask) + o_ref_split = ref_split_program(q, k, v, mask) print(o) print(o_ref) diff --git a/examples/flash_decoding/example_gqa_decode_varlen_logits.py b/examples/flash_decoding/example_gqa_decode_varlen_logits.py index 30acd879e6..468be22302 100644 --- a/examples/flash_decoding/example_gqa_decode_varlen_logits.py +++ b/examples/flash_decoding/example_gqa_decode_varlen_logits.py @@ -1,10 +1,9 @@ import torch -import triton -import triton.language as tl import math import argparse import tilelang import tilelang.language as T +from tilelang.profiler import do_bench torch.manual_seed(0) @@ -21,167 +20,6 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) -@triton.jit -def _fwd_inner( - q, - k_ptrs, - v_ptrs, - s_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N: tl.constexpr, -): - """Inner loop computation for attention""" - - for blk_idx in tl.range(lo, hi): - start_n = blk_idx * BLOCK_N - k = tl.load(k_ptrs + start_n * stride_kt, mask=offs_n[None, :] + start_n < seqlen) - v = tl.load(v_ptrs + start_n * stride_vt, mask=offs_n[:, None] + start_n < seqlen) - - qk = tl.dot(q, k) - qk *= softmax_scale - qk += tl.where(offs_n[None, :] + start_n < seqlen, 0, -1.0e9) - - row_max = tl.max(qk, 1) - tl.store(s_ptrs + offs_h * stride_sh + blk_idx * stride_sn, row_max, mask=mask_h) - - m_ij = tl.maximum(m_i, row_max) - qk -= m_ij[:, None] - p = tl.math.exp(qk) - l_ij = tl.sum(p, 1) - alpha = tl.math.exp(m_i - m_ij) - l_i = l_i * alpha + l_ij - m_i = m_ij - acc *= alpha[:, None] - p = p.to(v.type.element_ty) - acc += tl.dot(p, v) - - return m_i, l_i, acc - - -@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [4, 8] for num_stages in [2, 4]], - key=["gqa_group_size", "BLOCK_N", "BLOCK_D", "BLOCK_H"], -) -@triton.jit -def _fwd_kernel_varlen( - Q, # [token_q = b, h_q, dim] - K, # [token_k, h_kv, dim] - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - stride_qt, - stride_qh, - stride_qd, - stride_kt, - stride_kh, - stride_kd, - stride_vt, - stride_vh, - stride_vd, - stride_ot, - stride_oh, - stride_od, - stride_sb, - stride_sh, - stride_sn, # bmask shape [b, q_h, seq/BLOCK_N] - gqa_group_size: tl.constexpr, - BLOCK_H: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_D: tl.constexpr, -): - off_z = tl.program_id(0) - off_h_for_kv = tl.program_id(1) - off_h_q = off_h_for_kv * gqa_group_size - - cu_k_start = tl.load(cu_seqlens_k + off_z) - cu_k_end = tl.load(cu_seqlens_k + off_z + 1) - - seqlen_k = cu_k_end - cu_k_start - - offs_h = tl.arange(0, BLOCK_H) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_D) - - Q_ptrs = Q + off_z * stride_qt + off_h_q * stride_qh - K_ptrs = K + (cu_k_start) * stride_kt + off_h_for_kv * stride_kh - V_ptrs = V + (cu_k_start) * stride_vt + off_h_for_kv * stride_vh - O_ptrs = O + off_z * stride_ot + off_h_q * stride_oh - S_ptrs = S + off_z * stride_sb + off_h_q * stride_sh - - mask_h = offs_h < gqa_group_size - q = tl.load(Q_ptrs + offs_d[None, :] * stride_qd + offs_h[:, None] * stride_qh, mask=mask_h[:, None]) - - if s_aux is not None: - sink = tl.load(s_aux + off_h_q + offs_h, mask=mask_h).to(tl.float32) - l_i = tl.zeros([BLOCK_H], dtype=tl.float32) - m_i = tl.zeros([BLOCK_H], dtype=tl.float32) + sink - else: - l_i = tl.full([BLOCK_H], 1.0, dtype=tl.float32) - m_i = tl.full([BLOCK_H], float("-inf"), dtype=tl.float32) - - acc = tl.zeros([BLOCK_H, BLOCK_D], dtype=tl.float32) - - k_ptrs = K_ptrs + offs_n[None, :] * stride_kt + offs_d[:, None] * stride_kd - v_ptrs = V_ptrs + offs_n[:, None] * stride_vt + offs_d[None, :] * stride_vd - - lo, hi = 0, tl.cdiv(seqlen_k, BLOCK_N) - m_i, l_i, acc = _fwd_inner( - q, - k_ptrs, - v_ptrs, - S_ptrs, - m_i, - l_i, - acc, - offs_h, - mask_h, - offs_n, - seqlen_k, - softmax_scale, - lo, - hi, - stride_kt, - stride_vt, - stride_sh, - stride_sn, - BLOCK_N, - ) - - if s_aux is not None: - sink = tl.math.exp(sink - m_i) - l_i = l_i + sink - acc = acc / l_i[:, None] - - else: - l_recip = 1 / l_i[:, None] - acc = acc * l_recip - - for blk_idx in tl.range(lo, hi): - s = tl.load(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, mask=mask_h) - s = tl.exp(s - m_i) / l_i - tl.store(S_ptrs + offs_h * stride_sh + blk_idx * stride_sn, s, mask=mask_h) - - acc = acc.to(O.dtype.element_ty) - - tl.store(O_ptrs + offs_h[:, None] * stride_oh + offs_d[None, :] * stride_od, acc, mask=mask_h[:, None]) - - def get_configs(): import itertools @@ -211,7 +49,6 @@ def flashattn( kv_group_num = heads // k_heads valid_block_H = min(block_H, kv_group_num) - # TODO: check if max_seqlen_kv is correct for varlen case @T.prim_func def flashattn_gqa_decode_no_split( @@ -223,7 +60,7 @@ def flashattn_gqa_decode_no_split( Output: T.Tensor(shape_o, dtype), S: T.Tensor(shape_s, dtype), ): - with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bx, by, bz): + with T.Kernel(batch, heads // valid_block_H, num_split, threads=threads) as (bid, hid, bz): Q_shared = T.alloc_shared([block_H, dim], dtype) K_shared = T.alloc_shared([block_N, dim], dtype) V_shared = T.alloc_shared([block_N, dim], dtype) @@ -236,12 +73,10 @@ def flashattn_gqa_decode_no_split( scores_scale = T.alloc_fragment([block_H], accum_dtype) scores_sum = T.alloc_fragment([block_H], accum_dtype) logsum = T.alloc_fragment([block_H], accum_dtype) - S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) - # S_fragment = T.alloc_fragment([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + S_shared = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], accum_dtype) + S_shared_cast = T.alloc_shared([block_H, math.ceil(max_seqlen_kv / block_N)], dtype) s_aux_shared = T.alloc_shared([block_H], T.float32) - bid = bx - hid = by cur_kv_head = hid // (kv_group_num // valid_block_H) cur_start_k = cu_seqlens_k[bid] @@ -253,30 +88,22 @@ def flashattn_gqa_decode_no_split( T.fill(logsum, 0) T.fill(scores_max, -T.infinity(accum_dtype)) - # loop_range = T.ceildiv((seqlen_kv // num_split), block_N) loop_range = T.ceildiv((cur_seqlen_k // num_split), block_N) for k in T.Pipelined(loop_range, num_stages=num_stages): T.copy(K[cur_start_k + k * block_N : cur_start_k + (k + 1) * block_N, cur_kv_head, :], K_shared) T.clear(acc_s) T.gemm(Q_shared, K_shared, acc_s, transpose_B=True, policy=T.GemmWarpPolicy.FullRow) for i, j in T.Parallel(block_H, block_N): - # acc_s[i, j] = T.if_then_else(mask_local[j] != 0 and k * block_N + j < cur_seqlen_k, acc_s[i, j], - # -T.infinity(accum_dtype)) acc_s[i, j] = T.if_then_else(k * block_N + j < cur_seqlen_k, acc_s[i, j], -T.infinity(accum_dtype)) T.copy(scores_max, scores_max_prev) T.fill(scores_max, -T.infinity(accum_dtype)) T.reduce_max(acc_s, scores_max, dim=1, clear=False) - # scores_max_prev is m_i - # scores_max is row_max->m_ij in triton T.copy(scores_max, S_shared[:, k]) - # scores_scale is alpha in triton for i in T.Parallel(block_H): scores_scale[i] = T.exp2(scores_max_prev[i] * scale - scores_max[i] * scale) for i, j in T.Parallel(block_H, block_N): acc_s[i, j] = T.exp2(acc_s[i, j] * scale - scores_max[i] * scale) T.reduce_sum(acc_s, scores_sum, dim=1) - # scores_sum is l_ij in triton - # logsum is l_i in triton for i in T.Parallel(block_H): logsum[i] = logsum[i] * scores_scale[i] + scores_sum[i] T.copy(acc_s, acc_s_cast) @@ -293,355 +120,109 @@ def flashattn_gqa_decode_no_split( acc_o[i, j] /= logsum[i] for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): S_shared[h, k] = T.exp2((S_shared[h, k] - scores_max[h]) * scale) / logsum[h] - # T.copy(S_shared, S_fragment) - # for h, k in T.Parallel(block_H, math.ceil(max_seqlen_kv / block_N)): - # S_fragment[h, k] = T.exp2((S_fragment[h, k] - scores_max[h]) * scale) / logsum[h] for i in T.Parallel(block_H): logsum[i] = T.log2(logsum[i]) + scores_max[i] * scale T.copy(acc_o[:valid_block_H, :], O_shared) T.copy(O_shared, Output[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - # T.copy(S_fragment, S_shared) - T.copy(S_shared[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) + T.copy(S_shared, S_shared_cast) + T.copy(S_shared_cast[:valid_block_H, :], S[bid, hid * valid_block_H : (hid + 1) * valid_block_H, :]) - # TODO: split version return flashattn_gqa_decode_no_split -def flash_attn_with_attn_pool_decode_tilelang( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, - tl_kernel=None, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - # assert K.is_contiguous() - # assert V.is_contiguous() +def ref_attention(q, k, v, k_seqlens, q_heads, sink=None): + """ + Compute reference attention output and weights. + Args: + q: [b, q_heads, head_size] + k, v: [b, kv_heads, max_seqlen, head_size] + k_seqlens: [b] actual sequence lengths + sink: [q_heads] optional sink values + Returns: output [b, q_heads, head_size], attn_weights [b, q_heads, max_seqlen] + """ + batch_size, kv_heads, max_seqlen, head_size = k.shape + softmax_scale = 1.0 / math.sqrt(head_size) - gqa_group_size = q_h // k_h + # Expand KV heads and compute attention scores + k = repeat_kv(k, q_heads // kv_heads) + v = repeat_kv(v, q_heads // kv_heads) + logits = torch.matmul(q.unsqueeze(2), k.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - O_tl = torch.zeros_like(Q) - S_tl = torch.zeros((batch, q_h, math.ceil(real_max_k_seqlen / block_size)), dtype=Q.dtype, device=Q.device) - O_tl, S_tl = tl_kernel(Q, K, V, cu_seqlens_k, s_aux) + # Mask invalid positions + mask = torch.arange(max_seqlen, device=q.device).expand(batch_size, -1) >= k_seqlens.unsqueeze(1) + logits.masked_fill_(mask.unsqueeze(1).unsqueeze(2), float("-inf")) - if use_per_kv_head_sparse_index: - S_tl = torch.max_pool2d(S_tl, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) - else: - S_tl = torch.max_pool2d(S_tl, kernel_size=(q_h, 1), stride=(q_h, 1)) - - return O_tl, S_tl - - -def flash_attn_with_attn_pool_decode( - Q: torch.Tensor, ## [tq = b, q_h, q_dim] - K: torch.Tensor, ## [tk, k_h, k_dim] - V: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_k: int, - real_max_k_seqlen: int, - num_split: int, - softmax_scale: float, - s_aux: torch.Tensor = None, - block_size: int = 64, - use_per_kv_head_sparse_index: bool = False, -): - num_tokens, q_h, head_size = Q.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = K.size(1) - - assert Q.dim() == K.dim() == 3 - assert Q.size(2) == K.size(2) - assert cu_seqlens_k.dim() == 1 - assert head_size in {64, 128, 256} - assert Q.is_contiguous() - # assert K.is_contiguous() - # assert V.is_contiguous() - - gqa_group_size = q_h // k_h - - BLOCK_D = head_size - BLOCK_N = block_size - BLOCK_H = 64 - - O = torch.zeros_like(Q) - S = torch.zeros((batch, q_h, math.ceil(max_seqlen_k / block_size)), dtype=Q.dtype, device=Q.device) - - def grid(META): - return (batch, k_h) - - with torch.cuda.device(Q.device.index): - _fwd_kernel_varlen[grid]( - Q, - K, - V, - O, - S, - s_aux, - softmax_scale, - cu_seqlens_k, - *Q.stride(), - *K.stride(), - *V.stride(), - *O.stride(), - *S.stride(), - gqa_group_size, - BLOCK_H=BLOCK_H, - BLOCK_N=BLOCK_N, - BLOCK_D=BLOCK_D, - ) - - if use_per_kv_head_sparse_index: - S = torch.max_pool2d(S, kernel_size=(gqa_group_size, 1), stride=(gqa_group_size, 1)) + if sink is None: + attn_weights = logits.softmax(dim=-1) else: - S = torch.max_pool2d(S, kernel_size=(q_h, 1), stride=(q_h, 1)) + # Sink attention: softmax with additional sink term + sink_expanded = sink.view(1, q_heads, 1, 1) + logits_max = torch.maximum(logits.max(dim=-1, keepdim=True).values, sink_expanded) + exp_logits = torch.exp(logits - logits_max) + attn_weights = exp_logits / (exp_logits.sum(dim=-1, keepdim=True) + torch.exp(sink_expanded - logits_max)) - return O, S + attn_weights.masked_fill_(mask.unsqueeze(1).unsqueeze(2), 0.0) + output = torch.matmul(attn_weights.to(v.dtype), v).squeeze(2) + return output, attn_weights.squeeze(2) def test_varlen_decode_main(args): - """Test decode kernel with variable sequence lengths""" - batch_size = args.batch_size - q_heads = args.q_heads - kv_heads = args.kv_heads - max_k_seqlen = args.k_seqlen # Use as max sequence length - real_max_k_seqlen = args.k_seqlen - head_size = args.head_size - block_size = args.block_size + """Test decode kernel with variable sequence lengths.""" + batch_size, q_heads, kv_heads = args.batch_size, args.q_heads, args.kv_heads + max_k_seqlen, head_size, block_size = args.k_seqlen, args.head_size, args.block_size dtype = torch.bfloat16 if args.dtype == T.bfloat16 else torch.float16 - print(f"Testing decode kernel with variable sequence lengths (max_k_seqlen={max_k_seqlen})") - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(f"Using sink attention with sink values: {sink}") - - # Generate variable length k sequences - k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) - print(f"k_seqlens: {k_seqlens}") - - # Generate cumulative sequence lengths for k - cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) - total_k_tokens = 0 - for i in range(batch_size): - cu_seqlens_k[i] = total_k_tokens - total_k_tokens += k_seqlens[i] - cu_seqlens_k[batch_size] = total_k_tokens - - print(f"cu_seqlens_k: {cu_seqlens_k}") - - # Generate tensors - Q is [batch_size, q_heads, head_size] for decode - q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) - k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - print(f"Actual max_seqlen_k: {max_seqlen_k}") - print(f"q_decode shape: {q_decode.shape}") - print(f"k_varlen shape: {k_varlen.shape}") - print(f"v_varlen shape: {v_varlen.shape}") - - num_tokens, q_h, head_size = q_decode.shape - batch = cu_seqlens_k.size(0) - 1 - k_h = k_varlen.size(1) - tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + # Make the test deterministic and independent of global RNG state. + # This avoids flaky allclose failures when run under xdist with different + # test ordering. + cuda_devices = list(range(torch.cuda.device_count())) + with torch.random.fork_rng(devices=cuda_devices): + torch.manual_seed(0) + if cuda_devices: + torch.cuda.manual_seed_all(0) - # Test our decode kernel - O_triton, S_triton = flash_attn_with_attn_pool_decode( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - ) - O_tilelang, S_tilelang = flash_attn_with_attn_pool_decode_tilelang( - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - real_max_k_seqlen, - args.num_split, - softmax_scale, - s_aux=sink, - block_size=block_size, - tl_kernel=tl_kernel, - ) + # Generate variable length sequences and cumulative lengths + k_seqlens = torch.randint(max_k_seqlen // 4, max_k_seqlen + 1, size=(batch_size,)) + cu_seqlens_k = torch.zeros(batch_size + 1, device="cuda", dtype=torch.int32) + cu_seqlens_k[1:] = torch.cumsum(k_seqlens, dim=0).to(torch.int32).cuda() + total_k_tokens = cu_seqlens_k[-1].item() + + # Generate input tensors + q = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) + k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None + + # Run tilelang kernel + tl_kernel = flashattn(batch_size, q_heads, kv_heads, max_k_seqlen, total_k_tokens, head_size, args.test_sink) + O_tl, S_tl = tl_kernel(q, k_varlen, v_varlen, cu_seqlens_k, sink) + S_tl = torch.max_pool2d(S_tl, kernel_size=(q_heads, 1), stride=(q_heads, 1)) + + # Mask out invalid S positions for i in range(batch_size): - S_tilelang[i, :, math.ceil((cu_seqlens_k[i + 1].item() - cu_seqlens_k[i].item()) / block_size) :] = 0 - - # Create torch reference - pad tensors for comparison - k_padded_list = [] - v_padded_list = [] + valid_blocks = math.ceil(k_seqlens[i].item() / block_size) + S_tl[i, :, valid_blocks:] = 0 + # Prepare padded tensors for reference + actual_max = int(k_seqlens.max()) + k_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) + v_padded = torch.zeros(batch_size, kv_heads, actual_max, head_size, device="cuda", dtype=dtype) for i in range(batch_size): - actual_k_len = k_seqlens[i] - - # Extract and pad k, v for this batch - k_start = cu_seqlens_k[i] - k_end = cu_seqlens_k[i + 1] - - # Pad to max_seqlen_k - k_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - v_padded = torch.zeros(max_seqlen_k, kv_heads, head_size, device="cuda", dtype=dtype) - - k_padded[:actual_k_len] = k_varlen[k_start:k_end] - v_padded[:actual_k_len] = v_varlen[k_start:k_end] - - k_padded_list.append(k_padded) - v_padded_list.append(v_padded) - - # Stack to create batched tensors [b, max_seqlen, kv_heads, head_size] - k_padded_batched = torch.stack(k_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - v_padded_batched = torch.stack(v_padded_list, dim=0).transpose(1, 2) # [b, kv_heads, max_seqlen, head_size] - - # Expand q to match kv heads: [b, q_heads, 1, head_size] - q_expanded = q_decode.unsqueeze(2) # [b, q_heads, 1, head_size] - - print(f"q_expanded shape: {q_expanded.shape}") - print(f"k_padded_batched shape: {k_padded_batched.shape}") - print(f"v_padded_batched shape: {v_padded_batched.shape}") - - # Compute torch reference - k_repeat = repeat_kv(k_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - v_repeat = repeat_kv(v_padded_batched, q_heads // kv_heads) # [b, q_heads, max_seqlen, head_size] - - if sink is None: - # Standard attention computation: [b, q_heads, 1, head_size] @ [b, q_heads, head_size, max_seqlen] - attn_score = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] + seq_len = k_seqlens[i].item() + k_padded[i, :, :seq_len] = k_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) + v_padded[i, :, :seq_len] = v_varlen[cu_seqlens_k[i] : cu_seqlens_k[i + 1]].transpose(0, 1) - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_score[i, :, :, actual_k_len:] = float("-inf") - - attn_weights = attn_score.softmax(dim=-1) # [b, q_heads, 1, max_seqlen] - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights, v_repeat) # [b, q_heads, 1, head_size] - else: - # s_aux attention - logits = torch.matmul(q_expanded, k_repeat.transpose(-2, -1)) * softmax_scale # [b, q_heads, 1, max_seqlen] - - # Apply sequence length masking - for i in range(batch_size): - actual_k_len = k_seqlens[i] - logits[i, :, :, actual_k_len:] = float("-inf") - - sink_expanded = sink.view(1, q_heads, 1, 1) # [1, q_heads, 1, 1] - logits_max = torch.max(logits, dim=-1, keepdim=True).values - logits_or_sinks_max = torch.maximum(logits_max, sink_expanded) - sinks = torch.exp(sink_expanded - logits_or_sinks_max) - unnormalized_scores = torch.exp(logits - logits_or_sinks_max) - normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + sinks - attn_weights = unnormalized_scores / normalizer - - # Mask out invalid positions - for i in range(batch_size): - actual_k_len = k_seqlens[i] - attn_weights[i, :, :, actual_k_len:] = 0.0 - - # Compute output: [b, q_heads, 1, max_seqlen] @ [b, q_heads, max_seqlen, head_size] - O_torch = torch.matmul(attn_weights.to(v_repeat.dtype), v_repeat) # [b, q_heads, 1, head_size] - - O_torch = O_torch.squeeze(2) # [b, q_heads, head_size] - - # Compute attention score pooling for S - attn_score_pooled = torch.max_pool2d( - attn_weights.squeeze(2), # [b, q_heads, max_seqlen] - kernel_size=(q_heads, block_size), - stride=(q_heads, block_size), - ceil_mode=True, - ).to(dtype=torch.float16) # [b, 1, ceil(max_seqlen/block_size)] - - print(f"O_triton shape: {O_triton.shape}") - print(f"O_tilelang shape: {O_tilelang.shape}") - print(f"O_torch shape: {O_torch.shape}") - print(f"S_triton shape: {S_triton.shape}") - print(f"S_tilelang shape: {S_tilelang.shape}") - print(f"attn_score_pooled shape: {attn_score_pooled.shape}") + # Compute reference + O_ref, attn_weights = ref_attention(q, k_padded, v_padded, k_seqlens.cuda(), q_heads, sink) + S_ref = torch.max_pool2d(attn_weights, kernel_size=(q_heads, block_size), stride=(q_heads, block_size), ceil_mode=True).to(dtype) # Compare results - max_diff_o = torch.max(torch.abs(O_triton - O_torch)) - max_diff_o_tl = torch.max(torch.abs(O_tilelang - O_torch)) - print(f"Max difference in O: {max_diff_o.item()}") - print(f"Max difference in O_tilelang: {max_diff_o_tl.item()}") - - max_diff_s = torch.max(torch.abs(S_triton - attn_score_pooled)) - max_diff_s_tl = torch.max( - torch.abs( - S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)] - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)] - ) - ) - print(f"Max difference in S: {max_diff_s.item()}") - print(f"Max difference in S_tilelang: {max_diff_s_tl.item()}") - - assert torch.allclose(O_triton, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o.item()}" - assert torch.allclose(S_triton, attn_score_pooled, atol=1e-2, rtol=1e-2), f"Score mismatch: {max_diff_s.item()}" - assert torch.allclose(O_tilelang, O_torch, atol=1e-2, rtol=1e-2), f"Output mismatch: {max_diff_o_tl.item()}" - assert torch.allclose( - S_tilelang[:, :, : math.ceil(max_seqlen_k / block_size)], - attn_score_pooled[:, :, : math.ceil(max_seqlen_k / block_size)], - atol=1e-2, - rtol=1e-2, - ), f"Score mismatch: {max_diff_s_tl.item()}" - + num_blocks = math.ceil(actual_max / block_size) + assert torch.allclose(O_tl, O_ref, atol=1e-2, rtol=1e-2), f"Output mismatch: {(O_tl - O_ref).abs().max()}" + assert torch.allclose(S_tl[:, :, :num_blocks], S_ref[:, :, :num_blocks], atol=1e-2, rtol=1e-2), "Score mismatch" print("✅ All tests passed!") -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def speed_benchmark_decode_comparison(args): """Speed benchmark for decode kernel""" batch_size = args.batch_size @@ -682,66 +263,25 @@ def speed_benchmark_decode_comparison(args): q_decode = torch.randn(batch_size, q_heads, head_size, device="cuda", dtype=dtype) k_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) v_varlen = torch.randn(total_k_tokens, kv_heads, head_size, device="cuda", dtype=dtype) - - softmax_scale = 1.0 / math.sqrt(head_size) - max_seqlen_k = int(k_seqlens.max()) - - # Generate sink values if needed - sink = None - if args.test_sink: - sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 # Small sink values - print(" Using sink attention with sink values") - - print("Setup complete:") - print(f" Total K tokens: {total_k_tokens}") - print(f" Actual max K seq len: {max_seqlen_k}") + sink = torch.randn(q_heads, device="cuda", dtype=torch.float32) * 0.1 if args.test_sink else None if args.test_varlen: print(f" K sequence lengths: {k_seqlens.tolist()}") - # Warmup - num_tokens, q_h, head_size = q_decode.shape + _, q_h, head_size = q_decode.shape batch = cu_seqlens_k.size(0) - 1 k_h = k_varlen.size(1) tl_kernel = flashattn(batch, q_h, k_h, args.k_seqlen, cu_seqlens_k[-1].item(), head_size, args.test_sink) + def run_once(): + tl_kernel(q_decode, k_varlen, v_varlen, cu_seqlens_k, sink) + # Benchmark print("⚡ Benchmarking Tilelang kernel (100 iterations)...") tilelang_time = do_bench( - flash_attn_with_attn_pool_decode_tilelang, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - False, - tl_kernel, + run_once, ) print(f"Average decode kernel time Tilelang: {tilelang_time:.3f} ms") - # Benchmark - print("⚡ Benchmarking Triton kernel (100 iterations)...") - triton_time = do_bench( - flash_attn_with_attn_pool_decode, - q_decode, - k_varlen, - v_varlen, - cu_seqlens_k, - max_seqlen_k, - args.k_seqlen, - 1, - softmax_scale, - sink, - block_size, - ) - print(f"Average decode kernel time Triton: {triton_time:.3f} ms") - - print(f"Speedup: {(triton_time / tilelang_time):.3f}") - def main(): args = argparse.Namespace( @@ -779,7 +319,9 @@ def main(): args.dtype = T.float16 args.num_split = 1 - if args.benchmark: - speed_benchmark_decode_comparison(args) - else: - test_varlen_decode_main(args) + # if args.benchmark: + # speed_benchmark_decode_comparison(args) + # else: + # test_varlen_decode_main(args) + + speed_benchmark_decode_comparison(args) diff --git a/examples/flash_decoding/example_mha_inference.py b/examples/flash_decoding/example_mha_inference.py index 24a90c57b5..f17d6abc75 100644 --- a/examples/flash_decoding/example_mha_inference.py +++ b/examples/flash_decoding/example_mha_inference.py @@ -8,7 +8,7 @@ num_split = 4 -@tilelang.jit(out_idx=[5]) +@tilelang.jit(out_idx=[3]) def flashattn(batch, heads, seqlen_q, seqlen_kv, dim, is_causal, block_M, block_N): scale = (1.0 / dim) ** 0.5 * 1.44269504 # log2(e) shape_q = [batch, seqlen_q, heads, dim] @@ -22,10 +22,10 @@ def flashattn_mha_inference( Q: T.Tensor(shape_q, dtype), K: T.Tensor(shape_kv, dtype), V: T.Tensor(shape_kv, dtype), - glse: T.Tensor([batch, heads, num_split, seqlen_q], dtype), - Output_partial: T.Tensor(part_shape, dtype), # [batch, seqlen_q, heads, num_split, dim] Output: T.Tensor(shape_q, dtype), ): + glse = T.alloc_global([batch, heads, num_split, seqlen_q], dtype) + Output_partial = T.alloc_global(part_shape, dtype) # [batch, seqlen_q, heads, num_split, dim] # split with T.Kernel(T.ceildiv(seqlen_q, block_M), heads * batch, num_split, threads=128) as (bx, by, bz): Q_shared = T.alloc_shared([block_M, dim], dtype) @@ -155,7 +155,7 @@ def flashattn_mha_inference( return flashattn_mha_inference -def ref_program(Q, K, V, glse, Output_partial, causal): +def ref_program(Q, K, V, causal): assert causal is False dim = Q.size(-1) scores = torch.einsum("bqhd,bkhd->bhqk", Q, K) diff --git a/examples/flash_decoding/test_example_flash_decoding.py b/examples/flash_decoding/test_example_flash_decoding.py index a02a920974..3181df2d56 100644 --- a/examples/flash_decoding/test_example_flash_decoding.py +++ b/examples/flash_decoding/test_example_flash_decoding.py @@ -1,18 +1,22 @@ +import os +import pytest import tilelang.testing import example_gqa_decode import example_mha_inference import example_gqa_decode_varlen_logits -import example_gqa_decode_varlen_logits_paged + +_is_cutedsl = os.environ.get("TILELANG_TARGET", "").lower() == "cutedsl" -# TODO(lei): fix the correctness of gqa decode on sm90 @tilelang.testing.requires_cuda @tilelang.testing.requires_cuda_compute_version_le(8, 9) +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_example_gqa_decode(): example_gqa_decode.main() +@pytest.mark.skipif(_is_cutedsl, reason="CuTeDSL backend does not support alloc_global yet") def test_example_example_mha_inference(): example_mha_inference.main(BATCH=1, H=32, Q_CTX=128, KV_CTX=2048, D_HEAD=128, causal=False) @@ -21,9 +25,5 @@ def test_example_example_gqa_decode_varlen_logits(): example_gqa_decode_varlen_logits.main() -def test_example_example_gqa_decode_varlen_logits_paged(): - example_gqa_decode_varlen_logits_paged.main() - - if __name__ == "__main__": tilelang.testing.main() diff --git a/examples/fusedmoe/example_fusedmoe_tilelang.py b/examples/fusedmoe/example_fusedmoe_tilelang.py index 5c236dd802..d4a2ced46d 100644 --- a/examples/fusedmoe/example_fusedmoe_tilelang.py +++ b/examples/fusedmoe/example_fusedmoe_tilelang.py @@ -8,7 +8,7 @@ from example_fusedmoe_torch import * -@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(pass_configs={"tl.disable_warp_specialized": True}) def moe_forward_tilelang_shared( d_hidden, d_expert, @@ -93,7 +93,7 @@ def kernel_shared( return kernel_shared -@tilelang.jit(pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def moe_forward_tilelang_routed( d_hidden, d_expert, @@ -106,8 +106,6 @@ def moe_forward_tilelang_routed( block_dexpert=128, threads=256, num_stages=1, - k_pack=1, - coalesced_width=None, ): scale = 1.44269504 # log2(e) @@ -155,7 +153,7 @@ def kernel( gate_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) up_logits_local = T.alloc_fragment((block_token, block_dexpert), dtype=accum_dtype) - T.use_swizzle(10, enable=True) + T.use_swizzle(10) m_start_padded = bx * block_token @@ -172,24 +170,21 @@ def kernel( T.copy( input[m_start : m_start + block_token, k * block_dhidden : (k + 1) * block_dhidden], input_shared, - coalesced_width=coalesced_width, ) T.copy( routed_expert_gate[ cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden ], routed_expert_gate_shared, - coalesced_width=coalesced_width, ) - T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, k_pack=k_pack, transpose_B=True) + T.gemm(input_shared, routed_expert_gate_shared, gate_logits_local, transpose_B=True) T.copy( routed_expert_up[ cur_group_idx, by * block_dexpert : (by + 1) * block_dexpert, k * block_dhidden : (k + 1) * block_dhidden ], routed_expert_up_shared, - coalesced_width=coalesced_width, ) - T.gemm(input_shared, routed_expert_up_shared, up_logits_local, k_pack=k_pack, transpose_B=True) + T.gemm(input_shared, routed_expert_up_shared, up_logits_local, transpose_B=True) for i, j in T.Parallel(block_token, block_dexpert): gate_logits_local[i, j] = gate_logits_local[i, j] * (1.0 / (1.0 + T.exp2(-gate_logits_local[i, j] * scale))) @@ -205,7 +200,7 @@ def kernel( routed_expert_down_shared = T.alloc_shared((block_dhidden, block_dexpert), dtype=dtype) output_local = T.alloc_fragment((block_token, block_dhidden), dtype=accum_dtype) - T.use_swizzle(10, enable=True) + T.use_swizzle(10) m_start_padded = bx * block_token @@ -221,16 +216,14 @@ def kernel( T.copy( up_logits[m_start : m_start + block_token, k * block_dexpert : (k + 1) * block_dexpert], up_logits_shared, - coalesced_width=coalesced_width, ) T.copy( routed_expert_down[ cur_group_idx, by * block_dhidden : (by + 1) * block_dhidden, k * block_dexpert : (k + 1) * block_dexpert ], routed_expert_down_shared, - coalesced_width=coalesced_width, ) - T.gemm(up_logits_shared, routed_expert_down_shared, output_local, k_pack=k_pack, transpose_B=True) + T.gemm(up_logits_shared, routed_expert_down_shared, output_local, transpose_B=True) for i, j in T.Parallel(block_token, block_dhidden): if i < actual_rows: @@ -479,8 +472,6 @@ def custom_kernel(data: Tuple[torch.Tensor, Dict, Dict]) -> torch.Tensor: block_dexpert=128, threads=256, num_stages=1, - k_pack=1, - coalesced_width=2, ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -503,13 +494,8 @@ def main(d_hidden=7168, d_expert=2048, n_routed_experts=8, n_shared_experts=1, n } data = generate_input(**config) - - torch.cuda.synchronize() ref_output = ref_kernel(clone_data(data)).to(torch.float32) - torch.cuda.synchronize() tilelang_output = custom_kernel(clone_data(data)).to(torch.float32) - torch.cuda.synchronize() - torch.testing.assert_close(ref_output, tilelang_output, atol=1e-2, rtol=1e-2) print("✅ Tilelang and Torch match") @@ -554,8 +540,6 @@ def run_regression_perf( block_dexpert=128, threads=256, num_stages=1, - k_pack=1, - coalesced_width=2, ) moe = MoE(config, shared_kernel, routed_kernel, weights, padding_M=128) @@ -627,8 +611,11 @@ def run_routed_kernel_only(): moe.expert_output_routed, ) - return do_bench(run_routed_kernel_only, backend="cupti") + shared_latency = do_bench(run_shared_kernel_only, backend="cupti") + routed_latency = do_bench(run_routed_kernel_only, backend="cupti") + return (shared_latency + routed_latency) / 2 if __name__ == "__main__": + tilelang.disable_cache() main() diff --git a/examples/gdn/example_chunk_delta_bwd.py b/examples/gdn/example_chunk_delta_bwd.py index 4230df525e..466c471821 100644 --- a/examples/gdn/example_chunk_delta_bwd.py +++ b/examples/gdn/example_chunk_delta_bwd.py @@ -4,6 +4,7 @@ import tilelang import tilelang.language as T +from tilelang.profiler import do_bench print(tilelang.__file__, flush=True) @@ -544,31 +545,6 @@ def run_test( assert_similar(dv2_ref_torch, dv2_tilelang, 1e-5, "torch-tilelang", data="dv2") -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def main(): DK = 128 run_test( diff --git a/examples/gdn/example_chunk_delta_h.py b/examples/gdn/example_chunk_delta_h.py index 2ee84e7bf6..c34d9b5304 100644 --- a/examples/gdn/example_chunk_delta_h.py +++ b/examples/gdn/example_chunk_delta_h.py @@ -4,6 +4,7 @@ import tilelang import tilelang.language as T from tilelang.autotuner import autotune +from tilelang.profiler import do_bench # Add your fla repository path to sys.path # Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae @@ -224,31 +225,6 @@ def kernel( return kernel -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def run_test( B, S, diff --git a/examples/gdn/example_chunk_o.py b/examples/gdn/example_chunk_o.py index a4d7281f55..bb95f555f8 100644 --- a/examples/gdn/example_chunk_o.py +++ b/examples/gdn/example_chunk_o.py @@ -127,16 +127,15 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 < i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 < i_s2: + A_fragment[i_s1, i_s2] = 0 T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], V_shared) T.copy(A_fragment, A_shared) diff --git a/examples/gdn/example_chunk_o_bwd.py b/examples/gdn/example_chunk_o_bwd.py index e589818f4c..b369e03a8f 100644 --- a/examples/gdn/example_chunk_o_bwd.py +++ b/examples/gdn/example_chunk_o_bwd.py @@ -109,7 +109,7 @@ def prepare_output( @tilelang.jit( out_idx=[-4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tilelang_chunk_o_bwd_dqkwg( # task config @@ -359,31 +359,6 @@ def kernel( return kernel -def do_bench(fn, *args, warmup=10, rep=10, **kwargs): - """ - Do benchmark for a function. - """ - start_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - end_event = [torch.cuda.Event(enable_timing=True) for i in range(rep)] - for _ in range(warmup): - fn(*args, **kwargs) - - torch.cuda.synchronize() - for i in range(rep): - start_event[i].record() - fn(*args, **kwargs) - end_event[i].record() - torch.cuda.synchronize() - - # Record clocks - times = torch.tensor( - [s.elapsed_time(e) for s, e in zip(start_event, end_event)], - dtype=torch.float, - ) - - return times.mean().item() - - def run_test( B, S, diff --git a/examples/gdn/example_chunk_scaled_dot_kkt.py b/examples/gdn/example_chunk_scaled_dot_kkt.py index 8c7a4d573b..c16374fe8c 100644 --- a/examples/gdn/example_chunk_scaled_dot_kkt.py +++ b/examples/gdn/example_chunk_scaled_dot_kkt.py @@ -111,16 +111,15 @@ def kernel( for i_s1, i_s2 in T.Parallel(block_S, block_S): G_diff_local[i_s1, i_s2] = G_shared[i_s1] - G_shared[i_s2] for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2): - with T.Then(): - A_fragment[i_s1, i_s2] = A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]) - with T.Else(): - A_fragment[i_s1, i_s2] = 0 + A_fragment[i_s1, i_s2] = T.if_then_else( + G_diff_local[i_s1, i_s2] <= 0 and i_s1 > i_s2, + A_fragment[i_s1, i_s2] * T.exp(G_diff_local[i_s1, i_s2]), + 0, + ) else: for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - A_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + A_fragment[i_s1, i_s2] = 0 T.copy(A_fragment, A_shared) T.copy(A_shared, A[bb, bs * block_S : (bs + 1) * block_S, bh, :]) diff --git a/examples/gdn/example_cumsum.py b/examples/gdn/example_cumsum.py index 0760b49645..9d4ca0222e 100644 --- a/examples/gdn/example_cumsum.py +++ b/examples/gdn/example_cumsum.py @@ -20,9 +20,7 @@ import torch -@tilelang.jit( - out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True} -) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_chunk_local_cumsum_scalar( # task config B, diff --git a/examples/gdn/example_wy_fast_bwd_split.py b/examples/gdn/example_wy_fast_bwd_split.py index de8afc2b77..5711010025 100644 --- a/examples/gdn/example_wy_fast_bwd_split.py +++ b/examples/gdn/example_wy_fast_bwd_split.py @@ -94,7 +94,7 @@ def prepare_output( @tilelang.jit( out_idx=[-5, -4, -3, -2, -1], - pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) def tilelang_wy_fast_bwd( # task config @@ -247,7 +247,7 @@ def kernel( return kernel -@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}) def tilelang_wy_fast_bwd_split( # task config B, @@ -345,26 +345,25 @@ def kernel( T.copy(dA_shared, dA_fragment) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): # noqa: SIM117 - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 + if i_s1 <= i_s2: + dA_fragment[i_s1, i_s2] = 0 T.copy(dA_fragment, dA_shared) T.gemm(dA_shared, A_shared, dA_fragment, clear_accum=True, transpose_B=True) T.copy(dA_fragment, dA_shared) T.gemm(A_shared, dA_shared, dA_fragment, clear_accum=True, transpose_A=True) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(i_s1 <= i_s2): - with T.Then(): - dA_fragment[i_s1, i_s2] = 0 - with T.Else(): - dA_fragment[i_s1, i_s2] = -dA_fragment[i_s1, i_s2] + dA_fragment[i_s1, i_s2] = T.if_then_else( + i_s1 <= i_s2, + 0, + -dA_fragment[i_s1, i_s2], + ) for i_s1, i_s2 in T.Parallel(block_S, block_S): - with T.If(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0): - with T.Then(): - dA_fragment[i_s1, i_s2] *= T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]) - with T.Else(): - dA_fragment[i_s1, i_s2] = 0 + dA_fragment[i_s1, i_s2] = T.if_then_else( + G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh] <= 0, + dA_fragment[i_s1, i_s2] * T.exp(G[bb, bs * block_S + i_s1, bh] - G[bb, bs * block_S + i_s2, bh]), + 0, + ) T.copy(dA_fragment, dA_shared) # acceptable dA diff diff --git a/examples/gdn/test_example_gdn_compilation.py b/examples/gdn/test_example_gdn_compilation.py index 6f9fa5d2f7..b3d255a70a 100644 --- a/examples/gdn/test_example_gdn_compilation.py +++ b/examples/gdn/test_example_gdn_compilation.py @@ -50,6 +50,7 @@ def test_example_wy_fast_compilation(): ) print(kernel.get_kernel_source()) W_tilelang, U_tilelang = kernel(K, V, Beta, G, A) + torch.cuda.synchronize() def test_example_wy_fast_bwd_split_compilation(): @@ -317,4 +318,4 @@ def test_example_chunk_delta_bwd_compilation(): if __name__ == "__main__": # tilelang.testing.main() - test_example_chunk_delta_bwd_compilation() + test_example_wy_fast_compilation() diff --git a/examples/gemm/example_gemm_autotune.py b/examples/gemm/example_gemm_autotune.py index 016d448a4c..052bd64c6d 100644 --- a/examples/gemm/example_gemm_autotune.py +++ b/examples/gemm/example_gemm_autotune.py @@ -107,7 +107,13 @@ def get_configs(M, N, K, with_roller=False, topk=20): return configs -def get_best_config(M, N, K, with_roller=False): +def get_best_config( + M, + N, + K, + with_roller: bool = False, + profile_backend: str = "event", +): def kernel( block_M=None, block_N=None, @@ -156,6 +162,7 @@ def main( supply_type=tl.TensorSupplyType.Integer, ref_prog=ref_program, skip_check=False, + backend=profile_backend, ) ) return autotuner.run(warmup=3, rep=20) @@ -207,10 +214,22 @@ def gemm_autotune( return gemm_autotune -def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False, with_roller: bool = False): - use_autotune = True +def main( + M: int = 4096, + N: int = 4096, + K: int = 4096, + use_autotune: bool = False, + with_roller: bool = False, + profile_backend: str = "event", +): if use_autotune: - result = get_best_config(M, N, K, with_roller) + result = get_best_config( + M, + N, + K, + with_roller=with_roller, + profile_backend=profile_backend, + ) print(result.config) kernel = result.kernel else: @@ -219,8 +238,13 @@ def main(M: int = 4096, N: int = 4096, K: int = 4096, use_autotune: bool = False # benchmark profiler = kernel.get_profiler(tensor_supply_type=tl.TensorSupplyType.Auto) - tilelang_latency = profiler.do_bench() - ref_latency = profiler.do_bench(ref_program) + tilelang_latency = profiler.do_bench( + backend=profile_backend, + ) + ref_latency = profiler.do_bench( + ref_program, + backend=profile_backend, + ) profiler.assert_allclose(ref_program, atol=1e-2, rtol=1e-2) print(f"TileLang latency: {tilelang_latency}") print(f"Ref latency: {ref_latency}") @@ -242,5 +266,13 @@ def run_regression_perf(M: int = 4096, N: int = 4096, K: int = 4096): parser.add_argument("--k", type=int, default=4096, help="Matrix dimension K") parser.add_argument("--use_autotune", action="store_true", default=False, help="Whether to use autotune for matmul configs") parser.add_argument("--with_roller", action="store_true", default=False, help="Whether to enable BitBLAS roller for search space") + parser.add_argument("--profile_backend", type=str, default="event", help="Profiler backend") args = parser.parse_args() - main(args.m, args.n, args.k, args.use_autotune, args.with_roller) + main( + args.m, + args.n, + args.k, + args.use_autotune, + args.with_roller, + args.profile_backend, + ) diff --git a/examples/gemm/example_gemm_intrinsics.py b/examples/gemm/example_gemm_intrinsics.py index d4bc9480ff..15e552587e 100644 --- a/examples/gemm/example_gemm_intrinsics.py +++ b/examples/gemm/example_gemm_intrinsics.py @@ -6,7 +6,6 @@ from tilelang.intrinsics.mma_macro_generator import ( TensorCoreIntrinEmitter, ) -from tilelang.transform import simplify_prim_func def make_swizzle_layout(shared_buf): @@ -25,7 +24,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, diff --git a/examples/gemm/regression_example_gemm.py b/examples/gemm/regression_example_gemm.py index 3583cf16ac..4976020598 100644 --- a/examples/gemm/regression_example_gemm.py +++ b/examples/gemm/regression_example_gemm.py @@ -2,7 +2,6 @@ import example_gemm import example_gemm_autotune import example_gemm_intrinsics -import example_gemm_schedule def regression_example_gemm_autotune(): @@ -13,10 +12,6 @@ def regression_example_gemm_intrinsics(): tilelang.testing.process_func(example_gemm_intrinsics.run_regression_perf, M=1024, N=1024, K=1024) -def regression_example_gemm_schedule(): - tilelang.testing.process_func(example_gemm_schedule.run_regression_perf) - - def regression_example_gemm(): tilelang.testing.process_func(example_gemm.run_regression_perf) diff --git a/examples/gemm/test_example_gemm.py b/examples/gemm/test_example_gemm.py index 5f69364be6..fb0ae3ab4b 100644 --- a/examples/gemm/test_example_gemm.py +++ b/examples/gemm/test_example_gemm.py @@ -1,7 +1,6 @@ import tilelang.testing import example_gemm_autotune import example_gemm_intrinsics -import example_gemm_schedule import example_gemm @@ -14,10 +13,6 @@ def test_example_gemm_intrinsics(): example_gemm_intrinsics.main(M=1024, N=1024, K=1024) -def test_example_gemm_schedule(): - example_gemm_schedule.main() - - def test_example_gemm(): example_gemm.main() diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd.py b/examples/gemm_fp8/example_tilelang_gemm_amd.py index 93f8c4980c..16a9d5f329 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_amd.py +++ b/examples/gemm_fp8/example_tilelang_gemm_amd.py @@ -2,6 +2,7 @@ import tilelang import tilelang.language as T from tilelang.utils.tensor import torch_assert_close +from tilelang.utils import determine_fp8_type, determine_torch_fp8_type import itertools @@ -17,8 +18,9 @@ def supply_prog(args): a_param, b_param = args M, K = a_param.shape N, _ = b_param.shape - a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) return [a, b] @@ -53,7 +55,7 @@ def get_configs(): ) @tilelang.jit(out_idx=[-1]) def fp8_matmul(M, N, K, block_M, block_N, block_K, num_stages, num_threads, k_pack, gemm_type): - dtype = T.float8_e4m3fnuz + dtype = determine_fp8_type() accum_dtype = T.float32 @T.prim_func @@ -104,8 +106,9 @@ def gemm_fp8_ss( def test_gemm_fp8(M, N, K): kernel = fp8_matmul(M, N, K) - a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=torch.float8_e4m3fnuz) + fp8_dtype = determine_torch_fp8_type() + a = (torch.randn(M, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) + b = (torch.randn(N, K, dtype=torch.float16, device="cuda") * 0.01).to(dtype=fp8_dtype) c = kernel(a, b) ref_c = ref_program(a, b) torch_assert_close(c, ref_c, rtol=1e-2, atol=1e-2) diff --git a/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py new file mode 100644 index 0000000000..fc7fb44003 --- /dev/null +++ b/examples/gemm_fp8/example_tilelang_gemm_amd_fp8_preshuffle.py @@ -0,0 +1,225 @@ +import torch +import itertools +import tilelang +import tilelang.testing +from tilelang import tvm as tvm +import tilelang.language as T +from tilelang.tileop.base import GemmWarpPolicy +from tilelang.layout import make_swizzled_layout +from tilelang.intrinsics.mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter +from tilelang.utils import determine_fp8_type + +tilelang.testing.set_random_seed(0) + + +def get_configs(): + block_Ms = [32, 64, 128] + block_Ns = [32, 64, 128] + block_Ks = [64, 128] + num_stages = [0, 1, 2] + + valid_configs = [] + + for m, n, k, stages in itertools.product(block_Ms, block_Ns, block_Ks, num_stages): + valid_configs.append( + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + } + ) + return valid_configs + + +@tilelang.autotune( + configs=get_configs(), +) +@tilelang.jit(out_idx=[-1]) +def tl_matmul( + M, + N, + K, + block_M, + block_N, + block_K, + num_stages, + k_pack=2, + num_threads=256, + in_dtype=None, + out_dtype=T.float32, + accum_dtype=T.float32, + a_transposed=False, + b_transposed=True, +): + if in_dtype is None: + in_dtype = determine_fp8_type() + b_preshuffle = True + warp_size = 64 + num_warps = num_threads // warp_size + + policy = GemmWarpPolicy.Square + m_warp, n_warp = policy.compute_warp_partition(block_M, block_N, num_warps) + + shared_scope = "shared" + warp_row_tiles = block_M // m_warp + warp_col_tiles = block_N // n_warp + + # MMA Wrapper to Auto Generate Code for MMA + mfma_emitter = MatrixCorePreshuffleIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + block_row_warps=m_warp, + block_col_warps=n_warp, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=block_K, + k_pack=k_pack, + b_preshuffle=b_preshuffle, + ) + local_size_a = mfma_emitter.local_size_a + local_size_b = mfma_emitter.local_size_b + + warp_rows = mfma_emitter.warp_rows + warp_cols = mfma_emitter.warp_cols + + micro_size_y = mfma_emitter.micro_size_y + micro_size_k = mfma_emitter.micro_size_k + pack_size_k = micro_size_k * k_pack + + A_shape = (K, M) if a_transposed else (M, K) + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) + + B_shape = ( + (N // micro_size_y, K // pack_size_k, micro_size_y, pack_size_k) + if b_transposed + else (K // pack_size_k, N // micro_size_y, pack_size_k, micro_size_y) + ) + + @T.prim_func + def main( + A: T.Tensor(A_shape, in_dtype), + B: T.Tensor(B_shape, in_dtype), + C: T.Tensor((M, N), out_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) + A_local = T.alloc_local((warp_rows * local_size_a * k_pack), in_dtype) + B_local = T.alloc_local((warp_cols * local_size_b * k_pack), in_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.annotate_layout( + { + A_shared: make_swizzled_layout(A_shared), + C_local: mfma_emitter.make_mfma_store_layout(C_local), + } + ) + + num_ko = K // block_K + num_ki = block_K // (k_pack * micro_size_k) + + # Improve L2 Cache + # T.use_swizzle(panel_size=10) + T.clear(C_local) + for ko in T.Pipelined(num_ko, num_stages=num_stages): + # Load A into shared memory + if a_transposed: + T.copy(A[ko * block_K, by * block_M], A_shared) + else: + T.copy(A[by * block_M, ko * block_K], A_shared) + + for ki in T.serial(0, num_ki): + mfma_emitter.ldmatrix_a( + A_local, + A_shared, + ki, + ) + mfma_emitter.ldmatrix_b(B_local, B, ki + ko * num_ki, pid_m=by, pid_n=bx) + + # Perform Matrix Multiplication + mfma_emitter.mfma(A_local, B_local, C_local, ki) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def shuffle_weight( + x: torch.Tensor, + layout=(16, 32), + k_pack=1, + is_transpose=False, +) -> torch.Tensor: + IN, IK = layout + BK = IK * k_pack + BN = IN + + N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) + assert N % BN == 0 + assert K % BK == 0 + + x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) + x = x.permute(0, 2, 1, 3) + return x.contiguous() + + +def assert_tl_matmul_correctness(M, N, K, k_pack=1, a_transposed=False, b_transposed=True): + in_dtype = determine_fp8_type() + out_dtype = T.float32 + accum_dtype = T.float32 + kernel = tl_matmul( + M, + N, + K, + k_pack=k_pack, + in_dtype=in_dtype, + out_dtype=out_dtype, + accum_dtype=accum_dtype, + a_transposed=a_transposed, + b_transposed=b_transposed, + ) + + src_code = kernel.get_kernel_source() + # src_code is the generated cuda source + assert src_code is not None + A_shape = (K, M) if a_transposed else (M, K) + B_shape = (N, K) if b_transposed else (K, N) + + A = (torch.rand(A_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + B = (torch.rand(B_shape, device="cuda", dtype=torch.float16) / 10).to(getattr(torch, in_dtype)) + + B_preshuffle = shuffle_weight(B, k_pack=k_pack, is_transpose=b_transposed) + C = kernel(A, B_preshuffle) + + profiler = kernel.get_profiler() + latency = profiler.do_bench() + + # Ensure that the latency is not None + assert latency is not None + print("time: ", latency) + + if a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.half(), B.T.half()).to(getattr(torch, out_dtype)) + elif a_transposed and not b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.T.half(), B.half()).to(getattr(torch, out_dtype)) + elif not a_transposed and b_transposed: + # Get Reference Result + ref_c = torch.matmul(A.half(), B.T.half()).to(getattr(torch, out_dtype)) + else: + # Get Reference Result + ref_c = torch.matmul(A.half(), B.half()).to(getattr(torch, out_dtype)) + + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) + + +def test_assert_tl_matmul(): + assert_tl_matmul_correctness(512, 512, 512, k_pack=2) + + +if __name__ == "__main__": + test_assert_tl_matmul() diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8.py b/examples/gemm_fp8/example_tilelang_gemm_fp8.py index 0869979756..3b575c78e8 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8.py @@ -1,6 +1,7 @@ import torch import tilelang import tilelang.language as T +from tilelang.utils import determine_fp8_type def calc_diff(x, y): @@ -55,21 +56,24 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 1024, T.float8_e4m3fn) - test_gemm_fp8(1024, 1024, 1024, T.float8_e5m2) + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 1024, determine_fp8_type("e5m2")) def run_regression_perf(): M, N, K = 4096, 4096, 4096 - dtype = "float8_e4m3" + dtype = determine_fp8_type() kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - dtype = "float8_e5m2" - kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) - profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") - return (latency_e4m3 + latency_e5m2) / 2 + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + latency_e4m3 = profiler_e4m3.do_bench() + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py index a702e8ae0a..39c6fc333c 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_2xAcc.py @@ -1,6 +1,7 @@ import torch import tilelang import tilelang.language as T +from tilelang.utils import determine_fp8_type @tilelang.jit(out_idx=[-1]) @@ -73,21 +74,26 @@ def test_gemm_fp8(M, N, K, dtype): def main(): - test_gemm_fp8(1024, 1024, 8192, T.float8_e4m3fn) - test_gemm_fp8(1024, 1024, 8192, T.float8_e5m2) + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type()) + test_gemm_fp8(1024, 1024, 8192, determine_fp8_type("e5m2")) def run_regression_perf(): M, N, K = 1024, 1024, 8192 - dtype = "float8_e4m3" + dtype = determine_fp8_type() kernel_e4m3 = matmul(M, N, K, 128, 128, 64, dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - dtype = "float8_e5m2" - kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) - profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") - return (latency_e4m3 + latency_e5m2) / 2 + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() + if torch.version.hip is None: + dtype = determine_fp8_type("e5m2") + kernel_e5m2 = matmul(M, N, K, 128, 128, 64, dtype) + profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) + latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") + return (latency_e4m3 + latency_e5m2) / 2 + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py index 762885ec38..d9f749d9f2 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_intrinsic.py @@ -4,11 +4,9 @@ from tvm import DataType import tilelang.language as T from tilelang.intrinsics import get_swizzle_layout -from tilelang.intrinsics.mma_macro_generator import ( - TensorCoreIntrinEmitter, -) -from tilelang.transform import simplify_prim_func -from tilelang.utils.tensor import map_torch_type +from tilelang.intrinsics.mma_macro_generator import TensorCoreIntrinEmitter +from tilelang.intrinsics.mfma_macro_generator import MatrixCoreIntrinEmitter +from tilelang.utils import determine_fp8_type tilelang.testing.set_random_seed(0) @@ -29,7 +27,6 @@ def transform_func(i, j): @tilelang.jit(out_idx=[2]) -@simplify_prim_func def tl_matmul( M, N, @@ -41,26 +38,17 @@ def tl_matmul( assert in_dtype in [ T.float16, T.float8_e4m3fn, + T.float8_e4m3fnuz, T.float8_e5m2, + T.float8_e5m2fnuz, T.int8, - ], "Currently only float16 and int8 are supported" + ], "Currently only float16, float8, and int8 are supported" assert out_dtype in [ T.float16, T.float32, T.int32, ], "Currently only float16, float32 and int32 are supported" - micro_size_x = micro_size_y = micro_size_k = 16 - - is_float8 = in_dtype in [ - T.float8_e4m3fn, - T.float8_e5m2, - T.float8_e4m3fn, - T.float8_e5m2fnuz, - ] - if out_dtype == T.int32 or is_float8: - micro_size_k = 32 - # This is a debug config block_row_warps = 2 block_col_warps = 2 @@ -80,6 +68,38 @@ def tl_matmul( B_shape = (N, K) A_shared_shape = (block_M, block_K) B_shared_shape = (block_N, block_K) + is_hip = torch.version.hip is not None + # MMA Wrapper to Auto Generate Code for MMA/MFMA + if is_hip: + mma_emitter = MatrixCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + else: + mma_emitter = TensorCoreIntrinEmitter( + a_dtype=in_dtype, + b_dtype=in_dtype, + accum_dtype=accum_dtype, + a_transposed=False, + b_transposed=True, + block_row_warps=block_row_warps, + block_col_warps=block_col_warps, + warp_row_tiles=warp_row_tiles, + warp_col_tiles=warp_col_tiles, + chunk=chunk, + ) + + micro_size_x = mma_emitter.M_DIM + micro_size_y = getattr(mma_emitter, "n_dim", getattr(mma_emitter, "N_DIM", micro_size_x)) + micro_size_k = mma_emitter.k_dim C_shared_shape = ( block_M // micro_size_x, block_N // micro_size_y, @@ -87,27 +107,12 @@ def tl_matmul( micro_size_y, ) - warp_size = 32 - threads = warp_size * (block_row_warps * block_col_warps) - local_size_a = (micro_size_x * micro_size_k) // warp_size - local_size_b = (micro_size_y * micro_size_k) // warp_size - local_size_c = (micro_size_x * micro_size_y) // warp_size - warp_rows = warp_row_tiles // micro_size_x - warp_cols = warp_col_tiles // micro_size_y - - # MMA Wrapper to Auto Generate Code for MMA - mma_emitter = TensorCoreIntrinEmitter( - a_dtype=in_dtype, - b_dtype=in_dtype, - accum_dtype=accum_dtype, - a_transposed=False, - b_transposed=True, - block_row_warps=block_row_warps, - block_col_warps=block_col_warps, - warp_row_tiles=warp_row_tiles, - warp_col_tiles=warp_col_tiles, - chunk=chunk, - ) + threads = mma_emitter.threads + local_size_a = mma_emitter.local_size_a + local_size_b = mma_emitter.local_size_b + local_size_c = mma_emitter.local_size_out + warp_rows = mma_emitter.warp_rows + warp_cols = mma_emitter.warp_cols @T.prim_func def gemm_fp8_intrinsic( @@ -160,7 +165,10 @@ def gemm_fp8_intrinsic( ) # Perform Matrix Multiplication - mma_emitter.mma(A_local, B_local, C_local) + if is_hip: + mma_emitter.mfma(A_local, B_local, C_local, ki) + else: + mma_emitter.mma(A_local, B_local, C_local) # Perform STMatrix mma_emitter.stmatrix( @@ -183,18 +191,17 @@ def gemm_fp8_intrinsic( def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): kernel = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) src_code = kernel.get_kernel_source() - print(src_code) # src_code is the generated cuda source assert src_code is not None - in_dtype = map_torch_type(in_dtype) - out_dtype = map_torch_type(out_dtype) - accum_dtype = map_torch_type(accum_dtype) + in_dtype = in_dtype.as_torch() + out_dtype = out_dtype.as_torch() + accum_dtype = accum_dtype.as_torch() if in_dtype in {torch.int8, torch.int32}: A = torch.randint(-128, 128, (M, K), dtype=torch.int8).to(in_dtype).cuda() B = torch.randint(-128, 128, (N, K), dtype=torch.int8).to(in_dtype).cuda() - elif in_dtype in {torch.float8_e4m3fn, torch.float8_e5m2}: + elif in_dtype in {torch.float8_e4m3fn, torch.float8_e4m3fnuz, torch.float8_e5m2, torch.float8_e5m2fnuz}: A = torch.randn(M, K).to(in_dtype).cuda() B = torch.randn(N, K).to(in_dtype).cuda() else: @@ -214,28 +221,27 @@ def assert_tl_matmul_correctness(M, N, K, in_dtype, out_dtype, accum_dtype): # Get Reference Result ref_c = torch.matmul(A.to(accum_dtype), B.T.to(accum_dtype)).to(out_dtype) - print(C) - print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) def main(): - assert_tl_matmul_correctness(128, 128, 128, T.float8_e4m3fn, T.float32, T.float32) - assert_tl_matmul_correctness(128, 128, 128, T.float8_e5m2, T.float32, T.float32) + e4m3_dtype = determine_fp8_type() + assert_tl_matmul_correctness(128, 128, 128, e4m3_dtype, T.float32, T.float32) + e5m2_dtype = determine_fp8_type("e5m2") + assert_tl_matmul_correctness(128, 128, 128, e5m2_dtype, T.float32, T.float32) def run_regression_perf(): M, N, K = 4096, 4096, 4096 - out_dtype, accum_dtype = "float32", "float32" - in_dtype = T.float8_e4m3fn + out_dtype, accum_dtype = T.float32, T.float32 + in_dtype = determine_fp8_type() kernel_e4m3 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) profiler_e4m3 = kernel_e4m3.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") - in_dtype = T.float8_e5m2 - kernel_e5m2 = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype) - profiler_e5m2 = kernel_e5m2.get_profiler(tilelang.TensorSupplyType.Integer) - latency_e5m2 = profiler_e5m2.do_bench(backend="cupti") - return (latency_e4m3 + latency_e5m2) / 2 + if torch.version.hip is None: + latency_e4m3 = profiler_e4m3.do_bench(backend="cupti") + else: + latency_e4m3 = profiler_e4m3.do_bench() + return latency_e4m3 if __name__ == "__main__": diff --git a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py index aa7e8b3608..72f09c2503 100644 --- a/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py +++ b/examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py @@ -1,7 +1,6 @@ import torch import tilelang import tilelang.language as T -from tilelang.utils.tensor import map_torch_type def matmul( @@ -41,14 +40,13 @@ def main( for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): T.copy(A[by * block_M, k * block_K], A_shared) T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm_v2( + T.tcgen05_gemm( A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, - wg_wait=-1, clear_accum=(k == 0), ) T.mbarrier_wait_parity(mbar, k % 2) @@ -75,8 +73,8 @@ def calc_diff(x, y): threads = 256 for tvm_fp8_dtype in [T.float8_e4m3fn, T.float8_e5m2]: for tvm_acc_dtype in [T.float16, T.float32]: # , torch.float16]: - torch_fp8_dtype = map_torch_type(tvm_fp8_dtype) - torch_acc_dtype = map_torch_type(tvm_acc_dtype) + torch_fp8_dtype = tvm_fp8_dtype.as_torch() + torch_acc_dtype = tvm_acc_dtype.as_torch() print(f"running {tvm_fp8_dtype} -> {tvm_acc_dtype}") in_dtype, out_dtype, accum_dtype = tvm_fp8_dtype, tvm_acc_dtype, tvm_acc_dtype @@ -100,7 +98,6 @@ def calc_diff(x, y): out_idx=[2], target="cuda", pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, tilelang.PassConfigKey.TL_ENABLE_PTXAS_VERBOSE_OUTPUT: True, }, diff --git a/examples/gemm_int4/example_tilelang_gemm_int4.py b/examples/gemm_int4/example_tilelang_gemm_int4.py new file mode 100644 index 0000000000..3db000b616 --- /dev/null +++ b/examples/gemm_int4/example_tilelang_gemm_int4.py @@ -0,0 +1,113 @@ +"""Frontend int4 GEMM example for the T.gemm int4 path. + +This file intentionally models the desired TileLang frontend API: +- A/B are declared as T.int4 tensors +- the matmul is expressed with T.gemm(...) + +The example compiles the kernel, prints the generated CUDA source, and +checks correctness against a PyTorch reference. +""" + +import torch + +import tilelang +import tilelang.language as T + + +def matmul_nt_int4(M, N, K, block_M, block_N, block_K, threads=128): + @T.prim_func + def main( + A: T.Tensor((M, K), T.int4), + B: T.Tensor((N, K), T.int4), + C: T.Tensor((M, N), T.int32), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), T.int4) + B_shared = T.alloc_shared((block_N, block_K), T.int4) + C_local = T.alloc_fragment((block_M, block_N), T.int32) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=3): + T.copy(A[by * block_M, ko * block_K], A_shared) + T.copy(B[bx * block_N, ko * block_K], B_shared) + # Frontend expectation: T.gemm should accept int4 operands directly. + T.gemm(A_shared, B_shared, C_local, transpose_B=True) + + T.copy(C_local, C[by * block_M, bx * block_N]) + + return main + + +def compile_int4_gemm( + M=1024, + N=1024, + K=1024, + block_M=128, + block_N=128, + block_K=64, + threads=128, + print_cuda_source=True, +): + func = matmul_nt_int4(M, N, K, block_M, block_N, block_K, threads) + kernel = tilelang.compile(func, out_idx=-1) + print("Compilation succeeded.") + if print_cuda_source: + print(kernel.get_kernel_source()) + return func, kernel + + +def pack_int4(tensor: torch.Tensor) -> torch.Tensor: + if tensor.dtype != torch.int8: + raise TypeError(f"Expected torch.int8 logical int4 tensor, but got {tensor.dtype}.") + if tensor.ndim == 0 or tensor.shape[-1] % 2 != 0: + raise ValueError("The last dimension of a logical int4 tensor must be even for int8 packing.") + + tensor_i16 = tensor.to(torch.int16) + packed = (tensor_i16[..., ::2] & 0x0F) | ((tensor_i16[..., 1::2] & 0x0F) << 4) + return packed.to(torch.int8).contiguous() + + +def check_int4_gemm_correctness( + M=1024, + N=1024, + K=1024, + block_M=128, + block_N=128, + block_K=64, + threads=128, +): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA is required to run the int4 GEMM example.") + + _, kernel = compile_int4_gemm( + M=M, + N=N, + K=K, + block_M=block_M, + block_N=block_N, + block_K=block_K, + threads=threads, + ) + + A_logical = torch.randint(-8, 8, (M, K), device="cuda", dtype=torch.int8) + B_logical = torch.randint(-8, 8, (N, K), device="cuda", dtype=torch.int8) + + A_packed = pack_int4(A_logical) + B_packed = pack_int4(B_logical) + C = kernel(A_packed, B_packed) + torch.cuda.synchronize() + + ref_c = torch.matmul(A_logical.cpu().to(torch.int32), B_logical.cpu().to(torch.int32).T) + torch.testing.assert_close(C.cpu(), ref_c, rtol=0, atol=0) + print("Correctness check passed.") + return C, ref_c + + +def main(): + # check_int4_gemm_correctness(M=16, N=16, K=32, block_M=16, block_N=16, block_K=32) + # check_int4_gemm_correctness(M=16, N=16, K=64, block_M=16, block_N=16, block_K=64) + check_int4_gemm_correctness() + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sm100/README.md b/examples/gemm_sm100/README.md index d630d2d0d3..3ae66dde91 100644 --- a/examples/gemm_sm100/README.md +++ b/examples/gemm_sm100/README.md @@ -7,17 +7,30 @@ This directory contains examples for TileLang's experimental SM100 architecture ### 1. Manual TCGEN5.MMA Management Users must manually handle TCGEN5MMA operations using: - `T.alloc_tmem()` - Allocate Tensor Memory -- `T.gemm()` with `wg_wait=-1` - Launch TCGEN5MMA without waiting +- `T.tcgen05_gemm()` - Launch TCGEN5MMA without an implicit wait - Manual synchronization with mbarrier +For the default synchronous path, `T.gemm(..., mbar=...)` now inserts the +matching `mbarrier_wait_parity(...)` automatically after TCGEN5MMA issue. + ### 2. Manual mbarrier Synchronization TCGEN5MMA is asynchronous and requires explicit synchronization: ```python mbar = T.alloc_barrier(1) # expect-arrive-count = 1 -T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k==0) +T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, clear_accum=k==0) T.mbarrier_wait_parity(mbar, k%2) # Manual phase calculation required ``` +TileLang now has a conservative `InjectTcgen05Fence` pass on SM100+ that can +insert `tcgen05_before_thread_sync()` / `tcgen05_after_thread_sync()` around: +- `tvm_storage_sync("shared"|"shared.dyn")` +- linear `mbarrier_wait_parity(...) -> tcgen05/TMEM use` regions +- linear `tcgen05/TMEM use -> mbarrier_arrive(...)` regions + +This does **not** eliminate the need to structure the mbarrier protocol +explicitly in user code, and the examples in this directory still keep manual +fences where they make the handoff points obvious. + ## Examples ### TCGEN5MMA Example (`gemm_tcgen5mma.py`) @@ -61,8 +74,8 @@ def main( T.copy(B[bx * block_N, k * block_K], B_shared) # TCGEN5MMA computation: asynchronous launch, output to Tensor Memory - T.gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, - mbar=mbar, wg_wait=-1, clear_accum=k==0) + T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A=False, trans_B=True, + mbar=mbar, clear_accum=k==0) # Critical: wait for TCGEN5MMA completion T.mbarrier_wait_parity(mbar, k%2) @@ -84,7 +97,6 @@ block_M, block_N, block_K = 128, 256, 128 # Compile kernel jit_kernel = tilelang.compile(func, out_idx=[2], target="cuda", pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, # Required tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, # Required }) diff --git a/examples/gemm_sm100/gemm_mma.py b/examples/gemm_sm100/gemm_mma.py index 226e33c01e..e3a70df973 100644 --- a/examples/gemm_sm100/gemm_mma.py +++ b/examples/gemm_sm100/gemm_mma.py @@ -58,10 +58,7 @@ def main( func, out_idx=[2], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(jit_kernel.get_kernel_source()) # 3. Test the kernel in Python with PyTorch data diff --git a/examples/gemm_sm100/gemm_tcgen5mma.py b/examples/gemm_sm100/gemm_tcgen5mma.py index 523a94fea6..229908992a 100644 --- a/examples/gemm_sm100/gemm_tcgen5mma.py +++ b/examples/gemm_sm100/gemm_tcgen5mma.py @@ -38,9 +38,9 @@ def main( C_shared = T.alloc_shared((block_M, block_N), out_dtype) for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, wg_wait=-1, clear_accum=k == 0) + T.copy(A[by * block_M, k * block_K], A_shared) # not trans_A + T.copy(B[bx * block_N, k * block_K], B_shared) # trans_B + T.tcgen05_gemm(A_shared, B_shared, C_tmem, trans_A, trans_B, mbar=mbar, clear_accum=k == 0) T.mbarrier_wait_parity(mbar, k % 2) T.copy(C_tmem, C_local) @@ -52,10 +52,10 @@ def main( M, N, K = 4096, 4096, 8192 -block_M, block_N, block_K = 128, 256, 128 +block_M, block_N, block_K = 128, 128, 128 trans_A, trans_B = False, True in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float -num_stages = 2 +num_stages = 0 if block_N >= 256 or block_M >= 256 or block_K >= 256 else 2 threads = 256 func = matmul(M, N, K, block_M, block_N, block_K, trans_A, trans_B, in_dtype, out_dtype, accum_dtype, num_stages, threads) @@ -63,10 +63,7 @@ def main( func, out_idx=[2], target="cuda", - pass_configs={ - tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, ) print(jit_kernel.get_kernel_source()) diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws.py b/examples/gemm_sm100/gemm_tcgen5mma_ws.py new file mode 100644 index 0000000000..b8f2adf41a --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws.py @@ -0,0 +1,164 @@ +# Non-persistent + +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +@tilelang.jit +def gemm(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, use_tma_store=True): + M, N, K = T.const("M, N, K") + + k_iters = T.ceildiv(K, block_K) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128) as (bx, by): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + loaded = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + + T.use_swizzle(8) + + if tx < 32: # warp 0: issue tma + for k in T.serial(k_iters): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], + B_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + T.mbarrier_arrive(loaded[k % num_stages]) + elif tx < 64: # warp 1: issue tcgen5 + for k in T.serial(k_iters): + T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + T.tcgen05_gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem, + mbar=consumed[k % num_stages], + clear_accum=k == 0, + ) + T.tcgen05_mma_arrive(tmem_full) + + # Wait for all tcgen5 to finish + T.mbarrier_wait_parity(tmem_full, 0) + T.copy(C_tmem, C_local) + if use_tma_store: + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) # STG256 + return C + + +@tilelang.jit +def gemm_2cta(A, B, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, use_tma_store=True): + M, N, K = T.const("M, N, K") + + k_iters = T.ceildiv(K, block_K) + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + with T.Kernel(T.ceildiv(M, block_M), T.ceildiv(N, block_N), threads=128, cluster_dims=2) as (bx, by): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) # Each cta hold half of B + C_tmem = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_shared = T.alloc_shared((block_M, block_N), out_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + loaded = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1]) + + tx = T.get_thread_binding() + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) # todo: automatically assume this + + T.use_swizzle(16) # TL will perform auto threadblock swizzle with cluster + + if tx < 32: # warp 0: issue tma + for k in T.serial(k_iters): + T.mbarrier_wait_parity(consumed[k % num_stages], ((k // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * (block_N // 2) : (by * 2 + cta_id + 1) * (block_N // 2)], + B_shared[k % num_stages, :, :], + barrier=loaded[k % num_stages], + ) + T.mbarrier_arrive(loaded[k % num_stages], 0) # arrive on leader cta's barrier + elif cta_id == 0 and tx < 64: # Only warp 1 on leader cta issues tcgen5 + for k in T.serial(k_iters): + T.mbarrier_wait_parity(loaded[k % num_stages], (k // num_stages) & 1) + T.tcgen05_gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem, + mbar=consumed[k % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full, arrive_2cta=True) + + # Wait for all tcgen5 to finish + T.mbarrier_wait_parity(tmem_full, 0) + T.copy(C_tmem, C_local) + if use_tma_store: + T.copy(C_local, C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + enable_2cta_tcgen5mma = True + num_stages = 6 if enable_2cta_tcgen5mma else 4 # Each cta only needs to load half of B, enabling larger stages + kernel = gemm_2cta if enable_2cta_tcgen5mma else gemm + + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + c = kernel(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) + print(kernel.get_kernel_source(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench(lambda: kernel(a, b, block_M, block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti") + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py new file mode 100644 index 0000000000..10ac66937a --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_clc.py @@ -0,0 +1,212 @@ +# Introduce CLC tile schedule + +import torch +import tilelang +import tilelang.language as T +from tilelang.profiler import do_bench + + +def get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id): + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + return bx, by + + +@tilelang.jit(pass_configs={tilelang.PassConfigKey.TL_DISABLE_THREAD_STORAGE_SYNC: True}) +def gemm_clc_persistent_2cta( + A, + B, + block_M, + block_N, + store_block_N, + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + group_size=8, + use_tma_store=True, +): + M, N, K = T.const("M, N, K") + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + m_blocks = T.ceildiv(M, block_M) + m_clusters = m_blocks // 2 + n_blocks = T.ceildiv(N, block_N) + total_cluster_tiles = m_clusters * n_blocks + k_blocks = T.ceildiv(K, block_K) + assert n_blocks % (2 * group_size) == 0 + + with T.Kernel(total_cluster_tiles * 2, threads=256, cluster_dims=2) as block_id: + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) + C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + loaded = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1] * 2) + tmem_empty = T.alloc_cluster_barrier([128 * 2] * 2) + schedule_arrived = T.alloc_cluster_barrier([1]) + schedule_finished = T.alloc_cluster_barrier([7]) + clc_result = T.alloc_shared((4,), "uint32", scope="shared") + schedule_valid = T.alloc_shared((1,), "int32") + schedule_tile_id = T.alloc_shared((1,), "int32") + + tx = T.get_thread_binding() + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) + + if tx < 32: # Producer (TMA loads) + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 0: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + tile_id = T.if_then_else( + work_iter == 0, + block_id // 2, + schedule_tile_id[0], + ) + bx, by = get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id) + + for k in T.serial(k_blocks): + phase = work_iter * k_blocks + k + T.mbarrier_wait_parity(consumed[phase % num_stages], ((phase // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * block_N // 2 : (by * 2 + cta_id + 1) * block_N // 2], + B_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.mbarrier_arrive(loaded[phase % num_stages], 0) + + elif cta_id == 0 and tx < 64: # MMA (cta_id 0 only) + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 32: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + T.mbarrier_wait_parity(tmem_empty[work_iter & 1], ((work_iter // 2) & 1) ^ 1) + for k in T.serial(k_blocks): + phase = work_iter * k_blocks + k + T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + if work_iter & 1 == 0: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_0, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + else: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_1, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full[work_iter & 1], arrive_2cta=True) + + elif 64 <= tx < 96: # CLC Scheduler (both CTAs) + for work_iter in T.unroll(total_cluster_tiles): + if tx == 64: + if cta_id == 0 and work_iter > 0: + T.mbarrier_wait_parity(schedule_finished, (work_iter - 1) & 1) + T.mbarrier_arrive_expect_tx(schedule_arrived, 16) + if cta_id == 0: + T.clc_try_cancel_multicast(clc_result, schedule_arrived) + T.mbarrier_wait_parity(schedule_arrived, work_iter & 1) + schedule_valid[0] = T.clc_is_canceled(clc_result) + schedule_tile_id[0] = T.cast(T.clc_get_first_ctaid_x(clc_result), "int32") // 2 + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + elif 128 <= tx < 256: # Epilogue + for work_iter in T.unroll(total_cluster_tiles): + if work_iter > 0: + T.mbarrier_wait_parity(schedule_arrived, (work_iter - 1) & 1) + if tx == 128: + T.mbarrier_arrive(schedule_finished, 0) + if schedule_valid[0] == 0: + break + + tile_id = T.if_then_else( + work_iter == 0, + block_id // 2, + schedule_tile_id[0], + ) + bx, by = get_swizzled_block_idx(tile_id, group_size, m_clusters, cta_id) + + T.mbarrier_wait_parity(tmem_full[work_iter & 1], (work_iter // 2) & 1) + T.sync_threads(1, 128) + if work_iter & 1 == 0: + T.copy(C_tmem_0, C_local) + else: + T.copy(C_tmem_1, C_local) + T.mbarrier_arrive(tmem_empty[work_iter & 1], 0) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.sync_threads(3, 128) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + T.sync_threads(3, 128) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + store_block_N = 64 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + num_stages = 6 + l2_swizzle_group_size = 8 + + kernel_args = (block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages, l2_swizzle_group_size) + + # a = (torch.rand(M, K, device="cuda", dtype=torch.bfloat16) * 2 - 1) + # b = (torch.rand(K, N, device="cuda", dtype=torch.bfloat16) * 2 - 1) + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + print(gemm_clc_persistent_2cta.get_kernel_source(a, b, *kernel_args)) + c = gemm_clc_persistent_2cta(a, b, *kernel_args) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench(lambda: gemm_clc_persistent_2cta(a, b, *kernel_args), backend="cupti") + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py new file mode 100644 index 0000000000..5a7d820220 --- /dev/null +++ b/examples/gemm_sm100/gemm_tcgen5mma_ws_persistent.py @@ -0,0 +1,298 @@ +# Persistent, num_epi_stages = 2 + +import torch +import tilelang +import tilelang.language as T +from tilelang.carver.arch import driver +from tilelang.profiler import do_bench + + +@tilelang.jit +def gemm_persistent( + A, + B, + block_M, + block_N, + store_block_N, # block_N for C_shared + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + use_tma_store=True, +): + M, N, K = T.const("M, N, K") + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + sm_num = driver.get_num_sms() + m_blocks = T.ceildiv(M, block_M) + n_blocks = T.ceildiv(N, block_N) + assert K % (2 * block_K) == 0 # for simplicity + k_blocks = T.ceildiv(K, block_K) + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 8 + assert n_blocks % (2 * group_size) == 0 # Please adjust group_size if not satisfied + + with T.Kernel(sm_num, threads=256) as (block_id): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N), in_dtype) + C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + loaded = T.alloc_barrier([32] * num_stages) + consumed = T.alloc_barrier([1] * num_stages) + tmem_full = T.alloc_barrier([1] * 2) + tmem_empty = T.alloc_barrier([128] * 2) + + tx = T.get_thread_binding() + + if tx < 32: # warp 0: issue tma + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_blocks): + phase = w * k_blocks + k + T.mbarrier_wait_parity(consumed[phase % num_stages], ((phase // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.tma_copy( + B[k * block_K : (k + 1) * block_K, by * block_N : (by + 1) * block_N], + B_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.mbarrier_arrive(loaded[phase % num_stages]) + + elif tx < 64: # warp 1: issue tcgen5 + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) + for k in T.serial(k_blocks): + phase = w * k_blocks + k + T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + if w & 1 == 0: + T.tcgen05_gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem_0, + False, + False, + mbar=consumed[k % num_stages], + clear_accum=k == 0, + ) + else: + T.tcgen05_gemm( + A_shared[k % num_stages, :, :], + B_shared[k % num_stages, :, :], + C_tmem_1, + False, + False, + mbar=consumed[k % num_stages], + clear_accum=k == 0, + ) + T.tcgen05_mma_arrive(tmem_full[w & 1]) + + elif 128 <= tx < 256: # warp 4~7: epilogue + for w in T.unroll(waves): + tile_id = sm_num * w + block_id + bx = (tile_id // group_size) % m_blocks + by = (tile_id % group_size) + (tile_id // group_size) // m_blocks * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) + if (w & 1) == 0: + T.copy(C_tmem_0, C_local) + else: + T.copy(C_tmem_1, C_local) + T.mbarrier_arrive(tmem_empty[w & 1]) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + return C + + +@tilelang.jit +def gemm_persistent_2cta( + A, + B, + block_M, + block_N, + store_block_N, # block_N for C_shared + block_K, + in_dtype, + out_dtype, + accum_dtype, + num_stages, + use_tma_store=True, +): + M, N, K = T.const("M, N, K") + + A: T.Tensor[[M, K], in_dtype] + B: T.Tensor[[K, N], in_dtype] + C = T.empty((M, N), out_dtype) + + sm_num = driver.get_num_sms() + num_clusters = sm_num // 2 + m_blocks = T.ceildiv(M, block_M) + m_clusters = m_blocks // 2 + n_blocks = T.ceildiv(N, block_N) + assert K % (2 * block_K) == 0 # for simplicity + k_blocks = T.ceildiv(K, block_K) + waves = T.ceildiv(m_blocks * n_blocks, sm_num) + group_size = 8 # in cluster + assert n_blocks % (2 * group_size) == 0 # Please adjust group_size if not satisfied + + with T.Kernel(sm_num, threads=256, cluster_dims=2) as (block_id): + A_shared = T.alloc_shared((num_stages, block_M, block_K), in_dtype) + B_shared = T.alloc_shared((num_stages, block_K, block_N // 2), in_dtype) + C_tmem_0 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_tmem_1 = T.alloc_tmem([block_M, block_N], accum_dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + C_local_cast = T.alloc_fragment((block_M, block_N), out_dtype) + C_shared = T.alloc_shared((block_M, store_block_N), out_dtype) + loaded = T.alloc_cluster_barrier([32 * 2] * num_stages) + consumed = T.alloc_cluster_barrier([1] * num_stages) + tmem_full = T.alloc_cluster_barrier([1] * 2) + tmem_empty = T.alloc_cluster_barrier([128 * 2] * 2) + + tx = T.get_thread_binding() + cta_id = T.block_rank_in_cluster() + T.assume(cta_id < 2) # todo: automatically assume this + + if tx < 32: # warp 0: issue tma + for w in T.unroll(waves): + # manual threadblock swizzle + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + for k in T.serial(k_blocks): + phase = w * k_blocks + k + T.mbarrier_wait_parity(consumed[phase % num_stages], ((phase // num_stages) & 1) ^ 1) + T.tma_copy( + A[bx * block_M : (bx + 1) * block_M, k * block_K : (k + 1) * block_K], + A_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + + T.tma_copy( + B[k * block_K : (k + 1) * block_K, (by * 2 + cta_id) * block_N // 2 : (by * 2 + cta_id + 1) * block_N // 2], + B_shared[phase % num_stages, :, :], + barrier=loaded[phase % num_stages], + ) + T.mbarrier_arrive(loaded[phase % num_stages], 0) + + elif tx < 64 and cta_id == 0: # warp 1: issue tcgen5 + for w in T.unroll(waves): + # manual threadblock swizzle + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_empty[w & 1], ((w // 2) & 1) ^ 1) + for k in T.serial(k_blocks): + phase = w * k_blocks + k + T.mbarrier_wait_parity(loaded[phase % num_stages], (phase // num_stages) & 1) + if w & 1 == 0: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_0, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + else: + T.tcgen05_gemm( + A_shared[phase % num_stages, :, :], + B_shared[phase % num_stages, :, :], + C_tmem_1, + mbar=consumed[phase % num_stages], + clear_accum=k == 0, + use_2cta=True, + ) + T.tcgen05_mma_arrive(tmem_full[w & 1], arrive_2cta=True) + + elif 128 <= tx < 256: # warp 4~7: epilogue + for w in T.unroll(waves): + # manual threadblock swizzle + cluster_id = block_id // 2 + tile_id = num_clusters * w + cluster_id + bx_cluster = (tile_id // group_size) % m_clusters + bx = bx_cluster * 2 + cta_id + by = (tile_id % group_size) + (tile_id // group_size) // m_clusters * group_size + + if bx * block_M < M and by * block_N < N: + T.mbarrier_wait_parity(tmem_full[w & 1], (w // 2) & 1) + if (w & 1) == 0: + T.copy(C_tmem_0, C_local) + else: + T.copy(C_tmem_1, C_local) + T.mbarrier_arrive(tmem_empty[w & 1], 0) + + if use_tma_store: + for i in T.unroll(T.ceildiv(block_N, store_block_N)): + T.copy(C_local[:, i * store_block_N : (i + 1) * store_block_N], C_shared) + T.copy(C_shared, C[bx * block_M, by * block_N + i * store_block_N]) + else: + T.copy(C_local, C_local_cast) + T.copy(C_local_cast, C[bx * block_M, by * block_N]) + + return C + + +def main(): + M, N, K = 8192, 8192, 8192 + block_M, block_N, block_K = 128, 256, 64 + store_block_N = 64 + in_dtype, out_dtype, accum_dtype = T.bfloat16, T.bfloat16, T.float + enable_2cta_tcgen5mma = True + num_stages = 6 if enable_2cta_tcgen5mma else 4 # Each cta only needs to load half of B, enabling larger stages + kernel = gemm_persistent_2cta if enable_2cta_tcgen5mma else gemm_persistent + + a = torch.randn(M, K, device="cuda", dtype=torch.bfloat16) + b = torch.randn(K, N, device="cuda", dtype=torch.bfloat16) + print(kernel.get_kernel_source(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages)) + c = kernel(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages) + + ref_c = (a.to(torch.float) @ b.to(torch.float)).to(torch.bfloat16) + torch.testing.assert_close(c, ref_c, rtol=1e-2, atol=1e-2) + print("All checks passed. ✅") + + tl_latency = do_bench( + lambda: kernel(a, b, block_M, block_N, store_block_N, block_K, in_dtype, out_dtype, accum_dtype, num_stages), backend="cupti" + ) + torch_latency = do_bench(lambda: a @ b, backend="cupti") + print(f"Tilelang latency: {tl_latency} ms") + print(f"Flops: {2 * M * N * K / (tl_latency / 1e3) / 1e12} TFLOPS") + print(f"Torch latency: {torch_latency} ms") + print(f"Flops: {2 * M * N * K / (torch_latency / 1e3) / 1e12} TFLOPS") + + +if __name__ == "__main__": + main() diff --git a/examples/gemm_sp/example_custom_compress.py b/examples/gemm_sp/example_custom_compress.py index 0544b82557..4b03ae83da 100644 --- a/examples/gemm_sp/example_custom_compress.py +++ b/examples/gemm_sp/example_custom_compress.py @@ -7,7 +7,7 @@ from tilelang.utils.sparse import randn_semi_sparse from tilelang.utils.tensor import torch_assert_close -from triton.testing import do_bench +from tilelang.profiler import do_bench import torch @@ -291,19 +291,17 @@ def kernel( return kernel -def main(m=16384, n=16384, k=16384, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=None, cfg="4090"): - if accum_dtype is None: - accum_dtype = T.float - kernel = matmul_sp_fp16_custom_compress(m, n, k, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) +def main(M=1024, N=1024, K=1024, use_cutlass_layout=False, use_torch_compressor=False, accum_dtype=T.float, cfg="4090"): + kernel = matmul_sp_fp16_custom_compress(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype], use_cutlass_layout=use_cutlass_layout) - a = randn_semi_sparse(m, k, device="cuda", dtype=torch.half) - b = torch.randn(k, n, device="cuda", dtype=torch.half) + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) if use_torch_compressor: assert not use_cutlass_layout, "torch sparse must be used with naive layout" a_sparse, e = torch_compress(a) else: - a_sparse, e = compress_kernel(m, k, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) + a_sparse, e = compress_kernel(M, K, 32, 32, T.float16, use_cutlass_layout=use_cutlass_layout)(a) c = kernel(a_sparse, e, b) @@ -316,7 +314,7 @@ def main(m=16384, n=16384, k=16384, use_cutlass_layout=False, use_torch_compress latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * m * n * k + total_flops = 2 * M * N * K tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") @@ -330,8 +328,15 @@ def main(m=16384, n=16384, k=16384, use_cutlass_layout=False, use_torch_compress parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") parser.add_argument("--use_cutlass_layout", action="store_true", help="Use cutlass layout for E tensor") parser.add_argument("--use_torch_compressor", action="store_true", help="Use torch sparse for reference") - parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090"], default="4090") args = parser.parse_args() - accum_dtype = T.float if args.accum_dtype == "float" else T.float16 - main(args.m, args.n, args.k, args.use_cutlass_layout, args.use_torch_compressor, accum_dtype, args.cfg) + main( + M=args.m, + N=args.n, + K=args.k, + use_cutlass_layout=args.use_cutlass_layout, + use_torch_compressor=args.use_torch_compressor, + accum_dtype=args.accum_dtype, + cfg=args.cfg, + ) diff --git a/examples/gemm_sp/example_gemm_sp.py b/examples/gemm_sp/example_gemm_sp.py index 8163c84cc8..769ea67362 100644 --- a/examples/gemm_sp/example_gemm_sp.py +++ b/examples/gemm_sp/example_gemm_sp.py @@ -6,7 +6,7 @@ from tilelang.layout import make_cutlass_metadata_layout from tilelang.utils.sparse import compress, randn_semi_sparse from tilelang.contrib import nvcc -from triton.testing import do_bench +from tilelang.profiler import do_bench import torch @@ -97,13 +97,11 @@ def gemm_sp_fp16( return gemm_sp_fp16 -def main(m=16384, n=16384, k=16384, accum_dtype=None, cfg="4090"): - if accum_dtype is None: - accum_dtype = T.float - kernel = matmul_sp_fp16(m, n, k, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) +def main(M=1024, N=1024, K=1024, accum_dtype=T.float, cfg="h20"): + kernel = matmul_sp_fp16(M, N, K, accum_dtype, **DEFAULT_CONFIG[cfg][accum_dtype]) - a = randn_semi_sparse(m, k, device="cuda", dtype=torch.half) - b = torch.randn(k, n, device="cuda", dtype=torch.half) + a = randn_semi_sparse(M, K, device="cuda", dtype=torch.half) + b = torch.randn(K, N, device="cuda", dtype=torch.half) a_sparse, e = compress(a, transposed=False, block_k=DEFAULT_CONFIG[cfg][accum_dtype]["block_K"], arch=arch) c = kernel(a_sparse, e, b) @@ -117,7 +115,7 @@ def main(m=16384, n=16384, k=16384, accum_dtype=None, cfg="4090"): latency = do_bench(lambda: kernel(a_sparse, e, b)) ref_latency = do_bench(lambda: a @ b) - total_flops = 2 * m * n * k + total_flops = 2 * M * N * K tflops = total_flops / latency / 1e9 ref_tflops = total_flops / ref_latency / 1e9 print(f"Sparse TFLOPS: {tflops:.2f}, Latency: {latency / 1e3} s") @@ -129,8 +127,7 @@ def main(m=16384, n=16384, k=16384, accum_dtype=None, cfg="4090"): parser.add_argument("--m", type=int, default=16384, help="Matrix dimension M") parser.add_argument("--n", type=int, default=16384, help="Matrix dimension N") parser.add_argument("--k", type=int, default=16384, help="Matrix dimension K") - parser.add_argument("--accum_dtype", type=str, default="float", choices=["float", "float16"], help="Accumulation datatype") + parser.add_argument("--accum_dtype", type=str, default=T.float, choices=[T.float, T.float16], help="Accumulation datatype") parser.add_argument("--cfg", type=str, choices=["4090", "h20"], default="4090") args = parser.parse_args() - accum_dtype = T.float if args.accum_dtype == "float" else T.float16 - main(args.m, args.n, args.k, accum_dtype, args.cfg) + main(M=args.m, N=args.n, K=args.k, accum_dtype=args.accum_dtype, cfg=args.cfg) diff --git a/examples/gemm_sp/test_example_gemm_sp.py b/examples/gemm_sp/test_example_gemm_sp.py index fe26df1449..aa1a747f24 100644 --- a/examples/gemm_sp/test_example_gemm_sp.py +++ b/examples/gemm_sp/test_example_gemm_sp.py @@ -4,10 +4,14 @@ import example_gemm_sp +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_custom_compress(): example_custom_compress.main() +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_eq(9, 0) def test_example_gemm_sp(): example_gemm_sp.main() diff --git a/examples/gemm_streamk/example_tilelang_gemm_streamk.py b/examples/gemm_streamk/example_tilelang_gemm_streamk.py index b2e8e93690..48dc175a96 100644 --- a/examples/gemm_streamk/example_tilelang_gemm_streamk.py +++ b/examples/gemm_streamk/example_tilelang_gemm_streamk.py @@ -158,7 +158,7 @@ def main(): False, True, T.float16, - T.float16, + T.float32, # fp32 for atom add T.float32, 2, 64, @@ -166,9 +166,10 @@ def main(): print(kernel.get_kernel_source()) - b_c = torch.zeros((m, n), device="cuda", dtype=torch.float16) + b_c = torch.zeros((m, n), device="cuda", dtype=torch.float32) kernel(A, B, b_c) + b_c = b_c.to(torch.float16) C = torch.matmul(A, B.T) diff --git a/examples/gemv/example_gemv.py b/examples/gemv/example_gemv.py index 8ca77a2e89..ddbe4fd7a6 100644 --- a/examples/gemv/example_gemv.py +++ b/examples/gemv/example_gemv.py @@ -194,7 +194,7 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): @@ -227,10 +227,7 @@ def get_block_template_configs(): rep=20, ) @tl.jit( - pass_configs={ - tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, - tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, - }, + pass_configs={tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True}, out_idx=[2], ) def gemv_alloc_reducer( @@ -304,7 +301,7 @@ def main( C_accum[0] += A_local[k].astype(accum_dtype) * B_local[k].astype(accum_dtype) C_reduced = T.alloc_local((1,), accum_dtype) with T.attr( - T.comm_reducer(lambda x, y: x + y, [T.Cast(accum_dtype, 0)]), + T.comm_reducer(lambda x, y: x + y, [T.cast(0, accum_dtype)]), "reduce_scope", T.reinterpret(T.uint64(0), dtype="handle"), ): diff --git a/examples/grouped_gemm/example_grouped_gemm_bwd.py b/examples/grouped_gemm/example_grouped_gemm_bwd.py index 49cce0d1dd..339f8bc1ae 100644 --- a/examples/grouped_gemm/example_grouped_gemm_bwd.py +++ b/examples/grouped_gemm/example_grouped_gemm_bwd.py @@ -5,7 +5,7 @@ import tilelang.language as T -@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_warp_specialized": True}) def grouped_gemm_fwd(batch_sum, batch_count, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: @@ -157,7 +157,7 @@ def construct_inputs(batch_sizes_list, K, M, trans_b, padding_M, device, dtype): return A, B, C, batch_sizes, batch_offsets, batch_padded_offsets -@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_tma_lower": True, "tl.disable_warp_specialized": True}) +@tilelang.jit(out_idx=[2], pass_configs={"tl.disable_warp_specialized": True}) def grouped_gemm_bwd(batch_sum, batch_count, M, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): """ args: diff --git a/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py new file mode 100644 index 0000000000..d57edcc6ca --- /dev/null +++ b/examples/grouped_gemm/example_grouped_gemm_fwd_ptr.py @@ -0,0 +1,186 @@ +import argparse +import math +import time + +import torch + +import tilelang as tl +import tilelang.language as T + + +def make_ptr_table(tensors): + assert tensors, "pointer table requires at least one tensor" + device = tensors[0].device + return torch.tensor([tensor.data_ptr() for tensor in tensors], device=device, dtype=torch.int64) + + +def torch_grouped_gemm_ptr(a_list, b_list): + assert len(a_list) == len(b_list), "A/B group count mismatch" + outputs = [] + for a, b in zip(a_list, b_list): + assert a.shape[1] == b.shape[0], "incompatible GEMM shapes" + outputs.append(torch.matmul(a, b)) + return outputs + + +def grouped_gemm_ptr(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages=2, threads=128, dtype=T.float16): + # Keep per-group tensors separate and pass them via pointer tables. + # We currently use a common max_M storage shape per group because + # ptr-backed tensors with runtime-varying shapes are not stable enough yet. + # Multi-stage software pipelining on ptr-backed tensors is not correct yet. + # Keep a single-stage pipeline so the ptr path can still use T.copy lowering. + copy_num_stages = 1 + batch_count = len(batch_sizes_list) + max_M = max(batch_sizes_list) + batch_tile_offsets = [0] + for size in batch_sizes_list[:-1]: + batch_tile_offsets.append(batch_tile_offsets[-1] + math.ceil(size / block_M)) + total_m_blocks = sum(math.ceil(size / block_M) for size in batch_sizes_list) + accum_dtype = T.float32 + + @T.prim_func + def kernel( + A_ptrs: T.Tensor([batch_count], T.ptr), + B_ptrs: T.Tensor([batch_count], T.ptr), + C_ptrs: T.Tensor([batch_count], T.ptr), + batch_tile_offsets: T.Tensor([batch_count], T.int32), + ): + with T.Kernel(total_m_blocks, T.ceildiv(N, block_N), threads=threads) as (bx, by): + A_shared = T.alloc_shared((block_M, block_K), dtype) + B_shared = T.alloc_shared((block_K, block_N), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + cur_batch_idx = T.alloc_var(dtype=T.int32) + cur_tile_offset = T.alloc_var(dtype=T.int32) + + cur_batch_idx = 0 + cur_tile_offset = 0 + for i in range(batch_count): + in_cur_batch_idx = bx >= batch_tile_offsets[i] + cur_batch_idx = T.if_then_else(in_cur_batch_idx, i, cur_batch_idx) + cur_tile_offset = T.if_then_else(in_cur_batch_idx, batch_tile_offsets[i], cur_tile_offset) + + m_start = (bx - cur_tile_offset) * block_M + A = T.make_tensor(A_ptrs[cur_batch_idx], (max_M, K), dtype) + B = T.make_tensor(B_ptrs[cur_batch_idx], (K, N), dtype) + C = T.make_tensor(C_ptrs[cur_batch_idx], (max_M, N), dtype) + + T.clear(C_local) + for ko in T.Pipelined(T.ceildiv(K, block_K), num_stages=copy_num_stages): + T.copy(A[m_start, ko * block_K], A_shared) + T.copy(B[ko * block_K, by * block_N], B_shared) + T.gemm(A_shared, B_shared, C_local) + + T.copy(C_local, C[m_start, by * block_N]) + + return kernel + + +def construct_inputs(batch_sizes_list, K, N, block_M, device, dtype): + max_M = max(batch_sizes_list) + batch_tile_offsets_list = [0] + for size in batch_sizes_list[:-1]: + batch_tile_offsets_list.append(batch_tile_offsets_list[-1] + math.ceil(size / block_M)) + # Each group owns an independent padded tensor; nothing is concatenated. + a_list = [torch.zeros(max_M, K, device=device, dtype=dtype) for _ in batch_sizes_list] + b_list = [torch.randn(K, N, device=device, dtype=dtype) for _ in batch_sizes_list] + c_list = [torch.empty(max_M, N, device=device, dtype=dtype) for _ in batch_sizes_list] + for a, size in zip(a_list, batch_sizes_list): + a[:size].copy_(torch.randn(size, K, device=device, dtype=dtype)) + a_ptrs = make_ptr_table(a_list) + b_ptrs = make_ptr_table(b_list) + c_ptrs = make_ptr_table(c_list) + batch_tile_offsets = torch.tensor(batch_tile_offsets_list, device=device, dtype=torch.int32) + return a_list, b_list, c_list, a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets + + +def verify_outputs(outputs, refs, batch_sizes_list, atol=1e-2, rtol=1e-2): + for idx, (out, ref, batch_size) in enumerate(zip(outputs, refs, batch_sizes_list)): + try: + torch.testing.assert_close(out[:batch_size], ref, atol=atol, rtol=rtol) + except AssertionError as err: + raise AssertionError(f"group {idx}: {err}") from err + + +def benchmark(kernel, inputs, warmup=50, rep=100): + for _ in range(warmup): + kernel(*inputs) + torch.cuda.synchronize() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(rep): + kernel(*inputs) + end.record() + torch.cuda.synchronize() + return start.elapsed_time(end) / rep + + +def run_tilelang_grouped_gemm_ptr( + batch_sizes_list, + K, + N, + block_M, + block_N, + block_K, + num_stages=2, + threads=128, + profile=False, +): + device = torch.device("cuda") + dtype = torch.float16 + program = grouped_gemm_ptr(batch_sizes_list, K, N, block_M, block_N, block_K, num_stages, threads) + # The ptr-backed grouped GEMM example is intended to exercise the regular CUDA + # execution path; CuTeDSL does not support these handle tensors. + kernel = tl.compile( + program, + target="cuda", + execution_backend="auto", + pass_configs={"tl.disable_warp_specialized": True}, + ) + a_list, b_list, c_list, a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets = construct_inputs(batch_sizes_list, K, N, block_M, device, dtype) + refs = torch_grouped_gemm_ptr([a[:size] for a, size in zip(a_list, batch_sizes_list)], b_list) + + kernel(a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets) + verify_outputs(c_list, refs, batch_sizes_list) + print("✅ TileLang ptr-grouped-gemm matches PyTorch") + + if profile: + latency = benchmark(kernel, (a_ptrs, b_ptrs, c_ptrs, batch_tile_offsets)) + total_flops = sum(size * K * N * 2 for size in batch_sizes_list) + print(f"Latency: {latency:.4f} ms") + print(f"TFlops: {total_flops / (latency * 1e9):.4f}") + + +def test_grouped_gemm_ptr(): + run_tilelang_grouped_gemm_ptr([16, 33, 64], 128, 96, 32, 32, 32) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--batch_sizes", type=str, default="64,128,256", help="comma-separated per-group M sizes") + parser.add_argument("--K", type=int, default=4096, help="reduce dim") + parser.add_argument("--N", type=int, default=4096, help="output dim") + parser.add_argument("--profile", action="store_true", help="benchmark the kernel") + args = parser.parse_args() + + batch_sizes_list = [int(x.strip()) for x in args.batch_sizes.split(",") if x.strip()] + block_M = 64 + block_N = 128 + block_K = 64 + num_stages = 1 + threads = 256 + + t0 = time.time() + run_tilelang_grouped_gemm_ptr( + batch_sizes_list, + args.K, + args.N, + block_M, + block_N, + block_K, + num_stages=num_stages, + threads=threads, + profile=args.profile, + ) + print(f"End-to-end: {time.time() - t0:.3f} s") diff --git a/examples/grouped_gemm/test_example_grouped_gemm.py b/examples/grouped_gemm/test_example_grouped_gemm.py new file mode 100644 index 0000000000..dc0c945072 --- /dev/null +++ b/examples/grouped_gemm/test_example_grouped_gemm.py @@ -0,0 +1,59 @@ +import tilelang.testing + +import example_grouped_gemm_bwd +import example_grouped_gemm_fwd +import example_grouped_gemm_fwd_ptr + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_fwd_small(): + example_grouped_gemm_fwd.run_tilelang_grouped_gemm( + [5, 9, 13], + K=64, + M=96, + block_M=64, + block_N=64, + block_K=32, + trans_b=False, + num_stages=2, + threads=256, + profile=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_fwd_ptr_small(): + example_grouped_gemm_fwd_ptr.run_tilelang_grouped_gemm_ptr( + [5, 9, 13], + K=64, + N=96, + block_M=64, + block_N=64, + block_K=32, + num_stages=1, + threads=256, + profile=False, + ) + + +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version(9, 0) +def test_example_grouped_gemm_bwd_small(): + example_grouped_gemm_bwd.run_tilelang_grouped_gemm( + [5, 9, 13], + K=64, + M=96, + block_M=64, + block_N=64, + block_K=32, + trans_b=False, + num_stages=2, + threads=256, + profile=False, + ) + + +if __name__ == "__main__": + tilelang.testing.main() diff --git a/examples/kda/FLA_KDA/cumsum.py b/examples/kda/FLA_KDA/cumsum.py new file mode 100644 index 0000000000..0fb3368f6a --- /dev/null +++ b/examples/kda/FLA_KDA/cumsum.py @@ -0,0 +1,469 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, autotune_cache_kwargs, input_guard + +BS_LIST = [32, 64] + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_o = tl.cumsum(b_s, axis=0, reverse=True) + else: + b_o = tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [32, 64, 128, 256] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_scalar_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_nh = tl.program_id(0) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + b_ss = tl.sum(b_s, 0) + if REVERSE: + b_o = -b_o + b_ss + b_s + b_o += b_z + if i_c >= 0: + b_z += b_ss + if HAS_SCALE: + b_o *= scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics( + { + "HAS_SCALE": lambda args: args["scale"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BT": BT}, num_warps=num_warps, num_stages=num_stages) + for BT in [16, 32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [1, 2, 3, 4] + ], + key=["B", "H", "S", "IS_VARLEN", "REVERSE"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_global_cumsum_vector_kernel( + s, + o, + scale, + cu_seqlens, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + HAS_SCALE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + T = eos - bos + + b_z = tl.zeros([BS], dtype=tl.float32) + NT = tl.cdiv(T, BT) + for i_c in range(NT): + i_t = NT - 1 - i_c if REVERSE else i_c + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + if REVERSE: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0, reverse=True) + else: + b_c = b_z[None, :] + tl.cumsum(b_s, axis=0) + if HAS_SCALE: + b_c *= scale + tl.store(p_o, b_c.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + b_z += tl.sum(b_s, 0) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + s=g_org, + o=g, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return g + + +@input_guard +def chunk_global_cumsum_scalar( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = s.shape + else: + B, T, H = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (N * H,) + chunk_global_cumsum_scalar_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum_vector( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = s.shape + else: + B, T, H, S = s.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BS = min(32, triton.next_power_of_2(S)) + + z = torch.empty_like(s, dtype=output_dtype or s.dtype) + grid = (triton.cdiv(S, BS), N * H) + chunk_global_cumsum_vector_kernel[grid]( + s=s, + o=z, + scale=scale, + cu_seqlens=cu_seqlens, + T=T, + B=B, + H=H, + S=S, + BS=BS, + HEAD_FIRST=head_first, + REVERSE=reverse, + ) + return z + + +@input_guard +def chunk_global_cumsum( + s: torch.Tensor, + reverse: bool = False, + cu_seqlens: torch.Tensor = None, + scale: float = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, +) -> torch.Tensor: + if cu_seqlens is not None: + assert s.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(s.shape) == 3: + return chunk_global_cumsum_scalar( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + elif len(s.shape) == 4: + return chunk_global_cumsum_vector( + s=s, + reverse=reverse, + cu_seqlens=cu_seqlens, + scale=scale, + head_first=head_first, + output_dtype=output_dtype, + ) + else: + raise ValueError( + f"Unsupported input shape {s.shape}, " + f"which should be [B, T, H]/[B, T, H, D] if `head_first=False` " + f"or [B, H, T]/[B, H, T, D] otherwise", + ) + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + scale: float = None, + cu_seqlens: torch.Tensor = None, + head_first: bool = False, + output_dtype: torch.dtype = torch.float, + chunk_indices: torch.LongTensor = None, + **kwargs, +) -> torch.Tensor: + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector( + g=g, + chunk_size=chunk_size, + reverse=reverse, + scale=scale, + cu_seqlens=cu_seqlens, + head_first=head_first, + output_dtype=output_dtype, + chunk_indices=chunk_indices, + ) + else: + raise ValueError( + f"Unsupported input shape {g.shape}, which should be (B, T, H, D) if `head_first=False` or (B, H, T, D) otherwise", + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_delta.py b/examples/kda/FLA_KDA/fla_chunk_delta.py new file mode 100644 index 0000000000..3b0fc908d0 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_delta.py @@ -0,0 +1,579 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl +from .fla_utils import prepare_chunk_indices, exp, exp2, USE_CUDA_GRAPH, autotune_cache_kwargs + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + gk, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + if SAVE_NEW_VALUE: + v_new += ((bos * H + i_h) * V).to(tl.int64) + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v + + if SAVE_NEW_VALUE: + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + b_v = b_v * tl.where(m_t, exp2(b_g_last - b_g), 0)[:, None] + b_g_last = exp2(b_g_last) + else: + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 *= b_g_last + if K > 64: + b_h2 *= b_g_last + if K > 128: + b_h3 *= b_g_last + if K > 192: + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k1, mask=(o_k1 < K), other=0.0) + if USE_EXP2: + b_h1 *= exp2(b_gk_last1)[:, None] + else: + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k2, mask=(o_k2 < K), other=0.0) + if USE_EXP2: + b_h2 *= exp2(b_gk_last2)[:, None] + else: + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k3, mask=(o_k3 < K), other=0.0) + if USE_EXP2: + b_h3 *= exp2(b_gk_last3)[:, None] + else: + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + (bos + last_idx) * H * K + i_h * K + o_k4, mask=(o_k4 < K), other=0.0) + if USE_EXP2: + b_h4 *= exp2(b_gk_last4)[:, None] + else: + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v) + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, + "USE_INITIAL_STATE": lambda args: args["dh0"] is not None, + "USE_FINAL_STATE_GRADIENT": lambda args: args["dht"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in ([4, 3, 2]) + for BV in [64, 32] + ], + key=["H", "K", "V", "BT", "BV", "USE_G", "USE_EXP2"], + use_cuda_graph=USE_CUDA_GRAPH, + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64( + q, + k, + w, + g, + gk, + dht, + dh0, + do, + dh, + dv, + dv2, + cu_seqlens, + chunk_offsets, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_GK: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + USE_FINAL_STATE_GRADIENT: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_dh1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_dh2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_dh3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_dh4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + q += ((bos * H + i_h) * K).to(tl.int64) + k += ((bos * H + i_h) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) + do += ((bos * H + i_h) * V).to(tl.int64) + dv += ((bos * H + i_h) * V).to(tl.int64) + dv2 += ((bos * H + i_h) * V).to(tl.int64) + dh += ((boh * H + i_h) * K * V).to(tl.int64) + if USE_GK: + gk += ((bos * H + i_h) * K).to(tl.int64) + + stride_v = H * V + stride_h = H * K * V + stride_k = H * K + if USE_INITIAL_STATE: + dh0 += i_nh * K * V + if USE_FINAL_STATE_GRADIENT: + dht += i_nh * K * V + + if USE_FINAL_STATE_GRADIENT: + p_dht1 = tl.make_block_ptr(dht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_dh1 += tl.load(p_dht1, boundary_check=(0, 1)) + if K > 64: + p_dht2 = tl.make_block_ptr(dht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_dh2 += tl.load(p_dht2, boundary_check=(0, 1)) + if K > 128: + p_dht3 = tl.make_block_ptr(dht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_dh3 += tl.load(p_dht3, boundary_check=(0, 1)) + if K > 192: + p_dht4 = tl.make_block_ptr(dht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_dh4 += tl.load(p_dht4, boundary_check=(0, 1)) + + for i_t in range(NT - 1, -1, -1): + p_dh1 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh1.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh2 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh2.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh3 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh3.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh4 = tl.make_block_ptr(dh + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh4, b_dh4.to(p_dh4.dtype.element_ty), boundary_check=(0, 1)) + + last_idx = min((i_t + 1) * BT, T) - 1 + if USE_G: + bg_last = tl.load(g + (bos + last_idx) * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + if USE_EXP2: + bg_last_exp = exp2(bg_last) + b_g_exp = exp2(b_g) + else: + bg_last_exp = exp(bg_last) + b_g_exp = exp(b_g) + + p_dv = tl.make_block_ptr(dv, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv2 = tl.make_block_ptr(dv2, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + + b_do = tl.load(p_do, boundary_check=(0, 1)) + + # Update dv + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load(gk + last_idx * H * K + o_k1, mask=(o_k1 < K), other=0.0) + b_dv = tl.dot(b_k, b_dh1.to(b_k.dtype)) + + if K > 64: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load(gk + last_idx * H * K + o_k2, mask=(o_k2 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh2.to(b_k.dtype)) + + if K > 128: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load(gk + last_idx * H * K + o_k3, mask=(o_k3 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh3.to(b_k.dtype)) + + if K > 192: + p_k = tl.make_block_ptr(k, (T, K), (stride_k, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + if USE_GK: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load(gk + last_idx * H * K + o_k4, mask=(o_k4 < K), other=0.0) + b_dv += tl.dot(b_k, b_dh4.to(b_k.dtype)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + if USE_EXP2: + b_dv *= tl.where(m_t, exp2(bg_last - b_g), 0)[:, None] + else: + b_dv *= tl.where(m_t, exp(bg_last - b_g), 0)[:, None] + b_dv += tl.load(p_dv, boundary_check=(0, 1)) + + tl.store(p_dv2, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + # Update dh + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + if USE_G: + b_dh1 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh1 *= exp2(b_gk_last1[:, None]) + else: + b_dh1 *= exp(b_gk_last1[:, None]) + b_dh1 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 64: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh2 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh2 *= exp2(b_gk_last2[:, None]) + else: + b_dh2 *= exp(b_gk_last2[:, None]) + b_dh2 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 128: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh3 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh3 *= exp2(b_gk_last3[:, None]) + else: + b_dh3 *= exp(b_gk_last3[:, None]) + b_dh3 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + if K > 192: + p_q = tl.make_block_ptr(q, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + p_w = tl.make_block_ptr(w, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + if USE_G: + b_dh4 *= bg_last_exp + b_q = b_q * b_g_exp[None, :] + if USE_GK: + if USE_EXP2: + b_dh4 *= exp2(b_gk_last4[:, None]) + else: + b_dh4 *= exp(b_gk_last4[:, None]) + b_dh4 += tl.dot(b_q.to(b_q.dtype), b_do.to(b_q.dtype)) * scale - tl.dot(b_w, b_dv.to(b_w.dtype)) + + if USE_INITIAL_STATE: + p_dh0 = tl.make_block_ptr(dh0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh0, b_dh1.to(p_dh0.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_dh1 = tl.make_block_ptr(dh0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh1, b_dh2.to(p_dh1.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_dh2 = tl.make_block_ptr(dh0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh2, b_dh3.to(p_dh2.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_dh3 = tl.make_block_ptr(dh0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_dh3, b_dh4.to(p_dh3.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, u.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + gk=gk, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return h, v_new, final_state + + +def chunk_gated_delta_rule_bwd_dhu( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + do: torch.Tensor, + dv: torch.Tensor, + g: torch.Tensor = None, + gk: torch.Tensor = None, + h0: torch.Tensor = None, + dht: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *q.shape, do.shape[-1] + # N: the actual number of sequences in the batch with either equal or variable lengths + BT = 64 + assert K <= 256, "current kernel does not support head dimension being larger than 256." + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + + dh = q.new_empty(B, NT, H, K, V) + dh0 = torch.empty_like(h0, dtype=torch.float32) if h0 is not None else None + dv2 = torch.empty_like(dv) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_bwd_kernel_dhu_blockdim64[grid]( + q=q, + k=k, + w=w, + g=g, + gk=gk, + dht=dht, + dh0=dh0, + do=do, + dh=dh, + dv=dv, + dv2=dv2, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return dh, dh0, dv2 diff --git a/examples/kda/FLA_KDA/fla_chunk_inter.py b/examples/kda/FLA_KDA/fla_chunk_inter.py new file mode 100644 index 0000000000..e6de9bb28f --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_inter.py @@ -0,0 +1,193 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs, check_shared_mem + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dv, + dw, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + o_t = i_t * BT + tl.arange(0, BT) + m_k = o_k < K + m_t = o_t < T + m_last = o_t == min(T, i_t * BT + BT) - 1 + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dw += (bos * H + i_h) * K + dv += (bos * H + i_h) * V + dg += (bos * H + i_h) * K + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dw = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_dv = tl.load(p_dv, boundary_check=(0, 1)) + b_dw += tl.dot(b_dv.to(b_v.dtype), b_h.to(b_v.dtype)) + + p_dw = tl.make_block_ptr(dw, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dw, -b_dw.to(p_dw.dtype.element_ty), boundary_check=(0, 1)) + + b_dgk *= exp2(b_gn) + b_dq *= scale + b_dq = b_dq * exp2(b_g) + b_dk = b_dk * tl.where(m_t[:, None], exp2(b_gn[None, :] - b_g), 0) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dg = b_q * b_dq - b_k * b_dk + m_last[:, None] * b_dgk + + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_dqkwg( + q: torch.Tensor, + k: torch.Tensor, + w: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: torch.Tensor, + do: torch.Tensor, + dh: torch.Tensor, + dv: torch.Tensor, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K, V = *k.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dq = torch.empty_like(q, dtype=torch.float) + dk = torch.empty_like(k, dtype=torch.float) + dw = torch.empty_like(w) + dg = torch.empty_like(g) + + def grid(meta): + return (triton.cdiv(K, meta["BK"]), NT, B * H) + + chunk_kda_bwd_kernel_inter[grid]( + q=q, + k=k, + v=v, + g=g, + h=h, + do=do, + dh=dh, + dq=dq, + dk=dk, + dv=dv, + dw=dw, + dg=dg, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + ) + return dq, dk, dw, dg diff --git a/examples/kda/FLA_KDA/fla_chunk_intra.py b/examples/kda/FLA_KDA/fla_chunk_intra.py new file mode 100644 index 0000000000..244f05f1c1 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra.py @@ -0,0 +1,650 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import autotune_cache_kwargs, exp2, prepare_chunk_indices +from .cumsum import chunk_local_cumsum + +IS_TF32_SUPPORTED = False +if IS_TF32_SUPPORTED: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32x3") +else: + SOLVE_TRIL_DOT_PRECISION = tl.constexpr("ieee") +SOLVE_TRIL_DOT_PRECISION = tl.constexpr("tf32") +# ============================================================================ +# Fused inter + solve_tril kernel: compute off-diagonal Akk and solve in one pass +# ============================================================================ + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK}, num_warps=num_warps) for BK in [32, 64] for num_warps in [1, 2, 4]], + key=["H", "K", "BC"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_kda_fwd_kernel_inter_solve_fused( + q, + k, + g, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if i_t * BT >= T: + return + + i_tc0 = i_t * BT + i_tc1 = i_t * BT + BC + i_tc2 = i_t * BT + 2 * BC + i_tc3 = i_t * BT + 3 * BC + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + Aqk += (bos * H + i_h) * BT + Akk += (bos * H + i_h) * BT + Akk_diag += (bos * H + i_h) * BC + + m_tc1 = (i_tc1 + tl.arange(0, BC)) < T + m_tc2 = (i_tc2 + tl.arange(0, BC)) < T + m_tc3 = (i_tc3 + tl.arange(0, BC)) < T + + b_Aqk10 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk10 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk20 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk21 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk21 = tl.zeros([BC, BC], dtype=tl.float32) + + b_Aqk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk30 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk31 = tl.zeros([BC, BC], dtype=tl.float32) + b_Aqk32 = tl.zeros([BC, BC], dtype=tl.float32) + b_Akk32 = tl.zeros([BC, BC], dtype=tl.float32) + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k0 = tl.make_block_ptr(k, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + p_g0 = tl.make_block_ptr(g, (K, T), (1, H * K), (i_k * BK, i_tc0), (BK, BC), (0, 1)) + b_kt0 = tl.load(p_k0, boundary_check=(0, 1)).to(tl.float32) + b_gt0 = tl.load(p_g0, boundary_check=(0, 1)).to(tl.float32) + + b_kt1, b_gt1 = b_kt0, b_gt0 + b_kt2, b_gt2 = b_kt0, b_gt0 + if i_tc1 < T: + p_q1 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_k1 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + p_g1 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc1, i_k * BK), (BC, BK), (1, 0)) + + b_q1 = tl.load(p_q1, boundary_check=(0, 1)).to(tl.float32) + b_k1 = tl.load(p_k1, boundary_check=(0, 1)).to(tl.float32) + b_g1 = tl.load(p_g1, boundary_check=(0, 1)).to(tl.float32) + b_kt1 = tl.trans(b_k1) + b_gt1 = tl.trans(b_g1) + + b_gn1 = tl.load(g + i_tc1 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn1 = tl.where(m_tc1[:, None], exp2(b_g1 - b_gn1[None, :]), 0) + b_qg1 = b_q1 * b_gqn1 + b_kg1 = b_k1 * b_gqn1 + b_kgt = b_kt0 * exp2(b_gn1[:, None] - b_gt0) + b_Aqk10 += tl.dot(b_qg1, b_kgt) + b_Akk10 += tl.dot(b_kg1, b_kgt) + + if i_tc2 < T: + p_q2 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_k2 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + p_g2 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc2, i_k * BK), (BC, BK), (1, 0)) + + b_q2 = tl.load(p_q2, boundary_check=(0, 1)).to(tl.float32) + b_k2 = tl.load(p_k2, boundary_check=(0, 1)).to(tl.float32) + b_g2 = tl.load(p_g2, boundary_check=(0, 1)).to(tl.float32) + b_kt2 = tl.trans(b_k2) + b_gt2 = tl.trans(b_g2) + + b_gn2 = tl.load(g + i_tc2 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn2 = tl.where(m_tc2[:, None], exp2(b_g2 - b_gn2[None, :]), 0) + b_qg2 = b_q2 * b_gqn2 + b_kg2 = b_k2 * b_gqn2 + b_kgt = b_kt0 * exp2(b_gn2[:, None] - b_gt0) + b_Aqk20 += tl.dot(b_qg2, b_kgt) + b_Akk20 += tl.dot(b_kg2, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn2[:, None] - b_gt1) + b_Aqk21 += tl.dot(b_qg2, b_kgt) + b_Akk21 += tl.dot(b_kg2, b_kgt) + + if i_tc3 < T: + p_q3 = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_k3 = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + p_g3 = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_tc3, i_k * BK), (BC, BK), (1, 0)) + b_q3 = tl.load(p_q3, boundary_check=(0, 1)).to(tl.float32) + b_k3 = tl.load(p_k3, boundary_check=(0, 1)).to(tl.float32) + b_g3 = tl.load(p_g3, boundary_check=(0, 1)).to(tl.float32) + + b_gn3 = tl.load(g + i_tc3 * H * K + o_k, mask=m_k, other=0).to(tl.float32) + b_gqn3 = tl.where(m_tc3[:, None], exp2(b_g3 - b_gn3[None, :]), 0) + b_qg3 = b_q3 * b_gqn3 + b_kg3 = b_k3 * b_gqn3 + b_kgt = b_kt0 * exp2(b_gn3[:, None] - b_gt0) + b_Aqk30 += tl.dot(b_qg3, b_kgt) + b_Akk30 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt1 * exp2(b_gn3[:, None] - b_gt1) + b_Aqk31 += tl.dot(b_qg3, b_kgt) + b_Akk31 += tl.dot(b_kg3, b_kgt) + + b_kgt = b_kt2 * exp2(b_gn3[:, None] - b_gt2) + b_Aqk32 += tl.dot(b_qg3, b_kgt) + b_Akk32 += tl.dot(b_kg3, b_kgt) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + if i_tc1 < T: + p_Aqk10 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + tl.store(p_Aqk10, (b_Aqk10 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b1 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc1,), (BC,), (0,)) + b_b1 = tl.load(p_b1, boundary_check=(0,)).to(tl.float32) + b_Akk10 = b_Akk10 * b_b1[:, None] + if i_tc2 < T: + p_Aqk20 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Aqk21 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + tl.store(p_Aqk20, (b_Aqk20 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk21, (b_Aqk21 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b2 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc2,), (BC,), (0,)) + b_b2 = tl.load(p_b2, boundary_check=(0,)).to(tl.float32) + b_Akk20 = b_Akk20 * b_b2[:, None] + b_Akk21 = b_Akk21 * b_b2[:, None] + if i_tc3 < T: + p_Aqk30 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Aqk31 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Aqk32 = tl.make_block_ptr(Aqk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + tl.store(p_Aqk30, (b_Aqk30 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk31, (b_Aqk31 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Aqk32, (b_Aqk32 * scale).to(Aqk.dtype.element_ty), boundary_check=(0, 1)) + + p_b3 = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_tc3,), (BC,), (0,)) + b_b3 = tl.load(p_b3, boundary_check=(0,)).to(tl.float32) + b_Akk30 = b_Akk30 * b_b3[:, None] + b_Akk31 = b_Akk31 * b_b3[:, None] + b_Akk32 = b_Akk32 * b_b3[:, None] + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + p_Akk00 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk_diag, (T, BC), (H * BC, 1), (i_tc3, 0), (BC, BC), (1, 0)) + # each diagonal block is stored contiguously: row i of block s is at Akk_diag[t=i_t*BT+s*BC+i, :BC] + b_Ai00 = tl.load(p_Akk00, boundary_check=(0, 1)).to(tl.float32) + b_Ai11 = tl.load(p_Akk11, boundary_check=(0, 1)).to(tl.float32) + b_Ai22 = tl.load(p_Akk22, boundary_check=(0, 1)).to(tl.float32) + b_Ai33 = tl.load(p_Akk33, boundary_check=(0, 1)).to(tl.float32) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + o_i = tl.arange(0, BC) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + + b_Ai00 = -tl.where(m_A, b_Ai00, 0) + b_Ai11 = -tl.where(m_A, b_Ai11, 0) + b_Ai22 = -tl.where(m_A, b_Ai22, 0) + b_Ai33 = -tl.where(m_A, b_Ai33, 0) + + # Forward substitution: load from Akk_diag (stride H*BC, columns 0:BC) + for i in range(2, min(BC, T - i_tc0)): + b_a00 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a00 = tl.where(o_i < i, b_a00, 0.0) + b_a00 += tl.sum(b_a00[:, None] * b_Ai00, 0) + b_Ai00 = tl.where((o_i == i)[:, None], b_a00, b_Ai00) + for i in range(BC + 2, min(2 * BC, T - i_tc0)): + b_a11 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a11 = tl.where(o_i < i - BC, b_a11, 0.0) + b_a11 += tl.sum(b_a11[:, None] * b_Ai11, 0) + b_Ai11 = tl.where((o_i == i - BC)[:, None], b_a11, b_Ai11) + for i in range(2 * BC + 2, min(3 * BC, T - i_tc0)): + b_a22 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a22 = tl.where(o_i < i - 2 * BC, b_a22, 0.0) + b_a22 += tl.sum(b_a22[:, None] * b_Ai22, 0) + b_Ai22 = tl.where((o_i == i - 2 * BC)[:, None], b_a22, b_Ai22) + for i in range(3 * BC + 2, min(4 * BC, T - i_tc0)): + b_a33 = -tl.load(Akk_diag + (i_tc0 + i) * H * BC + o_i) + b_a33 = tl.where(o_i < i - 3 * BC, b_a33, 0.0) + b_a33 += tl.sum(b_a33[:, None] * b_Ai33, 0) + b_Ai33 = tl.where((o_i == i - 3 * BC)[:, None], b_a33, b_Ai33) + + b_Ai00 += m_I + b_Ai11 += m_I + b_Ai22 += m_I + b_Ai33 += m_I + + # ################################################################################ + # # 5. compute merged inverse using off-diagonals + # ################################################################################ + + # we used tf32x3 to maintain matrix inverse's precision whenever possible. + b_Ai10 = -tl.dot(tl.dot(b_Ai11, b_Akk10, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai21 = -tl.dot(tl.dot(b_Ai22, b_Akk21, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + b_Ai32 = -tl.dot(tl.dot(b_Ai33, b_Akk32, input_precision=SOLVE_TRIL_DOT_PRECISION), b_Ai22, input_precision=SOLVE_TRIL_DOT_PRECISION) + + b_Ai20 = -tl.dot( + b_Ai22, + tl.dot(b_Akk20, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk21, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai31 = -tl.dot( + b_Ai33, + tl.dot(b_Akk31, b_Ai11, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai21, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + b_Ai30 = -tl.dot( + b_Ai33, + tl.dot(b_Akk30, b_Ai00, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk31, b_Ai10, input_precision=SOLVE_TRIL_DOT_PRECISION) + + tl.dot(b_Akk32, b_Ai20, input_precision=SOLVE_TRIL_DOT_PRECISION), + input_precision=SOLVE_TRIL_DOT_PRECISION, + ) + + ################################################################################ + # 6. store full Akk_inv to Akk + ################################################################################ + + p_Akk00 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc0, 0), (BC, BC), (1, 0)) + p_Akk10 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, 0), (BC, BC), (1, 0)) + p_Akk11 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc1, BC), (BC, BC), (1, 0)) + p_Akk20 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 0), (BC, BC), (1, 0)) + p_Akk21 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, BC), (BC, BC), (1, 0)) + p_Akk22 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc2, 2 * BC), (BC, BC), (1, 0)) + p_Akk30 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 0), (BC, BC), (1, 0)) + p_Akk31 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, BC), (BC, BC), (1, 0)) + p_Akk32 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 2 * BC), (BC, BC), (1, 0)) + p_Akk33 = tl.make_block_ptr(Akk, (T, BT), (H * BT, 1), (i_tc3, 3 * BC), (BC, BC), (1, 0)) + + tl.store(p_Akk00, b_Ai00.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk10, b_Ai10.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk11, b_Ai11.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk20, b_Ai20.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk21, b_Ai21.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk22, b_Ai22.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk30, b_Ai30.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk31, b_Ai31.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk32, b_Ai32.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_Akk33, b_Ai33.to(Akk.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BK", "NC", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["B", "T"]) +def chunk_kda_bwd_kernel_intra( + q, + k, + g, + beta, + dAqk, + dAkk, + dq, + dq2, + dk, + dk2, + dg, + dg2, + db, + cu_seqlens, + chunk_indices, + B, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BK: tl.constexpr, + NC: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_kc, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + i_k, i_i = i_kc // NC, i_kc % NC + + all = B * T + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + i_ti = i_t * BT + i_i * BC + if i_ti >= T: + return + + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + g += (bos * H + i_h) * K + beta += bos * H + i_h + + dAqk += (bos * H + i_h) * BT + dAkk += (bos * H + i_h) * BT + dq += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + dg2 += (bos * H + i_h) * K + db += (i_k * all + bos) * H + i_h + + p_g = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_g = tl.load(p_g, boundary_check=(0, 1)) + + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_ti,), (BC,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + b_dq2 = tl.zeros([BC, BK], dtype=tl.float32) + b_dk2 = tl.zeros([BC, BK], dtype=tl.float32) + if i_i > 0: + p_gn = g + i_ti * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(0, i_i): + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_dAqk = tl.make_block_ptr(dAqk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + p_dAkk = tl.make_block_ptr(dAkk, (T, BT), (H * BT, 1), (i_ti, i_j * BC), (BC, BC), (1, 0)) + # [BC, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kg = b_k * exp2(b_gn[None, :] - b_gk) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + # [BC, BK] + b_dq2 += tl.dot(b_dAqk, b_kg) + b_dk2 += tl.dot(b_dAkk, b_kg) + b_gqn = exp2(b_g - b_gn[None, :]) + b_dq2 *= b_gqn + b_dk2 *= b_gqn + + o_i = tl.arange(0, BC) + m_dA = (i_ti + o_i) < T + o_dA = (i_ti + o_i) * H * BT + i_i * BC + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC] + b_dAqk = tl.load(dAqk + o_dA + j, mask=m_dA, other=0) + b_dAkk = tl.load(dAkk + o_dA + j, mask=m_dA, other=0) + # [BK] + b_kj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] >= j + # [BC, BK] + b_kgj = b_kj[None, :] * exp2(b_g - b_gkj[None, :]) + b_dq2 += tl.where(m_i, b_dAqk[:, None] * b_kgj, 0.0) + b_dk2 += tl.where(m_i, b_dAkk[:, None] * b_kgj, 0.0) + + p_kj += H * K + p_gkj += H * K + b_db = tl.sum(b_dk2 * b_k, 1) + b_dk2 *= b_b[:, None] + + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dq2 = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_db = tl.make_block_ptr(db, (T,), (H,), (i_ti,), (BC,), (0,)) + + b_dg2 = b_q * b_dq2 + b_dq2 = b_dq2 + tl.load(p_dq, boundary_check=(0, 1)) + tl.store(p_dq2, b_dq2.to(p_dq2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + tl.debug_barrier() + b_dkt = tl.zeros([BC, BK], dtype=tl.float32) + + NC = min(NC, tl.cdiv(T - i_t * BT, BC)) + if i_i < NC - 1: + p_gn = g + (min(i_ti + BC, T) - 1) * H * K + o_k + # [BK,] + b_gn = tl.load(p_gn, mask=m_k, other=0) + for i_j in range(i_i + 1, NC): + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT + i_j * BC, i_k * BK), (BC, BK), (1, 0)) + p_b = tl.make_block_ptr(beta, (T,), (H,), (i_t * BT + i_j * BC,), (BC,), (0,)) + p_dAqk = tl.make_block_ptr(dAqk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + p_dAkk = tl.make_block_ptr(dAkk, (BT, T), (1, H * BT), (i_i * BC, i_t * BT + i_j * BC), (BC, BC), (0, 1)) + # [BC] + b_b = tl.load(p_b, boundary_check=(0,)) + # [BC, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_kb = tl.load(p_k, boundary_check=(0, 1)) * b_b[:, None] + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + # [BC, BC] + b_dAqk = tl.load(p_dAqk, boundary_check=(0, 1)) + b_dAkk = tl.load(p_dAkk, boundary_check=(0, 1)) + + o_j = i_t * BT + i_j * BC + o_i + m_j = o_j < T + # [BC, BK] + b_gkn = tl.where(m_j[:, None], exp2(b_gk - b_gn[None, :]), 0) + b_qg = b_q * b_gkn + b_kbg = b_kb * b_gkn + # [BC, BK] + b_dkt += tl.dot(b_dAqk, b_qg) + tl.dot(b_dAkk, b_kbg) + b_dkt *= exp2(b_gn[None, :] - b_g) + + o_dA = i_ti * H * BT + i_i * BC + o_i + p_qj = q + i_ti * H * K + o_k # [bs, i_ti, i_h*block_h, i_k*bk:(i_k+1)*bk] + p_kj = k + i_ti * H * K + o_k + p_gkj = g + i_ti * H * K + o_k + p_bj = beta + i_ti * H + + for j in range(0, min(BC, T - i_t * BT - i_i * BC)): + # [BC,] + b_dAqk = tl.load(dAqk + o_dA + j * H * BT) + b_dAkk = tl.load(dAkk + o_dA + j * H * BT) + # [BK,] + b_qj = tl.load(p_qj, mask=m_k, other=0).to(tl.float32) + b_kbj = tl.load(p_kj, mask=m_k, other=0).to(tl.float32) * tl.load(p_bj) + b_gkj = tl.load(p_gkj, mask=m_k, other=0).to(tl.float32) + # [BC, BK] + m_i = o_i[:, None] <= j + b_gkq = exp2(b_gkj[None, :] - b_g) + b_dkt += tl.where(m_i, (b_dAkk[:, None] * b_kbj[None, :] + b_dAqk[:, None] * b_qj[None, :]) * b_gkq, 0.0) + + p_qj += H * K + p_kj += H * K + p_gkj += H * K + p_bj += H + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2, (T, K), (H * K, 1), (i_ti, i_k * BK), (BC, BK), (1, 0)) + + b_dg2 += (b_dk2 - b_dkt) * b_k + tl.load(p_dg, boundary_check=(0, 1)) + b_dk2 += tl.load(p_dk, boundary_check=(0, 1)) + b_dk2 += b_dkt + + tl.store(p_dk2, b_dk2.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg2.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_kda_bwd_intra( + q: torch.Tensor, + k: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + dAqk: torch.Tensor, + dAkk: torch.Tensor, + dq: torch.Tensor, + dk: torch.Tensor, + db: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, + chunk_size: int = 64, +): + B, T, H, K = k.shape + BT = chunk_size + BC = min(16, BT) + BK = min(32, triton.next_power_of_2(K)) + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + NC = triton.cdiv(BT, BC) + NK = triton.cdiv(K, BK) + + dq2 = torch.empty_like(q) + dk2 = torch.empty_like(k) + db2 = beta.new_empty(NK, *beta.shape, dtype=torch.float) + dg2 = torch.empty_like(dg, dtype=torch.float) + grid = (NK * NC, NT, B * H) + chunk_kda_bwd_kernel_intra[grid]( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dq2=dq2, + dk=dk, + dk2=dk2, + dg=dg, + dg2=dg2, + db=db2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + B=B, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + BK=BK, + NC=NC, + ) + dq = dq2 + dk = dk2 + db = db2.sum(0).add_(db) + dg = chunk_local_cumsum( + dg2, + chunk_size=chunk_size, + reverse=True, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + ) + + return dq, dk, db, dg + + +def chunk_kda_fwd_inter_solve_fused( + q, + k, + gk, + beta, + Aqk, + Akk_diag, + Akk, + scale, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, K = k.shape + assert K <= 256 + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + BC = 16 + + grid = (NT, B * H) + chunk_kda_fwd_kernel_inter_solve_fused[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk_diag=Akk_diag, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) diff --git a/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py new file mode 100644 index 0000000000..1dba202821 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_intra_token_parallel.py @@ -0,0 +1,168 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# Token-parallel implementation of KDA intra chunk kernel + +import torch +import triton +import triton.language as tl + +from .fla_utils import exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BH": BH}, num_warps=num_warps) for BH in [1, 2, 4, 8] for num_warps in [1, 2, 4, 8]], + key=["K", "H"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T", "N"]) +def chunk_kda_fwd_kernel_intra_token_parallel( + q, + k, + g, + beta, + Aqk, + Akk, + scale, + cu_seqlens, + N, + T, + H: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BC: tl.constexpr, + BH: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_tg, i_hg = tl.program_id(0), tl.program_id(1) + + if IS_VARLEN: + i_n = 0 + left, right = 0, N + + # Unrolled binary search (max B=2^32) + # We can limit iterations based on expected max batch size if needed + # 20 iterations covers B=1M, usually enough + for _ in range(20): + if left < right: + mid = (left + right) // 2 + if i_tg < tl.load(cu_seqlens + mid + 1).to(tl.int32): + right = mid + else: + left = mid + 1 + i_n = left + + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + i_t = i_tg - bos + else: + bos = (i_tg // T) * T + i_t = i_tg % T + + if i_t >= T: + return + + i_c = i_t // BT # chunk indices + i_s = (i_t % BT) // BC # sub_chunk indices + i_tc = i_c * BT # chunk 首坐标 + i_ts = i_tc + i_s * BC # subchunk 首坐标 + + q += bos * H * K + k += bos * H * K + g += bos * H * K + Aqk += bos * H * BT + Akk += bos * H * BC + beta += bos * H + + BK: tl.constexpr = triton.next_power_of_2(K) + o_h = tl.arange(0, BH) + o_k = tl.arange(0, BK) + m_h = (i_hg * BH + o_h) < H + m_k = o_k < K + + p_q = tl.make_block_ptr(q + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_k = tl.make_block_ptr(k + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_g = tl.make_block_ptr(g + i_t * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_beta = tl.make_block_ptr(beta + i_t * H, (H,), (1,), (i_hg * BH,), (BH,), (0,)) + # [BH, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)).to(tl.float32) + b_k = tl.load(p_k, boundary_check=(0, 1)).to(tl.float32) + b_g = tl.load(p_g, boundary_check=(0, 1)).to(tl.float32) + b_k = b_k * tl.load(p_beta, boundary_check=(0,)).to(tl.float32)[:, None] + + for j in range(i_ts, min(i_t + 1, min(T, i_ts + BC))): + p_kj = tl.make_block_ptr(k + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + p_gj = tl.make_block_ptr(g + j * H * K, (H, K), (K, 1), (i_hg * BH, 0), (BH, BK), (1, 0)) + # [BH, BK] + b_kj = tl.load(p_kj, boundary_check=(0, 1)).to(tl.float32) + b_gj = tl.load(p_gj, boundary_check=(0, 1)).to(tl.float32) + + b_kgj = b_kj * exp2(b_g - b_gj) + + b_kgj = tl.where(m_k[None, :], b_kgj, 0.0) + # [BH] + b_Aqk = tl.sum(b_q * b_kgj, axis=1) * scale + b_Akk = tl.sum(b_k * b_kgj, axis=1) * tl.where(j < i_t, 1.0, 0.0) + + tl.store(Aqk + i_t * H * BT + (i_hg * BH + o_h) * BT + j % BT, b_Aqk.to(Aqk.dtype.element_ty), mask=m_h) + tl.store(Akk + i_t * H * BC + (i_hg * BH + o_h) * BC + j - i_ts, b_Akk.to(Akk.dtype.element_ty), mask=m_h) + + +def chunk_kda_fwd_intra_token_parallel( + q: torch.Tensor, + k: torch.Tensor, + gk: torch.Tensor, + beta: torch.Tensor, + Aqk: torch.Tensor, + Akk: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + sub_chunk_size: int = 16, +) -> None: + """ + Token-parallel implementation: each token gets its own thread block. + Supports both fixed-length and variable-length sequences. + Reduces wasted computation on padding. + + Writes directly to Aqk and Akk tensors (in-place). + + Args: + q: [B, T, H, K] + k: [B, T, H, K] + gk: [B, T, H, K] cumsum of gates + beta: [B, T, H] + Aqk: [B, T, H, BT] output tensor to write to + Akk: [B, T, H, BC] output tensor for diagonal blocks (fp32) + scale: attention scale + chunk_size: BT (default 64) + sub_chunk_size: BC (default 16) + """ + B, T, H, K = q.shape + N = len(cu_seqlens) - 1 if cu_seqlens is not None else B + BT = chunk_size + BC = sub_chunk_size + + def grid(meta): + return (B * T, triton.cdiv(H, meta["BH"])) + + chunk_kda_fwd_kernel_intra_token_parallel[grid]( + q=q, + k=k, + g=gk, + beta=beta, + Aqk=Aqk, + Akk=Akk, + scale=scale, + cu_seqlens=cu_seqlens, + N=N, + T=T, + H=H, + K=K, + BT=BT, + BC=BC, + ) + return Aqk, Akk diff --git a/examples/kda/FLA_KDA/fla_chunk_o.py b/examples/kda/FLA_KDA/fla_chunk_o.py new file mode 100644 index 0000000000..c29db9508f --- /dev/null +++ b/examples/kda/FLA_KDA/fla_chunk_o.py @@ -0,0 +1,546 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + + +from .fla_utils import prepare_chunk_indices, exp, exp2, autotune_cache_kwargs, check_shared_mem + + +BK_LIST = [32, 64] if check_shared_mem() else [16, 32] +BV_LIST = [64, 128] if check_shared_mem("ampere") else [16, 32] + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64] + for BV in [64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_fwd_kernel_o( + q, + v, + g, + h, + o, + A, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_EXP2: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_g = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_h = tl.make_block_ptr(h + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_q = (b_q * scale).to(b_q.dtype) + # [BT, BK] + b_g = tl.load(p_g, boundary_check=(0, 1)) + # [BT, BK] + if USE_EXP2: + b_qg = (b_q * exp2(b_g)).to(b_q.dtype) + else: + b_qg = (b_q * exp(b_g)).to(b_q.dtype) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + # works but dkw, owing to divine benevolence + # [BT, BV] + if i_k >= 0: + b_o += tl.dot(b_qg, b_h.to(b_qg.dtype)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + # [BT, BT] + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_A = tl.where(m_s, b_A, 0.0).to(b_v.dtype) + b_o += tl.dot(b_A, b_v) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BK_LIST + for BV in BV_LIST + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dv( + k, + g, + A, + do, + dh, + dv, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_A = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], b_A, 0.0) + # (SY 09/17) important to disallow tf32 here to maintain a good precision. + b_dv = tl.dot(b_A, b_do.to(b_A.dtype), allow_tf32=False) + + for i_k in range(tl.cdiv(K, BK)): + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gk = tl.make_block_ptr(g + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_gn = g + (bos + min(i_t * BT + BT, T) - 1) * H * K + i_h * K + o_k + p_dh = tl.make_block_ptr(dh + (i_tg * H + i_h) * K * V, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + b_gn = exp(tl.load(p_gn, mask=m_k, other=0)[None, :] - b_gk) + b_k = (b_k * b_gn).to(b_k.dtype) + # [BT, BV] + # (SY 09/17) it is ok to have bf16 interchunk gradient contribution here + b_dv += tl.dot(b_k, b_dh.to(b_k.dtype)) + + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps) for BK in BK_LIST for BV in BV_LIST for num_warps in [2, 4, 8]], + key=["BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_inter( + q, + k, + v, + g, + h, + do, + dh, + dq, + dk, + dq2, + dk2, + dg, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_k, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + v += (bos * H + i_h) * V + g += (bos * H + i_h) * K + h += (i_tg * H + i_h) * K * V + do += (bos * H + i_h) * V + dh += (i_tg * H + i_h) * K * V + dq += (bos * H + i_h) * K + dk += (bos * H + i_h) * K + dq2 += (bos * H + i_h) * K + dk2 += (bos * H + i_h) * K + dg += (bos * H + i_h) * K + + p_gk = tl.make_block_ptr(g, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + p_gn = g + (min(T, i_t * BT + BT) - 1) * H * K + o_k + b_gn = tl.load(p_gn, mask=m_k, other=0) + b_dq = tl.zeros([BT, BK], dtype=tl.float32) + b_dk = tl.zeros([BT, BK], dtype=tl.float32) + b_dgk = tl.zeros([BK], dtype=tl.float32) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_h = tl.make_block_ptr(h, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + p_dh = tl.make_block_ptr(dh, (V, K), (1, V), (i_v * BV, i_k * BK), (BV, BK), (0, 1)) + # [BT, BV] + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + # [BV, BK] + b_h = tl.load(p_h, boundary_check=(0, 1)) + b_dh = tl.load(p_dh, boundary_check=(0, 1)) + + # [BK] + b_dgk += tl.sum(b_h * b_dh, axis=0) + # [BT, BK] + b_dq += tl.dot(b_do, b_h.to(b_do.dtype)) + b_dk += tl.dot(b_v, b_dh.to(b_v.dtype)) + + b_dgk *= exp(b_gn) + b_dq *= scale + b_dq = b_dq * exp(b_gk) + b_dk = b_dk * exp(b_gn[None, :] - b_gk) + + p_q = tl.make_block_ptr(q, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dq = tl.make_block_ptr(dq, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_dgk += tl.sum(b_dk * b_k, axis=0) + b_dq += tl.load(p_dq, boundary_check=(0, 1)) + b_dk += tl.load(p_dk, boundary_check=(0, 1)) + b_dg = b_q * b_dq - b_k * b_dk + # tl.debug_barrier() + b_dg = b_dg - tl.cumsum(b_dg, axis=0) + tl.sum(b_dg, axis=0)[None, :] + b_dgk[None, :] + # Buggy due to strange triton compiler issue. + # m_s = tl.where(tl.arange(0, BT)[:, None] <= tl.arange(0, BT)[None, :], 1., 0.) + # b_dg = tl.dot(m_s, b_dg, allow_tf32=False) + b_dgk[None, :] + p_dq = tl.make_block_ptr(dq2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk2, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_dq, b_dq.to(p_dq.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dk, b_dk.to(p_dk.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg, b_dg.to(p_dg.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_fwd_o_gk( + q: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + A: torch.Tensor, + h: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, + use_exp2: bool = False, +): + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_gla_fwd_kernel_o[grid]( + q=q, + v=v, + g=g, + h=h, + o=o, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + USE_EXP2=use_exp2, + ) + return o + + +NUM_WARPS = [2, 4] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_G_GAMMA": lambda args: args["g_gamma"] is not None, + "USE_A": lambda args: args["A"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in NUM_WARPS for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "USE_G"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_bwd_kernel_dv_local( + q, + k, + g, + g_gamma, + A, + do, + dv, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_G_GAMMA: tl.constexpr, + USE_A: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * H + i_h) * K + k += (bos * H + i_h) * K + do += (bos * H + i_h) * V + dv += (bos * H + i_h) * V + + if USE_A: + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] <= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0).to(do.dtype.element_ty) + + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + b_dv = tl.dot(b_A.to(b_do.dtype), b_do) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_bwd_dv_local( + q: torch.Tensor, + k: torch.Tensor, + do: torch.Tensor, + g: torch.Tensor = None, + g_gamma: torch.Tensor = None, + A: torch.Tensor = None, + scale: float = None, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +) -> torch.Tensor: + B, T, H, K, V = *k.shape, do.shape[-1] + BT = chunk_size + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + # H100 can have larger block size + if check_shared_mem("hopper", k.device.index): + CONST_TILING = 128 + elif check_shared_mem: + CONST_TILING = 64 + else: + CONST_TILING = 32 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + dv = torch.empty_like(do) + grid = (NT, B * H) + chunk_bwd_kernel_dv_local[grid]( + q=q, + k=k, + g=g, + g_gamma=g_gamma, + A=A, + do=do, + dv=dv, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return dv + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [1, 2, 4, 8] for num_stages in [2, 3, 4]], + key=["BV", "BT"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gla_bwd_kernel_dA( + v, + do, + dA, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + else: + bos, eos = i_b * T, i_b * T + T + T = eos - bos + + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + for i_v in range(tl.cdiv(V, BV)): + p_do = tl.make_block_ptr(do + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (V, T), (1, H * V), (i_v * BV, i_t * BT), (BV, BT), (0, 1)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_do = tl.load(p_do, boundary_check=(0, 1)) + + b_dA += tl.dot(b_do, b_v) + + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + m_s = tl.arange(0, BT)[:, None] >= tl.arange(0, BT)[None, :] + b_dA = tl.where(m_s, b_dA * scale, 0.0) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gla_bwd_dA( + v: torch.Tensor, + do: torch.Tensor, + scale: float, + cu_seqlens: torch.LongTensor = None, + chunk_size: int = 64, + chunk_indices: torch.LongTensor = None, +): + B, T, H, V = v.shape + BT = chunk_size + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BV = min(64, triton.next_power_of_2(V)) + + dA = v.new_empty(B, T, H, BT, dtype=torch.float32) + grid = (NT, B * H) + chunk_gla_bwd_kernel_dA[grid]( + v=v, + do=do, + dA=dA, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + scale=scale, + T=T, + H=H, + V=V, + BT=BT, + BV=BV, + ) + return dA diff --git a/examples/kda/FLA_KDA/fla_utils.py b/examples/kda/FLA_KDA/fla_utils.py new file mode 100644 index 0000000000..b278aec909 --- /dev/null +++ b/examples/kda/FLA_KDA/fla_utils.py @@ -0,0 +1,240 @@ +import contextlib +import functools +import inspect +import os +import warnings +from collections.abc import Callable +from typing import Any +from packaging import version +from enum import Enum + +import torch +import triton +import triton.language.extra.libdevice as tldevice + + +device = "cuda" +device_torch_lib = getattr(torch, device) + +exp = tldevice.fast_expf +exp2 = tldevice.exp2 +log = tldevice.fast_logf +log2 = tldevice.fast_log2f + +IS_NVIDIA_HOPPER = True and ("NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9) +USE_CUDA_GRAPH = True and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +FLA_CACHE_RESULTS = os.getenv("FLA_CACHE_RESULTS", "1") == "1" +SUPPORTS_AUTOTUNE_CACHE = "cache_results" in inspect.signature(triton.autotune).parameters +autotune_cache_kwargs = {"cache_results": FLA_CACHE_RESULTS} if SUPPORTS_AUTOTUNE_CACHE else {} + + +# error check,copy from +def get_abs_err(x, y): + return (x.detach() - y.detach()).flatten().abs().max().item() + + +def get_err_ratio(x, y): + err = (x.detach() - y.detach()).flatten().square().mean().sqrt().item() + base = (x.detach()).flatten().square().mean().sqrt().item() + return err / (base + 1e-8) + + +def assert_close(prefix, ref, tri, ratio, warning=False, err_atol=1e-6): + abs_atol = get_abs_err(ref, tri) + msg = f"{prefix:>16} diff: {abs_atol:.6f} ratio: {get_err_ratio(ref, tri):.6f}" + print(msg) + error_rate = get_err_ratio(ref, tri) + if abs_atol <= err_atol: + return + if warning or (error_rate < 0.01 or abs_atol <= 0.3): + if error_rate > ratio: + warnings.warn(msg, stacklevel=2) + else: + assert error_rate < ratio, msg + + +def tensor_cache( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent result of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + If the function is called again with the same input tensors, it will return the cached result. + + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + last_args: tuple | None = None + last_kwargs: dict | None = None + last_result: Any = None + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal last_args, last_kwargs, last_result + + if ( + last_args is not None + and last_kwargs is not None + and len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + return last_result + + result = fn(*args, **kwargs) + last_args, last_kwargs, last_result = args, kwargs, result + return result + + return wrapper + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return torch.diff(cu_seqlens) + + +@tensor_cache +def prepare_chunk_indices( + cu_seqlens: torch.LongTensor, + chunk_size: int, +) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +# @functools.cache +# def get_multiprocessor_count(tensor_idx: int = 0) -> int: +# try: +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['multiprocessor_count'] +# except BaseException: +# # Maybe we use a NPU device. +# if triton.runtime.driver.active.get_current_target().backend == 'npu': +# return triton.runtime.driver.active.utils.get_device_properties(tensor_idx)['num_vectorcore'] +# else: +# return 1 +@functools.cache +def get_multiprocessor_count(tensor_idx: int = 0) -> int: + """ + Compatible across Triton versions: + - 2.0.x + - 2.1.0 + - 2.2.x and above + Supports CUDA and NPU. + """ + + # ---- Try the newer Triton 2.2+ API ---- + try: + drv = triton.runtime.driver.active + props = drv.utils.get_device_properties(tensor_idx) + return props.get("multiprocessor_count") or props.get("num_vectorcore") or 1 + except Exception: + pass + + # ---- Fallback: Triton 2.0 / 2.1 API ---- + try: + cuda = triton.runtime.driver.CudaDriver + dev = cuda.get_current_device() + props = cuda.get_device_properties(dev) + return props.get("multiprocessor_count", 1) + except Exception: + pass + + return 1 + + +def input_guard( + fn: Callable[..., torch.Tensor], +) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = custom_device_ctx(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def check_pytorch_version(version_s: str = "2.4") -> bool: + return version.parse(torch.__version__) >= version.parse(version_s) + + +if check_pytorch_version("2.4"): + device = "cuda" + autocast_custom_fwd = functools.partial(torch.amp.custom_fwd, device_type=device) + autocast_custom_bwd = functools.partial(torch.amp.custom_bwd, device_type=device) + + def custom_device_ctx(index: int): + return device_torch_lib.device(index) +else: + assert device == "cuda", "Only cuda device is supported for PyTorch version < 2.4.0." + autocast_custom_fwd = device_torch_lib.amp.custom_fwd + autocast_custom_bwd = device_torch_lib.amp.custom_bwd + + def custom_device_ctx(index: int): + return torch.cuda.device(index) + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/examples/kda/FLA_KDA/fla_wy_fast.py b/examples/kda/FLA_KDA/fla_wy_fast.py new file mode 100644 index 0000000000..a042c2a5fe --- /dev/null +++ b/examples/kda/FLA_KDA/fla_wy_fast.py @@ -0,0 +1,312 @@ +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import torch +import triton +import triton.language as tl + +from .fla_utils import prepare_chunk_indices, exp2, autotune_cache_kwargs + + +@triton.heuristics( + { + "STORE_QG": lambda args: args["qg"] is not None, + "STORE_KG": lambda args: args["kg"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"DOT_PRECISION": DOT_PRECISION}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + for DOT_PRECISION in (["tf32x3", "ieee"]) + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + q, + k, + qg, + kg, + v, + beta, + w, + u, + A, + gk, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + STORE_QG: tl.constexpr, + STORE_KG: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_b = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_b = tl.load(p_b, boundary_check=(0,)) + + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, input_precision=DOT_PRECISION) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_b[:, None] # 乘beta + + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk = tl.load(p_gk, boundary_check=(0, 1)) + b_kb *= exp2(b_gk) + if STORE_QG: + p_q = tl.make_block_ptr(q + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_qg = tl.make_block_ptr(qg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_q = tl.load(p_q, boundary_check=(0, 1)) + b_qg = b_q * exp2(b_gk) + tl.store(p_qg, b_qg.to(p_qg.dtype.element_ty), boundary_check=(0, 1)) + if STORE_KG: + last_idx = min(i_t * BT + BT, T) - 1 + o_k = i_k * BK + tl.arange(0, BK) + m_k = o_k < K + b_gn = tl.load(gk + ((bos + last_idx) * H + i_h) * K + o_k, mask=m_k, other=0.0) # chunk的最后一个g + b_kg = b_k * tl.where((i_t * BT + tl.arange(0, BT) < T)[:, None], exp2(b_gn[None, :] - b_gk), 0) + p_kg = tl.make_block_ptr(kg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + tl.store(p_kg, b_kg.to(p_kg.dtype.element_ty), boundary_check=(0, 1)) + + b_w = tl.dot(b_A, b_kb.to(b_k.dtype)) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.heuristics( + { + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps, num_stages=num_stages) for num_warps in [2, 4] for num_stages in [2, 3, 4]], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], + **autotune_cache_kwargs, +) +@triton.jit(do_not_specialize=["T"]) +def prepare_wy_repr_bwd_kernel( + k, + v, + beta, + gk, + A, + dA, + dw, + du, + dk, + dk2, + dv, + db, + dg, + dg2, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + p_b = tl.make_block_ptr(beta + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_db = tl.make_block_ptr(db + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (BT, T), (1, H * BT), (0, i_t * BT), (BT, BT), (0, 1)) + + b_b = tl.load(p_b, boundary_check=(0,)) + b_db = tl.zeros([BT], dtype=tl.float32) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_dA = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr(k + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk = tl.make_block_ptr(dk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dk2 = tl.make_block_ptr(dk2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dw = tl.make_block_ptr(dw + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg = tl.make_block_ptr(dg + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_dg2 = tl.make_block_ptr(dg2 + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + + # [BT, BK] + b_k = tl.load(p_k, boundary_check=(0, 1)) + p_gk = tl.make_block_ptr(gk + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_gk_exp = exp2(tl.load(p_gk, boundary_check=(0, 1))) + b_kbg = b_k * b_b[:, None] * b_gk_exp + b_dw = tl.load(p_dw, boundary_check=(0, 1)) + + b_dA += tl.dot(b_dw, tl.trans(b_kbg).to(b_dw.dtype)) + b_dkbg = tl.dot(b_A, b_dw) + b_dk = b_dkbg * b_gk_exp * b_b[:, None] + tl.load(p_dk, boundary_check=(0, 1)) + b_db += tl.sum(b_dkbg * b_k * b_gk_exp, 1) + b_dg = b_kbg * b_dkbg + tl.load(p_dg, boundary_check=(0, 1)) + + tl.store(p_dk2, b_dk.to(p_dk2.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_dg2, b_dg.to(p_dg2.dtype.element_ty), boundary_check=(0, 1)) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_dv = tl.make_block_ptr(dv + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_du = tl.make_block_ptr(du + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_b[:, None]).to(b_v.dtype) + b_du = tl.load(p_du, boundary_check=(0, 1)) + b_dA += tl.dot(b_du, tl.trans(b_vb)) + b_dvb = tl.dot(b_A, b_du) + b_dv = b_dvb * b_b[:, None] + b_db += tl.sum(b_dvb * b_v, 1) + tl.store(p_dv, b_dv.to(p_dv.dtype.element_ty), boundary_check=(0, 1)) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_dA = tl.where(m_A, b_dA, 0) + b_dA = tl.dot(b_dA.to(b_A.dtype), b_A) + b_dA = tl.dot(b_A, b_dA.to(b_A.dtype)) + + b_dA = tl.where(m_A, -b_dA, 0) + + # if using gk, save dA first and handle dk in another kernel + p_dA = tl.make_block_ptr(dA + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_dA, b_dA.to(p_dA.dtype.element_ty), boundary_check=(0, 1)) + tl.store(p_db, b_db.to(p_db.dtype.element_ty), boundary_check=(0,)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + A: torch.Tensor, + q: torch.Tensor = None, + gk: torch.Tensor = None, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = A.shape[-1] + BK = 64 + BV = 64 + + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + + w = torch.empty_like(k) + u = torch.empty_like(v) + qg = torch.empty_like(q) if q is not None else None + kg = torch.empty_like(k) if gk is not None else None + recompute_w_u_fwd_kernel[(NT, B * H)]( + q=q, + k=k, + qg=qg, + kg=kg, + v=v, + beta=beta, + w=w, + u=u, + A=A, + gk=gk, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u, qg, kg + + +def prepare_wy_repr_bwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + gk: torch.Tensor, + A: torch.Tensor, + dk: torch.Tensor, + dw: torch.Tensor, + du: torch.Tensor, + dg: torch.Tensor, + cu_seqlens: torch.LongTensor = None, + chunk_indices: torch.LongTensor = None, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + BT = 64 + if chunk_indices is None and cu_seqlens is not None: + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + CONST_TILING = 64 + BK = min(max(triton.next_power_of_2(K), 16), CONST_TILING) + BV = min(max(triton.next_power_of_2(V), 16), CONST_TILING) + + dk2 = torch.empty_like(dk, dtype=torch.float) + dv = torch.empty_like(v) + dg2 = torch.empty_like(gk, dtype=torch.float) + dA = torch.empty_like(A, dtype=torch.float) + db = torch.empty_like(beta, dtype=torch.float) + prepare_wy_repr_bwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + gk=gk, + A=A, + dA=dA, + dw=dw, + du=du, + dk=dk, + dk2=dk2, + dv=dv, + db=db, + dg=dg, + dg2=dg2, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + dk = dk2 + dg = dg2 + + return dk, dv, db, dg, dA diff --git a/examples/kda/README.md b/examples/kda/README.md new file mode 100644 index 0000000000..f445a9f097 --- /dev/null +++ b/examples/kda/README.md @@ -0,0 +1,7 @@ +# KDA kernel implementation with TileLang +## Requirement +- TileLang: 0.1.6.post2+cuda.git729e66ca +- triton: 3.2.0 +- FLA: commit 9714c5(used for comparison) + +We copy the needed files and function from flash-linear-attention to the FLA_KDA/ for easily comparison. diff --git a/examples/kda/chunk_bwd_dqkwg.py b/examples/kda/chunk_bwd_dqkwg.py new file mode 100644 index 0000000000..d3d4df4b44 --- /dev/null +++ b/examples/kda/chunk_bwd_dqkwg.py @@ -0,0 +1,274 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_inter import chunk_kda_bwd_dqkwg +from test_utils_kda import do_bench, compare_tensors + +import torch + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + gate_dtype, +): + BS = S // chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + v_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + w = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + h = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + do = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + dh = torch.randn(B, BS, H, DK, DV, dtype=input_dtype).cuda() + + return q, k, v_new, w, g, h, dv, do, dh + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + gate_dtype, +): + dq = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dk = torch.randn(B, S, H, DK, dtype=torch.float32).cuda() + dw = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, dw, dg + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-4, -3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def chunk_bwd_dqkwg( + B, + S, + H, + DK, + DV, + scale, + chunk_size, + input_dtype, + gate_dtype, + block_DK=32, + block_DV=32, + threads=32, + num_stages=0, +): + block_S = chunk_size + BS = S // block_S + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + H_shape = (B, BS, H, DK, DV) + + @T.prim_func + def kernel( + Q: T.Tensor(K_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + G: T.Tensor(K_shape, dtype=gate_dtype), + h: T.Tensor(H_shape, dtype=input_dtype), + dv: T.Tensor(V_shape, dtype=input_dtype), + DO: T.Tensor(V_shape, dtype=input_dtype), + Dh: T.Tensor(H_shape, dtype=input_dtype), + dq: T.Tensor(K_shape, dtype=T.float32), + dk: T.Tensor(K_shape, dtype=T.float32), + dw: T.Tensor(K_shape, dtype=gate_dtype), + dg: T.Tensor(K_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK), T.ceildiv(S, block_S), B * H, threads=threads) as (bk, bs, bbh): + bb, bh = bbh // H, bbh % H + chunk_last_idx = T.min(S, (bs + 1) * block_S) - 1 + + dgkn_fragment = T.alloc_fragment((block_DK), dtype=T.float32) + dgkn_fragment_tmp = T.alloc_fragment((block_DK,), dtype=T.float32) + dq_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dk_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dw_fragment = T.alloc_fragment((block_S, block_DK), dtype=T.float32) + dgk_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + + h_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dh_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) + dgkn_shared = T.alloc_shared((block_DK, block_DV), dtype=input_dtype) # d of last token in a chunk + V_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + DV_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + G_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) # chunk G + Gn_shared = T.alloc_shared((block_DK), dtype=input_dtype) # chunk last token G + Q_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, block_DK), dtype=input_dtype) + + dkkn_shared = T.alloc_shared((block_S, block_DK), dtype=T.float32) + pp_shared = T.alloc_shared((block_DK), dtype=T.float32) + + T.clear(dgkn_fragment) + T.clear(dq_fragment) + T.clear(dk_fragment) + T.clear(dw_fragment) + + T.copy(G[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], G_shared) + T.copy(G[bb, chunk_last_idx, bh, bk * block_DK : (bk + 1) * block_DK], Gn_shared) + + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(h[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], h_shared) + T.copy(Dh[bb, bs, bh, bk * block_DK : (bk + 1) * block_DK, i_v * block_DV : (i_v + 1) * block_DV], dh_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.copy(dv[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], DV_shared) + # += reduce_sum + for i_k1, i_v1 in T.Parallel(block_DK, block_DV): + dgkn_shared[i_k1, i_v1] = h_shared[i_k1, i_v1] * dh_shared[i_k1, i_v1] + T.reduce_sum(dgkn_shared, dgkn_fragment_tmp, dim=1, clear=True) # [block_DK] + for i_ks in T.Parallel(block_DK): + dgkn_fragment[i_ks] += dgkn_fragment_tmp[i_ks] + T.gemm(DO_shared, h_shared, dq_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(V_shared, dh_shared, dk_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + T.gemm(DV_shared, h_shared, dw_fragment, transpose_B=True, clear_accum=False) # [block_S, block_DK] + # chunk last token + for i_k0 in T.Parallel(block_DK): + dgkn_fragment[i_k0] = dgkn_fragment[i_k0] * T.exp2(Gn_shared[i_k0]) + + for i_s, i_k in T.Parallel(block_S, block_DK): + dw_fragment[i_s, i_k] = -dw_fragment[i_s, i_k] + dq_fragment[i_s, i_k] = dq_fragment[i_s, i_k] * scale * T.exp2(G_shared[i_s, i_k]) + dk_fragment[i_s, i_k] = dk_fragment[i_s, i_k] * T.exp2(Gn_shared[i_k] - G_shared[i_s, i_k]) + + T.copy(dw_fragment, dw[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dq_fragment, dq[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + T.copy(dk_fragment, dk[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + T.copy(Q[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], Q_shared) + T.copy(K[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK], K_shared) + + for i_s2, i_k2 in T.Parallel(block_S, block_DK): + dkkn_shared[i_s2, i_k2] = dk_fragment[i_s2, i_k2] * K_shared[i_s2, i_k2] + T.reduce_sum(dkkn_shared, pp_shared, dim=0, clear=True) + for i_k3 in T.Parallel(block_DK): + pp_shared[i_k3] += dgkn_fragment[i_k3] + + for i_s4, i_k4 in T.Parallel(block_S, block_DK): + dgk_shared[i_s4, i_k4] = ( + Q_shared[i_s4, i_k4] * dq_fragment[i_s4, i_k4] + - K_shared[i_s4, i_k4] * dk_fragment[i_s4, i_k4] + + T.if_then_else(chunk_last_idx == bs * block_S + i_s4, pp_shared[i_k4], 0.0) + ) + + T.copy(dgk_shared, dg[bb, bs * block_S : (bs + 1) * block_S, bh, bk * block_DK : (bk + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + gate_dtype, + qk_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + q, k, v_new, w, g, h, dv, do, dh = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, gate_dtype)) + + dq_ref, dk_ref, dw_ref, dg_ref = chunk_kda_bwd_dqkwg( + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + + dq, dk, dw, dg = prepare_output(B, S, H, DK, DV, chunk_size, getattr(torch, gate_dtype)) + kernel = chunk_bwd_dqkwg( + B=B, S=S, H=H, DK=DK, DV=DV, scale=scale, chunk_size=chunk_size, input_dtype=input_dtype, gate_dtype=gate_dtype + ) + dq, dk, dw, dg = kernel(q, k, v_new, g, h, dv, do, dh) + + compare_tensors("dq", dq_ref, dq) + compare_tensors("dk", dk_ref, dk) + compare_tensors("dw", dw_ref, dw) + compare_tensors("dg", dg_ref, dg) + + fla_time = do_bench( + chunk_kda_bwd_dqkwg, + q=q, + k=k, + v=v_new, + w=w, + g=g, + h=h, + dv=dv, + do=do, + dh=dh, + scale=scale, + ) + tilelang_time = do_bench(kernel, q, k, v_new, g, h, dv, do, dh) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="float32", + gate_dtype="float32", # gate must be float32 + qk_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_dv.py b/examples/kda/chunk_bwd_dv.py new file mode 100644 index 0000000000..cdbe0a899c --- /dev/null +++ b/examples/kda/chunk_bwd_dv.py @@ -0,0 +1,150 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune +import sys # noqa: F401 + +from FLA_KDA.fla_chunk_o import chunk_bwd_dv_local +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + do_dtype, +): + q = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=do_dtype).cuda() + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + A = torch.randn(B, S, H, chunk_size, dtype=input_dtype).cuda() + return q, k, DO, A + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + output_dtype, +): + dv = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dv + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + input_dtype, + output_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + A_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + A: T.Tensor(A_shape, dtype=input_dtype), + dv: T.Tensor(DO_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + A_shared = T.alloc_shared((BS, BS), dtype=do_dtype) + DO_shared = T.alloc_shared((BS, block_DV), dtype=do_dtype) + dv_fragment = T.alloc_fragment((BS, block_DV), dtype=T.float32) + dv_shared = T.alloc_shared((BS, block_DV), dtype=output_dtype) + + T.copy(A[bb, bs * BS : (bs + 1) * BS, bh, :], A_shared) + for i_s1, i_s2 in T.Parallel(BS, BS): + A_shared[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, A_shared[i_s1, i_s2], 0.0) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV], DO_shared) + T.gemm(A_shared, DO_shared, dv_fragment, transpose_A=True, clear_accum=True) # transpose_A: A^T + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv[bb, bs * BS : (bs + 1) * BS, bh, i_v * block_DV : (i_v + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + output_dtype, + chunk_size, +): + q, k, DO, A = prepare_input(B, S, H, DK, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + dv_ref = chunk_bwd_dv_local(q, k, do=DO, A=A) + + dv_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, output_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + input_dtype=input_dtype, + output_dtype=output_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dv_tilelang = kernel(DO, A) + compare_tensors("dv", dv_ref, dv_tilelang) + + fla_time = do_bench(chunk_bwd_dv_local, q, k, do=DO, A=A) + tilelang_time = do_bench(kernel, DO, A) + print("fla_time: ", fla_time) + print("tilelang_time: ", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="float32", + output_dtype="bfloat16", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_gla_dA.py b/examples/kda/chunk_bwd_gla_dA.py new file mode 100644 index 0000000000..913fa9171f --- /dev/null +++ b/examples/kda/chunk_bwd_gla_dA.py @@ -0,0 +1,147 @@ +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_o import chunk_gla_bwd_dA +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(1) + + +def prepare_input( + B, + S, + H, + DV, + chunk_size, + input_dtype, + do_dtype, +): + DO = torch.randn(B, S, H, DV, dtype=do_dtype).cuda() + V_new = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return DO, V_new + + +def prepare_output( + B, + S, + H, + DV, + chunk_size, + d_type, +): + dA = torch.empty(B, S, H, chunk_size, dtype=d_type).cuda() + return dA + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=5) +@tilelang.jit(out_idx=[-1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_bwd_kernel_dv_local( + B, + S, + H, + DV, + scale, + input_dtype, + da_dtype, + do_dtype, + chunk_size, + block_DV=128, + threads=128, + num_stages=1, +): + block_S = BS = chunk_size + DO_shape = (B, S, H, DV) + V_shape = (B, S, H, DV) + dA_shape = (B, S, H, BS) + + @T.prim_func + def kernel( + DO: T.Tensor(DO_shape, dtype=do_dtype), + V: T.Tensor(V_shape, dtype=input_dtype), + dA: T.Tensor(dA_shape, dtype=da_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + do_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + V_shared = T.alloc_shared((block_S, block_DV), dtype=do_dtype) + dA_fragment = T.alloc_fragment((block_S, block_S), dtype=T.float32) + + T.clear(dA_fragment) + for i_v in T.Pipelined(T.ceildiv(DV, block_DV), num_stages=num_stages): + T.copy(DO[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], do_shared) + T.copy(V[bb, bs * block_S : (bs + 1) * block_S, bh, i_v * block_DV : (i_v + 1) * block_DV], V_shared) + T.gemm(do_shared, V_shared, dA_fragment, transpose_B=True) + for i_s1, i_s2 in T.Parallel(block_S, block_S): + dA_fragment[i_s1, i_s2] = T.if_then_else(i_s1 >= i_s2, dA_fragment[i_s1, i_s2] * scale, 0.0) # 下三角矩阵 + T.copy(dA_fragment, dA[bb, bs * block_S : (bs + 1) * block_S, bh, 0:block_S]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + scale, + input_dtype, + do_dtype, + da_dtype, + chunk_size, +): + DO, V_new = prepare_input(B, S, H, DV, chunk_size, getattr(torch, input_dtype), getattr(torch, do_dtype)) + print(DO.dtype, V_new.dtype) + dA_ref = chunk_gla_bwd_dA(v=V_new, do=DO, scale=scale) + + dA_tilelang = prepare_output(B, S, H, DV, chunk_size, getattr(torch, da_dtype)) + kernel = tilelang_chunk_bwd_kernel_dv_local( + B=B, + S=S, + H=H, + DV=DV, + scale=scale, + input_dtype=input_dtype, + da_dtype=da_dtype, + do_dtype=do_dtype, + chunk_size=chunk_size, + ) + dA_tilelang = kernel(DO, V_new) + compare_tensors("dA", dA_ref, dA_tilelang) + fla_time = do_bench(chunk_gla_bwd_dA, v=V_new, do=DO, scale=scale) + tilelang_time = do_bench(kernel, DO, V_new) + print("fla_time:", fla_time) + print("tilelang_time:", tilelang_time) + + +def main(): + run_test( + B=1, + S=1024 * 8, # 32768 + H=64, + DK=128, + DV=128, + scale=1.0, + input_dtype="bfloat16", + do_dtype="bfloat16", + da_dtype="float32", + chunk_size=64, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_bwd_intra.py b/examples/kda/chunk_bwd_intra.py new file mode 100644 index 0000000000..a4aa4f9d43 --- /dev/null +++ b/examples/kda/chunk_bwd_intra.py @@ -0,0 +1,492 @@ +# Reference: FLA_KDA/fla_chunk_intra.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_intra import chunk_kda_bwd_intra +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch + +torch.random.manual_seed(0) +torch.set_printoptions(profile="full") + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + BT = chunk_size + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + g = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + + # dAqk and dAkk are gradients w.r.t. Aqk and Akk + # Shape: (B, S, H, BT) + dAqk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + dAkk = torch.randn(B, S, H, BT, dtype=input_dtype).cuda() + + # Initial gradients (will be updated by the kernel) + dq = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + dk = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + db = torch.randn(B, S, H, dtype=input_dtype).cuda() + dg = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + + return q, k, g, beta, dAqk, dAkk, dq, dk, db, dg + + +def prepare_output( + B, + S, + H, + DK, + chunk_size, + NK, + output_dtype, + gate_dtype, + state_dtype, +): + dq = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + dk = torch.empty(B, S, H, DK, dtype=output_dtype).cuda() + db = torch.empty(NK, B, S, H, dtype=output_dtype).cuda() + dg = torch.empty(B, S, H, DK, dtype=gate_dtype).cuda() + return dq, dk, db, dg + + +def get_configs(): + import itertools + + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3] + _configs = list(itertools.product(threads, num_stages)) + + configs = [{"threads": c[0], "num_stages": c[1]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=5, rep=5) +@tilelang.jit( + out_idx=[-4, -3, -2, -1], +) +def tilelang_chunk_bwd_intra( + # task config + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + # kernel config + block_DK, + block_BC=16, + threads=128, + num_stages=0, +): + BT = chunk_size + BC = block_BC # sub-chunk size, typically 16 + + NC = BT // BC # number of sub-chunks + NT = T.ceildiv(S, BT) + NK = T.ceildiv(DK, block_DK) # number of K blocks + + K_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + G_shape = (B, S, H, DK) + BT_shape = (B, S, H, BT) # for dAqk and dAkk + + dq_shape = (B, S, H, DK) + dk_shape = (B, S, H, DK) + db_shape = (B, S, H) + db2_shape = (NK, B, S, H) + dg_shape = (B, S, H, DK) + + @T.prim_func + def kernel( + # input + q: T.Tensor(K_shape, dtype=input_dtype), + k: T.Tensor(K_shape, dtype=input_dtype), + g: T.Tensor(G_shape, dtype=gate_dtype), + beta: T.Tensor(Beta_shape, dtype=input_dtype), + dAqk: T.Tensor(BT_shape, dtype=input_dtype), + dAkk: T.Tensor(BT_shape, dtype=input_dtype), + dq: T.Tensor(dq_shape, dtype=input_dtype), + dk: T.Tensor(dk_shape, dtype=input_dtype), + db: T.Tensor(db_shape, dtype=input_dtype), + dg: T.Tensor(dg_shape, dtype=gate_dtype), + # output + dq2: T.Tensor(dq_shape, dtype=output_dtype), + dk2: T.Tensor(dk_shape, dtype=output_dtype), + db2: T.Tensor(db2_shape, dtype=output_dtype), + dg2: T.Tensor(dg_shape, dtype=gate_dtype), + ): + with T.Kernel(T.ceildiv(DK, block_DK) * NC, NT, B * H, threads=threads) as (i_kc, i_t, i_bh): + i_k, i_i = i_kc // NC, i_kc % NC + bb, bh = i_bh // H, i_bh % H + + # actual sub-chunk index + i_ti = i_t * BT + i_i * BC + + # current sub-chunk data + q_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + beta_shared = T.alloc_shared((BC,), dtype=input_dtype) + g_current_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + gn_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) # last token's g in current sub-chunk + + dq_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dk_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + dg_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + + # Allocate fragments + dq2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dk2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + dg2_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + db_fragment = T.alloc_fragment((BC,), dtype=accum_dtype) + + # Initialize fragments + T.clear(dq2_fragment) + T.clear(dk2_fragment) + T.clear(dg2_fragment) + T.clear(db_fragment) + + # Temporary shared memory for previous sub-chunks + k_prev_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_prev_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + dAqk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_prev_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragment for b_kg computation + kg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + + kj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + gkj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kgj_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + dAqk_col = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col = T.alloc_shared((BC,), dtype=input_dtype) + + # Load g, q, k for current sub-chunk + T.copy(q[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_shared) + T.copy(k[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_shared) + T.copy(g[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_current_shared) + T.copy(beta[bb, i_ti : i_ti + BC, bh], beta_shared) + + if i_i > 0: + chunk_first_idx = i_ti # chunk first token idx + + T.copy(g[bb, chunk_first_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_shared) # Get the first token's g value (b_gn) + + # Loop over previous sub-chunks (i_j from 0 to i_i-1) + # Since i_i is computed from i_kc % NC and NC is small, we can use conditional blocks + # Process each possible previous sub-chunk with conditional execution + for i_j in T.Pipelined(i_i, num_stages=num_stages): # i_j is index ofprevious sub_chunks + prev_ti = i_t * BT + i_j * BC + T.copy(k[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_prev_shared) + T.copy(g[bb, prev_ti : prev_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_prev_shared) + + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAqk_prev_shared) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_j * BC : (i_j + 1) * BC], dAkk_prev_shared) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kg_fragment[i_bc, i_k2] = k_prev_shared[i_bc, i_k2] * T.exp2(gn_shared[i_k2] - g_prev_shared[i_bc, i_k2]) + + T.gemm(dAqk_prev_shared, kg_fragment, dq2_fragment, clear_accum=False) + T.gemm(dAkk_prev_shared, kg_fragment, dk2_fragment, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gqn = T.exp2(g_current_shared[i_bc, i_k2] - gn_shared[i_k2]) + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] * gqn + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * gqn + + # Process current sub-chunk diagonal + loop_length = T.min(BC, S - i_t * BT - i_i * BC) + for j in T.Pipelined(loop_length, num_stages=num_stages): + token_j_idx = i_ti + j + + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gkj_shared) + T.copy(dAqk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAqk_col) + T.copy(dAkk[bb, i_ti : i_ti + BC, bh, i_i * BC + j], dAkk_col) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + kgj_fragment[i_bc, i_k2] = kj_shared[i_k2] * T.exp2(g_current_shared[i_bc, i_k2] - gkj_shared[i_k2]) + dq2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAqk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + dk2_fragment[i_bc, i_k2] += T.if_then_else(i_bc >= j, dAkk_col[i_bc] * kgj_fragment[i_bc, i_k2], 0.0) + + # Compute b_db = sum(b_dk2 * b_k, dim=1) + dk2_k_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_k_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * k_shared[i_bc, i_k2] + T.reduce_sum(dk2_k_fragment, db_fragment, dim=1, clear=True) + + # b_dk2 *= b_b[:, None] + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] = dk2_fragment[i_bc, i_k2] * beta_shared[i_bc] + + # Compute b_dg2 = b_q * b_dq2 (before adding dq to dq2) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = q_shared[i_bc, i_k2] * dq2_fragment[i_bc, i_k2] + + # Load dq and compute b_dq2 = b_dq2 + b_dq + T.copy(dq[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dq_shared) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dq2_fragment[i_bc, i_k2] = dq2_fragment[i_bc, i_k2] + dq_shared[i_bc, i_k2] + + # # Store results + T.copy(dq2_fragment, dq2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(db_fragment, db2[i_k, bb, i_ti : i_ti + BC, bh]) + + # Initialize dkt_fragment for processing subsequent sub-chunks and lower triangular part + dkt_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + T.clear(dkt_fragment) + + # Temporary shared memory for subsequent sub-chunks + q_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + k_next_shared = T.alloc_shared((BC, block_DK), dtype=input_dtype) + g_next_shared = T.alloc_shared((BC, block_DK), dtype=gate_dtype) + beta_next_shared = T.alloc_shared((BC,), dtype=input_dtype) + dAqk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + dAkk_next_shared = T.alloc_shared((BC, BC), dtype=input_dtype) + + # Temporary fragments for computation + gkn_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + qg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + kbg_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + kbg_shared = T.alloc_shared((BC, block_DK), dtype=accum_dtype) + dkt_temp_fragment = T.alloc_fragment((BC, block_DK), dtype=accum_dtype) + # T.use_swizzle(10) + + NC_actual = T.min(NC, T.ceildiv(S - i_t * BT, BC)) # Process subsequent sub-chunks (i_j from i_i+1 to NC-1) + if i_i < NC_actual - 1: + # Get the last token's g value in current sub-chunk + chunk_last_idx = T.min(S, i_ti + BC) - 1 + gn_last_shared = T.alloc_shared((block_DK,), dtype=gate_dtype) + T.copy(g[bb, chunk_last_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gn_last_shared) + + # Loop over subsequent sub-chunks + for i_j in T.Pipelined(i_i + 1, NC_actual, num_stages=num_stages): + i_tj = i_t * BT + i_j * BC + + T.copy(q[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], q_next_shared) + T.copy(k[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], k_next_shared) + T.copy(g[bb, i_tj : i_tj + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], g_next_shared) + T.copy(beta[bb, i_tj : i_tj + BC, bh], beta_next_shared) + + T.copy(dAqk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAqk_next_shared) # [BC, BC] need transpose + T.copy(dAkk[bb, i_tj : i_tj + BC, bh, i_i * BC : (i_i + 1) * BC], dAkk_next_shared) # [BC, BC] need transpose + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + # kbg = k * beta + kbg_fragment[i_bc, i_k2] = k_next_shared[i_bc, i_k2] * beta_next_shared[i_bc] + gkn_shared[i_bc, i_k2] = T.if_then_else( + i_tj + i_bc < S, T.exp2(g_next_shared[i_bc, i_k2] - gn_last_shared[i_k2]), 0.0 + ) + + # Compute qg and kbg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + qg_shared[i_bc, i_k2] = q_next_shared[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + kbg_shared[i_bc, i_k2] = kbg_fragment[i_bc, i_k2] * gkn_shared[i_bc, i_k2] + + # Accumulate: dkt += dAqk^T @ qg + dAkk^T @ kbg + # Use transpose_A=True because dAqk/dAkk are loaded in (T, BT) layout but we need (BT, T) for gemm + T.gemm(dAqk_next_shared, qg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=True) + T.gemm(dAkk_next_shared, kbg_shared, dkt_temp_fragment, transpose_A=True, clear_accum=False) + + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] + dkt_temp_fragment[i_bc, i_k2] + + # Scale dkt by exp2(gn_last - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + g_scale = T.exp2(gn_last_shared[i_k2] - g_current_shared[i_bc, i_k2]) + dkt_fragment[i_bc, i_k2] = dkt_fragment[i_bc, i_k2] * g_scale + + # Process lower triangular part of current sub-chunk diagonal + # This corresponds to j <= i_bc in the diagonal block + qj_shared = T.alloc_shared((block_DK,), dtype=T.float32) + kj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + gj_shared_lower = T.alloc_shared((block_DK,), dtype=T.float32) + bj_local = T.alloc_local((1), dtype=input_dtype) + dAqk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + dAkk_col_lower = T.alloc_shared((BC,), dtype=input_dtype) + + gkq_fragment = T.alloc_fragment((BC, block_DK), dtype=T.float32) + # dkt_lower_temp = T.alloc_fragment((BC, block_DK), dtype=T.float32) + kbj_fragment = T.alloc_fragment((block_DK,), dtype=T.float32) + + max_token_j_idx = T.min(S, i_ti + BC) + for j in T.Pipelined(BC, num_stages=num_stages): + token_j_idx = i_ti + j + + if token_j_idx < max_token_j_idx: + T.copy(q[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], qj_shared) # [BK] + T.copy(k[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], kj_shared_lower) + T.copy(g[bb, token_j_idx, bh, i_k * block_DK : (i_k + 1) * block_DK], gj_shared_lower) + + bj_local[0] = beta[bb, token_j_idx, bh] + T.copy(dAqk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAqk_col_lower) # [BC] + T.copy(dAkk[bb, token_j_idx, bh, i_i * BC : (i_i + 1) * BC], dAkk_col_lower) + + # Compute kbj = kj * bj + for i_k2 in T.Parallel(block_DK): + kbj_fragment[i_k2] = kj_shared_lower[i_k2] * bj_local[0] + # Compute gkq = exp2(gj - g_current) + for i_bc, i_k2 in T.Parallel(BC, block_DK): + gkq_fragment[i_bc, i_k2] = T.exp2(gj_shared_lower[i_k2] - g_current_shared[i_bc, i_k2]) + + # Accumulate: dkt += (dAkk * kbj + dAqk * qj) * gkq for i_bc <= j + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dkt_fragment[i_bc, i_k2] += T.if_then_else( + i_bc <= j, + (dAkk_col_lower[i_bc] * kbj_fragment[i_k2] + dAqk_col_lower[i_bc] * qj_shared[i_k2]) * gkq_fragment[i_bc, i_k2], + 0.0, + ) + + # Load dk and dg + T.copy(dk[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dk_shared) + T.copy(dg[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], dg_shared) + + # Update dg2: dg2 += (dk2 - dkt) * k + dg + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dg2_fragment[i_bc, i_k2] = ( + dg2_fragment[i_bc, i_k2] + + (dk2_fragment[i_bc, i_k2] - dkt_fragment[i_bc, i_k2]) * k_shared[i_bc, i_k2] + + dg_shared[i_bc, i_k2] + ) + + # Update dk2: dk2 += dk + dkt + for i_bc, i_k2 in T.Parallel(BC, block_DK): + dk2_fragment[i_bc, i_k2] += dk_shared[i_bc, i_k2] + dkt_fragment[i_bc, i_k2] + + # Store dk2 and dg2 + T.copy(dk2_fragment, dk2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + T.copy(dg2_fragment, dg2[bb, i_ti : i_ti + BC, bh, i_k * block_DK : (i_k + 1) * block_DK]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + threads=128, + num_stages=0, + cu_seqlens=None, + chunk_indices=None, +): + q, k, g, beta, dAqk, dAkk, dq, dk, db, dg = prepare_input( + B, + S, + H, + DK, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + # Reference implementation + dq_ref, dk_ref, db_ref, dg_ref = chunk_kda_bwd_intra( + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + block_DK = min(64, tilelang.math.next_power_of_2(DK)) + NK = (DK + block_DK - 1) // block_DK + # TileLang implementation + kernel = tilelang_chunk_bwd_intra( + B=B, + S=S, + H=H, + DK=DK, + input_dtype=input_dtype, + output_dtype=output_dtype, + accum_dtype=accum_dtype, + gate_dtype=gate_dtype, + state_dtype=state_dtype, + chunk_size=chunk_size, + block_DK=block_DK, + ) + + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = prepare_output( + B, S, H, DK, chunk_size, NK, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + dq_tilelang, dk_tilelang, db_tilelang, dg_tilelang = kernel(q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + db_tilelang = db_tilelang.sum(0).add_(db) + dg_tilelang = chunk_local_cumsum( + dg_tilelang, + chunk_size=chunk_size, + reverse=True, + ) + + compare_tensors("dq", dq_tilelang, dq_ref) + compare_tensors("dk", dk_tilelang, dk_ref) + compare_tensors("db", db_tilelang, db_ref) + compare_tensors("dg", dg_tilelang, dg_ref) + + fla_time = do_bench( + chunk_kda_bwd_intra, + q=q, + k=k, + g=g, + beta=beta, + dAqk=dAqk, + dAkk=dAkk, + dq=dq, + dk=dk, + db=db, + dg=dg, + ) + tilelang_time = do_bench(kernel, q, k, g, beta, dAqk, dAkk, dq, dk, db, dg) + print(f"Fla time: {fla_time}") + print(f"Tilelang time: {tilelang_time}") + + +def main(): + DK = 128 + run_test( + B=1, + S=8192, + H=8, + DK=DK, + input_dtype=T.float32, + output_dtype=T.float32, + accum_dtype=T.float32, + gate_dtype=T.float32, + state_dtype=T.float32, + chunk_size=64, + threads=128, + num_stages=0, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_bwd.py b/examples/kda/chunk_delta_bwd.py new file mode 100644 index 0000000000..8c22488ca4 --- /dev/null +++ b/examples/kda/chunk_delta_bwd.py @@ -0,0 +1,309 @@ +# Reference: fla/ops/common/chunk_delta_h.py +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_bwd_dhu +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import do_bench, compare_tensors + +import torch +import torch.nn.functional as F + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, +): + Q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() * 0.01 + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + # Note: G should be in logspace and do chunkwise cumsum + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + + h0 = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dht = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + dO = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() * 0.01 + + dv = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + return Q, K, W, G, h0, dht, dO, dv + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + gate_dtype, + state_dtype, +): + BS = S // chunk_size + dh = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + dh0 = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + dv2 = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return dh, dh0, dv2 + + +def get_configs(): + import itertools + + block_DV = [32, 64, 128] + threads = [32, 64, 128, 256] + num_stages = [0, 1, 2, 3, 4] + _configs = list(itertools.product(block_DV, threads, num_stages)) + + configs = [{"block_DV": c[0], "threads": c[1], "num_stages": c[2]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=10, rep=10) +@tilelang.jit(out_idx=[-3, -2, -1]) +def tilelang_chunk_gated_delta_rule_bwd_dhu( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + # kernel config + block_DV=64, + threads=256, + num_stages=0, +): + block_S = chunk_size + # Should support cu_seqlen + BS = S // block_S + + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + W_shape = (B, S, H, DK) + G_shape = (B, S, H, DK) + h0_shape = (B, H, DK, DV) + dht_shape = (B, H, DK, DV) + dO_shape = (B, S, H, DV) + dv_shape = (B, S, H, DV) + + dh_shape = (B, BS, H, DK, DV) + dh0_shape = (B, H, DK, DV) + dv2_shape = (B, S, H, DV) + + @T.prim_func + def kernel( + # Input + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + GK: T.Tensor(G_shape, dtype=gate_dtype), + h0: T.Tensor(h0_shape, dtype=input_dtype), + dht: T.Tensor(dht_shape, dtype=input_dtype), + dO: T.Tensor(dO_shape, dtype=input_dtype), + dv: T.Tensor(dv_shape, dtype=input_dtype), + # Output + dh: T.Tensor(dh_shape, dtype=output_dtype), + dh0: T.Tensor(dh0_shape, dtype=state_dtype), + dv2: T.Tensor(dv2_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_dh_shared = T.alloc_shared((DK, block_DV), dtype=output_dtype) + b_dh_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_1 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + b_dh_fragment_2 = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + dv_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + dv_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dv_fragment_2 = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + dO_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + Q_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + + GK_last_shared = T.alloc_shared((DK,), dtype=gate_dtype) + + if use_final_state_gradient: + T.copy(dht[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_dh_shared) + T.copy(b_dh_shared, b_dh_fragment) + else: + T.clear(b_dh_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # The gradient should be stored in the reverse order + i_s_inv = T.ceildiv(S, block_S) - i_s - 1 # reverse indices + # Store the updated dh + T.copy(b_dh_fragment, b_dh_shared) + T.copy(b_dh_shared, dh[bb, i_s_inv, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + # Update dv + T.copy(K[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], K_shared) + T.gemm(K_shared, b_dh_shared, dv_fragment, clear_accum=True) + T.copy( + dv[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dv_shared + ) # copy old dv + T.copy(dv_shared, dv_fragment_2) + for i_s2, i_v in T.Parallel(block_S, block_DV): + dv_fragment[i_s2, i_v] = dv_fragment[i_s2, i_v] + dv_fragment_2[i_s2, i_v] + # Store the updated dv + T.copy(dv_fragment, dv_shared) + T.copy(dv_shared, dv2[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + # Update dh + T.copy(Q[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], Q_shared) # [block_S, DK] + T.copy(W[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, 0:DK], W_shared) # [block_S, DK] + T.copy( + dO[bb, i_s_inv * block_S : (i_s_inv + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], dO_shared + ) # [block_S, block_DV] + + if use_gk: + last_idx = T.min((i_s_inv + 1) * block_S, S) - 1 # chunk last token gk + T.copy(GK[bb, last_idx, bh, :], GK_last_shared) + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + T.gemm(Q_shared, dO_shared, b_dh_fragment_1, transpose_A=True, clear_accum=True) # [DK, block_DV] + + # dv_shared: [block_S, block_DV] + T.gemm(W_shared, dv_shared, b_dh_fragment_2, transpose_A=True, clear_accum=True) # [DK, block_DV] + for i_k, i_v in T.Parallel(DK, block_DV): + b_dh_fragment[i_k, i_v] += b_dh_fragment_1[i_k, i_v] * scale - b_dh_fragment_2[i_k, i_v] + + if use_initial_state: + T.copy(b_dh_fragment, dh0[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=64, + threads=256, + num_stages=0, + use_torch=False, +): + Q, K, W, G, h0, dht, dO, dv = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + getattr(torch, state_dtype), + ) + + dh_tilelang, dh0_tilelang, dv2_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, gate_dtype), getattr(torch, state_dtype) + ) + + # fla ref + print("fla running...", flush=True) + if use_gk: + dh_ref, dh0_ref, dv2_ref = chunk_gated_delta_rule_bwd_dhu( + q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, use_exp2=True + ) + + # tilelang + print("tilelang running...", flush=True) + kernel = tilelang_chunk_gated_delta_rule_bwd_dhu( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + scale, + use_gk, + use_initial_state, + use_final_state_gradient, + ) + dh_tilelang, dh0_tilelang, dv2_tilelang = kernel(Q, K, W, G, h0, dht, dO, dv) + + fla_time = do_bench( + chunk_gated_delta_rule_bwd_dhu, q=Q, k=K, w=W, do=dO, dv=dv, gk=G, h0=h0, dht=dht, scale=scale, chunk_size=chunk_size + ) + tilelang_time = do_bench(kernel, Q, K, W, G, h0, dht, dO, dv) + + print(f"fla time: {fla_time} ms") + print(f"tilelang time: {tilelang_time} ms") + + compare_tensors("dh", dh_ref, dh_tilelang) + compare_tensors("dh0", dh0_ref, dh0_tilelang) + compare_tensors("dv2", dv2_ref, dv2_tilelang) + + +def main(): + DK = 128 + run_test( + B=1, + S=1024 * 8, + H=64, + DK=DK, + DV=128, + input_dtype="bfloat16", + output_dtype="bfloat16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + scale=DK**-0.5, + use_gk=True, + use_initial_state=True, + use_final_state_gradient=True, + block_DV=32, + threads=128, + num_stages=1, + use_torch=False, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_delta_h_fwd.py b/examples/kda/chunk_delta_h_fwd.py new file mode 100644 index 0000000000..fbb8bd9882 --- /dev/null +++ b/examples/kda/chunk_delta_h_fwd.py @@ -0,0 +1,306 @@ +# Reference: fla/ops/common/chunk_delta_h.py + +import sys # noqa: F401 +import tilelang +import tilelang.language as T +from tilelang.autotuner import autotune + +# Add your fla repository path to sys.path +# Currently we use the fla repository from the flash-linear-attention project at commit id f03cb3ae +# sys.path.insert(0, "/your/path/to/flash-linear-attention") + +from FLA_KDA.fla_chunk_delta import chunk_gated_delta_rule_fwd_h +from FLA_KDA.cumsum import chunk_local_cumsum + +import torch +import torch.nn.functional as F + +from test_utils_kda import compare_tensors, do_bench + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + K = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + K = F.normalize(K, dim=-1, p=2) + W = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + W = F.normalize(W, dim=-1, p=2) + U = torch.randn(B, S, H, DV, dtype=input_dtype).cuda() + U = F.normalize(U, dim=-1, p=2) + G = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() + G = F.logsigmoid(G) + G = chunk_local_cumsum(G, chunk_size) + initial_state = torch.randn(B, H, DK, DV, dtype=input_dtype).cuda() + return K, W, U, G, initial_state + + +def prepare_output( + B, + S, + H, + DK, + DV, + chunk_size, + output_dtype, + state_dtype, +): + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + h = torch.empty(B, BS, H, DK, DV, dtype=output_dtype).cuda() + final_state = torch.empty(B, H, DK, DV, dtype=state_dtype).cuda() + V_new = torch.empty(B, S, H, DV, dtype=output_dtype).cuda() + return h, final_state, V_new + + +def get_configs(): + import itertools + + block_DK = [32, 64, 128] + block_DV = [32, 64, 128] + threads = [128, 256] + num_stages = [1, 2, 3] + _configs = list(itertools.product(block_DK, block_DV, threads, num_stages)) + + configs = [{"block_DK": c[0], "block_DV": c[1], "threads": c[2], "num_stages": c[3]} for c in _configs] + return configs + + +@autotune(configs=get_configs(), warmup=3, rep=5) +@tilelang.jit(out_idx=[-3, -2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_gated_delta_rule_fwd_h( + # task config + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + # kernel config + block_DK=64, + block_DV=32, + threads=128, + num_stages=1, +): + block_S = chunk_size + BS = (S + chunk_size - 1) // chunk_size # ceildiv to match kernel iteration + + K_shape = (B, S, H, DK) + V_shape = (B, S, H, DV) + W_shape = (B, S, H, DK) + U_shape = (B, S, H, DV) + GK_shape = (B, S, H, DK) + h_shape = (B, BS, H, DK, DV) + initial_state_shape = (B, H, DK, DV) + final_state_shape = (B, H, DK, DV) + + @T.prim_func + def kernel( + K: T.Tensor(K_shape, dtype=input_dtype), + W: T.Tensor(W_shape, dtype=input_dtype), + U: T.Tensor(U_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + initial_state: T.Tensor(initial_state_shape, dtype=input_dtype), + h: T.Tensor(h_shape, dtype=output_dtype), + final_state: T.Tensor(final_state_shape, dtype=state_dtype), + V_new: T.Tensor(V_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(DV, block_DV), B * H, threads=threads) as (bv, bbh): + bb, bh = bbh // H, bbh % H + + b_h_shared = T.alloc_shared((DK, block_DV), dtype=input_dtype) + b_h_fragment = T.alloc_fragment((DK, block_DV), dtype=accum_dtype) + + U_shared = T.alloc_shared((block_S, block_DV), dtype=input_dtype) + U_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + W_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + V_new_fragment = T.alloc_fragment((block_S, block_DV), dtype=accum_dtype) + V_new_shared = T.alloc_shared((block_S, block_DV), dtype=output_dtype) + K_shared = T.alloc_shared((block_S, DK), dtype=input_dtype) + GK_last_shared = T.alloc_shared((DK), dtype=gate_dtype) + + if use_initial_state: + T.copy(initial_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV], b_h_shared) + T.copy(b_h_shared, b_h_fragment) + else: + T.clear(b_h_fragment) + + for i_s in T.Pipelined(T.ceildiv(S, block_S), num_stages=num_stages): + # Store previous result to the hidden tensor, like the epilogue + T.copy(b_h_shared, h[bb, i_s, bh, :, bv * block_DV : (bv + 1) * block_DV]) + + # Recurrence + T.copy(W[bb, i_s * block_S : (i_s + 1) * block_S, bh, :], W_shared) + T.gemm(W_shared, b_h_shared, V_new_fragment, clear_accum=True) + + # U - W * S + T.copy(U[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV], U_shared) + T.copy(U_shared, U_fragment) + for i_s2, i_v in T.Parallel(block_S, block_DV): + V_new_fragment[i_s2, i_v] = -V_new_fragment[i_s2, i_v] + U_fragment[i_s2, i_v] + + # Save V_new + if save_new_value: + T.copy(V_new_fragment, dst=V_new_shared) + T.copy(V_new_shared, V_new[bb, i_s * block_S : (i_s + 1) * block_S, bh, bv * block_DV : (bv + 1) * block_DV]) + + T.copy(K[bb, i_s * block_S : (i_s + 1) * block_S, bh, 0:DK], K_shared) + # use_gk + if use_gk: + T.copy(GK[bb, (i_s + 1) * block_S - 1, bh, :], GK_last_shared) # block last token + for i_k, i_v in T.Parallel(DK, block_DV): + b_h_fragment[i_k, i_v] *= T.exp2(GK_last_shared[i_k]) + + # Update intermediate results + T.copy(V_new_fragment, V_new_shared) + T.gemm(K_shared, V_new_shared, b_h_fragment, transpose_A=True) + + T.copy(b_h_fragment, b_h_shared) + + # Save final state + if store_final_state: + T.copy(b_h_fragment, final_state[bb, bh, 0:DK, bv * block_DV : (bv + 1) * block_DV]) + + return kernel + + +def run_test( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=64, + block_DV=32, + threads=128, + num_stages=0, +): + K, W, U, G, initial_state = prepare_input( + B, + S, + H, + DK, + DV, + chunk_size, + getattr(torch, input_dtype), + getattr(torch, output_dtype), + getattr(torch, accum_dtype), + getattr(torch, gate_dtype), + ) + h_ref, final_state_ref, V_new_ref = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + h_tilelang, final_state_tilelang, V_new_tilelang = prepare_output( + B, S, H, DK, DV, chunk_size, getattr(torch, output_dtype), getattr(torch, state_dtype) + ) + + # fla ref + h_ref, V_new_ref, final_state_ref = chunk_gated_delta_rule_fwd_h( + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + + # tilelang + kernel = tilelang_chunk_gated_delta_rule_fwd_h( + B, + S, + H, + DK, + DV, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + state_dtype, + chunk_size, + use_gk, + use_initial_state, + store_final_state, + save_new_value, + ) + h_tilelang, final_state_tilelang, V_new_tilelang = kernel(K, W, U, G, initial_state) + + fla_time = do_bench( + chunk_gated_delta_rule_fwd_h, + k=K, + w=W, + u=U, + gk=G, + initial_state=initial_state, + output_final_state=store_final_state, + chunk_size=chunk_size, + save_new_value=save_new_value, + use_exp2=True, + ) + tilelang_time = do_bench(kernel, K, W, U, G, initial_state) + + # check correctness + compare_tensors("h", h_ref, h_tilelang) + compare_tensors("final_state", final_state_ref, final_state_tilelang) + compare_tensors("V_new", V_new_ref, V_new_tilelang) + + print(f"tilelang time: {tilelang_time} ms") + print(f"fla time: {fla_time} ms") + + +def main(): + run_test( + B=1, + S=8192, + H=64, + DK=128, + DV=128, + input_dtype="float16", + output_dtype="float16", + accum_dtype="float32", + gate_dtype="float32", + state_dtype="float32", + chunk_size=64, + use_gk=True, + use_initial_state=True, + store_final_state=True, + save_new_value=True, + block_DK=32, + block_DV=32, + threads=128, + num_stages=2, + ) + + +if __name__ == "__main__": + main() diff --git a/examples/kda/chunk_inter_solve_fused.py b/examples/kda/chunk_inter_solve_fused.py new file mode 100644 index 0000000000..940dc20c86 --- /dev/null +++ b/examples/kda/chunk_inter_solve_fused.py @@ -0,0 +1,566 @@ +import tilelang +import tilelang.language as T + +from FLA_KDA.fla_chunk_intra import chunk_kda_fwd_inter_solve_fused +from FLA_KDA.cumsum import chunk_local_cumsum +from test_utils_kda import compare_tensors, do_bench + +import torch +import torch.nn.functional as F + + +torch.random.manual_seed(42) + + +def prepare_input( + B, + S, + H, + DK, + chunk_size, + sub_chunk_size, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, +): + q = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + k = torch.randn(B, S, H, DK, dtype=input_dtype).cuda() + beta = torch.randn(B, S, H, dtype=input_dtype).cuda() + gk = torch.randn(B, S, H, DK, dtype=gate_dtype).cuda() # 需要是cumsum + gk = F.logsigmoid(gk) + gk = chunk_local_cumsum(gk, chunk_size) + + Aqk = torch.empty(B, S, H, chunk_size, dtype=input_dtype).cuda() + Akk_diag = torch.ones(B, S, H, sub_chunk_size, dtype=torch.float32).cuda() + + return q, k, gk, beta, Aqk, Akk_diag + + +def prepare_output( + B, + S, + H, + chunk_size, + sub_chunk_size, + output_dtype, +): + Akk = torch.empty(B, S, H, chunk_size, dtype=output_dtype).cuda() + return Akk + + +@tilelang.jit(out_idx=[-2, -1], pass_configs={tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True}) +def tilelang_chunk_kda_fwd_inter_fused( + B, + S, + H, + DK, + input_dtype, + output_dtype, + accum_dtype, + gate_dtype, + chunk_size, + sub_chunk_size, + scale, + block_DK=32, + threads=32, + num_stages=1, +): + block_S = BS = chunk_size + BC = sub_chunk_size + Q_shape = (B, S, H, DK) + K_shape = (B, S, H, DK) + GK_shape = (B, S, H, DK) + Beta_shape = (B, S, H) + Aqk_shape = (B, S, H, BS) + Akk_diag_shape = (B, S, H, BC) + """ + Fused kernel: compute inter-subchunk Akk + solve_tril in one pass. + Prerequisite: token_parallel has already computed diagonal Akk blocks in Akk_diag. + + This kernel: + 1. Computes off-diagonal Aqk blocks -> writes to global + 2. Computes off-diagonal Akk blocks -> keeps in registers + 3. Loads diagonal Akk blocks from Akk_diag (fp32) + 4. Does forward substitution on diagonals + 5. Computes merged Akk_inv + 6. Writes Akk_inv to Akk + """ + + @T.prim_func + def kernel( + Q: T.Tensor(Q_shape, dtype=input_dtype), + K: T.Tensor(K_shape, dtype=input_dtype), + GK: T.Tensor(GK_shape, dtype=gate_dtype), + Beta: T.Tensor(Beta_shape, dtype=input_dtype), + Akk_diag: T.Tensor(Akk_diag_shape, dtype=T.float32), + Aqk: T.Tensor(Aqk_shape, dtype=output_dtype), + Akk: T.Tensor(Aqk_shape, dtype=output_dtype), + ): + with T.Kernel(T.ceildiv(S, block_S), B * H, threads=threads) as (bs, bbh): + bb, bh = bbh // H, bbh % H + + Aqk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk20_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk21_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk30_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk31_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Aqk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk32_fragment = T.alloc_fragment((BC, BC), dtype=accum_dtype) + Akk10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Akk32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + K0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK0_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + Q3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + GK3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + Q_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + K_GK_scaled_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_kt_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + b_gn1_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn2_shared = T.alloc_shared((block_DK,), dtype=T.float32) + b_gn3_shared = T.alloc_shared((block_DK,), dtype=T.float32) + + b_gqn1_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn2_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + b_gqn3_shared = T.alloc_shared((BC, block_DK), dtype=T.float32) + + beta_1_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_2_shared = T.alloc_shared((BC,), dtype=T.float32) + beta_3_shared = T.alloc_shared((BC,), dtype=T.float32) + # Akk_inv + Ai_00_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_10_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_11_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_20_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_21_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_22_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_30_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_31_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_32_shared = T.alloc_shared((BC, BC), dtype=T.float32) + Ai_33_shared = T.alloc_shared((BC, BC), dtype=T.float32) + + T.clear(Aqk10_fragment) + T.clear(Akk10_fragment) + T.clear(Aqk20_fragment) + T.clear(Akk20_fragment) + T.clear(Aqk21_fragment) + T.clear(Akk21_fragment) + T.clear(Aqk30_fragment) + T.clear(Akk30_fragment) + T.clear(Aqk31_fragment) + T.clear(Akk31_fragment) + T.clear(Aqk32_fragment) + T.clear(Akk32_fragment) + + i_tc0 = bs * BS + i_tc1 = bs * BS + BC + i_tc2 = bs * BS + 2 * BC + i_tc3 = bs * BS + 3 * BC + + ################################################################################ + # 1. off-diagonal blocks + ################################################################################ + + for i_k in T.Pipelined(T.ceildiv(DK, block_DK), num_stages=num_stages): + T.copy(K[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K0_shared) + T.copy(GK[bb, bs * BS : bs * BS + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK0_shared) + if i_tc1 < S: + T.copy(Q[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q1_shared) + T.copy(K[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K1_shared) + T.copy(GK[bb, i_tc1 : i_tc1 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK1_shared) + T.copy(GK[bb, i_tc1, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn1_shared) # subblock第一个token的GK + for i_c1, i_k1 in T.Parallel(BC, block_DK): + b_gqn1_shared[i_c1, i_k1] = T.if_then_else( + i_tc1 + i_c1 < S, T.exp2(GK1_shared[i_c1, i_k1] - b_gn1_shared[i_k1]), 0.0 + ) + Q_GK_scaled_shared[i_c1, i_k1] = Q1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + K_GK_scaled_shared[i_c1, i_k1] = K1_shared[i_c1, i_k1] * b_gqn1_shared[i_c1, i_k1] + b_kt_shared[i_c1, i_k1] = K0_shared[i_c1, i_k1] * T.exp2(b_gn1_shared[i_k1] - GK0_shared[i_c1, i_k1]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk10_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk10_fragment, transpose_B=True) + if i_tc2 < S: + T.copy(Q[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q2_shared) + T.copy(K[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K2_shared) + T.copy(GK[bb, i_tc2 : i_tc2 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK2_shared) + T.copy(GK[bb, i_tc2, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn2_shared) + for i_c2, i_k2 in T.Parallel(BC, block_DK): + b_gqn2_shared[i_c2, i_k2] = T.if_then_else( + i_tc2 + i_c2 < S, T.exp2(GK2_shared[i_c2, i_k2] - b_gn2_shared[i_k2]), 0.0 + ) + Q_GK_scaled_shared[i_c2, i_k2] = Q2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + K_GK_scaled_shared[i_c2, i_k2] = K2_shared[i_c2, i_k2] * b_gqn2_shared[i_c2, i_k2] + b_kt_shared[i_c2, i_k2] = K0_shared[i_c2, i_k2] * T.exp2(b_gn2_shared[i_k2] - GK0_shared[i_c2, i_k2]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk20_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk20_fragment, transpose_B=True) + for i_c3, i_k3 in T.Parallel(BC, block_DK): + b_kt_shared[i_c3, i_k3] = K1_shared[i_c3, i_k3] * T.exp2(b_gn2_shared[i_k3] - GK1_shared[i_c3, i_k3]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk21_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk21_fragment, transpose_B=True) + if i_tc3 < S: + T.copy(Q[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], Q3_shared) + T.copy(K[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], K3_shared) + T.copy(GK[bb, i_tc3 : i_tc3 + BC, bh, i_k * block_DK : (i_k + 1) * block_DK], GK3_shared) + T.copy(GK[bb, i_tc3, bh, i_k * block_DK : (i_k + 1) * block_DK], b_gn3_shared) + for i_c4, i_k4 in T.Parallel(BC, block_DK): + b_gqn3_shared[i_c4, i_k4] = T.if_then_else( + i_tc3 + i_c4 < S, T.exp2(GK3_shared[i_c4, i_k4] - b_gn3_shared[i_k4]), 0.0 + ) + Q_GK_scaled_shared[i_c4, i_k4] = Q3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + K_GK_scaled_shared[i_c4, i_k4] = K3_shared[i_c4, i_k4] * b_gqn3_shared[i_c4, i_k4] + b_kt_shared[i_c4, i_k4] = K0_shared[i_c4, i_k4] * T.exp2(b_gn3_shared[i_k4] - GK0_shared[i_c4, i_k4]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk30_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk30_fragment, transpose_B=True) + for i_c5, i_k5 in T.Parallel(BC, block_DK): + b_kt_shared[i_c5, i_k5] = K1_shared[i_c5, i_k5] * T.exp2(b_gn3_shared[i_k5] - GK1_shared[i_c5, i_k5]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk31_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk31_fragment, transpose_B=True) + for i_c6, i_k6 in T.Parallel(BC, block_DK): + b_kt_shared[i_c6, i_k6] = K2_shared[i_c6, i_k6] * T.exp2(b_gn3_shared[i_k6] - GK2_shared[i_c6, i_k6]) + T.gemm(Q_GK_scaled_shared, b_kt_shared, Aqk32_fragment, transpose_B=True) + T.gemm(K_GK_scaled_shared, b_kt_shared, Akk32_fragment, transpose_B=True) + + ################################################################################ + # 2. save off-diagonal Aqk blocks and prepare Akk + ################################################################################ + + if i_tc1 < S: + T.copy(Beta[bb, i_tc1 : i_tc1 + BC, bh], beta_1_shared) + for i_c21, i_c22 in T.Parallel(BC, BC): + Aqk10_fragment[i_c21, i_c22] = Aqk10_fragment[i_c21, i_c22] * scale + Akk10_fragment[i_c21, i_c22] = Akk10_fragment[i_c21, i_c22] * beta_1_shared[i_c21] + T.copy(Aqk10_fragment, Aqk[bb, i_tc1 : i_tc1 + BC, bh, 0:BC]) + T.copy(Akk10_fragment, Akk10_shared) + if i_tc2 < S: + T.copy(Beta[bb, i_tc2 : i_tc2 + BC, bh], beta_2_shared) + for i_c23, i_c24 in T.Parallel(BC, BC): + Aqk20_fragment[i_c23, i_c24] = Aqk20_fragment[i_c23, i_c24] * scale + Aqk21_fragment[i_c23, i_c24] = Aqk21_fragment[i_c23, i_c24] * scale + Akk20_fragment[i_c23, i_c24] = Akk20_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + Akk21_fragment[i_c23, i_c24] = Akk21_fragment[i_c23, i_c24] * beta_2_shared[i_c23] + T.copy(Aqk20_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, 0:BC]) + T.copy(Aqk21_fragment, Aqk[bb, i_tc2 : i_tc2 + BC, bh, BC : 2 * BC]) + T.copy(Akk20_fragment, Akk20_shared) + T.copy(Akk21_fragment, Akk21_shared) + if i_tc3 < S: + T.copy(Beta[bb, i_tc3 : i_tc3 + BC, bh], beta_3_shared) + for i_c25, i_c26 in T.Parallel(BC, BC): + Aqk30_fragment[i_c25, i_c26] = Aqk30_fragment[i_c25, i_c26] * scale + Aqk31_fragment[i_c25, i_c26] = Aqk31_fragment[i_c25, i_c26] * scale + Aqk32_fragment[i_c25, i_c26] = Aqk32_fragment[i_c25, i_c26] * scale + Akk30_fragment[i_c25, i_c26] = Akk30_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk31_fragment[i_c25, i_c26] = Akk31_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + Akk32_fragment[i_c25, i_c26] = Akk32_fragment[i_c25, i_c26] * beta_3_shared[i_c25] + T.copy(Aqk30_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 0:BC]) + T.copy(Aqk31_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, BC : 2 * BC]) + T.copy(Aqk32_fragment, Aqk[bb, i_tc3 : i_tc3 + BC, bh, 2 * BC : 3 * BC]) + T.copy(Akk30_fragment, Akk30_shared) + T.copy(Akk31_fragment, Akk31_shared) + T.copy(Akk32_fragment, Akk32_shared) + + ################################################################################ + # 3. load diagonal Akk blocks + ################################################################################ + + T.copy(Akk_diag[bb, i_tc0 : i_tc0 + BC, bh, :], Ai_00_shared) + T.copy(Akk_diag[bb, i_tc1 : i_tc1 + BC, bh, :], Ai_11_shared) + T.copy(Akk_diag[bb, i_tc2 : i_tc2 + BC, bh, :], Ai_22_shared) + T.copy(Akk_diag[bb, i_tc3 : i_tc3 + BC, bh, :], Ai_33_shared) + for i_c1, i_c2 in T.Parallel(BC, BC): + Ai_00_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_00_shared[i_c1, i_c2], 0) + Ai_11_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_11_shared[i_c1, i_c2], 0) + Ai_22_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_22_shared[i_c1, i_c2], 0) + Ai_33_shared[i_c1, i_c2] = T.if_then_else(i_c1 > i_c2, -Ai_33_shared[i_c1, i_c2], 0) + + ################################################################################ + # 4. forward substitution on diagonals + ################################################################################ + a_00_shared = T.alloc_shared((BC,), dtype=T.float32) + Aa_mul_shared = T.alloc_shared((BC, BC), dtype=T.float32) + reduce_shared = T.alloc_shared((BC,), dtype=T.float32) + for i_i in T.Pipelined(2, T.min(BC, S - i_tc0), num_stages=num_stages): + T.copy(Akk_diag[bb, i_tc0 + i_i, bh, :], a_00_shared) # load row + for i_c in T.Parallel(BC): + a_00_shared[i_c] = T.if_then_else(i_c < i_i, -a_00_shared[i_c], 0.0) # mask:i_c