Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions sdks/python/apache_beam/transforms/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3341,6 +3341,18 @@ def infer_output_type(self, input_type):
return typehints.KV[
key_type, typehints.WindowedValue[value_type]] # type: ignore[misc]

def get_windowing(self, inputs):
# Switch to the continuation trigger associated with the current trigger.
windowing = inputs[0].windowing
triggerfn = windowing.triggerfn.get_continuation_trigger()
return Windowing(
windowfn=windowing.windowfn,
triggerfn=triggerfn,
accumulation_mode=windowing.accumulation_mode,
timestamp_combiner=windowing.timestamp_combiner,
allowed_lateness=windowing.allowed_lateness,
environment_id=windowing.environment_id)

def expand(self, pcoll):
from apache_beam.transforms.trigger import DataLossReason
from apache_beam.transforms.trigger import DefaultTrigger
Expand Down
19 changes: 19 additions & 0 deletions sdks/python/apache_beam/transforms/ptransform_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from apache_beam.metrics import Metrics
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.options.pipeline_options import PipelineOptions
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.options.pipeline_options import StreamingOptions
from apache_beam.options.pipeline_options import TypeOptions
from apache_beam.portability import common_urns
Expand All @@ -61,6 +62,9 @@
from apache_beam.transforms.display import DisplayData
from apache_beam.transforms.display import DisplayDataItem
from apache_beam.transforms.ptransform import PTransform
from apache_beam.transforms.trigger import AccumulationMode
from apache_beam.transforms.trigger import AfterProcessingTime
from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
from apache_beam.transforms.window import TimestampedValue
from apache_beam.typehints import with_input_types
from apache_beam.typehints import with_output_types
Expand Down Expand Up @@ -510,6 +514,21 @@ def test_group_by_key_unbounded_global_default_trigger(self):
with TestPipeline(options=test_options) as pipeline:
pipeline | TestStream() | beam.GroupByKey()

def test_group_by_key_trigger(self):
options = PipelineOptions(['--allow_unsafe_triggers'])
options.view_as(StandardOptions).streaming = True
with TestPipeline(runner='BundleBasedDirectRunner',
options=options) as pipeline:
pcoll = pipeline | 'Start' >> beam.Create([(0, 0)])
triggered = pcoll | 'Trigger' >> beam.WindowInto(
window.GlobalWindows(),
trigger=AfterProcessingTime(1),
accumulation_mode=AccumulationMode.DISCARDING)
output = triggered | 'Gbk' >> beam.GroupByKey()
self.assertTrue(
isinstance(
output.windowing.triggerfn, _AfterSynchronizedProcessingTime))

def test_group_by_key_unsafe_trigger(self):
test_options = PipelineOptions()
test_options.view_as(TypeOptions).allow_unsafe_triggers = False
Expand Down
119 changes: 115 additions & 4 deletions sdks/python/apache_beam/transforms/trigger.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,7 @@ def from_runner_api(proto, context):
'after_each': AfterEach,
'after_end_of_window': AfterWatermark,
'after_processing_time': AfterProcessingTime,
# after_processing_time, after_synchronized_processing_time
'after_synchronized_processing_time': _AfterSynchronizedProcessingTime,
'always': Always,
'default': DefaultTrigger,
'element_count': AfterCount,
Expand All @@ -317,6 +317,17 @@ def from_runner_api(proto, context):
def to_runner_api(self, unused_context):
pass

@abstractmethod
def get_continuation_trigger(self):
"""Returns:
Trigger to use after a GroupBy to preserve the intention of this
trigger. Specifically, triggers that are time based and intended
to provide speculative results should continue providing speculative
results. Triggers that fire once (or multiple times) should
continue firing once (or multiple times).
"""
pass


class DefaultTrigger(TriggerFn):
"""Semantically Repeatedly(AfterWatermark()), but more optimized."""
Expand Down Expand Up @@ -366,6 +377,9 @@ def to_runner_api(self, unused_context):
def has_ontime_pane(self):
return True

def get_continuation_trigger(self):
return self


class AfterProcessingTime(TriggerFn):
"""Fire exactly once after a specified delay from processing time."""
Expand Down Expand Up @@ -421,6 +435,11 @@ def to_runner_api(self, context):
def has_ontime_pane(self):
return False

def get_continuation_trigger(self):
# The continuation of an AfterProcessingTime trigger is an
# _AfterSynchronizedProcessingTime trigger.
return _AfterSynchronizedProcessingTime()


class Always(TriggerFn):
"""Repeatedly invoke the given trigger, never finishing."""
Expand Down Expand Up @@ -466,6 +485,9 @@ def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
always=beam_runner_api_pb2.Trigger.Always())

def get_continuation_trigger(self):
return self


class _Never(TriggerFn):
"""A trigger that never fires.
Expand Down Expand Up @@ -518,6 +540,9 @@ def to_runner_api(self, context):
return beam_runner_api_pb2.Trigger(
never=beam_runner_api_pb2.Trigger.Never())

def get_continuation_trigger(self):
return self


class AfterWatermark(TriggerFn):
"""Fire exactly once when the watermark passes the end of the window.
Expand All @@ -531,9 +556,19 @@ class AfterWatermark(TriggerFn):
LATE_TAG = _CombiningValueStateTag('is_late', any)

def __init__(self, early=None, late=None):
# TODO(zhoufek): Maybe don't wrap early/late if they are already Repeatedly
self.early = Repeatedly(early) if early else None
self.late = Repeatedly(late) if late else None
self.early = self._wrap_if_not_repeatedly(early)
self.late = self._wrap_if_not_repeatedly(late)

@staticmethod
def _wrap_if_not_repeatedly(trigger):
if trigger and not isinstance(trigger, Repeatedly):
return Repeatedly(trigger)
return trigger

def get_continuation_trigger(self):
return AfterWatermark(
self.early.get_continuation_trigger() if self.early else None,
self.late.get_continuation_trigger() if self.late else None)

def __repr__(self):
qualifiers = []
Expand Down Expand Up @@ -692,6 +727,9 @@ def to_runner_api(self, unused_context):
def has_ontime_pane(self):
return False

def get_continuation_trigger(self):
return AfterCount(1)


class Repeatedly(TriggerFn):
"""Repeatedly invoke the given trigger, never finishing."""
Expand Down Expand Up @@ -741,6 +779,9 @@ def to_runner_api(self, context):
def has_ontime_pane(self):
return self.underlying.has_ontime_pane()

def get_continuation_trigger(self):
return Repeatedly(self.underlying.get_continuation_trigger())


class _ParallelTriggerFn(TriggerFn, metaclass=ABCMeta):
def __init__(self, *triggers):
Expand Down Expand Up @@ -831,6 +872,12 @@ def to_runner_api(self, context):
def has_ontime_pane(self):
return any(t.has_ontime_pane() for t in self.triggers)

def get_continuation_trigger(self):
return self.__class__(
*(
subtrigger.get_continuation_trigger()
for subtrigger in self.triggers))


class AfterAny(_ParallelTriggerFn):
"""Fires when any subtrigger fires.
Expand Down Expand Up @@ -933,6 +980,13 @@ def to_runner_api(self, context):
def has_ontime_pane(self):
return any(t.has_ontime_pane() for t in self.triggers)

def get_continuation_trigger(self):
return Repeatedly(
Copy link
Collaborator

@shunping shunping Oct 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why the continuation trigger of a AfterEach trigger is a repeatedly trigger? @tarun-google

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shunping With continuation trigger concept, we are injecting a new trigger after the GroupBy window. Which gets evaluated every time there is a new pane of data released by first trigger. For example if the initial trigger is AfterProcessingTime(5), which trigger only once after 5 sec. we are adding a new trigger after GroupBy when this trigger happens, which is pass by layer. A lot of our triggers are one time.

But the point with AfterEach(condition1, condition2,..) is it is not a one time trigger. it triggers every time there is a condition met. So, if we just write the continuation trigger AfterAny() then it triggers only once. we want continuation trigger for AfterEach to trigger every time the condition is met, not once.

Reference:

  1. Java impl for AfterEach continuation trigger and its definition
  2. Definition of AfterAny

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yea, it's not an exact science. There's no real correctness criteria for continuation trigger except "don't hold up data that is already triggered". And the only reason we don't make all of them Repeatedly(Always) is for the corner case of aligned processing time where the user might be surprised if a downstream aggregation had many more outputs because it fired right away instead of waiting for everything aligned to the same processing time. TBH even then it is sort of meh.

AfterAny(
*(
subtrigger.get_continuation_trigger()
for subtrigger in self.triggers)))


class OrFinally(AfterAny):
@staticmethod
Expand Down Expand Up @@ -1643,3 +1697,60 @@ def __repr__(self):
state_str = '\n'.join(
'%s: %s' % (key, dict(state)) for key, state in self.state.items())
return 'timers: %s\nstate: %s' % (dict(self.timers), state_str)


class _AfterSynchronizedProcessingTime(TriggerFn):
"""A "runner's-discretion" trigger downstream of a GroupByKey
with AfterProcessingTime trigger.

In runners that directly execute this
Python code, the trigger currently always fires,
but this behavior is neither guaranteed nor
required by runners, regardless of whether they
execute triggers via Python.

_AfterSynchronizedProcessingTime is experimental
and internal-only. No backwards compatibility
guarantees.
"""
def __init__(self):
pass

def __repr__(self):
return '_AfterSynchronizedProcessingTime()'

def __eq__(self, other):
return type(self) == type(other)

def __hash__(self):
return hash(type(self))

def on_element(self, _element, _window, _context):
pass

def on_merge(self, _to_be_merged, _merge_result, _context):
pass

def should_fire(self, _time_domain, _timestamp, _window, _context):
return True

def on_fire(self, _timestamp, _window, _context):
return False

def reset(self, _window, _context):
pass

@staticmethod
def from_runner_api(_proto, _context):
return _AfterSynchronizedProcessingTime()

def to_runner_api(self, _context):
return beam_runner_api_pb2.Trigger(
after_synchronized_processing_time=beam_runner_api_pb2.Trigger.
AfterSynchronizedProcessingTime())

def has_ontime_pane(self):
return False

def get_continuation_trigger(self):
return self
50 changes: 50 additions & 0 deletions sdks/python/apache_beam/transforms/trigger_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,6 +554,56 @@ def test_trigger_encoding(self):
TriggerFn.from_runner_api(trigger_fn.to_runner_api(context), context))


class ContinuationTriggerTest(unittest.TestCase):
def test_after_all(self):
self.assertEqual(
AfterAll(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
AfterAll(AfterCount(1), AfterCount(1)))

def test_after_any(self):
self.assertEqual(
AfterAny(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
AfterAny(AfterCount(1), AfterCount(1)))

def test_after_count(self):
self.assertEqual(AfterCount(1).get_continuation_trigger(), AfterCount(1))
self.assertEqual(AfterCount(100).get_continuation_trigger(), AfterCount(1))

def test_after_each(self):
self.assertEqual(
AfterEach(AfterCount(2), AfterCount(5)).get_continuation_trigger(),
Repeatedly(AfterAny(AfterCount(1), AfterCount(1))))

def test_after_processing_time(self):
from apache_beam.transforms.trigger import _AfterSynchronizedProcessingTime
self.assertEqual(
AfterProcessingTime(10).get_continuation_trigger(),
_AfterSynchronizedProcessingTime())

def test_after_watermark(self):
self.assertEqual(
AfterWatermark().get_continuation_trigger(), AfterWatermark())
self.assertEqual(
AfterWatermark(early=AfterCount(10),
late=AfterCount(20)).get_continuation_trigger(),
AfterWatermark(early=AfterCount(1), late=AfterCount(1)))

def test_always(self):
self.assertEqual(Always().get_continuation_trigger(), Always())

def test_default(self):
self.assertEqual(
DefaultTrigger().get_continuation_trigger(), DefaultTrigger())

def test_never(self):
self.assertEqual(_Never().get_continuation_trigger(), _Never())

def test_repeatedly(self):
self.assertEqual(
Repeatedly(AfterCount(10)).get_continuation_trigger(),
Repeatedly(AfterCount(1)))


class TriggerPipelineTest(unittest.TestCase):
def test_after_processing_time(self):
test_options = PipelineOptions(
Expand Down
Loading