Skip to content

Commit 8dac72b

Browse files
committed
Auto merge of #136428 - EnzymeAD:enable-autodiff, r=oli-obk
test building enzyme in CI 1) This PR fixes a significant compile-time regression, by only running the expensive autodiff pipeline, if the users pass the newly introduced Enable value to the `-Zautodiff=` flag. It updates the test(s) accordingly. It gives a nice error if users forget that. 2) It fixes macos support by explicitly linking against the Enzyme build folder. This doesn't cover CI macos yet. 3) It fixes the issue that setting ENZYME_RUNPASS was ignored by enzyme and in fact did not schedule enzyme's opt pass. 4) It also re-enables support for various other values for the autodiff flag, which were ignored since the refactor. 5) I merged some improvements to Enzyme core, which means we do not longer depend on LLVM being build with the Plugin Interface enabled. 6) Unrelated to other fixes, this changes `rustc_autodiff` to `EncodeCrossCrate::Yes`. It is not enough on it's own to enable usage of Enzyme in libraries, but it is for sure a piece of the fixes needed to get this to work. try-job: x86_64-gnu r? `@oli-obk` Tracking: - #124509
2 parents b6d3be4 + 49e9630 commit 8dac72b

File tree

20 files changed

+237
-104
lines changed

20 files changed

+237
-104
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ use crate::{Ty, TyKind};
1717
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
1818
/// as it's already done in the C++ and Julia frontend of Enzyme.
1919
///
20-
/// (FIXME) remove *First variants.
2120
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
2221
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
2322
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]

compiler/rustc_builtin_macros/src/autodiff.rs

+1
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,7 @@ mod llvm_enzyme {
242242
defaultness: ast::Defaultness::Final,
243243
sig: d_sig,
244244
generics: Generics::default(),
245+
contract: None,
245246
body: Some(d_body),
246247
});
247248
let mut rustc_ad_attr =

compiler/rustc_codegen_llvm/messages.ftl

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
codegen_llvm_autodiff_without_enable = using the autodiff feature requires -Z autodiff=Enable
12
codegen_llvm_autodiff_without_lto = using the autodiff feature requires using fat-lto
23
34
codegen_llvm_copy_bitcode = failed to copy bitcode to object file: {$err}

compiler/rustc_codegen_llvm/src/back/lto.rs

+62-23
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,42 @@ fn thin_lto(
586586
}
587587
}
588588

589+
fn enable_autodiff_settings(ad: &[config::AutoDiff], module: &mut ModuleCodegen<ModuleLlvm>) {
590+
for &val in ad {
591+
match val {
592+
config::AutoDiff::PrintModBefore => {
593+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
594+
}
595+
config::AutoDiff::PrintPerf => {
596+
llvm::set_print_perf(true);
597+
}
598+
config::AutoDiff::PrintAA => {
599+
llvm::set_print_activity(true);
600+
}
601+
config::AutoDiff::PrintTA => {
602+
llvm::set_print_type(true);
603+
}
604+
config::AutoDiff::Inline => {
605+
llvm::set_inline(true);
606+
}
607+
config::AutoDiff::LooseTypes => {
608+
llvm::set_loose_types(false);
609+
}
610+
config::AutoDiff::PrintSteps => {
611+
llvm::set_print(true);
612+
}
613+
// We handle this below
614+
config::AutoDiff::PrintModAfter => {}
615+
// This is required and already checked
616+
config::AutoDiff::Enable => {}
617+
}
618+
}
619+
// This helps with handling enums for now.
620+
llvm::set_strict_aliasing(false);
621+
// FIXME(ZuseZ4): Test this, since it was added a long time ago.
622+
llvm::set_rust_rules(true);
623+
}
624+
589625
pub(crate) fn run_pass_manager(
590626
cgcx: &CodegenContext<LlvmCodegenBackend>,
591627
dcx: DiagCtxtHandle<'_>,
@@ -604,34 +640,37 @@ pub(crate) fn run_pass_manager(
604640
let opt_stage = if thin { llvm::OptStage::ThinLTO } else { llvm::OptStage::FatLTO };
605641
let opt_level = config.opt_level.unwrap_or(config::OptLevel::No);
606642

607-
// If this rustc version was build with enzyme/autodiff enabled, and if users applied the
608-
// `#[autodiff]` macro at least once, then we will later call llvm_optimize a second time.
609-
debug!("running llvm pm opt pipeline");
643+
// The PostAD behavior is the same that we would have if no autodiff was used.
644+
// It will run the default optimization pipeline. If AD is enabled we select
645+
// the DuringAD stage, which will disable vectorization and loop unrolling, and
646+
// schedule two autodiff optimization + differentiation passes.
647+
// We then run the llvm_optimize function a second time, to optimize the code which we generated
648+
// in the enzyme differentiation pass.
649+
let enable_ad = config.autodiff.contains(&config::AutoDiff::Enable);
650+
let stage =
651+
if enable_ad { write::AutodiffStage::DuringAD } else { write::AutodiffStage::PostAD };
652+
653+
if enable_ad {
654+
enable_autodiff_settings(&config.autodiff, module);
655+
}
656+
610657
unsafe {
611-
write::llvm_optimize(
612-
cgcx,
613-
dcx,
614-
module,
615-
config,
616-
opt_level,
617-
opt_stage,
618-
write::AutodiffStage::DuringAD,
619-
)?;
658+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
620659
}
621-
// FIXME(ZuseZ4): Make this more granular
622-
if cfg!(llvm_enzyme) && !thin {
660+
661+
if cfg!(llvm_enzyme) && enable_ad {
662+
let opt_stage = llvm::OptStage::FatLTO;
663+
let stage = write::AutodiffStage::PostAD;
623664
unsafe {
624-
write::llvm_optimize(
625-
cgcx,
626-
dcx,
627-
module,
628-
config,
629-
opt_level,
630-
llvm::OptStage::FatLTO,
631-
write::AutodiffStage::PostAD,
632-
)?;
665+
write::llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, stage)?;
666+
}
667+
668+
// This is the final IR, so people should be able to inspect the optimized autodiff output.
669+
if config.autodiff.contains(&config::AutoDiff::PrintModAfter) {
670+
unsafe { llvm::LLVMDumpModule(module.module_llvm.llmod()) };
633671
}
634672
}
673+
635674
debug!("lto done");
636675
Ok(())
637676
}

compiler/rustc_codegen_llvm/src/back/write.rs

+5-10
Original file line numberDiff line numberDiff line change
@@ -564,19 +564,16 @@ pub(crate) unsafe fn llvm_optimize(
564564
// FIXME(ZuseZ4): In a future update we could figure out how to only optimize individual functions getting
565565
// differentiated.
566566

567+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
568+
let run_enzyme = autodiff_stage == AutodiffStage::DuringAD;
567569
let unroll_loops;
568570
let vectorize_slp;
569571
let vectorize_loop;
570-
let run_enzyme = cfg!(llvm_enzyme) && autodiff_stage == AutodiffStage::DuringAD;
571572

572573
// When we build rustc with enzyme/autodiff support, we want to postpone size-increasing
573574
// optimizations until after differentiation. Our pipeline is thus: (opt + enzyme), (full opt).
574575
// We therefore have two calls to llvm_optimize, if autodiff is used.
575-
//
576-
// FIXME(ZuseZ4): Before shipping on nightly,
577-
// we should make this more granular, or at least check that the user has at least one autodiff
578-
// call in their code, to justify altering the compilation pipeline.
579-
if cfg!(llvm_enzyme) && autodiff_stage != AutodiffStage::PostAD {
576+
if consider_ad && autodiff_stage != AutodiffStage::PostAD {
580577
unroll_loops = false;
581578
vectorize_slp = false;
582579
vectorize_loop = false;
@@ -706,10 +703,8 @@ pub(crate) unsafe fn optimize(
706703

707704
// If we know that we will later run AD, then we disable vectorization and loop unrolling.
708705
// Otherwise we pretend AD is already done and run the normal opt pipeline (=PostAD).
709-
// FIXME(ZuseZ4): Make this more granular, only set PreAD if we actually have autodiff
710-
// usages, not just if we build rustc with autodiff support.
711-
let autodiff_stage =
712-
if cfg!(llvm_enzyme) { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
706+
let consider_ad = cfg!(llvm_enzyme) && config.autodiff.contains(&config::AutoDiff::Enable);
707+
let autodiff_stage = if consider_ad { AutodiffStage::PreAD } else { AutodiffStage::PostAD };
713708
return unsafe {
714709
llvm_optimize(cgcx, dcx, module, config, opt_level, opt_stage, autodiff_stage)
715710
};

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

+9-4
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::back::write::llvm_err;
1010
use crate::builder::SBuilder;
1111
use crate::context::SimpleCx;
1212
use crate::declare::declare_simple_fn;
13-
use crate::errors::LlvmError;
13+
use crate::errors::{AutoDiffWithoutEnable, LlvmError};
1414
use crate::llvm::AttributePlace::Function;
1515
use crate::llvm::{Metadata, True};
1616
use crate::value::Value;
@@ -46,9 +46,6 @@ fn generate_enzyme_call<'ll>(
4646
let output = attrs.ret_activity;
4747

4848
// We have to pick the name depending on whether we want forward or reverse mode autodiff.
49-
// FIXME(ZuseZ4): The new pass based approach should not need the {Forward/Reverse}First method anymore, since
50-
// it will handle higher-order derivatives correctly automatically (in theory). Currently
51-
// higher-order derivatives fail, so we should debug that before adjusting this code.
5249
let mut ad_name: String = match attrs.mode {
5350
DiffMode::Forward => "__enzyme_fwddiff",
5451
DiffMode::Reverse => "__enzyme_autodiff",
@@ -291,6 +288,14 @@ pub(crate) fn differentiate<'ll>(
291288
let diag_handler = cgcx.create_dcx();
292289
let cx = SimpleCx { llmod: module.module_llvm.llmod(), llcx: module.module_llvm.llcx };
293290

291+
// First of all, did the user try to use autodiff without using the -Zautodiff=Enable flag?
292+
if !diff_items.is_empty()
293+
&& !cgcx.opts.unstable_opts.autodiff.contains(&rustc_session::config::AutoDiff::Enable)
294+
{
295+
let dcx = cgcx.create_dcx();
296+
return Err(dcx.handle().emit_almost_fatal(AutoDiffWithoutEnable));
297+
}
298+
294299
// Before dumping the module, we want all the TypeTrees to become part of the module.
295300
for item in diff_items.iter() {
296301
let name = item.source.clone();

compiler/rustc_codegen_llvm/src/errors.rs

+4-1
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,12 @@ impl<G: EmissionGuarantee> Diagnostic<'_, G> for ParseTargetMachineConfig<'_> {
9292

9393
#[derive(Diagnostic)]
9494
#[diag(codegen_llvm_autodiff_without_lto)]
95-
#[note]
9695
pub(crate) struct AutoDiffWithoutLTO;
9796

97+
#[derive(Diagnostic)]
98+
#[diag(codegen_llvm_autodiff_without_enable)]
99+
pub(crate) struct AutoDiffWithoutEnable;
100+
98101
#[derive(Diagnostic)]
99102
#[diag(codegen_llvm_lto_disallowed)]
100103
pub(crate) struct LtoDisallowed;

compiler/rustc_codegen_llvm/src/llvm/enzyme_ffi.rs

+94
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,97 @@ pub enum LLVMRustVerifierFailureAction {
3535
LLVMPrintMessageAction = 1,
3636
LLVMReturnStatusAction = 2,
3737
}
38+
39+
#[cfg(llvm_enzyme)]
40+
pub use self::Enzyme_AD::*;
41+
42+
#[cfg(llvm_enzyme)]
43+
pub mod Enzyme_AD {
44+
use libc::c_void;
45+
extern "C" {
46+
pub fn EnzymeSetCLBool(arg1: *mut ::std::os::raw::c_void, arg2: u8);
47+
}
48+
extern "C" {
49+
static mut EnzymePrintPerf: c_void;
50+
static mut EnzymePrintActivity: c_void;
51+
static mut EnzymePrintType: c_void;
52+
static mut EnzymePrint: c_void;
53+
static mut EnzymeStrictAliasing: c_void;
54+
static mut looseTypeAnalysis: c_void;
55+
static mut EnzymeInline: c_void;
56+
static mut RustTypeRules: c_void;
57+
}
58+
pub fn set_print_perf(print: bool) {
59+
unsafe {
60+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintPerf), print as u8);
61+
}
62+
}
63+
pub fn set_print_activity(print: bool) {
64+
unsafe {
65+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintActivity), print as u8);
66+
}
67+
}
68+
pub fn set_print_type(print: bool) {
69+
unsafe {
70+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrintType), print as u8);
71+
}
72+
}
73+
pub fn set_print(print: bool) {
74+
unsafe {
75+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymePrint), print as u8);
76+
}
77+
}
78+
pub fn set_strict_aliasing(strict: bool) {
79+
unsafe {
80+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeStrictAliasing), strict as u8);
81+
}
82+
}
83+
pub fn set_loose_types(loose: bool) {
84+
unsafe {
85+
EnzymeSetCLBool(std::ptr::addr_of_mut!(looseTypeAnalysis), loose as u8);
86+
}
87+
}
88+
pub fn set_inline(val: bool) {
89+
unsafe {
90+
EnzymeSetCLBool(std::ptr::addr_of_mut!(EnzymeInline), val as u8);
91+
}
92+
}
93+
pub fn set_rust_rules(val: bool) {
94+
unsafe {
95+
EnzymeSetCLBool(std::ptr::addr_of_mut!(RustTypeRules), val as u8);
96+
}
97+
}
98+
}
99+
100+
#[cfg(not(llvm_enzyme))]
101+
pub use self::Fallback_AD::*;
102+
103+
#[cfg(not(llvm_enzyme))]
104+
pub mod Fallback_AD {
105+
#![allow(unused_variables)]
106+
107+
pub fn set_inline(val: bool) {
108+
unimplemented!()
109+
}
110+
pub fn set_print_perf(print: bool) {
111+
unimplemented!()
112+
}
113+
pub fn set_print_activity(print: bool) {
114+
unimplemented!()
115+
}
116+
pub fn set_print_type(print: bool) {
117+
unimplemented!()
118+
}
119+
pub fn set_print(print: bool) {
120+
unimplemented!()
121+
}
122+
pub fn set_strict_aliasing(strict: bool) {
123+
unimplemented!()
124+
}
125+
pub fn set_loose_types(loose: bool) {
126+
unimplemented!()
127+
}
128+
pub fn set_rust_rules(val: bool) {
129+
unimplemented!()
130+
}
131+
}

compiler/rustc_codegen_ssa/src/back/write.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,8 @@ fn generate_lto_work<B: ExtraBackendMethods>(
405405
B::run_fat_lto(cgcx, needs_fat_lto, import_only_modules).unwrap_or_else(|e| e.raise());
406406
if cgcx.lto == Lto::Fat && !autodiff.is_empty() {
407407
let config = cgcx.config(ModuleKind::Regular);
408-
module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
408+
module =
409+
unsafe { module.autodiff(cgcx, autodiff, config).unwrap_or_else(|e| e.raise()) };
409410
}
410411
// We are adding a single work item, so the cost doesn't matter.
411412
vec![(WorkItem::LTO(module), 0)]

compiler/rustc_feature/src/builtin_attrs.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -743,7 +743,7 @@ pub static BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[
743743
rustc_attr!(
744744
rustc_autodiff, Normal,
745745
template!(Word, List: r#""...""#), DuplicatesOk,
746-
EncodeCrossCrate::No, INTERNAL_UNSTABLE
746+
EncodeCrossCrate::Yes, INTERNAL_UNSTABLE
747747
),
748748

749749
// ==========================================================================

compiler/rustc_interface/src/tests.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ fn test_unstable_options_tracking_hash() {
759759
tracked!(allow_features, Some(vec![String::from("lang_items")]));
760760
tracked!(always_encode_mir, true);
761761
tracked!(assume_incomplete_release, true);
762-
tracked!(autodiff, vec![AutoDiff::Print]);
762+
tracked!(autodiff, vec![AutoDiff::Enable]);
763763
tracked!(binary_dep_depinfo, true);
764764
tracked!(box_noalias, false);
765765
tracked!(

compiler/rustc_llvm/llvm-wrapper/PassWrapper.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -692,9 +692,12 @@ struct LLVMRustSanitizerOptions {
692692
bool SanitizeKernelAddressRecover;
693693
};
694694

695-
// This symbol won't be available or used when Enzyme is not enabled
695+
// This symbol won't be available or used when Enzyme is not enabled.
696+
// Always set AugmentPassBuilder to true, since it registers optimizations which
697+
// will improve the performance for Enzyme.
696698
#ifdef ENZYME
697-
extern "C" void registerEnzyme(llvm::PassBuilder &PB);
699+
extern "C" void registerEnzymeAndPassPipeline(llvm::PassBuilder &PB,
700+
/* augmentPassBuilder */ bool);
698701
#endif
699702

700703
extern "C" LLVMRustResult LLVMRustOptimize(
@@ -1023,7 +1026,7 @@ extern "C" LLVMRustResult LLVMRustOptimize(
10231026
// now load "-enzyme" pass:
10241027
#ifdef ENZYME
10251028
if (RunEnzyme) {
1026-
registerEnzyme(PB);
1029+
registerEnzymeAndPassPipeline(PB, true);
10271030
if (auto Err = PB.parsePassPipeline(MPM, "enzyme")) {
10281031
std::string ErrMsg = toString(std::move(Err));
10291032
LLVMRustSetLastError(ErrMsg.c_str());

compiler/rustc_monomorphize/src/partitioning/autodiff.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ pub(crate) fn find_autodiff_source_functions<'tcx>(
6666
let mut autodiff_items: Vec<AutoDiffItem> = vec![];
6767
for (item, instance) in autodiff_mono_items {
6868
let target_id = instance.def_id();
69-
let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
69+
let cg_fn_attr = &tcx.codegen_fn_attrs(target_id).autodiff_item;
7070
let Some(target_attrs) = cg_fn_attr else {
7171
continue;
7272
};

0 commit comments

Comments
 (0)