|
14 | 14 |
|
15 | 15 | from src.net import Net
|
16 | 16 |
|
17 |
| -from nvflare import FedJob, ScriptExecutor |
18 | 17 | from nvflare.apis.dxo import DataKind
|
19 | 18 | from nvflare.app_common.aggregators.intime_accumulate_model_aggregator import InTimeAccumulateWeightedAggregator
|
20 |
| -from nvflare.app_common.ccwf import ( |
21 |
| - CrossSiteEvalClientController, |
22 |
| - CrossSiteEvalServerController, |
23 |
| - SwarmClientController, |
24 |
| - SwarmServerController, |
25 |
| -) |
| 19 | +from nvflare.app_common.ccwf.ccwf_job import CCWFJob, CrossSiteEvalConfig, SwarmClientConfig, SwarmServerConfig |
26 | 20 | from nvflare.app_common.ccwf.comps.simple_model_shareable_generator import SimpleModelShareableGenerator
|
| 21 | +from nvflare.app_common.executors.script_executor import ScriptExecutor |
27 | 22 | from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
|
28 | 23 |
|
29 | 24 | if __name__ == "__main__":
|
30 | 25 | n_clients = 2
|
31 | 26 | num_rounds = 3
|
32 | 27 | train_script = "src/train_eval_submit.py"
|
33 | 28 |
|
34 |
| - job = FedJob(name="cifar10_swarm") |
35 |
| - |
36 |
| - controller = SwarmServerController( |
37 |
| - num_rounds=num_rounds, |
| 29 | + job = CCWFJob(name="cifar10_swarm") |
| 30 | + aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) |
| 31 | + job.add_swarm( |
| 32 | + server_config=SwarmServerConfig(num_rounds=num_rounds), |
| 33 | + client_config=SwarmClientConfig( |
| 34 | + executor=ScriptExecutor(task_script_path=train_script, evaluate_task_name="validate"), |
| 35 | + aggregator=aggregator, |
| 36 | + persistor=PTFileModelPersistor(model=Net()), |
| 37 | + shareable_generator=SimpleModelShareableGenerator(), |
| 38 | + ), |
| 39 | + cse_config=CrossSiteEvalConfig(eval_task_timeout=300), |
38 | 40 | )
|
39 |
| - job.to(controller, "server") |
40 |
| - controller = CrossSiteEvalServerController(eval_task_timeout=300) |
41 |
| - job.to(controller, "server") |
42 |
| - |
43 |
| - # Define the initial server model |
44 |
| - job.to(Net(), "server") |
45 |
| - |
46 |
| - for i in range(n_clients): |
47 |
| - executor = ScriptExecutor(task_script_path=train_script, evaluate_task_name="validate") |
48 |
| - job.to(executor, f"site-{i}", gpu=0, tasks=["train", "validate", "submit_model"]) |
49 |
| - |
50 |
| - # In swarm learning, each client acts also as an aggregator |
51 |
| - aggregator = InTimeAccumulateWeightedAggregator(expected_data_kind=DataKind.WEIGHTS) |
52 |
| - |
53 |
| - # In swarm learning, each client uses a model persistor and shareable_generator |
54 |
| - persistor = PTFileModelPersistor(model=Net()) |
55 |
| - shareable_generator = SimpleModelShareableGenerator() |
56 |
| - |
57 |
| - persistor_id = job.as_id(persistor) |
58 |
| - client_controller = SwarmClientController( |
59 |
| - aggregator_id=job.as_id(aggregator), |
60 |
| - persistor_id=persistor_id, |
61 |
| - shareable_generator_id=job.as_id(shareable_generator), |
62 |
| - ) |
63 |
| - job.to(client_controller, f"site-{i}", tasks=["swarm_*"]) |
64 |
| - |
65 |
| - client_controller = CrossSiteEvalClientController(persistor_id=persistor_id) |
66 |
| - job.to(client_controller, f"site-{i}", tasks=["cse_*"]) |
67 | 41 |
|
68 | 42 | # job.export_job("/tmp/nvflare/jobs/job_config")
|
69 |
| - job.simulator_run("/tmp/nvflare/jobs/workdir") |
| 43 | + job.simulator_run("/tmp/nvflare/jobs/workdir", n_clients=n_clients, gpu="0") |
0 commit comments