@@ -35,13 +35,11 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
3535 public let decay : Scalar
3636
3737 public init (
38- for _: __shared Model,
3938 learningRate: Scalar = 1e-3 ,
4039 beta1: Scalar = 0.9 ,
4140 beta2: Scalar = 0.999 ,
4241 epsilon: Scalar = 1e-8 ,
43- decay: Scalar = 0 ,
44- scalarType: Scalar . Type
42+ decay: Scalar = 0
4543 ) {
4644 precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
4745 precondition ( 0 <= beta1 && beta1 <= 1 , " Beta parameter must be between 0 and 1 " )
@@ -55,6 +53,23 @@ public class Adam<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
5553 self . decay = decay
5654 }
5755
56+ public convenience init (
57+ for _: __shared Model,
58+ learningRate: Scalar = 1e-3 ,
59+ beta1: Scalar = 0.9 ,
60+ beta2: Scalar = 0.999 ,
61+ epsilon: Scalar = 1e-8 ,
62+ decay: Scalar = 0 ,
63+ scalarType: Scalar . Type
64+ ) {
65+ self . init (
66+ learningRate: learningRate,
67+ beta1: beta1,
68+ beta2: beta2,
69+ epsilon: epsilon,
70+ decay: decay)
71+ }
72+
5873 private var step : Scalar = 0
5974 private var firstMoments = Model . AllDifferentiableVariables. zero
6075 private var secondMoments = Model . AllDifferentiableVariables. zero
@@ -84,12 +99,10 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
8499 public let decay : Scalar
85100
86101 public init (
87- for _: __shared Model,
88102 learningRate: Scalar = 0.001 ,
89103 rho: Scalar = 0.9 ,
90104 epsilon: Scalar = 1e-8 ,
91- decay: Scalar = 0 ,
92- scalarType: Scalar . Type
105+ decay: Scalar = 0
93106 ) {
94107 precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
95108 precondition ( rho >= 0 , " Rho must be non-negative " )
@@ -101,6 +114,17 @@ public class RMSProp<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
101114 self . decay = decay
102115 }
103116
117+ public convenience init (
118+ for _: __shared Model,
119+ learningRate: Scalar = 0.001 ,
120+ rho: Scalar = 0.9 ,
121+ epsilon: Scalar = 1e-8 ,
122+ decay: Scalar = 0 ,
123+ scalarType: Scalar . Type
124+ ) {
125+ self . init ( learningRate: learningRate, rho: rho, epsilon: epsilon, decay: decay)
126+ }
127+
104128 private var step : Scalar = 0
105129 private var alpha = Model . AllDifferentiableVariables. zero
106130
@@ -125,12 +149,10 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
125149 public let nesterov : Bool
126150
127151 public init (
128- for _: __shared Model,
129152 learningRate: Scalar = 0.01 ,
130153 momentum: Scalar = 0 ,
131154 decay: Scalar = 0 ,
132- nesterov: Bool = false ,
133- scalarType: Scalar . Type
155+ nesterov: Bool = false
134156 ) {
135157 precondition ( learningRate >= 0 , " Learning rate must be non-negative " )
136158 precondition ( momentum >= 0 , " Momentum must be non-negative " )
@@ -142,6 +164,17 @@ public class SGD<Model: Layer, Scalar: TensorFlowFloatingPoint>: Optimizer
142164 self . nesterov = nesterov
143165 }
144166
167+ public convenience init (
168+ for _: __shared Model,
169+ learningRate: Scalar = 0.01 ,
170+ momentum: Scalar = 0 ,
171+ decay: Scalar = 0 ,
172+ nesterov: Bool = false ,
173+ scalarType: Scalar . Type
174+ ) {
175+ self . init ( learningRate: learningRate, momentum: momentum, decay: decay, nesterov: nesterov)
176+ }
177+
145178 private var step : Scalar = 0
146179 private var velocity = Model . AllDifferentiableVariables. zero
147180
0 commit comments