Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support relay - Part 1 #3198

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
25 changes: 24 additions & 1 deletion nvflare/apis/fl_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,6 @@ class SecureTrainConst:
SSL_ROOT_CERT = "ssl_root_cert"
SSL_CERT = "ssl_cert"
PRIVATE_KEY = "ssl_private_key"
CONNECTION_SECURITY = "connection_security"


class FLMetaKey:
Expand Down Expand Up @@ -542,6 +541,8 @@ class SystemVarName:
WORKSPACE = "WORKSPACE" # directory of the workspace
JOB_ID = "JOB_ID" # Job ID
ROOT_URL = "ROOT_URL" # the URL of the Service Provider (server)
CP_URL = "CP_URL" # URL to CP
RELAY_URL = "RELAY_URL" # URL to relay that the CP is connected to
SECURE_MODE = "SECURE_MODE" # whether the system is running in secure mode
JOB_CUSTOM_DIR = "JOB_CUSTOM_DIR" # custom dir of the job
PYTHONPATH = "PYTHONPATH"
Expand All @@ -552,3 +553,25 @@ class RunnerTask:
INIT = "init"
TASK_EXEC = "task_exec"
END_RUN = "end_run"


class ConnPropKey:

IDENTITY = "identity"
PARENT = "parent"
FQCN = "fqcn"
URL = "url"
SCHEME = "scheme"
ADDRESS = "address"
CONNECTION_SECURITY = "connection_security"

RELAY_CONFIG = "relay_config"
CP_CONN_PROPS = "cp_conn_props"
RELAY_CONN_PROPS = "relay_conn_props"
ROOT_CONN_PROPS = "root_conn_props"


class ConnectionSecurity:
INSECURE = "insecure"
TLS = "tls"
MTLS = "mtls"
1 change: 1 addition & 0 deletions nvflare/apis/job_launcher_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class JobProcessArgs:
CLIENT_NAME = "client_name"
ROOT_URL = "root_url"
PARENT_URL = "parent_url"
PARENT_CONN_SEC = "parent_conn_sec"
SERVICE_HOST = "service_host"
SERVICE_PORT = "service_port"
HA_MODE = "ha_mode"
Expand Down
16 changes: 2 additions & 14 deletions nvflare/app_common/executors/client_api_launcher_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,12 @@
import os
from typing import Optional

from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.app_constant import AppConstants
from nvflare.app_common.executors.launcher_executor import LauncherExecutor
from nvflare.app_common.utils.export_utils import update_export_props
from nvflare.client.config import ConfigKey, ExchangeFormat, TransferType, write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.data_event.utils import get_scope_property
from nvflare.fuel.utils.attributes_exportable import ExportMode


Expand Down Expand Up @@ -126,22 +125,11 @@ def prepare_config_for_launch(self, fl_ctx: FLContext):
ConfigKey.HEARTBEAT_TIMEOUT: self.heartbeat_timeout,
}

site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

config_data = {
ConfigKey.TASK_EXCHANGE: task_exchange_attributes,
FLMetaKey.SITE_NAME: site_name,
FLMetaKey.JOB_ID: fl_ctx.get_job_id(),
FLMetaKey.AUTH_TOKEN: auth_token,
FLMetaKey.AUTH_TOKEN_SIGNATURE: signature,
}

conn_sec = get_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY)
if conn_sec:
config_data[SecureTrainConst.CONNECTION_SECURITY] = conn_sec

update_export_props(config_data, fl_ctx)
config_file_path = self._get_external_config_file_path(fl_ctx)
write_config_to_file(config_data=config_data, config_file_path=config_file_path)

Expand Down
39 changes: 39 additions & 0 deletions nvflare/app_common/utils/export_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.fuel.data_event.utils import get_scope_property


def update_export_props(props: dict, fl_ctx: FLContext):
site_name = fl_ctx.get_identity_name()
auth_token = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN, default="NA")
signature = get_scope_property(scope_name=site_name, key=FLMetaKey.AUTH_TOKEN_SIGNATURE, default="NA")

props[FLMetaKey.SITE_NAME] = site_name
props[FLMetaKey.JOB_ID] = fl_ctx.get_job_id()
props[FLMetaKey.AUTH_TOKEN] = auth_token
props[FLMetaKey.AUTH_TOKEN_SIGNATURE] = signature

root_conn_props = get_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS)
if root_conn_props:
props[ConnPropKey.ROOT_CONN_PROPS] = root_conn_props

cp_conn_props = get_scope_property(site_name, ConnPropKey.CP_CONN_PROPS)
if cp_conn_props:
props[ConnPropKey.CP_CONN_PROPS] = cp_conn_props

relay_conn_props = get_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS)
if relay_conn_props:
props[ConnPropKey.RELAY_CONN_PROPS] = relay_conn_props
15 changes: 10 additions & 5 deletions nvflare/app_common/widgets/external_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,9 @@
from typing import List

from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import FLMetaKey
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.utils.export_utils import update_export_props
from nvflare.client.config import write_config_to_file
from nvflare.client.constants import CLIENT_API_CONFIG
from nvflare.fuel.utils.attributes_exportable import ExportMode, export_components
Expand Down Expand Up @@ -47,9 +48,7 @@ def __init__(
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.ABOUT_TO_START_RUN:
components_data = self._export_all_components(fl_ctx)
components_data[FLMetaKey.SITE_NAME] = fl_ctx.get_identity_name()
components_data[FLMetaKey.JOB_ID] = fl_ctx.get_job_id()

update_export_props(components_data, fl_ctx)
config_file_path = self._get_external_config_file_path(fl_ctx)
write_config_to_file(config_data=components_data, config_file_path=config_file_path)

Expand All @@ -65,5 +64,11 @@ def _export_all_components(self, fl_ctx: FLContext) -> dict:
engine = fl_ctx.get_engine()
all_components = engine.get_all_components()
components = {i: all_components.get(i) for i in self._component_ids}
reserved_keys = [FLMetaKey.SITE_NAME, FLMetaKey.JOB_ID]
reserved_keys = [
FLMetaKey.SITE_NAME,
FLMetaKey.JOB_ID,
ConnPropKey.CP_CONN_PROPS,
ConnPropKey.ROOT_CONN_PROPS,
ConnPropKey.RELAY_CONN_PROPS,
]
return export_components(components=components, reserved_keys=reserved_keys, export_mode=ExportMode.PEER)
6 changes: 5 additions & 1 deletion nvflare/app_opt/job_launcher/k8s_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,11 @@ def get_module_args(self, job_id, fl_ctx: FLContext):
def _job_args_dict(job_args: dict, arg_names: list) -> dict:
result = {}
for name in arg_names:
n, v = job_args[name]
e = job_args.get(name)
if not e:
continue

n, v = e
result[n] = v
return result

Expand Down
13 changes: 11 additions & 2 deletions nvflare/client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import os
from typing import Dict, Optional

from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.fuel.utils.config_factory import ConfigFactory


Expand Down Expand Up @@ -157,7 +157,16 @@ def get_heartbeat_timeout(self):
)

def get_connection_security(self):
return self.config.get(SecureTrainConst.CONNECTION_SECURITY)
return self.config.get(ConnPropKey.CONNECTION_SECURITY)

def get_root_conn_props(self):
return self.config.get(ConnPropKey.ROOT_CONN_PROPS)

def get_cp_conn_props(self):
return self.config.get(ConnPropKey.CP_CONN_PROPS)

def get_relay_conn_props(self):
return self.config.get(ConnPropKey.RELAY_CONN_PROPS)

def get_site_name(self):
return self.config.get(FLMetaKey.SITE_NAME)
Expand Down
17 changes: 13 additions & 4 deletions nvflare/client/ex_process/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from typing import Any, Dict, Optional, Tuple

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.fl_constant import FLMetaKey, SecureTrainConst
from nvflare.apis.fl_constant import ConnPropKey, FLMetaKey
from nvflare.apis.utils.analytix_utils import create_analytic_dxo
from nvflare.app_common.abstract.fl_model import FLModel
from nvflare.client.api_spec import APISpec
Expand All @@ -39,9 +39,18 @@ def _create_client_config(config: str) -> ClientConfig:
raise ValueError(f"config should be a string but got: {type(config)}")

site_name = client_config.get_site_name()
conn_sec = client_config.get_connection_security()
if conn_sec:
set_scope_property(site_name, SecureTrainConst.CONNECTION_SECURITY, conn_sec)

root_conn_props = client_config.get_root_conn_props()
if root_conn_props:
set_scope_property(site_name, ConnPropKey.ROOT_CONN_PROPS, root_conn_props)

cp_conn_props = client_config.get_cp_conn_props()
if cp_conn_props:
set_scope_property(site_name, ConnPropKey.CP_CONN_PROPS, cp_conn_props)

relay_conn_props = client_config.get_relay_conn_props()
if relay_conn_props:
set_scope_property(site_name, ConnPropKey.RELAY_CONN_PROPS, relay_conn_props)

# get message auth info and put them into Databus for CellPipe to use
auth_token = client_config.get_auth_token()
Expand Down
22 changes: 17 additions & 5 deletions nvflare/fuel/f3/cellnet/connector_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
import time
from typing import Union

from nvflare.apis.fl_constant import ConnectionSecurity
from nvflare.fuel.common.excepts import ConfigError
from nvflare.fuel.f3.cellnet.defs import ConnectorRequirementKey
from nvflare.fuel.f3.cellnet.fqcn import FqcnInfo
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import CommError, Communicator, Mode
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.utils.log_utils import get_obj_logger
from nvflare.security.logging import secure_format_exception, secure_format_traceback

Expand Down Expand Up @@ -48,6 +50,9 @@ def __init__(self, handle, connect_url: str, active: bool, params: dict):
def get_connection_url(self):
return self.connect_url

def get_connection_params(self):
return self.params


class ConnectorManager:
"""
Expand Down Expand Up @@ -85,6 +90,11 @@ def __init__(self, communicator: Communicator, secure: bool, comm_configurator:
self.adhoc_scheme = adhoc_conf.get(_KEY_SCHEME)
self.adhoc_resources = adhoc_conf.get(_KEY_RESOURCES)

# default conn sec
conn_sec = self.int_resources.get(DriverParams.CONNECTION_SECURITY)
if not conn_sec:
self.int_resources[DriverParams.CONNECTION_SECURITY] = ConnectionSecurity.INSECURE

self.logger.debug(f"internal scheme={self.int_scheme}, resources={self.int_resources}")
self.logger.debug(f"adhoc scheme={self.adhoc_scheme}, resources={self.adhoc_resources}")
self.comm_config = comm_config
Expand Down Expand Up @@ -152,7 +162,7 @@ def _validate_conn_config(config: dict, key: str) -> Union[None, dict]:
return conn_config

def _get_connector(
self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool
self, url: str, active: bool, internal: bool, adhoc: bool, secure: bool, conn_resources=None
) -> Union[None, ConnectorData]:
if active and not url:
raise RuntimeError("url is required by not provided for active connector!")
Expand Down Expand Up @@ -193,10 +203,10 @@ def _get_connector(

try:
if active:
handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.ACTIVE, ssl_required, conn_resources)
connect_url = url
elif url:
handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required)
handle, conn_params = self.communicator.add_connector(url, Mode.PASSIVE, ssl_required, conn_resources)
connect_url = url
else:
self.logger.info(f"{os.getpid()}: Try start_listener Listener resources: {reqs}")
Expand Down Expand Up @@ -240,11 +250,13 @@ def get_internal_listener(self) -> Union[None, ConnectorData]:
"""
return self._get_connector(url="", active=False, internal=True, adhoc=False, secure=False)

def get_internal_connector(self, url: str) -> Union[None, ConnectorData]:
def get_internal_connector(self, url: str, conn_resources=None) -> Union[None, ConnectorData]:
"""
Try to get an internal listener.

Args:
url:
"""
return self._get_connector(url=url, active=True, internal=True, adhoc=False, secure=False)
return self._get_connector(
url=url, active=True, internal=True, adhoc=False, secure=False, conn_resources=conn_resources
)
17 changes: 13 additions & 4 deletions nvflare/fuel/f3/cellnet/core_cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import Dict, List, Tuple, Union
from urllib.parse import urlparse

from nvflare.apis.fl_constant import ConnectionSecurity
from nvflare.fuel.f3.cellnet.connector_manager import ConnectorManager
from nvflare.fuel.f3.cellnet.credential_manager import CredentialManager
from nvflare.fuel.f3.cellnet.defs import (
Expand All @@ -43,7 +44,7 @@
from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.communicator import Communicator, MessageReceiver
from nvflare.fuel.f3.connection import Connection
from nvflare.fuel.f3.drivers.driver_params import ConnectionSecurity, DriverParams
from nvflare.fuel.f3.drivers.driver_params import DriverParams
from nvflare.fuel.f3.drivers.net_utils import enhance_credential_info
from nvflare.fuel.f3.endpoint import Endpoint, EndpointMonitor, EndpointState
from nvflare.fuel.f3.message import Message
Expand Down Expand Up @@ -281,6 +282,7 @@ def __init__(
credentials: dict,
create_internal_listener: bool = False,
parent_url: str = None,
parent_resources: dict = None,
max_timeout=3600,
bulk_check_interval=0.5,
bulk_process_interval=0.5,
Expand All @@ -296,6 +298,7 @@ def __init__(
max_timeout: default timeout for send_and_receive
create_internal_listener: whether to create an internal listener for child cells
parent_url: url for connecting to parent cell
parent_resources: extra resources for making connection to parent

FQCN is the names of all ancestor, concatenated with dots.

Expand Down Expand Up @@ -370,6 +373,7 @@ def __init__(
self.root_url = root_url
self.create_internal_listener = create_internal_listener
self.parent_url = parent_url
self.parent_resources = parent_resources
self.bulk_check_interval = bulk_check_interval
self.max_bulk_size = max_bulk_size
self.bulk_checker = None
Expand Down Expand Up @@ -566,7 +570,7 @@ def _set_bb_for_client_root(self):

def _set_bb_for_client_child(self, parent_url: str, create_internal_listener: bool):
if parent_url:
self._create_internal_connector(parent_url)
self._create_internal_connector(parent_url, self.parent_resources)

if create_internal_listener:
self._create_internal_listener()
Expand Down Expand Up @@ -693,6 +697,11 @@ def get_internal_listener_url(self) -> Union[None, str]:
return None
return self.int_listener.get_connection_url()

def get_internal_listener_params(self) -> Union[None, dict]:
if not self.int_listener:
return None
return self.int_listener.get_connection_params()

def _add_adhoc_connector(self, to_cell: str, url: str):
if self.bb_ext_connector:
# it is possible that the server root offers connect url after the bb_ext_connector is created
Expand Down Expand Up @@ -786,8 +795,8 @@ def _create_bb_external_connector(self):
else:
raise RuntimeError(f"{self.my_info.fqcn}: cannot create backbone external connector to {self.root_url}")

def _create_internal_connector(self, url: str):
self.bb_int_connector = self.connector_manager.get_internal_connector(url)
def _create_internal_connector(self, url: str, resources=None):
self.bb_int_connector = self.connector_manager.get_internal_connector(url, resources)
if self.bb_int_connector:
self.logger.info(f"{self.my_info.fqcn}: created backbone internal connector to {url} on parent")
else:
Expand Down
Loading
Loading