From cecd3a1028184cb8e905ee14b2cbdc0c6a094219 Mon Sep 17 00:00:00 2001 From: nate nowack Date: Fri, 3 Jan 2025 13:56:19 -0600 Subject: [PATCH] consolidate hook calling in flow engine (#16596) --- src/prefect/filesystems.py | 8 ++++++-- src/prefect/flow_engine.py | 18 ++++-------------- src/prefect/workers/process.py | 2 +- tests/test_flows.py | 4 ++-- 4 files changed, 13 insertions(+), 19 deletions(-) diff --git a/src/prefect/filesystems.py b/src/prefect/filesystems.py index fee5f35c125c..20c0a45d23dd 100644 --- a/src/prefect/filesystems.py +++ b/src/prefect/filesystems.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import abc import urllib.parse from pathlib import Path @@ -92,7 +94,9 @@ class LocalFileSystem(WritableFileSystem, WritableDeploymentStorage): ) @field_validator("basepath", mode="before") - def cast_pathlib(cls, value): + def cast_pathlib(cls, value: str | Path | None) -> str | None: + if value is None: + return value return stringify_path(value) def _resolve_path(self, path: str, validate: bool = False) -> Path: @@ -132,7 +136,7 @@ async def get_directory( Defaults to copying the entire contents of the block's basepath to the current working directory. """ if not from_path: - from_path = Path(self.basepath).expanduser().resolve() + from_path = Path(self.basepath or ".").expanduser().resolve() else: from_path = self._resolve_path(from_path) diff --git a/src/prefect/flow_engine.py b/src/prefect/flow_engine.py index b0eb0ec91e35..2b1fdfd79e60 100644 --- a/src/prefect/flow_engine.py +++ b/src/prefect/flow_engine.py @@ -209,7 +209,7 @@ def client(self) -> SyncPrefectClient: def _resolve_parameters(self): if not self.parameters: - return {} + return resolved_parameters = {} for parameter, value in self.parameters.items(): @@ -277,7 +277,6 @@ def begin_run(self) -> State: ), ) self.short_circuit = True - self.call_hooks() new_state = Running() state = self.set_state(new_state) @@ -300,6 +299,7 @@ def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_type = state.type # type: ignore self._telemetry.update_state(state) + self.call_hooks(state) return state def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -711,8 +711,6 @@ def start(self) -> Generator[None, None, None]: ): self.begin_run() - if self.state.is_running(): - self.call_hooks() yield @contextmanager @@ -734,9 +732,6 @@ def run_context(self): except Exception as exc: self.logger.exception("Encountered exception during execution: %r", exc) self.handle_exception(exc) - finally: - if self.state.is_final() or self.state.is_cancelling(): - self.call_hooks() def call_flow_fn(self) -> Union[R, Coroutine[Any, Any, R]]: """ @@ -774,7 +769,7 @@ def client(self) -> PrefectClient: def _resolve_parameters(self): if not self.parameters: - return {} + return resolved_parameters = {} for parameter, value in self.parameters.items(): @@ -842,7 +837,6 @@ async def begin_run(self) -> State: ), ) self.short_circuit = True - await self.call_hooks() new_state = Running() state = await self.set_state(new_state) @@ -865,6 +859,7 @@ async def set_state(self, state: State, force: bool = False) -> State: self.flow_run.state_type = state.type # type: ignore self._telemetry.update_state(state) + await self.call_hooks(state) return state async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]": @@ -1279,8 +1274,6 @@ async def start(self) -> AsyncGenerator[None, None]: ): await self.begin_run() - if self.state.is_running(): - await self.call_hooks() yield @asynccontextmanager @@ -1302,9 +1295,6 @@ async def run_context(self): except Exception as exc: self.logger.exception("Encountered exception during execution: %r", exc) await self.handle_exception(exc) - finally: - if self.state.is_final() or self.state.is_cancelling(): - await self.call_hooks() async def call_flow_fn(self) -> Coroutine[Any, Any, R]: """ diff --git a/src/prefect/workers/process.py b/src/prefect/workers/process.py index a180b39bd822..4a89665f4fca 100644 --- a/src/prefect/workers/process.py +++ b/src/prefect/workers/process.py @@ -85,7 +85,7 @@ class ProcessJobConfiguration(BaseJobConfiguration): @field_validator("working_dir") @classmethod - def validate_command(cls, v): + def validate_command(cls, v: str) -> str: return validate_command(v) def prepare_for_flow_run( diff --git a/tests/test_flows.py b/tests/test_flows.py index 596ebc947cff..06739e72a267 100644 --- a/tests/test_flows.py +++ b/tests/test_flows.py @@ -3667,7 +3667,7 @@ def my_flow(): await my_flow(return_state=True) assert my_mock.mock_calls == [call("crashed")] - async def test_on_crashed_hook_not_called_on_sigterm_from_flow_with_cancelling_state( + async def test_on_crashed_hook_called_on_sigterm_from_flow_with_cancelling_state( self, mock_sigterm_handler ): my_mock = MagicMock() @@ -3691,7 +3691,7 @@ async def my_flow(): with pytest.raises(prefect.exceptions.TerminationSignal): await my_flow(return_state=True) - my_mock.assert_not_called() + my_mock.assert_called_once() def test_on_crashed_hooks_respect_env_var(self, monkeypatch): my_mock = MagicMock()