Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh committed Jan 28, 2025
1 parent 515ddd3 commit ca7b69c
Show file tree
Hide file tree
Showing 10 changed files with 341 additions and 160 deletions.
96 changes: 47 additions & 49 deletions examples/hello-world/hello-pt/src/hello-pt_cifar10_fl.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.client import FlareClientContext, FLModel
import nvflare.client as flare
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"


def main():
batch_size = 4
epochs = 5
epochs = 2
lr = 0.01
model = SimpleNetwork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
Expand All @@ -43,55 +43,53 @@ def main():
]
)

with FlareClientContext() as flare:
sys_info = flare.system_info()
client_name = sys_info["site_name"]
flare.init()
sys_info = flare.system_info()
client_name = sys_info["site_name"]

train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(tag="loss_for_each_batch", scalar=running_loss, global_step=global_step)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = flare.FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter()
while flare.is_running():
input_model = flare.receive()
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(
tag="loss_for_each_batch", scalar=running_loss, global_step=global_step
)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

flare.send(output_model)

flare.send(output_model)


if __name__ == "__main__":
Expand Down
99 changes: 99 additions & 0 deletions examples/hello-world/hello-pt/src/hello-pt_cifar10_fl_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch
from simple_network import SimpleNetwork
from torch import nn
from torch.optim import SGD
from torch.utils.data.dataloader import DataLoader
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose, Normalize, ToTensor

import nvflare.client as flare
from nvflare.client import FLModel
from nvflare.client.tracking import SummaryWriter

DATASET_PATH = "/tmp/nvflare/data"


def main():
batch_size = 4
epochs = 2
lr = 0.01
model = SimpleNetwork()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
loss = nn.CrossEntropyLoss()
optimizer = SGD(model.parameters(), lr=lr, momentum=0.9)
transforms = Compose(
[
ToTensor(),
Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]
)

with flare.init() as ctx:
sys_info = flare.system_info(ctx=ctx)
client_name = sys_info["site_name"]

train_dataset = CIFAR10(
root=os.path.join(DATASET_PATH, client_name), transform=transforms, download=True, train=True
)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

summary_writer = SummaryWriter(ctx=ctx)
while flare.is_running(ctx=ctx):
input_model = flare.receive(ctx=ctx)
print(f"current_round={input_model.current_round}")

model.load_state_dict(input_model.params)
model.to(device)

steps = epochs * len(train_loader)
for epoch in range(epochs):
running_loss = 0.0
for i, batch in enumerate(train_loader):
images, labels = batch[0].to(device), batch[1].to(device)
optimizer.zero_grad()

predictions = model(images)
cost = loss(predictions, labels)
cost.backward()
optimizer.step()

running_loss += cost.cpu().detach().numpy() / images.size()[0]
if i % 3000 == 0:
print(f"Epoch: {epoch}/{epochs}, Iteration: {i}, Loss: {running_loss / 3000}")
global_step = input_model.current_round * steps + epoch * len(train_loader) + i
summary_writer.add_scalar(
tag="loss_for_each_batch", scalar=running_loss, global_step=global_step
)
running_loss = 0.0

print("Finished Training")

PATH = "./cifar_net.pth"
torch.save(model.state_dict(), PATH)

output_model = FLModel(
params=model.cpu().state_dict(),
meta={"NUM_STEPS_CURRENT_ROUND": steps},
)

flare.send(output_model, ctx=ctx)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions nvflare/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from nvflare.app_common.abstract.fl_model import FLModel as FLModel
from nvflare.app_common.abstract.fl_model import ParamsType as ParamsType

from .api import FlareClientContext as FlareClientContext
from .api import get_config as get_config
from .api import get_job_id as get_job_id
from .api import get_site_name as get_site_name
Expand All @@ -31,7 +30,8 @@
from .api import log as log
from .api import receive as receive
from .api import send as send
from .api import shutdown as shutdown
from .api import system_info as system_info
from .decorator import evaluate as evaluate
from .decorator import train as train
from .ipc.ipc_agent import IPCAgent as IPCAgent
from .ipc.ipc_agent import IPCAgent
Loading

0 comments on commit ca7b69c

Please sign in to comment.