2929from typing import Tuple
3030from typing import Union
3131
32+ from apache_beam .coders .coder_impl import create_OutputStream
3233from apache_beam .options import pipeline_options
3334from apache_beam .options .value_provider import RuntimeValueProvider
3435from apache_beam .pipeline import Pipeline
4243from apache_beam .runners .portability .fn_api_runner import translations
4344from apache_beam .runners .portability .fn_api_runner .execution import ListBuffer
4445from apache_beam .transforms import environments
45- from apache_beam .utils import proto_utils
46+ from apache_beam .utils import proto_utils , timestamp
4647
4748import ray
4849from ray_beam_runner .portability .context_management import RayBundleContextManager
@@ -334,6 +335,11 @@ def _run_bundle(
334335 )
335336 result = beam_fn_api_pb2 .InstructionResponse .FromString (result_str )
336337
338+ (
339+ watermarks_by_transform_and_timer_family ,
340+ newly_set_timers ,
341+ ) = self ._collect_written_timers (bundle_context_manager )
342+
337343 # TODO(pabloem): Add support for splitting of results.
338344
339345 # After collecting deferred inputs, we 'pad' the structure with empty
@@ -346,9 +352,84 @@ def _run_bundle(
346352 # coder_impl=bundle_context_manager.get_input_coder_impl(
347353 # other_input))
348354
349- newly_set_timers = {}
350355 return result , newly_set_timers , delayed_applications , output
351356
357+ @staticmethod
358+ def _collect_written_timers (
359+ bundle_context_manager : RayBundleContextManager ,
360+ ) -> Tuple [
361+ Dict [translations .TimerFamilyId , timestamp .Timestamp ],
362+ Mapping [translations .TimerFamilyId , execution .PartitionableBuffer ],
363+ ]:
364+ """Review output buffers, and collect written timers.
365+ This function reviews a stage that has just been run. The stage will have
366+ written timers to its output buffers. The function then takes the timers,
367+ and adds them to the `newly_set_timers` dictionary, and the
368+ timer_watermark_data dictionary.
369+ The function then returns the following two elements in a tuple:
370+ - timer_watermark_data: A dictionary mapping timer family to upcoming
371+ timestamp to fire.
372+ - newly_set_timers: A dictionary mapping timer family to timer buffers
373+ to be passed to the SDK upon firing.
374+ """
375+ timer_watermark_data = {}
376+ newly_set_timers = {}
377+
378+ execution_context = bundle_context_manager .execution_context
379+ buffer_manager = execution_context .pcollection_buffers
380+
381+ for (
382+ transform_id ,
383+ timer_family_id ,
384+ ), buffer_id in bundle_context_manager .stage_timers .items ():
385+ timer_buffer = ray .get (buffer_manager .get .remote (buffer_id ))
386+
387+ coder_id = bundle_context_manager ._timer_coder_ids [
388+ (transform_id , timer_family_id )
389+ ]
390+
391+ coder = execution_context .pipeline_context .coders [coder_id ]
392+ timer_coder_impl = coder .get_impl ()
393+
394+ timers_by_key_tag_and_window = {}
395+ if len (timer_buffer ) >= 1 :
396+ written_timers = ray .get (timer_buffer [0 ])
397+ # clear the timer buffer
398+ buffer_manager .clear .remote (buffer_id )
399+
400+ # deduplicate updates to the same timer
401+ for elements_timers in written_timers :
402+ for decoded_timer in timer_coder_impl .decode_all (elements_timers ):
403+ key_tag_win = (
404+ decoded_timer .user_key ,
405+ decoded_timer .dynamic_timer_tag ,
406+ decoded_timer .windows [0 ],
407+ )
408+ if not decoded_timer .clear_bit :
409+ timers_by_key_tag_and_window [key_tag_win ] = decoded_timer
410+ elif (
411+ decoded_timer .clear_bit
412+ and key_tag_win in timers_by_key_tag_and_window
413+ ):
414+ del timers_by_key_tag_and_window [key_tag_win ]
415+ if not timers_by_key_tag_and_window :
416+ continue
417+
418+ out = create_OutputStream ()
419+ for decoded_timer in timers_by_key_tag_and_window .values ():
420+ timer_coder_impl .encode_to_stream (decoded_timer , out , True )
421+ timer_watermark_data [(transform_id , timer_family_id )] = min (
422+ timer_watermark_data .get (
423+ (transform_id , timer_family_id ), timestamp .MAX_TIMESTAMP
424+ ),
425+ decoded_timer .hold_timestamp ,
426+ )
427+
428+ buf = ListBuffer (coder_impl = timer_coder_impl )
429+ buf .append (out .get ())
430+ newly_set_timers [(transform_id , timer_family_id )] = buf
431+ return timer_watermark_data , newly_set_timers
432+
352433
353434class RayRunnerResult (runner .PipelineResult ):
354435 def __init__ (self , state ):
0 commit comments