Skip to content

Commit 9737f53

Browse files
Add flower metrics streaming example (#2764)
* Add flower metrics streaing example * Fix format * Use context and RecordSet * Undo stuff * Update to new style * Update hello-flwr-pt_tb_streaming * Remove debug msgs * Update readme * Use flower job * Add missing code * Make client api type an arg
1 parent e956fea commit 9737f53

File tree

17 files changed

+424
-48
lines changed

17 files changed

+424
-48
lines changed

examples/hello-world/hello-flower/README.md

+21-6
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,35 @@ $ tree jobs/hello-flwr-pt/app/custom
1717
```
1818
Note, this code is adapted from Flower's [app-pytorch](https://github.com/adap/flower/tree/main/examples/app-pytorch) example.
1919

20-
## Install dependencies
20+
## 1. Install dependencies
2121
If you haven't already, we recommend creating a virtual environment.
2222
```bash
2323
python3 -m venv nvflare_flwr
2424
source nvflare_flwr/bin/activate
2525
```
26-
To run a job with NVFlare, we first need to install its dependencies.
26+
27+
## 2.1 Run a simulation
28+
29+
To run flwr-pt job with NVFlare, we first need to install its dependencies.
2730
```bash
28-
pip install ./jobs/hello-flwr-pt/app/custom
31+
pip install ./flwr-pt/
2932
```
3033

31-
## Run a simulation
32-
3334
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare's simulator.
3435
```bash
35-
nvflare simulator jobs/hello-flwr-pt -n 2 -t 2 -w /tmp/nvflare/flwr
36+
python job.py
37+
```
38+
39+
## 2.2 Run a simulation with TensorBoard streaming
40+
41+
To run flwr-pt_tb_streaming job with NVFlare, we first need to install its dependencies.
42+
```bash
43+
pip install ./flwr-pt-metrics/
44+
```
45+
46+
Next, we run 2 Flower clients and Flower Server in parallel using NVFlare while streaming
47+
the TensorBoard metrics to the server at each iteration using NVFlare's metric streaming.
48+
49+
```bash
50+
python job_with_metric.py
3651
```
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from flwr.client import ClientApp, NumPyClient
17+
from flwr.common import Context
18+
from flwr.common.record import MetricsRecord, RecordSet
19+
20+
from .task import DEVICE, Net, get_weights, load_data, set_weights, test, train
21+
22+
# Load model and data (simple CNN, CIFAR-10)
23+
net = Net().to(DEVICE)
24+
trainloader, testloader = load_data()
25+
26+
import nvflare.client as flare
27+
28+
# initializes NVFlare interface
29+
from nvflare.client.tracking import SummaryWriter
30+
31+
flare.init()
32+
33+
34+
# Define FlowerClient and client_fn
35+
class FlowerClient(NumPyClient):
36+
def __init__(self, context: Context):
37+
super().__init__()
38+
self.writer = SummaryWriter()
39+
self.set_context(context)
40+
if "step" not in context.state.metrics_records:
41+
self.set_step(0)
42+
43+
def set_step(self, step: int):
44+
context = self.get_context()
45+
context.state = RecordSet(metrics_records={"step": MetricsRecord({"step": step})})
46+
self.set_context(context)
47+
48+
def get_step(self):
49+
context = self.get_context()
50+
return int(context.state.metrics_records["step"]["step"])
51+
52+
def fit(self, parameters, config):
53+
step = self.get_step()
54+
set_weights(net, parameters)
55+
results = train(net, trainloader, testloader, epochs=1, device=DEVICE)
56+
57+
self.writer.add_scalar("train_loss", results["train_loss"], step)
58+
self.writer.add_scalar("train_accuracy", results["train_accuracy"], step)
59+
self.writer.add_scalar("val_loss", results["val_loss"], step)
60+
self.writer.add_scalar("val_accuracy", results["val_accuracy"], step)
61+
62+
self.set_step(step + 1)
63+
64+
return get_weights(net), len(trainloader.dataset), results
65+
66+
def evaluate(self, parameters, config):
67+
set_weights(net, parameters)
68+
step = self.get_step()
69+
loss, accuracy = test(net, testloader)
70+
71+
self.writer.add_scalar("test_loss", loss, step)
72+
self.writer.add_scalar("test_accuracy", accuracy, step)
73+
74+
return loss, len(testloader.dataset), {"accuracy": accuracy}
75+
76+
77+
def client_fn(context: Context):
78+
"""Create and return an instance of Flower `Client`."""
79+
return FlowerClient(context).to_client()
80+
81+
82+
# Flower ClientApp
83+
app = ClientApp(
84+
client_fn=client_fn,
85+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from typing import List, Tuple
15+
16+
from flwr.common import Metrics, ndarrays_to_parameters
17+
from flwr.server import ServerApp, ServerConfig
18+
from flwr.server.strategy import FedAvg
19+
20+
from .task import Net, get_weights
21+
22+
23+
# Define metric aggregation function
24+
def weighted_average(metrics: List[Tuple[int, Metrics]]) -> Metrics:
25+
examples = [num_examples for num_examples, _ in metrics]
26+
27+
# Multiply accuracy of each client by number of examples used
28+
train_losses = [num_examples * m["train_loss"] for num_examples, m in metrics]
29+
train_accuracies = [num_examples * m["train_accuracy"] for num_examples, m in metrics]
30+
val_losses = [num_examples * m["val_loss"] for num_examples, m in metrics]
31+
val_accuracies = [num_examples * m["val_accuracy"] for num_examples, m in metrics]
32+
33+
# Aggregate and return custom metric (weighted average)
34+
return {
35+
"train_loss": sum(train_losses) / sum(examples),
36+
"train_accuracy": sum(train_accuracies) / sum(examples),
37+
"val_loss": sum(val_losses) / sum(examples),
38+
"val_accuracy": sum(val_accuracies) / sum(examples),
39+
}
40+
41+
42+
# Initialize model parameters
43+
ndarrays = get_weights(Net())
44+
parameters = ndarrays_to_parameters(ndarrays)
45+
46+
47+
# Define strategy
48+
strategy = FedAvg(
49+
fraction_fit=1.0, # Select all available clients
50+
fraction_evaluate=0.0, # Disable evaluation
51+
min_available_clients=2,
52+
fit_metrics_aggregation_fn=weighted_average,
53+
initial_parameters=parameters,
54+
)
55+
56+
57+
# Define config
58+
config = ServerConfig(num_rounds=3)
59+
60+
61+
# Flower ServerApp
62+
app = ServerApp(
63+
config=config,
64+
strategy=strategy,
65+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
[build-system]
2+
requires = ["hatchling"]
3+
build-backend = "hatchling.build"
4+
5+
[project]
6+
name = "flwr_pt_tb_streaming"
7+
version = "1.0.0"
8+
description = ""
9+
license = "Apache-2.0"
10+
dependencies = [
11+
"flwr[simulation]>=1.11.0,<2.0",
12+
"nvflare~=2.5.0rc",
13+
"torch==2.2.1",
14+
"torchvision==0.17.1",
15+
]
16+
17+
[tool.hatch.build.targets.wheel]
18+
packages = ["."]
19+
20+
[tool.flwr.app]
21+
publisher = "nvidia"
22+
23+
[tool.flwr.app.components]
24+
serverapp = "flwr_pt_tb_streaming.server:app"
25+
clientapp = "flwr_pt_tb_streaming.client:app"
26+
27+
[tool.flwr.app.config]
28+
num-server-rounds = 3
29+
30+
[tool.flwr.federations]
31+
default = "local-simulation"
32+
33+
[tool.flwr.federations.local-simulation]
34+
options.num-supernodes = 2
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""flwr_pt."""

examples/hello-world/hello-flower/jobs/hello-flwr-pt/app/custom/flwr_pt/client.py examples/hello-world/hello-flower/flwr-pt/flwr_pt/client.py

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
from flwr.client import ClientApp, NumPyClient
1516
from flwr.common import Context
1617

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from collections import OrderedDict
15+
from logging import INFO
16+
17+
import torch
18+
import torch.nn as nn
19+
import torch.nn.functional as F
20+
from flwr.common.logger import log
21+
from torch.utils.data import DataLoader
22+
from torchvision.datasets import CIFAR10
23+
from torchvision.transforms import Compose, Normalize, ToTensor
24+
25+
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
26+
27+
28+
class Net(nn.Module):
29+
"""Model (simple CNN adapted from 'PyTorch: A 60 Minute Blitz')"""
30+
31+
def __init__(self) -> None:
32+
super(Net, self).__init__()
33+
self.conv1 = nn.Conv2d(3, 6, 5)
34+
self.pool = nn.MaxPool2d(2, 2)
35+
self.conv2 = nn.Conv2d(6, 16, 5)
36+
self.fc1 = nn.Linear(16 * 5 * 5, 120)
37+
self.fc2 = nn.Linear(120, 84)
38+
self.fc3 = nn.Linear(84, 10)
39+
40+
def forward(self, x: torch.Tensor) -> torch.Tensor:
41+
x = self.pool(F.relu(self.conv1(x)))
42+
x = self.pool(F.relu(self.conv2(x)))
43+
x = x.view(-1, 16 * 5 * 5)
44+
x = F.relu(self.fc1(x))
45+
x = F.relu(self.fc2(x))
46+
return self.fc3(x)
47+
48+
49+
def load_data():
50+
"""Load CIFAR-10 (training and test set)."""
51+
trf = Compose([ToTensor(), Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
52+
trainset = CIFAR10("./data", train=True, download=True, transform=trf)
53+
testset = CIFAR10("./data", train=False, download=True, transform=trf)
54+
return DataLoader(trainset, batch_size=32, shuffle=True), DataLoader(testset)
55+
56+
57+
def train(net, trainloader, valloader, epochs, device):
58+
"""Train the model on the training set."""
59+
log(INFO, "Starting training...")
60+
net.to(device) # move model to GPU if available
61+
criterion = torch.nn.CrossEntropyLoss().to(device)
62+
optimizer = torch.optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
63+
net.train()
64+
for _ in range(epochs):
65+
for images, labels in trainloader:
66+
images, labels = images.to(device), labels.to(device)
67+
optimizer.zero_grad()
68+
loss = criterion(net(images), labels)
69+
loss.backward()
70+
optimizer.step()
71+
72+
train_loss, train_acc = test(net, trainloader)
73+
val_loss, val_acc = test(net, valloader)
74+
75+
results = {
76+
"train_loss": train_loss,
77+
"train_accuracy": train_acc,
78+
"val_loss": val_loss,
79+
"val_accuracy": val_acc,
80+
}
81+
return results
82+
83+
84+
def test(net, testloader):
85+
"""Validate the model on the test set."""
86+
net.to(DEVICE)
87+
criterion = torch.nn.CrossEntropyLoss()
88+
correct, loss = 0, 0.0
89+
with torch.no_grad():
90+
for images, labels in testloader:
91+
outputs = net(images.to(DEVICE))
92+
labels = labels.to(DEVICE)
93+
loss += criterion(outputs, labels).item()
94+
correct += (torch.max(outputs.data, 1)[1] == labels).sum().item()
95+
accuracy = correct / len(testloader.dataset)
96+
return loss, accuracy
97+
98+
99+
def get_weights(net):
100+
return [val.cpu().numpy() for _, val in net.state_dict().items()]
101+
102+
103+
def set_weights(net, parameters):
104+
params_dict = zip(net.state_dict().keys(), parameters)
105+
state_dict = OrderedDict({k: torch.tensor(v) for k, v in params_dict})
106+
net.load_state_dict(state_dict, strict=True)
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from nvflare.app_opt.flower.flower_job import FlowerJob
16+
17+
if __name__ == "__main__":
18+
job = FlowerJob(name="flwr-pt", flower_content="./flwr-pt")
19+
20+
job.export_job("jobs")
21+
job.simulator_run("/tmp/nvflare/flwr-pt", gpu="0", n_clients=2)

0 commit comments

Comments
 (0)