Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/fix position bias in tp #1713

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@

### Fixes and improvements

## [v4.3.0](https://github.com/OpenNMT/CTranslate2/releases/tag/v4.3.0) (2024-05-17)

### New features
* Support phi-3 (8k and 128k) (#1700 and #1680)

### Fixes and improvements
* Fix regression Flash Attention (#1695)

## [v4.2.1](https://github.com/OpenNMT/CTranslate2/releases/tag/v4.2.1) (2024-04-24)

Note: Because of the increasing of package's size (> 100 MB), the release v4.2.0 was pushed unsuccessfully.
Expand Down
2 changes: 1 addition & 1 deletion python/ctranslate2/version.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Version information."""

__version__ = "4.2.1"
__version__ = "4.3.0"
2 changes: 1 addition & 1 deletion src/devices.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ namespace ctranslate2 {
for (auto* comm : _nccl_comms) {
//finalizing NCCL
if (*comm) {
NCCL_CHECK(ncclCommAbort(*comm));
NCCL_CHECK(ncclCommFinalize(*comm));
NCCL_CHECK(ncclCommDestroy(*comm));
}
}
Expand Down
15 changes: 13 additions & 2 deletions src/layers/attention.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#include "ctranslate2/layers/attention.h"
#include "ctranslate2/ops/split.h"
#include "ctranslate2/utils.h"


#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -210,11 +212,20 @@ namespace ctranslate2 {
is_decoder,
with_cache ? key_length - 1 : 0);
}
StorageView* position_bias_per_gpu = position_bias;
StorageView position_bias_tmp(position_bias->dtype(), position_bias->device());
if (ScopedMPISetter::getCurRank() != 0) {
const dim_t num_head_per_gpu = SAFE_DIVIDE(position_bias->dim(0), ScopedMPISetter::getNRanks());
ops::Slide slide_ops(0, num_head_per_gpu * ScopedMPISetter::getCurRank(),
num_head_per_gpu, true);
slide_ops(*position_bias, position_bias_tmp);
position_bias_per_gpu = &position_bias_tmp;
}

DEVICE_AND_TYPE_DISPATCH(output.device(), output.dtype(),
primitives<D>::add_batch_broadcast(position_bias->data<T>(),
primitives<D>::add_batch_broadcast(position_bias_per_gpu->data<T>(),
output.data<T>(),
position_bias->size(),
position_bias_per_gpu->size(),
output.size()));
}

Expand Down
Loading