Skip to content

Commit

Permalink
Add global round info for logging metrics global step (#2258)
Browse files Browse the repository at this point in the history
* Add global round info for logging metrics global step

* Fix ci

* update how to get current round

* improvements

* fixes
  • Loading branch information
nvkevlu authored Jan 8, 2024
1 parent 1094ff2 commit 1d2bf4d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand All @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -147,6 +149,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand Down Expand Up @@ -174,12 +177,12 @@ def local_train(self, fl_ctx, abort_signal):
)

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.log_metrics({"train_loss": cost.item(), "running_loss": running_loss}, current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.log_metric("validation_accuracy", metric, epoch)
self.writer.log_metric("validation_accuracy", metric, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down
9 changes: 6 additions & 3 deletions examples/advanced/experiment-tracking/pt/learner_with_tb.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
class PTLearner(Learner):
def __init__(
self,
data_path="/tmp/nvflare/tensorboard-streaming",
data_path="~/data",
lr=0.01,
epochs=5,
exclude_vars=None,
Expand All @@ -52,6 +52,7 @@ def __init__(
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -71,6 +72,7 @@ def __init__(
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -150,6 +152,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand All @@ -173,12 +176,12 @@ def local_train(self, fl_ctx, abort_signal):
running_loss = 0.0

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.add_scalar("train_loss", cost.item(), current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.add_scalar("validation_accuracy", metric, epoch)
self.writer.add_scalar("validation_accuracy", metric, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from torchvision.transforms import Compose, Normalize, ToTensor

from nvflare.apis.dxo import DXO, DataKind, MetaKey, from_shareable
from nvflare.apis.fl_constant import ReservedKey, ReturnCode
from nvflare.apis.fl_constant import FLContextKey, ReservedKey, ReturnCode
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable, make_reply
from nvflare.apis.signal import Signal
Expand All @@ -44,6 +44,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
"""Simple PyTorch Learner that trains and validates a simple network on the CIFAR10 dataset.
Args:
data_path (str): Path that the data will be stored at. Defaults to "~/data".
lr (float, optional): Learning rate. Defaults to 0.01
epochs (int, optional): Epochs. Defaults to 5
exclude_vars (list): List of variables to exclude during model loading.
Expand All @@ -63,6 +64,7 @@ def __init__(self, data_path="~/data", lr=0.01, epochs=5, exclude_vars=None, ana
self.loss = None
self.device = None
self.model = None

self.data_path = data_path
self.lr = lr
self.epochs = epochs
Expand Down Expand Up @@ -141,6 +143,7 @@ def train(self, data: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Sha

def local_train(self, fl_ctx, abort_signal):
# Basic training
current_round = fl_ctx.get_prop(FLContextKey.TASK_DATA).get_header("current_round")
for epoch in range(self.epochs):
self.model.train()
running_loss = 0.0
Expand All @@ -164,12 +167,12 @@ def local_train(self, fl_ctx, abort_signal):
running_loss = 0.0

# Stream training loss at each step
current_step = len(self.train_loader) * epoch + i
current_step = self.n_iterations * self.epochs * current_round + self.n_iterations * epoch + i
self.writer.log({"train_loss": cost.item()}, current_step)

# Stream validation accuracy at the end of each epoch
metric = self.local_validate(abort_signal)
self.writer.log({"validation_accuracy": metric}, epoch)
self.writer.log({"validation_accuracy": metric}, current_step)

def get_model_for_validation(self, model_name: str, fl_ctx: FLContext) -> Shareable:
run_dir = fl_ctx.get_engine().get_workspace().get_run_dir(fl_ctx.get_job_id())
Expand Down

0 comments on commit 1d2bf4d

Please sign in to comment.