Skip to content

Commit 5ac0ccd

Browse files
committed
Updated the beam search decoder so that it can handle non-full beams better.
1 parent 65cb3dd commit 5ac0ccd

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

api/src/main/scala/org/platanios/tensorflow/api/ops/seq2seq/decoders/BeamSearchDecoder.scala

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
package org.platanios.tensorflow.api.ops.seq2seq.decoders
1717

18-
import org.platanios.tensorflow.api.core.{Indexer, NewAxis, Shape}
18+
import org.platanios.tensorflow.api.core.{NewAxis, Shape}
1919
import org.platanios.tensorflow.api.core.exception.{InvalidArgumentException, InvalidShapeException}
2020
import org.platanios.tensorflow.api.implicits.Implicits._
2121
import org.platanios.tensorflow.api.ops
@@ -128,11 +128,15 @@ class BeamSearchDecoder[S, SS](
128128
*/
129129
override def initialize(): (Output, Output, BeamSearchDecoder.State[S, SS]) = {
130130
Op.createWithNameScope(s"$name/Initialize", Set(batchSize.op)) {
131-
val finished = Basic.zeros(BOOLEAN, Basic.stack(Seq(batchSize, beamWidth)))
131+
val finished = Basic.oneHot(
132+
indices = Basic.zeros(INT32, batchSize.expandDims(0)), depth = beamWidth,
133+
onValue = false, offValue = true, dataType = BOOLEAN)
132134
val initialState = BeamSearchDecoder.State[S, SS](
133135
rnnState = processedInitialCellState,
134-
logProbabilities = Basic.zeros(
135-
evS.outputs(processedInitialCellState).head.dataType, Basic.stack(Seq(batchSize, beamWidth))),
136+
logProbabilities = Basic.oneHot(
137+
indices = Basic.zeros(INT32, batchSize.expandDims(0)), depth = beamWidth,
138+
onValue = 0.0f, offValue = Float.NegativeInfinity,
139+
dataType = evS.outputs(processedInitialCellState).head.dataType),
136140
finished = finished,
137141
sequenceLengths = Basic.zeros(INT64, Basic.stack(Seq(batchSize, beamWidth))))
138142
(finished, beginInput, initialState)
@@ -189,18 +193,10 @@ class BeamSearchDecoder[S, SS](
189193
val scores = lengthPenalty(totalLogProbabilities, newPredictionLengths)
190194

191195
// During the first time step we only consider the initial beam
192-
val scoresShape = Basic.shape(scores)
193-
val scoresFlat = ControlFlow.cond(
194-
time > 0,
195-
() => scores.reshape(Basic.stack(Seq(batchSize, -1))),
196-
() => scores(Indexer.::, 0))
197-
val numAvailableBeams = ControlFlow.cond(
198-
time > 0,
199-
() => scoresShape(1 ::).prod(),
200-
() => scoresShape(2 ::).prod())
196+
val scoresFlat = Basic.reshape(scores, Basic.stack(Seq(batchSize, -1)))
201197

202198
// Pick the next beams according to the specified successors function
203-
val nextBeamSize = Math.minimum(Basic.constant(beamWidth, INT32, name = "BeamWidth"), numAvailableBeams)
199+
val nextBeamSize = Basic.constant(beamWidth, INT32, name = "BeamWidth")
204200
val (nextBeamScores, wordIndices) = NN.topK(scoresFlat, nextBeamSize)
205201
nextBeamScores.setShape(Shape(staticBatchSize, beamWidth))
206202
wordIndices.setShape(Shape(staticBatchSize, beamWidth))
@@ -417,7 +413,9 @@ object BeamSearchDecoder {
417413

418414
/** Final outputs returned by the beam search after all decoding is finished.
419415
*
420-
* @param predictedIDs Tensor of shape `[T, batchSize, beamWidth]` containing the final prediction IDs.
416+
* @param predictedIDs Tensor of shape `[batchSize, T, beamWidth]` (or `[T, batchSize, beamWidth]`,
417+
* if `outputTimeMajor == true`) containing the final prediction IDs. The beams are ordered
418+
* from best to worst.
421419
* @param output State of the beam search at the end of decoding.
422420
*/
423421
case class FinalOutput(predictedIDs: ops.Output, output: Output)

0 commit comments

Comments
 (0)