From b26b22474d04f06ce7fd11adebee272c2493e128 Mon Sep 17 00:00:00 2001 From: Ziyue Xu Date: Thu, 30 Jan 2025 16:53:13 -0500 Subject: [PATCH] Update streaming example (#3195) Fixes # . ### Description Convert job templates to job API, add end-to-end example with memory comparison, update Readme ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Quick tests passed locally by running `./runtest.sh`. - [ ] In-line docstrings updated. - [ ] Documentation updated. --- examples/advanced/streaming/README.md | 107 +++++++++------ .../advanced/streaming/container_stream.sh | 3 + examples/advanced/streaming/file_stream.sh | 3 + .../app/config/config_fed_client.json | 23 ---- .../app/config/config_fed_server.json | 20 --- .../dict_streaming/app/custom/__init__.py | 13 -- .../streaming/jobs/dict_streaming/meta.json | 10 -- .../app/config/config_fed_client.json | 23 ---- .../app/config/config_fed_server.json | 19 --- .../file_streaming/app/custom/__init__.py | 13 -- .../streaming/jobs/file_streaming/meta.json | 10 -- .../streaming/regular_transmission.sh | 4 + .../streaming/simple_dict_streaming_job.py | 53 ++++++++ .../streaming/simple_file_streaming_job.py | 54 ++++++++ .../simple_controller.py} | 0 .../trainer.py => src/simple_executor.py} | 3 +- .../simple_streaming_controller.py} | 31 +++-- .../simple_streaming_executor.py} | 4 +- .../standalone_file_streaming.py} | 1 + .../streaming/src/streaming_controller.py | 126 ++++++++++++++++++ .../streaming/src/streaming_executor.py | 110 +++++++++++++++ examples/advanced/streaming/streaming_job.py | 80 +++++++++++ .../advanced/streaming/utils/log_memory.sh | 9 ++ 23 files changed, 531 insertions(+), 188 deletions(-) create mode 100644 examples/advanced/streaming/container_stream.sh create mode 100644 examples/advanced/streaming/file_stream.sh delete mode 100755 examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json delete mode 100755 examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json delete mode 100644 examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py delete mode 100644 examples/advanced/streaming/jobs/dict_streaming/meta.json delete mode 100755 examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json delete mode 100755 examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json delete mode 100644 examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py delete mode 100644 examples/advanced/streaming/jobs/file_streaming/meta.json create mode 100644 examples/advanced/streaming/regular_transmission.sh create mode 100644 examples/advanced/streaming/simple_dict_streaming_job.py create mode 100644 examples/advanced/streaming/simple_file_streaming_job.py rename examples/advanced/streaming/{jobs/file_streaming/app/custom/controller.py => src/simple_controller.py} (100%) rename examples/advanced/streaming/{jobs/file_streaming/app/custom/trainer.py => src/simple_executor.py} (97%) rename examples/advanced/streaming/{jobs/dict_streaming/app/custom/streaming_controller.py => src/simple_streaming_controller.py} (97%) rename examples/advanced/streaming/{jobs/dict_streaming/app/custom/streaming_executor.py => src/simple_streaming_executor.py} (96%) rename examples/advanced/streaming/{jobs/file_streaming/app/custom/file_streaming.py => src/standalone_file_streaming.py} (99%) create mode 100644 examples/advanced/streaming/src/streaming_controller.py create mode 100644 examples/advanced/streaming/src/streaming_executor.py create mode 100644 examples/advanced/streaming/streaming_job.py create mode 100644 examples/advanced/streaming/utils/log_memory.sh diff --git a/examples/advanced/streaming/README.md b/examples/advanced/streaming/README.md index 6ab160c7d5..d97c8145bf 100644 --- a/examples/advanced/streaming/README.md +++ b/examples/advanced/streaming/README.md @@ -1,68 +1,99 @@ -# Object Streaming Examples +# Object Streaming ## Overview -The examples here demonstrate how to use object streamers to send large file/objects memory efficiently. +The examples here demonstrate how to use object streamers to send large objects in a memory-efficient manner. -The object streamer uses less memory because it sends files by chunks (default chunk size is 1MB) and -it sends containers entry by entry. +Current default setting is to send and receive large objects in full, so extra memory will be needed and allocated to hold the received message. +This works fine when the message is small, but can become a limit when model size is large, e.g. for large language models. -For example, if you have a dict with 10 1GB entries, it will take 10GB extra space to send the dict without -streaming. It only requires extra 1GB to serialize the entry using streaming. +To save on memory usage, we can stream the message send / receive: when sending large objects (e.g. a dict), +streamer sends containers entry by entry (e.g. one dict item each time); further, if we save the object to a file, +streamer can send the file by chunks (default chunk size is 1MB). +Thus, the memory demand can be reduced to the size of the largest entry for container streaming; while nearly no extra memory is needed for file +streaming. For example, if sending a dict with 10 1GB entries, without streaming, it will take 10GB extra space to send the dict. +With container streaming, it only requires extra 1GB; and if saved to a file before sending, it only requires 1MB extra space to send the file. + +All examples are run with NVFlare Simulator via [JobAPI](https://nvflare.readthedocs.io/en/main/programming_guide/fed_job_api.html). ## Concepts ### Object Streamer - -ObjectStreamer is a base class to stream an object piece by piece. The `StreamableEngine` built in the NVFlare can +ObjectStreamer is the base class to stream an object piece by piece. The `StreamableEngine` built in the NVFlare can stream any implementations of ObjectSteamer -Following implementations are included in NVFlare, +The following implementations are included in NVFlare, -* `FileStreamer`: It can be used to stream a file -* `ContainerStreamer`: This class can stream a container entry by entry. Currently, dict, list and set are supported +* `ContainerStreamer`: This class is used to stream a container entry by entry. Currently, dict, list and set are supported +* `FileStreamer`: This class is used to stream a file -The container streamer can only stream the top level entries. All the sub entries of a top entry are sent at once with -the top entry. +Note that the container streamer split the stream by the top level entries. All the sub entries of a top entry are expected to be +sent as a whole, therefore the memory is determined by the largest entry at top level. ### Object Retriever - -`ObjectRetriever` is designed to request an object to be streamed from a remote site. It automatically sets up the streaming +Building upon the streamers, `ObjectRetriever` is designed for easier integration with existing code: to request an object to be streamed from a remote site. It automatically sets up the streaming on both ends and handles the coordination. -Currently, following implementations are available, - -* `FileRetriever`: It's used to retrieve a file from remote site using FileStreamer. -* `ContainerRetriever`: This class can be used to retrieve a container from remote site using ContainerStreamer. +Similarly, the following implementations are available, -To use ContainerRetriever, the container must be given a name and added on the sending site, +* `ContainerRetriever`: This class is used to retrieve a container from remote site using `ContainerStreamer`. +* `FileRetriever`: This class is used to retrieve a file from remote site using `FileStreamer`. +Note that to use ContainerRetriever, the container must be given a name and added on the sending site, ``` ContainerRetriever.add_container("model", model_dict) ``` -## Example Jobs +## Simple Examples +First, we demonstrate how to use the Streamer directly without Retriever: +```commandline +python simple_file_streaming_job.py +``` +Note that in this example, the file streaming is relatively "standalone", as the `FileReceiver` and `FileSender` +are used directly as components, and no training workflow is used - as executor is required by NVFlare, here we used +a dummy executor. + +Although the file streaming is simple, it is not very practical for real-world applications, because +in most cases, rather than standalone, we need to send an object when it is generated at certain point in the workflow. In such cases, +Retriever is more convenient to use: +```commandline +python simple_dict_streaming_job.py +``` +In this second example, the `ContainerRetriever` is setup in both server and client, and will automatically handle the streaming. +It couples closely with the workflow, and is easier to define what to send and where to retrieve. + +## Full-scale Examples and Comparisons +The above two simple examples illustrated the basic usage of streaming with random small messages. In the following, +we will demonstrate how to use the streamer with Retriever in a workflow with real large language model object, +and compare the memory usage with and without streaming. To track the memory usage, we use a simple script `utils/log_memory.sh`. +Note that the tracked usage is not fully accurate, but it is sufficient to give us a rough idea. + +All three settings: regular, container streaming, and file streaming, are integrated in the same script to avoid extra variabilities. +To run the examples: +```commandline +bash regular_transmission.sh +``` +```commandline +bash container_stream.sh +``` +```commandline +bash file_stream.sh +``` -### file_streaming job +We then examine the memory usage by comparing the peak memory usage of the three settings. The results are shown below, +note that the numbers here are the results of one experiment on one machine, and can be highly variable depending on the system and the environment. + +| Setting | Peak Memory Usage (MB) | Job Finishing Time (s) | +| --- | --- | --- | +| Regular Transmission | 42,427 | 47 +| Container Streaming | 23,265 | 50 +| File Streaming | 19,176 | 170 + +As shown, the memory usage is significantly reduced by using streaming, especially for file streaming, +while file streaming takes much longer time to finish the job. -This job uses the FileStreamer object to send a large file from server to client. -It demonstrates following mechanisms: -1. It uses components to handle the file transferring. No training workflow is used. - Since executor is required by NVFlare, a dummy executor is created. -2. It shows how to use the streamer directly without an object retriever. -The job creates a temporary file to test. You can run the job in POC or using simulator as follows, -``` -nvflare simulator -n 1 -t 1 jobs/file_streaming -``` -### dict_streaming job -This job demonstrate how to send a dict from server to client using object retriever. -It creates a task called "retrieve_dict" to tell client to get ready for the streaming. -The example can be run in simulator like this, -``` -nvflare simulator -n 1 -t 1 jobs/dict_streaming -``` diff --git a/examples/advanced/streaming/container_stream.sh b/examples/advanced/streaming/container_stream.sh new file mode 100644 index 0000000000..695a7ac922 --- /dev/null +++ b/examples/advanced/streaming/container_stream.sh @@ -0,0 +1,3 @@ +pkill -9 python +bash utils/log_memory.sh >>/tmp/nvflare/workspace/container.txt & +python streaming_job.py --retriever_mode container diff --git a/examples/advanced/streaming/file_stream.sh b/examples/advanced/streaming/file_stream.sh new file mode 100644 index 0000000000..b1e9754f14 --- /dev/null +++ b/examples/advanced/streaming/file_stream.sh @@ -0,0 +1,3 @@ +pkill -9 python +bash utils/log_memory.sh >>/tmp/nvflare/workspace/file.txt & +python streaming_job.py --retriever_mode file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json deleted file mode 100755 index c2d85b8b48..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_client.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "format_version": 2, - "cell_wait_timeout": 5.0, - "executors": [ - { - "tasks": ["*"], - "executor": { - "path": "streaming_executor.StreamingExecutor", - "args": { - "dict_retriever_id": "dict_retriever" - } - } - } - ], - "components": [ - { - "id": "dict_retriever", - "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", - "args": { - } - } - ] -} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json deleted file mode 100755 index fd847e0175..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/config/config_fed_server.json +++ /dev/null @@ -1,20 +0,0 @@ -{ - "format_version": 2, - "components": [ - { - "id": "dict_retriever", - "path": "nvflare.app_common.streamers.container_retriever.ContainerRetriever", - "args": { - } - } - ], - "workflows": [ - { - "id": "controller", - "path": "streaming_controller.StreamingController", - "args": { - "dict_retriever_id": "dict_retriever" - } - } - ] -} \ No newline at end of file diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py deleted file mode 100644 index 341a77c5bc..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/examples/advanced/streaming/jobs/dict_streaming/meta.json b/examples/advanced/streaming/jobs/dict_streaming/meta.json deleted file mode 100644 index 0fcb99272c..0000000000 --- a/examples/advanced/streaming/jobs/dict_streaming/meta.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "file_streaming", - "resource_spec": {}, - "min_clients" : 1, - "deploy_map": { - "app": [ - "@ALL" - ] - } -} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json deleted file mode 100755 index 5ac09cbb4f..0000000000 --- a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_client.json +++ /dev/null @@ -1,23 +0,0 @@ -{ - "format_version": 2, - "executors": [ - { - "tasks": [ - "train" - ], - "executor": { - "path": "trainer.TestTrainer", - "args": {} - } - } - ], - "task_result_filters": [], - "task_data_filters": [], - "components": [ - { - "id": "sender", - "path": "file_streaming.FileSender", - "args": {} - } - ] -} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json b/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json deleted file mode 100755 index 1c0be95c54..0000000000 --- a/examples/advanced/streaming/jobs/file_streaming/app/config/config_fed_server.json +++ /dev/null @@ -1,19 +0,0 @@ -{ - "format_version": 2, - "task_data_filters": [], - "task_result_filters": [], - "components": [ - { - "id": "receiver", - "path": "file_streaming.FileReceiver", - "args": {} - } - ], - "workflows": [ - { - "id": "controller", - "path": "controller.SimpleController", - "args": {} - } - ] -} diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py b/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py deleted file mode 100644 index 341a77c5bc..0000000000 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/examples/advanced/streaming/jobs/file_streaming/meta.json b/examples/advanced/streaming/jobs/file_streaming/meta.json deleted file mode 100644 index 0fcb99272c..0000000000 --- a/examples/advanced/streaming/jobs/file_streaming/meta.json +++ /dev/null @@ -1,10 +0,0 @@ -{ - "name": "file_streaming", - "resource_spec": {}, - "min_clients" : 1, - "deploy_map": { - "app": [ - "@ALL" - ] - } -} diff --git a/examples/advanced/streaming/regular_transmission.sh b/examples/advanced/streaming/regular_transmission.sh new file mode 100644 index 0000000000..f90a9b68a2 --- /dev/null +++ b/examples/advanced/streaming/regular_transmission.sh @@ -0,0 +1,4 @@ +pkill -9 python +mkdir /tmp/nvflare/workspace/ +bash utils/log_memory.sh >>/tmp/nvflare/workspace/regular.txt & +python streaming_job.py diff --git a/examples/advanced/streaming/simple_dict_streaming_job.py b/examples/advanced/streaming/simple_dict_streaming_job.py new file mode 100644 index 0000000000..4e84cfb130 --- /dev/null +++ b/examples/advanced/streaming/simple_dict_streaming_job.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.simple_streaming_controller import SimpleStreamingController +from src.simple_streaming_executor import SimpleStreamingExecutor + +from nvflare import FedJob +from nvflare.app_common.streamers.container_retriever import ContainerRetriever + + +def main(): + # Create the FedJob + job = FedJob(name="simple_dict_streaming", min_clients=1) + + # Define dict_retriever component and send to both server and clients + dict_retriever = ContainerRetriever() + job.to_server(dict_retriever, id="dict_retriever") + job.to_clients(dict_retriever, id="dict_retriever") + + # Define the controller workflow and send to server + controller = SimpleStreamingController(dict_retriever_id="dict_retriever") + job.to_server(controller) + + # Define the executor and send to clients + executor = SimpleStreamingExecutor(dict_retriever_id="dict_retriever") + job.to_clients(executor, tasks=["*"]) + + # Export the job + job_dir = "/tmp/nvflare/workspace/jobs/simple_dict_streaming" + print("job_dir=", job_dir) + job.export_job(job_dir) + + # Run the job + work_dir = "/tmp/nvflare/workspace/works/simple_dict_streaming" + print("workspace_dir=", work_dir) + + # starting the monitoring + job.simulator_run(work_dir, n_clients=1, threads=1) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/streaming/simple_file_streaming_job.py b/examples/advanced/streaming/simple_file_streaming_job.py new file mode 100644 index 0000000000..7506fdda53 --- /dev/null +++ b/examples/advanced/streaming/simple_file_streaming_job.py @@ -0,0 +1,54 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from src.simple_controller import SimpleController +from src.simple_executor import SimpleExecutor +from src.standalone_file_streaming import FileReceiver, FileSender + +from nvflare import FedJob + + +def main(): + # Create the FedJob + job = FedJob(name="simple_file_streaming", min_clients=1) + + # Define the controller workflow, and file receiver + # and send to server + controller = SimpleController() + receiver = FileReceiver() + job.to_server(controller) + job.to_server(receiver, id="receiver") + + # Define the executor, and file sender + # and send to clients + executor = SimpleExecutor() + sender = FileSender() + job.to_clients(executor, tasks=["train"]) + job.to_clients(sender, id="sender") + + # Export the job + job_dir = "/tmp/nvflare/workspace/jobs/simple_file_streaming" + print("job_dir=", job_dir) + job.export_job(job_dir) + + # Run the job + work_dir = "/tmp/nvflare/workspace/works/simple_file_streaming" + print("workspace_dir=", work_dir) + + # starting the monitoring + job.simulator_run(work_dir, n_clients=1, threads=1) + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py b/examples/advanced/streaming/src/simple_controller.py similarity index 100% rename from examples/advanced/streaming/jobs/file_streaming/app/custom/controller.py rename to examples/advanced/streaming/src/simple_controller.py diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py b/examples/advanced/streaming/src/simple_executor.py similarity index 97% rename from examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py rename to examples/advanced/streaming/src/simple_executor.py index 216d69f793..3811b4a23c 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/trainer.py +++ b/examples/advanced/streaming/src/simple_executor.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.event_type import EventType from nvflare.apis.executor import Executor @@ -19,7 +20,7 @@ from nvflare.apis.signal import Signal -class TestTrainer(Executor): +class SimpleExecutor(Executor): def __init__(self): super().__init__() self.aborted = False diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py b/examples/advanced/streaming/src/simple_streaming_controller.py similarity index 97% rename from examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py rename to examples/advanced/streaming/src/simple_streaming_controller.py index 1f1700d1a8..9ec402055b 100644 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_controller.py +++ b/examples/advanced/streaming/src/simple_streaming_controller.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + from random import randbytes from nvflare.apis.controller_spec import Client, ClientTask, Task @@ -21,10 +22,8 @@ from nvflare.apis.signal import Signal from nvflare.app_common.streamers.container_retriever import ContainerRetriever -STREAM_TOPIC = "rtr_file_stream" - -class StreamingController(Controller): +class SimpleStreamingController(Controller): def __init__(self, dict_retriever_id=None, task_timeout=60, task_check_period: float = 0.5): Controller.__init__(self, task_check_period=task_check_period) self.dict_retriever_id = dict_retriever_id @@ -38,6 +37,19 @@ def start_controller(self, fl_ctx: FLContext): def stop_controller(self, fl_ctx: FLContext): pass + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.dict_retriever_id: + c = engine.get_component(self.dict_retriever_id) + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.dict_retriever = c + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): s = Shareable() s["name"] = "model" @@ -66,19 +78,6 @@ def process_result_of_unknown_task( ): pass - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.START_RUN: - engine = fl_ctx.get_engine() - if self.dict_retriever_id: - c = engine.get_component(self.dict_retriever_id) - if not isinstance(c, ContainerRetriever): - self.system_panic( - f"invalid dict_retriever {self.dict_retriever_id}, wrong type: {type(c)}", - fl_ctx, - ) - return - self.dict_retriever = c - @staticmethod def _get_test_model() -> dict: model = {} diff --git a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py b/examples/advanced/streaming/src/simple_streaming_executor.py similarity index 96% rename from examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py rename to examples/advanced/streaming/src/simple_streaming_executor.py index a238f82aff..c01827c302 100644 --- a/examples/advanced/streaming/jobs/dict_streaming/app/custom/streaming_executor.py +++ b/examples/advanced/streaming/src/simple_streaming_executor.py @@ -21,7 +21,7 @@ from nvflare.app_common.streamers.container_retriever import ContainerRetriever -class StreamingExecutor(Executor): +class SimpleStreamingExecutor(Executor): def __init__(self, dict_retriever_id=None): Executor.__init__(self) self.dict_retriever_id = dict_retriever_id @@ -41,7 +41,7 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): self.dict_retriever = c def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: - self.log_info(fl_ctx, f"got task {task_name}: {shareable}") + self.log_info(fl_ctx, f"got task {task_name}") if task_name == "retrieve_dict": name = shareable.get("name") if not name: diff --git a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py b/examples/advanced/streaming/src/standalone_file_streaming.py similarity index 99% rename from examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py rename to examples/advanced/streaming/src/standalone_file_streaming.py index 9b49b230e5..1c4c44b87b 100644 --- a/examples/advanced/streaming/jobs/file_streaming/app/custom/file_streaming.py +++ b/examples/advanced/streaming/src/standalone_file_streaming.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + import os import tempfile from threading import Thread diff --git a/examples/advanced/streaming/src/streaming_controller.py b/examples/advanced/streaming/src/streaming_controller.py new file mode 100644 index 0000000000..c3f64c1f4d --- /dev/null +++ b/examples/advanced/streaming/src/streaming_controller.py @@ -0,0 +1,126 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +from transformers import AutoModelForCausalLM + +from nvflare.apis.controller_spec import Client, ClientTask, Task +from nvflare.apis.event_type import EventType +from nvflare.apis.fl_context import FLContext +from nvflare.apis.impl.controller import Controller +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever +from nvflare.app_common.streamers.file_retriever import FileRetriever + + +class StreamingController(Controller): + def __init__(self, retriever_mode=None, retriever_id=None, task_timeout=200, task_check_period: float = 0.5): + Controller.__init__(self, task_check_period=task_check_period) + self.retriever_mode = retriever_mode + self.retriever_id = retriever_id + self.retriever = None + self.task_timeout = task_timeout + + def start_controller(self, fl_ctx: FLContext): + self.file_name, self.model = self._get_test_model() + if self.retriever_mode == "container": + self.retriever.add_container("model", self.model) + + def stop_controller(self, fl_ctx: FLContext): + pass + + def handle_event(self, event_type: str, fl_ctx: FLContext): + # perform initialization and checks + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.retriever_mode: + c = engine.get_component(self.retriever_id) + if self.retriever_mode == "container": + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid container_retriever {self.retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.retriever = c + elif self.retriever_mode == "file": + if not isinstance(c, FileRetriever): + self.system_panic( + f"invalid file_retriever {self.retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.retriever = c + else: + self.system_panic( + f"invalid retriever_mode {self.retriever_mode}", + fl_ctx, + ) + return + + def control_flow(self, abort_signal: Signal, fl_ctx: FLContext): + s = Shareable() + # set shareable payload + if self.retriever_mode == "container": + s["model"] = "model" + elif self.retriever_mode == "file": + s["model"] = self.file_name + else: + s["model"] = self.model + task = Task(name="retrieve_model", data=s, timeout=self.task_timeout) + self.broadcast_and_wait( + task=task, + fl_ctx=fl_ctx, + min_responses=1, + abort_signal=abort_signal, + ) + client_resps = {} + for ct in task.client_tasks: + assert isinstance(ct, ClientTask) + resp = ct.result + if resp is None: + resp = "no answer" + else: + assert isinstance(resp, Shareable) + self.log_info(fl_ctx, f"got resp {resp} from client {ct.client.name}") + resp = resp.get_return_code() + client_resps[ct.client.name] = resp + return {"status": "OK", "data": client_resps} + + def process_result_of_unknown_task( + self, client: Client, task_name: str, client_task_id: str, result: Shareable, fl_ctx: FLContext + ): + pass + + @staticmethod + def _get_test_model(): + model_name = "meta-llama/llama-3.2-1b" + # load model to dict + model = AutoModelForCausalLM.from_pretrained( + model_name, + torch_dtype=torch.float32, + device_map="auto", + use_cache=False, + ) + params = model.state_dict() + for key in params: + params[key] = params[key].cpu().numpy() + + # save params dict to a npz file + file_name = "model.npz" + np.savez(file_name, **params) + + return file_name, params diff --git a/examples/advanced/streaming/src/streaming_executor.py b/examples/advanced/streaming/src/streaming_executor.py new file mode 100644 index 0000000000..228db2c226 --- /dev/null +++ b/examples/advanced/streaming/src/streaming_executor.py @@ -0,0 +1,110 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +from nvflare.apis.event_type import EventType +from nvflare.apis.executor import Executor +from nvflare.apis.fl_constant import ReturnCode +from nvflare.apis.fl_context import FLContext +from nvflare.apis.shareable import Shareable, make_reply +from nvflare.apis.signal import Signal +from nvflare.app_common.streamers.container_retriever import ContainerRetriever +from nvflare.app_common.streamers.file_retriever import FileRetriever + + +class StreamingExecutor(Executor): + def __init__(self, retriever_mode=None, retriever_id=None, task_timeout=200): + Executor.__init__(self) + self.retriever_mode = retriever_mode + self.retriever_id = retriever_id + self.retriever = None + self.task_timeout = task_timeout + + def handle_event(self, event_type: str, fl_ctx: FLContext): + # perform initialization and checks + if event_type == EventType.START_RUN: + engine = fl_ctx.get_engine() + if self.retriever_mode: + c = engine.get_component(self.retriever_id) + if self.retriever_mode == "container": + if not isinstance(c, ContainerRetriever): + self.system_panic( + f"invalid container_retriever {self.retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.retriever = c + elif self.retriever_mode == "file": + if not isinstance(c, FileRetriever): + self.system_panic( + f"invalid file_retriever {self.retriever_id}, wrong type: {type(c)}", + fl_ctx, + ) + return + self.retriever = c + else: + self.system_panic( + f"invalid retriever_mode {self.retriever_mode}", + fl_ctx, + ) + return + + def execute(self, task_name: str, shareable: Shareable, fl_ctx: FLContext, abort_signal: Signal) -> Shareable: + self.log_info(fl_ctx, f"got task {task_name}") + if task_name == "retrieve_model": + model = shareable.get("model") + if not model: + self.log_error(fl_ctx, "missing model info in request") + return make_reply(ReturnCode.BAD_TASK_DATA) + + if self.retriever_mode is None: + self.log_info(fl_ctx, f"received container type: {type(model)} size: {len(model)}") + return make_reply(ReturnCode.OK) + elif self.retriever_mode == "container": + rc, result = self.retriever.retrieve_container( + from_site="server", + fl_ctx=fl_ctx, + timeout=self.task_timeout, + name=model, + ) + if rc != ReturnCode.OK: + self.log_error(fl_ctx, f"failed to retrieve {model}: {rc}") + return make_reply(rc) + self.log_info(fl_ctx, f"received container type: {type(result)} size: {len(result)}") + return make_reply(ReturnCode.OK) + elif self.retriever_mode == "file": + rc, result = self.retriever.retrieve_file( + from_site="server", + fl_ctx=fl_ctx, + timeout=self.task_timeout, + file_name=model, + ) + if rc != ReturnCode.OK: + self.log_error(fl_ctx, f"failed to retrieve file {model}: {rc}") + return make_reply(rc) + # rename the received file to its original name + rename_path = os.path.join(os.path.dirname(result), model) + os.rename(result, rename_path) + self.log_info(fl_ctx, f"received file: {result}, renamed to: {rename_path}") + # Load local model + result = dict(np.load(rename_path)) + self.log_info(fl_ctx, f"loaded file content type: {type(result)} size: {len(result)}") + + return make_reply(ReturnCode.OK) + else: + self.log_error(fl_ctx, f"got unknown task {task_name}") + return make_reply(ReturnCode.TASK_UNKNOWN) diff --git a/examples/advanced/streaming/streaming_job.py b/examples/advanced/streaming/streaming_job.py new file mode 100644 index 0000000000..9569462157 --- /dev/null +++ b/examples/advanced/streaming/streaming_job.py @@ -0,0 +1,80 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse + +from src.streaming_controller import StreamingController +from src.streaming_executor import StreamingExecutor + +from nvflare import FedJob +from nvflare.app_common.streamers.container_retriever import ContainerRetriever +from nvflare.app_common.streamers.file_retriever import FileRetriever + + +def main(): + args = define_parser() + retriever_mode = args.retriever_mode + + # Create the FedJob + job = FedJob(name="streaming", min_clients=1) + + if retriever_mode: + if retriever_mode == "file": + retriever = FileRetriever(source_dir="./", dest_dir="./") + job_dir = "/tmp/nvflare/workspace/jobs/file_streaming" + work_dir = "/tmp/nvflare/workspace/works/file_streaming" + elif retriever_mode == "container": + retriever = ContainerRetriever() + job_dir = "/tmp/nvflare/workspace/jobs/container_streaming" + work_dir = "/tmp/nvflare/workspace/works/container_streaming" + else: + raise ValueError(f"invalid retriever_mode {retriever_mode}") + job.to_server(retriever, id="retriever") + job.to_clients(retriever, id="retriever") + + controller = StreamingController(retriever_mode=retriever_mode, retriever_id="retriever") + job.to_server(controller) + + executor = StreamingExecutor(retriever_mode=retriever_mode, retriever_id="retriever") + job.to_clients(executor, tasks=["*"]) + else: + job_dir = "/tmp/nvflare/workspace/jobs/regular_streaming" + work_dir = "/tmp/nvflare/workspace/works/regular_streaming" + controller = StreamingController() + job.to_server(controller) + executor = StreamingExecutor() + job.to_clients(executor, tasks=["*"]) + + # Export the job + print("job_dir=", job_dir) + job.export_job(job_dir) + + # Run the job + print("workspace_dir=", work_dir) + job.simulator_run(work_dir, n_clients=1, threads=1) + + +def define_parser(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--retriever_mode", + type=str, + default=None, + help="Retriever mode, default is None, can be 'container' or 'file'", + ) + return parser.parse_args() + + +if __name__ == "__main__": + main() diff --git a/examples/advanced/streaming/utils/log_memory.sh b/examples/advanced/streaming/utils/log_memory.sh new file mode 100644 index 0000000000..33f5af5944 --- /dev/null +++ b/examples/advanced/streaming/utils/log_memory.sh @@ -0,0 +1,9 @@ +#!/bin/bash -e + +echo " date time $(free -m | grep total | sed -E 's/^ (.*)/\1/g')" +counter=1 +while [ $counter -le 400 ]; do + echo "$(date '+%Y-%m-%d %H:%M:%S') $(free -m | grep Mem: | sed 's/Mem://g')" + sleep 0.5 + ((counter++)) +done