|
15 | 15 |
|
16 | 16 | package org.platanios.tensorflow.api.ops.seq2seq.decoders
|
17 | 17 |
|
18 |
| -import org.platanios.tensorflow.api.core.{Indexer, NewAxis, Shape} |
| 18 | +import org.platanios.tensorflow.api.core.{NewAxis, Shape} |
19 | 19 | import org.platanios.tensorflow.api.core.exception.{InvalidArgumentException, InvalidShapeException}
|
20 | 20 | import org.platanios.tensorflow.api.implicits.Implicits._
|
21 | 21 | import org.platanios.tensorflow.api.ops
|
@@ -128,11 +128,15 @@ class BeamSearchDecoder[S, SS](
|
128 | 128 | */
|
129 | 129 | override def initialize(): (Output, Output, BeamSearchDecoder.State[S, SS]) = {
|
130 | 130 | Op.createWithNameScope(s"$name/Initialize", Set(batchSize.op)) {
|
131 |
| - val finished = Basic.zeros(BOOLEAN, Basic.stack(Seq(batchSize, beamWidth))) |
| 131 | + val finished = Basic.oneHot( |
| 132 | + indices = Basic.zeros(INT32, batchSize.expandDims(0)), depth = beamWidth, |
| 133 | + onValue = false, offValue = true, dataType = BOOLEAN) |
132 | 134 | val initialState = BeamSearchDecoder.State[S, SS](
|
133 | 135 | rnnState = processedInitialCellState,
|
134 |
| - logProbabilities = Basic.zeros( |
135 |
| - evS.outputs(processedInitialCellState).head.dataType, Basic.stack(Seq(batchSize, beamWidth))), |
| 136 | + logProbabilities = Basic.oneHot( |
| 137 | + indices = Basic.zeros(INT32, batchSize.expandDims(0)), depth = beamWidth, |
| 138 | + onValue = 0.0f, offValue = Float.NegativeInfinity, |
| 139 | + dataType = evS.outputs(processedInitialCellState).head.dataType), |
136 | 140 | finished = finished,
|
137 | 141 | sequenceLengths = Basic.zeros(INT64, Basic.stack(Seq(batchSize, beamWidth))))
|
138 | 142 | (finished, beginInput, initialState)
|
@@ -189,18 +193,10 @@ class BeamSearchDecoder[S, SS](
|
189 | 193 | val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths)
|
190 | 194 |
|
191 | 195 | // During the first time step we only consider the initial beam
|
192 |
| - val scoresShape = Basic.shape(scores) |
193 |
| - val scoresFlat = ControlFlow.cond( |
194 |
| - time > 0, |
195 |
| - () => scores.reshape(Basic.stack(Seq(batchSize, -1))), |
196 |
| - () => scores(Indexer.::, 0)) |
197 |
| - val numAvailableBeams = ControlFlow.cond( |
198 |
| - time > 0, |
199 |
| - () => scoresShape(1 ::).prod(), |
200 |
| - () => scoresShape(2 ::).prod()) |
| 196 | + val scoresFlat = Basic.reshape(scores, Basic.stack(Seq(batchSize, -1))) |
201 | 197 |
|
202 | 198 | // Pick the next beams according to the specified successors function
|
203 |
| - val nextBeamSize = Math.minimum(Basic.constant(beamWidth, INT32, name = "BeamWidth"), numAvailableBeams) |
| 199 | + val nextBeamSize = Basic.constant(beamWidth, INT32, name = "BeamWidth") |
204 | 200 | val (nextBeamScores, wordIndices) = NN.topK(scoresFlat, nextBeamSize)
|
205 | 201 | nextBeamScores.setShape(Shape(staticBatchSize, beamWidth))
|
206 | 202 | wordIndices.setShape(Shape(staticBatchSize, beamWidth))
|
@@ -417,7 +413,9 @@ object BeamSearchDecoder {
|
417 | 413 |
|
418 | 414 | /** Final outputs returned by the beam search after all decoding is finished.
|
419 | 415 | *
|
420 |
| - * @param predictedIDs Tensor of shape `[T, batchSize, beamWidth]` containing the final prediction IDs. |
| 416 | + * @param predictedIDs Tensor of shape `[batchSize, T, beamWidth]` (or `[T, batchSize, beamWidth]`, |
| 417 | + * if `outputTimeMajor == true`) containing the final prediction IDs. The beams are ordered |
| 418 | + * from best to worst. |
421 | 419 | * @param output State of the beam search at the end of decoding.
|
422 | 420 | */
|
423 | 421 | case class FinalOutput(predictedIDs: ops.Output, output: Output)
|
|
0 commit comments