@@ -71,10 +71,14 @@ def __init__(
7171 h2 .events .ResponseReceived
7272 | h2 .events .DataReceived
7373 | h2 .events .StreamEnded
74- | h2 .events .StreamReset ,
74+ | h2 .events .StreamReset
75+ | h2 .events .TrailersReceived ,
7576 ],
7677 ] = {}
7778
79+ # Mapping from stream ID to trailing headers
80+ self ._trailing_headers : dict [int , list [tuple [bytes , bytes ]]] = {}
81+
7882 # Connection terminated events are stored as state since
7983 # we need to handle them for all streams.
8084 self ._connection_terminated : h2 .events .ConnectionTerminated | None = None
@@ -152,16 +156,24 @@ async def handle_async_request(self, request: Request) -> Response:
152156 )
153157 trace .return_value = (status , headers )
154158
155- return Response (
159+ extensions = {
160+ "http_version" : b"HTTP/2" ,
161+ "network_stream" : self ._network_stream ,
162+ "stream_id" : stream_id ,
163+ }
164+
165+ http2_stream = HTTP2ConnectionByteStream (self , request , stream_id = stream_id )
166+
167+ response = Response (
156168 status = status ,
157169 headers = headers ,
158- content = HTTP2ConnectionByteStream (self , request , stream_id = stream_id ),
159- extensions = {
160- "http_version" : b"HTTP/2" ,
161- "network_stream" : self ._network_stream ,
162- "stream_id" : stream_id ,
163- },
170+ content = http2_stream ,
171+ extensions = extensions ,
164172 )
173+
174+ http2_stream .set_response (response )
175+
176+ return response
165177 except BaseException as exc : # noqa: PIE786
166178 with AsyncShieldCancellation ():
167179 kwargs = {"stream_id" : stream_id }
@@ -321,12 +333,21 @@ async def _receive_response_body(
321333 self ._h2_state .acknowledge_received_data (amount , stream_id )
322334 await self ._write_outgoing_data (request )
323335 yield event .data
336+ elif isinstance (event , h2 .events .TrailersReceived ):
337+ # Process trailing headers but continue receiving events
338+ # The trailing headers are already stored in self._trailing_headers
339+ continue
324340 elif isinstance (event , h2 .events .StreamEnded ):
325341 break
326342
327343 async def _receive_stream_event (
328344 self , request : Request , stream_id : int
329- ) -> h2 .events .ResponseReceived | h2 .events .DataReceived | h2 .events .StreamEnded :
345+ ) -> (
346+ h2 .events .ResponseReceived
347+ | h2 .events .DataReceived
348+ | h2 .events .StreamEnded
349+ | h2 .events .TrailersReceived
350+ ):
330351 """
331352 Return the next available event for a given stream ID.
332353
@@ -377,10 +398,19 @@ async def _receive_events(
377398 h2 .events .DataReceived ,
378399 h2 .events .StreamEnded ,
379400 h2 .events .StreamReset ,
401+ h2 .events .TrailersReceived ,
380402 ),
381403 ):
382404 if event .stream_id in self ._events :
383405 self ._events [event .stream_id ].append (event )
406+ if isinstance (event , h2 .events .TrailersReceived ):
407+ self ._trailing_headers [event .stream_id ] = []
408+ if event .headers is not None :
409+ for k , v in event .headers :
410+ if not k .startswith (b":" ):
411+ self ._trailing_headers [
412+ event .stream_id
413+ ].append ((k , v ))
384414
385415 elif isinstance (event , h2 .events .ConnectionTerminated ):
386416 self ._connection_terminated = event
@@ -409,6 +439,8 @@ async def _receive_remote_settings_change(
409439 async def _response_closed (self , stream_id : int ) -> None :
410440 await self ._max_streams_semaphore .release ()
411441 del self ._events [stream_id ]
442+ if stream_id in self ._trailing_headers :
443+ del self ._trailing_headers [stream_id ]
412444 async with self ._state_lock :
413445 if self ._connection_terminated and not self ._events :
414446 await self .aclose ()
@@ -567,6 +599,10 @@ def __init__(
567599 self ._request = request
568600 self ._stream_id = stream_id
569601 self ._closed = False
602+ self ._response : Response | None = None
603+
604+ def set_response (self , response : Response ) -> None :
605+ self ._response = response
570606
571607 async def __aiter__ (self ) -> typing .AsyncIterator [bytes ]:
572608 kwargs = {"request" : self ._request , "stream_id" : self ._stream_id }
@@ -576,6 +612,14 @@ async def __aiter__(self) -> typing.AsyncIterator[bytes]:
576612 request = self ._request , stream_id = self ._stream_id
577613 ):
578614 yield chunk
615+
616+ if (
617+ self ._response is not None
618+ and self ._stream_id in self ._connection ._trailing_headers
619+ ):
620+ self ._response .extensions ["trailing_headers" ] = (
621+ self ._connection ._trailing_headers [self ._stream_id ]
622+ )
579623 except BaseException as exc :
580624 # If we get an exception while streaming the response,
581625 # we want to close the response (and possibly the connection)
0 commit comments