diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py index ebafeb850657f..84918bd75f675 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker.py @@ -852,6 +852,9 @@ def _check_if_model_exists(self, model_name: str, describe_func: Callable[[str], def execute_complete(self, context: Context, event: dict[str, Any] | None = None) -> dict[str, dict]: validated_event = validate_execute_complete_event(event) + if validated_event["status"] != "success": + raise RuntimeError(f"Error while running transform job: {validated_event}") + self.log.info("SageMaker job %s completed.", validated_event["job_name"]) return self.serialize_result(validated_event["job_name"]) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py index 40c86115c48e8..de8ed7841e856 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_transform.py @@ -341,6 +341,33 @@ def test_operator_failed_before_defer(self, _, mock_transform, mock_describe_tra self.sagemaker.execute(context=None) assert not mock_defer.called + @mock.patch.object(sagemaker, "serialize", return_value="") + @mock.patch.object(SageMakerHook, "describe_model", return_value={"ModelName": "model_name"}) + @mock.patch.object( + SageMakerHook, + "describe_transform_job", + return_value={ + "ModelName": "model_name", + "TransformJobStatus": "Failed", + "FailureReason": "it failed", + }, + ) + def test_execute_complete_raises_when_job_failed_during_deferred_wait( + self, mock_describe_transform_job, mock_describe_model, mock_serialize + ): + # When the transform job fails during the deferred wait, the trigger (an + # AwsBaseWaiterTrigger) yields {"status": "error", ...} instead of raising, so + # execute_complete must reject a non-success status rather than report the task as + # successful — matching every other SageMaker operator's execute_complete. + event = { + "status": "error", + "message": "Error while waiting for transform job: terminal failure", + "job_name": "job_name", + } + + with pytest.raises(RuntimeError, match="Error while running transform job"): + self.sagemaker.execute_complete(context=None, event=event) + @mock.patch("airflow.providers.amazon.aws.operators.sagemaker.SageMakerTransformOperator.defer") @mock.patch.object(SageMakerHook, "describe_model") @mock.patch.object(