@@ -28,9 +28,11 @@ pub fn py_init(fnname: &Ident, name: &Ident, doc: syn::Lit) -> TokenStream {
28
28
}
29
29
30
30
/// Finds and takes care of the #[pyfn(...)] in `#[pymodule]`
31
- pub fn process_functions_in_module ( func : & mut syn:: ItemFn ) {
31
+ pub fn process_functions_in_module ( func : & mut syn:: ItemFn ) -> syn :: Result < ( ) > {
32
32
let mut stmts: Vec < syn:: Stmt > = Vec :: new ( ) ;
33
33
34
+ let module_name = get_module_name ( & func. sig ) ?;
35
+
34
36
for stmt in func. block . stmts . iter_mut ( ) {
35
37
if let syn:: Stmt :: Item ( syn:: Item :: Fn ( ref mut func) ) = stmt {
36
38
if let Some ( ( module_name, python_name, pyfn_attrs) ) =
@@ -45,12 +47,32 @@ pub fn process_functions_in_module(func: &mut syn::ItemFn) {
45
47
}
46
48
} ;
47
49
stmts. extend ( item. block . stmts . into_iter ( ) ) ;
50
+ } else if is_pyfunction ( & func. attrs ) {
51
+ // This is done lazily so that empty #[pymodules] with an argument _: &PyModule
52
+ // don't raise an error.
53
+ let module_name = match module_name {
54
+ ModuleName :: Ident ( i) => i,
55
+ ModuleName :: Wild ( u) => {
56
+ return Err ( syn:: Error :: new_spanned (
57
+ u,
58
+ "module name is required for `#[pyfunction]`" ,
59
+ ) )
60
+ }
61
+ } ;
62
+
63
+ let function_wrapper_ident = function_wrapper_ident ( & func. sig . ident ) ;
64
+ let stmt: syn:: Stmt = syn:: parse_quote! {
65
+ #module_name. add_wrapped( & #function_wrapper_ident) ?;
66
+ } ;
67
+ stmts. push ( stmt) ;
48
68
}
49
69
} ;
50
70
stmts. push ( stmt. clone ( ) ) ;
51
71
}
52
72
53
73
func. block . stmts = stmts;
74
+
75
+ Ok ( ( ) )
54
76
}
55
77
56
78
/// Transforms a rust fn arg parsed with syn into a method::FnArg
@@ -124,6 +146,48 @@ fn extract_pyfn_attrs(
124
146
Some ( ( modname?, fnname?, fn_attrs) )
125
147
}
126
148
149
+ /// Check if the function has a #[pyfunction] attribute
150
+ fn is_pyfunction ( attrs : & Vec < syn:: Attribute > ) -> bool {
151
+ for attr in attrs. iter ( ) {
152
+ match attr. parse_meta ( ) {
153
+ Ok ( syn:: Meta :: Path ( ref path) ) if path. is_ident ( "pyfunction" ) => return true ,
154
+ _ => { }
155
+ }
156
+ }
157
+
158
+ return false ;
159
+ }
160
+
161
+ enum ModuleName < ' a > {
162
+ Ident ( & ' a syn:: Ident ) ,
163
+ Wild ( & ' a syn:: token:: Underscore ) ,
164
+ }
165
+
166
+ fn get_module_name ( sig : & syn:: Signature ) -> syn:: Result < ModuleName > {
167
+ // module is second argument to #[pymodule]
168
+ if sig. inputs . len ( ) != 2 {
169
+ return Err ( syn:: Error :: new_spanned (
170
+ & sig,
171
+ "#[pymodule] expects two arguments" ,
172
+ ) ) ;
173
+ }
174
+
175
+ match & sig. inputs [ 1 ] {
176
+ syn:: FnArg :: Typed ( syn:: PatType { pat, .. } ) => match & * * pat {
177
+ syn:: Pat :: Ident ( syn:: PatIdent { ident, .. } ) => Ok ( ModuleName :: Ident ( ident) ) ,
178
+ syn:: Pat :: Wild ( syn:: PatWild {
179
+ underscore_token : u,
180
+ ..
181
+ } ) => Ok ( ModuleName :: Wild ( u) ) ,
182
+ _ => Err ( syn:: Error :: new_spanned ( pat, "expected #[pymodule] name" ) ) ,
183
+ } ,
184
+ syn:: FnArg :: Receiver ( _) => Err ( syn:: Error :: new_spanned (
185
+ & sig. inputs [ 1 ] ,
186
+ "expected module argument" ,
187
+ ) ) ,
188
+ }
189
+ }
190
+
127
191
/// Coordinates the naming of a the add-function-to-python-module function
128
192
fn function_wrapper_ident ( name : & Ident ) -> Ident {
129
193
// Make sure this ident matches the one of wrap_pyfunction
0 commit comments