From 0f41f2e281aa994858ba98245aea98af0b28c037 Mon Sep 17 00:00:00 2001 From: ljwoods2 Date: Wed, 27 Nov 2024 14:28:17 -0700 Subject: [PATCH 1/3] better error messages, context manager --- imdclient/IMDClient.py | 83 +++++++++++++++++++++++-------- imdclient/IMDREADER.py | 5 +- imdclient/tests/test_imdclient.py | 51 +++++++++++++++++++ imdclient/tests/test_imdreader.py | 20 ++++++-- 4 files changed, 132 insertions(+), 27 deletions(-) diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index 359f9d2..3e4fb6a 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -43,6 +43,8 @@ class IMDClient: Size of the socket buffer in bytes. Default is to use the system default buffer_size : int (optional) IMDFramebuffer will be filled with as many :class:`IMDFrame` fit in `buffer_size` [``10MB``] + timeout : int, optional + Timeout for the socket in seconds [``5``] **kwargs : dict (optional) Additional keyword arguments to pass to the :class:`BaseIMDProducer` and :class:`IMDFrameBuffer` """ @@ -68,8 +70,10 @@ def __init__( n_atoms, **kwargs, ) + self._error_queue = queue.Queue() else: self._buf = None + self._error_queue = None if self._imdsinfo.version == 2: self._producer = IMDProducerV2( self._conn, @@ -77,6 +81,7 @@ def __init__( self._imdsinfo, n_atoms, multithreaded, + self._error_queue, **kwargs, ) elif self._imdsinfo.version == 3: @@ -86,6 +91,7 @@ def __init__( self._imdsinfo, n_atoms, multithreaded, + self._error_queue, **kwargs, ) @@ -103,6 +109,10 @@ def signal_handler(self, sig, frame): def get_imdframe(self): """ + Returns + ------- + IMDFrame + The next frame from the IMD server Raises ------ EOFError @@ -116,6 +126,9 @@ def get_imdframe(self): # and doesn't need to be notified self._disconnect() self._stopped = True + + if self._error_queue.qsize(): + raise EOFError(f"{self._error_queue.get()}") raise EOFError else: try: @@ -125,9 +138,18 @@ def get_imdframe(self): raise EOFError def get_imdsessioninfo(self): + """ + Returns + ------- + IMDSessionInfo + Information about the IMD session + """ return self._imdsinfo def stop(self): + """ + Stop the client and close the connection + """ if self._multithreaded: if not self._stopped: self._buf.notify_consumer_finished() @@ -146,9 +168,7 @@ def _connect_to_server(self, host, port, socket_bufsize): # /proc/sys/net/core/rmem_default # Max (linux): # /proc/sys/net/core/rmem_max - conn.setsockopt( - socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize - ) + conn.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize) try: logger.debug(f"IMDClient: Connecting to {host}:{port}") conn.connect((host, port)) @@ -254,6 +274,13 @@ def _disconnect(self): finally: self._conn.close() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False + class BaseIMDProducer(threading.Thread): """ @@ -269,11 +296,14 @@ class BaseIMDProducer(threading.Thread): Information about the IMD session n_atoms : int Number of atoms in the simulation - multithreaded : bool, optional + multithreaded : bool If True, socket interaction will occur in a separate thread & frames will be buffered. Single-threaded, blocking IMDClient - should only be used in testing [[``True``]] - + should only be used in testing + error_queue: queue.Queue + Queue to hold errors produced by the producer thread + timeout : int, optional + Timeout for the socket in seconds [``5``] """ def __init__( @@ -282,7 +312,8 @@ def __init__( buffer, sinfo, n_atoms, - multithreaded=True, + multithreaded, + error_queue, timeout=5, **kwargs, ): @@ -291,6 +322,7 @@ def __init__( self._imdsinfo = sinfo self._paused = False + self.error_queue = error_queue # Timeout for first frame should be longer # than rest of frames self._timeout = timeout @@ -385,6 +417,7 @@ def run(self): logger.debug("IMDProducer: Simulation ended normally, cleaning up") except Exception as e: logger.debug("IMDProducer: An unexpected error occurred: %s", e) + self.error_queue.put(e) finally: logger.debug("IMDProducer: Stopping run loop") # Tell consumer not to expect more frames to be added @@ -400,9 +433,19 @@ def _expect_header(self, expected_type, expected_value=None): ) # Sometimes we do not care what the value is if expected_value is not None and header.length != expected_value: - raise RuntimeError( - f"IMDProducer: Expected header value {expected_value}, got {header.length}" - ) + if expected_type in [ + IMDHeaderType.IMD_FCOORDS, + IMDHeaderType.IMD_VELOCITIES, + IMDHeaderType.IMD_FORCES, + ]: + raise RuntimeError( + f"IMDProducer: Expected n_atoms value {expected_value}, got {header.length}. " + + "Ensure you are using the correct topology file." + ) + else: + raise RuntimeError( + f"IMDProducer: Expected header value {expected_value}, got {header.length}" + ) def _get_header(self): self._read(self._header) @@ -422,9 +465,11 @@ def _read(self, buf): class IMDProducerV2(BaseIMDProducer): - def __init__(self, conn, buffer, sinfo, n_atoms, multithreaded, **kwargs): + def __init__( + self, conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs + ): super(IMDProducerV2, self).__init__( - conn, buffer, sinfo, n_atoms, multithreaded, **kwargs + conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs ) self._energies = bytearray(IMDENERGYPACKETLENGTH) @@ -517,6 +562,7 @@ def __init__( sinfo, n_atoms, multithreaded, + error_queue, **kwargs, ): super(IMDProducerV3, self).__init__( @@ -525,6 +571,7 @@ def __init__( sinfo, n_atoms, multithreaded, + error_queue, **kwargs, ) # The body of an x/v/f packet should contain @@ -682,9 +729,7 @@ def __init__( raise ValueError("pause_empty_proportion must be between 0 and 1") self._pause_empty_proportion = pause_empty_proportion if unpause_empty_proportion < 0 or unpause_empty_proportion > 1: - raise ValueError( - "unpause_empty_proportion must be between 0 and 1" - ) + raise ValueError("unpause_empty_proportion must be between 0 and 1") self._unpause_empty_proportion = unpause_empty_proportion if buffer_size <= 0: @@ -751,9 +796,7 @@ def wait_for_space(self): logger.debug("IMDProducer: Noticing consumer finished") raise EOFError except Exception as e: - logger.debug( - f"IMDProducer: Error waiting for space in buffer: {e}" - ) + logger.debug(f"IMDProducer: Error waiting for space in buffer: {e}") def pop_empty_imdframe(self): logger.debug("IMDProducer: Getting empty frame") @@ -799,9 +842,7 @@ def pop_full_imdframe(self): imdf = self._full_q.get() else: with self._full_imdf_avail: - while ( - self._full_q.qsize() == 0 and not self._producer_finished - ): + while self._full_q.qsize() == 0 and not self._producer_finished: self._full_imdf_avail.wait() if self._producer_finished and self._full_q.qsize() == 0: diff --git a/imdclient/IMDREADER.py b/imdclient/IMDREADER.py index 96073d5..e9d7062 100644 --- a/imdclient/IMDREADER.py +++ b/imdclient/IMDREADER.py @@ -83,9 +83,8 @@ def _read_frame(self, frame): try: imdf = self._imdclient.get_imdframe() - except EOFError: - # Not strictly necessary, but for clarity - raise StopIteration + except EOFError as e: + raise e self._frame = frame self._load_imdframe_into_ts(imdf) diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 3880665..c8fe9b2 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -148,3 +148,54 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf): client.get_imdframe() # server should receive disconnect from client (though it doesn't have to do anything) server.expect_packet(IMDHeaderType.IMD_DISCONNECT) + + +class TestIMDClientV3ContextManager: + @pytest.fixture + def port(self): + return get_free_port() + + @pytest.fixture + def universe(self): + return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD) + + @pytest.fixture + def imdsinfo(self): + return create_default_imdsinfo_v3() + + @pytest.fixture + def server(self, universe, imdsinfo, port): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + yield server + server.cleanup() + + def test_context_manager_traj_unchanged(self, server, port, universe): + server.handshake_sequence("localhost", port, first_frame=False) + + i = 0 + with IMDClient( + "localhost", + port, + universe.trajectory.n_atoms, + ) as client: + server.send_frames(0, 5) + while i < 5: + + imdf = client.get_imdframe() + assert_allclose(universe.trajectory[i].time, imdf.time) + assert_allclose(universe.trajectory[i].dt, imdf.dt) + assert_allclose(universe.trajectory[i].data["step"], imdf.step) + assert_allclose( + universe.trajectory[i].positions, imdf.positions + ) + assert_allclose( + universe.trajectory[i].velocities, imdf.velocities + ) + assert_allclose(universe.trajectory[i].forces, imdf.forces) + assert_allclose( + universe.trajectory[i].triclinic_dimensions, imdf.box + ) + i += 1 + server.expect_packet(IMDHeaderType.IMD_DISCONNECT) + assert i == 5 diff --git a/imdclient/tests/test_imdreader.py b/imdclient/tests/test_imdreader.py index a7a7f4c..bfe4a34 100644 --- a/imdclient/tests/test_imdreader.py +++ b/imdclient/tests/test_imdreader.py @@ -171,9 +171,7 @@ def test_total_time(self, reader, ref): decimal=ref.prec, ) - @pytest.mark.skip( - reason="Stream-based reader can only be read iteratively" - ) + @pytest.mark.skip(reason="Stream-based reader can only be read iteratively") def test_changing_dimensions(self, ref, reader): if ref.changing_dimensions: reader.rewind() @@ -642,3 +640,19 @@ def test_subslice_fi_all_after_iteration_raises_error(self, reader): with pytest.raises(RuntimeError): for ts in sub_sliced_reader: pass + + +def test_n_atoms_mismatch(): + universe = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD) + port = get_free_port() + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(create_default_imdsinfo_v3()) + server.handshake_sequence("localhost", port, first_frame=True) + with pytest.raises( + EOFError, + match="IMDProducer: Expected n_atoms value 6, got 5. Ensure you are using the correct topology file.", + ): + IMDReader( + f"imd://localhost:{port}", + n_atoms=universe.trajectory.n_atoms + 1, + ) From 1d8eab455057fe5cb2123ba3470a546a61641355 Mon Sep 17 00:00:00 2001 From: ljwoods2 Date: Thu, 28 Nov 2024 13:21:40 -0700 Subject: [PATCH 2/3] jupyter signal intercept --- imdclient/IMDClient.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index 3e4fb6a..fa633cd 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -98,14 +98,42 @@ def __init__( self._go() if self._multithreaded: + # Disconnect MUST occur. This covers typical cases signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + + # Disconnect MUST occur. This covers Jupyter + Ipython cases + # since in jupyter, the signal handler is reset to the default + # by pre- and post- hooks + # https://stackoverflow.com/questions/70841648/jupyter-reverts-signal-handler-to-default-when-running-next-cell + try: + import IPython + except ImportError: + has_ipython = False + else: + has_ipython = True + + if has_ipython: + try: + from IPython import get_ipython + + if get_ipython() is not None: + kernel = get_ipython().kernel + kernel.pre_handler_hook = lambda: None + kernel.post_handler_hook = lambda: None + logger.debug("Running in Jupyter") + except NameError: + logger.debug("Running in non-jupyter IPython environment") + self._producer.start() def signal_handler(self, sig, frame): """Catch SIGINT to allow clean shutdown on CTRL+C This also ensures that main thread execution doesn't get stuck waiting in buf.pop_full_imdframe()""" + logger.debug("Intercepted signal") self.stop() + logger.debug("Shutdown success") def get_imdframe(self): """ @@ -152,9 +180,9 @@ def stop(self): """ if self._multithreaded: if not self._stopped: + self._stopped = True self._buf.notify_consumer_finished() self._disconnect() - self._stopped = True else: self._disconnect() From 746d94c742d72172fe40eaa0e0084ce1303d9467 Mon Sep 17 00:00:00 2001 From: ljwoods2 Date: Thu, 28 Nov 2024 22:23:09 -0700 Subject: [PATCH 3/3] atexit --- imdclient/IMDClient.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index c8bbd77..f158e2a 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -25,6 +25,7 @@ import numpy as np from typing import Union, Dict import signal +import atexit logger = logging.getLogger(__name__) @@ -98,11 +99,11 @@ def __init__( self._go() if self._multithreaded: - # Disconnect MUST occur. This covers typical cases + # Disconnect MUST occur. This covers typical cases (Python, IPython interpreter) signal.signal(signal.SIGINT, self.signal_handler) signal.signal(signal.SIGTERM, self.signal_handler) - # Disconnect MUST occur. This covers Jupyter + Ipython cases + # Disconnect and socket shutdown MUST occur. This covers Jupyter use # since in jupyter, the signal handler is reset to the default # by pre- and post- hooks # https://stackoverflow.com/questions/70841648/jupyter-reverts-signal-handler-to-default-when-running-next-cell @@ -125,6 +126,10 @@ def __init__( except NameError: logger.debug("Running in non-jupyter IPython environment") + # Final case: error is raised outside of IMDClient code + logger.debug("Registering atexit") + atexit.register(self.stop) + self._producer.start() def signal_handler(self, sig, frame): @@ -181,8 +186,8 @@ def stop(self): if self._multithreaded: if not self._stopped: self._stopped = True - self._buf.notify_consumer_finished() self._disconnect() + self._buf.notify_consumer_finished() else: self._disconnect()