From 44033ba35959b43ef1ccb85ff67db707313a74ae Mon Sep 17 00:00:00 2001 From: David Hewitt <1939362+davidhewitt@users.noreply.github.com> Date: Wed, 18 Dec 2019 09:58:28 +0000 Subject: [PATCH] Support `#[pyfunction]` inside `#[pymodule]` --- CHANGELOG.md | 4 ++ pyo3-derive-backend/src/module.rs | 66 ++++++++++++++++++++++++++++++- pyo3cls/src/lib.rs | 4 +- tests/test_module.rs | 22 +++++++++++ 4 files changed, 94 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index fadf9d99cde..289d367bccb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0. ## Unreleased +### Added + +* Support for `#[pyfunction]` inside `#[pymodule]`. [#693](https://github.com/PyO3/pyo3/pull/693) + ## [0.8.4] ### Added diff --git a/pyo3-derive-backend/src/module.rs b/pyo3-derive-backend/src/module.rs index 3ee1433b5eb..cd0fdc207f3 100644 --- a/pyo3-derive-backend/src/module.rs +++ b/pyo3-derive-backend/src/module.rs @@ -28,9 +28,11 @@ pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::Lit) -> TokenStream { } /// Finds and takes care of the #[pyfn(...)] in `#[pymodule]` -pub fn process_functions_in_module(func: &mut syn::ItemFn) { +pub fn process_functions_in_module(func: &mut syn::ItemFn) -> syn::Result<()> { let mut stmts: Vec = Vec::new(); + let module_name = get_module_name(&func.sig)?; + for stmt in func.block.stmts.iter_mut() { if let syn::Stmt::Item(syn::Item::Fn(ref mut func)) = stmt { if let Some((module_name, python_name, pyfn_attrs)) = @@ -45,12 +47,32 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) { } }; stmts.extend(item.block.stmts.into_iter()); + } else if is_pyfunction(&func.attrs) { + // This is done lazily so that empty #[pymodules] with an argument _: &PyModule + // don't raise an error. + let module_name = match module_name { + ModuleName::Ident(i) => i, + ModuleName::Wild(u) => { + return Err(syn::Error::new_spanned( + u, + "module name is required for `#[pyfunction]`", + )) + } + }; + + let function_wrapper_ident = function_wrapper_ident(&func.sig.ident); + let stmt: syn::Stmt = syn::parse_quote! { + #module_name.add_wrapped(&#function_wrapper_ident)?; + }; + stmts.push(stmt); } }; stmts.push(stmt.clone()); } func.block.stmts = stmts; + + Ok(()) } /// Transforms a rust fn arg parsed with syn into a method::FnArg @@ -124,6 +146,48 @@ fn extract_pyfn_attrs( Some((modname?, fnname?, fn_attrs)) } +/// Check if the function has a #[pyfunction] attribute +fn is_pyfunction(attrs: &Vec) -> bool { + for attr in attrs.iter() { + match attr.parse_meta() { + Ok(syn::Meta::Path(ref path)) if path.is_ident("pyfunction") => return true, + _ => {} + } + } + + return false; +} + +enum ModuleName<'a> { + Ident(&'a syn::Ident), + Wild(&'a syn::token::Underscore), +} + +fn get_module_name(sig: &syn::Signature) -> syn::Result { + // module is second argument to #[pymodule] + if sig.inputs.len() != 2 { + return Err(syn::Error::new_spanned( + &sig, + "#[pymodule] expects two arguments", + )); + } + + match &sig.inputs[1] { + syn::FnArg::Typed(syn::PatType { pat, .. }) => match &**pat { + syn::Pat::Ident(syn::PatIdent { ident, .. }) => Ok(ModuleName::Ident(ident)), + syn::Pat::Wild(syn::PatWild { + underscore_token: u, + .. + }) => Ok(ModuleName::Wild(u)), + _ => Err(syn::Error::new_spanned(pat, "expected #[pymodule] name")), + }, + syn::FnArg::Receiver(_) => Err(syn::Error::new_spanned( + &sig.inputs[1], + "expected module argument", + )), + } +} + /// Coordinates the naming of a the add-function-to-python-module function fn function_wrapper_ident(name: &Ident) -> Ident { // Make sure this ident matches the one of wrap_pyfunction diff --git a/pyo3cls/src/lib.rs b/pyo3cls/src/lib.rs index e85a645adc4..541392f3320 100644 --- a/pyo3cls/src/lib.rs +++ b/pyo3cls/src/lib.rs @@ -25,7 +25,9 @@ pub fn pymodule(attr: TokenStream, input: TokenStream) -> TokenStream { parse_macro_input!(attr as syn::Ident) }; - process_functions_in_module(&mut ast); + if let Err(err) = process_functions_in_module(&mut ast) { + return err.to_compile_error().into(); + } let doc = match get_doc(&ast.attrs, None, false) { Ok(doc) => doc, diff --git a/tests/test_module.rs b/tests/test_module.rs index cda0d8356c2..1bea8e8e330 100644 --- a/tests/test_module.rs +++ b/tests/test_module.rs @@ -148,6 +148,28 @@ fn test_raw_idents() { py_assert!(py, module, "module.move() == 42"); } +#[pymodule] +fn pyfunction_module(_py: Python, module: &PyModule) -> PyResult<()> { + #[pyfunction] + fn foobar() -> usize { + 101 + } + + Ok(()) +} + +#[test] +fn test_pyfunction_module() { + use pyo3::wrap_pymodule; + + let gil = Python::acquire_gil(); + let py = gil.python(); + + let module = wrap_pymodule!(pyfunction_module)(py); + + py_assert!(py, module, "module.foobar() == 101"); +} + #[pyfunction] fn subfunction() -> String { "Subfunction".to_string()