Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions compiler/rustc_mir_transform/src/coroutine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ struct SelfArgVisitor<'tcx> {
}

impl<'tcx> SelfArgVisitor<'tcx> {
fn new(tcx: TyCtxt<'tcx>, elem: ProjectionElem<Local, Ty<'tcx>>) -> Self {
Self { tcx, new_base: Place { local: SELF_ARG, projection: tcx.mk_place_elems(&[elem]) } }
fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
Self { tcx, new_base }
}
}

Expand All @@ -146,16 +146,14 @@ impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
assert_ne!(*local, SELF_ARG);
}

fn visit_place(&mut self, place: &mut Place<'tcx>, context: PlaceContext, location: Location) {
fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
if place.local == SELF_ARG {
replace_base(place, self.new_base, self.tcx);
} else {
self.visit_local(&mut place.local, context, location);
}

for elem in place.projection.iter() {
if let PlaceElem::Index(local) = elem {
assert_ne!(local, SELF_ARG);
}
for elem in place.projection.iter() {
if let PlaceElem::Index(local) = elem {
assert_ne!(local, SELF_ARG);
}
}
}
Expand All @@ -176,6 +174,7 @@ const SELF_ARG: Local = Local::from_u32(1);
const CTX_ARG: Local = Local::from_u32(2);

/// A `yield` point in the coroutine.
#[derive(Debug)]
struct SuspensionPoint<'tcx> {
/// State discriminant used when suspending or resuming at this point.
state: usize,
Expand Down Expand Up @@ -520,32 +519,56 @@ fn make_aggregate_adt<'tcx>(

#[tracing::instrument(level = "trace", skip(tcx, body))]
fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let coroutine_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls[SELF_ARG].ty;

let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);

// Replace the by value coroutine argument
body.local_decls.raw[1].ty = ref_coroutine_ty;
body.local_decls[SELF_ARG].ty = ref_coroutine_ty;

// Add a deref to accesses of the coroutine state
SelfArgVisitor::new(tcx, ProjectionElem::Deref).visit_body(body);
SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
}

#[tracing::instrument(level = "trace", skip(tcx, body))]
fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
let ref_coroutine_ty = body.local_decls.raw[1].ty;
let coroutine_ty = body.local_decls[SELF_ARG].ty;

let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);

let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
let pin_adt_ref = tcx.adt_def(pin_did);
let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);

// Replace the by ref coroutine argument
body.local_decls.raw[1].ty = pin_ref_coroutine_ty;
body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;

let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));

// Add the Pin field access to accesses of the coroutine state
SelfArgVisitor::new(tcx, ProjectionElem::Field(FieldIdx::ZERO, ref_coroutine_ty))
.visit_body(body);
SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);

let source_info = SourceInfo::outermost(body.span);
let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);

let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
// Miri requires retags to be the very first thing in the body.
// We insert this assignment just after.
let insert_point = statements
.iter()
.position(|stmt| !matches!(stmt.kind, StatementKind::Retag(..)))
.unwrap_or(statements.len());
statements.insert(
insert_point,
Statement::new(
source_info,
StatementKind::Assign(Box::new((
unpinned_local.into(),
Rvalue::Use(Operand::Copy(pin_field)),
))),
),
);
}

/// Transforms the `body` of the coroutine applying the following transforms:
Expand Down Expand Up @@ -634,7 +657,7 @@ fn replace_resume_ty_local<'tcx>(
// We have to replace the `ResumeTy` that is used for type and borrow checking
// with `&mut Context<'_>` in MIR.
#[cfg(debug_assertions)]
{
if local_ty != context_mut_ref {
if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
assert_eq!(*resume_ty_adt, expected_adt);
Expand Down Expand Up @@ -1297,8 +1320,6 @@ fn create_coroutine_resume_function<'tcx>(
let default_block = insert_term_block(body, TerminatorKind::Unreachable);
insert_switch(body, cases, &transform, default_block);

make_coroutine_state_argument_indirect(tcx, body);

match transform.coroutine_kind {
CoroutineKind::Coroutine(_)
| CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
Expand All @@ -1307,17 +1328,9 @@ fn create_coroutine_resume_function<'tcx>(
}
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
}

// Make sure we remove dead blocks to remove
// unrelated code from the drop part of the function
simplify::remove_dead_blocks(body);

pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);

if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
dumper.dump_mir(body);
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
make_coroutine_state_argument_indirect(tcx, body);
}
}
}

Expand Down Expand Up @@ -1674,6 +1687,21 @@ impl<'tcx> crate::MirPass<'tcx> for StateTransform {
// Create the Coroutine::resume / Future::poll function
create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);

if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
dumper.dump_mir(body);
}

pm::run_passes_no_validate(
tcx,
body,
&[
&crate::abort_unwinding_calls::AbortUnwindingCalls,
&crate::simplify::SimplifyCfg::PostStateTransform,
&crate::simplify::SimplifyLocals::PostStateTransform,
],
None,
);

// Run derefer to fix Derefs that are not in the first place
deref_finder(tcx, body, false);
}
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_mir_transform/src/coroutine/drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -684,12 +684,13 @@ pub(super) fn create_coroutine_drop_shim_async<'tcx>(
let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);

make_coroutine_state_argument_indirect(tcx, &mut body);

match transform.coroutine_kind {
// Iterator::next doesn't accept a pinned argument,
// unlike for all other coroutine kinds.
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
make_coroutine_state_argument_indirect(tcx, &mut body);
}

_ => {
make_coroutine_state_argument_pinned(tcx, &mut body);
}
Expand Down
11 changes: 8 additions & 3 deletions compiler/rustc_mir_transform/src/dataflow_const_prop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,12 @@ impl<'tcx> crate::MirPass<'tcx> for DataflowConstProp {
return;
}

// Avoid computing layout inside coroutines, since their `optimized_mir` is used for layout
// computation, which can create a cycle.
if body.coroutine.is_some() {
return;
}

// We want to have a somewhat linear runtime w.r.t. the number of statements/terminators.
// Let's call this number `n`. Dataflow analysis has `O(h*n)` transfer function
// applications, where `h` is the height of the lattice. Because the height of our lattice
Expand Down Expand Up @@ -237,9 +243,8 @@ impl<'a, 'tcx> ConstAnalysis<'a, 'tcx> {
TerminatorKind::Drop { place, .. } => {
state.flood_with(place.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM);
}
TerminatorKind::Yield { .. } => {
// They would have an effect, but are not allowed in this phase.
bug!("encountered disallowed terminator");
TerminatorKind::Yield { resume_arg, .. } => {
state.flood_with(resume_arg.as_ref(), &self.map, FlatSet::<Scalar>::BOTTOM);
}
TerminatorKind::SwitchInt { discr, targets } => {
return self.handle_switch_int(discr, targets, state);
Expand Down
20 changes: 12 additions & 8 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1911,14 +1911,18 @@ impl<'tcx> MutVisitor<'tcx> for VnState<'_, '_, 'tcx> {
}

fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
if let Terminator { kind: TerminatorKind::Call { destination, .. }, .. } = terminator {
if let Some(local) = destination.as_local()
&& self.ssa.is_ssa(local)
{
let ty = self.local_decls[local].ty;
let opaque = self.new_opaque(ty);
self.assign(local, opaque);
}
let destination = match terminator.kind {
TerminatorKind::Call { destination, .. } => Some(destination),
TerminatorKind::Yield { resume_arg, .. } => Some(resume_arg),
_ => None,
};
if let Some(destination) = destination
&& let Some(local) = destination.as_local()
&& self.ssa.is_ssa(local)
{
let ty = self.local_decls[local].ty;
let opaque = self.new_opaque(ty);
self.assign(local, opaque);
}
// Function calls and ASM may invalidate (nested) derefs. We must handle them carefully.
// Currently, only preserving derefs for trivial terminators like SwitchInt and Goto.
Expand Down
7 changes: 4 additions & 3 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -610,9 +610,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
| TerminatorKind::Unreachable
| TerminatorKind::CoroutineDrop => bug!("{term:?} has no terminators"),
// Disallowed during optimizations.
TerminatorKind::FalseEdge { .. }
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
TerminatorKind::FalseEdge { .. } | TerminatorKind::FalseUnwind { .. } => {
bug!("{term:?} invalid")
}
// Cannot reason about inline asm.
TerminatorKind::InlineAsm { .. } => return,
// `SwitchInt` is handled specially.
Expand All @@ -621,6 +621,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
TerminatorKind::Goto { .. } => None,
// Flood the overwritten place, and progress through.
TerminatorKind::Drop { place: destination, .. }
| TerminatorKind::Yield { resume_arg: destination, .. }
| TerminatorKind::Call { destination, .. } => Some(destination),
// Ignore, as this can be a no-op at codegen time.
TerminatorKind::Assert { .. } => None,
Expand Down
10 changes: 6 additions & 4 deletions compiler/rustc_mir_transform/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ mod ssa;
macro_rules! declare_passes {
(
$(
$vis:vis mod $mod_name:ident : $($pass_name:ident $( { $($ident:ident),* } )?),+ $(,)?;
$vis:vis mod $mod_name:ident : $($pass_name:ident $( { $($ident:ident),* $(,)? } )?),+ $(,)?;
)*
) => {
$(
Expand Down Expand Up @@ -181,12 +181,14 @@ declare_passes! {
PreOptimizations,
Final,
MakeShim,
AfterUnreachableEnumBranching
AfterUnreachableEnumBranching,
PostStateTransform,
},
SimplifyLocals {
BeforeConstProp,
AfterGVN,
Final
Final,
PostStateTransform,
};
mod simplify_branches : SimplifyConstCondition {
AfterConstProp,
Expand Down Expand Up @@ -626,7 +628,6 @@ fn run_runtime_lowering_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
&add_retag::AddRetag,
&erase_deref_temps::EraseDerefTemps,
&elaborate_box_derefs::ElaborateBoxDerefs,
&coroutine::StateTransform,
&Lint(known_panics_lint::KnownPanicsLint),
];
pm::run_passes_no_validate(tcx, body, passes, Some(MirPhase::Runtime(RuntimePhase::Initial)));
Expand Down Expand Up @@ -732,6 +733,7 @@ pub(crate) fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'
&simplify::SimplifyLocals::Final,
&multiple_return_terminators::MultipleReturnTerminators,
&large_enums::EnumSizeOpt { discrepancy: 128 },
&coroutine::StateTransform,
// Some cleanup necessary at least for LLVM and potentially other codegen backends.
&add_call_guards::CriticalCallEdges,
// Cleanup for human readability, off by default.
Expand Down
1 change: 0 additions & 1 deletion compiler/rustc_mir_transform/src/shim.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ fn make_shim<'tcx>(tcx: TyCtxt<'tcx>, instance: ty::InstanceKind<'tcx>) -> Body<
tcx,
&mut body,
&[
&mentioned_items::MentionedItems,
&abort_unwinding_calls::AbortUnwindingCalls,
&add_call_guards::CriticalCallEdges,
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ fn build_adrop_for_coroutine_shim<'tcx>(
body.source.instance = instance;
body.phase = MirPhase::Runtime(RuntimePhase::Initial);
body.var_debug_info.clear();
body.mentioned_items = None;
let pin_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Pin, span));
let args = tcx.mk_args(&[proxy_ref.into()]);
let pin_proxy_ref = Ty::new_adt(tcx, pin_adt_ref, args);
Expand Down
6 changes: 6 additions & 0 deletions compiler/rustc_mir_transform/src/simplify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ pub(super) enum SimplifyCfg {
Final,
MakeShim,
AfterUnreachableEnumBranching,
/// Extra run introduced by `StateTransform`.
PostStateTransform,
}

impl SimplifyCfg {
Expand All @@ -72,6 +74,7 @@ impl SimplifyCfg {
SimplifyCfg::AfterUnreachableEnumBranching => {
"SimplifyCfg-after-unreachable-enum-branching"
}
SimplifyCfg::PostStateTransform => "SimplifyCfg-post-StateTransform",
}
}
}
Expand Down Expand Up @@ -416,6 +419,8 @@ pub(super) enum SimplifyLocals {
BeforeConstProp,
AfterGVN,
Final,
/// Extra run introduced by `StateTransform`.
PostStateTransform,
}

impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
Expand All @@ -424,6 +429,7 @@ impl<'tcx> crate::MirPass<'tcx> for SimplifyLocals {
SimplifyLocals::BeforeConstProp => "SimplifyLocals-before-const-prop",
SimplifyLocals::AfterGVN => "SimplifyLocals-after-value-numbering",
SimplifyLocals::Final => "SimplifyLocals-final",
SimplifyLocals::PostStateTransform => "SimplifyLocals-post-StateTransform",
}
}

Expand Down
4 changes: 2 additions & 2 deletions compiler/rustc_mir_transform/src/validate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,7 @@ impl<'a, 'tcx> Visitor<'tcx> for CfgChecker<'a, 'tcx> {
if self.body.coroutine.is_none() {
self.fail(location, "`Yield` cannot appear outside coroutine bodies");
}
if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) {
if self.body.phase >= MirPhase::Runtime(RuntimePhase::Optimized) {
self.fail(location, "`Yield` should have been replaced by coroutine lowering");
}
self.check_edge(location, *resume, EdgeKind::Normal);
Expand Down Expand Up @@ -489,7 +489,7 @@ impl<'a, 'tcx> Visitor<'tcx> for CfgChecker<'a, 'tcx> {
if self.body.coroutine.is_none() {
self.fail(location, "`CoroutineDrop` cannot appear outside coroutine bodies");
}
if self.body.phase >= MirPhase::Runtime(RuntimePhase::Initial) {
if self.body.phase >= MirPhase::Runtime(RuntimePhase::Optimized) {
self.fail(
location,
"`CoroutineDrop` should have been replaced by coroutine lowering",
Expand Down
22 changes: 0 additions & 22 deletions tests/crashes/140303.rs

This file was deleted.

Loading
Loading