Skip to content

Commit

Permalink
refactor: Moved source and destination hooks into cached properties
Browse files Browse the repository at this point in the history
  • Loading branch information
davidblain-infrabel authored and potiuk committed Jan 2, 2025
1 parent 8a1d2de commit 757ab68
Showing 1 changed file with 17 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

from collections.abc import Sequence
from functools import cached_property
from typing import TYPE_CHECKING, Any

import jinja2
Expand Down Expand Up @@ -126,17 +127,21 @@ def render_template_fields(
if isinstance(commit_every, str):
self.insert_args["commit_every"] = int(commit_every)

def paginated_execute(self, context: Context):
destination_hook = BaseHook.get_hook(
self.destination_conn_id, hook_params=self.destination_hook_params
)
@cached_property
def source_hook(self) -> BaseHook:
return BaseHook.get_hook(self.source_conn_id, hook_params=self.source_hook_params)

@cached_property
def destination_hook(self) -> BaseHook:
return BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params)

def paginated_execute(self, context: Context):
if self.preoperator:
run = getattr(destination_hook, "run", None)
run = getattr(self.destination_hook, "run", None)
if not callable(run):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
f"({type(destination_hook).__name__}) has no `run` method"
f"({type(self.destination_hook).__name__}) has no `run` method"
)
self.log.info("Running preoperator")
self.log.info(self.preoperator)
Expand All @@ -155,28 +160,24 @@ def paginated_execute(self, context: Context):
method_name=self.execute_complete.__name__,
)
else:
source_hook = BaseHook.get_hook(
self.source_conn_id, hook_params=self.source_hook_params
)

self.log.info("Extracting data from %s", self.source_conn_id)
self.log.info("Executing: \n %s", self.sql)

get_records = getattr(source_hook, "get_records", None)
get_records = getattr(self.source_hook, "get_records", None)

if not callable(get_records):
raise RuntimeError(
f"Hook for connection {self.source_conn_id!r} "
f"({type(source_hook).__name__}) has no `get_records` method"
f"({type(self.source_hook).__name__}) has no `get_records` method"
)

results = get_records(self.sql)
insert_rows = getattr(destination_hook, "insert_rows", None)
insert_rows = getattr(self.destination_hook, "insert_rows", None)

if not callable(insert_rows):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
f"({type(destination_hook).__name__}) has no `insert_rows` method"
f"({type(self.destination_hook).__name__}) has no `insert_rows` method"
)

self.log.info("Inserting rows into %s", self.destination_conn_id)
Expand All @@ -191,17 +192,15 @@ def execute_complete(
if event.get("status") == "failure":
raise AirflowException(event.get("message"))

destination_hook = BaseHook.get_hook(self.destination_conn_id, hook_params=self.destination_hook_params)

results = event.get("results")

if results:
insert_rows = getattr(destination_hook, "insert_rows", None)
insert_rows = getattr(self.destination_hook, "insert_rows", None)

if not callable(insert_rows):
raise RuntimeError(
f"Hook for connection {self.destination_conn_id!r} "
f"({type(destination_hook).__name__}) has no `insert_rows` method"
f"({type(self.destination_hook).__name__}) has no `insert_rows` method"
)

map_index = context["ti"].map_index
Expand Down

0 comments on commit 757ab68

Please sign in to comment.