Skip to content

Commit

Permalink
use task.run(context=context) where implemented + use optimistic patt…
Browse files Browse the repository at this point in the history
…ern in CopyIncrementally

this will work around issues where task.run does not implement the context pattern e.g. RunFunction. These commands will always run locally
  • Loading branch information
leo-schick committed May 4, 2022
1 parent 2b0f5f7 commit dca0bb9
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
9 changes: 2 additions & 7 deletions mara_pipelines/commands/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ..incremental_processing import file_dependencies
from ..incremental_processing import incremental_copy_status
from ..logging import logger
from ..contexts import ExecutionContext, _LocalShellExecutionContext
from ..contexts import ExecutionContext


class _SQLCommand(pipelines.Command):
Expand Down Expand Up @@ -265,12 +265,7 @@ def target_db_alias(self):
return self._target_db_alias or config.default_db_alias()

def run(self, context: ExecutionContext = None) -> bool:
if isinstance(context, _LocalShellExecutionContext):
run_shell_command = context.run_shell_command
elif context is None:
run_shell_command = shell.run_shell_command
else:
raise ValueError('The context must inherit type _LocalShellExecutionContext')
run_shell_command = context.run_shell_command if context else shell.run_shell_command

# retrieve the highest current value for the modification comparison (e.g.: the highest timestamp)
# We intentionally use the command line here (rather than sqlalchemy) to avoid forcing people python drivers,
Expand Down
9 changes: 7 additions & 2 deletions mara_pipelines/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,8 +443,13 @@ def __init__(self, task: pipelines.Task, event_queue: multiprocessing.Queue, sta
self.task = task
self.event_queue = event_queue
self.status_queue = status_queue
self.context = context
self.start_time = datetime.datetime.now(tz.utc)
self.run_kargs = {}

# add dynamic kargs for self.task.run(...)
from inspect import signature
if 'context' in signature(task.run):
self.run_kargs['context'] = self.context

def run(self):
# redirect stdout and stderr to queue
Expand All @@ -454,7 +459,7 @@ def run(self):
attempt = 0
try:
while True:
if not self.task.run(context=self.context):
if not self.task.run(**self.run_kargs):
max_retries = self.task.max_retries or config.default_task_max_retries()
if attempt < max_retries:
attempt += 1
Expand Down

0 comments on commit dca0bb9

Please sign in to comment.