diff --git a/crates/pyrefly_types/src/display.rs b/crates/pyrefly_types/src/display.rs index b2a574fee1..87c69bebbb 100644 --- a/crates/pyrefly_types/src/display.rs +++ b/crates/pyrefly_types/src/display.rs @@ -561,6 +561,24 @@ impl<'a> TypeDisplayContext<'a> { output.write_str(ta.name.as_str()) } } + Type::Forall(box Forall { + tparams, + body: Forallable::ParamSpecValue(_), + }) => { + output.write_str("[")?; + output.write_str(&format!("{}", commas_iter(|| tparams.iter())))?; + output.write_str("]")?; + output.write_str("") + } + Type::Forall(box Forall { + tparams, + body: Forallable::Concatenate(_, _), + }) => { + output.write_str("[")?; + output.write_str(&format!("{}", commas_iter(|| tparams.iter())))?; + output.write_str("]")?; + output.write_str("") + } Type::Type(ty) => { output.write_str("type[")?; self.fmt_helper_generic(ty, false, output)?; diff --git a/crates/pyrefly_types/src/types.rs b/crates/pyrefly_types/src/types.rs index 07aaec080d..e371036682 100644 --- a/crates/pyrefly_types/src/types.rs +++ b/crates/pyrefly_types/src/types.rs @@ -548,11 +548,43 @@ impl Forall { /// These are things that can have Forall around them, so often you see `Forall` #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[derive(Visit, VisitMut, TypeEq)] +#[derive(TypeEq)] pub enum Forallable { TypeAlias(TypeAlias), Function(Function), Callable(Callable), + ParamSpecValue(ParamList), + Concatenate(Box<[Type]>, Box), +} + +impl Visit for Forallable { + fn recurse<'a>(&'a self, f: &mut dyn FnMut(&'a Type)) { + match self { + Self::TypeAlias(ta) => ta.visit(f), + Self::Function(func) => func.visit(f), + Self::Callable(callable) => callable.visit(f), + Self::ParamSpecValue(params) => params.visit(f), + Self::Concatenate(args, inner) => { + args.visit(f); + inner.visit(f); + } + } + } +} + +impl VisitMut for Forallable { + fn recurse_mut(&mut self, f: &mut dyn FnMut(&mut Type)) { + match self { + Self::TypeAlias(ta) => ta.visit_mut(f), + Self::Function(func) => func.visit_mut(f), + Self::Callable(callable) => callable.visit_mut(f), + Self::ParamSpecValue(params) => params.visit_mut(f), + Self::Concatenate(args, inner) => { + args.visit_mut(f); + inner.visit_mut(f); + } + } + } } impl Forallable { @@ -572,6 +604,8 @@ impl Forallable { Self::Function(func) => func.metadata.kind.function_name(), Self::Callable(_) => Cow::Owned(Name::new_static("")), Self::TypeAlias(ta) => Cow::Borrowed(&ta.name), + Self::ParamSpecValue(_) => Cow::Owned(Name::new_static("")), + Self::Concatenate(_, _) => Cow::Owned(Name::new_static("")), } } @@ -580,6 +614,10 @@ impl Forallable { Self::Function(func) => Type::Function(Box::new(func)), Self::Callable(callable) => Type::Callable(Box::new(callable)), Self::TypeAlias(ta) => Type::TypeAlias(Box::new(ta)), + Self::ParamSpecValue(params) => Type::ParamSpecValue(params), + Self::Concatenate(args, paramspec) => { + Type::Concatenate(args, Box::new(paramspec.as_type())) + } } } @@ -588,6 +626,8 @@ impl Forallable { Self::Function(func) => func.signature.is_typeguard(), Self::Callable(callable) => callable.is_typeguard(), Self::TypeAlias(_) => false, + Self::ParamSpecValue(_) => false, + Self::Concatenate(_, _) => false, } } @@ -596,6 +636,8 @@ impl Forallable { Self::Function(func) => func.signature.is_typeis(), Self::Callable(callable) => callable.is_typeis(), Self::TypeAlias(_) => false, + Self::ParamSpecValue(_) => false, + Self::Concatenate(_, _) => false, } } } diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 157a208c22..4ac9a2e43b 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -79,7 +79,7 @@ pub enum CallTarget { /// Method of a class. The `Type` is the self/cls argument. BoundMethod(Type, TargetWithTParams), /// A class object. - Class(ClassType, ConstructorKind), + Class(TargetWithTParams, ConstructorKind), /// A TypedDict. TypedDict(TypedDict), /// An overloaded function. @@ -213,7 +213,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Type::ClassDef(cls) => match self.instantiate(&cls) { // `instantiate` can only return `ClassType` or `TypedDict` Type::ClassType(cls) => CallTargetLookup::Ok(Box::new(CallTarget::Class( - cls, + TargetWithTParams(None, cls), ConstructorKind::BareClassName, ))), Type::TypedDict(typed_dict) => { @@ -223,13 +223,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }, Type::Type(box Type::ClassType(cls)) | Type::Type(box Type::SelfType(cls)) => { CallTargetLookup::Ok(Box::new(CallTarget::Class( - cls, + TargetWithTParams(None, cls), + ConstructorKind::TypeOfClass, + ))) + } + Type::Type(box Type::Tuple(tuple)) => { + CallTargetLookup::Ok(Box::new(CallTarget::Class( + TargetWithTParams(None, self.erase_tuple_type(tuple)), ConstructorKind::TypeOfClass, ))) } - Type::Type(box Type::Tuple(tuple)) => CallTargetLookup::Ok(Box::new( - CallTarget::Class(self.erase_tuple_type(tuple), ConstructorKind::TypeOfClass), - )), Type::Type(box Type::Quantified(quantified)) => { CallTargetLookup::Ok(Box::new(CallTarget::Callable(TargetWithTParams( None, @@ -249,7 +252,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { match &mut target { CallTargetLookup::Ok( box (CallTarget::Callable(TargetWithTParams(x, _)) - | CallTarget::Function(TargetWithTParams(x, _))), + | CallTarget::Function(TargetWithTParams(x, _)) + | CallTarget::Class(TargetWithTParams(x, _), _)), ) => { *x = Some(forall.tparams); } @@ -756,7 +760,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } }; let res = match call_target { - CallTarget::Class(cls, constructor_kind) => { + CallTarget::Class(TargetWithTParams(_tparams, cls), constructor_kind) => { if cls.has_qname("typing", "Any") { return self.error( errors, diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 92714da64e..9e996521f2 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -46,6 +46,7 @@ use crate::types::callable::Params; use crate::types::callable::Required; use crate::types::quantified::Quantified; use crate::types::tuple::Tuple; +use crate::types::types::Forallable; use crate::types::types::Type; use crate::types::types::Var; @@ -1043,6 +1044,38 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Params::ParamSpec(concatenate, p) => { let p = self.solver().expand_vars(p); match p { + Type::Forall(box forall) + if matches!(forall.body, Forallable::ParamSpecValue(_)) => + { + // Instantiate fresh type variables for the Forall-wrapped ParamSpecValue + if let Forallable::ParamSpecValue(params_before) = forall.body { + let (_qs_params, instantiated_type) = self.solver().fresh_quantified( + &forall.tparams, + Type::ParamSpecValue(params_before), + self.uniques, + ); + + // Extract the instantiated params + if let Type::ParamSpecValue(params) = instantiated_type { + self.callable_infer_params( + callable_name, + ¶ms.prepend_types(&concatenate), + None, + self_arg, + args, + keywords, + range, + arg_errors, + call_errors, + context, + ) + } else { + unreachable!() + } + } else { + unreachable!() + } + } Type::ParamSpecValue(params) => self.callable_infer_params( callable_name, ¶ms.prepend_types(&concatenate), diff --git a/pyrefly/lib/alt/class/targs.rs b/pyrefly/lib/alt/class/targs.rs index e6a5a1b442..7a6102e446 100644 --- a/pyrefly/lib/alt/class/targs.rs +++ b/pyrefly/lib/alt/class/targs.rs @@ -244,6 +244,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .1 } + /// Instantiates a class with fresh variables, returning both the handle and the type. + /// Used when the caller needs to finalize the quantified variables. + pub fn instantiate_fresh_class_with_handle(&self, cls: &Class) -> (QuantifiedHandle, Type) { + self.solver().fresh_quantified( + &self.get_class_tparams(cls), + self.instantiate(cls), + self.uniques, + ) + } + pub fn instantiate_fresh_tuple(&self) -> Type { let quantified = Quantified::type_var_tuple(Name::new_static("Ts"), self.uniques, None); let tparams = TParams::new(vec![TParam { diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 0e53dd2f85..f1c94ee073 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -1366,6 +1366,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { })) }), Forallable::TypeAlias(_) => None, + Forallable::ParamSpecValue(_) => None, + Forallable::Concatenate(_, _) => None, }, Type::Callable(callable) => callable .split_first_param() diff --git a/pyrefly/lib/query.rs b/pyrefly/lib/query.rs index 89c075e595..9d394b1df8 100644 --- a/pyrefly/lib/query.rs +++ b/pyrefly/lib/query.rs @@ -880,6 +880,8 @@ impl<'a> CalleesWithLocation<'a> { Forallable::TypeAlias(t) => { self.callee_from_type(&t.as_type(), call_target, callee_range, call_arguments) } + Forallable::ParamSpecValue(_) => vec![], + Forallable::Concatenate(_, _) => vec![], }, Type::SelfType(c) | Type::ClassType(c) => { self.callee_from_mro(c.class_object(), "__call__", |_solver, c| { diff --git a/pyrefly/lib/report/pysa/call_graph.rs b/pyrefly/lib/report/pysa/call_graph.rs index 1a75eb9ab6..275baeea5e 100644 --- a/pyrefly/lib/report/pysa/call_graph.rs +++ b/pyrefly/lib/report/pysa/call_graph.rs @@ -1690,7 +1690,10 @@ impl<'a> CallGraphVisitor<'a> { ) .into_call_callees() } - Some(CallTargetLookup::Ok(box crate::alt::call::CallTarget::Class(class_type, _))) => { + Some(CallTargetLookup::Ok(box crate::alt::call::CallTarget::Class( + crate::alt::call::TargetWithTParams(_tparams, class_type), + _, + ))) => { // Constructing a class instance. let (init_method, new_method) = self .module_context diff --git a/pyrefly/lib/solver/solver.rs b/pyrefly/lib/solver/solver.rs index 4f849d55fa..b434c94258 100644 --- a/pyrefly/lib/solver/solver.rs +++ b/pyrefly/lib/solver/solver.rs @@ -12,9 +12,11 @@ use std::cell::RefMut; use std::fmt; use std::fmt::Display; use std::mem; +use std::sync::Arc; use pyrefly_types::quantified::Quantified; use pyrefly_types::simplify::intersect; +use pyrefly_types::type_var::PreInferenceVariance; use pyrefly_types::types::TArgs; use pyrefly_util::gas::Gas; use pyrefly_util::lock::Mutex; @@ -46,6 +48,9 @@ use crate::types::module::ModuleType; use crate::types::simplify::simplify_tuples; use crate::types::simplify::unions; use crate::types::simplify::unions_with_literals; +use crate::types::types::Forall; +use crate::types::types::Forallable; +use crate::types::types::TParam; use crate::types::types::TParams; use crate::types::types::Type; use crate::types::types::Var; @@ -351,6 +356,16 @@ impl Solver { } } + /// Get the Quantified backing a Var, if it is backed by Variable::Quantified or Variable::Unsolved + pub fn get_quantified_for_var(&self, var: Var) -> Option { + let lock = self.variables.lock(); + let variable = lock.get(var); + match &*variable { + Variable::Quantified(q) | Variable::Unsolved(q) => Some((**q).clone()), + _ => None, + } + } + /// Finish the type returned from a function call. This entails expanding solved variables and /// erasing unsolved variables without defaults from unions. pub fn finish_function_return(&self, mut t: Type) -> Type { @@ -489,6 +504,8 @@ impl Solver { ret, }) = callable { + // Clone kind for use in Forall case since it's moved into the closure + let kind_for_forall = kind.as_ref().map(|k| (*k).clone()); let new_callable = |c| { if let Some(k) = kind { Type::Function(Box::new(Function { @@ -505,6 +522,92 @@ impl Solver { let new_callable = new_callable(Callable::list(params, ret.clone())); *x = new_callable; } + Type::Forall(box Forall { + tparams: _, + body: Forallable::ParamSpecValue(paramlist), + }) => { + // Lift the Forall to wrap the entire Callable + // + // The params may contain Vars that need to be resolved to their + // representative Quantifieds (after all unifications are complete). + // The return type may also contain Vars that are unified with param Vars. + // + // Strategy: + // 1. Collect all Vars in params that are backed by Quantifieds + // 2. Find the representative Quantified for each (after unification) + // 3. Substitute Vars with representative Quantifieds in both params and ret + // 4. Build new tparams from the representative Quantifieds + + let mut params = paramlist.prepend_types(ts).into_owned(); + + // Collect Vars from params and their representative Quantifieds + let mut var_to_representative: SmallMap = SmallMap::new(); + let mut representative_quantifieds: Vec = Vec::new(); + + for param in params.items() { + param.as_type().universe(&mut |t| { + if let Type::Var(v) = t { + if let Some(q) = self.get_quantified_for_var(*v) { + var_to_representative.insert(*v, q.clone()); + if !representative_quantifieds.contains(&q) { + representative_quantifieds.push(q); + } + } + } + }); + } + + // Check the return type for Vars that are unified with param Vars. + // This handles both simple cases (return type is T) and complex cases + // (return type is C[T] where T is inside a complex type). + // + // We use transform_mut which recurses into the type structure first, + // then applies the transformation. This is necessary because visit_mut + // returns early if the downcast to Type succeeds (which it does for ClassType). + let mut final_ret = ret.clone(); + final_ret.transform_mut(&mut |t| { + if let Type::Var(v) = t { + if let Some(ret_q) = self.get_quantified_for_var(*v) { + if representative_quantifieds.contains(&ret_q) { + *t = Type::Quantified(Box::new(ret_q)); + } + } + } + }); + + // Substitute Vars in params with their representative Quantifieds + params.visit_mut(&mut |t| { + if let Type::Var(v) = t { + if let Some(q) = var_to_representative.get(v) { + *t = Type::Quantified(Box::new(q.clone())); + } + } + }); + + // Build new tparams from representative Quantifieds + // + // TODO: We use PUndefined (Bivariant) because Quantified doesn't store + // variance and Variable::Quantified only stores Quantified, not TParam. + // The proper fix would be to track variance through fresh_quantified. + // PUndefined is safe (won't cause false errors) but may miss some errors. + let new_tparams: Vec = representative_quantifieds + .into_iter() + .map(|q| TParam { + quantified: q, + variance: PreInferenceVariance::PUndefined, + }) + .collect(); + + let forallable = if let Some(k) = kind_for_forall { + Forallable::Function(Function { + signature: Callable::list(params, final_ret), + metadata: k, + }) + } else { + Forallable::Callable(Callable::list(params, final_ret)) + }; + *x = forallable.forall(Arc::new(TParams::new(new_tparams))); + } Type::Ellipsis if ts.is_empty() => { *x = new_callable(Callable::ellipsis(ret.clone())); } @@ -1270,6 +1373,11 @@ pub struct Subset<'a, Ans: LookupAnswer> { } impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { + /// Get the Quantified backing a Var, if it is backed by Variable::Quantified + pub fn get_quantified_for_var(&self, var: Var) -> Option { + self.solver.get_quantified_for_var(var) + } + pub fn is_equal(&mut self, got: &Type, want: &Type) -> Result<(), SubsetError> { self.is_subset_eq(got, want)?; self.is_subset_eq(want, got) diff --git a/pyrefly/lib/solver/subset.rs b/pyrefly/lib/solver/subset.rs index d78ab48820..24762950a3 100644 --- a/pyrefly/lib/solver/subset.rs +++ b/pyrefly/lib/solver/subset.rs @@ -8,6 +8,7 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::iter; +use std::sync::Arc; use itertools::EitherOrBoth; use itertools::Itertools; @@ -36,10 +37,13 @@ use crate::types::class::ClassType; use crate::types::quantified::QuantifiedKind; use crate::types::simplify::unions; use crate::types::tuple::Tuple; +use crate::types::type_var::PreInferenceVariance; use crate::types::type_var::Restriction; use crate::types::type_var::Variance; use crate::types::types::Forall; use crate::types::types::Forallable; +use crate::types::types::TParam; +use crate::types::types::TParams; use crate::types::types::Type; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -567,6 +571,63 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { } } + /// Create a ParamSpecValue type, wrapping it in Forall if it contains quantified variables + /// (either directly as Type::Quantified or as Type::Var backed by Variable::Quantified) + /// + /// IMPORTANT: We do NOT substitute Vars with Quantifieds here. The Vars must be preserved + /// so that after unification (e.g., T unified with R), we can correctly resolve all + /// references to the same representative Quantified in simplify_mut. + fn create_paramspec_value(&self, params: ParamList) -> Type { + // Collect all quantified variables used in the params + // We look for both Type::Quantified and Type::Var that are backed by Variable::Quantified + let mut quantifieds = Vec::new(); + + for param in params.items() { + param.as_type().universe(&mut |t| { + match t { + Type::Quantified(q) => { + if !quantifieds.contains(&(**q)) { + quantifieds.push((**q).clone()); + } + } + Type::Var(v) => { + // Check if this Var is backed by a Quantified + if let Some(q) = self.get_quantified_for_var(*v) { + if !quantifieds.contains(&q) { + quantifieds.push(q.clone()); + } + } + } + _ => {} + } + }); + } + + if quantifieds.is_empty() { + // No quantified variables, return plain ParamSpecValue + Type::ParamSpecValue(params) + } else { + // Create TParams from the quantified variables + // Note: We keep the original Vars in the params - substitution will happen + // later in simplify_mut after all unifications are complete. + // + // TODO: We use PUndefined (Bivariant) because Quantified doesn't store variance + // and Variable::Quantified only stores Quantified, not TParam. The proper fix + // would be to track variance through fresh_quantified -> Variable::Quantified. + // PUndefined is safe (won't cause false errors) but may miss some real errors. + let tparams: Vec = quantifieds + .into_iter() + .map(|q| TParam { + quantified: q, + variance: PreInferenceVariance::PUndefined, + }) + .collect(); + + // Wrap in Forall, keeping original params (with Vars, not substituted Quantifieds) + Forallable::ParamSpecValue(params).forall(Arc::new(TParams::new(tparams))) + } + } + fn is_paramlist_subset_of_paramspec( &mut self, got: &ParamList, @@ -579,10 +640,8 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { let args = ParamList::new_types(want_ts.to_owned()); let (pre, post) = got.items().split_at(args.len()); self.is_subset_param_list(pre, args.items())?; - self.is_subset_eq( - &Type::ParamSpecValue(ParamList::new(post.to_vec())), - want_pspec, - ) + let paramspec_value = self.create_paramspec_value(ParamList::new(post.to_vec())); + self.is_subset_eq(¶mspec_value, want_pspec) } fn is_paramspec_subset_of_paramlist( @@ -597,10 +656,8 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { let args = ParamList::new_types(got_ts.to_owned()); let (pre, post) = want.items().split_at(args.len()); self.is_subset_param_list(args.items(), pre)?; - self.is_subset_eq( - got_pspec, - &Type::ParamSpecValue(ParamList::new(post.to_vec())), - ) + let paramspec_value = self.create_paramspec_value(ParamList::new(post.to_vec())); + self.is_subset_eq(got_pspec, ¶mspec_value) } fn is_paramspec_subset_of_paramspec( @@ -1139,10 +1196,36 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { Type::BoundMethod(_) | Type::Callable(_) | Type::Function(_), ) => self.is_subset_eq(&self.type_order.constructor_to_callable(got), want), (Type::ClassDef(got), Type::BoundMethod(_) | Type::Callable(_) | Type::Function(_)) => { - self.is_subset_eq( - &Type::type_form(self.type_order.promote_silently(got)), - want, - ) + let tparams = self.type_order.get_class_tparams(got); + if tparams.is_empty() { + // No type parameters, use existing logic + self.is_subset_eq( + &Type::type_form(self.type_order.promote_silently(got)), + want, + ) + } else { + // Generic class: instantiate fresh vars for the type parameters + // and check if the constructor callable matches, similar to Forall handling + let (vs, class_instance_type) = + self.type_order.instantiate_fresh_class_with_handle(got); + let result = match class_instance_type { + Type::ClassType(cls) => { + let callable = self.type_order.constructor_to_callable(&cls); + self.is_subset_eq(&callable, want) + } + _ => { + // Fall back if instantiation doesn't produce ClassType + self.is_subset_eq( + &Type::type_form(self.type_order.promote_silently(got)), + want, + ) + } + }; + result.and( + self.finish_quantified(vs) + .map_err(SubsetError::TypeVarSpecialization), + ) + } } (Type::ClassDef(got), Type::ClassDef(want)) => ok_or( self.type_order.has_superclass(got, want), diff --git a/pyrefly/lib/solver/type_order.rs b/pyrefly/lib/solver/type_order.rs index fc93283ecb..eb4d9fe479 100644 --- a/pyrefly/lib/solver/type_order.rs +++ b/pyrefly/lib/solver/type_order.rs @@ -30,6 +30,7 @@ use crate::types::typed_dict::TypedDict; use crate::types::typed_dict::TypedDictField; use crate::types::types::Forall; use crate::types::types::Forallable; +use crate::types::types::TParams; use crate::types::types::Type; /// `TypeOrder` provides a minimal API allowing `Subset` to request additional @@ -168,4 +169,20 @@ impl<'a, Ans: LookupAnswer> TypeOrder<'a, Ans> { ) -> Option { self.0.bind_boundmethod(m, is_subset) } + + pub fn get_class_tparams(self, cls: &Class) -> Arc { + self.0.get_class_tparams(cls) + } + + pub fn as_class_type_with_tparams(self, cls: &Class) -> ClassType { + self.0.as_class_type_unchecked(cls) + } + + pub fn instantiate_fresh_class(self, cls: &Class) -> Type { + self.0.instantiate_fresh_class(cls) + } + + pub fn instantiate_fresh_class_with_handle(self, cls: &Class) -> (QuantifiedHandle, Type) { + self.0.instantiate_fresh_class_with_handle(cls) + } } diff --git a/pyrefly/lib/test/paramspec.rs b/pyrefly/lib/test/paramspec.rs index 78129d7be8..a911ccffe6 100644 --- a/pyrefly/lib/test/paramspec.rs +++ b/pyrefly/lib/test/paramspec.rs @@ -54,12 +54,11 @@ def identity[**P, R](x: Callable[P, R]) -> Callable[P, R]: def foo[T](x: T, y: T) -> T: return x foo2 = identity(foo) -reveal_type(foo2) # E: revealed type: (x: @_, y: @_) -> @_ +reveal_type(foo2) # E: revealed type: [R](x: R, y: R) -> R "#, ); testcase!( - bug = "Generic class constructors don't work with ParamSpec", test_param_spec_generic_constructor, r#" from typing import Callable, reveal_type @@ -70,8 +69,29 @@ class C[T]: def __init__(self, x: T) -> None: self.x = x c2 = identity(C) -reveal_type(c2) # E: revealed type: (x: Unknown) -> C[Unknown] +reveal_type(c2) # E: revealed type: [T](x: T) -> C[T] x: C[int] = c2(1) +y = c2(1) +reveal_type(y) # E: revealed type: C[int] +"#, +); + +testcase!( + test_param_spec_generic_function_inference, + r#" +from typing import Callable, assert_type + +def f[T](x: T) -> T: + return x + +def g[**P, R](f: Callable[P, R]) -> Callable[P, R]: + def inner(*args: P.args, **kwargs: P.kwargs): + return f(*args, **kwargs) + return inner + +h = g(f) +assert_type(h(1), int) +assert_type(h(""), str) "#, );