Skip to content

Commit bba672a

Browse files
committed
zmq: implement task for heeartbeats on clients to detect reconnection
Heartbeat task sends HEARTBEAT to all the clients (ie. Monitor) at client.HEARTBEAT_TIMEOUT intervals. Clients do not reply, just process the message. If client detects longer delay between two heartbeats, the client will send CONNECT to evaluator in addition; ie. getting the connection re-established after a break. This is to simulate re-connection. Each CONNECT_MSG will then trigger sending FullSnapshot from the ensemble evaluator. Initially HEARTBEAT_TIMEOUT is set to 5 seconds while Monitor accepts 10 seconds at max as a delay. Additionally, initial connection will now undergo same amount of retries as standard messages.
1 parent 61a6776 commit bba672a

6 files changed

Lines changed: 101 additions & 24 deletions

File tree

src/_ert/forward_model_runner/client.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ class ClientConnectionError(Exception):
1818
CONNECT_MSG = b"CONNECT"
1919
DISCONNECT_MSG = b"DISCONNECT"
2020
ACK_MSG = b"ACK"
21+
HEARTBEAT_MSG = b"BEAT"
22+
HEARTBEAT_TIMEOUT = 5.0
2123

2224

2325
class Client:
@@ -83,7 +85,7 @@ async def connect(self) -> None:
8385
await self._term_receiver_task()
8486
self._receiver_task = asyncio.create_task(self._receiver())
8587
try:
86-
await self.send(CONNECT_MSG, retries=1)
88+
await self.send(CONNECT_MSG)
8789
except ClientConnectionError:
8890
await self._term_receiver_task()
8991
self.term()
@@ -93,11 +95,23 @@ async def process_message(self, msg: str) -> None:
9395
raise NotImplementedError("Only monitor can receive messages!")
9496

9597
async def _receiver(self) -> None:
98+
last_heartbeat_time: float | None = None
9699
while True:
97100
try:
98101
_, raw_msg = await self.socket.recv_multipart()
99102
if raw_msg == ACK_MSG:
100103
self._ack_event.set()
104+
elif raw_msg == HEARTBEAT_MSG:
105+
if (
106+
last_heartbeat_time
107+
and (asyncio.get_running_loop().time() - last_heartbeat_time)
108+
> 2 * HEARTBEAT_TIMEOUT
109+
):
110+
await self.socket.send_multipart([b"", CONNECT_MSG])
111+
logger.warning(
112+
f"{self.dealer_id} heartbeat failed - reconnecting."
113+
)
114+
last_heartbeat_time = asyncio.get_running_loop().time()
101115
else:
102116
await self.process_message(raw_msg.decode("utf-8"))
103117
except zmq.ZMQError as exc:
@@ -144,5 +158,5 @@ async def send(self, message: str | bytes, retries: int | None = None) -> None:
144158
self.socket.connect(self.url)
145159
backoff = min(backoff * 2, 10) # Exponential backoff
146160
raise ClientConnectionError(
147-
f"{self.dealer_id} Failed to send {message!r} after retries!"
161+
f"{self.dealer_id} Failed to send {message!r} to {self.url} after retries!"
148162
)

src/ert/ensemble_evaluator/evaluator.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
event_from_json,
2828
event_to_json,
2929
)
30-
from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
30+
from _ert.forward_model_runner.client import (
31+
ACK_MSG,
32+
CONNECT_MSG,
33+
DISCONNECT_MSG,
34+
HEARTBEAT_MSG,
35+
HEARTBEAT_TIMEOUT,
36+
)
3137
from ert.ensemble_evaluator import identifiers as ids
3238

3339
from ._ensemble import FMStepSnapshot
@@ -51,7 +57,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
5157
self._ensemble: Ensemble = ensemble
5258

5359
self._events: asyncio.Queue[Event] = asyncio.Queue()
54-
self._events_to_send: asyncio.Queue[Event] = asyncio.Queue()
60+
self._events_to_send: asyncio.Queue[Event | bytes] = asyncio.Queue()
5561
self._manifest_queue: asyncio.Queue[Any] = asyncio.Queue()
5662

5763
self._ee_tasks: list[asyncio.Task[None]] = []
@@ -72,14 +78,24 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):
7278
self._dispatchers_empty: asyncio.Event = asyncio.Event()
7379
self._dispatchers_empty.set()
7480

81+
async def _do_heartbeat_clients(self) -> None:
82+
await self._server_started.wait()
83+
while True:
84+
if self._clients_connected:
85+
await self._events_to_send.put(HEARTBEAT_MSG)
86+
await asyncio.sleep(HEARTBEAT_TIMEOUT)
87+
7588
async def _publisher(self) -> None:
7689
await self._server_started.wait()
7790
while True:
7891
event = await self._events_to_send.get()
7992
for identity in self._clients_connected:
80-
await self._router_socket.send_multipart(
81-
[identity, b"", event_to_json(event).encode("utf-8")]
82-
)
93+
if isinstance(event, bytes):
94+
await self._router_socket.send_multipart([identity, b"", event])
95+
else:
96+
await self._router_socket.send_multipart(
97+
[identity, b"", event_to_json(event).encode("utf-8")]
98+
)
8399
self._events_to_send.task_done()
84100

85101
async def _append_message(self, snapshot_update_event: EnsembleSnapshot) -> None:
@@ -197,6 +213,8 @@ def ensemble(self) -> Ensemble:
197213

198214
async def handle_client(self, dealer: bytes, frame: bytes) -> None:
199215
if frame == CONNECT_MSG:
216+
if dealer in self._clients_connected:
217+
logger.warning(f"{dealer!r} wants to reconnect.")
200218
self._clients_connected.add(dealer)
201219
self._clients_empty.clear()
202220
current_snapshot_dict = self._ensemble.snapshot.to_dict()
@@ -342,6 +360,7 @@ async def _start_running(self) -> None:
342360
raise ValueError("no config for evaluator")
343361
self._ee_tasks = [
344362
asyncio.create_task(self._server(), name="server_task"),
363+
asyncio.create_task(self._do_heartbeat_clients(), name="beat_task"),
345364
asyncio.create_task(
346365
self._batch_events_into_buffer(), name="dispatcher_task"
347366
),

tests/ert/unit_tests/ensemble_evaluator/test_ensemble_client.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1+
import asyncio
2+
13
import pytest
24

5+
import _ert.forward_model_runner.client
36
from _ert.forward_model_runner.client import Client, ClientConnectionError
47
from tests.ert.utils import MockZMQServer
58

@@ -18,12 +21,12 @@ async def test_invalid_server():
1821
async def test_successful_sending(unused_tcp_port):
1922
host = "localhost"
2023
url = f"tcp://{host}:{unused_tcp_port}"
21-
messages_c1 = ["test_1", "test_2", "test_3"]
22-
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as c1:
23-
for message in messages_c1:
24-
await c1.send(message)
24+
messages = ["test_1", "test_2", "test_3"]
25+
async with MockZMQServer(unused_tcp_port) as mock_server, Client(url) as client:
26+
for message in messages:
27+
await client.send(message)
2528

26-
for msg in messages_c1:
29+
for msg in messages:
2730
assert msg in mock_server.messages
2831

2932

@@ -32,18 +35,41 @@ async def test_retry(unused_tcp_port):
3235
host = "localhost"
3336
url = f"tcp://{host}:{unused_tcp_port}"
3437
client_connection_error_set = False
35-
messages_c1 = ["test_1", "test_2", "test_3"]
38+
messages = ["test_1", "test_2", "test_3"]
3639
async with (
3740
MockZMQServer(unused_tcp_port, signal=2) as mock_server,
38-
Client(url, ack_timeout=0.5) as c1,
41+
Client(url, ack_timeout=0.5) as client,
3942
):
40-
for message in messages_c1:
43+
for message in messages:
4144
try:
42-
await c1.send(message, retries=1)
45+
await client.send(message, retries=1)
4346
except ClientConnectionError:
4447
client_connection_error_set = True
4548
mock_server.signal(0)
4649
assert client_connection_error_set
4750
assert mock_server.messages.count("test_1") == 2
4851
assert mock_server.messages.count("test_2") == 1
4952
assert mock_server.messages.count("test_3") == 1
53+
54+
55+
async def test_reconnect_when_missing_heartbeat(unused_tcp_port, monkeypatch):
56+
host = "localhost"
57+
url = f"tcp://{host}:{unused_tcp_port}"
58+
monkeypatch.setattr(_ert.forward_model_runner.client, "HEARTBEAT_TIMEOUT", 0.1)
59+
60+
async with (
61+
MockZMQServer(unused_tcp_port, signal=3) as mock_server,
62+
Client(url) as client,
63+
):
64+
await client.send("start", retries=1)
65+
66+
await mock_server.do_heartbeat()
67+
await asyncio.sleep(1)
68+
await mock_server.do_heartbeat()
69+
await client.send("stop", retries=1)
70+
71+
# when reconnection happens CONNECT message is sent again
72+
assert mock_server.messages.count("CONNECT") == 2
73+
assert mock_server.messages.count("DISCONNECT") == 1
74+
assert "start" in mock_server.messages
75+
assert "stop" in mock_server.messages

tests/ert/unit_tests/ensemble_evaluator/test_monitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ async def mock_event_handler(router_socket):
5555
assert msg == DISCONNECT_MSG
5656

5757

58-
async def test_no_connection_established(make_ee_config):
58+
async def test_no_connection_established(monkeypatch, make_ee_config):
5959
ee_config = make_ee_config()
60+
monkeypatch.setattr(Monitor, "DEFAULT_MAX_RETRIES", 0)
6061
monitor = Monitor(ee_config.get_connection_info())
6162
monitor._ack_timeout = 0.1
6263
with pytest.raises(ClientConnectionError):

tests/ert/unit_tests/forward_model_runner/test_event_reporter.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,7 @@ def test_report_with_failed_reporter_but_finished_jobs(unused_tcp_port):
213213
def test_report_with_reconnected_reporter_but_finished_jobs(unused_tcp_port):
214214
# this is to show when the reporter fails but reconnects
215215
# reporter still manages to send events and completes fine
216-
# see assert reporter._timeout_timestamp is not None
217-
# meaning Finish event initiated _timeout but timeout wasn't reached since
218-
# it finished succesfully
216+
# see reporter._event_publisher for more details.
219217

220218
host = "localhost"
221219
url = f"tcp://{host}:{unused_tcp_port}"

tests/ert/utils.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,12 @@
99
import zmq
1010
import zmq.asyncio
1111

12-
from _ert.forward_model_runner.client import ACK_MSG, CONNECT_MSG, DISCONNECT_MSG
12+
from _ert.forward_model_runner.client import (
13+
ACK_MSG,
14+
CONNECT_MSG,
15+
DISCONNECT_MSG,
16+
HEARTBEAT_MSG,
17+
)
1318
from _ert.threading import ErtThread
1419
from ert.scheduler.event import FinishedEvent, StartedEvent
1520

@@ -64,16 +69,18 @@ def wait_until(func, interval=0.5, timeout=30):
6469
class MockZMQServer:
6570
def __init__(self, port, signal=0):
6671
"""Mock ZMQ server for testing
67-
signal = 0: normal operation
72+
signal = 0: normal operation, receive messages but don't store CONNECT and DISCONNECT messages
6873
signal = 1: don't send ACK and don't receive messages
6974
signal = 2: don't send ACK, but receive messages
75+
signal = 3: normal operation, and store also CONNECT and DISCONNECT messages
7076
"""
7177
self.port = port
7278
self.messages = []
7379
self.value = signal
7480
self.loop = None
7581
self.server_task = None
7682
self.handler_task = None
83+
self.dealers = set()
7784

7885
def start_event_loop(self):
7986
asyncio.set_event_loop(self.loop)
@@ -116,13 +123,25 @@ async def mock_zmq_server(self):
116123
def signal(self, value):
117124
self.value = value
118125

126+
async def do_heartbeat(self):
127+
for dealer in self.dealers:
128+
await self.router_socket.send_multipart([dealer, b"", HEARTBEAT_MSG])
129+
119130
async def _handler(self):
120131
while True:
121132
try:
122133
dealer, __, frame = await self.router_socket.recv_multipart()
123-
if frame in {CONNECT_MSG, DISCONNECT_MSG} or self.value == 0:
134+
if frame == CONNECT_MSG:
135+
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
136+
self.dealers.add(dealer)
137+
elif frame == DISCONNECT_MSG:
138+
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
139+
self.dealers.discard(dealer)
140+
elif self.value in {0, 3}:
124141
await self.router_socket.send_multipart([dealer, b"", ACK_MSG])
125-
if frame not in {CONNECT_MSG, DISCONNECT_MSG} and self.value != 1:
142+
if (
143+
self.value in {0, 2} and frame not in {CONNECT_MSG, DISCONNECT_MSG}
144+
) or self.value == 3:
126145
self.messages.append(frame.decode("utf-8"))
127146
except asyncio.CancelledError:
128147
break

0 commit comments

Comments
 (0)