Skip to content

Commit

Permalink
Merge branch 'main' into release-1.8.1
Browse files Browse the repository at this point in the history
  • Loading branch information
tatiana authored Dec 30, 2024
2 parents d5da609 + 7dc9411 commit fe55831
Show file tree
Hide file tree
Showing 13 changed files with 306 additions and 12 deletions.
73 changes: 72 additions & 1 deletion cosmos/airflow/graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from collections import defaultdict
from typing import Any, Callable, Union

from airflow.models import BaseOperator
Expand Down Expand Up @@ -74,6 +75,26 @@ def calculate_leaves(tasks_ids: list[str], nodes: dict[str, DbtNode]) -> list[st
return leaves


def exclude_detached_tests_if_needed(
node: DbtNode,
task_args: dict[str, str],
detached_from_parent: dict[str, DbtNode] | None = None,
) -> None:
"""
Add exclude statements if there are tests associated to the model that should be run detached from the model/tests.
Change task_args in-place.
"""
if detached_from_parent is None:
detached_from_parent = {}
exclude: list[str] = task_args.get("exclude", []) # type: ignore
tests_detached_from_this_node: list[DbtNode] = detached_from_parent.get(node.unique_id, []) # type: ignore
for test_node in tests_detached_from_this_node:
exclude.append(test_node.resource_name.split(".")[0])
if exclude:
task_args["exclude"] = exclude # type: ignore


def create_test_task_metadata(
test_task_name: str,
execution_mode: ExecutionMode,
Expand All @@ -82,6 +103,7 @@ def create_test_task_metadata(
on_warning_callback: Callable[..., Any] | None = None,
node: DbtNode | None = None,
render_config: RenderConfig | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
) -> TaskMetadata:
"""
Create the metadata that will be used to instantiate the Airflow Task that will be used to run the Dbt test node.
Expand All @@ -92,11 +114,13 @@ def create_test_task_metadata(
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List.
:param node: If the test relates to a specific node, the node reference
:param detached_from_parent: Dictionary that maps node ids and their children tests that should be run detached
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
task_args = dict(task_args)
task_args["on_warning_callback"] = on_warning_callback
extra_context = {}
detached_from_parent = detached_from_parent or {}

task_owner = ""
airflow_task_config = {}
Expand All @@ -119,6 +143,9 @@ def create_test_task_metadata(
task_args["selector"] = render_config.selector
task_args["exclude"] = render_config.exclude

if node:
exclude_detached_tests_if_needed(node, task_args, detached_from_parent)

return TaskMetadata(
id=test_task_name,
owner=task_owner,
Expand Down Expand Up @@ -192,6 +219,7 @@ def create_task_metadata(
normalize_task_id: Callable[..., Any] | None = None,
test_behavior: TestBehavior = TestBehavior.AFTER_ALL,
on_warning_callback: Callable[..., Any] | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
) -> TaskMetadata | None:
"""
Create the metadata that will be used to instantiate the Airflow Task used to run the Dbt node.
Expand All @@ -205,6 +233,7 @@ def create_task_metadata(
If it is False, then use the name as a prefix for the task id, otherwise do not.
:param on_warning_callback: A callback function called on warnings with additional Context variables “test_names”
and “test_results” of type List. This is param available for dbt test and dbt source freshness command.
:param detached_from_parent: Dictionary that maps node ids and their children tests that should be run detached
:returns: The metadata necessary to instantiate the source dbt node as an Airflow task.
"""
dbt_resource_to_class = create_dbt_resource_to_class(test_behavior)
Expand All @@ -218,6 +247,7 @@ def create_task_metadata(
}

if test_behavior == TestBehavior.BUILD and node.resource_type in SUPPORTED_BUILD_RESOURCES:
exclude_detached_tests_if_needed(node, args, detached_from_parent)
task_id, args = _get_task_id_and_args(
node, args, use_task_group, normalize_task_id, "build", include_resource_type=True
)
Expand Down Expand Up @@ -268,6 +298,17 @@ def create_task_metadata(
return None


def is_detached_test(node: DbtNode) -> bool:
"""
Identify if node should be rendered detached from the parent. Conditions that should be met:
* is a test
* has multiple parents
"""
if node.resource_type == DbtResourceType.TEST and len(node.depends_on) > 1:
return True
return False


def generate_task_or_group(
dag: DAG,
task_group: TaskGroup | None,
Expand All @@ -279,9 +320,11 @@ def generate_task_or_group(
test_indirect_selection: TestIndirectSelection,
on_warning_callback: Callable[..., Any] | None,
normalize_task_id: Callable[..., Any] | None = None,
detached_from_parent: dict[str, DbtNode] | None = None,
**kwargs: Any,
) -> BaseOperator | TaskGroup | None:
task_or_group: BaseOperator | TaskGroup | None = None
detached_from_parent = detached_from_parent or {}

use_task_group = (
node.resource_type in TESTABLE_DBT_RESOURCES
Expand All @@ -299,12 +342,13 @@ def generate_task_or_group(
normalize_task_id=normalize_task_id,
test_behavior=test_behavior,
on_warning_callback=on_warning_callback,
detached_from_parent=detached_from_parent,
)

# In most cases, we'll map one DBT node to one Airflow task
# The exception are the test nodes, since it would be too slow to run test tasks individually.
# If test_behaviour=="after_each", each model task will be bundled with a test task, using TaskGroup
if task_meta and node.resource_type != DbtResourceType.TEST:
if task_meta and not node.resource_type == DbtResourceType.TEST:
if use_task_group:
with TaskGroup(dag=dag, group_id=node.name, parent_group=task_group) as model_task_group:
task = create_airflow_task(task_meta, dag, task_group=model_task_group)
Expand All @@ -315,12 +359,14 @@ def generate_task_or_group(
task_args=task_args,
node=node,
on_warning_callback=on_warning_callback,
detached_from_parent=detached_from_parent,
)
test_task = create_airflow_task(test_meta, dag, task_group=model_task_group)
task >> test_task
task_or_group = model_task_group
else:
task_or_group = create_airflow_task(task_meta, dag, task_group=task_group)

return task_or_group


Expand Down Expand Up @@ -405,6 +451,16 @@ def build_airflow_graph(
tasks_map: dict[str, Union[TaskGroup, BaseOperator]] = {}
task_or_group: TaskGroup | BaseOperator

# Identify test nodes that should be run detached from the associated dbt resource nodes because they
# have multiple parents
detached_from_parent = defaultdict(list)
detached_nodes = {}
for node_id, node in nodes.items():
if is_detached_test(node):
detached_nodes[node_id] = node
for parent_id in node.depends_on:
detached_from_parent[parent_id].append(node)

for node_id, node in nodes.items():
conversion_function = node_converters.get(node.resource_type, generate_task_or_group)
if conversion_function != generate_task_or_group:
Expand All @@ -425,11 +481,26 @@ def build_airflow_graph(
on_warning_callback=on_warning_callback,
normalize_task_id=normalize_task_id,
node=node,
detached_from_parent=detached_from_parent,
)
if task_or_group is not None:
logger.debug(f"Conversion of <{node.unique_id}> was successful!")
tasks_map[node_id] = task_or_group

# Handle detached test nodes
for node_id, node in detached_nodes.items():
test_meta = create_test_task_metadata(
f"{node.resource_name.split('.')[0]}_test",
execution_mode,
test_indirect_selection,
task_args=task_args,
on_warning_callback=on_warning_callback,
render_config=render_config,
node=node,
)
test_task = create_airflow_task(test_meta, dag, task_group=task_group)
tasks_map[node_id] = test_task

# If test_behaviour=="after_all", there will be one test task, run by the end of the DAG
# The end of a DAG is defined by the DAG leaf tasks (tasks which do not have downstream tasks)
if test_behavior == TestBehavior.AFTER_ALL:
Expand Down
15 changes: 8 additions & 7 deletions dev/dags/dbt/jaffle_shop/models/schema.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,14 @@ models:
- not_null
description: This is a unique identifier for an order

- name: customer_id
description: Foreign key to the customers table
tests:
- not_null
- relationships:
to: ref('customers')
field: customer_id
# Comment so we don't have a standalone test relationships_orders_customer_id__customer_id__ref_customers__test
#- name: customer_id
# description: Foreign key to the customers table
# tests:
# - not_null
# - relationships:
# to: ref('customers')
# field: customer_id

- name: order_date
description: Date (UTC) that the order was placed
Expand Down
12 changes: 12 additions & 0 deletions dev/dags/dbt/multiple_parents_test/dbt_project.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: 'my_dbt_project'
version: '1.0.0'
config-version: 2

profile: 'default'

model-paths: ["models"]
test-paths: ["tests"]

models:
my_dbt_project:
materialized: view
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{% test custom_test_combined_model(model) %}
WITH source_data AS (
SELECT id FROM {{ ref('model_a') }}
),
combined_data AS (
SELECT id FROM {{ model }}
)
SELECT
s.id
FROM
source_data s
LEFT JOIN
combined_data c
ON s.id = c.id
WHERE
c.id IS NULL
{% endtest %}
16 changes: 16 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/combined_model.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
-- Combine data from model_a and model_b
WITH model_a AS (
SELECT * FROM {{ ref('model_a') }}
),
model_b AS (
SELECT * FROM {{ ref('model_b') }}
)
SELECT
a.id,
a.name,
b.created_at
FROM
model_a AS a
JOIN
model_b AS b
ON a.id = b.id
4 changes: 4 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/model_a.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Create a simple table
SELECT 1 AS id, 'Alice' AS name
UNION ALL
SELECT 2 AS id, 'Bob' AS name
4 changes: 4 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/model_b.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Create another simple table
SELECT 1 AS id, '2024-12-25'::date AS created_at
UNION ALL
SELECT 2 AS id, '2024-12-26'::date AS created_at
32 changes: 32 additions & 0 deletions dev/dags/dbt/multiple_parents_test/models/schema.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
version: 2

models:
- name: model_a
description: "A simple model with user data"
tests:
- unique:
column_name: id

- name: model_b
description: "A simple model with date data"
tests:
- unique:
column_name: id

- name: combined_model
description: "Combines data from model_a and model_b"
columns:
- name: id
tests:
- not_null

- name: name
tests:
- not_null

- name: created_at
tests:
- not_null

tests:
- custom_test_combined_model: {}
12 changes: 12 additions & 0 deletions dev/dags/dbt/multiple_parents_test/profiles.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
default:
target: dev
outputs:
dev:
type: postgres
host: "{{ env_var('POSTGRES_HOST') }}"
user: "{{ env_var('POSTGRES_USER') }}"
password: "{{ env_var('POSTGRES_PASSWORD') }}"
port: "{{ env_var('POSTGRES_PORT') | int }}"
dbname: "{{ env_var('POSTGRES_DB') }}"
schema: "{{ env_var('POSTGRES_SCHEMA') }}"
threads: 4
34 changes: 34 additions & 0 deletions dev/dags/example_tests_multiple_parents.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
"""
An example DAG that uses Cosmos to render a dbt project into an Airflow DAG.
"""

import os
from datetime import datetime
from pathlib import Path

from cosmos import DbtDag, ProfileConfig, ProjectConfig
from cosmos.profiles import PostgresUserPasswordProfileMapping

DEFAULT_DBT_ROOT_PATH = Path(__file__).parent / "dbt"
DBT_ROOT_PATH = Path(os.getenv("DBT_ROOT_PATH", DEFAULT_DBT_ROOT_PATH))

profile_config = ProfileConfig(
profile_name="default",
target_name="dev",
profile_mapping=PostgresUserPasswordProfileMapping(
conn_id="example_conn",
profile_args={"schema": "public"},
disable_event_tracking=True,
),
)

example_multiple_parents_test = DbtDag(
# dbt/cosmos-specific parameters
project_config=ProjectConfig(
DBT_ROOT_PATH / "multiple_parents_test",
),
profile_config=profile_config,
# normal dag parameters
start_date=datetime(2023, 1, 1),
dag_id="example_multiple_parents_test",
)
2 changes: 1 addition & 1 deletion tests/dbt/parser/test_project.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def test_LegacyDbtProject__handle_config_file():

dbt_project._handle_config_file(SAMPLE_YML_PATH)

assert len(dbt_project.tests) == 12
assert len(dbt_project.tests) == 10
assert "not_null_customer_id_customers" in dbt_project.tests
sample_test = dbt_project.tests["not_null_customer_id_customers"]
assert sample_test.type == DbtModelType.DBT_TEST
Expand Down
4 changes: 2 additions & 2 deletions tests/dbt/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,9 +1609,9 @@ def test_save_dbt_ls_cache(mock_variable_set, mock_datetime, tmp_dbt_project_dir
hash_dir, hash_args = version.split(",")
assert hash_args == "d41d8cd98f00b204e9800998ecf8427e"
if sys.platform == "darwin":
assert hash_dir == "2b0b0c3d243f9bfdda0f60b56ab65836"
assert hash_dir == "fa5edac64de49909d4b8cbc4dc8abd4f"
else:
assert hash_dir == "cd0535d9a4acb972d74e49eaab85fb6f"
assert hash_dir == "9c9f712b6f6f1ace880dfc7f5f4ff051"


@pytest.mark.integration
Expand Down
Loading

0 comments on commit fe55831

Please sign in to comment.