diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 2f7a408bb1a47..7f94448a8d402 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -319,6 +319,7 @@ connectTimeoutMS connexion containerConfiguration containerd +ContainerEntrypoint ContainerGroup containerinstance ContainerPort @@ -693,6 +694,7 @@ Gantt gantt gapic gapped +gb gbq gcc gcloud @@ -826,6 +828,7 @@ ImageAnnotatorClient imageORfile imagePullPolicy imagePullSecrets +ImageUri imageVersion Imap imap @@ -859,6 +862,7 @@ InstanceFlexibilityPolicy InstanceGroupConfig InstanceSelection instanceTemplates +InstanceType instantiation integrations interdependencies @@ -876,6 +880,7 @@ IPv4 ipv4 IPv6 ipv6 +ipynb iPython irreproducible IRSA @@ -1050,6 +1055,7 @@ masterType Matomo matomo Maxime +MaxRuntimeInSeconds mb md mediawiki @@ -1373,6 +1379,8 @@ Qubole qubole QuboleCheckHook Quboles +querybook +Querybooks queryParameters querystring queueing @@ -1887,8 +1895,10 @@ views virtualenv virtualenvs vm +VolumeKmsKeyId VolumeMount volumeMounts +VolumeSizeInGB vpc WaiterModel wape diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index e1fc18617cb49..0b887e4481e78 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -38,6 +38,7 @@ "jsonpath_ng>=1.5.3", "python3-saml>=1.16.0", "redshift_connector>=2.0.918", + "sagemaker-studio>=1.0.9", "watchtower>=3.0.0,!=3.3.0,<4" ], "devel-deps": [ diff --git a/providers/amazon/README.rst b/providers/amazon/README.rst index 0fec71dbae179..c61e4133b02a1 100644 --- a/providers/amazon/README.rst +++ b/providers/amazon/README.rst @@ -67,6 +67,7 @@ PIP package Version required ``PyAthena`` ``>=3.0.10`` ``jmespath`` ``>=0.7.0`` ``python3-saml`` ``>=1.16.0`` +``sagemaker-studio`` ``>=1.0.9`` ========================================== ====================== Cross provider package dependencies diff --git a/providers/amazon/docs/operators/sagemakerunifiedstudio.rst b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst new file mode 100644 index 0000000000000..33833cf395de7 --- /dev/null +++ b/providers/amazon/docs/operators/sagemakerunifiedstudio.rst @@ -0,0 +1,60 @@ + .. Licensed to the Apache Software Foundation (ASF) under one + or more contributor license agreements. See the NOTICE file + distributed with this work for additional information + regarding copyright ownership. The ASF licenses this file + to you under the Apache License, Version 2.0 (the + "License"); you may not use this file except in compliance + with the License. You may obtain a copy of the License at + + .. http://www.apache.org/licenses/LICENSE-2.0 + + .. Unless required by applicable law or agreed to in writing, + software distributed under the License is distributed on an + "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + KIND, either express or implied. See the License for the + specific language governing permissions and limitations + under the License. + +=============================== +Amazon SageMaker Unified Studio +=============================== + +`Amazon SageMaker Unified Studio `__ is a unified development experience that +brings together AWS data, analytics, artificial intelligence (AI), and machine learning (ML) services. +It provides a place to build, deploy, execute, and monitor end-to-end workflows from a single interface. +This helps drive collaboration across teams and facilitate agile development. + +Airflow provides operators to orchestrate Notebooks, Querybooks, and Visual ETL jobs within SageMaker Unified Studio Workflows. + +Prerequisite Tasks +------------------ + +To use these operators, you must do a few things: + + * Create a SageMaker Unified Studio domain and project, following the instruction in `AWS documentation `__. + * Within your project: + * Navigate to the "Compute > Workflow environments" tab, and click "Create" to create a new MWAA environment. + * Create a Notebook, Querybook, or Visual ETL job and save it to your project. + +Operators +--------- + +.. _howto/operator:SageMakerNotebookOperator: + +Create an Amazon SageMaker Unified Studio Workflow +================================================== + +To create an Amazon SageMaker Unified Studio workflow to orchestrate your notebook, querybook, and visual ETL runs you can use +:class:`~airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookOperator`. + +.. exampleinclude:: /../../providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_sagemaker_unified_studio_notebook] + :end-before: [END howto_operator_sagemaker_unified_studio_notebook] + + +Reference +--------- + +* `What is Amazon SageMaker Unified Studio `__ diff --git a/providers/amazon/provider.yaml b/providers/amazon/provider.yaml index 107bd35bb091d..1aa2947fea578 100644 --- a/providers/amazon/provider.yaml +++ b/providers/amazon/provider.yaml @@ -234,6 +234,12 @@ integrations: how-to-guide: - /docs/apache-airflow-providers-amazon/operators/sagemaker.rst tags: [aws] + - integration-name: Amazon SageMaker Unified Studio + external-doc-url: https://aws.amazon.com/sagemaker/unified-studio/ + logo: /docs/integration-logos/Amazon-SageMaker_light-bg@4x.png + how-to-guide: + - /docs/apache-airflow-providers-amazon/operators/sagemakerunifiedstudio.rst + tags: [aws] - integration-name: Amazon SecretsManager external-doc-url: https://aws.amazon.com/secrets-manager/ logo: /docs/integration-logos/AWS-Secrets-Manager_light-bg@4x.png @@ -402,6 +408,9 @@ operators: - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.operators.sagemaker + - integration-name: Amazon SageMaker Unified Studio + python-modules: + - airflow.providers.amazon.aws.operators.sagemaker_unified_studio - integration-name: Amazon Simple Notification Service (SNS) python-modules: - airflow.providers.amazon.aws.operators.sns @@ -503,6 +512,9 @@ sensors: - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.sensors.sagemaker + - integration-name: Amazon SageMaker Unified Studio + python-modules: + - airflow.providers.amazon.aws.sensors.sagemaker_unified_studio - integration-name: Amazon Simple Queue Service (SQS) python-modules: - airflow.providers.amazon.aws.sensors.sqs @@ -627,6 +639,9 @@ hooks: - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.hooks.sagemaker + - integration-name: Amazon SageMaker Unified Studio + python-modules: + - airflow.providers.amazon.aws.hooks.sagemaker_unified_studio - integration-name: Amazon Simple Email Service (SES) python-modules: - airflow.providers.amazon.aws.hooks.ses @@ -699,6 +714,9 @@ triggers: - integration-name: Amazon SageMaker python-modules: - airflow.providers.amazon.aws.triggers.sagemaker + - integration-name: Amazon SageMaker Unified Studio + python-modules: + - airflow.providers.amazon.aws.triggers.sagemaker_unified_studio - integration-name: AWS Glue python-modules: - airflow.providers.amazon.aws.triggers.glue @@ -734,7 +752,6 @@ triggers: python-modules: - airflow.providers.amazon.aws.triggers.dms - transfers: - source-integration-name: Amazon DynamoDB target-integration-name: Amazon Simple Storage Service (S3) @@ -837,6 +854,7 @@ extra-links: - airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink - airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink - airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink + - airflow.providers.amazon.aws.links.sagemaker_unified_studio.SageMakerUnifiedStudioLink - airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink - airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink - airflow.providers.amazon.aws.links.comprehend.ComprehendPiiEntitiesDetectionLink diff --git a/providers/amazon/pyproject.toml b/providers/amazon/pyproject.toml index 7586425e3f3c3..14d5498613650 100644 --- a/providers/amazon/pyproject.toml +++ b/providers/amazon/pyproject.toml @@ -75,6 +75,7 @@ dependencies = [ "PyAthena>=3.0.10", "jmespath>=0.7.0", "python3-saml>=1.16.0", + "sagemaker-studio>=1.0.9", ] # The optional dependencies should be modified in place in the generated file diff --git a/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..4ad327b51c5ff --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio.py @@ -0,0 +1,188 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook hook.""" + +from __future__ import annotations + +import time + +from sagemaker_studio import ClientConfig +from sagemaker_studio.sagemaker_studio_api import SageMakerStudioAPI + +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner + + +class SageMakerNotebookHook(BaseHook): + """ + Interact with Sagemaker Unified Studio Workflows. + + This hook provides a wrapper around the Sagemaker Workflows Notebook Execution API. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import SageMakerNotebookHook + + notebook_hook = SageMakerNotebookHook( + input_config={"input_path": "path/to/notebook.ipynb", "input_params": {"param1": "value1"}}, + output_config={"output_uri": "folder/output/location/prefix", "output_formats": "NOTEBOOK"}, + execution_name="notebook_execution", + waiter_delay=10, + waiter_max_attempts=1440, + ) + + :param execution_name: The name of the notebook job to be executed, this is same as task_id. + :param input_config: Configuration for the input file. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params': {'param1': 'value1'}} + :param output_config: Configuration for the output format. It should include an output_formats parameter to specify the output format. + Example: {'output_formats': ['NOTEBOOK']} + :param compute: compute configuration to use for the notebook execution. This is a required attribute + if the execution is on a remote compute. + Example: { "instance_type": "ml.m5.large", "volume_size_in_gb": 30, "volume_kms_key_id": "", "image_uri": "string", "container_entrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param waiter_delay: Interval in seconds to check the task execution status. + :param waiter_max_attempts: Number of attempts to wait before returning FAILED. + """ + + def __init__( + self, + execution_name: str, + input_config: dict | None = None, + output_config: dict | None = None, + compute: dict | None = None, + termination_condition: dict | None = None, + tags: dict | None = None, + waiter_delay: int = 10, + waiter_max_attempts: int = 1440, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + self._sagemaker_studio = SageMakerStudioAPI(self._get_sagemaker_studio_config()) + self.execution_name = execution_name + self.input_config = input_config or {} + self.output_config = output_config or {"output_formats": ["NOTEBOOK"]} + self.compute = compute + self.termination_condition = termination_condition or {} + self.tags = tags or {} + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def _get_sagemaker_studio_config(self): + config = ClientConfig() + config.overrides["execution"] = {"local": is_local_runner()} + return config + + def _format_start_execution_input_config(self): + config = { + "notebook_config": { + "input_path": self.input_config.get("input_path"), + "input_parameters": self.input_config.get("input_params"), + }, + } + + return config + + def _format_start_execution_output_config(self): + output_formats = self.output_config.get("output_formats") + config = { + "notebook_config": { + "output_formats": output_formats, + } + } + return config + + def start_notebook_execution(self): + start_execution_params = { + "execution_name": self.execution_name, + "execution_type": "NOTEBOOK", + "input_config": self._format_start_execution_input_config(), + "output_config": self._format_start_execution_output_config(), + "termination_condition": self.termination_condition, + "tags": self.tags, + } + if self.compute: + start_execution_params["compute"] = self.compute + else: + start_execution_params["compute"] = {"instance_type": "ml.m4.xlarge"} + + print(start_execution_params) + return self._sagemaker_studio.execution_client.start_execution(**start_execution_params) + + def wait_for_execution_completion(self, execution_id, context): + wait_attempts = 0 + while wait_attempts < self.waiter_max_attempts: + wait_attempts += 1 + time.sleep(self.waiter_delay) + response = self._sagemaker_studio.execution_client.get_execution(execution_id=execution_id) + error_message = response.get("error_details", {}).get("error_message") + status = response["status"] + if "files" in response: + self._set_xcom_files(response["files"], context) + if "s3_path" in response: + self._set_xcom_s3_path(response["s3_path"], context) + + ret = self._handle_state(execution_id, status, error_message) + if ret: + return ret + + # If timeout, handle state FAILED with timeout message + return self._handle_state(execution_id, "FAILED", "Execution timed out") + + def _set_xcom_files(self, files, context): + if not context: + error_message = "context is required" + raise AirflowException(error_message) + for file in files: + context["ti"].xcom_push( + key=f"{file['display_name']}.{file['file_format']}", + value=file["file_path"], + ) + + def _set_xcom_s3_path(self, s3_path, context): + if not context: + error_message = "context is required" + raise AirflowException(error_message) + context["ti"].xcom_push( + key="s3_path", + value=s3_path, + ) + + def _handle_state(self, execution_id, status, error_message): + finished_states = ["COMPLETED"] + in_progress_states = ["IN_PROGRESS", "STOPPING"] + + if status in in_progress_states: + info_message = f"Execution {execution_id} is still in progress with state:{status}, will check for a terminal status again in {self.waiter_delay}" + self.log.info(info_message) + return None + execution_message = f"Exiting Execution {execution_id} State: {status}" + if status in finished_states: + self.log.info(execution_message) + return {"Status": status, "ExecutionId": execution_id} + else: + log_error_message = f"Execution {execution_id} failed with error: {error_message}" + self.log.error(log_error_message) + if error_message == "": + error_message = execution_message + raise AirflowException(error_message) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..802a1fbfff88d --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/links/sagemaker_unified_studio.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.base_aws import BASE_AWS_CONSOLE_LINK, BaseAwsLink + + +class SageMakerUnifiedStudioLink(BaseAwsLink): + """Helper class for constructing Amazon SageMaker Unified Studio Links.""" + + name = "Amazon SageMaker Unified Studio" + key = "sagemaker_unified_studio" + format_str = BASE_AWS_CONSOLE_LINK + "/datazone/home?region={region_name}" diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..c872c56afa634 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio.py @@ -0,0 +1,155 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook operator.""" + +from __future__ import annotations + +from functools import cached_property +from typing import TYPE_CHECKING + +from airflow.configuration import conf +from airflow.exceptions import AirflowException +from airflow.models import BaseOperator +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.providers.amazon.aws.links.sagemaker_unified_studio import ( + SageMakerUnifiedStudioLink, +) +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import ( + SageMakerNotebookJobTrigger, +) + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SageMakerNotebookOperator(BaseOperator): + """ + Provides Artifact execution functionality for Sagemaker Unified Studio Workflows. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import SageMakerNotebookOperator + + notebook_operator = SageMakerNotebookOperator( + task_id="notebook_task", + input_config={"input_path": "path/to/notebook.ipynb", "input_params": ""}, + output_config={"output_format": "ipynb"}, + wait_for_completion=True, + waiter_delay=10, + waiter_max_attempts=1440, + ) + + :param task_id: A unique, meaningful id for the task. + :param input_config: Configuration for the input file. Input path should be specified as a relative path. + The provided relative path will be automatically resolved to an absolute path within + the context of the user's home directory in the IDE. Input params should be a dict. + Example: {'input_path': 'folder/input/notebook.ipynb', 'input_params':{'key': 'value'}} + :param output_config: Configuration for the output format. It should include an output_format parameter to control + the format of the notebook execution output. + Example: {"output_formats": ["NOTEBOOK"]} + :param compute: compute configuration to use for the artifact execution. This is a required attribute + if the execution is on a remote compute. + Example: { "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, "VolumeKmsKeyId": "", "ImageUri": "string", "ContainerEntrypoint": [ "string" ]} + :param termination_condition: conditions to match to terminate the remote execution. + Example: { "MaxRuntimeInSeconds": 3600 } + :param tags: tags to be associated with the remote execution runs. + Example: { "md_analytics": "logs" } + :param wait_for_completion: Indicates whether to wait for the notebook execution to complete. If True, wait for completion; if False, don't wait. + :param waiter_delay: Interval in seconds to check the notebook execution status. + :param waiter_max_attempts: Number of attempts to wait before returning FAILED. + :param deferrable: If True, the operator will wait asynchronously for the job to complete. + This implies waiting for completion. This mode requires aiobotocore module to be installed. + (default: False) + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerNotebookOperator` + """ + + operator_extra_links = (SageMakerUnifiedStudioLink(),) + + def __init__( + self, + task_id: str, + input_config: dict, + output_config: dict | None = None, + compute: dict | None = None, + termination_condition: dict | None = None, + tags: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 10, + waiter_max_attempts: int = 1440, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(task_id=task_id, **kwargs) + self.execution_name = task_id + self.input_config = input_config + self.output_config = output_config or {"output_formats": ["NOTEBOOK"]} + self.compute = compute or {} + self.termination_condition = termination_condition or {} + self.tags = tags or {} + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + self.deferrable = deferrable + self.input_kwargs = kwargs + + @cached_property + def notebook_execution_hook(self): + if not self.input_config: + raise AirflowException("input_config is required") + + if "input_path" not in self.input_config: + raise AirflowException("input_path is a required field in the input_config") + + return SageMakerNotebookHook( + input_config=self.input_config, + output_config=self.output_config, + execution_name=self.execution_name, + compute=self.compute, + termination_condition=self.termination_condition, + tags=self.tags, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ) + + def execute(self, context: Context): + notebook_execution = self.notebook_execution_hook.start_notebook_execution() + execution_id = notebook_execution["execution_id"] + + if self.deferrable: + self.defer( + trigger=SageMakerNotebookJobTrigger( + execution_id=execution_id, + execution_name=self.execution_name, + waiter_delay=self.waiter_delay, + waiter_max_attempts=self.waiter_max_attempts, + ), + method_name="execute_complete", + ) + elif self.wait_for_completion: + response = self.notebook_execution_hook.wait_for_execution_completion(execution_id, context) + status = response["Status"] + log_info_message = ( + f"Notebook Execution: {self.execution_name} Status: {status}. Run Id: {execution_id}" + ) + self.log.info(log_info_message) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..ab32b50dbe89b --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio.py @@ -0,0 +1,73 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook sensor.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.sensors.base import BaseSensorOperator + +if TYPE_CHECKING: + from airflow.utils.context import Context + + +class SageMakerNotebookSensor(BaseSensorOperator): + """ + Waits for a Sagemaker Workflows Notebook execution to reach any of the status below. + + 'FAILED', 'STOPPED', 'COMPLETED' + + :param execution_id: The Sagemaker Workflows Notebook running execution identifier + :param execution_name: The Sagemaker Workflows Notebook unique execution name + """ + + def __init__(self, *, execution_id: str, execution_name: str, **kwargs): + super().__init__(**kwargs) + self.execution_id = execution_id + self.execution_name = execution_name + self.success_state = ["COMPLETED"] + self.in_progress_states = ["PENDING", "RUNNING"] + + def hook(self): + return SageMakerNotebookHook(execution_name=self.execution_name) + + # override from base sensor + def poke(self, context=None): + status = self.hook().get_execution_status(execution_id=self.execution_id) + + if status in self.success_state: + log_info_message = f"Exiting Execution {self.execution_id} State: {status}" + self.log.info(log_info_message) + return True + elif status in self.in_progress_states: + return False + else: + error_message = f"Exiting Execution {self.execution_id} State: {status}" + self.log.info(error_message) + raise AirflowException(error_message) + + def execute(self, context: Context): + # This will invoke poke method in the base sensor + log_info_message = f"Polling Sagemaker Workflows Artifact execution: {self.execution_name} and execution id: {self.execution_id}" + self.log.info(log_info_message) + super().execute(context=context) diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..e9285e9d8dd8c --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio.py @@ -0,0 +1,66 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook job trigger.""" + +from __future__ import annotations + +from airflow.triggers.base import BaseTrigger + + +class SageMakerNotebookJobTrigger(BaseTrigger): + """ + Watches for a notebook job, triggers when it finishes. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import SageMakerNotebookJobTrigger + + notebook_trigger = SageMakerNotebookJobTrigger( + execution_id="notebook_job_1234", + execution_name="notebook_task", + waiter_delay=10, + waiter_max_attempts=1440, + ) + + :param execution_id: A unique, meaningful id for the task. + :param execution_name: A unique, meaningful name for the task. + :param waiter_delay: Interval in seconds to check the notebook execution status. + :param waiter_max_attempts: Number of attempts to wait before returning FAILED. + """ + + def __init__(self, execution_id, execution_name, waiter_delay, waiter_max_attempts, **kwargs): + super().__init__(**kwargs) + self.execution_id = execution_id + self.execution_name = execution_name + self.waiter_delay = waiter_delay + self.waiter_max_attempts = waiter_max_attempts + + def serialize(self): + return ( + # dynamically generate the fully qualified name of the class + self.__class__.__module__ + "." + self.__class__.__qualname__, + { + "execution_id": self.execution_id, + "execution_name": self.execution_name, + "poll_interval": self.poll_interval, + }, + ) + + async def run(self): + pass diff --git a/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py b/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py new file mode 100644 index 0000000000000..63862239bdd26 --- /dev/null +++ b/providers/amazon/src/airflow/providers/amazon/aws/utils/sagemaker_unified_studio.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains utils for the Amazon SageMaker Unified Studio Notebook plugin.""" + +from __future__ import annotations + +import os + +workflows_env_key = "WORKFLOWS_ENV" + + +def is_local_runner(): + return os.getenv(workflows_env_key, "") == "Local" diff --git a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py index 69a1a80fc4d02..cef82566eeb8f 100644 --- a/providers/amazon/src/airflow/providers/amazon/get_provider_info.py +++ b/providers/amazon/src/airflow/providers/amazon/get_provider_info.py @@ -275,6 +275,15 @@ def get_provider_info(): "how-to-guide": ["/docs/apache-airflow-providers-amazon/operators/sagemaker.rst"], "tags": ["aws"], }, + { + "integration-name": "Amazon SageMaker Unified Studio", + "external-doc-url": "https://aws.amazon.com/sagemaker/unified-studio/", + "logo": "/docs/integration-logos/Amazon-SageMaker_light-bg@4x.png", + "how-to-guide": [ + "/docs/apache-airflow-providers-amazon/operators/sagemakerunifiedstudio.rst" + ], + "tags": ["aws"], + }, { "integration-name": "Amazon SecretsManager", "external-doc-url": "https://aws.amazon.com/secrets-manager/", @@ -491,6 +500,10 @@ def get_provider_info(): "integration-name": "Amazon SageMaker", "python-modules": ["airflow.providers.amazon.aws.operators.sagemaker"], }, + { + "integration-name": "Amazon SageMaker Unified Studio", + "python-modules": ["airflow.providers.amazon.aws.operators.sagemaker_unified_studio"], + }, { "integration-name": "Amazon Simple Notification Service (SNS)", "python-modules": ["airflow.providers.amazon.aws.operators.sns"], @@ -628,6 +641,10 @@ def get_provider_info(): "integration-name": "Amazon SageMaker", "python-modules": ["airflow.providers.amazon.aws.sensors.sagemaker"], }, + { + "integration-name": "Amazon SageMaker Unified Studio", + "python-modules": ["airflow.providers.amazon.aws.sensors.sagemaker_unified_studio"], + }, { "integration-name": "Amazon Simple Queue Service (SQS)", "python-modules": ["airflow.providers.amazon.aws.sensors.sqs"], @@ -781,6 +798,10 @@ def get_provider_info(): "integration-name": "Amazon SageMaker", "python-modules": ["airflow.providers.amazon.aws.hooks.sagemaker"], }, + { + "integration-name": "Amazon SageMaker Unified Studio", + "python-modules": ["airflow.providers.amazon.aws.hooks.sagemaker_unified_studio"], + }, { "integration-name": "Amazon Simple Email Service (SES)", "python-modules": ["airflow.providers.amazon.aws.hooks.ses"], @@ -878,6 +899,10 @@ def get_provider_info(): "integration-name": "Amazon SageMaker", "python-modules": ["airflow.providers.amazon.aws.triggers.sagemaker"], }, + { + "integration-name": "Amazon SageMaker Unified Studio", + "python-modules": ["airflow.providers.amazon.aws.triggers.sagemaker_unified_studio"], + }, { "integration-name": "AWS Glue", "python-modules": [ @@ -1072,6 +1097,7 @@ def get_provider_info(): "airflow.providers.amazon.aws.links.glue.GlueJobRunDetailsLink", "airflow.providers.amazon.aws.links.logs.CloudWatchEventsLink", "airflow.providers.amazon.aws.links.sagemaker.SageMakerTransformJobLink", + "airflow.providers.amazon.aws.links.sagemaker_unified_studio.SageMakerUnifiedStudioLink", "airflow.providers.amazon.aws.links.step_function.StateMachineDetailsLink", "airflow.providers.amazon.aws.links.step_function.StateMachineExecutionsDetailsLink", "airflow.providers.amazon.aws.links.comprehend.ComprehendPiiEntitiesDetectionLink", @@ -1354,6 +1380,7 @@ def get_provider_info(): "PyAthena>=3.0.10", "jmespath>=0.7.0", "python3-saml>=1.16.0", + "sagemaker-studio>=1.0.9", ], "optional-dependencies": { "pandas": ["pandas>=2.1.2,<2.2"], diff --git a/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py b/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..8a4a5c14c6698 --- /dev/null +++ b/providers/amazon/tests/system/amazon/aws/example_sagemaker_unified_studio.py @@ -0,0 +1,166 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from datetime import datetime + +import pytest + +from airflow.decorators import task +from airflow.models.baseoperator import chain +from airflow.models.dag import DAG +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from system.amazon.aws.utils import ENV_ID_KEY, SystemTestContextBuilder + +from tests_common.test_utils.version_compat import AIRFLOW_V_2_10_PLUS + +""" +Prerequisites: The account which runs this test must manually have the following: +1. An IAM IDC organization set up in the testing region with a user initialized +2. A SageMaker Unified Studio Domain (with default VPC and roles) +3. A project within the SageMaker Unified Studio Domain +4. A notebook (test_notebook.ipynb) placed in the project's s3 path + +This test will emulate a DAG run in the shared MWAA environment inside a SageMaker Unified Studio Project. +The setup tasks will set up the project and configure the test runner to emulate an MWAA instance. +Then, the SageMakerNotebookOperator will run a test notebook. This should spin up a SageMaker training job, run the notebook, and exit successfully. +""" + +pytestmark = pytest.mark.skipif(not AIRFLOW_V_2_10_PLUS, reason="Test requires Airflow 2.10+") + +DAG_ID = "example_sagemaker_unified_studio" + +# Externally fetched variables: +DOMAIN_ID_KEY = "DOMAIN_ID" +PROJECT_ID_KEY = "PROJECT_ID" +ENVIRONMENT_ID_KEY = "ENVIRONMENT_ID" +S3_PATH_KEY = "S3_PATH" +REGION_NAME_KEY = "REGION_NAME" + +sys_test_context_task = ( + SystemTestContextBuilder() + .add_variable(DOMAIN_ID_KEY) + .add_variable(PROJECT_ID_KEY) + .add_variable(ENVIRONMENT_ID_KEY) + .add_variable(S3_PATH_KEY) + .add_variable(REGION_NAME_KEY) + .build() +) + + +def get_mwaa_environment_params( + domain_id: str, + project_id: str, + environment_id: str, + s3_path: str, + region_name: str, +): + AIRFLOW_PREFIX = "AIRFLOW__WORKFLOWS__" + + parameters = {} + parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_ID"] = domain_id + parameters[f"{AIRFLOW_PREFIX}DATAZONE_PROJECT_ID"] = project_id + parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENVIRONMENT_ID"] = environment_id + parameters[f"{AIRFLOW_PREFIX}DATAZONE_SCOPE_NAME"] = "dev" + parameters[f"{AIRFLOW_PREFIX}DATAZONE_STAGE"] = "prod" + parameters[f"{AIRFLOW_PREFIX}DATAZONE_ENDPOINT"] = f"https://datazone.{region_name}.api.aws" + parameters[f"{AIRFLOW_PREFIX}PROJECT_S3_PATH"] = s3_path + parameters[f"{AIRFLOW_PREFIX}DATAZONE_DOMAIN_REGION"] = region_name + return parameters + + +@task +def mock_mwaa_environment(parameters: dict): + """ + Sets several environment variables in the container to emulate an MWAA environment provisioned + within SageMaker Unified Studio. When running in the ECSExecutor, this is a no-op. + """ + import os + + for key, value in parameters.items(): + os.environ[key] = value + + +with DAG( + DAG_ID, + schedule="@once", + start_date=datetime(2021, 1, 1), + tags=["example"], + catchup=False, +) as dag: + test_context = sys_test_context_task() + + test_env_id = test_context[ENV_ID_KEY] + domain_id = test_context[DOMAIN_ID_KEY] + project_id = test_context[PROJECT_ID_KEY] + environment_id = test_context[ENVIRONMENT_ID_KEY] + s3_path = test_context[S3_PATH_KEY] + region_name = test_context[REGION_NAME_KEY] + + mock_mwaa_environment_params = get_mwaa_environment_params( + domain_id, + project_id, + environment_id, + s3_path, + region_name, + ) + + setup_mwaa_environment = mock_mwaa_environment(mock_mwaa_environment_params) + + # [START howto_operator_sagemaker_unified_studio_notebook] + notebook_path = "test_notebook.ipynb" # This should be the path to your .ipynb, .sqlnb, or .vetl file in your project. + + run_notebook = SageMakerNotebookOperator( + task_id="run-notebook", + input_config={"input_path": notebook_path, "input_params": {}}, + output_config={"output_formats": ["NOTEBOOK"]}, # optional + compute={ + "instance_type": "ml.m5.large", + "volume_size_in_gb": 30, + }, # optional + termination_condition={"max_runtime_in_seconds": 600}, # optional + tags={}, # optional + wait_for_completion=True, # optional + waiter_delay=5, # optional + deferrable=False, # optional + executor_config={ # optional + "overrides": {"containerOverrides": {"environment": mock_mwaa_environment_params}} + }, + ) + # [END howto_operator_sagemaker_unified_studio_notebook] + + chain( + # TEST SETUP + test_context, + setup_mwaa_environment, + # TEST BODY + run_notebook, + ) + + from tests_common.test_utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + + +from tests_common.test_utils.system_tests import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag) diff --git a/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..179d997740c58 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/hooks/test_sagemaker_unified_studio.py @@ -0,0 +1,201 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, call, patch + +import pytest +from sagemaker_studio.models.execution import ExecutionClient + +from airflow.exceptions import AirflowException +from airflow.models import TaskInstance +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio import ( + SageMakerNotebookHook, +) +from airflow.utils.session import create_session + +pytestmark = pytest.mark.db_test + + +class TestSageMakerNotebookHook: + @pytest.fixture(autouse=True) + def setup(self): + with patch( + "airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerStudioAPI", + autospec=True, + ) as mock_sdk: + self.execution_name = "test-execution" + self.waiter_delay = 10 + sdk_instance = mock_sdk.return_value + sdk_instance.execution_client = MagicMock(spec=ExecutionClient) + sdk_instance.execution_client.start_execution.return_value = { + "execution_id": "execution_id", + "execution_name": "execution_name", + } + self.hook = SageMakerNotebookHook( + input_config={ + "input_path": "test-data/notebook/test_notebook.ipynb", + "input_params": {"key": "value"}, + }, + output_config={"output_formats": ["NOTEBOOK"]}, + execution_name=self.execution_name, + waiter_delay=self.waiter_delay, + compute={"instance_type": "ml.c4.2xlarge"}, + ) + + self.hook._sagemaker_studio = mock_sdk + self.files = [ + {"display_name": "file1.txt", "url": "http://example.com/file1.txt"}, + {"display_name": "file2.txt", "url": "http://example.com/file2.txt"}, + ] + self.context = { + "ti": MagicMock(spec=TaskInstance), + } + self.s3Path = "S3Path" + yield + + def test_format_input_config(self): + expected_config = { + "notebook_config": { + "input_path": "test-data/notebook/test_notebook.ipynb", + "input_parameters": {"key": "value"}, + } + } + + config = self.hook._format_start_execution_input_config() + assert config == expected_config + + def test_format_output_config(self): + expected_config = { + "notebook_config": { + "output_formats": ["NOTEBOOK"], + } + } + + config = self.hook._format_start_execution_output_config() + assert config == expected_config + + def test_format_output_config_default(self): + no_output_config_hook = SageMakerNotebookHook( + input_config={ + "input_path": "test-data/notebook/test_notebook.ipynb", + "input_params": {"key": "value"}, + }, + execution_name=self.execution_name, + waiter_delay=self.waiter_delay, + ) + + no_output_config_hook._sagemaker_studio = self.hook._sagemaker_studio + expected_config = {"notebook_config": {"output_formats": ["NOTEBOOK"]}} + + config = no_output_config_hook._format_start_execution_output_config() + assert config == expected_config + + def test_start_notebook_execution(self): + self.hook._sagemaker_studio = MagicMock() + self.hook._sagemaker_studio.execution_client = MagicMock(spec=ExecutionClient) + + self.hook._sagemaker_studio.execution_client.start_execution.return_value = {"executionId": "123456"} + result = self.hook.start_notebook_execution() + assert result == {"executionId": "123456"} + self.hook._sagemaker_studio.execution_client.start_execution.assert_called_once() + + @patch("time.sleep", return_value=None) # To avoid actual sleep during tests + def test_wait_for_execution_completion(self, mock_sleep): + execution_id = "123456" + self.hook._sagemaker_studio = MagicMock() + self.hook._sagemaker_studio.execution_client = MagicMock(spec=ExecutionClient) + self.hook._sagemaker_studio.execution_client.get_execution.return_value = {"status": "COMPLETED"} + + result = self.hook.wait_for_execution_completion(execution_id, {}) + assert result == {"Status": "COMPLETED", "ExecutionId": execution_id} + self.hook._sagemaker_studio.execution_client.get_execution.assert_called() + mock_sleep.assert_called_once() + + @patch("time.sleep", return_value=None) + def test_wait_for_execution_completion_failed(self, mock_sleep): + execution_id = "123456" + self.hook._sagemaker_studio = MagicMock() + self.hook._sagemaker_studio.execution_client = MagicMock(spec=ExecutionClient) + self.hook._sagemaker_studio.execution_client.get_execution.return_value = { + "status": "FAILED", + "error_details": {"error_message": "Execution failed"}, + } + + with pytest.raises(AirflowException, match="Execution failed"): + self.hook.wait_for_execution_completion(execution_id, self.context) + + def test_handle_in_progress_state(self): + execution_id = "123456" + states = ["IN_PROGRESS", "STOPPING"] + + for status in states: + result = self.hook._handle_state(execution_id, status, None) + assert result is None + + def test_handle_finished_state(self): + execution_id = "123456" + states = ["COMPLETED"] + + for status in states: + result = self.hook._handle_state(execution_id, status, None) + assert result == {"Status": status, "ExecutionId": execution_id} + + def test_handle_failed_state(self): + execution_id = "123456" + status = "FAILED" + error_message = "Execution failed" + with pytest.raises(AirflowException, match=error_message): + self.hook._handle_state(execution_id, status, error_message) + + status = "STOPPED" + error_message = "" + with pytest.raises(AirflowException, match=f"Exiting Execution {execution_id} State: {status}"): + self.hook._handle_state(execution_id, status, error_message) + + def test_handle_unexpected_state(self): + execution_id = "123456" + status = "PENDING" + error_message = f"Exiting Execution {execution_id} State: {status}" + with pytest.raises(AirflowException, match=error_message): + self.hook._handle_state(execution_id, status, error_message) + + @patch( + "airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerNotebookHook._set_xcom_files" + ) + def test_set_xcom_files(self, mock_set_xcom_files): + with create_session(): + self.hook._set_xcom_files(self.files, self.context) + expected_call = call(self.files, self.context) + mock_set_xcom_files.assert_called_once_with(*expected_call.args, **expected_call.kwargs) + + def test_set_xcom_files_negative_missing_context(self): + with pytest.raises(AirflowException, match="context is required"): + self.hook._set_xcom_files(self.files, {}) + + @patch( + "airflow.providers.amazon.aws.hooks.sagemaker_unified_studio.SageMakerNotebookHook._set_xcom_s3_path" + ) + def test_set_xcom_s3_path(self, mock_set_xcom_s3_path): + with create_session(): + self.hook._set_xcom_s3_path(self.s3Path, self.context) + expected_call = call(self.s3Path, self.context) + mock_set_xcom_s3_path.assert_called_once_with(*expected_call.args, **expected_call.kwargs) + + def test_set_xcom_s3_path_negative_missing_context(self): + with pytest.raises(AirflowException, match="context is required"): + self.hook._set_xcom_s3_path(self.s3Path, {}) diff --git a/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..c55d1231fd83f --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/links/test_sagemaker_unified_studio.py @@ -0,0 +1,32 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from airflow.providers.amazon.aws.links.sagemaker_unified_studio import SageMakerUnifiedStudioLink +from unit.amazon.aws.links.test_base_aws import BaseAwsLinksTestCase + + +class TestSageMakerUnifiedStudioLink(BaseAwsLinksTestCase): + link_class = SageMakerUnifiedStudioLink + + def test_extra_link(self): + self.assert_extra_link_url( + expected_url=("https://console.aws.amazon.com/datazone/home?region=us-east-1"), + region_name="us-east-1", + aws_partition="aws", + job_name="test_job_name", + ) diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb b/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb new file mode 100644 index 0000000000000..395eff4ef6255 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_notebook.ipynb @@ -0,0 +1,61 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "437d7d66", + "metadata": {}, + "outputs": [], + "source": [ + "# Licensed to the Apache Software Foundation (ASF) under one\n", + "# or more contributor license agreements. See the NOTICE file\n", + "# distributed with this work for additional information\n", + "# regarding copyright ownership. The ASF licenses this file\n", + "# to you under the Apache License, Version 2.0 (the\n", + "# \"License\"); you may not use this file except in compliance\n", + "# with the License. You may obtain a copy of the License at\n", + "#\n", + "# http://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing,\n", + "# software distributed under the License is distributed on an\n", + "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n", + "# KIND, either express or implied. See the License for the\n", + "# specific language governing permissions and limitations\n", + "# under the License." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a734f58df854b5fa", + "metadata": {}, + "outputs": [], + "source": [ + "def add(num1, num2):\n", + " return num1 + num2" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..87e9b004db192 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_sagemaker_unified_studio.py @@ -0,0 +1,176 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.operators.sagemaker_unified_studio import ( + SageMakerNotebookOperator, +) +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio import ( + SageMakerNotebookJobTrigger, +) + + +class TestSageMakerNotebookOperator: + def test_init(self): + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={ + "notebook_path": "tests/amazon/aws/operators/test_notebook.ipynb", + }, + output_config={"output_format": "ipynb"}, + ) + + assert operator.task_id == "test_id" + assert operator.input_config == { + "notebook_path": "tests/amazon/aws/operators/test_notebook.ipynb", + } + assert operator.output_config == {"output_format": "ipynb"} + + def test_only_required_params_init(self): + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={ + "notebook_path": "tests/amazon/aws/operators/test_notebook.ipynb", + }, + ) + assert isinstance(operator, SageMakerNotebookOperator) + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + def test_execute_success(self, mock_notebook_hook): # Mock the NotebookHook and its execute method + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.start_notebook_execution.return_value = { + "execution_id": "123456", + "executionType": "test", + } + + # Create the operator + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={"input_path": "test_input_path"}, + output_config={"output_uri": "test_output_uri", "output_format": "ipynb"}, + ) + + # Execute the operator + operator.execute({}) + mock_hook_instance.start_notebook_execution.assert_called_once_with() + mock_hook_instance.wait_for_execution_completion.assert_called_once_with("123456", {}) + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + def test_execute_failure_missing_input_config(self, mock_notebook_hook): + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={}, + output_config={"output_uri": "test_output_uri", "output_format": "ipynb"}, + ) + + with pytest.raises(AirflowException, match="input_config is required"): + operator.execute({}) + + mock_notebook_hook.assert_not_called() + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + def test_execute_failure_missing_input_path(self, mock_notebook_hook): + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={"invalid_key": "test_input_path"}, + output_config={"output_uri": "test_output_uri", "output_format": "ipynb"}, + ) + + with pytest.raises(AirflowException, match="input_path is a required field in the input_config"): + operator.execute({}) + + mock_notebook_hook.assert_not_called() + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + def test_execute_with_wait_for_completion(self, mock_notebook_hook): + # Mock the execute and job_completion methods of NotebookHook + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.start_notebook_execution.return_value = { + "execution_id": "123456", + "executionType": "test", + } + mock_hook_instance.wait_for_execution_completion.return_value = {"Status": "COMPLETED"} + + # Create the operator with wait_for_completion set to True + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={"input_path": "test_input_path"}, + output_config={"output_uri": "test_output_uri", "output_format": "ipynb"}, + wait_for_completion=True, + ) + # Execute the operator + operator.execute({}) + + # Verify that execute and wait_for_execution_completion methods are called + mock_hook_instance.start_notebook_execution.assert_called_once_with() + mock_hook_instance.wait_for_execution_completion.assert_called_once_with("123456", {}) + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + @patch.object(SageMakerNotebookOperator, "defer") + def test_execute_with_deferrable(self, mock_defer, mock_notebook_hook): + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.start_notebook_execution.return_value = { + "execution_id": "123456", + "executionType": "test", + } + + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={"input_path": "test_input_path"}, + output_config={"output_format": "ipynb"}, + deferrable=True, + ) + + operator.execute({}) + + mock_hook_instance.start_notebook_execution.assert_called_once_with() + mock_defer.assert_called_once() + trigger_call = mock_defer.call_args[1]["trigger"] + assert isinstance(trigger_call, SageMakerNotebookJobTrigger) + assert trigger_call.execution_id == "123456" + assert trigger_call.execution_name == "test_id" + assert trigger_call.waiter_delay == 10 + mock_hook_instance.wait_for_execution_completion.assert_not_called() + + @patch("airflow.providers.amazon.aws.operators.sagemaker_unified_studio.SageMakerNotebookHook") + def test_execute_without_wait_for_completion(self, mock_notebook_hook): + # Mock the execute method of NotebookHook + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.start_notebook_execution.return_value = { + "execution_id": "123456", + "executionType": "test", + } + + # Create the operator with wait_for_completion set to False + operator = SageMakerNotebookOperator( + task_id="test_id", + input_config={"input_path": "test_input_path"}, + output_config={"output_uri": "test_output_uri", "output_format": "ipynb"}, + wait_for_completion=False, + ) + + # Execute the operator + operator.execute({}) + + # Verify that execute and wait_for_execution_completion methods are called + mock_hook_instance.start_notebook_execution.assert_called_once_with() + mock_hook_instance.wait_for_execution_completion.assert_not_called() diff --git a/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..46b4e40bf7cd0 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker_unified_studio.py @@ -0,0 +1,105 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest + +from airflow.exceptions import AirflowException +from airflow.providers.amazon.aws.sensors.sagemaker_unified_studio import ( + SageMakerNotebookSensor, +) +from airflow.utils.context import Context + + +class TestSageMakerNotebookSensor: + def test_init(self): + # Test the initialization of the sensor + sensor = SageMakerNotebookSensor( + task_id="test_task", + execution_id="test_execution_id", + execution_name="test_execution_name", + ) + assert sensor.execution_id == "test_execution_id" + assert sensor.execution_name == "test_execution_name" + assert sensor.success_state == ["COMPLETED"] + assert sensor.in_progress_states == ["PENDING", "RUNNING"] + + @patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook") + def test_poke_success_state(self, mock_notebook_hook): + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.get_execution_status.return_value = "COMPLETED" + + sensor = SageMakerNotebookSensor( + task_id="test_task", + execution_id="test_execution_id", + execution_name="test_execution_name", + ) + + # Test the poke method + result = sensor.poke() + assert result is True + mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id") + + @patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook") + def test_poke_failure_state(self, mock_notebook_hook): + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.get_execution_status.return_value = "FAILED" + + sensor = SageMakerNotebookSensor( + task_id="test_task", + execution_id="test_execution_id", + execution_name="test_execution_name", + ) + + # Test the poke method and assert exception + with pytest.raises(AirflowException, match="Exiting Execution test_execution_id State: FAILED"): + sensor.poke() + + mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id") + + @patch("airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookHook") + def test_poke_in_progress_state(self, mock_notebook_hook): + mock_hook_instance = mock_notebook_hook.return_value + mock_hook_instance.get_execution_status.return_value = "RUNNING" + + sensor = SageMakerNotebookSensor( + task_id="test_task", + execution_id="test_execution_id", + execution_name="test_execution_name", + ) + + # Test the poke method + result = sensor.poke() + assert result is False + mock_hook_instance.get_execution_status.assert_called_once_with(execution_id="test_execution_id") + + @patch.object(SageMakerNotebookSensor, "poke", return_value=True) + def test_execute_calls_poke(self, mock_poke): + # Create the sensor + sensor = SageMakerNotebookSensor( + task_id="test_task", + execution_id="test_execution_id", + execution_name="test_execution_name", + ) + + context = MagicMock(spec=Context) + sensor.execute(context=context) + + # Assert that the poke method was called + mock_poke.assert_called_once_with(context) diff --git a/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py b/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py new file mode 100644 index 0000000000000..0730923fae1e7 --- /dev/null +++ b/providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker_unified_studio.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os + +from airflow.providers.amazon.aws.utils.sagemaker_unified_studio import is_local_runner, workflows_env_key + + +def test_is_local_runner_false(): + assert not is_local_runner() + + +def test_is_local_runner_true(): + os.environ[workflows_env_key] = "Local" + assert is_local_runner() + + +def test_is_local_runner_false_with_env_var(): + os.environ[workflows_env_key] = "False" + assert not is_local_runner() + + +def test_is_local_runner_false_with_env_var_empty(): + os.environ[workflows_env_key] = "" + assert not is_local_runner() + + +def test_is_local_runner_false_with_env_var_invalid(): + os.environ[workflows_env_key] = "random string" + assert not is_local_runner() + + +def test_is_local_runner_false_with_string_int(): + os.environ[workflows_env_key] = "1" + assert not is_local_runner() diff --git a/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt b/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt new file mode 100644 index 0000000000000..7ad85fd17468f --- /dev/null +++ b/providers/fab/src/airflow/providers/3rd-party-licenses/LICENSES-ui.txt @@ -0,0 +1,89 @@ +Apache Airflow +Copyright 2016-2023 The Apache Software Foundation + +This product includes software developed at The Apache Software +Foundation (http://www.apache.org/). + +======================================================================= +css-loader|5.2.7: +----- +MIT +Copyright JS Foundation and other contributors + +Permission is hereby granted, free of charge, to any person obtaining +a copy of this software and associated documentation files (the +'Software'), to deal in the Software without restriction, including +without limitation the rights to use, copy, modify, merge, publish, +distribute, sublicense, and/or sell copies of the Software, and to +permit persons to whom the Software is furnished to do so, subject to +the following conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED 'AS IS', WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +webpack-contrib/css-loader + + +moment|2.30.1: +----- +MIT +Copyright (c) JS Foundation and other contributors + +Permission is hereby granted, free of charge, to any person +obtaining a copy of this software and associated documentation +files (the "Software"), to deal in the Software without +restriction, including without limitation the rights to use, +copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following +conditions: + +The above copyright notice and this permission notice shall be +included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND +NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT +HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +https://github.com/moment/moment.git + + +moment-timezone|0.5.47: +----- +MIT +The MIT License (MIT) + +Copyright (c) JS Foundation and other contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +https://github.com/moment/moment-timezone.git + + diff --git a/tests/always/test_project_structure.py b/tests/always/test_project_structure.py index 310994cf38e95..cd73fbce1ec26 100644 --- a/tests/always/test_project_structure.py +++ b/tests/always/test_project_structure.py @@ -71,6 +71,7 @@ def test_providers_modules_should_have_tests(self): "providers/amazon/tests/unit/amazon/aws/sensors/test_emr.py", "providers/amazon/tests/unit/amazon/aws/sensors/test_sagemaker.py", "providers/amazon/tests/unit/amazon/aws/test_exceptions.py", + "providers/amazon/tests/unit/amazon/aws/triggers/test_sagemaker_unified_studio.py", "providers/amazon/tests/unit/amazon/aws/triggers/test_step_function.py", "providers/amazon/tests/unit/amazon/aws/utils/test_rds.py", "providers/amazon/tests/unit/amazon/aws/utils/test_sagemaker.py", @@ -603,6 +604,8 @@ class TestAmazonProviderProjectStructure(ExampleCoverageTest): # These operations take a lot of time, there are commented out in the system tests for this reason "airflow.providers.amazon.aws.operators.dms.DmsStartReplicationOperator", "airflow.providers.amazon.aws.operators.dms.DmsStopReplicationOperator", + # These modules are used in the SageMakerNotebookOperator and therefore don't have their own examples + "airflow.providers.amazon.aws.sensors.sagemaker_unified_studio.SageMakerNotebookSensor", } DEPRECATED_CLASSES = {