Skip to content

Commit 9f45ae0

Browse files
committed
Direct implementation for map, flat_map and defer_io
1 parent 1250d4b commit 9f45ae0

File tree

1 file changed

+119
-87
lines changed

1 file changed

+119
-87
lines changed

raffiot/io.py

Lines changed: 119 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -201,27 +201,43 @@ def on_failure(
201201
)
202202
)
203203

204-
205204
# IO PURE 0 VALUE
206-
# IO AP 1 FUN ARG
207-
# IO FLATTEN 2 TOWER
208-
# IO DEFER 3 DEFERED
209-
# IO READ 4
210-
# IO CONTRA_MAP_READ 5 FUN MAIN
211-
# IO RAISE 6 ERROR
212-
# IO CATCH 7 MAIN HANDLER
213-
# IO MAP_ERROR 8 MAIN FUN
214-
# IO PANIC 9 EXCEPTION
215-
# IO RECOVER 10 MAIN HANDLER
216-
# IO MAP_PANIC 11 MAIN FUN
217-
205+
# IO MAP 1 MAIN FUN
206+
# IO FLATMAP 2 MAIN HANDLER
207+
# IO AP 3 FUN ARG
208+
# IO FLATTEN 4 TOWER
209+
# IO DEFER 5 DEFERED
210+
# IO DEFER_IO 6 DEFERED
211+
# IO READ 7
212+
# IO CONTRA_MAP_READ 8 FUN MAIN
213+
# IO RAISE 9 ERROR
214+
# IO CATCH 10 MAIN HANDLER
215+
# IO MAP_ERROR 11 MAIN FUN
216+
# IO PANIC 12 EXCEPTION
217+
# IO RECOVER 13 MAIN HANDLER
218+
# IO MAP_PANIC 14 MAIN FUN
218219

219220
def pure(a: A) -> IO[R, E, A]:
220221
"""
221222
An always successful computation returning a.
222223
"""
223-
return IO(0, a)
224+
return IO(0,a)
225+
226+
def map(main: IO[R, E, A], f: Callable[[A], A2]) -> IO[R, E, A2]:
227+
"""
228+
Transform the computed value with f if the computation is successful.
229+
Do nothing otherwise.
230+
"""
231+
if main.tag in [9,12]:
232+
return main
233+
return IO(1, (main, f))
224234

235+
def flat_map(main: IO[R, E, A], f: Callable[[A], IO[R, E, A2]]) -> IO[R, E, A2]:
236+
"""
237+
Chain two computations.
238+
The result of the first one (main) can be used in the second (f).
239+
"""
240+
return IO(2, (main, f))
225241

226242
def ap(fun: IO[R, E, Callable[[X], A]], arg: IO[R, E, X]) -> IO[R, E, A]:
227243
"""
@@ -232,16 +248,8 @@ def ap(fun: IO[R, E, Callable[[X], A]], arg: IO[R, E, X]) -> IO[R, E, A]:
232248
then fun.ap(arg) computes f(x): A
233249
"""
234250
if fun.tag == 0 and arg.tag == 0:
235-
return IO(3, lambda: fun.fields(arg.fields))
236-
return IO(1, (fun, arg))
237-
238-
239-
def map(main: IO[R, E, A], f: Callable[[A], A2]) -> IO[R, E, A2]:
240-
"""
241-
Transform the computed value with f if the computation is successful.
242-
Do nothing otherwise.
243-
"""
244-
return IO(1, (IO(0, f), main))
251+
return IO(5, lambda: fun.fields(arg.fields))
252+
return IO(3, (fun, arg))
245253

246254

247255
def flatten(tower: IO[R, E, IO[R, E, A]]) -> IO[R, E, A]:
@@ -250,18 +258,7 @@ def flatten(tower: IO[R, E, IO[R, E, A]]) -> IO[R, E, A]:
250258
"""
251259
if tower.tag == 0:
252260
return tower.fields
253-
return IO(2, tower)
254-
255-
256-
def flat_map(main: IO[R, E, A], f: Callable[[A], IO[R, E, A2]]) -> IO[R, E, A2]:
257-
"""
258-
Chain two computations.
259-
The result of the first one (main) can be used in the second (f).
260-
"""
261-
if main.tag == 0:
262-
IO(2, IO(3, lambda: f(main.fields)))
263-
return IO(2, IO(1, (IO(0, f), main)))
264-
261+
return IO(4,tower)
265262

266263
def defer(deferred: Callable[[], A]) -> IO[R, E, A]:
267264
"""
@@ -288,7 +285,7 @@ def defer(deferred: Callable[[], A]) -> IO[R, E, A]:
288285
>>> hello.run(None)
289286
"Hello World!" is printed again
290287
"""
291-
return IO(3, deferred)
288+
return IO(5,deferred)
292289

293290

294291
def defer_io(deferred: Callable[[], IO[R, E, A]]) -> IO[R, E, A]:
@@ -309,7 +306,7 @@ def defer_io(deferred: Callable[[], IO[R, E, A]]) -> IO[R, E, A]:
309306
>> return defer_io(lambda: f())
310307
>> f().run(None)
311308
"""
312-
return IO(2, IO(3, deferred))
309+
return IO(6, deferred)
313310

314311

315312
def read() -> IO[R, E, R]:
@@ -322,24 +319,24 @@ def read() -> IO[R, E, R]:
322319
323320
Please note that the contra_map_read method can transform this value r.
324321
"""
325-
return IO(4, None)
322+
return IO(7, None)
326323

327324

328325
def contra_map_read(fun: Callable[[R], R2], main: IO[R2, E, A]) -> IO[R, E2, A]:
329326
"""
330327
Transform the context with f.
331328
Note that f is not from R to R2 but from R2 to R!
332329
"""
333-
if main.tag in [0, 3, 6, 9]:
330+
if main.tag in [0,5,9,12]:
334331
return main
335-
return IO(5, (fun, main))
332+
return IO(8, (fun, main))
336333

337334

338335
def error(err: E) -> IO[R, E, A]:
339336
"""
340337
Computation that fails on the error err.
341338
"""
342-
return IO(6, err)
339+
return IO(9,err)
343340

344341

345342
def catch(main: IO[R, E, A], handler: Callable[[E], IO[R, E, A]]) -> IO[R, E, A]:
@@ -348,26 +345,26 @@ def catch(main: IO[R, E, A], handler: Callable[[E], IO[R, E, A]]) -> IO[R, E, A]
348345
349346
On error, call the handler with the error.
350347
"""
351-
if main.tag in [0, 3, 4, 9]:
348+
if main.tag in [0,5,7,12]:
352349
return main
353-
return IO(7, (main, handler))
350+
return IO(10, (main, handler))
354351

355352

356353
def map_error(main: IO[R, E2, A], fun: Callable[[E2], E]) -> IO[R, E2, A]:
357354
"""
358355
Transform the stored error if the computation fails on an error.
359356
Do nothing otherwise.
360357
"""
361-
if main.tag in [0, 3, 4, 9]:
358+
if main.tag in [0,5,7,12]:
362359
return main
363-
return IO(8, (main, fun))
360+
return IO(11, (main, fun))
364361

365362

366363
def panic(exception: Exception) -> IO[R, E, A]:
367364
"""
368365
Computation that fails with the panic exception.
369366
"""
370-
return IO(9, exception)
367+
return IO(12, exception)
371368

372369

373370
def recover(
@@ -378,20 +375,19 @@ def recover(
378375
379376
On panic, call the handler with the exception.
380377
"""
381-
if main.tag in [0, 4, 6]:
378+
if main.tag in [0,7,9]:
382379
return main
383-
return IO(10, (main, handler))
380+
return IO(13, (main, handler))
384381

385382

386383
def map_panic(main: IO[R, E, A], fun: Callable[[Exception], Exception]) -> IO[R, E, A]:
387384
"""
388385
Transform the exception stored if the computation fails on a panic.
389386
Do nothing otherwise.
390387
"""
391-
if main.tag in [0, 4, 6]:
388+
if main.tag in [0,7,9]:
392389
return main
393-
return IO(11, (main, fun))
394-
390+
return IO(14, (main, fun))
395391

396392
def from_result(r: Result[E, A]) -> IO[R, E, A]:
397393
"""
@@ -417,66 +413,85 @@ def run(main_context: R, main_io: IO[R, E, A]) -> Result[E, A]:
417413
cont = (0,)
418414
arg = None
419415
# CONT ID 0
420-
# CONT AP1 1 CONT CONTEXT IO
421-
# CONT AP2 2 CONT FUN
422-
# CONT FLATTEN 3 CONT CONTEXT
423-
# CONT CATCH 4 CONT CONTEXT HANDLER
424-
# CONT MAP_ERROR 5 CONT FUN
425-
# CONT RECOVER 6 CONT CONTEXT HANDLER
426-
# CONT MAP_PANIC 7 CONT FUN
416+
# CONT MAP 1 CONT FUN
417+
# CONT FLATMAP1 2 CONT CONTEXT HANDLER
418+
# CONT AP1 3 CONT CONTEXT IO
419+
# CONT AP2 4 CONT FUN
420+
# CONT FLATTEN 5 CONT CONTEXT
421+
# CONT CATCH 6 CONT CONTEXT HANDLER
422+
# CONT MAP_ERROR 7 CONT FUN
423+
# CONT RECOVER 8 CONT CONTEXT HANDLER
424+
# CONT MAP_PANIC 9 CONT FUN
427425

428426
while True:
429427
# Eval IO
430428
while True:
431429
tag = io.tag
432-
if tag == 0: # PURE
430+
if tag == 0: # PURE
433431
arg = result.Ok(io.fields)
434432
break
435-
if tag == 1: # AP
436-
cont = (1, cont, context, io.fields[1])
433+
if tag == 1: # MAP
434+
#run(context, io.main).map(io.f)
435+
cont = (1, cont, io.fields[1])
436+
io = io.fields[0]
437+
continue
438+
if tag == 2: # FLATMAP
439+
#run(context, io.main).flat_map(lambda x: run(context, io.f(x)))
440+
cont = (2, cont, context, io.fields[1])
441+
io = io.fields[0]
442+
continue
443+
if tag == 3: # AP
444+
cont = (3, cont, context, io.fields[1])
437445
io = io.fields[0]
438446
continue
439-
if tag == 2: # FLATTEN
440-
cont = (3, cont, context)
447+
if tag == 4: # FLATTEN
448+
cont = (5, cont, context)
441449
io = io.fields
442450
continue
443-
if tag == 3: # DEREF
451+
if tag == 5: # DEREF
444452
try:
445453
arg = result.Ok(io.fields())
446454
except Exception as exception:
447455
arg = result.Panic(exception)
448456
break
449-
if tag == 4: # READ
457+
if tag == 6: # DEREF_IO
458+
try:
459+
io = io.fields()
460+
continue
461+
except Exception as exception:
462+
arg = result.Panic(exception)
463+
break
464+
if tag == 7: # READ
450465
arg = result.Ok(context)
451466
break
452-
if tag == 5: # MAP READ
467+
if tag == 8: # MAP READ
453468
try:
454469
context = io.fields[0](context)
455470
io = io.fields[1]
456471
continue
457472
except Exception as exception:
458473
arg = result.Panic(exception)
459474
break
460-
if tag == 6: # RAISE
475+
if tag == 9: # RAISE
461476
arg = result.Error(io.fields)
462477
break
463-
if tag == 7: # CATCH
464-
cont = (4, cont, context, io.fields[1])
478+
if tag == 10: # CATCH
479+
cont = (6, cont, context, io.fields[1])
465480
io = io.fields[0]
466481
continue
467-
if tag == 8: # MAP ERROR
468-
cont = (5, cont, io.fields[1])
482+
if tag == 11: # MAP ERROR
483+
cont = (7, cont, io.fields[1])
469484
io = io.fields[0]
470485
continue
471-
if tag == 9: # PANIC
486+
if tag == 12: # PANIC
472487
arg = result.Panic(io.fields)
473488
break
474-
if tag == 10: # RECOVER
475-
cont = (6, cont, context, io.fields[1])
489+
if tag == 13: # RECOVER
490+
cont = (8, cont, context, io.fields[1])
476491
io = io.fields[0]
477492
continue
478-
if tag == 11: # MAP PANIC
479-
cont = (7, cont, io.fields[1])
493+
if tag == 14: # MAP PANIC
494+
cont = (9, cont, io.fields[1])
480495
io = io.fields[0]
481496
continue
482497
arg = result.Panic(_MatchError(f"{io} should be an IO"))
@@ -485,29 +500,47 @@ def run(main_context: R, main_io: IO[R, E, A]) -> Result[E, A]:
485500
# Eval Cont
486501
while True:
487502
tag = cont[0]
488-
if tag == 0: # Cont ID
503+
if tag == 0: # Cont ID
489504
return arg
490-
if tag == 1: # Cont AP1
505+
if tag == 1: # Cont MAP
506+
try:
507+
arg = arg.map(cont[2])
508+
except Exception as exception:
509+
arg = result.Panic(exception)
510+
cont = cont[1]
511+
continue
512+
if tag == 2: # Cont FLATMAP
513+
try:
514+
if isinstance(arg, result.Ok):
515+
io = cont[3](arg.success)
516+
context = cont[2]
517+
cont = cont[1]
518+
break
519+
except Exception as exception:
520+
arg = result.Panic(exception)
521+
cont = cont[1]
522+
continue
523+
if tag == 3: # Cont AP1
491524
context = cont[2]
492525
io = cont[3]
493-
cont = (2, cont[1], arg)
526+
cont = (4, cont[1], arg)
494527
break
495-
if tag == 2: # Cont AP2
528+
if tag == 4: # Cont AP2
496529
try:
497530
arg = cont[2].ap(arg)
498531
except Exception as exception:
499532
arg = result.Panic(exception)
500533
cont = cont[1]
501534
continue
502-
if tag == 3: # Cont Flatten
535+
if tag == 5: # Cont Flatten
503536
if isinstance(arg, result.Ok):
504537
context = cont[2]
505538
io = arg.success
506539
cont = cont[1]
507540
break
508541
cont = cont[1]
509542
continue
510-
if tag == 4: # Cont CATCH
543+
if tag == 6: # Cont CATCH
511544
try:
512545
if isinstance(arg, result.Error):
513546
io = cont[3](arg.error)
@@ -518,14 +551,14 @@ def run(main_context: R, main_io: IO[R, E, A]) -> Result[E, A]:
518551
arg = result.Panic(exception)
519552
cont = cont[1]
520553
continue
521-
if tag == 5: # Cont MAP ERROR
554+
if tag == 7: # Cont MAP ERROR
522555
try:
523556
arg = arg.map_error(cont[2])
524557
except Exception as exception:
525558
arg = result.Panic(exception)
526559
cont = cont[1]
527560
continue
528-
if tag == 6: # Cont RECOVER
561+
if tag == 8: # Cont RECOVER
529562
try:
530563
if isinstance(arg, result.Panic):
531564
io = cont[3](arg.exception)
@@ -536,7 +569,7 @@ def run(main_context: R, main_io: IO[R, E, A]) -> Result[E, A]:
536569
arg = result.Panic(exception)
537570
cont = cont[1]
538571
continue
539-
if tag == 7: # CONT MAP PANIC
572+
if tag == 9: # CONT MAP PANIC
540573
try:
541574
arg = arg.map_panic(cont[2])
542575
except Exception as exception:
@@ -545,7 +578,6 @@ def run(main_context: R, main_io: IO[R, E, A]) -> Result[E, A]:
545578
continue
546579
raise _MatchError(f"{cont} should be a Cont")
547580

548-
549581
def safe(f: Callable[..., IO[R, E, A]]) -> Callable[..., IO[R, E, A]]:
550582
"""
551583
Ensures a function retuning an IO never raise any exception but returns a

0 commit comments

Comments
 (0)