diff --git a/providers/src/airflow/providers/google/cloud/openlineage/mixins.py b/providers/src/airflow/providers/google/cloud/openlineage/mixins.py index ce7a14e03ae32..d94aed19cfb85 100644 --- a/providers/src/airflow/providers/google/cloud/openlineage/mixins.py +++ b/providers/src/airflow/providers/google/cloud/openlineage/mixins.py @@ -19,11 +19,13 @@ import copy import json +import logging import traceback from typing import TYPE_CHECKING, cast if TYPE_CHECKING: from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, Dataset, InputDataset, OutputDataset, @@ -36,8 +38,12 @@ BIGQUERY_NAMESPACE = "bigquery" +log = logging.getLogger(__name__) + class _BigQueryOpenLineageMixin: + """ Mixin for BigQueryInsertJobOperator to extract OpenLineage metadata. """ + def get_openlineage_facets_on_complete(self, _): """ Retrieve OpenLineage data for a COMPLETE BigQuery job. @@ -70,8 +76,7 @@ def get_openlineage_facets_on_complete(self, _): from airflow.providers.openlineage.sqlparser import SQLParser if not self.job_id: - if hasattr(self, "log"): - self.log.warning("No BigQuery job_id was found by OpenLineage.") + self.log.warning("No BigQuery job_id was found by OpenLineage.") return OperatorLineage() if not self.hook: @@ -113,8 +118,7 @@ def get_facets(self, job_id: str): inputs = [] outputs = [] run_facets: dict[str, RunFacet] = {} - if hasattr(self, "log"): - self.log.debug("Extracting data from bigquery job: `%s`", job_id) + self.log.debug("Extracting data from bigquery job: `%s`", job_id) try: job = self.client.get_job(job_id=job_id) # type: ignore props = job._properties @@ -125,8 +129,7 @@ def get_facets(self, job_id: str): run_facets["bigQueryJob"] = self._get_bigquery_job_run_facet(props) if get_from_nullable_chain(props, ["statistics", "numChildJobs"]): - if hasattr(self, "log"): - self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") + self.log.debug("Found SCRIPT job. Extracting lineage from child jobs instead.") # SCRIPT job type has no input / output information but spawns child jobs that have one # https://cloud.google.com/bigquery/docs/information-schema-jobs#multi-statement_query_job for child_job_id in self.client.list_jobs(parent_job=job_id): @@ -138,8 +141,7 @@ def get_facets(self, job_id: str): inputs, _output = self._get_inputs_outputs_from_job(props) outputs.append(_output) except Exception as e: - if hasattr(self, "log"): - self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) + self.log.warning("Cannot retrieve job details from BigQuery.Client. %s", e, exc_info=True) exception_msg = traceback.format_exc() run_facets.update( { @@ -152,8 +154,53 @@ def get_facets(self, job_id: str): deduplicated_outputs = self._deduplicate_outputs(outputs) return inputs, deduplicated_outputs, run_facets + @staticmethod + def _merge_column_lineage_facets(facets: list[ColumnLineageDatasetFacet]) -> ColumnLineageDatasetFacet: + """ + Merge multiple column lineage facets into a single facet. + + Note: + Transformation information will be lost if present. + + Args: + facets: A list of column lineage facets to be merged. + + Returns: + A single merged column lineage facet. + """ + from collections import defaultdict + + from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + + merged_fields: dict[str, set[tuple[str, str, str]]] = defaultdict(set) + + for facet in facets: + for field_name, field in facet.fields.items(): + merged_fields[field_name].update( + (input_field.namespace, input_field.name, input_field.field) + for input_field in field.inputFields + ) + + return ColumnLineageDatasetFacet( + fields={ + field_name: Fields( + inputFields=[ + InputField(namespace, name, column) + for namespace, name, column in sorted(input_fields) + ], + transformationType="", + transformationDescription="", + ) + for field_name, input_fields in merged_fields.items() + } + ) + def _deduplicate_outputs(self, outputs: list[OutputDataset | None]) -> list[OutputDataset]: - # Sources are the same so we can compare only names + # Namespaces are the same so we can compare only names final_outputs = {} for single_output in outputs: if not single_output: @@ -167,6 +214,21 @@ def _deduplicate_outputs(self, outputs: list[OutputDataset | None]) -> list[Outp # if the rowCount or size can be summed together. if single_output.outputFacets: single_output.outputFacets.pop("outputStatistics", None) + + # If both outputs contain Column Level Lineage Facet - merge the facets + if ( + single_output.facets + and final_outputs[key].facets + and "columnLineage" in single_output.facets # type: ignore + and "columnLineage" in final_outputs[key].facets # type: ignore + ): + single_output.facets["columnLineage"] = self._merge_column_lineage_facets( + [ + single_output.facets["columnLineage"], # type: ignore + final_outputs[key].facets["columnLineage"], # type: ignore + ] + ) + final_outputs[key] = single_output return list(final_outputs.values()) @@ -178,14 +240,22 @@ def _get_inputs_outputs_from_job( input_tables = get_from_nullable_chain(properties, ["statistics", "query", "referencedTables"]) or [] output_table = get_from_nullable_chain(properties, ["configuration", "query", "destinationTable"]) - inputs = [(self._get_input_dataset(input_table)) for input_table in input_tables] - if output_table: - output = self._get_output_dataset(output_table) - dataset_stat_facet = self._get_statistics_dataset_facet(properties) - output.outputFacets = output.outputFacets or {} - if dataset_stat_facet: - output.outputFacets["outputStatistics"] = dataset_stat_facet + inputs = [ + (self._get_input_dataset(input_table)) + for input_table in input_tables + if input_table != output_table # Output table is in `referencedTables` and needs to be removed + ] + if not output_table: + return inputs, None + + output = self._get_output_dataset(output_table) + if dataset_stat_facet := self._get_statistics_dataset_facet(properties): + output.outputFacets = output.outputFacets or {} + output.outputFacets["outputStatistics"] = dataset_stat_facet + if cll_facet := self._get_column_level_lineage_facet(properties, output, inputs): + output.facets = output.facets or {} + output.facets["columnLineage"] = cll_facet return inputs, output @staticmethod @@ -225,6 +295,71 @@ def _get_statistics_dataset_facet( return OutputStatisticsOutputDatasetFacet(rowCount=int(out_rows), size=int(out_bytes)) return None + def _get_column_level_lineage_facet( + self, properties: dict, output: OutputDataset, inputs: list[InputDataset] + ) -> ColumnLineageDatasetFacet | None: + """ + Extract column-level lineage information from a BigQuery job and return it as a facet. + + The Column Level Lineage Facet will NOT be returned if any of the following condition is met: + - The parsed result does not contain column lineage information. + - The parsed result does not contain exactly one output table. + - The parsed result has a different output table than the output table from the BQ job. + - The parsed result has at least one input table not present in the input tables from the BQ job. + - The parsed result has a column not present in the schema of given dataset from the BQ job. + + Args: + properties: The properties of the BigQuery job. + output: The output dataset for which the column lineage is being extracted. + + Returns: + The extracted Column Lineage Dataset Facet, or None if conditions are not met. + """ + from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain + from airflow.providers.openlineage.sqlparser import SQLParser + + # Extract SQL query and parse it + self.log.debug("Extracting column-level lineage facet from BigQuery query.") + query = get_from_nullable_chain(properties, ["configuration", "query", "query"]) or "" + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string(SQLParser.normalize_sql(query))) + + if parse_result is None or parse_result.column_lineage == []: + self.log.debug("No column-level lineage found in the SQL query. Facet generation skipped.") + return None + + default_dataset, default_project = self._extract_default_dataset_and_project( + properties, + self.project_id, # type: ignore[attr-defined] + ) + + # Verify if the output table id from the parse result matches the BQ job output table + if not self._validate_output_table_id( + parse_result, + output, + default_project, + default_dataset, + ): + return None + + # Verify if all columns from parse results are present in the output dataset schema + if not self._validate_output_columns(parse_result, output): + return None + + input_tables_from_parse_result = self._extract_parsed_input_tables( + parse_result, default_project, default_dataset + ) + input_tables_from_bq = {input_ds.name: self._extract_column_names(input_ds) for input_ds in inputs} + + # Verify if all datasets from parse results are present in bq job input datasets + if not self._validate_input_tables(input_tables_from_parse_result, input_tables_from_bq): + return None + + # Verify if all columns from parse results are present in their respective bq job input datasets + if not self._validate_input_columns(input_tables_from_parse_result, input_tables_from_bq): + return None + + return self._generate_column_lineage_facet(parse_result, default_project, default_dataset) + def _get_input_dataset(self, table: dict) -> InputDataset: from airflow.providers.common.compat.openlineage.facet import InputDataset @@ -273,8 +408,7 @@ def _get_table_schema_safely(self, table_name: str) -> SchemaDatasetFacet | None try: return self._get_table_schema(table_name) except Exception as e: - if hasattr(self, "log"): - self.log.warning("Could not extract output schema from bigquery. %s", e) + self.log.warning("Could not extract output schema from bigquery. %s", e) return None def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: @@ -303,3 +437,155 @@ def _get_table_schema(self, table: str) -> SchemaDatasetFacet | None: for field in fields ] ) + + @staticmethod + def _get_qualified_name_from_parse_result(table, default_project: str, default_dataset: str) -> str: + """Get the qualified name of a table from the parse result.""" + return ".".join( + ( + table.database or default_project, + table.schema or default_dataset, + table.name, + ) + ) + + @staticmethod + def _extract_default_dataset_and_project(properties: dict, default_project: str) -> tuple[str, str]: + """Extract the default dataset and project from the BigQuery job properties.""" + from airflow.providers.google.cloud.openlineage.utils import get_from_nullable_chain + + default_dataset_obj = get_from_nullable_chain( + properties, ["configuration", "query", "defaultDataset"] + ) + default_dataset = default_dataset_obj.get("datasetId", "") if default_dataset_obj else "" + default_project = ( + default_dataset_obj.get("projectId", default_project) if default_dataset_obj else default_project + ) + return default_dataset, default_project + + def _validate_output_table_id( + self, parse_result, output: OutputDataset, default_project: str, default_dataset: str + ) -> bool: + """Check if the output table id from the parse result matches the BQ job output table.""" + if len(parse_result.out_tables) != 1: + self.log.debug( + "Invalid output tables in the parse result: `%s`. Expected exactly one output table.", + parse_result.out_tables, + ) + return False + + parsed_output_table = self._get_qualified_name_from_parse_result( + parse_result.out_tables[0], default_project, default_dataset + ) + if parsed_output_table != output.name: + self.log.debug( + "Mismatch between parsed output table `%s` and BQ job output table `%s`.", + parsed_output_table, + output.name, + ) + return False + return True + + @staticmethod + def _extract_column_names(dataset: Dataset) -> list[str]: + """Extract column names from a dataset's schema.""" + from airflow.providers.common.compat.openlineage.facet import SchemaDatasetFacet + + return [ + f.name + for f in dataset.facets.get("schema", SchemaDatasetFacet(fields=[])).fields # type: ignore[union-attr] + if dataset.facets + ] + + def _validate_output_columns(self, parse_result, output: OutputDataset) -> bool: + """Validate if all descendant columns in parse result exist in output dataset schema.""" + output_column_names = self._extract_column_names(output) + missing_columns = [ + lineage.descendant.name + for lineage in parse_result.column_lineage + if lineage.descendant.name not in output_column_names + ] + if missing_columns: + self.log.debug( + "Output dataset schema is missing columns from the parse result: `%s`.", missing_columns + ) + return False + return True + + def _extract_parsed_input_tables( + self, parse_result, default_project: str, default_dataset: str + ) -> dict[str, list[str]]: + """Extract input tables and their columns from the parse result.""" + input_tables: dict[str, list[str]] = {} + for lineage in parse_result.column_lineage: + for column_meta in lineage.lineage: + if not column_meta.origin: + self.log.debug( + "Column `%s` lacks origin information. Skipping facet generation.", column_meta.name + ) + return {} + + input_table_id = self._get_qualified_name_from_parse_result( + column_meta.origin, default_project, default_dataset + ) + input_tables.setdefault(input_table_id, []).append(column_meta.name) + return input_tables + + def _validate_input_tables( + self, parsed_input_tables: dict[str, list[str]], input_tables_from_bq: dict[str, list[str]] + ) -> bool: + """Validate if all parsed input tables exist in the BQ job's input datasets.""" + if not parsed_input_tables: + self.log.debug("No input tables found in the parse result. Facet generation skipped.") + return False + if missing_tables := set(parsed_input_tables) - set(input_tables_from_bq): + self.log.debug( + "Parsed input tables not found in the BQ job's input datasets: `%s`.", missing_tables + ) + return False + return True + + def _validate_input_columns( + self, parsed_input_tables: dict[str, list[str]], input_tables_from_bq: dict[str, list[str]] + ) -> bool: + """Validate if all parsed input columns exist in their respective BQ job input table schemas.""" + if not parsed_input_tables: + self.log.debug("No input tables found in the parse result. Facet generation skipped.") + return False + for table, columns in parsed_input_tables.items(): + if missing_columns := set(columns) - set(input_tables_from_bq.get(table, [])): + self.log.debug( + "Input table `%s` is missing columns from the parse result: `%s`.", table, missing_columns + ) + return False + return True + + def _generate_column_lineage_facet( + self, parse_result, default_project: str, default_dataset: str + ) -> ColumnLineageDatasetFacet: + """Generate the ColumnLineageDatasetFacet based on the parsed result.""" + from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Fields, + InputField, + ) + + return ColumnLineageDatasetFacet( + fields={ + lineage.descendant.name: Fields( + inputFields=[ + InputField( + namespace=BIGQUERY_NAMESPACE, + name=self._get_qualified_name_from_parse_result( + column_meta.origin, default_project, default_dataset + ), + field=column_meta.name, + ) + for column_meta in lineage.lineage + ], + transformationType="", + transformationDescription="", + ) + for lineage in parse_result.column_lineage + } + ) diff --git a/providers/tests/google/cloud/openlineage/test_mixins.py b/providers/tests/google/cloud/openlineage/test_mixins.py index fb047ddc2d1c8..17a00a806d41f 100644 --- a/providers/tests/google/cloud/openlineage/test_mixins.py +++ b/providers/tests/google/cloud/openlineage/test_mixins.py @@ -16,15 +16,21 @@ # under the License. from __future__ import annotations +import copy import json +import logging import os from unittest.mock import MagicMock import pytest from airflow.providers.common.compat.openlineage.facet import ( + ColumnLineageDatasetFacet, + Dataset, ExternalQueryRunFacet, + Fields, InputDataset, + InputField, OutputDataset, OutputStatisticsOutputDatasetFacet, SchemaDatasetFacet, @@ -34,6 +40,69 @@ from airflow.providers.google.cloud.openlineage.utils import ( BigQueryJobRunFacet, ) +from airflow.providers.openlineage.sqlparser import SQLParser + +QUERY_JOB_PROPERTIES = { + "configuration": { + "query": { + "query": """ + INSERT INTO dest_project.dest_dataset.dest_table + SELECT a, b, c FROM source_project.source_dataset.source_table + UNION ALL + SELECT a, b, c FROM source_table2 + """, + "defaultDataset": {"datasetId": "default_dataset", "projectId": "default_project"}, + } + } +} +OUTPUT_DATASET = OutputDataset( + namespace="bigquery", + name="dest_project.dest_dataset.dest_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("a", "STRING"), + SchemaDatasetFacetFields("b", "STRING"), + SchemaDatasetFacetFields("c", "STRING"), + SchemaDatasetFacetFields("d", "STRING"), + SchemaDatasetFacetFields("e", "STRING"), + SchemaDatasetFacetFields("f", "STRING"), + SchemaDatasetFacetFields("g", "STRING"), + ] + ) + }, +) +INPUT_DATASETS = [ + InputDataset( + namespace="bigquery", + name="source_project.source_dataset.source_table", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("a", "STRING"), + SchemaDatasetFacetFields("b", "STRING"), + SchemaDatasetFacetFields("c", "STRING"), + SchemaDatasetFacetFields("x", "STRING"), + ] + ) + }, + ), + InputDataset( + namespace="bigquery", + name="default_project.default_dataset.source_table2", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("a", "STRING"), + SchemaDatasetFacetFields("b", "STRING"), + SchemaDatasetFacetFields("c", "STRING"), + SchemaDatasetFacetFields("y", "STRING"), + ] + ) + }, + ), + InputDataset("bigquery", "some.random.tb"), +] def read_common_json_file(rel: str): @@ -64,6 +133,7 @@ def setup_method(self): class BQOperator(_BigQueryOpenLineageMixin): sql = "" job_id = "job_id" + project_id = "project_id" location = None @property @@ -202,6 +272,100 @@ def test_deduplicate_outputs(self): assert second_result.name == "d2" assert second_result.facets == {"t20": "t20"} + def test_deduplicate_outputs_with_cll(self): + outputs = [ + None, + OutputDataset( + name="a.b.c", + namespace="bigquery", + facets={ + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.1", "c")], + transformationType="", + transformationDescription="", + ), + "d": Fields( + inputFields=[InputField("bigquery", "a.b.2", "d")], + transformationType="", + transformationDescription="", + ), + } + ) + }, + ), + OutputDataset( + name="a.b.c", + namespace="bigquery", + facets={ + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ), + "e": Fields( + inputFields=[InputField("bigquery", "a.b.1", "e")], + transformationType="", + transformationDescription="", + ), + } + ) + }, + ), + OutputDataset( + name="x.y.z", + namespace="bigquery", + facets={ + "columnLineage": ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ) + } + ) + }, + ), + ] + result = self.operator._deduplicate_outputs(outputs) + assert len(result) == 2 + first_result = result[0] + assert first_result.name == "a.b.c" + assert first_result.facets["columnLineage"] == ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.1", "c"), InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ), + "d": Fields( + inputFields=[InputField("bigquery", "a.b.2", "d")], + transformationType="", + transformationDescription="", + ), + "e": Fields( + inputFields=[InputField("bigquery", "a.b.1", "e")], + transformationType="", + transformationDescription="", + ), + } + ) + second_result = result[1] + assert second_result.name == "x.y.z" + assert second_result.facets["columnLineage"] == ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ) + } + ) + @pytest.mark.parametrize("cache", (None, "false", False, 0)) def test_get_job_run_facet_no_cache_and_with_bytes(self, cache): properties = { @@ -259,3 +423,348 @@ def test_get_statistics_dataset_facet_with_stats(self): result = self.operator._get_statistics_dataset_facet(properties) assert result.rowCount == 123 assert result.size == 321 + + def test_get_column_level_lineage_facet(self): + result = self.operator._get_column_level_lineage_facet( + QUERY_JOB_PROPERTIES, OUTPUT_DATASET, INPUT_DATASETS + ) + assert result == ColumnLineageDatasetFacet( + fields={ + col: Fields( + inputFields=[ + InputField("bigquery", "default_project.default_dataset.source_table2", col), + InputField("bigquery", "source_project.source_dataset.source_table", col), + ], + transformationType="", + transformationDescription="", + ) + for col in ("a", "b", "c") + } + ) + + def test_get_column_level_lineage_facet_early_exit_empty_cll_from_parser(self): + properties = {"configuration": {"query": {"query": "SELECT 1"}}} + assert ( + self.operator._get_column_level_lineage_facet(properties, OUTPUT_DATASET, INPUT_DATASETS) is None + ) + assert self.operator._get_column_level_lineage_facet({}, OUTPUT_DATASET, INPUT_DATASETS) is None + + def test_get_column_level_lineage_facet_early_exit_output_table_id_mismatch(self): + output = copy.deepcopy(OUTPUT_DATASET) + output.name = "different.name.table" + assert ( + self.operator._get_column_level_lineage_facet(QUERY_JOB_PROPERTIES, output, INPUT_DATASETS) + is None + ) + + def test_get_column_level_lineage_facet_early_exit_output_columns_mismatch(self): + output = copy.deepcopy(OUTPUT_DATASET) + output.facets["schema"].fields = [ + SchemaDatasetFacetFields("different_col", "STRING"), + ] + assert ( + self.operator._get_column_level_lineage_facet(QUERY_JOB_PROPERTIES, output, INPUT_DATASETS) + is None + ) + + def test_get_column_level_lineage_facet_early_exit_wrong_parsed_input_tables(self): + properties = { + "configuration": { + "query": { + "query": """ + INSERT INTO dest_project.dest_dataset.dest_table + SELECT a, b, c FROM some.wrong.source_table + """, + } + } + } + assert ( + self.operator._get_column_level_lineage_facet(properties, OUTPUT_DATASET, INPUT_DATASETS) is None + ) + + def test_get_column_level_lineage_facet_early_exit_wrong_parsed_input_columns(self): + properties = { + "configuration": { + "query": { + "query": """ + INSERT INTO dest_project.dest_dataset.dest_table + SELECT wrong_col, wrong2, wrong3 FROM source_project.source_dataset.source_table + """, + } + } + } + assert ( + self.operator._get_column_level_lineage_facet(properties, OUTPUT_DATASET, INPUT_DATASETS) is None + ) + + def test_safe_log(self): + class OperatorWithLog(_BigQueryOpenLineageMixin): + log = logging.getLogger("test") + + class OperatorWithoutLog(_BigQueryOpenLineageMixin): + pass + + assert isinstance(OperatorWithoutLog()._safe_log, logging.Logger) + assert isinstance(OperatorWithLog()._safe_log, logging.Logger) + + def test_get_qualified_name_from_parse_result(self): + class _Table: # Replacement for SQL parser TableMeta + database = "project" + schema = "dataset" + name = "table" + + class _TableNoSchema: # Replacement for SQL parser TableMeta + database = None + schema = "dataset" + name = "table" + + class _TableNoSchemaNoDb: # Replacement for SQL parser TableMeta + database = None + schema = None + name = "table" + + result = self.operator._get_qualified_name_from_parse_result( + table=_Table(), + default_project="default_project", + default_dataset="default_dataset", + ) + assert result == "project.dataset.table" + + result = self.operator._get_qualified_name_from_parse_result( + table=_TableNoSchema(), + default_project="default_project", + default_dataset="default_dataset", + ) + assert result == "default_project.dataset.table" + + result = self.operator._get_qualified_name_from_parse_result( + table=_TableNoSchemaNoDb(), + default_project="default_project", + default_dataset="default_dataset", + ) + assert result == "default_project.default_dataset.table" + + def test_extract_default_dataset_and_project(self): + properties = {"configuration": {"query": {"defaultDataset": {"datasetId": "default_dataset"}}}} + result = self.operator._extract_default_dataset_and_project(properties, "default_project") + assert result == ("default_dataset", "default_project") + + properties = { + "configuration": { + "query": {"defaultDataset": {"datasetId": "default_dataset", "projectId": "default_project"}} + } + } + result = self.operator._extract_default_dataset_and_project(properties, "another_project") + assert result == ("default_dataset", "default_project") + + result = self.operator._extract_default_dataset_and_project({}, "default_project") + assert result == ("", "default_project") + + def test_validate_output_table_id_no_table(self): + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("SELECT 1")) + assert parse_result.out_tables == [] + assert self.operator._validate_output_table_id(parse_result, None, None, None) is False + + def test_validate_output_table_id_multiple_tables(self): + query = "INSERT INTO a.b.c VALUES (1); INSERT INTO d.e.f VALUES (2);" + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string(query)) + assert len(parse_result.out_tables) == 2 + assert self.operator._validate_output_table_id(parse_result, None, None, None) is False + + def test_validate_output_table_id_mismatch(self): + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("INSERT INTO a.b.c VALUES (1)")) + assert len(parse_result.out_tables) == 1 + assert parse_result.out_tables[0].qualified_name == "a.b.c" + assert ( + self.operator._validate_output_table_id(parse_result, OutputDataset("", "d.e.f"), None, None) + is False + ) + + def test_validate_output_table_id(self): + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("INSERT INTO a.b.c VALUES (1)")) + assert len(parse_result.out_tables) == 1 + assert parse_result.out_tables[0].qualified_name == "a.b.c" + assert ( + self.operator._validate_output_table_id(parse_result, OutputDataset("", "a.b.c"), None, None) + is True + ) + + def test_validate_output_table_id_query_with_table_name_only(self): + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("INSERT INTO c VALUES (1)")) + assert len(parse_result.out_tables) == 1 + assert parse_result.out_tables[0].qualified_name == "c" + assert ( + self.operator._validate_output_table_id(parse_result, OutputDataset("", "a.b.c"), "a", "b") + is True + ) + + def test_extract_column_names_dataset_without_schema(self): + assert self.operator._extract_column_names(Dataset("a", "b")) == [] + + def test_extract_column_names_dataset_(self): + ds = Dataset( + "a", + "b", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("col1", "STRING"), + SchemaDatasetFacetFields("col2", "STRING"), + ] + ) + }, + ) + assert self.operator._extract_column_names(ds) == ["col1", "col2"] + + def test_validate_output_columns_mismatch(self): + ds = OutputDataset( + "a", + "b", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("col1", "STRING"), + SchemaDatasetFacetFields("col2", "STRING"), + ] + ) + }, + ) + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("SELECT a , b FROM c")) + assert self.operator._validate_output_columns(parse_result, ds) is False + + def test_validate_output_columns(self): + ds = OutputDataset( + "a", + "b", + facets={ + "schema": SchemaDatasetFacet( + fields=[ + SchemaDatasetFacetFields("a", "STRING"), + SchemaDatasetFacetFields("b", "STRING"), + ] + ) + }, + ) + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("SELECT a , b FROM c")) + assert self.operator._validate_output_columns(parse_result, ds) is True + + def test_extract_parsed_input_tables(self): + query = "INSERT INTO x SELECT a, b from project1.ds1.tb1; INSERT INTO y SELECT c, d from tb2;" + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string(query)) + assert self.operator._extract_parsed_input_tables(parse_result, "default_project", "default_ds") == { + "project1.ds1.tb1": ["a", "b"], + "default_project.default_ds.tb2": ["c", "d"], + } + + def test_extract_parsed_input_tables_no_cll(self): + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string("SELECT 1")) + assert self.operator._extract_parsed_input_tables(parse_result, "p", "d") == {} + + def test_validate_input_tables_mismatch(self): + result = self.operator._validate_input_tables({"a": None, "b": None}, {"a": None, "c": None}) + assert result is False + + def test_validate_input_tables_bq_has_more_tables(self): + result = self.operator._validate_input_tables({"a": None}, {"a": None, "c": None}) + assert result is True + + def test_validate_input_tables_empty(self): + result = self.operator._validate_input_tables({}, {"a": None, "c": None}) + assert result is False + + def test_validate_input_columns_mismatch(self): + result = self.operator._validate_input_columns( + {"a": ["1", "2"], "b": ["3", "4"]}, {"a": ["1", "2", "3"], "c": ["4", "5"]} + ) + assert result is False + + def test_validate_input_columns_bq_has_more_cols(self): + result = self.operator._validate_input_columns( + {"a": ["1", "2"]}, {"a": ["1", "2", "3"], "c": ["4", "5"]} + ) + assert result is True + + def test_validate_input_columns_empty(self): + result = self.operator._validate_input_columns({}, {"a": ["1", "2", "3"], "c": ["4", "5"]}) + assert result is False + + def test_generate_column_lineage_facet(self): + query = "INSERT INTO b.c SELECT c, d from tb2;" + parse_result = SQLParser("bigquery").parse(SQLParser.split_sql_string(query)) + result = self.operator._generate_column_lineage_facet(parse_result, "default_project", "default_ds") + assert result == ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "default_project.default_ds.tb2", "c")], + transformationType="", + transformationDescription="", + ), + "d": Fields( + inputFields=[InputField("bigquery", "default_project.default_ds.tb2", "d")], + transformationType="", + transformationDescription="", + ), + } + ) + + def test_merge_column_lineage_facets(self): + result = self.operator._merge_column_lineage_facets( + [ + ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.1", "c")], + transformationType="", + transformationDescription="", + ), + "d": Fields( + inputFields=[InputField("bigquery", "a.b.2", "d")], + transformationType="", + transformationDescription="", + ), + } + ), + ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ), + "e": Fields( + inputFields=[InputField("bigquery", "a.b.1", "e")], + transformationType="", + transformationDescription="", + ), + } + ), + ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ) + } + ), + ] + ) + assert result == ColumnLineageDatasetFacet( + fields={ + "c": Fields( + inputFields=[InputField("bigquery", "a.b.1", "c"), InputField("bigquery", "a.b.3", "x")], + transformationType="", + transformationDescription="", + ), + "d": Fields( + inputFields=[InputField("bigquery", "a.b.2", "d")], + transformationType="", + transformationDescription="", + ), + "e": Fields( + inputFields=[InputField("bigquery", "a.b.1", "e")], + transformationType="", + transformationDescription="", + ), + } + ) diff --git a/providers/tests/google/cloud/utils/job_details.json b/providers/tests/google/cloud/utils/job_details.json index f12ec1321d57f..b533e0faa6704 100644 --- a/providers/tests/google/cloud/utils/job_details.json +++ b/providers/tests/google/cloud/utils/job_details.json @@ -225,11 +225,18 @@ "billingTier": 1, "totalSlotMs": "825", "cacheHit": false, - "referencedTables": [{ - "projectId": "airflow-openlineage", - "datasetId": "new_dataset", - "tableId": "test_table" - }], + "referencedTables": [ + { + "projectId": "airflow-openlineage", + "datasetId": "new_dataset", + "tableId": "test_table" + }, + { + "projectId": "airflow-openlineage", + "datasetId": "new_dataset", + "tableId": "output_table" + } + ], "statementType": "SELECT" }, "totalSlotMs": "825"