Skip to content

Commit 44033ba

Browse files
committed
Support #[pyfunction] inside #[pymodule]
1 parent 1518767 commit 44033ba

File tree

4 files changed

+94
-2
lines changed

4 files changed

+94
-2
lines changed

CHANGELOG.md

+4
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
77

88
## Unreleased
99

10+
### Added
11+
12+
* Support for `#[pyfunction]` inside `#[pymodule]`. [#693](https://github.com/PyO3/pyo3/pull/693)
13+
1014
## [0.8.4]
1115

1216
### Added

pyo3-derive-backend/src/module.rs

+65-1
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@ pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::Lit) -> TokenStream {
2828
}
2929

3030
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
31-
pub fn process_functions_in_module(func: &mut syn::ItemFn) {
31+
pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> {
3232
let mut stmts: Vec<syn::Stmt> = Vec::new();
3333

34+
let module_name = get_module_name(&func.sig)?;
35+
3436
for stmt in func.block.stmts.iter_mut() {
3537
if let syn::Stmt::Item(syn::Item::Fn(ref mut func)) = stmt {
3638
if let Some((module_name, python_name, pyfn_attrs)) =
@@ -45,12 +47,32 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) {
4547
}
4648
};
4749
stmts.extend(item.block.stmts.into_iter());
50+
} else if is_pyfunction(&func.attrs) {
51+
// This is done lazily so that empty #[pymodules] with an argument _: &PyModule
52+
// don't raise an error.
53+
let module_name = match module_name {
54+
ModuleName::Ident(i) => i,
55+
ModuleName::Wild(u) => {
56+
return Err(syn::Error::new_spanned(
57+
u,
58+
"module name is required for `#[pyfunction]`",
59+
))
60+
}
61+
};
62+
63+
let function_wrapper_ident = function_wrapper_ident(&func.sig.ident);
64+
let stmt: syn::Stmt = syn::parse_quote! {
65+
#module_name.add_wrapped(&#function_wrapper_ident)?;
66+
};
67+
stmts.push(stmt);
4868
}
4969
};
5070
stmts.push(stmt.clone());
5171
}
5272

5373
func.block.stmts = stmts;
74+
75+
Ok(())
5476
}
5577

5678
/// Transforms a rust fn arg parsed with syn into a method::FnArg
@@ -124,6 +146,48 @@ fn extract_pyfn_attrs(
124146
Some((modname?, fnname?, fn_attrs))
125147
}
126148

149+
/// Check if the function has a #[pyfunction] attribute
150+
fn is_pyfunction(attrs: &Vec<syn::Attribute>) -> bool {
151+
for attr in attrs.iter() {
152+
match attr.parse_meta() {
153+
Ok(syn::Meta::Path(ref path)) if path.is_ident("pyfunction") => return true,
154+
_ => {}
155+
}
156+
}
157+
158+
return false;
159+
}
160+
161+
enum ModuleName<'a> {
162+
Ident(&'a syn::Ident),
163+
Wild(&'a syn::token::Underscore),
164+
}
165+
166+
fn get_module_name(sig: &syn::Signature) -> syn::Result<ModuleName> {
167+
// module is second argument to #[pymodule]
168+
if sig.inputs.len() != 2 {
169+
return Err(syn::Error::new_spanned(
170+
&sig,
171+
"#[pymodule] expects two arguments",
172+
));
173+
}
174+
175+
match &sig.inputs[1] {
176+
syn::FnArg::Typed(syn::PatType { pat, .. }) => match &**pat {
177+
syn::Pat::Ident(syn::PatIdent { ident, .. }) => Ok(ModuleName::Ident(ident)),
178+
syn::Pat::Wild(syn::PatWild {
179+
underscore_token: u,
180+
..
181+
}) => Ok(ModuleName::Wild(u)),
182+
_ => Err(syn::Error::new_spanned(pat, "expected #[pymodule] name")),
183+
},
184+
syn::FnArg::Receiver(_) => Err(syn::Error::new_spanned(
185+
&sig.inputs[1],
186+
"expected module argument",
187+
)),
188+
}
189+
}
190+
127191
/// Coordinates the naming of a the add-function-to-python-module function
128192
fn function_wrapper_ident(name: &Ident) -> Ident {
129193
// Make sure this ident matches the one of wrap_pyfunction

pyo3cls/src/lib.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@ pub fn pymodule(attr: TokenStream, input: TokenStream) -> TokenStream {
2525
parse_macro_input!(attr as syn::Ident)
2626
};
2727

28-
process_functions_in_module(&mut ast);
28+
if let Err(err) = process_functions_in_module(&mut ast) {
29+
return err.to_compile_error().into();
30+
}
2931

3032
let doc = match get_doc(&ast.attrs, None, false) {
3133
Ok(doc) => doc,

tests/test_module.rs

+22
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,28 @@ fn test_raw_idents() {
148148
py_assert!(py, module, "module.move() == 42");
149149
}
150150

151+
#[pymodule]
152+
fn pyfunction_module(_py: Python, module: &PyModule) -> PyResult<()> {
153+
#[pyfunction]
154+
fn foobar() -> usize {
155+
101
156+
}
157+
158+
Ok(())
159+
}
160+
161+
#[test]
162+
fn test_pyfunction_module() {
163+
use pyo3::wrap_pymodule;
164+
165+
let gil = Python::acquire_gil();
166+
let py = gil.python();
167+
168+
let module = wrap_pymodule!(pyfunction_module)(py);
169+
170+
py_assert!(py, module, "module.foobar() == 101");
171+
}
172+
151173
#[pyfunction]
152174
fn subfunction() -> String {
153175
"Subfunction".to_string()

0 commit comments

Comments
 (0)