diff --git a/tests/test_socket.py b/tests/test_socket.py index dedc722d3..117caa07c 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -560,6 +560,7 @@ def test_shadow(self): assert s2._shadow_obj is s assert s.underlying != p.underlying assert s2.underlying == s.underlying + assert s2.context is s.context s3 = zmq.Socket(s) assert s3._shadow_obj is s assert s3.underlying == s.underlying diff --git a/zmq/sugar/socket.py b/zmq/sugar/socket.py index 0f10f3d7b..2d8ee018a 100644 --- a/zmq/sugar/socket.py +++ b/zmq/sugar/socket.py @@ -133,6 +133,7 @@ def __init__( shadow: Socket | int = 0, copy_threshold: int | None = None, ): + shadow_context: zmq.Context | None = None if isinstance(ctx_or_socket, zmq.Socket): # positional Socket(other_socket) shadow = ctx_or_socket @@ -145,6 +146,8 @@ def __init__( # hold a reference to the shadow object self._shadow_obj = shadow if not isinstance(shadow, int): + if isinstance(shadow, zmq.Socket): + shadow_context = shadow.context try: shadow = cast(int, shadow.underlying) except AttributeError: @@ -159,6 +162,9 @@ def __init__( shadow=shadow_address, copy_threshold=copy_threshold, ) + if self._shadow_obj and shadow_context: + # keep self.context reference if shadowing a Socket object + self.context = shadow_context try: socket_type = cast(int, self.get(zmq.TYPE))