@@ -7,6 +7,7 @@ use std::sync::mpsc::{Receiver, Sender, channel};
77use  std:: { fs,  io,  mem,  str,  thread} ; 
88
99use  rustc_ast:: attr; 
10+ use  rustc_ast:: expand:: autodiff_attrs:: AutoDiffItem ; 
1011use  rustc_data_structures:: fx:: { FxHashMap ,  FxIndexMap } ; 
1112use  rustc_data_structures:: jobserver:: { self ,  Acquired } ; 
1213use  rustc_data_structures:: memmap:: Mmap ; 
@@ -40,7 +41,7 @@ use tracing::debug;
4041use  super :: link:: { self ,  ensure_removed} ; 
4142use  super :: lto:: { self ,  SerializedModule } ; 
4243use  super :: symbol_export:: symbol_name_for_instance_in_crate; 
43- use  crate :: errors:: ErrorCreatingRemarkDir ; 
44+ use  crate :: errors:: { AutodiffWithoutLto ,   ErrorCreatingRemarkDir } ; 
4445use  crate :: traits:: * ; 
4546use  crate :: { 
4647    CachedModuleCodegen ,  CodegenResults ,  CompiledModule ,  CrateInfo ,  ModuleCodegen ,  ModuleKind , 
@@ -118,6 +119,7 @@ pub struct ModuleConfig {
118119    pub  merge_functions :  bool , 
119120    pub  emit_lifetime_markers :  bool , 
120121    pub  llvm_plugins :  Vec < String > , 
122+     pub  autodiff :  Vec < config:: AutoDiff > , 
121123} 
122124
123125impl  ModuleConfig  { 
@@ -266,6 +268,7 @@ impl ModuleConfig {
266268
267269            emit_lifetime_markers :  sess. emit_lifetime_markers ( ) , 
268270            llvm_plugins :  if_regular ! ( sess. opts. unstable_opts. llvm_plugins. clone( ) ,  vec![ ] ) , 
271+             autodiff :  if_regular ! ( sess. opts. unstable_opts. autodiff. clone( ) ,  vec![ ] ) , 
269272        } 
270273    } 
271274
@@ -389,6 +392,7 @@ impl<B: WriteBackendMethods> CodegenContext<B> {
389392
390393fn  generate_lto_work < B :  ExtraBackendMethods > ( 
391394    cgcx :  & CodegenContext < B > , 
395+     autodiff :  Vec < AutoDiffItem > , 
392396    needs_fat_lto :  Vec < FatLtoInput < B > > , 
393397    needs_thin_lto :  Vec < ( String ,  B :: ThinBuffer ) > , 
394398    import_only_modules :  Vec < ( SerializedModule < B :: ModuleBuffer > ,  WorkProduct ) > , 
@@ -397,11 +401,19 @@ fn generate_lto_work<B: ExtraBackendMethods>(
397401
398402    if  !needs_fat_lto. is_empty ( )  { 
399403        assert ! ( needs_thin_lto. is_empty( ) ) ; 
400-         let  module =
404+         let  mut   module =
401405            B :: run_fat_lto ( cgcx,  needs_fat_lto,  import_only_modules) . unwrap_or_else ( |e| e. raise ( ) ) ; 
406+         if  cgcx. lto  == Lto :: Fat  { 
407+             let  config = cgcx. config ( ModuleKind :: Regular ) ; 
408+             module = unsafe  {  module. autodiff ( cgcx,  autodiff,  config) . unwrap ( )  } ; 
409+         } 
402410        // We are adding a single work item, so the cost doesn't matter. 
403411        vec ! [ ( WorkItem :: LTO ( module) ,  0 ) ] 
404412    }  else  { 
413+         if  !autodiff. is_empty ( )  { 
414+             let  dcx = cgcx. create_dcx ( ) ; 
415+             dcx. handle ( ) . emit_fatal ( AutodiffWithoutLto  { } ) ; 
416+         } 
405417        assert ! ( needs_fat_lto. is_empty( ) ) ; 
406418        let  ( lto_modules,  copy_jobs)  = B :: run_thin_lto ( cgcx,  needs_thin_lto,  import_only_modules) 
407419            . unwrap_or_else ( |e| e. raise ( ) ) ; 
@@ -1021,6 +1033,9 @@ pub(crate) enum Message<B: WriteBackendMethods> {
10211033/// Sent from a backend worker thread. 
10221034WorkItem  {  result :  Result < WorkItemResult < B > ,  Option < WorkerFatalError > > ,  worker_id :  usize  } , 
10231035
1036+     /// A vector containing all the AutoDiff tasks that we have to pass to Enzyme. 
1037+ AddAutoDiffItems ( Vec < AutoDiffItem > ) , 
1038+ 
10241039    /// The frontend has finished generating something (backend IR or a 
10251040/// post-LTO artifact) for a codegen unit, and it should be passed to the 
10261041/// backend. Sent from the main thread. 
@@ -1348,6 +1363,7 @@ fn start_executing_work<B: ExtraBackendMethods>(
13481363
13491364        // This is where we collect codegen units that have gone all the way 
13501365        // through codegen and LLVM. 
1366+         let  mut  autodiff_items = Vec :: new ( ) ; 
13511367        let  mut  compiled_modules = vec ! [ ] ; 
13521368        let  mut  compiled_allocator_module = None ; 
13531369        let  mut  needs_link = Vec :: new ( ) ; 
@@ -1459,9 +1475,13 @@ fn start_executing_work<B: ExtraBackendMethods>(
14591475                    let  needs_thin_lto = mem:: take ( & mut  needs_thin_lto) ; 
14601476                    let  import_only_modules = mem:: take ( & mut  lto_import_only_modules) ; 
14611477
1462-                     for  ( work,  cost)  in 
1463-                         generate_lto_work ( & cgcx,  needs_fat_lto,  needs_thin_lto,  import_only_modules) 
1464-                     { 
1478+                     for  ( work,  cost)  in  generate_lto_work ( 
1479+                         & cgcx, 
1480+                         autodiff_items. clone ( ) , 
1481+                         needs_fat_lto, 
1482+                         needs_thin_lto, 
1483+                         import_only_modules, 
1484+                     )  { 
14651485                        let  insertion_index = work_items
14661486                            . binary_search_by_key ( & cost,  |& ( _,  cost) | cost) 
14671487                            . unwrap_or_else ( |e| e) ; 
@@ -1596,6 +1616,10 @@ fn start_executing_work<B: ExtraBackendMethods>(
15961616                    main_thread_state = MainThreadState :: Idle ; 
15971617                } 
15981618
1619+                 Message :: AddAutoDiffItems ( mut  items)  => { 
1620+                     autodiff_items. append ( & mut  items) ; 
1621+                 } 
1622+ 
15991623                Message :: CodegenComplete  => { 
16001624                    if  codegen_state != Aborted  { 
16011625                        codegen_state = Completed ; 
@@ -2070,6 +2094,10 @@ impl<B: ExtraBackendMethods> OngoingCodegen<B> {
20702094        drop ( self . coordinator . sender . send ( Box :: new ( Message :: CodegenComplete :: < B > ) ) ) ; 
20712095    } 
20722096
2097+     pub ( crate )  fn  submit_autodiff_items ( & self ,  items :  Vec < AutoDiffItem > )  { 
2098+         drop ( self . coordinator . sender . send ( Box :: new ( Message :: < B > :: AddAutoDiffItems ( items) ) ) ) ; 
2099+     } 
2100+ 
20732101    pub ( crate )  fn  check_for_errors ( & self ,  sess :  & Session )  { 
20742102        self . shared_emitter_main . check ( sess,  false ) ; 
20752103    } 
0 commit comments