Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 76 additions & 14 deletions crates/ide-assists/src/handlers/generate_enum_is_method.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use std::slice;

use ide_db::assists::GroupLabel;
use itertools::Itertools;
use stdx::to_lower_snake_case;
use syntax::ast::HasVisibility;
use syntax::ast::{self, AstNode, HasName};

use crate::utils;
use crate::{
AssistContext, AssistId, Assists,
utils::{add_method_to_adt, find_struct_impl},
Expand Down Expand Up @@ -41,20 +41,20 @@ use crate::{
// ```
pub(crate) fn generate_enum_is_method(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
let variant = ctx.find_node_at_offset::<ast::Variant>()?;
let variant_name = variant.name()?;
let parent_enum = ast::Adt::Enum(variant.parent_enum());
let pattern_suffix = match variant.kind() {
ast::StructKind::Record(_) => " { .. }",
ast::StructKind::Tuple(_) => "(..)",
ast::StructKind::Unit => "",
};

let variants = variant
.parent_enum()
.variant_list()?
.variants()
.filter(utils::selected(ctx.selection_trimmed()))
.collect::<Vec<_>>();
let methods = variants.iter().map(Method::new).collect::<Option<Vec<_>>>()?;
let enum_name = parent_enum.name()?;
let enum_lowercase_name = to_lower_snake_case(&enum_name.to_string()).replace('_', " ");
let fn_name = format!("is_{}", &to_lower_snake_case(&variant_name.text()));
let fn_names = methods.iter().map(|it| it.fn_name.clone()).collect::<Vec<_>>();

// Return early if we've found an existing new fn
let impl_def = find_struct_impl(ctx, &parent_enum, slice::from_ref(&fn_name))?;
let impl_def = find_struct_impl(ctx, &parent_enum, &fn_names)?;

let target = variant.syntax().text_range();
acc.add_group(
Expand All @@ -64,21 +64,47 @@ pub(crate) fn generate_enum_is_method(acc: &mut Assists, ctx: &AssistContext<'_>
target,
|builder| {
let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{v} "));
let method = format!(
" /// Returns `true` if the {enum_lowercase_name} is [`{variant_name}`].
let method = methods
.iter()
.map(|Method { pattern_suffix, fn_name, variant_name }| {
format!(
" \
/// Returns `true` if the {enum_lowercase_name} is [`{variant_name}`].
///
/// [`{variant_name}`]: {enum_name}::{variant_name}
#[must_use]
{vis}fn {fn_name}(&self) -> bool {{
matches!(self, Self::{variant_name}{pattern_suffix})
}}",
);
)
})
.join("\n\n");

add_method_to_adt(builder, &parent_enum, impl_def, &method);
},
)
}

struct Method {
pattern_suffix: &'static str,
fn_name: String,
variant_name: ast::Name,
}

impl Method {
fn new(variant: &ast::Variant) -> Option<Self> {
let pattern_suffix = match variant.kind() {
ast::StructKind::Record(_) => " { .. }",
ast::StructKind::Tuple(_) => "(..)",
ast::StructKind::Unit => "",
};

let variant_name = variant.name()?;
let fn_name = format!("is_{}", &to_lower_snake_case(&variant_name.text()));
Some(Method { pattern_suffix, fn_name, variant_name })
}
}

#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
Expand Down Expand Up @@ -113,6 +139,42 @@ impl Variant {
);
}

#[test]
fn test_generate_enum_is_from_multiple_variant() {
check_assist(
generate_enum_is_method,
r#"
enum Variant {
Undefined,
$0Minor,
M$0ajor,
}"#,
r#"enum Variant {
Undefined,
Minor,
Major,
}

impl Variant {
/// Returns `true` if the variant is [`Minor`].
///
/// [`Minor`]: Variant::Minor
#[must_use]
fn is_minor(&self) -> bool {
matches!(self, Self::Minor)
}

/// Returns `true` if the variant is [`Major`].
///
/// [`Major`]: Variant::Major
#[must_use]
fn is_major(&self) -> bool {
matches!(self, Self::Major)
}
}"#,
);
}

#[test]
fn test_generate_enum_is_already_implemented() {
check_assist_not_applicable(
Expand Down
164 changes: 132 additions & 32 deletions crates/ide-assists/src/handlers/generate_enum_projection_method.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
use std::slice;

use ide_db::assists::GroupLabel;
use itertools::Itertools;
use stdx::to_lower_snake_case;
use syntax::ast::HasVisibility;
use syntax::ast::{self, AstNode, HasName};

use crate::utils;
use crate::{
AssistContext, AssistId, Assists,
utils::{add_method_to_adt, find_struct_impl},
Expand Down Expand Up @@ -128,29 +127,21 @@ fn generate_enum_projection_method(
} = props;

let variant = ctx.find_node_at_offset::<ast::Variant>()?;
let variant_name = variant.name()?;
let parent_enum = ast::Adt::Enum(variant.parent_enum());

let (pattern_suffix, field_type, bound_name) = match variant.kind() {
ast::StructKind::Record(record) => {
let (field,) = record.fields().collect_tuple()?;
let name = field.name()?.to_string();
let ty = field.ty()?;
let pattern_suffix = format!(" {{ {name} }}");
(pattern_suffix, ty, name)
}
ast::StructKind::Tuple(tuple) => {
let (field,) = tuple.fields().collect_tuple()?;
let ty = field.ty()?;
("(v)".to_owned(), ty, "v".to_owned())
}
ast::StructKind::Unit => return None,
};

let fn_name = format!("{fn_name_prefix}_{}", &to_lower_snake_case(&variant_name.text()));
let variants = variant
.parent_enum()
.variant_list()?
.variants()
.filter(utils::selected(ctx.selection_trimmed()))
.collect::<Vec<_>>();
let methods = variants
.iter()
.map(|variant| Method::new(variant, fn_name_prefix))
.collect::<Option<Vec<_>>>()?;
let fn_names = methods.iter().map(|it| it.fn_name.clone()).collect::<Vec<_>>();

// Return early if we've found an existing new fn
let impl_def = find_struct_impl(ctx, &parent_enum, slice::from_ref(&fn_name))?;
let impl_def = find_struct_impl(ctx, &parent_enum, &fn_names)?;

let target = variant.syntax().text_range();
acc.add_group(
Expand All @@ -161,29 +152,66 @@ fn generate_enum_projection_method(
|builder| {
let vis = parent_enum.visibility().map_or(String::new(), |v| format!("{v} "));

let field_type_syntax = field_type.syntax();
let must_use = if ctx.config.assist_emit_must_use { "#[must_use]\n " } else { "" };

let must_use = if ctx.config.assist_emit_must_use {
"#[must_use]\n "
} else {
""
};

let method = format!(
" {must_use}{vis}fn {fn_name}({self_param}) -> {return_prefix}{field_type_syntax}{return_suffix} {{
let method = methods
.iter()
.map(|Method { pattern_suffix, field_type, bound_name, fn_name, variant_name }| {
format!(
" \
{must_use}{vis}fn {fn_name}({self_param}) -> {return_prefix}{field_type}{return_suffix} {{
if let Self::{variant_name}{pattern_suffix} = self {{
{happy_case}({bound_name})
}} else {{
{sad_case}
}}
}}"
);
)
})
.join("\n\n");

add_method_to_adt(builder, &parent_enum, impl_def, &method);
},
)
}

struct Method {
pattern_suffix: String,
field_type: ast::Type,
bound_name: String,
fn_name: String,
variant_name: ast::Name,
}

impl Method {
fn new(variant: &ast::Variant, fn_name_prefix: &str) -> Option<Self> {
let variant_name = variant.name()?;
let fn_name = format!("{fn_name_prefix}_{}", &to_lower_snake_case(&variant_name.text()));

match variant.kind() {
ast::StructKind::Record(record) => {
let (field,) = record.fields().collect_tuple()?;
let name = field.name()?.to_string();
let field_type = field.ty()?;
let pattern_suffix = format!(" {{ {name} }}");
Some(Method { pattern_suffix, field_type, bound_name: name, fn_name, variant_name })
}
ast::StructKind::Tuple(tuple) => {
let (field,) = tuple.fields().collect_tuple()?;
let field_type = field.ty()?;
Some(Method {
pattern_suffix: "(v)".to_owned(),
field_type,
bound_name: "v".to_owned(),
variant_name,
fn_name,
})
}
ast::StructKind::Unit => None,
}
}
}

#[cfg(test)]
mod tests {
use crate::tests::{check_assist, check_assist_not_applicable};
Expand Down Expand Up @@ -216,6 +244,42 @@ impl Value {
);
}

#[test]
fn test_generate_enum_multiple_try_into_tuple_variant() {
check_assist(
generate_enum_try_into_method,
r#"
enum Value {
Unit(()),
$0Number(i32),
Text(String)$0,
}"#,
r#"enum Value {
Unit(()),
Number(i32),
Text(String),
}

impl Value {
fn try_into_number(self) -> Result<i32, Self> {
if let Self::Number(v) = self {
Ok(v)
} else {
Err(self)
}
}

fn try_into_text(self) -> Result<String, Self> {
if let Self::Text(v) = self {
Ok(v)
} else {
Err(self)
}
}
}"#,
);
}

#[test]
fn test_generate_enum_try_into_already_implemented() {
check_assist_not_applicable(
Expand Down Expand Up @@ -323,6 +387,42 @@ impl Value {
);
}

#[test]
fn test_generate_enum_as_multiple_tuple_variant() {
check_assist(
generate_enum_as_method,
r#"
enum Value {
Unit(()),
$0Number(i32),
Text(String)$0,
}"#,
r#"enum Value {
Unit(()),
Number(i32),
Text(String),
}

impl Value {
fn as_number(&self) -> Option<&i32> {
if let Self::Number(v) = self {
Some(v)
} else {
None
}
}

fn as_text(&self) -> Option<&String> {
if let Self::Text(v) = self {
Some(v)
} else {
None
}
}
}"#,
);
}

#[test]
fn test_generate_enum_as_record_variant() {
check_assist(
Expand Down
7 changes: 7 additions & 0 deletions crates/ide-assists/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,13 @@ pub(crate) fn trimmed_text_range(source_file: &SourceFile, initial_range: TextRa
trimmed_range
}

pub(crate) fn selected<T: AstNode>(range: TextRange) -> impl Fn(&T) -> bool {
let empty_style = range.is_empty();
move |node| {
node.syntax().text_range().intersect(range).is_some_and(|it| it.is_empty() == empty_style)
}
}

/// Convert a list of function params to a list of arguments that can be passed
/// into a function call.
pub(crate) fn convert_param_list_to_arg_list(list: ast::ParamList) -> ast::ArgList {
Expand Down
Loading