|
15 | 15 |
|
16 | 16 | package org.platanios.tensorflow.api.core.client
|
17 | 17 |
|
18 |
| -import org.platanios.tensorflow.api.ops.{Output, OutputIndexedSlices, SparseOutput} |
19 |
| -import org.platanios.tensorflow.api.tensors.{SparseTensor, Tensor, TensorIndexedSlices} |
| 18 | +import org.platanios.tensorflow.api.ops.{Output, OutputIndexedSlices, OutputLike, SparseOutput} |
| 19 | +import org.platanios.tensorflow.api.tensors.{SparseTensor, Tensor, TensorIndexedSlices, TensorLike} |
20 | 20 | import org.platanios.tensorflow.api.utilities.Collections
|
21 | 21 |
|
22 | 22 | import shapeless._
|
@@ -123,6 +123,30 @@ object Fetchable {
|
123 | 123 | }
|
124 | 124 | }
|
125 | 125 |
|
| 126 | + // TODO: Make this more elegant. |
| 127 | + |
| 128 | + implicit val outputLikeFetchable: Aux[OutputLike, TensorLike] = new Fetchable[OutputLike] { |
| 129 | + override type ResultType = TensorLike |
| 130 | + |
| 131 | + override def numberOfFetches(fetchable: OutputLike): Int = fetchable match { |
| 132 | + case _: Output => 1 |
| 133 | + case _: OutputIndexedSlices => 3 |
| 134 | + case _: SparseOutput => 3 |
| 135 | + } |
| 136 | + |
| 137 | + override def fetches(fetchable: OutputLike): Seq[Output] = fetchable match { |
| 138 | + case o: Output => Seq(o) |
| 139 | + case o: OutputIndexedSlices => Seq(o.indices, o.values, o.denseShape) |
| 140 | + case o: SparseOutput => Seq(o.indices, o.values, o.denseShape) |
| 141 | + } |
| 142 | + |
| 143 | + override def segment(fetchable: OutputLike, values: Seq[Tensor]): (TensorLike, Seq[Tensor]) = fetchable match { |
| 144 | + case _: Output => (values.head, values.tail) |
| 145 | + case _: OutputIndexedSlices => (TensorIndexedSlices(values(0), values(1), values(2)), values.drop(3)) |
| 146 | + case _: SparseOutput => (SparseTensor(values(0), values(1), values(2)), values.drop(3)) |
| 147 | + } |
| 148 | + } |
| 149 | + |
126 | 150 | implicit def fetchableArray[T, R: ClassTag](implicit ev: Aux[T, R]): Aux[Array[T], Array[R]] = {
|
127 | 151 | new Fetchable[Array[T]] {
|
128 | 152 | override type ResultType = Array[R]
|
|
0 commit comments