From a52b9fc8cf8561b135cbc843b67956c9fac90f88 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 7 Mar 2025 15:47:54 -0800 Subject: [PATCH 1/3] Add support for persistent H3 connections. --- dns/asyncquery.py | 21 +++++++++++++++++---- dns/query.py | 24 +++++++++++++++++++----- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 883e8afc..1c1c79d2 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -536,7 +536,7 @@ async def https( source_port: int = 0, # pylint: disable=W0613 one_rr_per_rrset: bool = False, ignore_trailing: bool = False, - client: Optional["httpx.AsyncClient"] = None, + client: Optional["httpx.AsyncClient|dns.quic.AsyncQuicConnection"] = None, path: str = "/dns-query", post: bool = True, verify: Union[bool, str] = True, @@ -591,6 +591,8 @@ async def https( parsed.hostname, family # pyright: ignore ) bootstrap_address = random.choice(list(answers.addresses())) + if client and not isinstance(client, dns.quic.AsyncQuicConnection): # pyright: ignore + raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.") return await _http3( q, bootstrap_address, @@ -603,13 +605,14 @@ async def https( ignore_trailing, verify=verify, post=post, + connection=client, ) if not have_doh: raise NoDOH # pragma: no cover # pylint: disable=possibly-used-before-assignment if client and not isinstance(client, httpx.AsyncClient): # pyright: ignore - raise ValueError("session parameter must be an httpx.AsyncClient") + raise ValueError("client parameter must be an httpx.AsyncClient") # pylint: enable=possibly-used-before-assignment wire = q.to_wire() @@ -711,6 +714,7 @@ async def _http3( backend: Optional[dns.asyncbackend.Backend] = None, hostname: Optional[str] = None, post: bool = True, + connection: Optional[dns.quic.AsyncQuicConnection] = None, ) -> dns.message.Message: if not dns.quic.have_quic: raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover @@ -722,13 +726,22 @@ async def _http3( q.id = 0 wire = q.to_wire() - (cfactory, mfactory) = dns.quic.factories_for_backend(backend) + the_connection: dns.quic.AsyncQuicConnection + if connection: + cfactory = dns.quic.null_factory + mfactory = dns.quic.null_factory + the_connection = connection + else: + (cfactory, mfactory) = dns.quic.factories_for_backend(backend) async with cfactory() as context: async with mfactory( context, verify_mode=verify, server_name=hostname, h3=True ) as the_manager: - the_connection = the_manager.connect(where, port, source, source_port) + if not connection: + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) stream = await the_connection.make_stream(timeout) async with stream: diff --git a/dns/query.py b/dns/query.py index b7ebe1ec..661884fe 100644 --- a/dns/query.py +++ b/dns/query.py @@ -491,6 +491,8 @@ def https( assert parsed.hostname is not None # pyright: ignore answers = resolver.resolve_name(parsed.hostname, family) # pyright: ignore bootstrap_address = random.choice(list(answers.addresses())) + if session and not isinstance(session, dns.quic.SyncQuicConnection): # pyright: ignore + raise ValueError("session parameter must be a dns.quic.SyncQuicConnection.") return _http3( q, bootstrap_address, @@ -503,6 +505,7 @@ def https( ignore_trailing, verify=verify, post=post, + connection=session, ) if not have_doh: @@ -629,6 +632,7 @@ def _http3( verify: Union[bool, str] = True, hostname: Optional[str] = None, post: bool = True, + connection: Optional[dns.quic.SyncQuicConnection] = None, ) -> dns.message.Message: if not dns.quic.have_quic: raise NoDOH("DNS-over-HTTP3 is not available.") # pragma: no cover @@ -640,14 +644,24 @@ def _http3( q.id = 0 wire = q.to_wire() - manager = dns.quic.SyncQuicManager( - verify_mode=verify, server_name=hostname, h3=True # pyright: ignore - ) + the_connection: dns.quic.SyncQuicConnection + the_manager: dns.quic.SyncQuicManager + if connection: + manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) + the_connection = connection + else: + manager = dns.quic.SyncQuicManager( + verify_mode=verify, server_name=hostname, h3=True # pyright: ignore + ) + the_manager = manager # for type checking happiness with manager: - connection = manager.connect(where, port, source, source_port) + if not connection: + the_connection = the_manager.connect( # pyright: ignore + where, port, source, source_port + ) (start, expiration) = _compute_times(timeout) - with connection.make_stream(timeout) as stream: + with the_connection.make_stream(timeout) as stream: stream.send_h3(url, wire, post) wire = stream.receive(_remaining(expiration)) _check_status(stream.headers(), where, wire) From 45c2e3fdcb7ff69f019218c71ef65c5bc98c0d52 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Fri, 7 Mar 2025 15:52:35 -0800 Subject: [PATCH 2/3] Make mypy happy. --- dns/asyncquery.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 1c1c79d2..24dedfa3 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -593,6 +593,7 @@ async def https( bootstrap_address = random.choice(list(answers.addresses())) if client and not isinstance(client, dns.quic.AsyncQuicConnection): # pyright: ignore raise ValueError("client parameter must be a dns.quic.AsyncQuicConnection.") + assert client is None or isinstance(client, dns.quic.AsyncQuicConnection) return await _http3( q, bootstrap_address, From 4b42667d6f8b173599489de3f6debea1df2abcb7 Mon Sep 17 00:00:00 2001 From: Brian Wellington Date: Mon, 17 Mar 2025 10:29:44 -0700 Subject: [PATCH 3/3] Make pyright happy. --- dns/asyncquery.py | 7 ++++--- dns/query.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 24dedfa3..03e384ad 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -731,7 +731,6 @@ async def _http3( if connection: cfactory = dns.quic.null_factory mfactory = dns.quic.null_factory - the_connection = connection else: (cfactory, mfactory) = dns.quic.factories_for_backend(backend) @@ -739,12 +738,14 @@ async def _http3( async with mfactory( context, verify_mode=verify, server_name=hostname, h3=True ) as the_manager: - if not connection: + if connection: + the_connection = connection + else: the_connection = the_manager.connect( # pyright: ignore where, port, source, source_port ) (start, expiration) = _compute_times(timeout) - stream = await the_connection.make_stream(timeout) + stream = await the_connection.make_stream(timeout) # pyright: ignore async with stream: # note that send_h3() does not need await stream.send_h3(url, wire, post) diff --git a/dns/query.py b/dns/query.py index 661884fe..b81ffd18 100644 --- a/dns/query.py +++ b/dns/query.py @@ -648,7 +648,6 @@ def _http3( the_manager: dns.quic.SyncQuicManager if connection: manager: contextlib.AbstractContextManager = contextlib.nullcontext(None) - the_connection = connection else: manager = dns.quic.SyncQuicManager( verify_mode=verify, server_name=hostname, h3=True # pyright: ignore @@ -656,12 +655,14 @@ def _http3( the_manager = manager # for type checking happiness with manager: - if not connection: + if connection: + the_connection = connection + else: the_connection = the_manager.connect( # pyright: ignore where, port, source, source_port ) (start, expiration) = _compute_times(timeout) - with the_connection.make_stream(timeout) as stream: + with the_connection.make_stream(timeout) as stream: # pyright: ignore stream.send_h3(url, wire, post) wire = stream.receive(_remaining(expiration)) _check_status(stream.headers(), where, wire)