Skip to content

Commit d07f06b

Browse files
committed
Minor edit to the fetchable type trait.
1 parent 13f1089 commit d07f06b

File tree

1 file changed

+26
-2
lines changed

1 file changed

+26
-2
lines changed

api/src/main/scala/org/platanios/tensorflow/api/core/client/Fetchable.scala

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
package org.platanios.tensorflow.api.core.client
1717

18-
import org.platanios.tensorflow.api.ops.{Output, OutputIndexedSlices, SparseOutput}
19-
import org.platanios.tensorflow.api.tensors.{SparseTensor, Tensor, TensorIndexedSlices}
18+
import org.platanios.tensorflow.api.ops.{Output, OutputIndexedSlices, OutputLike, SparseOutput}
19+
import org.platanios.tensorflow.api.tensors.{SparseTensor, Tensor, TensorIndexedSlices, TensorLike}
2020
import org.platanios.tensorflow.api.utilities.Collections
2121

2222
import shapeless._
@@ -123,6 +123,30 @@ object Fetchable {
123123
}
124124
}
125125

126+
// TODO: Make this more elegant.
127+
128+
implicit val outputLikeFetchable: Aux[OutputLike, TensorLike] = new Fetchable[OutputLike] {
129+
override type ResultType = TensorLike
130+
131+
override def numberOfFetches(fetchable: OutputLike): Int = fetchable match {
132+
case _: Output => 1
133+
case _: OutputIndexedSlices => 3
134+
case _: SparseOutput => 3
135+
}
136+
137+
override def fetches(fetchable: OutputLike): Seq[Output] = fetchable match {
138+
case o: Output => Seq(o)
139+
case o: OutputIndexedSlices => Seq(o.indices, o.values, o.denseShape)
140+
case o: SparseOutput => Seq(o.indices, o.values, o.denseShape)
141+
}
142+
143+
override def segment(fetchable: OutputLike, values: Seq[Tensor]): (TensorLike, Seq[Tensor]) = fetchable match {
144+
case _: Output => (values.head, values.tail)
145+
case _: OutputIndexedSlices => (TensorIndexedSlices(values(0), values(1), values(2)), values.drop(3))
146+
case _: SparseOutput => (SparseTensor(values(0), values(1), values(2)), values.drop(3))
147+
}
148+
}
149+
126150
implicit def fetchableArray[T, R: ClassTag](implicit ev: Aux[T, R]): Aux[Array[T], Array[R]] = {
127151
new Fetchable[Array[T]] {
128152
override type ResultType = Array[R]

0 commit comments

Comments
 (0)