Skip to content

Commit 56b8e89

Browse files
committed
Support nested generics and where clauses
Instead of only looking for generics and where clauses on the immediate parent, we can traverse all the `ancestors` and collect all the generics and where clauses that may be in scope.
1 parent ea7b8b8 commit 56b8e89

File tree

1 file changed

+186
-43
lines changed

1 file changed

+186
-43
lines changed

crates/ide-assists/src/handlers/extract_function.rs

+186-43
Original file line numberDiff line numberDiff line change
@@ -284,8 +284,8 @@ struct ContainerInfo {
284284
parent_loop: Option<SyntaxNode>,
285285
/// The function's return type, const's type etc.
286286
ret_type: Option<hir::Type>,
287-
generic_param_list: Option<ast::GenericParamList>,
288-
where_clause: Option<ast::WhereClause>,
287+
generic_param_lists: Vec<ast::GenericParamList>,
288+
where_clauses: Vec<ast::WhereClause>,
289289
}
290290

291291
/// Control flow that is exported from extracted function
@@ -715,17 +715,11 @@ impl FunctionBody {
715715
}
716716
};
717717

718-
let mut generic_param_list = None;
719-
let mut where_clause = None;
720-
721718
let (is_const, expr, ty) = loop {
722719
let anc = ancestors.next()?;
723720
break match_ast! {
724721
match anc {
725-
ast::ClosureExpr(closure) => {
726-
generic_param_list = closure.generic_param_list();
727-
(false, closure.body(), infer_expr_opt(closure.body()))
728-
},
722+
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body())),
729723
ast::BlockExpr(block_expr) => {
730724
let (constness, block) = match block_expr.modifier() {
731725
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
@@ -744,8 +738,6 @@ impl FunctionBody {
744738
ret_ty = async_ret;
745739
}
746740
}
747-
generic_param_list = fn_.generic_param_list();
748-
where_clause = fn_.where_clause();
749741
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
750742
},
751743
ast::Static(statik) => {
@@ -790,13 +782,18 @@ impl FunctionBody {
790782
container_tail.zip(self.tail_expr()).map_or(false, |(container_tail, body_tail)| {
791783
container_tail.syntax().text_range().contains_range(body_tail.syntax().text_range())
792784
});
785+
786+
let parent = self.parent()?;
787+
let generic_param_lists = parent_generic_param_lists(&parent);
788+
let where_clauses = parent_where_clauses(&parent);
789+
793790
Some(ContainerInfo {
794791
is_in_tail,
795792
is_const,
796793
parent_loop,
797794
ret_type: ty,
798-
generic_param_list,
799-
where_clause,
795+
generic_param_lists,
796+
where_clauses,
800797
})
801798
}
802799

@@ -954,6 +951,26 @@ impl FunctionBody {
954951
}
955952
}
956953

954+
fn parent_where_clauses(parent: &SyntaxNode) -> Vec<ast::WhereClause> {
955+
let mut where_clause: Vec<ast::WhereClause> = parent
956+
.ancestors()
957+
.filter_map(ast::AnyHasGenericParams::cast)
958+
.filter_map(|it| it.where_clause())
959+
.collect();
960+
where_clause.reverse();
961+
where_clause
962+
}
963+
964+
fn parent_generic_param_lists(parent: &SyntaxNode) -> Vec<ast::GenericParamList> {
965+
let mut generic_param_list: Vec<ast::GenericParamList> = parent
966+
.ancestors()
967+
.filter_map(ast::AnyHasGenericParams::cast)
968+
.filter_map(|it| it.generic_param_list())
969+
.collect();
970+
generic_param_list.reverse();
971+
generic_param_list
972+
}
973+
957974
/// checks if relevant var is used with `&mut` access inside body
958975
fn has_exclusive_usages(ctx: &AssistContext, usages: &LocalUsages, body: &FunctionBody) -> bool {
959976
usages
@@ -1393,14 +1410,14 @@ fn format_function(
13931410

13941411
format_to!(fn_def, "{}", params);
13951412

1396-
if let Some(where_clause) = where_clause {
1397-
format_to!(fn_def, " {}", where_clause);
1398-
}
1399-
14001413
if let Some(ret_ty) = ret_ty {
14011414
format_to!(fn_def, " {}", ret_ty);
14021415
}
14031416

1417+
if let Some(where_clause) = where_clause {
1418+
format_to!(fn_def, " {}", where_clause);
1419+
}
1420+
14041421
format_to!(fn_def, " {}", body);
14051422

14061423
fn_def
@@ -1412,34 +1429,32 @@ fn make_generic_params_and_where_clause(
14121429
) -> (Option<ast::GenericParamList>, Option<ast::WhereClause>) {
14131430
let used_type_params = fun.type_params(ctx);
14141431

1415-
let generic_param_list = fun
1416-
.mods
1417-
.generic_param_list
1418-
.as_ref()
1419-
.map(|parent_params| make_generic_param_list(ctx, parent_params, &used_type_params))
1420-
.flatten();
1421-
1422-
let where_clause =
1423-
fun.mods.where_clause.as_ref().map(|parent_where_clause| {
1424-
make_where_clause(ctx, parent_where_clause, &used_type_params)
1425-
});
1432+
let generic_param_list = make_generic_param_list(ctx, fun, &used_type_params);
1433+
let where_clause = make_where_clause(ctx, fun, &used_type_params);
14261434

14271435
(generic_param_list, where_clause)
14281436
}
14291437

14301438
fn make_generic_param_list(
14311439
ctx: &AssistContext,
1432-
parent_params: &ast::GenericParamList,
1440+
fun: &Function,
14331441
used_type_params: &[TypeParam],
14341442
) -> Option<ast::GenericParamList> {
1435-
let required_generic_params: Vec<ast::GenericParam> = parent_params
1436-
.generic_params()
1437-
.filter(|param| param_is_required(ctx, param, used_type_params))
1438-
.collect();
1439-
if required_generic_params.is_empty() {
1440-
None
1443+
let mut generic_params = fun
1444+
.mods
1445+
.generic_param_lists
1446+
.iter()
1447+
.flat_map(|parent_params| {
1448+
parent_params
1449+
.generic_params()
1450+
.filter(|param| param_is_required(ctx, param, used_type_params))
1451+
})
1452+
.peekable();
1453+
1454+
if generic_params.peek().is_some() {
1455+
Some(make::generic_param_list(generic_params))
14411456
} else {
1442-
Some(make::generic_param_list(required_generic_params))
1457+
None
14431458
}
14441459
}
14451460

@@ -1451,21 +1466,33 @@ fn param_is_required(
14511466
match param {
14521467
ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => false,
14531468
ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
1454-
Some(def) => used_type_params.iter().contains(def),
1469+
Some(def) => used_type_params.contains(def),
14551470
_ => false,
14561471
},
14571472
}
14581473
}
14591474

14601475
fn make_where_clause(
14611476
ctx: &AssistContext,
1462-
parent_where_clause: &ast::WhereClause,
1477+
fun: &Function,
14631478
used_type_params: &[TypeParam],
1464-
) -> ast::WhereClause {
1465-
let preds = parent_where_clause
1466-
.predicates()
1467-
.filter(|pred| pred_is_required(ctx, pred, used_type_params));
1468-
make::where_clause(preds)
1479+
) -> Option<ast::WhereClause> {
1480+
let mut predicates = fun
1481+
.mods
1482+
.where_clauses
1483+
.iter()
1484+
.flat_map(|parent_where_clause| {
1485+
parent_where_clause
1486+
.predicates()
1487+
.filter(|pred| pred_is_required(ctx, pred, used_type_params))
1488+
})
1489+
.peekable();
1490+
1491+
if predicates.peek().is_some() {
1492+
Some(make::where_clause(predicates))
1493+
} else {
1494+
None
1495+
}
14691496
}
14701497

14711498
fn pred_is_required(
@@ -5052,6 +5079,122 @@ fn func<T, U>(i: T, u: U) where T: Debug, U: Copy {
50525079
fn $0fun_name<T>(i: T) where T: Debug {
50535080
foo(i);
50545081
}
5082+
"#,
5083+
);
5084+
}
5085+
5086+
#[test]
5087+
fn nested_generics() {
5088+
check_assist(
5089+
extract_function,
5090+
r#"
5091+
struct Struct<T: Into<i32>>(T);
5092+
impl <T: Into<i32> + Copy> Struct<T> {
5093+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5094+
let t = self.0;
5095+
$0t.into() + v.into()$0
5096+
}
5097+
}
5098+
"#,
5099+
r#"
5100+
struct Struct<T: Into<i32>>(T);
5101+
impl <T: Into<i32> + Copy> Struct<T> {
5102+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5103+
let t = self.0;
5104+
fun_name(t, v)
5105+
}
5106+
}
5107+
5108+
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
5109+
t.into() + v.into()
5110+
}
5111+
"#,
5112+
);
5113+
}
5114+
5115+
#[test]
5116+
fn filters_unused_nested_generics() {
5117+
check_assist(
5118+
extract_function,
5119+
r#"
5120+
struct Struct<T: Into<i32>, U: Debug>(T, U);
5121+
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
5122+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5123+
let t = self.0;
5124+
$0t.into() + v.into()$0
5125+
}
5126+
}
5127+
"#,
5128+
r#"
5129+
struct Struct<T: Into<i32>, U: Debug>(T, U);
5130+
impl <T: Into<i32> + Copy, U: Debug> Struct<T, U> {
5131+
fn func<V: Into<i32>>(&self, v: V) -> i32 {
5132+
let t = self.0;
5133+
fun_name(t, v)
5134+
}
5135+
}
5136+
5137+
fn $0fun_name<T: Into<i32> + Copy, V: Into<i32>>(t: T, v: V) -> i32 {
5138+
t.into() + v.into()
5139+
}
5140+
"#,
5141+
);
5142+
}
5143+
5144+
#[test]
5145+
fn nested_where_clauses() {
5146+
check_assist(
5147+
extract_function,
5148+
r#"
5149+
struct Struct<T>(T) where T: Into<i32>;
5150+
impl <T> Struct<T> where T: Into<i32> + Copy {
5151+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5152+
let t = self.0;
5153+
$0t.into() + v.into()$0
5154+
}
5155+
}
5156+
"#,
5157+
r#"
5158+
struct Struct<T>(T) where T: Into<i32>;
5159+
impl <T> Struct<T> where T: Into<i32> + Copy {
5160+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5161+
let t = self.0;
5162+
fun_name(t, v)
5163+
}
5164+
}
5165+
5166+
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
5167+
t.into() + v.into()
5168+
}
5169+
"#,
5170+
);
5171+
}
5172+
5173+
#[test]
5174+
fn filters_unused_nested_where_clauses() {
5175+
check_assist(
5176+
extract_function,
5177+
r#"
5178+
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
5179+
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
5180+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5181+
let t = self.0;
5182+
$0t.into() + v.into()$0
5183+
}
5184+
}
5185+
"#,
5186+
r#"
5187+
struct Struct<T, U>(T, U) where T: Into<i32>, U: Debug;
5188+
impl <T, U> Struct<T, U> where T: Into<i32> + Copy, U: Debug {
5189+
fn func<V>(&self, v: V) -> i32 where V: Into<i32> {
5190+
let t = self.0;
5191+
fun_name(t, v)
5192+
}
5193+
}
5194+
5195+
fn $0fun_name<T, V>(t: T, v: V) -> i32 where T: Into<i32> + Copy, V: Into<i32> {
5196+
t.into() + v.into()
5197+
}
50555198
"#,
50565199
);
50575200
}

0 commit comments

Comments
 (0)