5
5
since these may assume that MROs are ready.
6
6
"""
7
7
from collections import defaultdict
8
- from typing import cast , Optional , List , Sequence , Set , Iterable , TypeVar , Dict , Tuple , Any , Union
8
+ from typing import (
9
+ cast , Optional , List , Sequence , Set , Iterable , TypeVar , Dict , Tuple , Any , Union , Callable
10
+ )
9
11
from typing_extensions import Type as TypingType
10
12
import itertools
11
13
import sys
@@ -919,7 +921,8 @@ def separate_union_literals(t: UnionType) -> Tuple[Sequence[LiteralType], Sequen
919
921
return literal_items , union_items
920
922
921
923
922
- def infer_impl_from_parts (impl : OverloadPart , types : List [CallableType ], fallback : Instance ):
924
+ def infer_impl_from_parts (impl : OverloadPart , types : List [CallableType ], fallback : Instance ,
925
+ named_type : Callable [[str , List [Type ]], Type ]):
923
926
impl_func = impl if isinstance (impl , FuncDef ) else impl .func
924
927
# infer the types of the impl from the overload types
925
928
arg_types : Dict [str , List [Type ]] = defaultdict (list )
@@ -930,8 +933,13 @@ def infer_impl_from_parts(impl: OverloadPart, types: List[CallableType], fallbac
930
933
if arg_name and arg_name in impl_func .arg_names :
931
934
if arg_type not in arg_types [arg_name ]:
932
935
arg_types [arg_name ].append (arg_type )
933
- if tp .ret_type not in ret_types :
934
- ret_types .append (tp .ret_type )
936
+ t = get_proper_type (tp .ret_type )
937
+ if isinstance (t , Instance ) and t .type .fullname == "typing.Coroutine" :
938
+ ret_type = t .args [2 ]
939
+ else :
940
+ ret_type = tp .ret_type
941
+ if ret_type not in ret_types :
942
+ ret_types .append (ret_type )
935
943
arg_types2 = {
936
944
name : UnionType .make_union (it )
937
945
for name , it in arg_types .items ()
@@ -943,6 +951,12 @@ def infer_impl_from_parts(impl: OverloadPart, types: List[CallableType], fallbac
943
951
for arg_name , arg_kind in zip (impl_func .arg_names , impl_func .arg_kinds )
944
952
]
945
953
ret_type = UnionType .make_union (ret_types )
954
+
955
+ if impl_func .is_coroutine :
956
+ # if the impl is a coroutine, then assume the parts are also, if not need annotation
957
+ any_type = AnyType (TypeOfAny .special_form )
958
+ ret_type = named_type ("typing.Coroutine" , [any_type , any_type , ret_type ])
959
+
946
960
# use unanalyzed_type because we would have already tried to infer from defaults
947
961
if impl_func .unanalyzed_type :
948
962
assert isinstance (impl_func .unanalyzed_type , CallableType )
0 commit comments