@@ -122,6 +122,8 @@ def __init__(
122
122
123
123
# Special attribute to mark something as non-wrappable
124
124
self ._nowrap_attr = "_sync_nonwrap_%d" % id (self )
125
+ self ._no_input_translation_attr = "_sync_no_input_translation_%d" % id (self )
126
+ self ._no_output_unwrapping_attr = "_sync_no_output_translation_%d" % id (self )
125
127
126
128
# Prep a synchronized context manager in case one is returned and needs translation
127
129
self ._ctx_mgr_cls = contextlib ._AsyncGeneratorContextManager
@@ -285,13 +287,16 @@ def _translate_in(self, obj):
285
287
def _translate_out (self , obj , interface ):
286
288
return self ._recurse_map (lambda scalar : self ._translate_scalar_out (scalar , interface ), obj )
287
289
288
- def _translate_coro_out (self , coro , interface ):
290
+ def _translate_coro_out (self , coro , interface , original_func ):
289
291
async def unwrap_coro ():
290
- return self ._translate_out (await coro , interface )
292
+ res = await coro
293
+ if not getattr (original_func , self ._no_output_unwrapping_attr , False ):
294
+ return self ._translate_out (res , interface )
295
+ return res
291
296
292
297
return unwrap_coro ()
293
298
294
- def _run_function_sync (self , coro , interface ):
299
+ def _run_function_sync (self , coro , interface , original_func ):
295
300
if self ._is_inside_loop ():
296
301
raise Exception ("Deadlock detected: calling a sync function from the synchronizer loop" )
297
302
@@ -300,9 +305,11 @@ def _run_function_sync(self, coro, interface):
300
305
loop = self ._get_loop (start = True )
301
306
fut = asyncio .run_coroutine_threadsafe (coro , loop )
302
307
value = fut .result ()
303
- return self ._translate_out (value , interface )
308
+ if not getattr (original_func , self ._no_output_unwrapping_attr , False ):
309
+ return self ._translate_out (value , interface )
310
+ return value
304
311
305
- def _run_function_sync_future (self , coro , interface ):
312
+ def _run_function_sync_future (self , coro , interface , original_func ):
306
313
coro = wrap_coro_exception (coro )
307
314
coro = self ._wrap_check_async_leakage (coro )
308
315
loop = self ._get_loop (start = True )
@@ -311,7 +318,7 @@ def _run_function_sync_future(self, coro, interface):
311
318
coro = self ._translate_coro_out (coro , interface )
312
319
return asyncio .run_coroutine_threadsafe (coro , loop )
313
320
314
- async def _run_function_async (self , coro , interface ):
321
+ async def _run_function_async (self , coro , interface , original_func ):
315
322
coro = wrap_coro_exception (coro )
316
323
coro = self ._wrap_check_async_leakage (coro )
317
324
loop = self ._get_loop (start = True )
@@ -321,16 +328,19 @@ async def _run_function_async(self, coro, interface):
321
328
c_fut = asyncio .run_coroutine_threadsafe (coro , loop )
322
329
a_fut = asyncio .wrap_future (c_fut )
323
330
value = await a_fut
324
- return self ._translate_out (value , interface )
325
331
326
- def _run_generator_sync (self , gen , interface ):
332
+ if not getattr (original_func , self ._no_output_unwrapping_attr , False ):
333
+ return self ._translate_out (value , interface )
334
+ return value
335
+
336
+ def _run_generator_sync (self , gen , interface , original_func ):
327
337
value , is_exc = None , False
328
338
while True :
329
339
try :
330
340
if is_exc :
331
- value = self ._run_function_sync (gen .athrow (value ), interface )
341
+ value = self ._run_function_sync (gen .athrow (value ), interface , original_func )
332
342
else :
333
- value = self ._run_function_sync (gen .asend (value ), interface )
343
+ value = self ._run_function_sync (gen .asend (value ), interface , original_func )
334
344
except UserCodeException as uc_exc :
335
345
raise uc_exc .exc from None
336
346
except StopAsyncIteration :
@@ -342,14 +352,14 @@ def _run_generator_sync(self, gen, interface):
342
352
value = exc
343
353
is_exc = True
344
354
345
- async def _run_generator_async (self , gen , interface ):
355
+ async def _run_generator_async (self , gen , interface , original_func ):
346
356
value , is_exc = None , False
347
357
while True :
348
358
try :
349
359
if is_exc :
350
- value = await self ._run_function_async (gen .athrow (value ), interface )
360
+ value = await self ._run_function_async (gen .athrow (value ), interface , original_func )
351
361
else :
352
- value = await self ._run_function_async (gen .asend (value ), interface )
362
+ value = await self ._run_function_async (gen .asend (value ), interface , original_func )
353
363
except UserCodeException as uc_exc :
354
364
raise uc_exc .exc from None
355
365
except StopAsyncIteration :
@@ -403,8 +413,9 @@ def f_wrapped(*args, **kwargs):
403
413
404
414
# If this gets called with an argument that represents an external type,
405
415
# translate it into an internal type
406
- args = self ._translate_in (args )
407
- kwargs = self ._translate_in (kwargs )
416
+ if not getattr (f , self ._no_input_translation_attr , False ):
417
+ args = self ._translate_in (args )
418
+ kwargs = self ._translate_in (kwargs )
408
419
409
420
# Call the function
410
421
res = f (* args , ** kwargs )
@@ -417,14 +428,14 @@ def f_wrapped(*args, **kwargs):
417
428
if not allow_futures :
418
429
raise Exception ("Can not return future for this function" )
419
430
elif is_coroutine :
420
- return self ._run_function_sync_future (res , interface )
431
+ return self ._run_function_sync_future (res , interface , f )
421
432
elif is_asyncgen :
422
433
raise Exception ("Can not return futures for generators" )
423
434
else :
424
435
return res
425
436
elif is_coroutine :
426
437
if interface in (Interface .ASYNC , Interface ._ASYNC_WITH_BLOCKING_TYPES ):
427
- coro = self ._run_function_async (res , interface )
438
+ coro = self ._run_function_async (res , interface , f )
428
439
if not is_coroutinefunction :
429
440
# If this is a non-async function that returns a coroutine,
430
441
# then this is the exit point, and we need to unwrap any
@@ -435,7 +446,7 @@ def f_wrapped(*args, **kwargs):
435
446
elif interface == Interface .BLOCKING :
436
447
# This is the exit point, so we need to unwrap the exception here
437
448
try :
438
- return self ._run_function_sync (res , interface )
449
+ return self ._run_function_sync (res , interface , f )
439
450
except UserCodeException as uc_exc :
440
451
# Used to skip a frame when called from `proxy_method`.
441
452
if unwrap_user_excs and not (Interface .BLOCKING and include_aio_interface ):
@@ -446,9 +457,9 @@ def f_wrapped(*args, **kwargs):
446
457
# Note that the _run_generator_* functions handle their own
447
458
# unwrapping of exceptions (this happens during yielding)
448
459
if interface in (Interface .ASYNC , Interface ._ASYNC_WITH_BLOCKING_TYPES ):
449
- return self ._run_generator_async (res , interface )
460
+ return self ._run_generator_async (res , interface , f )
450
461
elif interface == Interface .BLOCKING :
451
- return self ._run_generator_sync (res , interface )
462
+ return self ._run_generator_sync (res , interface , f )
452
463
else :
453
464
if inspect .isfunction (res ) or isinstance (res , functools .partial ): # TODO: HACKY HACK
454
465
# TODO: this is needed for decorator wrappers that returns functions
@@ -458,11 +469,17 @@ def f_wrapped(*args, **kwargs):
458
469
args = self ._translate_in (args )
459
470
kwargs = self ._translate_in (kwargs )
460
471
f_res = res (* args , ** kwargs )
461
- return self ._translate_out (f_res , interface )
472
+ if not getattr (f , self ._no_output_unwrapping_attr , False ):
473
+ return self ._translate_out (f_res , interface )
474
+ else :
475
+ return f_res
462
476
463
477
return f_wrapped
464
478
465
- return self ._translate_out (res , interface )
479
+ if not getattr (f , self ._no_output_unwrapping_attr , False ):
480
+ return self ._translate_out (res , interface )
481
+ else :
482
+ return res
466
483
467
484
self ._update_wrapper (f_wrapped , f , _name , interface , target_module = target_module )
468
485
setattr (f_wrapped , self ._original_attr , f )
@@ -701,6 +718,17 @@ def nowrap(self, obj):
701
718
setattr (obj , self ._nowrap_attr , True )
702
719
return obj
703
720
721
+ def no_input_translation (self , obj ):
722
+ setattr (obj , self ._no_input_translation_attr , True )
723
+ return obj
724
+
725
+ def no_output_translation (self , obj ):
726
+ setattr (obj , self ._no_output_unwrapping_attr , True )
727
+ return obj
728
+
729
+ def no_io_translation (self , obj ):
730
+ return self .no_input_translation (self .no_output_translation (obj ))
731
+
704
732
# New interface that (almost) doesn't mutate objects
705
733
def create_blocking (self , obj , name : Optional [str ] = None , target_module : Optional [str ] = None ):
706
734
wrapped = self ._wrap (obj , Interface .BLOCKING , name , target_module = target_module )
0 commit comments