4242from apache_beam .runners .portability .fn_api_runner import translations
4343from apache_beam .runners .portability .fn_api_runner .execution import ListBuffer
4444from apache_beam .transforms import environments
45- from apache_beam .utils import proto_utils
45+ from apache_beam .utils import proto_utils , timestamp
4646
4747import ray
4848from ray_beam_runner .portability .context_management import RayBundleContextManager
@@ -227,7 +227,9 @@ def _run_stage(
227227 bundle_context_manager (execution.BundleContextManager): A description of
228228 the stage to execute, and its context.
229229 """
230+
230231 bundle_context_manager .setup ()
232+
231233 runner_execution_context .worker_manager .register_process_bundle_descriptor (
232234 bundle_context_manager .process_bundle_descriptor
233235 )
@@ -246,6 +248,8 @@ def _run_stage(
246248 for k in bundle_context_manager .transform_to_buffer_coder
247249 }
248250
251+ watermark_manager = runner_execution_context .watermark_manager
252+
249253 final_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse]
250254
251255 while True :
@@ -262,19 +266,26 @@ def _run_stage(
262266
263267 final_result = merge_stage_results (final_result , last_result )
264268 if not delayed_applications and not fired_timers :
269+ # Processing has completed; marking all outputs as completed
270+ for output_pc in bundle_outputs :
271+ _ , update_output_pc = translations .split_buffer_id (output_pc )
272+ watermark_manager .set_pcoll_produced_watermark .remote (
273+ update_output_pc , timestamp .MAX_TIMESTAMP
274+ )
265275 break
266276 else :
267- # TODO: Enable following assertion after watermarking is implemented
268- # assert (ray.get(
269- # runner_execution_context.watermark_manager
270- # .get_stage_node.remote(
271- # bundle_context_manager.stage.name)).output_watermark()
272- # < timestamp.MAX_TIMESTAMP), (
273- # 'wrong timestamp for %s. '
274- # % ray.get(
275- # runner_execution_context.watermark_manager
276- # .get_stage_node.remote(
277- # bundle_context_manager.stage.name)))
277+ assert (
278+ ray .get (
279+ watermark_manager .get_stage_node .remote (
280+ bundle_context_manager .stage .name
281+ )
282+ ).output_watermark ()
283+ < timestamp .MAX_TIMESTAMP
284+ ), "wrong timestamp for %s. " % ray .get (
285+ watermark_manager .get_stage_node .remote (
286+ bundle_context_manager .stage .name
287+ )
288+ )
278289 input_data = delayed_applications
279290 input_timers = fired_timers
280291
@@ -288,6 +299,20 @@ def _run_stage(
288299 # TODO(pabloem): Make sure that side inputs are being stored somewhere.
289300 # runner_execution_context.commit_side_inputs_to_state(data_side_input)
290301
302+ # assert that the output watermark was correctly set for this stage
303+ stage_node = ray .get (
304+ runner_execution_context .watermark_manager .get_stage_node .remote (
305+ bundle_context_manager .stage .name
306+ )
307+ )
308+ assert (
309+ stage_node .output_watermark () == timestamp .MAX_TIMESTAMP
310+ ), "wrong output watermark for %s. Expected %s, but got %s." % (
311+ stage_node ,
312+ timestamp .MAX_TIMESTAMP ,
313+ stage_node .output_watermark (),
314+ )
315+
291316 return final_result
292317
293318 def _run_bundle (
@@ -346,6 +371,21 @@ def _run_bundle(
346371 # coder_impl=bundle_context_manager.get_input_coder_impl(
347372 # other_input))
348373
374+ # TODO: replace placeholder sets when timers are implemented
375+ watermark_updates = fn_runner .FnApiRunner ._build_watermark_updates (
376+ runner_execution_context ,
377+ transform_to_buffer_coder .keys (),
378+ set (), # expected_timers
379+ set (), # pcolls_with_da
380+ delayed_applications .keys (),
381+ set (), # watermarks_by_transform_and_timer_family
382+ )
383+
384+ for pc_name , watermark in watermark_updates .items ():
385+ runner_execution_context .watermark_manager .set_pcoll_watermark .remote (
386+ pc_name , watermark
387+ )
388+
349389 newly_set_timers = {}
350390 return result , newly_set_timers , delayed_applications , output
351391
0 commit comments