Skip to content

Commit f2abf82

Browse files
committed
Auto merge of #132894 - frank-king:feature/where-refactor, r=cjgillot
Refactor `where` predicates, and reserve for attributes support Refactor `WherePredicate` to `WherePredicateKind`, and reserve for attributes support in `where` predicates. This is a part of #115590 and is split from #132388. r? petrochenkov
2 parents 6d22ff1 + 161221d commit f2abf82

File tree

44 files changed

+395
-409
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+395
-409
lines changed

compiler/rustc_ast/src/ast.rs

+9-14
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,15 @@ impl Default for WhereClause {
428428

429429
/// A single predicate in a where-clause.
430430
#[derive(Clone, Encodable, Decodable, Debug)]
431-
pub enum WherePredicate {
431+
pub struct WherePredicate {
432+
pub kind: WherePredicateKind,
433+
pub id: NodeId,
434+
pub span: Span,
435+
}
436+
437+
/// Predicate kind in where-clause.
438+
#[derive(Clone, Encodable, Decodable, Debug)]
439+
pub enum WherePredicateKind {
432440
/// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
433441
BoundPredicate(WhereBoundPredicate),
434442
/// A lifetime predicate (e.g., `'a: 'b + 'c`).
@@ -437,22 +445,11 @@ pub enum WherePredicate {
437445
EqPredicate(WhereEqPredicate),
438446
}
439447

440-
impl WherePredicate {
441-
pub fn span(&self) -> Span {
442-
match self {
443-
WherePredicate::BoundPredicate(p) => p.span,
444-
WherePredicate::RegionPredicate(p) => p.span,
445-
WherePredicate::EqPredicate(p) => p.span,
446-
}
447-
}
448-
}
449-
450448
/// A type bound.
451449
///
452450
/// E.g., `for<'c> Foo: Send + Clone + 'c`.
453451
#[derive(Clone, Encodable, Decodable, Debug)]
454452
pub struct WhereBoundPredicate {
455-
pub span: Span,
456453
/// Any generics from a `for` binding.
457454
pub bound_generic_params: ThinVec<GenericParam>,
458455
/// The type being bounded.
@@ -466,7 +463,6 @@ pub struct WhereBoundPredicate {
466463
/// E.g., `'a: 'b + 'c`.
467464
#[derive(Clone, Encodable, Decodable, Debug)]
468465
pub struct WhereRegionPredicate {
469-
pub span: Span,
470466
pub lifetime: Lifetime,
471467
pub bounds: GenericBounds,
472468
}
@@ -476,7 +472,6 @@ pub struct WhereRegionPredicate {
476472
/// E.g., `T = int`.
477473
#[derive(Clone, Encodable, Decodable, Debug)]
478474
pub struct WhereEqPredicate {
479-
pub span: Span,
480475
pub lhs_ty: P<Ty>,
481476
pub rhs_ty: P<Ty>,
482477
}

compiler/rustc_ast/src/mut_visit.rs

+20-12
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,11 @@ pub trait MutVisitor: Sized {
332332
}
333333

334334
fn visit_where_predicate(&mut self, where_predicate: &mut WherePredicate) {
335-
walk_where_predicate(self, where_predicate);
335+
walk_where_predicate(self, where_predicate)
336+
}
337+
338+
fn visit_where_predicate_kind(&mut self, kind: &mut WherePredicateKind) {
339+
walk_where_predicate_kind(self, kind)
336340
}
337341

338342
fn visit_vis(&mut self, vis: &mut Visibility) {
@@ -1065,26 +1069,30 @@ fn walk_where_clause<T: MutVisitor>(vis: &mut T, wc: &mut WhereClause) {
10651069
vis.visit_span(span);
10661070
}
10671071

1068-
fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
1069-
match pred {
1070-
WherePredicate::BoundPredicate(bp) => {
1071-
let WhereBoundPredicate { span, bound_generic_params, bounded_ty, bounds } = bp;
1072+
pub fn walk_where_predicate<T: MutVisitor>(vis: &mut T, pred: &mut WherePredicate) {
1073+
let WherePredicate { kind, id, span } = pred;
1074+
vis.visit_id(id);
1075+
vis.visit_where_predicate_kind(kind);
1076+
vis.visit_span(span);
1077+
}
1078+
1079+
pub fn walk_where_predicate_kind<T: MutVisitor>(vis: &mut T, kind: &mut WherePredicateKind) {
1080+
match kind {
1081+
WherePredicateKind::BoundPredicate(bp) => {
1082+
let WhereBoundPredicate { bound_generic_params, bounded_ty, bounds } = bp;
10721083
bound_generic_params.flat_map_in_place(|param| vis.flat_map_generic_param(param));
10731084
vis.visit_ty(bounded_ty);
10741085
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
1075-
vis.visit_span(span);
10761086
}
1077-
WherePredicate::RegionPredicate(rp) => {
1078-
let WhereRegionPredicate { span, lifetime, bounds } = rp;
1087+
WherePredicateKind::RegionPredicate(rp) => {
1088+
let WhereRegionPredicate { lifetime, bounds } = rp;
10791089
vis.visit_lifetime(lifetime);
10801090
visit_vec(bounds, |bound| vis.visit_param_bound(bound, BoundKind::Bound));
1081-
vis.visit_span(span);
10821091
}
1083-
WherePredicate::EqPredicate(ep) => {
1084-
let WhereEqPredicate { span, lhs_ty, rhs_ty } = ep;
1092+
WherePredicateKind::EqPredicate(ep) => {
1093+
let WhereEqPredicate { lhs_ty, rhs_ty } = ep;
10851094
vis.visit_ty(lhs_ty);
10861095
vis.visit_ty(rhs_ty);
1087-
vis.visit_span(span);
10881096
}
10891097
}
10901098
}

compiler/rustc_ast/src/visit.rs

+15-5
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,9 @@ pub trait Visitor<'ast>: Sized {
192192
fn visit_where_predicate(&mut self, p: &'ast WherePredicate) -> Self::Result {
193193
walk_where_predicate(self, p)
194194
}
195+
fn visit_where_predicate_kind(&mut self, k: &'ast WherePredicateKind) -> Self::Result {
196+
walk_where_predicate_kind(self, k)
197+
}
195198
fn visit_fn(&mut self, fk: FnKind<'ast>, _: Span, _: NodeId) -> Self::Result {
196199
walk_fn(self, fk)
197200
}
@@ -794,22 +797,29 @@ pub fn walk_where_predicate<'a, V: Visitor<'a>>(
794797
visitor: &mut V,
795798
predicate: &'a WherePredicate,
796799
) -> V::Result {
797-
match predicate {
798-
WherePredicate::BoundPredicate(WhereBoundPredicate {
800+
let WherePredicate { kind, id: _, span: _ } = predicate;
801+
visitor.visit_where_predicate_kind(kind)
802+
}
803+
804+
pub fn walk_where_predicate_kind<'a, V: Visitor<'a>>(
805+
visitor: &mut V,
806+
kind: &'a WherePredicateKind,
807+
) -> V::Result {
808+
match kind {
809+
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
799810
bounded_ty,
800811
bounds,
801812
bound_generic_params,
802-
span: _,
803813
}) => {
804814
walk_list!(visitor, visit_generic_param, bound_generic_params);
805815
try_visit!(visitor.visit_ty(bounded_ty));
806816
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
807817
}
808-
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span: _ }) => {
818+
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds }) => {
809819
try_visit!(visitor.visit_lifetime(lifetime, LifetimeCtxt::Bound));
810820
walk_list!(visitor, visit_param_bound, bounds, BoundKind::Bound);
811821
}
812-
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span: _ }) => {
822+
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty }) => {
813823
try_visit!(visitor.visit_ty(lhs_ty));
814824
try_visit!(visitor.visit_ty(rhs_ty));
815825
}

compiler/rustc_ast_lowering/src/index.rs

+4-9
Original file line numberDiff line numberDiff line change
@@ -381,15 +381,10 @@ impl<'a, 'hir> Visitor<'hir> for NodeCollector<'a, 'hir> {
381381
}
382382

383383
fn visit_where_predicate(&mut self, predicate: &'hir WherePredicate<'hir>) {
384-
match predicate {
385-
WherePredicate::BoundPredicate(pred) => {
386-
self.insert(pred.span, pred.hir_id, Node::WhereBoundPredicate(pred));
387-
self.with_parent(pred.hir_id, |this| {
388-
intravisit::walk_where_predicate(this, predicate)
389-
})
390-
}
391-
_ => intravisit::walk_where_predicate(self, predicate),
392-
}
384+
self.insert(predicate.span, predicate.hir_id, Node::WherePredicate(predicate));
385+
self.with_parent(predicate.hir_id, |this| {
386+
intravisit::walk_where_predicate(this, predicate)
387+
});
393388
}
394389

395390
fn visit_array_length(&mut self, len: &'hir ArrayLen<'hir>) {

compiler/rustc_ast_lowering/src/item.rs

+21-25
Original file line numberDiff line numberDiff line change
@@ -1401,7 +1401,7 @@ impl<'hir> LoweringContext<'_, 'hir> {
14011401
// keep track of the Span info. Now, `<dyn HirTyLowerer>::add_implicit_sized_bound`
14021402
// checks both param bounds and where clauses for `?Sized`.
14031403
for pred in &generics.where_clause.predicates {
1404-
let WherePredicate::BoundPredicate(bound_pred) = pred else {
1404+
let WherePredicateKind::BoundPredicate(bound_pred) = &pred.kind else {
14051405
continue;
14061406
};
14071407
let compute_is_param = || {
@@ -1538,9 +1538,9 @@ impl<'hir> LoweringContext<'_, 'hir> {
15381538
}
15391539
});
15401540
let span = self.lower_span(span);
1541-
1542-
match kind {
1543-
GenericParamKind::Const { .. } => None,
1541+
let hir_id = self.next_id();
1542+
let kind = self.arena.alloc(match kind {
1543+
GenericParamKind::Const { .. } => return None,
15441544
GenericParamKind::Type { .. } => {
15451545
let def_id = self.local_def_id(id).to_def_id();
15461546
let hir_id = self.next_id();
@@ -1555,38 +1555,36 @@ impl<'hir> LoweringContext<'_, 'hir> {
15551555
let ty_id = self.next_id();
15561556
let bounded_ty =
15571557
self.ty_path(ty_id, param_span, hir::QPath::Resolved(None, ty_path));
1558-
Some(hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
1559-
hir_id: self.next_id(),
1558+
hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
15601559
bounded_ty: self.arena.alloc(bounded_ty),
15611560
bounds,
1562-
span,
15631561
bound_generic_params: &[],
15641562
origin,
1565-
}))
1563+
})
15661564
}
15671565
GenericParamKind::Lifetime => {
15681566
let ident = self.lower_ident(ident);
15691567
let lt_id = self.next_node_id();
15701568
let lifetime = self.new_named_lifetime(id, lt_id, ident);
1571-
Some(hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
1569+
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
15721570
lifetime,
1573-
span,
15741571
bounds,
15751572
in_where_clause: false,
1576-
}))
1573+
})
15771574
}
1578-
}
1575+
});
1576+
Some(hir::WherePredicate { hir_id, span, kind })
15791577
}
15801578

15811579
fn lower_where_predicate(&mut self, pred: &WherePredicate) -> hir::WherePredicate<'hir> {
1582-
match pred {
1583-
WherePredicate::BoundPredicate(WhereBoundPredicate {
1580+
let hir_id = self.lower_node_id(pred.id);
1581+
let span = self.lower_span(pred.span);
1582+
let kind = self.arena.alloc(match &pred.kind {
1583+
WherePredicateKind::BoundPredicate(WhereBoundPredicate {
15841584
bound_generic_params,
15851585
bounded_ty,
15861586
bounds,
1587-
span,
1588-
}) => hir::WherePredicate::BoundPredicate(hir::WhereBoundPredicate {
1589-
hir_id: self.next_id(),
1587+
}) => hir::WherePredicateKind::BoundPredicate(hir::WhereBoundPredicate {
15901588
bound_generic_params: self
15911589
.lower_generic_params(bound_generic_params, hir::GenericParamSource::Binder),
15921590
bounded_ty: self
@@ -1595,12 +1593,10 @@ impl<'hir> LoweringContext<'_, 'hir> {
15951593
bounds,
15961594
ImplTraitContext::Disallowed(ImplTraitPosition::Bound),
15971595
),
1598-
span: self.lower_span(*span),
15991596
origin: PredicateOrigin::WhereClause,
16001597
}),
1601-
WherePredicate::RegionPredicate(WhereRegionPredicate { lifetime, bounds, span }) => {
1602-
hir::WherePredicate::RegionPredicate(hir::WhereRegionPredicate {
1603-
span: self.lower_span(*span),
1598+
WherePredicateKind::RegionPredicate(WhereRegionPredicate { lifetime, bounds }) => {
1599+
hir::WherePredicateKind::RegionPredicate(hir::WhereRegionPredicate {
16041600
lifetime: self.lower_lifetime(lifetime),
16051601
bounds: self.lower_param_bounds(
16061602
bounds,
@@ -1609,15 +1605,15 @@ impl<'hir> LoweringContext<'_, 'hir> {
16091605
in_where_clause: true,
16101606
})
16111607
}
1612-
WherePredicate::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty, span }) => {
1613-
hir::WherePredicate::EqPredicate(hir::WhereEqPredicate {
1608+
WherePredicateKind::EqPredicate(WhereEqPredicate { lhs_ty, rhs_ty }) => {
1609+
hir::WherePredicateKind::EqPredicate(hir::WhereEqPredicate {
16141610
lhs_ty: self
16151611
.lower_ty(lhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
16161612
rhs_ty: self
16171613
.lower_ty(rhs_ty, ImplTraitContext::Disallowed(ImplTraitPosition::Bound)),
1618-
span: self.lower_span(*span),
16191614
})
16201615
}
1621-
}
1616+
});
1617+
hir::WherePredicate { hir_id, span, kind }
16221618
}
16231619
}

compiler/rustc_ast_passes/src/ast_validation.rs

+17-15
Original file line numberDiff line numberDiff line change
@@ -1200,14 +1200,15 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
12001200
validate_generic_param_order(self.dcx(), &generics.params, generics.span);
12011201

12021202
for predicate in &generics.where_clause.predicates {
1203-
if let WherePredicate::EqPredicate(predicate) = predicate {
1204-
deny_equality_constraints(self, predicate, generics);
1203+
let span = predicate.span;
1204+
if let WherePredicateKind::EqPredicate(predicate) = &predicate.kind {
1205+
deny_equality_constraints(self, predicate, span, generics);
12051206
}
12061207
}
12071208
walk_list!(self, visit_generic_param, &generics.params);
12081209
for predicate in &generics.where_clause.predicates {
1209-
match predicate {
1210-
WherePredicate::BoundPredicate(bound_pred) => {
1210+
match &predicate.kind {
1211+
WherePredicateKind::BoundPredicate(bound_pred) => {
12111212
// This is slightly complicated. Our representation for poly-trait-refs contains a single
12121213
// binder and thus we only allow a single level of quantification. However,
12131214
// the syntax of Rust permits quantification in two places in where clauses,
@@ -1504,9 +1505,10 @@ impl<'a> Visitor<'a> for AstValidator<'a> {
15041505
fn deny_equality_constraints(
15051506
this: &AstValidator<'_>,
15061507
predicate: &WhereEqPredicate,
1508+
predicate_span: Span,
15071509
generics: &Generics,
15081510
) {
1509-
let mut err = errors::EqualityInWhere { span: predicate.span, assoc: None, assoc2: None };
1511+
let mut err = errors::EqualityInWhere { span: predicate_span, assoc: None, assoc2: None };
15101512

15111513
// Given `<A as Foo>::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
15121514
if let TyKind::Path(Some(qself), full_path) = &predicate.lhs_ty.kind
@@ -1550,7 +1552,7 @@ fn deny_equality_constraints(
15501552
}
15511553
}
15521554
err.assoc = Some(errors::AssociatedSuggestion {
1553-
span: predicate.span,
1555+
span: predicate_span,
15541556
ident: *ident,
15551557
param: param.ident,
15561558
path: pprust::path_to_string(&assoc_path),
@@ -1580,23 +1582,23 @@ fn deny_equality_constraints(
15801582
// We're removing th eonly where bound left, remove the whole thing.
15811583
generics.where_clause.span
15821584
} else {
1583-
let mut span = predicate.span;
1585+
let mut span = predicate_span;
15841586
let mut prev: Option<Span> = None;
15851587
let mut preds = generics.where_clause.predicates.iter().peekable();
15861588
// Find the predicate that shouldn't have been in the where bound list.
15871589
while let Some(pred) = preds.next() {
1588-
if let WherePredicate::EqPredicate(pred) = pred
1589-
&& pred.span == predicate.span
1590+
if let WherePredicateKind::EqPredicate(_) = pred.kind
1591+
&& pred.span == predicate_span
15901592
{
15911593
if let Some(next) = preds.peek() {
15921594
// This is the first predicate, remove the trailing comma as well.
1593-
span = span.with_hi(next.span().lo());
1595+
span = span.with_hi(next.span.lo());
15941596
} else if let Some(prev) = prev {
15951597
// Remove the previous comma as well.
15961598
span = span.with_lo(prev.hi());
15971599
}
15981600
}
1599-
prev = Some(pred.span());
1601+
prev = Some(pred.span);
16001602
}
16011603
span
16021604
};
@@ -1613,8 +1615,8 @@ fn deny_equality_constraints(
16131615
if let TyKind::Path(None, full_path) = &predicate.lhs_ty.kind {
16141616
// Given `A: Foo, Foo::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
16151617
for bounds in generics.params.iter().map(|p| &p.bounds).chain(
1616-
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
1617-
WherePredicate::BoundPredicate(p) => Some(&p.bounds),
1618+
generics.where_clause.predicates.iter().filter_map(|pred| match &pred.kind {
1619+
WherePredicateKind::BoundPredicate(p) => Some(&p.bounds),
16181620
_ => None,
16191621
}),
16201622
) {
@@ -1637,8 +1639,8 @@ fn deny_equality_constraints(
16371639
// Given `A: Foo, A::Bar = RhsTy`, suggest `A: Foo<Bar = RhsTy>`.
16381640
if let [potential_param, potential_assoc] = &full_path.segments[..] {
16391641
for (ident, bounds) in generics.params.iter().map(|p| (p.ident, &p.bounds)).chain(
1640-
generics.where_clause.predicates.iter().filter_map(|pred| match pred {
1641-
WherePredicate::BoundPredicate(p)
1642+
generics.where_clause.predicates.iter().filter_map(|pred| match &pred.kind {
1643+
WherePredicateKind::BoundPredicate(p)
16421644
if let ast::TyKind::Path(None, path) = &p.bounded_ty.kind
16431645
&& let [segment] = &path.segments[..] =>
16441646
{

compiler/rustc_ast_passes/src/feature_gate.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@ impl<'a> Visitor<'a> for PostExpansionVisitor<'a> {
345345

346346
fn visit_generics(&mut self, g: &'a ast::Generics) {
347347
for predicate in &g.where_clause.predicates {
348-
match predicate {
349-
ast::WherePredicate::BoundPredicate(bound_pred) => {
348+
match &predicate.kind {
349+
ast::WherePredicateKind::BoundPredicate(bound_pred) => {
350350
// A type bound (e.g., `for<'c> Foo: Send + Clone + 'c`).
351351
self.check_late_bound_lifetime_defs(&bound_pred.bound_generic_params);
352352
}

0 commit comments

Comments
 (0)