diff --git a/nvflare/apis/fl_constant.py b/nvflare/apis/fl_constant.py index b45cb58582..cafc8cb2e8 100644 --- a/nvflare/apis/fl_constant.py +++ b/nvflare/apis/fl_constant.py @@ -443,7 +443,6 @@ class SecureTrainConst: SSL_ROOT_CERT = "ssl_root_cert" SSL_CERT = "ssl_cert" PRIVATE_KEY = "ssl_private_key" - CONNECTION_SECURITY = "connection_security" class FLMetaKey: @@ -467,6 +466,7 @@ class FLMetaKey: class CellMessageAuthHeaderKey: CLIENT_NAME = "client_name" + SSID = "ssid" TOKEN = "__token__" TOKEN_SIGNATURE = "__token_signature__" @@ -542,6 +542,8 @@ class SystemVarName: WORKSPACE = "WORKSPACE" # directory of the workspace JOB_ID = "JOB_ID" # Job ID ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server) + CP_URL = "CP_URL" # URL to CP + RELAY_URL = "RELAY_URL" # URL to relay that the CP is connected to SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job PYTHONPATH = "PYTHONPATH" @@ -552,3 +554,27 @@ class RunnerTask: INIT = "init" TASK_EXEC = "task_exec" END_RUN = "end_run" + + +class ConnPropKey: + + PROJECT_NAME = "project_name" + SERVER_IDENTITY = "server_identity" + IDENTITY = "identity" + PARENT = "parent" + FQCN = "fqcn" + URL = "url" + SCHEME = "scheme" + ADDRESS = "address" + CONNECTION_SECURITY = "connection_security" + + RELAY_CONFIG = "relay_config" + CP_CONN_PROPS = "cp_conn_props" + RELAY_CONN_PROPS = "relay_conn_props" + ROOT_CONN_PROPS = "root_conn_props" + + +class ConnectionSecurity: + CLEAR = "clear" + TLS = "tls" + MTLS = "mtls" diff --git a/nvflare/apis/job_launcher_spec.py b/nvflare/apis/job_launcher_spec.py index 36edc62380..cb75130e6a 100644 --- a/nvflare/apis/job_launcher_spec.py +++ b/nvflare/apis/job_launcher_spec.py @@ -32,6 +32,7 @@ class JobProcessArgs: CLIENT_NAME = "client_name" ROOT_URL = "root_url" PARENT_URL = "parent_url" + PARENT_CONN_SEC = "parent_conn_sec" SERVICE_HOST = "service_host" SERVICE_PORT = "service_port" HA_MODE = "ha_mode" diff --git a/nvflare/app_common/executors/client_api_launcher_executor.py b/nvflare/app_common/executors/client_api_launcher_executor.py index 3b470edf2c..35a99e0fe0 100644 --- a/nvflare/app_common/executors/client_api_launcher_executor.py +++ b/nvflare/app_common/executors/client_api_launcher_executor.py @@ -15,13 +15,12 @@ import os from typing import Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst from nvflare.apis.fl_context import FLContext from nvflare.app_common.app_constant import AppConstants from nvflare.app_common.executors.launcher_executor import LauncherExecutor +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG -from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.attributes_exportable import ExportMode @@ -126,22 +125,11 @@ def prepare_config_for_launch(self, fl_ctx: FLContext): ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout, } - site_name = fl_ctx.get_identity_name() - auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") - config_data = { ConfigKey.TASK_EXCHANGE: task_exchange_attributes, - FLMetaKey.SITE_NAME: site_name, - FLMetaKey.JOB_ID: fl_ctx.get_job_id(), - FLMetaKey.AUTH_TOKEN: auth_token, - FLMetaKey.AUTH_TOKEN_SIGNATURE: signature, } - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) - if conn_sec: - config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec - + update_export_props(config_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=config_data, config_file_path=config_file_path) diff --git a/nvflare/app_common/utils/export_utils.py b/nvflare/app_common/utils/export_utils.py new file mode 100644 index 0000000000..ba52632c9c --- /dev/null +++ b/nvflare/app_common/utils/export_utils.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey +from nvflare.apis.fl_context import FLContext +from nvflare.fuel.data_event.utils import get_scope_property + + +def update_export_props(props: dict, fl_ctx: FLContext): + site_name = fl_ctx.get_identity_name() + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + + props[FLMetaKey.SITE_NAME] = site_name + props[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() + props[FLMetaKey.AUTH_TOKEN] = auth_token + props[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature + + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + if root_conn_props: + props[ConnPropKey.ROOT_CONN_PROPS] = root_conn_props + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + props[ConnPropKey.CP_CONN_PROPS] = cp_conn_props + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + props[ConnPropKey.RELAY_CONN_PROPS] = relay_conn_props diff --git a/nvflare/app_common/widgets/external_configurator.py b/nvflare/app_common/widgets/external_configurator.py index b9b8c5b14d..648dba64eb 100644 --- a/nvflare/app_common/widgets/external_configurator.py +++ b/nvflare/app_common/widgets/external_configurator.py @@ -16,8 +16,9 @@ from typing import List from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import FLMetaKey +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.fl_context import FLContext +from nvflare.app_common.utils.export_utils import update_export_props from nvflare.client.config import write_config_to_file from nvflare.client.constants import CLIENT_API_CONFIG from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components @@ -47,9 +48,7 @@ def __init__( def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.ABOUT_TO_START_RUN: components_data = self._export_all_components(fl_ctx) - components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name() - components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id() - + update_export_props(components_data, fl_ctx) config_file_path = self._get_external_config_file_path(fl_ctx) write_config_to_file(config_data=components_data, config_file_path=config_file_path) @@ -65,5 +64,11 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict: engine = fl_ctx.get_engine() all_components = engine.get_all_components() components = {i: all_components.get(i) for i in self._component_ids} - reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID] + reserved_keys = [ + FLMetaKey.SITE_NAME, + FLMetaKey.JOB_ID, + ConnPropKey.CP_CONN_PROPS, + ConnPropKey.ROOT_CONN_PROPS, + ConnPropKey.RELAY_CONN_PROPS, + ] return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER) diff --git a/nvflare/app_opt/job_launcher/k8s_launcher.py b/nvflare/app_opt/job_launcher/k8s_launcher.py index 39af3675ed..12db4891f2 100644 --- a/nvflare/app_opt/job_launcher/k8s_launcher.py +++ b/nvflare/app_opt/job_launcher/k8s_launcher.py @@ -277,7 +277,11 @@ def get_module_args(self, job_id, fl_ctx: FLContext): def _job_args_dict(job_args: dict, arg_names: list) -> dict: result = {} for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + + n, v = e result[n] = v return result diff --git a/nvflare/client/config.py b/nvflare/client/config.py index 477b132dc3..06369a9954 100644 --- a/nvflare/client/config.py +++ b/nvflare/client/config.py @@ -16,7 +16,7 @@ import os from typing import Dict, Optional -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.fuel.utils.config_factory import ConfigFactory @@ -157,7 +157,16 @@ def get_heartbeat_timeout(self): ) def get_connection_security(self): - return self.config.get(SecureTrainConst.CONNECTION_SECURITY) + return self.config.get(ConnPropKey.CONNECTION_SECURITY) + + def get_root_conn_props(self): + return self.config.get(ConnPropKey.ROOT_CONN_PROPS) + + def get_cp_conn_props(self): + return self.config.get(ConnPropKey.CP_CONN_PROPS) + + def get_relay_conn_props(self): + return self.config.get(ConnPropKey.RELAY_CONN_PROPS) def get_site_name(self): return self.config.get(FLMetaKey.SITE_NAME) diff --git a/nvflare/client/ex_process/api.py b/nvflare/client/ex_process/api.py index bb94b70939..73e9328a9e 100644 --- a/nvflare/client/ex_process/api.py +++ b/nvflare/client/ex_process/api.py @@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple from nvflare.apis.analytix import AnalyticsDataType -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey from nvflare.apis.utils.analytix_utils import create_analytic_dxo from nvflare.app_common.abstract.fl_model import FLModel from nvflare.client.api_spec import APISpec @@ -39,9 +39,18 @@ def _create_client_config(config: str) -> ClientConfig: raise ValueError(f"config should be a string but got: {type(config)}") site_name = client_config.get_site_name() - conn_sec = client_config.get_connection_security() - if conn_sec: - set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec) + + root_conn_props = client_config.get_root_conn_props() + if root_conn_props: + set_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = client_config.get_cp_conn_props() + if cp_conn_props: + set_scope_property(site_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) + + relay_conn_props = client_config.get_relay_conn_props() + if relay_conn_props: + set_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) # get message auth info and put them into Databus for CellPipe to use auth_token = client_config.get_auth_token() diff --git a/nvflare/fuel/f3/cellnet/connector_manager.py b/nvflare/fuel/f3/cellnet/connector_manager.py index 5c0e65a48f..567b6cb0ed 100644 --- a/nvflare/fuel/f3/cellnet/connector_manager.py +++ b/nvflare/fuel/f3/cellnet/connector_manager.py @@ -15,11 +15,13 @@ import time from typing import Union +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.common.excepts import ConfigError from nvflare.fuel.f3.cellnet.defs import ConnectorRequirementKey from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import CommError, Communicator, Mode +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.security.logging import secure_format_exception, secure_format_traceback @@ -48,6 +50,9 @@ def __init__(self, handle, connect_url: str, active: bool, params: dict): def get_connection_url(self): return self.connect_url + def get_connection_params(self): + return self.params + class ConnectorManager: """ @@ -85,6 +90,11 @@ def __init__(self, communicator: Communicator, secure: bool, comm_configurator: self.adhoc_scheme = adhoc_conf.get(_KEY_SCHEME) self.adhoc_resources = adhoc_conf.get(_KEY_RESOURCES) + # default conn sec + conn_sec = self.int_resources.get(DriverParams.CONNECTION_SECURITY) + if not conn_sec: + self.int_resources[DriverParams.CONNECTION_SECURITY] = ConnectionSecurity.CLEAR + self.logger.debug(f"internal scheme={self.int_scheme}, resources={self.int_resources}") self.logger.debug(f"adhoc scheme={self.adhoc_scheme}, resources={self.adhoc_resources}") self.comm_config = comm_config @@ -152,7 +162,7 @@ def _validate_conn_config(config: dict, key: str) -> Union[None, dict]: return conn_config def _get_connector( - self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool + self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool, conn_resources=None ) -> Union[None, ConnectorData]: if active and not url: raise RuntimeError("url is required by not provided for active connector!") @@ -193,10 +203,10 @@ def _get_connector( try: if active: - handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required, conn_resources) connect_url = url elif url: - handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required) + handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required, conn_resources) connect_url = url else: self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}") @@ -240,11 +250,13 @@ def get_internal_listener(self) -> Union[None, ConnectorData]: """ return self._get_connector(url="", active=False, internal=True, adhoc=False, secure=False) - def get_internal_connector(self, url: str) -> Union[None, ConnectorData]: + def get_internal_connector(self, url: str, conn_resources=None) -> Union[None, ConnectorData]: """ Try to get an internal listener. Args: url: """ - return self._get_connector(url=url, active=True, internal=True, adhoc=False, secure=False) + return self._get_connector( + url=url, active=True, internal=True, adhoc=False, secure=False, conn_resources=conn_resources + ) diff --git a/nvflare/fuel/f3/cellnet/core_cell.py b/nvflare/fuel/f3/cellnet/core_cell.py index 7228e7a8c8..3b5aec82c3 100644 --- a/nvflare/fuel/f3/cellnet/core_cell.py +++ b/nvflare/fuel/f3/cellnet/core_cell.py @@ -22,6 +22,7 @@ from typing import Dict, List, Tuple, Union from urllib.parse import urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager from nvflare.fuel.f3.cellnet.defs import ( @@ -43,7 +44,7 @@ from nvflare.fuel.f3.comm_config import CommConfigurator from nvflare.fuel.f3.communicator import Communicator, MessageReceiver from nvflare.fuel.f3.connection import Connection -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState from nvflare.fuel.f3.message import Message @@ -281,6 +282,7 @@ def __init__( credentials: dict, create_internal_listener: bool = False, parent_url: str = None, + parent_resources: dict = None, max_timeout=3600, bulk_check_interval=0.5, bulk_process_interval=0.5, @@ -296,6 +298,7 @@ def __init__( max_timeout: default timeout for send_and_receive create_internal_listener: whether to create an internal listener for child cells parent_url: url for connecting to parent cell + parent_resources: extra resources for making connection to parent FQCN is the names of all ancestor, concatenated with dots. @@ -331,7 +334,7 @@ def __init__( # If configured, use it; otherwise keep the original value of 'secure'. conn_security = credentials.get(DriverParams.CONNECTION_SECURITY.value) if conn_security: - if conn_security == ConnectionSecurity.INSECURE: + if conn_security == ConnectionSecurity.CLEAR: secure = False else: secure = True @@ -370,6 +373,7 @@ def __init__( self.root_url = root_url self.create_internal_listener = create_internal_listener self.parent_url = parent_url + self.parent_resources = parent_resources self.bulk_check_interval = bulk_check_interval self.max_bulk_size = max_bulk_size self.bulk_checker = None @@ -566,7 +570,7 @@ def _set_bb_for_client_root(self): def _set_bb_for_client_child(self, parent_url: str, create_internal_listener: bool): if parent_url: - self._create_internal_connector(parent_url) + self._create_internal_connector(parent_url, self.parent_resources) if create_internal_listener: self._create_internal_listener() @@ -693,6 +697,11 @@ def get_internal_listener_url(self) -> Union[None, str]: return None return self.int_listener.get_connection_url() + def get_internal_listener_params(self) -> Union[None, dict]: + if not self.int_listener: + return None + return self.int_listener.get_connection_params() + def _add_adhoc_connector(self, to_cell: str, url: str): if self.bb_ext_connector: # it is possible that the server root offers connect url after the bb_ext_connector is created @@ -786,8 +795,8 @@ def _create_bb_external_connector(self): else: raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone external connector to {self.root_url}") - def _create_internal_connector(self, url: str): - self.bb_int_connector = self.connector_manager.get_internal_connector(url) + def _create_internal_connector(self, url: str, resources=None): + self.bb_int_connector = self.connector_manager.get_internal_connector(url, resources) if self.bb_int_connector: self.logger.info(f"{self.my_info.fqcn}: created backbone internal connector to {url} on parent") else: diff --git a/nvflare/fuel/f3/cellnet/net_manager.py b/nvflare/fuel/f3/cellnet/net_manager.py index 82a1a0d980..e25232f2d3 100644 --- a/nvflare/fuel/f3/cellnet/net_manager.py +++ b/nvflare/fuel/f3/cellnet/net_manager.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.stats_pool import VALID_HIST_MODES, parse_hist_mode @@ -31,6 +32,12 @@ class NetManager(CommandModule): def __init__(self, agent: NetAgent, diagnose=False): self.agent = agent self.diagnose = diagnose + data_bus = DataBus() + data_bus.subscribe(["stop_cellnet"], self._stop_cellnet) + + def _stop_cellnet(self, topic: str, conn: Connection, db: DataBus): + self.agent.stop() + conn.append_string("Cellnet Stopped") def get_spec(self) -> CommandModuleSpec: return CommandModuleSpec( diff --git a/nvflare/fuel/f3/comm_config_utils.py b/nvflare/fuel/f3/comm_config_utils.py new file mode 100644 index 0000000000..5ffd4e07b0 --- /dev/null +++ b/nvflare/fuel/f3/comm_config_utils.py @@ -0,0 +1,36 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +def requires_secure_connection(resources: dict): + """Determine whether secure connection is required based on information in resources. + + Args: + resources: a dict that contains info for making connection + + Returns: whether secure connection is required + + """ + conn_sec = resources.get(DriverParams.CONNECTION_SECURITY.value) + if conn_sec: + # if connection security is specified, it takes precedence over the "secure" flag + if conn_sec == ConnectionSecurity.CLEAR: + return False + else: + return True + else: + # Connection security is not specified, check the "secure" flag. + return resources.get(DriverParams.SECURE.value, False) diff --git a/nvflare/fuel/f3/communicator.py b/nvflare/fuel/f3/communicator.py index dbd4e86298..7cd7681f36 100644 --- a/nvflare/fuel/f3/communicator.py +++ b/nvflare/fuel/f3/communicator.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import atexit +import copy import logging import os import weakref @@ -152,13 +153,14 @@ def register_message_receiver(self, app_id: int, receiver: MessageReceiver): self.conn_manager.register_message_receiver(app_id, receiver) - def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dict): + def add_connector(self, url: str, mode: Mode, secure: bool = False, resources=None) -> (str, dict): """Load a connector. The driver is selected based on the URL Args: url: The url to listen on or connect to, like "https://0:443". Use 0 for empty host mode: Active for connecting, Passive for listening secure: True if SSL is required. + resources: extra resources for creating connection Returns: A tuple of (A handle that can be used to delete connector, connector params) @@ -175,6 +177,8 @@ def add_connector(self, url: str, mode: Mode, secure: bool = False) -> (str, dic raise CommError(CommError.NOT_SUPPORTED, f"No driver found for URL {url}") params = parse_url(url) + if resources: + params.update(resources) return self.add_connector_advanced(driver_class(), mode, params, secure, False), params def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): @@ -199,7 +203,9 @@ def start_listener(self, scheme: str, resources: dict) -> (str, str, dict): raise CommError(CommError.NOT_SUPPORTED, f"No driver found for scheme {scheme}") connect_url, listening_url = driver_class.get_urls(scheme, resources) - params = parse_url(listening_url) + extra_params = parse_url(listening_url) + params = copy.copy(resources) + params.update(extra_params) handle = self.add_connector_advanced(driver_class(), Mode.PASSIVE, params, False, True) @@ -223,10 +229,14 @@ def add_connector_advanced( Raises: CommError: If any errors """ - + original_conn_sec = params.get(DriverParams.CONNECTION_SECURITY) if self.local_endpoint.conn_props: params.update(self.local_endpoint.conn_props) + if original_conn_sec: + # we do not allow the connection sec to be overwritten by the endpoint's conn_props + params[DriverParams.CONNECTION_SECURITY] = original_conn_sec + params[DriverParams.SECURE] = secure handle = self.conn_manager.add_connector(driver, params, mode) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index aa71b0441a..65544a2f84 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -22,6 +22,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers.aio_context import AioContext @@ -409,7 +410,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "grpcs" diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index 61c953a867..383f95d7cb 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -18,6 +18,7 @@ import websockets from websockets.exceptions import ConnectionClosedOK +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers import net_utils @@ -120,7 +121,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: scheme = "https" diff --git a/nvflare/fuel/f3/drivers/driver_params.py b/nvflare/fuel/f3/drivers/driver_params.py index 54118855e3..c03c89618d 100644 --- a/nvflare/fuel/f3/drivers/driver_params.py +++ b/nvflare/fuel/f3/drivers/driver_params.py @@ -44,12 +44,6 @@ class DriverParams(str, Enum): IMPLEMENTED_CONN_SEC = "implemented_conn_sec" -class ConnectionSecurity: - INSECURE = "insecure" - TLS = "tls" - MTLS = "mtls" - - class DriverCap(str, Enum): SEND_HEARTBEAT = "send_heartbeat" diff --git a/nvflare/fuel/f3/drivers/grpc/utils.py b/nvflare/fuel/f3/drivers/grpc/utils.py index bbad1a522f..7c1f0de896 100644 --- a/nvflare/fuel/f3/drivers/grpc/utils.py +++ b/nvflare/fuel/f3/drivers/grpc/utils.py @@ -13,8 +13,9 @@ # limitations under the License. import grpc +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams def use_aio_grpc(): diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 7cdd444360..bc37e54fce 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -20,6 +20,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver import ConnectorInfo @@ -274,7 +275,7 @@ def connect(self, connector: ConnectorInfo): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "nagrpcs" diff --git a/nvflare/fuel/f3/drivers/net_utils.py b/nvflare/fuel/f3/drivers/net_utils.py index 8aa4e1f60f..a649566647 100644 --- a/nvflare/fuel/f3/drivers/net_utils.py +++ b/nvflare/fuel/f3/drivers/net_utils.py @@ -20,8 +20,9 @@ from typing import Any, Optional from urllib.parse import parse_qsl, urlencode, urlparse +from nvflare.apis.fl_constant import ConnectionSecurity from nvflare.fuel.f3.comm_error import CommError -from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams +from nvflare.fuel.f3.drivers.driver_params import DriverParams from nvflare.fuel.utils.argument_utils import str2bool from nvflare.security.logging import secure_format_exception @@ -64,6 +65,10 @@ def get_ssl_context(params: dict, ssl_server: bool) -> Optional[SSLContext]: ca_path = params.get(DriverParams.CA_CERT.value) cert_path = params.get(DriverParams.SERVER_CERT.value) key_path = params.get(DriverParams.SERVER_KEY.value) + + if not cert_path or not key_path: + raise RuntimeError(f"not cert or key for SSL server: {params=}") + if conn_security == ConnectionSecurity.TLS: # do not require client auth ctx.verify_mode = ssl.CERT_NONE diff --git a/nvflare/fuel/f3/drivers/tcp_driver.py b/nvflare/fuel/f3/drivers/tcp_driver.py index f7aff1a75d..09256bca88 100644 --- a/nvflare/fuel/f3/drivers/tcp_driver.py +++ b/nvflare/fuel/f3/drivers/tcp_driver.py @@ -17,6 +17,7 @@ from socketserver import TCPServer, ThreadingTCPServer from typing import Any, Dict, List +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.drivers.base_driver import BaseDriver from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams @@ -100,7 +101,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = resources.get(DriverParams.SECURE) + secure = requires_secure_connection(resources) if secure: scheme = "stcp" diff --git a/nvflare/fuel/sec/authn.py b/nvflare/fuel/sec/authn.py index 4bbe2f66b3..2aa466a5ce 100644 --- a/nvflare/fuel/sec/authn.py +++ b/nvflare/fuel/sec/authn.py @@ -12,10 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. from nvflare.apis.fl_constant import CellMessageAuthHeaderKey +from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.message import Message +from nvflare.fuel.utils.validation_utils import check_object_type, check_str -def add_authentication_headers(msg: Message, client_name: str, auth_token, token_signature): +def add_authentication_headers(msg: Message, client_name: str, auth_token, token_signature, ssid=None): """Add authentication headers to the specified message. Args: @@ -23,6 +25,7 @@ def add_authentication_headers(msg: Message, client_name: str, auth_token, token client_name: name of the client auth_token: authentication token token_signature: token signature + ssid: optional SSID Returns: @@ -30,5 +33,52 @@ def add_authentication_headers(msg: Message, client_name: str, auth_token, token if client_name: msg.set_header(CellMessageAuthHeaderKey.CLIENT_NAME, client_name) + if ssid: + msg.set_header(CellMessageAuthHeaderKey.SSID, ssid) + msg.set_header(CellMessageAuthHeaderKey.TOKEN, auth_token if auth_token else "NA") msg.set_header(CellMessageAuthHeaderKey.TOKEN_SIGNATURE, token_signature if token_signature else "NA") + + +def set_add_auth_headers_filters(cell: Cell, client_name: str, auth_token: str, token_signature: str, ssid=None): + """Set filters for adding auth headers. + + Args: + cell: the cell to add the filters to. + client_name: name of the client + auth_token: authentication token + token_signature: token signature + ssid: SSID, optional + + Returns: None + + """ + check_object_type("cell", cell, Cell) + + if client_name: + check_str("client_name", client_name) + + check_str("auth_token", auth_token) + check_str("token_signature", token_signature) + + if ssid: + check_str("ssid", ssid) + + cell.core_cell.add_outgoing_reply_filter( + channel="*", + topic="*", + cb=add_authentication_headers, + client_name=client_name, + auth_token=auth_token, + token_signature=token_signature, + ssid=ssid, + ) + cell.core_cell.add_outgoing_request_filter( + channel="*", + topic="*", + cb=add_authentication_headers, + client_name=client_name, + auth_token=auth_token, + token_signature=token_signature, + ssid=ssid, + ) diff --git a/nvflare/fuel/utils/config_service.py b/nvflare/fuel/utils/config_service.py index 7095c59e8f..d8f59aaf6c 100644 --- a/nvflare/fuel/utils/config_service.py +++ b/nvflare/fuel/utils/config_service.py @@ -116,7 +116,9 @@ def initialize(cls, section_files: Dict[str, str], config_path: List[str], parse if not os.path.isdir(d): raise ValueError(f"'{d}' is not a valid directory") - cls._config_path = config_path + for d in config_path: + if d not in cls._config_path: + cls._config_path.append(d) for section, file_basename in section_files.items(): cls._sections[section] = cls.load_config_dict(file_basename, cls._config_path) @@ -185,7 +187,8 @@ def load_configuration(cls, file_basename: str) -> Optional[Config]: Returns: config data loaded, or None if the config file is not found. """ - return ConfigFactory.load_config(file_basename, cls._config_path) + result = ConfigFactory.load_config(file_basename, cls._config_path) + return result @classmethod def load_config_dict( diff --git a/nvflare/fuel/utils/pipe/cell_pipe.py b/nvflare/fuel/utils/pipe/cell_pipe.py index 1554a0dd44..7d266166b9 100644 --- a/nvflare/fuel/utils/pipe/cell_pipe.py +++ b/nvflare/fuel/utils/pipe/cell_pipe.py @@ -17,15 +17,16 @@ import time from typing import Tuple, Union -from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey, SystemVarName from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.cell import Message as CellMessage from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent from nvflare.fuel.f3.cellnet.utils import make_reply from nvflare.fuel.f3.drivers.driver_params import DriverParams -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.attributes_exportable import ExportMode from nvflare.fuel.utils.config_service import search_file from nvflare.fuel.utils.constants import Mode @@ -44,12 +45,16 @@ _HEADER_HB_SEQ = _PREFIX + "hb_seq" -def _cell_fqcn(mode, site_name, token): +def _cell_fqcn(mode, site_name, token, parent_fqcn): # The FQCN of the cell must be unique in the whole cellnet. # We use the combination of mode, site_name, and token to derive the value of FQCN # Since the token is usually used across all sites, the "site_name" differentiate cell on one site from another. # The two peer pipes on the same site share the same site_name and token, but are differentiated by their modes. - return f"{site_name}_{token}_{mode}" + base = f"{site_name}_{token}_{mode}" + if parent_fqcn == FQCN.ROOT_SERVER: + return base + else: + return FQCN.join([parent_fqcn, base]) def _to_cell_message(msg: Message, extra=None) -> CellMessage: @@ -77,8 +82,11 @@ class _CellInfo: A cell could be used by multiple pipes (e.g. one pipe for task interaction, another for metrics logging). """ - def __init__(self, cell, net_agent): + def __init__(self, site_name, cell, net_agent, auth_token, token_signature): + self.site_name = site_name self.cell = cell + self.auth_token = auth_token + self.token_signature = token_signature self.net_agent = net_agent self.started = False self.pipes = [] @@ -114,21 +122,15 @@ class CellPipe(Pipe): _lock = threading.Lock() _cells_info = {} # (root_url, site_name, token) => _CellInfo - _auth_token = None - _token_signature = None - _site_name = None @classmethod - def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_dir): + def _build_cell(cls, site_name, fqcn, parent_conn_props, secure_mode, workspace_dir, logger): """Build a cell if necessary. The combination of (root_url, site_name, token) uniquely determine one cell. There can be multiple pipes on the same cell. Args: - root_url: root url of the cell net - mode: mode (passive or active) of the pipe - site_name: name of the site - token: the unique token + parent_conn_props: parent for this cell secure_mode: whether cellnet is in secure mode workspace_dir: workspace that contains startup kit for connecting to server. Needed only if secure_mode @@ -136,11 +138,8 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di """ with cls._lock: - cls._site_name = site_name - cell_key = f"{root_url}.{site_name}.{token}" - ci = cls._cells_info.get(cell_key) + ci = cls._cells_info.get(fqcn) if not ci: - credentials = {} if secure_mode: root_cert_path = search_file(SSL_ROOT_CERT, workspace_dir) if not root_cert_path: @@ -149,35 +148,42 @@ def _build_cell(cls, mode, root_url, site_name, token, secure_mode, workspace_di credentials = { DriverParams.CA_CERT.value: root_cert_path, } + else: + credentials = {} - conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY) - if conn_sec: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + conn_sec = parent_conn_props.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_sec + + parent_url = parent_conn_props.get(ConnPropKey.URL) + + if FQCN.get_parent(fqcn): + # the cell has a parent: connect to the parent + cell_root = None + cell_parent_url = parent_url + else: + # the cell has no parent: the parent_url is the root of the cellnet + cell_root = parent_url + cell_parent_url = None cell = Cell( - fqcn=_cell_fqcn(mode, site_name, token), - root_url=root_url, + fqcn=fqcn, + root_url=cell_root, secure=secure_mode, credentials=credentials, + parent_url=cell_parent_url, create_internal_listener=False, ) - # set filter to add additional auth headers - cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=cls._add_auth_headers) - cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=cls._add_auth_headers) + auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") + token_signature = get_scope_property(site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") net_agent = NetAgent(cell) - ci = _CellInfo(cell, net_agent) - cls._cells_info[cell_key] = ci - return ci - - @classmethod - def _add_auth_headers(cls, message: CellMessage): - if not cls._auth_token: - cls._auth_token = get_scope_property(scope_name=cls._site_name, key=FLMetaKey.AUTH_TOKEN, default="NA") - cls._token_signature = get_scope_property(cls._site_name, FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA") + ci = _CellInfo(site_name, cell, net_agent, auth_token, token_signature) + cls._cells_info[fqcn] = ci - add_authentication_headers(message, cls._site_name, cls._auth_token, cls._token_signature) + set_add_auth_headers_filters(cell, ci.site_name, ci.auth_token, ci.token_signature) + return ci def __init__( self, @@ -203,9 +209,9 @@ def __init__( self.site_name = site_name self.token = token - self.root_url = root_url self.secure_mode = secure_mode self.workspace_dir = workspace_dir + self.root_url = root_url # this section is needed by job config to prevent building cell when using SystemVarName arguments # TODO: enhance this part @@ -219,8 +225,49 @@ def __init__( check_str("site_name", site_name) check_str("workspace_dir", workspace_dir) + # determine the endpoint for this pipe to connect to + root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS) + + if root_conn_props: + # Not in simulator + if not isinstance(root_conn_props, dict): + raise RuntimeError(f"expect root_conn_props for {site_name} to be dict but got {type(root_conn_props)}") + + cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS) + if cp_conn_props: + if not isinstance(cp_conn_props, dict): + raise RuntimeError(f"expect cp_conn_props to be dict but got {type(cp_conn_props)}") + + url_to_conns = { + root_conn_props.get(ConnPropKey.URL): root_conn_props, + cp_conn_props.get(ConnPropKey.URL): cp_conn_props, + } + + relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS) + if relay_conn_props: + if not isinstance(relay_conn_props, dict): + raise RuntimeError(f"expect relay_conn_props to be dict but got {type(relay_conn_props)}") + url_to_conns[relay_conn_props.get(ConnPropKey.URL)] = relay_conn_props + + if not root_url: + # root_url not specified - use CP! + root_url = cp_conn_props.get(ConnPropKey.URL) + self.root_url = root_url + + conn_props = url_to_conns.get(self.root_url) + if not conn_props: + raise RuntimeError(f"cannot determine conn props for '{root_url}'") + else: + # this is running in simulator + conn_props = { + ConnPropKey.URL: root_url, + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + } + mode = f"{mode}".strip().lower() # convert to lower case string - self.ci = self._build_cell(mode, root_url, site_name, token, secure_mode, workspace_dir) + fqcn = _cell_fqcn(mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) + + self.ci = self._build_cell(site_name, fqcn, conn_props, secure_mode, workspace_dir, self.logger) self.cell = self.ci.cell self.ci.add_pipe(self) @@ -231,7 +278,7 @@ def __init__( else: raise ValueError(f"invalid mode {mode} - must be 'active' or 'passive'") - self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token) + self.peer_fqcn = _cell_fqcn(peer_mode, site_name, token, conn_props.get(ConnPropKey.FQCN)) self.received_msgs = queue.Queue() # contains Message(s), not CellMessage(s)! self.channel = None # the cellnet message channel self.pipe_lock = threading.Lock() # used to ensure no msg to be sent after closed @@ -366,16 +413,14 @@ def close(self): def export(self, export_mode: str) -> Tuple[str, dict]: if export_mode == ExportMode.SELF: mode = self.mode - root_url = self.root_url else: mode = Mode.ACTIVE if self.mode == Mode.PASSIVE else Mode.PASSIVE - root_url = self.cell.get_root_url_for_child() export_args = { "mode": mode, "site_name": self.site_name, "token": self.token, - "root_url": root_url, + "root_url": self.root_url, "secure_mode": self.cell.core_cell.secure, "workspace_dir": self.workspace_dir, } diff --git a/nvflare/fuel/utils/url_utils.py b/nvflare/fuel/utils/url_utils.py new file mode 100644 index 0000000000..e80b7fbfa9 --- /dev/null +++ b/nvflare/fuel/utils/url_utils.py @@ -0,0 +1,95 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_SECURE_SCHEME_MAPPING = {"tcp": "stcp", "grpc": "grpcs", "http": "https"} +_CLEAR_SCHEME_MAPPING = {"stcp": "tcp", "grpcs": "grpc", "https": "http"} + + +def make_url(scheme: str, address, secure: bool) -> str: + """Make a full URL based on specified info + + Args: + scheme: scheme of the url + address: host address. Multiple formats are supported: + str: this is a string that contains host name and optionally port number (e.g. localhost:1234) + dict: contains item "host" and optionally "port" + tuple or list: contains 1 or 2 items for host and port + secure: whether secure connection is required + + Returns: + + """ + if secure: + if scheme in _SECURE_SCHEME_MAPPING.values(): + # already secure scheme + secure_scheme = scheme + else: + secure_scheme = _SECURE_SCHEME_MAPPING.get(scheme) + + if not secure_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + + scheme = secure_scheme + else: + if scheme in _CLEAR_SCHEME_MAPPING.values(): + # already clear scheme + clear_scheme = scheme + else: + clear_scheme = _CLEAR_SCHEME_MAPPING.get(scheme) + + if not clear_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + + scheme = clear_scheme + + if isinstance(address, str): + if not address: + raise ValueError("address must not be empty") + return f"{scheme}://{address}" + else: + port = None + if isinstance(address, (tuple, list)): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + host = address[0] + if len(address) > 1: + port = address[1] + elif isinstance(address, dict): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + + host = address.get("host") + if not host: + raise ValueError(f"invalid address {address}: missing 'host'") + + port = address.get("port") + if not port and len(address) > 1: + raise ValueError(f"invalid address {address}: missing 'port'") + else: + raise ValueError(f"invalid address: {address}") + + if not isinstance(host, str): + raise ValueError(f"invalid host '{host}': must be str but got {type(host)}") + + if port: + if not isinstance(port, (str, int)): + raise ValueError(f"invalid port '{port}': must be str or int but got {type(port)}") + port_str = f":{port}" + else: + port_str = "" + return f"{scheme}://{host}{port_str}" diff --git a/nvflare/job_config/script_runner.py b/nvflare/job_config/script_runner.py index d39f65db61..7b3ca4f40b 100644 --- a/nvflare/job_config/script_runner.py +++ b/nvflare/job_config/script_runner.py @@ -14,6 +14,7 @@ from typing import Optional, Type, Union +from nvflare.apis.fl_constant import SystemVarName from nvflare.app_common.abstract.launcher import Launcher from nvflare.app_common.executors.client_api_launcher_executor import ClientAPILauncherExecutor from nvflare.app_common.executors.in_process_client_api_executor import InProcessClientAPIExecutor @@ -22,8 +23,9 @@ from nvflare.app_common.widgets.metric_relay import MetricRelay from nvflare.client.config import ExchangeFormat, TransferType from nvflare.fuel.utils.import_utils import optional_import -from nvflare.fuel.utils.pipe.cell_pipe import CellPipe +from nvflare.fuel.utils.pipe.cell_pipe import CellPipe, Mode from nvflare.fuel.utils.pipe.pipe import Pipe +from nvflare.fuel.utils.validation_utils import check_str from .api import FedJob, validate_object_for_job @@ -35,6 +37,19 @@ class FrameworkType: TENSORFLOW = "tensorflow" +class PipeConnectType: + VIA_ROOT = "via_root" + VIA_CP = "via_cp" + VIA_RELAY = "via_relay" + + +_PIPE_CONNECT_URL = { + PipeConnectType.VIA_CP: "{" + SystemVarName.CP_URL + "}", + PipeConnectType.VIA_RELAY: "{" + SystemVarName.RELAY_URL + "}", + PipeConnectType.VIA_ROOT: "{" + SystemVarName.ROOT_URL + "}", +} + + class BaseScriptRunner: def __init__( self, @@ -52,6 +67,7 @@ def __init__( launcher: Optional[Launcher] = None, metric_relay: Optional[MetricRelay] = None, metric_pipe: Optional[Pipe] = None, + pipe_connect_type: str = None, ): """BaseScriptRunner is used with FedJob API to run or launch a script. @@ -102,6 +118,12 @@ def __init__( metric_pipe (Optional[Pipe], optional): An optional Pipe instance for passing metric data between components. This allows for real-time metric handling during execution. Defaults to `None`. + + pipe_connect_type: how pipe peers are to be connected: + Via Root: peers are both connected to the root of the cellnet + Via Relay: peers are both connected to the relay if a relay is used; otherwise via root. + Via CP: peers are both connected to the CP + If not specified, will be via CP. """ self._script = script self._script_args = script_args @@ -112,6 +134,7 @@ def __init__( self._params_exchange_format = params_exchange_format self._from_nvflare_converter_id = from_nvflare_converter_id self._to_nvflare_converter_id = to_nvflare_converter_id + self._pipe_connect_type = pipe_connect_type if self._framework == FrameworkType.PYTORCH: _, torch_ok = optional_import(module="torch") @@ -151,12 +174,35 @@ def __init__( elif executor is not None: validate_object_for_job("executor", executor, InProcessClientAPIExecutor) + if pipe_connect_type: + check_str("pipe_connect_type", pipe_connect_type) + valid_connect_types = [PipeConnectType.VIA_CP, PipeConnectType.VIA_RELAY, PipeConnectType.VIA_RELAY] + if pipe_connect_type not in valid_connect_types: + raise ValueError(f"invalid pipe_connect_type '{pipe_connect_type}': must be {valid_connect_types}") + self._metric_pipe = metric_pipe self._metric_relay = metric_relay self._task_pipe = task_pipe self._executor = executor self._launcher = launcher + def _create_cell_pipe(self): + ct = self._pipe_connect_type + if not ct: + ct = PipeConnectType.VIA_CP + conn_url = _PIPE_CONNECT_URL.get(ct) + if not conn_url: + raise RuntimeError(f"cannot determine pipe connect url for {self._pipe_connect_type}") + + return CellPipe( + mode=Mode.PASSIVE, + site_name="{" + SystemVarName.SITE_NAME + "}", + token="{" + SystemVarName.JOB_ID + "}", + root_url=conn_url, + secure_mode="{" + SystemVarName.SECURE_MODE + "}", + workspace_dir="{" + SystemVarName.WORKSPACE + "}", + ) + def add_to_fed_job(self, job: FedJob, ctx, **kwargs): """This method is used by Job API. @@ -172,18 +218,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): comp_ids = {} if self._launch_external_process: - task_pipe = ( - self._task_pipe - if self._task_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + task_pipe = self._task_pipe if self._task_pipe else self._create_cell_pipe() task_pipe_id = job.add_component("pipe", task_pipe, ctx) comp_ids["pipe_id"] = task_pipe_id @@ -211,18 +246,7 @@ def add_to_fed_job(self, job: FedJob, ctx, **kwargs): ) job.add_executor(executor, tasks=tasks, ctx=ctx) - metric_pipe = ( - self._metric_pipe - if self._metric_pipe - else CellPipe( - mode="PASSIVE", - site_name="{SITE_NAME}", - token="{JOB_ID}", - root_url="{ROOT_URL}", - secure_mode="{SECURE_MODE}", - workspace_dir="{WORKSPACE}", - ) - ) + metric_pipe = self._metric_pipe if self._metric_pipe else self._create_cell_pipe() metric_pipe_id = job.add_component("metrics_pipe", metric_pipe, ctx) comp_ids["metric_pipe_id"] = metric_pipe_id @@ -295,6 +319,7 @@ def __init__( framework: FrameworkType = FrameworkType.PYTORCH, params_exchange_format: ExchangeFormat = ExchangeFormat.NUMPY, params_transfer_type: str = TransferType.FULL, + pipe_connect_type: str = PipeConnectType.VIA_CP, ): """ScriptRunner is used with FedJob API to run or launch a script. @@ -310,6 +335,7 @@ def __init__( params_exchange_format (str): The format to exchange the parameters. Defaults to ExchangeFormat.NUMPY. params_transfer_type (str): How to transfer the parameters. FULL means the whole model parameters are sent. DIFF means that only the difference is sent. Defaults to TransferType.FULL. + pipe_connect_type (str): how pipe peers are to be connected """ super().__init__( script=script, @@ -319,4 +345,5 @@ def __init__( framework=framework, params_exchange_format=params_exchange_format, params_transfer_type=params_transfer_type, + pipe_connect_type=pipe_connect_type, ) diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index 99d765f770..76131ff14c 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from nvflare.apis.fl_constant import ConnectionSecurity class WorkDir: @@ -66,10 +67,9 @@ class ProvisionMode: class ConnSecurity: - CLEAR = "clear" - INSECURE = "insecure" - TLS = "tls" - MTLS = "mtls" + CLEAR = ConnectionSecurity.CLEAR + TLS = ConnectionSecurity.TLS + MTLS = ConnectionSecurity.MTLS class AdminRole: diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index c80208ba6a..7c2f00dce8 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -115,7 +115,7 @@ def _build_overseer(self, overseer: Participant, ctx: ProvisionContext): @staticmethod def _build_conn_properties(site: Participant, ctx: ProvisionContext, site_config: dict): - valid_values = [ConnSecurity.CLEAR, ConnSecurity.INSECURE, ConnSecurity.TLS, ConnSecurity.MTLS] + valid_values = [ConnSecurity.CLEAR, ConnSecurity.TLS, ConnSecurity.MTLS] conn_security = site.get_prop_fb(PropKey.CONN_SECURITY) if conn_security: assert isinstance(conn_security, str) @@ -124,8 +124,6 @@ def _build_conn_properties(site: Participant, ctx: ProvisionContext, site_config if conn_security not in valid_values: raise ValueError(f"invalid connection_security '{conn_security}': must be in {valid_values}") - if conn_security in [ConnSecurity.CLEAR, ConnSecurity.INSECURE]: - conn_security = ConnSecurity.INSECURE site_config["connection_security"] = conn_security custom_ca_cert = site.get_prop_fb(PropKey.CUSTOM_CA_CERT) diff --git a/nvflare/private/defs.py b/nvflare/private/defs.py index 0a452e262a..3a2af90122 100644 --- a/nvflare/private/defs.py +++ b/nvflare/private/defs.py @@ -143,11 +143,12 @@ class AppFolderConstants: class CellMessageHeaderKeys: CLIENT_NAME = CellMessageAuthHeaderKey.CLIENT_NAME + CLIENT_TYPE = "client_type" TOKEN = CellMessageAuthHeaderKey.TOKEN TOKEN_SIGNATURE = CellMessageAuthHeaderKey.TOKEN_SIGNATURE CLIENT_IP = "client_ip" PROJECT_NAME = "project_name" - SSID = "ssid" + SSID = CellMessageAuthHeaderKey.SSID UNAUTHENTICATED = "unauthenticated" JOB_ID = "job_id" JOB_IDS = "job_ids" @@ -155,6 +156,11 @@ class CellMessageHeaderKeys: ABORT_JOBS = "abort_jobs" +class ClientType: + RELAY = "relay" + REGULAR = "regular" + + AUTH_CLIENT_NAME_FOR_SJ = "server_job" diff --git a/nvflare/private/fed/app/client/worker_process.py b/nvflare/private/fed/app/client/worker_process.py index b43133abe5..4aa8bea5a6 100644 --- a/nvflare/private/fed/app/client/worker_process.py +++ b/nvflare/private/fed/app/client/worker_process.py @@ -154,6 +154,14 @@ def parse_arguments(): parser.add_argument("--sp_target", "-g", type=str, help="Sp target", required=True) parser.add_argument("--sp_scheme", "-scheme", type=str, help="Sp connection scheme", required=True) parser.add_argument("--parent_url", "-p", type=str, help="parent_url", required=True) + parser.add_argument( + "--parent_conn_sec", + "-pcs", + type=str, + help="parent conn security", + required=False, + default="", + ) parser.add_argument( "--fed_client", "-s", type=str, help="an aggregation server specification json file", required=True ) diff --git a/nvflare/private/fed/app/deployer/base_client_deployer.py b/nvflare/private/fed/app/deployer/base_client_deployer.py index b4c684c093..353d63ce20 100644 --- a/nvflare/private/fed/app/deployer/base_client_deployer.py +++ b/nvflare/private/fed/app/deployer/base_client_deployer.py @@ -45,8 +45,9 @@ def build(self, build_ctx): self.components = build_ctx["client_components"] self.handlers = build_ctx["client_handlers"] - def set_model_manager(self, model_manager): - self.model_manager = model_manager + relay_config = build_ctx.get("relay_config") + if relay_config: + self.client_config["relay_config"] = relay_config def create_fed_client(self, args, sp_target=None): if sp_target: diff --git a/nvflare/private/fed/app/fl_conf.py b/nvflare/private/fed/app/fl_conf.py index 7a60876637..d87a2d9b48 100644 --- a/nvflare/private/fed/app/fl_conf.py +++ b/nvflare/private/fed/app/fl_conf.py @@ -19,11 +19,14 @@ import sys from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FilterKey, SiteType, SystemConfigs +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, FilterKey, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node +from nvflare.fuel.utils.url_utils import make_url from nvflare.fuel.utils.wfconf import ConfigContext, ConfigError from nvflare.private.defs import SSLConstants from nvflare.private.json_configer import JsonConfigurator @@ -225,6 +228,8 @@ def __init__(self, workspace: Workspace, args, kv_list=None): config_files = workspace.get_config_files_for_startup(is_server=False, for_job=True if args.job_id else False) + print(f"got all config files: {config_files}") + JsonConfigurator.__init__( self, config_file_name=config_files, @@ -283,6 +288,72 @@ def build_component(self, config_dict): self.handlers.append(t) return t + def _determine_conn_props(self, client_name, config_data: dict): + relay_fqcn = None + relay_url = None + relay_conn_security = None + + # relay info is set in the client's relay__resources.json. + # If relay is used, then connect via the specified relay; if not, try to connect the Server directly + print(f"Config data: {config_data=}") + print(f"Args: {self.args=}") + relay_config = config_data.get(ConnPropKey.RELAY_CONFIG) + self.logger.info(f"got relay config: {relay_config}") + if relay_config: + if relay_config: + relay_fqcn = relay_config.get(ConnPropKey.FQCN) + scheme = relay_config.get(ConnPropKey.SCHEME) + addr = relay_config.get(ConnPropKey.ADDRESS) + relay_conn_security = relay_config.get(ConnPropKey.CONNECTION_SECURITY) + secure = True + if relay_conn_security == ConnectionSecurity.CLEAR: + secure = False + relay_url = make_url(scheme, addr, secure) + print(f"connect to server via relay: {relay_url=} {relay_fqcn=}") + else: + print("no relay defined: connect to server directly") + else: + print("no relay_config: connect to server directly") + + if relay_fqcn: + cp_fqcn = FQCN.join([relay_fqcn, client_name]) + else: + cp_fqcn = client_name + + if relay_fqcn: + relay_conn_props = { + ConnPropKey.FQCN: relay_fqcn, + ConnPropKey.URL: relay_url, + ConnPropKey.CONNECTION_SECURITY: relay_conn_security, + } + set_scope_property(client_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props) + + client = self.config_data["client"] + + if hasattr(self.args, "job_id") and self.args.job_id: + # this is CJ + sp_scheme = self.args.sp_scheme + sp_target = self.args.sp_target + root_url = f"{sp_scheme}://{sp_target}" + root_conn_props = { + ConnPropKey.FQCN: FQCN.ROOT_SERVER, + ConnPropKey.URL: root_url, + ConnPropKey.CONNECTION_SECURITY: client.get(ConnPropKey.CONNECTION_SECURITY), + } + set_scope_property(client_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props) + + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + ConnPropKey.URL: self.args.parent_url, + ConnPropKey.CONNECTION_SECURITY: self.args.parent_conn_sec, + } + else: + # this is CP + cp_conn_props = { + ConnPropKey.FQCN: cp_fqcn, + } + set_scope_property(client_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props) + def start_config(self, config_ctx: ConfigContext): """Start the config process. @@ -301,6 +372,17 @@ def start_config(self, config_ctx: ConfigContext): client[SSLConstants.CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.CERT]) if client.get(SSLConstants.ROOT_CERT): client[SSLConstants.ROOT_CERT] = self.workspace.get_file_path_in_startup(client[SSLConstants.ROOT_CERT]) + + client_name = self.cmd_vars.get("uid", None) + if not client_name: + raise ConfigError("missing 'uid' from command args") + + conn_sec = client.get(ConnPropKey.CONNECTION_SECURITY) + if conn_sec: + set_scope_property(client_name, ConnPropKey.CONNECTION_SECURITY, conn_sec) + + self._determine_conn_props(client_name, self.config_data) + except Exception: raise ValueError(f"Client config error: '{self.client_config_file_names}'") diff --git a/nvflare/private/fed/app/relay/relay.py b/nvflare/private/fed/app/relay/relay.py new file mode 100644 index 0000000000..1fc2d4a0da --- /dev/null +++ b/nvflare/private/fed/app/relay/relay.py @@ -0,0 +1,224 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import json +import logging +import os +import sys +import threading + +from nvflare.apis.fl_constant import ConnectionSecurity, ConnPropKey, ReservedKey, WorkspaceConstants +from nvflare.apis.fl_context import FLContext +from nvflare.apis.signal import Signal +from nvflare.apis.utils.decomposers import flare_decomposers +from nvflare.apis.workspace import Workspace +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.cellnet.net_agent import NetAgent +from nvflare.fuel.f3.drivers.driver_params import DriverParams +from nvflare.fuel.f3.drivers.net_utils import SSL_ROOT_CERT, enhance_credential_info +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm +from nvflare.fuel.sec.authn import set_add_auth_headers_filters +from nvflare.fuel.utils.argument_utils import parse_vars +from nvflare.fuel.utils.config_service import ConfigService, search_file +from nvflare.fuel.utils.log_utils import configure_logging +from nvflare.fuel.utils.url_utils import make_url +from nvflare.private.defs import ClientType +from nvflare.private.fed.authenticator import Authenticator, validate_auth_headers +from nvflare.private.fed.utils.identity_utils import TokenVerifier + + +class CellnetMonitor: + def __init__(self, stop_event: threading.Event, workspace: str): + self.stop_event = stop_event + self.workspace = workspace + + def cellnet_stopped(self): + touch_file = os.path.join(self.workspace, WorkspaceConstants.SHUTDOWN_FILE) + with open(touch_file, "a"): + os.utime(touch_file, None) + self.stop_event.set() + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--workspace", "-m", type=str, help="WORKSPACE folder", required=True) + parser.add_argument("--relay_config", "-s", type=str, help="relay config json file", required=True) + parser.add_argument("--set", metavar="KEY=VALUE", nargs="*") + args = parser.parse_args() + return args + + +def main(args): + workspace = Workspace(root_dir=args.workspace) + for name in [WorkspaceConstants.RESTART_FILE, WorkspaceConstants.SHUTDOWN_FILE]: + try: + f = workspace.get_file_path_in_root(name) + if os.path.exists(f): + os.remove(f) + except Exception as ex: + print(f"Could not remove file '{name}': {ex}. Please check your system before starting FL.") + sys.exit(-1) + + configure_logging(workspace, workspace.get_root_dir()) + logger = logging.getLogger() + + relay_config_file = workspace.get_file_path_in_startup(args.relay_config) + with open(relay_config_file, "rt") as f: + relay_config = json.load(f) + + if not isinstance(relay_config, dict): + raise RuntimeError(f"invalid relay config file {args.relay_config}") + + project_name = relay_config.get(ConnPropKey.PROJECT_NAME) + if not project_name: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.PROJECT_NAME}") + + server_identity = relay_config.get(ConnPropKey.SERVER_IDENTITY) + if not server_identity: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.SERVER_IDENTITY}") + + my_identity = relay_config.get(ConnPropKey.IDENTITY) + if not my_identity: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.IDENTITY}") + + parent = relay_config.get(ConnPropKey.PARENT) + if not parent: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing {ConnPropKey.PARENT}") + + parent_address = parent.get(ConnPropKey.ADDRESS) + if not parent_address: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.address") + + parent_scheme = parent.get(ConnPropKey.SCHEME) + if not parent_scheme: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.scheme") + + parent_fqcn = parent.get(ConnPropKey.FQCN) + if not parent_fqcn: + raise RuntimeError(f"invalid relay config file {args.relay_config}: missing parent.fqcn") + + cmd_vars = parse_vars(args.set) + secure_train = cmd_vars.get("secure_train", False) + logger.info(f"{cmd_vars=} {secure_train=}") + + stop_event = threading.Event() + monitor = CellnetMonitor(stop_event, args.workspace) + + ConfigService.initialize( + section_files={}, + config_path=[args.workspace], + ) + + root_cert_path = search_file(SSL_ROOT_CERT, args.workspace) + if not root_cert_path: + raise ValueError(f"cannot find {SSL_ROOT_CERT} from config path {args.workspace}") + + credentials = { + DriverParams.CA_CERT.value: root_cert_path, + } + enhance_credential_info(credentials) + + logger.info(f"{credentials=}") + + conn_security = parent.get(ConnPropKey.CONNECTION_SECURITY) + secure_conn = True + if conn_security: + credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security + if conn_security == ConnectionSecurity.CLEAR: + secure_conn = False + parent_url = make_url(parent_scheme, parent_address, secure_conn) + + if parent_fqcn == FQCN.ROOT_SERVER: + my_fqcn = my_identity + root_url = parent_url + parent_url = None + else: + my_fqcn = FQCN.join([parent_fqcn, my_identity]) + root_url = None + + flare_decomposers.register() + + cell = Cell( + fqcn=my_fqcn, + root_url=root_url, + secure=secure_conn, + credentials=credentials, + create_internal_listener=True, + parent_url=parent_url, + ) + NetAgent(cell, agent_closed_cb=monitor.cellnet_stopped) + cell.start() + + # authenticate + authenticator = Authenticator( + cell=cell, + project_name=project_name, + client_name=my_identity, + client_type=ClientType.RELAY, + expected_sp_identity=server_identity, + secure_mode=secure_train, + root_cert_file=credentials.get(DriverParams.CA_CERT.value), + private_key_file=credentials.get(DriverParams.CLIENT_KEY.value), + cert_file=credentials.get(DriverParams.CLIENT_CERT.value), + msg_timeout=5.0, + retry_interval=2.0, + ) + + abort_signal = Signal() + shared_fl_ctx = FLContext() + shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: my_identity}) + token, token_signature, ssid, token_verifier = authenticator.authenticate( + shared_fl_ctx=shared_fl_ctx, + abort_signal=abort_signal, + ) + + if secure_train: + if not isinstance(token_verifier, TokenVerifier): + raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") + + set_add_auth_headers_filters(cell, my_identity, token, token_signature, ssid) + + cell.core_cell.add_incoming_filter( + channel="*", + topic="*", + cb=_validate_auth_headers, + token_verifier=token_verifier, + logger=logger, + ) + + logger.info(f"Successfully authenticated to {server_identity}: {token=} {ssid=}") + + # wait until stopped + logger.info(f"Started relay {my_identity=} {my_fqcn=} {root_url=} {parent_url=} {parent_fqcn=}") + stop_event.wait() + cell.stop() + logger.info(f"Relay {my_fqcn} stopped.") + + +def _validate_auth_headers(message: CellMessage, token_verifier: TokenVerifier, logger): + """Validate auth headers from messages that go through the server. + Args: + message: the message to validate + Returns: + """ + return validate_auth_headers(message, token_verifier, logger) + + +if __name__ == "__main__": + args = parse_arguments() + rc = mpm.run(main_func=main, run_dir=args.workspace, args=args) + sys.exit(rc) diff --git a/nvflare/private/fed/app/server/runner_process.py b/nvflare/private/fed/app/server/runner_process.py index c0ac5aa576..f93712e4a5 100644 --- a/nvflare/private/fed/app/server/runner_process.py +++ b/nvflare/private/fed/app/server/runner_process.py @@ -23,13 +23,12 @@ from nvflare.apis.fl_constant import ConfigVarName, JobConstants, SiteType, SystemConfigs from nvflare.apis.workspace import Workspace from nvflare.fuel.common.excepts import ConfigError -from nvflare.fuel.f3.message import Message as CellMessage from nvflare.fuel.f3.mpm import MainProcessMonitor as mpm -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.log_utils import configure_logging, get_script_logger -from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants, CellMessageHeaderKeys +from nvflare.private.defs import AUTH_CLIENT_NAME_FOR_SJ, AppFolderConstants from nvflare.private.fed.app.fl_conf import FLServerStarterConfiger from nvflare.private.fed.app.utils import monitor_parent_process from nvflare.private.fed.server.server_app_runner import ServerAppRunner @@ -112,8 +111,13 @@ def main(args): ) # set filter to add additional auth headers - server.cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) - server.cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=_add_auth_headers, config=args) + set_add_auth_headers_filters( + cell=server.cell, + client_name=AUTH_CLIENT_NAME_FOR_SJ, + auth_token=args.job_id, + token_signature=args.token_signature, + ssid=args.ssid, + ) server.server_state = HotState(host=args.host, port=args.port, ssid=args.ssid) @@ -145,16 +149,6 @@ def main(args): raise e -def _add_auth_headers(message: CellMessage, config): - message.set_header(CellMessageHeaderKeys.SSID, config.ssid) - add_authentication_headers( - message, - client_name=AUTH_CLIENT_NAME_FOR_SJ, - auth_token=config.job_id, - token_signature=config.token_signature, - ) - - def parse_arguments(): """FL Server program starting point.""" parser = argparse.ArgumentParser() diff --git a/nvflare/private/fed/authenticator.py b/nvflare/private/fed/authenticator.py new file mode 100644 index 0000000000..8a905b4963 --- /dev/null +++ b/nvflare/private/fed/authenticator.py @@ -0,0 +1,322 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import socket +import time +import traceback +import uuid + +from nvflare.apis.fl_constant import ServerCommandKey +from nvflare.apis.fl_context import FLContext +from nvflare.apis.fl_exception import FLCommunicationError +from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal +from nvflare.fuel.f3.cellnet.cell import Cell +from nvflare.fuel.f3.cellnet.core_cell import make_reply as make_cellnet_reply +from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey +from nvflare.fuel.f3.cellnet.defs import ReturnCode +from nvflare.fuel.f3.cellnet.defs import ReturnCode as F3ReturnCode +from nvflare.fuel.f3.cellnet.fqcn import FQCN +from nvflare.fuel.f3.message import Message +from nvflare.fuel.f3.message import Message as CellMessage +from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, new_cell_message +from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, TokenVerifier, load_crt_bytes + + +def _get_client_ip(): + """Return localhost IP. + + More robust than ``socket.gethostbyname(socket.gethostname())``. See + https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib/28950776#28950776 + for more details. + + Returns: + The host IP + + """ + s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + try: + s.connect(("10.255.255.255", 1)) # doesn't even have to be reachable + ip = s.getsockname()[0] + except Exception: + ip = "127.0.0.1" + finally: + s.close() + return ip + + +class Authenticator: + def __init__( + self, + cell: Cell, + project_name: str, + client_name: str, + client_type: str, + expected_sp_identity: str, + secure_mode: bool, + root_cert_file: str, + private_key_file: str, + cert_file: str, + msg_timeout: float, + retry_interval: float, + ): + """Authenticator is to be used to register a client to the Server. + + Args: + cell: the communication cell + project_name: name of the project + client_name: name of the client + client_type: type of the client: regular or relay + expected_sp_identity: identity of the service provider (i.e. server) + secure_mode: whether the project is in secure training mode + root_cert_file: file path of the root cert + private_key_file: file path of the private key + cert_file: file path of the client's certificate + msg_timeout: timeout for authentication messages + retry_interval: interval between tries + """ + self.cell = cell + self.project_name = project_name + self.client_name = client_name + self.client_type = client_type + self.expected_sp_identity = expected_sp_identity + self.root_cert_file = root_cert_file + self.private_key_file = private_key_file + self.cert_file = cert_file + self.msg_timeout = msg_timeout + self.retry_interval = retry_interval + self.secure_mode = secure_mode + self.logger = get_obj_logger(self) + + def _challenge_server(self): + # ask server for its info and make sure that it matches expected host + my_nonce = str(uuid.uuid4()) + headers = {IdentityChallengeKey.COMMON_NAME: self.client_name, IdentityChallengeKey.NONCE: my_nonce} + challenge = new_cell_message(headers, None) + result = self.cell.send_request( + target=FQCN.ROOT_SERVER, + channel=CellChannel.SERVER_MAIN, + topic=CellChannelTopic.Challenge, + request=challenge, + timeout=self.msg_timeout, + ) + return_code = result.get_header(MessageHeaderKey.RETURN_CODE) + error = result.get_header(MessageHeaderKey.ERROR, "") + self.logger.info(f"challenge result: {return_code} {error}") + if return_code != ReturnCode.OK: + if return_code in [ReturnCode.TARGET_UNREACHABLE, ReturnCode.COMM_ERROR]: + # trigger retry + return None, None + err = result.get_header(MessageHeaderKey.ERROR, "") + raise FLCommunicationError(f"failed to challenge server: {return_code}: {err}") + + reply = result.payload + assert isinstance(reply, Shareable) + server_nonce = reply.get(IdentityChallengeKey.NONCE) + cert_bytes = reply.get(IdentityChallengeKey.CERT) + server_cert = load_crt_bytes(cert_bytes) + server_signature = reply.get(IdentityChallengeKey.SIGNATURE) + server_cn = reply.get(IdentityChallengeKey.COMMON_NAME) + + if server_cn != self.expected_sp_identity: + raise FLCommunicationError( + f"expected server identity is '{self.expected_sp_identity}' but got '{server_cn}'" + ) + + # Use IdentityVerifier to validate: + # - the server cert can be validated with the root cert. Note that all sites have the same root cert! + # - the asserted CN matches the CN on the server cert + # - signature received from the server is valid + id_verifier = IdentityVerifier(root_cert_file=self.root_cert_file) + id_verifier.verify_common_name( + asserter_cert=server_cert, asserted_cn=server_cn, nonce=my_nonce, signature=server_signature + ) + + self.logger.info(f"verified server identity '{self.expected_sp_identity}'") + return server_nonce, TokenVerifier(server_cert) + + def authenticate(self, shared_fl_ctx: FLContext, abort_signal: Signal): + """Register the client with the FLARE Server. + + Note that the client no longer needs to be directly connected with the Server! + + Since the client may be connected with the Server indirectly (e.g. via bridge nodes or proxy), in the secure + mode, the client authentication cannot be based on the connection's TLS cert. Instead, the server and the + client will explicitly authenticate each other using their provisioned PKI credentials, as follows: + + 1. Make sure that the Server is authentic. The client sends a Challenge request with a random nonce. + The server is expected to return the following in its reply: + - its cert and common name (Server_CN) + - signature on the received client nonce + Server_CN + - a random Server Nonce. This will be used for the server to validate the client's identity in the + Registration request. + + The client then validates to make sure: + - the Server_CN is the same as presented in the server cert + - the Server_CN is the same as configured in the client's config (fed_client.json) + - the signature is valid + + 2. Client sends Registration request that contains: + - client cert and common name (Client_CN) + - signature on the received Server Nonce + Client_CN + + The Server then validates to make sure: + - the Client_CN is the same as presented in the client cert + - the signature is valid + + NOTE: we do not explicitly validate certs' expiration time. This is because currently the same certs are + also used for SSL connections, which already validate expiration. + + Args: + fl_ctx: FLContext + + Returns: + The client's token + + """ + local_ip = _get_client_ip() + shareable = Shareable() + shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + + token_verifier = None + if self.secure_mode: + # explicitly authenticate with the Server + while True: + server_nonce, token_verifier = self._challenge_server() + + if abort_signal.triggered: + return None, None, None, None + + if server_nonce is None: + # retry + self.logger.info(f"re-challenge after {self.retry_interval} seconds") + time.sleep(self.retry_interval) + else: + break + + id_asserter = IdentityAsserter(private_key_file=self.private_key_file, cert_file=self.cert_file) + cn_signature = id_asserter.sign_common_name(nonce=server_nonce) + shareable[IdentityChallengeKey.CERT] = id_asserter.cert_data + shareable[IdentityChallengeKey.SIGNATURE] = cn_signature + shareable[IdentityChallengeKey.COMMON_NAME] = id_asserter.cn + self.logger.info(f"sent identity info for client {self.client_name}") + + headers = { + CellMessageHeaderKeys.CLIENT_NAME: self.client_name, + CellMessageHeaderKeys.CLIENT_TYPE: self.client_type, + CellMessageHeaderKeys.CLIENT_IP: local_ip, + CellMessageHeaderKeys.PROJECT_NAME: self.project_name, + } + login_message = new_cell_message(headers, shareable) + + self.logger.info("Trying to register with server ...") + while True: + try: + result = self.cell.send_request( + target=FQCN.ROOT_SERVER, + channel=CellChannel.SERVER_MAIN, + topic=CellChannelTopic.Register, + request=login_message, + timeout=self.msg_timeout, + ) + + if not isinstance(result, Message): + raise FLCommunicationError(f"expect result to be Message but got {type(result)}") + + return_code = result.get_header(MessageHeaderKey.RETURN_CODE) + self.logger.info(f"register RC: {return_code}") + if return_code == ReturnCode.UNAUTHENTICATED: + reason = result.get_header(MessageHeaderKey.ERROR) + self.logger.error(f"registration rejected: {reason}") + raise FLCommunicationError("error:client_registration " + reason) + + payload = result.payload + if not isinstance(payload, dict): + raise FLCommunicationError(f"expect payload to be dict but got {type(payload)}") + + token = payload.get(CellMessageHeaderKeys.TOKEN) + token_signature = payload.get(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") + ssid = payload.get(CellMessageHeaderKeys.SSID) + if not token and not abort_signal.triggered: + time.sleep(self.retry_interval) + else: + break + + except Exception as ex: + traceback.print_exc() + raise FLCommunicationError("error:client_registration", ex) + + # make sure token_verifier works + if token_verifier: + if not isinstance(token_verifier, TokenVerifier): + raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") + + if token_verifier and token_signature: + valid = token_verifier.verify(client_name=self.client_name, token=token, signature=token_signature) + if valid: + self.logger.info("Verified received token and signature successfully") + else: + raise RuntimeError("invalid token or verifier!") + + return token, token_signature, ssid, token_verifier + + +def validate_auth_headers(message: CellMessage, token_verifier: TokenVerifier, logger): + """Validate auth headers from messages that go through the server. + + Args: + message: the message to validate + token_verifier: the TokenVerifier to be used to verify the token and signature + + Returns: + """ + headers = message.headers + logger.debug(f"**** _validate_auth_headers: {headers=}") + topic = message.get_header(MessageHeaderKey.TOPIC) + channel = message.get_header(MessageHeaderKey.CHANNEL) + + origin = message.get_header(MessageHeaderKey.ORIGIN) + + if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: + # skip: client not registered yet + logger.debug(f"skip special message {topic=} {channel=}") + return None + + client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) + err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" + if not client_name: + err = "missing client name" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + token = message.get_header(CellMessageHeaderKeys.TOKEN) + if not token: + err = "missing auth token" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) + if not signature: + err = "missing auth token signature" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + if not token_verifier.verify(client_name, token, signature): + err = "invalid auth token signature" + logger.error(f"{err_text}: {err}") + return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) + + # all good + logger.debug(f"auth headers valid from {origin}: {topic=} {channel=}") + return None diff --git a/nvflare/private/fed/client/client_executor.py b/nvflare/private/fed/client/client_executor.py index cbf94777d5..c1c178d20b 100644 --- a/nvflare/private/fed/client/client_executor.py +++ b/nvflare/private/fed/client/client_executor.py @@ -17,7 +17,7 @@ from abc import ABC, abstractmethod from nvflare.apis.event_type import EventType -from nvflare.apis.fl_constant import AdminCommandNames, FLContextKey, RunProcessKey, SystemConfigs +from nvflare.apis.fl_constant import AdminCommandNames, ConnPropKey, FLContextKey, RunProcessKey, SystemConfigs from nvflare.apis.fl_context import FLContext from nvflare.apis.job_launcher_spec import JobLauncherSpec, JobProcessArgs from nvflare.apis.resource_manager_spec import ResourceManagerSpec @@ -201,6 +201,13 @@ def start_app( JobProcessArgs.STARTUP_CONFIG_FILE: ("-s", "fed_client.json"), JobProcessArgs.OPTIONS: ("--set", command_options), } + + params = client.cell.get_internal_listener_params() + if params: + parent_conn_sec = params.get(ConnPropKey.CONNECTION_SECURITY) + if parent_conn_sec: + job_args[JobProcessArgs.PARENT_CONN_SEC] = ("-pcs", parent_conn_sec) + fl_ctx.set_prop(key=FLContextKey.JOB_PROCESS_ARGS, value=job_args, private=True, sticky=False) job_handle = job_launcher.launch_job(job_meta, fl_ctx) self.logger.info(f"Launch job_id: {job_id} with job launcher: {type(job_launcher)} ") diff --git a/nvflare/private/fed/client/client_json_config.py b/nvflare/private/fed/client/client_json_config.py index acb66a3b05..2bc028ee50 100644 --- a/nvflare/private/fed/client/client_json_config.py +++ b/nvflare/private/fed/client/client_json_config.py @@ -16,8 +16,9 @@ from nvflare.apis.executor import Executor from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import SystemConfigs, SystemVarName +from nvflare.apis.fl_constant import ConnPropKey, SystemConfigs, SystemVarName from nvflare.apis.workspace import Workspace +from nvflare.fuel.data_event.utils import get_scope_property from nvflare.fuel.utils.argument_utils import parse_vars from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.json_scanner import Node @@ -64,11 +65,27 @@ def __init__( sp_target = args.sp_target sp_url = f"{sp_scheme}://{sp_target}" + # determine relay URL + # if relay is not used, use the root URL as relay URL + relay_conn_props = get_scope_property(args.client_name, ConnPropKey.RELAY_CONN_PROPS) + relay_url = None + if relay_conn_props: + relay_url = relay_conn_props.get(ConnPropKey.URL) + if not relay_url: + relay_url = sp_url + + if hasattr(args, "parent_url") and args.parent_url: + parent_url = args.parent_url + else: + parent_url = sp_url + sys_vars = { SystemVarName.JOB_ID: args.job_id, SystemVarName.SITE_NAME: args.client_name, SystemVarName.WORKSPACE: args.workspace, SystemVarName.ROOT_URL: sp_url, + SystemVarName.CP_URL: parent_url, + SystemVarName.RELAY_URL: relay_url, SystemVarName.SECURE_MODE: self.cmd_vars.get("secure_train", True), SystemVarName.JOB_CUSTOM_DIR: workspace_obj.get_app_custom_dir(args.job_id), } diff --git a/nvflare/private/fed/client/communicator.py b/nvflare/private/fed/client/communicator.py index 50d86cf30e..14a1b6fd67 100644 --- a/nvflare/private/fed/client/communicator.py +++ b/nvflare/private/fed/client/communicator.py @@ -12,10 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -import socket import time -import traceback -import uuid from typing import List, Optional from nvflare.apis.event_type import EventType @@ -26,43 +23,28 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.shareable import Shareable +from nvflare.apis.signal import Signal from nvflare.apis.utils.fl_context_utils import gen_new_peer_ctx from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell -from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey, ReturnCode +from nvflare.fuel.f3.cellnet.defs import MessageHeaderKey, ReturnCode from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.utils import format_size -from nvflare.fuel.f3.message import Message as CellMessage -from nvflare.fuel.sec.authn import add_authentication_headers +from nvflare.fuel.sec.authn import set_add_auth_headers_filters from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.private.defs import CellChannel, CellChannelTopic, CellMessageHeaderKeys, SpecialTaskName, new_cell_message +from nvflare.private.defs import ( + CellChannel, + CellChannelTopic, + CellMessageHeaderKeys, + ClientType, + SpecialTaskName, + new_cell_message, +) +from nvflare.private.fed.authenticator import Authenticator from nvflare.private.fed.client.client_engine_internal_spec import ClientEngineInternalSpec -from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, load_crt_bytes from nvflare.security.logging import secure_format_exception -def _get_client_ip(): - """Return localhost IP. - - More robust than ``socket.gethostbyname(socket.gethostname())``. See - https://stackoverflow.com/questions/166506/finding-local-ip-addresses-using-pythons-stdlib/28950776#28950776 - for more details. - - Returns: - The host IP - - """ - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - s.connect(("10.255.255.255", 1)) # doesn't even have to be reachable - ip = s.getsockname()[0] - except Exception: - ip = "127.0.0.1" - finally: - s.close() - return ip - - class Communicator: def __init__( self, @@ -88,7 +70,6 @@ def __init__( self.secure_train = secure_train self.verbose = False - self.should_stop = False self.heartbeat_done = False self.client_state_processors = client_state_processors self.compression = compression @@ -102,78 +83,37 @@ def __init__( self.token_signature = None self.ssid = None self.client_name = None + self.token_verifier = None + self.abort_signal = Signal() self.logger = get_obj_logger(self) + """ + To call set_add_auth_headers_filters, both cell and token must be available. + The set_cell is called when cell becomes available, set_auth is called when token becomes available. + In CP, set_cell happens before set_auth, hence we call set_add_auth_headers_filters in set_auth for CP. + In CJ, set_auth happens before set_cell, hence we call set_add_auth_headers_filters in set_cell for CJ. + """ + def set_auth(self, client_name, token, token_signature, ssid): self.ssid = ssid self.token_signature = token_signature self.token = token self.client_name = client_name - # put auth properties in databus so that they can be used elsewhere + if self.cell: + # for CP + set_add_auth_headers_filters(self.cell, client_name, token, token_signature, ssid) + + # put auth properties in data bus so that they can be used elsewhere set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN, value=token) set_scope_property(scope_name=client_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, value=token_signature) def set_cell(self, cell): self.cell = cell - - # set filter to add additional auth headers - cell.core_cell.add_outgoing_reply_filter(channel="*", topic="*", cb=self._add_auth_headers) - cell.core_cell.add_outgoing_request_filter(channel="*", topic="*", cb=self._add_auth_headers) - - def _add_auth_headers(self, message: CellMessage): - if self.ssid: - message.set_header(CellMessageHeaderKeys.SSID, self.ssid) - - # Note that auth info (client_name, token and signature) is not available until the client is fully - # authenticated. - add_authentication_headers(message, self.client_name, self.token, self.token_signature) - - def _challenge_server(self, client_name, expected_host, root_cert_file): - # ask server for its info and make sure that it matches expected host - my_nonce = str(uuid.uuid4()) - headers = {IdentityChallengeKey.COMMON_NAME: client_name, IdentityChallengeKey.NONCE: my_nonce} - challenge = new_cell_message(headers, None) - result = self.cell.send_request( - target=FQCN.ROOT_SERVER, - channel=CellChannel.SERVER_MAIN, - topic=CellChannelTopic.Challenge, - request=challenge, - timeout=self.maint_msg_timeout, - ) - return_code = result.get_header(MessageHeaderKey.RETURN_CODE) - error = result.get_header(MessageHeaderKey.ERROR, "") - self.logger.info(f"challenge result: {return_code} {error}") - if return_code != ReturnCode.OK: - if return_code in [ReturnCode.TARGET_UNREACHABLE, ReturnCode.COMM_ERROR]: - # trigger retry - return None - err = result.get_header(MessageHeaderKey.ERROR, "") - raise FLCommunicationError(f"failed to challenge server: {return_code}: {err}") - - reply = result.payload - assert isinstance(reply, Shareable) - server_nonce = reply.get(IdentityChallengeKey.NONCE) - cert_bytes = reply.get(IdentityChallengeKey.CERT) - server_cert = load_crt_bytes(cert_bytes) - server_signature = reply.get(IdentityChallengeKey.SIGNATURE) - server_cn = reply.get(IdentityChallengeKey.COMMON_NAME) - - if server_cn != expected_host: - raise FLCommunicationError(f"expected server identity is '{expected_host}' but got '{server_cn}'") - - # Use IdentityVerifier to validate: - # - the server cert can be validated with the root cert. Note that all sites have the same root cert! - # - the asserted CN matches the CN on the server cert - # - signature received from the server is valid - id_verifier = IdentityVerifier(root_cert_file=root_cert_file) - id_verifier.verify_common_name( - asserter_cert=server_cert, asserted_cn=server_cn, nonce=my_nonce, signature=server_signature - ) - - self.logger.info(f"verified server identity '{expected_host}'") - return server_nonce + if self.token: + # for CJ + set_add_auth_headers_filters(self.cell, self.client_name, self.token, self.token_signature, self.ssid) def client_registration(self, client_name, project_name, fl_ctx: FLContext): """Register the client with the FLARE Server. @@ -223,12 +163,14 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): raise RuntimeError("Client cell could not be created. Failed to login the client.") time.sleep(0.5) - local_ip = _get_client_ip() - shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) - shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) + private_key_file = None + root_cert_file = None + cert_file = None secure_mode = fl_ctx.get_prop(FLContextKey.SECURE_MODE, False) + expected_host = None + if secure_mode: # explicitly authenticate with the Server expected_host = None @@ -253,60 +195,24 @@ def client_registration(self, client_name, project_name, fl_ctx: FLContext): cert_file = client_config.get(SecureTrainConst.SSL_CERT) root_cert_file = client_config.get(SecureTrainConst.SSL_ROOT_CERT) - while True: - server_nonce = self._challenge_server(client_name, expected_host, root_cert_file) - if server_nonce is None and not self.should_stop: - # retry - self.logger.info(f"re-challenge after {self.client_register_interval} seconds") - time.sleep(self.client_register_interval) - else: - break - - id_asserter = IdentityAsserter(private_key_file=private_key_file, cert_file=cert_file) - cn_signature = id_asserter.sign_common_name(nonce=server_nonce) - shareable[IdentityChallengeKey.CERT] = id_asserter.cert_data - shareable[IdentityChallengeKey.SIGNATURE] = cn_signature - shareable[IdentityChallengeKey.COMMON_NAME] = id_asserter.cn - self.logger.info(f"sent identity info for client {client_name}") - - headers = { - CellMessageHeaderKeys.CLIENT_NAME: client_name, - CellMessageHeaderKeys.CLIENT_IP: local_ip, - CellMessageHeaderKeys.PROJECT_NAME: project_name, - } - login_message = new_cell_message(headers, shareable) - - self.logger.info("Trying to register with server ...") - while True: - try: - result = self.cell.send_request( - target=FQCN.ROOT_SERVER, - channel=CellChannel.SERVER_MAIN, - topic=CellChannelTopic.Register, - request=login_message, - timeout=self.maint_msg_timeout, - ) - return_code = result.get_header(MessageHeaderKey.RETURN_CODE) - self.logger.info(f"register RC: {return_code}") - if return_code == ReturnCode.UNAUTHENTICATED: - reason = result.get_header(MessageHeaderKey.ERROR) - self.logger.error(f"registration rejected: {reason}") - raise FLCommunicationError("error:client_registration " + reason) - - token = result.get_header(CellMessageHeaderKeys.TOKEN) - token_signature = result.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE, "NA") - ssid = result.get_header(CellMessageHeaderKeys.SSID) - if not token and not self.should_stop: - time.sleep(self.client_register_interval) - else: - self.set_auth(client_name, token, token_signature, ssid) - break - - except Exception as ex: - traceback.print_exc() - raise FLCommunicationError("error:client_registration", ex) - - return token, token_signature, ssid + authenticator = Authenticator( + cell=self.cell, + project_name=project_name, + client_name=client_name, + client_type=ClientType.REGULAR, + expected_sp_identity=expected_host, + secure_mode=secure_mode, + root_cert_file=root_cert_file, + private_key_file=private_key_file, + cert_file=cert_file, + msg_timeout=self.maint_msg_timeout, + retry_interval=self.client_register_interval, + ) + + token, signature, ssid, token_verifier = authenticator.authenticate(shared_fl_ctx, self.abort_signal) + self.token_verifier = token_verifier + self.set_auth(client_name, token, signature, ssid) + return token, signature, ssid def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): """Get a task from server. @@ -326,7 +232,6 @@ def pull_task(self, project_name, token, ssid, fl_ctx: FLContext, timeout=None): shareable = Shareable() shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) - client_name = fl_ctx.get_identity_name() task_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: project_name, @@ -444,10 +349,10 @@ def quit_remote(self, servers, task_name, token, ssid, fl_ctx: FLContext): server's reply to the last message """ + self.abort_signal.trigger(True) shared_fl_ctx = gen_new_peer_ctx(fl_ctx) shareable = Shareable() shareable.set_header(ServerCommandKey.PEER_FL_CONTEXT, shared_fl_ctx) - client_name = fl_ctx.get_identity_name() quit_message = new_cell_message( { CellMessageHeaderKeys.PROJECT_NAME: task_name, diff --git a/nvflare/private/fed/client/fed_client_base.py b/nvflare/private/fed/client/fed_client_base.py index 6a40717457..14215bed55 100644 --- a/nvflare/private/fed/client/fed_client_base.py +++ b/nvflare/private/fed/client/fed_client_base.py @@ -18,13 +18,13 @@ from nvflare.apis.filter import Filter from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import FLContextKey, SecureTrainConst, ServerCommandKey +from nvflare.apis.fl_constant import ConnPropKey, FLContextKey, SecureTrainConst, ServerCommandKey from nvflare.apis.fl_context import FLContext from nvflare.apis.fl_exception import FLCommunicationError from nvflare.apis.overseer_spec import SP from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal -from nvflare.fuel.data_event.utils import set_scope_property +from nvflare.fuel.data_event.utils import get_scope_property, set_scope_property from nvflare.fuel.f3.cellnet.cell import Cell from nvflare.fuel.f3.cellnet.fqcn import FQCN from nvflare.fuel.f3.cellnet.net_agent import NetAgent @@ -188,29 +188,36 @@ def _create_cell(self, location, scheme): """ # Determine the CP's fqcn root_url = scheme + "://" + location + root_conn_security = self.client_args.get(ConnPropKey.CONNECTION_SECURITY) - # bridge_fqcn and bridge_url are set in the client's local/resources.json. - # If they are set, then connect via the specified bridge; if not, try to connect the Server directly - bridge_fqcn = self.client_args.get("bridge_fqcn") - bridge_url = self.client_args.get("bridge_url") - if bridge_fqcn: - cp_fqcn = FQCN.join([bridge_fqcn, self.client_name]) - root_url = None # do not connect to server if bridge is used - else: - cp_fqcn = self.client_name + relay_conn_props = get_scope_property(self.client_name, ConnPropKey.RELAY_CONN_PROPS, {}) + self.logger.info(f"got {ConnPropKey.RELAY_CONN_PROPS}: {relay_conn_props}") + + relay_fqcn = relay_conn_props.get(ConnPropKey.FQCN) + if relay_fqcn: + root_url = None # do not connect to server if relay is used + cp_conn_props = get_scope_property(self.client_name, ConnPropKey.CP_CONN_PROPS) + cp_fqcn = cp_conn_props.get(ConnPropKey.FQCN) + parent_resources = None if self.args.job_id: # I am CJ me = "CJ" my_fqcn = FQCN.join([cp_fqcn, self.args.job_id]) - parent_url = self.args.parent_url + parent_url = cp_conn_props.get(ConnPropKey.URL) + parent_conn_sec = cp_conn_props.get(ConnPropKey.CONNECTION_SECURITY) create_internal_listener = False + if parent_conn_sec: + parent_resources = {DriverParams.CONNECTION_SECURITY.value: parent_conn_sec} else: # I am CP me = "CP" my_fqcn = cp_fqcn - parent_url = bridge_url + parent_url = relay_conn_props.get(ConnPropKey.URL) create_internal_listener = True + relay_conn_security = relay_conn_props.get(ConnPropKey.CONNECTION_SECURITY) + if relay_conn_security: + parent_resources = {DriverParams.CONNECTION_SECURITY.value: relay_conn_security} if self.secure_train: root_cert = self.client_args[SecureTrainConst.SSL_ROOT_CERT] @@ -222,13 +229,13 @@ def _create_cell(self, location, scheme): DriverParams.CLIENT_CERT.value: ssl_cert, DriverParams.CLIENT_KEY.value: private_key, } - conn_security = self.client_args.get(SecureTrainConst.CONNECTION_SECURITY) - if conn_security: - credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security - set_scope_property(self.client_name, SecureTrainConst.CONNECTION_SECURITY, conn_security) else: credentials = {} + if root_conn_security: + # this is the default conn sec + credentials[DriverParams.CONNECTION_SECURITY.value] = root_conn_security + self.logger.info(f"{me=}: {my_fqcn=} {root_url=} {parent_url=}") self.cell = Cell( fqcn=my_fqcn, @@ -237,6 +244,7 @@ def _create_cell(self, location, scheme): credentials=credentials, create_internal_listener=create_internal_listener, parent_url=parent_url, + parent_resources=parent_resources, ) self.cell.start() self.communicator.set_cell(self.cell) diff --git a/nvflare/private/fed/server/client_manager.py b/nvflare/private/fed/server/client_manager.py index 1e1c58ad11..2dade0ad57 100644 --- a/nvflare/private/fed/server/client_manager.py +++ b/nvflare/private/fed/server/client_manager.py @@ -23,7 +23,7 @@ from nvflare.apis.shareable import Shareable from nvflare.fuel.f3.cellnet.defs import IdentityChallengeKey, MessageHeaderKey from nvflare.fuel.utils.log_utils import get_obj_logger -from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, InternalFLContextKey +from nvflare.private.defs import CellMessageHeaderKeys, ClientRegSession, ClientType, InternalFLContextKey from nvflare.private.fed.utils.identity_utils import IdentityVerifier, load_crt_bytes from nvflare.security.logging import secure_format_exception @@ -57,10 +57,17 @@ def authenticate(self, request, fl_ctx: FLContext) -> Optional[Client]: # new client join with self.lock: - self.clients.update({client.token: client}) + client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE) + if client_type == ClientType.REGULAR: + self.clients.update({client.token: client}) + client_kind = "client" + else: + # do not update self.clients for non-regular clients + client_kind = client_type + self.logger.info( - "Client: New client {} joined. Sent token: {}. Total clients: {}".format( - client.name + "@" + client_ip, client.token, len(self.clients) + "Client: New {} {} joined. Sent token: {}. Total clients: {}".format( + client_kind, client.name + "@" + client_ip, client.token, len(self.clients) ) ) return client diff --git a/nvflare/private/fed/server/fed_server.py b/nvflare/private/fed/server/fed_server.py index 354df3fda3..f1008e5951 100644 --- a/nvflare/private/fed/server/fed_server.py +++ b/nvflare/private/fed/server/fed_server.py @@ -25,6 +25,7 @@ from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_constant import ( ConfigVarName, + ConnPropKey, FLContextKey, MachineStatus, RunProcessKey, @@ -63,13 +64,15 @@ CellChannelTopic, CellMessageHeaderKeys, ClientRegSession, + ClientType, InternalFLContextKey, JobFailureMsgKey, new_cell_message, ) +from nvflare.private.fed.authenticator import validate_auth_headers from nvflare.private.fed.server.server_command_agent import ServerCommandAgent from nvflare.private.fed.server.server_runner import ServerRunner -from nvflare.private.fed.utils.identity_utils import IdentityAsserter +from nvflare.private.fed.utils.identity_utils import IdentityAsserter, TokenVerifier from nvflare.security.logging import secure_format_exception from nvflare.widgets.fed_event import ServerFedEventRunner @@ -169,7 +172,7 @@ def deploy(self, args, grpc_args=None, secure_train=False): DriverParams.SERVER_KEY.value: private_key, } - conn_security = grpc_args.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = grpc_args.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: @@ -401,12 +404,15 @@ def _add_auth_headers(self, message: Message): """ origin = message.get_header(MessageHeaderKey.ORIGIN) dest = message.get_header(MessageHeaderKey.DESTINATION) - if origin == FQCN.ROOT_SERVER and dest == origin: - if not self.my_own_token_signature: - self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) - add_authentication_headers( - message, self.my_own_auth_client_name, self.my_own_token, self.my_own_token_signature - ) + channel = message.get_header(MessageHeaderKey.CHANNEL) + topic = message.get_header(MessageHeaderKey.TOPIC) + if not self.my_own_token_signature: + self.my_own_token_signature = self.sign_auth_token(self.my_own_auth_client_name, self.my_own_token) + + add_authentication_headers( + message, self.my_own_auth_client_name, self.my_own_token, self.my_own_token_signature + ) + self.logger.debug(f"added auth headers: {origin=} {dest=} {channel=} {topic=}") def _validate_auth_headers(self, message: Message): """Validate auth headers from messages that go through the server. @@ -414,45 +420,17 @@ def _validate_auth_headers(self, message: Message): message: the message to validate Returns: """ - headers = message.headers - self.logger.debug(f"**** _validate_auth_headers: {headers=}") - topic = message.get_header(MessageHeaderKey.TOPIC) - channel = message.get_header(MessageHeaderKey.CHANNEL) - - origin = message.get_header(MessageHeaderKey.ORIGIN) - - if topic in [CellChannelTopic.Register, CellChannelTopic.Challenge] and channel == CellChannel.SERVER_MAIN: - # skip: client not registered yet - self.logger.debug(f"skip special message {topic=} {channel=}") + id_asserter = self._get_id_asserter() + if not id_asserter: return None - client_name = message.get_header(CellMessageHeaderKeys.CLIENT_NAME) - err_text = f"unauthenticated msg ({channel=} {topic=}) received from {origin}" - if not client_name: - err = "missing client name" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - token = message.get_header(CellMessageHeaderKeys.TOKEN) - if not token: - err = "missing auth token" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - signature = message.get_header(CellMessageHeaderKeys.TOKEN_SIGNATURE) - if not signature: - err = "missing auth token signature" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - if not self.verify_auth_token(client_name, token, signature): - err = "invalid auth token signature" - self.logger.error(f"{err_text}: {err}") - return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error=err) - - # all good - self.logger.debug(f"auth valid from {origin}: {topic=} {channel=}") - return None + token_verifier = TokenVerifier(id_asserter.cert) + + return validate_auth_headers( + message=message, + token_verifier=token_verifier, + logger=self.logger, + ) def sign_auth_token(self, client_name: str, token: str): id_asserter = self._get_id_asserter() @@ -464,7 +442,9 @@ def verify_auth_token(self, client_name: str, token: str, signature): id_asserter = self._get_id_asserter() if not id_asserter: return True - return id_asserter.verify_signature(client_name + token, signature) + + token_verifier = TokenVerifier(id_asserter.cert) + return token_verifier.verify(client_name, token, signature) def _check_regs(self): while True: @@ -541,7 +521,7 @@ def create_job_cell(self, job_id, root_url, parent_url, secure_train, server_con DriverParams.SERVER_KEY.value: private_key, } - conn_security = server_config.get(SecureTrainConst.CONNECTION_SECURITY) + conn_security = server_config.get(ConnPropKey.CONNECTION_SECURITY) if conn_security: credentials[DriverParams.CONNECTION_SECURITY.value] = conn_security else: @@ -710,20 +690,22 @@ def register_client(self, request: Message) -> Message: client = self.client_manager.authenticate(request, fl_ctx) if client and client.token: - self.tokens[client.token] = self.task_meta_info(client.name) - if self.admin_server: - self.admin_server.client_heartbeat(client.token, client.name, client.get_fqcn()) + client_type = request.get_header(CellMessageHeaderKeys.CLIENT_TYPE) + if client_type == ClientType.REGULAR: + self.tokens[client.token] = self.task_meta_info(client.name) + if self.admin_server: + self.admin_server.client_heartbeat(client.token, client.name, client.get_fqcn()) token_signature = self.sign_auth_token(client.name, client.token) - headers = { + result = { CellMessageHeaderKeys.TOKEN: client.token, CellMessageHeaderKeys.TOKEN_SIGNATURE: token_signature, CellMessageHeaderKeys.SSID: self.server_state.ssid, } else: - headers = {} + result = {} self.engine.fire_event(EventType.CLIENT_REGISTER_PROCESSED, fl_ctx=fl_ctx) - return self._generate_reply(headers=headers, payload=None, fl_ctx=fl_ctx) + return self._generate_reply(headers={}, payload=result, fl_ctx=fl_ctx) except NotAuthenticated as e: self.logger.error(f"Failed to authenticate the register_client: {secure_format_exception(e)}") return make_cellnet_reply(rc=F3ReturnCode.UNAUTHENTICATED, error="register_client unauthenticated") diff --git a/nvflare/private/fed/server/server_state.py b/nvflare/private/fed/server/server_state.py index 5b0ba3ca9e..fa2a8cd414 100644 --- a/nvflare/private/fed/server/server_state.py +++ b/nvflare/private/fed/server/server_state.py @@ -91,11 +91,11 @@ def aux_communicate(self, fl_ctx: FLContext) -> dict: def handle_sd_callback(self, sp: SP, fl_ctx: FLContext) -> ServerState: if sp: - self.logger.info( + self.logger.debug( f"handle_sd_callback Got SP: {sp.name=} {sp.fl_port=} {sp.primary=} {self.host=} {self.service_port=}" ) else: - self.logger.info("handle_sd_callback no SP!") + self.logger.debug("handle_sd_callback no SP!") if sp and sp.primary is True: if sp.name == self.host and sp.fl_port == self.service_port: diff --git a/nvflare/private/fed/server/training_cmds.py b/nvflare/private/fed/server/training_cmds.py index 8ad9be7b67..7d31af86e0 100644 --- a/nvflare/private/fed/server/training_cmds.py +++ b/nvflare/private/fed/server/training_cmds.py @@ -18,6 +18,7 @@ from nvflare.apis.client import Client from nvflare.apis.fl_constant import AdminCommandNames, SiteType +from nvflare.fuel.data_event.data_bus import DataBus from nvflare.fuel.hci.conn import Connection from nvflare.fuel.hci.proto import ConfirmMethod, MetaKey, MetaStatusValue, make_meta from nvflare.fuel.hci.reg import CommandModule, CommandModuleSpec, CommandSpec @@ -160,6 +161,12 @@ def shutdown(self, conn: Connection, args: List[str]): conn.update_meta(make_meta(MetaStatusValue.ERROR, "failed to shut down all clients")) return + if target_type in [self.TARGET_TYPE_ALL]: + # shutdown the cellnet + data_bus = DataBus() + data_bus.publish(["stop_cellnet"], conn) + # time.sleep(2.0) + if target_type in [self.TARGET_TYPE_SERVER, self.TARGET_TYPE_ALL]: # shut down the server err = self._shutdown_app_on_server(conn) diff --git a/nvflare/private/fed/simulator/simulator_server.py b/nvflare/private/fed/simulator/simulator_server.py index 881eed6039..13b273fa1b 100644 --- a/nvflare/private/fed/simulator/simulator_server.py +++ b/nvflare/private/fed/simulator/simulator_server.py @@ -79,7 +79,6 @@ def create_job_processing_context_properties(self, workspace, job_id): class SimulatorIdentityAsserter(IdentityAsserter): - def __init__(self, private_key_file: str, cert_file: str): self.private_key_file = private_key_file self.cert_file = cert_file diff --git a/nvflare/private/fed/utils/identity_utils.py b/nvflare/private/fed/utils/identity_utils.py index d8a8a44850..10d948d6a3 100644 --- a/nvflare/private/fed/utils/identity_utils.py +++ b/nvflare/private/fed/utils/identity_utils.py @@ -14,6 +14,7 @@ from cryptography.x509.oid import NameOID +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.lighter.utils import ( load_crt, load_crt_bytes, @@ -110,3 +111,18 @@ def verify_common_name(self, asserted_cn: str, nonce: str, asserter_cert, signat except Exception as ex: raise InvalidCNSignature(f"cannot verify common name signature: {secure_format_exception(ex)}") return True + + +class TokenVerifier: + def __init__(self, cert): + self.cert = cert + self.public_key = cert.public_key() + self.logger = get_obj_logger(self) + + def verify(self, client_name, token, signature): + try: + verify_content(content=client_name + token, signature=signature, public_key=self.public_key) + return True + except Exception as ex: + self.logger.error(f"exception verifying token: {client_name=} {token=}: {secure_format_exception(ex)}") + return False diff --git a/nvflare/private/json_configer.py b/nvflare/private/json_configer.py index aca538658b..aaecbcbd1b 100644 --- a/nvflare/private/json_configer.py +++ b/nvflare/private/json_configer.py @@ -22,6 +22,7 @@ from nvflare.fuel.utils.config_service import ConfigService from nvflare.fuel.utils.dict_utils import augment from nvflare.fuel.utils.json_scanner import JsonObjectProcessor, JsonScanner, Node +from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.fuel.utils.wfconf import resolve_var_refs from nvflare.security.logging import secure_format_exception @@ -54,6 +55,7 @@ def __init__( sys_vars: system vars """ JsonObjectProcessor.__init__(self) + self.logger = get_obj_logger(self) if not isinstance(num_passes, int): raise TypeError(f"num_passes must be int but got {num_passes}") diff --git a/nvflare/utils/job_launcher_utils.py b/nvflare/utils/job_launcher_utils.py index 6865673869..a3013e25aa 100644 --- a/nvflare/utils/job_launcher_utils.py +++ b/nvflare/utils/job_launcher_utils.py @@ -24,7 +24,10 @@ def _job_args_str(job_args, arg_names) -> str: result = "" sep = "" for name in arg_names: - n, v = job_args[name] + e = job_args.get(name) + if not e: + continue + n, v = e result += f"{sep}{n} {v}" sep = " " return result @@ -45,6 +48,7 @@ def get_client_job_args(include_exe_module=True, include_set_options=True): JobProcessArgs.JOB_ID, JobProcessArgs.CLIENT_NAME, JobProcessArgs.PARENT_URL, + JobProcessArgs.PARENT_CONN_SEC, JobProcessArgs.TARGET, JobProcessArgs.SCHEME, JobProcessArgs.STARTUP_CONFIG_FILE, diff --git a/tests/unit_test/client/in_process/api_test.py b/tests/unit_test/client/in_process/api_test.py index 835bd80b5d..a873d78499 100644 --- a/tests/unit_test/client/in_process/api_test.py +++ b/tests/unit_test/client/in_process/api_test.py @@ -65,8 +65,10 @@ def test_init_with_custom_interval(self): def test_init_subscriptions(self): client_api = InProcessClientAPI(self.task_metadata) xs = list(client_api.data_bus.subscribers.keys()) - xs.sort() - assert xs == [TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP] + + # Depending on the timing of this test, the data bus may have other subscribed topics + # since the data bus is a singleton! + assert set(xs).issuperset([TOPIC_ABORT, TOPIC_GLOBAL_RESULT, TOPIC_STOP]) def local_result_callback(self, data, topic): pass diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py new file mode 100644 index 0000000000..b8da01935f --- /dev/null +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection +from nvflare.fuel.f3.drivers.driver_params import DriverParams + +CS = DriverParams.CONNECTION_SECURITY.value +S = DriverParams.SECURE.value +IS = ConnectionSecurity.CLEAR +T = ConnectionSecurity.TLS +M = ConnectionSecurity.MTLS + + +class TestCommConfigUtils: + + @pytest.mark.parametrize( + "resources, expected", + [ + ({}, False), + ({"x": 1, "y": 2}, False), + ({S: True}, True), + ({S: False}, False), + ({CS: IS}, False), + ({CS: T}, True), + ({CS: M}, True), + ({CS: M, S: False}, True), + ({CS: M, S: True}, True), + ({CS: T, S: False}, True), + ({CS: T, S: True}, True), + ({CS: IS, S: False}, False), + ({CS: IS, S: True}, False), + ], + ) + def test_requires_secure_connection(self, resources, expected): + result = requires_secure_connection(resources) + assert result == expected diff --git a/tests/unit_test/fuel/utils/url_utils_test.py b/tests/unit_test/fuel/utils/url_utils_test.py new file mode 100644 index 0000000000..a466d4a5c4 --- /dev/null +++ b/tests/unit_test/fuel/utils/url_utils_test.py @@ -0,0 +1,72 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.fuel.utils.url_utils import make_url + + +class TestUrlUtils: + + @pytest.mark.parametrize( + "scheme, address, secure, expected", + [ + ("tcp", "xyz.com", False, "tcp://xyz.com"), + ("tcp", "xyz.com:1234", False, "tcp://xyz.com:1234"), + ("tcp", "xyz.com:1234", True, "stcp://xyz.com:1234"), + ("grpc", "xyz.com", False, "grpc://xyz.com"), + ("grpc", "xyz.com:1234", False, "grpc://xyz.com:1234"), + ("grpc", "xyz.com:1234", True, "grpcs://xyz.com:1234"), + ("http", "xyz.com", False, "http://xyz.com"), + ("http", "xyz.com:1234", False, "http://xyz.com:1234"), + ("http", "xyz.com:1234", True, "https://xyz.com:1234"), + ("tcp", ("xyz.com",), False, "tcp://xyz.com"), + ("tcp", ("xyz.com", 1234), False, "tcp://xyz.com:1234"), + ("tcp", ["xyz.com"], False, "tcp://xyz.com"), + ("tcp", ["xyz.com", 1234], False, "tcp://xyz.com:1234"), + ("tcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("tcp", {"host": "xyz.com", "port": 1234}, False, "tcp://xyz.com:1234"), + ("stcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("https", {"host": "xyz.com"}, False, "http://xyz.com"), + ("grpcs", {"host": "xyz.com"}, False, "grpc://xyz.com"), + ("stcp", {"host": "xyz.com"}, True, "stcp://xyz.com"), + ("https", {"host": "xyz.com"}, True, "https://xyz.com"), + ("grpcs", {"host": "xyz.com"}, True, "grpcs://xyz.com"), + ], + ) + def test_make_url(self, scheme, address, secure, expected): + result = make_url(scheme, address, secure) + assert result == expected + + @pytest.mark.parametrize( + "scheme, address, secure", + [ + ("tcp", "", False), + ("abc", "xyz.com:1234", False), + ("tcp", 1234, True), + ("grpc", [], False), + ("grpc", (), False), + ("grpc", {}, True), + ("http", [1234], False), + ("http", [1234, "xyz.com"], False), + ("http", ["xyz.com", 1234, 22], True), + ("http", (1234,), False), + ("http", (1234, "xyz.com"), False), + ("http", ("xyz.com", 1234, 22), True), + ("tcp", {"hosts": "xyz.com"}, False), + ("tcp", {"host": "xyz.com", "port": 1234, "extra": 2323}, False), + ], + ) + def test_make_url_error(self, scheme, address, secure): + with pytest.raises(ValueError): + make_url(scheme, address, secure)