@@ -4,31 +4,33 @@ use rustpython_parser::{
4
4
self ,
5
5
ast:: { Mod , Stmt } ,
6
6
} ;
7
+ use std:: collections:: HashMap ;
7
8
use std:: collections:: HashSet ;
8
9
use std:: fs;
9
- use std:: { collections:: HashMap , sync:: Arc } ;
10
10
11
11
use crate :: ast_visit;
12
12
use crate :: indexing;
13
13
use crate :: package_discovery;
14
14
15
- pub type Imports = HashMap < String , HashSet < String > > ;
15
+ pub type Imports < ' a > = HashMap < & ' a str , HashSet < & ' a str > > ;
16
16
17
- pub fn discover_imports ( modules_by_pypath : & indexing:: ModulesByPypath ) -> Result < Imports > {
17
+ pub fn discover_imports < ' a > (
18
+ modules_by_pypath : & ' a indexing:: ModulesByPypath ,
19
+ ) -> Result < Imports < ' a > > {
18
20
modules_by_pypath
19
21
. values ( )
20
22
. par_bridge ( )
21
23
. map ( |module| {
22
- let imports = get_imports_for_module ( Arc :: clone ( module) , modules_by_pypath) ?;
23
- Ok ( ( module. pypath . clone ( ) , imports) )
24
+ let imports = get_imports_for_module ( module, modules_by_pypath) ?;
25
+ Ok ( ( module. pypath . as_str ( ) , imports) )
24
26
} )
25
27
. collect :: < Result < Imports > > ( )
26
28
}
27
29
28
- fn get_imports_for_module (
29
- module : Arc < package_discovery:: Module > ,
30
- modules_by_pypath : & indexing:: ModulesByPypath ,
31
- ) -> Result < HashSet < String > > {
30
+ fn get_imports_for_module < ' a > (
31
+ module : & ' a package_discovery:: Module ,
32
+ modules_by_pypath : & ' a indexing:: ModulesByPypath ,
33
+ ) -> Result < HashSet < & ' a str > > {
32
34
let code = fs:: read_to_string ( & module. path ) ?;
33
35
let ast = rustpython_parser:: parse (
34
36
& code,
@@ -44,7 +46,7 @@ fn get_imports_for_module(
44
46
} ;
45
47
46
48
let mut visitor = ImportVisitor {
47
- module : Arc :: clone ( & module ) ,
49
+ module,
48
50
modules_by_pypath,
49
51
imports : HashSet :: new ( ) ,
50
52
} ;
@@ -54,9 +56,9 @@ fn get_imports_for_module(
54
56
}
55
57
56
58
struct ImportVisitor < ' a > {
57
- module : Arc < package_discovery:: Module > ,
58
- modules_by_pypath : & ' a indexing:: ModulesByPypath ,
59
- imports : HashSet < String > ,
59
+ module : & ' a package_discovery:: Module ,
60
+ modules_by_pypath : & ' a indexing:: ModulesByPypath < ' a > ,
61
+ imports : HashSet < & ' a str > ,
60
62
}
61
63
62
64
impl < ' a > ast_visit:: StatementVisitor for ImportVisitor < ' a > {
@@ -65,9 +67,9 @@ impl<'a> ast_visit::StatementVisitor for ImportVisitor<'a> {
65
67
Stmt :: Import ( stmt) => {
66
68
for name in stmt. names . iter ( ) {
67
69
for pypath in [ name. name . to_string ( ) , format ! ( "{}.__init__" , name. name) ] {
68
- match self . modules_by_pypath . get ( & pypath) {
70
+ match self . modules_by_pypath . get ( pypath. as_str ( ) ) {
69
71
Some ( imported_module) => {
70
- self . imports . insert ( imported_module. pypath . clone ( ) ) ;
72
+ self . imports . insert ( imported_module. pypath . as_str ( ) ) ;
71
73
}
72
74
None => continue ,
73
75
}
@@ -116,9 +118,9 @@ impl<'a> ast_visit::StatementVisitor for ImportVisitor<'a> {
116
118
format ! ( "{}.{}.__init__" , & pypath_prefix, name. name) ,
117
119
format ! ( "{}.__init__" , & pypath_prefix) ,
118
120
] {
119
- match self . modules_by_pypath . get ( & pypath) {
121
+ match self . modules_by_pypath . get ( pypath. as_str ( ) ) {
120
122
Some ( imported_module) => {
121
- self . imports . insert ( imported_module. pypath . clone ( ) ) ;
123
+ self . imports . insert ( imported_module. pypath . as_str ( ) ) ;
122
124
break ;
123
125
}
124
126
None => continue ,
@@ -143,10 +145,10 @@ mod tests {
143
145
fn test_get_imports_for_module ( ) {
144
146
let root_package_path = Path :: new ( "./example" ) ;
145
147
let root_package = package_discovery:: discover_package ( root_package_path) . unwrap ( ) ;
146
- let modules_by_pypath = indexing:: get_modules_by_pypath ( Arc :: clone ( & root_package) ) . unwrap ( ) ;
148
+ let modules_by_pypath = indexing:: get_modules_by_pypath ( & root_package) . unwrap ( ) ;
147
149
148
150
let module = modules_by_pypath. get ( "example.__init__" ) . unwrap ( ) ;
149
- let imports = get_imports_for_module ( Arc :: clone ( module) , & modules_by_pypath) . unwrap ( ) ;
151
+ let imports = get_imports_for_module ( module, & modules_by_pypath) . unwrap ( ) ;
150
152
assert_eq ! (
151
153
imports,
152
154
[
@@ -167,12 +169,11 @@ mod tests {
167
169
"example.child5.__init__" ,
168
170
]
169
171
. into_iter( )
170
- . map( |i| i. to_string( ) )
171
172
. collect:: <HashSet <_>>( )
172
173
) ;
173
174
174
175
let module = modules_by_pypath. get ( "example.child.__init__" ) . unwrap ( ) ;
175
- let imports = get_imports_for_module ( Arc :: clone ( module) , & modules_by_pypath) . unwrap ( ) ;
176
+ let imports = get_imports_for_module ( module, & modules_by_pypath) . unwrap ( ) ;
176
177
assert_eq ! (
177
178
imports,
178
179
[
@@ -193,12 +194,11 @@ mod tests {
193
194
"example.child5.__init__" ,
194
195
]
195
196
. into_iter( )
196
- . map( |i| i. to_string( ) )
197
197
. collect:: <HashSet <_>>( )
198
198
) ;
199
199
200
200
let module = modules_by_pypath. get ( "example.z" ) . unwrap ( ) ;
201
- let imports = get_imports_for_module ( Arc :: clone ( module) , & modules_by_pypath) . unwrap ( ) ;
201
+ let imports = get_imports_for_module ( module, & modules_by_pypath) . unwrap ( ) ;
202
202
assert_eq ! (
203
203
imports,
204
204
[
@@ -219,12 +219,11 @@ mod tests {
219
219
"example.child5.__init__" ,
220
220
]
221
221
. into_iter( )
222
- . map( |i| i. to_string( ) )
223
222
. collect:: <HashSet <_>>( )
224
223
) ;
225
224
226
225
let module = modules_by_pypath. get ( "example.child.c_z" ) . unwrap ( ) ;
227
- let imports = get_imports_for_module ( Arc :: clone ( module) , & modules_by_pypath) . unwrap ( ) ;
226
+ let imports = get_imports_for_module ( module, & modules_by_pypath) . unwrap ( ) ;
228
227
assert_eq ! (
229
228
imports,
230
229
[
@@ -245,7 +244,6 @@ mod tests {
245
244
"example.child5.__init__" ,
246
245
]
247
246
. into_iter( )
248
- . map( |i| i. to_string( ) )
249
247
. collect:: <HashSet <_>>( )
250
248
) ;
251
249
}
0 commit comments