Skip to content

Commit 39c986e

Browse files
committed
Updated the attention mechanisms so that they support arbitrary attention state types.
1 parent 5ac0ccd commit 39c986e

File tree

5 files changed

+138
-74
lines changed

5 files changed

+138
-74
lines changed

api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/attention/Attention.scala

Lines changed: 64 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515

1616
package org.platanios.tensorflow.api.ops.rnn.attention
1717

18+
import org.platanios.tensorflow.api.core.Shape
1819
import org.platanios.tensorflow.api.core.exception.InvalidShapeException
1920
import org.platanios.tensorflow.api.implicits.Implicits._
21+
import org.platanios.tensorflow.api.ops.control_flow.WhileLoopVariable
2022
import org.platanios.tensorflow.api.ops.{Basic, Checks, Math, NN, Op, Output}
2123
import org.platanios.tensorflow.api.types.{DataType, INT32}
2224

@@ -38,27 +40,31 @@ import scala.language.postfixOps
3840
*
3941
* @author Emmanouil Antonios Platanios
4042
*/
41-
abstract class Attention(
43+
abstract class Attention[AS, ASS](
4244
protected val memory: Output,
4345
protected val memorySequenceLengths: Output = null,
4446
val checkInnerDimensionsDefined: Boolean = true,
4547
val scoreMaskValue: Output = Float.NegativeInfinity,
4648
val name: String = "Attention"
49+
)(implicit
50+
evAS: WhileLoopVariable.Aux[AS, ASS]
4751
) {
48-
lazy val values: Output = Op.createWithNameScope(s"$name/Initialization") {
52+
lazy val values: Output = Op.createWithNameScope(s"$name/Values") {
4953
Attention.maybeMaskValues(memory, memorySequenceLengths, checkInnerDimensionsDefined)
5054
}
5155

5256
lazy val keys: Output = values
5357

54-
lazy val batchSize: Output = Op.createWithNameScope(s"$name/Initialization") {
58+
lazy val batchSize: Output = Op.createWithNameScope(s"$name/BatchSize") {
5559
Attention.dimSize(keys, 0)
5660
}
5761

58-
lazy val alignmentSize: Output = Op.createWithNameScope(s"$name/Initialization") {
62+
lazy val alignmentSize: Output = Op.createWithNameScope(s"$name/AlignmentSize") {
5963
Attention.dimSize(keys, 1)
6064
}
6165

66+
def stateSize: ASS
67+
6268
lazy val dataType: DataType = keys.dataType
6369

6470
/** Initial alignment value.
@@ -69,42 +75,79 @@ abstract class Attention(
6975
* The default behavior is to return a tensor of all zeros.
7076
*/
7177
lazy val initialAlignment: Output = {
72-
Op.createWithNameScope(s"$name/InitialAlignments", Set(batchSize.op)) {
78+
Op.createWithNameScope(s"$name/InitialAlignment", Set(batchSize.op)) {
7379
val fullShape = Basic.stack(Seq(batchSize, alignmentSize.cast(batchSize.dataType)), axis = 0)
7480
Basic.zeros(dataType, fullShape)
7581
}
7682
}
7783

84+
/** Initial state value.
85+
*
86+
* This is important for attention mechanisms that use the previous alignment to calculate the alignment at the
87+
* next time step (e.g., monotonic attention).
88+
*
89+
* The default behavior is to return the same output as `initialAlignment`.
90+
*/
91+
def initialState: AS
92+
7893
/** Computes an alignment tensor given the provided query and previous alignment tensor.
7994
*
8095
* The previous alignment tensor is important for attention mechanisms that use the previous alignment to calculate
8196
* the attention at the next time step, such as monotonic attention mechanisms.
8297
*
83-
* @param query Query tensor.
84-
* @param previousAlignment Previous alignment tensor.
85-
* @return Alignment tensor.
98+
* TODO: Figure out how to generalize the "next state" functionality.
99+
*
100+
* @param query Query tensor.
101+
* @param previousState Previous alignment tensor.
102+
* @return Tuple containing the alignment tensor and the next attention state.
86103
*/
87-
final def alignment(query: Output, previousAlignment: Output): Output = Op.createWithNameScope(name) {
88-
val unmaskedScore = score(query, previousAlignment)
89-
val maskedScore = Attention.maybeMaskScore(unmaskedScore, memorySequenceLengths, scoreMaskValue)
90-
probability(maskedScore, previousAlignment)
91-
}
104+
def alignment(query: Output, previousState: AS): (Output, AS)
92105

93106
/** Computes an alignment score for `query`.
94107
*
95-
* @param query Query tensor.
96-
* @param previousAlignment Previous alignment tensor.
108+
* @param query Query tensor.
109+
* @param state Current attention mechanism state (defaults to the previous alignment tensor). The data type of
110+
* this tensor matches that of `values` and its shape is `[batchSize, alignmentSize]`, where
111+
* `alignmentSize` is the memory's maximum time.
97112
* @return Score tensor.
98113
*/
99-
protected def score(query: Output, previousAlignment: Output): Output
114+
protected def score(query: Output, state: AS): Output
100115

101116
/** Computes alignment probabilities for `score`.
102117
*
103-
* @param score Alignment score tensor.
104-
* @param previousAlignment Previous alignment tensor.
118+
* @param score Alignment score tensor.
119+
* @param state Current attention mechanism state (defaults to the previous alignment tensor). The data type of
120+
* this tensor matches that of `values` and its shape is `[batchSize, alignmentSize]`, where
121+
* `alignmentSize` is the memory's maximum time.
105122
* @return Alignment probabilities tensor.
106123
*/
107-
protected def probability(score: Output, previousAlignment: Output): Output = NN.softmax(score, name = "Probability")
124+
protected def probability(score: Output, state: AS): Output = NN.softmax(score, name = "Probability")
125+
}
126+
127+
/** Base class for attention models that use as state the previous alignment. */
128+
abstract class SimpleAttention(
129+
override protected val memory: Output,
130+
override protected val memorySequenceLengths: Output = null,
131+
override val checkInnerDimensionsDefined: Boolean = true,
132+
override val scoreMaskValue: Output = Float.NegativeInfinity,
133+
override val name: String = "SimpleAttention"
134+
) extends Attention[Output, Shape](memory, memorySequenceLengths, checkInnerDimensionsDefined, scoreMaskValue, name) {
135+
override def stateSize: Shape = {
136+
Output.constantValueAsShape(alignmentSize).getOrElse(Shape.unknown())
137+
}
138+
139+
override def initialState: Output = {
140+
Op.createWithNameScope(s"$name/InitialState", Set(batchSize.op)) {
141+
Basic.identity(initialAlignment)
142+
}
143+
}
144+
145+
override def alignment(query: Output, previousState: Output): (Output, Output) = Op.createWithNameScope(name) {
146+
val unmaskedScore = score(query, previousState)
147+
val maskedScore = Attention.maybeMaskScore(unmaskedScore, memorySequenceLengths, scoreMaskValue)
148+
val alignment = probability(maskedScore, previousState)
149+
(alignment, alignment)
150+
}
108151
}
109152

110153
object Attention {
@@ -117,7 +160,7 @@ object Attention {
117160

118161
/** Potentially masks the provided values tensor based on the provided sequence lengths. */
119162
@throws[InvalidShapeException]
120-
private[Attention] def maybeMaskValues(
163+
private[attention] def maybeMaskValues(
121164
values: Output, sequenceLengths: Output, checkInnerDimensionsDefined: Boolean
122165
): Output = {
123166
if (checkInnerDimensionsDefined && !values.shape(2 ::).isFullyDefined)
@@ -152,7 +195,7 @@ object Attention {
152195
}
153196

154197
/** Potentially masks the provided score tensor based on the provided sequence lengths. */
155-
private[Attention] def maybeMaskScore(
198+
private[attention] def maybeMaskScore(
156199
score: Output, sequenceLengths: Output, scoreMaskValue: Output
157200
): Output = {
158201
if (sequenceLengths != null) {

api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/attention/AttentionWrapperCell.scala

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,18 @@ import org.platanios.tensorflow.api.types.{DataType, INT32}
4343
*
4444
* @author Emmanouil Antonios Platanios
4545
*/
46-
class AttentionWrapperCell[S, SS] private[attention] (
46+
class AttentionWrapperCell[S, SS, AS, ASS] private[attention] (
4747
val cell: RNNCell[Output, Shape, S, SS],
48-
val attentions: Seq[Attention],
48+
val attentions: Seq[Attention[AS, ASS]], // TODO: Allow for varying supported types in the sequence.
4949
val attentionLayerWeights: Seq[Output] = null,
5050
val cellInputFn: (Output, Output) => Output = (input, attention) => Basic.concatenate(Seq(input, attention), -1),
5151
val outputAttention: Boolean = true,
5252
val storeAlignmentsHistory: Boolean = false,
5353
val name: String = "AttentionWrapperCell"
5454
)(implicit
55-
evS: WhileLoopVariable.Aux[S, SS]
56-
) extends RNNCell[Output, Shape, AttentionWrapperState[S, SS], (SS, Shape, Shape, Seq[Shape], Seq[Shape])] {
55+
evS: WhileLoopVariable.Aux[S, SS],
56+
evAS: WhileLoopVariable.Aux[AS, ASS]
57+
) extends RNNCell[Output, Shape, AttentionWrapperState[S, SS, Seq[AS], Seq[ASS]], (SS, Shape, Shape, Seq[Shape], Seq[Shape], Seq[ASS])] {
5758
private[this] val attentionLayersSize: Int = {
5859
if (attentionLayerWeights != null) {
5960
require(attentionLayerWeights.lengthCompare(attentions.size) == 0,
@@ -74,7 +75,7 @@ class AttentionWrapperCell[S, SS] private[attention] (
7475
* `initialCellState`.
7576
* @return Initial state for this attention cell wrapper.
7677
*/
77-
def initialState(initialCellState: S, dataType: DataType = null): AttentionWrapperState[S, SS] = {
78+
def initialState(initialCellState: S, dataType: DataType = null): AttentionWrapperState[S, SS, Seq[AS], Seq[ASS]] = {
7879
if (initialCellState == null) {
7980
null
8081
} else {
@@ -101,17 +102,18 @@ class AttentionWrapperCell[S, SS] private[attention] (
101102
attentions.map(_ => TensorArray.create(0, inferredDataType, dynamicSize = true))
102103
else
103104
Seq.empty
104-
})
105+
},
106+
attentionState = attentions.map(_.initialState))
105107
}
106108
}
107109
}
108110

109111
override def outputShape: Shape = if (outputAttention) Shape(attentionLayersSize) else cell.outputShape
110112

111-
override def stateShape: (SS, Shape, Shape, Seq[Shape], Seq[Shape]) = {
113+
override def stateShape: (SS, Shape, Shape, Seq[Shape], Seq[Shape], Seq[ASS]) = {
112114
(cell.stateShape, Shape(1), Shape(attentionLayersSize),
113115
attentions.map(a => Output.constantValueAsShape(a.alignmentSize.expandDims(0)).getOrElse(Shape.unknown())),
114-
attentions.map(_ => Shape.scalar()))
116+
attentions.map(_ => Shape.scalar()), attentions.map(_.stateSize))
115117
}
116118

117119
/** Performs a step using this attention-wrapped RNN cell.
@@ -129,7 +131,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
129131
* @return Next tuple.
130132
*/
131133
override def forward(
132-
input: Tuple[Output, AttentionWrapperState[S, SS]]): Tuple[Output, AttentionWrapperState[S, SS]] = {
134+
input: Tuple[Output, AttentionWrapperState[S, SS, Seq[AS], Seq[ASS]]]
135+
): Tuple[Output, AttentionWrapperState[S, SS, Seq[AS], Seq[ASS]]] = {
133136
// Step 1: Calculate the true inputs to the cell based on the previous attention value.
134137
val cellInput = cellInputFn(input.output, input.state.attention)
135138
val nextTuple = cell.forward(Tuple(cellInput, input.state.cellState))
@@ -142,9 +145,9 @@ class AttentionWrapperCell[S, SS] private[attention] (
142145
Basic.identity(output, "CheckedCellOutput")
143146
}
144147
val weights = if (attentionLayerWeights != null) attentionLayerWeights else attentions.map(_ => null)
145-
val (allAttentions, allAlignments) = (attentions, input.state.alignments, weights).zipped.map {
146-
case (mechanism, previous, w) =>
147-
val alignments = mechanism.alignment(checkedOutput, previous)
148+
val (allAttentions, allAlignments, allStates) = (attentions, input.state.attentionState, weights).zipped.map {
149+
case (mechanism, previousState, w) =>
150+
val (alignments, state) = mechanism.alignment(checkedOutput, previousState)
148151
// Reshape from [batchSize, memoryTime] to [batchSize, 1, memoryTime]
149152
val expandedAlignments = alignments.expandDims(1)
150153
// Context is the inner product of alignments and values along the memory time dimension.
@@ -159,8 +162,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
159162
else
160163
context
161164
}
162-
(attention, alignments)
163-
}.unzip
165+
(attention, alignments, state)
166+
}.unzip3
164167
val histories = {
165168
if (storeAlignmentsHistory)
166169
input.state.alignmentsHistory.zip(allAlignments).map(p => p._1.write(input.state.time, p._2))
@@ -169,7 +172,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
169172
}
170173
val one = Basic.constant(1)
171174
val attention = Basic.concatenate(allAttentions, one)
172-
val nextState = AttentionWrapperState(nextTuple.state, input.state.time + one, attention, allAlignments, histories)
175+
val nextState = AttentionWrapperState(
176+
nextTuple.state, input.state.time + one, attention, allAlignments, histories, allStates)
173177
if (outputAttention)
174178
Tuple(attention, nextState)
175179
else
@@ -178,18 +182,19 @@ class AttentionWrapperCell[S, SS] private[attention] (
178182
}
179183

180184
object AttentionWrapperCell {
181-
def apply[S, SS](
185+
def apply[S, SS, AS, ASS](
182186
cell: RNNCell[Output, Shape, S, SS],
183-
attentions: Seq[Attention],
187+
attentions: Seq[Attention[AS, ASS]],
184188
attentionLayerWeights: Seq[Output] = null,
185189
cellInputFn: (Output, Output) => Output = (input, attention) => Basic.concatenate(Seq(input, attention), -1),
186190
outputAttention: Boolean = true,
187191
storeAlignmentsHistory: Boolean = false,
188192
name: String = "AttentionWrapperCell"
189193
)(implicit
190-
evS: WhileLoopVariable.Aux[S, SS]
191-
): AttentionWrapperCell[S, SS] = {
192-
new AttentionWrapperCell[S, SS](
194+
evS: WhileLoopVariable.Aux[S, SS],
195+
evAS: WhileLoopVariable.Aux[AS, ASS]
196+
): AttentionWrapperCell[S, SS, AS, ASS] = {
197+
new AttentionWrapperCell[S, SS, AS, ASS](
193198
cell, attentions, attentionLayerWeights, cellInputFn, outputAttention, storeAlignmentsHistory, name)
194199
}
195200
}

api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/attention/BahdanauAttention.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ class BahdanauAttention(
6464
protected val probabilityFn: (Output) => Output = NN.softmax(_, name = "Probability"),
6565
override val scoreMaskValue: Output = Float.NegativeInfinity,
6666
override val name: String = "BahdanauAttention"
67-
) extends Attention(memory, memorySequenceLengths, checkInnerDimensionsDefined = true, scoreMaskValue, name) {
67+
) extends SimpleAttention(memory, memorySequenceLengths, checkInnerDimensionsDefined = true, scoreMaskValue, name) {
6868
override lazy val keys: Output = NN.linear(values, memoryWeights)
6969

7070
@throws[InvalidArgumentException]

api/src/main/scala/org/platanios/tensorflow/api/ops/rnn/attention/LuongAttention.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ class LuongAttention(
5050
protected val probabilityFn: (Output) => Output = NN.softmax(_, name = "Probability"),
5151
override val scoreMaskValue: Output = Float.NegativeInfinity,
5252
override val name: String = "LuongAttention"
53-
) extends Attention(memory, memorySequenceLengths, checkInnerDimensionsDefined = true, scoreMaskValue, name) {
53+
) extends SimpleAttention(memory, memorySequenceLengths, checkInnerDimensionsDefined = true, scoreMaskValue, name) {
5454
override lazy val keys: Output = NN.linear(values, memoryWeights)
5555

5656
@throws[InvalidArgumentException]

0 commit comments

Comments
 (0)