Skip to content

Commit 5967d02

Browse files
committed
Locking
1 parent 2b09396 commit 5967d02

File tree

3 files changed

+91
-1
lines changed

3 files changed

+91
-1
lines changed

raffiot/__internal.py

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@
1010

1111
from typing_extensions import final
1212
from functools import total_ordering
13+
import threading
14+
from queue import Queue
1315

14-
__all__ = ["IOTag", "ContTag", "ResultTag", "Scheduled"]
16+
__all__ = ["IOTag", "ContTag", "ResultTag", "Scheduled", "Lock", "Semaphore"]
1517

1618

1719
@final
@@ -43,6 +45,7 @@ class IOTag(Enum):
4345
WAIT = 24 # FIBERS
4446
SLEEP_UNTIL = 25 # EPOCH IN SECONDS
4547
REC = 26 # FUN
48+
LOCK = 27 # LOCK
4649

4750

4851
@final
@@ -86,3 +89,60 @@ def __lt__(self, other):
8689
if self.__schedule == other.__schedule:
8790
return hash(self.__fiber) < hash(other.__fiber)
8891
return self.__schedule < other.__schedule
92+
93+
94+
@final
95+
class Lock:
96+
97+
__slots__ = ["lock", "fiber", "__nb_taken", "waiting"]
98+
99+
def __init__(self):
100+
self.lock = threading.Lock()
101+
self.fiber = None
102+
self.__nb_taken = 0
103+
self.waiting = Queue()
104+
105+
def acquire(self, fiber) -> bool:
106+
if self.fiber is None:
107+
self.__nb_taken = 1
108+
return True
109+
if self.fiber is fiber:
110+
self.fiber = fiber
111+
self.__nb_taken += 1
112+
return True
113+
return False
114+
115+
def release(self):
116+
with self.lock:
117+
self.__nb_taken -= 1
118+
if self.__nb_taken == 0:
119+
if self.waiting.empty():
120+
self.fiber = None
121+
return
122+
self.fiber = self.waiting.get()
123+
self.__nb_taken = 1
124+
self.fiber._Fiber__monitor._Monitor__resume(self.fiber)
125+
126+
@final
127+
class Semaphore:
128+
129+
__slots__ = ["lock", "tokens", "waiting"]
130+
131+
def __init__(self, tokens: int):
132+
self.lock = threading.Lock()
133+
self.tokens = tokens
134+
self.waiting = Queue()
135+
136+
def acquire(self, fiber) -> bool:
137+
if self.tokens > 0:
138+
self.tokens -= 1
139+
return True
140+
return False
141+
142+
def release(self):
143+
with self.lock:
144+
if self.waiting.empty():
145+
self.tokens += 1
146+
return
147+
fiber = self.waiting.get()
148+
fiber._Fiber__monitor._Monitor__resume(fiber)

raffiot/io.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,6 +351,8 @@ def __str__(self) -> str:
351351
return f"SleepUntil({self.__fields})"
352352
if self.__tag == IOTag.REC:
353353
return f"Rec({self.__fields})"
354+
if self.__tag == IOTag.LOCK:
355+
return f"Lock({self.__fields})"
354356

355357
def __repr__(self):
356358
return str(self)
@@ -1078,6 +1080,24 @@ def callback(r):
10781080
arg_tag = ResultTag.PANIC
10791081
arg_value = exception
10801082
break
1083+
if tag == IOTag.LOCK:
1084+
lock = io._IO__fields
1085+
try:
1086+
with lock.lock:
1087+
if lock.acquire(self):
1088+
arg_tag = ResultTag.OK
1089+
arg_value = None
1090+
break
1091+
else:
1092+
self.__io = IO(IOTag.PURE, None)
1093+
self.__context = context
1094+
self.__cont = cont
1095+
lock.waiting.put(self)
1096+
return
1097+
except Exception as exception:
1098+
arg_tag = ResultTag.PANIC
1099+
arg_value = exception
1100+
break
10811101
arg_tag = ResultTag.PANIC
10821102
arg_value = _MatchError(f"{io} should be an IO")
10831103
break

raffiot/resource.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from raffiot import io, _MatchError
1616
from raffiot.io import IO
1717
from raffiot.result import Result, Ok, Error, Panic
18+
import __internal
1819

1920
R = TypeVar("R")
2021
E = TypeVar("E")
@@ -593,3 +594,12 @@ def sleep(seconds: float) -> Resource[R, E, None]:
593594
:return:
594595
"""
595596
return lift_io(io.sleep(seconds))
597+
598+
599+
def lock() -> Resource[None, None, None]:
600+
new_lock = __internal.Lock()
601+
return Resource(IO(__internal.IOTag.LOCK, new_lock).then(io.pure((None, io.defer(new_lock.release)))))
602+
603+
def semaphore(tokens: int) -> Resource[None, None, None]:
604+
new_semaphore = __internal.Semaphore(tokens)
605+
return Resource(IO(__internal.IOTag.LOCK, new_semaphore).then(io.pure((None, io.defer(new_semaphore.release)))))

0 commit comments

Comments
 (0)