Skip to content

[experiment] How expensive is if_cause? #139594

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 5 additions & 94 deletions compiler/rustc_hir_typeck/src/_match.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use rustc_errors::{Applicability, Diag};
use rustc_hir::def::{CtorOf, DefKind, Res};
use rustc_hir::def_id::LocalDefId;
use rustc_hir::{self as hir, ExprKind, PatKind};
use rustc_hir::{self as hir, ExprKind, HirId, PatKind};
use rustc_hir_pretty::ty_to_string;
use rustc_middle::ty::{self, Ty};
use rustc_span::Span;
use rustc_trait_selection::traits::{
IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
};
use tracing::{debug, instrument};

Expand Down Expand Up @@ -412,105 +412,16 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

pub(crate) fn if_cause(
&self,
span: Span,
cond_span: Span,
then_expr: &'tcx hir::Expr<'tcx>,
expr_id: HirId,
else_expr: &'tcx hir::Expr<'tcx>,
then_ty: Ty<'tcx>,
else_ty: Ty<'tcx>,
tail_defines_return_position_impl_trait: Option<LocalDefId>,
) -> ObligationCause<'tcx> {
let mut outer_span = if self.tcx.sess.source_map().is_multiline(span) {
// The `if`/`else` isn't in one line in the output, include some context to make it
// clear it is an if/else expression:
// ```
// LL | let x = if true {
// | _____________-
// LL || 10i32
// || ----- expected because of this
// LL || } else {
// LL || 10u32
// || ^^^^^ expected `i32`, found `u32`
// LL || };
// ||_____- `if` and `else` have incompatible types
// ```
Some(span)
} else {
// The entire expression is in one line, only point at the arms
// ```
// LL | let x = if true { 10i32 } else { 10u32 };
// | ----- ^^^^^ expected `i32`, found `u32`
// | |
// | expected because of this
// ```
None
};

let (error_sp, else_id) = if let ExprKind::Block(block, _) = &else_expr.kind {
let block = block.innermost_block();

// Avoid overlapping spans that aren't as readable:
// ```
// 2 | let x = if true {
// | _____________-
// 3 | | 3
// | | - expected because of this
// 4 | | } else {
// | |____________^
// 5 | ||
// 6 | || };
// | || ^
// | ||_____|
// | |______if and else have incompatible types
// | expected integer, found `()`
// ```
// by not pointing at the entire expression:
// ```
// 2 | let x = if true {
// | ------- `if` and `else` have incompatible types
// 3 | 3
// | - expected because of this
// 4 | } else {
// | ____________^
// 5 | |
// 6 | | };
// | |_____^ expected integer, found `()`
// ```
if block.expr.is_none()
&& block.stmts.is_empty()
&& let Some(outer_span) = &mut outer_span
&& let Some(cond_span) = cond_span.find_ancestor_inside(*outer_span)
{
*outer_span = outer_span.with_hi(cond_span.hi())
}

(self.find_block_span(block), block.hir_id)
} else {
(else_expr.span, else_expr.hir_id)
};

let then_id = if let ExprKind::Block(block, _) = &then_expr.kind {
let block = block.innermost_block();
// Exclude overlapping spans
if block.expr.is_none() && block.stmts.is_empty() {
outer_span = None;
}
block.hir_id
} else {
then_expr.hir_id
};
let error_sp = self.find_block_span_from_hir_id(else_expr.hir_id);

// Finally construct the cause:
self.cause(
error_sp,
ObligationCauseCode::IfExpression(Box::new(IfExpressionCause {
else_id,
then_id,
then_ty,
else_ty,
outer_span,
tail_defines_return_position_impl_trait,
})),
ObligationCauseCode::IfExpression { expr_id, tail_defines_return_position_impl_trait },
)
}

Expand Down
44 changes: 17 additions & 27 deletions compiler/rustc_hir_typeck/src/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ use rustc_hir_analysis::hir_ty_lowering::HirTyLowerer;
use rustc_infer::infer::relate::RelateResult;
use rustc_infer::infer::{Coercion, DefineOpaqueTypes, InferOk, InferResult};
use rustc_infer::traits::{
IfExpressionCause, MatchExpressionArmCause, Obligation, PredicateObligation,
PredicateObligations,
MatchExpressionArmCause, Obligation, PredicateObligation, PredicateObligations,
};
use rustc_middle::span_bug;
use rustc_middle::ty::adjustment::{
Expand Down Expand Up @@ -1695,39 +1694,30 @@ impl<'tcx, 'exprs, E: AsCoercionSite> CoerceMany<'tcx, 'exprs, E> {
);
}
}
ObligationCauseCode::IfExpression(box IfExpressionCause {
then_id,
else_id,
then_ty,
else_ty,
ObligationCauseCode::IfExpression {
expr_id,
tail_defines_return_position_impl_trait: Some(rpit_def_id),
..
}) => {
} => {
let hir::Node::Expr(hir::Expr {
kind: hir::ExprKind::If(_, then_expr, Some(else_expr)),
..
}) = fcx.tcx.hir_node(expr_id)
else {
unreachable!();
};
err = fcx.err_ctxt().report_mismatched_types(
cause,
fcx.param_env,
expected,
found,
coercion_error,
);
let then_span = fcx.find_block_span_from_hir_id(then_id);
let else_span = fcx.find_block_span_from_hir_id(else_id);
// don't suggest wrapping either blocks in `if .. {} else {}`
let is_empty_arm = |id| {
let hir::Node::Block(blk) = fcx.tcx.hir_node(id) else {
return false;
};
if blk.expr.is_some() || !blk.stmts.is_empty() {
return false;
}
let Some((_, hir::Node::Expr(expr))) =
fcx.tcx.hir_parent_iter(id).nth(1)
else {
return false;
};
matches!(expr.kind, hir::ExprKind::If(..))
};
if !is_empty_arm(then_id) && !is_empty_arm(else_id) {
let then_span = fcx.find_block_span_from_hir_id(then_expr.hir_id);
let else_span = fcx.find_block_span_from_hir_id(else_expr.hir_id);
// Don't suggest wrapping whole block in `Box::new`.
if then_span != then_expr.span && else_span != else_expr.span {
let then_ty = fcx.typeck_results.borrow().expr_ty(then_expr);
let else_ty = fcx.typeck_results.borrow().expr_ty(else_expr);
self.suggest_boxing_tail_for_return_position_impl_trait(
fcx,
&mut err,
Expand Down
14 changes: 4 additions & 10 deletions compiler/rustc_hir_typeck/src/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -566,7 +566,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
ascribed_ty
}
ExprKind::If(cond, then_expr, opt_else_expr) => {
self.check_expr_if(cond, then_expr, opt_else_expr, expr.span, expected)
self.check_expr_if(expr.hir_id, cond, then_expr, opt_else_expr, expr.span, expected)
}
ExprKind::DropTemps(e) => self.check_expr_with_expectation(e, expected),
ExprKind::Array(args) => self.check_expr_array(args, expected, expr),
Expand Down Expand Up @@ -1298,6 +1298,7 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {
// or 'if-else' expression.
fn check_expr_if(
&self,
expr_id: HirId,
cond_expr: &'tcx hir::Expr<'tcx>,
then_expr: &'tcx hir::Expr<'tcx>,
opt_else_expr: Option<&'tcx hir::Expr<'tcx>>,
Expand Down Expand Up @@ -1337,15 +1338,8 @@ impl<'a, 'tcx> FnCtxt<'a, 'tcx> {

let tail_defines_return_position_impl_trait =
self.return_position_impl_trait_from_match_expectation(orig_expected);
let if_cause = self.if_cause(
sp,
cond_expr.span,
then_expr,
else_expr,
then_ty,
else_ty,
tail_defines_return_position_impl_trait,
);
let if_cause =
self.if_cause(expr_id, else_expr, tail_defines_return_position_impl_trait);

coerce.coerce(self, &if_cause, else_expr, else_ty);

Expand Down
15 changes: 8 additions & 7 deletions compiler/rustc_infer/src/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ use rustc_middle::ty::{
TyVid, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable, TypeVisitableExt, TypingEnv,
TypingMode, fold_regions,
};
use rustc_span::{Span, Symbol};
use rustc_span::{DUMMY_SP, Span, Symbol};
use snapshot::undo_log::InferCtxtUndoLogs;
use tracing::{debug, instrument};
use type_variable::TypeVariableOrigin;
Expand Down Expand Up @@ -1560,15 +1560,16 @@ impl<'tcx> InferCtxt<'tcx> {
}
}

/// Given a [`hir::HirId`] for a block, get the span of its last expression
/// or statement, peeling off any inner blocks.
/// Given a [`hir::HirId`] for a block (or an expr of a block), get the span
/// of its last expression or statement, peeling off any inner blocks.
pub fn find_block_span_from_hir_id(&self, hir_id: hir::HirId) -> Span {
match self.tcx.hir_node(hir_id) {
hir::Node::Block(blk) => self.find_block_span(blk),
// The parser was in a weird state if either of these happen, but
// it's better not to panic.
hir::Node::Block(blk)
| hir::Node::Expr(&hir::Expr { kind: hir::ExprKind::Block(blk, _), .. }) => {
self.find_block_span(blk)
}
hir::Node::Expr(e) => e.span,
_ => rustc_span::DUMMY_SP,
_ => DUMMY_SP,
}
}
}
18 changes: 5 additions & 13 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -332,7 +332,11 @@ pub enum ObligationCauseCode<'tcx> {
},

/// Computing common supertype in an if expression
IfExpression(Box<IfExpressionCause<'tcx>>),
IfExpression {
expr_id: HirId,
// Is the expectation of this match expression an RPIT?
tail_defines_return_position_impl_trait: Option<LocalDefId>,
},

/// Computing common supertype of an if expression with no else counter-part
IfExpressionWithNoElse,
Expand Down Expand Up @@ -548,18 +552,6 @@ pub struct PatternOriginExpr {
pub peeled_prefix_suggestion_parentheses: bool,
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
#[derive(TypeFoldable, TypeVisitable, HashStable, TyEncodable, TyDecodable)]
pub struct IfExpressionCause<'tcx> {
pub then_id: HirId,
pub else_id: HirId,
pub then_ty: Ty<'tcx>,
pub else_ty: Ty<'tcx>,
pub outer_span: Option<Span>,
// Is the expectation of this match expression an RPIT?
pub tail_defines_return_position_impl_trait: Option<LocalDefId>,
}

#[derive(Clone, Debug, PartialEq, Eq, HashStable, TyEncodable, TyDecodable)]
#[derive(TypeVisitable, TypeFoldable)]
pub struct DerivedCause<'tcx> {
Expand Down
61 changes: 46 additions & 15 deletions compiler/rustc_trait_selection/src/error_reporting/infer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ use crate::infer;
use crate::infer::relate::{self, RelateResult, TypeRelation};
use crate::infer::{InferCtxt, InferCtxtExt as _, TypeTrace, ValuePairs};
use crate::solve::deeply_normalize_for_diagnostics;
use crate::traits::{
IfExpressionCause, MatchExpressionArmCause, ObligationCause, ObligationCauseCode,
};
use crate::traits::{MatchExpressionArmCause, ObligationCause, ObligationCauseCode};

mod note_and_explain;
mod suggest;
Expand Down Expand Up @@ -613,28 +611,61 @@ impl<'a, 'tcx> TypeErrCtxt<'a, 'tcx> {
}
}
},
ObligationCauseCode::IfExpression(box IfExpressionCause {
then_id,
else_id,
then_ty,
else_ty,
outer_span,
..
}) => {
let then_span = self.find_block_span_from_hir_id(then_id);
let else_span = self.find_block_span_from_hir_id(else_id);
if let hir::Node::Expr(e) = self.tcx.hir_node(else_id)
&& let hir::ExprKind::If(_cond, _then, None) = e.kind
ObligationCauseCode::IfExpression { expr_id, .. } => {
let hir::Node::Expr(&hir::Expr {
kind: hir::ExprKind::If(cond_expr, then_expr, Some(else_expr)),
span: expr_span,
..
}) = self.tcx.hir_node(expr_id)
else {
return;
};
let then_span = self.find_block_span_from_hir_id(then_expr.hir_id);
let then_ty = self
.typeck_results
.as_ref()
.expect("if expression only expected inside FnCtxt")
.expr_ty(then_expr);
let else_span = self.find_block_span_from_hir_id(else_expr.hir_id);
let else_ty = self
.typeck_results
.as_ref()
.expect("if expression only expected inside FnCtxt")
.expr_ty(else_expr);
if let hir::ExprKind::If(_cond, _then, None) = else_expr.kind
&& else_ty.is_unit()
{
// Account for `let x = if a { 1 } else if b { 2 };`
err.note("`if` expressions without `else` evaluate to `()`");
err.note("consider adding an `else` block that evaluates to the expected type");
}
err.span_label(then_span, "expected because of this");

let outer_span = if self.tcx.sess.source_map().is_multiline(expr_span) {
if then_span.hi() == expr_span.hi() || else_span.hi() == expr_span.hi() {
// Point at condition only if either block has the same end point as
// the whole expression, since that'll cause awkward overlapping spans.
Some(expr_span.shrink_to_lo().to(cond_expr.peel_drop_temps().span))
} else {
Some(expr_span)
}
} else {
None
};
if let Some(sp) = outer_span {
err.span_label(sp, "`if` and `else` have incompatible types");
}

let then_id = if let hir::ExprKind::Block(then_blk, _) = then_expr.kind {
then_blk.hir_id
} else {
then_expr.hir_id
};
let else_id = if let hir::ExprKind::Block(else_blk, _) = else_expr.kind {
else_blk.hir_id
} else {
else_expr.hir_id
};
if let Some(subdiag) = self.suggest_remove_semi_or_return_binding(
Some(then_id),
then_ty,
Expand Down
Loading
Loading