Skip to content

Commit 71182f2

Browse files
authored
Add method to determine server TLD (#8)
copied from Hochfrequenz/tmdsclient.py#22
1 parent cfcb124 commit 71182f2

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

src/bssclient/client/bssclient.py

+20
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
from aiohttp import BasicAuth, ClientSession, ClientTimeout
99
from more_itertools import chunked
10+
from yarl import URL
1011

1112
from bssclient.client.config import BssConfig
1213
from bssclient.models.aufgabe import AufgabeStats
@@ -27,6 +28,25 @@ def __init__(self, config: BssConfig):
2728
self._session: Optional[ClientSession] = None
2829
_logger.info("Instantiated BssClient with server_url %s", str(self._config.server_url))
2930

31+
def get_top_level_domain(self) -> URL | None:
32+
"""
33+
Returns the top level domain of the server_url; this is useful to differentiate prod from test systems.
34+
If the server_url is an IP address, None is returned.
35+
"""
36+
# this method is unit tested; check the testcases to understand its branches
37+
domain_parts = self._config.server_url.host.split(".") # type:ignore[union-attr]
38+
if all(x.isnumeric() for x in domain_parts):
39+
# seems like this is an IP address
40+
return None
41+
if not any(domain_parts):
42+
return self._config.server_url
43+
tld: str
44+
if domain_parts[-1] == "localhost":
45+
tld = ".".join(domain_parts[-1:])
46+
else:
47+
tld = ".".join(domain_parts[-2:])
48+
return URL(self._config.server_url.scheme + "://" + tld)
49+
3050
async def _get_session(self) -> ClientSession:
3151
"""
3252
returns a client session (that may be reused or newly created)

unittests/test_bss_client.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from yarl import URL
3+
4+
from bssclient import BssClient, BssConfig
5+
6+
7+
@pytest.mark.parametrize(
8+
"actual_url,expected_tld",
9+
[
10+
pytest.param(URL("https://bss.example.com"), URL("https://example.com")),
11+
pytest.param(URL("https://bss.prod.de"), URL("https://prod.de")),
12+
pytest.param(URL("https://test.de"), URL("https://test.de")),
13+
pytest.param(URL("https://localhost"), URL("https://localhost")),
14+
pytest.param(URL("http://test.localhost"), URL("http://localhost")),
15+
pytest.param(URL("http://foo.bar.test.localhost"), URL("http://localhost")),
16+
pytest.param(URL("http://1.2.3.4"), None),
17+
],
18+
)
19+
def test_get_tld(actual_url: URL, expected_tld: URL):
20+
config = BssConfig(server_url=actual_url, usr="user", pwd="password")
21+
client = BssClient(config)
22+
actual = client.get_top_level_domain()
23+
assert actual == expected_tld

0 commit comments

Comments
 (0)