diff --git a/nvflare/fuel/f3/comm_config_utils.py b/nvflare/fuel/f3/comm_config_utils.py index 6a9097ef73..494d320646 100644 --- a/nvflare/fuel/f3/comm_config_utils.py +++ b/nvflare/fuel/f3/comm_config_utils.py @@ -15,12 +15,22 @@ from nvflare.fuel.f3.drivers.driver_params import DriverParams -def requires_secure(resources: dict): +def requires_secure_connection(resources: dict): + """Determine whether secure connection is required based on information in resources. + + Args: + resources: a dict that contains info for making connection + + Returns: whether secure connection is required + + """ conn_sec = resources.get(DriverParams.CONNECTION_SECURITY) if conn_sec: + # if connection security is specified, it takes precedence over the "secure" flag if conn_sec == ConnectionSecurity.INSECURE: return False else: return True else: - return resources.get(DriverParams.SECURE) + # Connection security is not specified, check the "secure" flag. + return resources.get(DriverParams.SECURE, False) diff --git a/nvflare/fuel/f3/drivers/aio_grpc_driver.py b/nvflare/fuel/f3/drivers/aio_grpc_driver.py index 55deec90d2..65544a2f84 100644 --- a/nvflare/fuel/f3/drivers/aio_grpc_driver.py +++ b/nvflare/fuel/f3/drivers/aio_grpc_driver.py @@ -22,7 +22,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers.aio_context import AioContext @@ -410,7 +410,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "grpcs" diff --git a/nvflare/fuel/f3/drivers/aio_http_driver.py b/nvflare/fuel/f3/drivers/aio_http_driver.py index cfcd7e2121..383f95d7cb 100644 --- a/nvflare/fuel/f3/drivers/aio_http_driver.py +++ b/nvflare/fuel/f3/drivers/aio_http_driver.py @@ -18,7 +18,7 @@ import websockets from websockets.exceptions import ConnectionClosedOK -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import BytesAlike, Connection from nvflare.fuel.f3.drivers import net_utils @@ -121,7 +121,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: scheme = "https" diff --git a/nvflare/fuel/f3/drivers/grpc_driver.py b/nvflare/fuel/f3/drivers/grpc_driver.py index 764a61772e..bc37e54fce 100644 --- a/nvflare/fuel/f3/drivers/grpc_driver.py +++ b/nvflare/fuel/f3/drivers/grpc_driver.py @@ -20,7 +20,7 @@ import grpc from nvflare.fuel.f3.comm_config import CommConfigurator -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.comm_error import CommError from nvflare.fuel.f3.connection import Connection from nvflare.fuel.f3.drivers.driver import ConnectorInfo @@ -275,7 +275,7 @@ def connect(self, connector: ConnectorInfo): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: if use_aio_grpc(): scheme = "nagrpcs" diff --git a/nvflare/fuel/f3/drivers/tcp_driver.py b/nvflare/fuel/f3/drivers/tcp_driver.py index 6ee11a4f4f..09256bca88 100644 --- a/nvflare/fuel/f3/drivers/tcp_driver.py +++ b/nvflare/fuel/f3/drivers/tcp_driver.py @@ -17,7 +17,7 @@ from socketserver import TCPServer, ThreadingTCPServer from typing import Any, Dict, List -from nvflare.fuel.f3.comm_config_utils import requires_secure +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection from nvflare.fuel.f3.drivers.base_driver import BaseDriver from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams @@ -101,7 +101,7 @@ def shutdown(self): @staticmethod def get_urls(scheme: str, resources: dict) -> (str, str): - secure = requires_secure(resources) + secure = requires_secure_connection(resources) if secure: scheme = "stcp" diff --git a/nvflare/fuel/utils/url_utils.py b/nvflare/fuel/utils/url_utils.py index 2219ebe424..38af73b705 100644 --- a/nvflare/fuel/utils/url_utils.py +++ b/nvflare/fuel/utils/url_utils.py @@ -16,22 +16,53 @@ def make_url(scheme: str, address, secure: bool) -> str: + """Make a full URL based on specified info + + Args: + scheme: scheme of the url + address: host address. Multiple formats are supported: + str: this is a string that contains host name and optionally port number (e.g. localhost:1234) + dict: contains item "host" and optionally "port" + tuple or list: contains 1 or 2 items for host and port + secure: whether secure connection is required + + Returns: + + """ + secure_scheme = _SECURE_SCHEME_MAPPING.get(scheme) + if not secure_scheme: + raise ValueError(f"unsupported scheme '{scheme}'") + if secure: - scheme = _SECURE_SCHEME_MAPPING.get(scheme) - if not scheme: - raise ValueError(f"unsupported scheme '{scheme}'") + scheme = secure_scheme if isinstance(address, str): + if not address: + raise ValueError("address must not be empty") return f"{scheme}://{address}" else: port = None if isinstance(address, (tuple, list)): + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") host = address[0] if len(address) > 1: port = address[1] elif isinstance(address, dict): - host = address["host"] + if len(address) < 1: + raise ValueError("address must not be empty") + if len(address) > 2: + raise ValueError(f"invalid address {address}") + + host = address.get("host") + if not host: + raise ValueError(f"invalid address {address}: missing 'host'") + port = address.get("port") + if not port and len(address) > 1: + raise ValueError(f"invalid address {address}: missing 'port'") else: raise ValueError(f"invalid address: {address}") diff --git a/tests/unit_test/fuel/f3/comm_config_utils_test.py b/tests/unit_test/fuel/f3/comm_config_utils_test.py new file mode 100644 index 0000000000..c703d9d4c3 --- /dev/null +++ b/tests/unit_test/fuel/f3/comm_config_utils_test.py @@ -0,0 +1,90 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.apis.fl_constant import ConnectionSecurity +from nvflare.fuel.f3.comm_config_utils import requires_secure_connection +from nvflare.fuel.f3.drivers.driver_params import DriverParams + + +class TestCommConfigUtils: + + @pytest.mark.parametrize( + "resources, expected", + [ + ({}, False), + ({"x": 1, "y": 2}, False), + ({DriverParams.SECURE.value: True}, True), + ({DriverParams.SECURE.value: False}, False), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE}, False), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS}, True), + ({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS}, True), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: False, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: False, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS, + DriverParams.SECURE.value: True, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS, + DriverParams.SECURE.value: True, + }, + True + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + DriverParams.SECURE.value: True, + }, + False + ), + ( + { + DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE, + }, + False + ), + ], + ) + def test_requires_secure_connection(self, resources, expected): + result = requires_secure_connection(resources) + assert result == expected diff --git a/tests/unit_test/fuel/utils/url_utils_test.py b/tests/unit_test/fuel/utils/url_utils_test.py new file mode 100644 index 0000000000..e4c338ffbc --- /dev/null +++ b/tests/unit_test/fuel/utils/url_utils_test.py @@ -0,0 +1,66 @@ +# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest + +from nvflare.fuel.utils.url_utils import make_url + + +class TestUrlUtils: + + @pytest.mark.parametrize( + "scheme, address, secure, expected", + [ + ("tcp", "xyz.com", False, "tcp://xyz.com"), + ("tcp", "xyz.com:1234", False, "tcp://xyz.com:1234"), + ("tcp", "xyz.com:1234", True, "stcp://xyz.com:1234"), + ("grpc", "xyz.com", False, "grpc://xyz.com"), + ("grpc", "xyz.com:1234", False, "grpc://xyz.com:1234"), + ("grpc", "xyz.com:1234", True, "grpcs://xyz.com:1234"), + ("http", "xyz.com", False, "http://xyz.com"), + ("http", "xyz.com:1234", False, "http://xyz.com:1234"), + ("http", "xyz.com:1234", True, "https://xyz.com:1234"), + ("tcp", ("xyz.com",), False, "tcp://xyz.com"), + ("tcp", ("xyz.com", 1234), False, "tcp://xyz.com:1234"), + ("tcp", ["xyz.com"], False, "tcp://xyz.com"), + ("tcp", ["xyz.com", 1234], False, "tcp://xyz.com:1234"), + ("tcp", {"host": "xyz.com"}, False, "tcp://xyz.com"), + ("tcp", {"host": "xyz.com", "port": 1234}, False, "tcp://xyz.com:1234"), + ], + ) + def test_make_url(self, scheme, address, secure, expected): + result = make_url(scheme, address, secure) + assert result == expected + + @pytest.mark.parametrize( + "scheme, address, secure", + [ + ("tcp", "", False), + ("abc", "xyz.com:1234", False), + ("tcp", 1234, True), + ("grpc", [], False), + ("grpc", (), False), + ("grpc", {}, True), + ("http", [1234], False), + ("http", [1234, "xyz.com"], False), + ("http", ["xyz.com", 1234, 22], True), + ("http", (1234,), False), + ("http", (1234, "xyz.com"), False), + ("http", ("xyz.com", 1234, 22), True), + ("tcp", {"hosts": "xyz.com"}, False), + ("tcp", {"host": "xyz.com", "port": 1234, "extra": 2323}, False), + ], + ) + def test_make_url_error(self, scheme, address, secure): + with pytest.raises(ValueError): + make_url(scheme, address, secure)