Skip to content

Commit 8c7116e

Browse files
feat: add minst pytorch example.
Signed-off-by: Electronic-Waste <[email protected]>
1 parent 09523cd commit 8c7116e

File tree

1 file changed

+313
-0
lines changed

1 file changed

+313
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,313 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Tune and Train with Push-based Metrics Collection Using MNIST\n",
8+
"\n",
9+
"In this Notebook we are going to do the following:\n",
10+
"- Train PyTorch MNIST image classification model(CNN).\n",
11+
"- Improve the model HyperParameters with [Kubeflow Katib](https://www.kubeflow.org/docs/components/katib/overview/).\n",
12+
"- Use Push-based Metrics Collection to efficiently collect metrics in the training containers."
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"## Install Kubeflow Python SDKs\n",
20+
"\n",
21+
"You need to install Kubeflow SDKs to run this Notebook."
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": null,
27+
"metadata": {},
28+
"outputs": [],
29+
"source": [
30+
"# TODO (Electronic-Waste): Change to release version when SDK with the updated `tune()` is published.\n",
31+
"%pip install git+https://github.com/kubeflow/katib.git#subdirectory=sdk/python/v1beta1"
32+
]
33+
},
34+
{
35+
"cell_type": "markdown",
36+
"metadata": {},
37+
"source": [
38+
"## Create Train Script for CNN Model\n",
39+
"\n",
40+
"This is simple **Convolutional Neural Network (CNN)** model for recognizing hand-written digits using [MNIST Dataset](https://yann.lecun.com/exdb/mnist/)."
41+
]
42+
},
43+
{
44+
"cell_type": "code",
45+
"execution_count": 1,
46+
"metadata": {},
47+
"outputs": [],
48+
"source": [
49+
"def train_mnist_model(parameters):\n",
50+
" import torch\n",
51+
" import logging\n",
52+
" import kubeflow.katib as katib\n",
53+
" from torchvision import datasets, transforms\n",
54+
"\n",
55+
" logging.basicConfig(\n",
56+
" format=\"%(asctime)s %(levelname)-8s %(message)s\",\n",
57+
" datefmt=\"%Y-%m-%dT%H:%M:%SZ\",\n",
58+
" level=logging.INFO,\n",
59+
" )\n",
60+
" logging.info(\"--------------------------------------------------------------------------------------\")\n",
61+
" logging.info(f\"Input Parameters: {parameters}\")\n",
62+
" logging.info(\"--------------------------------------------------------------------------------------\\n\\n\")\n",
63+
"\n",
64+
" # Get HyperParameters from the input params dict.\n",
65+
" lr = float(parameters[\"lr\"])\n",
66+
" momentum = float(parameters[\"momentum\"])\n",
67+
" batch_size = int(parameters[\"batch_size\"])\n",
68+
" num_epoch = int(parameters[\"num_epoch\"])\n",
69+
" log_interval = int(parameters[\"log_interval\"])\n",
70+
"\n",
71+
" # Prepare MNIST Dataset.\n",
72+
" def mnist_train_dataset(batch_size):\n",
73+
" return torch.utils.data.DataLoader(\n",
74+
" datasets.FashionMNIST(\n",
75+
" \"./data\",\n",
76+
" train=True,\n",
77+
" download=True,\n",
78+
" transform=transforms.Compose([transforms.ToTensor()]),\n",
79+
" ),\n",
80+
" batch_size=batch_size,\n",
81+
" shuffle=True,\n",
82+
" )\n",
83+
"\n",
84+
" def mnist_test_dataset(batch_size):\n",
85+
" return torch.utils.data.DataLoader(\n",
86+
" datasets.FashionMNIST(\n",
87+
" \"./data\", train=False, transform=transforms.Compose([transforms.ToTensor()])\n",
88+
" ),\n",
89+
" batch_size=batch_size,\n",
90+
" shuffle=False,\n",
91+
" )\n",
92+
" \n",
93+
" # Build CNN Model.\n",
94+
" def build_and_compile_cnn_model():\n",
95+
" return torch.nn.Sequential(\n",
96+
" torch.nn.Conv2d(1, 20, 5, 1),\n",
97+
" torch.nn.ReLU(),\n",
98+
" torch.nn.MaxPool2d(2, 2),\n",
99+
" \n",
100+
" torch.nn.Conv2d(20, 50, 5, 1),\n",
101+
" torch.nn.ReLU(),\n",
102+
" torch.nn.MaxPool2d(2, 2),\n",
103+
" \n",
104+
" torch.nn.Flatten(),\n",
105+
" \n",
106+
" torch.nn.Linear(4 * 4 * 50, 500),\n",
107+
" torch.nn.ReLU(),\n",
108+
" \n",
109+
" torch.nn.Linear(500, 10),\n",
110+
" torch.nn.LogSoftmax(dim=1)\n",
111+
" )\n",
112+
" \n",
113+
" # Train CNN Model.\n",
114+
" def train_cnn_model(model, train_loader, optimizer, epoch):\n",
115+
" model.train()\n",
116+
" for batch_idx, (data, target) in enumerate(train_loader):\n",
117+
" optimizer.zero_grad()\n",
118+
" output = model(data)\n",
119+
" loss = torch.nn.functional.nll_loss(output, target)\n",
120+
" loss.backward()\n",
121+
" optimizer.step()\n",
122+
" if batch_idx % log_interval == 0:\n",
123+
" msg = \"Train Epoch: {} [{}/{} ({:.0f}%)]\\tloss={:.4f}\".format(\n",
124+
" epoch,\n",
125+
" batch_idx * len(data),\n",
126+
" len(train_loader.dataset),\n",
127+
" 100.0 * batch_idx / len(train_loader),\n",
128+
" loss.item(),\n",
129+
" )\n",
130+
" logging.info(msg)\n",
131+
" \n",
132+
" # Test CNN Model and report training metrics\n",
133+
" def test_cnn_model(model, test_loader):\n",
134+
" model.eval()\n",
135+
" test_loss = 0\n",
136+
" correct = 0\n",
137+
" with torch.no_grad():\n",
138+
" for data, target in test_loader:\n",
139+
" output = model(data)\n",
140+
" test_loss += torch.nn.functional.nll_loss(\n",
141+
" output, target, reduction=\"sum\"\n",
142+
" ).item() # sum up batch loss\n",
143+
" pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability\n",
144+
" correct += pred.eq(target.view_as(pred)).sum().item()\n",
145+
" \n",
146+
" test_loss /= len(test_loader.dataset)\n",
147+
" test_accuracy = float(correct) / len(test_loader.dataset)\n",
148+
" katib.report_metrics({ # report metrics directly without outputing logs\n",
149+
" \"accuracy\": test_accuracy, \n",
150+
" \"loss\": test_loss,\n",
151+
" })\n",
152+
"\n",
153+
" # Download dataset and construct loaders for training and testing\n",
154+
" train_loader = mnist_train_dataset(batch_size)\n",
155+
" test_loader = mnist_test_dataset(batch_size)\n",
156+
"\n",
157+
" # Build Model and Optimizer\n",
158+
" model = build_and_compile_cnn_model()\n",
159+
" optimizer = torch.optim.SGD(model.parameters(), lr, momentum)\n",
160+
"\n",
161+
" # Train Model and report metrics\n",
162+
" for epoch_idx in range(1, num_epoch + 1):\n",
163+
" train_cnn_model(model, train_loader, optimizer, epoch_idx)\n",
164+
" test_cnn_model(model, test_loader)\n",
165+
"\n"
166+
]
167+
},
168+
{
169+
"cell_type": "markdown",
170+
"metadata": {},
171+
"source": [
172+
"## Start Model Tuning with Katib\n",
173+
"\n",
174+
"If you want to improve your model, you can run HyperParameter tuning with Katib.\n",
175+
"\n",
176+
"The following example uses **Random Search** algorithm to tune HyperParameters.\n",
177+
"\n",
178+
"We are going to tune `learning rate` and `momentum`."
179+
]
180+
},
181+
{
182+
"cell_type": "code",
183+
"execution_count": 2,
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"import kubeflow.katib as katib\n",
188+
"\n",
189+
"# Set parameters with their distribution for HyperParameter Tuning with Katib.\n",
190+
"parameters = {\n",
191+
" \"lr\": katib.search.double(min=0.01, max=0.03),\n",
192+
" \"momentum\": katib.search.double(min=0.3, max=0.7),\n",
193+
" \"num_epoch\": 1,\n",
194+
" \"batch_size\": 64,\n",
195+
" \"log_interval\": 10\n",
196+
"}\n",
197+
"\n",
198+
"# Start the Katib Experiment.\n",
199+
"# TODO (Electronic-Waste): \n",
200+
"# 1. Change `kubeflow-katib` to release version when `0.18.0` is ready.\n",
201+
"# 2. Change `base_image` to official image when `kubeflow-katib` release version `0.18.0` is ready.\n",
202+
"exp_name = \"tune-mnist\"\n",
203+
"katib_client = katib.KatibClient(namespace=\"kubeflow\")\n",
204+
"\n",
205+
"katib_client.tune(\n",
206+
" name=exp_name,\n",
207+
" objective=train_mnist_model, # Objective function.\n",
208+
" base_image=\"docker.io/electronicwaste/pytorch:gitv1\",\n",
209+
" parameters=parameters, # HyperParameters to tune.\n",
210+
" algorithm_name=\"random\", # Alorithm to use.\n",
211+
" objective_metric_name=\"accuracy\", # Katib is going to optimize \"accuracy\".\n",
212+
" additional_metric_names=[\"loss\"], # Katib is going to collect these metrics in addition to the objective metric.\n",
213+
" max_trial_count=12, # Trial Threshold.\n",
214+
" parallel_trial_count=2,\n",
215+
" packages_to_install=[\"git+https://github.com/kubeflow/katib.git@master#subdirectory=sdk/python/v1beta1\"],\n",
216+
" metrics_collector_config={\"kind\": \"Push\"},\n",
217+
")"
218+
]
219+
},
220+
{
221+
"cell_type": "markdown",
222+
"metadata": {},
223+
"source": [
224+
"### Access to Katib UI\n",
225+
"\n",
226+
"You can check created experiment in the Katib UI.\n",
227+
"\n"
228+
]
229+
},
230+
{
231+
"cell_type": "markdown",
232+
"metadata": {},
233+
"source": [
234+
"### Get the Best HyperParameters from the Katib Experiment\n",
235+
"\n",
236+
"You can get the best HyperParameters from the most optimal Katib Trial."
237+
]
238+
},
239+
{
240+
"cell_type": "code",
241+
"execution_count": 4,
242+
"metadata": {},
243+
"outputs": [
244+
{
245+
"name": "stdout",
246+
"output_type": "stream",
247+
"text": [
248+
"Katib Experiment is Succeeded: True\n",
249+
"\n",
250+
"Current Optimal Trial:\n",
251+
"{'best_trial_name': 'tune-mnist-xqwfhr9w',\n",
252+
" 'observation': {'metrics': [{'latest': '0.8276',\n",
253+
" 'max': '0.8276',\n",
254+
" 'min': '0.8276',\n",
255+
" 'name': 'accuracy'},\n",
256+
" {'latest': '0.48769191679954527',\n",
257+
" 'max': '0.48769191679954527',\n",
258+
" 'min': '0.48769191679954527',\n",
259+
" 'name': 'loss'}]},\n",
260+
" 'parameter_assignments': [{'name': 'lr', 'value': '0.024527727574297616'},\n",
261+
" {'name': 'momentum', 'value': '0.6490973329748595'}]}\n"
262+
]
263+
}
264+
],
265+
"source": [
266+
"status = katib_client.is_experiment_succeeded(exp_name)\n",
267+
"print(f\"Katib Experiment is Succeeded: {status}\\n\")\n",
268+
"\n",
269+
"best_hps = katib_client.get_optimal_hyperparameters(exp_name)\n",
270+
"print(f\"Current Optimal Trial:\\n{best_hps}\")"
271+
]
272+
},
273+
{
274+
"cell_type": "markdown",
275+
"metadata": {},
276+
"source": [
277+
"## Delete Katib Experiment\n",
278+
"\n",
279+
"When jobs are finished, you can delete the resources."
280+
]
281+
},
282+
{
283+
"cell_type": "code",
284+
"execution_count": 5,
285+
"metadata": {},
286+
"outputs": [],
287+
"source": [
288+
"katib_client.delete_experiment(exp_name)"
289+
]
290+
}
291+
],
292+
"metadata": {
293+
"kernelspec": {
294+
"display_name": "katib",
295+
"language": "python",
296+
"name": "python3"
297+
},
298+
"language_info": {
299+
"codemirror_mode": {
300+
"name": "ipython",
301+
"version": 3
302+
},
303+
"file_extension": ".py",
304+
"mimetype": "text/x-python",
305+
"name": "python",
306+
"nbconvert_exporter": "python",
307+
"pygments_lexer": "ipython3",
308+
"version": "3.10.14"
309+
}
310+
},
311+
"nbformat": 4,
312+
"nbformat_minor": 2
313+
}

0 commit comments

Comments
 (0)