|
| 1 | +from pathlib import Path |
| 2 | +from unittest.mock import patch |
| 3 | + |
| 4 | +import anyio |
| 5 | +import pytest |
| 6 | + |
| 7 | +from prefect import flow, task |
| 8 | +from prefect.filesystems import LocalFileSystem |
| 9 | +from prefect.results import ResultStore |
| 10 | + |
| 11 | + |
| 12 | +@pytest.fixture |
| 13 | +def custom_storage_block(tmpdir: Path): |
| 14 | + class Test(LocalFileSystem): |
| 15 | + _block_type_slug = "test" |
| 16 | + |
| 17 | + async def awrite_path(self, path: str, content: bytes) -> str: |
| 18 | + _path: Path = self._resolve_path(path) |
| 19 | + |
| 20 | + _path.parent.mkdir(exist_ok=True, parents=True) |
| 21 | + |
| 22 | + if _path.exists() and not _path.is_file(): |
| 23 | + raise ValueError(f"Path {_path} already exists and is not a file.") |
| 24 | + |
| 25 | + async with await anyio.open_file(_path, mode="wb") as f: |
| 26 | + await f.write(content) |
| 27 | + return str(_path) |
| 28 | + |
| 29 | + Test.register_type_and_schema() |
| 30 | + test = Test(basepath=str(tmpdir)) |
| 31 | + test.save("test", overwrite=True) |
| 32 | + return test |
| 33 | + |
| 34 | + |
| 35 | +async def test_async_method_used_in_async_context( |
| 36 | + custom_storage_block: LocalFileSystem, |
| 37 | +): |
| 38 | + # this is a regression test for https://github.com/PrefectHQ/prefect/issues/16486 |
| 39 | + with patch.object( |
| 40 | + custom_storage_block, "awrite_path", wraps=custom_storage_block.awrite_path |
| 41 | + ) as mock_awrite: |
| 42 | + |
| 43 | + @task(result_storage=custom_storage_block, result_storage_key="testing") |
| 44 | + async def t(): |
| 45 | + return "this is a test" |
| 46 | + |
| 47 | + @flow |
| 48 | + async def f(): |
| 49 | + return await t() |
| 50 | + |
| 51 | + result = await f() |
| 52 | + assert result == "this is a test" |
| 53 | + store = ResultStore(result_storage=custom_storage_block) |
| 54 | + stored_result_record = await store.aread("testing") |
| 55 | + |
| 56 | + assert stored_result_record.result == result == "this is a test" |
| 57 | + # Verify awrite_path was called |
| 58 | + mock_awrite.assert_awaited_once() |
| 59 | + assert mock_awrite.await_args[0][0] == "testing" # Check path argument |
0 commit comments