Skip to content

Commit

Permalink
Merged PR 25733: Fused inplace ReLU and Dropout in transformer FFN layer
Browse files Browse the repository at this point in the history
* First attempt at fused inplace ReLU and Dropout in transformer FFN layer
* Adds optional output projection to SSRU.

For large FFN blocks and dropout about 20-25% speed improvement during training.
  • Loading branch information
emjotde committed Sep 26, 2022
1 parent cfc33f5 commit 1f2929d
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 36 deletions.
2 changes: 2 additions & 0 deletions src/common/config_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,8 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) {
8);
cli.add<bool>("--transformer-no-projection",
"Omit linear projection after multi-head attention (transformer)");
cli.add<bool>("--transformer-rnn-projection",
"Add linear projection after rnn layer (transformer)");
cli.add<bool>("--transformer-pool",
"Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)");
cli.add<int>("--transformer-dim-ffn",
Expand Down
28 changes: 22 additions & 6 deletions src/graph/expression_operators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -676,13 +676,29 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
}
}

Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) {
auto graph = a->graph();
// @TODO: unify all these
Expr affineWithReluDropout(Expr x, Expr W, Expr bias, float dropProb) {
auto graph = x->graph();
if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu) {
// not doing any dropout in inference mode
return Expression<AffineWithReluNodeOp>(x, W, bias);
} else {
Expr output = affine(x, W, bias);
int dimModel = output->shape()[-1];
int dimTime = output->shape()[-2];
output = dropoutReluInplace(output, dropProb, {dimTime, dimModel});
return output;
}
}

if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu)
return Expression<AffineWithReluNodeOp>(a, b, bias, transA, transB, scale);
else
return relu(affine(a, b, bias, transA, transB, scale));
Expr dropoutReluInplace(Expr x, float dropProb, Shape shape) {
if(dropProb == 0) {
return relu(x);
} else {
auto graph = x->graph();
auto mask = graph->dropoutMask(dropProb, shape);
return Expression<DropoutReluInplaceNodeOp>(x, mask);
}
}

// @TODO: Not a great place to check this
Expand Down
13 changes: 7 additions & 6 deletions src/graph/expression_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,12 +493,10 @@ Expr affine(Expr a,
/**
* As above, but efficiently applies relu transformation to output. For inference only.
*/
Expr affineWithRelu(Expr a,
Expr b,
Expr bias,
bool transA = false,
bool transB = false,
float scalar = 1.f);
Expr affineWithReluDropout(Expr a,
Expr b,
Expr bias,
float dropProb = 0.f);

/**
* Computes the dot product of CSR-tensor @p A with @p B.
Expand Down Expand Up @@ -971,6 +969,7 @@ static inline Expr dropout(Expr x, float dropProb, Shape shape) {
return dropout(x, mask);
}


/**
* Performs dropout with a given probably.
*/
Expand All @@ -980,6 +979,8 @@ static inline Expr dropout(Expr x, float dropProb) {
return dropout(x, dropProb, x->shape());
}

Expr dropoutReluInplace(Expr x, float dropProb, Shape shape);

/**
* Shifts the elements of an expression by a per-axis offset @p shift
* padded with @p padValue.
Expand Down
21 changes: 9 additions & 12 deletions src/graph/node_operators_binary.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,16 +431,13 @@ class AffineWithReluNodeOp : public NaryNodeOp {
public:
AffineWithReluNodeOp(Expr a,
Expr b,
Expr bias,
bool transA,
bool transB,
float scalar)
: NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)),
transA_(transA),
transB_(transB),
scalar_(scalar) {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
Expr bias)
: NaryNodeOp({a, b, bias}, newShape(a, b, false, false)),
transA_(false),
transB_(false),
scalar_(1.0) {
ABORT_IF(!graph()->isInference(),
"AffineWithReluNodeOp currently only supported for inference");
}

Shape newShape(Expr a, Expr b, bool transA, bool transB) {
Expand All @@ -464,8 +461,8 @@ class AffineWithReluNodeOp : public NaryNodeOp {
}

NodeOps forwardOps() override {
ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu,
"AffineWithReluNodeOp currently only supported for inference on GPU");
ABORT_IF(!graph()->isInference(),
"AffineWithReluNodeOp currently only supported for inference");

return {
NodeOp(Affine(val_,
Expand Down
41 changes: 41 additions & 0 deletions src/graph/node_operators_unary.h
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,8 @@ class ReshapeNodeOp : public UnaryNodeOp {
}
};



// @TODO: add version with access to backward step
// This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise
// i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the
Expand Down Expand Up @@ -897,6 +899,45 @@ class CallbackNodeOp : public ReshapeNodeOp {
}
};

class DropoutReluInplaceNodeOp : public ReshapeNodeOp {
private:
Expr mask_;

public:
DropoutReluInplaceNodeOp(Expr node, Expr mask)
: ReshapeNodeOp(node, node->shape()),
mask_(mask) {}

void forward() override {
using namespace marian::functional;
Element(_1 = ReLU(_1 * _2), val(), mask_->val());
}

void backward() override {
using namespace marian::functional;
Element(_1 = _1 * ReLUback(_2) * _3, grad(), val(), mask_->val());
}

const std::string type() override { return "dropoutReluInplace"; }

virtual size_t hash() override {
size_t seed = ReshapeNodeOp::hash();
util::hash_combine(seed, mask_->hash());
return seed;
}

virtual bool equal(Expr node) override {
if(!ReshapeNodeOp::equal(node))
return false;
auto cnode = std::dynamic_pointer_cast<DropoutReluInplaceNodeOp>(node);
if(!cnode)
return false;
if(mask_ != cnode->mask_)
return false;
return true;
}
};

// @TODO: review if still required as this is an ugly hack anyway.
// Memory less operator that clips gradients during backward step
// Executes this as an additional operation on the gradient.
Expand Down
8 changes: 6 additions & 2 deletions src/layers/generic.h
Original file line number Diff line number Diff line change
Expand Up @@ -234,12 +234,16 @@ static inline Expr denseInline(Expr x,
auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros());

if(actName == "relu") {
x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework
x = affineWithReluDropout(x, W, b, dropProb); // fused operator for transformer FFN
} else {
x = affine(x, W, b);
x = activationByName(actName)(x);

int dimModel = x->shape()[-1];
int dimTime = x->shape()[-2];
x = dropout(x, dropProb, {dimTime, dimModel});
}
x = dropout(x, dropProb); // @TODO: check for infernce?

return x;
}

Expand Down
1 change: 1 addition & 0 deletions src/models/encoder_classifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ class EncoderClassifier : public EncoderClassifierBase {

modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-rnn-projection");
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
Expand Down
1 change: 1 addition & 0 deletions src/models/encoder_decoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ EncoderDecoder::EncoderDecoder(Ptr<ExpressionGraph> graph, Ptr<Options> options)

modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-rnn-projection");
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-decoder-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
Expand Down
1 change: 1 addition & 0 deletions src/models/encoder_pooler.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ class EncoderPooler : public EncoderPoolerBase {

modelFeatures_.insert("transformer-heads");
modelFeatures_.insert("transformer-no-projection");
modelFeatures_.insert("transformer-rnn-projection");
modelFeatures_.insert("transformer-dim-ffn");
modelFeatures_.insert("transformer-ffn-depth");
modelFeatures_.insert("transformer-ffn-activation");
Expand Down
16 changes: 13 additions & 3 deletions src/models/transformer.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,11 @@ class Transformer : public EncoderOrDecoderBase {
auto output = input;
for(auto op : ops) {
// dropout
if (op == 'd')
output = dropout(output, dropProb);
if (op == 'd') {
int dimModel = output->shape()[-1];
int dimTime = output->shape()[-2];
output = dropout(output, dropProb, {dimTime, dimModel});
}
// layer normalization
else if (op == 'n')
output = layerNorm(output, prefix, "_pre");
Expand Down Expand Up @@ -435,7 +438,7 @@ class Transformer : public EncoderOrDecoderBase {

// the stack of FF layers
for(int i = 1; i < depthFfn; ++i)
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb);
output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb);
output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel, initFn);

auto opsPost = opt<std::string>("transformer-postprocess");
Expand Down Expand Up @@ -538,6 +541,13 @@ class Transformer : public EncoderOrDecoderBase {
decoderState = rnn->lastCellStates()[0];
output = transposeTimeBatch(output);

if(opt<bool>("transformer-rnn-projection", false)) {
int dimModel = output->shape()[-1];
auto Wo = graph_->param(prefix + "_Wo", {dimModel, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f));
auto bo = graph_->param(prefix + "_bo", {1, dimModel}, inits::zeros());
output = affine(output, Wo, bo); // [-4: beam depth, -3: batch size, -2: 1, -1: vector dim]
}

auto opsPost = opt<std::string>("transformer-postprocess");
output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb);

Expand Down
3 changes: 3 additions & 0 deletions src/tensors/gpu/element.inc
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ template void marian::gpu::Element<marian::functional::Assign<marian::functional
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >>(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Sgn, marian::functional::Assignee<1> >, marian::functional::Capture>, marian::functional::BinaryFunctor<marian::functional::elem::Pow, marian::functional::Capture, marian::functional::BinaryFunctor<marian::functional::elem::Clip, marian::functional::UnaryFunctor<marian::functional::elem::Floor, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::UnaryFunctor<marian::functional::elem::Abs, marian::functional::BinaryFunctor<marian::functional::elem::Div, marian::functional::Assignee<1>, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor<marian::functional::elem::Log, marian::functional::Capture> > >, marian::functional::Capture> > > >, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Cos, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::Tan, marian::functional::Assignee<2> > >, marian::Tensor, marian::Tensor);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::BinaryFunctor<marian::functional::elem::NEq, marian::functional::Assignee<2>, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::sReLU, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::UnaryFunctor<marian::functional::elem::sReLU, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::Assignee<2> > > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
template void marian::gpu::Element<marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::sReLUBack, marian::functional::Assignee<2> > >, marian::functional::Assignee<3> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase> >(marian::functional::Assign<marian::functional::Var<1>, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::BinaryFunctor<marian::functional::elem::Mult, marian::functional::Assignee<1>, marian::functional::UnaryFunctor<marian::functional::elem::sReLUBack, marian::functional::Assignee<2> > >, marian::functional::Assignee<3> > >, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>, IntrusivePtr<marian::TensorBase>);
// How to add new specializations:
// When you use a new specialization, it will cause a link error of this form (example):
// .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element<marian::functional::Assign< ... > ( ... )'
Expand Down
8 changes: 2 additions & 6 deletions src/tests/dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
using namespace marian;

int main(int argc, char** argv) {
auto c = New<Config>(argc, argv);
auto c = parseOptions(argc, argv, cli::mode::scoring, false);

auto type = c->get<size_t>("cpu-threads") > 0
? DeviceType::cpu
Expand All @@ -20,11 +20,7 @@ int main(int argc, char** argv) {

for(int i = 0; i < 10; ++i) {
g->clear();
auto mask1 = g->dropoutMask(0.2, {10, 3072});
auto mask2 = g->dropoutMask(0.3, {1, 3072});
auto mask = mask1 + mask2;
debug(mask1, "mask1");
debug(mask2, "mask2");
auto mask = g->dropoutMask(0.2, {1000, 16384});
debug(mask, "mask");
g->forward();
}
Expand Down
2 changes: 1 addition & 1 deletion src/tests/units/operator_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -595,7 +595,7 @@ void tests(DeviceType device, Type floatType = Type::float32) {
auto aff1 = affine(A, B, bias);
auto aff2 = dot(A, B) + bias;

auto affRelu1 = affineWithRelu(A, B, bias);
auto affRelu1 = affineWithReluDropout(A, B, bias);
auto affRelu2 = relu(dot(A, B) + bias);

graph->forward();
Expand Down

0 comments on commit 1f2929d

Please sign in to comment.