diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 883e8afc..03e384ad 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -536,7 +536,7 @@ async def https( source_port: int = 0, # pylint: disable=W0613 one_rr_per_rrset: bool = False, ignore_trailing: bool = False, - client: Optional["httpx.AsyncClient"] = None, + client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None, path: str = "/dns-query", post: bool = True, verify: Union[bool, str] = True, @@ -591,6 +591,9 @@ async def https( parsed.hostname, family # pyright: ignore ) bootstrap_address = random.choice(list(answers.addresses())) + if client and not isinstance(client, dns.quic.AsyncQuicConnection): # pyright: ignore + raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.") + assert client is None or isinstance(client, dns.quic.AsyncQuicConnection) return await _http3( q, bootstrap_address, @@ -603,13 +606,14 @@ async def https( ignore_trailing, verify=verify, post=post, + connection=client, ) if not have_doh: raise NoDOH # pragma: no cover # pylint: disable=possibly-used-before-assignment if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore - raise ValueError("session parameter must be an httpx.AsyncClient") + raise ValueError("client parameter must be an httpx.AsyncClient") # pylint: enable=possibly-used-before-assignment wire = q.to_wire() @@ -711,6 +715,7 @@ async def _http3( backend: Optional[dns.asyncbackend.Backend] = None, hostname: Optional[str] = None, post: bool = True, + connection: Optional[dns.quic.AsyncQuicConnection] = None, ) -> dns.message.Message: if not dns.quic.have_quic: raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover @@ -722,15 +727,25 @@ async def _http3( q.id = 0 wire = q.to_wire() - (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) async with cfactory() as context: async with mfactory( context, verify_mode=verify, server_name=hostname, h3=True ) as the_manager: - the_connection = the_manager.connect(where, port, source, source_port) + if connection: + the_connection = connection + else: + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) - stream = await the_connection.make_stream(timeout) + stream = await the_connection.make_stream(timeout) # pyright: ignore async with stream: # note that send_h3() does not need await stream.send_h3(url, wire, post) diff --git a/dns/query.py b/dns/query.py index b7ebe1ec..b81ffd18 100644 --- a/dns/query.py +++ b/dns/query.py @@ -491,6 +491,8 @@ def https( assert parsed.hostname is not None # pyright: ignore answers = resolver.resolve_name(parsed.hostname, family) # pyright: ignore bootstrap_address = random.choice(list(answers.addresses())) + if session and not isinstance(session, dns.quic.SyncQuicConnection): # pyright: ignore + raise ValueError("session parameter must be a dns.quic.SyncQuicConnection.") return _http3( q, bootstrap_address, @@ -503,6 +505,7 @@ def https( ignore_trailing, verify=verify, post=post, + connection=session, ) if not have_doh: @@ -629,6 +632,7 @@ def _http3( verify: Union[bool, str] = True, hostname: Optional[str] = None, post: bool = True, + connection: Optional[dns.quic.SyncQuicConnection] = None, ) -> dns.message.Message: if not dns.quic.have_quic: raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover @@ -640,14 +644,25 @@ def _http3( q.id = 0 wire = q.to_wire() - manager = dns.quic.SyncQuicManager( - verify_mode=verify, server_name=hostname, h3=True # pyright: ignore - ) + the_connection: dns.quic.SyncQuicConnection + the_manager: dns.quic.SyncQuicManager + if connection: + manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) + else: + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=hostname, h3=True # pyright: ignore + ) + the_manager = manager # for type checking happiness with manager: - connection = manager.connect(where, port, source, source_port) + if connection: + the_connection = connection + else: + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) - with connection.make_stream(timeout) as stream: + with the_connection.make_stream(timeout) as stream: # pyright: ignore stream.send_h3(url, wire, post) wire = stream.receive(_remaining(expiration)) _check_status(stream.headers(), where, wire)