Skip to content

Commit

Permalink
Alignment train optimization (facebookresearch#2200)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: fairinternal/fairseq-py#2200

The expected alignment for p-choose is the performance bottleneck that needs to be optimized. The solution is to implement a custom operator to reduce the kernel launch overhead, and optimize the implementations of some operations.

Some key optimizations:

* Use a contiguous alpha array to avoid array concatenation. The original version create an array for each slice of alpha and concat them in the end.
* Implement cumprod using prod operation directly. It used log-cumsum-exp operations before.
* Implement cumprod using cuda CUB library which is more efficient than scan operation in pytorch.

Reviewed By: cndn

Differential Revision: D30033767

fbshipit-source-id: 853c1c2d366838d6bcfa0863999f217a394e46a7
  • Loading branch information
renganxu authored and facebook-github-bot committed Oct 6, 2021
1 parent dd3bd3c commit ecea95c
Show file tree
Hide file tree
Showing 8 changed files with 695 additions and 26 deletions.
166 changes: 166 additions & 0 deletions examples/operators/alignment_train_cpu.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/

#include <torch/extension.h> // @manual=//caffe2:torch_extension
#include <algorithm>

namespace {

template <typename T>
void exclusiveCumprod(
const T* p_choose,
T* cumprod_1mp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len) {
// cumprod_1mp = 1 - p_choose
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t src = 0; src < src_len; src++) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
cumprod_1mp[idx] = 1 - p_choose[idx];
}
}
}

// Implementing exclusive cumprod in the innermost dimension
// cumprod_1mp = cumprod(1 - p_choose)
// There is cumprod in pytorch, however there is no exclusive mode.
// cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i]
// exclusive means
// cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i]
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
uint32_t idx_offset = b * tgt_len * src_len + tgt * src_len;
T prev = cumprod_1mp[idx_offset];
// index [b][tgt][0]
cumprod_1mp[idx_offset] = (T)1.0;
T curr;
for (uint32_t src = 1; src < src_len; src++) {
uint32_t idx = idx_offset + src;
curr = cumprod_1mp[idx];
cumprod_1mp[idx] = cumprod_1mp[idx - 1] * prev;
prev = curr;
}
}
}
}

template <typename T>
void clamp(
const T* cumprod_1mp,
T* cumprod_1mp_clamp,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
T min_val,
T max_val) {
for (uint32_t b = 0; b < bsz; b++) {
for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t src = 0; src < src_len; src++) {
uint32_t idx = b * tgt_len * src_len + tgt * src_len + src;
if (cumprod_1mp[idx] < min_val) {
cumprod_1mp_clamp[idx] = min_val;
} else if (cumprod_1mp[idx] > max_val) {
cumprod_1mp_clamp[idx] = max_val;
} else {
cumprod_1mp_clamp[idx] = cumprod_1mp[idx];
}
}
}
}
}

template <typename T>
void alignmentTrainCPUImpl(
const T* p_choose,
T* alpha,
uint32_t bsz,
uint32_t tgt_len,
uint32_t src_len,
float eps) {
// p_choose: bsz , tgt_len, src_len
// cumprod_1mp: bsz , tgt_len, src_len
// cumprod_1mp_clamp : bsz, tgt_len, src_len
// alpha: bsz + 1, tgt_len, src_len

uint32_t elements = bsz * tgt_len * src_len;
T* cumprod_1mp = new T[elements];
T* cumprod_1mp_clamp = new T[elements];

exclusiveCumprod<T>(p_choose, cumprod_1mp, bsz, tgt_len, src_len);
clamp<T>(
cumprod_1mp, cumprod_1mp_clamp, bsz, tgt_len, src_len, (T)eps, (T)1.0);

// ai = p_i * cumprod(1 − pi) * cumsum(a_i / cumprod(1 − pi))

// Initialize alpha [:, 0, 0]
for (uint32_t b = 0; b < bsz; b++) {
alpha[b * tgt_len * src_len] = 1.0;
}

for (uint32_t tgt = 0; tgt < tgt_len; tgt++) {
for (uint32_t b = 0; b < bsz; b++) {
uint32_t alpha_idx, inout_idx;
T prev_scan = 0, curr_scan, out;
for (uint32_t src = 0; src < src_len; src++) {
// Apply scan/cumsum
if (tgt == 0) {
// alpha index is [b][tgt][src]
alpha_idx = b * tgt_len * src_len + src;
} else {
// alpha index is [b][tgt-1][src]
alpha_idx = b * tgt_len * src_len + (tgt - 1) * src_len + src;
}
// input index is [b][tgt][src]
inout_idx = b * tgt_len * src_len + tgt * src_len + src;
curr_scan = prev_scan + alpha[alpha_idx] / cumprod_1mp_clamp[inout_idx];

out = curr_scan * p_choose[inout_idx] * cumprod_1mp[inout_idx];
alpha[inout_idx] = std::min<T>(std::max<T>(out, 0), 1.0);
prev_scan = curr_scan;
}
}
}

free(cumprod_1mp);
free(cumprod_1mp_clamp);
}

void alignmentTrainCPU(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
uint32_t bsz = p_choose.size(0);
uint32_t tgt_len = p_choose.size(1);
uint32_t src_len = p_choose.size(2);

AT_DISPATCH_FLOATING_TYPES_AND2(
torch::ScalarType::Half,
torch::ScalarType::BFloat16,
p_choose.scalar_type(),
"alignmentCPUImpl",
[&]() {
alignmentTrainCPUImpl<scalar_t>(
p_choose.data_ptr<scalar_t>(),
alpha.data_ptr<scalar_t>(),
bsz,
tgt_len,
src_len,
eps);
});
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"alignment_train_cpu",
&alignmentTrainCPU,
"expected_alignment_from_p_choose (CPU)");
}

} // namespace
31 changes: 31 additions & 0 deletions examples/operators/alignment_train_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/

#include "alignment_train_cuda.h"
#include "utils.h"

namespace {

void alignmentTrainCUDA(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps) {
CHECK_INPUT(p_choose);
CHECK_INPUT(alpha);

alignmentTrainCUDAWrapper(p_choose, alpha, eps);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def(
"alignment_train_cuda",
&alignmentTrainCUDA,
"expected_alignment_from_p_choose (CUDA)");
}

} // namespace
16 changes: 16 additions & 0 deletions examples/operators/alignment_train_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <torch/extension.h> // @manual=//caffe2:torch_extension

void alignmentTrainCUDAWrapper(
const torch::Tensor& p_choose,
torch::Tensor& alpha,
float eps);
Loading

0 comments on commit ecea95c

Please sign in to comment.