diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fe42ea7dc..c6f376231 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -17,7 +17,7 @@ jobs: CT2_VERBOSE: 1 strategy: matrix: - os: [ubuntu-20.04, macos-11] + os: [ubuntu-20.04] backend: [mkl, dnnl] steps: @@ -33,17 +33,6 @@ jobs: sudo sh -c 'echo "deb https://apt.repos.intel.com/oneapi all main" > /etc/apt/sources.list.d/oneAPI.list' sudo apt-get update - - name: Store URL for downloading Intel oneAPI to environment variable - if: startsWith(matrix.os, 'macos') - run: | - echo 'ONEAPI_INSTALLER_URL=https://registrationcenter-download.intel.com/akdlm/irc_nas/19080/m_BaseKit_p_2023.0.0.25441_offline.dmg' >> $GITHUB_ENV - - - name: Install Intel oneAPI - if: startsWith(matrix.os, 'macos') - run: | - wget -q $ONEAPI_INSTALLER_URL - hdiutil attach -noverify -noautofsck $(basename $ONEAPI_INSTALLER_URL) - - name: Configure with MKL if: startsWith(matrix.os, 'ubuntu') && matrix.backend == 'mkl' env: @@ -61,20 +50,6 @@ jobs: sudo apt-get install -y intel-oneapi-dnnl-devel=$DNNL_VERSION intel-oneapi-dnnl=$DNNL_VERSION cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON -DWITH_MKL=OFF -DOPENMP_RUNTIME=COMP -DWITH_DNNL=ON . - - name: Configure with MKL - if: startsWith(matrix.os, 'macos') && matrix.backend == 'mkl' - env: - CT2_USE_MKL: 1 - run: | - sudo /Volumes/$(basename $ONEAPI_INSTALLER_URL .dmg)/bootstrapper.app/Contents/MacOS/bootstrapper --silent --eula accept --components intel.oneapi.mac.mkl.devel - cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON . - - - name: Configure with DNNL - if: startsWith(matrix.os, 'macos') && matrix.backend == 'dnnl' - run: | - sudo /Volumes/$(basename $ONEAPI_INSTALLER_URL .dmg)/bootstrapper.app/Contents/MacOS/bootstrapper --silent --eula accept --components intel.oneapi.mac.dnnl - cmake -DCMAKE_INSTALL_PREFIX=$PWD/install -DBUILD_TESTS=ON -DWITH_MKL=OFF -DWITH_DNNL=ON . - - name: Build run: | make install @@ -157,12 +132,12 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04, macos-11, windows-2019] + os: [ubuntu-20.04, windows-2019] arch: [auto64] include: - os: ubuntu-20.04 arch: aarch64 - - os: macos-11 + - os: macos-12 arch: arm64 steps: @@ -206,7 +181,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: matrix: - os: [ubuntu-20.04, macos-11, windows-2019] + os: [ubuntu-20.04, windows-2019] steps: - name: Set up Python 3.8 @@ -231,11 +206,6 @@ jobs: run: | pip install *cp38*manylinux*x86_64.whl - - name: Install wheel - if: startsWith(matrix.os, 'macos') - run: | - pip install *cp38*macosx*x86_64.whl - - name: Install wheel if: startsWith(matrix.os, 'windows') shell: bash diff --git a/CMakeLists.txt b/CMakeLists.txt index ac94aac57..5b7d0bf7c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -21,6 +21,7 @@ option(BUILD_CLI "Compile the clients" ON) option(BUILD_TESTS "Compile the tests" OFF) option(BUILD_SHARED_LIBS "Build shared libraries" ON) option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF) +option(WITH_FLASH_ATTN "Compile with Flash Attention 2" OFF) if(ENABLE_PROFILING) message(STATUS "Enable profiling support") @@ -485,8 +486,10 @@ if (WITH_CUDA) set(CUDA_HOST_COMPILER ${CMAKE_CXX_COMPILER}) # flags for flash attention - list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") - list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda") + if (WITH_FLASH_ATTN) + list(APPEND CUDA_NVCC_FLAGS "--expt-relaxed-constexpr") + list(APPEND CUDA_NVCC_FLAGS "--expt-extended-lambda") + endif() message(STATUS "NVCC host compiler: ${CUDA_HOST_COMPILER}") message(STATUS "NVCC compilation flags: ${CUDA_NVCC_FLAGS}") @@ -570,77 +573,84 @@ if (WITH_CUDA) src/ops/topp_mask_gpu.cu src/ops/quantize_gpu.cu src/ops/nccl_ops_gpu.cu - src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu src/ops/awq/gemm_gpu.cu src/ops/awq/gemv_gpu.cu src/ops/awq/dequantize_gpu.cu ) + if (WITH_FLASH_ATTN) + add_definitions(-DCT2_WITH_FLASH_ATTN) + cuda_add_library(${PROJECT_NAME} + src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu + ) + + set_source_files_properties( + src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu + src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu + PROPERTIES COMPILE_FLAGS "--use_fast_math") + endif() + - set_source_files_properties( - src/ops/flash-attention/flash_fwd_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_hdim256_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim32_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim64_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim96_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim128_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim160_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim192_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim224_fp16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_bf16_sm80.cu - src/ops/flash-attention/flash_fwd_split_hdim256_fp16_sm80.cu - PROPERTIES COMPILE_FLAGS "--use_fast_math") elseif(WITH_CUDNN) message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON") else() diff --git a/python/tools/prepare_build_environment_linux.sh b/python/tools/prepare_build_environment_linux.sh index 89f8293f6..b47b865f5 100755 --- a/python/tools/prepare_build_environment_linux.sh +++ b/python/tools/prepare_build_environment_linux.sh @@ -19,9 +19,12 @@ if [ "$CIBW_ARCHS" == "aarch64" ]; then rm -r OpenBLAS-* else - # Install CUDA 12.2: yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel7/x86_64/cuda-rhel7.repo + # error mirrorlist.centos.org doesn't exists anymore. + sed -i s/mirror.centos.org/vault.centos.org/g /etc/yum.repos.d/*.repo + sed -i s/^#.*baseurl=http/baseurl=http/g /etc/yum.repos.d/*.repo + sed -i s/^mirrorlist=http/#mirrorlist=http/g /etc/yum.repos.d/*.repo yum install --setopt=obsoletes=0 -y \ cuda-nvcc-12-2-12.2.140-1 \ cuda-cudart-devel-12-2-12.2.140-1 \ diff --git a/python/tools/prepare_test_environment.sh b/python/tools/prepare_test_environment.sh index 599e4c673..3e40184d3 100755 --- a/python/tools/prepare_test_environment.sh +++ b/python/tools/prepare_test_environment.sh @@ -3,6 +3,9 @@ set -e set -x +# force use pip < 24.1 +python -m pip install 'pip<24.1' + # Install test rquirements pip cache purge pip --no-cache-dir install -r python/tests/requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu diff --git a/src/ops/flash_attention_gpu.cu b/src/ops/flash_attention_gpu.cu index dae401db7..af0e1ea56 100644 --- a/src/ops/flash_attention_gpu.cu +++ b/src/ops/flash_attention_gpu.cu @@ -1,8 +1,9 @@ #include "ctranslate2/ops/flash_attention.h" +#ifdef CT2_WITH_FLASH_ATTN #include "ctranslate2/ops/flash-attention/flash.h" #include "ctranslate2/ops/flash-attention/static_switch.h" +#endif #include "ctranslate2/ops/transpose.h" -#include "ctranslate2/ops/slide.h" #include "cuda/utils.h" #include "dispatch.h" @@ -13,6 +14,7 @@ namespace ctranslate2 { namespace ops { +#ifdef CT2_WITH_FLASH_ATTN static void set_params_fprop(Flash_fwd_params ¶ms, // sizes const size_t b, @@ -188,7 +190,7 @@ namespace ctranslate2 { } static const ops::Transpose transpose_op({0, 2, 1, 3}); - +#endif template<> void FlashAttention::compute(StorageView& queries, StorageView& keys, @@ -203,6 +205,7 @@ namespace ctranslate2 { const bool rotary_interleave, StorageView* alibi, dim_t offset) const { +#ifdef CT2_WITH_FLASH_ATTN const Device device = queries.device(); const DataType dtype = queries.dtype(); @@ -357,6 +360,9 @@ namespace ctranslate2 { output.reshape({batch_size, 1, num_heads_k * seqlen_q, head_size_og}); softmax_lse = softmax_lse.reshape({batch_size, num_heads_k * seqlen_q, 1}); } +#else + throw std::runtime_error("Flash attention 2 is not supported"); +#endif } } }