Skip to content

Commit 4c05366

Browse files
committed
Split fn_ctxt/adjust_fulfillment_errors from fn_ctxt/checks
1 parent e1eaa2d commit 4c05366

File tree

2 files changed

+374
-370
lines changed

2 files changed

+374
-370
lines changed

compiler/rustc_hir_typeck/src/fn_ctxt/adjust_fulfillment_errors.rs

+373-1
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,382 @@
11
use crate::FnCtxt;
22
use rustc_hir as hir;
33
use rustc_hir::def::Res;
4-
use rustc_middle::ty::{self, DefIdTree, Ty};
4+
use rustc_hir::def_id::DefId;
5+
use rustc_infer::traits::ObligationCauseCode;
6+
use rustc_middle::ty::{self, DefIdTree, Ty, TypeSuperVisitable, TypeVisitable, TypeVisitor};
7+
use rustc_span::{self, Span};
58
use rustc_trait_selection::traits;
69

10+
use std::ops::ControlFlow;
11+
712
impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
13+
pub fn adjust_fulfillment_error_for_expr_obligation(
14+
&self,
15+
error: &mut traits::FulfillmentError<'tcx>,
16+
) -> bool {
17+
let (traits::ExprItemObligation(def_id, hir_id, idx) | traits::ExprBindingObligation(def_id, _, hir_id, idx))
18+
= *error.obligation.cause.code().peel_derives() else { return false; };
19+
let hir = self.tcx.hir();
20+
let hir::Node::Expr(expr) = hir.get(hir_id) else { return false; };
21+
22+
let Some(unsubstituted_pred) =
23+
self.tcx.predicates_of(def_id).instantiate_identity(self.tcx).predicates.into_iter().nth(idx)
24+
else { return false; };
25+
26+
let generics = self.tcx.generics_of(def_id);
27+
let predicate_substs = match unsubstituted_pred.kind().skip_binder() {
28+
ty::PredicateKind::Clause(ty::Clause::Trait(pred)) => pred.trait_ref.substs,
29+
ty::PredicateKind::Clause(ty::Clause::Projection(pred)) => pred.projection_ty.substs,
30+
_ => ty::List::empty(),
31+
};
32+
33+
let find_param_matching = |matches: &dyn Fn(&ty::ParamTy) -> bool| {
34+
predicate_substs.types().find_map(|ty| {
35+
ty.walk().find_map(|arg| {
36+
if let ty::GenericArgKind::Type(ty) = arg.unpack()
37+
&& let ty::Param(param_ty) = ty.kind()
38+
&& matches(param_ty)
39+
{
40+
Some(arg)
41+
} else {
42+
None
43+
}
44+
})
45+
})
46+
};
47+
48+
// Prefer generics that are local to the fn item, since these are likely
49+
// to be the cause of the unsatisfied predicate.
50+
let mut param_to_point_at = find_param_matching(&|param_ty| {
51+
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) == def_id
52+
});
53+
// Fall back to generic that isn't local to the fn item. This will come
54+
// from a trait or impl, for example.
55+
let mut fallback_param_to_point_at = find_param_matching(&|param_ty| {
56+
self.tcx.parent(generics.type_param(param_ty, self.tcx).def_id) != def_id
57+
&& param_ty.name != rustc_span::symbol::kw::SelfUpper
58+
});
59+
// Finally, the `Self` parameter is possibly the reason that the predicate
60+
// is unsatisfied. This is less likely to be true for methods, because
61+
// method probe means that we already kinda check that the predicates due
62+
// to the `Self` type are true.
63+
let mut self_param_to_point_at =
64+
find_param_matching(&|param_ty| param_ty.name == rustc_span::symbol::kw::SelfUpper);
65+
66+
// Finally, for ambiguity-related errors, we actually want to look
67+
// for a parameter that is the source of the inference type left
68+
// over in this predicate.
69+
if let traits::FulfillmentErrorCode::CodeAmbiguity = error.code {
70+
fallback_param_to_point_at = None;
71+
self_param_to_point_at = None;
72+
param_to_point_at =
73+
self.find_ambiguous_parameter_in(def_id, error.root_obligation.predicate);
74+
}
75+
76+
if self.closure_span_overlaps_error(error, expr.span) {
77+
return false;
78+
}
79+
80+
match &expr.kind {
81+
hir::ExprKind::Path(qpath) => {
82+
if let hir::Node::Expr(hir::Expr {
83+
kind: hir::ExprKind::Call(callee, args),
84+
hir_id: call_hir_id,
85+
span: call_span,
86+
..
87+
}) = hir.get_parent(expr.hir_id)
88+
&& callee.hir_id == expr.hir_id
89+
{
90+
if self.closure_span_overlaps_error(error, *call_span) {
91+
return false;
92+
}
93+
94+
for param in
95+
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
96+
.into_iter()
97+
.flatten()
98+
{
99+
if self.blame_specific_arg_if_possible(
100+
error,
101+
def_id,
102+
param,
103+
*call_hir_id,
104+
callee.span,
105+
None,
106+
args,
107+
)
108+
{
109+
return true;
110+
}
111+
}
112+
}
113+
// Notably, we only point to params that are local to the
114+
// item we're checking, since those are the ones we are able
115+
// to look in the final `hir::PathSegment` for. Everything else
116+
// would require a deeper search into the `qpath` than I think
117+
// is worthwhile.
118+
if let Some(param_to_point_at) = param_to_point_at
119+
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
120+
{
121+
return true;
122+
}
123+
}
124+
hir::ExprKind::MethodCall(segment, receiver, args, ..) => {
125+
for param in [param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
126+
.into_iter()
127+
.flatten()
128+
{
129+
if self.blame_specific_arg_if_possible(
130+
error,
131+
def_id,
132+
param,
133+
hir_id,
134+
segment.ident.span,
135+
Some(receiver),
136+
args,
137+
) {
138+
return true;
139+
}
140+
}
141+
if let Some(param_to_point_at) = param_to_point_at
142+
&& self.point_at_generic_if_possible(error, def_id, param_to_point_at, segment)
143+
{
144+
return true;
145+
}
146+
}
147+
hir::ExprKind::Struct(qpath, fields, ..) => {
148+
if let Res::Def(
149+
hir::def::DefKind::Struct | hir::def::DefKind::Variant,
150+
variant_def_id,
151+
) = self.typeck_results.borrow().qpath_res(qpath, hir_id)
152+
{
153+
for param in
154+
[param_to_point_at, fallback_param_to_point_at, self_param_to_point_at]
155+
{
156+
if let Some(param) = param {
157+
let refined_expr = self.point_at_field_if_possible(
158+
def_id,
159+
param,
160+
variant_def_id,
161+
fields,
162+
);
163+
164+
match refined_expr {
165+
None => {}
166+
Some((refined_expr, _)) => {
167+
error.obligation.cause.span = refined_expr
168+
.span
169+
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
170+
.unwrap_or(refined_expr.span);
171+
return true;
172+
}
173+
}
174+
}
175+
}
176+
}
177+
if let Some(param_to_point_at) = param_to_point_at
178+
&& self.point_at_path_if_possible(error, def_id, param_to_point_at, qpath)
179+
{
180+
return true;
181+
}
182+
}
183+
_ => {}
184+
}
185+
186+
false
187+
}
188+
189+
fn point_at_path_if_possible(
190+
&self,
191+
error: &mut traits::FulfillmentError<'tcx>,
192+
def_id: DefId,
193+
param: ty::GenericArg<'tcx>,
194+
qpath: &hir::QPath<'tcx>,
195+
) -> bool {
196+
match qpath {
197+
hir::QPath::Resolved(_, path) => {
198+
if let Some(segment) = path.segments.last()
199+
&& self.point_at_generic_if_possible(error, def_id, param, segment)
200+
{
201+
return true;
202+
}
203+
}
204+
hir::QPath::TypeRelative(_, segment) => {
205+
if self.point_at_generic_if_possible(error, def_id, param, segment) {
206+
return true;
207+
}
208+
}
209+
_ => {}
210+
}
211+
212+
false
213+
}
214+
215+
fn point_at_generic_if_possible(
216+
&self,
217+
error: &mut traits::FulfillmentError<'tcx>,
218+
def_id: DefId,
219+
param_to_point_at: ty::GenericArg<'tcx>,
220+
segment: &hir::PathSegment<'tcx>,
221+
) -> bool {
222+
let own_substs = self
223+
.tcx
224+
.generics_of(def_id)
225+
.own_substs(ty::InternalSubsts::identity_for_item(self.tcx, def_id));
226+
let Some((index, _)) = own_substs
227+
.iter()
228+
.filter(|arg| matches!(arg.unpack(), ty::GenericArgKind::Type(_)))
229+
.enumerate()
230+
.find(|(_, arg)| **arg == param_to_point_at) else { return false };
231+
let Some(arg) = segment
232+
.args()
233+
.args
234+
.iter()
235+
.filter(|arg| matches!(arg, hir::GenericArg::Type(_)))
236+
.nth(index) else { return false; };
237+
error.obligation.cause.span = arg
238+
.span()
239+
.find_ancestor_in_same_ctxt(error.obligation.cause.span)
240+
.unwrap_or(arg.span());
241+
true
242+
}
243+
244+
fn find_ambiguous_parameter_in<T: TypeVisitable<'tcx>>(
245+
&self,
246+
item_def_id: DefId,
247+
t: T,
248+
) -> Option<ty::GenericArg<'tcx>> {
249+
struct FindAmbiguousParameter<'a, 'tcx>(&'a FnCtxt<'a, 'tcx>, DefId);
250+
impl<'tcx> TypeVisitor<'tcx> for FindAmbiguousParameter<'_, 'tcx> {
251+
type BreakTy = ty::GenericArg<'tcx>;
252+
fn visit_ty(&mut self, ty: Ty<'tcx>) -> std::ops::ControlFlow<Self::BreakTy> {
253+
if let Some(origin) = self.0.type_var_origin(ty)
254+
&& let rustc_infer::infer::type_variable::TypeVariableOriginKind::TypeParameterDefinition(_, Some(def_id)) =
255+
origin.kind
256+
&& let generics = self.0.tcx.generics_of(self.1)
257+
&& let Some(index) = generics.param_def_id_to_index(self.0.tcx, def_id)
258+
&& let Some(subst) = ty::InternalSubsts::identity_for_item(self.0.tcx, self.1)
259+
.get(index as usize)
260+
{
261+
ControlFlow::Break(*subst)
262+
} else {
263+
ty.super_visit_with(self)
264+
}
265+
}
266+
}
267+
t.visit_with(&mut FindAmbiguousParameter(self, item_def_id)).break_value()
268+
}
269+
270+
fn closure_span_overlaps_error(
271+
&self,
272+
error: &traits::FulfillmentError<'tcx>,
273+
span: Span,
274+
) -> bool {
275+
if let traits::FulfillmentErrorCode::CodeSelectionError(
276+
traits::SelectionError::OutputTypeParameterMismatch(_, expected, _),
277+
) = error.code
278+
&& let ty::Closure(def_id, _) | ty::Generator(def_id, ..) = expected.skip_binder().self_ty().kind()
279+
&& span.overlaps(self.tcx.def_span(*def_id))
280+
{
281+
true
282+
} else {
283+
false
284+
}
285+
}
286+
287+
fn point_at_field_if_possible(
288+
&self,
289+
def_id: DefId,
290+
param_to_point_at: ty::GenericArg<'tcx>,
291+
variant_def_id: DefId,
292+
expr_fields: &[hir::ExprField<'tcx>],
293+
) -> Option<(&'tcx hir::Expr<'tcx>, Ty<'tcx>)> {
294+
let def = self.tcx.adt_def(def_id);
295+
296+
let identity_substs = ty::InternalSubsts::identity_for_item(self.tcx, def_id);
297+
let fields_referencing_param: Vec<_> = def
298+
.variant_with_id(variant_def_id)
299+
.fields
300+
.iter()
301+
.filter(|field| {
302+
let field_ty = field.ty(self.tcx, identity_substs);
303+
Self::find_param_in_ty(field_ty.into(), param_to_point_at)
304+
})
305+
.collect();
306+
307+
if let [field] = fields_referencing_param.as_slice() {
308+
for expr_field in expr_fields {
309+
// Look for the ExprField that matches the field, using the
310+
// same rules that check_expr_struct uses for macro hygiene.
311+
if self.tcx.adjust_ident(expr_field.ident, variant_def_id) == field.ident(self.tcx)
312+
{
313+
return Some((expr_field.expr, self.tcx.type_of(field.did)));
314+
}
315+
}
316+
}
317+
318+
None
319+
}
320+
321+
/// - `blame_specific_*` means that the function will recursively traverse the expression,
322+
/// looking for the most-specific-possible span to blame.
323+
///
324+
/// - `point_at_*` means that the function will only go "one level", pointing at the specific
325+
/// expression mentioned.
326+
///
327+
/// `blame_specific_arg_if_possible` will find the most-specific expression anywhere inside
328+
/// the provided function call expression, and mark it as responsible for the fullfillment
329+
/// error.
330+
fn blame_specific_arg_if_possible(
331+
&self,
332+
error: &mut traits::FulfillmentError<'tcx>,
333+
def_id: DefId,
334+
param_to_point_at: ty::GenericArg<'tcx>,
335+
call_hir_id: hir::HirId,
336+
callee_span: Span,
337+
receiver: Option<&'tcx hir::Expr<'tcx>>,
338+
args: &'tcx [hir::Expr<'tcx>],
339+
) -> bool {
340+
let ty = self.tcx.type_of(def_id);
341+
if !ty.is_fn() {
342+
return false;
343+
}
344+
let sig = ty.fn_sig(self.tcx).skip_binder();
345+
let args_referencing_param: Vec<_> = sig
346+
.inputs()
347+
.iter()
348+
.enumerate()
349+
.filter(|(_, ty)| Self::find_param_in_ty((**ty).into(), param_to_point_at))
350+
.collect();
351+
// If there's one field that references the given generic, great!
352+
if let [(idx, _)] = args_referencing_param.as_slice()
353+
&& let Some(arg) = receiver
354+
.map_or(args.get(*idx), |rcvr| if *idx == 0 { Some(rcvr) } else { args.get(*idx - 1) }) {
355+
356+
error.obligation.cause.span = arg.span.find_ancestor_in_same_ctxt(error.obligation.cause.span).unwrap_or(arg.span);
357+
358+
if let hir::Node::Expr(arg_expr) = self.tcx.hir().get(arg.hir_id) {
359+
// This is more specific than pointing at the entire argument.
360+
self.blame_specific_expr_if_possible(error, arg_expr)
361+
}
362+
363+
error.obligation.cause.map_code(|parent_code| {
364+
ObligationCauseCode::FunctionArgumentObligation {
365+
arg_hir_id: arg.hir_id,
366+
call_hir_id,
367+
parent_code,
368+
}
369+
});
370+
return true;
371+
} else if args_referencing_param.len() > 0 {
372+
// If more than one argument applies, then point to the callee span at least...
373+
// We have chance to fix this up further in `point_at_generics_if_possible`
374+
error.obligation.cause.span = callee_span;
375+
}
376+
377+
false
378+
}
379+
8380
/**
9381
* Recursively searches for the most-specific blamable expression.
10382
* For example, if you have a chain of constraints like:

0 commit comments

Comments
 (0)