From 6aa47a4850d755ad55035d7c653e2c4864f1abfb Mon Sep 17 00:00:00 2001 From: Tarun Annapareddy Date: Thu, 25 Sep 2025 09:04:46 -0700 Subject: [PATCH 1/4] Add AftersynchronizedProcessing Time as continuation trigger --- sdks/python/apache_beam/transforms/core.py | 12 ++ .../apache_beam/transforms/ptransform_test.py | 19 +++ sdks/python/apache_beam/transforms/trigger.py | 120 +++++++++++++++++- .../apache_beam/transforms/trigger_test.py | 50 ++++++++ 4 files changed, 196 insertions(+), 5 deletions(-) diff --git a/sdks/python/apache_beam/transforms/core.py b/sdks/python/apache_beam/transforms/core.py index 2304faf478f9..cbd78d8222e8 100644 --- a/sdks/python/apache_beam/transforms/core.py +++ b/sdks/python/apache_beam/transforms/core.py @@ -3341,6 +3341,18 @@ def infer_output_type(self, input_type): return typehints.KV[ key_type, typehints.WindowedValue[value_type]] # type: ignore[misc] + def get_windowing(self, inputs): + # Switch to the continuation trigger associated with the current trigger. + windowing = inputs[0].windowing + triggerfn = windowing.triggerfn.get_continuation_trigger() + return Windowing( + windowfn=windowing.windowfn, + triggerfn=triggerfn, + accumulation_mode=windowing.accumulation_mode, + timestamp_combiner=windowing.timestamp_combiner, + allowed_lateness=windowing.allowed_lateness, + environment_id=windowing.environment_id) + def expand(self, pcoll): from apache_beam.transforms.trigger import DataLossReason from apache_beam.transforms.trigger import DefaultTrigger diff --git a/sdks/python/apache_beam/transforms/ptransform_test.py b/sdks/python/apache_beam/transforms/ptransform_test.py index 3df33bcd8be6..ea736dceddb1 100644 --- a/sdks/python/apache_beam/transforms/ptransform_test.py +++ b/sdks/python/apache_beam/transforms/ptransform_test.py @@ -47,6 +47,7 @@ from apache_beam.metrics import Metrics from apache_beam.metrics.metric import MetricsFilter from apache_beam.options.pipeline_options import PipelineOptions +from apache_beam.options.pipeline_options import StandardOptions from apache_beam.options.pipeline_options import StreamingOptions from apache_beam.options.pipeline_options import TypeOptions from apache_beam.portability import common_urns @@ -61,6 +62,9 @@ from apache_beam.transforms.display import DisplayData from apache_beam.transforms.display import DisplayDataItem from apache_beam.transforms.ptransform import PTransform +from apache_beam.transforms.trigger import AccumulationMode +from apache_beam.transforms.trigger import AfterProcessingTime +from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime from apache_beam.transforms.window import TimestampedValue from apache_beam.typehints import with_input_types from apache_beam.typehints import with_output_types @@ -510,6 +514,21 @@ def test_group_by_key_unbounded_global_default_trigger(self): with TestPipeline(options=test_options) as pipeline: pipeline | TestStream() | beam.GroupByKey() + def test_group_by_key_trigger(self): + options = PipelineOptions(['--allow_unsafe_triggers']) + options.view_as(StandardOptions).streaming = True + with TestPipeline(runner='BundleBasedDirectRunner', + options=options) as pipeline: + pcoll = pipeline | 'Start' >> beam.Create([(0, 0)]) + triggered = pcoll | 'Trigger' >> beam.WindowInto( + window.GlobalWindows(), + trigger=AfterProcessingTime(1), + accumulation_mode=AccumulationMode.DISCARDING) + output = triggered | 'Gbk' >> beam.GroupByKey() + self.assertTrue( + isinstance( + output.windowing.triggerfn, _AfterSynchronizedProcessingTime)) + def test_group_by_key_unsafe_trigger(self): test_options = PipelineOptions() test_options.view_as(TypeOptions).allow_unsafe_triggers = False diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 7d573a58e3f1..ce9a0de13aa1 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -304,7 +304,7 @@ def from_runner_api(proto, context): 'after_each': AfterEach, 'after_end_of_window': AfterWatermark, 'after_processing_time': AfterProcessingTime, - # after_processing_time, after_synchronized_processing_time + 'after_synchronized_processing_time': _AfterSynchronizedProcessingTime, 'always': Always, 'default': DefaultTrigger, 'element_count': AfterCount, @@ -317,6 +317,16 @@ def from_runner_api(proto, context): def to_runner_api(self, unused_context): pass + @abstractmethod + def get_continuation_trigger(self): + """Returns: + Trigger to use after a GroupBy to preserve the intention of this + trigger. Specifically, triggers that are time based and intended + to provide speculative results should continue providing speculative + results. Triggers that fire once (or multiple times) should + continue firing once (or multiple times). + """ + pass class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" @@ -365,7 +375,9 @@ def to_runner_api(self, unused_context): def has_ontime_pane(self): return True - + + def get_continuation_trigger(self): + return self class AfterProcessingTime(TriggerFn): """Fire exactly once after a specified delay from processing time.""" @@ -420,6 +432,12 @@ def to_runner_api(self, context): def has_ontime_pane(self): return False + + def get_continuation_trigger(self): + # The continuation of an AfterProcessingTime trigger is an + # _AfterSynchronizedProcessingTime trigger. + return _AfterSynchronizedProcessingTime() + class Always(TriggerFn): @@ -465,6 +483,9 @@ def from_runner_api(proto, context): def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( always=beam_runner_api_pb2.Trigger.Always()) + + def get_continuation_trigger(self): + return self class _Never(TriggerFn): @@ -517,6 +538,9 @@ def from_runner_api(proto, context): def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( never=beam_runner_api_pb2.Trigger.Never()) + + def get_continuation_trigger(self): + return self class AfterWatermark(TriggerFn): @@ -531,9 +555,19 @@ class AfterWatermark(TriggerFn): LATE_TAG = _CombiningValueStateTag('is_late', any) def __init__(self, early=None, late=None): - # TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly - self.early = Repeatedly(early) if early else None - self.late = Repeatedly(late) if late else None + self.early = self._wrap_if_not_repeatedly(early) + self.late = self._wrap_if_not_repeatedly(late) + + @staticmethod + def _wrap_if_not_repeatedly(trigger): + if trigger and not isinstance(trigger, Repeatedly): + return Repeatedly(trigger) + return trigger + + def get_continuation_trigger(self): + return AfterWatermark( + self.early.get_continuation_trigger() if self.early else None, + self.late.get_continuation_trigger() if self.late else None) def __repr__(self): qualifiers = [] @@ -691,6 +725,9 @@ def to_runner_api(self, unused_context): def has_ontime_pane(self): return False + + def get_continuation_trigger(self): + return AfterCount(1) class Repeatedly(TriggerFn): @@ -740,6 +777,9 @@ def to_runner_api(self, context): def has_ontime_pane(self): return self.underlying.has_ontime_pane() + + def get_continuation_trigger(self): + return Repeatedly(self.underlying.get_continuation_trigger()) class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta): @@ -830,6 +870,12 @@ def to_runner_api(self, context): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) + + def get_continuation_trigger(self): + return self.__class__( + *( + subtrigger.get_continuation_trigger() + for subtrigger in self.triggers)) class AfterAny(_ParallelTriggerFn): @@ -932,6 +978,13 @@ def to_runner_api(self, context): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) + + def get_continuation_trigger(self): + return Repeatedly( + AfterAny( + *( + subtrigger.get_continuation_trigger() + for subtrigger in self.triggers))) class OrFinally(AfterAny): @@ -1643,3 +1696,60 @@ def __repr__(self): state_str = '\n'.join( '%s: %s' % (key, dict(state)) for key, state in self.state.items()) return 'timers: %s\nstate: %s' % (dict(self.timers), state_str) + + +class _AfterSynchronizedProcessingTime(TriggerFn): + """A "runner's-discretion" trigger downstream of a GroupByKey + with AfterProcessingTime trigger. + + In runners that directly execute this + Python code, the trigger currently always fires, + but this behavior is neither guaranteed nor + required by runners, regardless of whether they + execute triggers via Python. + + _AfterSynchronizedProcessingTime is experimental + and internal-only. No backwards compatibility + guarantees. + """ + def __init__(self): + pass + + def __repr__(self): + return '_AfterSynchronizedProcessingTime()' + + def __eq__(self, other): + return type(self) == type(other) + + def __hash__(self): + return hash(type(self)) + + def on_element(self, _element, _window, _context): + pass + + def on_merge(self, _to_be_merged, _merge_result, _context): + pass + + def should_fire(self, _time_domain, _timestamp, _window, _context): + return True + + def on_fire(self, _timestamp, _window, _context): + return False + + def reset(self, _window, _context): + pass + + @staticmethod + def from_runner_api(_proto, _context): + return _AfterSynchronizedProcessingTime() + + def to_runner_api(self, _context): + return beam_runner_api_pb2.Trigger( + after_synchronized_processing_time=beam_runner_api_pb2.Trigger. + AfterSynchronizedProcessingTime()) + + def has_ontime_pane(self): + return False + + def get_continuation_trigger(self): + return self \ No newline at end of file diff --git a/sdks/python/apache_beam/transforms/trigger_test.py b/sdks/python/apache_beam/transforms/trigger_test.py index b9a8cdc594b5..9f9b7fe51a9f 100644 --- a/sdks/python/apache_beam/transforms/trigger_test.py +++ b/sdks/python/apache_beam/transforms/trigger_test.py @@ -554,6 +554,56 @@ def test_trigger_encoding(self): TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context)) +class ContinuationTriggerTest(unittest.TestCase): + def test_after_all(self): + self.assertEqual( + AfterAll(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + AfterAll(AfterCount(1), AfterCount(1))) + + def test_after_any(self): + self.assertEqual( + AfterAny(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + AfterAny(AfterCount(1), AfterCount(1))) + + def test_after_count(self): + self.assertEqual(AfterCount(1).get_continuation_trigger(), AfterCount(1)) + self.assertEqual(AfterCount(100).get_continuation_trigger(), AfterCount(1)) + + def test_after_each(self): + self.assertEqual( + AfterEach(AfterCount(2), AfterCount(5)).get_continuation_trigger(), + Repeatedly(AfterAny(AfterCount(1), AfterCount(1)))) + + def test_after_processing_time(self): + from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime + self.assertEqual( + AfterProcessingTime(10).get_continuation_trigger(), + _AfterSynchronizedProcessingTime()) + + def test_after_watermark(self): + self.assertEqual( + AfterWatermark().get_continuation_trigger(), AfterWatermark()) + self.assertEqual( + AfterWatermark(early=AfterCount(10), + late=AfterCount(20)).get_continuation_trigger(), + AfterWatermark(early=AfterCount(1), late=AfterCount(1))) + + def test_always(self): + self.assertEqual(Always().get_continuation_trigger(), Always()) + + def test_default(self): + self.assertEqual( + DefaultTrigger().get_continuation_trigger(), DefaultTrigger()) + + def test_never(self): + self.assertEqual(_Never().get_continuation_trigger(), _Never()) + + def test_repeatedly(self): + self.assertEqual( + Repeatedly(AfterCount(10)).get_continuation_trigger(), + Repeatedly(AfterCount(1))) + + class TriggerPipelineTest(unittest.TestCase): def test_after_processing_time(self): test_options = PipelineOptions( From d5e42f263248ae4aaa9f10e91eac2afcc2e96534 Mon Sep 17 00:00:00 2001 From: Tarun Annapareddy Date: Thu, 25 Sep 2025 09:10:36 -0700 Subject: [PATCH 2/4] fix trailing space --- sdks/python/apache_beam/transforms/trigger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index ce9a0de13aa1..039a1155eda3 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -1752,4 +1752,5 @@ def has_ontime_pane(self): return False def get_continuation_trigger(self): - return self \ No newline at end of file + return self + \ No newline at end of file From d9cf1597f7bc53af7b5716e7b5c7b8c65578c6a6 Mon Sep 17 00:00:00 2001 From: Tarun Annapareddy Date: Thu, 25 Sep 2025 09:12:30 -0700 Subject: [PATCH 3/4] fix trailing space --- sdks/python/apache_beam/transforms/trigger.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 039a1155eda3..466d47f3f62f 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -1753,4 +1753,3 @@ def has_ontime_pane(self): def get_continuation_trigger(self): return self - \ No newline at end of file From 14a601921aa1f5068d890e8e7bf945af1eb41ad0 Mon Sep 17 00:00:00 2001 From: Tarun Annapareddy Date: Thu, 25 Sep 2025 10:12:53 -0700 Subject: [PATCH 4/4] fix formatting --- sdks/python/apache_beam/transforms/trigger.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/sdks/python/apache_beam/transforms/trigger.py b/sdks/python/apache_beam/transforms/trigger.py index 466d47f3f62f..cc9922dd158f 100644 --- a/sdks/python/apache_beam/transforms/trigger.py +++ b/sdks/python/apache_beam/transforms/trigger.py @@ -328,6 +328,7 @@ def get_continuation_trigger(self): """ pass + class DefaultTrigger(TriggerFn): """Semantically Repeatedly(AfterWatermark()), but more optimized.""" def __init__(self): @@ -375,10 +376,11 @@ def to_runner_api(self, unused_context): def has_ontime_pane(self): return True - + def get_continuation_trigger(self): return self + class AfterProcessingTime(TriggerFn): """Fire exactly once after a specified delay from processing time.""" @@ -432,14 +434,13 @@ def to_runner_api(self, context): def has_ontime_pane(self): return False - + def get_continuation_trigger(self): # The continuation of an AfterProcessingTime trigger is an # _AfterSynchronizedProcessingTime trigger. return _AfterSynchronizedProcessingTime() - class Always(TriggerFn): """Repeatedly invoke the given trigger, never finishing.""" def __init__(self): @@ -483,7 +484,7 @@ def from_runner_api(proto, context): def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( always=beam_runner_api_pb2.Trigger.Always()) - + def get_continuation_trigger(self): return self @@ -538,7 +539,7 @@ def from_runner_api(proto, context): def to_runner_api(self, context): return beam_runner_api_pb2.Trigger( never=beam_runner_api_pb2.Trigger.Never()) - + def get_continuation_trigger(self): return self @@ -725,7 +726,7 @@ def to_runner_api(self, unused_context): def has_ontime_pane(self): return False - + def get_continuation_trigger(self): return AfterCount(1) @@ -777,7 +778,7 @@ def to_runner_api(self, context): def has_ontime_pane(self): return self.underlying.has_ontime_pane() - + def get_continuation_trigger(self): return Repeatedly(self.underlying.get_continuation_trigger()) @@ -870,7 +871,7 @@ def to_runner_api(self, context): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) - + def get_continuation_trigger(self): return self.__class__( *( @@ -978,7 +979,7 @@ def to_runner_api(self, context): def has_ontime_pane(self): return any(t.has_ontime_pane() for t in self.triggers) - + def get_continuation_trigger(self): return Repeatedly( AfterAny(