Skip to content

Commit 0ad1543

Browse files
committed
WIP: implement pin-project for &pin mut|const T
1 parent fda35a6 commit 0ad1543

File tree

12 files changed

+209
-19
lines changed

12 files changed

+209
-19
lines changed

compiler/rustc_feature/src/builtin_attrs.rs

+6
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,12 @@ pub static BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
583583
EncodeCrossCrate::Yes, min_generic_const_args, experimental!(type_const),
584584
),
585585

586+
// Pinned fields `#[pin]`
587+
gated!(
588+
pin, Normal, template!(Word), ErrorFollowing, EncodeCrossCrate::Yes, pin_ergonomics,
589+
experimental!(pin),
590+
),
591+
586592
// ==========================================================================
587593
// Internal attributes: Stability, deprecation, and unsafe:
588594
// ==========================================================================

compiler/rustc_hir_analysis/src/collect.rs

+6-1
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use std::iter;
2020
use std::ops::Bound;
2121

2222
use rustc_abi::ExternAbi;
23-
use rustc_ast::Recovered;
23+
use rustc_ast::{Pinnedness, Recovered};
2424
use rustc_data_structures::fx::{FxHashSet, FxIndexMap};
2525
use rustc_data_structures::unord::UnordMap;
2626
use rustc_errors::{
@@ -1061,6 +1061,11 @@ fn lower_variant<'tcx>(
10611061
vis: tcx.visibility(f.def_id),
10621062
safety: f.safety,
10631063
value: f.default.map(|v| v.def_id.to_def_id()),
1064+
pinnedness: if tcx.has_attr(f.def_id, sym::pin) {
1065+
Pinnedness::Pinned
1066+
} else {
1067+
Pinnedness::Not
1068+
},
10641069
})
10651070
.collect();
10661071
let recovered = match def {

compiler/rustc_hir_typeck/src/pat.rs

+35-9
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ struct TopInfo<'tcx> {
8888

8989
#[derive(Copy, Clone)]
9090
struct PatInfo<'tcx> {
91+
pinnedness: ast::Pinnedness,
9192
binding_mode: ByRef,
9293
max_ref_mutbl: MutblCap,
9394
top_info: TopInfo<'tcx>,
@@ -302,6 +303,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
302303
) {
303304
let top_info = TopInfo { expected, origin_expr, span, hir_id: pat.hir_id };
304305
let pat_info = PatInfo {
306+
pinnedness: ast::Pinnedness::Not,
305307
binding_mode: ByRef::No,
306308
max_ref_mutbl: MutblCap::Mut,
307309
top_info,
@@ -400,11 +402,13 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
400402
let pat_info = PatInfo { current_depth: old_pat_info.current_depth + 1, ..old_pat_info };
401403

402404
match pat.kind {
403-
// Peel off a `&` or `&mut` from the scrutinee type. See the examples in
404-
// `tests/ui/rfcs/rfc-2005-default-binding-mode`.
405+
// Peel off a `&`, `&mut`, `&pin const` or `&pin mut` from the scrutinee type.
406+
// See the examples in `tests/ui/rfcs/rfc-2005-default-binding-mode`
407+
// and `tests/ui/async-await/pin-ergonomics/project-pattern-match`.
405408
_ if let AdjustMode::Peel = adjust_mode
406409
&& pat.default_binding_modes
407-
&& let ty::Ref(_, inner_ty, inner_mutability) = *expected.kind() =>
410+
&& let Some((_, pinnedness, inner_ty, inner_mutability)) =
411+
expected.is_ref_or_pin_ref(self.tcx) =>
408412
{
409413
debug!("inspecting {:?}", expected);
410414

@@ -428,6 +432,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
428432
ByRef::Yes(Mutability::Not) => Mutability::Not,
429433
});
430434

435+
let pinnedness = pinnedness.max(pat_info.pinnedness);
436+
431437
let mut max_ref_mutbl = pat_info.max_ref_mutbl;
432438
if self.downgrade_mut_inside_shared() {
433439
binding_mode = binding_mode.cap_ref_mutability(max_ref_mutbl.as_mutbl());
@@ -438,7 +444,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
438444
debug!("default binding mode is now {:?}", binding_mode);
439445

440446
// Use the old pat info to keep `current_depth` to its old value.
441-
let new_pat_info = PatInfo { binding_mode, max_ref_mutbl, ..old_pat_info };
447+
let new_pat_info =
448+
PatInfo { pinnedness, binding_mode, max_ref_mutbl, ..old_pat_info };
442449
// Recurse with the new expected type.
443450
self.check_pat_inner(pat, opt_path_res, adjust_mode, inner_ty, new_pat_info)
444451
}
@@ -790,7 +797,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
790797
expected: Ty<'tcx>,
791798
pat_info: PatInfo<'tcx>,
792799
) -> Ty<'tcx> {
793-
let PatInfo { binding_mode: def_br, top_info: ti, .. } = pat_info;
800+
let PatInfo { pinnedness, binding_mode: def_br, top_info: ti, .. } = pat_info;
794801

795802
// Determine the binding mode...
796803
let bm = match user_bind_annot {
@@ -883,6 +890,17 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
883890
ByRef::No => expected, // As above, `T <: typeof(x)` is required, but we use equality, see (note_1).
884891
};
885892

893+
// Wrapping the type into `Pin` if the pattern is pinned
894+
let eq_ty = if pinnedness == ast::Pinnedness::Pinned {
895+
Ty::new_adt(
896+
self.tcx,
897+
self.tcx.adt_def(self.tcx.require_lang_item(hir::LangItem::Pin, Some(pat.span))),
898+
self.tcx.mk_args(&[eq_ty.into()]),
899+
)
900+
} else {
901+
eq_ty
902+
};
903+
886904
// We have a concrete type for the local, so we do not need to taint it and hide follow up errors *using* the local.
887905
let _ = self.demand_eqtype_pat(pat.span, eq_ty, local_ty, &ti);
888906

@@ -1386,6 +1404,10 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
13861404
for (i, subpat) in subpats.iter().enumerate_and_adjust(variant.fields.len(), ddpos) {
13871405
let field = &variant.fields[FieldIdx::from_usize(i)];
13881406
let field_ty = self.field_ty(subpat.span, field, args);
1407+
// If the field is not marked as `#[pin]`, then remove the
1408+
// pinnedness in `pat_info`.
1409+
let pinnedness = pat_info.pinnedness.min(field.pinnedness);
1410+
let pat_info = PatInfo { pinnedness, ..pat_info };
13891411
self.check_pat(subpat, field_ty, pat_info);
13901412

13911413
self.tcx.check_stability(
@@ -1642,11 +1664,11 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
16421664
for field in fields {
16431665
let span = field.span;
16441666
let ident = tcx.adjust_ident(field.ident, variant.def_id);
1645-
let field_ty = match used_fields.entry(ident) {
1667+
let (field_ty, pinnedness) = match used_fields.entry(ident) {
16461668
Occupied(occupied) => {
16471669
let guar = self.error_field_already_bound(span, field.ident, *occupied.get());
16481670
result = Err(guar);
1649-
Ty::new_error(tcx, guar)
1671+
(Ty::new_error(tcx, guar), ast::Pinnedness::Not)
16501672
}
16511673
Vacant(vacant) => {
16521674
vacant.insert(span);
@@ -1655,15 +1677,19 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
16551677
.map(|(i, f)| {
16561678
self.write_field_index(field.hir_id, *i);
16571679
self.tcx.check_stability(f.did, Some(field.hir_id), span, None);
1658-
self.field_ty(span, f, args)
1680+
(self.field_ty(span, f, args), f.pinnedness)
16591681
})
16601682
.unwrap_or_else(|| {
16611683
inexistent_fields.push(field);
1662-
Ty::new_misc_error(tcx)
1684+
(Ty::new_misc_error(tcx), ast::Pinnedness::Not)
16631685
})
16641686
}
16651687
};
16661688

1689+
// If the field is not marked as `#[pin]`, then remove the
1690+
// pinnedness in `pat_info`.
1691+
let pinnedness = pat_info.pinnedness.min(pinnedness);
1692+
let pat_info = PatInfo { pinnedness, ..pat_info };
16671693
self.check_pat(field.pat, field_ty, pat_info);
16681694
}
16691695

compiler/rustc_metadata/src/rmeta/decoder.rs

+13-3
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use rustc_session::Session;
3232
use rustc_session::config::TargetModifier;
3333
use rustc_session::cstore::{CrateSource, ExternCrate};
3434
use rustc_span::hygiene::HygieneDecodeContext;
35-
use rustc_span::{BytePos, DUMMY_SP, Pos, SpanData, SpanDecoder, SyntaxContext, kw};
35+
use rustc_span::{BytePos, DUMMY_SP, Pos, SpanData, SpanDecoder, SyntaxContext, kw, sym};
3636
use tracing::debug;
3737

3838
use crate::creader::CStore;
@@ -1084,6 +1084,7 @@ impl<'a> CrateMetadataRef<'a> {
10841084
kind: DefKind,
10851085
index: DefIndex,
10861086
parent_did: DefId,
1087+
sess: &'a Session,
10871088
) -> (VariantIdx, ty::VariantDef) {
10881089
let adt_kind = match kind {
10891090
DefKind::Variant => ty::AdtKind::Enum,
@@ -1112,6 +1113,15 @@ impl<'a> CrateMetadataRef<'a> {
11121113
vis: self.get_visibility(did.index),
11131114
safety: self.get_safety(did.index),
11141115
value: self.get_default_field(did.index),
1116+
pinnedness: if self
1117+
.get_item_attrs(did.index, sess)
1118+
.find(|attr| attr.has_name(sym::pin))
1119+
.is_some()
1120+
{
1121+
ast::Pinnedness::Pinned
1122+
} else {
1123+
ast::Pinnedness::Not
1124+
},
11151125
})
11161126
.collect(),
11171127
parent_did,
@@ -1144,12 +1154,12 @@ impl<'a> CrateMetadataRef<'a> {
11441154
let kind = self.def_kind(index);
11451155
match kind {
11461156
DefKind::Ctor(..) => None,
1147-
_ => Some(self.get_variant(kind, index, did)),
1157+
_ => Some(self.get_variant(kind, index, did, tcx.sess)),
11481158
}
11491159
})
11501160
.collect()
11511161
} else {
1152-
std::iter::once(self.get_variant(kind, item_id, did)).collect()
1162+
std::iter::once(self.get_variant(kind, item_id, did, tcx.sess)).collect()
11531163
};
11541164

11551165
variants.sort_by_key(|(idx, _)| *idx);

compiler/rustc_middle/src/ty/mod.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -1379,6 +1379,7 @@ pub struct FieldDef {
13791379
pub vis: Visibility<DefId>,
13801380
pub safety: hir::Safety,
13811381
pub value: Option<DefId>,
1382+
pub pinnedness: ast::Pinnedness,
13821383
}
13831384

13841385
impl PartialEq for FieldDef {
@@ -1391,9 +1392,9 @@ impl PartialEq for FieldDef {
13911392
// of `FieldDef` changes, a compile-error will be produced, reminding
13921393
// us to revisit this assumption.
13931394

1394-
let Self { did: lhs_did, name: _, vis: _, safety: _, value: _ } = &self;
1395+
let Self { did: lhs_did, name: _, vis: _, safety: _, value: _, pinnedness: _ } = &self;
13951396

1396-
let Self { did: rhs_did, name: _, vis: _, safety: _, value: _ } = other;
1397+
let Self { did: rhs_did, name: _, vis: _, safety: _, value: _, pinnedness: _ } = other;
13971398

13981399
let res = lhs_did == rhs_did;
13991400

@@ -1420,7 +1421,7 @@ impl Hash for FieldDef {
14201421
// of `FieldDef` changes, a compile-error will be produced, reminding
14211422
// us to revisit this assumption.
14221423

1423-
let Self { did, name: _, vis: _, safety: _, value: _ } = &self;
1424+
let Self { did, name: _, vis: _, safety: _, value: _, pinnedness: _ } = &self;
14241425

14251426
did.hash(s)
14261427
}

compiler/rustc_middle/src/ty/util.rs

+21
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,27 @@ impl<'tcx> Ty<'tcx> {
15451545
ty
15461546
}
15471547

1548+
/// Destructs a reference type `&'a [mut] T` or a pinned reference type
1549+
/// `&'a pin const|mut T` into `'a` the region, `[pin]` the pinnedness,
1550+
/// `T` the inner type, and `mut|const` the mutability.
1551+
pub fn is_ref_or_pin_ref(
1552+
self,
1553+
tcx: TyCtxt<'tcx>,
1554+
) -> Option<(ty::Region<'tcx>, ty::Pinnedness, Ty<'tcx>, ty::Mutability)> {
1555+
match self.kind() {
1556+
&ty::Ref(region, inner_ty, mutbl) => {
1557+
Some((region, ty::Pinnedness::Not, inner_ty, mutbl))
1558+
}
1559+
ty::Adt(adt, args)
1560+
if tcx.is_lang_item(adt.did(), hir::LangItem::Pin)
1561+
&& let &ty::Ref(region, inner_ty, mutbl) = args.type_at(0).kind() =>
1562+
{
1563+
Some((region, ty::Pinnedness::Pinned, inner_ty, mutbl))
1564+
}
1565+
_ => None,
1566+
}
1567+
}
1568+
15481569
// FIXME(compiler-errors): Think about removing this.
15491570
#[inline]
15501571
pub fn outer_exclusive_binder(self) -> ty::DebruijnIndex {

compiler/rustc_mir_build/src/thir/pattern/mod.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -348,8 +348,8 @@ impl<'a, 'tcx> PatCtxt<'a, 'tcx> {
348348
// x's type, which is &T, where we want T (the type being matched).
349349
let var_ty = ty;
350350
if let hir::ByRef::Yes(_) = mode.0 {
351-
if let ty::Ref(_, rty, _) = ty.kind() {
352-
ty = *rty;
351+
if let Some((_, _, rty, _)) = ty.is_ref_or_pin_ref(self.tcx) {
352+
ty = rty;
353353
} else {
354354
bug!("`ref {}` has wrong type {}", ident, ty);
355355
}

compiler/rustc_passes/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,10 @@ passes_pass_by_value =
618618
`pass_by_value` attribute should be applied to a struct, enum or type alias
619619
.label = is not a struct, enum or type alias
620620
621+
passes_pin_bad_location =
622+
`pin` attribute should be applied to a field
623+
.label = expect field, but apllied to this {$target}
624+
621625
passes_proc_macro_bad_sig = {$kind} has incorrect signature
622626
623627
passes_remove_fields =

compiler/rustc_passes/src/check_attr.rs

+15
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
263263
}
264264
[sym::linkage, ..] => self.check_linkage(attr, span, target),
265265
[sym::rustc_pub_transparent, ..] => self.check_rustc_pub_transparent(attr.span(), span, attrs),
266+
[sym::pin, ..] => self.check_pin(attr.span(), span, target),
266267
[
267268
// ok
268269
sym::allow
@@ -2644,6 +2645,20 @@ impl<'tcx> CheckAttrVisitor<'tcx> {
26442645
}
26452646
}
26462647
}
2648+
2649+
/// Checkes if `#[pin]` is applied to a field definition.
2650+
fn check_pin(&self, attr_span: Span, span: Span, target: Target) {
2651+
match target {
2652+
Target::Field => {}
2653+
_ => {
2654+
self.dcx().emit_err(errors::PinBadLocation {
2655+
attr_span,
2656+
span,
2657+
target: target.name(),
2658+
});
2659+
}
2660+
}
2661+
}
26472662
}
26482663

26492664
impl<'tcx> Visitor<'tcx> for CheckAttrVisitor<'tcx> {

compiler/rustc_passes/src/errors.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1928,3 +1928,13 @@ pub(crate) struct UnsupportedAttributesInWhere {
19281928
#[primary_span]
19291929
pub span: MultiSpan,
19301930
}
1931+
1932+
#[derive(Diagnostic)]
1933+
#[diag(passes_pin_bad_location)]
1934+
pub(crate) struct PinBadLocation {
1935+
#[primary_span]
1936+
pub attr_span: Span,
1937+
#[label]
1938+
pub span: Span,
1939+
pub target: &'static str,
1940+
}

compiler/rustc_pattern_analysis/src/rustc.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@ use std::ops::ControlFlow;
44

55
use rustc_abi::{FIRST_VARIANT, FieldIdx, Integer, VariantIdx};
66
use rustc_arena::DroplessArena;
7-
use rustc_hir::HirId;
87
use rustc_hir::def_id::DefId;
8+
use rustc_hir::{HirId, LangItem};
99
use rustc_index::{Idx, IndexVec};
1010
use rustc_middle::middle::stability::EvalResult;
1111
use rustc_middle::mir::{self, Const};
@@ -484,6 +484,8 @@ impl<'p, 'tcx: 'p> RustcPatCtxt<'p, 'tcx> {
484484
ctor = match ty.kind() {
485485
// This is a box pattern.
486486
ty::Adt(adt, ..) if adt.is_box() => Struct,
487+
// This is a pin ref pattern.
488+
ty::Adt(adt, ..) if self.tcx.is_lang_item(adt.did(), LangItem::Pin) => Ref,
487489
ty::Ref(..) => Ref,
488490
_ => span_bug!(
489491
pat.span,

0 commit comments

Comments
 (0)