Skip to content

Commit fe9ea35

Browse files
committed
upstream rustc_codegen_ssa/rustc_middle changes for enzyme/autodiff
1 parent c1db4dc commit fe9ea35

File tree

29 files changed

+531
-40
lines changed

29 files changed

+531
-40
lines changed

Cargo.lock

+2
Original file line numberDiff line numberDiff line change
@@ -4121,6 +4121,7 @@ dependencies = [
41214121
name = "rustc_monomorphize"
41224122
version = "0.0.0"
41234123
dependencies = [
4124+
"rustc_ast",
41244125
"rustc_data_structures",
41254126
"rustc_errors",
41264127
"rustc_fluent_macro",
@@ -4129,6 +4130,7 @@ dependencies = [
41294130
"rustc_middle",
41304131
"rustc_session",
41314132
"rustc_span",
4133+
"rustc_symbol_mangling",
41324134
"rustc_target",
41334135
"serde",
41344136
"serde_json",

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+2-11
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
use std::fmt::{self, Display, Formatter};
77
use std::str::FromStr;
88

9-
use crate::expand::typetree::TypeTree;
109
use crate::expand::{Decodable, Encodable, HashStable_Generic};
1110
use crate::ptr::P;
1211
use crate::{Ty, TyKind};
@@ -79,10 +78,6 @@ pub struct AutoDiffItem {
7978
/// The name of the function being generated
8079
pub target: String,
8180
pub attrs: AutoDiffAttrs,
82-
/// Describe the memory layout of input types
83-
pub inputs: Vec<TypeTree>,
84-
/// Describe the memory layout of the output type
85-
pub output: TypeTree,
8681
}
8782
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8883
pub struct AutoDiffAttrs {
@@ -266,18 +261,14 @@ impl AutoDiffAttrs {
266261
self,
267262
source: String,
268263
target: String,
269-
inputs: Vec<TypeTree>,
270-
output: TypeTree,
271264
) -> AutoDiffItem {
272-
AutoDiffItem { source, target, inputs, output, attrs: self }
265+
AutoDiffItem { source, target, attrs: self }
273266
}
274267
}
275268

276269
impl fmt::Display for AutoDiffItem {
277270
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
278271
write!(f, "Differentiating {} -> {}", self.source, self.target)?;
279-
write!(f, " with attributes: {:?}", self.attrs)?;
280-
write!(f, " with inputs: {:?}", self.inputs)?;
281-
write!(f, " with output: {:?}", self.output)
272+
write!(f, " with attributes: {:?}", self.attrs)
282273
}
283274
}

compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,7 @@ struct UsageSets<'tcx> {
427427
/// Prepare sets of definitions that are relevant to deciding whether something
428428
/// is an "unused function" for coverage purposes.
429429
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
430-
let (all_mono_items, cgus) = tcx.collect_and_partition_mono_items(());
430+
let (all_mono_items, _, cgus) = tcx.collect_and_partition_mono_items(());
431431

432432
// Obtain a MIR body for each function participating in codegen, via an
433433
// arbitrary instance.

compiler/rustc_codegen_ssa/messages.ftl

+3
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,6 @@ codegen_ssa_use_cargo_directive = use the `cargo:rustc-link-lib` directive to sp
351351
codegen_ssa_version_script_write_failure = failed to write version script: {$error}
352352
353353
codegen_ssa_visual_studio_not_installed = you may need to install Visual Studio build tools with the "C++ build tools" workload
354+
355+
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
356+

compiler/rustc_codegen_ssa/src/assert_module_sources.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub fn assert_module_sources(tcx: TyCtxt<'_>, set_reuse: &dyn Fn(&mut CguReuseTr
4848
}
4949

5050
let available_cgus =
51-
tcx.collect_and_partition_mono_items(()).1.iter().map(|cgu| cgu.name()).collect();
51+
tcx.collect_and_partition_mono_items(()).2.iter().map(|cgu| cgu.name()).collect();
5252

5353
let mut ams = AssertModuleSource {
5454
tcx,

compiler/rustc_codegen_ssa/src/back/symbol_export.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,7 @@ fn exported_symbols_provider_local(
293293
// external linkage is enough for monomorphization to be linked to.
294294
let need_visibility = tcx.sess.target.dynamic_linking && !tcx.sess.target.only_cdylib;
295295

296-
let (_, cgus) = tcx.collect_and_partition_mono_items(());
296+
let (_, _, cgus) = tcx.collect_and_partition_mono_items(());
297297

298298
// The symbols created in this loop are sorted below it
299299
#[allow(rustc::potential_query_instability)]

compiler/rustc_codegen_ssa/src/back/write.rs

+33-4
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::{fs, io, mem, str, thread};
88

99
use jobserver::{Acquired, Client};
1010
use rustc_ast::attr;
11+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1112
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
1213
use rustc_data_structures::memmap::Mmap;
1314
use rustc_data_structures::profiling::{SelfProfilerRef, VerboseTimingGuard};
@@ -41,7 +42,7 @@ use tracing::debug;
4142
use super::link::{self, ensure_removed};
4243
use super::lto::{self, SerializedModule};
4344
use super::symbol_export::symbol_name_for_instance_in_crate;
44-
use crate::errors::ErrorCreatingRemarkDir;
45+
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
4546
use crate::traits::*;
4647
use crate::{
4748
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@@ -120,6 +121,7 @@ pub struct ModuleConfig {
120121
pub merge_functions: bool,
121122
pub emit_lifetime_markers: bool,
122123
pub llvm_plugins: Vec<String>,
124+
pub autodiff: Vec<config::AutoDiff>,
123125
}
124126

125127
impl ModuleConfig {
@@ -280,6 +282,7 @@ impl ModuleConfig {
280282

281283
emit_lifetime_markers: sess.emit_lifetime_markers(),
282284
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
285+
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
283286
}
284287
}
285288

@@ -401,6 +404,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
401404

402405
fn generate_lto_work<B: ExtraBackendMethods>(
403406
cgcx: &CodegenContext<B>,
407+
autodiff: Vec<AutoDiffItem>,
404408
needs_fat_lto: Vec<FatLtoInput<B>>,
405409
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
406410
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@@ -411,9 +415,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
411415
assert!(needs_thin_lto.is_empty());
412416
let module =
413417
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
418+
if cgcx.lto == Lto::Fat {
419+
let _config = cgcx.config(ModuleKind::Regular);
420+
todo!("fat LTO with autodiff is not yet implemented");
421+
//module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
422+
}
414423
// We are adding a single work item, so the cost doesn't matter.
415424
vec![(WorkItem::LTO(module), 0)]
416425
} else {
426+
if !autodiff.is_empty() {
427+
let dcx = cgcx.create_dcx();
428+
dcx.handle().emit_fatal(AutodiffWithoutLto {});
429+
}
417430
assert!(needs_fat_lto.is_empty());
418431
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
419432
.unwrap_or_else(|e| e.raise());
@@ -1041,6 +1054,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10411054
/// Sent from a backend worker thread.
10421055
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
10431056

1057+
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1058+
AddAutoDiffItems(Vec<AutoDiffItem>),
1059+
10441060
/// The frontend has finished generating something (backend IR or a
10451061
/// post-LTO artifact) for a codegen unit, and it should be passed to the
10461062
/// backend. Sent from the main thread.
@@ -1367,6 +1383,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13671383

13681384
// This is where we collect codegen units that have gone all the way
13691385
// through codegen and LLVM.
1386+
let mut autodiff_items = Vec::new();
13701387
let mut compiled_modules = vec![];
13711388
let mut compiled_allocator_module = None;
13721389
let mut needs_link = Vec::new();
@@ -1478,9 +1495,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14781495
let needs_thin_lto = mem::take(&mut needs_thin_lto);
14791496
let import_only_modules = mem::take(&mut lto_import_only_modules);
14801497

1481-
for (work, cost) in
1482-
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1483-
{
1498+
for (work, cost) in generate_lto_work(
1499+
&cgcx,
1500+
autodiff_items.clone(),
1501+
needs_fat_lto,
1502+
needs_thin_lto,
1503+
import_only_modules,
1504+
) {
14841505
let insertion_index = work_items
14851506
.binary_search_by_key(&cost, |&(_, cost)| cost)
14861507
.unwrap_or_else(|e| e);
@@ -1615,6 +1636,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
16151636
main_thread_state = MainThreadState::Idle;
16161637
}
16171638

1639+
Message::AddAutoDiffItems(mut items) => {
1640+
autodiff_items.append(&mut items);
1641+
}
1642+
16181643
Message::CodegenComplete => {
16191644
if codegen_state != Aborted {
16201645
codegen_state = Completed;
@@ -2092,6 +2117,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20922117
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
20932118
}
20942119

2120+
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
2121+
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
2122+
}
2123+
20952124
pub(crate) fn check_for_errors(&self, sess: &Session) {
20962125
self.shared_emitter_main.check(sess, false);
20972126
}

compiler/rustc_codegen_ssa/src/base.rs

+7-2
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,8 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
621621

622622
// Run the monomorphization collector and partition the collected items into
623623
// codegen units.
624-
let codegen_units = tcx.collect_and_partition_mono_items(()).1;
624+
let (_, autodiff_fncs, codegen_units) = tcx.collect_and_partition_mono_items(());
625+
let autodiff_fncs = autodiff_fncs.to_vec();
625626

626627
// Force all codegen_unit queries so they are already either red or green
627628
// when compile_codegen_unit accesses them. We are not able to re-execute
@@ -692,6 +693,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
692693
);
693694
}
694695

696+
if !autodiff_fncs.is_empty() {
697+
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
698+
}
699+
695700
// For better throughput during parallel processing by LLVM, we used to sort
696701
// CGUs largest to smallest. This would lead to better thread utilization
697702
// by, for example, preventing a large CGU from being processed last and
@@ -1051,7 +1056,7 @@ pub(crate) fn provide(providers: &mut Providers) {
10511056
config::OptLevel::SizeMin => config::OptLevel::Default,
10521057
};
10531058

1054-
let (defids, _) = tcx.collect_and_partition_mono_items(cratenum);
1059+
let (defids, _, _) = tcx.collect_and_partition_mono_items(cratenum);
10551060

10561061
let any_for_speed = defids.items().any(|id| {
10571062
let CodegenFnAttrs { optimize, .. } = tcx.codegen_fn_attrs(*id);

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+134-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
use rustc_ast::{MetaItemInner, MetaItemKind, ast, attr};
1+
use std::str::FromStr;
2+
3+
use rustc_ast::expand::autodiff_attrs::{
4+
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
5+
};
6+
use rustc_ast::{MetaItem, MetaItemInner, MetaItemKind, ast, attr};
27
use rustc_attr::{InlineAttr, InstructionSetAttr, OptimizeAttr, list_contains_name};
38
use rustc_data_structures::fx::FxHashMap;
49
use rustc_errors::codes::*;
@@ -779,6 +784,133 @@ fn check_link_name_xor_ordinal(
779784
}
780785
}
781786

787+
/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
788+
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
789+
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
790+
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
791+
/// panic, unless we introduced a bug when parsing the autodiff macro.
792+
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
793+
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
794+
795+
let attrs =
796+
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
797+
798+
// check for exactly one autodiff attribute on placeholder functions.
799+
// There should only be one, since we generate a new placeholder per ad macro.
800+
// TODO: re-enable this. We should fix that rustc_autodiff isn't applied multiple times to the
801+
// source function.
802+
let msg_once = "cg_ssa: implementation bug. Autodiff attribute can only be applied once";
803+
let attr = match attrs.len() {
804+
0 => return AutoDiffAttrs::error(),
805+
1 => attrs.get(0).unwrap(),
806+
_ => {
807+
attrs.get(0).unwrap()
808+
//tcx.dcx().struct_span_err(attrs[1].span, msg_once).with_note("more than one").emit();
809+
//return AutoDiffAttrs::error();
810+
}
811+
};
812+
813+
let list = attr.meta_item_list().unwrap_or_default();
814+
815+
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
816+
if list.len() == 0 {
817+
return AutoDiffAttrs::source();
818+
}
819+
820+
let [mode, input_activities @ .., ret_activity] = &list[..] else {
821+
tcx.dcx()
822+
.struct_span_err(attr.span, msg_once)
823+
.with_note("Implementation bug in autodiff_attrs. Please report this!")
824+
.emit();
825+
return AutoDiffAttrs::error();
826+
};
827+
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
828+
p1.segments.first().unwrap().ident
829+
} else {
830+
let msg = "autodiff attribute must contain autodiff mode";
831+
tcx.dcx().struct_span_err(attr.span, msg).with_note("empty argument list").emit();
832+
return AutoDiffAttrs::error();
833+
};
834+
835+
// parse mode
836+
let msg_mode = "mode should be either forward or reverse";
837+
let mode = match mode.as_str() {
838+
"Forward" => DiffMode::Forward,
839+
"Reverse" => DiffMode::Reverse,
840+
"ForwardFirst" => DiffMode::ForwardFirst,
841+
"ReverseFirst" => DiffMode::ReverseFirst,
842+
_ => {
843+
tcx.dcx().struct_span_err(attr.span, msg_mode).with_note("invalid mode").emit();
844+
return AutoDiffAttrs::error();
845+
}
846+
};
847+
848+
// First read the ret symbol from the attribute
849+
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
850+
p1.segments.first().unwrap().ident
851+
} else {
852+
let msg = "autodiff attribute must contain the return activity";
853+
tcx.dcx().struct_span_err(attr.span, msg).with_note("missing return activity").emit();
854+
return AutoDiffAttrs::error();
855+
};
856+
857+
// Then parse it into an actual DiffActivity
858+
let msg_unknown_ret_activity = "unknown return activity";
859+
let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) {
860+
Ok(x) => x,
861+
Err(_) => {
862+
tcx.dcx()
863+
.struct_span_err(attr.span, msg_unknown_ret_activity)
864+
.with_note("invalid return activity")
865+
.emit();
866+
return AutoDiffAttrs::error();
867+
}
868+
};
869+
870+
// Now parse all the intermediate (input) activities
871+
let msg_arg_activity = "autodiff attribute must contain the return activity";
872+
let mut arg_activities: Vec<DiffActivity> = vec![];
873+
for arg in input_activities {
874+
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
875+
p2.segments.first().unwrap().ident
876+
} else {
877+
tcx.dcx()
878+
.struct_span_err(attr.span, msg_arg_activity)
879+
.with_note("Implementation bug, please report this!")
880+
.emit();
881+
return AutoDiffAttrs::error();
882+
};
883+
884+
match DiffActivity::from_str(arg_symbol.as_str()) {
885+
Ok(arg_activity) => arg_activities.push(arg_activity),
886+
Err(_) => {
887+
tcx.dcx()
888+
.struct_span_err(attr.span, msg_unknown_ret_activity)
889+
.with_note("invalid input activity")
890+
.emit();
891+
return AutoDiffAttrs::error();
892+
}
893+
}
894+
}
895+
896+
let mut msg = "".to_string();
897+
for &input in &arg_activities {
898+
if !valid_input_activity(mode, input) {
899+
msg = format!("Invalid input activity {} for {} mode", input, mode);
900+
}
901+
}
902+
if !valid_ret_activity(mode, ret_activity) {
903+
msg = format!("Invalid return activity {} for {} mode", ret_activity, mode);
904+
}
905+
if msg != "".to_string() {
906+
tcx.dcx().struct_span_err(attr.span, msg).with_note("invalid activity").emit();
907+
return AutoDiffAttrs::error();
908+
}
909+
910+
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
911+
}
912+
782913
pub(crate) fn provide(providers: &mut Providers) {
783-
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
914+
*providers =
915+
Providers { codegen_fn_attrs, should_inherit_track_caller, autodiff_attrs, ..*providers };
784916
}

compiler/rustc_codegen_ssa/src/errors.rs

+4
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ pub(crate) struct CguNotRecorded<'a> {
3737
pub cgu_name: &'a str,
3838
}
3939

40+
#[derive(Diagnostic)]
41+
#[diag(codegen_ssa_autodiff_without_lto)]
42+
pub struct AutodiffWithoutLto;
43+
4044
#[derive(Diagnostic)]
4145
#[diag(codegen_ssa_unknown_reuse_kind)]
4246
pub(crate) struct UnknownReuseKind {

0 commit comments

Comments
 (0)