@@ -228,7 +228,9 @@ def _run_stage(
228228          bundle_context_manager (execution.BundleContextManager): A description of 
229229            the stage to execute, and its context. 
230230        """ 
231+ 
231232        bundle_context_manager .setup ()
233+ 
232234        runner_execution_context .worker_manager .register_process_bundle_descriptor (
233235            bundle_context_manager .process_bundle_descriptor 
234236        )
@@ -247,6 +249,8 @@ def _run_stage(
247249            for  k  in  bundle_context_manager .transform_to_buffer_coder 
248250        }
249251
252+         watermark_manager  =  runner_execution_context .watermark_manager 
253+ 
250254        final_result  =  None   # type: Optional[beam_fn_api_pb2.InstructionResponse] 
251255
252256        while  True :
@@ -263,19 +267,28 @@ def _run_stage(
263267
264268            final_result  =  merge_stage_results (final_result , last_result )
265269            if  not  delayed_applications  and  not  fired_timers :
270+                 # Processing has completed; marking all outputs as completed 
271+                 # TODO: why is it necessary to set both the watermark and produced_watermark? 
272+                 # How do they interact? 
273+                 for  output_pc  in  bundle_outputs :
274+                     _ , update_output_pc  =  translations .split_buffer_id (output_pc )
275+                     watermark_manager .set_pcoll_produced_watermark .remote (
276+                         update_output_pc , timestamp .MAX_TIMESTAMP 
277+                     )
266278                break 
267279            else :
268-                 # TODO: Enable following assertion after watermarking is implemented 
269-                 # assert (ray.get( 
270-                 # runner_execution_context.watermark_manager 
271-                 # .get_stage_node.remote( 
272-                 #     bundle_context_manager.stage.name)).output_watermark() 
273-                 #         < timestamp.MAX_TIMESTAMP), ( 
274-                 #     'wrong timestamp for %s. ' 
275-                 #     % ray.get( 
276-                 #     runner_execution_context.watermark_manager 
277-                 #     .get_stage_node.remote( 
278-                 #     bundle_context_manager.stage.name))) 
280+                 assert  (
281+                     ray .get (
282+                         watermark_manager .get_stage_node .remote (
283+                             bundle_context_manager .stage .name 
284+                         )
285+                     ).output_watermark ()
286+                     <  timestamp .MAX_TIMESTAMP 
287+                 ), "wrong timestamp for %s. "  %  ray .get (
288+                     watermark_manager .get_stage_node .remote (
289+                         bundle_context_manager .stage .name 
290+                     )
291+                 )
279292                input_data  =  delayed_applications 
280293                input_timers  =  fired_timers 
281294
@@ -289,6 +302,20 @@ def _run_stage(
289302        # TODO(pabloem): Make sure that side inputs are being stored somewhere. 
290303        # runner_execution_context.commit_side_inputs_to_state(data_side_input) 
291304
305+         # assert that the output watermark was correctly set for this stage 
306+         stage_node  =  ray .get (
307+             runner_execution_context .watermark_manager .get_stage_node .remote (
308+                 bundle_context_manager .stage .name 
309+             )
310+         )
311+         assert  (
312+             stage_node .output_watermark () ==  timestamp .MAX_TIMESTAMP 
313+         ), "wrong output watermark for %s. Expected %s, but got %s."  %  (
314+             stage_node ,
315+             timestamp .MAX_TIMESTAMP ,
316+             stage_node .output_watermark (),
317+         )
318+ 
292319        return  final_result 
293320
294321    def  _run_bundle (
@@ -352,6 +379,21 @@ def _run_bundle(
352379        #           coder_impl=bundle_context_manager.get_input_coder_impl( 
353380        #               other_input)) 
354381
382+         # TODO: fill expected timers and pcolls with da 
383+         watermark_updates  =  fn_runner .FnApiRunner ._build_watermark_updates (
384+             runner_execution_context ,
385+             transform_to_buffer_coder .keys (),
386+             bundle_context_manager .stage_timers .keys (),  # expected_timers 
387+             set (),  # pcolls_with_da 
388+             delayed_applications .keys (),
389+             watermarks_by_transform_and_timer_family ,
390+         )
391+ 
392+         for  pc_name , watermark  in  watermark_updates .items ():
393+             runner_execution_context .watermark_manager .set_pcoll_watermark .remote (
394+                 pc_name , watermark 
395+             )
396+ 
355397        return  result , newly_set_timers , delayed_applications , output 
356398
357399    @staticmethod  
0 commit comments