diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 286ef0dab6ae..2f872212640a 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2278,6 +2278,22 @@ def get_arg_infer_passes( # run(test, 1, 2) # we will use `test` for inference, since it will allow to infer also # argument *names* for P <: [x: int, y: int]. + if isinstance(p_actual, UnionType): + new_items = [] + for item in p_actual.items: + # narrow the union based on some approximations + p_item = get_proper_type(item) + if isinstance(p_item, CallableType) or ( + isinstance(p_item, Instance) + and find_member("__call__", p_item, p_item, is_operator=True) + is not None + ): + new_items.append(p_item) + if len(new_items) == 2: + break + + if len(new_items) == 1: + p_actual = new_items[0] if isinstance(p_actual, Instance): call_method = find_member("__call__", p_actual, p_actual, is_operator=True) if call_method is not None: diff --git a/mypy/typeops.py b/mypy/typeops.py index 1667e8431a17..0afaaa9a7a3c 100644 --- a/mypy/typeops.py +++ b/mypy/typeops.py @@ -458,21 +458,20 @@ def callable_corresponding_argument( if by_name is not None and by_pos is not None: if by_name == by_pos: return by_name - # If we're dealing with an optional pos-only and an optional - # name-only arg, merge them. This is the case for all functions - # taking both *args and **args, or a pair of functions like so: + # If we're dealing with an optional pos and an optional + # name arg, merge them. This is the case for all functions + # taking both *args and **args, or a functions like so: # def right(a: int = ...) -> None: ... - # def left(__a: int = ..., *, a: int = ...) -> None: ... - from mypy.subtypes import is_equivalent + # def left1(__a: int = ..., *, a: int = ...) -> None: ... + # def left2(x: int = ..., a: int = ...) -> None: ... - if ( - not (by_name.required or by_pos.required) - and by_pos.name is None - and by_name.pos is None - and is_equivalent(by_name.typ, by_pos.typ) - ): - return FormalArgument(by_name.name, by_pos.pos, by_name.typ, False) + from mypy.meet import meet_types + + if not (by_name.required or by_pos.required): + return FormalArgument( + by_name.name, by_pos.pos, meet_types(by_pos.typ, by_name.typ), False + ) return by_name if by_name is not None else by_pos diff --git a/test-data/unit/check-overloading.test b/test-data/unit/check-overloading.test index 2092f99487b0..949f6d3053f8 100644 --- a/test-data/unit/check-overloading.test +++ b/test-data/unit/check-overloading.test @@ -6768,3 +6768,34 @@ class D(Generic[T]): a: D[str] # E: Type argument "str" of "D" must be a subtype of "C" reveal_type(a.f(1)) # N: Revealed type is "builtins.int" reveal_type(a.f("x")) # N: Revealed type is "builtins.str" + +[case testOverloadWithTwoRelevantArgsWithDifferentType] +from typing import overload, Union + +@overload +def set(year: int) -> None: + ... + +@overload +def set() -> None: + ... + +# no error here: +def set(*args: object, **kw: int) -> None: + pass +[builtins fixtures/tuple.pyi] + +[case testOverloadWithTwoRelevantOptionalArgs] +from typing import overload + +@overload +def set(year: int) -> None: + ... + +@overload +def set() -> None: + ... + +# no error: +def set(x: int = 42, year: int = 42) -> None: + pass diff --git a/test-data/unit/check-parameter-specification.test b/test-data/unit/check-parameter-specification.test index 5530bc0ecbf9..d61f3f9b4536 100644 --- a/test-data/unit/check-parameter-specification.test +++ b/test-data/unit/check-parameter-specification.test @@ -2560,3 +2560,21 @@ def fn(f: MiddlewareFactory[P]) -> Capture[P]: ... reveal_type(fn(ServerErrorMiddleware)) # N: Revealed type is "__main__.Capture[[handler: Union[builtins.str, None] =, debug: builtins.bool =]]" [builtins fixtures/paramspec.pyi] + +[case testParamSpecInferenceWithAny] +from typing_extensions import ParamSpec +from typing import Any, Callable, Union + +P = ParamSpec("P") + +def into(f: Callable[P, None], *args: P.args, **kwargs: P.kwargs) -> None: + return None + +class C: + def f(self, y: bool = False, *, x: int = 42) -> None: + return None + +ex: Union[C, Any] = C() + +into(ex.f, x=-1) +[builtins fixtures/paramspec.pyi]