From 3c0d43b2e556bfb6cd8dd85c6b77369100a48dd3 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Sun, 26 Jan 2025 16:28:11 -0500 Subject: [PATCH 01/11] dev --- nvflare/fuel/f3/cellnet/connector_manager.py | 6 + nvflare/fuel/f3/communicator.py | 14 +- nvflare/fuel/f3/drivers/aio_grpc_driver.py | 3 +- nvflare/fuel/f3/drivers/aio_http_driver.py | 3 +- nvflare/fuel/f3/drivers/grpc_driver.py | 3 +- nvflare/fuel/f3/drivers/tcp_driver.py | 3 +- .../fed/app/deployer/base_client_deployer.py | 5 +- nvflare/private/fed/app/fl_conf.py | 3 + nvflare/private/fed/app/relay/relay.py | 133 ++++++++++++++++++ nvflare/private/fed/client/fed_client_base.py | 47 +++++-- nvflare/private/fed/server/fed_server.py | 5 + 11 files changed, 206 insertions(+), 19 deletions(-) create mode 100644 nvflare/private/fed/app/relay/relay.py diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 5c0e65a48f..99bcf07912 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -20,6 +20,7 @@ from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import CommError, Communicator, Mode +from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception, secure_format_traceback @@ -85,6 +86,11 @@ def __init__(self, communicator: Communicator, secure: bool, comm_configurator: self.adhoc_scheme = adhoc_conf.get(_KEY_SCHEME) self.adhoc_resources = adhoc_conf.get(_KEY_RESOURCES) + # default conn sec + conn_sec = self.int_resources.get(DriverParams.CONNECTION_SECURITY) + if not conn_sec: + self.int_resources[DriverParams.CONNECTION_SECURITY] = ConnectionSecurity.INSECURE + self.logger.debug(f"internal scheme={self.int_scheme}, resources={self.int_resources}") self.logger.debug(f"adhoc scheme={self.adhoc_scheme}, resources={self.adhoc_resources}") self.comm_config = comm_config diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index 02714fd84e..e59de1b58f 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import atexit +import copy import logging import os import weakref @@ -201,7 +202,9 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}") connect_url, listening_url = driver_class.get_urls(scheme, resources) - params = parse_url(listening_url) + extra_params = parse_url(listening_url) + params = copy.copy(resources) + params.update(extra_params) handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True) @@ -225,11 +228,18 @@ def add_connector_advanced( Raises: CommError: If any errors """ - + original_conn_sec = params.get(DriverParams.CONNECTION_SECURITY) if self.local_endpoint.conn_props: params.update(self.local_endpoint.conn_props) + if original_conn_sec: + # we do not allow the connection sec to be overwritten by the endpoint's conn_props + params[DriverParams.CONNECTION_SECURITY] = original_conn_sec + params[DriverParams.SECURE] = secure + + log.info(f"$$$ add_connector_advanced: {params=} {mode=} {original_conn_sec=}") + handle = self.conn_manager.add_connector(driver, params, mode) if not start: diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 972e8fc05f..ac82ec4dec 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -22,6 +22,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers.aio_context import AioContext @@ -409,7 +410,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure(resources) if secure: if use_aio_grpc(): scheme = "grpcs" diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index 61c953a867..cfcd7e2121 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -18,6 +18,7 @@ import websockets from websockets.exceptions import ConnectionClosedOK +from nvflare.fuel.f3.comm_config_utils import requires_secure from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers import net_utils @@ -120,7 +121,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure(resources) if secure: scheme = "https" diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 2cd2d0bb1b..f61302def1 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -20,6 +20,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver import ConnectorInfo @@ -274,7 +275,7 @@ def connect(self, connector: ConnectorInfo): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure(resources) if secure: if use_aio_grpc(): scheme = "nagrpcs" diff --git a/nvflare/fuel/f3/drivers/tcp_driver.py b/nvflare/fuel/f3/drivers/tcp_driver.py index f7aff1a75d..6ee11a4f4f 100644 --- a/nvflare/fuel/f3/drivers/tcp_driver.py +++ b/nvflare/fuel/f3/drivers/tcp_driver.py @@ -17,6 +17,7 @@ from socketserver import TCPServer, ThreadingTCPServer from typing import Any, Dict, List +from nvflare.fuel.f3.comm_config_utils import requires_secure from nvflare.fuel.f3.drivers.base_driver import BaseDriver from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams @@ -100,7 +101,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure(resources) if secure: scheme = "stcp" diff --git a/nvflare/private/fed/app/deployer/base_client_deployer.py b/nvflare/private/fed/app/deployer/base_client_deployer.py index b4c684c093..353d63ce20 100644 --- a/nvflare/private/fed/app/deployer/base_client_deployer.py +++ b/nvflare/private/fed/app/deployer/base_client_deployer.py @@ -45,8 +45,9 @@ def build(self, build_ctx): self.components = build_ctx["client_components"] self.handlers = build_ctx["client_handlers"] - def set_model_manager(self, model_manager): - self.model_manager = model_manager + relay_config = build_ctx.get("relay_config") + if relay_config: + self.client_config["relay_config"] = relay_config def create_fed_client(self, args, sp_target=None): if sp_target: diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 7a60876637..9cba54ef14 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -225,6 +225,8 @@ def __init__(self, workspace: Workspace, args, kv_list=None): config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False) + print(f"got all config files: {config_files}") + JsonConfigurator.__init__( self, config_file_name=config_files, @@ -324,6 +326,7 @@ def finalize_config(self, config_ctx: ConfigContext): "overseer_agent": self.overseer_agent, "client_components": self.components, "client_handlers": self.handlers, + "relay_config": self.config_data.get("relay"), } custom_validators = [self.app_validator] if self.app_validator else [] diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py new file mode 100644 index 0000000000..1a40c508c7 --- /dev/null +++ b/nvflare/private/fed/app/relay/relay.py @@ -0,0 +1,133 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import os +import sys +import threading + +from nvflare.apis.fl_constant import SecureTrainConst, WorkspaceConstants +from nvflare.apis.workspace import Workspace +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.cellnet.net_agent import NetAgent +from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.net_utils import SSL_ROOT_CERT +from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm +from nvflare.fuel.utils.config_service import ConfigService, search_file +from nvflare.fuel.utils.log_utils import configure_logging +from nvflare.fuel.utils.url_utils import make_url + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) + parser.add_argument("--relay_config", "-s", type=str, help="relay config json file", required=True) + parser.add_argument("--set", metavar="KEY=VALUE", nargs="*") + args = parser.parse_args() + return args + + +def main(args): + workspace = Workspace(root_dir=args.workspace) + for name in [WorkspaceConstants.RESTART_FILE, WorkspaceConstants.SHUTDOWN_FILE]: + try: + f = workspace.get_file_path_in_root(name) + if os.path.exists(f): + os.remove(f) + except Exception as ex: + print(f"Could not remove file '{name}': {ex}. Please check your system before starting FL.") + sys.exit(-1) + + relay_config_file = workspace.get_file_path_in_startup(args.relay_config) + with open(relay_config_file, "rt") as f: + relay_config = json.load(f) + + if not isinstance(relay_config, dict): + raise RuntimeError(f"invalid relay config file {args.relay_config}") + + my_identity = relay_config.get("identity") + if not my_identity: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing identity") + + parent = relay_config.get("parent") + if not parent: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent") + + parent_address = parent.get("address") + if not parent_address: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.address") + + parent_scheme = parent.get("scheme") + if not parent_scheme: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.scheme") + + parent_fqcn = parent.get("fqcn") + if not parent_fqcn: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.fqcn") + + configure_logging(workspace, workspace.get_root_dir()) + + stop_event = threading.Event() + + ConfigService.initialize( + section_files={}, + config_path=[args.workspace], + ) + + root_cert_path = search_file(SSL_ROOT_CERT, args.workspace) + if not root_cert_path: + raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {args.workspace}") + + credentials = { + DriverParams.CA_CERT.value: root_cert_path, + } + + conn_security = parent.get(SecureTrainConst.CONNECTION_SECURITY) + secure = True + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + if conn_security == ConnectionSecurity.INSECURE: + secure = False + parent_url = make_url(parent_scheme, parent_address, secure) + + if parent_fqcn == FQCN.ROOT_SERVER: + my_fqcn = my_identity + root_url = parent_url + parent_url = None + else: + my_fqcn = FQCN.join([parent_fqcn, my_identity]) + root_url = None + + cell = Cell( + fqcn=my_fqcn, + root_url=root_url, + secure=secure, + credentials=credentials, + create_internal_listener=True, + parent_url=parent_url, + ) + net_agent = NetAgent(cell) + cell.start() + + # wait until stopped + print(f"started relay {my_identity=} {my_fqcn=}") + stop_event.wait() + + +if __name__ == "__main__": + args = parse_arguments() + rc = mpm.run(main_func=main, run_dir=args.workspace, args=args) + sys.exit(rc) diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 6a40717457..d7c98da5b1 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -32,6 +32,7 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.fuel.utils.url_utils import make_url from nvflare.security.logging import secure_format_exception from .client_status import ClientStatus @@ -188,13 +189,35 @@ def _create_cell(self, location, scheme): """ # Determine the CP's fqcn root_url = scheme + "://" + location + relay_fqcn = None + relay_url = None + relay_conn_security = None + conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) + + # relay info is set in the client's local/resources.json. + # If relay is used, then connect via the specified relay; if not, try to connect the Server directly + self.logger.info(f"Config data: {self.client_args=}") + self.logger.info(f"Args: {self.args=}") + relay_config = self.client_args.get("relay_config") + self.logger.info(f"got relay config: {relay_config}") + if relay_config: + if relay_config: + relay_fqcn = relay_config.get("fqcn") + scheme = relay_config.get("scheme") + addr = relay_config.get("address") + relay_conn_security = relay_config.get(SecureTrainConst.CONNECTION_SECURITY) + secure = True + if relay_conn_security == "insecure": + secure = False + relay_url = make_url(scheme, addr, secure) + self.logger.info(f"connect to server via relay: {relay_url=} {relay_fqcn=}") + else: + self.logger.info("no relay defined: connect to server directly") + else: + self.logger.info("no comm_config: connect to server directly") - # bridge_fqcn and bridge_url are set in the client's local/resources.json. - # If they are set, then connect via the specified bridge; if not, try to connect the Server directly - bridge_fqcn = self.client_args.get("bridge_fqcn") - bridge_url = self.client_args.get("bridge_url") - if bridge_fqcn: - cp_fqcn = FQCN.join([bridge_fqcn, self.client_name]) + if relay_fqcn: + cp_fqcn = FQCN.join([relay_fqcn, self.client_name]) root_url = None # do not connect to server if bridge is used else: cp_fqcn = self.client_name @@ -209,8 +232,10 @@ def _create_cell(self, location, scheme): # I am CP me = "CP" my_fqcn = cp_fqcn - parent_url = bridge_url + parent_url = relay_url create_internal_listener = True + if relay_conn_security: + conn_security = relay_conn_security if self.secure_train: root_cert = self.client_args[SecureTrainConst.SSL_ROOT_CERT] @@ -222,13 +247,13 @@ def _create_cell(self, location, scheme): DriverParams.CLIENT_CERT.value: ssl_cert, DriverParams.CLIENT_KEY.value: private_key, } - conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) - if conn_security: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security - set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) else: credentials = {} + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) + self.logger.info(f"{me=}: {my_fqcn=} {root_url=} {parent_url=}") self.cell = Cell( fqcn=my_fqcn, diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 354df3fda3..feaa73a161 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -426,6 +426,11 @@ def _validate_auth_headers(self, message: Message): self.logger.debug(f"skip special message {topic=} {channel=}") return None + if channel in ["_net_manager"]: + # skip internal net query messages for now to support relays + # TBD: need to add relay authentication + return None + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" if not client_name: From d02bd9b585085126d42bbae74a8bdbf5007ffaf9 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Mon, 27 Jan 2025 14:52:01 -0500 Subject: [PATCH 02/11] relay dev --- nvflare/apis/fl_constant.py | 21 +++++- .../executors/client_api_launcher_executor.py | 6 +- nvflare/client/config.py | 4 +- nvflare/client/ex_process/api.py | 4 +- nvflare/fuel/f3/cellnet/connector_manager.py | 3 +- nvflare/fuel/f3/cellnet/core_cell.py | 3 +- nvflare/fuel/f3/drivers/driver_params.py | 6 -- nvflare/fuel/f3/drivers/grpc/utils.py | 3 +- nvflare/fuel/f3/drivers/net_utils.py | 3 +- nvflare/fuel/utils/pipe/cell_pipe.py | 4 +- nvflare/private/fed/app/fl_conf.py | 65 ++++++++++++++++++- nvflare/private/fed/app/relay/relay.py | 16 ++--- nvflare/private/fed/client/fed_client_base.py | 46 ++++--------- nvflare/private/fed/server/fed_server.py | 9 ++- .../private/fed/simulator/simulator_server.py | 1 - nvflare/private/json_configer.py | 2 + 16 files changed, 128 insertions(+), 68 deletions(-) diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index b45cb58582..828b6f80bd 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -443,7 +443,6 @@ class SecureTrainConst: SSL_ROOT_CERT = "ssl_root_cert" SSL_CERT = "ssl_cert" PRIVATE_KEY = "ssl_private_key" - CONNECTION_SECURITY = "connection_security" class FLMetaKey: @@ -552,3 +551,23 @@ class RunnerTask: INIT = "init" TASK_EXEC = "task_exec" END_RUN = "end_run" + + +class ConnPropKey: + + IDENTITY = "identity" + PARENT = "parent" + FQCN = "fqcn" + URL = "url" + SCHEME = "scheme" + ADDRESS = "address" + CONNECTION_SECURITY = "connection_security" + CP_FQCN = "cp_fqcn" + CONNECTION_PROPERTIES = "connection_properties" + RELAY_CONFIG = "relay_config" + + +class ConnectionSecurity: + INSECURE = "insecure" + TLS = "tls" + MTLS = "mtls" diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 3b470edf2c..ee6e257f02 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,7 +15,7 @@ import os from typing import Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.fl_context import FLContext from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.launcher_executor import LauncherExecutor @@ -138,9 +138,9 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): FLMetaKey.AUTH_TOKEN_SIGNATURE: signature, } - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) + conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) if conn_sec: - config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec + config_data[ConnPropKey.CONNECTION_SECURITY] = conn_sec config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=config_data, config_file_path=config_file_path) diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 477b132dc3..4346242712 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -16,7 +16,7 @@ import os from typing import Dict, Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.fuel.utils.config_factory import ConfigFactory @@ -157,7 +157,7 @@ def get_heartbeat_timeout(self): ) def get_connection_security(self): - return self.config.get(SecureTrainConst.CONNECTION_SECURITY) + return self.config.get(ConnPropKey.CONNECTION_SECURITY) def get_site_name(self): return self.config.get(FLMetaKey.SITE_NAME) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index bb94b70939..cc68c55ff1 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple from nvflare.apis.analytix import AnalyticsDataType -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.fl_model import FLModel from nvflare.client.api_spec import APISpec @@ -41,7 +41,7 @@ def _create_client_config(config: str) -> ClientConfig: site_name = client_config.get_site_name() conn_sec = client_config.get_connection_security() if conn_sec: - set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec) + set_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) # get message auth info and put them into Databus for CellPipe to use auth_token = client_config.get_auth_token() diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 99bcf07912..05d7517afa 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -15,12 +15,13 @@ import time from typing import Union +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.cellnet.defs import ConnectorRequirementKey from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import CommError, Communicator, Mode -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception, secure_format_traceback diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 4e8b68fa8b..738d9aa5ea 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -22,6 +22,7 @@ from typing import Dict, List, Tuple, Union from urllib.parse import urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager from nvflare.fuel.f3.cellnet.defs import ( @@ -43,7 +44,7 @@ from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState from nvflare.fuel.f3.message import Message diff --git a/nvflare/fuel/f3/drivers/driver_params.py b/nvflare/fuel/f3/drivers/driver_params.py index 54118855e3..c03c89618d 100644 --- a/nvflare/fuel/f3/drivers/driver_params.py +++ b/nvflare/fuel/f3/drivers/driver_params.py @@ -44,12 +44,6 @@ class DriverParams(str, Enum): IMPLEMENTED_CONN_SEC = "implemented_conn_sec" -class ConnectionSecurity: - INSECURE = "insecure" - TLS = "tls" - MTLS = "mtls" - - class DriverCap(str, Enum): SEND_HEARTBEAT = "send_heartbeat" diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py index bbad1a522f..7c1f0de896 100644 --- a/nvflare/fuel/f3/drivers/grpc/utils.py +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -13,8 +13,9 @@ # limitations under the License. import grpc +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams def use_aio_grpc(): diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 8aa4e1f60f..a59ef713ec 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -20,8 +20,9 @@ from typing import Any, Optional from urllib.parse import parse_qsl, urlencode, urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_error import CommError -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.argument_utils import str2bool from nvflare.security.logging import secure_format_exception diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 1554a0dd44..3150bd9386 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -17,7 +17,7 @@ import time from typing import Tuple, Union -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey, SystemVarName from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.cell import Message as CellMessage @@ -150,7 +150,7 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di DriverParams.CA_CERT.value: root_cert_path, } - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) + conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) if conn_sec: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 9cba54ef14..535692ff38 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -19,11 +19,14 @@ import sys from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FilterKey, SiteType, SystemConfigs +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, FilterKey, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node +from nvflare.fuel.utils.url_utils import make_url from nvflare.fuel.utils.wfconf import ConfigContext, ConfigError from nvflare.private.defs import SSLConstants from nvflare.private.json_configer import JsonConfigurator @@ -285,6 +288,48 @@ def build_component(self, config_dict): self.handlers.append(t) return t + def _determine_conn_props(self, client_name, config_data: dict): + relay_fqcn = None + relay_url = None + relay_conn_security = None + + # relay info is set in the client's relay__resources.json. + # If relay is used, then connect via the specified relay; if not, try to connect the Server directly + print(f"Config data: {config_data=}") + print(f"Args: {self.args=}") + relay_config = config_data.get(ConnPropKey.RELAY_CONFIG) + self.logger.info(f"got relay config: {relay_config}") + if relay_config: + if relay_config: + relay_fqcn = relay_config.get(ConnPropKey.FQCN) + scheme = relay_config.get(ConnPropKey.SCHEME) + addr = relay_config.get(ConnPropKey.ADDRESS) + relay_conn_security = relay_config.get(ConnPropKey.CONNECTION_SECURITY) + secure = True + if relay_conn_security == ConnectionSecurity.INSECURE: + secure = False + relay_url = make_url(scheme, addr, secure) + print(f"connect to server via relay: {relay_url=} {relay_fqcn=}") + else: + print("no relay defined: connect to server directly") + else: + print("no relay_config: connect to server directly") + + if relay_fqcn: + cp_fqcn = FQCN.join([relay_fqcn, client_name]) + else: + cp_fqcn = client_name + + result = { + ConnPropKey.CP_FQCN: cp_fqcn, + } + + if relay_fqcn: + result[ConnPropKey.FQCN] = relay_fqcn + result[ConnPropKey.URL] = relay_url + result[ConnPropKey.CONNECTION_SECURITY] = relay_conn_security + return result + def start_config(self, config_ctx: ConfigContext): """Start the config process. @@ -303,6 +348,23 @@ def start_config(self, config_ctx: ConfigContext): client[SSLConstants.CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.CERT]) if client.get(SSLConstants.ROOT_CERT): client[SSLConstants.ROOT_CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.ROOT_CERT]) + + client_name = self.cmd_vars.get("uid", None) + if not client_name: + raise ConfigError("missing 'uid' from command args") + + relay_config = self.config_data.get(ConnPropKey.RELAY_CONFIG) + print(f"got relay config: {relay_config}") + if relay_config: + set_scope_property(client_name, ConnPropKey.RELAY_CONFIG, relay_config) + + conn_sec = client.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + set_scope_property(client_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) + + conn_props = self._determine_conn_props(client_name, self.config_data) + set_scope_property(client_name, ConnPropKey.CONNECTION_PROPERTIES, conn_props) + except Exception: raise ValueError(f"Client config error: '{self.client_config_file_names}'") @@ -326,7 +388,6 @@ def finalize_config(self, config_ctx: ConfigContext): "overseer_agent": self.overseer_agent, "client_components": self.components, "client_handlers": self.handlers, - "relay_config": self.config_data.get("relay"), } custom_validators = [self.app_validator] if self.app_validator else [] diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py index 1a40c508c7..ec390ef06a 100644 --- a/nvflare/private/fed/app/relay/relay.py +++ b/nvflare/private/fed/app/relay/relay.py @@ -18,12 +18,12 @@ import sys import threading -from nvflare.apis.fl_constant import SecureTrainConst, WorkspaceConstants +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, WorkspaceConstants from nvflare.apis.workspace import Workspace from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.drivers.net_utils import SSL_ROOT_CERT from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.config_service import ConfigService, search_file @@ -58,23 +58,23 @@ def main(args): if not isinstance(relay_config, dict): raise RuntimeError(f"invalid relay config file {args.relay_config}") - my_identity = relay_config.get("identity") + my_identity = relay_config.get(ConnPropKey.IDENTITY) if not my_identity: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing identity") - parent = relay_config.get("parent") + parent = relay_config.get(ConnPropKey.PARENT) if not parent: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent") - parent_address = parent.get("address") + parent_address = parent.get(ConnPropKey.ADDRESS) if not parent_address: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.address") - parent_scheme = parent.get("scheme") + parent_scheme = parent.get(ConnPropKey.SCHEME) if not parent_scheme: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.scheme") - parent_fqcn = parent.get("fqcn") + parent_fqcn = parent.get(ConnPropKey.FQCN) if not parent_fqcn: raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.fqcn") @@ -95,7 +95,7 @@ def main(args): DriverParams.CA_CERT.value: root_cert_path, } - conn_security = parent.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = parent.get(ConnPropKey.CONNECTION_SECURITY) secure = True if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index d7c98da5b1..49d1571505 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -18,13 +18,13 @@ from nvflare.apis.filter import Filter from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, SecureTrainConst, ServerCommandKey +from nvflare.apis.fl_constant import ConnPropKey, FLContextKey, SecureTrainConst, ServerCommandKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.overseer_spec import SP from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal -from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent @@ -32,7 +32,6 @@ from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.fuel.utils.url_utils import make_url from nvflare.security.logging import secure_format_exception from .client_status import ClientStatus @@ -189,39 +188,16 @@ def _create_cell(self, location, scheme): """ # Determine the CP's fqcn root_url = scheme + "://" + location - relay_fqcn = None - relay_url = None - relay_conn_security = None - conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) - - # relay info is set in the client's local/resources.json. - # If relay is used, then connect via the specified relay; if not, try to connect the Server directly - self.logger.info(f"Config data: {self.client_args=}") - self.logger.info(f"Args: {self.args=}") - relay_config = self.client_args.get("relay_config") - self.logger.info(f"got relay config: {relay_config}") - if relay_config: - if relay_config: - relay_fqcn = relay_config.get("fqcn") - scheme = relay_config.get("scheme") - addr = relay_config.get("address") - relay_conn_security = relay_config.get(SecureTrainConst.CONNECTION_SECURITY) - secure = True - if relay_conn_security == "insecure": - secure = False - relay_url = make_url(scheme, addr, secure) - self.logger.info(f"connect to server via relay: {relay_url=} {relay_fqcn=}") - else: - self.logger.info("no relay defined: connect to server directly") - else: - self.logger.info("no comm_config: connect to server directly") + conn_security = self.client_args.get(ConnPropKey.CONNECTION_SECURITY) + + conn_props = get_scope_property(self.client_name, ConnPropKey.CONNECTION_PROPERTIES) + self.logger.info(f"got connection_properties: {conn_props}") + relay_fqcn = conn_props.get(ConnPropKey.FQCN) if relay_fqcn: - cp_fqcn = FQCN.join([relay_fqcn, self.client_name]) - root_url = None # do not connect to server if bridge is used - else: - cp_fqcn = self.client_name + root_url = None # do not connect to server if relay is used + cp_fqcn = conn_props.get(ConnPropKey.CP_FQCN) if self.args.job_id: # I am CJ me = "CJ" @@ -232,8 +208,9 @@ def _create_cell(self, location, scheme): # I am CP me = "CP" my_fqcn = cp_fqcn - parent_url = relay_url + parent_url = conn_props.get(ConnPropKey.URL) create_internal_listener = True + relay_conn_security = conn_props.get(ConnPropKey.CONNECTION_SECURITY) if relay_conn_security: conn_security = relay_conn_security @@ -252,7 +229,6 @@ def _create_cell(self, location, scheme): if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security - set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) self.logger.info(f"{me=}: {my_fqcn=} {root_url=} {parent_url=}") self.cell = Cell( diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index feaa73a161..a6c13b8c87 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -25,6 +25,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ( ConfigVarName, + ConnPropKey, FLContextKey, MachineStatus, RunProcessKey, @@ -169,7 +170,7 @@ def deploy(self, args, grpc_args=None, secure_train=False): DriverParams.SERVER_KEY.value: private_key, } - conn_security = grpc_args.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = grpc_args.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: @@ -431,6 +432,10 @@ def _validate_auth_headers(self, message: Message): # TBD: need to add relay authentication return None + if channel == "cellnet.channel" and topic == "bye": + # skip cellnet goodbye + return None + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" if not client_name: @@ -546,7 +551,7 @@ def create_job_cell(self, job_id, root_url, parent_url, secure_train, server_con DriverParams.SERVER_KEY.value: private_key, } - conn_security = server_config.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = server_config.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: diff --git a/nvflare/private/fed/simulator/simulator_server.py b/nvflare/private/fed/simulator/simulator_server.py index 881eed6039..13b273fa1b 100644 --- a/nvflare/private/fed/simulator/simulator_server.py +++ b/nvflare/private/fed/simulator/simulator_server.py @@ -79,7 +79,6 @@ def create_job_processing_context_properties(self, workspace, job_id): class SimulatorIdentityAsserter(IdentityAsserter): - def __init__(self, private_key_file: str, cert_file: str): self.private_key_file = private_key_file self.cert_file = cert_file diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index 3a987fde69..82f0cbd3f4 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -21,6 +21,7 @@ from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.dict_utils import augment from nvflare.fuel.utils.json_scanner import JsonObjectProcessor, JsonScanner, Node +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.wfconf import resolve_var_refs from nvflare.security.logging import secure_format_exception @@ -53,6 +54,7 @@ def __init__( sys_vars: system vars """ JsonObjectProcessor.__init__(self) + self.logger = get_obj_logger(self) if not isinstance(num_passes, int): raise TypeError(f"num_passes must be int but got {num_passes}") From e750a8865273d1fa3ed6127d3cce6cbe2afbb202 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Tue, 28 Jan 2025 16:53:04 -0500 Subject: [PATCH 03/11] dev --- nvflare/apis/fl_constant.py | 6 ++- nvflare/apis/job_launcher_spec.py | 1 + .../widgets/external_configurator.py | 16 ++++++- nvflare/app_opt/job_launcher/k8s_launcher.py | 6 ++- nvflare/fuel/f3/cellnet/connector_manager.py | 15 ++++-- nvflare/fuel/f3/cellnet/core_cell.py | 16 +++++-- nvflare/fuel/f3/communicator.py | 5 +- nvflare/fuel/f3/drivers/net_utils.py | 7 +++ nvflare/fuel/utils/config_service.py | 9 +++- nvflare/fuel/utils/url_utils.py | 47 ++++++++++++++++++ .../private/fed/app/client/worker_process.py | 7 +++ nvflare/private/fed/app/fl_conf.py | 48 +++++++++++++------ nvflare/private/fed/client/client_executor.py | 10 +++- nvflare/private/fed/client/fed_client_base.py | 33 ++++++++----- nvflare/utils/job_launcher_utils.py | 6 ++- 15 files changed, 188 insertions(+), 44 deletions(-) create mode 100644 nvflare/fuel/utils/url_utils.py diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index 828b6f80bd..ac42c12cc5 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -562,9 +562,11 @@ class ConnPropKey: SCHEME = "scheme" ADDRESS = "address" CONNECTION_SECURITY = "connection_security" - CP_FQCN = "cp_fqcn" - CONNECTION_PROPERTIES = "connection_properties" + RELAY_CONFIG = "relay_config" + CP_CONN_PROPS = "cp_conn_props" + RELAY_CONN_PROPS = "relay_conn_props" + ROOT_CONN_PROPS = "root_conn_props" class ConnectionSecurity: diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index 36edc62380..cb75130e6a 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -32,6 +32,7 @@ class JobProcessArgs: CLIENT_NAME = "client_name" ROOT_URL = "root_url" PARENT_URL = "parent_url" + PARENT_CONN_SEC = "parent_conn_sec" SERVICE_HOST = "service_host" SERVICE_PORT = "service_port" HA_MODE = "ha_mode" diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index b9b8c5b14d..1d320dba91 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -16,10 +16,11 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.fl_constant import FLMetaKey, ConnPropKey from nvflare.apis.fl_context import FLContext from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components from nvflare.fuel.utils.validation_utils import check_object_type from nvflare.widgets.widget import Widget @@ -47,8 +48,19 @@ def __init__( def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABOUT_TO_START_RUN: components_data = self._export_all_components(fl_ctx) - components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name() + + site_name = fl_ctx.get_identity_name() + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + + components_data[FLMetaKey.SITE_NAME] = site_name components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() + components_data[FLMetaKey.AUTH_TOKEN] = auth_token + components_data[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature + + conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + components_data[ConnPropKey.CONNECTION_SECURITY] = conn_sec config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=components_data, config_file_path=config_file_path) diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 39af3675ed..12db4891f2 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -277,7 +277,11 @@ def get_module_args(self, job_id, fl_ctx: FLContext): def _job_args_dict(job_args: dict, arg_names: list) -> dict: result = {} for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + + n, v = e result[n] = v return result diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 05d7517afa..1b2184eb9b 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -50,6 +50,9 @@ def __init__(self, handle, connect_url: str, active: bool, params: dict): def get_connection_url(self): return self.connect_url + def get_connection_params(self): + return self.params + class ConnectorManager: """ @@ -159,7 +162,7 @@ def _validate_conn_config(config: dict, key: str) -> Union[None, dict]: return conn_config def _get_connector( - self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool + self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool, conn_resources=None ) -> Union[None, ConnectorData]: if active and not url: raise RuntimeError("url is required by not provided for active connector!") @@ -200,10 +203,10 @@ def _get_connector( try: if active: - handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required, conn_resources) connect_url = url elif url: - handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required, conn_resources) connect_url = url else: self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}") @@ -247,11 +250,13 @@ def get_internal_listener(self) -> Union[None, ConnectorData]: """ return self._get_connector(url="", active=False, internal=True, adhoc=False, secure=False) - def get_internal_connector(self, url: str) -> Union[None, ConnectorData]: + def get_internal_connector(self, url: str, conn_resources=None) -> Union[None, ConnectorData]: """ Try to get an internal listener. Args: url: """ - return self._get_connector(url=url, active=True, internal=True, adhoc=False, secure=False) + return self._get_connector( + url=url, active=True, internal=True, adhoc=False, secure=False, conn_resources=conn_resources + ) diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 738d9aa5ea..8c43d22d86 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -282,6 +282,7 @@ def __init__( credentials: dict, create_internal_listener: bool = False, parent_url: str = None, + parent_resources: dict = None, max_timeout=3600, bulk_check_interval=0.5, bulk_process_interval=0.5, @@ -297,6 +298,7 @@ def __init__( max_timeout: default timeout for send_and_receive create_internal_listener: whether to create an internal listener for child cells parent_url: url for connecting to parent cell + parent_resources: extra resources for making connection to parent FQCN is the names of all ancestor, concatenated with dots. @@ -324,6 +326,8 @@ def __init__( self.max_msg_size = comm_configurator.get_max_message_size() self.comm_configurator = comm_configurator + self.logger.info(f"XXXX {parent_resources=} {parent_url=}") + err = FQCN.validate(fqcn) if err: raise ValueError(f"Invalid FQCN '{fqcn}': {err}") @@ -371,6 +375,7 @@ def __init__( self.root_url = root_url self.create_internal_listener = create_internal_listener self.parent_url = parent_url + self.parent_resources = parent_resources self.bulk_check_interval = bulk_check_interval self.max_bulk_size = max_bulk_size self.bulk_checker = None @@ -567,7 +572,7 @@ def _set_bb_for_client_root(self): def _set_bb_for_client_child(self, parent_url: str, create_internal_listener: bool): if parent_url: - self._create_internal_connector(parent_url) + self._create_internal_connector(parent_url, self.parent_resources) if create_internal_listener: self._create_internal_listener() @@ -694,6 +699,11 @@ def get_internal_listener_url(self) -> Union[None, str]: return None return self.int_listener.get_connection_url() + def get_internal_listener_params(self) -> Union[None, dict]: + if not self.int_listener: + return None + return self.int_listener.get_connection_params() + def _add_adhoc_connector(self, to_cell: str, url: str): if self.bb_ext_connector: # it is possible that the server root offers connect url after the bb_ext_connector is created @@ -787,8 +797,8 @@ def _create_bb_external_connector(self): else: raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone external connector to {self.root_url}") - def _create_internal_connector(self, url: str): - self.bb_int_connector = self.connector_manager.get_internal_connector(url) + def _create_internal_connector(self, url: str, resources=None): + self.bb_int_connector = self.connector_manager.get_internal_connector(url, resources) if self.bb_int_connector: self.logger.info(f"{self.my_info.fqcn}: created backbone internal connector to {url} on parent") else: diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index e59de1b58f..aac9a4ba37 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -155,13 +155,14 @@ def register_message_receiver(self, app_id: int, receiver: MessageReceiver): self.conn_manager.register_message_receiver(app_id, receiver) - def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dict): + def add_connector(self, url: str, mode: Mode, secure: bool = False, resources=None) -> (str, dict): """Load a connector. The driver is selected based on the URL Args: url: The url to listen on or connect to, like "https://0:443". Use 0 for empty host mode: Active for connecting, Passive for listening secure: True if SSL is required. + resources: extra resources for creating connection Returns: A tuple of (A handle that can be used to delete connector, connector params) @@ -178,6 +179,8 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dic raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") params = parse_url(url) + if resources: + params.update(resources) return self.add_connector_advanced(driver_class(), mode, params, secure, False), params def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index a59ef713ec..12f0fe0eb5 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -57,6 +57,7 @@ def ssl_required(params: dict) -> bool: def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: if not ssl_required(params): params[DriverParams.IMPLEMENTED_CONN_SEC.value] = "clear" + log.info("get_ssl_context: XXX clear") return None conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS) @@ -65,6 +66,10 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: ca_path = params.get(DriverParams.CA_CERT.value) cert_path = params.get(DriverParams.SERVER_CERT.value) key_path = params.get(DriverParams.SERVER_KEY.value) + + if not cert_path or not key_path: + raise RuntimeError(f"not cert or key for SSL server: {params=}") + if conn_security == ConnectionSecurity.TLS: # do not require client auth ctx.verify_mode = ssl.CERT_NONE @@ -83,6 +88,7 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: # no custom CA cert: use provisioned CA cert ca_path = params.get(DriverParams.CA_CERT.value) params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Flare CA Cert used" + log.info("get_ssl_context: XXX 1-way SSL") cert_path = None key_path = None else: @@ -91,6 +97,7 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: cert_path = params.get(DriverParams.CLIENT_CERT.value) key_path = params.get(DriverParams.CLIENT_KEY.value) params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client mTLS: Flare credentials used" + log.info("get_ssl_context: XXX 2-way SSL") if not ca_path: scheme = params.get(DriverParams.SCHEME.value, "Unknown") diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index 7095c59e8f..a829601872 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -116,7 +116,9 @@ def initialize(cls, section_files: Dict[str, str], config_path: List[str], parse if not os.path.isdir(d): raise ValueError(f"'{d}' is not a valid directory") - cls._config_path = config_path + for d in config_path: + if d not in cls._config_path: + cls._config_path.append(d) for section, file_basename in section_files.items(): cls._sections[section] = cls.load_config_dict(file_basename, cls._config_path) @@ -185,7 +187,10 @@ def load_configuration(cls, file_basename: str) -> Optional[Config]: Returns: config data loaded, or None if the config file is not found. """ - return ConfigFactory.load_config(file_basename, cls._config_path) + cls.logger.info(f"XXXX load_configuration: {file_basename=} {cls._config_path=}") + result = ConfigFactory.load_config(file_basename, cls._config_path) + cls.logger.info(f"XXXX loaded: {result}") + return result @classmethod def load_config_dict( diff --git a/nvflare/fuel/utils/url_utils.py b/nvflare/fuel/utils/url_utils.py new file mode 100644 index 0000000000..2219ebe424 --- /dev/null +++ b/nvflare/fuel/utils/url_utils.py @@ -0,0 +1,47 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_SECURE_SCHEME_MAPPING = {"tcp": "stcp", "grpc": "grpcs", "http": "https"} + + +def make_url(scheme: str, address, secure: bool) -> str: + if secure: + scheme = _SECURE_SCHEME_MAPPING.get(scheme) + if not scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + + if isinstance(address, str): + return f"{scheme}://{address}" + else: + port = None + if isinstance(address, (tuple, list)): + host = address[0] + if len(address) > 1: + port = address[1] + elif isinstance(address, dict): + host = address["host"] + port = address.get("port") + else: + raise ValueError(f"invalid address: {address}") + + if not isinstance(host, str): + raise ValueError(f"invalid host '{host}': must be str but got {type(host)}") + + if port: + if not isinstance(port, (str, int)): + raise ValueError(f"invalid port '{port}': must be str or int but got {type(port)}") + port_str = f":{port}" + else: + port_str = "" + return f"{scheme}://{host}{port_str}" diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index b43133abe5..30c8306511 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -154,6 +154,13 @@ def parse_arguments(): parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True) parser.add_argument("--sp_scheme", "-scheme", type=str, help="Sp connection scheme", required=True) parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True) + parser.add_argument( + "--parent_conn_sec", "-pcs", + type=str, + help="parent conn security", + required=False, + default="", + ) parser.add_argument( "--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True ) diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 535692ff38..25d758a6ca 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -320,15 +320,39 @@ def _determine_conn_props(self, client_name, config_data: dict): else: cp_fqcn = client_name - result = { - ConnPropKey.CP_FQCN: cp_fqcn, - } - if relay_fqcn: - result[ConnPropKey.FQCN] = relay_fqcn - result[ConnPropKey.URL] = relay_url - result[ConnPropKey.CONNECTION_SECURITY] = relay_conn_security - return result + relay_conn_props = { + ConnPropKey.FQCN: relay_fqcn, + ConnPropKey.URL: relay_url, + ConnPropKey.CONNECTION_SECURITY: relay_conn_security + } + set_scope_property(client_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) + + client = self.config_data["client"] + + if hasattr(self.args, "job_id") and self.args.job_id: + # this is CJ + sp_scheme = self.args.sp_scheme + sp_target = self.args.sp_target + root_url = f"{sp_scheme}://{sp_target}" + root_conn_props = { + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + ConnPropKey.URL: root_url, + ConnPropKey.CONNECTION_SECURITY: client.get(ConnPropKey.CONNECTION_SECURITY) + } + set_scope_property(client_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + ConnPropKey.URL: self.args.parent_url, + ConnPropKey.CONNECTION_SECURITY: self.args.parent_conn_sec, + } + else: + # this is CP + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + } + set_scope_property(client_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) def start_config(self, config_ctx: ConfigContext): """Start the config process. @@ -353,17 +377,11 @@ def start_config(self, config_ctx: ConfigContext): if not client_name: raise ConfigError("missing 'uid' from command args") - relay_config = self.config_data.get(ConnPropKey.RELAY_CONFIG) - print(f"got relay config: {relay_config}") - if relay_config: - set_scope_property(client_name, ConnPropKey.RELAY_CONFIG, relay_config) - conn_sec = client.get(ConnPropKey.CONNECTION_SECURITY) if conn_sec: set_scope_property(client_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) - conn_props = self._determine_conn_props(client_name, self.config_data) - set_scope_property(client_name, ConnPropKey.CONNECTION_PROPERTIES, conn_props) + self._determine_conn_props(client_name, self.config_data) except Exception: raise ValueError(f"Client config error: '{self.client_config_file_names}'") diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index cbf94777d5..8b4d85e999 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs +from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs, ConnPropKey from nvflare.apis.fl_context import FLContext from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.resource_manager_spec import ResourceManagerSpec @@ -201,6 +201,14 @@ def start_app( JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_client.json"), JobProcessArgs.OPTIONS: ("--set", command_options), } + + params = client.cell.get_internal_listener_params() + self.logger.info(f"XXXX get_internal_listener_params: {params} ") + if params: + parent_conn_sec = params.get(ConnPropKey.CONNECTION_SECURITY) + if parent_conn_sec: + job_args[JobProcessArgs.PARENT_CONN_SEC] = ("-pcs", parent_conn_sec) + fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 49d1571505..99ef38022c 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -188,31 +188,40 @@ def _create_cell(self, location, scheme): """ # Determine the CP's fqcn root_url = scheme + "://" + location - conn_security = self.client_args.get(ConnPropKey.CONNECTION_SECURITY) + root_conn_security = self.client_args.get(ConnPropKey.CONNECTION_SECURITY) - conn_props = get_scope_property(self.client_name, ConnPropKey.CONNECTION_PROPERTIES) - self.logger.info(f"got connection_properties: {conn_props}") + relay_conn_props = get_scope_property(self.client_name, ConnPropKey.RELAY_CONN_PROPS, {}) + self.logger.info(f"got {ConnPropKey.RELAY_CONN_PROPS}: {relay_conn_props}") - relay_fqcn = conn_props.get(ConnPropKey.FQCN) + relay_fqcn = relay_conn_props.get(ConnPropKey.FQCN) if relay_fqcn: root_url = None # do not connect to server if relay is used - cp_fqcn = conn_props.get(ConnPropKey.CP_FQCN) + cp_conn_props = get_scope_property(self.client_name, ConnPropKey.CP_CONN_PROPS) + cp_fqcn = cp_conn_props.get(ConnPropKey.FQCN) + parent_resources = None if self.args.job_id: # I am CJ me = "CJ" my_fqcn = FQCN.join([cp_fqcn, self.args.job_id]) - parent_url = self.args.parent_url + parent_url = cp_conn_props.get(ConnPropKey.URL) + parent_conn_sec = cp_conn_props.get(ConnPropKey.CONNECTION_SECURITY) create_internal_listener = False + if parent_conn_sec: + parent_resources = { + DriverParams.CONNECTION_SECURITY.value: parent_conn_sec + } else: # I am CP me = "CP" my_fqcn = cp_fqcn - parent_url = conn_props.get(ConnPropKey.URL) + parent_url = relay_conn_props.get(ConnPropKey.URL) create_internal_listener = True - relay_conn_security = conn_props.get(ConnPropKey.CONNECTION_SECURITY) + relay_conn_security = relay_conn_props.get(ConnPropKey.CONNECTION_SECURITY) if relay_conn_security: - conn_security = relay_conn_security + parent_resources = { + DriverParams.CONNECTION_SECURITY.value: relay_conn_security + } if self.secure_train: root_cert = self.client_args[SecureTrainConst.SSL_ROOT_CERT] @@ -227,8 +236,9 @@ def _create_cell(self, location, scheme): else: credentials = {} - if conn_security: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + if root_conn_security: + # this is the default conn sec + credentials[DriverParams.CONNECTION_SECURITY.value] = root_conn_security self.logger.info(f"{me=}: {my_fqcn=} {root_url=} {parent_url=}") self.cell = Cell( @@ -238,6 +248,7 @@ def _create_cell(self, location, scheme): credentials=credentials, create_internal_listener=create_internal_listener, parent_url=parent_url, + parent_resources=parent_resources, ) self.cell.start() self.communicator.set_cell(self.cell) diff --git a/nvflare/utils/job_launcher_utils.py b/nvflare/utils/job_launcher_utils.py index 6865673869..a3013e25aa 100644 --- a/nvflare/utils/job_launcher_utils.py +++ b/nvflare/utils/job_launcher_utils.py @@ -24,7 +24,10 @@ def _job_args_str(job_args, arg_names) -> str: result = "" sep = "" for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + n, v = e result += f"{sep}{n} {v}" sep = " " return result @@ -45,6 +48,7 @@ def get_client_job_args(include_exe_module=True, include_set_options=True): JobProcessArgs.JOB_ID, JobProcessArgs.CLIENT_NAME, JobProcessArgs.PARENT_URL, + JobProcessArgs.PARENT_CONN_SEC, JobProcessArgs.TARGET, JobProcessArgs.SCHEME, JobProcessArgs.STARTUP_CONFIG_FILE, From a40f8d7da8b45a7fb0ce89aaa5ea396958c9f509 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Thu, 30 Jan 2025 17:40:47 -0500 Subject: [PATCH 04/11] support pipe_conn_type; fix to work with simulator --- nvflare/apis/fl_constant.py | 2 + .../executors/client_api_launcher_executor.py | 16 +-- nvflare/app_common/utils/export_utils.py | 39 +++++ .../widgets/external_configurator.py | 26 ++-- nvflare/client/config.py | 9 ++ nvflare/client/ex_process/api.py | 15 +- nvflare/fuel/f3/cellnet/core_cell.py | 2 - nvflare/fuel/f3/cellnet/net_manager.py | 7 + nvflare/fuel/f3/communicator.py | 3 - nvflare/fuel/f3/drivers/net_utils.py | 3 - nvflare/fuel/utils/config_service.py | 2 - nvflare/fuel/utils/pipe/cell_pipe.py | 135 +++++++++++++----- nvflare/job_config/script_runner.py | 77 ++++++---- .../private/fed/app/client/worker_process.py | 3 +- nvflare/private/fed/app/fl_conf.py | 6 +- nvflare/private/fed/app/relay/relay.py | 22 ++- nvflare/private/fed/client/client_executor.py | 3 +- .../private/fed/client/client_json_config.py | 19 ++- nvflare/private/fed/client/fed_client_base.py | 8 +- nvflare/private/fed/server/training_cmds.py | 7 + 20 files changed, 284 insertions(+), 120 deletions(-) create mode 100644 nvflare/app_common/utils/export_utils.py diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index ac42c12cc5..401dd801f7 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -541,6 +541,8 @@ class SystemVarName: WORKSPACE = "WORKSPACE" # directory of the workspace JOB_ID = "JOB_ID" # Job ID ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server) + CP_URL = "CP_URL" # URL to CP + RELAY_URL = "RELAY_URL" # URL to relay that the CP is connected to SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job PYTHONPATH = "PYTHONPATH" diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index ee6e257f02..35a99e0fe0 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,13 +15,12 @@ import os from typing import Optional -from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.fl_context import FLContext from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.launcher_executor import LauncherExecutor +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG -from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode @@ -126,22 +125,11 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout, } - site_name = fl_ctx.get_identity_name() - auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") - config_data = { ConfigKey.TASK_EXCHANGE: task_exchange_attributes, - FLMetaKey.SITE_NAME: site_name, - FLMetaKey.JOB_ID: fl_ctx.get_job_id(), - FLMetaKey.AUTH_TOKEN: auth_token, - FLMetaKey.AUTH_TOKEN_SIGNATURE: signature, } - conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) - if conn_sec: - config_data[ConnPropKey.CONNECTION_SECURITY] = conn_sec - + update_export_props(config_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=config_data, config_file_path=config_file_path) diff --git a/nvflare/app_common/utils/export_utils.py b/nvflare/app_common/utils/export_utils.py new file mode 100644 index 0000000000..ba52632c9c --- /dev/null +++ b/nvflare/app_common/utils/export_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey +from nvflare.apis.fl_context import FLContext +from nvflare.fuel.data_event.utils import get_scope_property + + +def update_export_props(props: dict, fl_ctx: FLContext): + site_name = fl_ctx.get_identity_name() + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + + props[FLMetaKey.SITE_NAME] = site_name + props[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() + props[FLMetaKey.AUTH_TOKEN] = auth_token + props[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature + + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + if root_conn_props: + props[ConnPropKey.ROOT_CONN_PROPS] = root_conn_props + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + props[ConnPropKey.CP_CONN_PROPS] = cp_conn_props + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + props[ConnPropKey.RELAY_CONN_PROPS] = relay_conn_props diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index 1d320dba91..4ccb663823 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -16,8 +16,9 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLMetaKey, ConnPropKey +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.fl_context import FLContext +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG from nvflare.fuel.data_event.utils import get_scope_property @@ -48,20 +49,7 @@ def __init__( def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABOUT_TO_START_RUN: components_data = self._export_all_components(fl_ctx) - - site_name = fl_ctx.get_identity_name() - auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") - - components_data[FLMetaKey.SITE_NAME] = site_name - components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() - components_data[FLMetaKey.AUTH_TOKEN] = auth_token - components_data[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature - - conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) - if conn_sec: - components_data[ConnPropKey.CONNECTION_SECURITY] = conn_sec - + update_export_props(components_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=components_data, config_file_path=config_file_path) @@ -77,5 +65,11 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() all_components = engine.get_all_components() components = {i: all_components.get(i) for i in self._component_ids} - reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID] + reserved_keys = [ + FLMetaKey.SITE_NAME, + FLMetaKey.JOB_ID, + ConnPropKey.CP_CONN_PROPS, + ConnPropKey.ROOT_CONN_PROPS, + ConnPropKey.RELAY_CONN_PROPS, + ] return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER) diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 4346242712..06369a9954 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -159,6 +159,15 @@ def get_heartbeat_timeout(self): def get_connection_security(self): return self.config.get(ConnPropKey.CONNECTION_SECURITY) + def get_root_conn_props(self): + return self.config.get(ConnPropKey.ROOT_CONN_PROPS) + + def get_cp_conn_props(self): + return self.config.get(ConnPropKey.CP_CONN_PROPS) + + def get_relay_conn_props(self): + return self.config.get(ConnPropKey.RELAY_CONN_PROPS) + def get_site_name(self): return self.config.get(FLMetaKey.SITE_NAME) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index cc68c55ff1..73e9328a9e 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -39,9 +39,18 @@ def _create_client_config(config: str) -> ClientConfig: raise ValueError(f"config should be a string but got: {type(config)}") site_name = client_config.get_site_name() - conn_sec = client_config.get_connection_security() - if conn_sec: - set_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) + + root_conn_props = client_config.get_root_conn_props() + if root_conn_props: + set_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = client_config.get_cp_conn_props() + if cp_conn_props: + set_scope_property(site_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) + + relay_conn_props = client_config.get_relay_conn_props() + if relay_conn_props: + set_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) # get message auth info and put them into Databus for CellPipe to use auth_token = client_config.get_auth_token() diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 8c43d22d86..064781476b 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -326,8 +326,6 @@ def __init__( self.max_msg_size = comm_configurator.get_max_message_size() self.comm_configurator = comm_configurator - self.logger.info(f"XXXX {parent_resources=} {parent_url=}") - err = FQCN.validate(fqcn) if err: raise ValueError(f"Invalid FQCN '{fqcn}': {err}") diff --git a/nvflare/fuel/f3/cellnet/net_manager.py b/nvflare/fuel/f3/cellnet/net_manager.py index 82a1a0d980..e25232f2d3 100644 --- a/nvflare/fuel/f3/cellnet/net_manager.py +++ b/nvflare/fuel/f3/cellnet/net_manager.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.stats_pool import VALID_HIST_MODES, parse_hist_mode @@ -31,6 +32,12 @@ class NetManager(CommandModule): def __init__(self, agent: NetAgent, diagnose=False): self.agent = agent self.diagnose = diagnose + data_bus = DataBus() + data_bus.subscribe(["stop_cellnet"], self._stop_cellnet) + + def _stop_cellnet(self, topic: str, conn: Connection, db: DataBus): + self.agent.stop() + conn.append_string("Cellnet Stopped") def get_spec(self) -> CommandModuleSpec: return CommandModuleSpec( diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index aac9a4ba37..8e8fbd221b 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -240,9 +240,6 @@ def add_connector_advanced( params[DriverParams.CONNECTION_SECURITY] = original_conn_sec params[DriverParams.SECURE] = secure - - log.info(f"$$$ add_connector_advanced: {params=} {mode=} {original_conn_sec=}") - handle = self.conn_manager.add_connector(driver, params, mode) if not start: diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 12f0fe0eb5..a649566647 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -57,7 +57,6 @@ def ssl_required(params: dict) -> bool: def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: if not ssl_required(params): params[DriverParams.IMPLEMENTED_CONN_SEC.value] = "clear" - log.info("get_ssl_context: XXX clear") return None conn_security = params.get(DriverParams.CONNECTION_SECURITY.value, ConnectionSecurity.MTLS) @@ -88,7 +87,6 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: # no custom CA cert: use provisioned CA cert ca_path = params.get(DriverParams.CA_CERT.value) params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client TLS: Flare CA Cert used" - log.info("get_ssl_context: XXX 1-way SSL") cert_path = None key_path = None else: @@ -97,7 +95,6 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: cert_path = params.get(DriverParams.CLIENT_CERT.value) key_path = params.get(DriverParams.CLIENT_KEY.value) params[DriverParams.IMPLEMENTED_CONN_SEC] = "Client mTLS: Flare credentials used" - log.info("get_ssl_context: XXX 2-way SSL") if not ca_path: scheme = params.get(DriverParams.SCHEME.value, "Unknown") diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index a829601872..d8f59aaf6c 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -187,9 +187,7 @@ def load_configuration(cls, file_basename: str) -> Optional[Config]: Returns: config data loaded, or None if the config file is not found. """ - cls.logger.info(f"XXXX load_configuration: {file_basename=} {cls._config_path=}") result = ConfigFactory.load_config(file_basename, cls._config_path) - cls.logger.info(f"XXXX loaded: {result}") return result @classmethod diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 3150bd9386..376443ed8f 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -22,6 +22,7 @@ from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.cell import Message as CellMessage from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.utils import make_reply from nvflare.fuel.f3.drivers.driver_params import DriverParams @@ -44,12 +45,16 @@ _HEADER_HB_SEQ = _PREFIX + "hb_seq" -def _cell_fqcn(mode, site_name, token): +def _cell_fqcn(mode, site_name, token, parent_fqcn): # The FQCN of the cell must be unique in the whole cellnet. # We use the combination of mode, site_name, and token to derive the value of FQCN # Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another. # The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes. - return f"{site_name}_{token}_{mode}" + base = f"{site_name}_{token}_{mode}" + if parent_fqcn == FQCN.ROOT_SERVER: + return base + else: + return FQCN.join([parent_fqcn, base]) def _to_cell_message(msg: Message, extra=None) -> CellMessage: @@ -77,8 +82,11 @@ class _CellInfo: A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging). """ - def __init__(self, cell, net_agent): + def __init__(self, site_name, cell, net_agent, auth_token, token_signature): + self.site_name = site_name self.cell = cell + self.auth_token = auth_token + self.token_signature = token_signature self.net_agent = net_agent self.started = False self.pipes = [] @@ -114,21 +122,15 @@ class CellPipe(Pipe): _lock = threading.Lock() _cells_info = {} # (root_url, site_name, token) => _CellInfo - _auth_token = None - _token_signature = None - _site_name = None @classmethod - def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir): + def _build_cell(cls, site_name, fqcn, parent_conn_props, secure_mode, workspace_dir, logger): """Build a cell if necessary. The combination of (root_url, site_name, token) uniquely determine one cell. There can be multiple pipes on the same cell. Args: - root_url: root url of the cell net - mode: mode (passive or active) of the pipe - site_name: name of the site - token: the unique token + parent_conn_props: parent for this cell secure_mode: whether cellnet is in secure mode workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode @@ -136,11 +138,8 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di """ with cls._lock: - cls._site_name = site_name - cell_key = f"{root_url}.{site_name}.{token}" - ci = cls._cells_info.get(cell_key) + ci = cls._cells_info.get(fqcn) if not ci: - credentials = {} if secure_mode: root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir) if not root_cert_path: @@ -149,35 +148,58 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di credentials = { DriverParams.CA_CERT.value: root_cert_path, } + else: + credentials = {} + + conn_sec = parent_conn_props.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + + parent_url = parent_conn_props.get(ConnPropKey.URL) - conn_sec = get_scope_property(site_name, ConnPropKey.CONNECTION_SECURITY) - if conn_sec: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + if FQCN.get_parent(fqcn): + # the cell has a parent: connect to the parent + cell_root = None + cell_parent_url = parent_url + else: + # the cell has no parent: the parent_url is the root of the cellnet + cell_root = parent_url + cell_parent_url = None cell = Cell( - fqcn=_cell_fqcn(mode, site_name, token), - root_url=root_url, + fqcn=fqcn, + root_url=cell_root, secure=secure_mode, credentials=credentials, + parent_url=cell_parent_url, create_internal_listener=False, ) - # set filter to add additional auth headers - cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers) - cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers) + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + token_signature = get_scope_property(site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") net_agent = NetAgent(cell) - ci = _CellInfo(cell, net_agent) - cls._cells_info[cell_key] = ci + ci = _CellInfo(site_name, cell, net_agent, auth_token, token_signature) + cls._cells_info[fqcn] = ci + + # set filter to add additional auth headers + cell.core_cell.add_outgoing_reply_filter( + channel="*", + topic="*", + cb=cls._add_auth_headers, + ci=ci, + ) + cell.core_cell.add_outgoing_request_filter( + channel="*", + topic="*", + cb=cls._add_auth_headers, + ci=ci, + ) return ci @classmethod - def _add_auth_headers(cls, message: CellMessage): - if not cls._auth_token: - cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") - - add_authentication_headers(message, cls._site_name, cls._auth_token, cls._token_signature) + def _add_auth_headers(cls, message: CellMessage, ci: _CellInfo): + add_authentication_headers(message, ci.site_name, ci.auth_token, ci.token_signature) def __init__( self, @@ -203,9 +225,9 @@ def __init__( self.site_name = site_name self.token = token - self.root_url = root_url self.secure_mode = secure_mode self.workspace_dir = workspace_dir + self.root_url = root_url # this section is needed by job config to prevent building cell when using SystemVarName arguments # TODO: enhance this part @@ -219,8 +241,49 @@ def __init__( check_str("site_name", site_name) check_str("workspace_dir", workspace_dir) + # determine the endpoint for this pipe to connect to + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + + if root_conn_props: + # Not in simulator + if not isinstance(root_conn_props, dict): + raise RuntimeError(f"expect root_conn_props for {site_name} to be dict but got {type(root_conn_props)}") + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + if not isinstance(cp_conn_props, dict): + raise RuntimeError(f"expect cp_conn_props to be dict but got {type(cp_conn_props)}") + + url_to_conns = { + root_conn_props.get(ConnPropKey.URL): root_conn_props, + cp_conn_props.get(ConnPropKey.URL): cp_conn_props, + } + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + if not isinstance(relay_conn_props, dict): + raise RuntimeError(f"expect relay_conn_props to be dict but got {type(relay_conn_props)}") + url_to_conns[relay_conn_props.get(ConnPropKey.URL)] = relay_conn_props + + if not root_url: + # root_url not specified - use CP! + root_url = cp_conn_props.get(ConnPropKey.URL) + self.root_url = root_url + + conn_props = url_to_conns.get(self.root_url) + if not conn_props: + raise RuntimeError(f"cannot determine conn props for '{root_url}'") + else: + # this is running in simulator + conn_props = { + ConnPropKey.URL: root_url, + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + } + mode = f"{mode}".strip().lower() # convert to lower case string - self.ci = self._build_cell(mode, root_url, site_name, token, secure_mode, workspace_dir) + fqcn = _cell_fqcn(mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) + + self.ci = self._build_cell(site_name, fqcn, conn_props, secure_mode, workspace_dir, self.logger) self.cell = self.ci.cell self.ci.add_pipe(self) @@ -231,7 +294,7 @@ def __init__( else: raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'") - self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token) + self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) self.received_msgs = queue.Queue() # contains Message(s), not CellMessage(s)! self.channel = None # the cellnet message channel self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed @@ -366,16 +429,14 @@ def close(self): def export(self, export_mode: str) -> Tuple[str, dict]: if export_mode == ExportMode.SELF: mode = self.mode - root_url = self.root_url else: mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE - root_url = self.cell.get_root_url_for_child() export_args = { "mode": mode, "site_name": self.site_name, "token": self.token, - "root_url": root_url, + "root_url": self.root_url, "secure_mode": self.cell.core_cell.secure, "workspace_dir": self.workspace_dir, } diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index d39f65db61..7b3ca4f40b 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -14,6 +14,7 @@ from typing import Optional, Type, Union +from nvflare.apis.fl_constant import SystemVarName from nvflare.app_common.abstract.launcher import Launcher from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor @@ -22,8 +23,9 @@ from nvflare.app_common.widgets.metric_relay import MetricRelay from nvflare.client.config import ExchangeFormat, TransferType from nvflare.fuel.utils.import_utils import optional_import -from nvflare.fuel.utils.pipe.cell_pipe import CellPipe +from nvflare.fuel.utils.pipe.cell_pipe import CellPipe, Mode from nvflare.fuel.utils.pipe.pipe import Pipe +from nvflare.fuel.utils.validation_utils import check_str from .api import FedJob, validate_object_for_job @@ -35,6 +37,19 @@ class FrameworkType: TENSORFLOW = "tensorflow" +class PipeConnectType: + VIA_ROOT = "via_root" + VIA_CP = "via_cp" + VIA_RELAY = "via_relay" + + +_PIPE_CONNECT_URL = { + PipeConnectType.VIA_CP: "{" + SystemVarName.CP_URL + "}", + PipeConnectType.VIA_RELAY: "{" + SystemVarName.RELAY_URL + "}", + PipeConnectType.VIA_ROOT: "{" + SystemVarName.ROOT_URL + "}", +} + + class BaseScriptRunner: def __init__( self, @@ -52,6 +67,7 @@ def __init__( launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, metric_pipe: Optional[Pipe] = None, + pipe_connect_type: str = None, ): """BaseScriptRunner is used with FedJob API to run or launch a script. @@ -102,6 +118,12 @@ def __init__( metric_pipe (Optional[Pipe], optional): An optional Pipe instance for passing metric data between components. This allows for real-time metric handling during execution. Defaults to `None`. + + pipe_connect_type: how pipe peers are to be connected: + Via Root: peers are both connected to the root of the cellnet + Via Relay: peers are both connected to the relay if a relay is used; otherwise via root. + Via CP: peers are both connected to the CP + If not specified, will be via CP. """ self._script = script self._script_args = script_args @@ -112,6 +134,7 @@ def __init__( self._params_exchange_format = params_exchange_format self._from_nvflare_converter_id = from_nvflare_converter_id self._to_nvflare_converter_id = to_nvflare_converter_id + self._pipe_connect_type = pipe_connect_type if self._framework == FrameworkType.PYTORCH: _, torch_ok = optional_import(module="torch") @@ -151,12 +174,35 @@ def __init__( elif executor is not None: validate_object_for_job("executor", executor, InProcessClientAPIExecutor) + if pipe_connect_type: + check_str("pipe_connect_type", pipe_connect_type) + valid_connect_types = [PipeConnectType.VIA_CP, PipeConnectType.VIA_RELAY, PipeConnectType.VIA_RELAY] + if pipe_connect_type not in valid_connect_types: + raise ValueError(f"invalid pipe_connect_type '{pipe_connect_type}': must be {valid_connect_types}") + self._metric_pipe = metric_pipe self._metric_relay = metric_relay self._task_pipe = task_pipe self._executor = executor self._launcher = launcher + def _create_cell_pipe(self): + ct = self._pipe_connect_type + if not ct: + ct = PipeConnectType.VIA_CP + conn_url = _PIPE_CONNECT_URL.get(ct) + if not conn_url: + raise RuntimeError(f"cannot determine pipe connect url for {self._pipe_connect_type}") + + return CellPipe( + mode=Mode.PASSIVE, + site_name="{" + SystemVarName.SITE_NAME + "}", + token="{" + SystemVarName.JOB_ID + "}", + root_url=conn_url, + secure_mode="{" + SystemVarName.SECURE_MODE + "}", + workspace_dir="{" + SystemVarName.WORKSPACE + "}", + ) + def add_to_fed_job(self, job: FedJob, ctx, **kwargs): """This method is used by Job API. @@ -172,18 +218,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): comp_ids = {} if self._launch_external_process: - task_pipe = ( - self._task_pipe - if self._task_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + task_pipe = self._task_pipe if self._task_pipe else self._create_cell_pipe() task_pipe_id = job.add_component("pipe", task_pipe, ctx) comp_ids["pipe_id"] = task_pipe_id @@ -211,18 +246,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): ) job.add_executor(executor, tasks=tasks, ctx=ctx) - metric_pipe = ( - self._metric_pipe - if self._metric_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + metric_pipe = self._metric_pipe if self._metric_pipe else self._create_cell_pipe() metric_pipe_id = job.add_component("metrics_pipe", metric_pipe, ctx) comp_ids["metric_pipe_id"] = metric_pipe_id @@ -295,6 +319,7 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_exchange_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: str = TransferType.FULL, + pipe_connect_type: str = PipeConnectType.VIA_CP, ): """ScriptRunner is used with FedJob API to run or launch a script. @@ -310,6 +335,7 @@ def __init__( params_exchange_format (str): The format to exchange the parameters. Defaults to ExchangeFormat.NUMPY. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. + pipe_connect_type (str): how pipe peers are to be connected """ super().__init__( script=script, @@ -319,4 +345,5 @@ def __init__( framework=framework, params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, + pipe_connect_type=pipe_connect_type, ) diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index 30c8306511..4aa8bea5a6 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -155,7 +155,8 @@ def parse_arguments(): parser.add_argument("--sp_scheme", "-scheme", type=str, help="Sp connection scheme", required=True) parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True) parser.add_argument( - "--parent_conn_sec", "-pcs", + "--parent_conn_sec", + "-pcs", type=str, help="parent conn security", required=False, diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 25d758a6ca..acc8c969a4 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -323,8 +323,8 @@ def _determine_conn_props(self, client_name, config_data: dict): if relay_fqcn: relay_conn_props = { ConnPropKey.FQCN: relay_fqcn, - ConnPropKey.URL: relay_url, - ConnPropKey.CONNECTION_SECURITY: relay_conn_security + ConnPropKey.URL: relay_url, + ConnPropKey.CONNECTION_SECURITY: relay_conn_security, } set_scope_property(client_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) @@ -338,7 +338,7 @@ def _determine_conn_props(self, client_name, config_data: dict): root_conn_props = { ConnPropKey.FQCN: FQCN.ROOT_SERVER, ConnPropKey.URL: root_url, - ConnPropKey.CONNECTION_SECURITY: client.get(ConnPropKey.CONNECTION_SECURITY) + ConnPropKey.CONNECTION_SECURITY: client.get(ConnPropKey.CONNECTION_SECURITY), } set_scope_property(client_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py index ec390ef06a..2996066e14 100644 --- a/nvflare/private/fed/app/relay/relay.py +++ b/nvflare/private/fed/app/relay/relay.py @@ -14,6 +14,7 @@ import argparse import json +import logging import os import sys import threading @@ -31,6 +32,18 @@ from nvflare.fuel.utils.url_utils import make_url +class CellnetMonitor: + def __init__(self, stop_event: threading.Event, workspace: str): + self.stop_event = stop_event + self.workspace = workspace + + def cellnet_stopped(self): + touch_file = os.path.join(self.workspace, WorkspaceConstants.SHUTDOWN_FILE) + with open(touch_file, "a"): + os.utime(touch_file, None) + self.stop_event.set() + + def parse_arguments(): parser = argparse.ArgumentParser() parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) @@ -80,7 +93,10 @@ def main(args): configure_logging(workspace, workspace.get_root_dir()) + logger = logging.getLogger() + stop_event = threading.Event() + monitor = CellnetMonitor(stop_event, args.workspace) ConfigService.initialize( section_files={}, @@ -119,12 +135,14 @@ def main(args): create_internal_listener=True, parent_url=parent_url, ) - net_agent = NetAgent(cell) + NetAgent(cell, agent_closed_cb=monitor.cellnet_stopped) cell.start() # wait until stopped - print(f"started relay {my_identity=} {my_fqcn=}") + logger.info(f"Started relay {my_identity=} {my_fqcn=} {root_url=} {parent_url=} {parent_fqcn=}") stop_event.wait() + cell.stop() + logger.info(f"Relay stopped.") if __name__ == "__main__": diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index 8b4d85e999..c1c178d20b 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs, ConnPropKey +from nvflare.apis.fl_constant import AdminCommandNames, ConnPropKey, FLContextKey, RunProcessKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.resource_manager_spec import ResourceManagerSpec @@ -203,7 +203,6 @@ def start_app( } params = client.cell.get_internal_listener_params() - self.logger.info(f"XXXX get_internal_listener_params: {params} ") if params: parent_conn_sec = params.get(ConnPropKey.CONNECTION_SECURITY) if parent_conn_sec: diff --git a/nvflare/private/fed/client/client_json_config.py b/nvflare/private/fed/client/client_json_config.py index acb66a3b05..2bc028ee50 100644 --- a/nvflare/private/fed/client/client_json_config.py +++ b/nvflare/private/fed/client/client_json_config.py @@ -16,8 +16,9 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import SystemConfigs, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, SystemConfigs, SystemVarName from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -64,11 +65,27 @@ def __init__( sp_target = args.sp_target sp_url = f"{sp_scheme}://{sp_target}" + # determine relay URL + # if relay is not used, use the root URL as relay URL + relay_conn_props = get_scope_property(args.client_name, ConnPropKey.RELAY_CONN_PROPS) + relay_url = None + if relay_conn_props: + relay_url = relay_conn_props.get(ConnPropKey.URL) + if not relay_url: + relay_url = sp_url + + if hasattr(args, "parent_url") and args.parent_url: + parent_url = args.parent_url + else: + parent_url = sp_url + sys_vars = { SystemVarName.JOB_ID: args.job_id, SystemVarName.SITE_NAME: args.client_name, SystemVarName.WORKSPACE: args.workspace, SystemVarName.ROOT_URL: sp_url, + SystemVarName.CP_URL: parent_url, + SystemVarName.RELAY_URL: relay_url, SystemVarName.SECURE_MODE: self.cmd_vars.get("secure_train", True), SystemVarName.JOB_CUSTOM_DIR: workspace_obj.get_app_custom_dir(args.job_id), } diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 99ef38022c..14215bed55 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -208,9 +208,7 @@ def _create_cell(self, location, scheme): parent_conn_sec = cp_conn_props.get(ConnPropKey.CONNECTION_SECURITY) create_internal_listener = False if parent_conn_sec: - parent_resources = { - DriverParams.CONNECTION_SECURITY.value: parent_conn_sec - } + parent_resources = {DriverParams.CONNECTION_SECURITY.value: parent_conn_sec} else: # I am CP me = "CP" @@ -219,9 +217,7 @@ def _create_cell(self, location, scheme): create_internal_listener = True relay_conn_security = relay_conn_props.get(ConnPropKey.CONNECTION_SECURITY) if relay_conn_security: - parent_resources = { - DriverParams.CONNECTION_SECURITY.value: relay_conn_security - } + parent_resources = {DriverParams.CONNECTION_SECURITY.value: relay_conn_security} if self.secure_train: root_cert = self.client_args[SecureTrainConst.SSL_ROOT_CERT] diff --git a/nvflare/private/fed/server/training_cmds.py b/nvflare/private/fed/server/training_cmds.py index 8ad9be7b67..7d31af86e0 100644 --- a/nvflare/private/fed/server/training_cmds.py +++ b/nvflare/private/fed/server/training_cmds.py @@ -18,6 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.fl_constant import AdminCommandNames, SiteType +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.hci.conn import Connection from nvflare.fuel.hci.proto import ConfirmMethod, MetaKey, MetaStatusValue, make_meta from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec @@ -160,6 +161,12 @@ def shutdown(self, conn: Connection, args: List[str]): conn.update_meta(make_meta(MetaStatusValue.ERROR, "failed to shut down all clients")) return + if target_type in [self.TARGET_TYPE_ALL]: + # shutdown the cellnet + data_bus = DataBus() + data_bus.publish(["stop_cellnet"], conn) + # time.sleep(2.0) + if target_type in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]: # shut down the server err = self._shutdown_app_on_server(conn) From 0bd9bb03799672529fa5f8bd78559d3424a3a871 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 10:10:18 -0500 Subject: [PATCH 05/11] fix formatting --- nvflare/app_common/widgets/external_configurator.py | 1 - nvflare/private/fed/app/relay/relay.py | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index 4ccb663823..648dba64eb 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -21,7 +21,6 @@ from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG -from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components from nvflare.fuel.utils.validation_utils import check_object_type from nvflare.widgets.widget import Widget diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py index 2996066e14..a881e22e62 100644 --- a/nvflare/private/fed/app/relay/relay.py +++ b/nvflare/private/fed/app/relay/relay.py @@ -142,7 +142,7 @@ def main(args): logger.info(f"Started relay {my_identity=} {my_fqcn=} {root_url=} {parent_url=} {parent_fqcn=}") stop_event.wait() cell.stop() - logger.info(f"Relay stopped.") + logger.info(f"Relay {my_fqcn} stopped.") if __name__ == "__main__": From a36dba89e532f1514e1ff717638b3961ff18b6e2 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 10:23:06 -0500 Subject: [PATCH 06/11] add comm_config_utils --- nvflare/fuel/f3/comm_config_utils.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 nvflare/fuel/f3/comm_config_utils.py diff --git a/nvflare/fuel/f3/comm_config_utils.py b/nvflare/fuel/f3/comm_config_utils.py new file mode 100644 index 0000000000..6a9097ef73 --- /dev/null +++ b/nvflare/fuel/f3/comm_config_utils.py @@ -0,0 +1,26 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def requires_secure(resources: dict): + conn_sec = resources.get(DriverParams.CONNECTION_SECURITY) + if conn_sec: + if conn_sec == ConnectionSecurity.INSECURE: + return False + else: + return True + else: + return resources.get(DriverParams.SECURE) From 151864dd1ad9d6ed1b29f56558ace46641495748 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 10:43:20 -0500 Subject: [PATCH 07/11] fix unit test --- tests/unit_test/client/in_process/api_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/unit_test/client/in_process/api_test.py b/tests/unit_test/client/in_process/api_test.py index 835bd80b5d..a873d78499 100644 --- a/tests/unit_test/client/in_process/api_test.py +++ b/tests/unit_test/client/in_process/api_test.py @@ -65,8 +65,10 @@ def test_init_with_custom_interval(self): def test_init_subscriptions(self): client_api = InProcessClientAPI(self.task_metadata) xs = list(client_api.data_bus.subscribers.keys()) - xs.sort() - assert xs == [TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP] + + # Depending on the timing of this test, the data bus may have other subscribed topics + # since the data bus is a singleton! + assert set(xs).issuperset([TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP]) def local_result_callback(self, data, topic): pass From 3d79c7ae1fdd46494ddb710198b8cbbdff4a601e Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 12:36:45 -0500 Subject: [PATCH 08/11] change log level --- nvflare/private/fed/server/server_state.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nvflare/private/fed/server/server_state.py b/nvflare/private/fed/server/server_state.py index 5b0ba3ca9e..fa2a8cd414 100644 --- a/nvflare/private/fed/server/server_state.py +++ b/nvflare/private/fed/server/server_state.py @@ -91,11 +91,11 @@ def aux_communicate(self, fl_ctx: FLContext) -> dict: def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: if sp: - self.logger.info( + self.logger.debug( f"handle_sd_callback Got SP: {sp.name=} {sp.fl_port=} {sp.primary=} {self.host=} {self.service_port=}" ) else: - self.logger.info("handle_sd_callback no SP!") + self.logger.debug("handle_sd_callback no SP!") if sp and sp.primary is True: if sp.name == self.host and sp.fl_port == self.service_port: From b50aefa2089b89ddf2a72eb868041324f5f548a4 Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 14:54:31 -0500 Subject: [PATCH 09/11] add test cases --- nvflare/fuel/f3/comm_config_utils.py | 14 ++- nvflare/fuel/f3/drivers/aio_grpc_driver.py | 4 +- nvflare/fuel/f3/drivers/aio_http_driver.py | 4 +- nvflare/fuel/f3/drivers/grpc_driver.py | 4 +- nvflare/fuel/f3/drivers/tcp_driver.py | 4 +- nvflare/fuel/utils/url_utils.py | 39 +++++++- .../fuel/f3/comm_config_utils_test.py | 90 +++++++++++++++++++ tests/unit_test/fuel/utils/url_utils_test.py | 66 ++++++++++++++ 8 files changed, 211 insertions(+), 14 deletions(-) create mode 100644 tests/unit_test/fuel/f3/comm_config_utils_test.py create mode 100644 tests/unit_test/fuel/utils/url_utils_test.py diff --git a/nvflare/fuel/f3/comm_config_utils.py b/nvflare/fuel/f3/comm_config_utils.py index 6a9097ef73..494d320646 100644 --- a/nvflare/fuel/f3/comm_config_utils.py +++ b/nvflare/fuel/f3/comm_config_utils.py @@ -15,12 +15,22 @@ from nvflare.fuel.f3.drivers.driver_params import DriverParams -def requires_secure(resources: dict): +def requires_secure_connection(resources: dict): + """Determine whether secure connection is required based on information in resources. + + Args: + resources: a dict that contains info for making connection + + Returns: whether secure connection is required + + """ conn_sec = resources.get(DriverParams.CONNECTION_SECURITY) if conn_sec: + # if connection security is specified, it takes precedence over the "secure" flag if conn_sec == ConnectionSecurity.INSECURE: return False else: return True else: - return resources.get(DriverParams.SECURE) + # Connection security is not specified, check the "secure" flag. + return resources.get(DriverParams.SECURE, False) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 55deec90d2..65544a2f84 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -22,7 +22,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers.aio_context import AioContext @@ -410,7 +410,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "grpcs" diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index cfcd7e2121..383f95d7cb 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -18,7 +18,7 @@ import websockets from websockets.exceptions import ConnectionClosedOK -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers import net_utils @@ -121,7 +121,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: scheme = "https" diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 764a61772e..bc37e54fce 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -20,7 +20,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver import ConnectorInfo @@ -275,7 +275,7 @@ def connect(self, connector: ConnectorInfo): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "nagrpcs" diff --git a/nvflare/fuel/f3/drivers/tcp_driver.py b/nvflare/fuel/f3/drivers/tcp_driver.py index 6ee11a4f4f..09256bca88 100644 --- a/nvflare/fuel/f3/drivers/tcp_driver.py +++ b/nvflare/fuel/f3/drivers/tcp_driver.py @@ -17,7 +17,7 @@ from socketserver import TCPServer, ThreadingTCPServer from typing import Any, Dict, List -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.drivers.base_driver import BaseDriver from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams @@ -101,7 +101,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: scheme = "stcp" diff --git a/nvflare/fuel/utils/url_utils.py b/nvflare/fuel/utils/url_utils.py index 2219ebe424..38af73b705 100644 --- a/nvflare/fuel/utils/url_utils.py +++ b/nvflare/fuel/utils/url_utils.py @@ -16,22 +16,53 @@ def make_url(scheme: str, address, secure: bool) -> str: + """Make a full URL based on specified info + + Args: + scheme: scheme of the url + address: host address. Multiple formats are supported: + str: this is a string that contains host name and optionally port number (e.g. localhost:1234) + dict: contains item "host" and optionally "port" + tuple or list: contains 1 or 2 items for host and port + secure: whether secure connection is required + + Returns: + + """ + secure_scheme = _SECURE_SCHEME_MAPPING.get(scheme) + if not secure_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + if secure: - scheme = _SECURE_SCHEME_MAPPING.get(scheme) - if not scheme: - raise ValueError(f"unsupported scheme '{scheme}'") + scheme = secure_scheme if isinstance(address, str): + if not address: + raise ValueError("address must not be empty") return f"{scheme}://{address}" else: port = None if isinstance(address, (tuple, list)): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") host = address[0] if len(address) > 1: port = address[1] elif isinstance(address, dict): - host = address["host"] + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + + host = address.get("host") + if not host: + raise ValueError(f"invalid address {address}: missing 'host'") + port = address.get("port") + if not port and len(address) > 1: + raise ValueError(f"invalid address {address}: missing 'port'") else: raise ValueError(f"invalid address: {address}") diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py new file mode 100644 index 0000000000..c703d9d4c3 --- /dev/null +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +class TestCommConfigUtils: + + @pytest.mark.parametrize( + "resources, expected", + [ + ({}, False), + ({"x": 1, "y": 2}, False), + ({DriverParams.SECURE.value: True}, True), + ({DriverParams.SECURE.value: False}, False), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE}, False), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS}, True), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS}, True), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: False, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: False, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: True, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: True, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + DriverParams.SECURE.value: True, + }, + False + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + }, + False + ), + ], + ) + def test_requires_secure_connection(self, resources, expected): + result = requires_secure_connection(resources) + assert result == expected diff --git a/tests/unit_test/fuel/utils/url_utils_test.py b/tests/unit_test/fuel/utils/url_utils_test.py new file mode 100644 index 0000000000..e4c338ffbc --- /dev/null +++ b/tests/unit_test/fuel/utils/url_utils_test.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.fuel.utils.url_utils import make_url + + +class TestUrlUtils: + + @pytest.mark.parametrize( + "scheme, address, secure, expected", + [ + ("tcp", "xyz.com", False, "tcp://xyz.com"), + ("tcp", "xyz.com:1234", False, "tcp://xyz.com:1234"), + ("tcp", "xyz.com:1234", True, "stcp://xyz.com:1234"), + ("grpc", "xyz.com", False, "grpc://xyz.com"), + ("grpc", "xyz.com:1234", False, "grpc://xyz.com:1234"), + ("grpc", "xyz.com:1234", True, "grpcs://xyz.com:1234"), + ("http", "xyz.com", False, "http://xyz.com"), + ("http", "xyz.com:1234", False, "http://xyz.com:1234"), + ("http", "xyz.com:1234", True, "https://xyz.com:1234"), + ("tcp", ("xyz.com",), False, "tcp://xyz.com"), + ("tcp", ("xyz.com", 1234), False, "tcp://xyz.com:1234"), + ("tcp", ["xyz.com"], False, "tcp://xyz.com"), + ("tcp", ["xyz.com", 1234], False, "tcp://xyz.com:1234"), + ("tcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("tcp", {"host": "xyz.com", "port": 1234}, False, "tcp://xyz.com:1234"), + ], + ) + def test_make_url(self, scheme, address, secure, expected): + result = make_url(scheme, address, secure) + assert result == expected + + @pytest.mark.parametrize( + "scheme, address, secure", + [ + ("tcp", "", False), + ("abc", "xyz.com:1234", False), + ("tcp", 1234, True), + ("grpc", [], False), + ("grpc", (), False), + ("grpc", {}, True), + ("http", [1234], False), + ("http", [1234, "xyz.com"], False), + ("http", ["xyz.com", 1234, 22], True), + ("http", (1234,), False), + ("http", (1234, "xyz.com"), False), + ("http", ("xyz.com", 1234, 22), True), + ("tcp", {"hosts": "xyz.com"}, False), + ("tcp", {"host": "xyz.com", "port": 1234, "extra": 2323}, False), + ], + ) + def test_make_url_error(self, scheme, address, secure): + with pytest.raises(ValueError): + make_url(scheme, address, secure) From ac4cc5aee41b8ab43033073aebf8302f2abc074d Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 15:14:06 -0500 Subject: [PATCH 10/11] reformat --- .../fuel/f3/comm_config_utils_test.py | 74 +++++++++---------- 1 file changed, 37 insertions(+), 37 deletions(-) diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py index c703d9d4c3..db6aba8475 100644 --- a/tests/unit_test/fuel/f3/comm_config_utils_test.py +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -31,57 +31,57 @@ class TestCommConfigUtils: ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS}, True), ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS}, True), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, - DriverParams.SECURE.value: False, - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: False, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, - DriverParams.SECURE.value: False, - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: False, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, - DriverParams.SECURE.value: True, - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: True, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, - DriverParams.SECURE.value: True, - }, - True + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: True, + }, + True, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, - DriverParams.SECURE.value: True, - }, - False + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + DriverParams.SECURE.value: True, + }, + False, ), ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, - }, - False + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + }, + False, ), ], ) From a5d547435950ff6ae783ef2ead68c211b475a4aa Mon Sep 17 00:00:00 2001 From: Yan Cheng Date: Fri, 31 Jan 2025 15:26:34 -0500 Subject: [PATCH 11/11] reformat --- .../fuel/f3/comm_config_utils_test.py | 71 +++---------------- 1 file changed, 11 insertions(+), 60 deletions(-) diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py index db6aba8475..61e83f799a 100644 --- a/tests/unit_test/fuel/f3/comm_config_utils_test.py +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -13,9 +13,7 @@ # limitations under the License. import pytest -from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_config_utils import requires_secure_connection -from nvflare.fuel.f3.drivers.driver_params import DriverParams class TestCommConfigUtils: @@ -25,64 +23,17 @@ class TestCommConfigUtils: [ ({}, False), ({"x": 1, "y": 2}, False), - ({DriverParams.SECURE.value: True}, True), - ({DriverParams.SECURE.value: False}, False), - ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE}, False), - ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS}, True), - ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS}, True), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, - DriverParams.SECURE.value: False, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, - DriverParams.SECURE.value: False, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, - DriverParams.SECURE.value: True, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, - DriverParams.SECURE.value: True, - }, - True, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, - DriverParams.SECURE.value: True, - }, - False, - ), - ( - { - DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, - }, - False, - ), + ({"secure": True}, True), + ({"secure": False}, False), + ({"connection_security": "insecure"}, False), + ({"connection_security": "tls"}, True), + ({"connection_security": "mtls"}, True), + ({"connection_security": "mtls", "secure": False}, True), + ({"connection_security": "mtls", "secure": True}, True), + ({"connection_security": "tls", "secure": False}, True), + ({"connection_security": "tls", "secure": True}, True), + ({"connection_security": "insecure", "secure": False}, False), + ({"connection_security": "insecure", "secure": True}, False), ], ) def test_requires_secure_connection(self, resources, expected):