|
1 | 1 | import asyncio
|
2 | 2 | import os
|
| 3 | +from itertools import cycle |
3 | 4 |
|
4 | 5 | import pytest
|
5 | 6 |
|
6 | 7 | import fsspec
|
| 8 | +from fsspec.asyn import AsyncFileSystem |
7 | 9 | from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
|
8 | 10 | from fsspec.implementations.local import LocalFileSystem
|
9 | 11 |
|
10 | 12 | from .test_local import csv_files, filetexts
|
11 | 13 |
|
12 | 14 |
|
| 15 | +class LockedFileSystem(AsyncFileSystem): |
| 16 | + """ |
| 17 | + A mock file system that simulates a synchronous locking file systems with delays. |
| 18 | + """ |
| 19 | + |
| 20 | + def __init__( |
| 21 | + self, |
| 22 | + asynchronous: bool = False, |
| 23 | + delays=None, |
| 24 | + ) -> None: |
| 25 | + self.lock = asyncio.Lock() |
| 26 | + self.delays = cycle((0.03, 0.01) if delays is None else delays) |
| 27 | + |
| 28 | + super().__init__(asynchronous=asynchronous) |
| 29 | + |
| 30 | + async def _cat_file(self, path, start=None, end=None) -> bytes: |
| 31 | + await self._simulate_io_operation(path) |
| 32 | + return path.encode() |
| 33 | + |
| 34 | + async def _await_io(self) -> None: |
| 35 | + await asyncio.sleep(next(self.delays)) |
| 36 | + |
| 37 | + async def _simulate_io_operation(self, path) -> None: |
| 38 | + await self._check_active() |
| 39 | + async with self.lock: |
| 40 | + await self._await_io() |
| 41 | + |
| 42 | + async def _check_active(self) -> None: |
| 43 | + if self.lock.locked(): |
| 44 | + raise RuntimeError("Concurrent requests!") |
| 45 | + |
| 46 | + |
13 | 47 | @pytest.mark.asyncio
|
14 | 48 | async def test_is_async_default():
|
15 | 49 | fs = fsspec.filesystem("file")
|
@@ -161,3 +195,26 @@ def test_open(tmpdir):
|
161 | 195 | )
|
162 | 196 | with of as f:
|
163 | 197 | assert f.read() == b"hello"
|
| 198 | + |
| 199 | + |
| 200 | +@pytest.mark.asyncio |
| 201 | +async def test_semaphore_synchronous(): |
| 202 | + fs = AsyncFileSystemWrapper( |
| 203 | + LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(1) |
| 204 | + ) |
| 205 | + |
| 206 | + paths = [f"path_{i}" for i in range(1, 3)] |
| 207 | + results = await asyncio.gather(*(fs._cat_file(path) for path in paths)) |
| 208 | + |
| 209 | + assert set(results) == {path.encode() for path in paths} |
| 210 | + |
| 211 | + |
| 212 | +@pytest.mark.asyncio |
| 213 | +async def test_deadlock_when_asynchronous(): |
| 214 | + fs = AsyncFileSystemWrapper( |
| 215 | + LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(3) |
| 216 | + ) |
| 217 | + paths = [f"path_{i}" for i in range(1, 3)] |
| 218 | + |
| 219 | + with pytest.raises(RuntimeError, match="Concurrent requests!"): |
| 220 | + await asyncio.gather(*(fs._cat_file(path) for path in paths)) |
0 commit comments