Skip to content

Commit

Permalink
consolidate hook calling in flow engine (#16596)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Jan 3, 2025
1 parent 45fc50f commit cecd3a1
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 19 deletions.
8 changes: 6 additions & 2 deletions src/prefect/filesystems.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import abc
import urllib.parse
from pathlib import Path
Expand Down Expand Up @@ -92,7 +94,9 @@ class LocalFileSystem(WritableFileSystem, WritableDeploymentStorage):
)

@field_validator("basepath", mode="before")
def cast_pathlib(cls, value):
def cast_pathlib(cls, value: str | Path | None) -> str | None:
if value is None:
return value
return stringify_path(value)

def _resolve_path(self, path: str, validate: bool = False) -> Path:
Expand Down Expand Up @@ -132,7 +136,7 @@ async def get_directory(
Defaults to copying the entire contents of the block's basepath to the current working directory.
"""
if not from_path:
from_path = Path(self.basepath).expanduser().resolve()
from_path = Path(self.basepath or ".").expanduser().resolve()
else:
from_path = self._resolve_path(from_path)

Expand Down
18 changes: 4 additions & 14 deletions src/prefect/flow_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def client(self) -> SyncPrefectClient:

def _resolve_parameters(self):
if not self.parameters:
return {}
return

resolved_parameters = {}
for parameter, value in self.parameters.items():
Expand Down Expand Up @@ -277,7 +277,6 @@ def begin_run(self) -> State:
),
)
self.short_circuit = True
self.call_hooks()

new_state = Running()
state = self.set_state(new_state)
Expand All @@ -300,6 +299,7 @@ def set_state(self, state: State, force: bool = False) -> State:
self.flow_run.state_type = state.type # type: ignore

self._telemetry.update_state(state)
self.call_hooks(state)
return state

def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -711,8 +711,6 @@ def start(self) -> Generator[None, None, None]:
):
self.begin_run()

if self.state.is_running():
self.call_hooks()
yield

@contextmanager
Expand All @@ -734,9 +732,6 @@ def run_context(self):
except Exception as exc:
self.logger.exception("Encountered exception during execution: %r", exc)
self.handle_exception(exc)
finally:
if self.state.is_final() or self.state.is_cancelling():
self.call_hooks()

def call_flow_fn(self) -> Union[R, Coroutine[Any, Any, R]]:
"""
Expand Down Expand Up @@ -774,7 +769,7 @@ def client(self) -> PrefectClient:

def _resolve_parameters(self):
if not self.parameters:
return {}
return

resolved_parameters = {}
for parameter, value in self.parameters.items():
Expand Down Expand Up @@ -842,7 +837,6 @@ async def begin_run(self) -> State:
),
)
self.short_circuit = True
await self.call_hooks()

new_state = Running()
state = await self.set_state(new_state)
Expand All @@ -865,6 +859,7 @@ async def set_state(self, state: State, force: bool = False) -> State:
self.flow_run.state_type = state.type # type: ignore

self._telemetry.update_state(state)
await self.call_hooks(state)
return state

async def result(self, raise_on_failure: bool = True) -> "Union[R, State, None]":
Expand Down Expand Up @@ -1279,8 +1274,6 @@ async def start(self) -> AsyncGenerator[None, None]:
):
await self.begin_run()

if self.state.is_running():
await self.call_hooks()
yield

@asynccontextmanager
Expand All @@ -1302,9 +1295,6 @@ async def run_context(self):
except Exception as exc:
self.logger.exception("Encountered exception during execution: %r", exc)
await self.handle_exception(exc)
finally:
if self.state.is_final() or self.state.is_cancelling():
await self.call_hooks()

async def call_flow_fn(self) -> Coroutine[Any, Any, R]:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/prefect/workers/process.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ProcessJobConfiguration(BaseJobConfiguration):

@field_validator("working_dir")
@classmethod
def validate_command(cls, v):
def validate_command(cls, v: str) -> str:
return validate_command(v)

def prepare_for_flow_run(
Expand Down
4 changes: 2 additions & 2 deletions tests/test_flows.py
Original file line number Diff line number Diff line change
Expand Up @@ -3667,7 +3667,7 @@ def my_flow():
await my_flow(return_state=True)
assert my_mock.mock_calls == [call("crashed")]

async def test_on_crashed_hook_not_called_on_sigterm_from_flow_with_cancelling_state(
async def test_on_crashed_hook_called_on_sigterm_from_flow_with_cancelling_state(
self, mock_sigterm_handler
):
my_mock = MagicMock()
Expand All @@ -3691,7 +3691,7 @@ async def my_flow():

with pytest.raises(prefect.exceptions.TerminationSignal):
await my_flow(return_state=True)
my_mock.assert_not_called()
my_mock.assert_called_once()

def test_on_crashed_hooks_respect_env_var(self, monkeypatch):
my_mock = MagicMock()
Expand Down

0 comments on commit cecd3a1

Please sign in to comment.