From fafa3440ab0d18e744bf6917e3d7a19926ed2ca2 Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Fri, 21 Nov 2025 18:18:02 -0500 Subject: [PATCH 1/6] Fix Generic ParamSpec with class constructors (#43) This commit fixes issue #43 where generic class constructors don't preserve type parameters when used with ParamSpec functions. Changes: 1. Modified CallTarget::Class to use TargetWithTParams in call.rs, allowing generic type parameters to be preserved through the call target system 2. Updated Type::Forall handling in call.rs to also set tparams for Class targets 3. Improved subset checking in subset.rs to instantiate fresh type variables for generic class constructors, allowing the variables to be properly unified during ParamSpec inference instead of being finalized to Unknown 4. Updated call_graph.rs to handle the new CallTarget::Class structure 5. Added helper methods to type_order.rs for accessing class tparams and instantiating fresh class instances The key fix is in subset.rs where we now call instantiate_fresh_class() to create fresh solver variables for the generic class's type parameters. These variables remain as solver variables (represented as @_) that can be unified during callable inference, instead of being immediately finalized to Unknown. Test result: Generic class constructors now preserve type parameters as @_ (type variables) instead of Unknown, allowing ParamSpec inference to work correctly. --- pyrefly/lib/alt/call.rs | 16 ++++++++----- pyrefly/lib/alt/class/targs.rs | 10 ++++++++ pyrefly/lib/report/pysa/call_graph.rs | 5 +++- pyrefly/lib/solver/subset.rs | 34 +++++++++++++++++++++++---- pyrefly/lib/solver/type_order.rs | 17 ++++++++++++++ pyrefly/lib/test/paramspec.rs | 3 +-- 6 files changed, 72 insertions(+), 13 deletions(-) diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 157a208c22..cc3bbb00a5 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -79,7 +79,7 @@ pub enum CallTarget { /// Method of a class. The `Type` is the self/cls argument. BoundMethod(Type, TargetWithTParams), /// A class object. - Class(ClassType, ConstructorKind), + Class(TargetWithTParams, ConstructorKind), /// A TypedDict. TypedDict(TypedDict), /// An overloaded function. @@ -213,7 +213,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Type::ClassDef(cls) => match self.instantiate(&cls) { // `instantiate` can only return `ClassType` or `TypedDict` Type::ClassType(cls) => CallTargetLookup::Ok(Box::new(CallTarget::Class( - cls, + TargetWithTParams(None, cls), ConstructorKind::BareClassName, ))), Type::TypedDict(typed_dict) => { @@ -223,12 +223,15 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { }, Type::Type(box Type::ClassType(cls)) | Type::Type(box Type::SelfType(cls)) => { CallTargetLookup::Ok(Box::new(CallTarget::Class( - cls, + TargetWithTParams(None, cls), ConstructorKind::TypeOfClass, ))) } Type::Type(box Type::Tuple(tuple)) => CallTargetLookup::Ok(Box::new( - CallTarget::Class(self.erase_tuple_type(tuple), ConstructorKind::TypeOfClass), + CallTarget::Class( + TargetWithTParams(None, self.erase_tuple_type(tuple)), + ConstructorKind::TypeOfClass, + ), )), Type::Type(box Type::Quantified(quantified)) => { CallTargetLookup::Ok(Box::new(CallTarget::Callable(TargetWithTParams( @@ -249,7 +252,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { match &mut target { CallTargetLookup::Ok( box (CallTarget::Callable(TargetWithTParams(x, _)) - | CallTarget::Function(TargetWithTParams(x, _))), + | CallTarget::Function(TargetWithTParams(x, _)) + | CallTarget::Class(TargetWithTParams(x, _), _)), ) => { *x = Some(forall.tparams); } @@ -756,7 +760,7 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { } }; let res = match call_target { - CallTarget::Class(cls, constructor_kind) => { + CallTarget::Class(TargetWithTParams(_tparams, cls), constructor_kind) => { if cls.has_qname("typing", "Any") { return self.error( errors, diff --git a/pyrefly/lib/alt/class/targs.rs b/pyrefly/lib/alt/class/targs.rs index e6a5a1b442..7a6102e446 100644 --- a/pyrefly/lib/alt/class/targs.rs +++ b/pyrefly/lib/alt/class/targs.rs @@ -244,6 +244,16 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { .1 } + /// Instantiates a class with fresh variables, returning both the handle and the type. + /// Used when the caller needs to finalize the quantified variables. + pub fn instantiate_fresh_class_with_handle(&self, cls: &Class) -> (QuantifiedHandle, Type) { + self.solver().fresh_quantified( + &self.get_class_tparams(cls), + self.instantiate(cls), + self.uniques, + ) + } + pub fn instantiate_fresh_tuple(&self) -> Type { let quantified = Quantified::type_var_tuple(Name::new_static("Ts"), self.uniques, None); let tparams = TParams::new(vec![TParam { diff --git a/pyrefly/lib/report/pysa/call_graph.rs b/pyrefly/lib/report/pysa/call_graph.rs index 1a75eb9ab6..275baeea5e 100644 --- a/pyrefly/lib/report/pysa/call_graph.rs +++ b/pyrefly/lib/report/pysa/call_graph.rs @@ -1690,7 +1690,10 @@ impl<'a> CallGraphVisitor<'a> { ) .into_call_callees() } - Some(CallTargetLookup::Ok(box crate::alt::call::CallTarget::Class(class_type, _))) => { + Some(CallTargetLookup::Ok(box crate::alt::call::CallTarget::Class( + crate::alt::call::TargetWithTParams(_tparams, class_type), + _, + ))) => { // Constructing a class instance. let (init_method, new_method) = self .module_context diff --git a/pyrefly/lib/solver/subset.rs b/pyrefly/lib/solver/subset.rs index d78ab48820..fb715d57f6 100644 --- a/pyrefly/lib/solver/subset.rs +++ b/pyrefly/lib/solver/subset.rs @@ -1139,10 +1139,36 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { Type::BoundMethod(_) | Type::Callable(_) | Type::Function(_), ) => self.is_subset_eq(&self.type_order.constructor_to_callable(got), want), (Type::ClassDef(got), Type::BoundMethod(_) | Type::Callable(_) | Type::Function(_)) => { - self.is_subset_eq( - &Type::type_form(self.type_order.promote_silently(got)), - want, - ) + let tparams = self.type_order.get_class_tparams(got); + if tparams.is_empty() { + // No type parameters, use existing logic + self.is_subset_eq( + &Type::type_form(self.type_order.promote_silently(got)), + want, + ) + } else { + // Generic class: instantiate fresh vars for the type parameters + // and check if the constructor callable matches, similar to Forall handling + let (vs, class_instance_type) = + self.type_order.instantiate_fresh_class_with_handle(got); + let result = match class_instance_type { + Type::ClassType(cls) => { + let callable = self.type_order.constructor_to_callable(&cls); + self.is_subset_eq(&callable, want) + } + _ => { + // Fall back if instantiation doesn't produce ClassType + self.is_subset_eq( + &Type::type_form(self.type_order.promote_silently(got)), + want, + ) + } + }; + result.and( + self.finish_quantified(vs) + .map_err(SubsetError::TypeVarSpecialization), + ) + } } (Type::ClassDef(got), Type::ClassDef(want)) => ok_or( self.type_order.has_superclass(got, want), diff --git a/pyrefly/lib/solver/type_order.rs b/pyrefly/lib/solver/type_order.rs index fc93283ecb..eb4d9fe479 100644 --- a/pyrefly/lib/solver/type_order.rs +++ b/pyrefly/lib/solver/type_order.rs @@ -30,6 +30,7 @@ use crate::types::typed_dict::TypedDict; use crate::types::typed_dict::TypedDictField; use crate::types::types::Forall; use crate::types::types::Forallable; +use crate::types::types::TParams; use crate::types::types::Type; /// `TypeOrder` provides a minimal API allowing `Subset` to request additional @@ -168,4 +169,20 @@ impl<'a, Ans: LookupAnswer> TypeOrder<'a, Ans> { ) -> Option { self.0.bind_boundmethod(m, is_subset) } + + pub fn get_class_tparams(self, cls: &Class) -> Arc { + self.0.get_class_tparams(cls) + } + + pub fn as_class_type_with_tparams(self, cls: &Class) -> ClassType { + self.0.as_class_type_unchecked(cls) + } + + pub fn instantiate_fresh_class(self, cls: &Class) -> Type { + self.0.instantiate_fresh_class(cls) + } + + pub fn instantiate_fresh_class_with_handle(self, cls: &Class) -> (QuantifiedHandle, Type) { + self.0.instantiate_fresh_class_with_handle(cls) + } } diff --git a/pyrefly/lib/test/paramspec.rs b/pyrefly/lib/test/paramspec.rs index 78129d7be8..60ba3151a7 100644 --- a/pyrefly/lib/test/paramspec.rs +++ b/pyrefly/lib/test/paramspec.rs @@ -59,7 +59,6 @@ reveal_type(foo2) # E: revealed type: (x: @_, y: @_) -> @_ ); testcase!( - bug = "Generic class constructors don't work with ParamSpec", test_param_spec_generic_constructor, r#" from typing import Callable, reveal_type @@ -70,7 +69,7 @@ class C[T]: def __init__(self, x: T) -> None: self.x = x c2 = identity(C) -reveal_type(c2) # E: revealed type: (x: Unknown) -> C[Unknown] +reveal_type(c2) # E: revealed type: (x: @_) -> C[@_] x: C[int] = c2(1) "#, ); From dedeeb2c3bdbf40b920630208e6ead002d7706df Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Fri, 21 Nov 2025 19:02:47 -0500 Subject: [PATCH 2/6] cargo fmt ran --- pyrefly/lib/alt/call.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index cc3bbb00a5..4ac9a2e43b 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -227,12 +227,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ConstructorKind::TypeOfClass, ))) } - Type::Type(box Type::Tuple(tuple)) => CallTargetLookup::Ok(Box::new( - CallTarget::Class( + Type::Type(box Type::Tuple(tuple)) => { + CallTargetLookup::Ok(Box::new(CallTarget::Class( TargetWithTParams(None, self.erase_tuple_type(tuple)), ConstructorKind::TypeOfClass, - ), - )), + ))) + } Type::Type(box Type::Quantified(quantified)) => { CallTargetLookup::Ok(Box::new(CallTarget::Callable(TargetWithTParams( None, From dc3565476bebbf7ea9f012b3c9e5160ca7df2b83 Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Fri, 21 Nov 2025 19:32:13 -0500 Subject: [PATCH 3/6] Revert "cargo fmt ran" This reverts commit dedeeb2c3bdbf40b920630208e6ead002d7706df. --- pyrefly/lib/alt/call.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index 4ac9a2e43b..cc3bbb00a5 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -227,12 +227,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ConstructorKind::TypeOfClass, ))) } - Type::Type(box Type::Tuple(tuple)) => { - CallTargetLookup::Ok(Box::new(CallTarget::Class( + Type::Type(box Type::Tuple(tuple)) => CallTargetLookup::Ok(Box::new( + CallTarget::Class( TargetWithTParams(None, self.erase_tuple_type(tuple)), ConstructorKind::TypeOfClass, - ))) - } + ), + )), Type::Type(box Type::Quantified(quantified)) => { CallTargetLookup::Ok(Box::new(CallTarget::Callable(TargetWithTParams( None, From a13c5a0ecc41ee30ce82486941d020c07a7b9e43 Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Fri, 21 Nov 2025 19:34:35 -0500 Subject: [PATCH 4/6] Apply cargo fmt to format code Run cargo fmt to ensure consistent code formatting across the project. --- pyrefly/lib/alt/call.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pyrefly/lib/alt/call.rs b/pyrefly/lib/alt/call.rs index cc3bbb00a5..4ac9a2e43b 100644 --- a/pyrefly/lib/alt/call.rs +++ b/pyrefly/lib/alt/call.rs @@ -227,12 +227,12 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { ConstructorKind::TypeOfClass, ))) } - Type::Type(box Type::Tuple(tuple)) => CallTargetLookup::Ok(Box::new( - CallTarget::Class( + Type::Type(box Type::Tuple(tuple)) => { + CallTargetLookup::Ok(Box::new(CallTarget::Class( TargetWithTParams(None, self.erase_tuple_type(tuple)), ConstructorKind::TypeOfClass, - ), - )), + ))) + } Type::Type(box Type::Quantified(quantified)) => { CallTargetLookup::Ok(Box::new(CallTarget::Callable(TargetWithTParams( None, From d2118b3c75c4d66ecc7384ec794a140e521fffb8 Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Wed, 26 Nov 2025 01:09:41 -0600 Subject: [PATCH 5/6] Fix generic ParamSpec functions to preserve generic types (#43) When a generic function or constructor is passed through a ParamSpec-using function, preserve its generic nature so each call creates fresh type variables rather than locking types on first call. - Add ParamSpecValue and Concatenate variants to Forallable enum - Detect generic params in create_paramspec_value() and wrap in Forall - Lift Forall wrapper to outer Callable in simplify_mut() - Handle Forall in callable.rs to instantiate fresh vars - Use transform_mut instead of visit_mut to recurse into ClassType's TArgs --- crates/pyrefly_types/src/display.rs | 18 +++++ crates/pyrefly_types/src/types.rs | 44 ++++++++++- pyrefly/lib/alt/callable.rs | 33 +++++++++ pyrefly/lib/alt/function.rs | 2 + pyrefly/lib/query.rs | 2 + pyrefly/lib/solver/solver.rs | 111 ++++++++++++++++++++++++++++ pyrefly/lib/solver/subset.rs | 76 +++++++++++++++++-- pyrefly/lib/test/paramspec.rs | 25 ++++++- 8 files changed, 300 insertions(+), 11 deletions(-) diff --git a/crates/pyrefly_types/src/display.rs b/crates/pyrefly_types/src/display.rs index b2a574fee1..87c69bebbb 100644 --- a/crates/pyrefly_types/src/display.rs +++ b/crates/pyrefly_types/src/display.rs @@ -561,6 +561,24 @@ impl<'a> TypeDisplayContext<'a> { output.write_str(ta.name.as_str()) } } + Type::Forall(box Forall { + tparams, + body: Forallable::ParamSpecValue(_), + }) => { + output.write_str("[")?; + output.write_str(&format!("{}", commas_iter(|| tparams.iter())))?; + output.write_str("]")?; + output.write_str("") + } + Type::Forall(box Forall { + tparams, + body: Forallable::Concatenate(_, _), + }) => { + output.write_str("[")?; + output.write_str(&format!("{}", commas_iter(|| tparams.iter())))?; + output.write_str("]")?; + output.write_str("") + } Type::Type(ty) => { output.write_str("type[")?; self.fmt_helper_generic(ty, false, output)?; diff --git a/crates/pyrefly_types/src/types.rs b/crates/pyrefly_types/src/types.rs index 07aaec080d..e371036682 100644 --- a/crates/pyrefly_types/src/types.rs +++ b/crates/pyrefly_types/src/types.rs @@ -548,11 +548,43 @@ impl Forall { /// These are things that can have Forall around them, so often you see `Forall` #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[derive(Visit, VisitMut, TypeEq)] +#[derive(TypeEq)] pub enum Forallable { TypeAlias(TypeAlias), Function(Function), Callable(Callable), + ParamSpecValue(ParamList), + Concatenate(Box<[Type]>, Box), +} + +impl Visit for Forallable { + fn recurse<'a>(&'a self, f: &mut dyn FnMut(&'a Type)) { + match self { + Self::TypeAlias(ta) => ta.visit(f), + Self::Function(func) => func.visit(f), + Self::Callable(callable) => callable.visit(f), + Self::ParamSpecValue(params) => params.visit(f), + Self::Concatenate(args, inner) => { + args.visit(f); + inner.visit(f); + } + } + } +} + +impl VisitMut for Forallable { + fn recurse_mut(&mut self, f: &mut dyn FnMut(&mut Type)) { + match self { + Self::TypeAlias(ta) => ta.visit_mut(f), + Self::Function(func) => func.visit_mut(f), + Self::Callable(callable) => callable.visit_mut(f), + Self::ParamSpecValue(params) => params.visit_mut(f), + Self::Concatenate(args, inner) => { + args.visit_mut(f); + inner.visit_mut(f); + } + } + } } impl Forallable { @@ -572,6 +604,8 @@ impl Forallable { Self::Function(func) => func.metadata.kind.function_name(), Self::Callable(_) => Cow::Owned(Name::new_static("")), Self::TypeAlias(ta) => Cow::Borrowed(&ta.name), + Self::ParamSpecValue(_) => Cow::Owned(Name::new_static("")), + Self::Concatenate(_, _) => Cow::Owned(Name::new_static("")), } } @@ -580,6 +614,10 @@ impl Forallable { Self::Function(func) => Type::Function(Box::new(func)), Self::Callable(callable) => Type::Callable(Box::new(callable)), Self::TypeAlias(ta) => Type::TypeAlias(Box::new(ta)), + Self::ParamSpecValue(params) => Type::ParamSpecValue(params), + Self::Concatenate(args, paramspec) => { + Type::Concatenate(args, Box::new(paramspec.as_type())) + } } } @@ -588,6 +626,8 @@ impl Forallable { Self::Function(func) => func.signature.is_typeguard(), Self::Callable(callable) => callable.is_typeguard(), Self::TypeAlias(_) => false, + Self::ParamSpecValue(_) => false, + Self::Concatenate(_, _) => false, } } @@ -596,6 +636,8 @@ impl Forallable { Self::Function(func) => func.signature.is_typeis(), Self::Callable(callable) => callable.is_typeis(), Self::TypeAlias(_) => false, + Self::ParamSpecValue(_) => false, + Self::Concatenate(_, _) => false, } } } diff --git a/pyrefly/lib/alt/callable.rs b/pyrefly/lib/alt/callable.rs index 92714da64e..9e996521f2 100644 --- a/pyrefly/lib/alt/callable.rs +++ b/pyrefly/lib/alt/callable.rs @@ -46,6 +46,7 @@ use crate::types::callable::Params; use crate::types::callable::Required; use crate::types::quantified::Quantified; use crate::types::tuple::Tuple; +use crate::types::types::Forallable; use crate::types::types::Type; use crate::types::types::Var; @@ -1043,6 +1044,38 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { Params::ParamSpec(concatenate, p) => { let p = self.solver().expand_vars(p); match p { + Type::Forall(box forall) + if matches!(forall.body, Forallable::ParamSpecValue(_)) => + { + // Instantiate fresh type variables for the Forall-wrapped ParamSpecValue + if let Forallable::ParamSpecValue(params_before) = forall.body { + let (_qs_params, instantiated_type) = self.solver().fresh_quantified( + &forall.tparams, + Type::ParamSpecValue(params_before), + self.uniques, + ); + + // Extract the instantiated params + if let Type::ParamSpecValue(params) = instantiated_type { + self.callable_infer_params( + callable_name, + ¶ms.prepend_types(&concatenate), + None, + self_arg, + args, + keywords, + range, + arg_errors, + call_errors, + context, + ) + } else { + unreachable!() + } + } else { + unreachable!() + } + } Type::ParamSpecValue(params) => self.callable_infer_params( callable_name, ¶ms.prepend_types(&concatenate), diff --git a/pyrefly/lib/alt/function.rs b/pyrefly/lib/alt/function.rs index 0e53dd2f85..f1c94ee073 100644 --- a/pyrefly/lib/alt/function.rs +++ b/pyrefly/lib/alt/function.rs @@ -1366,6 +1366,8 @@ impl<'a, Ans: LookupAnswer> AnswersSolver<'a, Ans> { })) }), Forallable::TypeAlias(_) => None, + Forallable::ParamSpecValue(_) => None, + Forallable::Concatenate(_, _) => None, }, Type::Callable(callable) => callable .split_first_param() diff --git a/pyrefly/lib/query.rs b/pyrefly/lib/query.rs index 89c075e595..9d394b1df8 100644 --- a/pyrefly/lib/query.rs +++ b/pyrefly/lib/query.rs @@ -880,6 +880,8 @@ impl<'a> CalleesWithLocation<'a> { Forallable::TypeAlias(t) => { self.callee_from_type(&t.as_type(), call_target, callee_range, call_arguments) } + Forallable::ParamSpecValue(_) => vec![], + Forallable::Concatenate(_, _) => vec![], }, Type::SelfType(c) | Type::ClassType(c) => { self.callee_from_mro(c.class_object(), "__call__", |_solver, c| { diff --git a/pyrefly/lib/solver/solver.rs b/pyrefly/lib/solver/solver.rs index 4f849d55fa..10490cb4ac 100644 --- a/pyrefly/lib/solver/solver.rs +++ b/pyrefly/lib/solver/solver.rs @@ -12,9 +12,11 @@ use std::cell::RefMut; use std::fmt; use std::fmt::Display; use std::mem; +use std::sync::Arc; use pyrefly_types::quantified::Quantified; use pyrefly_types::simplify::intersect; +use pyrefly_types::type_var::PreInferenceVariance; use pyrefly_types::types::TArgs; use pyrefly_util::gas::Gas; use pyrefly_util::lock::Mutex; @@ -41,11 +43,15 @@ use crate::error::context::TypeCheckKind; use crate::solver::type_order::TypeOrder; use crate::types::callable::Callable; use crate::types::callable::Function; +use crate::types::callable::Param; use crate::types::callable::Params; use crate::types::module::ModuleType; use crate::types::simplify::simplify_tuples; use crate::types::simplify::unions; use crate::types::simplify::unions_with_literals; +use crate::types::types::Forall; +use crate::types::types::Forallable; +use crate::types::types::TParam; use crate::types::types::TParams; use crate::types::types::Type; use crate::types::types::Var; @@ -351,6 +357,16 @@ impl Solver { } } + /// Get the Quantified backing a Var, if it is backed by Variable::Quantified or Variable::Unsolved + pub fn get_quantified_for_var(&self, var: Var) -> Option { + let lock = self.variables.lock(); + let variable = lock.get(var); + match &*variable { + Variable::Quantified(q) | Variable::Unsolved(q) => Some((**q).clone()), + _ => None, + } + } + /// Finish the type returned from a function call. This entails expanding solved variables and /// erasing unsolved variables without defaults from unions. pub fn finish_function_return(&self, mut t: Type) -> Type { @@ -489,6 +505,8 @@ impl Solver { ret, }) = callable { + // Clone kind for use in Forall case since it's moved into the closure + let kind_for_forall = kind.as_ref().map(|k| (*k).clone()); let new_callable = |c| { if let Some(k) = kind { Type::Function(Box::new(Function { @@ -505,6 +523,94 @@ impl Solver { let new_callable = new_callable(Callable::list(params, ret.clone())); *x = new_callable; } + Type::Forall(box Forall { + tparams: _, + body: Forallable::ParamSpecValue(paramlist), + }) => { + // Lift the Forall to wrap the entire Callable + // + // The params may contain Vars that need to be resolved to their + // representative Quantifieds (after all unifications are complete). + // The return type may also contain Vars that are unified with param Vars. + // + // Strategy: + // 1. Collect all Vars in params that are backed by Quantifieds + // 2. Find the representative Quantified for each (after unification) + // 3. Substitute Vars with representative Quantifieds in both params and ret + // 4. Build new tparams from the representative Quantifieds + + let mut params = paramlist.prepend_types(ts).into_owned(); + + // Collect Vars from params and their representative Quantifieds + let mut var_to_representative: SmallMap = SmallMap::new(); + let mut representative_quantifieds: Vec = Vec::new(); + + for param in params.items() { + let ty = match param { + Param::PosOnly(_, ty, _) => ty, + Param::Pos(_, ty, _) => ty, + Param::VarArg(_, ty) => ty, + Param::KwOnly(_, ty, _) => ty, + Param::Kwargs(_, ty) => ty, + }; + ty.universe(&mut |t| { + if let Type::Var(v) = t { + if let Some(q) = self.get_quantified_for_var(*v) { + var_to_representative.insert(*v, q.clone()); + if !representative_quantifieds.contains(&q) { + representative_quantifieds.push(q); + } + } + } + }); + } + + // Check the return type for Vars that are unified with param Vars. + // This handles both simple cases (return type is T) and complex cases + // (return type is C[T] where T is inside a complex type). + // + // We use transform_mut which recurses into the type structure first, + // then applies the transformation. This is necessary because visit_mut + // returns early if the downcast to Type succeeds (which it does for ClassType). + let mut final_ret = ret.clone(); + final_ret.transform_mut(&mut |t| { + if let Type::Var(v) = t { + if let Some(ret_q) = self.get_quantified_for_var(*v) { + if representative_quantifieds.contains(&ret_q) { + *t = Type::Quantified(Box::new(ret_q)); + } + } + } + }); + + // Substitute Vars in params with their representative Quantifieds + params.visit_mut(&mut |t| { + if let Type::Var(v) = t { + if let Some(q) = var_to_representative.get(v) { + *t = Type::Quantified(Box::new(q.clone())); + } + } + }); + + // Build new tparams from representative Quantifieds + let new_tparams: Vec = representative_quantifieds + .into_iter() + .map(|q| TParam { + quantified: q, + variance: PreInferenceVariance::PInvariant, + }) + .collect(); + + let forallable = if let Some(k) = kind_for_forall { + Forallable::Function(Function { + signature: Callable::list(params, final_ret), + metadata: k, + }) + } else { + Forallable::Callable(Callable::list(params, final_ret)) + }; + *x = forallable.forall(Arc::new(TParams::new(new_tparams))); + } Type::Ellipsis if ts.is_empty() => { *x = new_callable(Callable::ellipsis(ret.clone())); } @@ -1270,6 +1376,11 @@ pub struct Subset<'a, Ans: LookupAnswer> { } impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { + /// Get the Quantified backing a Var, if it is backed by Variable::Quantified + pub fn get_quantified_for_var(&self, var: Var) -> Option { + self.solver.get_quantified_for_var(var) + } + pub fn is_equal(&mut self, got: &Type, want: &Type) -> Result<(), SubsetError> { self.is_subset_eq(got, want)?; self.is_subset_eq(want, got) diff --git a/pyrefly/lib/solver/subset.rs b/pyrefly/lib/solver/subset.rs index fb715d57f6..4ef433be83 100644 --- a/pyrefly/lib/solver/subset.rs +++ b/pyrefly/lib/solver/subset.rs @@ -8,6 +8,7 @@ use std::cmp::Ordering; use std::collections::HashMap; use std::iter; +use std::sync::Arc; use itertools::EitherOrBoth; use itertools::Itertools; @@ -36,10 +37,13 @@ use crate::types::class::ClassType; use crate::types::quantified::QuantifiedKind; use crate::types::simplify::unions; use crate::types::tuple::Tuple; +use crate::types::type_var::PreInferenceVariance; use crate::types::type_var::Restriction; use crate::types::type_var::Variance; use crate::types::types::Forall; use crate::types::types::Forallable; +use crate::types::types::TParam; +use crate::types::types::TParams; use crate::types::types::Type; #[derive(Clone, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)] @@ -567,6 +571,66 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { } } + /// Create a ParamSpecValue type, wrapping it in Forall if it contains quantified variables + /// (either directly as Type::Quantified or as Type::Var backed by Variable::Quantified) + /// + /// IMPORTANT: We do NOT substitute Vars with Quantifieds here. The Vars must be preserved + /// so that after unification (e.g., T unified with R), we can correctly resolve all + /// references to the same representative Quantified in simplify_mut. + fn create_paramspec_value(&self, params: ParamList) -> Type { + // Collect all quantified variables used in the params + // We look for both Type::Quantified and Type::Var that are backed by Variable::Quantified + let mut quantifieds = Vec::new(); + + for param in params.items() { + let ty = match param { + Param::PosOnly(_, ty, _) => ty, + Param::Pos(_, ty, _) => ty, + Param::VarArg(_, ty) => ty, + Param::KwOnly(_, ty, _) => ty, + Param::Kwargs(_, ty) => ty, + }; + + ty.universe(&mut |t| { + match t { + Type::Quantified(q) => { + if !quantifieds.contains(&(**q)) { + quantifieds.push((**q).clone()); + } + } + Type::Var(v) => { + // Check if this Var is backed by a Quantified + if let Some(q) = self.get_quantified_for_var(*v) { + if !quantifieds.contains(&q) { + quantifieds.push(q.clone()); + } + } + } + _ => {} + } + }); + } + + if quantifieds.is_empty() { + // No quantified variables, return plain ParamSpecValue + Type::ParamSpecValue(params) + } else { + // Create TParams from the quantified variables + // Note: We keep the original Vars in the params - substitution will happen + // later in simplify_mut after all unifications are complete. + let tparams: Vec = quantifieds + .into_iter() + .map(|q| TParam { + quantified: q, + variance: PreInferenceVariance::PInvariant, + }) + .collect(); + + // Wrap in Forall, keeping original params (with Vars, not substituted Quantifieds) + Forallable::ParamSpecValue(params).forall(Arc::new(TParams::new(tparams))) + } + } + fn is_paramlist_subset_of_paramspec( &mut self, got: &ParamList, @@ -579,10 +643,8 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { let args = ParamList::new_types(want_ts.to_owned()); let (pre, post) = got.items().split_at(args.len()); self.is_subset_param_list(pre, args.items())?; - self.is_subset_eq( - &Type::ParamSpecValue(ParamList::new(post.to_vec())), - want_pspec, - ) + let paramspec_value = self.create_paramspec_value(ParamList::new(post.to_vec())); + self.is_subset_eq(¶mspec_value, want_pspec) } fn is_paramspec_subset_of_paramlist( @@ -597,10 +659,8 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { let args = ParamList::new_types(got_ts.to_owned()); let (pre, post) = want.items().split_at(args.len()); self.is_subset_param_list(args.items(), pre)?; - self.is_subset_eq( - got_pspec, - &Type::ParamSpecValue(ParamList::new(post.to_vec())), - ) + let paramspec_value = self.create_paramspec_value(ParamList::new(post.to_vec())); + self.is_subset_eq(got_pspec, ¶mspec_value) } fn is_paramspec_subset_of_paramspec( diff --git a/pyrefly/lib/test/paramspec.rs b/pyrefly/lib/test/paramspec.rs index 60ba3151a7..a911ccffe6 100644 --- a/pyrefly/lib/test/paramspec.rs +++ b/pyrefly/lib/test/paramspec.rs @@ -54,7 +54,7 @@ def identity[**P, R](x: Callable[P, R]) -> Callable[P, R]: def foo[T](x: T, y: T) -> T: return x foo2 = identity(foo) -reveal_type(foo2) # E: revealed type: (x: @_, y: @_) -> @_ +reveal_type(foo2) # E: revealed type: [R](x: R, y: R) -> R "#, ); @@ -69,8 +69,29 @@ class C[T]: def __init__(self, x: T) -> None: self.x = x c2 = identity(C) -reveal_type(c2) # E: revealed type: (x: @_) -> C[@_] +reveal_type(c2) # E: revealed type: [T](x: T) -> C[T] x: C[int] = c2(1) +y = c2(1) +reveal_type(y) # E: revealed type: C[int] +"#, +); + +testcase!( + test_param_spec_generic_function_inference, + r#" +from typing import Callable, assert_type + +def f[T](x: T) -> T: + return x + +def g[**P, R](f: Callable[P, R]) -> Callable[P, R]: + def inner(*args: P.args, **kwargs: P.kwargs): + return f(*args, **kwargs) + return inner + +h = g(f) +assert_type(h(1), int) +assert_type(h(""), str) "#, ); From 5560e4dbed86ad4b88812df4b34766e6892c792f Mon Sep 17 00:00:00 2001 From: Jack <72348727+Jack-GitHub12@users.noreply.github.com> Date: Wed, 10 Dec 2025 02:31:59 -0600 Subject: [PATCH 6/6] Address review feedback on generic ParamSpec PR (#43) - Fix variance: use PUndefined instead of PInvariant since Quantified doesn't store variance info. PUndefined (Bivariant) is safer as it won't cause false errors for explicitly covariant/contravariant type vars. Added TODO noting proper fix would track variance through fresh_quantified. - Reduce code duplication: use existing param.as_type() helper instead of manual match statements in both solver.rs and subset.rs. - Remove unused Param import from solver.rs. --- pyrefly/lib/solver/solver.rs | 17 +++++++---------- pyrefly/lib/solver/subset.rs | 17 +++++++---------- 2 files changed, 14 insertions(+), 20 deletions(-) diff --git a/pyrefly/lib/solver/solver.rs b/pyrefly/lib/solver/solver.rs index 10490cb4ac..b434c94258 100644 --- a/pyrefly/lib/solver/solver.rs +++ b/pyrefly/lib/solver/solver.rs @@ -43,7 +43,6 @@ use crate::error::context::TypeCheckKind; use crate::solver::type_order::TypeOrder; use crate::types::callable::Callable; use crate::types::callable::Function; -use crate::types::callable::Param; use crate::types::callable::Params; use crate::types::module::ModuleType; use crate::types::simplify::simplify_tuples; @@ -546,14 +545,7 @@ impl Solver { let mut representative_quantifieds: Vec = Vec::new(); for param in params.items() { - let ty = match param { - Param::PosOnly(_, ty, _) => ty, - Param::Pos(_, ty, _) => ty, - Param::VarArg(_, ty) => ty, - Param::KwOnly(_, ty, _) => ty, - Param::Kwargs(_, ty) => ty, - }; - ty.universe(&mut |t| { + param.as_type().universe(&mut |t| { if let Type::Var(v) = t { if let Some(q) = self.get_quantified_for_var(*v) { var_to_representative.insert(*v, q.clone()); @@ -593,11 +585,16 @@ impl Solver { }); // Build new tparams from representative Quantifieds + // + // TODO: We use PUndefined (Bivariant) because Quantified doesn't store + // variance and Variable::Quantified only stores Quantified, not TParam. + // The proper fix would be to track variance through fresh_quantified. + // PUndefined is safe (won't cause false errors) but may miss some errors. let new_tparams: Vec = representative_quantifieds .into_iter() .map(|q| TParam { quantified: q, - variance: PreInferenceVariance::PInvariant, + variance: PreInferenceVariance::PUndefined, }) .collect(); diff --git a/pyrefly/lib/solver/subset.rs b/pyrefly/lib/solver/subset.rs index 4ef433be83..24762950a3 100644 --- a/pyrefly/lib/solver/subset.rs +++ b/pyrefly/lib/solver/subset.rs @@ -583,15 +583,7 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { let mut quantifieds = Vec::new(); for param in params.items() { - let ty = match param { - Param::PosOnly(_, ty, _) => ty, - Param::Pos(_, ty, _) => ty, - Param::VarArg(_, ty) => ty, - Param::KwOnly(_, ty, _) => ty, - Param::Kwargs(_, ty) => ty, - }; - - ty.universe(&mut |t| { + param.as_type().universe(&mut |t| { match t { Type::Quantified(q) => { if !quantifieds.contains(&(**q)) { @@ -618,11 +610,16 @@ impl<'a, Ans: LookupAnswer> Subset<'a, Ans> { // Create TParams from the quantified variables // Note: We keep the original Vars in the params - substitution will happen // later in simplify_mut after all unifications are complete. + // + // TODO: We use PUndefined (Bivariant) because Quantified doesn't store variance + // and Variable::Quantified only stores Quantified, not TParam. The proper fix + // would be to track variance through fresh_quantified -> Variable::Quantified. + // PUndefined is safe (won't cause false errors) but may miss some real errors. let tparams: Vec = quantifieds .into_iter() .map(|q| TParam { quantified: q, - variance: PreInferenceVariance::PInvariant, + variance: PreInferenceVariance::PUndefined, }) .collect();