From 48a5a0aae9782d2df66f91cdc5f8cf46985026c8 Mon Sep 17 00:00:00 2001 From: Kacper Muda Date: Thu, 2 Jan 2025 09:14:19 +0100 Subject: [PATCH] feat: automatically inject OL info into spark job in DataprocInstantiateInlineWorkflowTemplateOperator (#44697) Signed-off-by: Kacper Muda --- docs/exts/templates/openlineage.rst.jinja2 | 2 + .../google/cloud/openlineage/utils.py | 64 +++++ .../google/cloud/operators/dataproc.py | 13 + .../google/cloud/openlineage/test_utils.py | 123 +++++++++ .../google/cloud/operators/test_dataproc.py | 236 ++++++++++++++++++ 5 files changed, 438 insertions(+) diff --git a/docs/exts/templates/openlineage.rst.jinja2 b/docs/exts/templates/openlineage.rst.jinja2 index 217e634457c70..af5798d5d51a9 100644 --- a/docs/exts/templates/openlineage.rst.jinja2 +++ b/docs/exts/templates/openlineage.rst.jinja2 @@ -38,6 +38,8 @@ apache-airflow-providers-google - Parent Job Information - :class:`~airflow.providers.google.cloud.operators.dataproc.DataprocCreateBatchOperator` - Parent Job Information +- :class:`~airflow.providers.google.cloud.operators.dataproc.DataprocInstantiateInlineWorkflowTemplateOperator` + - Parent Job Information :class:`~airflow.providers.common.sql.operators.sql.SQLExecuteQueryOperator` diff --git a/providers/src/airflow/providers/google/cloud/openlineage/utils.py b/providers/src/airflow/providers/google/cloud/openlineage/utils.py index 0f3dcb5d4be92..c6fbadba953ee 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/utils.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/utils.py @@ -622,3 +622,67 @@ def inject_openlineage_properties_into_dataproc_batch( batch_with_ol_config = _replace_dataproc_batch_properties(batch=batch, new_properties=properties) return batch_with_ol_config + + +def inject_openlineage_properties_into_dataproc_workflow_template( + template: dict, context: Context, inject_parent_job_info: bool +) -> dict: + """ + Inject OpenLineage properties into Spark jobs in Workflow Template. + + Function is not removing any configuration or modifying the jobs in any other way, + apart from adding desired OpenLineage properties to Dataproc job definition if not already present. + + Note: + Any modification to job will be skipped if: + - OpenLineage provider is not accessible. + - The job type is not supported. + - Automatic parent job information injection is disabled. + - Any OpenLineage properties with parent job information are already present + in the Spark job definition. + + Args: + template: The original Dataproc Workflow Template definition. + context: The Airflow context in which the job is running. + inject_parent_job_info: Flag indicating whether to inject parent job information. + + Returns: + The modified Workflow Template definition with OpenLineage properties injected, if applicable. + """ + if not inject_parent_job_info: + log.debug("Automatic injection of OpenLineage information is disabled.") + return template + + if not _is_openlineage_provider_accessible(): + log.warning( + "Could not access OpenLineage provider for automatic OpenLineage " + "properties injection. No action will be performed." + ) + return template + + final_jobs = [] + for single_job_definition in template["jobs"]: + step_id = single_job_definition["step_id"] + log.debug("Injecting OpenLineage properties into Workflow step: `%s`", step_id) + + if (job_type := _extract_supported_job_type_from_dataproc_job(single_job_definition)) is None: + log.debug( + "Could not find a supported Dataproc job type for automatic OpenLineage " + "properties injection. No action will be performed.", + ) + final_jobs.append(single_job_definition) + continue + + properties = single_job_definition[job_type].get("properties", {}) + + properties = inject_parent_job_information_into_spark_properties( + properties=properties, context=context + ) + + job_with_ol_config = _replace_dataproc_job_properties( + job=single_job_definition, job_type=job_type, new_properties=properties + ) + final_jobs.append(job_with_ol_config) + + template["jobs"] = final_jobs + return template diff --git a/providers/src/airflow/providers/google/cloud/operators/dataproc.py b/providers/src/airflow/providers/google/cloud/operators/dataproc.py index 5e64f7d920707..1d5ced10283c9 100644 --- a/providers/src/airflow/providers/google/cloud/operators/dataproc.py +++ b/providers/src/airflow/providers/google/cloud/operators/dataproc.py @@ -57,6 +57,7 @@ from airflow.providers.google.cloud.openlineage.utils import ( inject_openlineage_properties_into_dataproc_batch, inject_openlineage_properties_into_dataproc_job, + inject_openlineage_properties_into_dataproc_workflow_template, ) from airflow.providers.google.cloud.operators.cloud_base import GoogleCloudBaseOperator from airflow.providers.google.cloud.triggers.dataproc import ( @@ -1825,6 +1826,9 @@ def __init__( deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), polling_interval_seconds: int = 10, cancel_on_kill: bool = True, + openlineage_inject_parent_job_info: bool = conf.getboolean( + "openlineage", "spark_inject_parent_job_info", fallback=False + ), **kwargs, ) -> None: super().__init__(**kwargs) @@ -1844,11 +1848,20 @@ def __init__( self.polling_interval_seconds = polling_interval_seconds self.cancel_on_kill = cancel_on_kill self.operation_name: str | None = None + self.openlineage_inject_parent_job_info = openlineage_inject_parent_job_info def execute(self, context: Context): self.log.info("Instantiating Inline Template") hook = DataprocHook(gcp_conn_id=self.gcp_conn_id, impersonation_chain=self.impersonation_chain) project_id = self.project_id or hook.project_id + if self.openlineage_inject_parent_job_info: + self.log.info("Automatic injection of OpenLineage information into Spark properties is enabled.") + self.template = inject_openlineage_properties_into_dataproc_workflow_template( + template=self.template, + context=context, + inject_parent_job_info=self.openlineage_inject_parent_job_info, + ) + operation = hook.instantiate_inline_workflow_template( template=self.template, project_id=project_id, diff --git a/providers/tests/google/cloud/openlineage/test_utils.py b/providers/tests/google/cloud/openlineage/test_utils.py index 58949125f8433..b5e451debe5fa 100644 --- a/providers/tests/google/cloud/openlineage/test_utils.py +++ b/providers/tests/google/cloud/openlineage/test_utils.py @@ -48,6 +48,7 @@ get_identity_column_lineage_facet, inject_openlineage_properties_into_dataproc_batch, inject_openlineage_properties_into_dataproc_job, + inject_openlineage_properties_into_dataproc_workflow_template, merge_column_lineage_facets, ) @@ -829,3 +830,125 @@ def test_inject_openlineage_properties_into_dataproc_batch(mock_is_ol_accessible } result = inject_openlineage_properties_into_dataproc_batch(batch, context, True) assert result == expected_batch + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_workflow_template_provider_not_accessible( + mock_is_accessible, +): + mock_is_accessible.return_value = False + template = {"workflow": "template"} # It does not matter what the dict is, we should return it unmodified + result = inject_openlineage_properties_into_dataproc_workflow_template(template, None, True) + assert result == template + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +@patch("airflow.providers.google.cloud.openlineage.utils._extract_supported_job_type_from_dataproc_job") +def test_inject_openlineage_properties_into_dataproc_workflow_template_no_inject_parent_job_info( + mock_extract_job_type, mock_is_accessible +): + mock_is_accessible.return_value = True + mock_extract_job_type.return_value = "sparkJob" + inject_parent_job_info = False + template = {"workflow": "template"} # It does not matter what the dict is, we should return it unmodified + result = inject_openlineage_properties_into_dataproc_workflow_template( + template, None, inject_parent_job_info + ) + assert result == template + + +@patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") +def test_inject_openlineage_properties_into_dataproc_workflow_template(mock_is_ol_accessible): + mock_is_ol_accessible.return_value = True + context = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) + } + template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + expected_template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { # Injected properties + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Not modified because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + } + result = inject_openlineage_properties_into_dataproc_workflow_template(template, context, True) + assert result == expected_template diff --git a/providers/tests/google/cloud/operators/test_dataproc.py b/providers/tests/google/cloud/operators/test_dataproc.py index 5d4a9b0d79c8d..f79a0bdba0ce9 100644 --- a/providers/tests/google/cloud/operators/test_dataproc.py +++ b/providers/tests/google/cloud/operators/test_dataproc.py @@ -2356,6 +2356,242 @@ def test_wait_for_operation_on_execute(self, mock_hook): ) mock_op.return_value.result.assert_not_called() + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection(self, mock_hook, mock_ol_accessible): + mock_ol_accessible.return_value = True + context = { + "ti": MagicMock( + dag_id="dag_id", + task_id="task_id", + try_number=1, + map_index=1, + logical_date=dt.datetime(2024, 11, 11), + ) + } + template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + "parameters": [ + { + "name": "ZONE", + "fields": [ + "placement.clusterSelector.zone", + ], + } + ], + } + expected_template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + "properties": { # Injected properties + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobName": "dag_id.task_id", + "spark.openlineage.parentJobNamespace": "default", + "spark.openlineage.parentRunId": "01931885-2800-7be7-aa8d-aaa15c337267", + }, + }, + }, + { + "step_id": "job_2", + "pyspark_job": { # Not modified because it's already present + "main_python_file_uri": "gs://bucket2/spark_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + "spark.openlineage.parentJobNamespace": "test", + }, + }, + }, + { + "step_id": "job_3", + "hive_job": { # Not modified because it's unsupported job type + "main_python_file_uri": "gs://bucket3/hive_job.py", + "properties": { + "spark.sql.shuffle.partitions": "1", + }, + }, + }, + ], + "parameters": [ + { + "name": "ZONE", + "fields": [ + "placement.clusterSelector.zone", + ], + } + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + openlineage_inject_parent_job_info=True, + ) + op.execute(context=context) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=expected_template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection_skipped_by_default_unless_enabled( + self, mock_hook, mock_ol_accessible + ): + mock_ol_accessible.return_value = True + + template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + }, + } + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + # not passing openlineage_inject_parent_job_info, should be False by default + ) + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + + @mock.patch("airflow.providers.google.cloud.openlineage.utils._is_openlineage_provider_accessible") + @mock.patch(DATAPROC_PATH.format("DataprocHook")) + def test_execute_openlineage_parent_job_info_injection_skipped_when_ol_not_accessible( + self, mock_hook, mock_ol_accessible + ): + mock_ol_accessible.return_value = False + + template = { + "id": "test-workflow", + "placement": { + "cluster_selector": { + "zone": "europe-central2-c", + "cluster_labels": {"key": "value"}, + } + }, + "jobs": [ + { + "step_id": "job_1", + "pyspark_job": { + "main_python_file_uri": "gs://bucket1/spark_job.py", + }, + } + ], + } + + op = DataprocInstantiateInlineWorkflowTemplateOperator( + task_id=TASK_ID, + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + gcp_conn_id=GCP_CONN_ID, + impersonation_chain=IMPERSONATION_CHAIN, + openlineage_inject_parent_job_info=True, + ) + op.execute(context=MagicMock()) + mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID, impersonation_chain=IMPERSONATION_CHAIN) + mock_hook.return_value.instantiate_inline_workflow_template.assert_called_once_with( + template=template, + region=GCP_REGION, + project_id=GCP_PROJECT, + request_id=REQUEST_ID, + retry=RETRY, + timeout=TIMEOUT, + metadata=METADATA, + ) + @pytest.mark.db_test @pytest.mark.need_serialized_dag