From e84ae13733174cf92a2e7c7b233fdfaeb7e5ea77 Mon Sep 17 00:00:00 2001 From: Steve Ahn <38049807+steveahnahn@users.noreply.github.com> Date: Fri, 26 Jun 2026 09:22:08 -0700 Subject: [PATCH] Fix SageMakerTransformOperator succeeding on a failed deferred job A deferred SageMaker transform job that fails during the wait was reported as a successful task: execute_complete did not check the trigger event status, so the failed job's description was pushed as the result and downstream tasks ran against missing or invalid transform output with no error surfaced. Since the trigger was migrated to AwsBaseWaiterTrigger it yields a {"status": "error"} event on failure instead of raising, so the status guard that every sibling SageMaker operator has became load-bearing here too. --- .../amazon/aws/operators/sagemaker.py | 3 +++ .../aws/operators/test_sagemaker_transform.py | 27 +++++++++++++++++++ 2 files changed, 30 insertions(+) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py index ebafeb850657f..84918bd75f675 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -852,6 +852,9 @@ def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: validated_event = validate_execute_complete_event(event) + if validated_event["status"] != "success": + raise RuntimeError(f"Error while running transform job: {validated_event}") + self.log.info("SageMaker job %s completed.", validated_event["job_name"]) return self.serialize_result(validated_event["job_name"]) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py index 40c86115c48e8..de8ed7841e856 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py @@ -341,6 +341,33 @@ def test_operator_failed_before_defer(self, _, mock_transform, mock_describe_tra self.sagemaker.execute(context=None) assert not mock_defer.called + @mock.patch.object(sagemaker, "serialize", return_value="") + @mock.patch.object(SageMakerHook, "describe_model", return_value={"ModelName": "model_name"}) + @mock.patch.object( + SageMakerHook, + "describe_transform_job", + return_value={ + "ModelName": "model_name", + "TransformJobStatus": "Failed", + "FailureReason": "it failed", + }, + ) + def test_execute_complete_raises_when_job_failed_during_deferred_wait( + self, mock_describe_transform_job, mock_describe_model, mock_serialize + ): + # When the transform job fails during the deferred wait, the trigger (an + # AwsBaseWaiterTrigger) yields {"status": "error", ...} instead of raising, so + # execute_complete must reject a non-success status rather than report the task as + # successful — matching every other SageMaker operator's execute_complete. + event = { + "status": "error", + "message": "Error while waiting for transform job: terminal failure", + "job_name": "job_name", + } + + with pytest.raises(RuntimeError, match="Error while running transform job"): + self.sagemaker.execute_complete(context=None, event=event) + @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer") @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(