@@ -605,10 +605,36 @@ public struct UpSampling2D<Scalar: TensorFlowFloatingPoint>: Layer {
605605 public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
606606 let shape = input. shape
607607 let ( batchSize, height, width, channels) = ( shape [ 0 ] , shape [ 1 ] , shape [ 2 ] , shape [ 3 ] )
608- let reshapeSize = Tensor < Int32 > ( [ batchSize, height, 1 , width, 1 , channels] )
609608 let scaleOnes = Tensor < Scalar > ( ones: [ 1 , 1 , size, 1 , size, 1 ] )
610- let upSampling = input. reshaped ( toShape: reshapeSize) * scaleOnes
611- let upSampledShape = Tensor < Int32 > ( [ batchSize, height * size, width * size, channels] )
612- return upSampling. reshaped ( toShape: upSampledShape)
609+ let upSampling = input. reshaped ( to: [ batchSize, height, 1 , width, 1 , channels] ) * scaleOnes
610+ return upSampling. reshaped ( to: [ batchSize, height * size, width * size, channels] )
611+ }
612+ }
613+
614+ @_fixed_layout
615+ public struct Flatten < Scalar: TensorFlowFloatingPoint > : Layer {
616+ @differentiable
617+ public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
618+ let batchSize = input. shape [ 0 ]
619+ let remaining = input. shape [ 1 ..< input. rank] . contiguousSize
620+ return input. reshaped ( to: [ batchSize, remaining] )
621+ }
622+ }
623+
624+ @_fixed_layout
625+ public struct Reshape < Scalar: TensorFlowFloatingPoint > : Layer {
626+ @noDerivative public let shape : Tensor < Int32 >
627+
628+ public init ( shape: Tensor < Int32 > ) {
629+ self . shape = shape
630+ }
631+
632+ public init ( _ shape: TensorShape ) {
633+ self . init ( shape: Tensor ( shape. dimensions) )
634+ }
635+
636+ @differentiable
637+ public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
638+ return input. reshaped ( toShape: shape)
613639 }
614640}
0 commit comments