Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions dranspose/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def replay(args: argparse.Namespace) -> None:
args.port,
keepalive,
args.nworkers,
latency=args.latency,
loop=args.loop,
)


Expand Down Expand Up @@ -306,6 +308,15 @@ def create_parser() -> argparse.ArgumentParser:
default=2,
type=int,
)
parser_replay.add_argument(
"--latency",
help="time in seconds between replay frames",
default=0,
type=float,
)
parser_replay.add_argument(
"--loop", action="store_true", help="continuously loop over replay data"
)

return parser

Expand Down
23 changes: 17 additions & 6 deletions dranspose/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import time
import traceback
from typing import ContextManager, Iterator, Any, Optional, IO, Tuple
import itertools

import cbor2
import uvicorn
Expand All @@ -22,6 +23,7 @@
from dranspose.event import (
InternalWorkerMessage,
EventData,
EventNumber,
ResultData,
message_tag_hook,
)
Expand Down Expand Up @@ -257,7 +259,8 @@ def replay(
broadcast_first: bool = True,
done_event: threading.Event | None = None,
start_event: threading.Event | None = None,
latency: float | None = None,
latency: float = 0,
loop: bool = False,
) -> None:
if source is not None:
sourcecls = utils.import_class(source)
Expand All @@ -267,6 +270,8 @@ def replay(
gens = [get_internals(f) for f in zmq_files]
else:
gens = []
if loop:
gens = list(map(itertools.cycle, gens))

workercls = utils.import_class(wclass)
logger.info("custom worker class %s", workercls)
Expand Down Expand Up @@ -299,15 +304,15 @@ def replay(
reducer_app, port=port or 5000, host="localhost", log_level="info"
)
server = Server(config)
# server.run()

first = True
last_output_event_number = -1

with server.run_in_thread(port):
cache = [None for _ in gens]
last_tick = 0.0
if start_event is not None:
start_event.wait()

while True:
try:
internals = [
Expand All @@ -316,13 +321,18 @@ def replay(
if len(internals) == 0:
break
lowestevn = min([ev.event_number for ev in internals])

lowinternals = []
cache = internals
for idx, ie in enumerate(internals):
if ie.event_number == lowestevn:
lowinternals.append(ie)
cache[idx] = None
event = EventData.from_internals(lowinternals)
if loop:
if last_output_event_number > event.event_number:
event.event_number = EventNumber(last_output_event_number + 1)
last_output_event_number = event.event_number
Comment thread
felix-engelmann marked this conversation as resolved.

dst_worker_ids = [random.randint(0, len(workers) - 1)]
if first and broadcast_first:
Expand All @@ -347,8 +357,10 @@ def replay(
reducer_app.state.parameters,
tick,
)
if latency is not None:
time.sleep(latency)
time.sleep(latency)
if stop_event is not None:
if stop_event.is_set():
raise StopIteration()
except StopIteration:
logger.debug("end of replay, calling finish")
_finish(workers, reducer, reducer_app.state.parameters)
Expand All @@ -366,4 +378,3 @@ def replay(
stop_event.wait()
except KeyboardInterrupt:
pass
Comment thread
felix-engelmann marked this conversation as resolved.
logger.info("replay finished")
Comment thread
felix-engelmann marked this conversation as resolved.
70 changes: 70 additions & 0 deletions tests/test_replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,76 @@ async def test_replay(
)


@pytest.mark.skipif("config.getoption('rust')", reason="rust does not support dumping")
@pytest.mark.asyncio
async def test_replay_looping(
controller: None,
reducer: Callable[[Optional[str]], Awaitable[None]],
create_worker: Callable[[WorkerName], Awaitable[Worker]],
create_ingester: Callable[[Ingester], Awaitable[Ingester]],
stream_eiger: Callable[[zmq.Context[Any], int, int], Coroutine[Any, Any, None]],
stream_orca: Callable[[zmq.Context[Any], int, int], Coroutine[Any, Any, None]],
stream_small: Callable[[zmq.Context[Any], int, int], Coroutine[Any, Any, None]],
tmp_path: Any,
) -> None:
p_eiger, p_prefix, uuid = await dump_data(
reducer,
create_worker,
create_ingester,
stream_eiger,
stream_orca,
stream_small,
tmp_path,
)
# read dump

par_file = generate_params(tmp_path)
stop_event = threading.Event()

thread = threading.Thread(
target=replay,
args=(
"tests.aux_payloads:TestWorker",
"tests.aux_payloads:TestReducer",
[p_eiger, f"{p_prefix}orca-ingester-{uuid}.cbors"],
None,
par_file,
),
kwargs={"port": 5010, "stop_event": stop_event, "loop": True, "latency": 0.1},
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The test sets a latency parameter, but does not check if the replay actually happens "slow" enough

)
thread.start()

async def check_poll_results() -> None:
print("Starting to poll hdf5-rest output with h5pyd")
Comment thread
felix-engelmann marked this conversation as resolved.
Outdated
await asyncio.sleep(2)
# NOTE: See https://github.com/felix-engelmann/dranspose/pull/58#discussion_r2861877247
# the results have length 11, though the stream seems to be length 10
single_run_len_results = 11
max_wait_time = 5
wait_step_duration = 0.5
wait_total_steps = int((max_wait_time // wait_step_duration) + 1)
for _ in range(wait_total_steps):
f = h5pyd.File("http://localhost:5010/", "r", timeout=5)
len_results = len(f.get("results", []))
print("Length of results:", len_results)
Comment thread
felix-engelmann marked this conversation as resolved.
Outdated

if len_results > single_run_len_results:
Comment thread
felix-engelmann marked this conversation as resolved.
Outdated
return
await asyncio.sleep(wait_step_duration)
assert False, "Results never had more than 10 entries"

try:
await check_poll_results()
except Exception as err:
raise Exception from err
finally:
logging.info("shut down server")
stop_event.set()
Comment thread
felix-engelmann marked this conversation as resolved.
thread.join()
await asyncio.sleep(0.1)
logging.info("thread joined")


@pytest.mark.skipif("config.getoption('rust')", reason="rust does not support dumping")
@pytest.mark.asyncio
async def test_replay_gzip(
Expand Down