Skip to content

Commit 7da1759

Browse files
committed
fix usage of autodiff macro with inner functions
1 parent be73c1f commit 7da1759

File tree

1 file changed

+76
-20
lines changed

1 file changed

+76
-20
lines changed

compiler/rustc_builtin_macros/src/autodiff.rs

+76-20
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ mod llvm_enzyme {
146146
}
147147
let dcx = ecx.sess.dcx();
148148
// first get the annotable item:
149-
let (sig, is_impl): (FnSig, bool) = match &item {
149+
let sig: FnSig = match &item {
150150
Annotatable::Item(iitem) => {
151151
let sig = match &iitem.kind {
152152
ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
@@ -155,7 +155,7 @@ mod llvm_enzyme {
155155
return vec![item];
156156
}
157157
};
158-
(sig.clone(), false)
158+
sig.clone()
159159
}
160160
Annotatable::AssocItem(assoc_item, _) => {
161161
let sig = match &assoc_item.kind {
@@ -165,7 +165,24 @@ mod llvm_enzyme {
165165
return vec![item];
166166
}
167167
};
168-
(sig.clone(), true)
168+
sig.clone()
169+
}
170+
Annotatable::Stmt(stmt) => {
171+
let sig = match &stmt.kind {
172+
ast::StmtKind::Item(iitem) => match &iitem.kind {
173+
ast::ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
174+
_ => {
175+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
176+
return vec![item];
177+
}
178+
},
179+
_ => {
180+
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
181+
return vec![item];
182+
}
183+
};
184+
185+
sig.clone()
169186
}
170187
_ => {
171188
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
@@ -189,6 +206,10 @@ mod llvm_enzyme {
189206
Annotatable::AssocItem(assoc_item, _) => {
190207
(assoc_item.vis.clone(), assoc_item.ident.clone())
191208
}
209+
Annotatable::Stmt(stmt) => match &stmt.kind {
210+
ast::StmtKind::Item(iitem) => (iitem.vis.clone(), iitem.ident.clone()),
211+
_ => unreachable!("stmt kind checked previously"),
212+
},
192213
_ => {
193214
dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
194215
return vec![item];
@@ -305,6 +326,22 @@ mod llvm_enzyme {
305326
}
306327
Annotatable::AssocItem(assoc_item.clone(), i)
307328
}
329+
Annotatable::Stmt(ref mut stmt) => {
330+
match stmt.kind {
331+
ast::StmtKind::Item(ref mut iitem) => {
332+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
333+
iitem.attrs.push(attr);
334+
}
335+
if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind))
336+
{
337+
iitem.attrs.push(inline_never.clone());
338+
}
339+
}
340+
_ => unreachable!("stmt kind checked previously"),
341+
};
342+
343+
Annotatable::Stmt(stmt.clone())
344+
}
308345
_ => {
309346
unreachable!("annotatable kind checked previously")
310347
}
@@ -315,24 +352,43 @@ mod llvm_enzyme {
315352
delim: rustc_ast::token::Delimiter::Parenthesis,
316353
tokens: ts,
317354
});
355+
318356
let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
319-
let d_annotatable = if is_impl {
320-
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
321-
let d_fn = P(ast::AssocItem {
322-
attrs: thin_vec![d_attr, inline_never],
323-
id: ast::DUMMY_NODE_ID,
324-
span,
325-
vis,
326-
ident: d_ident,
327-
kind: assoc_item,
328-
tokens: None,
329-
});
330-
Annotatable::AssocItem(d_fn, Impl)
331-
} else {
332-
let mut d_fn =
333-
ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
334-
d_fn.vis = vis;
335-
Annotatable::Item(d_fn)
357+
let d_annotatable = match &item {
358+
Annotatable::AssocItem(_, _) => {
359+
let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
360+
let d_fn = P(ast::AssocItem {
361+
attrs: thin_vec![d_attr, inline_never],
362+
id: ast::DUMMY_NODE_ID,
363+
span,
364+
vis,
365+
ident: d_ident,
366+
kind: assoc_item,
367+
tokens: None,
368+
});
369+
Annotatable::AssocItem(d_fn, Impl)
370+
}
371+
Annotatable::Item(_) => {
372+
let mut d_fn =
373+
ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
374+
d_fn.vis = vis;
375+
376+
Annotatable::Item(d_fn)
377+
}
378+
Annotatable::Stmt(_) => {
379+
let mut d_fn =
380+
ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
381+
d_fn.vis = vis;
382+
383+
Annotatable::Stmt(P(ast::Stmt {
384+
id: ast::DUMMY_NODE_ID,
385+
kind: ast::StmtKind::Item(d_fn),
386+
span,
387+
}))
388+
}
389+
_ => {
390+
unreachable!("item kind checked previously")
391+
}
336392
};
337393

338394
return vec![orig_annotatable, d_annotatable];

0 commit comments

Comments
 (0)