Skip to content

Commit 5d136c5

Browse files
author
Yehudit Kerido
committed
sdk tests with papermill
Signed-off-by: Yehudit Kerido <[email protected]>
1 parent 56f5bd9 commit 5d136c5

File tree

1 file changed

+113
-6
lines changed

1 file changed

+113
-6
lines changed

examples/v1beta1/sdk/tune-train-from-func.ipynb

+113-6
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,124 @@
8181
},
8282
{
8383
"cell_type": "markdown",
84-
"source": [
85-
"## Create Train Script for CNN Model\n",
86-
"\n",
87-
"This is simple **Convolutional Neural Network (CNN)** model for recognizing hand-written digits using [MNIST Dataset](http://yann.lecun.com/exdb/mnist/). "
88-
],
84+
"id": "ee4a3254",
8985
"metadata": {
9086
"collapsed": false,
87+
"jupyter": {
88+
"outputs_hidden": false
89+
},
9190
"pycharm": {
9291
"name": "#%% md\n"
9392
}
94-
}
93+
},
94+
"source": [
95+
"## Create Train Script for CNN Model\n",
96+
"\n",
97+
"This is simple **Convolutional Neural Network (CNN)** model for recognizing hand-written digits using [MNIST Dataset](http://yann.lecun.com/exdb/mnist/). "
98+
]
99+
},
100+
{
101+
"cell_type": "code",
102+
"execution_count": null,
103+
"id": "fce87ff7-bd14-40de-aaec-824a80021705",
104+
"metadata": {},
105+
"outputs": [],
106+
"source": [
107+
"def train_mnist_model(parameters):\n",
108+
" import tensorflow as tf\n",
109+
" import numpy as np\n",
110+
" import logging\n",
111+
"\n",
112+
" logging.basicConfig(\n",
113+
" format=\"%(asctime)s %(levelname)-8s %(message)s\",\n",
114+
" datefmt=\"%Y-%m-%dT%H:%M:%SZ\",\n",
115+
" level=logging.INFO,\n",
116+
" )\n",
117+
" logging.info(\"--------------------------------------------------------------------------------------\")\n",
118+
" logging.info(f\"Input Parameters: {parameters}\")\n",
119+
" logging.info(\"--------------------------------------------------------------------------------------\\n\\n\")\n",
120+
"\n",
121+
"\n",
122+
" # Get HyperParameters from the input params dict.\n",
123+
" lr = float(parameters[\"lr\"])\n",
124+
" num_epoch = int(parameters[\"num_epoch\"])\n",
125+
"\n",
126+
" # Set dist parameters and strategy.\n",
127+
" is_dist = parameters[\"is_dist\"]\n",
128+
" num_workers = parameters[\"num_workers\"]\n",
129+
" batch_size_per_worker = 64\n",
130+
" batch_size_global = batch_size_per_worker * num_workers\n",
131+
" strategy = tf.distribute.MultiWorkerMirroredStrategy(\n",
132+
" communication_options=tf.distribute.experimental.CommunicationOptions(\n",
133+
" implementation=tf.distribute.experimental.CollectiveCommunication.RING\n",
134+
" )\n",
135+
" )\n",
136+
"\n",
137+
" # Callback class for logging training.\n",
138+
" # Katib parses metrics in this format: <metric-name>=<metric-value>.\n",
139+
" class CustomCallback(tf.keras.callbacks.Callback):\n",
140+
" def on_epoch_end(self, epoch, logs=None):\n",
141+
" logging.info(\n",
142+
" \"Epoch {}/{}. accuracy={:.4f} - loss={:.4f}\".format(\n",
143+
" epoch+1, num_epoch, logs[\"accuracy\"], logs[\"loss\"]\n",
144+
" )\n",
145+
" )\n",
146+
"\n",
147+
" # Prepare MNIST Dataset.\n",
148+
" def mnist_dataset(batch_size):\n",
149+
" (x_train, y_train), _ = tf.keras.datasets.mnist.load_data()\n",
150+
" x_train = x_train / np.float32(255)\n",
151+
" y_train = y_train.astype(np.int64)\n",
152+
" train_dataset = (\n",
153+
" tf.data.Dataset.from_tensor_slices((x_train, y_train))\n",
154+
" .shuffle(60000)\n",
155+
" .repeat()\n",
156+
" .batch(batch_size)\n",
157+
" )\n",
158+
" return train_dataset\n",
159+
"\n",
160+
" # Build and compile CNN Model.\n",
161+
" def build_and_compile_cnn_model():\n",
162+
" model = tf.keras.Sequential(\n",
163+
" [\n",
164+
" tf.keras.layers.InputLayer(input_shape=(28, 28)),\n",
165+
" tf.keras.layers.Reshape(target_shape=(28, 28, 1)),\n",
166+
" tf.keras.layers.Conv2D(32, 3, activation=\"relu\"),\n",
167+
" tf.keras.layers.Flatten(),\n",
168+
" tf.keras.layers.Dense(128, activation=\"relu\"),\n",
169+
" tf.keras.layers.Dense(10),\n",
170+
" ]\n",
171+
" )\n",
172+
" model.compile(\n",
173+
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
174+
" optimizer=tf.keras.optimizers.SGD(learning_rate=lr),\n",
175+
" metrics=[\"accuracy\"],\n",
176+
" )\n",
177+
" return model\n",
178+
" \n",
179+
" # Download Dataset.\n",
180+
" dataset = mnist_dataset(batch_size_global)\n",
181+
"\n",
182+
" # For dist strategy we should build model under scope().\n",
183+
" if is_dist:\n",
184+
" logging.info(\"Running Distributed Training\")\n",
185+
" logging.info(\"--------------------------------------------------------------------------------------\\n\\n\")\n",
186+
" with strategy.scope():\n",
187+
" model = build_and_compile_cnn_model()\n",
188+
" else:\n",
189+
" logging.info(\"Running Single Worker Training\")\n",
190+
" logging.info(\"--------------------------------------------------------------------------------------\\n\\n\")\n",
191+
" model = build_and_compile_cnn_model()\n",
192+
" \n",
193+
" # Start Training.\n",
194+
" model.fit(\n",
195+
" dataset,\n",
196+
" epochs=num_epoch,\n",
197+
" steps_per_epoch=70,\n",
198+
" callbacks=[CustomCallback()],\n",
199+
" verbose=0,\n",
200+
" )"
201+
]
95202
},
96203
{
97204
"cell_type": "markdown",

0 commit comments

Comments
 (0)