Skip to content

Commit c4af0ba

Browse files
committed
upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
1 parent 814ebca commit c4af0ba

File tree

24 files changed

+443
-25
lines changed

24 files changed

+443
-25
lines changed

Cargo.lock

+2
Original file line numberDiff line numberDiff line change
@@ -4233,6 +4233,7 @@ name = "rustc_monomorphize"
42334233
version = "0.0.0"
42344234
dependencies = [
42354235
"rustc_abi",
4236+
"rustc_ast",
42364237
"rustc_attr_parsing",
42374238
"rustc_data_structures",
42384239
"rustc_errors",
@@ -4242,6 +4243,7 @@ dependencies = [
42424243
"rustc_middle",
42434244
"rustc_session",
42444245
"rustc_span",
4246+
"rustc_symbol_mangling",
42454247
"rustc_target",
42464248
"serde",
42474249
"serde_json",

compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ struct UsageSets<'tcx> {
297297
/// Prepare sets of definitions that are relevant to deciding whether something
298298
/// is an "unused function" for coverage purposes.
299299
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
300-
let (all_mono_items, cgus) = tcx.collect_and_partition_mono_items(());
300+
let (all_mono_items, _, cgus) = tcx.collect_and_partition_mono_items(());
301301

302302
// Obtain a MIR body for each function participating in codegen, via an
303303
// arbitrary instance.

compiler/rustc_codegen_ssa/messages.ftl

+3
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,6 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp
367367
codegen_ssa_version_script_write_failure = failed to write version script: {$error}
368368
369369
codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload
370+
371+
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
372+

compiler/rustc_codegen_ssa/src/assert_module_sources.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr
4747
}
4848

4949
let available_cgus =
50-
tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect();
50+
tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect();
5151

5252
let mut ams = AssertModuleSource {
5353
tcx,

compiler/rustc_codegen_ssa/src/back/symbol_export.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ fn exported_symbols_provider_local(
293293
// external linkage is enough for monomorphization to be linked to.
294294
let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib;
295295

296-
let (_, cgus) = tcx.collect_and_partition_mono_items(());
296+
let (_, _, cgus) = tcx.collect_and_partition_mono_items(());
297297

298298
// The symbols created in this loop are sorted below it
299299
#[allow(rustc::potential_query_instability)]

compiler/rustc_codegen_ssa/src/back/write.rs

+33-4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
77
use std::{fs, io, mem, str, thread};
88

99
use rustc_ast::attr;
10+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1011
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
1112
use rustc_data_structures::jobserver::{self, Acquired};
1213
use rustc_data_structures::memmap::Mmap;
@@ -40,7 +41,7 @@ use tracing::debug;
4041
use super::link::{self, ensure_removed};
4142
use super::lto::{self, SerializedModule};
4243
use super::symbol_export::symbol_name_for_instance_in_crate;
43-
use crate::errors::ErrorCreatingRemarkDir;
44+
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
4445
use crate::traits::*;
4546
use crate::{
4647
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118119
pub merge_functions: bool,
119120
pub emit_lifetime_markers: bool,
120121
pub llvm_plugins: Vec<String>,
122+
pub autodiff: Vec<config::AutoDiff>,
121123
}
122124

123125
impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
266268

267269
emit_lifetime_markers: sess.emit_lifetime_markers(),
268270
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
271+
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
269272
}
270273
}
271274

@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389392

390393
fn generate_lto_work<B: ExtraBackendMethods>(
391394
cgcx: &CodegenContext<B>,
395+
autodiff: Vec<AutoDiffItem>,
392396
needs_fat_lto: Vec<FatLtoInput<B>>,
393397
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
394398
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@@ -399,9 +403,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
399403
assert!(needs_thin_lto.is_empty());
400404
let module =
401405
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
406+
if cgcx.lto == Lto::Fat {
407+
let _config = cgcx.config(ModuleKind::Regular);
408+
todo!("fat LTO with autodiff is not yet implemented");
409+
//module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
410+
}
402411
// We are adding a single work item, so the cost doesn't matter.
403412
vec![(WorkItem::LTO(module), 0)]
404413
} else {
414+
if !autodiff.is_empty() {
415+
let dcx = cgcx.create_dcx();
416+
dcx.handle().emit_fatal(AutodiffWithoutLto {});
417+
}
405418
assert!(needs_fat_lto.is_empty());
406419
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
407420
.unwrap_or_else(|e| e.raise());
@@ -1021,6 +1034,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10211034
/// Sent from a backend worker thread.
10221035
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
10231036

1037+
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1038+
AddAutoDiffItems(Vec<AutoDiffItem>),
1039+
10241040
/// The frontend has finished generating something (backend IR or a
10251041
/// post-LTO artifact) for a codegen unit, and it should be passed to the
10261042
/// backend. Sent from the main thread.
@@ -1348,6 +1364,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13481364

13491365
// This is where we collect codegen units that have gone all the way
13501366
// through codegen and LLVM.
1367+
let mut autodiff_items = Vec::new();
13511368
let mut compiled_modules = vec![];
13521369
let mut compiled_allocator_module = None;
13531370
let mut needs_link = Vec::new();
@@ -1459,9 +1476,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14591476
let needs_thin_lto = mem::take(&mut needs_thin_lto);
14601477
let import_only_modules = mem::take(&mut lto_import_only_modules);
14611478

1462-
for (work, cost) in
1463-
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1464-
{
1479+
for (work, cost) in generate_lto_work(
1480+
&cgcx,
1481+
autodiff_items.clone(),
1482+
needs_fat_lto,
1483+
needs_thin_lto,
1484+
import_only_modules,
1485+
) {
14651486
let insertion_index = work_items
14661487
.binary_search_by_key(&cost, |&(_, cost)| cost)
14671488
.unwrap_or_else(|e| e);
@@ -1596,6 +1617,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
15961617
main_thread_state = MainThreadState::Idle;
15971618
}
15981619

1620+
Message::AddAutoDiffItems(mut items) => {
1621+
autodiff_items.append(&mut items);
1622+
}
1623+
15991624
Message::CodegenComplete => {
16001625
if codegen_state != Aborted {
16011626
codegen_state = Completed;
@@ -2070,6 +2095,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20702095
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
20712096
}
20722097

2098+
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
2099+
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
2100+
}
2101+
20732102
pub(crate) fn check_for_errors(&self, sess: &Session) {
20742103
self.shared_emitter_main.check(sess, false);
20752104
}

compiler/rustc_codegen_ssa/src/base.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,8 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
618618

619619
// Run the monomorphization collector and partition the collected items into
620620
// codegen units.
621-
let codegen_units = tcx.collect_and_partition_mono_items(()).1;
621+
let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(());
622+
let autodiff_fncs = autodiff_fncs.to_vec();
622623

623624
// Force all codegen_unit queries so they are already either red or green
624625
// when compile_codegen_unit accesses them. We are not able to re-execute
@@ -689,6 +690,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
689690
);
690691
}
691692

693+
if !autodiff_fncs.is_empty() {
694+
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
695+
}
696+
692697
// For better throughput during parallel processing by LLVM, we used to sort
693698
// CGUs largest to smallest. This would lead to better thread utilization
694699
// by, for example, preventing a large CGU from being processed last and
@@ -1049,7 +1054,7 @@ pub(crate) fn provide(providers: &mut Providers) {
10491054
config::OptLevel::SizeMin => config::OptLevel::Default,
10501055
};
10511056

1052-
let (defids, _) = tcx.collect_and_partition_mono_items(cratenum);
1057+
let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum);
10531058

10541059
let any_for_speed = defids.items().any(|id| {
10551060
let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+134-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
use std::str::FromStr;
2+
13
use rustc_ast::attr::list_contains_name;
2-
use rustc_ast::{MetaItemInner, attr};
4+
use rustc_ast::expand::autodiff_attrs::{
5+
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
6+
};
7+
use rustc_ast::{MetaItem, MetaItemInner, attr};
38
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
49
use rustc_data_structures::fx::FxHashMap;
510
use rustc_errors::codes::*;
@@ -854,6 +859,133 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
854859
}
855860
}
856861

862+
/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
863+
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
864+
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
865+
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
866+
/// panic, unless we introduced a bug when parsing the autodiff macro.
867+
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
868+
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
869+
870+
let attrs =
871+
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
872+
873+
// check for exactly one autodiff attribute on placeholder functions.
874+
// There should only be one, since we generate a new placeholder per ad macro.
875+
// TODO: re-enable this. We should fix that rustc_autodiff isn't applied multiple times to the
876+
// source function.
877+
let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once";
878+
let attr = match attrs.len() {
879+
0 => return AutoDiffAttrs::error(),
880+
1 => attrs.get(0).unwrap(),
881+
_ => {
882+
attrs.get(0).unwrap()
883+
//tcx.dcx().struct_span_err(attrs[1].span, msg_once).with_note("more than one").emit();
884+
//return AutoDiffAttrs::error();
885+
}
886+
};
887+
888+
let list = attr.meta_item_list().unwrap_or_default();
889+
890+
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
891+
if list.len() == 0 {
892+
return AutoDiffAttrs::source();
893+
}
894+
895+
let [mode, input_activities @ .., ret_activity] = &list[..] else {
896+
tcx.dcx()
897+
.struct_span_err(attr.span, msg_once)
898+
.with_note("Implementation bug in autodiff_attrs. Please report this!")
899+
.emit();
900+
return AutoDiffAttrs::error();
901+
};
902+
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
903+
p1.segments.first().unwrap().ident
904+
} else {
905+
let msg = "autodiff attribute must contain autodiff mode";
906+
tcx.dcx().struct_span_err(attr.span, msg).with_note("empty argument list").emit();
907+
return AutoDiffAttrs::error();
908+
};
909+
910+
// parse mode
911+
let msg_mode = "mode should be either forward or reverse";
912+
let mode = match mode.as_str() {
913+
"Forward" => DiffMode::Forward,
914+
"Reverse" => DiffMode::Reverse,
915+
"ForwardFirst" => DiffMode::ForwardFirst,
916+
"ReverseFirst" => DiffMode::ReverseFirst,
917+
_ => {
918+
tcx.dcx().struct_span_err(attr.span, msg_mode).with_note("invalid mode").emit();
919+
return AutoDiffAttrs::error();
920+
}
921+
};
922+
923+
// First read the ret symbol from the attribute
924+
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
925+
p1.segments.first().unwrap().ident
926+
} else {
927+
let msg = "autodiff attribute must contain the return activity";
928+
tcx.dcx().struct_span_err(attr.span, msg).with_note("missing return activity").emit();
929+
return AutoDiffAttrs::error();
930+
};
931+
932+
// Then parse it into an actual DiffActivity
933+
let msg_unknown_ret_activity = "unknown return activity";
934+
let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) {
935+
Ok(x) => x,
936+
Err(_) => {
937+
tcx.dcx()
938+
.struct_span_err(attr.span, msg_unknown_ret_activity)
939+
.with_note("invalid return activity")
940+
.emit();
941+
return AutoDiffAttrs::error();
942+
}
943+
};
944+
945+
// Now parse all the intermediate (input) activities
946+
let msg_arg_activity = "autodiff attribute must contain the return activity";
947+
let mut arg_activities: Vec<DiffActivity> = vec![];
948+
for arg in input_activities {
949+
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
950+
p2.segments.first().unwrap().ident
951+
} else {
952+
tcx.dcx()
953+
.struct_span_err(attr.span, msg_arg_activity)
954+
.with_note("Implementation bug, please report this!")
955+
.emit();
956+
return AutoDiffAttrs::error();
957+
};
958+
959+
match DiffActivity::from_str(arg_symbol.as_str()) {
960+
Ok(arg_activity) => arg_activities.push(arg_activity),
961+
Err(_) => {
962+
tcx.dcx()
963+
.struct_span_err(attr.span, msg_unknown_ret_activity)
964+
.with_note("invalid input activity")
965+
.emit();
966+
return AutoDiffAttrs::error();
967+
}
968+
}
969+
}
970+
971+
let mut msg = "".to_string();
972+
for &input in &arg_activities {
973+
if !valid_input_activity(mode, input) {
974+
msg = format!("Invalid input activity {} for {} mode", input, mode);
975+
}
976+
}
977+
if !valid_ret_activity(mode, ret_activity) {
978+
msg = format!("Invalid return activity {} for {} mode", ret_activity, mode);
979+
}
980+
if msg != "".to_string() {
981+
tcx.dcx().struct_span_err(attr.span, msg).with_note("invalid activity").emit();
982+
return AutoDiffAttrs::error();
983+
}
984+
985+
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
986+
}
987+
857988
pub(crate) fn provide(providers: &mut Providers) {
858-
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
989+
*providers =
990+
Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers };
859991
}

compiler/rustc_codegen_ssa/src/errors.rs

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ pub(crate) struct CguNotRecorded<'a> {
3939
pub cgu_name: &'a str,
4040
}
4141

42+
#[derive(Diagnostic)]
43+
#[diag(codegen_ssa_autodiff_without_lto)]
44+
pub struct AutodiffWithoutLto;
45+
4246
#[derive(Diagnostic)]
4347
#[diag(codegen_ssa_unknown_reuse_kind)]
4448
pub(crate) struct UnknownReuseKind {

compiler/rustc_codegen_ssa/src/traits/write.rs

+2
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
1313
type ModuleBuffer: ModuleBufferMethods;
1414
type ThinData: Send + Sync;
1515
type ThinBuffer: ThinBufferMethods;
16+
//type TypeTree: Clone;
1617

1718
/// Merge all modules into main_module and returning it
1819
fn run_link(
@@ -37,6 +38,7 @@ pub trait WriteBackendMethods: 'static + Sized + Clone {
3738
) -> Result<(Vec<LtoModuleCodegen<Self>>, Vec<WorkProduct>), FatalError>;
3839
fn print_pass_timings(&self);
3940
fn print_statistics(&self);
41+
// does enzyme prep work, should do ad too.
4042
unsafe fn optimize(
4143
cgcx: &CodegenContext<Self>,
4244
dcx: DiagCtxtHandle<'_>,

compiler/rustc_interface/src/tests.rs

+1
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,7 @@ fn test_unstable_options_tracking_hash() {
760760
tracked!(allow_features, Some(vec![String::from("lang_items")]));
761761
tracked!(always_encode_mir, true);
762762
tracked!(assume_incomplete_release, true);
763+
tracked!(autodiff, vec![String::from("ad_flags")]);
763764
tracked!(binary_dep_depinfo, true);
764765
tracked!(box_noalias, false);
765766
tracked!(

compiler/rustc_middle/messages.ftl

+4
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1+
middle_autodiff_unsafe_inner_const_ref = reading from a `Duplicated` const {$ty} is unsafe
2+
3+
middle_unsupported_union = we don't support unions yet: '{$ty_name}'
4+
15
middle_adjust_for_foreign_abi_error =
26
target architecture {$arch} does not support `extern {$abi}` ABI
37

compiler/rustc_middle/src/arena.rs

+1
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ macro_rules! arena_types {
8787
[] codegen_unit: rustc_middle::mir::mono::CodegenUnit<'tcx>,
8888
[decode] attribute: rustc_hir::Attribute,
8989
[] name_set: rustc_data_structures::unord::UnordSet<rustc_span::Symbol>,
90+
[] autodiff_item: rustc_ast::expand::autodiff_attrs::AutoDiffItem,
9091
[] ordered_name_set: rustc_data_structures::fx::FxIndexSet<rustc_span::Symbol>,
9192
[] pats: rustc_middle::ty::PatternKind<'tcx>,
9293

0 commit comments

Comments
 (0)