Skip to content

Commit 530591f

Browse files
committed
Filter generic param list to only params used in function
1 parent f7de23c commit 530591f

File tree

2 files changed

+151
-14
lines changed

2 files changed

+151
-14
lines changed

crates/hir/src/lib.rs

+9
Original file line numberDiff line numberDiff line change
@@ -3280,6 +3280,15 @@ impl Type {
32803280
let tys = hir_ty::replace_errors_with_variables(&(self.ty.clone(), to.ty.clone()));
32813281
hir_ty::could_coerce(db, self.env.clone(), &tys)
32823282
}
3283+
3284+
pub fn as_type_param(&self, db: &dyn HirDatabase) -> Option<TypeParam> {
3285+
match self.ty.kind(Interner) {
3286+
TyKind::Placeholder(p) => Some(TypeParam {
3287+
id: TypeParamId::from_unchecked(hir_ty::from_placeholder_idx(db, *p)),
3288+
}),
3289+
_ => None,
3290+
}
3291+
}
32833292
}
32843293

32853294
#[derive(Debug)]

crates/ide-assists/src/handlers/extract_function.rs

+142-14
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::iter;
22

33
use ast::make;
44
use either::Either;
5-
use hir::{HirDisplay, InFile, Local, ModuleDef, Semantics, TypeInfo};
5+
use hir::{HirDisplay, InFile, Local, ModuleDef, PathResolution, Semantics, TypeInfo, TypeParam};
66
use ide_db::{
77
defs::{Definition, NameRefClass},
88
famous_defs::FamousDefs,
@@ -469,6 +469,24 @@ impl FunctionBody {
469469
}
470470
}
471471

472+
fn descendants(&self) -> impl Iterator<Item = SyntaxNode> {
473+
match self {
474+
FunctionBody::Expr(expr) => expr.syntax().descendants(),
475+
FunctionBody::Span { parent, .. } => parent.syntax().descendants(),
476+
}
477+
}
478+
479+
fn descendant_paths(&self) -> impl Iterator<Item = ast::Path> {
480+
self.descendants().filter_map(|node| {
481+
match_ast! {
482+
match node {
483+
ast::Path(it) => Some(it),
484+
_ => None
485+
}
486+
}
487+
})
488+
}
489+
472490
fn from_expr(expr: ast::Expr) -> Option<Self> {
473491
match expr {
474492
ast::Expr::BreakExpr(it) => it.expr().map(Self::Expr),
@@ -678,11 +696,16 @@ impl FunctionBody {
678696
parent_loop.get_or_insert(loop_.syntax().clone());
679697
}
680698
};
681-
let (is_const, expr, ty, generic_param_list, where_clause) = loop {
699+
700+
let (mut generic_param_list, mut where_clause) = (None, None);
701+
let (is_const, expr, ty) = loop {
682702
let anc = ancestors.next()?;
683703
break match_ast! {
684704
match anc {
685-
ast::ClosureExpr(closure) => (false, closure.body(), infer_expr_opt(closure.body()), closure.generic_param_list(), None),
705+
ast::ClosureExpr(closure) => {
706+
generic_param_list = closure.generic_param_list();
707+
(false, closure.body(), infer_expr_opt(closure.body()))
708+
},
686709
ast::BlockExpr(block_expr) => {
687710
let (constness, block) = match block_expr.modifier() {
688711
Some(ast::BlockModifier::Const(_)) => (true, block_expr),
@@ -691,7 +714,7 @@ impl FunctionBody {
691714
_ => continue,
692715
};
693716
let expr = Some(ast::Expr::BlockExpr(block));
694-
(constness, expr.clone(), infer_expr_opt(expr), None, None)
717+
(constness, expr.clone(), infer_expr_opt(expr))
695718
},
696719
ast::Fn(fn_) => {
697720
let func = sema.to_def(&fn_)?;
@@ -701,23 +724,25 @@ impl FunctionBody {
701724
ret_ty = async_ret;
702725
}
703726
}
704-
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty), fn_.generic_param_list(), fn_.where_clause())
727+
generic_param_list = fn_.generic_param_list();
728+
where_clause = fn_.where_clause();
729+
(fn_.const_token().is_some(), fn_.body().map(ast::Expr::BlockExpr), Some(ret_ty))
705730
},
706731
ast::Static(statik) => {
707-
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)), None, None)
732+
(true, statik.body(), Some(sema.to_def(&statik)?.ty(sema.db)))
708733
},
709734
ast::ConstArg(ca) => {
710-
(true, ca.expr(), infer_expr_opt(ca.expr()), None, None)
735+
(true, ca.expr(), infer_expr_opt(ca.expr()))
711736
},
712737
ast::Const(konst) => {
713-
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)), None, None)
738+
(true, konst.body(), Some(sema.to_def(&konst)?.ty(sema.db)))
714739
},
715740
ast::ConstParam(cp) => {
716-
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)), None, None)
741+
(true, cp.default_val(), Some(sema.to_def(&cp)?.ty(sema.db)))
717742
},
718743
ast::ConstBlockPat(cbp) => {
719744
let expr = cbp.block_expr().map(ast::Expr::BlockExpr);
720-
(true, expr.clone(), infer_expr_opt(expr), None, None)
745+
(true, expr.clone(), infer_expr_opt(expr))
721746
},
722747
ast::Variant(__) => return None,
723748
ast::Meta(__) => return None,
@@ -1321,8 +1346,7 @@ fn format_function(
13211346
let const_kw = if fun.mods.is_const { "const " } else { "" };
13221347
let async_kw = if fun.control_flow.is_async { "async " } else { "" };
13231348
let unsafe_kw = if fun.control_flow.is_unsafe { "unsafe " } else { "" };
1324-
let generic_params = format_generic_param_list(fun);
1325-
let where_clause = format_where_clause(fun);
1349+
let (generic_params, where_clause) = format_generic_params_and_where_clause(ctx, fun);
13261350
match ctx.config.snippet_cap {
13271351
Some(_) => format_to!(
13281352
fn_def,
@@ -1357,13 +1381,50 @@ fn format_function(
13571381
fn_def
13581382
}
13591383

1360-
fn format_generic_param_list(fun: &Function) -> String {
1384+
fn format_generic_params_and_where_clause(ctx: &AssistContext, fun: &Function) -> (String, String) {
1385+
(format_generic_param_list(fun, ctx), format_where_clause(fun))
1386+
}
1387+
1388+
fn format_generic_param_list(fun: &Function, ctx: &AssistContext) -> String {
1389+
let type_params_in_descendant_paths =
1390+
fun.body.descendant_paths().filter_map(|it| match ctx.sema.resolve_path(&it) {
1391+
Some(PathResolution::TypeParam(type_param)) => Some(type_param),
1392+
_ => None,
1393+
});
1394+
1395+
let type_params_in_params = fun.params.iter().filter_map(|p| p.ty.as_type_param(ctx.db()));
1396+
1397+
let used_type_params: Vec<TypeParam> =
1398+
type_params_in_descendant_paths.chain(type_params_in_params).collect();
1399+
13611400
match &fun.mods.generic_param_list {
1362-
Some(it) => format!("{}", it),
1401+
Some(list) => {
1402+
let filtered_generic_params = filter_generic_param_list(ctx, list, used_type_params);
1403+
if filtered_generic_params.is_empty() {
1404+
return "".to_string();
1405+
}
1406+
format!("{}", make::generic_param_list(filtered_generic_params))
1407+
}
13631408
None => "".to_string(),
13641409
}
13651410
}
13661411

1412+
fn filter_generic_param_list(
1413+
ctx: &AssistContext,
1414+
list: &ast::GenericParamList,
1415+
used_type_params: Vec<TypeParam>,
1416+
) -> Vec<ast::GenericParam> {
1417+
list.generic_params()
1418+
.filter(|p| match p {
1419+
ast::GenericParam::ConstParam(_) | ast::GenericParam::LifetimeParam(_) => true,
1420+
ast::GenericParam::TypeParam(type_param) => match &ctx.sema.to_def(type_param) {
1421+
Some(def) => used_type_params.iter().contains(def),
1422+
_ => false,
1423+
},
1424+
})
1425+
.collect()
1426+
}
1427+
13671428
fn format_where_clause(fun: &Function) -> String {
13681429
match &fun.mods.where_clause {
13691430
Some(it) => format!(" {}", it),
@@ -4764,6 +4825,73 @@ fn $0fun_name<T: Debug>(i: T) {
47644825
);
47654826
}
47664827

4828+
#[test]
4829+
fn preserve_generics_from_body() {
4830+
check_assist(
4831+
extract_function,
4832+
r#"
4833+
fn func<T: Default>() -> T {
4834+
$0T::default()$0
4835+
}
4836+
"#,
4837+
r#"
4838+
fn func<T: Default>() -> T {
4839+
fun_name()
4840+
}
4841+
4842+
fn $0fun_name<T: Default>() -> T {
4843+
T::default()
4844+
}
4845+
"#,
4846+
);
4847+
}
4848+
4849+
#[test]
4850+
fn filter_unused_generics() {
4851+
check_assist(
4852+
extract_function,
4853+
r#"
4854+
fn func<T: Debug, U: Copy>(i: T, u: U) {
4855+
bar(u);
4856+
$0foo(i);$0
4857+
}
4858+
"#,
4859+
r#"
4860+
fn func<T: Debug, U: Copy>(i: T, u: U) {
4861+
bar(u);
4862+
fun_name(i);
4863+
}
4864+
4865+
fn $0fun_name<T: Debug>(i: T) {
4866+
foo(i);
4867+
}
4868+
"#,
4869+
);
4870+
}
4871+
4872+
#[test]
4873+
fn empty_generic_param_list() {
4874+
check_assist(
4875+
extract_function,
4876+
r#"
4877+
fn func<T: Debug>(t: T, i: u32) {
4878+
bar(t);
4879+
$0foo(i);$0
4880+
}
4881+
"#,
4882+
r#"
4883+
fn func<T: Debug>(t: T, i: u32) {
4884+
bar(t);
4885+
fun_name(i);
4886+
}
4887+
4888+
fn $0fun_name(i: u32) {
4889+
foo(i);
4890+
}
4891+
"#,
4892+
);
4893+
}
4894+
47674895
#[test]
47684896
fn preserve_where_clause() {
47694897
check_assist(

0 commit comments

Comments
 (0)