Skip to content

Commit

Permalink
Merge pull request #34 from Becksteinlab/context-manager
Browse files Browse the repository at this point in the history
Better error messages, Context manager
  • Loading branch information
ljwoods2 authored Nov 29, 2024
2 parents ba6c8d5 + 746d94c commit d666383
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 17 deletions.
5 changes: 2 additions & 3 deletions imdclient/IMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def _read_frame(self, frame):

try:
imdf = self._imdclient.get_imdframe()
except EOFError:
# Not strictly necessary, but for clarity
raise StopIteration
except EOFError as e:
raise e

self._frame = frame
self._load_imdframe_into_ts(imdf)
Expand Down
104 changes: 93 additions & 11 deletions imdclient/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
from typing import Union, Dict
import signal
import atexit

logger = logging.getLogger(__name__)

Expand All @@ -43,6 +44,8 @@ class IMDClient:
Size of the socket buffer in bytes. Default is to use the system default
buffer_size : int (optional)
IMDFramebuffer will be filled with as many :class:`IMDFrame` fit in `buffer_size` bytes [``10MB``]
timeout : int, optional
Timeout for the socket in seconds [``5``]
**kwargs : dict (optional)
Additional keyword arguments to pass to the :class:`BaseIMDProducer` and :class:`IMDFrameBuffer`
"""
Expand All @@ -68,15 +71,18 @@ def __init__(
n_atoms,
**kwargs,
)
self._error_queue = queue.Queue()
else:
self._buf = None
self._error_queue = None
if self._imdsinfo.version == 2:
self._producer = IMDProducerV2(
self._conn,
self._buf,
self._imdsinfo,
n_atoms,
multithreaded,
self._error_queue,
**kwargs,
)
elif self._imdsinfo.version == 3:
Expand All @@ -86,23 +92,60 @@ def __init__(
self._imdsinfo,
n_atoms,
multithreaded,
self._error_queue,
**kwargs,
)

self._go()

if self._multithreaded:
# Disconnect MUST occur. This covers typical cases (Python, IPython interpreter)
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)

# Disconnect and socket shutdown MUST occur. This covers Jupyter use
# since in jupyter, the signal handler is reset to the default
# by pre- and post- hooks
# https://stackoverflow.com/questions/70841648/jupyter-reverts-signal-handler-to-default-when-running-next-cell
try:
import IPython
except ImportError:
has_ipython = False
else:
has_ipython = True

if has_ipython:
try:
from IPython import get_ipython

if get_ipython() is not None:
kernel = get_ipython().kernel
kernel.pre_handler_hook = lambda: None
kernel.post_handler_hook = lambda: None
logger.debug("Running in Jupyter")
except NameError:
logger.debug("Running in non-jupyter IPython environment")

# Final case: error is raised outside of IMDClient code
logger.debug("Registering atexit")
atexit.register(self.stop)

self._producer.start()

def signal_handler(self, sig, frame):
"""Catch SIGINT to allow clean shutdown on CTRL+C
This also ensures that main thread execution doesn't get stuck
waiting in buf.pop_full_imdframe()"""
logger.debug("Intercepted signal")
self.stop()
logger.debug("Shutdown success")

def get_imdframe(self):
"""
Returns
-------
IMDFrame
The next frame from the IMD server
Raises
------
EOFError
Expand All @@ -116,6 +159,9 @@ def get_imdframe(self):
# and doesn't need to be notified
self._disconnect()
self._stopped = True

if self._error_queue.qsize():
raise EOFError(f"{self._error_queue.get()}")
raise EOFError
else:
try:
Expand All @@ -125,14 +171,23 @@ def get_imdframe(self):
raise EOFError

def get_imdsessioninfo(self):
"""
Returns
-------
IMDSessionInfo
Information about the IMD session
"""
return self._imdsinfo

def stop(self):
"""
Stop the client and close the connection
"""
if self._multithreaded:
if not self._stopped:
self._buf.notify_consumer_finished()
self._disconnect()
self._stopped = True
self._disconnect()
self._buf.notify_consumer_finished()
else:
self._disconnect()

Expand Down Expand Up @@ -252,6 +307,13 @@ def _disconnect(self):
finally:
self._conn.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
return False


class BaseIMDProducer(threading.Thread):
"""
Expand All @@ -267,11 +329,14 @@ class BaseIMDProducer(threading.Thread):
Information about the IMD session
n_atoms : int
Number of atoms in the simulation
multithreaded : bool, optional
multithreaded : bool
If True, socket interaction will occur in a separate thread &
frames will be buffered. Single-threaded, blocking IMDClient
should only be used in testing [[``True``]]
should only be used in testing
error_queue: queue.Queue
Queue to hold errors produced by the producer thread
timeout : int, optional
Timeout for the socket in seconds [``5``]
"""

def __init__(
Expand All @@ -280,7 +345,8 @@ def __init__(
buffer,
sinfo,
n_atoms,
multithreaded=True,
multithreaded,
error_queue,
timeout=5,
**kwargs,
):
Expand All @@ -289,6 +355,7 @@ def __init__(
self._imdsinfo = sinfo
self._paused = False

self.error_queue = error_queue
# Timeout for first frame should be longer
# than rest of frames
self._timeout = timeout
Expand Down Expand Up @@ -383,6 +450,7 @@ def run(self):
logger.debug("IMDProducer: Simulation ended normally, cleaning up")
except Exception as e:
logger.debug("IMDProducer: An unexpected error occurred: %s", e)
self.error_queue.put(e)
finally:
logger.debug("IMDProducer: Stopping run loop")
# Tell consumer not to expect more frames to be added
Expand All @@ -398,9 +466,19 @@ def _expect_header(self, expected_type, expected_value=None):
)
# Sometimes we do not care what the value is
if expected_value is not None and header.length != expected_value:
raise RuntimeError(
f"IMDProducer: Expected header value {expected_value}, got {header.length}"
)
if expected_type in [
IMDHeaderType.IMD_FCOORDS,
IMDHeaderType.IMD_VELOCITIES,
IMDHeaderType.IMD_FORCES,
]:
raise RuntimeError(
f"IMDProducer: Expected n_atoms value {expected_value}, got {header.length}. "
+ "Ensure you are using the correct topology file."
)
else:
raise RuntimeError(
f"IMDProducer: Expected header value {expected_value}, got {header.length}"
)

def _get_header(self):
self._read(self._header)
Expand All @@ -420,9 +498,11 @@ def _read(self, buf):


class IMDProducerV2(BaseIMDProducer):
def __init__(self, conn, buffer, sinfo, n_atoms, multithreaded, **kwargs):
def __init__(
self, conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
):
super(IMDProducerV2, self).__init__(
conn, buffer, sinfo, n_atoms, multithreaded, **kwargs
conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
)

self._energies = bytearray(IMDENERGYPACKETLENGTH)
Expand Down Expand Up @@ -515,6 +595,7 @@ def __init__(
sinfo,
n_atoms,
multithreaded,
error_queue,
**kwargs,
):
super(IMDProducerV3, self).__init__(
Expand All @@ -523,6 +604,7 @@ def __init__(
sinfo,
n_atoms,
multithreaded,
error_queue,
**kwargs,
)
# The body of an x/v/f packet should contain
Expand Down
51 changes: 51 additions & 0 deletions imdclient/tests/test_imdclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,54 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf):
client.get_imdframe()
# server should receive disconnect from client (though it doesn't have to do anything)
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)


class TestIMDClientV3ContextManager:
@pytest.fixture
def port(self):
return get_free_port()

@pytest.fixture
def universe(self):
return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)

@pytest.fixture
def imdsinfo(self):
return create_default_imdsinfo_v3()

@pytest.fixture
def server(self, universe, imdsinfo, port):
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(imdsinfo)
yield server
server.cleanup()

def test_context_manager_traj_unchanged(self, server, port, universe):
server.handshake_sequence("localhost", port, first_frame=False)

i = 0
with IMDClient(
"localhost",
port,
universe.trajectory.n_atoms,
) as client:
server.send_frames(0, 5)
while i < 5:

imdf = client.get_imdframe()
assert_allclose(universe.trajectory[i].time, imdf.time)
assert_allclose(universe.trajectory[i].dt, imdf.dt)
assert_allclose(universe.trajectory[i].data["step"], imdf.step)
assert_allclose(
universe.trajectory[i].positions, imdf.positions
)
assert_allclose(
universe.trajectory[i].velocities, imdf.velocities
)
assert_allclose(universe.trajectory[i].forces, imdf.forces)
assert_allclose(
universe.trajectory[i].triclinic_dimensions, imdf.box
)
i += 1
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)
assert i == 5
20 changes: 17 additions & 3 deletions imdclient/tests/test_imdreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ def test_total_time(self, reader, ref):
decimal=ref.prec,
)

@pytest.mark.skip(
reason="Stream-based reader can only be read iteratively"
)
@pytest.mark.skip(reason="Stream-based reader can only be read iteratively")
def test_changing_dimensions(self, ref, reader):
if ref.changing_dimensions:
reader.rewind()
Expand Down Expand Up @@ -642,3 +640,19 @@ def test_subslice_fi_all_after_iteration_raises_error(self, reader):
with pytest.raises(RuntimeError):
for ts in sub_sliced_reader:
pass


def test_n_atoms_mismatch():
universe = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)
port = get_free_port()
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(create_default_imdsinfo_v3())
server.handshake_sequence("localhost", port, first_frame=True)
with pytest.raises(
EOFError,
match="IMDProducer: Expected n_atoms value 6, got 5. Ensure you are using the correct topology file.",
):
IMDReader(
f"imd://localhost:{port}",
n_atoms=universe.trajectory.n_atoms + 1,
)

0 comments on commit d666383

Please sign in to comment.