@@ -196,6 +196,16 @@ public struct Dense<Scalar: TensorFlowFloatingPoint>: Layer {
196196 public typealias Activation = @differentiable ( Tensor < Scalar > ) -> Tensor < Scalar >
197197 @noDerivative public let activation : Activation
198198
199+ public init (
200+ weight: Tensor < Scalar > ,
201+ bias: Tensor < Scalar > ,
202+ activation: @escaping Activation
203+ ) {
204+ self . weight = weight
205+ self . bias = bias
206+ self . activation = activation
207+ }
208+
199209 @differentiable
200210 public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
201211 return activation ( matmul ( input, weight) + bias)
@@ -245,6 +255,20 @@ public struct Conv2D<Scalar: TensorFlowFloatingPoint>: Layer {
245255 @noDerivative public let strides : ( Int32 , Int32 )
246256 @noDerivative public let padding : Padding
247257
258+ public init (
259+ filter: Tensor < Scalar > ,
260+ bias: Tensor < Scalar > ,
261+ activation: @escaping Activation ,
262+ strides: ( Int , Int ) ,
263+ padding: Padding
264+ ) {
265+ self . filter = filter
266+ self . bias = bias
267+ self . activation = activation
268+ ( self . strides. 0 , self . strides. 1 ) = ( Int32 ( strides. 0 ) , Int32 ( strides. 1 ) )
269+ self . padding = padding
270+ }
271+
248272 @differentiable
249273 public func applied( to input: Tensor < Scalar > , in _: Context ) -> Tensor < Scalar > {
250274 return activation ( input. convolved2D ( withFilter: filter,
@@ -268,7 +292,7 @@ public extension Conv2D where Scalar.RawSignificand: FixedWidthInteger {
268292 filter: Tensor ( glorotUniform: filterTensorShape) ,
269293 bias: Tensor ( zeros: TensorShape ( [ Int32 ( filterShape. 3 ) ] ) ) ,
270294 activation: activation,
271- strides: ( Int32 ( strides. 0 ) , Int32 ( strides . 1 ) ) ,
295+ strides: strides,
272296 padding: padding)
273297 }
274298
@@ -322,6 +346,25 @@ public struct BatchNorm<Scalar: TensorFlowFloatingPoint>: Layer {
322346 /// The running variance.
323347 @noDerivative public let runningVariance : Parameter < Scalar >
324348
349+ /// The batch dimension.
350+ public init (
351+ axis: Int ,
352+ momentum: Tensor < Scalar > ,
353+ offset: Tensor < Scalar > ,
354+ scale: Tensor < Scalar > ,
355+ epsilon: Tensor < Scalar > ,
356+ runningMean: Tensor < Scalar > ,
357+ runningVariance: Tensor < Scalar >
358+ ) {
359+ self . axis = Int32 ( axis)
360+ self . momentum = momentum
361+ self . offset = offset
362+ self . scale = scale
363+ self . epsilon = epsilon
364+ self . runningMean = Parameter ( runningMean)
365+ self . runningVariance = Parameter ( runningVariance)
366+ }
367+
325368 @differentiable
326369 private func applyingTraining( to input: Tensor < Scalar > ) -> Tensor < Scalar > {
327370 let positiveAxis = ( input. rank + axis) % input. rank
@@ -390,6 +433,18 @@ public struct MaxPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
390433 /// The padding algorithm for pooling.
391434 @noDerivative let padding : Padding
392435
436+ public init (
437+ poolSize: ( Int , Int , Int , Int ) ,
438+ strides: ( Int , Int , Int , Int ) ,
439+ padding: Padding
440+ ) {
441+ ( self . poolSize. 0 , self . poolSize. 1 , self . poolSize. 2 , self . poolSize. 3 )
442+ = ( Int32 ( poolSize. 0 ) , Int32 ( poolSize. 1 ) , Int32 ( poolSize. 2 ) , Int32 ( poolSize. 3 ) )
443+ ( self . strides. 0 , self . strides. 1 , self . strides. 2 , self . strides. 3 )
444+ = ( Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , Int32 ( strides. 2 ) , Int32 ( strides. 3 ) )
445+ self . padding = padding
446+ }
447+
393448 public init ( poolSize: ( Int , Int ) , strides: ( Int , Int ) , padding: Padding = . valid) {
394449 self . poolSize = ( 1 , Int32 ( poolSize. 0 ) , Int32 ( poolSize. 1 ) , 1 )
395450 self . strides = ( 1 , Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , 1 )
@@ -413,6 +468,18 @@ public struct AvgPool2D<Scalar: TensorFlowFloatingPoint>: Layer {
413468 /// The padding algorithm for pooling.
414469 @noDerivative let padding : Padding
415470
471+ public init (
472+ poolSize: ( Int , Int , Int , Int ) ,
473+ strides: ( Int , Int , Int , Int ) ,
474+ padding: Padding
475+ ) {
476+ ( self . poolSize. 0 , self . poolSize. 1 , self . poolSize. 2 , self . poolSize. 3 )
477+ = ( Int32 ( poolSize. 0 ) , Int32 ( poolSize. 1 ) , Int32 ( poolSize. 2 ) , Int32 ( poolSize. 3 ) )
478+ ( self . strides. 0 , self . strides. 1 , self . strides. 2 , self . strides. 3 )
479+ = ( Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , Int32 ( strides. 2 ) , Int32 ( strides. 3 ) )
480+ self . padding = padding
481+ }
482+
416483 public init ( poolSize: ( Int , Int ) , strides: ( Int , Int ) , padding: Padding = . valid) {
417484 self . poolSize = ( 1 , Int32 ( poolSize. 0 ) , Int32 ( poolSize. 1 ) , 1 )
418485 self . strides = ( 1 , Int32 ( strides. 0 ) , Int32 ( strides. 1 ) , 1 )
@@ -437,13 +504,27 @@ public struct LayerNorm<Scalar: TensorFlowFloatingPoint>: Layer {
437504 /// The variance epsilon value.
438505 @noDerivative public let epsilon : Tensor < Scalar >
439506
507+ public init (
508+ offset: Tensor < Scalar > ,
509+ scale: Tensor < Scalar > ,
510+ axis: Int ,
511+ epsilon: Tensor < Scalar >
512+ ) {
513+ self . offset = offset
514+ self . scale = scale
515+ self . axis = Int32 ( axis)
516+ self . epsilon = epsilon
517+ }
518+
440519 public init ( featureCount: Int ,
441520 axis: Int ,
442521 epsilon: Tensor < Scalar > = Tensor ( 0.001 ) ) {
443- self . scale = Tensor < Scalar > ( ones: [ Int32 ( featureCount) ] )
444- self . offset = Tensor < Scalar > ( zeros: [ Int32 ( featureCount) ] )
445- self . axis = Int32 ( axis)
446- self . epsilon = epsilon
522+ self . init (
523+ offset: Tensor ( zeros: [ Int32 ( featureCount) ] ) ,
524+ scale: Tensor ( ones: [ Int32 ( featureCount) ] ) ,
525+ axis: axis,
526+ epsilon: epsilon
527+ )
447528 }
448529
449530 @differentiable
0 commit comments