Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better error messages, Context manager #34

Merged
merged 4 commits into from
Nov 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions imdclient/IMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,8 @@ def _read_frame(self, frame):

try:
imdf = self._imdclient.get_imdframe()
except EOFError:
# Not strictly necessary, but for clarity
raise StopIteration
except EOFError as e:
raise e

self._frame = frame
self._load_imdframe_into_ts(imdf)
Expand Down
104 changes: 93 additions & 11 deletions imdclient/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import numpy as np
from typing import Union, Dict
import signal
import atexit

logger = logging.getLogger(__name__)

Expand All @@ -43,6 +44,8 @@ class IMDClient:
Size of the socket buffer in bytes. Default is to use the system default
buffer_size : int (optional)
IMDFramebuffer will be filled with as many :class:`IMDFrame` fit in `buffer_size` bytes [``10MB``]
timeout : int, optional
Timeout for the socket in seconds [``5``]
**kwargs : dict (optional)
Additional keyword arguments to pass to the :class:`BaseIMDProducer` and :class:`IMDFrameBuffer`
"""
Expand All @@ -68,15 +71,18 @@ def __init__(
n_atoms,
**kwargs,
)
self._error_queue = queue.Queue()
else:
self._buf = None
self._error_queue = None
if self._imdsinfo.version == 2:
self._producer = IMDProducerV2(
self._conn,
self._buf,
self._imdsinfo,
n_atoms,
multithreaded,
self._error_queue,
**kwargs,
)
elif self._imdsinfo.version == 3:
Expand All @@ -86,23 +92,60 @@ def __init__(
self._imdsinfo,
n_atoms,
multithreaded,
self._error_queue,
**kwargs,
)

self._go()

if self._multithreaded:
# Disconnect MUST occur. This covers typical cases (Python, IPython interpreter)
signal.signal(signal.SIGINT, self.signal_handler)
signal.signal(signal.SIGTERM, self.signal_handler)

# Disconnect and socket shutdown MUST occur. This covers Jupyter use
# since in jupyter, the signal handler is reset to the default
# by pre- and post- hooks
# https://stackoverflow.com/questions/70841648/jupyter-reverts-signal-handler-to-default-when-running-next-cell
try:
import IPython
except ImportError:
has_ipython = False
else:
has_ipython = True

if has_ipython:
try:
from IPython import get_ipython

if get_ipython() is not None:
kernel = get_ipython().kernel
kernel.pre_handler_hook = lambda: None
kernel.post_handler_hook = lambda: None
logger.debug("Running in Jupyter")
except NameError:
logger.debug("Running in non-jupyter IPython environment")

# Final case: error is raised outside of IMDClient code
logger.debug("Registering atexit")
atexit.register(self.stop)

self._producer.start()

def signal_handler(self, sig, frame):
"""Catch SIGINT to allow clean shutdown on CTRL+C
This also ensures that main thread execution doesn't get stuck
waiting in buf.pop_full_imdframe()"""
logger.debug("Intercepted signal")
self.stop()
logger.debug("Shutdown success")

def get_imdframe(self):
"""
Returns
-------
IMDFrame
The next frame from the IMD server
Raises
------
EOFError
Expand All @@ -116,6 +159,9 @@ def get_imdframe(self):
# and doesn't need to be notified
self._disconnect()
self._stopped = True

if self._error_queue.qsize():
raise EOFError(f"{self._error_queue.get()}")
raise EOFError
else:
try:
Expand All @@ -125,14 +171,23 @@ def get_imdframe(self):
raise EOFError

def get_imdsessioninfo(self):
"""
Returns
-------
IMDSessionInfo
Information about the IMD session
"""
return self._imdsinfo

def stop(self):
"""
Stop the client and close the connection
"""
if self._multithreaded:
if not self._stopped:
self._buf.notify_consumer_finished()
self._disconnect()
self._stopped = True
self._disconnect()
self._buf.notify_consumer_finished()
else:
self._disconnect()

Expand Down Expand Up @@ -252,6 +307,13 @@ def _disconnect(self):
finally:
self._conn.close()

def __enter__(self):
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()
return False


class BaseIMDProducer(threading.Thread):
"""
Expand All @@ -267,11 +329,14 @@ class BaseIMDProducer(threading.Thread):
Information about the IMD session
n_atoms : int
Number of atoms in the simulation
multithreaded : bool, optional
multithreaded : bool
If True, socket interaction will occur in a separate thread &
frames will be buffered. Single-threaded, blocking IMDClient
should only be used in testing [[``True``]]

should only be used in testing
error_queue: queue.Queue
Queue to hold errors produced by the producer thread
timeout : int, optional
Timeout for the socket in seconds [``5``]
"""

def __init__(
Expand All @@ -280,7 +345,8 @@ def __init__(
buffer,
sinfo,
n_atoms,
multithreaded=True,
multithreaded,
error_queue,
timeout=5,
**kwargs,
):
Expand All @@ -289,6 +355,7 @@ def __init__(
self._imdsinfo = sinfo
self._paused = False

self.error_queue = error_queue
# Timeout for first frame should be longer
# than rest of frames
self._timeout = timeout
Expand Down Expand Up @@ -383,6 +450,7 @@ def run(self):
logger.debug("IMDProducer: Simulation ended normally, cleaning up")
except Exception as e:
logger.debug("IMDProducer: An unexpected error occurred: %s", e)
self.error_queue.put(e)
finally:
logger.debug("IMDProducer: Stopping run loop")
# Tell consumer not to expect more frames to be added
Expand All @@ -398,9 +466,19 @@ def _expect_header(self, expected_type, expected_value=None):
)
# Sometimes we do not care what the value is
if expected_value is not None and header.length != expected_value:
raise RuntimeError(
f"IMDProducer: Expected header value {expected_value}, got {header.length}"
)
if expected_type in [
IMDHeaderType.IMD_FCOORDS,
IMDHeaderType.IMD_VELOCITIES,
IMDHeaderType.IMD_FORCES,
]:
raise RuntimeError(
f"IMDProducer: Expected n_atoms value {expected_value}, got {header.length}. "
+ "Ensure you are using the correct topology file."
)
else:
raise RuntimeError(
f"IMDProducer: Expected header value {expected_value}, got {header.length}"
)

def _get_header(self):
self._read(self._header)
Expand All @@ -420,9 +498,11 @@ def _read(self, buf):


class IMDProducerV2(BaseIMDProducer):
def __init__(self, conn, buffer, sinfo, n_atoms, multithreaded, **kwargs):
def __init__(
self, conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
):
super(IMDProducerV2, self).__init__(
conn, buffer, sinfo, n_atoms, multithreaded, **kwargs
conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
)

self._energies = bytearray(IMDENERGYPACKETLENGTH)
Expand Down Expand Up @@ -515,6 +595,7 @@ def __init__(
sinfo,
n_atoms,
multithreaded,
error_queue,
**kwargs,
):
super(IMDProducerV3, self).__init__(
Expand All @@ -523,6 +604,7 @@ def __init__(
sinfo,
n_atoms,
multithreaded,
error_queue,
**kwargs,
)
# The body of an x/v/f packet should contain
Expand Down
51 changes: 51 additions & 0 deletions imdclient/tests/test_imdclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,54 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf):
client.get_imdframe()
# server should receive disconnect from client (though it doesn't have to do anything)
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)


class TestIMDClientV3ContextManager:
@pytest.fixture
def port(self):
return get_free_port()

@pytest.fixture
def universe(self):
return mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)

@pytest.fixture
def imdsinfo(self):
return create_default_imdsinfo_v3()

@pytest.fixture
def server(self, universe, imdsinfo, port):
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(imdsinfo)
yield server
server.cleanup()

def test_context_manager_traj_unchanged(self, server, port, universe):
server.handshake_sequence("localhost", port, first_frame=False)

i = 0
with IMDClient(
"localhost",
port,
universe.trajectory.n_atoms,
) as client:
server.send_frames(0, 5)
while i < 5:

imdf = client.get_imdframe()
assert_allclose(universe.trajectory[i].time, imdf.time)
assert_allclose(universe.trajectory[i].dt, imdf.dt)
assert_allclose(universe.trajectory[i].data["step"], imdf.step)
assert_allclose(
universe.trajectory[i].positions, imdf.positions
)
assert_allclose(
universe.trajectory[i].velocities, imdf.velocities
)
assert_allclose(universe.trajectory[i].forces, imdf.forces)
assert_allclose(
universe.trajectory[i].triclinic_dimensions, imdf.box
)
i += 1
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)
assert i == 5
20 changes: 17 additions & 3 deletions imdclient/tests/test_imdreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ def test_total_time(self, reader, ref):
decimal=ref.prec,
)

@pytest.mark.skip(
reason="Stream-based reader can only be read iteratively"
)
@pytest.mark.skip(reason="Stream-based reader can only be read iteratively")
def test_changing_dimensions(self, ref, reader):
if ref.changing_dimensions:
reader.rewind()
Expand Down Expand Up @@ -642,3 +640,19 @@ def test_subslice_fi_all_after_iteration_raises_error(self, reader):
with pytest.raises(RuntimeError):
for ts in sub_sliced_reader:
pass


def test_n_atoms_mismatch():
universe = mda.Universe(COORDINATES_TOPOLOGY, COORDINATES_H5MD)
port = get_free_port()
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(create_default_imdsinfo_v3())
server.handshake_sequence("localhost", port, first_frame=True)
with pytest.raises(
EOFError,
match="IMDProducer: Expected n_atoms value 6, got 5. Ensure you are using the correct topology file.",
):
IMDReader(
f"imd://localhost:{port}",
n_atoms=universe.trajectory.n_atoms + 1,
)
Loading