Skip to content

Commit

Permalink
zmq: implement task for heeartbeats on clients to detect reconnection
Browse files Browse the repository at this point in the history
Heartbeat task sends HEARTBEAT to all the clients (ie. Monitor) at client.HEARTBEAT_TIMEOUT intervals.
Clients do not reply, just process the message. If client detects longer delay between two heartbeats,
the client will send CONNECT to evaluator in addition; ie. getting the connection re-established after a break.
This is to simulate re-connection. Each CONNECT_MSG will then trigger sending FullSnapshot from the ensemble evaluator.
Initially HEARTBEAT_TIMEOUT is set to 5 seconds while Monitor accepts 10 seconds at max as a delay.
Additionally, initial connection will now undergo same amount of retries as standard messages.
  • Loading branch information
xjules committed Jan 20, 2025
1 parent fd06848 commit 266cbd0
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 23 deletions.
18 changes: 16 additions & 2 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class ClientConnectionError(Exception):
CONNECT_MSG = b"CONNECT"
DISCONNECT_MSG = b"DISCONNECT"
ACK_MSG = b"ACK"
HEARTBEAT_MSG = b"BEAT"
HEARTBEAT_TIMEOUT = 5.0


class Client:
Expand Down Expand Up @@ -83,7 +85,7 @@ async def connect(self) -> None:
await self._term_receiver_task()
self._receiver_task = asyncio.create_task(self._receiver())
try:
await self.send(CONNECT_MSG, retries=1)
await self.send(CONNECT_MSG)
except ClientConnectionError:
await self._term_receiver_task()
self.term()
Expand All @@ -93,11 +95,23 @@ async def process_message(self, msg: str) -> None:
raise NotImplementedError("Only monitor can receive messages!")

async def _receiver(self) -> None:
last_beat_time: float | None = None
while True:
try:
_, raw_msg = await self.socket.recv_multipart()
if raw_msg == ACK_MSG:
self._ack_event.set()
elif raw_msg == HEARTBEAT_MSG:
if (
last_beat_time
and (asyncio.get_running_loop().time() - last_beat_time)
> 2 * HEARTBEAT_TIMEOUT
):
await self.socket.send_multipart([b"", CONNECT_MSG])
logger.warning(
f"{self.dealer_id} heartbeat failed - reconnecting."
)
last_beat_time = asyncio.get_running_loop().time()
else:
await self.process_message(raw_msg.decode("utf-8"))
except zmq.ZMQError as exc:
Expand Down Expand Up @@ -144,5 +158,5 @@ async def send(self, message: str | bytes, retries: int | None = None) -> None:
self.socket.connect(self.url)
backoff = min(backoff * 2, 10) # Exponential backoff
raise ClientConnectionError(
f"{self.dealer_id} Failed to send {message!r} after retries!"
f"{self.dealer_id} Failed to send {message!r} to {self.url} after retries!"
)
29 changes: 24 additions & 5 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,13 @@
event_from_json,
event_to_json,
)
from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
HEARTBEAT_MSG,
HEARTBEAT_TIMEOUT,
)
from ert.ensemble_evaluator import identifiers as ids

from ._ensemble import FMStepSnapshot
Expand All @@ -51,7 +57,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._ensemble: Ensemble = ensemble

self._events: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
self._events_to_send: asyncio.Queue[Event | bytes] = asyncio.Queue()
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()

self._ee_tasks: list[asyncio.Task[None]] = []
Expand All @@ -72,14 +78,24 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
self._dispatchers_empty: asyncio.Event = asyncio.Event()
self._dispatchers_empty.set()

async def _do_heartbeat_clients(self) -> None:
await self._server_started.wait()
while True:
if self._clients_connected:
await self._events_to_send.put(HEARTBEAT_MSG)
await asyncio.sleep(HEARTBEAT_TIMEOUT)

async def _publisher(self) -> None:
await self._server_started.wait()
while True:
event = await self._events_to_send.get()
for identity in self._clients_connected:
await self._router_socket.send_multipart(
[identity, b"", event_to_json(event).encode("utf-8")]
)
if isinstance(event, bytes):
await self._router_socket.send_multipart([identity, b"", event])
else:
await self._router_socket.send_multipart(
[identity, b"", event_to_json(event).encode("utf-8")]
)
self._events_to_send.task_done()

async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None:
Expand Down Expand Up @@ -197,6 +213,8 @@ def ensemble(self) -> Ensemble:

async def handle_client(self, dealer: bytes, frame: bytes) -> None:
if frame == CONNECT_MSG:
if dealer in self._clients_connected:
logger.warning(f"{dealer!r} wants to reconnect.")
self._clients_connected.add(dealer)
self._clients_empty.clear()
current_snapshot_dict = self._ensemble.snapshot.to_dict()
Expand Down Expand Up @@ -342,6 +360,7 @@ async def _start_running(self) -> None:
raise ValueError("no config for evaluator")
self._ee_tasks = [
asyncio.create_task(self._server(), name="server_task"),
asyncio.create_task(self._do_heartbeat_clients(), name="beat_task"),
asyncio.create_task(
self._batch_events_into_buffer(), name="dispatcher_task"
),
Expand Down
44 changes: 35 additions & 9 deletions tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import asyncio

import pytest

import _ert.forward_model_runner.client
from _ert.forward_model_runner.client import Client, ClientConnectionError
from tests.ert.utils import MockZMQServer

Expand All @@ -18,12 +21,12 @@ async def test_invalid_server():
async def test_successful_sending(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
messages_c1 = ["test_1", "test_2", "test_3"]
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1:
for message in messages_c1:
await c1.send(message)
messages = ["test_1", "test_2", "test_3"]
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as client:
for message in messages:
await client.send(message)

for msg in messages_c1:
for msg in messages:
assert msg in mock_server.messages


Expand All @@ -32,18 +35,41 @@ async def test_retry(unused_tcp_port):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
client_connection_error_set = False
messages_c1 = ["test_1", "test_2", "test_3"]
messages = ["test_1", "test_2", "test_3"]
async with (
MockZMQServer(unused_tcp_port, signal=2) as mock_server,
Client(url, ack_timeout=0.5) as c1,
Client(url, ack_timeout=0.5) as client,
):
for message in messages_c1:
for message in messages:
try:
await c1.send(message, retries=1)
await client.send(message, retries=1)
except ClientConnectionError:
client_connection_error_set = True
mock_server.signal(0)
assert client_connection_error_set
assert mock_server.messages.count("test_1") == 2
assert mock_server.messages.count("test_2") == 1
assert mock_server.messages.count("test_3") == 1


async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch):
host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
monkeypatch.setattr(_ert.forward_model_runner.client, "HEARTBEAT_TIMEOUT", 0.1)

async with (
MockZMQServer(unused_tcp_port, signal=3) as mock_server,
Client(url) as client,
):
await client.send("start", retries=1)

await mock_server.do_heartbeat()
await asyncio.sleep(1)
await mock_server.do_heartbeat()
await client.send("stop", retries=1)

# when reconnection happens CONNECT message is sent again
assert mock_server.messages.count("CONNECT") == 2
assert mock_server.messages.count("DISCONNECT") == 1
assert "start" in mock_server.messages
assert "stop" in mock_server.messages
Original file line number Diff line number Diff line change
Expand Up @@ -213,9 +213,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port):
def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port):
# this is to show when the reporter fails but reconnects
# reporter still manages to send events and completes fine
# see assert reporter._timeout_timestamp is not None
# meaning Finish event initiated _timeout but timeout wasn't reached since
# it finished succesfully
# see reporter._event_publisher for more details.

host = "localhost"
url = f"tcp://{host}:{unused_tcp_port}"
Expand Down
27 changes: 23 additions & 4 deletions tests/ert/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@
import zmq
import zmq.asyncio

from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
from _ert.forward_model_runner.client import (
ACK_MSG,
CONNECT_MSG,
DISCONNECT_MSG,
HEARTBEAT_MSG,
)
from _ert.threading import ErtThread
from ert.scheduler.event import FinishedEvent, StartedEvent

Expand Down Expand Up @@ -64,16 +69,18 @@ def wait_until(func, interval=0.5, timeout=30):
class MockZMQServer:
def __init__(self, port, signal=0):
"""Mock ZMQ server for testing
signal = 0: normal operation
signal = 0: normal operation, receive messages but don't store CONNECT and DISCONNECT messages
signal = 1: don't send ACK and don't receive messages
signal = 2: don't send ACK, but receive messages
signal = 3: normal operation, and store also CONNECT and DISCONNECT messages
"""
self.port = port
self.messages = []
self.value = signal
self.loop = None
self.server_task = None
self.handler_task = None
self.dealers = set()

def start_event_loop(self):
asyncio.set_event_loop(self.loop)
Expand Down Expand Up @@ -116,13 +123,25 @@ async def mock_zmq_server(self):
def signal(self, value):
self.value = value

async def do_heartbeat(self):
for dealer in self.dealers:
await self.router_socket.send_multipart([dealer, b"", HEARTBEAT_MSG])

async def _handler(self):
while True:
try:
dealer, __, frame = await self.router_socket.recv_multipart()
if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0:
if frame == CONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.add(dealer)
elif frame == DISCONNECT_MSG:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
self.dealers.discard(dealer)
elif self.value in {0, 3}:
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1:
if (
self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG}
) or self.value == 3:
self.messages.append(frame.decode("utf-8"))
except asyncio.CancelledError:
break
Expand Down

0 comments on commit 266cbd0

Please sign in to comment.