From cec65f17c19edff5265025d20040df360b149ad8 Mon Sep 17 00:00:00 2001
From: Minh-Thuc <46375464+minhthuc2502@users.noreply.github.com>
Date: Tue, 5 Mar 2024 09:56:43 +0100
Subject: [PATCH] support tensor parallel (#1599)
* tensor parallel support
* add docs
* small fix
* fix adding bias multiple times in layer output.
---
CMakeLists.txt | 36 ++-
README.md | 1 +
cmake/FindNCCL.cmake | 28 ++
docker/Dockerfile | 20 +-
docs/parallel.md | 44 ++-
include/ctranslate2/devices.h | 30 +++
include/ctranslate2/layers/attention.h | 4 +-
include/ctranslate2/layers/common.h | 4 +-
include/ctranslate2/layers/transformer.h | 3 +
include/ctranslate2/models/model.h | 12 +-
include/ctranslate2/ops/nccl_ops.h | 35 +++
include/ctranslate2/ops/ops.h | 1 +
include/ctranslate2/replica_pool.h | 2 +
include/ctranslate2/utils.h | 3 +
python/cpp/encoder.cc | 6 +-
python/cpp/generator.cc | 6 +-
python/cpp/module.cc | 1 +
python/cpp/module.h | 1 +
python/cpp/mpi.cc | 30 +++
python/cpp/replica_pool.h | 6 +
python/cpp/translator.cc | 8 +-
python/cpp/wav2vec2.cc | 6 +-
python/cpp/whisper.cc | 6 +-
python/ctranslate2/__init__.py | 1 +
python/ctranslate2/specs/model_spec.py | 3 +
python/ctranslate2/specs/transformer_spec.py | 13 +
.../tools/prepare_build_environment_linux.sh | 14 +-
src/cuda/mpi_stub.cc | 94 +++++++
src/cuda/mpi_stub.h | 18 ++
src/cuda/nccl_stub.cc | 93 +++++++
src/cuda/utils.h | 22 ++
src/devices.cc | 99 +++++++
src/layers/attention.cc | 31 ++-
src/layers/common.cc | 45 +++-
src/layers/transformer.cc | 42 ++-
src/models/model.cc | 254 +++++++++++++++++-
src/ops/nccl_ops.cc | 23 ++
src/ops/nccl_ops_cpu.cc | 23 ++
src/ops/nccl_ops_gpu.cu | 93 +++++++
src/utils.cc | 1 -
tools/benchmark_tensor_parallel/README.md | 18 ++
tools/benchmark_tensor_parallel/benchmark.py | 172 ++++++++++++
.../requirements.txt | 3 +
43 files changed, 1313 insertions(+), 42 deletions(-)
create mode 100644 cmake/FindNCCL.cmake
create mode 100644 include/ctranslate2/ops/nccl_ops.h
create mode 100644 python/cpp/mpi.cc
create mode 100644 src/cuda/mpi_stub.cc
create mode 100644 src/cuda/mpi_stub.h
create mode 100644 src/cuda/nccl_stub.cc
create mode 100644 src/ops/nccl_ops.cc
create mode 100644 src/ops/nccl_ops_cpu.cc
create mode 100644 src/ops/nccl_ops_gpu.cu
create mode 100644 tools/benchmark_tensor_parallel/README.md
create mode 100644 tools/benchmark_tensor_parallel/benchmark.py
create mode 100644 tools/benchmark_tensor_parallel/requirements.txt
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 1089106cc..a32a45fe7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -20,6 +20,7 @@ option(ENABLE_PROFILING "Compile with profiling support" OFF)
option(BUILD_CLI "Compile the clients" ON)
option(BUILD_TESTS "Compile the tests" OFF)
option(BUILD_SHARED_LIBS "Build shared libraries" ON)
+option(WITH_TENSOR_PARALLEL "Compile with NCCL and MPI backend" OFF)
if(ENABLE_PROFILING)
message(STATUS "Enable profiling support")
@@ -179,6 +180,8 @@ set(SOURCES
src/ops/topp_mask.cc
src/ops/topp_mask_cpu.cc
src/ops/transpose.cc
+ src/ops/nccl_ops.cc
+ src/ops/nccl_ops_cpu.cc
src/padder.cc
src/profiler.cc
src/random.cc
@@ -191,7 +194,7 @@ set(SOURCES
src/utils.cc
src/vocabulary.cc
src/vocabulary_map.cc
- )
+)
set(LIBRARIES
${CMAKE_THREAD_LIBS_INIT}
spdlog::spdlog_header_only
@@ -419,6 +422,24 @@ endif()
if (WITH_CUDA)
find_package(CUDA 11.0 REQUIRED)
+ list(APPEND CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake)
+ if (WITH_TENSOR_PARALLEL)
+ find_package(MPI REQUIRED)
+ find_package(NCCL REQUIRED)
+ include_directories(${NCCL_INCLUDE_DIR})
+ include_directories(${MPI_INCLUDE_PATH})
+ if(CUDA_DYNAMIC_LOADING)
+ list(APPEND SOURCES src/cuda/mpi_stub.cc)
+ list(APPEND SOURCES src/cuda/nccl_stub.cc)
+ add_definitions(-DCT2_WITH_CUDA_DYNAMIC_LOADING)
+ else ()
+ list(APPEND LIBRARIES ${NCCL_LIBRARY})
+ list(APPEND LIBRARIES ${MPI_LIBRARIES})
+ endif ()
+ add_definitions(-DCT2_WITH_TENSOR_PARALLEL)
+ endif ()
+ include_directories(${CUDA_TOOLKIT_ROOT_DIR}/include)
+
add_definitions(-DCT2_WITH_CUDA)
if(MSVC)
if(BUILD_SHARED_LIBS)
@@ -522,7 +543,8 @@ if (WITH_CUDA)
src/ops/topk_gpu.cu
src/ops/topp_mask_gpu.cu
src/ops/quantize_gpu.cu
- )
+ src/ops/nccl_ops_gpu.cu
+ )
elseif(WITH_CUDNN)
message(FATAL_ERROR "WITH_CUDNN=ON requires WITH_CUDA=ON")
else()
@@ -546,6 +568,10 @@ target_include_directories(${PROJECT_NAME} BEFORE
PRIVATE ${PRIVATE_INCLUDE_DIRECTORIES}
)
+if (WITH_TENSOR_PARALLEL AND CUDA_DYNAMIC_LOADING)
+ target_compile_options(${PROJECT_NAME} PRIVATE -DOMPI_SKIP_MPICXX)
+endif()
+
if(BUILD_TESTS)
add_subdirectory(tests)
endif()
@@ -587,6 +613,11 @@ configure_file(cmake/${PROJECT_NAME}Config.cmake
COPYONLY
)
+configure_file(cmake/FindNCCL.cmake
+ "${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/FindNCCL.cmake"
+ COPYONLY
+)
+
set(ConfigPackageLocation ${CMAKE_INSTALL_LIBDIR}/cmake/${PROJECT_NAME})
if(BUILD_SHARED_LIBS)
@@ -603,6 +634,7 @@ endif()
install(
FILES
cmake/${PROJECT_NAME}Config.cmake
+ cmake/FindNCCL.cmake
"${CMAKE_CURRENT_BINARY_DIR}/${PROJECT_NAME}/${PROJECT_NAME}ConfigVersion.cmake"
DESTINATION
${ConfigPackageLocation}
diff --git a/README.md b/README.md
index 5a9f51abd..53fd07430 100644
--- a/README.md
+++ b/README.md
@@ -34,6 +34,7 @@ The project is production-oriented and comes with [backward compatibility guaran
* **Lightweight on disk**
Quantization can make the models 4 times smaller on disk with minimal accuracy loss.
* **Simple integration**
The project has few dependencies and exposes simple APIs in [Python](https://opennmt.net/CTranslate2/python/overview.html) and C++ to cover most integration needs.
* **Configurable and interactive decoding**
[Advanced decoding features](https://opennmt.net/CTranslate2/decoding.html) allow autocompleting a partial sequence and returning alternatives at a specific location in the sequence.
+* **Support tensor parallelism for distributed inference.
Some of these features are difficult to achieve with standard deep learning frameworks and are the motivation for this project.
diff --git a/cmake/FindNCCL.cmake b/cmake/FindNCCL.cmake
new file mode 100644
index 000000000..c5f0e31e8
--- /dev/null
+++ b/cmake/FindNCCL.cmake
@@ -0,0 +1,28 @@
+# Find the NCCL libraries
+#
+# The following variables are optionally searched for defaults
+# NCCL_ROOT_DIR: Base directory where all NCCL components are found
+#
+# The following are set after configuration is done:
+# NCCL_FOUND
+# NCCL_INCLUDE_DIR
+# NCCL_LIBRARY
+
+find_path(NCCL_INCLUDE_DIR NAMES nccl.h
+ PATHS ${NCCL_ROOT_DIR}/include
+)
+
+find_library(NCCL_LIBRARY NAMES nccl
+ PATHS ${NCCL_ROOT_DIR}/lib ${NCCL_ROOT_DIR}/lib64)
+
+include(FindPackageHandleStandardArgs)
+find_package_handle_standard_args(NCCL DEFAULT_MSG NCCL_INCLUDE_DIR
+ NCCL_LIBRARY)
+
+if (NCCL_FOUND)
+ message(STATUS "Found NCCL (include: ${NCCL_INCLUDE_DIR}, library:
+ ${NCCL_LIBRARY})")
+ mark_as_advanced(NCCL_INCLUDE_DIR NCCL_LIBRARY)
+ set(NCCL_VERSION "${NCCL_MAJOR}.${NCCL_MINOR}.${NCCL_PATCH}")
+
+endif ()
diff --git a/docker/Dockerfile b/docker/Dockerfile
index bfc7dfcbf..c1d5a47fb 100644
--- a/docker/Dockerfile
+++ b/docker/Dockerfile
@@ -35,6 +35,16 @@ RUN wget -q https://github.com/oneapi-src/oneDNN/archive/refs/tags/v${ONEDNN_VER
cd .. && \
rm -r oneDNN-*
+ENV OPENMPI_VERSION=4.1.6
+RUN wget -q https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2 && \
+ tar xf *.tar.bz2 && \
+ rm *.tar.bz2 && \
+ cd openmpi-* && \
+ ./configure && \
+ make -j$(nproc) install && \
+ cd .. && \
+ rm -r openmpi-*
+
COPY third_party third_party
COPY cli cli
COPY include include
@@ -50,13 +60,14 @@ ENV CUDA_NVCC_FLAGS=${CUDA_NVCC_FLAGS:-"-Xfatbin=-compress-all"}
ARG CUDA_ARCH_LIST
ENV CUDA_ARCH_LIST=${CUDA_ARCH_LIST:-"Common"}
ENV CTRANSLATE2_ROOT=/opt/ctranslate2
+ENV LD_LIBRARY_PATH=/usr/local/lib/:${LD_LIBRARY_PATH}
-RUN mkdir build && \
- cd build && \
+RUN mkdir build_tmp && \
+ cd build_tmp && \
cmake -DCMAKE_INSTALL_PREFIX=${CTRANSLATE2_ROOT} \
-DWITH_CUDA=ON -DWITH_CUDNN=ON -DWITH_MKL=ON -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP \
-DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="${CXX_FLAGS}" \
- -DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" .. && \
+ -DCUDA_NVCC_FLAGS="${CUDA_NVCC_FLAGS}" -DCUDA_ARCH_LIST="${CUDA_ARCH_LIST}" -DWITH_TENSOR_PARALLEL=ON .. && \
VERBOSE=1 make -j$(nproc) install
ENV LANG=en_US.UTF-8
@@ -74,6 +85,9 @@ RUN apt-get update && \
apt-get install -y --no-install-recommends \
libcublas-12-2 \
libcudnn8=8.9.7.29-1+cuda12.2 \
+ libnccl2=2.19.3-1+cuda12.2 \
+ libopenmpi3=4.0.3-0ubuntu1 \
+ openmpi-bin \
libgomp1 \
python3-pip \
&& \
diff --git a/docs/parallel.md b/docs/parallel.md
index 604fea122..ba827d7b2 100644
--- a/docs/parallel.md
+++ b/docs/parallel.md
@@ -42,8 +42,50 @@ Parallelization with multiple Python threads is possible because all computation
```
## Model and tensor parallelism
+Models used with [`Translator`](python/ctranslate2.Translator.rst) and [`Generator`](python/ctranslate2.Generator.rst) can be split into multiple GPUs.
+This is very useful when the model is too big to be loaded in only 1 GPU.
-These types of parallelism are not yet implemented in CTranslate2.
+```python
+translator = ctranslate2.Translator(model_path, device="cuda", tensor_parallel=True)
+```
+
+Setup environment:
+* Install [open-mpi](https://www.open-mpi.org/)
+* Configure open-mpi by creating the config file like ``hostfile``:
+```bash
+[ipaddress or dns] slots=nbGPU1
+[other ipaddress or dns] slots=NbGPU2
+```
+
+Run:
+* Run the application in multiprocess to use tensor parallel:
+```bash
+mpirun -np nbGPUExpected -hostfile hostfile python3 script
+```
+
+If you're trying to use tensor parallelism in multiple machines, some additional configuration is needed:
+* Make sure Master and Slave can connect to each other as a pair with ssh + pubkey
+* Export all necessary environment variables from Master to Slave like the example below:
+```bash
+mpirun -x VIRTUAL_ENV_PROMPT -x PATH -x VIRTUAL_ENV -x _ -x LD_LIBRARY_PATH -np nbGPUExpected -hostfile hostfile python3 script
+```
+Read more [open-mpi docs](https://www.open-mpi.org/doc/) for more information.
+
+* In this mode, the application will run in multiprocess. We can filter out the master process by using:
+```python
+if ctranslate2.MpiInfo.getCurRank() == 0:
+ print(...)
+```
+
+```{note}
+Running model in tensor parallel mode in one machine can boost the performance but if the model shared between multiple machines
+could be slower because of the latency in the connectivity.
+```
+
+```{note}
+In mode tensor parallel, `inter_threads` is always supported to run multiple workers. Otherwise, `device_index` no longer has any effect
+because tensor parallel mode will check only for available gpus on the system and the number of gpus you want to use.
+```
## Asynchronous execution
diff --git a/include/ctranslate2/devices.h b/include/ctranslate2/devices.h
index 2691efc3a..674713b8f 100644
--- a/include/ctranslate2/devices.h
+++ b/include/ctranslate2/devices.h
@@ -2,6 +2,10 @@
#include
#include
+#include
+#ifdef CT2_WITH_TENSOR_PARALLEL
+# include
+#endif
namespace ctranslate2 {
@@ -45,4 +49,30 @@ namespace ctranslate2 {
int _new_index;
};
+ extern int my_rank;
+ extern int local_rank;
+ extern int n_ranks;
+
+ class ScopedMPISetter {
+ public:
+ ScopedMPISetter();
+ ~ScopedMPISetter();
+
+ static int getNRanks();
+ static int getCurRank();
+ static int getLocalRank();
+
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ static ncclComm_t getNcclComm();
+#endif
+
+ static void finalize();
+
+ private:
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ static uint64_t getHostHash(const char *string);
+ static void getHostName(char *hostname, int maxlen);
+ static std::vector _nccl_comms;
+#endif
+ };
}
diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h
index b342f4faa..d2deb5e03 100644
--- a/include/ctranslate2/layers/attention.h
+++ b/include/ctranslate2/layers/attention.h
@@ -43,7 +43,7 @@ namespace ctranslate2 {
}
bool multi_query() const {
- return _num_heads_kv == 1;
+ return _multi_query;
}
static StorageView prepare_length_mask(const StorageView& lengths,
@@ -53,6 +53,7 @@ namespace ctranslate2 {
const bool multi_query = false);
private:
+ const bool _tensor_parallel;
const dim_t _num_heads;
const bool _self_attention;
const bool _is_decoder;
@@ -68,6 +69,7 @@ namespace ctranslate2 {
const StorageView* _relative_position_values;
dim_t _maximum_relative_position;
const float _queries_scale;
+ const bool _multi_query;
const dim_t _num_heads_kv;
const bool _merge_time_and_head_dims;
const dim_t _cache_time_dim;
diff --git a/include/ctranslate2/layers/common.h b/include/ctranslate2/layers/common.h
index cb8586b78..6c69275b8 100644
--- a/include/ctranslate2/layers/common.h
+++ b/include/ctranslate2/layers/common.h
@@ -127,7 +127,8 @@ namespace ctranslate2 {
public:
Dense(const models::Model& model,
const std::string& scope,
- const ops::ActivationType* activation_type = nullptr);
+ const ops::ActivationType* activation_type = nullptr,
+ const bool is_layer_out = false);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
@@ -147,6 +148,7 @@ namespace ctranslate2 {
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
+ const bool _is_layer_out;
};
class LayerNorm : public Layer
diff --git a/include/ctranslate2/layers/transformer.h b/include/ctranslate2/layers/transformer.h
index 61b9fae47..a7183a30d 100644
--- a/include/ctranslate2/layers/transformer.h
+++ b/include/ctranslate2/layers/transformer.h
@@ -34,6 +34,7 @@ namespace ctranslate2 {
const Dense _ff1;
const std::unique_ptr _ff1_noact;
const Dense _ff2;
+ const bool _tensor_parallel;
};
class TransformerEncoderLayer : public Layer
@@ -149,6 +150,7 @@ namespace ctranslate2 {
const std::unique_ptr _output_norm;
const std::vector> _layers;
const std::unique_ptr _position_encoder;
+ const bool _tensor_parallel;
};
class TransformerDecoder : public Decoder
@@ -211,6 +213,7 @@ namespace ctranslate2 {
bool _average_alignment_heads;
Dense _proj;
const dim_t _sliding_window;
+ const bool _tensor_parallel;
};
}
diff --git a/include/ctranslate2/models/model.h b/include/ctranslate2/models/model.h
index 43a4ea5b9..1bd7a4c14 100644
--- a/include/ctranslate2/models/model.h
+++ b/include/ctranslate2/models/model.h
@@ -26,11 +26,13 @@ namespace ctranslate2 {
static std::shared_ptr load(const std::string& path,
Device device = Device::CPU,
int device_index = 0,
- ComputeType compute_type = ComputeType::DEFAULT);
+ ComputeType compute_type = ComputeType::DEFAULT,
+ bool tensor_parallel = false);
static std::shared_ptr load(ModelReader& model_reader,
Device device = Device::CPU,
int device_index = 0,
- ComputeType compute_type = ComputeType::DEFAULT);
+ ComputeType compute_type = ComputeType::DEFAULT,
+ bool tensor_parallel = false);
virtual std::unique_ptr as_sequence_to_sequence() const;
virtual std::unique_ptr as_sequence_generator() const;
@@ -78,6 +80,10 @@ namespace ctranslate2 {
return _binary_version >= 5;
}
+ bool tensor_parallel() const {
+ return _tensor_parallel;
+ }
+
virtual bool use_global_int16_scale() const {
return true;
}
@@ -163,6 +169,7 @@ namespace ctranslate2 {
ComputeType _effective_compute_type = ComputeType::DEFAULT;
dim_t _preferred_size_multiple = 1;
std::unordered_map> _variable_index;
+ bool _tensor_parallel = false;
};
template<>
@@ -191,6 +198,7 @@ namespace ctranslate2 {
std::vector device_indices = {0};
size_t num_replicas_per_device = 1;
ComputeType compute_type = ComputeType::DEFAULT;
+ bool tensor_parallel = false;
};
// Base class for replicas.
diff --git a/include/ctranslate2/ops/nccl_ops.h b/include/ctranslate2/ops/nccl_ops.h
new file mode 100644
index 000000000..d610d972f
--- /dev/null
+++ b/include/ctranslate2/ops/nccl_ops.h
@@ -0,0 +1,35 @@
+#pragma once
+
+#include "op.h"
+
+namespace ctranslate2 {
+ namespace ops {
+ class ReduceAll : public Op {
+ public:
+ enum class RED_OP {
+ SUM,
+ PROD,
+ MIN,
+ MAX,
+ AVG
+ };
+
+ explicit ReduceAll(RED_OP op = RED_OP::SUM);
+ void operator()(const StorageView& input, StorageView& output) const;
+ private:
+ RED_OP _reduce_op;
+
+ template
+ void compute(const StorageView& input, StorageView& output) const;
+ };
+
+ class GatherAll : public Op {
+ public:
+ explicit GatherAll();
+ void operator()(const StorageView& input, StorageView& output) const;
+ private:
+ template
+ void compute(const StorageView& input, StorageView& output) const;
+ };
+ }
+}
\ No newline at end of file
diff --git a/include/ctranslate2/ops/ops.h b/include/ctranslate2/ops/ops.h
index 051c81acc..f03d0211a 100644
--- a/include/ctranslate2/ops/ops.h
+++ b/include/ctranslate2/ops/ops.h
@@ -37,3 +37,4 @@
#include "rotary.h"
#include "alibi_add.h"
#include "slide.h"
+#include "nccl_ops.h"
diff --git a/include/ctranslate2/replica_pool.h b/include/ctranslate2/replica_pool.h
index efc9824d1..8c8e15d8e 100644
--- a/include/ctranslate2/replica_pool.h
+++ b/include/ctranslate2/replica_pool.h
@@ -34,11 +34,13 @@ namespace ctranslate2 {
const Device device,
const ComputeType compute_type = ComputeType::DEFAULT,
const std::vector& device_indices = {0},
+ const bool tensor_parallel = false,
const ReplicaPoolConfig& config = {}) {
models::ModelLoader model_loader(model_path);
model_loader.device = device;
model_loader.device_indices = device_indices;
model_loader.compute_type = compute_type;
+ model_loader.tensor_parallel = tensor_parallel;
initialize_pool(model_loader, config);
}
diff --git a/include/ctranslate2/utils.h b/include/ctranslate2/utils.h
index c8e7ef78b..23c58cb82 100644
--- a/include/ctranslate2/utils.h
+++ b/include/ctranslate2/utils.h
@@ -4,6 +4,7 @@
#include
#include
#include
+#include "ctranslate2/types.h"
namespace ctranslate2 {
@@ -92,5 +93,7 @@ namespace ctranslate2 {
#endif
#define THROW_RUNTIME_ERROR(MESSAGE) THROW_EXCEPTION(std::runtime_error, MESSAGE)
#define THROW_INVALID_ARGUMENT(MESSAGE) THROW_EXCEPTION(std::invalid_argument, MESSAGE)
+#define SAFE_DIVIDE(x, y) ((y != 0 && (x % y == 0)) ? (x / y) : (throw std::runtime_error("Division has a remainder," \
+ "Model can't be ran with the tensor parallel mode in " + std::to_string(y) + " nodes")))
}
diff --git a/python/cpp/encoder.cc b/python/cpp/encoder.cc
index ea8b1a430..9a50923ac 100644
--- a/python/cpp/encoder.cc
+++ b/python/cpp/encoder.cc
@@ -71,7 +71,7 @@ namespace ctranslate2 {
>>> encoder.forward_batch([["▁Hello", "▁world", "!"]])
)pbdoc")
- .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(),
+ .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
@@ -80,6 +80,7 @@ namespace ctranslate2 {
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
+ py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes the encoder.
@@ -96,6 +97,7 @@ namespace ctranslate2 {
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
+ tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
@@ -111,6 +113,8 @@ namespace ctranslate2 {
"Number of encoders backing this instance.")
.def_property_readonly("num_queued_batches", &EncoderWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
+ .def_property_readonly("tensor_parallel", &EncoderWrapper::tensor_parallel,
+ "Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &EncoderWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
diff --git a/python/cpp/generator.cc b/python/cpp/generator.cc
index 981c6da68..93b1a229a 100644
--- a/python/cpp/generator.cc
+++ b/python/cpp/generator.cc
@@ -128,7 +128,7 @@ namespace ctranslate2 {
>>> generator.generate_batch([[""]], max_length=50, sampling_topk=20)
)pbdoc")
- .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(),
+ .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
@@ -137,6 +137,7 @@ namespace ctranslate2 {
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
+ py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes the generator.
@@ -153,6 +154,7 @@ namespace ctranslate2 {
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
+ tensor_parallel: run model with tensor parallel mode.
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
@@ -168,6 +170,8 @@ namespace ctranslate2 {
"Number of generators backing this instance.")
.def_property_readonly("num_queued_batches", &GeneratorWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
+ .def_property_readonly("tensor_parallel", &GeneratorWrapper::tensor_parallel,
+ "Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &GeneratorWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
diff --git a/python/cpp/module.cc b/python/cpp/module.cc
index 4a9e47561..4489d5314 100644
--- a/python/cpp/module.cc
+++ b/python/cpp/module.cc
@@ -87,4 +87,5 @@ PYBIND11_MODULE(_ext, m)
ctranslate2::python::register_encoder(m);
ctranslate2::python::register_whisper(m);
ctranslate2::python::register_wav2vec2(m);
+ ctranslate2::python::register_mpi(m);
}
diff --git a/python/cpp/module.h b/python/cpp/module.h
index 01fdbdf59..9c9a9a2ff 100644
--- a/python/cpp/module.h
+++ b/python/cpp/module.h
@@ -18,6 +18,7 @@ namespace ctranslate2 {
void register_translator(py::module& m);
void register_whisper(py::module& m);
void register_wav2vec2(py::module& m);
+ void register_mpi(py::module& m);
}
}
diff --git a/python/cpp/mpi.cc b/python/cpp/mpi.cc
new file mode 100644
index 000000000..01abf1157
--- /dev/null
+++ b/python/cpp/mpi.cc
@@ -0,0 +1,30 @@
+#include "module.h"
+
+#include
+
+#include "utils.h"
+
+namespace ctranslate2 {
+ namespace python {
+
+ void register_mpi(py::module& m) {
+ py::class_(
+ m, "MpiInfo",
+ R"pbdoc(
+ An object to manage the MPI communication between processes.
+ It provides information about MPI connexion.
+ )pbdoc")
+
+ .def_static("getNRanks", &ScopedMPISetter::getNRanks,
+ "Get the number of gpus running for the current model.")
+
+ .def_static("getCurRank", &ScopedMPISetter::getCurRank,
+ "Get the current rank of process.")
+
+ .def_static("getLocalRank", &ScopedMPISetter::getLocalRank,
+ "Get the current GPU id used by process.")
+ ;
+ }
+
+ }
+}
diff --git a/python/cpp/replica_pool.h b/python/cpp/replica_pool.h
index a735ea363..d71bf6b96 100644
--- a/python/cpp/replica_pool.h
+++ b/python/cpp/replica_pool.h
@@ -44,6 +44,7 @@ namespace ctranslate2 {
size_t inter_threads,
size_t intra_threads,
long max_queued_batches,
+ bool tensor_parallel,
py::object files)
: _model_loader(create_model_reader(model_path, files))
{
@@ -53,6 +54,7 @@ namespace ctranslate2 {
_model_loader.device_indices = std::visit(DeviceIndexResolver(), device_index);
_model_loader.compute_type = std::visit(ComputeTypeResolver(device), compute_type);
_model_loader.num_replicas_per_device = inter_threads;
+ _model_loader.tensor_parallel = tensor_parallel;
_pool_config.num_threads_per_replica = intra_threads;
_pool_config.max_queued_batches = max_queued_batches;
@@ -77,6 +79,10 @@ namespace ctranslate2 {
return compute_type_to_str(model()->effective_compute_type());
}
+ bool tensor_parallel() const {
+ return _model_loader.tensor_parallel;
+ }
+
size_t num_replicas() const {
return _pool->num_replicas();
}
diff --git a/python/cpp/translator.cc b/python/cpp/translator.cc
index b46d7ab9e..8e4a8a4be 100644
--- a/python/cpp/translator.cc
+++ b/python/cpp/translator.cc
@@ -33,6 +33,7 @@ namespace ctranslate2 {
size_t inter_threads,
size_t intra_threads,
long max_queued_batches,
+ bool tensor_parallel,
py::object files)
: ReplicaPoolHelper(model_path,
device,
@@ -41,6 +42,7 @@ namespace ctranslate2 {
inter_threads,
intra_threads,
max_queued_batches,
+ tensor_parallel,
files)
, _device(_model_loader.device)
, _device_index(_model_loader.device_indices)
@@ -378,7 +380,7 @@ namespace ctranslate2 {
>>> translator.translate_batch([["▁Hello", "▁world", "!"]])
)pbdoc")
- .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(),
+ .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
@@ -387,6 +389,7 @@ namespace ctranslate2 {
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
+ py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes the translator.
@@ -403,6 +406,7 @@ namespace ctranslate2 {
max_queued_batches: Maximum numbers of batches in the queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
+ tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
@@ -418,6 +422,8 @@ namespace ctranslate2 {
"Number of translators backing this instance.")
.def_property_readonly("num_queued_batches", &TranslatorWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
+ .def_property_readonly("tensor_parallel", &TranslatorWrapper::tensor_parallel,
+ "Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &TranslatorWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
diff --git a/python/cpp/wav2vec2.cc b/python/cpp/wav2vec2.cc
index ced116cb4..343caa158 100644
--- a/python/cpp/wav2vec2.cc
+++ b/python/cpp/wav2vec2.cc
@@ -27,7 +27,7 @@ namespace ctranslate2 {
https://github.com/facebookresearch/fairseq/tree/main/examples/wav2vec
)pbdoc")
- .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(),
+ .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
@@ -36,6 +36,7 @@ namespace ctranslate2 {
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
+ py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Wav2Vec2 model from a converted model.
@@ -52,6 +53,7 @@ namespace ctranslate2 {
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
+ tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
@@ -67,6 +69,8 @@ namespace ctranslate2 {
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &Wav2Vec2Wrapper::num_queued_batches,
"Number of batches waiting to be processed.")
+ .def_property_readonly("tensor_parallel", &Wav2Vec2Wrapper::tensor_parallel,
+ "Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &Wav2Vec2Wrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
diff --git a/python/cpp/whisper.cc b/python/cpp/whisper.cc
index cb1b45a7d..47be8ece7 100644
--- a/python/cpp/whisper.cc
+++ b/python/cpp/whisper.cc
@@ -163,7 +163,7 @@ namespace ctranslate2 {
.def_property_readonly("num_languages", &WhisperWrapper::num_languages,
"Returns the number of languages supported.")
- .def(py::init>&, const StringOrMap&, size_t, size_t, long, py::object>(),
+ .def(py::init>&, const StringOrMap&, size_t, size_t, long, bool, py::object>(),
py::arg("model_path"),
py::arg("device")="cpu",
py::kw_only(),
@@ -172,6 +172,7 @@ namespace ctranslate2 {
py::arg("inter_threads")=1,
py::arg("intra_threads")=0,
py::arg("max_queued_batches")=0,
+ py::arg("tensor_parallel")=false,
py::arg("files")=py::none(),
R"pbdoc(
Initializes a Whisper model from a converted model.
@@ -188,6 +189,7 @@ namespace ctranslate2 {
max_queued_batches: Maximum numbers of batches in the worker queue (-1 for unlimited,
0 for an automatic value). When the queue is full, future requests will block
until a free slot is available.
+ tensor_parallel: run model with tensor parallel mode
files: Load model files from the memory. This argument is a dictionary mapping
file names to file contents as file-like or bytes objects. If this is set,
:obj:`model_path` acts as an identifier for this model.
@@ -203,6 +205,8 @@ namespace ctranslate2 {
"Number of model workers backing this instance.")
.def_property_readonly("num_queued_batches", &WhisperWrapper::num_queued_batches,
"Number of batches waiting to be processed.")
+ .def_property_readonly("tensor_parallel", &WhisperWrapper::tensor_parallel,
+ "Run model with tensor parallel mode.")
.def_property_readonly("num_active_batches", &WhisperWrapper::num_active_batches,
"Number of batches waiting to be processed or currently processed.")
diff --git a/python/ctranslate2/__init__.py b/python/ctranslate2/__init__.py
index a80997645..88da68aec 100644
--- a/python/ctranslate2/__init__.py
+++ b/python/ctranslate2/__init__.py
@@ -30,6 +30,7 @@
GenerationResult,
GenerationStepResult,
Generator,
+ MpiInfo,
ScoringResult,
StorageView,
TranslationResult,
diff --git a/python/ctranslate2/specs/model_spec.py b/python/ctranslate2/specs/model_spec.py
index 4cb765636..28b2e4f21 100644
--- a/python/ctranslate2/specs/model_spec.py
+++ b/python/ctranslate2/specs/model_spec.py
@@ -291,6 +291,9 @@ def to_dict(self):
if not key.startswith("_")
}
+ def add_attribute(self, key, value):
+ self.__dict__[key] = value
+
def save_as_json(self, path):
"""Saves the configuration as a JSON file."""
with open(path, "w", encoding="utf-8") as config_file:
diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py
index 9de261c58..c3f8d91be 100644
--- a/python/ctranslate2/specs/transformer_spec.py
+++ b/python/ctranslate2/specs/transformer_spec.py
@@ -45,6 +45,7 @@ def __init__(
rms_norm: Use the root mean square layer normalization.
multi_query_attention: Use multi-query attention.
"""
+ self.multi_query_attention = multi_query_attention
self.num_heads = np.dtype("int16").type(num_heads)
self.pre_norm = pre_norm
self.activation = np.dtype("int8").type(activation)
@@ -207,6 +208,9 @@ def __init__(
for _ in range(num_layers)
]
self.start_from_zero_embedding = False
+ self.multi_query_attention = multi_query_attention or (
+ num_heads_kv != num_heads
+ )
if project_in_out:
self.project_in = common_spec.LinearSpec()
@@ -339,6 +343,9 @@ def __init__(
super().__init__()
self.encoder = encoder
self.decoder = decoder
+ self._config.add_attribute(
+ "multi_query_attention", self.encoder.multi_query_attention
+ )
@classmethod
def from_config(
@@ -467,6 +474,9 @@ def __init__(self, decoder: TransformerDecoderSpec):
super().__init__()
self.decoder = decoder
+ self._config.add_attribute(
+ "multi_query_attention", self.decoder.multi_query_attention
+ )
@classmethod
def from_config(
@@ -608,6 +618,9 @@ def __init__(
super().__init__()
self.encoder = encoder
+ self._config.add_attribute(
+ "multi_query_attention", self.encoder.multi_query_attention
+ )
if pooling_layer:
self.pooler_dense = common_spec.LinearSpec()
diff --git a/python/tools/prepare_build_environment_linux.sh b/python/tools/prepare_build_environment_linux.sh
index 0350e01e7..89f8293f6 100755
--- a/python/tools/prepare_build_environment_linux.sh
+++ b/python/tools/prepare_build_environment_linux.sh
@@ -27,7 +27,8 @@ else
cuda-cudart-devel-12-2-12.2.140-1 \
libcurand-devel-12-2-10.3.3.141-1 \
libcudnn8-devel-8.9.7.29-1.cuda12.2 \
- libcublas-devel-12-2-12.2.5.6-1
+ libcublas-devel-12-2-12.2.5.6-1 \
+ libnccl-devel-2.19.3-1+cuda12.2
ln -s cuda-12.2 /usr/local/cuda
ONEAPI_VERSION=2023.2.0
@@ -44,6 +45,15 @@ else
cd ..
rm -r oneDNN-*
+ OPENMPI_VERSION=4.1.6
+ curl -L -O https://download.open-mpi.org/release/open-mpi/v4.1/openmpi-${OPENMPI_VERSION}.tar.bz2
+ tar xf *.tar.bz2 && rm *.tar.bz2
+ cd openmpi-*
+ ./configure
+ make -j$(nproc) install
+ cd ..
+ rm -r openmpi-*
+ export LD_LIBRARY_PATH="/usr/local/lib/:$LD_LIBRARY_PATH"
fi
mkdir build-release && cd build-release
@@ -51,7 +61,7 @@ mkdir build-release && cd build-release
if [ "$CIBW_ARCHS" == "aarch64" ]; then
cmake -DCMAKE_BUILD_TYPE=Release -DBUILD_CLI=OFF -DWITH_MKL=OFF -DOPENMP_RUNTIME=COMP -DCMAKE_PREFIX_PATH="/opt/OpenBLAS" -DWITH_OPENBLAS=ON -DWITH_RUY=ON ..
else
- cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msse4.1" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DWITH_CUDA=ON -DWITH_CUDNN=ON -DCUDA_DYNAMIC_LOADING=ON -DCUDA_NVCC_FLAGS="-Xfatbin=-compress-all" -DCUDA_ARCH_LIST="Common" ..
+ cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_CXX_FLAGS="-msse4.1" -DBUILD_CLI=OFF -DWITH_DNNL=ON -DOPENMP_RUNTIME=COMP -DWITH_CUDA=ON -DWITH_CUDNN=ON -DCUDA_DYNAMIC_LOADING=ON -DCUDA_NVCC_FLAGS="-Xfatbin=-compress-all" -DCUDA_ARCH_LIST="Common" -DWITH_TENSOR_PARALLEL=ON ..
fi
VERBOSE=1 make -j$(nproc) install
diff --git a/src/cuda/mpi_stub.cc b/src/cuda/mpi_stub.cc
new file mode 100644
index 000000000..a2a69c4da
--- /dev/null
+++ b/src/cuda/mpi_stub.cc
@@ -0,0 +1,94 @@
+#include
+#include
+
+#define STR_HELPER(x) #x
+#define STR(x) STR_HELPER(x)
+
+#include
+
+#define OPENMPI_LIBNAME "libmpi.so." STR(OMPI_MAJOR_VERSION) STR(0)
+
+namespace ctranslate2 {
+
+ template
+ static Signature load_symbol(void* handle, const char* name, const char* library_name) {
+ void* symbol = dlsym(handle, name);
+ if (!symbol)
+ throw std::runtime_error("Cannot load symbol " + std::string(name)
+ + " from library " + std::string(library_name));
+ return reinterpret_cast(symbol);
+ }
+
+ static void* get_so_handle() {
+ static auto so_handle = []() {
+ void* handle = dlopen(OPENMPI_LIBNAME, RTLD_LAZY);
+ return handle;
+ }();
+ return so_handle;
+ }
+
+ template
+ static Signature load_symbol(const char* name) {
+ void* handle = get_so_handle();
+ if (!handle)
+ throw std::runtime_error("Library " + std::string(OPENMPI_LIBNAME)
+ + " is not found or cannot be loaded");
+ return load_symbol(handle, name, OPENMPI_LIBNAME);
+ }
+
+ template
+ static Signature load_symbol_global(const char* name) {
+ void* handle = get_so_handle();
+ if (!handle)
+ return nullptr;
+ return load_symbol(handle, name, OPENMPI_LIBNAME);
+ }
+}
+
+extern "C" {
+
+ int MPI_Allgather(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
+ void *recvbuf, int recvcount,
+ MPI_Datatype recvtype, MPI_Comm comm) {
+ using Signature = int(*)(const void *sendbuf, int sendcount, MPI_Datatype sendtype,
+ void *recvbuf, int recvcount,
+ MPI_Datatype recvtype, MPI_Comm comm);
+ static auto func = ctranslate2::load_symbol("MPI_Allgather");
+ return func(sendbuf, sendcount, sendtype, recvbuf, recvcount, recvtype, comm);
+ }
+
+ int MPI_Bcast(void *buffer, int count, MPI_Datatype datatype,
+ int root, MPI_Comm comm) {
+ using Signature = int(*)(void *buffer, int count, MPI_Datatype datatype,
+ int root, MPI_Comm comm);
+ static auto func = ctranslate2::load_symbol("MPI_Bcast");
+ return func(buffer, count, datatype, root, comm);
+ }
+
+ int MPI_Init(int *argc, char ***argv) {
+ using Signature = int(*)(int *argc, char ***argv);
+ static auto func = ctranslate2::load_symbol("MPI_Init");
+ return func(argc, argv);
+ }
+
+ int MPI_Finalize(void) {
+ using Signature = int(*)(void);
+ static auto func = ctranslate2::load_symbol("MPI_Finalize");
+ return func();
+ }
+
+ int MPI_Comm_rank(MPI_Comm comm, int *rank) {
+ using Signature = int(*)(MPI_Comm comm, int *size);
+ static auto func = ctranslate2::load_symbol("MPI_Comm_rank");
+ return func(comm, rank);
+ }
+
+ int MPI_Comm_size(MPI_Comm comm, int *size) {
+ using Signature = int(*)(MPI_Comm comm, int *size);
+ static auto func = ctranslate2::load_symbol("MPI_Comm_size");
+ return func(comm, size);
+ }
+}
+struct ompi_predefined_datatype_t* stub_mpi_datatype_null = ctranslate2::load_symbol_global("ompi_mpi_datatype_null");
+struct ompi_predefined_datatype_t* stub_ompi_mpi_byte = ctranslate2::load_symbol_global("ompi_mpi_byte");
+struct ompi_predefined_communicator_t* stub_ompi_mpi_comm_world = ctranslate2::load_symbol_global("ompi_mpi_comm_world");
\ No newline at end of file
diff --git a/src/cuda/mpi_stub.h b/src/cuda/mpi_stub.h
new file mode 100644
index 000000000..83803900a
--- /dev/null
+++ b/src/cuda/mpi_stub.h
@@ -0,0 +1,18 @@
+#pragma once
+
+#include
+
+#ifdef CT2_WITH_CUDA_DYNAMIC_LOADING
+extern struct ompi_predefined_datatype_t* stub_mpi_datatype_null;
+#define STUB_MPI_DATATYPE_NULL OMPI_PREDEFINED_GLOBAL(MPI_Datatype, *stub_mpi_datatype_null)
+
+extern struct ompi_predefined_datatype_t* stub_ompi_mpi_byte;
+#define STUB_MPI_BYTE OMPI_PREDEFINED_GLOBAL(MPI_Datatype, *stub_ompi_mpi_byte)
+
+extern struct ompi_predefined_communicator_t* stub_ompi_mpi_comm_world;
+#define STUB_MPI_COMM_WORLD OMPI_PREDEFINED_GLOBAL(MPI_Comm, *stub_ompi_mpi_comm_world)
+#else
+#define STUB_MPI_DATATYPE_NULL MPI_DATATYPE_NULL
+#define STUB_MPI_BYTE MPI_BYTE
+#define STUB_MPI_COMM_WORLD MPI_COMM_WORLD
+#endif
\ No newline at end of file
diff --git a/src/cuda/nccl_stub.cc b/src/cuda/nccl_stub.cc
new file mode 100644
index 000000000..669518cb2
--- /dev/null
+++ b/src/cuda/nccl_stub.cc
@@ -0,0 +1,93 @@
+#include
+
+#include
+
+#define STR_HELPER(x) #x
+#define STR(x) STR_HELPER(x)
+
+#include
+#define NCCL_LIBNAME "libnccl.so." STR(NCCL_MAJOR)
+
+#include
+
+namespace ctranslate2 {
+
+ template
+ static Signature load_symbol(void* handle, const char* name, const char* library_name) {
+ void* symbol = dlsym(handle, name);
+ if (!symbol)
+ throw std::runtime_error("Cannot load symbol " + std::string(name)
+ + " from library " + std::string(library_name));
+ return reinterpret_cast(symbol);
+ }
+ static inline void log_nccl_version(void* handle) {
+ using Signature = ncclResult_t(*)(int*);
+ const auto nccl_get_version = load_symbol(handle,
+ "ncclGetVersion",
+ NCCL_LIBNAME);
+ int version = 0;
+ nccl_get_version(&version);
+ spdlog::info("Loaded nccl library version {}", version);
+ }
+
+ static void* get_so_handle() {
+ static auto so_handle = []() {
+ void* handle = dlopen(NCCL_LIBNAME, RTLD_LAZY);
+ if (!handle)
+ throw std::runtime_error("Library " + std::string(NCCL_LIBNAME)
+ + " is not found or cannot be loaded");
+ log_nccl_version(handle);
+ return handle;
+ }();
+ return so_handle;
+ }
+
+ template
+ static Signature load_symbol(const char* name) {
+ void* handle = get_so_handle();
+ return load_symbol(handle, name, NCCL_LIBNAME);
+ }
+
+}
+
+extern "C" {
+ ncclResult_t ncclGetUniqueId(ncclUniqueId* uniqueId) {
+ using Signature = ncclResult_t(*)(ncclUniqueId* uniqueId);
+ static auto func = ctranslate2::load_symbol("ncclGetUniqueId");
+ return func(uniqueId);
+ }
+
+ ncclResult_t ncclCommInitRank(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank) {
+ using Signature = ncclResult_t(*)(ncclComm_t* comm, int nranks, ncclUniqueId commId, int rank);
+ static auto func = ctranslate2::load_symbol("ncclCommInitRank");
+ return func(comm, nranks, commId, rank);
+ }
+
+ ncclResult_t ncclCommDestroy(ncclComm_t comm) {
+ using Signature = ncclResult_t(*)(ncclComm_t comm);
+ static auto func = ctranslate2::load_symbol("ncclCommDestroy");
+ return func(comm);
+ }
+
+ ncclResult_t ncclCommAbort(ncclComm_t comm) {
+ using Signature = ncclResult_t(*)(ncclComm_t comm);
+ static auto func = ctranslate2::load_symbol("ncclCommAbort");
+ return func(comm);
+ }
+
+ ncclResult_t ncclAllReduce(const void* sendbuff, void* recvbuff, size_t count,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream) {
+ using Signature = ncclResult_t(*)(const void* sendbuff, void* recvbuff, size_t count,
+ ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, cudaStream_t stream);
+ static auto func = ctranslate2::load_symbol("ncclAllReduce");
+ return func(sendbuff, recvbuff, count, datatype, op, comm, stream);
+ }
+
+ ncclResult_t ncclAllGather(const void* sendbuff, void* recvbuff, size_t sendcount,
+ ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream) {
+ using Signature = ncclResult_t(*)(const void* sendbuff, void* recvbuff, size_t sendcount,
+ ncclDataType_t datatype, ncclComm_t comm, cudaStream_t stream);
+ static auto func = ctranslate2::load_symbol("ncclAllGather");
+ return func(sendbuff, recvbuff, sendcount, datatype, comm, stream);
+ }
+}
diff --git a/src/cuda/utils.h b/src/cuda/utils.h
index 29bc99a39..8c1c134fe 100644
--- a/src/cuda/utils.h
+++ b/src/cuda/utils.h
@@ -6,6 +6,10 @@
#include
#include
+#ifdef CT2_WITH_TENSOR_PARALLEL
+# include
+# include
+#endif
#ifdef CT2_WITH_CUDNN
# include
#endif
@@ -16,6 +20,24 @@
namespace ctranslate2 {
namespace cuda {
+#ifdef CT2_WITH_TENSOR_PARALLEL
+#define MPI_CHECK(ans) \
+ { \
+ int e = ans; \
+ if( e != MPI_SUCCESS ) \
+ THROW_RUNTIME_ERROR("MPI failed with error " \
+ + std::to_string(e)); \
+ }
+
+#define NCCL_CHECK(ans) \
+ { \
+ ncclResult_t r = ans; \
+ if( r != ncclSuccess ) \
+ THROW_RUNTIME_ERROR("NCCL failed with error " \
+ + std::to_string(r)); \
+ }
+#endif
+
#define CUDA_CHECK(ans) \
{ \
cudaError_t code = (ans); \
diff --git a/src/devices.cc b/src/devices.cc
index 3822cc3c3..47582f8be 100644
--- a/src/devices.cc
+++ b/src/devices.cc
@@ -3,6 +3,9 @@
#ifdef CT2_WITH_CUDA
# include "cuda/utils.h"
#endif
+#ifdef CT2_WITH_TENSOR_PARALLEL
+# include
+#endif
#include "device_dispatch.h"
@@ -115,5 +118,101 @@ namespace ctranslate2 {
(void)device;
#endif
}
+ // Initialize the static member variable
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ std::vector ScopedMPISetter::_nccl_comms;
+#endif
+ int my_rank = 0;
+ int local_rank = 0;
+ int n_ranks = 1;
+
+ ScopedMPISetter::ScopedMPISetter() {
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ // initializing MPI
+ MPI_CHECK(MPI_Init(nullptr, nullptr));
+ MPI_CHECK(MPI_Comm_rank(STUB_MPI_COMM_WORLD, &my_rank));
+ MPI_CHECK(MPI_Comm_size(STUB_MPI_COMM_WORLD, &n_ranks));
+
+ uint64_t hostHashs[n_ranks];
+ char hostname[1024];
+ getHostName(hostname, 1024);
+ hostHashs[my_rank] = getHostHash(hostname);
+ MPI_CHECK(MPI_Allgather(MPI_IN_PLACE, 0, STUB_MPI_DATATYPE_NULL,
+ hostHashs, sizeof(uint64_t), STUB_MPI_BYTE, STUB_MPI_COMM_WORLD));
+ for (int p = 0; p < n_ranks; p++) {
+ if (p == my_rank) {
+ break;
+ }
+ if (hostHashs[p] == hostHashs[my_rank]) {
+ local_rank++;
+ }
+ }
+ atexit(finalize);
+#endif
+ }
+
+ ScopedMPISetter::~ScopedMPISetter() = default;
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ uint64_t ScopedMPISetter::getHostHash(const char *string) {
+ // Based on DJB2, result = result * 33 + char
+ uint64_t result = 5381;
+ for (int c = 0; string[c] != '\0'; c++) {
+ result = ((result << 5) + result) + string[c];
+ }
+ return result;
+ }
+
+ void ScopedMPISetter::getHostName(char *hostname, int maxlen) {
+ gethostname(hostname, maxlen);
+ for (int i = 0; i < maxlen; i++) {
+ if (hostname[i] == '.') {
+ hostname[i] = '\0';
+ return;
+ }
+ }
+ }
+
+ ncclComm_t ScopedMPISetter::getNcclComm() {
+ static thread_local ncclComm_t comm;
+ static thread_local ncclUniqueId id;
+
+ if (comm == nullptr) {
+ int nRanks = ScopedMPISetter::getNRanks();
+ int myRank = ScopedMPISetter::getCurRank();
+ if (myRank == 0) {
+ ncclGetUniqueId(&id);
+ }
+ MPI_CHECK(MPI_Bcast((void *) &id, sizeof(id), STUB_MPI_BYTE, 0, STUB_MPI_COMM_WORLD));
+ NCCL_CHECK(ncclCommInitRank(&comm, nRanks, id, myRank));
+ _nccl_comms.push_back(&comm);
+ }
+ return comm;
+ }
+#endif
+
+ void ScopedMPISetter::finalize() {
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ for (auto* comm : _nccl_comms) {
+ //finalizing NCCL
+ if (*comm) {
+ NCCL_CHECK(ncclCommAbort(*comm));
+ NCCL_CHECK(ncclCommDestroy(*comm));
+ }
+ }
+ MPI_CHECK(MPI_Finalize());
+#endif
+ }
+
+ int ScopedMPISetter::getNRanks() {
+ return n_ranks;
+ }
+
+ int ScopedMPISetter::getCurRank() {
+ return my_rank;
+ }
+
+ int ScopedMPISetter::getLocalRank() {
+ return local_rank;
+ }
}
diff --git a/src/layers/attention.cc b/src/layers/attention.cc
index 2a066ba86..cf6074b2a 100644
--- a/src/layers/attention.cc
+++ b/src/layers/attention.cc
@@ -343,7 +343,10 @@ namespace ctranslate2 {
std::vector layers;
layers.reserve(num_linear_layers);
for (dim_t i = 0; i < num_linear_layers; ++i)
- layers.emplace_back(model, scope + "/linear_" + std::to_string(i));
+ if (i == (num_linear_layers - 1)) {
+ layers.emplace_back(model, scope + "/linear_" + std::to_string(i), nullptr, true);
+ } else
+ layers.emplace_back(model, scope + "/linear_" + std::to_string(i));
return layers;
}
@@ -376,11 +379,12 @@ namespace ctranslate2 {
bool pre_norm,
bool is_decoder,
Alibi* alibi)
- : _num_heads(num_heads)
+ : _tensor_parallel(model.tensor_parallel())
+ , _num_heads(_tensor_parallel ? SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks()) : num_heads)
, _self_attention(self_attention)
, _is_decoder(is_decoder)
, _linear(make_linear_layers(model, scope, self_attention))
- , _d_model(_linear.back().output_size())
+ , _d_model(_tensor_parallel ? SAFE_DIVIDE(_linear.back().output_size(), ScopedMPISetter::getNRanks()) : _linear.back().output_size())
, _d_head(model.get_attribute_with_default(scope + "/head_dim", _d_model / _num_heads))
, _pre_norm(pre_norm)
, _layer_norm(build_optional_layer(model, scope + "/layer_norm"))
@@ -392,11 +396,13 @@ namespace ctranslate2 {
, _queries_scale(model.get_attribute_with_default(
scope + "/queries_scale",
1.f / std::sqrt(static_cast(_d_head))))
- , _num_heads_kv(model.get_flag_with_default(scope + "/multi_query", false)
+ , _multi_query(model.get_flag_with_default(scope + "/multi_query", false))
+ , _num_heads_kv(_multi_query
? 1
- : model.get_attribute_with_default(scope + "/num_heads_kv",
- _num_heads))
- , _merge_time_and_head_dims(_num_heads_kv == 1
+ : (_tensor_parallel ? model.get_attribute_with_default(scope + "/num_heads_kv",
+ _num_heads * ScopedMPISetter::getNRanks()) / ScopedMPISetter::getNRanks()
+ : model.get_attribute_with_default(scope + "/num_heads_kv", _num_heads)))
+ , _merge_time_and_head_dims(_multi_query
&& !_relative_attention_bias
&& !_relative_position_keys
&& !_relative_position_values)
@@ -458,7 +464,7 @@ namespace ctranslate2 {
if (cached_keys == nullptr || cached_keys->empty()) {
_linear[1](values, fused_proj);
- if (_num_heads_kv == 1) {
+ if (_multi_query) {
if (values_padder)
values_padder->add_padding(fused_proj);
ops::Split(2, {_d_head, _d_head})(fused_proj, keys_proj, values_proj);
@@ -476,7 +482,7 @@ namespace ctranslate2 {
if (queries_proj.dim(1) == 1 && cached_keys)
beam_size = queries_proj.dim(0) / cached_keys->dim(0);
- if (_num_heads_kv == 1) {
+ if (_multi_query) {
if (queries_padder)
queries_padder->add_padding(queries_proj);
queries_proj.reshape({queries_proj.dim(0) / beam_size, -1, _d_head});
@@ -592,6 +598,13 @@ namespace ctranslate2 {
_linear.back()(context, output);
+ if (_tensor_parallel) {
+ Shape shape = output.shape();
+ StorageView tmp(std::move(shape), output.dtype(), output.device());
+ ops::ReduceAll ops_reduce_all(ops::ReduceAll::RED_OP::SUM);
+ ops_reduce_all(output, tmp);
+ output = std::move(tmp);
+ }
if (_layer_norm) {
ops::Add()(queries, output, output);
diff --git a/src/layers/common.cc b/src/layers/common.cc
index 5f70c4336..22b2a55bd 100644
--- a/src/layers/common.cc
+++ b/src/layers/common.cc
@@ -265,7 +265,8 @@ namespace ctranslate2 {
Dense::Dense(const models::Model& model,
const std::string& scope,
- const ops::ActivationType* activation_type)
+ const ops::ActivationType* activation_type,
+ const bool is_layer_out)
: _packed_weight(false)
, _weight(get_linear_weight(model, scope, &_packed_weight))
, _bias(model.get_variable_if_exists(scope + "/bias"))
@@ -294,6 +295,7 @@ namespace ctranslate2 {
/*shift_to_uint8=*/bool(_u8_shift_compensation),
/*round_before_cast=*/model.round_before_cast_in_quantization())
, _dequantize_op(activation_type)
+ , _is_layer_out(is_layer_out)
{
}
@@ -339,13 +341,50 @@ namespace ctranslate2 {
const StorageView* compensation = (_partial_u8_shift_compensation.empty()
? _u8_shift_compensation
: &_partial_u8_shift_compensation);
+
+ bool affected_by_tp = ScopedMPISetter::getNRanks() > 1 && _is_layer_out;
if (_quantized_gemm) {
const auto device = input.device();
StorageView qinput(_weight.dtype(), device);
StorageView qinput_scale(_qscale->dtype(), device);
StorageView qoutput(DataType::INT32, device);
- _quantize_op(input, qinput, qinput_scale);
+ const StorageView* pinput = &input;
+
+ if (affected_by_tp) {
+ StorageView input_reshaped(input.shape(), input.dtype(), input.device());
+ Shape shape = input.shape();
+ dim_t batch_size = shape[0];
+ dim_t depth = shape[shape.size() - 1];
+ dim_t length = shape[shape.size() - 2];
+ StorageView input_gather_all({1, depth * ScopedMPISetter::getNRanks(), batch_size * length}, input.dtype(), input.device());
+ ops::Transpose transpose_op({0, 2, 1});
+ // Transpose input B x L x D -> B x D x L
+ if (batch_size > 1) {
+ input_reshaped.shallow_copy(const_cast(input));
+ input_reshaped.reshape({1, batch_size * length, depth});
+ pinput = &input_reshaped;
+ }
+ StorageView input_t(input.dtype(), input.device());
+ transpose_op(*pinput, input_t);
+ ops::GatherAll gather_ops;
+ gather_ops(input_t, input_gather_all);
+ input_t.resize({1, batch_size * length, depth * ScopedMPISetter::getNRanks()});
+ transpose_op(input_gather_all, input_t);
+ StorageView qinput_tmp(_weight.dtype(), device);
+ _quantize_op(input_t, qinput_tmp, qinput_scale);
+ dim_t index = _weight.dim(-1) * ScopedMPISetter::getCurRank();
+ dim_t size = _weight.dim(-1);
+ ops::Slide(-1, index, size)(qinput_tmp, qinput);
+ if (batch_size > 1)
+ qinput.reshape({batch_size, length, depth});
+ }
+ else {
+ _quantize_op(input, qinput, qinput_scale);
+ }
+
_gemm_op(qinput, *weight, qoutput, compensation);
+ if (affected_by_tp && ScopedMPISetter::getCurRank() == 0)
+ bias = nullptr;
_dequantize_op(qoutput,
qinput_scale,
*qscale,
@@ -354,6 +393,8 @@ namespace ctranslate2 {
output,
bias);
} else {
+ if (affected_by_tp && ScopedMPISetter::getCurRank() == 0)
+ bias = nullptr;
_gemm_op(input, *weight, output, nullptr, bias);
}
}
diff --git a/src/layers/transformer.cc b/src/layers/transformer.cc
index 056a01f99..97b5669c1 100644
--- a/src/layers/transformer.cc
+++ b/src/layers/transformer.cc
@@ -14,7 +14,8 @@ namespace ctranslate2 {
, _activation_type(activation_type)
, _ff1(model, scope + "/linear_0", &_activation_type)
, _ff1_noact(build_optional_layer(model, scope + "/linear_0_noact"))
- , _ff2(model, scope + "/linear_1") {
+ , _ff2(model, scope + "/linear_1", nullptr, true)
+ , _tensor_parallel(model.tensor_parallel()) {
}
void FeedForwardNetwork::operator()(const StorageView& input, StorageView& output) const {
@@ -29,7 +30,6 @@ namespace ctranslate2 {
StorageView inner(dtype, device);
_ff1(*x, inner);
-
if (_ff1_noact) {
StorageView linear(dtype, device);
(*_ff1_noact)(*x, linear);
@@ -38,6 +38,14 @@ namespace ctranslate2 {
_ff2(inner, output);
+ if (_tensor_parallel) {
+ Shape shape = output.shape();
+ StorageView tmp(std::move(shape), output.dtype(), output.device());
+ ops::ReduceAll red_op(ops::ReduceAll::RED_OP::SUM);
+ red_op(output, tmp);
+ output = std::move(tmp);
+ }
+
if (_layer_norm) {
ops::Add()(input, output, output);
@@ -250,6 +258,7 @@ namespace ctranslate2 {
, _position_encoder(_layers.front()->get_self_attention().has_positional_embeddings()
? nullptr
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
+ , _tensor_parallel(model.tensor_parallel())
{
}
@@ -278,8 +287,12 @@ namespace ctranslate2 {
padder->remove_padding(input);
}
+ int num_heads = _num_heads;
+ if (_tensor_parallel) {
+ num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
+ }
lengths_mask = std::make_unique(
- layers::MultiHeadAttention::prepare_length_mask(*lengths, _num_heads, max_time));
+ layers::MultiHeadAttention::prepare_length_mask(*lengths, num_heads, max_time));
}
StorageView position_bias(output.dtype(), output.device());
@@ -334,7 +347,8 @@ namespace ctranslate2 {
: build_position_encoder(model, scope + "/position_encodings", _embeddings))
, _with_encoder_attention(_layers.front()->has_cross_attention())
, _proj(model, scope + "/projection")
- , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0)) {
+ , _sliding_window(model.get_attribute_with_default(scope + "/sliding_window", 0))
+ , _tensor_parallel(model.tensor_parallel()) {
dim_t alignment_layer = (
model.get_attribute_with_default(scope + "/alignment_layer", -1));
@@ -497,13 +511,19 @@ namespace ctranslate2 {
input_padder->remove_padding(layer_in);
}
+ dim_t num_heads = _num_heads;
+ if (_tensor_parallel) {
+ num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
+ }
+
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
*lengths,
- _num_heads,
+ num_heads,
max_time,
/*mask_future=*/true,
multi_query);
+
if (step > 0)
ops::Add()(lengths_mask, StorageView(int32_t(step)), lengths_mask);
@@ -527,10 +547,14 @@ namespace ctranslate2 {
}
if (memory_lengths) {
+ dim_t num_heads = _num_heads;
+ if (_tensor_parallel) {
+ num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
+ }
const dim_t beam_size = batch_size / memory_lengths->dim(0);
memory_lengths_mask = std::make_unique(
layers::MultiHeadAttention::prepare_length_mask(*memory_lengths,
- _num_heads,
+ num_heads,
beam_size > 1 ? beam_size : max_time));
}
}
@@ -585,9 +609,13 @@ namespace ctranslate2 {
if (i > 0) {
auto max_tokens = _sliding_window + layer_in_chunk->dim(1);
StorageView tmp_lengths = StorageView(Shape{layer_in_chunk->dim(0)}, int32_t(max_tokens), device);
+ int num_heads = _num_heads;
+ if (_tensor_parallel) {
+ num_heads = SAFE_DIVIDE(num_heads, ScopedMPISetter::getNRanks());
+ }
StorageView lengths_mask = layers::MultiHeadAttention::prepare_length_mask(
tmp_lengths,
- _num_heads,
+ num_heads,
max_tokens,
/*mask_future=*/true,
multi_query);
diff --git a/src/models/model.cc b/src/models/model.cc
index 7855f9583..97bf3d1b5 100644
--- a/src/models/model.cc
+++ b/src/models/model.cc
@@ -5,6 +5,7 @@
#include "ctranslate2/models/model_factory.h"
#include "ctranslate2/ops/ops.h"
#include "ctranslate2/utils.h"
+#include
#ifdef CT2_WITH_CUDA
# include "cuda/utils.h"
@@ -17,6 +18,27 @@ namespace ctranslate2 {
static const std::string binary_file = "model.bin";
static const std::string config_file = "config.json";
+ enum class VARIABLE_TYPE {
+ ATTN_LINEAR_0_WEIGHT,
+ ATTN_LINEAR_0_WEIGHT_SCALE,
+ ATTN_LINEAR_0_BIAS,
+ ATTN_LINEAR_1_WEIGHT,
+ ATTN_LINEAR_1_WEIGHT_SCALE,
+ ATTN_LINEAR_1_BIAS,
+ ATTN_LINEAR_2_WEIGHT,
+ SELF_ATTN_LINEAR_0_WEIGHT,
+ SELF_ATTN_LINEAR_0_WEIGHT_SCALE,
+ SELF_ATTN_LINEAR_0_BIAS,
+ SELF_ATTN_LINEAR_1_WEIGHT,
+ FFN_LINEAR_0_WEIGHT,
+ FFN_LINEAR_0_BIAS,
+ FFN_LINEAR_0_WEIGHT_SCALE,
+ FFN_LINEAR_0_NOACT_WEIGHT,
+ FFN_LINEAR_0_NOACT_WEIGHT_SCALE,
+ FFN_LINEAR_0_NOACT_BIAS,
+ FFN_LINEAR_1_WEIGHT,
+ OTHERS,
+ };
static inline void report_stream_error(const std::streampos position,
const size_t read_size,
@@ -84,13 +106,13 @@ namespace ctranslate2 {
return;
// Move variables back to the CPU device.
- if (src_device != Device::CPU) {
+ if (src_device != Device::CPU && dst_device == Device::CPU) {
ScopedDeviceSetter scoped_device_setter(src_device, src_device_index);
move_variables_to_device(variables, Device::CPU);
}
// Move variables to the destination device.
- if (dst_device != Device::CPU) {
+ if (src_device == Device::CPU && dst_device != Device::CPU) {
ScopedDeviceSetter scoped_device_setter(dst_device, dst_device_index);
move_variables_to_device(variables, dst_device);
}
@@ -389,6 +411,31 @@ namespace ctranslate2 {
}
}
+ static void split_variables(StorageView variable, int dim, std::vector& partitions_size, std::vector& outputs)
+ {
+ if (variable.rank() < 1 || variable.rank() > 2)
+ throw std::runtime_error("Unsupported split variables which has the rank of matrix more than 2."
+ "Current variable has the rank " + std::to_string(variable.rank()));
+
+ //std::vector outputs(num, StorageView(variable.dtype(), variable.device()));
+
+ size_t num = partitions_size.size();
+ std::vector p_outputs(num);
+
+ for (int i = 0; i < num; ++i) {
+ p_outputs[i] = &outputs[i];
+ }
+ ops::Split(dim, partitions_size)(variable, p_outputs);
+ }
+
+ static bool replace(std::string& str, const std::string& from, const std::string& to) {
+ size_t start_pos = str.find(from);
+ if (start_pos == std::string::npos)
+ return false;
+ str.replace(start_pos, from.length(), to);
+ return true;
+ }
+
static void check_version(const size_t saved_version,
const size_t current_version,
const std::string& version_type) {
@@ -403,24 +450,114 @@ namespace ctranslate2 {
+ "(Forward compatibility is not guaranteed.)");
}
+ static VARIABLE_TYPE classify_variable(const std::string& name) {
+ std::regex pattern_self_attn("/self_attention/linear_(\\d+)/(\\w+)");
+ std::regex pattern_attn("/attention/linear_(\\d+)/(\\w+)");
+ std::regex pattern_ffn("/ffn/linear_(\\d+)(\\w*)/(\\w+)");
+
+ std::smatch match;
+
+ if (std::regex_search(name, match, pattern_self_attn)) {
+ int layer_number = std::stoi(match[1]);
+ std::string parameterName = match[2];
+
+ switch (layer_number) {
+ case 0:
+ if (parameterName == "bias")
+ return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_BIAS;
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT;
+ else
+ return VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT_SCALE;
+ case 1:
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::SELF_ATTN_LINEAR_1_WEIGHT;
+ default:
+ return VARIABLE_TYPE::OTHERS;
+ };
+ }
+ else if (std::regex_search(name, match, pattern_attn)) {
+ int layer_number = std::stoi(match[1]);
+ std::string parameterName = match[2];
+
+ switch (layer_number) {
+ case 0:
+ if (parameterName == "bias")
+ return VARIABLE_TYPE::ATTN_LINEAR_0_BIAS;
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::ATTN_LINEAR_0_WEIGHT;
+ return VARIABLE_TYPE::ATTN_LINEAR_0_WEIGHT_SCALE;
+ case 1:
+ if (parameterName == "bias")
+ return VARIABLE_TYPE::ATTN_LINEAR_1_BIAS;
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT;
+ return VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT_SCALE;
+ case 2:
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::ATTN_LINEAR_2_WEIGHT;
+ default:
+ return VARIABLE_TYPE::OTHERS;
+ };
+ }
+ else if (std::regex_search(name, match, pattern_ffn)) {
+ int layer_number = std::stoi(match[1]);
+ std::string noact = match[2];
+ std::string parameterName = match[3];
+
+ switch (layer_number) {
+ case 0:
+ if (noact == "noact" && parameterName == "bias")
+ return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_BIAS;
+ if (noact == "noact" && parameterName == "weight")
+ return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_WEIGHT;
+ if (noact == "noact")
+ return VARIABLE_TYPE::FFN_LINEAR_0_NOACT_WEIGHT_SCALE;
+ if (parameterName == "bias")
+ return VARIABLE_TYPE::FFN_LINEAR_0_BIAS;
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::FFN_LINEAR_0_WEIGHT;
+ return VARIABLE_TYPE::FFN_LINEAR_0_WEIGHT_SCALE;
+ case 1:
+ if (parameterName == "weight")
+ return VARIABLE_TYPE::FFN_LINEAR_1_WEIGHT;
+ default:
+ return VARIABLE_TYPE::OTHERS;
+ };
+ }
+
+ return VARIABLE_TYPE::OTHERS;
+ }
+
std::shared_ptr Model::load(const std::string& path,
Device device,
int device_index,
- ComputeType compute_type) {
+ ComputeType compute_type,
+ bool tensor_parallel) {
ModelFileReader model_reader(path);
- return load(model_reader, device, device_index, compute_type);
+ return load(model_reader, device, device_index, compute_type, tensor_parallel);
}
std::shared_ptr Model::load(ModelReader& model_reader,
Device device,
int device_index,
- ComputeType compute_type) {
+ ComputeType compute_type,
+ bool tensor_parallel) {
{
// Log the system configuration the first time a model is loaded.
static std::once_flag log_once;
std::call_once(log_once, log_system_config);
}
+ int world_size;
+ int current_index;
+ if (tensor_parallel) {
+ ScopedMPISetter mpi_setter = ScopedMPISetter();
+ device_index = ScopedMPISetter::getLocalRank();
+ current_index = ScopedMPISetter::getCurRank();
+ world_size = ScopedMPISetter::getNRanks();
+ }
+
{
// Check that the device and device index are valid.
set_device_index(device, device_index);
@@ -448,6 +585,7 @@ namespace ctranslate2 {
auto model = create_model(spec);
model->_binary_version = binary_version;
model->_spec_revision = spec_revision;
+ model->_tensor_parallel = tensor_parallel;
check_version(spec_revision, model->current_spec_revision(), "revision");
@@ -460,6 +598,19 @@ namespace ctranslate2 {
// Load the variables.
const auto num_variables = consume(model_file);
model->_variable_index.reserve(num_variables);
+
+ // check config for tensor parallel
+ bool multi_query_attention = false;
+ if (tensor_parallel)
+ {
+
+ if (model->config.contains("multi_query_attention"))
+ multi_query_attention = model->config["multi_query_attention"];
+ else
+ spdlog::warn("Running model in mode tensor parallel but missing multi_query_attention option in"
+ " the config.json could lead to error! Try using the latest version of converters");
+ }
+
for (uint32_t i = 0; i < num_variables; ++i) {
auto name = consume(model_file);
const size_t rank = consume(model_file);
@@ -481,6 +632,89 @@ namespace ctranslate2 {
StorageView variable(std::move(shape), dtype);
consume(model_file, num_bytes, static_cast(variable.buffer()));
+ if (tensor_parallel) {
+ int outer_dim = 0;
+ int inner_dim = 1;
+ static dim_t model_dim = 0;
+ static dim_t total_dim = 0;
+
+ auto variable_type = classify_variable(name);
+ if (variable_type != VARIABLE_TYPE::OTHERS) {
+ std::vector outputs(world_size, StorageView(variable.dtype(), variable.device()));
+ switch (variable_type) {
+ case VARIABLE_TYPE::SELF_ATTN_LINEAR_1_WEIGHT:
+ case VARIABLE_TYPE::ATTN_LINEAR_2_WEIGHT:
+ case VARIABLE_TYPE::FFN_LINEAR_1_WEIGHT:
+ {
+ dim_t output_per_partition_dim = SAFE_DIVIDE(variable.dim(inner_dim), world_size);
+ std::vector partitions_size(world_size, output_per_partition_dim);
+ split_variables(std::move(variable), inner_dim, partitions_size, outputs);
+ break;
+ }
+ case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT:
+ case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_WEIGHT_SCALE:
+ case VARIABLE_TYPE::SELF_ATTN_LINEAR_0_BIAS:
+ {
+ std::vector partitions_size;
+ if (multi_query_attention) {
+ if (model_dim == 0) {
+ model_dim = variable.dim(-1);
+ total_dim = variable.dim(outer_dim);
+ }
+ dim_t q_dim = SAFE_DIVIDE(model_dim, world_size);
+ dim_t kv_dim = SAFE_DIVIDE((total_dim - model_dim), (2 * world_size));
+ partitions_size = std::vector(world_size, q_dim);
+ std::vector kv_part(world_size * 2, kv_dim);
+ partitions_size.insert(partitions_size.end(), kv_part.begin(), kv_part.end());
+ }
+ else {
+ dim_t dim_per_kqv_per_partition = SAFE_DIVIDE(variable.dim(outer_dim) / 3, world_size);
+ partitions_size = std::vector(3 * world_size, dim_per_kqv_per_partition);
+ }
+ std::vector outputs_tmp = std::vector(partitions_size.size(),
+ StorageView(variable.dtype(),
+ variable.device()));
+ split_variables(std::move(variable), outer_dim, partitions_size, outputs_tmp);
+ for (int i = 0; i < world_size; i++) {
+ std::vector output_linear = {&outputs_tmp[i], &outputs_tmp[i + world_size],
+ &outputs_tmp[i + world_size * 2]};
+ StorageView tmp(variable.dtype(), variable.device());
+ ops::Concat(static_cast(outer_dim))(output_linear, tmp);
+ outputs[i] = std::move(tmp);
+ }
+ break;
+ }
+ case VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT:
+ case VARIABLE_TYPE::ATTN_LINEAR_1_WEIGHT_SCALE:
+ case VARIABLE_TYPE::ATTN_LINEAR_1_BIAS:
+ {
+ std::vector partitions_size;
+ dim_t dim_per_kqv_per_partition = SAFE_DIVIDE(variable.dim(outer_dim) / 2, world_size);
+ partitions_size = std::vector(2 * world_size, dim_per_kqv_per_partition);
+ std::vector outputs_tmp = std::vector(partitions_size.size(),
+ StorageView(variable.dtype(),
+ variable.device()));
+ split_variables(std::move(variable), outer_dim, partitions_size, outputs_tmp);
+ for (int i = 0; i < world_size; i++) {
+ std::vector output_linear = {&outputs_tmp[i], &outputs_tmp[i + world_size]};
+ StorageView tmp(variable.dtype(), variable.device());
+ ops::Concat(static_cast(outer_dim))(output_linear, tmp);
+ outputs[i] = std::move(tmp);
+ }
+ break;
+ }
+ default:
+ {
+ dim_t output_per_partition_dim = SAFE_DIVIDE(variable.dim(outer_dim), world_size);
+ std::vector partitions_size(world_size, output_per_partition_dim);
+ split_variables(std::move(variable), outer_dim, partitions_size, outputs);
+ }
+ };
+ if (outputs.size() > current_index && !outputs[current_index].empty())
+ variable = std::move(outputs[current_index]);
+ }
+ }
+
model->register_variable(std::move(name), std::move(variable));
}
@@ -558,16 +792,24 @@ namespace ctranslate2 {
if (device == Device::CUDA && !cuda::have_same_compute_capability(device_indices))
throw std::invalid_argument("Cannot use multiple GPUs with different Compute Capabilities "
"for the same model");
+ if (tensor_parallel && device != Device::CUDA) {
+ throw std::invalid_argument("Tensor Parallel mode can run only on cuda");
+ }
#endif
std::vector> models;
+ if (tensor_parallel && (device_indices.size() > 1)) {
+ spdlog::warn("Running model in mode tensor parallel does not support"
+ " running independently a model in each device");
+ }
+
models.reserve(device_indices.size() * num_replicas_per_device);
for (const size_t device_index : device_indices) {
std::shared_ptr model;
if (models.empty())
- model = Model::load(*model_reader, device, device_index, compute_type);
+ model = Model::load(*model_reader, device, device_index, compute_type, tensor_parallel);
else
model = models.back()->copy_to(device, device_index);
diff --git a/src/ops/nccl_ops.cc b/src/ops/nccl_ops.cc
new file mode 100644
index 000000000..756ce0332
--- /dev/null
+++ b/src/ops/nccl_ops.cc
@@ -0,0 +1,23 @@
+#include "ctranslate2/ops/nccl_ops.h"
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ ReduceAll::ReduceAll(ReduceAll::RED_OP op)
+ : _reduce_op(op) {
+ }
+
+ void ReduceAll::operator()(const StorageView& input, StorageView& output) const {
+ PROFILE("ReduceAll");
+ DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute(input, output)));
+ }
+
+ GatherAll::GatherAll() = default;
+
+ void GatherAll::operator()(const StorageView& input, StorageView& output) const {
+ PROFILE("ReduceAll");
+ DEVICE_AND_TYPE_DISPATCH(input.device(), input.dtype(), (compute(input, output)));
+ }
+ }
+}
diff --git a/src/ops/nccl_ops_cpu.cc b/src/ops/nccl_ops_cpu.cc
new file mode 100644
index 000000000..d0f63750b
--- /dev/null
+++ b/src/ops/nccl_ops_cpu.cc
@@ -0,0 +1,23 @@
+#include "ctranslate2/ops/nccl_ops.h"
+#include "dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+ template
+ void ReduceAll::compute(const StorageView& /*input*/, StorageView& /*output*/) const {
+ throw std::runtime_error("reduce all is not applied for the cpu");
+ }
+
+ template
+ void GatherAll::compute(const StorageView& /*input*/, StorageView& /*output*/) const {
+ throw std::runtime_error("gather all is not applied for the cpu");
+ }
+ #define DECLARE_IMPL(T) \
+ template void ReduceAll::compute(const StorageView&, \
+ StorageView&) const; \
+ template void GatherAll::compute(const StorageView&, \
+ StorageView&) const;
+ DECLARE_ALL_TYPES(DECLARE_IMPL)
+ }
+}
diff --git a/src/ops/nccl_ops_gpu.cu b/src/ops/nccl_ops_gpu.cu
new file mode 100644
index 000000000..1b607ef6a
--- /dev/null
+++ b/src/ops/nccl_ops_gpu.cu
@@ -0,0 +1,93 @@
+#include "ctranslate2/ops/nccl_ops.h"
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ #include
+ #include "cuda/utils.h"
+#endif
+#include "type_dispatch.h"
+
+namespace ctranslate2 {
+ namespace ops {
+
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ ncclDataType_t getNcclDataTypeFromDataType(DataType type) {
+ switch (type) {
+#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
+ case DataType::BFLOAT16:
+ return ncclBfloat16;
+#endif
+ case DataType::FLOAT16:
+ return ncclFloat16;
+ case DataType::FLOAT32:
+ return ncclFloat32;
+ case DataType::INT32:
+ return ncclInt32;
+ case DataType::INT8:
+ return ncclInt8;
+ default:
+ throw std::invalid_argument("The current datatype " + std::to_string(static_cast(type)) +
+ " is not supported for the mode tensor parallel ");
+ }
+ }
+
+ ncclRedOp_t redop_to_nccl_op(ReduceAll::RED_OP op) {
+ switch (op) {
+ case ReduceAll::RED_OP::SUM:
+ return ncclSum;
+ case ReduceAll::RED_OP::PROD:
+ return ncclProd;
+ case ReduceAll::RED_OP::MAX:
+ return ncclMax;
+ case ReduceAll::RED_OP::MIN:
+ return ncclMin;
+#if NCCL_VERSION_CODE >= NCCL_VERSION(2,10,0)
+ case ReduceAll::RED_OP::AVG:
+ return ncclAvg;
+#endif
+ default:
+ throw std::runtime_error("the current reduce operation " + std::to_string(static_cast(op)) + " is not supported");
+ }
+ }
+#endif
+
+ template
+ void ReduceAll::compute(const StorageView& input, StorageView& output) const {
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ // initializing NCCL
+ dim_t data_size = input.size();
+ ncclComm_t comm = ScopedMPISetter::getNcclComm();
+ ncclDataType_t ncclDataType = getNcclDataTypeFromDataType(input.dtype());
+ ncclRedOp_t ncclOp = redop_to_nccl_op(_reduce_op);
+ NCCL_CHECK(ncclAllReduce(input.data(), output.data(),
+ data_size, ncclDataType, ncclOp,
+ comm, cuda::get_cuda_stream()));
+
+ cudaStreamSynchronize(cuda::get_cuda_stream());
+#endif
+ (void)input;
+ (void)output;
+ }
+
+ template
+ void GatherAll::compute(const StorageView& input, StorageView& output) const {
+#ifdef CT2_WITH_TENSOR_PARALLEL
+ // initializing NCCL
+ dim_t data_size = input.size();
+ ncclComm_t comm = ScopedMPISetter::getNcclComm();
+ ncclDataType_t ncclDataType = getNcclDataTypeFromDataType(input.dtype());
+ NCCL_CHECK(ncclAllGather(input.data(), output.data(),
+ data_size, ncclDataType,
+ comm, cuda::get_cuda_stream()));
+
+ cudaStreamSynchronize(cuda::get_cuda_stream());
+#endif
+ (void)input;
+ (void)output;
+ }
+#define DECLARE_IMPL(T) \
+ template void GatherAll::compute(const StorageView&, \
+ StorageView&) const; \
+ template void ReduceAll::compute(const StorageView&, \
+ StorageView&) const;
+ DECLARE_ALL_TYPES(DECLARE_IMPL)
+ }
+}
diff --git a/src/utils.cc b/src/utils.cc
index f0eb29509..4f8bde57c 100644
--- a/src/utils.cc
+++ b/src/utils.cc
@@ -189,5 +189,4 @@ namespace ctranslate2 {
return features;
}
-
}
diff --git a/tools/benchmark_tensor_parallel/README.md b/tools/benchmark_tensor_parallel/README.md
new file mode 100644
index 000000000..3ee92f2ca
--- /dev/null
+++ b/tools/benchmark_tensor_parallel/README.md
@@ -0,0 +1,18 @@
+## Benchmark tools
+
+This directory contains script to test the tensor parallelism mode.
+
+### Requirements
+
+* Python 3
+* Following this [doc](../../docs/parallel.md#model-and-tensor-parallelism) to configure the environment.
+
+```bash
+python3 -m pip install -r requirements.txt
+```
+
+### Usage
+
+```text
+mpirun -np 2 -hostfile hostfile python3 benchmark.py --mode --model_path --src --target --batch_size
+```
\ No newline at end of file
diff --git a/tools/benchmark_tensor_parallel/benchmark.py b/tools/benchmark_tensor_parallel/benchmark.py
new file mode 100644
index 000000000..12a3e35fd
--- /dev/null
+++ b/tools/benchmark_tensor_parallel/benchmark.py
@@ -0,0 +1,172 @@
+import ctranslate2
+import argparse
+import os
+import collections
+import time
+import GPUtil
+import sentencepiece as spm
+import concurrent.futures
+
+B_INST, E_INST = "[INST]", "[/INST]"
+B_SYS, E_SYS = "<>\n", "\n<>\n\n"
+
+
+class BenchmarkResult(
+ collections.namedtuple(
+ "BenchmarkResult",
+ (
+ "generation_time",
+ "num_tokens",
+ "max_gpu_mem",
+ ),
+ )
+):
+ pass
+
+
+def build_prompt(sp, inputs):
+ prompt_tokens = []
+ for question in inputs:
+ input_tokens = [""] + sp.encode_as_pieces(
+ f"{B_INST} {question.strip()} {E_INST}"
+ )
+ prompt_tokens.append(input_tokens)
+ return prompt_tokens
+
+
+def count_tokens(generated_token):
+ count = 0
+ for output in generated_token:
+ count += len(output)
+ return count
+
+
+def avg_tokens(generated_token):
+ return count_tokens(generated_token) / len(generated_token)
+
+
+def process_prompt(generator, max_generation_length, generated_token, prompt):
+ step_results = generator.generate_tokens(
+ prompt,
+ max_length=max_generation_length,
+ sampling_temperature=0.6,
+ sampling_topk=20,
+ sampling_topp=1,
+ )
+ for step_result in step_results:
+ batch_id = step_result.batch_id
+ generated_token[batch_id].append(step_result.token)
+
+
+def benchmark_generation(generator,
+ sp,
+ prompt_tokens,
+ generated_file,
+ mode,
+ batch_size):
+ max_generation_length = 512
+ generated_token = [[] for _ in range(len(prompt_tokens))]
+ generated_text = ["" for _ in range(len(prompt_tokens))]
+ tokens_buffer = []
+ elapsed_time = None
+ num_tokens = 0
+
+ if mode == "sequence":
+ start_all = time.time()
+ for i in range(0, len(prompt_tokens), batch_size):
+ step_results = generator.generate_tokens(
+ prompt_tokens[i:i + batch_size],
+ max_length=max_generation_length,
+ sampling_temperature=0.6,
+ sampling_topk=20,
+ sampling_topp=1,
+ )
+ for step_result in step_results:
+ batch_id = step_result.batch_id
+ generated_token[batch_id].append(step_result.token)
+ end_all = time.time()
+ elapsed_time = end_all - start_all
+ num_tokens = count_tokens(generated_token)
+ elif mode == "parallel":
+ nb_process = len(prompt_tokens) / batch_size + 1
+ start_all = time.time()
+ with concurrent.futures.ThreadPoolExecutor(max_workers=nb_process) as executor:
+ futures = [executor.submit(process_prompt, generator, max_generation_length, generated_token,
+ prompt_tokens[index:index + batch_size])
+ for index in range(0, len(prompt_tokens), batch_size)]
+ num_tokens = count_tokens(generated_token)
+ end_all = time.time()
+ elapsed_time = end_all - start_all
+
+ memory_gpus = float(GPUtil.getGPUs()[0].memoryUsed)
+
+ # save answer to file
+ for index in range(0, len(generated_token)):
+ for token in generated_token[index]:
+ is_new_word = token.startswith("▁")
+ if is_new_word and tokens_buffer:
+ word = sp.decode(tokens_buffer)
+ if word:
+ if generated_text[index]:
+ word = ' ' + word
+ generated_text[index] += word
+ tokens_buffer = []
+ tokens_buffer.append(token)
+ if tokens_buffer:
+ word = sp.decode(tokens_buffer)
+ if generated_text[index]:
+ word = ' ' + word
+ generated_text[index] += word
+ tokens_buffer = []
+
+ # write result to target file
+ target_file = os.path.abspath(generated_file)
+ if ctranslate2.MpiInfo.getCurRank() == 0:
+ with open(target_file, 'w') as file:
+ for index in range(len(generated_text)):
+ file.write(f"answer{index}: ")
+ file.write(generated_text[index])
+ file.write(f"\n\n")
+
+ return BenchmarkResult(elapsed_time, num_tokens, memory_gpus)
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter
+ )
+ parser.add_argument(
+ "--mode",
+ choices=["sequence", "parallel"],
+ default="sequence",
+ help="benchmark in parallel or sequence mode",
+ )
+ parser.add_argument("--model_path", type=str, help="model path")
+ parser.add_argument("--src", type=str, help="source file")
+ parser.add_argument("--target", type=str, help="target file")
+ parser.add_argument("--batch_size", type=int, help="batch size")
+ args = parser.parse_args()
+
+ print("Loading the model...")
+ generator = ctranslate2.Generator(args.model_path, device="cuda", tensor_parallel=True, inter_threads=2)
+ sp = spm.SentencePieceProcessor(os.path.join(args.model_path, "tokenizer.model"))
+
+ if not os.path.exists(args.src):
+ raise Exception("No source file found: " + args.src)
+ # Open the file in read mode
+ with open(args.src, 'r') as file:
+ # Read all lines from the file and create a list
+ inputs = file.readlines()
+
+ prompt_tokens = build_prompt(sp, inputs)
+ result = benchmark_generation(generator, sp, prompt_tokens, args.target, args.mode, args.batch_size)
+ if ctranslate2.MpiInfo.getCurRank() == 0:
+ print("Benchmark result (%d sample(s)):" % len(prompt_tokens))
+ print("- Generation time: %.2f s" % result.generation_time)
+ print("- Number of tokens: %d" % result.num_tokens)
+ print("- Throughput: %.1f" % (result.num_tokens / result.generation_time))
+ print("- max. GPU memory usage: %dMB" % int(result.max_gpu_mem))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tools/benchmark_tensor_parallel/requirements.txt b/tools/benchmark_tensor_parallel/requirements.txt
new file mode 100644
index 000000000..533257d10
--- /dev/null
+++ b/tools/benchmark_tensor_parallel/requirements.txt
@@ -0,0 +1,3 @@
+ctranslate2>=4.1.0
+sentencepiece
+GPUtil
\ No newline at end of file