Skip to content

Commit cf798ec

Browse files
martindurantdgegen
andauthored
Add semaphore to AsyncFileSystemWrapper (#1908)
Co-authored-by: dgegen <[email protected]>
1 parent c46db87 commit cf798ec

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

fsspec/implementations/asyn_wrapper.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from fsspec.asyn import AsyncFileSystem, running_async
77

88

9-
def async_wrapper(func, obj=None):
9+
def async_wrapper(func, obj=None, semaphore=None):
1010
"""
1111
Wraps a synchronous function to make it awaitable.
1212
@@ -16,6 +16,8 @@ def async_wrapper(func, obj=None):
1616
The synchronous function to wrap.
1717
obj : object, optional
1818
The instance to bind the function to, if applicable.
19+
semaphore : asyncio.Semaphore, optional
20+
A semaphore to limit concurrent calls.
1921
2022
Returns
2123
-------
@@ -25,6 +27,9 @@ def async_wrapper(func, obj=None):
2527

2628
@functools.wraps(func)
2729
async def wrapper(*args, **kwargs):
30+
if semaphore:
31+
async with semaphore:
32+
return await asyncio.to_thread(func, *args, **kwargs)
2833
return await asyncio.to_thread(func, *args, **kwargs)
2934

3035
return wrapper
@@ -52,6 +57,8 @@ def __init__(
5257
asynchronous=None,
5358
target_protocol=None,
5459
target_options=None,
60+
semaphore=None,
61+
max_concurrent_tasks=None,
5562
**kwargs,
5663
):
5764
if asynchronous is None:
@@ -62,6 +69,7 @@ def __init__(
6269
else:
6370
self.sync_fs = fsspec.filesystem(target_protocol, **target_options)
6471
self.protocol = self.sync_fs.protocol
72+
self.semaphore = semaphore
6573
self._wrap_all_sync_methods()
6674

6775
@property
@@ -83,7 +91,7 @@ def _wrap_all_sync_methods(self):
8391

8492
method = getattr(self.sync_fs, method_name)
8593
if callable(method) and not inspect.iscoroutinefunction(method):
86-
async_method = async_wrapper(method, obj=self)
94+
async_method = async_wrapper(method, obj=self, semaphore=self.semaphore)
8795
setattr(self, f"_{method_name}", async_method)
8896

8997
@classmethod

fsspec/implementations/tests/test_asyn_wrapper.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,49 @@
11
import asyncio
22
import os
3+
from itertools import cycle
34

45
import pytest
56

67
import fsspec
8+
from fsspec.asyn import AsyncFileSystem
79
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
810
from fsspec.implementations.local import LocalFileSystem
911

1012
from .test_local import csv_files, filetexts
1113

1214

15+
class LockedFileSystem(AsyncFileSystem):
16+
"""
17+
A mock file system that simulates a synchronous locking file systems with delays.
18+
"""
19+
20+
def __init__(
21+
self,
22+
asynchronous: bool = False,
23+
delays=None,
24+
) -> None:
25+
self.lock = asyncio.Lock()
26+
self.delays = cycle((0.03, 0.01) if delays is None else delays)
27+
28+
super().__init__(asynchronous=asynchronous)
29+
30+
async def _cat_file(self, path, start=None, end=None) -> bytes:
31+
await self._simulate_io_operation(path)
32+
return path.encode()
33+
34+
async def _await_io(self) -> None:
35+
await asyncio.sleep(next(self.delays))
36+
37+
async def _simulate_io_operation(self, path) -> None:
38+
await self._check_active()
39+
async with self.lock:
40+
await self._await_io()
41+
42+
async def _check_active(self) -> None:
43+
if self.lock.locked():
44+
raise RuntimeError("Concurrent requests!")
45+
46+
1347
@pytest.mark.asyncio
1448
async def test_is_async_default():
1549
fs = fsspec.filesystem("file")
@@ -161,3 +195,26 @@ def test_open(tmpdir):
161195
)
162196
with of as f:
163197
assert f.read() == b"hello"
198+
199+
200+
@pytest.mark.asyncio
201+
async def test_semaphore_synchronous():
202+
fs = AsyncFileSystemWrapper(
203+
LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(1)
204+
)
205+
206+
paths = [f"path_{i}" for i in range(1, 3)]
207+
results = await asyncio.gather(*(fs._cat_file(path) for path in paths))
208+
209+
assert set(results) == {path.encode() for path in paths}
210+
211+
212+
@pytest.mark.asyncio
213+
async def test_deadlock_when_asynchronous():
214+
fs = AsyncFileSystemWrapper(
215+
LockedFileSystem(), asynchronous=False, semaphore=asyncio.Semaphore(3)
216+
)
217+
paths = [f"path_{i}" for i in range(1, 3)]
218+
219+
with pytest.raises(RuntimeError, match="Concurrent requests!"):
220+
await asyncio.gather(*(fs._cat_file(path) for path in paths))

0 commit comments

Comments
 (0)