forked from facebookresearch/fairseq
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Alignment train optimization (facebookresearch#2200)
Summary: Pull Request resolved: fairinternal/fairseq-py#2200 The expected alignment for p-choose is the performance bottleneck that needs to be optimized. The solution is to implement a custom operator to reduce the kernel launch overhead, and optimize the implementations of some operations. Some key optimizations: * Use a contiguous alpha array to avoid array concatenation. The original version create an array for each slice of alpha and concat them in the end. * Implement cumprod using prod operation directly. It used log-cumsum-exp operations before. * Implement cumprod using cuda CUB library which is more efficient than scan operation in pytorch. Reviewed By: cndn Differential Revision: D30033767 fbshipit-source-id: 853c1c2d366838d6bcfa0863999f217a394e46a7
- Loading branch information
1 parent
dd3bd3c
commit ecea95c
Showing
8 changed files
with
695 additions
and
26 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
/** | ||
* Copyright 2017-present, Facebook, Inc. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include <torch/extension.h> // @manual=//caffe2:torch_extension | ||
#include <algorithm> | ||
|
||
namespace { | ||
|
||
template <typename T> | ||
void exclusiveCumprod( | ||
const T* p_choose, | ||
T* cumprod_1mp, | ||
uint32_t bsz, | ||
uint32_t tgt_len, | ||
uint32_t src_len) { | ||
// cumprod_1mp = 1 - p_choose | ||
for (uint32_t b = 0; b < bsz; b++) { | ||
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { | ||
for (uint32_t src = 0; src < src_len; src++) { | ||
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; | ||
cumprod_1mp[idx] = 1 - p_choose[idx]; | ||
} | ||
} | ||
} | ||
|
||
// Implementing exclusive cumprod in the innermost dimension | ||
// cumprod_1mp = cumprod(1 - p_choose) | ||
// There is cumprod in pytorch, however there is no exclusive mode. | ||
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] | ||
// exclusive means | ||
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] | ||
for (uint32_t b = 0; b < bsz; b++) { | ||
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { | ||
uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len; | ||
T prev = cumprod_1mp[idx_offset]; | ||
// index [b][tgt][0] | ||
cumprod_1mp[idx_offset] = (T)1.0; | ||
T curr; | ||
for (uint32_t src = 1; src < src_len; src++) { | ||
uint32_t idx = idx_offset + src; | ||
curr = cumprod_1mp[idx]; | ||
cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev; | ||
prev = curr; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
void clamp( | ||
const T* cumprod_1mp, | ||
T* cumprod_1mp_clamp, | ||
uint32_t bsz, | ||
uint32_t tgt_len, | ||
uint32_t src_len, | ||
T min_val, | ||
T max_val) { | ||
for (uint32_t b = 0; b < bsz; b++) { | ||
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { | ||
for (uint32_t src = 0; src < src_len; src++) { | ||
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src; | ||
if (cumprod_1mp[idx] < min_val) { | ||
cumprod_1mp_clamp[idx] = min_val; | ||
} else if (cumprod_1mp[idx] > max_val) { | ||
cumprod_1mp_clamp[idx] = max_val; | ||
} else { | ||
cumprod_1mp_clamp[idx] = cumprod_1mp[idx]; | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
void alignmentTrainCPUImpl( | ||
const T* p_choose, | ||
T* alpha, | ||
uint32_t bsz, | ||
uint32_t tgt_len, | ||
uint32_t src_len, | ||
float eps) { | ||
// p_choose: bsz , tgt_len, src_len | ||
// cumprod_1mp: bsz , tgt_len, src_len | ||
// cumprod_1mp_clamp : bsz, tgt_len, src_len | ||
// alpha: bsz + 1, tgt_len, src_len | ||
|
||
uint32_t elements = bsz * tgt_len * src_len; | ||
T* cumprod_1mp = new T[elements]; | ||
T* cumprod_1mp_clamp = new T[elements]; | ||
|
||
exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len); | ||
clamp<T>( | ||
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0); | ||
|
||
// ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi)) | ||
|
||
// Initialize alpha [:, 0, 0] | ||
for (uint32_t b = 0; b < bsz; b++) { | ||
alpha[b * tgt_len * src_len] = 1.0; | ||
} | ||
|
||
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) { | ||
for (uint32_t b = 0; b < bsz; b++) { | ||
uint32_t alpha_idx, inout_idx; | ||
T prev_scan = 0, curr_scan, out; | ||
for (uint32_t src = 0; src < src_len; src++) { | ||
// Apply scan/cumsum | ||
if (tgt == 0) { | ||
// alpha index is [b][tgt][src] | ||
alpha_idx = b * tgt_len * src_len + src; | ||
} else { | ||
// alpha index is [b][tgt-1][src] | ||
alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src; | ||
} | ||
// input index is [b][tgt][src] | ||
inout_idx = b * tgt_len * src_len + tgt * src_len + src; | ||
curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx]; | ||
|
||
out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx]; | ||
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0); | ||
prev_scan = curr_scan; | ||
} | ||
} | ||
} | ||
|
||
free(cumprod_1mp); | ||
free(cumprod_1mp_clamp); | ||
} | ||
|
||
void alignmentTrainCPU( | ||
const torch::Tensor& p_choose, | ||
torch::Tensor& alpha, | ||
float eps) { | ||
uint32_t bsz = p_choose.size(0); | ||
uint32_t tgt_len = p_choose.size(1); | ||
uint32_t src_len = p_choose.size(2); | ||
|
||
AT_DISPATCH_FLOATING_TYPES_AND2( | ||
torch::ScalarType::Half, | ||
torch::ScalarType::BFloat16, | ||
p_choose.scalar_type(), | ||
"alignmentCPUImpl", | ||
[&]() { | ||
alignmentTrainCPUImpl<scalar_t>( | ||
p_choose.data_ptr<scalar_t>(), | ||
alpha.data_ptr<scalar_t>(), | ||
bsz, | ||
tgt_len, | ||
src_len, | ||
eps); | ||
}); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def( | ||
"alignment_train_cpu", | ||
&alignmentTrainCPU, | ||
"expected_alignment_from_p_choose (CPU)"); | ||
} | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
/** | ||
* Copyright 2017-present, Facebook, Inc. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#include "alignment_train_cuda.h" | ||
#include "utils.h" | ||
|
||
namespace { | ||
|
||
void alignmentTrainCUDA( | ||
const torch::Tensor& p_choose, | ||
torch::Tensor& alpha, | ||
float eps) { | ||
CHECK_INPUT(p_choose); | ||
CHECK_INPUT(alpha); | ||
|
||
alignmentTrainCUDAWrapper(p_choose, alpha, eps); | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { | ||
m.def( | ||
"alignment_train_cuda", | ||
&alignmentTrainCUDA, | ||
"expected_alignment_from_p_choose (CUDA)"); | ||
} | ||
|
||
} // namespace |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
/** | ||
* Copyright 2017-present, Facebook, Inc. | ||
* All rights reserved. | ||
* | ||
* This source code is licensed under the license found in the | ||
* LICENSE file in the root directory of this source tree. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <torch/extension.h> // @manual=//caffe2:torch_extension | ||
|
||
void alignmentTrainCUDAWrapper( | ||
const torch::Tensor& p_choose, | ||
torch::Tensor& alpha, | ||
float eps); |
Oops, something went wrong.