@@ -8,6 +8,7 @@ use std::{fs, io, mem, str, thread};
8
8
9
9
use jobserver:: { Acquired , Client } ;
10
10
use rustc_ast:: attr;
11
+ use rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ;
11
12
use rustc_data_structures:: fx:: { FxHashMap , FxIndexMap } ;
12
13
use rustc_data_structures:: memmap:: Mmap ;
13
14
use rustc_data_structures:: profiling:: { SelfProfilerRef , VerboseTimingGuard } ;
@@ -41,7 +42,7 @@ use tracing::debug;
41
42
use super :: link:: { self , ensure_removed} ;
42
43
use super :: lto:: { self , SerializedModule } ;
43
44
use super :: symbol_export:: symbol_name_for_instance_in_crate;
44
- use crate :: errors:: ErrorCreatingRemarkDir ;
45
+ use crate :: errors:: { AutodiffWithoutLto , ErrorCreatingRemarkDir } ;
45
46
use crate :: traits:: * ;
46
47
use crate :: {
47
48
CachedModuleCodegen , CodegenResults , CompiledModule , CrateInfo , ModuleCodegen , ModuleKind ,
@@ -120,6 +121,7 @@ pub struct ModuleConfig {
120
121
pub merge_functions : bool ,
121
122
pub emit_lifetime_markers : bool ,
122
123
pub llvm_plugins : Vec < String > ,
124
+ pub autodiff : Vec < config:: AutoDiff > ,
123
125
}
124
126
125
127
impl ModuleConfig {
@@ -280,6 +282,7 @@ impl ModuleConfig {
280
282
281
283
emit_lifetime_markers : sess. emit_lifetime_markers ( ) ,
282
284
llvm_plugins : if_regular ! ( sess. opts. unstable_opts. llvm_plugins. clone( ) , vec![ ] ) ,
285
+ autodiff : if_regular ! ( sess. opts. unstable_opts. autodiff. clone( ) , vec![ ] ) ,
283
286
}
284
287
}
285
288
@@ -401,6 +404,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
401
404
402
405
fn generate_lto_work < B : ExtraBackendMethods > (
403
406
cgcx : & CodegenContext < B > ,
407
+ autodiff : Vec < AutoDiffItem > ,
404
408
needs_fat_lto : Vec < FatLtoInput < B > > ,
405
409
needs_thin_lto : Vec < ( String , B :: ThinBuffer ) > ,
406
410
import_only_modules : Vec < ( SerializedModule < B :: ModuleBuffer > , WorkProduct ) > ,
@@ -411,9 +415,18 @@ fn generate_lto_work<B: ExtraBackendMethods>(
411
415
assert ! ( needs_thin_lto. is_empty( ) ) ;
412
416
let module =
413
417
B :: run_fat_lto ( cgcx, needs_fat_lto, import_only_modules) . unwrap_or_else ( |e| e. raise ( ) ) ;
418
+ if cgcx. lto == Lto :: Fat {
419
+ let _config = cgcx. config ( ModuleKind :: Regular ) ;
420
+ todo ! ( "fat LTO with autodiff is not yet implemented" ) ;
421
+ //module = unsafe { module.autodiff(cgcx, autodiff, config).unwrap() };
422
+ }
414
423
// We are adding a single work item, so the cost doesn't matter.
415
424
vec ! [ ( WorkItem :: LTO ( module) , 0 ) ]
416
425
} else {
426
+ if !autodiff. is_empty ( ) {
427
+ let dcx = cgcx. create_dcx ( ) ;
428
+ dcx. handle ( ) . emit_fatal ( AutodiffWithoutLto { } ) ;
429
+ }
417
430
assert ! ( needs_fat_lto. is_empty( ) ) ;
418
431
let ( lto_modules, copy_jobs) = B :: run_thin_lto ( cgcx, needs_thin_lto, import_only_modules)
419
432
. unwrap_or_else ( |e| e. raise ( ) ) ;
@@ -1041,6 +1054,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
1041
1054
/// Sent from a backend worker thread.
1042
1055
WorkItem { result : Result < WorkItemResult < B > , Option < WorkerFatalError > > , worker_id : usize } ,
1043
1056
1057
+ /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme.
1058
+ AddAutoDiffItems ( Vec < AutoDiffItem > ) ,
1059
+
1044
1060
/// The frontend has finished generating something (backend IR or a
1045
1061
/// post-LTO artifact) for a codegen unit, and it should be passed to the
1046
1062
/// backend. Sent from the main thread.
@@ -1367,6 +1383,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
1367
1383
1368
1384
// This is where we collect codegen units that have gone all the way
1369
1385
// through codegen and LLVM.
1386
+ let mut autodiff_items = Vec :: new ( ) ;
1370
1387
let mut compiled_modules = vec ! [ ] ;
1371
1388
let mut compiled_allocator_module = None ;
1372
1389
let mut needs_link = Vec :: new ( ) ;
@@ -1478,9 +1495,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
1478
1495
let needs_thin_lto = mem:: take ( & mut needs_thin_lto) ;
1479
1496
let import_only_modules = mem:: take ( & mut lto_import_only_modules) ;
1480
1497
1481
- for ( work, cost) in
1482
- generate_lto_work ( & cgcx, needs_fat_lto, needs_thin_lto, import_only_modules)
1483
- {
1498
+ for ( work, cost) in generate_lto_work (
1499
+ & cgcx,
1500
+ autodiff_items. clone ( ) ,
1501
+ needs_fat_lto,
1502
+ needs_thin_lto,
1503
+ import_only_modules,
1504
+ ) {
1484
1505
let insertion_index = work_items
1485
1506
. binary_search_by_key ( & cost, |& ( _, cost) | cost)
1486
1507
. unwrap_or_else ( |e| e) ;
@@ -1615,6 +1636,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
1615
1636
main_thread_state = MainThreadState :: Idle ;
1616
1637
}
1617
1638
1639
+ Message :: AddAutoDiffItems ( mut items) => {
1640
+ autodiff_items. append ( & mut items) ;
1641
+ }
1642
+
1618
1643
Message :: CodegenComplete => {
1619
1644
if codegen_state != Aborted {
1620
1645
codegen_state = Completed ;
@@ -2092,6 +2117,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
2092
2117
drop ( self . coordinator . sender . send ( Box :: new ( Message :: CodegenComplete :: < B > ) ) ) ;
2093
2118
}
2094
2119
2120
+ pub ( crate ) fn submit_autodiff_items ( & self , items : Vec < AutoDiffItem > ) {
2121
+ drop ( self . coordinator . sender . send ( Box :: new ( Message :: < B > :: AddAutoDiffItems ( items) ) ) ) ;
2122
+ }
2123
+
2095
2124
pub ( crate ) fn check_for_errors ( & self , sess : & Session ) {
2096
2125
self . shared_emitter_main . check ( sess, false ) ;
2097
2126
}
0 commit comments