Skip to content

Commit 41b141f

Browse files
authored
Merge pull request #31 from implement basic timer plumbing
implement basic timer plumbing
2 parents 45cf717 + b767100 commit 41b141f

File tree

3 files changed

+89
-10
lines changed

3 files changed

+89
-10
lines changed

ray_beam_runner/portability/execution.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -116,8 +116,8 @@ def ray_execute_bundle(
116116
):
117117
if isinstance(output, beam_fn_api_pb2.Elements.Timers) and not dry_run:
118118
output_buffers[
119-
expected_outputs[(output.transform_id, output.timer_family_id)]
120-
].append(output.data)
119+
stage_timers[(output.transform_id, output.timer_family_id)]
120+
].append(output.timers)
121121
if isinstance(output, beam_fn_api_pb2.Elements.Data) and not dry_run:
122122
output_buffers[expected_outputs[output.transform_id]].append(output.data)
123123

@@ -342,6 +342,9 @@ def put(self, pcoll, data_refs: List[ray.ObjectRef]):
342342
def get(self, pcoll) -> List[ray.ObjectRef]:
343343
return self.buffers[pcoll]
344344

345+
def clear(self, pcoll):
346+
self.buffers[pcoll].clear()
347+
345348

346349
@ray.remote
347350
class RayWatermarkManager(watermark_manager.WatermarkManager):
@@ -450,6 +453,7 @@ def _build_timer_coders_id_map(self):
450453
timer_coder_ids[
451454
(transform_id, id)
452455
] = timer_family_spec.timer_family_coder_id
456+
return timer_coder_ids
453457

454458
def __reduce__(self):
455459
# We need to implement custom serialization for this particular class

ray_beam_runner/portability/ray_fn_runner.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from typing import Tuple
3030
from typing import Union
3131

32+
from apache_beam.coders.coder_impl import create_OutputStream
3233
from apache_beam.options import pipeline_options
3334
from apache_beam.options.value_provider import RuntimeValueProvider
3435
from apache_beam.pipeline import Pipeline
@@ -42,7 +43,7 @@
4243
from apache_beam.runners.portability.fn_api_runner import translations
4344
from apache_beam.runners.portability.fn_api_runner.execution import ListBuffer
4445
from apache_beam.transforms import environments
45-
from apache_beam.utils import proto_utils
46+
from apache_beam.utils import proto_utils, timestamp
4647

4748
import ray
4849
from ray_beam_runner.portability.context_management import RayBundleContextManager
@@ -334,6 +335,11 @@ def _run_bundle(
334335
)
335336
result = beam_fn_api_pb2.InstructionResponse.FromString(result_str)
336337

338+
(
339+
watermarks_by_transform_and_timer_family,
340+
newly_set_timers,
341+
) = self._collect_written_timers(bundle_context_manager)
342+
337343
# TODO(pabloem): Add support for splitting of results.
338344

339345
# After collecting deferred inputs, we 'pad' the structure with empty
@@ -346,9 +352,84 @@ def _run_bundle(
346352
# coder_impl=bundle_context_manager.get_input_coder_impl(
347353
# other_input))
348354

349-
newly_set_timers = {}
350355
return result, newly_set_timers, delayed_applications, output
351356

357+
@staticmethod
358+
def _collect_written_timers(
359+
bundle_context_manager: RayBundleContextManager,
360+
) -> Tuple[
361+
Dict[translations.TimerFamilyId, timestamp.Timestamp],
362+
Mapping[translations.TimerFamilyId, execution.PartitionableBuffer],
363+
]:
364+
"""Review output buffers, and collect written timers.
365+
This function reviews a stage that has just been run. The stage will have
366+
written timers to its output buffers. The function then takes the timers,
367+
and adds them to the `newly_set_timers` dictionary, and the
368+
timer_watermark_data dictionary.
369+
The function then returns the following two elements in a tuple:
370+
- timer_watermark_data: A dictionary mapping timer family to upcoming
371+
timestamp to fire.
372+
- newly_set_timers: A dictionary mapping timer family to timer buffers
373+
to be passed to the SDK upon firing.
374+
"""
375+
timer_watermark_data = {}
376+
newly_set_timers = {}
377+
378+
execution_context = bundle_context_manager.execution_context
379+
buffer_manager = execution_context.pcollection_buffers
380+
381+
for (
382+
transform_id,
383+
timer_family_id,
384+
), buffer_id in bundle_context_manager.stage_timers.items():
385+
timer_buffer = ray.get(buffer_manager.get.remote(buffer_id))
386+
387+
coder_id = bundle_context_manager._timer_coder_ids[
388+
(transform_id, timer_family_id)
389+
]
390+
391+
coder = execution_context.pipeline_context.coders[coder_id]
392+
timer_coder_impl = coder.get_impl()
393+
394+
timers_by_key_tag_and_window = {}
395+
if len(timer_buffer) >= 1:
396+
written_timers = ray.get(timer_buffer[0])
397+
# clear the timer buffer
398+
buffer_manager.clear.remote(buffer_id)
399+
400+
# deduplicate updates to the same timer
401+
for elements_timers in written_timers:
402+
for decoded_timer in timer_coder_impl.decode_all(elements_timers):
403+
key_tag_win = (
404+
decoded_timer.user_key,
405+
decoded_timer.dynamic_timer_tag,
406+
decoded_timer.windows[0],
407+
)
408+
if not decoded_timer.clear_bit:
409+
timers_by_key_tag_and_window[key_tag_win] = decoded_timer
410+
elif (
411+
decoded_timer.clear_bit
412+
and key_tag_win in timers_by_key_tag_and_window
413+
):
414+
del timers_by_key_tag_and_window[key_tag_win]
415+
if not timers_by_key_tag_and_window:
416+
continue
417+
418+
out = create_OutputStream()
419+
for decoded_timer in timers_by_key_tag_and_window.values():
420+
timer_coder_impl.encode_to_stream(decoded_timer, out, True)
421+
timer_watermark_data[(transform_id, timer_family_id)] = min(
422+
timer_watermark_data.get(
423+
(transform_id, timer_family_id), timestamp.MAX_TIMESTAMP
424+
),
425+
decoded_timer.hold_timestamp,
426+
)
427+
428+
buf = ListBuffer(coder_impl=timer_coder_impl)
429+
buf.append(out.get())
430+
newly_set_timers[(transform_id, timer_family_id)] = buf
431+
return timer_watermark_data, newly_set_timers
432+
352433

353434
class RayRunnerResult(runner.PipelineResult):
354435
def __init__(self, state):

ray_beam_runner/portability/ray_runner_test.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -433,7 +433,6 @@ def process_timer(self):
433433
# expected = [('fired', ts) for ts in (20, 200)]
434434
# assert_that(actual, equal_to(expected))
435435

436-
@unittest.skip("Timers not yet supported")
437436
def test_pardo_timers(self):
438437
timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK)
439438
state_spec = userstate.CombiningValueStateSpec("num_called", sum)
@@ -467,7 +466,6 @@ def process_timer(
467466
expected = [("fired", ts) for ts in (20, 200, 40, 400)]
468467
assert_that(actual, equal_to(expected))
469468

470-
@unittest.skip("Timers not yet supported")
471469
def test_pardo_timers_clear(self):
472470
timer_spec = userstate.TimerSpec("timer", userstate.TimeDomain.WATERMARK)
473471
clear_timer_spec = userstate.TimerSpec(
@@ -506,15 +504,12 @@ def process_clear_timer(self):
506504
expected = [("fired", ts) for ts in (20, 200)]
507505
assert_that(actual, equal_to(expected))
508506

509-
@unittest.skip("Timers not yet supported")
510507
def test_pardo_state_timers(self):
511508
self._run_pardo_state_timers(windowed=False)
512509

513-
@unittest.skip("Timers not yet supported")
514510
def test_pardo_state_timers_non_standard_coder(self):
515511
self._run_pardo_state_timers(windowed=False, key_type=Any)
516512

517-
@unittest.skip("Timers not yet supported")
518513
def test_windowed_pardo_state_timers(self):
519514
self._run_pardo_state_timers(windowed=True)
520515

@@ -587,7 +582,6 @@ def is_buffered_correctly(actual):
587582

588583
assert_that(actual, is_buffered_correctly)
589584

590-
@unittest.skip("Timers not yet supported")
591585
def test_pardo_dynamic_timer(self):
592586
class DynamicTimerDoFn(beam.DoFn):
593587
dynamic_timer_spec = userstate.TimerSpec(

0 commit comments

Comments
 (0)