diff --git a/nats/aio/client.py b/nats/aio/client.py index 76fa0480..15a3fb40 100644 --- a/nats/aio/client.py +++ b/nats/aio/client.py @@ -38,6 +38,7 @@ TypedDict, Union, ) +from typing_extensions import AsyncIterator from urllib.parse import ParseResult, urlparse from collections import UserString from io import BytesIO @@ -69,26 +70,26 @@ ) from .transport import TcpTransport, Transport, WebSocketTransport -__version__ = '2.9.0' -__lang__ = 'python3' +__version__ = "2.9.0" +__lang__ = "python3" _logger = logging.getLogger(__name__) PROTOCOL = 1 -INFO_OP = b'INFO' -CONNECT_OP = b'CONNECT' -PING_OP = b'PING' -PONG_OP = b'PONG' -OK_OP = b'+OK' -ERR_OP = b'-ERR' -_CRLF_ = b'\r\n' +INFO_OP = b"INFO" +CONNECT_OP = b"CONNECT" +PING_OP = b"PING" +PONG_OP = b"PONG" +OK_OP = b"+OK" +ERR_OP = b"-ERR" +_CRLF_ = b"\r\n" _CRLF_LEN_ = len(_CRLF_) -_SPC_ = b' ' +_SPC_ = b" " _SPC_BYTE_ = 32 EMPTY = "" PING_PROTO = PING_OP + _CRLF_ PONG_PROTO = PONG_OP + _CRLF_ -DEFAULT_INBOX_PREFIX = b'_INBOX' +DEFAULT_INBOX_PREFIX = b"_INBOX" DEFAULT_PENDING_SIZE = 2 * 1024 * 1024 DEFAULT_BUFFER_SIZE = 32768 @@ -103,7 +104,7 @@ DEFAULT_DRAIN_TIMEOUT = 30 # in seconds MAX_CONTROL_LINE_SIZE = 1024 -NATS_HDR_LINE = bytearray(b'NATS/1.0') +NATS_HDR_LINE = bytearray(b"NATS/1.0") NATS_HDR_LINE_SIZE = len(NATS_HDR_LINE) NO_RESPONDERS_STATUS = "503" CTRL_STATUS = "100" @@ -138,7 +139,6 @@ class Srv: class ServerVersion: - def __init__(self, server_version: str) -> None: self._server_version = server_version self._major_version: Optional[int] = None @@ -148,10 +148,10 @@ def __init__(self, server_version: str) -> None: # TODO(@orsinium): use cached_property def parse_version(self) -> None: - v = (self._server_version).split('-') + v = (self._server_version).split("-") if len(v) > 1: self._dev_version = v[1] - tokens = v[0].split('.') + tokens = v[0].split(".") n = len(tokens) if n > 1: self._major_version = int(tokens[0]) @@ -182,7 +182,7 @@ def patch(self) -> int: def dev(self) -> str: if not self._dev_version: self.parse_version() - return self._dev_version or '' + return self._dev_version or "" def __repr__(self) -> str: return f"" @@ -193,7 +193,7 @@ async def _default_error_callback(ex: Exception) -> None: Provides a default way to handle async errors if the user does not provide one. """ - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) class Client: @@ -287,12 +287,12 @@ def __init__(self) -> None: self.options: Dict[str, Any] = {} self.stats = { - 'in_msgs': 0, - 'out_msgs': 0, - 'in_bytes': 0, - 'out_bytes': 0, - 'reconnects': 0, - 'errors_received': 0, + "in_msgs": 0, + "out_msgs": 0, + "in_bytes": 0, + "out_bytes": 0, + "reconnects": 0, + "errors_received": 0, } async def connect( @@ -420,8 +420,13 @@ async def subscribe_handler(msg): """ - for cb in [error_cb, disconnected_cb, closed_cb, reconnected_cb, - discovered_server_cb]: + for cb in [ + error_cb, + disconnected_cb, + closed_cb, + reconnected_cb, + discovered_server_cb, + ]: if cb and not asyncio.iscoroutinefunction(cb): raise errors.InvalidCallbackTypeError @@ -461,12 +466,12 @@ async def subscribe_handler(msg): self.options["token"] = token self.options["connect_timeout"] = connect_timeout self.options["drain_timeout"] = drain_timeout - self.options['tls_handshake_first'] = tls_handshake_first + self.options["tls_handshake_first"] = tls_handshake_first if tls: - self.options['tls'] = tls + self.options["tls"] = tls if tls_hostname: - self.options['tls_hostname'] = tls_hostname + self.options["tls_hostname"] = tls_hostname # Check if the username or password was set in the server URI server_auth_configured = False @@ -502,7 +507,9 @@ async def subscribe_handler(msg): try: await self._select_next_server() await self._process_connect_init() - assert self._current_server, "the current server must be set by _select_next_server" + assert ( + self._current_server + ), "the current server must be set by _select_next_server" self._current_server.reconnects = 0 break except errors.NoServersError as e: @@ -542,7 +549,7 @@ def _setup_nkeys_jwt_connect(self) -> None: def user_cb() -> bytearray: contents = None - with open(creds[0], 'rb') as f: + with open(creds[0], "rb") as f: contents = bytearray(os.fstat(f.fileno()).st_size) f.readinto(contents) # type: ignore[attr-defined] return contents @@ -551,7 +558,7 @@ def user_cb() -> bytearray: def sig_cb(nonce: str) -> bytes: seed = None - with open(creds[1], 'rb') as f: + with open(creds[1], "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] kp = nkeys.from_seed(seed) @@ -587,7 +594,6 @@ def sig_cb(nonce: str) -> bytes: self._signature_cb = sig_cb def _read_creds_user_nkey(self, creds: str | UserString) -> bytearray: - def get_user_seed(f): for line in f: # Detect line where the NKEY would start and end, @@ -616,7 +622,6 @@ def get_user_seed(f): return get_user_seed(f) def _read_creds_user_jwt(self, creds: str | RawCredentials): - def get_user_jwt(f): user_jwt = None while True: @@ -625,7 +630,7 @@ def get_user_jwt(f): user_jwt = bytearray(f.readline()) break # Remove trailing line break but reusing same memory view. - return user_jwt[:len(user_jwt) - 1] + return user_jwt[: len(user_jwt) - 1] if isinstance(creds, UserString): return get_user_jwt(BytesIO(creds.data.encode())) @@ -634,7 +639,9 @@ def get_user_jwt(f): return get_user_jwt(f) def _setup_nkeys_seed_connect(self) -> None: - assert self._nkeys_seed or self._nkeys_seed_str, "Client.connect must be called first" + assert ( + self._nkeys_seed or self._nkeys_seed_str + ), "Client.connect must be called first" import os import nkeys @@ -645,7 +652,7 @@ def _get_nkeys_seed() -> nkeys.KeyPair: seed = bytearray(self._nkeys_seed_str.encode()) else: creds = self._nkeys_seed - with open(creds, 'rb') as f: + with open(creds, "rb") as f: seed = bytearray(os.fstat(f.fileno()).st_size) f.readinto(seed) # type: ignore[attr-defined] key_pair = nkeys.from_seed(seed) @@ -686,25 +693,26 @@ async def _close(self, status: int, do_cbs: bool = True) -> None: # Kick the flusher once again so that Task breaks and avoid pending futures. await self._flush_pending() - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( + if ( + self._ping_interval_task is not None + and not self._ping_interval_task.cancelled() ): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() - if self._reconnection_task is not None and not self._reconnection_task.done( - ): + if self._reconnection_task is not None and not self._reconnection_task.done(): self._reconnection_task.cancel() # Wait for the reconnection task to be done which should be soon. try: - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( + if ( + self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled() ): await asyncio.wait_for( self._reconnection_task_future, @@ -788,9 +796,7 @@ async def drain(self) -> None: self._status = Client.DRAINING_SUBS try: - await asyncio.wait_for( - drain_is_done, self.options["drain_timeout"] - ) + await asyncio.wait_for(drain_is_done, self.options["drain_timeout"]) except asyncio.TimeoutError: drain_is_done.exception() drain_is_done.cancel() @@ -805,9 +811,9 @@ async def drain(self) -> None: async def publish( self, subject: str, - payload: bytes = b'', - reply: str = '', - headers: Optional[Dict[str, str]] = None + payload: bytes = b"", + reply: str = "", + headers: Optional[Dict[str, str]] = None, ) -> None: """ Publishes a NATS message. @@ -861,16 +867,17 @@ async def main(): payload_size = len(payload) if not self.is_connected: - if self._max_pending_size <= 0 or payload_size + self._pending_data_size > self._max_pending_size: + if ( + self._max_pending_size <= 0 + or payload_size + self._pending_data_size > self._max_pending_size + ): # Cannot publish during a reconnection when the buffering is disabled, # or if pending buffer is already full. raise errors.OutboundBufferLimitError if payload_size > self._max_payload: raise errors.MaxPayloadError - await self._send_publish( - subject, reply, payload, payload_size, headers - ) + await self._send_publish(subject, reply, payload, payload_size, headers) async def _send_publish( self, @@ -900,15 +907,15 @@ async def _send_publish( # Skip empty keys continue hdr.extend(key.encode()) - hdr.extend(b': ') + hdr.extend(b": ") value = v.strip() hdr.extend(value.encode()) hdr.extend(_CRLF_) hdr.extend(_CRLF_) pub_cmd = prot_command.hpub_cmd(subject, reply, hdr, payload) - self.stats['out_msgs'] += 1 - self.stats['out_bytes'] += payload_size + self.stats["out_msgs"] += 1 + self.stats["out_bytes"] += payload_size await self._send_command(pub_cmd) if self._flush_queue is not None and self._flush_queue.empty(): await self._flush_pending() @@ -931,10 +938,10 @@ async def subscribe( If a callback isn't provided, messages can be retrieved via an asynchronous iterator on the returned subscription object. """ - if not subject or (' ' in subject): + if not subject or (" " in subject): raise errors.BadSubjectError - if queue and (' ' in queue): + if queue and (" " in queue): raise errors.BadSubjectError if self.is_closed: @@ -979,17 +986,15 @@ async def _init_request_sub(self) -> None: self._resp_map = {} self._resp_sub_prefix = self._inbox_prefix[:] - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") self._resp_sub_prefix.extend(self._nuid.next()) - self._resp_sub_prefix.extend(b'.') + self._resp_sub_prefix.extend(b".") resp_mux_subject = self._resp_sub_prefix[:] - resp_mux_subject.extend(b'*') - await self.subscribe( - resp_mux_subject.decode(), cb=self._request_sub_callback - ) + resp_mux_subject.extend(b"*") + await self.subscribe(resp_mux_subject.decode(), cb=self._request_sub_callback) async def _request_sub_callback(self, msg: Msg) -> None: - token = msg.subject[len(self._inbox_prefix) + 22 + 2:] + token = msg.subject[len(self._inbox_prefix) + 22 + 2 :] future = self._resp_map.get(token) if not future: @@ -1001,7 +1006,7 @@ async def _request_sub_callback(self, msg: Msg) -> None: async def request( self, subject: str, - payload: bytes = b'', + payload: bytes = b"", timeout: float = 0.5, old_style: bool = False, headers: Optional[Dict[str, Any]] = None, @@ -1014,15 +1019,15 @@ async def request( """ if old_style: # FIXME: Support headers in old style requests. - return await self._request_old_style( - subject, payload, timeout=timeout - ) + return await self._request_old_style(subject, payload, timeout=timeout) else: msg = await self._request_new_style( subject, payload, timeout=timeout, headers=headers ) - if msg.headers and msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if ( + msg.headers + and msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS + ): raise errors.NoRespondersError return msg @@ -1052,9 +1057,7 @@ async def _request_new_style( self._resp_map[token.decode()] = future # Publish the request - await self.publish( - subject, payload, reply=inbox.decode(), headers=headers - ) + await self.publish(subject, payload, reply=inbox.decode(), headers=headers) # Wait for the response or give up on timeout. try: @@ -1062,6 +1065,42 @@ async def _request_new_style( except asyncio.TimeoutError: raise errors.TimeoutError + async def request_many( + self, + subject: str, + payload: bytes = b"", + max_wait: float = 1.5, + max_interval: float | None = None, + max_msgs: int | None = None, + ) -> AsyncIterator[Msg]: + reply_inbox = self.new_inbox() + subscription = await self.subscribe(reply_inbox) + await self.publish(subject, payload, reply=reply_inbox) + + async def generate_messages(): + try: + msg_count = 0 + start_time = time.monotonic() + deadline = start_time + max_wait + + while max_msgs is None or msg_count < max_msgs: + remaining = deadline - time.monotonic() + if remaining <= 0: + break + + timeout = min(remaining, max_interval or remaining) + + try: + msg = await subscription.next_msg(timeout=timeout) + yield msg + msg_count += 1 + except asyncio.TimeoutError: + break + finally: + await subscription.unsubscribe() + + return generate_messages() + def new_inbox(self) -> str: """ new_inbox returns a unique inbox that can be used @@ -1074,7 +1113,7 @@ def new_inbox(self) -> str: msg = sub.next_msg() """ next_inbox = self._inbox_prefix[:] - next_inbox.extend(b'.') + next_inbox.extend(b".") next_inbox.extend(self._nuid.next()) return next_inbox.decode() @@ -1097,8 +1136,7 @@ async def _request_old_style( try: msg = await asyncio.wait_for(future, timeout) if msg.headers: - if msg.headers.get(nats.js.api.Header.STATUS - ) == NO_RESPONDERS_STATUS: + if msg.headers.get(nats.js.api.Header.STATUS) == NO_RESPONDERS_STATUS: raise errors.NoRespondersError return msg except asyncio.TimeoutError: @@ -1198,8 +1236,7 @@ def is_connecting(self) -> bool: @property def is_draining(self) -> bool: return ( - self._status == Client.DRAINING_SUBS - or self._status == Client.DRAINING_PUBS + self._status == Client.DRAINING_SUBS or self._status == Client.DRAINING_PUBS ) @property @@ -1220,11 +1257,11 @@ def connected_server_version(self) -> ServerVersion: def ssl_context(self) -> ssl.SSLContext: ssl_context: Optional[ssl.SSLContext] = None if "tls" in self.options: - ssl_context = self.options.get('tls') + ssl_context = self.options.get("tls") else: ssl_context = ssl.create_default_context() if ssl_context is None: - raise errors.Error('nats: no ssl context provided') + raise errors.Error("nats: no ssl context provided") return ssl_context async def _send_command(self, cmd: bytes, priority: bool = False) -> None: @@ -1233,7 +1270,10 @@ async def _send_command(self, cmd: bytes, priority: bool = False) -> None: else: self._pending.append(cmd) self._pending_data_size += len(cmd) - if self._max_pending_size > 0 and self._pending_data_size > self._max_pending_size: + if ( + self._max_pending_size > 0 + and self._pending_data_size > self._max_pending_size + ): # Only flush force timeout on publish await self._flush_pending(force_flush=True) @@ -1298,10 +1338,14 @@ def _setup_server_pool(self, connect_url: Union[List[str]]) -> None: except ValueError: raise errors.Error("nats: invalid connect url option") # make sure protocols aren't mixed - if not (all(server.uri.scheme in ("nats", "tls") - for server in self._server_pool) - or all(server.uri.scheme in ("ws", "wss") - for server in self._server_pool)): + if not ( + all( + server.uri.scheme in ("nats", "tls") for server in self._server_pool + ) + or all( + server.uri.scheme in ("ws", "wss") for server in self._server_pool + ) + ): raise errors.Error( "nats: mixing of websocket and non websocket URLs is not allowed" ) @@ -1329,8 +1373,10 @@ async def _select_next_server(self) -> None: # Not yet exceeded max_reconnect_attempts so can still use # this server in the future. self._server_pool.append(s) - if s.last_attempt is not None and now < s.last_attempt + self.options[ - "reconnect_time_wait"]: + if ( + s.last_attempt is not None + and now < s.last_attempt + self.options["reconnect_time_wait"] + ): # Backoff connecting to server if we attempted recently. await asyncio.sleep(self.options["reconnect_time_wait"]) try: @@ -1347,13 +1393,13 @@ async def _select_next_server(self) -> None: s.uri, ssl_context=self.ssl_context, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) else: await self._transport.connect( s.uri, buffer_size=DEFAULT_BUFFER_SIZE, - connect_timeout=self.options['connect_timeout'] + connect_timeout=self.options["connect_timeout"], ) self._current_server = s break @@ -1409,7 +1455,9 @@ async def _process_op_err(self, e: Exception) -> None: self._status = Client.RECONNECTING self._ps.reset() - if self._reconnection_task is not None and not self._reconnection_task.cancelled( + if ( + self._reconnection_task is not None + and not self._reconnection_task.cancelled() ): # Cancel the previous task in case it may still be running. self._reconnection_task.cancel() @@ -1424,16 +1472,16 @@ async def _process_op_err(self, e: Exception) -> None: async def _attempt_reconnect(self) -> None: assert self._current_server, "Client.connect must be called first" - if self._reading_task is not None and not self._reading_task.cancelled( - ): + if self._reading_task is not None and not self._reading_task.cancelled(): self._reading_task.cancel() - if self._ping_interval_task is not None and not self._ping_interval_task.cancelled( + if ( + self._ping_interval_task is not None + and not self._ping_interval_task.cancelled() ): self._ping_interval_task.cancel() - if self._flusher_task is not None and not self._flusher_task.cancelled( - ): + if self._flusher_task is not None and not self._flusher_task.cancelled(): self._flusher_task.cancel() if self._transport is not None: @@ -1450,8 +1498,7 @@ async def _attempt_reconnect(self) -> None: if self.is_closed: return - if "dont_randomize" not in self.options or not self.options[ - "dont_randomize"]: + if "dont_randomize" not in self.options or not self.options["dont_randomize"]: shuffle(self._server_pool) # Create a future that the client can use to control waiting @@ -1487,9 +1534,7 @@ async def _attempt_reconnect(self) -> None: # auto unsubscribe the number of messages we have left max_msgs = sub._max_msgs - sub._received - sub_cmd = prot_command.sub_cmd( - sub._subject, sub._queue, sid - ) + sub_cmd = prot_command.sub_cmd(sub._subject, sub._queue, sid) self._transport.write(sub_cmd) if max_msgs > 0: @@ -1525,24 +1570,26 @@ async def _attempt_reconnect(self) -> None: except asyncio.CancelledError: break - if self._reconnection_task_future is not None and not self._reconnection_task_future.cancelled( + if ( + self._reconnection_task_future is not None + and not self._reconnection_task_future.cancelled() ): self._reconnection_task_future.set_result(True) def _connect_command(self) -> bytes: - ''' + """ Generates a JSON string with the params to be used when sending CONNECT to the server. ->> CONNECT {"lang": "python3"} - ''' + """ options = { "verbose": self.options["verbose"], "pedantic": self.options["pedantic"], "lang": __lang__, "version": __version__, - "protocol": PROTOCOL + "protocol": PROTOCOL, } if "headers" in self._server_info: options["headers"] = self._server_info["headers"] @@ -1560,8 +1607,10 @@ def _connect_command(self) -> bytes: options["nkey"] = self._public_nkey # In case there is no password, then consider handle # sending a token instead. - elif self.options["user"] is not None and self.options[ - "password"] is not None: + elif ( + self.options["user"] is not None + and self.options["password"] is not None + ): options["user"] = self.options["user"] options["pass"] = self.options["password"] elif self.options["token"] is not None: @@ -1579,7 +1628,7 @@ def _connect_command(self) -> bytes: options["echo"] = not self.options["no_echo"] connect_opts = json.dumps(options, sort_keys=True) - return b''.join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) + return b"".join([CONNECT_OP + _SPC_ + connect_opts.encode() + _CRLF_]) async def _process_ping(self) -> None: """ @@ -1598,8 +1647,7 @@ async def _process_pong(self) -> None: self._pongs_received += 1 self._pings_outstanding = 0 - def _is_control_message(self, data, header: Dict[str, - str]) -> Optional[str]: + def _is_control_message(self, data, header: Dict[str, str]) -> Optional[str]: if len(data) > 0: return None status = header.get(nats.js.api.Header.STATUS) @@ -1628,9 +1676,9 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # if raw_headers[0] == _SPC_BYTE_: # Special handling for status messages. - line = headers[len(NATS_HDR_LINE) + 1:] + line = headers[len(NATS_HDR_LINE) + 1 :] status = line[:STATUS_MSG_LEN] - desc = line[STATUS_MSG_LEN + 1:len(line) - _CRLF_LEN_ - _CRLF_LEN_] + desc = line[STATUS_MSG_LEN + 1 : len(line) - _CRLF_LEN_ - _CRLF_LEN_] stripped_status = status.strip().decode() # Process as status only when it is a valid integer. @@ -1640,7 +1688,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # Move the raw_headers to end of line i = raw_headers.find(_CRLF_) - raw_headers = raw_headers[i + _CRLF_LEN_:] + raw_headers = raw_headers[i + _CRLF_LEN_ :] if len(desc) > 0: # Heartbeat messages can have both headers and inline status, @@ -1648,9 +1696,7 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: i = desc.find(_CRLF_) if i > 0: hdr[nats.js.api.Header.DESCRIPTION] = desc[:i].decode() - parsed_hdr = self._hdr_parser.parsebytes( - desc[i + _CRLF_LEN_:] - ) + parsed_hdr = self._hdr_parser.parsebytes(desc[i + _CRLF_LEN_ :]) for k, v in parsed_hdr.items(): hdr[k] = v else: @@ -1665,15 +1711,14 @@ async def _process_headers(self, headers) -> Optional[Dict[str, str]]: # # NATS/1.0\r\nfoo: bar\r\nhello: world # - raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_:] + raw_headers = headers[NATS_HDR_LINE_SIZE + _CRLF_LEN_ :] try: if parse_email: parsed_hdr = parse_email(raw_headers).headers else: parsed_hdr = { k.strip(): v.strip() - for k, v in self._hdr_parser.parsebytes(raw_headers - ).items() + for k, v in self._hdr_parser.parsebytes(raw_headers).items() } if hdr: hdr.update(parsed_hdr) @@ -1706,8 +1751,8 @@ async def _process_msg( Process MSG sent by server. """ payload_size = len(data) - self.stats['in_msgs'] += 1 - self.stats['in_bytes'] += payload_size + self.stats["in_msgs"] += 1 + self.stats["in_bytes"] += payload_size sub = self._subs.get(sid) if not sub: @@ -1780,17 +1825,17 @@ async def _process_msg( try: sub._pending_size += payload_size # allow setting pending_bytes_limit to 0 to disable - if sub._pending_bytes_limit > 0 and sub._pending_size >= sub._pending_bytes_limit: + if ( + sub._pending_bytes_limit > 0 + and sub._pending_size >= sub._pending_bytes_limit + ): # Subtract the bytes since the message will be thrown away # so it would not be pending data. sub._pending_size -= payload_size await self._error_cb( errors.SlowConsumerError( - subject=msg.subject, - reply=msg.reply, - sid=sid, - sub=sub + subject=msg.subject, reply=msg.reply, sid=sid, sub=sub ) ) return @@ -1859,23 +1904,26 @@ def _process_info( with latest updates from cluster to enable server discovery. """ assert self._current_server, "Client.connect must be called first" - if 'connect_urls' in info: - if info['connect_urls']: + if "connect_urls" in info: + if info["connect_urls"]: connect_urls = [] - for connect_url in info['connect_urls']: - scheme = '' - if self._current_server.uri.scheme == 'tls': - scheme = 'tls' + for connect_url in info["connect_urls"]: + scheme = "" + if self._current_server.uri.scheme == "tls": + scheme = "tls" else: - scheme = 'nats' + scheme = "nats" uri = urlparse(f"{scheme}://{connect_url}") srv = Srv(uri) srv.discovered = True # Check whether we should reuse the original hostname. - if 'tls_required' in self._server_info and self._server_info['tls_required'] \ - and self._host_is_ip(uri.hostname): + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._host_is_ip(uri.hostname) + ): srv.tls_name = self._current_server.uri.hostname # Filter for any similar server in the server pool already. @@ -1891,7 +1939,11 @@ def _process_info( for srv in connect_urls: self._server_pool.append(srv) - if not initial_connection and connect_urls and self._discovered_server_cb: + if ( + not initial_connection + and connect_urls + and self._discovered_server_cb + ): self._discovered_server_cb() def _host_is_ip(self, connect_url: Optional[str]) -> bool: @@ -1922,13 +1974,13 @@ async def _process_connect_init(self) -> None: else: hostname = self._current_server.uri.hostname - handshake_first = self.options['tls_handshake_first'] + handshake_first = self.options["tls_handshake_first"] if handshake_first: await self._transport.connect_tls( hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) connection_completed = self._transport.readline() @@ -1955,17 +2007,20 @@ async def _process_connect_init(self) -> None: self._process_info(srv_info, initial_connection=True) - if 'version' in self._server_info: - self._current_server.server_version = self._server_info['version'] + if "version" in self._server_info: + self._current_server.server_version = self._server_info["version"] - if 'max_payload' in self._server_info: + if "max_payload" in self._server_info: self._max_payload = self._server_info["max_payload"] - if 'client_id' in self._server_info: + if "client_id" in self._server_info: self._client_id = self._server_info["client_id"] - if 'tls_required' in self._server_info and self._server_info[ - 'tls_required'] and self._current_server.uri.scheme != "ws": + if ( + "tls_required" in self._server_info + and self._server_info["tls_required"] + and self._current_server.uri.scheme != "ws" + ): if not handshake_first: await self._transport.drain() # just in case something is left @@ -1974,7 +2029,7 @@ async def _process_connect_init(self) -> None: hostname, self.ssl_context, DEFAULT_BUFFER_SIZE, - self.options['connect_timeout'], + self.options["connect_timeout"], ) # Refresh state of parser upon reconnect. @@ -1987,9 +2042,7 @@ async def _process_connect_init(self) -> None: await self._transport.drain() if self.options["verbose"]: future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if OK_OP in next_op: # Do nothing pass @@ -2000,15 +2053,13 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for errors.AuthorizationError for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) self._transport.write(PING_PROTO) await self._transport.drain() future = self._transport.readline() - next_op = await asyncio.wait_for( - future, self.options["connect_timeout"] - ) + next_op = await asyncio.wait_for(future, self.options["connect_timeout"]) if PONG_PROTO in next_op: self._status = Client.CONNECTED @@ -2019,14 +2070,12 @@ async def _process_connect_init(self) -> None: # FIXME: Maybe handling could be more special here, # checking for ErrAuthorization for example. # await self._process_err(err_msg) - raise errors.Error("nats: " + err_msg.rstrip('\r\n')) + raise errors.Error("nats: " + err_msg.rstrip("\r\n")) if PONG_PROTO in next_op: self._status = Client.CONNECTED - self._reading_task = asyncio.get_running_loop().create_task( - self._read_loop() - ) + self._reading_task = asyncio.get_running_loop().create_task(self._read_loop()) self._pongs = [] self._pings_outstanding = 0 self._ping_interval_task = asyncio.get_running_loop().create_task( @@ -2034,13 +2083,9 @@ async def _process_connect_init(self) -> None: ) # Task for kicking the flusher queue - self._flusher_task = asyncio.get_running_loop().create_task( - self._flusher() - ) + self._flusher_task = asyncio.get_running_loop().create_task(self._flusher()) - async def _send_ping( - self, future: Optional[asyncio.Future] = None - ) -> None: + async def _send_ping(self, future: Optional[asyncio.Future] = None) -> None: assert self._transport, "Client.connect must be called first" if future is None: future = asyncio.Future() @@ -2085,8 +2130,7 @@ async def _ping_interval(self) -> None: continue try: self._pings_outstanding += 1 - if self._pings_outstanding > self.options[ - "max_outstanding_pings"]: + if self._pings_outstanding > self.options["max_outstanding_pings"]: await self._process_op_err(ErrStaleConnection()) return await self._send_ping() @@ -2124,7 +2168,7 @@ async def _read_loop(self) -> None: except asyncio.CancelledError: break except Exception as ex: - _logger.error('nats: encountered error', exc_info=ex) + _logger.error("nats: encountered error", exc_info=ex) break # except asyncio.InvalidStateError: # pass diff --git a/tests/test_client.py b/tests/test_client.py index 5548aa07..5dce86b0 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -867,6 +867,109 @@ async def worker_handler(msg): await nc.close() + @async_test + async def test_requests_many_max_msgs(self): + nc = await nats.connect() + + async def handler(msg): + for _ in range(0, 10): + await msg.respond(b'OK') + + await nc.subscribe("foo", cb=handler) + + msgs = [] + async for msg in await nc.request_many("foo", b'', max_msgs=5, + max_wait=1.0): + msgs.append(msg) + + assert len(msgs) == 5 + assert all(msg.data == b'OK' for msg in msgs) + await nc.close() + + @async_test + async def test_requests_many_max_wait(self): + nc = await nats.connect() + + async def handler(msg): + await msg.respond(b'OK') + await msg.respond(b'OK') + + await nc.subscribe("foo", cb=handler) + + start_time = time.monotonic() + msgs = [] + async for msg in await nc.request_many("foo", b'', max_wait=0.5): + msgs.append(msg) + end_time = time.monotonic() + + assert 0.5 <= end_time - start_time < 0.6 # Allow small overhead + assert len( + msgs + ) < 5 # Should receive fewer than 5 messages in 0.5 seconds + assert all(msg.data == b'OK' for msg in msgs) + await nc.close() + + @async_test + async def test_requests_many_max_interval(self): + nc = await nats.connect() + responses_sent = 0 + + async def handler(msg): + nonlocal responses_sent + if responses_sent == 0: + await msg.respond(b'OK') + responses_sent += 1 + else: + await asyncio.sleep(0.2) # Delay to exceed max_interval + + await nc.subscribe("foo", cb=handler) + + msgs = [] + async for msg in await nc.request_many("foo", b'', max_interval=0.1, + max_wait=1.0): + msgs.append(msg) + + assert len( + msgs + ) == 1 # Should only receive one message before max_interval is exceeded + assert msgs[0].data == b'OK' + await nc.close() + + @async_test + async def test_requests_many_no_responses(self): + nc = await nats.connect() + + async def handler(msg): + pass # No response + + await nc.subscribe("foo", cb=handler) + + msgs = [] + async for msg in await nc.request_many("foo", b'', max_wait=0.5): + msgs.append(msg) + + assert len(msgs) == 0 # Should receive no messages + await nc.close() + + @async_test + async def test_requests_many_unsubscribe_during_iteration(self): + nc = await nats.connect() + sub = await nc.subscribe("foo") + + async def handler(msg): + await msg.respond(b'OK') + await sub.unsubscribe() + + await nc.subscribe("foo", cb=handler) + + msgs = [] + async for msg in await nc.request_many("foo", b'', max_wait=1.0): + msgs.append(msg) + + assert len(msgs) == 1 + assert msgs[0].data == b'OK' + await nc.close() + @async_test async def test_custom_inbox_prefix(self): nc = NATS()