diff --git a/nats/aio/client.py b/nats/aio/client.py index ff924fe8..cb80d9c2 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -1270,50 +1270,59 @@ async def _flush_pending( except asyncio.CancelledError: pass - def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: - if isinstance(connect_url, str): - try: - if "nats://" in connect_url or "tls://" in connect_url: - # Closer to how the Go client handles this. - # e.g. nats://localhost:4222 - uri = urlparse(connect_url) - elif "ws://" in connect_url or "wss://" in connect_url: - uri = urlparse(connect_url) - elif ":" in connect_url: - # Expand the scheme for the user - # e.g. localhost:4222 - uri = urlparse(f"nats://{connect_url}") - else: - # Just use the endpoint with the default NATS port. - # e.g. demo.nats.io - uri = urlparse(f"nats://{connect_url}:4222") - - # In case only endpoint with scheme was set. - # e.g. nats://demo.nats.io or localhost: - # the ws and wss do not need a default port as the transport will assume 80 and 443, respectively - if uri.port is None and uri.scheme not in ("ws", "wss"): - uri = urlparse(f"nats://{uri.hostname}:4222") - except ValueError: - raise errors.Error("nats: invalid connect url option") + def _normalize_url(self, url: str): + """ + Normalizes and validates a connection URL to ensure compatibility. + Adds default schemes and ports if missing, and checks for a valid hostname. + """ + try: + if "nats://" in url or "tls://" in url: + # Closer to how the Go client handles this. + # e.g. nats://localhost:4222 + uri = urlparse(url) + elif "ws://" in url or "wss://" in url: + uri = urlparse(url) + elif ":" in url: + # Expand the scheme for the user + # e.g. localhost:4222 + uri = urlparse(f"nats://{url}") + else: + # Just use the endpoint with the default NATS port. + # e.g. demo.nats.io + uri = urlparse(f"nats://{url}:4222") + + # In case only endpoint with scheme was set. + # e.g. nats://demo.nats.io or localhost: + # the ws and wss do not need a default port as the transport will assume 80 and 443, respectively + if uri.port is None and uri.scheme not in ("ws", "wss"): + # Set default NATS port 4222 if no port is specified + uri = urlparse(f"nats://{uri.hostname}:4222") if uri.hostname is None or uri.hostname == "none": + # Validates the hostname raise errors.Error("nats: invalid hostname in connect url") + + return uri + except ValueError: + raise errors.Error("nats: invalid connect url option") + + def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: + + if isinstance(connect_url, str) and "," in connect_url: + connect_url = connect_url.strip().split(",") + + if isinstance(connect_url, str): + uri = self._normalize_url(connect_url) self._server_pool.append(Srv(uri)) + elif isinstance(connect_url, list): - try: - for server in connect_url: - uri = urlparse(server) - self._server_pool.append(Srv(uri)) - except ValueError: - raise errors.Error("nats: invalid connect url option") - # make sure protocols aren't mixed - if not (all(server.uri.scheme in ("nats", "tls") - for server in self._server_pool) - or all(server.uri.scheme in ("ws", "wss") - for server in self._server_pool)): - raise errors.Error( - "nats: mixing of websocket and non websocket URLs is not allowed" - ) + for url in connect_url: + uri = self._normalize_url(url) + self._server_pool.append(Srv(uri)) + + if not (all(server.uri.scheme in ("nats", "tls") for server in self._server_pool) + or all(server.uri.scheme in ("ws", "wss") for server in self._server_pool)): + raise errors.Error("nats: mixing of websocket and non-websocket URLs is not allowed") else: raise errors.Error("nats: invalid connect url option") diff --git a/tests/test_client.py b/tests/test_client.py index 905effb8..02a632a6 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -88,6 +88,31 @@ def test_connect_syntax_sugar(self): ]) self.assertEqual(3, len(nc._server_pool)) + nc._setup_server_pool([ + "nats://127.0.0.1", "nats://127.0.0.1", + "nats://127.0.0.1" + ]) + self.assertEqual(3, len(nc._server_pool)) + + nc._setup_server_pool([ + "127.0.0.1", "127.0.0.1", + "127.0.0.1" + ]) + self.assertEqual(3, len(nc._server_pool)) + self.assertEqual(4222, nc._server_pool[0].uri.port) + + nc._setup_server_pool( + "127.0.0.1, 127.0.0.1, 127.0.0.1" + ) + self.assertEqual(3, len(nc._server_pool)) + self.assertEqual(4222, nc._server_pool[0].uri.port) + + nc._setup_server_pool([ + "nats://127.0.0.1:4222", "nats://127.0.0.1:4223", + "nats://127.0.0.1:4224" + ]) + self.assertEqual(3, len(nc._server_pool)) + nc = NATS() nc._setup_server_pool("nats://127.0.0.1:4222") self.assertEqual(1, len(nc._server_pool))