@@ -146,7 +146,7 @@ mod llvm_enzyme {
146
146
}
147
147
let dcx = ecx. sess . dcx ( ) ;
148
148
// first get the annotable item:
149
- let ( sig, is_impl ) : ( FnSig , bool ) = match & item {
149
+ let sig: FnSig = match & item {
150
150
Annotatable :: Item ( iitem) => {
151
151
let sig = match & iitem. kind {
152
152
ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
@@ -155,7 +155,7 @@ mod llvm_enzyme {
155
155
return vec ! [ item] ;
156
156
}
157
157
} ;
158
- ( sig. clone ( ) , false )
158
+ sig. clone ( )
159
159
}
160
160
Annotatable :: AssocItem ( assoc_item, _) => {
161
161
let sig = match & assoc_item. kind {
@@ -165,7 +165,24 @@ mod llvm_enzyme {
165
165
return vec ! [ item] ;
166
166
}
167
167
} ;
168
- ( sig. clone ( ) , true )
168
+ sig. clone ( )
169
+ }
170
+ Annotatable :: Stmt ( stmt) => {
171
+ let sig = match & stmt. kind {
172
+ ast:: StmtKind :: Item ( iitem) => match & iitem. kind {
173
+ ast:: ItemKind :: Fn ( box ast:: Fn { sig, .. } ) => sig,
174
+ _ => {
175
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
176
+ return vec ! [ item] ;
177
+ }
178
+ } ,
179
+ _ => {
180
+ dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
181
+ return vec ! [ item] ;
182
+ }
183
+ } ;
184
+
185
+ sig. clone ( )
169
186
}
170
187
_ => {
171
188
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
@@ -189,6 +206,10 @@ mod llvm_enzyme {
189
206
Annotatable :: AssocItem ( assoc_item, _) => {
190
207
( assoc_item. vis . clone ( ) , assoc_item. ident . clone ( ) )
191
208
}
209
+ Annotatable :: Stmt ( stmt) => match & stmt. kind {
210
+ ast:: StmtKind :: Item ( iitem) => ( iitem. vis . clone ( ) , iitem. ident . clone ( ) ) ,
211
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
212
+ } ,
192
213
_ => {
193
214
dcx. emit_err ( errors:: AutoDiffInvalidApplication { span : item. span ( ) } ) ;
194
215
return vec ! [ item] ;
@@ -305,6 +326,22 @@ mod llvm_enzyme {
305
326
}
306
327
Annotatable :: AssocItem ( assoc_item. clone ( ) , i)
307
328
}
329
+ Annotatable :: Stmt ( ref mut stmt) => {
330
+ match stmt. kind {
331
+ ast:: StmtKind :: Item ( ref mut iitem) => {
332
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & attr. kind ) ) {
333
+ iitem. attrs . push ( attr) ;
334
+ }
335
+ if !iitem. attrs . iter ( ) . any ( |a| same_attribute ( & a. kind , & inline_never. kind ) )
336
+ {
337
+ iitem. attrs . push ( inline_never. clone ( ) ) ;
338
+ }
339
+ }
340
+ _ => unreachable ! ( "stmt kind checked previously" ) ,
341
+ } ;
342
+
343
+ Annotatable :: Stmt ( stmt. clone ( ) )
344
+ }
308
345
_ => {
309
346
unreachable ! ( "annotatable kind checked previously" )
310
347
}
@@ -315,24 +352,43 @@ mod llvm_enzyme {
315
352
delim : rustc_ast:: token:: Delimiter :: Parenthesis ,
316
353
tokens : ts,
317
354
} ) ;
355
+
318
356
let d_attr = outer_normal_attr ( & rustc_ad_attr, new_id, span) ;
319
- let d_annotatable = if is_impl {
320
- let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
321
- let d_fn = P ( ast:: AssocItem {
322
- attrs : thin_vec ! [ d_attr, inline_never] ,
323
- id : ast:: DUMMY_NODE_ID ,
324
- span,
325
- vis,
326
- ident : d_ident,
327
- kind : assoc_item,
328
- tokens : None ,
329
- } ) ;
330
- Annotatable :: AssocItem ( d_fn, Impl )
331
- } else {
332
- let mut d_fn =
333
- ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
334
- d_fn. vis = vis;
335
- Annotatable :: Item ( d_fn)
357
+ let d_annotatable = match & item {
358
+ Annotatable :: AssocItem ( _, _) => {
359
+ let assoc_item: AssocItemKind = ast:: AssocItemKind :: Fn ( asdf) ;
360
+ let d_fn = P ( ast:: AssocItem {
361
+ attrs : thin_vec ! [ d_attr, inline_never] ,
362
+ id : ast:: DUMMY_NODE_ID ,
363
+ span,
364
+ vis,
365
+ ident : d_ident,
366
+ kind : assoc_item,
367
+ tokens : None ,
368
+ } ) ;
369
+ Annotatable :: AssocItem ( d_fn, Impl )
370
+ }
371
+ Annotatable :: Item ( _) => {
372
+ let mut d_fn =
373
+ ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
374
+ d_fn. vis = vis;
375
+
376
+ Annotatable :: Item ( d_fn)
377
+ }
378
+ Annotatable :: Stmt ( _) => {
379
+ let mut d_fn =
380
+ ecx. item ( span, d_ident, thin_vec ! [ d_attr, inline_never] , ItemKind :: Fn ( asdf) ) ;
381
+ d_fn. vis = vis;
382
+
383
+ Annotatable :: Stmt ( P ( ast:: Stmt {
384
+ id : ast:: DUMMY_NODE_ID ,
385
+ kind : ast:: StmtKind :: Item ( d_fn) ,
386
+ span,
387
+ } ) )
388
+ }
389
+ _ => {
390
+ unreachable ! ( "item kind checked previously" )
391
+ }
336
392
} ;
337
393
338
394
return vec ! [ orig_annotatable, d_annotatable] ;
0 commit comments