Skip to content

Commit

Permalink
Prevent get_connection from being called in example_dags (apache#45704)
Browse files Browse the repository at this point in the history
Follow up after apache#45690

Wee already had protection against example dags not using database, but
it turns out that just calling get_connection() of the BaseHook involves
calling out to secrets manager which - depending on the configuration,
providers and where it is called - might cause external calls, timeout
and various side effects.

This PR adds explicit test for that. As part of the change we also
added `--load-example-dags` and `--load-default-connections` to
breeze shell as it allows to easily test the case where default
connections are loaded in the database.

Note that the "example_bedrock_retrieve_and_generate" explicitly
avoided attempting to load the connections by specifing aws_conn_id
to None, because it was likely causing problems with fetching SSM
when get_connection was actually called during dag parsing, so this
aws_conn_id = None would also bypass this check, but we can't do
much about it - at least after this chanege, the contributor
will see failing test with explicit "get_connection() should not
be called during DAG parsing".

That also makes the example dag more of a "real" example as it does not
nullify the connection id and it can use "aws_default" connection to
actually ... be a good example. Also it allows to run the example dag as
system test for someone who would like to do it with "aws_default" as
a connection id to connect to their AWS account.
  • Loading branch information
potiuk authored Jan 16, 2025
1 parent d3fc6c4 commit dce8482
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 30 deletions.
24 changes: 16 additions & 8 deletions dev/breeze/doc/images/output_shell.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion dev/breeze/doc/images/output_shell.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
efd5cdefe6f99a82f8a9eb611f4cc25e
1f9defd0443e2de2496b75a00c5af2cb
40 changes: 23 additions & 17 deletions dev/breeze/src/airflow_breeze/commands/developer_commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,22 @@ def run(self):
envvar="START_WEBSERVER_WITH_EXAMPLES",
)

option_load_example_dags = click.option(
"-e",
"--load-example-dags",
help="Enable configuration to load example DAGs when starting Airflow.",
is_flag=True,
envvar="LOAD_EXAMPLES",
)

option_load_default_connections = click.option(
"-c",
"--load-default-connections",
help="Enable configuration to load default connections when starting Airflow.",
is_flag=True,
envvar="LOAD_DEFAULT_CONNECTIONS",
)


@main.command()
@click.argument("extra-args", nargs=-1, type=click.UNPROCESSED)
Expand Down Expand Up @@ -288,6 +304,8 @@ def run(self):
@option_install_airflow_with_constraints_default_true
@option_install_selected_providers
@option_installation_package_format
@option_load_example_dags
@option_load_default_connections
@option_all_integration
@option_keep_env_variables
@option_max_time
Expand Down Expand Up @@ -343,6 +361,8 @@ def shell(
install_airflow_python_client: bool,
integration: tuple[str, ...],
keep_env_variables: bool,
load_example_dags: bool,
load_default_connections: bool,
max_time: int | None,
mount_sources: str,
mysql_version: str,
Expand Down Expand Up @@ -412,6 +432,8 @@ def shell(
install_selected_providers=install_selected_providers,
integration=integration,
keep_env_variables=keep_env_variables,
load_example_dags=load_example_dags,
load_default_connections=load_default_connections,
mount_sources=mount_sources,
mysql_version=mysql_version,
no_db_cleanup=no_db_cleanup,
Expand Down Expand Up @@ -447,22 +469,6 @@ def shell(
sys.exit(result.returncode)


option_load_example_dags = click.option(
"-e",
"--load-example-dags",
help="Enable configuration to load example DAGs when starting Airflow.",
is_flag=True,
envvar="LOAD_EXAMPLES",
)

option_load_default_connection = click.option(
"-c",
"--load-default-connections",
help="Enable configuration to load default connections when starting Airflow.",
is_flag=True,
envvar="LOAD_DEFAULT_CONNECTIONS",
)

option_executor_start_airflow = click.option(
"--executor",
type=click.Choice(START_AIRFLOW_ALLOWED_EXECUTORS, case_sensitive=False),
Expand Down Expand Up @@ -506,7 +512,7 @@ def shell(
@option_installation_package_format
@option_install_selected_providers
@option_all_integration
@option_load_default_connection
@option_load_default_connections
@option_load_example_dags
@option_mount_sources
@option_mysql_version
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@
"options": [
"--python",
"--integration",
"--load-example-dags",
"--load-default-connections",
"--standalone-dag-processor",
"--start-webserver-with-examples",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -442,9 +442,8 @@ def delete_opensearch_policies(collection_name: str):
test_context = sys_test_context_task()
env_id = test_context["ENV_ID"]

aoss_client = OpenSearchServerlessHook(aws_conn_id=None)
bedrock_agent_client = BedrockAgentHook(aws_conn_id=None)

aoss_client = OpenSearchServerlessHook()
bedrock_agent_client = BedrockAgentHook()
region_name = boto3.session.Session().region_name

naming_prefix = "bedrock-kb-"
Expand Down
22 changes: 21 additions & 1 deletion tests/always/test_example_dags.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from glob import glob
from importlib import metadata as importlib_metadata
from pathlib import Path
from unittest.mock import patch

import pytest
from packaging.specifiers import SpecifierSet
from packaging.version import Version

from airflow.models import DagBag
from airflow.hooks.base import BaseHook
from airflow.models import Connection, DagBag
from airflow.utils import yaml

from tests_common.test_utils.asserts import assert_queries_count
Expand Down Expand Up @@ -209,3 +211,21 @@ def test_should_not_do_database_queries(example: str):
dag_folder=example,
include_examples=False,
)


@pytest.mark.db_test
@pytest.mark.parametrize("example", example_not_excluded_dags(xfail_db_exception=True))
def test_should_not_run_hook_connections(example: str):
# Example dags should never run BaseHook.get_connection() class method when parsed
with patch.object(BaseHook, "get_connection") as mock_get_connection:
mock_get_connection.return_value = Connection()
DagBag(
dag_folder=example,
include_examples=False,
)
assert mock_get_connection.call_count == 0, (
f"BaseHook.get_connection() should not be called during DAG parsing. "
f"It was called {mock_get_connection.call_count} times. Please make sure that no "
"connections are created during DAG parsing. NOTE! Do not set conn_id to None to avoid it, just make "
"sure that you do not create connection object in the `__init__` method of your operator."
)

0 comments on commit dce8482

Please sign in to comment.