diff --git a/providers/amazon/docs/operators/s3_tables.rst b/providers/amazon/docs/operators/s3_tables.rst index bf625e54dff16..207fa848bf403 100644 --- a/providers/amazon/docs/operators/s3_tables.rst +++ b/providers/amazon/docs/operators/s3_tables.rst @@ -33,6 +33,20 @@ To create an Amazon S3 Tables table bucket, use :start-after: [START howto_operator_s3tables_create_table_bucket] :end-before: [END howto_operator_s3tables_create_table_bucket] +.. _howto/operator:S3TablesCreateNamespaceOperator: + +Create a Namespace +------------------ + +To create a namespace in an Amazon S3 Tables table bucket, use +:class:`~airflow.providers.amazon.aws.operators.s3_tables.S3TablesCreateNamespaceOperator`. + +.. exampleinclude:: /../../amazon/tests/system/amazon/aws/example_s3_tables.py + :language: python + :dedent: 4 + :start-after: [START howto_operator_s3tables_create_namespace] + :end-before: [END howto_operator_s3tables_create_namespace] + .. _howto/operator:S3TablesCreateTableOperator: Create a Table diff --git a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_tables.py b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_tables.py index faed3e2ecea39..53dff2f1f03f8 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_tables.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/operators/s3_tables.py @@ -252,3 +252,46 @@ def execute(self, context: Context) -> None: self.log.info("Deleting S3 Tables table bucket %s", self.table_bucket_arn) self.hook.conn.delete_table_bucket(tableBucketARN=self.table_bucket_arn) self.log.info("Deleted table bucket %s", self.table_bucket_arn) + + +class S3TablesCreateNamespaceOperator(AwsBaseOperator[S3TablesHook]): + """ + Create a namespace in an Amazon S3 Tables table bucket. + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:S3TablesCreateNamespaceOperator` + + :param table_bucket_arn: The ARN of the table bucket. (templated) + :param namespace: The namespace name to create. (templated) + :param if_exists: Behavior when the namespace already exists. + ``"fail"`` raises an error, ``"skip"`` logs and returns. + """ + + template_fields: Sequence[str] = aws_template_fields("table_bucket_arn", "namespace") + aws_hook_class = S3TablesHook + + def __init__( + self, + *, + table_bucket_arn: str, + namespace: str, + if_exists: Literal["fail", "skip"] = "skip", + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.table_bucket_arn = table_bucket_arn + self.namespace = namespace + self.if_exists = if_exists + + def execute(self, context: Context) -> str: + self.log.info("Creating namespace %s in %s", self.namespace, self.table_bucket_arn) + try: + self.hook.conn.create_namespace(tableBucketARN=self.table_bucket_arn, namespace=[self.namespace]) + except ClientError as e: + if e.response["Error"]["Code"] == "ConflictException" and self.if_exists == "skip": + self.log.info("Namespace %s already exists, skipping.", self.namespace) + else: + raise + self.log.info("Namespace %s created.", self.namespace) + return self.namespace diff --git a/providers/amazon/tests/system/amazon/aws/example_s3_tables.py b/providers/amazon/tests/system/amazon/aws/example_s3_tables.py index 52492cc3542b4..f0bba39ed28ca 100644 --- a/providers/amazon/tests/system/amazon/aws/example_s3_tables.py +++ b/providers/amazon/tests/system/amazon/aws/example_s3_tables.py @@ -19,6 +19,7 @@ from datetime import datetime from airflow.providers.amazon.aws.operators.s3_tables import ( + S3TablesCreateNamespaceOperator, S3TablesCreateTableBucketOperator, S3TablesCreateTableOperator, S3TablesDeleteTableBucketOperator, @@ -84,7 +85,13 @@ def delete_namespace(table_bucket_arn: str, namespace: str): table_bucket_name=bucket_name, ) # [END howto_operator_s3tables_create_table_bucket] - setup_namespace = create_namespace(table_bucket_arn=create_table_bucket.output, namespace=namespace) + # [START howto_operator_s3tables_create_namespace] + setup_namespace = S3TablesCreateNamespaceOperator( + task_id="create_namespace", + table_bucket_arn=create_table_bucket.output, + namespace=namespace, + ) + # [END howto_operator_s3tables_create_namespace] # [START howto_operator_s3tables_create_table] create_table = S3TablesCreateTableOperator( diff --git a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_tables.py b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_tables.py index 541c21cd5b8aa..e9b3bf714f6fe 100644 --- a/providers/amazon/tests/unit/amazon/aws/operators/test_s3_tables.py +++ b/providers/amazon/tests/unit/amazon/aws/operators/test_s3_tables.py @@ -25,6 +25,7 @@ from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook from airflow.providers.amazon.aws.hooks.s3_tables import S3TablesHook from airflow.providers.amazon.aws.operators.s3_tables import ( + S3TablesCreateNamespaceOperator, S3TablesCreateTableBucketOperator, S3TablesCreateTableOperator, S3TablesDeleteTableBucketOperator, @@ -231,3 +232,60 @@ def test_execute(self, mock_conn): def test_template_fields(self): validate_template_fields(self.operator) + + +NAMESPACE = "test_namespace" + + +class TestS3TablesCreateNamespaceOperator: + def setup_method(self): + self.operator = S3TablesCreateNamespaceOperator( + task_id="create_namespace", + table_bucket_arn=TABLE_BUCKET_ARN, + namespace=NAMESPACE, + ) + + @mock.patch.object(S3TablesHook, "conn", new_callable=mock.PropertyMock) + def test_execute(self, mock_conn): + mock_client = mock.MagicMock() + mock_conn.return_value = mock_client + + result = self.operator.execute({}) + + mock_client.create_namespace.assert_called_once_with( + tableBucketARN=TABLE_BUCKET_ARN, namespace=[NAMESPACE] + ) + assert result == NAMESPACE + + @mock.patch.object(S3TablesHook, "conn", new_callable=mock.PropertyMock) + def test_execute_skip_existing(self, mock_conn): + mock_client = mock.MagicMock() + mock_client.create_namespace.side_effect = ClientError( + {"Error": {"Code": "ConflictException", "Message": "Already exists"}}, + "CreateNamespace", + ) + mock_conn.return_value = mock_client + + result = self.operator.execute({}) + assert result == NAMESPACE + + @mock.patch.object(S3TablesHook, "conn", new_callable=mock.PropertyMock) + def test_execute_fail_on_conflict(self, mock_conn): + op = S3TablesCreateNamespaceOperator( + task_id="create_namespace", + table_bucket_arn=TABLE_BUCKET_ARN, + namespace=NAMESPACE, + if_exists="fail", + ) + mock_client = mock.MagicMock() + mock_client.create_namespace.side_effect = ClientError( + {"Error": {"Code": "ConflictException", "Message": "Already exists"}}, + "CreateNamespace", + ) + mock_conn.return_value = mock_client + + with pytest.raises(ClientError): + op.execute({}) + + def test_template_fields(self): + validate_template_fields(self.operator)