@@ -43,17 +43,18 @@ import org.platanios.tensorflow.api.types.{DataType, INT32}
43
43
*
44
44
* @author Emmanouil Antonios Platanios
45
45
*/
46
- class AttentionWrapperCell [S , SS ] private [attention] (
46
+ class AttentionWrapperCell [S , SS , AS , ASS ] private [attention] (
47
47
val cell : RNNCell [Output , Shape , S , SS ],
48
- val attentions : Seq [Attention ],
48
+ val attentions : Seq [Attention [ AS , ASS ]], // TODO: Allow for varying supported types in the sequence.
49
49
val attentionLayerWeights : Seq [Output ] = null ,
50
50
val cellInputFn : (Output , Output ) => Output = (input, attention) => Basic .concatenate(Seq (input, attention), - 1 ),
51
51
val outputAttention : Boolean = true ,
52
52
val storeAlignmentsHistory : Boolean = false ,
53
53
val name : String = " AttentionWrapperCell"
54
54
)(implicit
55
- evS : WhileLoopVariable .Aux [S , SS ]
56
- ) extends RNNCell [Output , Shape , AttentionWrapperState [S , SS ], (SS , Shape , Shape , Seq [Shape ], Seq [Shape ])] {
55
+ evS : WhileLoopVariable .Aux [S , SS ],
56
+ evAS : WhileLoopVariable .Aux [AS , ASS ]
57
+ ) extends RNNCell [Output , Shape , AttentionWrapperState [S , SS , Seq [AS ], Seq [ASS ]], (SS , Shape , Shape , Seq [Shape ], Seq [Shape ], Seq [ASS ])] {
57
58
private [this ] val attentionLayersSize : Int = {
58
59
if (attentionLayerWeights != null ) {
59
60
require(attentionLayerWeights.lengthCompare(attentions.size) == 0 ,
@@ -74,7 +75,7 @@ class AttentionWrapperCell[S, SS] private[attention] (
74
75
* `initialCellState`.
75
76
* @return Initial state for this attention cell wrapper.
76
77
*/
77
- def initialState (initialCellState : S , dataType : DataType = null ): AttentionWrapperState [S , SS ] = {
78
+ def initialState (initialCellState : S , dataType : DataType = null ): AttentionWrapperState [S , SS , Seq [ AS ], Seq [ ASS ] ] = {
78
79
if (initialCellState == null ) {
79
80
null
80
81
} else {
@@ -101,17 +102,18 @@ class AttentionWrapperCell[S, SS] private[attention] (
101
102
attentions.map(_ => TensorArray .create(0 , inferredDataType, dynamicSize = true ))
102
103
else
103
104
Seq .empty
104
- })
105
+ },
106
+ attentionState = attentions.map(_.initialState))
105
107
}
106
108
}
107
109
}
108
110
109
111
override def outputShape : Shape = if (outputAttention) Shape (attentionLayersSize) else cell.outputShape
110
112
111
- override def stateShape : (SS , Shape , Shape , Seq [Shape ], Seq [Shape ]) = {
113
+ override def stateShape : (SS , Shape , Shape , Seq [Shape ], Seq [Shape ], Seq [ ASS ] ) = {
112
114
(cell.stateShape, Shape (1 ), Shape (attentionLayersSize),
113
115
attentions.map(a => Output .constantValueAsShape(a.alignmentSize.expandDims(0 )).getOrElse(Shape .unknown())),
114
- attentions.map(_ => Shape .scalar()))
116
+ attentions.map(_ => Shape .scalar()), attentions.map(_.stateSize) )
115
117
}
116
118
117
119
/** Performs a step using this attention-wrapped RNN cell.
@@ -129,7 +131,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
129
131
* @return Next tuple.
130
132
*/
131
133
override def forward (
132
- input : Tuple [Output , AttentionWrapperState [S , SS ]]): Tuple [Output , AttentionWrapperState [S , SS ]] = {
134
+ input : Tuple [Output , AttentionWrapperState [S , SS , Seq [AS ], Seq [ASS ]]]
135
+ ): Tuple [Output , AttentionWrapperState [S , SS , Seq [AS ], Seq [ASS ]]] = {
133
136
// Step 1: Calculate the true inputs to the cell based on the previous attention value.
134
137
val cellInput = cellInputFn(input.output, input.state.attention)
135
138
val nextTuple = cell.forward(Tuple (cellInput, input.state.cellState))
@@ -142,9 +145,9 @@ class AttentionWrapperCell[S, SS] private[attention] (
142
145
Basic .identity(output, " CheckedCellOutput" )
143
146
}
144
147
val weights = if (attentionLayerWeights != null ) attentionLayerWeights else attentions.map(_ => null )
145
- val (allAttentions, allAlignments) = (attentions, input.state.alignments , weights).zipped.map {
146
- case (mechanism, previous , w) =>
147
- val alignments = mechanism.alignment(checkedOutput, previous )
148
+ val (allAttentions, allAlignments, allStates ) = (attentions, input.state.attentionState , weights).zipped.map {
149
+ case (mechanism, previousState , w) =>
150
+ val ( alignments, state) = mechanism.alignment(checkedOutput, previousState )
148
151
// Reshape from [batchSize, memoryTime] to [batchSize, 1, memoryTime]
149
152
val expandedAlignments = alignments.expandDims(1 )
150
153
// Context is the inner product of alignments and values along the memory time dimension.
@@ -159,8 +162,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
159
162
else
160
163
context
161
164
}
162
- (attention, alignments)
163
- }.unzip
165
+ (attention, alignments, state )
166
+ }.unzip3
164
167
val histories = {
165
168
if (storeAlignmentsHistory)
166
169
input.state.alignmentsHistory.zip(allAlignments).map(p => p._1.write(input.state.time, p._2))
@@ -169,7 +172,8 @@ class AttentionWrapperCell[S, SS] private[attention] (
169
172
}
170
173
val one = Basic .constant(1 )
171
174
val attention = Basic .concatenate(allAttentions, one)
172
- val nextState = AttentionWrapperState (nextTuple.state, input.state.time + one, attention, allAlignments, histories)
175
+ val nextState = AttentionWrapperState (
176
+ nextTuple.state, input.state.time + one, attention, allAlignments, histories, allStates)
173
177
if (outputAttention)
174
178
Tuple (attention, nextState)
175
179
else
@@ -178,18 +182,19 @@ class AttentionWrapperCell[S, SS] private[attention] (
178
182
}
179
183
180
184
object AttentionWrapperCell {
181
- def apply [S , SS ](
185
+ def apply [S , SS , AS , ASS ](
182
186
cell : RNNCell [Output , Shape , S , SS ],
183
- attentions : Seq [Attention ],
187
+ attentions : Seq [Attention [ AS , ASS ] ],
184
188
attentionLayerWeights : Seq [Output ] = null ,
185
189
cellInputFn : (Output , Output ) => Output = (input, attention) => Basic .concatenate(Seq (input, attention), - 1 ),
186
190
outputAttention : Boolean = true ,
187
191
storeAlignmentsHistory : Boolean = false ,
188
192
name : String = " AttentionWrapperCell"
189
193
)(implicit
190
- evS : WhileLoopVariable .Aux [S , SS ]
191
- ): AttentionWrapperCell [S , SS ] = {
192
- new AttentionWrapperCell [S , SS ](
194
+ evS : WhileLoopVariable .Aux [S , SS ],
195
+ evAS : WhileLoopVariable .Aux [AS , ASS ]
196
+ ): AttentionWrapperCell [S , SS , AS , ASS ] = {
197
+ new AttentionWrapperCell [S , SS , AS , ASS ](
193
198
cell, attentions, attentionLayerWeights, cellInputFn, outputAttention, storeAlignmentsHistory, name)
194
199
}
195
200
}
0 commit comments