Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ConnectOptions dataclass #610

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 129 additions & 48 deletions nats/aio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@

import asyncio
import base64
import inspect
import ipaddress
import json
import logging
import ssl
import string
import time
from collections import UserString
from dataclasses import dataclass
from dataclasses import dataclass, field, replace, fields
from email.parser import BytesParser
from io import BytesIO
from pathlib import Path
Expand Down Expand Up @@ -182,7 +183,45 @@ async def _default_error_callback(ex: Exception) -> None:
Provides a default way to handle async errors if the user
does not provide one.
"""
_logger.error("nats: encountered error", exc_info=ex)
_logger.error('nats: encountered error', exc_info=ex)


@dataclass
class ConnectOptions:
servers: Union[str, List[str]] = field(default_factory=lambda: ["nats://localhost:4222"])
error_cb: Optional[ErrorCallback] = None
disconnected_cb: Optional[Callback] = None
closed_cb: Optional[Callback] = None
discovered_server_cb: Optional[Callback] = None
reconnected_cb: Optional[Callback] = None
name: Optional[str] = None
pedantic: bool = False
verbose: bool = False
allow_reconnect: bool = True
connect_timeout: int = DEFAULT_CONNECT_TIMEOUT
reconnect_time_wait: int = DEFAULT_RECONNECT_TIME_WAIT
max_reconnect_attempts: int = DEFAULT_MAX_RECONNECT_ATTEMPTS
ping_interval: int = DEFAULT_PING_INTERVAL
max_outstanding_pings: int = DEFAULT_MAX_OUTSTANDING_PINGS
dont_randomize: bool = False
flusher_queue_size: int = DEFAULT_MAX_FLUSHER_QUEUE_SIZE
no_echo: bool = False
tls: Optional[ssl.SSLContext] = None
tls_hostname: Optional[str] = None
tls_handshake_first: bool = False
user: Optional[str] = None
password: Optional[str] = None
token: Optional[str] = None
drain_timeout: int = DEFAULT_DRAIN_TIMEOUT
signature_cb: Optional[SignatureCallback] = None
user_jwt_cb: Optional[JWTCallback] = None
user_credentials: Optional[Credentials] = None
nkeys_seed: Optional[str] = None
nkeys_seed_str: Optional[str] = None
inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX
pending_size: int = DEFAULT_PENDING_SIZE
flush_timeout: Optional[float] = None



class Client:
Expand Down Expand Up @@ -319,6 +358,7 @@ async def connect(
inbox_prefix: Union[str, bytes] = DEFAULT_INBOX_PREFIX,
pending_size: int = DEFAULT_PENDING_SIZE,
flush_timeout: Optional[float] = None,
config: Optional[ConnectOptions] = None
) -> None:
"""
Establishes a connection to NATS.
Expand Down Expand Up @@ -409,58 +449,84 @@ async def subscribe_handler(msg):

"""

# Get the signature of the connect method
sig = inspect.signature(self.connect)

# Get the default values from the signature
default_values = {
k: v.default
for k, v in sig.parameters.items()
if v.default is not inspect.Parameter.empty
}

# Create a dictionary of the arguments and their values
kwargs = {k: v for k, v in locals().items() if k != "self"}

# Extract the config object from kwargs
config = kwargs.pop("config", None)

# Override only if the value differs from the default
kwargs = {
k: v
for k, v in kwargs.items()
if k in default_values and v != default_values[k]
}

config = self._merge_config(config, **kwargs)

# Set up callbacks
for cb in [
error_cb,
disconnected_cb,
closed_cb,
reconnected_cb,
discovered_server_cb,
config.error_cb,
config.disconnected_cb,
config.closed_cb,
config.reconnected_cb,
config.discovered_server_cb,
]:
if cb and not asyncio.iscoroutinefunction(cb):
raise errors.InvalidCallbackTypeError

self._setup_server_pool(servers)
self._error_cb = error_cb or _default_error_callback
self._closed_cb = closed_cb
self._discovered_server_cb = discovered_server_cb
self._reconnected_cb = reconnected_cb
self._disconnected_cb = disconnected_cb
self._setup_server_pool(config.servers)
self._error_cb = config.error_cb or _default_error_callback
self._closed_cb = config.closed_cb
self._discovered_server_cb = config.discovered_server_cb
self._reconnected_cb = config.reconnected_cb
self._disconnected_cb = config.disconnected_cb

# Custom inbox prefix
if isinstance(inbox_prefix, str):
inbox_prefix = inbox_prefix.encode()
assert isinstance(inbox_prefix, bytes)
self._inbox_prefix = bytearray(inbox_prefix)
if isinstance(config.inbox_prefix, str):
config.inbox_prefix = config.inbox_prefix.encode()
assert isinstance(config.inbox_prefix, bytes)
self._inbox_prefix = bytearray(config.inbox_prefix)

# NKEYS support
self._signature_cb = signature_cb
self._user_jwt_cb = user_jwt_cb
self._user_credentials = user_credentials
self._nkeys_seed = nkeys_seed
self._nkeys_seed_str = nkeys_seed_str
self._signature_cb = config.signature_cb
self._user_jwt_cb = config.user_jwt_cb
self._user_credentials = config.user_credentials
self._nkeys_seed = config.nkeys_seed
self._nkeys_seed_str = config.nkeys_seed_str

# Customizable options
self.options["verbose"] = verbose
self.options["pedantic"] = pedantic
self.options["name"] = name
self.options["allow_reconnect"] = allow_reconnect
self.options["dont_randomize"] = dont_randomize
self.options["reconnect_time_wait"] = reconnect_time_wait
self.options["max_reconnect_attempts"] = max_reconnect_attempts
self.options["ping_interval"] = ping_interval
self.options["max_outstanding_pings"] = max_outstanding_pings
self.options["no_echo"] = no_echo
self.options["user"] = user
self.options["password"] = password
self.options["token"] = token
self.options["connect_timeout"] = connect_timeout
self.options["drain_timeout"] = drain_timeout
self.options["tls_handshake_first"] = tls_handshake_first

if tls:
self.options["tls"] = tls
if tls_hostname:
self.options["tls_hostname"] = tls_hostname
self.options["verbose"] = config.verbose
self.options["pedantic"] = config.pedantic
self.options["name"] = config.name
self.options["allow_reconnect"] = config.allow_reconnect
self.options["dont_randomize"] = config.dont_randomize
self.options["reconnect_time_wait"] = config.reconnect_time_wait
self.options["max_reconnect_attempts"] = config.max_reconnect_attempts
self.options["ping_interval"] = config.ping_interval
self.options["max_outstanding_pings"] = config.max_outstanding_pings
self.options["no_echo"] = config.no_echo
self.options["user"] = config.user
self.options["password"] = config.password
self.options["token"] = config.token
self.options["connect_timeout"] = config.connect_timeout
self.options["drain_timeout"] = config.drain_timeout
self.options["tls_handshake_first"] = config.tls_handshake_first

if config.tls:
self.options["tls"] = config.tls
if config.tls_hostname:
self.options["tls_hostname"] = config.tls_hostname

# Check if the username or password was set in the server URI
server_auth_configured = False
Expand All @@ -469,7 +535,7 @@ async def subscribe_handler(msg):
if server.uri.username or server.uri.password:
server_auth_configured = True
break
if user or password or token or server_auth_configured:
if config.user or config.password or config.token or server_auth_configured:
self._auth_configured = True

if (self._user_credentials is not None or self._nkeys_seed is not None
Expand All @@ -478,13 +544,13 @@ async def subscribe_handler(msg):
self._setup_nkeys_connect()

# Queue used to trigger flushes to the socket.
self._flush_queue = asyncio.Queue(maxsize=flusher_queue_size)
self._flush_queue = asyncio.Queue(maxsize=config.flusher_queue_size)

# Max size of buffer used for flushing commands to the server.
self._max_pending_size = pending_size
self._max_pending_size = config.pending_size

# Max duration for a force flush (happens when a buffer is full).
self._flush_timeout = flush_timeout
self._flush_timeout = config.flush_timeout

if self.options["dont_randomize"] is False:
shuffle(self._server_pool)
Expand Down Expand Up @@ -517,6 +583,21 @@ async def subscribe_handler(msg):
self._current_server.last_attempt = time.monotonic()
self._current_server.reconnects += 1

def _merge_config(
self, config: Optional[ConnectOptions], **kwargs
) -> ConnectOptions:
if not config:
config = ConnectOptions()

defaults = {f.name: f.default for f in fields(ConnectOptions)}

# Override only if the value differs from the default
updated = {
k: v for k, v in kwargs.items() if k in defaults and v != defaults[k]
}

return replace(config, **updated)

def _setup_nkeys_connect(self) -> None:
if self._user_credentials is not None:
self._setup_nkeys_jwt_connect()
Expand Down Expand Up @@ -1265,7 +1346,7 @@ async def _flush_pending(
except asyncio.CancelledError:
pass

def _setup_server_pool(self, connect_url: Union[List[str]]) -> None:
def _setup_server_pool(self, connect_url: Union[str, List[str]]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intentional change?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this throws a typing error -- Union[List[str]] is just List[str], and _setup_server_pool accepts both str and List[str] (via isinstance(...).

It is not exactly related, but is in the connect call and should be fixed... do you think it belongs in a separate PR?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, just checking. I'm doing a lint fix PR so that should take care of it. Can just leave it in for now.

if isinstance(connect_url, str):
try:
if "nats://" in connect_url or "tls://" in connect_url:
Expand Down