-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
112 changed files
with
9,676 additions
and
42 deletions.
There are no files selected for viewing
46 changes: 46 additions & 0 deletions
46
...-1_running_federated_learning_applications/01.5_experiment_tracking/code/fl_job_mlflow.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from nvflare.app_opt.tracking.mlflow.mlflow_receiver import MLflowReceiver | ||
from src.network import SimpleNetwork | ||
|
||
from nvflare.app_opt.pt.job_config.fed_avg import FedAvgJob | ||
from nvflare.job_config.script_runner import ScriptRunner | ||
|
||
if __name__ == "__main__": | ||
n_clients = 5 | ||
num_rounds = 2 | ||
|
||
train_script = "src/client.py" | ||
|
||
job = FedAvgJob(name="fedavg", n_clients=n_clients, num_rounds=num_rounds, initial_model=SimpleNetwork()) | ||
receiver = MLflowReceiver( | ||
tracking_uri="file:///tmp/nvflare/jobs/workdir/server/simulate_job/mlruns", | ||
kw_args={ | ||
"experiment_name": "nvflare-fedavg-experiment", | ||
"run_name": "nvflare-fedavg-with-mlflow", | ||
"experiment_tags": {"mlflow.note.content": "## **NVFlare FedAvg experiment with MLflow**"}, | ||
"run_tags": {"mlflow.note.content": "## Federated Experiment tracking with MLflow.\n"}, | ||
}, | ||
) | ||
job.to_server(receiver) | ||
|
||
# Add clients | ||
for i in range(n_clients): | ||
executor = ScriptRunner( | ||
script=train_script, script_args="" # f"--batch_size 32 --data_path /tmp/data/site-{i}" | ||
) | ||
job.to(executor, f"site-{i + 1}") | ||
|
||
job.simulator_run(workspace="/tmp/nvflare/jobs/workdir", log_config="./log_config.json") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
316 changes: 316 additions & 0 deletions
316
...federated_learning_applications/01.5_experiment_tracking/experiment_tracking_mlflow.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,316 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "b75b2253-cba8-4579-907b-09311e0da587", | ||
"metadata": {}, | ||
"source": [ | ||
"# Experiment Tracking with MLflow\n", | ||
"\n", | ||
"If you would like to use MLflow for experiment tracking, NVFlare has `MLflowReceiver` available for use on the FL server to log to a MLflow tracking server.\n", | ||
"\n", | ||
"## Introduction to distributed experiment tracking\n", | ||
"\n", | ||
"In a federated computing setting, data is distributed across multiple devices or systems, and training is run on each device independently while preserving each client’s data privacy.\n", | ||
"\n", | ||
"Assuming a federated system consisting of one server and many clients and the server coordinating the ML training of clients, we can interact with ML experiment tracking tools in two different ways:\n", | ||
"\n", | ||
"- Client-side experiment tracking: Each client will directly send the log metrics/parameters to the ML experiment tracking server (like MLflow or Weights and Biases) or local file system (like tensorboard)\n", | ||
"- Aggregated experiment tracking: Clients will send the log metrics/parameters to the FL server, and the FL server will send the metrics to ML experiment tracking server or local file system\n", | ||
"\n", | ||
"NVFlare makes it possible for you to configure either way, but in this example we will demonstrate a server-side approach for aggregated experiment tracking.\n", | ||
"\n", | ||
"## Default in FedAvgJob\n", | ||
"\n", | ||
"The FedJob API makes it easy to create job congifurations, and by default the `TBAnalyticsReceiver` for TensorBoard streaming is included. You can specify your own analytics_receiver of type `AnalyticsReceiver` as a parameter if you want, but if left unspecified, `TBAnalyticsReceiver` is configured to be set up in `BaseFedJob` (nvflare/app_opt/pt/job_config/base_fed_job.py). \n", | ||
"\n", | ||
"The `TBAnalyticsReceiver` for TensorBoard streaming receives and records the logs during the experiment by saving them to Tensoboard event files on the FL server. See [this link](https://nvflare.readthedocs.io/en/main/programming_guide/experiment_tracking/experiment_tracking_log_writer.html#tools-sender-logwriter-and-receivers) for more details on the other available AnalyticsReceivers in NVFlare: MLflowReceiver and WandBReceiver." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4d9db032", | ||
"metadata": {}, | ||
"source": [ | ||
"## Add SummaryWriter and add_scalar for logging metrics\n", | ||
"\n", | ||
"To keep things simple, we start from the state of the code we had in part 1.1 earlier this chapter and make the few modifications needed to implement adding metrics for experiment tracking.\n", | ||
"\n", | ||
"### Add import from Client API \n", | ||
"\n", | ||
"In order to add SummaryWriter to the client training code, we need to import it with the following line (at the top of client.py):" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "f74bcbbc", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from nvflare.client.tracking import SummaryWriter" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d3052e9c", | ||
"metadata": {}, | ||
"source": [ | ||
"After that, we need to add the following line after `flare.init()`:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "26146142", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"summary_writer = SummaryWriter()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "4f460079", | ||
"metadata": {}, | ||
"source": [ | ||
"We can then use summary_writer to log. In this case, we have a running_loss available already, so we can use `add_scalar()` to log this:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "2a846954", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"summary_writer.add_scalar(tag=\"training_loss\", scalar=running_loss, global_step=global_step)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a6840044", | ||
"metadata": {}, | ||
"source": [ | ||
"Note that the global_step is included here, so we calculate the global step for it on the previous line:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "15e9b2fd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"global_step = epoch * n_loaders + i" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0d03f421", | ||
"metadata": {}, | ||
"source": [ | ||
"Also note that we log once every 100 steps to reduce the burden on the logging." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "66e03003", | ||
"metadata": {}, | ||
"source": [ | ||
"You can see the full contents of the updated training code in client.py:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "fdd7a99d", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.\n", | ||
"#\n", | ||
"# Licensed under the Apache License, Version 2.0 (the \"License\");\n", | ||
"# you may not use this file except in compliance with the License.\n", | ||
"# You may obtain a copy of the License at\n", | ||
"#\n", | ||
"# http://www.apache.org/licenses/LICENSE-2.0\n", | ||
"#\n", | ||
"# Unless required by applicable law or agreed to in writing, software\n", | ||
"# distributed under the License is distributed on an \"AS IS\" BASIS,\n", | ||
"# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", | ||
"# See the License for the specific language governing permissions and\n", | ||
"# limitations under the License.\n", | ||
"\n", | ||
"import os\n", | ||
"\n", | ||
"import torch\n", | ||
"from network import SimpleNetwork\n", | ||
"from torch import nn\n", | ||
"from torch.optim import SGD\n", | ||
"from torch.utils.data.dataloader import DataLoader\n", | ||
"from torchvision.datasets import CIFAR10\n", | ||
"from torchvision.transforms import Compose, Normalize, ToTensor\n", | ||
"\n", | ||
"import nvflare.client as flare\n", | ||
"\n", | ||
"from nvflare.client.tracking import SummaryWriter\n", | ||
"\n", | ||
"DATASET_PATH = \"/tmp/nvflare/data\"\n", | ||
"\n", | ||
"\n", | ||
"def main():\n", | ||
" batch_size = 4\n", | ||
" epochs = 1\n", | ||
" lr = 0.01\n", | ||
" model = SimpleNetwork()\n", | ||
" device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", | ||
" loss = nn.CrossEntropyLoss()\n", | ||
" optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)\n", | ||
" transforms = Compose(\n", | ||
" [\n", | ||
" ToTensor(),\n", | ||
" Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),\n", | ||
" ]\n", | ||
" )\n", | ||
"\n", | ||
" flare.init()\n", | ||
" summary_writer = SummaryWriter()\n", | ||
" sys_info = flare.system_info()\n", | ||
" site_name = sys_info[\"site_name\"]\n", | ||
"\n", | ||
" data_path = os.path.join(DATASET_PATH, site_name)\n", | ||
"\n", | ||
" train_dataset = CIFAR10(root=data_path, transform=transforms, download=True, train=True)\n", | ||
" train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)\n", | ||
" n_loaders = len(train_loader)\n", | ||
"\n", | ||
" print(\"number of loaders = \", n_loaders)\n", | ||
"\n", | ||
" round = 0\n", | ||
" last_loss = 0\n", | ||
" while flare.is_running():\n", | ||
" input_model = flare.receive()\n", | ||
" round = input_model.current_round\n", | ||
"\n", | ||
" print(f\"\\n\\nsite_name={site_name}, current_round={round + 1}\\n \")\n", | ||
"\n", | ||
" model.load_state_dict(input_model.params)\n", | ||
" model.to(device)\n", | ||
"\n", | ||
" steps = epochs * n_loaders\n", | ||
" for epoch in range(epochs):\n", | ||
" running_loss = 0.0\n", | ||
"\n", | ||
" for i, batch in enumerate(train_loader):\n", | ||
" images, labels = batch[0].to(device), batch[1].to(device)\n", | ||
"\n", | ||
" optimizer.zero_grad()\n", | ||
"\n", | ||
" predictions = model(images)\n", | ||
" cost = loss(predictions, labels)\n", | ||
" cost.backward()\n", | ||
" optimizer.step()\n", | ||
"\n", | ||
" running_loss += cost.cpu().detach().numpy() / batch_size\n", | ||
"\n", | ||
" if i % 100 == 0:\n", | ||
" global_step = epoch * n_loaders + i\n", | ||
" summary_writer.add_scalar(tag=\"training_loss\", scalar=running_loss, global_step=global_step)\n", | ||
"\n", | ||
" if i % 3000 == 0:\n", | ||
" print(\n", | ||
" f\"Round: {round + 1}, Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {running_loss / 3000}\"\n", | ||
" )\n", | ||
" running_loss = 0.0\n", | ||
"\n", | ||
" last_loss = {running_loss / (i + 1)}\n", | ||
" print(\n", | ||
" f\"site: {site_name}, round: {round + 1}, Epoch: {epoch + 1}/{epochs}, batch: {i + 1}, Loss: {last_loss}\"\n", | ||
" )\n", | ||
"\n", | ||
" print(\"Finished Training\")\n", | ||
"\n", | ||
" PATH = \"./cifar_net.pth\"\n", | ||
" torch.save(model.state_dict(), PATH)\n", | ||
"\n", | ||
" output_model = flare.FLModel(\n", | ||
" params=model.cpu().state_dict(),\n", | ||
" meta={\"NUM_STEPS_CURRENT_ROUND\": steps},\n", | ||
" )\n", | ||
"\n", | ||
" flare.send(output_model)\n", | ||
"\n", | ||
" print(\n", | ||
" f\"\\n\"\n", | ||
" f\"Result Summary\\n\"\n", | ||
" \" Training parameters:\\n\"\n", | ||
" \" number of clients = 5\\n\"\n", | ||
" f\" round = {round + 1},\\n\"\n", | ||
" f\" batch_size = {batch_size},\\n\"\n", | ||
" f\" epochs = {epochs},\\n\"\n", | ||
" f\" lr = {lr},\\n\"\n", | ||
" f\" total data batches = {n_loaders},\\n\"\n", | ||
" f\" Metrics: last_loss = {last_loss}\\n\"\n", | ||
" )\n", | ||
"\n", | ||
"\n", | ||
"if __name__ == \"__main__\":\n", | ||
" main()\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"!cat code/src/client.py" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "3a364ce1", | ||
"metadata": {}, | ||
"source": [ | ||
"## View tensorboard results\n", | ||
"\n", | ||
"In order to see the results, you can use the following command directed to the location of the tensorboard event files (by default the location for the server should be as follows using the default simulator path provided):\n", | ||
"\n", | ||
"```commandline\n", | ||
"tensorboard --logdir=/tmp/nvflare/jobs/workdir/server/simulate_job/tb_events\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bdd0eb76", | ||
"metadata": {}, | ||
"source": [ | ||
"Now, we know how experiment tracking can be achieved through metric logging and can be configured to work in a job with an `AnalyticsReceiver`. With this mechanism, we can stream various types of metric data.\n", | ||
"\n", | ||
"To continue, please see [Understanding FLARE federated learning Job structure](../01.6_job_structure_and_configuration/01.1.6.1_understanding_fl_job.ipynb)." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": ".venvpt", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.2" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Oops, something went wrong.