Skip to content

Commit

Permalink
add test cases
Browse files Browse the repository at this point in the history
  • Loading branch information
yanchengnv committed Jan 31, 2025
1 parent 6047f16 commit b50aefa
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 14 deletions.
14 changes: 12 additions & 2 deletions nvflare/fuel/f3/comm_config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,22 @@
from nvflare.fuel.f3.drivers.driver_params import DriverParams


def requires_secure(resources: dict):
def requires_secure_connection(resources: dict):
"""Determine whether secure connection is required based on information in resources.
Args:
resources: a dict that contains info for making connection
Returns: whether secure connection is required
"""
conn_sec = resources.get(DriverParams.CONNECTION_SECURITY)
if conn_sec:
# if connection security is specified, it takes precedence over the "secure" flag
if conn_sec == ConnectionSecurity.INSECURE:
return False
else:
return True
else:
return resources.get(DriverParams.SECURE)
# Connection security is not specified, check the "secure" flag.
return resources.get(DriverParams.SECURE, False)
4 changes: 2 additions & 2 deletions nvflare/fuel/f3/drivers/aio_grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import grpc

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_config_utils import requires_secure
from nvflare.fuel.f3.comm_config_utils import requires_secure_connection
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers.aio_context import AioContext
Expand Down Expand Up @@ -410,7 +410,7 @@ def shutdown(self):

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = requires_secure(resources)
secure = requires_secure_connection(resources)
if secure:
if use_aio_grpc():
scheme = "grpcs"
Expand Down
4 changes: 2 additions & 2 deletions nvflare/fuel/f3/drivers/aio_http_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import websockets
from websockets.exceptions import ConnectionClosedOK

from nvflare.fuel.f3.comm_config_utils import requires_secure
from nvflare.fuel.f3.comm_config_utils import requires_secure_connection
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers import net_utils
Expand Down Expand Up @@ -121,7 +121,7 @@ def shutdown(self):

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = requires_secure(resources)
secure = requires_secure_connection(resources)
if secure:
scheme = "https"

Expand Down
4 changes: 2 additions & 2 deletions nvflare/fuel/f3/drivers/grpc_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import grpc

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_config_utils import requires_secure
from nvflare.fuel.f3.comm_config_utils import requires_secure_connection
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import Connection
from nvflare.fuel.f3.drivers.driver import ConnectorInfo
Expand Down Expand Up @@ -275,7 +275,7 @@ def connect(self, connector: ConnectorInfo):

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = requires_secure(resources)
secure = requires_secure_connection(resources)
if secure:
if use_aio_grpc():
scheme = "nagrpcs"
Expand Down
4 changes: 2 additions & 2 deletions nvflare/fuel/f3/drivers/tcp_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from socketserver import TCPServer, ThreadingTCPServer
from typing import Any, Dict, List

from nvflare.fuel.f3.comm_config_utils import requires_secure
from nvflare.fuel.f3.comm_config_utils import requires_secure_connection
from nvflare.fuel.f3.drivers.base_driver import BaseDriver
from nvflare.fuel.f3.drivers.driver import ConnectorInfo, Driver
from nvflare.fuel.f3.drivers.driver_params import DriverCap, DriverParams
Expand Down Expand Up @@ -101,7 +101,7 @@ def shutdown(self):

@staticmethod
def get_urls(scheme: str, resources: dict) -> (str, str):
secure = requires_secure(resources)
secure = requires_secure_connection(resources)
if secure:
scheme = "stcp"

Expand Down
39 changes: 35 additions & 4 deletions nvflare/fuel/utils/url_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,22 +16,53 @@


def make_url(scheme: str, address, secure: bool) -> str:
"""Make a full URL based on specified info
Args:
scheme: scheme of the url
address: host address. Multiple formats are supported:
str: this is a string that contains host name and optionally port number (e.g. localhost:1234)
dict: contains item "host" and optionally "port"
tuple or list: contains 1 or 2 items for host and port
secure: whether secure connection is required
Returns:
"""
secure_scheme = _SECURE_SCHEME_MAPPING.get(scheme)
if not secure_scheme:
raise ValueError(f"unsupported scheme '{scheme}'")

if secure:
scheme = _SECURE_SCHEME_MAPPING.get(scheme)
if not scheme:
raise ValueError(f"unsupported scheme '{scheme}'")
scheme = secure_scheme

if isinstance(address, str):
if not address:
raise ValueError("address must not be empty")
return f"{scheme}://{address}"
else:
port = None
if isinstance(address, (tuple, list)):
if len(address) < 1:
raise ValueError("address must not be empty")
if len(address) > 2:
raise ValueError(f"invalid address {address}")
host = address[0]
if len(address) > 1:
port = address[1]
elif isinstance(address, dict):
host = address["host"]
if len(address) < 1:
raise ValueError("address must not be empty")
if len(address) > 2:
raise ValueError(f"invalid address {address}")

host = address.get("host")
if not host:
raise ValueError(f"invalid address {address}: missing 'host'")

port = address.get("port")
if not port and len(address) > 1:
raise ValueError(f"invalid address {address}: missing 'port'")
else:
raise ValueError(f"invalid address: {address}")

Expand Down
90 changes: 90 additions & 0 deletions tests/unit_test/fuel/f3/comm_config_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from nvflare.apis.fl_constant import ConnectionSecurity
from nvflare.fuel.f3.comm_config_utils import requires_secure_connection
from nvflare.fuel.f3.drivers.driver_params import DriverParams


class TestCommConfigUtils:

@pytest.mark.parametrize(
"resources, expected",
[
({}, False),
({"x": 1, "y": 2}, False),
({DriverParams.SECURE.value: True}, True),
({DriverParams.SECURE.value: False}, False),
({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE}, False),
({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS}, True),
({DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS}, True),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS,
DriverParams.SECURE.value: False,
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS,
DriverParams.SECURE.value: False,
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.MTLS,
DriverParams.SECURE.value: True,
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.TLS,
DriverParams.SECURE.value: True,
},
True
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE,
DriverParams.SECURE.value: True,
},
False
),
(
{
DriverParams.CONNECTION_SECURITY.value: ConnectionSecurity.INSECURE,
},
False
),
],
)
def test_requires_secure_connection(self, resources, expected):
result = requires_secure_connection(resources)
assert result == expected
66 changes: 66 additions & 0 deletions tests/unit_test/fuel/utils/url_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest

from nvflare.fuel.utils.url_utils import make_url


class TestUrlUtils:

@pytest.mark.parametrize(
"scheme, address, secure, expected",
[
("tcp", "xyz.com", False, "tcp://xyz.com"),
("tcp", "xyz.com:1234", False, "tcp://xyz.com:1234"),
("tcp", "xyz.com:1234", True, "stcp://xyz.com:1234"),
("grpc", "xyz.com", False, "grpc://xyz.com"),
("grpc", "xyz.com:1234", False, "grpc://xyz.com:1234"),
("grpc", "xyz.com:1234", True, "grpcs://xyz.com:1234"),
("http", "xyz.com", False, "http://xyz.com"),
("http", "xyz.com:1234", False, "http://xyz.com:1234"),
("http", "xyz.com:1234", True, "https://xyz.com:1234"),
("tcp", ("xyz.com",), False, "tcp://xyz.com"),
("tcp", ("xyz.com", 1234), False, "tcp://xyz.com:1234"),
("tcp", ["xyz.com"], False, "tcp://xyz.com"),
("tcp", ["xyz.com", 1234], False, "tcp://xyz.com:1234"),
("tcp", {"host": "xyz.com"}, False, "tcp://xyz.com"),
("tcp", {"host": "xyz.com", "port": 1234}, False, "tcp://xyz.com:1234"),
],
)
def test_make_url(self, scheme, address, secure, expected):
result = make_url(scheme, address, secure)
assert result == expected

@pytest.mark.parametrize(
"scheme, address, secure",
[
("tcp", "", False),
("abc", "xyz.com:1234", False),
("tcp", 1234, True),
("grpc", [], False),
("grpc", (), False),
("grpc", {}, True),
("http", [1234], False),
("http", [1234, "xyz.com"], False),
("http", ["xyz.com", 1234, 22], True),
("http", (1234,), False),
("http", (1234, "xyz.com"), False),
("http", ("xyz.com", 1234, 22), True),
("tcp", {"hosts": "xyz.com"}, False),
("tcp", {"host": "xyz.com", "port": 1234, "extra": 2323}, False),
],
)
def test_make_url_error(self, scheme, address, secure):
with pytest.raises(ValueError):
make_url(scheme, address, secure)

0 comments on commit b50aefa

Please sign in to comment.