Skip to content

Commit

Permalink
formatting changes
Browse files Browse the repository at this point in the history
  • Loading branch information
chesterxgchen committed Feb 2, 2025
1 parent c743b18 commit 3e8211e
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,15 @@ class FedAvg(BaseFedAvg):
initial_model (nn.Module, optional): initial PyTorch model
"""

def __init__(self, *args, stop_cond: str, num_rounds: int, save_filename: str = "FL_global_model.pt", initial_model=None, **kwargs):
def __init__(
self,
*args,
stop_cond: str,
num_rounds: int,
save_filename: str = "FL_global_model.pt",
initial_model=None,
**kwargs,
):
super().__init__(*args, **kwargs)

self.stop_cond = stop_cond
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
import argparse
import os


from src.server import FedAvg
from src.network import SimpleNetwork

from nvflare.job_config.api import FedJob
from src.server import FedAvg

from nvflare.fuel_opt.statsd.statsd_reporter import StatsDReporter
from nvflare.job_config.api import FedJob
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.metrics.job_metrics_collector import JobMetricsCollector

Expand All @@ -40,7 +38,6 @@ def main(job_configs_dir):
train_script = "src/client.py"
job_name = "fedavg"


job = FedJob(name=job_name, min_clients=num_clients)

controller = FedAvg(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import argparse
import os

from src.server import FedAvg
from src.network import SimpleNetwork

from nvflare.job_config.api import FedJob
from src.server import FedAvg

from nvflare.app_common.widgets.convert_to_fed_event import ConvertToFedEvent
from nvflare.fuel_opt.statsd.statsd_reporter import StatsDReporter
from nvflare.job_config.api import FedJob
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.metrics.job_metrics_collector import JobMetricsCollector
from nvflare.metrics.metrics_keys import METRICS_EVENT_TYPE
Expand Down

0 comments on commit 3e8211e

Please sign in to comment.