Skip to content

Commit 24a4b34

Browse files
committed
Support nested generics and where clauses
Instead of only looking for generics and where clauses on the immediate parent, we can traverse all the `ancestors` and collect all the generics and where clauses that may be in scope.
1 parent ec71899 commit 24a4b34

File tree

1 file changed

+186
-43
lines changed

1 file changed

+186
-43
lines changed

crates/ide-assists/src/handlers/extract_function.rs

+186-43
Original file line numberDiff line numberDiff line change
@@ -296,8 +296,8 @@ struct ContainerInfo {
296296
parent_loop: Option<SyntaxNode>,
297297
/// The function's return type, const's type etc.
298298
ret_type: Option<hir::Type>,
299-
generic_param_list: Option<ast::GenericParamList>,
300-
where_clause: Option<ast::WhereClause>,
299+
generic_param_lists: Vec<ast::GenericParamList>,
300+
where_clauses: Vec<ast::WhereClause>,
301301
}
302302

303303
/// Control flow that is exported from extracted function
@@ -754,17 +754,11 @@ impl FunctionBody {
754754
}
755755
};
756756

757-
let mut generic_param_list = None;
758-
let mut where_clause = None;
759-
760757
let (is_const, expr, ty) = loop {
761758
let anc = ancestors.next()?;
762759
break match_ast! {
763760
match anc {
764-
ast::ClosureExpr(closure) => {
765-
generic_param_list = closure.generic_param_list();
766-
(false, closure.body(), infer_expr_opt(closure.body()))
767-
},
761+
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())),
768762
ast::BlockExpr(block_expr) => {
769763
let (constness, block) = match block_expr.modifier() {
770764
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
@@ -783,8 +777,6 @@ impl FunctionBody {
783777
ret_ty = async_ret;
784778
}
785779
}
786-
generic_param_list = fn_.generic_param_list();
787-
where_clause = fn_.where_clause();
788780
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
789781
},
790782
ast::Static(statik) => {
@@ -829,13 +821,18 @@ impl FunctionBody {
829821
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
830822
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
831823
});
824+
825+
let parent = self.parent()?;
826+
let generic_param_lists = parent_generic_param_lists(&parent);
827+
let where_clauses = parent_where_clauses(&parent);
828+
832829
Some(ContainerInfo {
833830
is_in_tail,
834831
is_const,
835832
parent_loop,
836833
ret_type: ty,
837-
generic_param_list,
838-
where_clause,
834+
generic_param_lists,
835+
where_clauses,
839836
})
840837
}
841838

@@ -993,6 +990,26 @@ impl FunctionBody {
993990
}
994991
}
995992

993+
fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
994+
let mut where_clause: Vec<ast::WhereClause> = parent
995+
.ancestors()
996+
.filter_map(ast::AnyHasGenericParams::cast)
997+
.filter_map(|it| it.where_clause())
998+
.collect();
999+
where_clause.reverse();
1000+
where_clause
1001+
}
1002+
1003+
fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
1004+
let mut generic_param_list: Vec<ast::GenericParamList> = parent
1005+
.ancestors()
1006+
.filter_map(ast::AnyHasGenericParams::cast)
1007+
.filter_map(|it| it.generic_param_list())
1008+
.collect();
1009+
generic_param_list.reverse();
1010+
generic_param_list
1011+
}
1012+
9961013
/// checks if relevant var is used with `&mut` access inside body
9971014
fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
9981015
usages
@@ -1428,14 +1445,14 @@ fn format_function(
14281445

14291446
format_to!(fn_def, "{}", params);
14301447

1431-
if let Some(where_clause) = where_clause {
1432-
format_to!(fn_def, " {}", where_clause);
1433-
}
1434-
14351448
if let Some(ret_ty) = ret_ty {
14361449
format_to!(fn_def, " {}", ret_ty);
14371450
}
14381451

1452+
if let Some(where_clause) = where_clause {
1453+
format_to!(fn_def, " {}", where_clause);
1454+
}
1455+
14391456
format_to!(fn_def, " {}", body);
14401457

14411458
fn_def
@@ -1447,34 +1464,32 @@ fn make_generic_params_and_where_clause(
14471464
) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
14481465
let used_type_params = fun.type_params(ctx);
14491466

1450-
let generic_param_list = fun
1451-
.mods
1452-
.generic_param_list
1453-
.as_ref()
1454-
.map(|parent_params| make_generic_param_list(ctx, parent_params, &used_type_params))
1455-
.flatten();
1456-
1457-
let where_clause =
1458-
fun.mods.where_clause.as_ref().map(|parent_where_clause| {
1459-
make_where_clause(ctx, parent_where_clause, &used_type_params)
1460-
});
1467+
let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params);
1468+
let where_clause = make_where_clause(ctx, fun, &used_type_params);
14611469

14621470
(generic_param_list, where_clause)
14631471
}
14641472

14651473
fn make_generic_param_list(
14661474
ctx: &AssistContext,
1467-
parent_params: &ast::GenericParamList,
1475+
fun: &Function,
14681476
used_type_params: &[TypeParam],
14691477
) -> Option<ast::GenericParamList> {
1470-
let required_generic_params: Vec<ast::GenericParam> = parent_params
1471-
.generic_params()
1472-
.filter(|param| param_is_required(ctx, param, used_type_params))
1473-
.collect();
1474-
if required_generic_params.is_empty() {
1475-
None
1478+
let mut generic_params = fun
1479+
.mods
1480+
.generic_param_lists
1481+
.iter()
1482+
.flat_map(|parent_params| {
1483+
parent_params
1484+
.generic_params()
1485+
.filter(|param| param_is_required(ctx, param, used_type_params))
1486+
})
1487+
.peekable();
1488+
1489+
if generic_params.peek().is_some() {
1490+
Some(make::generic_param_list(generic_params))
14761491
} else {
1477-
Some(make::generic_param_list(required_generic_params))
1492+
None
14781493
}
14791494
}
14801495

@@ -1486,21 +1501,33 @@ fn param_is_required(
14861501
match param {
14871502
ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false,
14881503
ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
1489-
Some(def) => used_type_params.iter().contains(def),
1504+
Some(def) => used_type_params.contains(def),
14901505
_ => false,
14911506
},
14921507
}
14931508
}
14941509

14951510
fn make_where_clause(
14961511
ctx: &AssistContext,
1497-
parent_where_clause: &ast::WhereClause,
1512+
fun: &Function,
14981513
used_type_params: &[TypeParam],
1499-
) -> ast::WhereClause {
1500-
let preds = parent_where_clause
1501-
.predicates()
1502-
.filter(|pred| pred_is_required(ctx, pred, used_type_params));
1503-
make::where_clause(preds)
1514+
) -> Option<ast::WhereClause> {
1515+
let mut predicates = fun
1516+
.mods
1517+
.where_clauses
1518+
.iter()
1519+
.flat_map(|parent_where_clause| {
1520+
parent_where_clause
1521+
.predicates()
1522+
.filter(|pred| pred_is_required(ctx, pred, used_type_params))
1523+
})
1524+
.peekable();
1525+
1526+
if predicates.peek().is_some() {
1527+
Some(make::where_clause(predicates))
1528+
} else {
1529+
None
1530+
}
15041531
}
15051532

15061533
fn pred_is_required(
@@ -5149,6 +5176,122 @@ fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
51495176
fn $0fun_name<T>(i: T) where T: Debug {
51505177
foo(i);
51515178
}
5179+
"#,
5180+
);
5181+
}
5182+
5183+
#[test]
5184+
fn nested_generics() {
5185+
check_assist(
5186+
extract_function,
5187+
r#"
5188+
struct Struct<T: Into<i32>>(T);
5189+
impl <T: Into<i32> + Copy> Struct<T> {
5190+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5191+
let t = self.0;
5192+
$0t.into() + v.into()$0
5193+
}
5194+
}
5195+
"#,
5196+
r#"
5197+
struct Struct<T: Into<i32>>(T);
5198+
impl <T: Into<i32> + Copy> Struct<T> {
5199+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5200+
let t = self.0;
5201+
fun_name(t, v)
5202+
}
5203+
}
5204+
5205+
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
5206+
t.into() + v.into()
5207+
}
5208+
"#,
5209+
);
5210+
}
5211+
5212+
#[test]
5213+
fn filters_unused_nested_generics() {
5214+
check_assist(
5215+
extract_function,
5216+
r#"
5217+
struct Struct<T: Into<i32>, U: Debug>(T, U);
5218+
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
5219+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5220+
let t = self.0;
5221+
$0t.into() + v.into()$0
5222+
}
5223+
}
5224+
"#,
5225+
r#"
5226+
struct Struct<T: Into<i32>, U: Debug>(T, U);
5227+
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
5228+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5229+
let t = self.0;
5230+
fun_name(t, v)
5231+
}
5232+
}
5233+
5234+
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
5235+
t.into() + v.into()
5236+
}
5237+
"#,
5238+
);
5239+
}
5240+
5241+
#[test]
5242+
fn nested_where_clauses() {
5243+
check_assist(
5244+
extract_function,
5245+
r#"
5246+
struct Struct<T>(T) where T: Into<i32>;
5247+
impl <T> Struct<T> where T: Into<i32> + Copy {
5248+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5249+
let t = self.0;
5250+
$0t.into() + v.into()$0
5251+
}
5252+
}
5253+
"#,
5254+
r#"
5255+
struct Struct<T>(T) where T: Into<i32>;
5256+
impl <T> Struct<T> where T: Into<i32> + Copy {
5257+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5258+
let t = self.0;
5259+
fun_name(t, v)
5260+
}
5261+
}
5262+
5263+
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
5264+
t.into() + v.into()
5265+
}
5266+
"#,
5267+
);
5268+
}
5269+
5270+
#[test]
5271+
fn filters_unused_nested_where_clauses() {
5272+
check_assist(
5273+
extract_function,
5274+
r#"
5275+
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
5276+
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
5277+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5278+
let t = self.0;
5279+
$0t.into() + v.into()$0
5280+
}
5281+
}
5282+
"#,
5283+
r#"
5284+
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
5285+
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
5286+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5287+
let t = self.0;
5288+
fun_name(t, v)
5289+
}
5290+
}
5291+
5292+
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
5293+
t.into() + v.into()
5294+
}
51525295
"#,
51535296
);
51545297
}

0 commit comments

Comments
 (0)