@@ -2,7 +2,7 @@ use std::iter;
2
2
3
3
use ast:: make;
4
4
use either:: Either ;
5
- use hir:: { HirDisplay , InFile , Local , ModuleDef , Semantics , TypeInfo } ;
5
+ use hir:: { HirDisplay , InFile , Local , ModuleDef , PathResolution , Semantics , TypeInfo , TypeParam } ;
6
6
use ide_db:: {
7
7
defs:: { Definition , NameRefClass } ,
8
8
famous_defs:: FamousDefs ,
@@ -469,6 +469,24 @@ impl FunctionBody {
469
469
}
470
470
}
471
471
472
+ fn descendants ( & self ) -> impl Iterator < Item = SyntaxNode > {
473
+ match self {
474
+ FunctionBody :: Expr ( expr) => expr. syntax ( ) . descendants ( ) ,
475
+ FunctionBody :: Span { parent, .. } => parent. syntax ( ) . descendants ( ) ,
476
+ }
477
+ }
478
+
479
+ fn descendant_paths ( & self ) -> impl Iterator < Item = ast:: Path > {
480
+ self . descendants ( ) . filter_map ( |node| {
481
+ match_ast ! {
482
+ match node {
483
+ ast:: Path ( it) => Some ( it) ,
484
+ _ => None
485
+ }
486
+ }
487
+ } )
488
+ }
489
+
472
490
fn from_expr ( expr : ast:: Expr ) -> Option < Self > {
473
491
match expr {
474
492
ast:: Expr :: BreakExpr ( it) => it. expr ( ) . map ( Self :: Expr ) ,
@@ -678,11 +696,16 @@ impl FunctionBody {
678
696
parent_loop. get_or_insert ( loop_. syntax ( ) . clone ( ) ) ;
679
697
}
680
698
} ;
681
- let ( is_const, expr, ty, generic_param_list, where_clause) = loop {
699
+
700
+ let ( mut generic_param_list, mut where_clause) = ( None , None ) ;
701
+ let ( is_const, expr, ty) = loop {
682
702
let anc = ancestors. next ( ) ?;
683
703
break match_ast ! {
684
704
match anc {
685
- ast:: ClosureExpr ( closure) => ( false , closure. body( ) , infer_expr_opt( closure. body( ) ) , closure. generic_param_list( ) , None ) ,
705
+ ast:: ClosureExpr ( closure) => {
706
+ generic_param_list = closure. generic_param_list( ) ;
707
+ ( false , closure. body( ) , infer_expr_opt( closure. body( ) ) )
708
+ } ,
686
709
ast:: BlockExpr ( block_expr) => {
687
710
let ( constness, block) = match block_expr. modifier( ) {
688
711
Some ( ast:: BlockModifier :: Const ( _) ) => ( true , block_expr) ,
@@ -691,7 +714,7 @@ impl FunctionBody {
691
714
_ => continue ,
692
715
} ;
693
716
let expr = Some ( ast:: Expr :: BlockExpr ( block) ) ;
694
- ( constness, expr. clone( ) , infer_expr_opt( expr) , None , None )
717
+ ( constness, expr. clone( ) , infer_expr_opt( expr) )
695
718
} ,
696
719
ast:: Fn ( fn_) => {
697
720
let func = sema. to_def( & fn_) ?;
@@ -701,23 +724,25 @@ impl FunctionBody {
701
724
ret_ty = async_ret;
702
725
}
703
726
}
704
- ( fn_. const_token( ) . is_some( ) , fn_. body( ) . map( ast:: Expr :: BlockExpr ) , Some ( ret_ty) , fn_. generic_param_list( ) , fn_. where_clause( ) )
727
+ generic_param_list = fn_. generic_param_list( ) ;
728
+ where_clause = fn_. where_clause( ) ;
729
+ ( fn_. const_token( ) . is_some( ) , fn_. body( ) . map( ast:: Expr :: BlockExpr ) , Some ( ret_ty) )
705
730
} ,
706
731
ast:: Static ( statik) => {
707
- ( true , statik. body( ) , Some ( sema. to_def( & statik) ?. ty( sema. db) ) , None , None )
732
+ ( true , statik. body( ) , Some ( sema. to_def( & statik) ?. ty( sema. db) ) )
708
733
} ,
709
734
ast:: ConstArg ( ca) => {
710
- ( true , ca. expr( ) , infer_expr_opt( ca. expr( ) ) , None , None )
735
+ ( true , ca. expr( ) , infer_expr_opt( ca. expr( ) ) )
711
736
} ,
712
737
ast:: Const ( konst) => {
713
- ( true , konst. body( ) , Some ( sema. to_def( & konst) ?. ty( sema. db) ) , None , None )
738
+ ( true , konst. body( ) , Some ( sema. to_def( & konst) ?. ty( sema. db) ) )
714
739
} ,
715
740
ast:: ConstParam ( cp) => {
716
- ( true , cp. default_val( ) , Some ( sema. to_def( & cp) ?. ty( sema. db) ) , None , None )
741
+ ( true , cp. default_val( ) , Some ( sema. to_def( & cp) ?. ty( sema. db) ) )
717
742
} ,
718
743
ast:: ConstBlockPat ( cbp) => {
719
744
let expr = cbp. block_expr( ) . map( ast:: Expr :: BlockExpr ) ;
720
- ( true , expr. clone( ) , infer_expr_opt( expr) , None , None )
745
+ ( true , expr. clone( ) , infer_expr_opt( expr) )
721
746
} ,
722
747
ast:: Variant ( __) => return None ,
723
748
ast:: Meta ( __) => return None ,
@@ -1321,8 +1346,7 @@ fn format_function(
1321
1346
let const_kw = if fun. mods . is_const { "const " } else { "" } ;
1322
1347
let async_kw = if fun. control_flow . is_async { "async " } else { "" } ;
1323
1348
let unsafe_kw = if fun. control_flow . is_unsafe { "unsafe " } else { "" } ;
1324
- let generic_params = format_generic_param_list ( fun) ;
1325
- let where_clause = format_where_clause ( fun) ;
1349
+ let ( generic_params, where_clause) = format_generic_params_and_where_clause ( ctx, fun) ;
1326
1350
match ctx. config . snippet_cap {
1327
1351
Some ( _) => format_to ! (
1328
1352
fn_def,
@@ -1357,13 +1381,50 @@ fn format_function(
1357
1381
fn_def
1358
1382
}
1359
1383
1360
- fn format_generic_param_list ( fun : & Function ) -> String {
1384
+ fn format_generic_params_and_where_clause ( ctx : & AssistContext , fun : & Function ) -> ( String , String ) {
1385
+ ( format_generic_param_list ( fun, ctx) , format_where_clause ( fun) )
1386
+ }
1387
+
1388
+ fn format_generic_param_list ( fun : & Function , ctx : & AssistContext ) -> String {
1389
+ let type_params_in_descendant_paths =
1390
+ fun. body . descendant_paths ( ) . filter_map ( |it| match ctx. sema . resolve_path ( & it) {
1391
+ Some ( PathResolution :: TypeParam ( type_param) ) => Some ( type_param) ,
1392
+ _ => None ,
1393
+ } ) ;
1394
+
1395
+ let type_params_in_params = fun. params . iter ( ) . filter_map ( |p| p. ty . as_type_param ( ctx. db ( ) ) ) ;
1396
+
1397
+ let used_type_params: Vec < TypeParam > =
1398
+ type_params_in_descendant_paths. chain ( type_params_in_params) . collect ( ) ;
1399
+
1361
1400
match & fun. mods . generic_param_list {
1362
- Some ( it) => format ! ( "{}" , it) ,
1401
+ Some ( list) => {
1402
+ let filtered_generic_params = filter_generic_param_list ( ctx, list, used_type_params) ;
1403
+ if filtered_generic_params. is_empty ( ) {
1404
+ return "" . to_string ( ) ;
1405
+ }
1406
+ format ! ( "{}" , make:: generic_param_list( filtered_generic_params) )
1407
+ }
1363
1408
None => "" . to_string ( ) ,
1364
1409
}
1365
1410
}
1366
1411
1412
+ fn filter_generic_param_list (
1413
+ ctx : & AssistContext ,
1414
+ list : & ast:: GenericParamList ,
1415
+ used_type_params : Vec < TypeParam > ,
1416
+ ) -> Vec < ast:: GenericParam > {
1417
+ list. generic_params ( )
1418
+ . filter ( |p| match p {
1419
+ ast:: GenericParam :: ConstParam ( _) | ast:: GenericParam :: LifetimeParam ( _) => true ,
1420
+ ast:: GenericParam :: TypeParam ( type_param) => match & ctx. sema . to_def ( type_param) {
1421
+ Some ( def) => used_type_params. iter ( ) . contains ( def) ,
1422
+ _ => false ,
1423
+ } ,
1424
+ } )
1425
+ . collect ( )
1426
+ }
1427
+
1367
1428
fn format_where_clause ( fun : & Function ) -> String {
1368
1429
match & fun. mods . where_clause {
1369
1430
Some ( it) => format ! ( " {}" , it) ,
@@ -4764,6 +4825,73 @@ fn $0fun_name<T: Debug>(i: T) {
4764
4825
) ;
4765
4826
}
4766
4827
4828
+ #[ test]
4829
+ fn preserve_generics_from_body ( ) {
4830
+ check_assist (
4831
+ extract_function,
4832
+ r#"
4833
+ fn func<T: Default>() -> T {
4834
+ $0T::default()$0
4835
+ }
4836
+ "# ,
4837
+ r#"
4838
+ fn func<T: Default>() -> T {
4839
+ fun_name()
4840
+ }
4841
+
4842
+ fn $0fun_name<T: Default>() -> T {
4843
+ T::default()
4844
+ }
4845
+ "# ,
4846
+ ) ;
4847
+ }
4848
+
4849
+ #[ test]
4850
+ fn filter_unused_generics ( ) {
4851
+ check_assist (
4852
+ extract_function,
4853
+ r#"
4854
+ fn func<T: Debug, U: Copy>(i: T, u: U) {
4855
+ bar(u);
4856
+ $0foo(i);$0
4857
+ }
4858
+ "# ,
4859
+ r#"
4860
+ fn func<T: Debug, U: Copy>(i: T, u: U) {
4861
+ bar(u);
4862
+ fun_name(i);
4863
+ }
4864
+
4865
+ fn $0fun_name<T: Debug>(i: T) {
4866
+ foo(i);
4867
+ }
4868
+ "# ,
4869
+ ) ;
4870
+ }
4871
+
4872
+ #[ test]
4873
+ fn empty_generic_param_list ( ) {
4874
+ check_assist (
4875
+ extract_function,
4876
+ r#"
4877
+ fn func<T: Debug>(t: T, i: u32) {
4878
+ bar(t);
4879
+ $0foo(i);$0
4880
+ }
4881
+ "# ,
4882
+ r#"
4883
+ fn func<T: Debug>(t: T, i: u32) {
4884
+ bar(t);
4885
+ fun_name(i);
4886
+ }
4887
+
4888
+ fn $0fun_name(i: u32) {
4889
+ foo(i);
4890
+ }
4891
+ "# ,
4892
+ ) ;
4893
+ }
4894
+
4767
4895
#[ test]
4768
4896
fn preserve_where_clause ( ) {
4769
4897
check_assist (
0 commit comments