diff --git a/imdclient/IMD.py b/imdclient/IMD.py index 96073d5..e9d7062 100644 --- a/imdclient/IMD.py +++ b/imdclient/IMD.py @@ -83,9 +83,8 @@ def _read_frame(self, frame): try: imdf = self._imdclient.get_imdframe() - except EOFError: - # Not strictly necessary, but for clarity - raise StopIteration + except EOFError as e: + raise e self._frame = frame self._load_imdframe_into_ts(imdf) diff --git a/imdclient/IMDClient.py b/imdclient/IMDClient.py index afe3552..f158e2a 100644 --- a/imdclient/IMDClient.py +++ b/imdclient/IMDClient.py @@ -25,6 +25,7 @@ import numpy as np from typing import Union, Dict import signal +import atexit logger = logging.getLogger(__name__) @@ -43,6 +44,8 @@ class IMDClient: Size of the socket buffer in bytes. Default is to use the system default buffer_size : int (optional) IMDFramebuffer will be filled with as many :class:`IMDFrame` fit in `buffer_size` bytes [``10MB``] + timeout : int, optional + Timeout for the socket in seconds [``5``] **kwargs : dict (optional) Additional keyword arguments to pass to the :class:`BaseIMDProducer` and :class:`IMDFrameBuffer` """ @@ -68,8 +71,10 @@ def __init__( n_atoms, **kwargs, ) + self._error_queue = queue.Queue() else: self._buf = None + self._error_queue = None if self._imdsinfo.version == 2: self._producer = IMDProducerV2( self._conn, @@ -77,6 +82,7 @@ def __init__( self._imdsinfo, n_atoms, multithreaded, + self._error_queue, **kwargs, ) elif self._imdsinfo.version == 3: @@ -86,23 +92,60 @@ def __init__( self._imdsinfo, n_atoms, multithreaded, + self._error_queue, **kwargs, ) self._go() if self._multithreaded: + # Disconnect MUST occur. This covers typical cases (Python, IPython interpreter) signal.signal(signal.SIGINT, self.signal_handler) + signal.signal(signal.SIGTERM, self.signal_handler) + + # Disconnect and socket shutdown MUST occur. This covers Jupyter use + # since in jupyter, the signal handler is reset to the default + # by pre- and post- hooks + # https://stackoverflow.com/questions/70841648/jupyter-reverts-signal-handler-to-default-when-running-next-cell + try: + import IPython + except ImportError: + has_ipython = False + else: + has_ipython = True + + if has_ipython: + try: + from IPython import get_ipython + + if get_ipython() is not None: + kernel = get_ipython().kernel + kernel.pre_handler_hook = lambda: None + kernel.post_handler_hook = lambda: None + logger.debug("Running in Jupyter") + except NameError: + logger.debug("Running in non-jupyter IPython environment") + + # Final case: error is raised outside of IMDClient code + logger.debug("Registering atexit") + atexit.register(self.stop) + self._producer.start() def signal_handler(self, sig, frame): """Catch SIGINT to allow clean shutdown on CTRL+C This also ensures that main thread execution doesn't get stuck waiting in buf.pop_full_imdframe()""" + logger.debug("Intercepted signal") self.stop() + logger.debug("Shutdown success") def get_imdframe(self): """ + Returns + ------- + IMDFrame + The next frame from the IMD server Raises ------ EOFError @@ -116,6 +159,9 @@ def get_imdframe(self): # and doesn't need to be notified self._disconnect() self._stopped = True + + if self._error_queue.qsize(): + raise EOFError(f"{self._error_queue.get()}") raise EOFError else: try: @@ -125,14 +171,23 @@ def get_imdframe(self): raise EOFError def get_imdsessioninfo(self): + """ + Returns + ------- + IMDSessionInfo + Information about the IMD session + """ return self._imdsinfo def stop(self): + """ + Stop the client and close the connection + """ if self._multithreaded: if not self._stopped: - self._buf.notify_consumer_finished() - self._disconnect() self._stopped = True + self._disconnect() + self._buf.notify_consumer_finished() else: self._disconnect() @@ -252,6 +307,13 @@ def _disconnect(self): finally: self._conn.close() + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + return False + class BaseIMDProducer(threading.Thread): """ @@ -267,11 +329,14 @@ class BaseIMDProducer(threading.Thread): Information about the IMD session n_atoms : int Number of atoms in the simulation - multithreaded : bool, optional + multithreaded : bool If True, socket interaction will occur in a separate thread & frames will be buffered. Single-threaded, blocking IMDClient - should only be used in testing [[``True``]] - + should only be used in testing + error_queue: queue.Queue + Queue to hold errors produced by the producer thread + timeout : int, optional + Timeout for the socket in seconds [``5``] """ def __init__( @@ -280,7 +345,8 @@ def __init__( buffer, sinfo, n_atoms, - multithreaded=True, + multithreaded, + error_queue, timeout=5, **kwargs, ): @@ -289,6 +355,7 @@ def __init__( self._imdsinfo = sinfo self._paused = False + self.error_queue = error_queue # Timeout for first frame should be longer # than rest of frames self._timeout = timeout @@ -383,6 +450,7 @@ def run(self): logger.debug("IMDProducer: Simulation ended normally, cleaning up") except Exception as e: logger.debug("IMDProducer: An unexpected error occurred: %s", e) + self.error_queue.put(e) finally: logger.debug("IMDProducer: Stopping run loop") # Tell consumer not to expect more frames to be added @@ -398,9 +466,19 @@ def _expect_header(self, expected_type, expected_value=None): ) # Sometimes we do not care what the value is if expected_value is not None and header.length != expected_value: - raise RuntimeError( - f"IMDProducer: Expected header value {expected_value}, got {header.length}" - ) + if expected_type in [ + IMDHeaderType.IMD_FCOORDS, + IMDHeaderType.IMD_VELOCITIES, + IMDHeaderType.IMD_FORCES, + ]: + raise RuntimeError( + f"IMDProducer: Expected n_atoms value {expected_value}, got {header.length}. " + + "Ensure you are using the correct topology file." + ) + else: + raise RuntimeError( + f"IMDProducer: Expected header value {expected_value}, got {header.length}" + ) def _get_header(self): self._read(self._header) @@ -420,9 +498,11 @@ def _read(self, buf): class IMDProducerV2(BaseIMDProducer): - def __init__(self, conn, buffer, sinfo, n_atoms, multithreaded, **kwargs): + def __init__( + self, conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs + ): super(IMDProducerV2, self).__init__( - conn, buffer, sinfo, n_atoms, multithreaded, **kwargs + conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs ) self._energies = bytearray(IMDENERGYPACKETLENGTH) @@ -515,6 +595,7 @@ def __init__( sinfo, n_atoms, multithreaded, + error_queue, **kwargs, ): super(IMDProducerV3, self).__init__( @@ -523,6 +604,7 @@ def __init__( sinfo, n_atoms, multithreaded, + error_queue, **kwargs, ) # The body of an x/v/f packet should contain diff --git a/imdclient/tests/test_imdclient.py b/imdclient/tests/test_imdclient.py index 3880665..c8fe9b2 100644 --- a/imdclient/tests/test_imdclient.py +++ b/imdclient/tests/test_imdclient.py @@ -148,3 +148,54 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf): client.get_imdframe() # server should receive disconnect from client (though it doesn't have to do anything) server.expect_packet(IMDHeaderType.IMD_DISCONNECT) + + +class TestIMDClientV3ContextManager: + @pytest.fixture + def port(self): + return get_free_port() + + @pytest.fixture + def universe(self): + return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD) + + @pytest.fixture + def imdsinfo(self): + return create_default_imdsinfo_v3() + + @pytest.fixture + def server(self, universe, imdsinfo, port): + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(imdsinfo) + yield server + server.cleanup() + + def test_context_manager_traj_unchanged(self, server, port, universe): + server.handshake_sequence("localhost", port, first_frame=False) + + i = 0 + with IMDClient( + "localhost", + port, + universe.trajectory.n_atoms, + ) as client: + server.send_frames(0, 5) + while i < 5: + + imdf = client.get_imdframe() + assert_allclose(universe.trajectory[i].time, imdf.time) + assert_allclose(universe.trajectory[i].dt, imdf.dt) + assert_allclose(universe.trajectory[i].data["step"], imdf.step) + assert_allclose( + universe.trajectory[i].positions, imdf.positions + ) + assert_allclose( + universe.trajectory[i].velocities, imdf.velocities + ) + assert_allclose(universe.trajectory[i].forces, imdf.forces) + assert_allclose( + universe.trajectory[i].triclinic_dimensions, imdf.box + ) + i += 1 + server.expect_packet(IMDHeaderType.IMD_DISCONNECT) + assert i == 5 diff --git a/imdclient/tests/test_imdreader.py b/imdclient/tests/test_imdreader.py index 18279dd..4312e27 100644 --- a/imdclient/tests/test_imdreader.py +++ b/imdclient/tests/test_imdreader.py @@ -171,9 +171,7 @@ def test_total_time(self, reader, ref): decimal=ref.prec, ) - @pytest.mark.skip( - reason="Stream-based reader can only be read iteratively" - ) + @pytest.mark.skip(reason="Stream-based reader can only be read iteratively") def test_changing_dimensions(self, ref, reader): if ref.changing_dimensions: reader.rewind() @@ -642,3 +640,19 @@ def test_subslice_fi_all_after_iteration_raises_error(self, reader): with pytest.raises(RuntimeError): for ts in sub_sliced_reader: pass + + +def test_n_atoms_mismatch(): + universe = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD) + port = get_free_port() + server = InThreadIMDServer(universe.trajectory) + server.set_imdsessioninfo(create_default_imdsinfo_v3()) + server.handshake_sequence("localhost", port, first_frame=True) + with pytest.raises( + EOFError, + match="IMDProducer: Expected n_atoms value 6, got 5. Ensure you are using the correct topology file.", + ): + IMDReader( + f"imd://localhost:{port}", + n_atoms=universe.trajectory.n_atoms + 1, + )