Skip to content

Commit ca34e28

Browse files
committed
Keep poke_interval/max_attempts as deprecated SageMakerTrigger aliases
Preserve backward compatibility for SageMakerTrigger after switching to the AwsBaseWaiterTrigger naming: poke_interval and max_attempts are still accepted as deprecated aliases for waiter_delay and waiter_max_attempts, emitting AirflowProviderDeprecationWarning. This keeps existing keyword callers working and lets deferred-task triggers serialized by older versions deserialize after upgrade.
1 parent 0353bce commit ca34e28

2 files changed

Lines changed: 40 additions & 0 deletions

File tree

providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,15 @@
1818
from __future__ import annotations
1919

2020
import asyncio
21+
import warnings
2122
from collections import Counter
2223
from collections.abc import AsyncIterator
2324
from enum import IntEnum
2425
from typing import TYPE_CHECKING
2526

2627
from botocore.exceptions import WaiterError
2728

29+
from airflow.exceptions import AirflowProviderDeprecationWarning
2830
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
2931
from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
3032
from airflow.providers.common.compat.sdk import AirflowException
@@ -46,6 +48,8 @@ class SageMakerTrigger(AwsBaseWaiterTrigger):
4648
:param region_name: The AWS region where the job is running. Used to build the hook.
4749
:param verify: Whether or not to verify SSL certificates. Used to build the hook.
4850
:param botocore_config: Configuration dictionary for the botocore client. Used to build the hook.
51+
:param poke_interval: (deprecated) use ``waiter_delay`` instead.
52+
:param max_attempts: (deprecated) use ``waiter_max_attempts`` instead.
4953
"""
5054

5155
def __init__(
@@ -58,7 +62,25 @@ def __init__(
5862
region_name: str | None = None,
5963
verify: bool | str | None = None,
6064
botocore_config: dict | None = None,
65+
poke_interval: int | None = None,
66+
max_attempts: int | None = None,
6167
):
68+
if poke_interval is not None:
69+
warnings.warn(
70+
"`poke_interval` is deprecated and will be removed in a future release. "
71+
"Please use `waiter_delay` instead.",
72+
AirflowProviderDeprecationWarning,
73+
stacklevel=2,
74+
)
75+
waiter_delay = poke_interval
76+
if max_attempts is not None:
77+
warnings.warn(
78+
"`max_attempts` is deprecated and will be removed in a future release. "
79+
"Please use `waiter_max_attempts` instead.",
80+
AirflowProviderDeprecationWarning,
81+
stacklevel=2,
82+
)
83+
waiter_max_attempts = max_attempts
6284
self.job_name = job_name
6385
self.job_type = job_type
6486
super().__init__(

providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import pytest
2323
from botocore.exceptions import WaiterError
2424

25+
from airflow.exceptions import AirflowProviderDeprecationWarning
2526
from airflow.providers.amazon.aws.hooks.sagemaker import SageMakerHook
2627
from airflow.providers.amazon.aws.triggers.sagemaker import SageMakerPipelineTrigger, SageMakerTrigger
2728
from airflow.triggers.base import TriggerEvent
@@ -54,6 +55,23 @@ def test_sagemaker_trigger_serialize(self):
5455
assert args["aws_conn_id"] == AWS_CONN_ID
5556
assert args["region_name"] == REGION_NAME
5657

58+
@pytest.mark.parametrize(
59+
("deprecated_kwarg", "canonical_attr", "value"),
60+
[
61+
("poke_interval", "waiter_delay", 17),
62+
("max_attempts", "attempts", 21),
63+
],
64+
)
65+
def test_sagemaker_trigger_deprecated_params(self, deprecated_kwarg, canonical_attr, value):
66+
with pytest.warns(AirflowProviderDeprecationWarning, match=deprecated_kwarg):
67+
trigger = SageMakerTrigger(
68+
job_name=JOB_NAME,
69+
job_type=JOB_TYPE,
70+
aws_conn_id=AWS_CONN_ID,
71+
**{deprecated_kwarg: value},
72+
)
73+
assert getattr(trigger, canonical_attr) == value
74+
5775
def test_sagemaker_trigger_hook_uses_generic_params(self):
5876
sagemaker_trigger = SageMakerTrigger(
5977
job_name=JOB_NAME,

0 commit comments

Comments
 (0)