Skip to content

Commit

Permalink
Merge pull request #44 from Becksteinlab/wait
Browse files Browse the repository at this point in the history
Add continue_after_disconnect option (wait packet implementation)
  • Loading branch information
ljwoods2 authored Dec 11, 2024
2 parents dd47bcc + 5a41ddf commit 57601ab
Show file tree
Hide file tree
Showing 8 changed files with 115 additions and 26 deletions.
10 changes: 5 additions & 5 deletions docs/source/protocol_v3.rst
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ and its associated body packet (if present) is described in detail.
* - :ref:`forces`
- 15
- ❌
* - :ref:`wait-flag`
* - :ref:`wait`
- 16
- ❌

Expand Down Expand Up @@ -489,10 +489,10 @@ forces were previously specified for this session in the :ref:`session info pack
.. versionadded:: 3

.. _wait-flag:
.. _wait:

Wait flag
^^^^^^^^^
Wait
^^^^

Sent from the receiver to the simulation engine any time after the :ref:`session info packet <session-info>`
has been sent to request that the simulation engine modify its waiting behavior mid-simulation either
Expand All @@ -513,7 +513,7 @@ The simulation engine's waiting behavior also applies when a receiver disconnect
.. code-block:: none
Header:
16 (int32) Wait flag
16 (int32) Wait
<val> (int32) Nonzero to set the simulation engine's waiting behavior to blocking, 0
to set the simulation engine's waiting behavior to non-blocking
Expand Down
4 changes: 3 additions & 1 deletion imdclient/IMD.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def __init__(
):
super(IMDReader, self).__init__(filename, **kwargs)

self._imdclient = None
logger.debug("IMDReader initializing")

if n_atoms is None:
Expand Down Expand Up @@ -125,6 +126,7 @@ def _format_hint(thing):
def close(self):
"""Gracefully shut down the reader. Stops the producer thread."""
logger.debug("IMDReader close() called")
self._imdclient.stop()
if self._imdclient is not None:
self._imdclient.stop()
# NOTE: removeme after testing
logger.debug("IMDReader shut down gracefully.")
45 changes: 37 additions & 8 deletions imdclient/IMDClient.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,10 @@ class IMDClient:
IMDFramebuffer will be filled with as many :class:`IMDFrame` fit in `buffer_size` bytes [``10MB``]
timeout : int, optional
Timeout for the socket in seconds [``5``]
continue_after_disconnect : bool, optional [``None``]
If True, the client will attempt to change the simulation engine's waiting behavior to
non-blocking after the client disconnects. If False, the client will attempt to change it
to blocking. If None, the client will not attempt to change the simulation engine's behavior.
**kwargs : dict (optional)
Additional keyword arguments to pass to the :class:`BaseIMDProducer` and :class:`IMDFrameBuffer`
"""
Expand All @@ -57,13 +61,15 @@ def __init__(
n_atoms,
socket_bufsize=None,
multithreaded=True,
continue_after_disconnect=None,
**kwargs,
):

self._stopped = False
self._conn = self._connect_to_server(host, port, socket_bufsize)
self._imdsinfo = self._await_IMD_handshake()
self._multithreaded = multithreaded
self._continue_after_disconnect = continue_after_disconnect

if self._multithreaded:
self._buf = IMDFrameBuffer(
Expand Down Expand Up @@ -201,7 +207,9 @@ def _connect_to_server(self, host, port, socket_bufsize):
# /proc/sys/net/core/rmem_default
# Max (linux):
# /proc/sys/net/core/rmem_max
conn.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize)
conn.setsockopt(
socket.SOL_SOCKET, socket.SO_RCVBUF, socket_bufsize
)
try:
logger.debug(f"IMDClient: Connecting to {host}:{port}")
conn.connect((host, port))
Expand Down Expand Up @@ -292,6 +300,17 @@ def _go(self):
self._conn.sendall(go)
logger.debug("IMDClient: Sent go packet to server")

if self._continue_after_disconnect is not None:
wait_behavior = (int)(not self._continue_after_disconnect)
wait_packet = create_header_bytes(
IMDHeaderType.IMD_WAIT, wait_behavior
)
self._conn.sendall(wait_packet)
logger.debug(
"IMDClient: Attempted to change wait behavior to %s",
not self._continue_after_disconnect
)

def _disconnect(self):
# MUST disconnect before stopping execution
# if simulation already ended, this method will do nothing
Expand Down Expand Up @@ -499,7 +518,14 @@ def _read(self, buf):

class IMDProducerV2(BaseIMDProducer):
def __init__(
self, conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
self,
conn,
buffer,
sinfo,
n_atoms,
multithreaded,
error_queue,
**kwargs,
):
super(IMDProducerV2, self).__init__(
conn, buffer, sinfo, n_atoms, multithreaded, error_queue, **kwargs
Expand Down Expand Up @@ -713,9 +739,6 @@ def _parse_imdframe(self):
).reshape((self._n_atoms, 3)),
)

def __del__(self):
logger.debug("IMDProducer: I am being deleted")


class IMDFrameBuffer:
"""
Expand Down Expand Up @@ -762,7 +785,9 @@ def __init__(
raise ValueError("pause_empty_proportion must be between 0 and 1")
self._pause_empty_proportion = pause_empty_proportion
if unpause_empty_proportion < 0 or unpause_empty_proportion > 1:
raise ValueError("unpause_empty_proportion must be between 0 and 1")
raise ValueError(
"unpause_empty_proportion must be between 0 and 1"
)
self._unpause_empty_proportion = unpause_empty_proportion

if buffer_size <= 0:
Expand Down Expand Up @@ -829,7 +854,9 @@ def wait_for_space(self):
logger.debug("IMDProducer: Noticing consumer finished")
raise EOFError
except Exception as e:
logger.debug(f"IMDProducer: Error waiting for space in buffer: {e}")
logger.debug(
f"IMDProducer: Error waiting for space in buffer: {e}"
)

def pop_empty_imdframe(self):
logger.debug("IMDProducer: Getting empty frame")
Expand Down Expand Up @@ -875,7 +902,9 @@ def pop_full_imdframe(self):
imdf = self._full_q.get()
else:
with self._full_imdf_avail:
while self._full_q.qsize() == 0 and not self._producer_finished:
while (
self._full_q.qsize() == 0 and not self._producer_finished
):
self._full_imdf_avail.wait()

if self._producer_finished and self._full_q.qsize() == 0:
Expand Down
1 change: 1 addition & 0 deletions imdclient/IMDProtocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ class IMDHeaderType(Enum):
IMD_BOX = 13
IMD_VELOCITIES = 14
IMD_FORCES = 15
IMD_WAIT = 16


def parse_energy_bytes(data, endianness):
Expand Down
46 changes: 46 additions & 0 deletions imdclient/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,3 +208,49 @@ def test_compare_imd_to_true_traj(
imd_u.trajectory[i - first_frame].forces,
atol=1e-03,
)

def test_continue_after_disconnect(
self, docker_client, topol, tmp_path, port
):
u = mda.Universe(
(tmp_path / topol),
f"imd://localhost:{port}",
continue_after_disconnect=True,
# Make sure LAMMPS topol can be read
# Does nothing if not LAMMPS
atom_style="id type x y z",
)
# Though we disconnect here, the simulation should continue
u.trajectory.close()
# Wait for the simulation to finish running
time.sleep(45)

# Now, attempt to reconnect- should fail,
# since the simulation should have continued
with pytest.raises(IOError):
u = mda.Universe(
(tmp_path / topol),
f"imd://localhost:{port}",
atom_style="id type x y z",
)

def test_wait_after_disconnect(self, docker_client, topol, tmp_path, port):
u = mda.Universe(
(tmp_path / topol),
f"imd://localhost:{port}",
# Could also use None here- just being explicit
continue_after_disconnect=False,
# Make sure LAMMPS topol can be read
# Does nothing if not LAMMPS
atom_style="id type x y z",
)
u.trajectory.close()
# Give the simulation engine
# enough time to finish running (though it shouldn't)
time.sleep(45)

u = mda.Universe(
(tmp_path / topol),
f"imd://localhost:{port}",
atom_style="id type x y z",
)
13 changes: 2 additions & 11 deletions imdclient/tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,22 +93,14 @@ def _send_handshakeV3(self):
positions,
wrapped_coords,
velocities,
forces
forces,
)
logger.debug(f"InThreadIMDServer: Sending session info")
self.conn.sendall(sinfo)

def join_accept_thread(self):
self.accept_thread.join()

def _expect_go(self):
logger.debug(f"InThreadIMDServer: Waiting for go")
head_buf = bytearray(IMDHEADERSIZE)
read_into_buf(self.conn, head_buf)
header = IMDHeader(head_buf)
if header.type != IMDHeaderType.IMD_GO:
raise ValueError("Expected IMD_GO packet, got something else")

def send_frames(self, start, end):
for i in range(start, end):
self.send_frame(i)
Expand All @@ -126,7 +118,7 @@ def send_frame(self, i):
)

self.conn.sendall(time_header + time)

if self.imdsinfo.energies:
energy_header = create_header_bytes(IMDHeaderType.IMD_ENERGIES, 1)
energies = create_energy_bytes(
Expand Down Expand Up @@ -183,7 +175,6 @@ def send_frame(self, i):

self.conn.sendall(force_header + force)


def expect_packet(self, packet_type, expected_length=None):
head_buf = bytearray(IMDHEADERSIZE)
read_into_buf(self.conn, head_buf)
Expand Down
18 changes: 18 additions & 0 deletions imdclient/tests/test_imdclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ def server_client_two_frame_buf(self, universe, imdsinfo, port):
buffer_size=imdframe_memsize(universe.trajectory.n_atoms, imdsinfo)
* 2,
)
server.join_accept_thread()
yield server, client
client.stop()
server.cleanup()
Expand All @@ -84,6 +85,7 @@ def server_client(self, universe, imdsinfo, port, request):
port,
universe.trajectory.n_atoms,
)
server.join_accept_thread()
yield server, client
client.stop()
server.cleanup()
Expand Down Expand Up @@ -149,6 +151,22 @@ def test_pause_resume_no_disconnect(self, server_client_two_frame_buf):
# server should receive disconnect from client (though it doesn't have to do anything)
server.expect_packet(IMDHeaderType.IMD_DISCONNECT)

@pytest.mark.parametrize("cont", [True, False])
def test_continue_after_disconnect(self, universe, imdsinfo, port, cont):
server = InThreadIMDServer(universe.trajectory)
server.set_imdsessioninfo(imdsinfo)
server.handshake_sequence("localhost", port, first_frame=False)
client = IMDClient(
f"localhost",
port,
universe.trajectory.n_atoms,
continue_after_disconnect=cont,
)
server.join_accept_thread()
server.expect_packet(
IMDHeaderType.IMD_WAIT, expected_length=(int)(not cont)
)


class TestIMDClientV3ContextManager:
@pytest.fixture
Expand Down
4 changes: 3 additions & 1 deletion imdclient/tests/test_imdreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,9 @@ def test_total_time(self, reader, ref):
decimal=ref.prec,
)

@pytest.mark.skip(reason="Stream-based reader can only be read iteratively")
@pytest.mark.skip(
reason="Stream-based reader can only be read iteratively"
)
def test_changing_dimensions(self, ref, reader):
if ref.changing_dimensions:
reader.rewind()
Expand Down

0 comments on commit 57601ab

Please sign in to comment.