Skip to content

Commit 27891c5

Browse files
committed
Add Callable trait
1 parent b988efc commit 27891c5

File tree

144 files changed

+937
-924
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

144 files changed

+937
-924
lines changed

compiler/rustc_error_codes/src/error_codes/E0183.md

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ impl FnOnce<()> for MyClosure { // ok!
3131
println!("{}", self.foo);
3232
}
3333
}
34+
35+
impl std::ops::Callable<()> for MyClosure {}
3436
```
3537

3638
The arguments must be a tuple representing the argument list.

compiler/rustc_hir/src/lang_items.rs

+1
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,7 @@ language_item_table! {
207207
Fn, kw::Fn, fn_trait, Target::Trait, GenericRequirement::Exact(1);
208208
FnMut, sym::fn_mut, fn_mut_trait, Target::Trait, GenericRequirement::Exact(1);
209209
FnOnce, sym::fn_once, fn_once_trait, Target::Trait, GenericRequirement::Exact(1);
210+
Callable, sym::callable, callable_trait, Target::Trait, GenericRequirement::Exact(1);
210211

211212
FnOnceOutput, sym::fn_once_output, fn_once_output, Target::AssocTy, GenericRequirement::None;
212213

compiler/rustc_hir_typeck/src/callee.rs

+55-9
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,14 @@ use super::method::MethodCallee;
33
use super::{Expectation, FnCtxt, TupleArgumentsFlag};
44

55
use crate::type_error_struct;
6+
use hir::LangItem;
67
use rustc_ast::util::parser::PREC_POSTFIX;
78
use rustc_errors::{struct_span_err, Applicability, Diagnostic, ErrorGuaranteed, StashKey};
89
use rustc_hir as hir;
910
use rustc_hir::def::{self, CtorKind, Namespace, Res};
1011
use rustc_hir::def_id::DefId;
1112
use rustc_hir_analysis::autoderef::Autoderef;
13+
use rustc_infer::traits::ObligationCauseCode;
1214
use rustc_infer::{
1315
infer,
1416
traits::{self, Obligation},
@@ -22,7 +24,6 @@ use rustc_middle::ty::adjustment::{
2224
};
2325
use rustc_middle::ty::SubstsRef;
2426
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
25-
use rustc_span::def_id::LocalDefId;
2627
use rustc_span::symbol::{sym, Ident};
2728
use rustc_span::Span;
2829
use rustc_target::spec::abi;
@@ -66,7 +67,7 @@ pub fn check_legal_trait_for_method_call(
6667
#[derive(Debug)]
6768
enum CallStep<'tcx> {
6869
Builtin(Ty<'tcx>),
69-
DeferredClosure(LocalDefId, ty::FnSig<'tcx>),
70+
DeferredClosure(Ty<'tcx>, ty::FnSig<'tcx>),
7071
/// E.g., enum variant constructors.
7172
Overloaded(MethodCallee<'tcx>),
7273
}
@@ -173,7 +174,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
173174
closure_substs: substs,
174175
},
175176
);
176-
return Some(CallStep::DeferredClosure(def_id, closure_sig));
177+
return Some(CallStep::DeferredClosure(adjusted_ty, closure_sig));
177178
}
178179
}
179180

@@ -375,7 +376,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
375376
arg_exprs: &'tcx [hir::Expr<'tcx>],
376377
expected: Expectation<'tcx>,
377378
) -> Ty<'tcx> {
378-
let (fn_sig, def_id) = match *callee_ty.kind() {
379+
let fn_sig = match *callee_ty.kind() {
379380
ty::FnDef(def_id, subst) => {
380381
let fn_sig = self.tcx.fn_sig(def_id).subst(self.tcx, subst);
381382

@@ -403,9 +404,9 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
403404
.emit();
404405
}
405406
}
406-
(fn_sig, Some(def_id))
407+
fn_sig
407408
}
408-
ty::FnPtr(sig) => (sig, None),
409+
ty::FnPtr(sig) => sig,
409410
_ => {
410411
for arg in arg_exprs {
411412
self.check_expr(arg);
@@ -459,7 +460,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
459460
arg_exprs,
460461
fn_sig.c_variadic,
461462
TupleArgumentsFlag::DontTupleArguments,
462-
def_id,
463+
callee_ty,
464+
);
465+
466+
self.check_callable(
467+
call_expr.hir_id,
468+
call_expr.span,
469+
callee_ty,
470+
fn_sig.inputs().iter().copied(),
463471
);
464472

465473
if fn_sig.abi == abi::Abi::RustCall {
@@ -705,12 +713,37 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
705713
err.emit()
706714
}
707715

716+
/// Enforces that things being called actually are callable
717+
#[instrument(skip(self, arguments))]
718+
pub(super) fn check_callable(
719+
&self,
720+
hir_id: hir::HirId,
721+
span: Span,
722+
callable_ty: Ty<'tcx>,
723+
arguments: impl IntoIterator<Item = Ty<'tcx>>,
724+
) {
725+
if callable_ty.references_error() {
726+
return;
727+
}
728+
729+
let cause = self.cause(span, ObligationCauseCode::MiscObligation);
730+
731+
let arguments_tuple = self.tcx.mk_tup_from_iter(arguments.into_iter());
732+
let pred = ty::Binder::dummy(ty::TraitRef::from_lang_item(
733+
self.tcx,
734+
LangItem::Callable,
735+
span,
736+
[callable_ty, arguments_tuple],
737+
));
738+
self.register_predicate(Obligation::new(self.tcx, cause, self.param_env, pred));
739+
}
740+
708741
fn confirm_deferred_closure_call(
709742
&self,
710743
call_expr: &'tcx hir::Expr<'tcx>,
711744
arg_exprs: &'tcx [hir::Expr<'tcx>],
712745
expected: Expectation<'tcx>,
713-
closure_def_id: LocalDefId,
746+
closure_ty: Ty<'tcx>,
714747
fn_sig: ty::FnSig<'tcx>,
715748
) -> Ty<'tcx> {
716749
// `fn_sig` is the *signature* of the closure being called. We
@@ -733,7 +766,14 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
733766
arg_exprs,
734767
fn_sig.c_variadic,
735768
TupleArgumentsFlag::TupleArguments,
736-
Some(closure_def_id.to_def_id()),
769+
closure_ty,
770+
);
771+
772+
self.check_callable(
773+
call_expr.hir_id,
774+
call_expr.span,
775+
closure_ty,
776+
fn_sig.inputs()[0].tuple_fields(),
737777
);
738778

739779
fn_sig.output()
@@ -804,6 +844,12 @@ impl<'a, 'tcx> DeferredCallResolution<'tcx> {
804844
let mut adjustments = self.adjustments;
805845
adjustments.extend(autoref);
806846
fcx.apply_adjustments(self.callee_expr, adjustments);
847+
fcx.check_callable(
848+
self.callee_expr.hir_id,
849+
self.callee_expr.span,
850+
fcx.tcx.mk_fn_def(method_callee.def_id, method_callee.substs),
851+
method_sig.inputs().iter().copied(),
852+
);
807853

808854
fcx.write_method_call(self.call_expr.hir_id, method_callee);
809855
}

compiler/rustc_hir_typeck/src/fn_ctxt/_impl.rs

+1
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
313313
result
314314
}
315315

316+
#[instrument(level = "trace", skip(self, span), ret)]
316317
pub(in super::super) fn normalize<T>(&self, span: Span, value: T) -> T
317318
where
318319
T: TypeFoldable<TyCtxt<'tcx>>,

compiler/rustc_hir_typeck/src/fn_ctxt/checks.rs

+21-8
Original file line numberDiff line numberDiff line change
@@ -94,14 +94,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
9494
expected: Expectation<'tcx>,
9595
) -> Ty<'tcx> {
9696
let method = method.ok().filter(|method| !method.references_error());
97-
let Some(method) = method else {
97+
let Some(method) = method else {
9898
let err_inputs = self.err_args(args_no_rcvr.len());
9999

100100
let err_inputs = match tuple_arguments {
101101
DontTupleArguments => err_inputs,
102102
TupleArguments => vec![self.tcx.mk_tup(&err_inputs)],
103103
};
104104

105+
let err = self.tcx.ty_error_misc();
106+
let callee_ty = method.map_or(err, |method| self.tcx.mk_fn_def(method.def_id, method.substs));
105107
self.check_argument_types(
106108
sp,
107109
expr,
@@ -110,9 +112,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
110112
args_no_rcvr,
111113
false,
112114
tuple_arguments,
113-
method.map(|method| method.def_id),
115+
callee_ty,
114116
);
115-
return self.tcx.ty_error_misc();
117+
self.check_callable(expr.hir_id, sp, callee_ty, std::iter::once(err).chain(err_inputs));
118+
return err;
116119
};
117120

118121
// HACK(eddyb) ignore self in the definition (see above).
@@ -122,6 +125,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
122125
method.sig.output(),
123126
&method.sig.inputs()[1..],
124127
);
128+
let callee_ty = self.tcx.mk_fn_def(method.def_id, method.substs);
125129
self.check_argument_types(
126130
sp,
127131
expr,
@@ -130,14 +134,17 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
130134
args_no_rcvr,
131135
method.sig.c_variadic,
132136
tuple_arguments,
133-
Some(method.def_id),
137+
callee_ty,
134138
);
135139

140+
self.check_callable(expr.hir_id, sp, callee_ty, method.sig.inputs().iter().copied());
141+
136142
method.sig.output()
137143
}
138144

139145
/// Generic function that factors out common logic from function calls,
140146
/// method calls and overloaded operators.
147+
#[instrument(level = "trace", skip(self, call_expr, provided_args))]
141148
pub(in super::super) fn check_argument_types(
142149
&self,
143150
// Span enclosing the call site
@@ -154,8 +161,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
154161
c_variadic: bool,
155162
// Whether the arguments have been bundled in a tuple (ex: closures)
156163
tuple_arguments: TupleArgumentsFlag,
157-
// The DefId for the function being called, for better error messages
158-
fn_def_id: Option<DefId>,
164+
// The callee type (e.g. function item ZST, function pointer or closure).
165+
// Note that this may have surprising function definitions, like often just
166+
// referring to `FnOnce::call_once`, but with an appropriate `Self` type.
167+
callee_ty: Ty<'tcx>,
159168
) {
160169
let tcx = self.tcx;
161170

@@ -230,7 +239,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
230239
let minimum_input_count = expected_input_tys.len();
231240
let provided_arg_count = provided_args.len();
232241

233-
let is_const_eval_select = matches!(fn_def_id, Some(def_id) if
242+
let is_const_eval_select = matches!(*callee_ty.kind(), ty::FnDef(def_id, _) if
234243
self.tcx.def_kind(def_id) == hir::def::DefKind::Fn
235244
&& self.tcx.is_intrinsic(def_id)
236245
&& self.tcx.item_name(def_id) == sym::const_eval_select);
@@ -456,7 +465,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
456465
provided_args,
457466
c_variadic,
458467
err_code,
459-
fn_def_id,
468+
match *callee_ty.kind() {
469+
ty::Generator(did, ..) | ty::Closure(did, _) | ty::FnDef(did, _) => Some(did),
470+
ty::FnPtr(..) | ty::Error(_) => None,
471+
ref kind => span_bug!(call_span, "invalid call argument type: {kind:?}"),
472+
},
460473
call_span,
461474
call_expr,
462475
);

compiler/rustc_hir_typeck/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -430,7 +430,7 @@ fn report_unexpected_variant_res(
430430
/// # fn f(x: (isize, isize)) {}
431431
/// f((1, 2));
432432
/// ```
433-
#[derive(Copy, Clone, Eq, PartialEq)]
433+
#[derive(Copy, Clone, Eq, PartialEq, Debug)]
434434
enum TupleArgumentsFlag {
435435
DontTupleArguments,
436436
TupleArguments,

compiler/rustc_hir_typeck/src/method/confirm.rs

+12-5
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
8686
ConfirmContext { fcx, span, self_expr, call_expr, skip_record_for_diagnostics: false }
8787
}
8888

89+
#[instrument(level = "trace", skip(self), ret)]
8990
fn confirm(
9091
&mut self,
9192
unadjusted_self_ty: Ty<'tcx>,
@@ -99,7 +100,8 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
99100
let rcvr_substs = self.fresh_receiver_substs(self_ty, &pick);
100101
let all_substs = self.instantiate_method_substs(&pick, segment, rcvr_substs);
101102

102-
debug!("rcvr_substs={rcvr_substs:?}, all_substs={all_substs:?}");
103+
debug!(?rcvr_substs);
104+
debug!(?all_substs);
103105

104106
// Create the final signature for the method, replacing late-bound regions.
105107
let (method_sig, method_predicates) = self.instantiate_method_sig(&pick, all_substs);
@@ -125,11 +127,9 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
125127
// could alter our Self-type, except for normalizing the receiver from the
126128
// signature (which is also done during probing).
127129
let method_sig_rcvr = self.normalize(self.span, method_sig.inputs()[0]);
128-
debug!(
129-
"confirm: self_ty={:?} method_sig_rcvr={:?} method_sig={:?} method_predicates={:?}",
130-
self_ty, method_sig_rcvr, method_sig, method_predicates
131-
);
130+
debug!(?self_ty, ?method_sig_rcvr, ?pick);
132131
self.unify_receivers(self_ty, method_sig_rcvr, &pick, all_substs);
132+
let inputs = method_sig.inputs();
133133

134134
let (method_sig, method_predicates) =
135135
self.normalize(self.span, (method_sig, method_predicates));
@@ -150,6 +150,13 @@ impl<'a, 'tcx> ConfirmContext<'a, 'tcx> {
150150
);
151151
}
152152

153+
self.check_callable(
154+
self.call_expr.hir_id,
155+
self.span,
156+
self.tcx.mk_fn_def(pick.item.def_id, all_substs),
157+
inputs.iter().copied(),
158+
);
159+
153160
// Create the final `MethodCallee`.
154161
let callee = MethodCallee {
155162
def_id: pick.item.def_id,

compiler/rustc_hir_typeck/src/method/mod.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ use rustc_hir as hir;
1818
use rustc_hir::def::{CtorOf, DefKind, Namespace};
1919
use rustc_hir::def_id::DefId;
2020
use rustc_infer::infer::{self, InferOk};
21+
use rustc_macros::TypeVisitable;
2122
use rustc_middle::query::Providers;
2223
use rustc_middle::traits::ObligationCause;
2324
use rustc_middle::ty::subst::{InternalSubsts, SubstsRef};
@@ -33,7 +34,7 @@ pub fn provide(providers: &mut Providers) {
3334
probe::provide(providers);
3435
}
3536

36-
#[derive(Clone, Copy, Debug)]
37+
#[derive(Clone, Copy, Debug, TypeVisitable)]
3738
pub struct MethodCallee<'tcx> {
3839
/// Impl method ID, for inherent methods, or trait method ID, otherwise.
3940
pub def_id: DefId,

compiler/rustc_middle/src/ty/consts/int.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,7 @@ impl TryFrom<ScalarInt> for char {
418418

419419
#[inline]
420420
fn try_from(int: ScalarInt) -> Result<Self, Self::Error> {
421-
let Ok(bits) = int.to_bits(Size::from_bytes(std::mem::size_of::<char>())) else {
421+
let Ok(bits) = int.to_bits(Size::from_bytes(std::mem::size_of::<char>())) else {
422422
return Err(CharTryFromScalarInt);
423423
};
424424
match char::from_u32(bits.try_into().unwrap()) {

compiler/rustc_span/src/symbol.rs

+1
Original file line numberDiff line numberDiff line change
@@ -450,6 +450,7 @@ symbols! {
450450
call,
451451
call_mut,
452452
call_once,
453+
callable,
453454
caller_location,
454455
capture_disjoint_fields,
455456
cause,

compiler/rustc_trait_selection/src/solve/assembly/mod.rs

+9
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,12 @@ pub(super) trait GoalKind<'tcx>:
236236
goal: Goal<'tcx, Self>,
237237
) -> QueryResult<'tcx>;
238238

239+
// `Callable` is implemented for all function items, function definitions and closures.
240+
fn consider_builtin_callable_candidate(
241+
ecx: &mut EvalCtxt<'_, 'tcx>,
242+
goal: Goal<'tcx, Self>,
243+
) -> QueryResult<'tcx>;
244+
239245
// A generator (that comes from an `async` desugaring) is known to implement
240246
// `Future<Output = O>`, where `O` is given by the generator's return type
241247
// that was computed during type-checking.
@@ -284,6 +290,7 @@ pub(super) trait GoalKind<'tcx>:
284290
}
285291

286292
impl<'tcx> EvalCtxt<'_, 'tcx> {
293+
#[instrument(level = "trace", skip(self), ret)]
287294
pub(super) fn assemble_and_evaluate_candidates<G: GoalKind<'tcx>>(
288295
&mut self,
289296
goal: Goal<'tcx, G>,
@@ -430,6 +437,8 @@ impl<'tcx> EvalCtxt<'_, 'tcx> {
430437
G::consider_builtin_tuple_candidate(self, goal)
431438
} else if lang_items.pointee_trait() == Some(trait_def_id) {
432439
G::consider_builtin_pointee_candidate(self, goal)
440+
} else if lang_items.callable_trait() == Some(trait_def_id) {
441+
G::consider_builtin_callable_candidate(self, goal)
433442
} else if lang_items.future_trait() == Some(trait_def_id) {
434443
G::consider_builtin_future_candidate(self, goal)
435444
} else if lang_items.gen_trait() == Some(trait_def_id) {

0 commit comments

Comments
 (0)