Skip to content

Commit 0910ec9

Browse files
committed
Adds decorators for marking functions for no input/output translation
Useful in cases where arguments or return values would never contain synchronicity objects and/or awaitables that the user wants unwrapped, and when arguments or return values are large enough that nested inspection of them would cause a signficiant performance penalty
1 parent e4e4c59 commit 0910ec9

File tree

2 files changed

+94
-22
lines changed

2 files changed

+94
-22
lines changed

synchronicity/synchronizer.py

+50-22
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def __init__(
122122

123123
# Special attribute to mark something as non-wrappable
124124
self._nowrap_attr = "_sync_nonwrap_%d" % id(self)
125+
self._no_input_translation_attr = "_sync_no_input_translation_%d" % id(self)
126+
self._no_output_unwrapping_attr = "_sync_no_output_translation_%d" % id(self)
125127

126128
# Prep a synchronized context manager in case one is returned and needs translation
127129
self._ctx_mgr_cls = contextlib._AsyncGeneratorContextManager
@@ -285,13 +287,16 @@ def _translate_in(self, obj):
285287
def _translate_out(self, obj, interface):
286288
return self._recurse_map(lambda scalar: self._translate_scalar_out(scalar, interface), obj)
287289

288-
def _translate_coro_out(self, coro, interface):
290+
def _translate_coro_out(self, coro, interface, original_func):
289291
async def unwrap_coro():
290-
return self._translate_out(await coro, interface)
292+
res = await coro
293+
if not getattr(original_func, self._no_output_unwrapping_attr, False):
294+
return self._translate_out(res, interface)
295+
return res
291296

292297
return unwrap_coro()
293298

294-
def _run_function_sync(self, coro, interface):
299+
def _run_function_sync(self, coro, interface, original_func):
295300
if self._is_inside_loop():
296301
raise Exception("Deadlock detected: calling a sync function from the synchronizer loop")
297302

@@ -300,9 +305,11 @@ def _run_function_sync(self, coro, interface):
300305
loop = self._get_loop(start=True)
301306
fut = asyncio.run_coroutine_threadsafe(coro, loop)
302307
value = fut.result()
303-
return self._translate_out(value, interface)
308+
if not getattr(original_func, self._no_output_unwrapping_attr, False):
309+
return self._translate_out(value, interface)
310+
return value
304311

305-
def _run_function_sync_future(self, coro, interface):
312+
def _run_function_sync_future(self, coro, interface, original_func):
306313
coro = wrap_coro_exception(coro)
307314
coro = self._wrap_check_async_leakage(coro)
308315
loop = self._get_loop(start=True)
@@ -311,7 +318,7 @@ def _run_function_sync_future(self, coro, interface):
311318
coro = self._translate_coro_out(coro, interface)
312319
return asyncio.run_coroutine_threadsafe(coro, loop)
313320

314-
async def _run_function_async(self, coro, interface):
321+
async def _run_function_async(self, coro, interface, original_func):
315322
coro = wrap_coro_exception(coro)
316323
coro = self._wrap_check_async_leakage(coro)
317324
loop = self._get_loop(start=True)
@@ -321,16 +328,19 @@ async def _run_function_async(self, coro, interface):
321328
c_fut = asyncio.run_coroutine_threadsafe(coro, loop)
322329
a_fut = asyncio.wrap_future(c_fut)
323330
value = await a_fut
324-
return self._translate_out(value, interface)
325331

326-
def _run_generator_sync(self, gen, interface):
332+
if not getattr(original_func, self._no_output_unwrapping_attr, False):
333+
return self._translate_out(value, interface)
334+
return value
335+
336+
def _run_generator_sync(self, gen, interface, original_func):
327337
value, is_exc = None, False
328338
while True:
329339
try:
330340
if is_exc:
331-
value = self._run_function_sync(gen.athrow(value), interface)
341+
value = self._run_function_sync(gen.athrow(value), interface, original_func)
332342
else:
333-
value = self._run_function_sync(gen.asend(value), interface)
343+
value = self._run_function_sync(gen.asend(value), interface, original_func)
334344
except UserCodeException as uc_exc:
335345
raise uc_exc.exc from None
336346
except StopAsyncIteration:
@@ -342,14 +352,14 @@ def _run_generator_sync(self, gen, interface):
342352
value = exc
343353
is_exc = True
344354

345-
async def _run_generator_async(self, gen, interface):
355+
async def _run_generator_async(self, gen, interface, original_func):
346356
value, is_exc = None, False
347357
while True:
348358
try:
349359
if is_exc:
350-
value = await self._run_function_async(gen.athrow(value), interface)
360+
value = await self._run_function_async(gen.athrow(value), interface, original_func)
351361
else:
352-
value = await self._run_function_async(gen.asend(value), interface)
362+
value = await self._run_function_async(gen.asend(value), interface, original_func)
353363
except UserCodeException as uc_exc:
354364
raise uc_exc.exc from None
355365
except StopAsyncIteration:
@@ -403,8 +413,9 @@ def f_wrapped(*args, **kwargs):
403413

404414
# If this gets called with an argument that represents an external type,
405415
# translate it into an internal type
406-
args = self._translate_in(args)
407-
kwargs = self._translate_in(kwargs)
416+
if not getattr(f, self._no_input_translation_attr, False):
417+
args = self._translate_in(args)
418+
kwargs = self._translate_in(kwargs)
408419

409420
# Call the function
410421
res = f(*args, **kwargs)
@@ -417,14 +428,14 @@ def f_wrapped(*args, **kwargs):
417428
if not allow_futures:
418429
raise Exception("Can not return future for this function")
419430
elif is_coroutine:
420-
return self._run_function_sync_future(res, interface)
431+
return self._run_function_sync_future(res, interface, f)
421432
elif is_asyncgen:
422433
raise Exception("Can not return futures for generators")
423434
else:
424435
return res
425436
elif is_coroutine:
426437
if interface in (Interface.ASYNC, Interface._ASYNC_WITH_BLOCKING_TYPES):
427-
coro = self._run_function_async(res, interface)
438+
coro = self._run_function_async(res, interface, f)
428439
if not is_coroutinefunction:
429440
# If this is a non-async function that returns a coroutine,
430441
# then this is the exit point, and we need to unwrap any
@@ -435,7 +446,7 @@ def f_wrapped(*args, **kwargs):
435446
elif interface == Interface.BLOCKING:
436447
# This is the exit point, so we need to unwrap the exception here
437448
try:
438-
return self._run_function_sync(res, interface)
449+
return self._run_function_sync(res, interface, f)
439450
except UserCodeException as uc_exc:
440451
# Used to skip a frame when called from `proxy_method`.
441452
if unwrap_user_excs and not (Interface.BLOCKING and include_aio_interface):
@@ -446,9 +457,9 @@ def f_wrapped(*args, **kwargs):
446457
# Note that the _run_generator_* functions handle their own
447458
# unwrapping of exceptions (this happens during yielding)
448459
if interface in (Interface.ASYNC, Interface._ASYNC_WITH_BLOCKING_TYPES):
449-
return self._run_generator_async(res, interface)
460+
return self._run_generator_async(res, interface, f)
450461
elif interface == Interface.BLOCKING:
451-
return self._run_generator_sync(res, interface)
462+
return self._run_generator_sync(res, interface, f)
452463
else:
453464
if inspect.isfunction(res) or isinstance(res, functools.partial): # TODO: HACKY HACK
454465
# TODO: this is needed for decorator wrappers that returns functions
@@ -458,11 +469,17 @@ def f_wrapped(*args, **kwargs):
458469
args = self._translate_in(args)
459470
kwargs = self._translate_in(kwargs)
460471
f_res = res(*args, **kwargs)
461-
return self._translate_out(f_res, interface)
472+
if not getattr(f, self._no_output_unwrapping_attr, False):
473+
return self._translate_out(f_res, interface)
474+
else:
475+
return f_res
462476

463477
return f_wrapped
464478

465-
return self._translate_out(res, interface)
479+
if not getattr(f, self._no_output_unwrapping_attr, False):
480+
return self._translate_out(res, interface)
481+
else:
482+
return res
466483

467484
self._update_wrapper(f_wrapped, f, _name, interface, target_module=target_module)
468485
setattr(f_wrapped, self._original_attr, f)
@@ -701,6 +718,17 @@ def nowrap(self, obj):
701718
setattr(obj, self._nowrap_attr, True)
702719
return obj
703720

721+
def no_input_translation(self, obj):
722+
setattr(obj, self._no_input_translation_attr, True)
723+
return obj
724+
725+
def no_output_translation(self, obj):
726+
setattr(obj, self._no_output_unwrapping_attr, True)
727+
return obj
728+
729+
def no_io_translation(self, obj):
730+
return self.no_input_translation(self.no_output_translation(obj))
731+
704732
# New interface that (almost) doesn't mutate objects
705733
def create_blocking(self, obj, name: Optional[str] = None, target_module: Optional[str] = None):
706734
wrapped = self._wrap(obj, Interface.BLOCKING, name, target_module=target_module)

test/synchronicity_test.py

+44
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import concurrent.futures
33
import inspect
44
from typing import Coroutine
5+
from unittest.mock import MagicMock, ANY
56

67
import pytest
78
import time
@@ -490,3 +491,46 @@ async def get_self(self):
490491
self_from_aio_interface = await original.get_self.aio()
491492
assert self_from_aio_interface == original
492493
assert isinstance(self_from_aio_interface, BlockingFoo)
494+
495+
496+
497+
def test_no_input_translation(monkeypatch):
498+
s = Synchronizer()
499+
@s.create_blocking
500+
def does_input_translation(arg: float) -> str:
501+
return str(arg)
502+
503+
@s.create_blocking
504+
@s.no_input_translation
505+
async def without_input_translation(arg: float) -> str:
506+
return str(arg)
507+
508+
in_translate_spy = MagicMock(wraps=s._translate_scalar_in)
509+
monkeypatch.setattr(s, "_translate_scalar_in", in_translate_spy)
510+
does_input_translation(3.14) # test without decorator, this *should* do input translation
511+
in_translate_spy.assert_called_once_with(3.14)
512+
513+
in_translate_spy.reset_mock()
514+
without_input_translation(3.14) # test without decorator, this *should* do input translation
515+
in_translate_spy.assert_not_called()
516+
517+
518+
def test_no_output_translation(monkeypatch):
519+
s = Synchronizer()
520+
@s.create_blocking
521+
def does_input_translation(arg: float) -> str:
522+
return str(arg)
523+
524+
@s.create_blocking
525+
@s.no_output_translation
526+
async def without_output_translation(arg: float) -> str:
527+
return str(arg)
528+
529+
out_translate_spy = MagicMock(wraps=s._translate_scalar_out)
530+
monkeypatch.setattr(s, "_translate_scalar_out", out_translate_spy)
531+
does_input_translation(3.14) # test without decorator, this *should* do input translation
532+
out_translate_spy.assert_called_once_with("3.14", Interface.BLOCKING)
533+
534+
out_translate_spy.reset_mock()
535+
without_output_translation(3.14) # test without decorator, this *should* do input translation
536+
out_translate_spy.assert_not_called()

0 commit comments

Comments
 (0)