Skip to content

Commit c19c4b9

Browse files
authored
Rollup merge of #133429 - EnzymeAD:autodiff-middle, r=oli-obk
Autodiff Upstreaming - rustc_codegen_ssa, rustc_middle This PR should not be merged until the rustc_codegen_llvm part is merged. I will also alter it a little based on what get's shaved off from the cg_llvm PR, and address some of the feedback I received in the other PR (including cleanups). I am putting it already up to 1) Discuss with `@jieyouxu` if there is more work needed to add tests to this and 2) Pray that there is someone reviewing who can tell me why some of my autodiff invocations get lost. Re 1: My test require fat-lto. I also modify the compilation pipeline. So if there are any other llvm-ir tests in the same compilation unit then I will likely break them. Luckily there are two groups who currently have the same fat-lto requirement for their GPU code which I have for my autodiff code and both groups have some plans to enable support for thin-lto. Once either that work pans out, I'll copy it over for this feature. I will also work on not changing the optimization pipeline for functions not differentiated, but that will require some thoughts and engineering, so I think it would be good to be able to run the autodiff tests isolated from the rest for now. Can you guide me here please? For context, here are some of my tests in the samples folder: https://github.com/EnzymeAD/rustbook Re 2: This is a pretty serious issue, since it effectively prevents publishing libraries making use of autodiff: EnzymeAD#173. For some reason my dummy code persists till the end, so the code which calls autodiff, deletes the dummy, and inserts the code to compute the derivative never gets executed. To me it looks like the rustc_autodiff attribute just get's dropped, but I don't know WHY? Any help would be super appreciated, as rustc queries look a bit voodoo to me. Tracking: - #124509 r? `@jieyouxu`
2 parents ae9dbf1 + 1f30517 commit c19c4b9

File tree

27 files changed

+481
-37
lines changed

27 files changed

+481
-37
lines changed

Cargo.lock

+2
Original file line numberDiff line numberDiff line change
@@ -4234,6 +4234,7 @@ name = "rustc_monomorphize"
42344234
version = "0.0.0"
42354235
dependencies = [
42364236
"rustc_abi",
4237+
"rustc_ast",
42374238
"rustc_attr_parsing",
42384239
"rustc_data_structures",
42394240
"rustc_errors",
@@ -4243,6 +4244,7 @@ dependencies = [
42434244
"rustc_middle",
42444245
"rustc_session",
42454246
"rustc_span",
4247+
"rustc_symbol_mangling",
42464248
"rustc_target",
42474249
"serde",
42484250
"serde_json",

compiler/rustc_ast/src/expand/autodiff_attrs.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ pub struct AutoDiffItem {
7979
pub target: String,
8080
pub attrs: AutoDiffAttrs,
8181
}
82+
8283
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
8384
pub struct AutoDiffAttrs {
8485
/// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
@@ -231,7 +232,7 @@ impl AutoDiffAttrs {
231232
self.ret_activity == DiffActivity::ActiveOnly
232233
}
233234

234-
pub fn error() -> Self {
235+
pub const fn error() -> Self {
235236
AutoDiffAttrs {
236237
mode: DiffMode::Error,
237238
ret_activity: DiffActivity::None,

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+7-7
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ fn generate_enzyme_call<'ll>(
6262
// add outer_fn name to ad_name to make it unique, in case users apply autodiff to multiple
6363
// functions. Unwrap will only panic, if LLVM gave us an invalid string.
6464
let name = llvm::get_value_name(outer_fn);
65-
let outer_fn_name = std::ffi::CStr::from_bytes_with_nul(name).unwrap().to_str().unwrap();
66-
ad_name.push_str(outer_fn_name.to_string().as_str());
65+
let outer_fn_name = std::str::from_utf8(name).unwrap();
66+
ad_name.push_str(outer_fn_name);
6767

6868
// Let us assume the user wrote the following function square:
6969
//
@@ -255,14 +255,14 @@ fn generate_enzyme_call<'ll>(
255255
// have no debug info to copy, which would then be ok.
256256
trace!("no dbg info");
257257
}
258+
258259
// Now that we copied the metadata, get rid of dummy code.
259-
llvm::LLVMRustEraseInstBefore(entry, last_inst);
260-
llvm::LLVMRustEraseInstFromParent(last_inst);
260+
llvm::LLVMRustEraseInstUntilInclusive(entry, last_inst);
261261

262-
if cx.val_ty(outer_fn) != cx.type_void() {
263-
builder.ret(call);
264-
} else {
262+
if cx.val_ty(call) == cx.type_void() {
265263
builder.ret_void();
264+
} else {
265+
builder.ret(call);
266266
}
267267

268268
// Let's crash in case that we messed something up above and generated invalid IR.

compiler/rustc_codegen_llvm/src/coverageinfo/mapgen.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ struct UsageSets<'tcx> {
298298
/// Prepare sets of definitions that are relevant to deciding whether something
299299
/// is an "unused function" for coverage purposes.
300300
fn prepare_usage_sets<'tcx>(tcx: TyCtxt<'tcx>) -> UsageSets<'tcx> {
301-
let MonoItemPartitions { all_mono_items, codegen_units } =
301+
let MonoItemPartitions { all_mono_items, codegen_units, .. } =
302302
tcx.collect_and_partition_mono_items(());
303303

304304
// Obtain a MIR body for each function participating in codegen, via an

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@ use crate::llvm::Bool;
77
extern "C" {
88
// Enzyme
99
pub fn LLVMRustHasMetadata(I: &Value, KindID: c_uint) -> bool;
10-
pub fn LLVMRustEraseInstBefore(BB: &BasicBlock, I: &Value);
10+
pub fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
1111
pub fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
1212
pub fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
1313
pub fn LLVMRustEraseInstFromParent(V: &Value);
1414
pub fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
15+
pub fn LLVMDumpModule(M: &Module);
16+
pub fn LLVMDumpValue(V: &Value);
1517
pub fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
1618

1719
pub fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;

compiler/rustc_codegen_ssa/messages.ftl

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ codegen_ssa_archive_build_failure = failed to build archive at `{$path}`: {$erro
1616
1717
codegen_ssa_atomic_compare_exchange = Atomic compare-exchange intrinsic missing failure memory ordering
1818
19+
codegen_ssa_autodiff_without_lto = using the autodiff feature requires using fat-lto
20+
1921
codegen_ssa_binary_output_to_tty = option `-o` or `--emit` is used to write binary output type `{$shorthand}` to stdout, but stdout is a tty
2022
2123
codegen_ssa_cgu_not_recorded =

compiler/rustc_codegen_ssa/src/back/write.rs

+33-5
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
77
use std::{fs, io, mem, str, thread};
88

99
use rustc_ast::attr;
10+
use rustc_ast::expand::autodiff_attrs::AutoDiffItem;
1011
use rustc_data_structures::fx::{FxHashMap, FxIndexMap};
1112
use rustc_data_structures::jobserver::{self, Acquired};
1213
use rustc_data_structures::memmap::Mmap;
@@ -40,7 +41,7 @@ use tracing::debug;
4041
use super::link::{self, ensure_removed};
4142
use super::lto::{self, SerializedModule};
4243
use super::symbol_export::symbol_name_for_instance_in_crate;
43-
use crate::errors::ErrorCreatingRemarkDir;
44+
use crate::errors::{AutodiffWithoutLto, ErrorCreatingRemarkDir};
4445
use crate::traits::*;
4546
use crate::{
4647
CachedModuleCodegen, CodegenResults, CompiledModule, CrateInfo, ModuleCodegen, ModuleKind,
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118119
pub merge_functions: bool,
119120
pub emit_lifetime_markers: bool,
120121
pub llvm_plugins: Vec<String>,
122+
pub autodiff: Vec<config::AutoDiff>,
121123
}
122124

123125
impl ModuleConfig {
@@ -266,6 +268,7 @@ impl ModuleConfig {
266268

267269
emit_lifetime_markers: sess.emit_lifetime_markers(),
268270
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
271+
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
269272
}
270273
}
271274

@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389392

390393
fn generate_lto_work<B: ExtraBackendMethods>(
391394
cgcx: &CodegenContext<B>,
395+
autodiff: Vec<AutoDiffItem>,
392396
needs_fat_lto: Vec<FatLtoInput<B>>,
393397
needs_thin_lto: Vec<(String, B::ThinBuffer)>,
394398
import_only_modules: Vec<(SerializedModule<B::ModuleBuffer>, WorkProduct)>,
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
397401

398402
if !needs_fat_lto.is_empty() {
399403
assert!(needs_thin_lto.is_empty());
400-
let module =
404+
let mut module =
401405
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
406+
if cgcx.lto == Lto::Fat {
407+
let config = cgcx.config(ModuleKind::Regular);
408+
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
409+
}
402410
// We are adding a single work item, so the cost doesn't matter.
403411
vec![(WorkItem::LTO(module), 0)]
404412
} else {
413+
if !autodiff.is_empty() {
414+
let dcx = cgcx.create_dcx();
415+
dcx.handle().emit_fatal(AutodiffWithoutLto {});
416+
}
405417
assert!(needs_fat_lto.is_empty());
406418
let (lto_modules, copy_jobs) = B::run_thin_lto(cgcx, needs_thin_lto, import_only_modules)
407419
.unwrap_or_else(|e| e.raise());
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10211033
/// Sent from a backend worker thread.
10221034
WorkItem { result: Result<WorkItemResult<B>, Option<WorkerFatalError>>, worker_id: usize },
10231035

1036+
/// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1037+
AddAutoDiffItems(Vec<AutoDiffItem>),
1038+
10241039
/// The frontend has finished generating something (backend IR or a
10251040
/// post-LTO artifact) for a codegen unit, and it should be passed to the
10261041
/// backend. Sent from the main thread.
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13481363

13491364
// This is where we collect codegen units that have gone all the way
13501365
// through codegen and LLVM.
1366+
let mut autodiff_items = Vec::new();
13511367
let mut compiled_modules = vec![];
13521368
let mut compiled_allocator_module = None;
13531369
let mut needs_link = Vec::new();
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14591475
let needs_thin_lto = mem::take(&mut needs_thin_lto);
14601476
let import_only_modules = mem::take(&mut lto_import_only_modules);
14611477

1462-
for (work, cost) in
1463-
generate_lto_work(&cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1464-
{
1478+
for (work, cost) in generate_lto_work(
1479+
&cgcx,
1480+
autodiff_items.clone(),
1481+
needs_fat_lto,
1482+
needs_thin_lto,
1483+
import_only_modules,
1484+
) {
14651485
let insertion_index = work_items
14661486
.binary_search_by_key(&cost, |&(_, cost)| cost)
14671487
.unwrap_or_else(|e| e);
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
15961616
main_thread_state = MainThreadState::Idle;
15971617
}
15981618

1619+
Message::AddAutoDiffItems(mut items) => {
1620+
autodiff_items.append(&mut items);
1621+
}
1622+
15991623
Message::CodegenComplete => {
16001624
if codegen_state != Aborted {
16011625
codegen_state = Completed;
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20702094
drop(self.coordinator.sender.send(Box::new(Message::CodegenComplete::<B>)));
20712095
}
20722096

2097+
pub(crate) fn submit_autodiff_items(&self, items: Vec<AutoDiffItem>) {
2098+
drop(self.coordinator.sender.send(Box::new(Message::<B>::AddAutoDiffItems(items))));
2099+
}
2100+
20732101
pub(crate) fn check_for_errors(&self, sess: &Session) {
20742102
self.shared_emitter_main.check(sess, false);
20752103
}

compiler/rustc_codegen_ssa/src/base.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ use rustc_middle::middle::debugger_visualizer::{DebuggerVisualizerFile, Debugger
1818
use rustc_middle::middle::exported_symbols::SymbolExportKind;
1919
use rustc_middle::middle::{exported_symbols, lang_items};
2020
use rustc_middle::mir::BinOp;
21-
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem};
21+
use rustc_middle::mir::mono::{CodegenUnit, CodegenUnitNameBuilder, MonoItem, MonoItemPartitions};
2222
use rustc_middle::query::Providers;
2323
use rustc_middle::ty::layout::{HasTyCtxt, HasTypingEnv, LayoutOf, TyAndLayout};
2424
use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
@@ -624,7 +624,9 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
624624

625625
// Run the monomorphization collector and partition the collected items into
626626
// codegen units.
627-
let codegen_units = tcx.collect_and_partition_mono_items(()).codegen_units;
627+
let MonoItemPartitions { codegen_units, autodiff_items, .. } =
628+
tcx.collect_and_partition_mono_items(());
629+
let autodiff_fncs = autodiff_items.to_vec();
628630

629631
// Force all codegen_unit queries so they are already either red or green
630632
// when compile_codegen_unit accesses them. We are not able to re-execute
@@ -695,6 +697,10 @@ pub fn codegen_crate<B: ExtraBackendMethods>(
695697
);
696698
}
697699

700+
if !autodiff_fncs.is_empty() {
701+
ongoing_codegen.submit_autodiff_items(autodiff_fncs);
702+
}
703+
698704
// For better throughput during parallel processing by LLVM, we used to sort
699705
// CGUs largest to smallest. This would lead to better thread utilization
700706
// by, for example, preventing a large CGU from being processed last and

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

+117-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
1+
use std::str::FromStr;
2+
13
use rustc_ast::attr::list_contains_name;
2-
use rustc_ast::{MetaItemInner, attr};
4+
use rustc_ast::expand::autodiff_attrs::{
5+
AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
6+
};
7+
use rustc_ast::{MetaItem, MetaItemInner, attr};
38
use rustc_attr_parsing::{InlineAttr, InstructionSetAttr, OptimizeAttr};
49
use rustc_data_structures::fx::FxHashMap;
510
use rustc_errors::codes::*;
@@ -13,6 +18,7 @@ use rustc_middle::middle::codegen_fn_attrs::{
1318
};
1419
use rustc_middle::mir::mono::Linkage;
1520
use rustc_middle::query::Providers;
21+
use rustc_middle::span_bug;
1622
use rustc_middle::ty::{self as ty, TyCtxt};
1723
use rustc_session::parse::feature_err;
1824
use rustc_session::{Session, lint};
@@ -65,6 +71,13 @@ fn codegen_fn_attrs(tcx: TyCtxt<'_>, did: LocalDefId) -> CodegenFnAttrs {
6571
codegen_fn_attrs.flags |= CodegenFnAttrFlags::TRACK_CALLER;
6672
}
6773

74+
// If our rustc version supports autodiff/enzyme, then we call our handler
75+
// to check for any `#[rustc_autodiff(...)]` attributes.
76+
if cfg!(llvm_enzyme) {
77+
let ad = autodiff_attrs(tcx, did.into());
78+
codegen_fn_attrs.autodiff_item = ad;
79+
}
80+
6881
// When `no_builtins` is applied at the crate level, we should add the
6982
// `no-builtins` attribute to each function to ensure it takes effect in LTO.
7083
let crate_attrs = tcx.hir().attrs(rustc_hir::CRATE_HIR_ID);
@@ -856,6 +869,109 @@ impl<'a> MixedExportNameAndNoMangleState<'a> {
856869
}
857870
}
858871

872+
/// We now check the #\[rustc_autodiff\] attributes which we generated from the #[autodiff(...)]
873+
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
874+
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
875+
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
876+
/// panic, unless we introduced a bug when parsing the autodiff macro.
877+
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> Option<AutoDiffAttrs> {
878+
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
879+
880+
let attrs =
881+
attrs.filter(|attr| attr.name_or_empty() == sym::rustc_autodiff).collect::<Vec<_>>();
882+
883+
// check for exactly one autodiff attribute on placeholder functions.
884+
// There should only be one, since we generate a new placeholder per ad macro.
885+
// FIXME(ZuseZ4): re-enable this check. Currently we add multiple, which doesn't cause harm but
886+
// looks strange e.g. under cargo-expand.
887+
let attr = match &attrs[..] {
888+
[] => return None,
889+
[attr] => attr,
890+
// These two attributes are the same and unfortunately duplicated due to a previous bug.
891+
[attr, _attr2] => attr,
892+
_ => {
893+
//FIXME(ZuseZ4): Once we fixed our parser, we should also prohibit the two-attribute
894+
//branch above.
895+
span_bug!(attrs[1].span, "cg_ssa: rustc_autodiff should only exist once per source");
896+
}
897+
};
898+
899+
let list = attr.meta_item_list().unwrap_or_default();
900+
901+
// empty autodiff attribute macros (i.e. `#[autodiff]`) are used to mark source functions
902+
if list.is_empty() {
903+
return Some(AutoDiffAttrs::source());
904+
}
905+
906+
let [mode, input_activities @ .., ret_activity] = &list[..] else {
907+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode and activities");
908+
};
909+
let mode = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = mode {
910+
p1.segments.first().unwrap().ident
911+
} else {
912+
span_bug!(attr.span, "rustc_autodiff attribute must contain mode");
913+
};
914+
915+
// parse mode
916+
let mode = match mode.as_str() {
917+
"Forward" => DiffMode::Forward,
918+
"Reverse" => DiffMode::Reverse,
919+
"ForwardFirst" => DiffMode::ForwardFirst,
920+
"ReverseFirst" => DiffMode::ReverseFirst,
921+
_ => {
922+
span_bug!(mode.span, "rustc_autodiff attribute contains invalid mode");
923+
}
924+
};
925+
926+
// First read the ret symbol from the attribute
927+
let ret_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
928+
p1.segments.first().unwrap().ident
929+
} else {
930+
span_bug!(attr.span, "rustc_autodiff attribute must contain the return activity");
931+
};
932+
933+
// Then parse it into an actual DiffActivity
934+
let Ok(ret_activity) = DiffActivity::from_str(ret_symbol.as_str()) else {
935+
span_bug!(ret_symbol.span, "invalid return activity");
936+
};
937+
938+
// Now parse all the intermediate (input) activities
939+
let mut arg_activities: Vec<DiffActivity> = vec![];
940+
for arg in input_activities {
941+
let arg_symbol = if let MetaItemInner::MetaItem(MetaItem { path: ref p2, .. }) = arg {
942+
match p2.segments.first() {
943+
Some(x) => x.ident,
944+
None => {
945+
span_bug!(
946+
arg.span(),
947+
"rustc_autodiff attribute must contain the input activity"
948+
);
949+
}
950+
}
951+
} else {
952+
span_bug!(arg.span(), "rustc_autodiff attribute must contain the input activity");
953+
};
954+
955+
match DiffActivity::from_str(arg_symbol.as_str()) {
956+
Ok(arg_activity) => arg_activities.push(arg_activity),
957+
Err(_) => {
958+
span_bug!(arg_symbol.span, "invalid input activity");
959+
}
960+
}
961+
}
962+
963+
for &input in &arg_activities {
964+
if !valid_input_activity(mode, input) {
965+
span_bug!(attr.span, "Invalid input activity {} for {} mode", input, mode);
966+
}
967+
}
968+
if !valid_ret_activity(mode, ret_activity) {
969+
span_bug!(attr.span, "Invalid return activity {} for {} mode", ret_activity, mode);
970+
}
971+
972+
Some(AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities })
973+
}
974+
859975
pub(crate) fn provide(providers: &mut Providers) {
860976
*providers = Providers { codegen_fn_attrs, should_inherit_track_caller, ..*providers };
861977
}

0 commit comments

Comments
 (0)