diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 338933d9f..c9ab45f81 100644 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -263,6 +263,8 @@ void ConfigParser::addOptionsModel(cli::CLIWrapper& cli) { 8); cli.add("--transformer-no-projection", "Omit linear projection after multi-head attention (transformer)"); + cli.add("--transformer-rnn-projection", + "Add linear projection after rnn layer (transformer)"); cli.add("--transformer-pool", "Pool encoder states instead of using cross attention (selects first encoder state, best used with special token)"); cli.add("--transformer-dim-ffn", diff --git a/src/graph/expression_operators.cpp b/src/graph/expression_operators.cpp index ca5e68054..09049f98f 100644 --- a/src/graph/expression_operators.cpp +++ b/src/graph/expression_operators.cpp @@ -676,13 +676,29 @@ Expr affine(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { } } -Expr affineWithRelu(Expr a, Expr b, Expr bias, bool transA, bool transB, float scale) { - auto graph = a->graph(); +// @TODO: unify all these +Expr affineWithReluDropout(Expr x, Expr W, Expr bias, float dropProb) { + auto graph = x->graph(); + if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu) { + // not doing any dropout in inference mode + return Expression(x, W, bias); + } else { + Expr output = affine(x, W, bias); + int dimModel = output->shape()[-1]; + int dimTime = output->shape()[-2]; + output = dropoutReluInplace(output, dropProb, {dimTime, dimModel}); + return output; + } +} - if(graph->isInference() && graph->getDeviceId().type == DeviceType::gpu) - return Expression(a, b, bias, transA, transB, scale); - else - return relu(affine(a, b, bias, transA, transB, scale)); +Expr dropoutReluInplace(Expr x, float dropProb, Shape shape) { + if(dropProb == 0) { + return relu(x); + } else { + auto graph = x->graph(); + auto mask = graph->dropoutMask(dropProb, shape); + return Expression(x, mask); + } } // @TODO: Not a great place to check this diff --git a/src/graph/expression_operators.h b/src/graph/expression_operators.h index 1e98047f9..5d9ceab36 100644 --- a/src/graph/expression_operators.h +++ b/src/graph/expression_operators.h @@ -493,12 +493,10 @@ Expr affine(Expr a, /** * As above, but efficiently applies relu transformation to output. For inference only. */ -Expr affineWithRelu(Expr a, - Expr b, - Expr bias, - bool transA = false, - bool transB = false, - float scalar = 1.f); +Expr affineWithReluDropout(Expr a, + Expr b, + Expr bias, + float dropProb = 0.f); /** * Computes the dot product of CSR-tensor @p A with @p B. @@ -971,6 +969,7 @@ static inline Expr dropout(Expr x, float dropProb, Shape shape) { return dropout(x, mask); } + /** * Performs dropout with a given probably. */ @@ -980,6 +979,8 @@ static inline Expr dropout(Expr x, float dropProb) { return dropout(x, dropProb, x->shape()); } +Expr dropoutReluInplace(Expr x, float dropProb, Shape shape); + /** * Shifts the elements of an expression by a per-axis offset @p shift * padded with @p padValue. diff --git a/src/graph/node_operators_binary.h b/src/graph/node_operators_binary.h index 7a4824ef0..292554bd0 100644 --- a/src/graph/node_operators_binary.h +++ b/src/graph/node_operators_binary.h @@ -431,16 +431,13 @@ class AffineWithReluNodeOp : public NaryNodeOp { public: AffineWithReluNodeOp(Expr a, Expr b, - Expr bias, - bool transA, - bool transB, - float scalar) - : NaryNodeOp({a, b, bias}, newShape(a, b, transA, transB)), - transA_(transA), - transB_(transB), - scalar_(scalar) { - ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu, - "AffineWithReluNodeOp currently only supported for inference on GPU"); + Expr bias) + : NaryNodeOp({a, b, bias}, newShape(a, b, false, false)), + transA_(false), + transB_(false), + scalar_(1.0) { + ABORT_IF(!graph()->isInference(), + "AffineWithReluNodeOp currently only supported for inference"); } Shape newShape(Expr a, Expr b, bool transA, bool transB) { @@ -464,8 +461,8 @@ class AffineWithReluNodeOp : public NaryNodeOp { } NodeOps forwardOps() override { - ABORT_IF(!graph()->isInference() || graph()->getDeviceId().type != DeviceType::gpu, - "AffineWithReluNodeOp currently only supported for inference on GPU"); + ABORT_IF(!graph()->isInference(), + "AffineWithReluNodeOp currently only supported for inference"); return { NodeOp(Affine(val_, diff --git a/src/graph/node_operators_unary.h b/src/graph/node_operators_unary.h index 448b4c4a4..27121fa6d 100644 --- a/src/graph/node_operators_unary.h +++ b/src/graph/node_operators_unary.h @@ -858,6 +858,8 @@ class ReshapeNodeOp : public UnaryNodeOp { } }; + + // @TODO: add version with access to backward step // This allows to attach a lambda function to any node during the execution. It is a non-operation otherwise // i.e. doesn't consume any memory or take any time to execute (it's a reshape onto itself) other than the @@ -897,6 +899,45 @@ class CallbackNodeOp : public ReshapeNodeOp { } }; +class DropoutReluInplaceNodeOp : public ReshapeNodeOp { +private: + Expr mask_; + +public: + DropoutReluInplaceNodeOp(Expr node, Expr mask) + : ReshapeNodeOp(node, node->shape()), + mask_(mask) {} + + void forward() override { + using namespace marian::functional; + Element(_1 = ReLU(_1 * _2), val(), mask_->val()); + } + + void backward() override { + using namespace marian::functional; + Element(_1 = _1 * ReLUback(_2) * _3, grad(), val(), mask_->val()); + } + + const std::string type() override { return "dropoutReluInplace"; } + + virtual size_t hash() override { + size_t seed = ReshapeNodeOp::hash(); + util::hash_combine(seed, mask_->hash()); + return seed; + } + + virtual bool equal(Expr node) override { + if(!ReshapeNodeOp::equal(node)) + return false; + auto cnode = std::dynamic_pointer_cast(node); + if(!cnode) + return false; + if(mask_ != cnode->mask_) + return false; + return true; + } +}; + // @TODO: review if still required as this is an ugly hack anyway. // Memory less operator that clips gradients during backward step // Executes this as an additional operation on the gradient. diff --git a/src/layers/generic.h b/src/layers/generic.h index b423befeb..df11a2337 100644 --- a/src/layers/generic.h +++ b/src/layers/generic.h @@ -234,12 +234,16 @@ static inline Expr denseInline(Expr x, auto b = graph->param(prefix + "_b" + suffix, {1, outDim}, inits::zeros()); if(actName == "relu") { - x = affineWithRelu(x, W, b); // speed optimization for inference, @TODO: handle better in future layer framework + x = affineWithReluDropout(x, W, b, dropProb); // fused operator for transformer FFN } else { x = affine(x, W, b); x = activationByName(actName)(x); + + int dimModel = x->shape()[-1]; + int dimTime = x->shape()[-2]; + x = dropout(x, dropProb, {dimTime, dimModel}); } - x = dropout(x, dropProb); // @TODO: check for infernce? + return x; } diff --git a/src/models/encoder_classifier.h b/src/models/encoder_classifier.h index 265bdacbd..552e428f2 100644 --- a/src/models/encoder_classifier.h +++ b/src/models/encoder_classifier.h @@ -116,6 +116,7 @@ class EncoderClassifier : public EncoderClassifierBase { modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); + modelFeatures_.insert("transformer-rnn-projection"); modelFeatures_.insert("transformer-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); diff --git a/src/models/encoder_decoder.cpp b/src/models/encoder_decoder.cpp index a6f4dd3dc..6a298ed0d 100644 --- a/src/models/encoder_decoder.cpp +++ b/src/models/encoder_decoder.cpp @@ -37,6 +37,7 @@ EncoderDecoder::EncoderDecoder(Ptr graph, Ptr options) modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); + modelFeatures_.insert("transformer-rnn-projection"); modelFeatures_.insert("transformer-dim-ffn"); modelFeatures_.insert("transformer-decoder-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); diff --git a/src/models/encoder_pooler.h b/src/models/encoder_pooler.h index 7bd17c41a..124d873c5 100644 --- a/src/models/encoder_pooler.h +++ b/src/models/encoder_pooler.h @@ -130,6 +130,7 @@ class EncoderPooler : public EncoderPoolerBase { modelFeatures_.insert("transformer-heads"); modelFeatures_.insert("transformer-no-projection"); + modelFeatures_.insert("transformer-rnn-projection"); modelFeatures_.insert("transformer-dim-ffn"); modelFeatures_.insert("transformer-ffn-depth"); modelFeatures_.insert("transformer-ffn-activation"); diff --git a/src/models/transformer.h b/src/models/transformer.h index d87594e0e..243d2c7fc 100644 --- a/src/models/transformer.h +++ b/src/models/transformer.h @@ -170,8 +170,11 @@ class Transformer : public EncoderOrDecoderBase { auto output = input; for(auto op : ops) { // dropout - if (op == 'd') - output = dropout(output, dropProb); + if (op == 'd') { + int dimModel = output->shape()[-1]; + int dimTime = output->shape()[-2]; + output = dropout(output, dropProb, {dimTime, dimModel}); + } // layer normalization else if (op == 'n') output = layerNorm(output, prefix, "_pre"); @@ -435,7 +438,7 @@ class Transformer : public EncoderOrDecoderBase { // the stack of FF layers for(int i = 1; i < depthFfn; ++i) - output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb); + output = denseInline(output, prefix, /*suffix=*/std::to_string(i), dimFfn, initFn, actName, ffnDropProb); output = denseInline(output, prefix, /*suffix=*/std::to_string(depthFfn), dimModel, initFn); auto opsPost = opt("transformer-postprocess"); @@ -538,6 +541,13 @@ class Transformer : public EncoderOrDecoderBase { decoderState = rnn->lastCellStates()[0]; output = transposeTimeBatch(output); + if(opt("transformer-rnn-projection", false)) { + int dimModel = output->shape()[-1]; + auto Wo = graph_->param(prefix + "_Wo", {dimModel, dimModel}, inits::glorotUniform(true, true, depthScaling_ ? 1.f / sqrtf((float)depth_) : 1.f)); + auto bo = graph_->param(prefix + "_bo", {1, dimModel}, inits::zeros()); + output = affine(output, Wo, bo); // [-4: beam depth, -3: batch size, -2: 1, -1: vector dim] + } + auto opsPost = opt("transformer-postprocess"); output = postProcess(prefix + "_ffn", opsPost, output, input, dropProb); diff --git a/src/tensors/gpu/element.inc b/src/tensors/gpu/element.inc index ade8b4892..edec0e1a7 100755 --- a/src/tensors/gpu/element.inc +++ b/src/tensors/gpu/element.inc @@ -70,6 +70,9 @@ template void marian::gpu::Element, marian::functional::BinaryFunctor >, marian::functional::Capture>, marian::functional::BinaryFunctor, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor > >, marian::functional::Capture> > > >>(marian::functional::Assign, marian::functional::BinaryFunctor >, marian::functional::Capture>, marian::functional::BinaryFunctor, marian::functional::Capture> >, marian::functional::Capture> >, marian::functional::UnaryFunctor > >, marian::functional::Capture> > > >, IntrusivePtr); template void marian::gpu::Element, marian::functional::UnaryFunctor > >, marian::Tensor >(marian::functional::Assign, marian::functional::UnaryFunctor > >, marian::Tensor, marian::Tensor); template void marian::gpu::Element, marian::functional::UnaryFunctor > >, marian::Tensor >(marian::functional::Assign, marian::functional::UnaryFunctor > >, marian::Tensor, marian::Tensor); +template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::BinaryFunctor, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::BinaryFunctor, marian::functional::Capture> >, marian::functional::Capture> >, IntrusivePtr, IntrusivePtr); +template void marian::gpu::Element, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > > >, IntrusivePtr >(marian::functional::Assign, marian::functional::UnaryFunctor, marian::functional::Assignee<2> > > >, IntrusivePtr, IntrusivePtr); +template void marian::gpu::Element, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > >, marian::functional::Assignee<3> > >, IntrusivePtr, IntrusivePtr >(marian::functional::Assign, marian::functional::BinaryFunctor, marian::functional::UnaryFunctor > >, marian::functional::Assignee<3> > >, IntrusivePtr, IntrusivePtr, IntrusivePtr); // How to add new specializations: // When you use a new specialization, it will cause a link error of this form (example): // .../src/tensors/tensor_operators.h:41: undefined reference to `void marian::gpu::Element ( ... )' diff --git a/src/tests/dropout.cpp b/src/tests/dropout.cpp index 367029fe8..97a30b4f6 100644 --- a/src/tests/dropout.cpp +++ b/src/tests/dropout.cpp @@ -7,7 +7,7 @@ using namespace marian; int main(int argc, char** argv) { - auto c = New(argc, argv); + auto c = parseOptions(argc, argv, cli::mode::scoring, false); auto type = c->get("cpu-threads") > 0 ? DeviceType::cpu @@ -20,11 +20,7 @@ int main(int argc, char** argv) { for(int i = 0; i < 10; ++i) { g->clear(); - auto mask1 = g->dropoutMask(0.2, {10, 3072}); - auto mask2 = g->dropoutMask(0.3, {1, 3072}); - auto mask = mask1 + mask2; - debug(mask1, "mask1"); - debug(mask2, "mask2"); + auto mask = g->dropoutMask(0.2, {1000, 16384}); debug(mask, "mask"); g->forward(); } diff --git a/src/tests/units/operator_tests.cpp b/src/tests/units/operator_tests.cpp index f3b5fda34..236823fe4 100644 --- a/src/tests/units/operator_tests.cpp +++ b/src/tests/units/operator_tests.cpp @@ -595,7 +595,7 @@ void tests(DeviceType device, Type floatType = Type::float32) { auto aff1 = affine(A, B, bias); auto aff2 = dot(A, B) + bias; - auto affRelu1 = affineWithRelu(A, B, bias); + auto affRelu1 = affineWithReluDropout(A, B, bias); auto affRelu2 = relu(dot(A, B) + bias); graph->forward();