From 1d5f4ab06cb731ca4b74b4fe1db3d409a23bb963 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Sun, 7 Dec 2025 21:49:16 +0900 Subject: [PATCH 1/3] impl test --- pyrefly/lib/lsp/non_wasm/server.rs | 74 +++- pyrefly/lib/state/lsp.rs | 332 ++++++++++++++++++ pyrefly/lib/test/lsp/code_actions.rs | 102 ++++++ pyrefly/lib/test/lsp/lsp_interaction/basic.rs | 2 +- 4 files changed, 492 insertions(+), 18 deletions(-) diff --git a/pyrefly/lib/lsp/non_wasm/server.rs b/pyrefly/lib/lsp/non_wasm/server.rs index a205096929..607bbb7d30 100644 --- a/pyrefly/lib/lsp/non_wasm/server.rs +++ b/pyrefly/lib/lsp/non_wasm/server.rs @@ -469,7 +469,10 @@ pub fn capabilities( type_definition_provider: Some(TypeDefinitionProviderCapability::Simple(true)), implementation_provider: Some(ImplementationProviderCapability::Simple(true)), code_action_provider: Some(CodeActionProviderCapability::Options(CodeActionOptions { - code_action_kinds: Some(vec![CodeActionKind::QUICKFIX]), + code_action_kinds: Some(vec![ + CodeActionKind::QUICKFIX, + CodeActionKind::REFACTOR_EXTRACT, + ]), ..Default::default() })), completion_provider: Some(CompletionOptions { @@ -2294,26 +2297,63 @@ impl Server { let import_format = lsp_config.and_then(|c| c.import_format).unwrap_or_default(); let module_info = transaction.get_module_info(&handle)?; let range = self.from_lsp_range(uri, &module_info, params.range); - let code_actions = transaction - .local_quickfix_code_actions(&handle, range, import_format)? - .into_map(|(title, info, range, insert_text)| { - CodeActionOrCommand::CodeAction(CodeAction { - title, - kind: Some(CodeActionKind::QUICKFIX), + let mut actions = Vec::new(); + if let Some(quickfixes) = + transaction.local_quickfix_code_actions(&handle, range, import_format) + { + actions.extend( + quickfixes + .into_iter() + .map(|(title, info, range, insert_text)| { + CodeActionOrCommand::CodeAction(CodeAction { + title, + kind: Some(CodeActionKind::QUICKFIX), + edit: Some(WorkspaceEdit { + changes: Some(HashMap::from([( + uri.clone(), + vec![TextEdit { + range: info.to_lsp_range(range), + new_text: insert_text, + }], + )])), + ..Default::default() + }), + ..Default::default() + }) + }), + ); + } + if let Some(refactors) = transaction.extract_function_code_actions(&handle, range) { + for action in refactors { + let mut changes: HashMap> = HashMap::new(); + for (module, edit_range, new_text) in action.edits { + let Some(edit_uri) = module_info_to_uri(&module) else { + continue; + }; + changes.entry(edit_uri).or_default().push(TextEdit { + range: module.to_lsp_range(edit_range), + new_text, + }); + } + if changes.is_empty() { + continue; + } + actions.push(CodeActionOrCommand::CodeAction(CodeAction { + title: action.title, + kind: Some(action.kind), edit: Some(WorkspaceEdit { - changes: Some(HashMap::from([( - uri.clone(), - vec![TextEdit { - range: info.to_lsp_range(range), - new_text: insert_text, - }], - )])), + changes: Some(changes), ..Default::default() }), ..Default::default() - }) - }); - Some(code_actions) + })); + } + } + if actions.is_empty() { + None + } else { + Some(actions) + } } fn document_highlight( diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 28340e0cd7..6bb67b55b8 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -7,11 +7,13 @@ use std::cmp::Reverse; use std::collections::BTreeMap; +use std::collections::HashSet; use dupe::Dupe; use fuzzy_matcher::FuzzyMatcher; use fuzzy_matcher::skim::SkimMatcherV2; use itertools::Itertools; +use lsp_types::CodeActionKind; use lsp_types::CompletionItem; use lsp_types::CompletionItemKind; use lsp_types::CompletionItemLabelDetails; @@ -52,9 +54,12 @@ use ruff_python_ast::Identifier; use ruff_python_ast::Keyword; use ruff_python_ast::ModModule; use ruff_python_ast::Number; +use ruff_python_ast::Stmt; use ruff_python_ast::StmtImportFrom; use ruff_python_ast::UnaryOp; use ruff_python_ast::name::Name; +use ruff_python_ast::visitor; +use ruff_python_ast::visitor::Visitor; use ruff_text_size::Ranged; use ruff_text_size::TextRange; use ruff_text_size::TextSize; @@ -82,6 +87,13 @@ use crate::types::callable::Param; use crate::types::module::ModuleType; use crate::types::types::Type; +#[derive(Clone, Debug)] +pub struct LocalRefactorCodeAction { + pub title: String, + pub edits: Vec<(Module, TextRange, String)>, + pub kind: CodeActionKind, +} + fn default_true() -> bool { true } @@ -1713,6 +1725,129 @@ impl<'a> Transaction<'a> { Some(code_actions) } + pub fn extract_function_code_actions( + &self, + handle: &Handle, + selection: TextRange, + ) -> Option> { + if selection.is_empty() { + return None; + } + let Some(module_info) = self.get_module_info(handle) else { + return None; + }; + let Some(ast) = self.get_ast(handle) else { + return None; + }; + let selection_text = module_info.code_at(selection); + if selection_text.trim().is_empty() { + return None; + } + let module_len = + TextSize::try_from(module_info.contents().len()).unwrap_or(TextSize::new(0)); + let module_stmt_range = + find_enclosing_module_statement_range(ast.as_ref(), selection, module_len); + if selection_contains_disallowed_statements(ast.as_ref(), selection) { + return None; + } + let (load_refs, store_refs) = collect_identifier_refs(ast.as_ref(), selection); + if load_refs.is_empty() && store_refs.is_empty() { + return None; + } + let post_loads = + collect_post_selection_loads(ast.as_ref(), module_stmt_range, selection.end()); + let block_indent = detect_block_indent(selection_text); + let Some(mut dedented_body) = dedent_selection(selection_text) else { + return None; + }; + if dedented_body.trim().is_empty() { + return None; + } + if dedented_body.ends_with('\n') { + dedented_body.pop(); + if dedented_body.ends_with('\r') { + dedented_body.pop(); + } + } + let helper_name = generate_helper_name(module_info.contents()); + let mut params = Vec::new(); + let mut seen_params = HashSet::new(); + for ident in load_refs { + if seen_params.contains(&ident.name) { + continue; + } + if ident.synthetic_load { + seen_params.insert(ident.name.clone()); + params.push(ident.name.clone()); + continue; + } + let defs = self.find_definition(handle, ident.position, FindPreference::default()); + let Some(def) = defs.first() else { + continue; + }; + if def.module.path() != module_info.path() { + continue; + } + if !contains_range(module_stmt_range, def.definition_range) + || contains_range(selection, def.definition_range) + || def.definition_range.start() >= selection.start() + { + continue; + } + seen_params.insert(ident.name.clone()); + params.push(ident.name.clone()); + } + let mut returns = Vec::new(); + let mut seen_returns = HashSet::new(); + for ident in store_refs { + if seen_returns.contains(&ident.name) || !post_loads.contains(&ident.name) { + continue; + } + seen_returns.insert(ident.name.clone()); + returns.push(ident.name.clone()); + } + let helper_params = params.join(", "); + let indented_body = if dedented_body.trim().is_empty() { + " pass\n".to_owned() + } else { + indent_block(&dedented_body, " ") + }; + let mut helper_text = format!("def {helper_name}({helper_params}):\n{indented_body}"); + if !returns.is_empty() { + let return_expr = if returns.len() == 1 { + returns[0].clone() + } else { + returns.join(", ") + }; + helper_text.push_str(&format!(" return {return_expr}\n")); + } + helper_text.push('\n'); + let call_args = params.join(", "); + let call_expr = format!("{helper_name}({call_args})"); + let replacement_line = if returns.is_empty() { + format!("{block_indent}{call_expr}\n") + } else { + let lhs = if returns.len() == 1 { + returns[0].clone() + } else { + returns.join(", ") + }; + format!("{block_indent}{lhs} = {call_expr}\n") + }; + let helper_edit = ( + module_info.dupe(), + TextRange::at(module_stmt_range.start(), TextSize::new(0)), + helper_text, + ); + let call_edit = (module_info.dupe(), selection, replacement_line); + let action = LocalRefactorCodeAction { + title: format!("Extract into helper `{helper_name}`"), + edits: vec![helper_edit, call_edit], + kind: CodeActionKind::REFACTOR_EXTRACT, + }; + Some(vec![action]) + } + /// Determines whether a module is a third-party package. /// /// Checks if the module's path is located within any of the configured @@ -2974,6 +3109,203 @@ impl<'a> Transaction<'a> { } } +#[derive(Clone, Debug)] +struct IdentifierRef { + name: String, + position: TextSize, + synthetic_load: bool, +} + +fn contains_range(outer: TextRange, inner: TextRange) -> bool { + outer.start() <= inner.start() && outer.end() >= inner.end() +} + +fn ranges_overlap(a: TextRange, b: TextRange) -> bool { + a.start() < b.end() && b.start() < a.end() +} + +fn collect_identifier_refs( + ast: &ModModule, + selection: TextRange, +) -> (Vec, Vec) { + struct IdentifierCollector { + selection: TextRange, + loads: Vec, + stores: Vec, + } + + impl<'a> visitor::Visitor<'a> for IdentifierCollector { + fn visit_expr(&mut self, expr: &'a Expr) { + if contains_range(self.selection, expr.range()) + && let Expr::Name(name) = expr + { + let ident = IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: false, + }; + match name.ctx { + ExprContext::Load => self.loads.push(ident), + ExprContext::Store => self.stores.push(ident), + ExprContext::Del | ExprContext::Invalid => {} + } + } + visitor::walk_expr(self, expr); + } + + fn visit_stmt(&mut self, stmt: &'a Stmt) { + if contains_range(self.selection, stmt.range()) + && let Stmt::AugAssign(aug) = stmt + && let Expr::Name(name) = aug.target.as_ref() + { + self.loads.push(IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: true, + }); + } + visitor::walk_stmt(self, stmt); + } + } + + let mut collector = IdentifierCollector { + selection, + loads: Vec::new(), + stores: Vec::new(), + }; + collector.visit_body(&ast.body); + (collector.loads, collector.stores) +} + +fn selection_contains_disallowed_statements(ast: &ModModule, selection: TextRange) -> bool { + fn visit_stmt(stmt: &Stmt, selection: TextRange, found: &mut bool) { + if *found || !ranges_overlap(stmt.range(), selection) { + return; + } + if contains_range(selection, stmt.range()) { + match stmt { + Stmt::Return(_) + | Stmt::Break(_) + | Stmt::Continue(_) + | Stmt::Raise(_) + | Stmt::FunctionDef(_) + | Stmt::ClassDef(_) => { + *found = true; + return; + } + _ => {} + } + } + stmt.recurse(&mut |child| visit_stmt(child, selection, found)); + } + let mut found = false; + for stmt in &ast.body { + visit_stmt(stmt, selection, &mut found); + if found { + break; + } + } + found +} + +fn find_enclosing_module_statement_range( + ast: &ModModule, + selection: TextRange, + module_len: TextSize, +) -> TextRange { + for stmt in &ast.body { + if contains_range(stmt.range(), selection) { + return stmt.range(); + } + } + TextRange::new(TextSize::new(0), module_len) +} + +fn collect_post_selection_loads( + ast: &ModModule, + module_stmt_range: TextRange, + selection_end: TextSize, +) -> HashSet { + let mut loads = HashSet::new(); + ast.visit(&mut |expr: &Expr| { + if let Expr::Name(name) = expr + && matches!(name.ctx, ExprContext::Load) + && contains_range(module_stmt_range, name.range) + && name.range.start() > selection_end + { + loads.insert(name.id.to_string()); + } + }); + loads +} + +fn detect_block_indent(selection_text: &str) -> String { + for line in selection_text.lines() { + if line.trim().is_empty() { + continue; + } + return line + .chars() + .take_while(|c| c.is_whitespace()) + .collect::(); + } + String::new() +} + +fn dedent_selection(selection_text: &str) -> Option { + let mut min_indent = usize::MAX; + for line in selection_text.lines() { + if line.trim().is_empty() { + continue; + } + let indent = line.chars().take_while(|c| c.is_whitespace()).count(); + min_indent = min_indent.min(indent); + } + if min_indent == usize::MAX { + return None; + } + let mut dedented = String::new(); + for line in selection_text.lines() { + if line.trim().is_empty() { + dedented.push('\n'); + continue; + } + let dedented_line = line.chars().skip(min_indent).collect::(); + dedented.push_str(&dedented_line); + dedented.push('\n'); + } + if !selection_text.ends_with('\n') { + dedented.push('\n'); + } + Some(dedented) +} + +fn indent_block(block: &str, indent: &str) -> String { + let mut result = String::new(); + for line in block.lines() { + result.push_str(indent); + result.push_str(line); + result.push('\n'); + } + result +} + +fn generate_helper_name(source: &str) -> String { + let mut counter = 1; + loop { + let candidate = if counter == 1 { + "extracted_function".to_owned() + } else { + format!("extracted_function_{counter}") + }; + let needle = format!("def {candidate}("); + if !source.contains(&needle) { + return candidate; + } + counter += 1; + } +} + impl<'a> CancellableTransaction<'a> { /// Finds child class implementations of a method definition. /// Returns the ranges of child methods that reimplement the given parent method. diff --git a/pyrefly/lib/test/lsp/code_actions.rs b/pyrefly/lib/test/lsp/code_actions.rs index 203406f465..6095bffc4c 100644 --- a/pyrefly/lib/test/lsp/code_actions.rs +++ b/pyrefly/lib/test/lsp/code_actions.rs @@ -7,13 +7,16 @@ use pretty_assertions::assert_eq; use pyrefly_build::handle::Handle; +use pyrefly_python::module::Module; use ruff_text_size::TextRange; use ruff_text_size::TextSize; use crate::module::module_info::ModuleInfo; use crate::state::lsp::ImportFormat; +use crate::state::require::Require; use crate::state::state::State; use crate::test::util::get_batched_lsp_operations_report_allow_error; +use crate::test::util::mk_multi_file_state_assert_no_errors; fn apply_patch(info: &ModuleInfo, range: TextRange, patch: String) -> (String, String) { let before = info.contents().as_str().to_owned(); @@ -50,6 +53,49 @@ fn get_test_report(state: &State, handle: &Handle, position: TextSize) -> String report } +fn apply_refactor_edits_for_module( + module: &ModuleInfo, + edits: &[(Module, TextRange, String)], +) -> String { + let mut relevant_edits: Vec<(TextRange, String)> = edits + .iter() + .filter(|(edit_module, _, _)| edit_module.path() == module.path()) + .map(|(_, range, text)| (*range, text.clone())) + .collect(); + relevant_edits.sort_by_key(|(range, _)| range.start()); + let mut result = module.contents().as_str().to_owned(); + for (range, replacement) in relevant_edits.into_iter().rev() { + result.replace_range( + range.start().to_usize()..range.end().to_usize(), + &replacement, + ); + } + result +} + +fn find_marked_range(source: &str) -> TextRange { + let start_marker = "# EXTRACT-START"; + let end_marker = "# EXTRACT-END"; + let start_idx = source + .find(start_marker) + .expect("missing start marker for extract refactor test"); + let start_line_end = source[start_idx..] + .find('\n') + .map(|offset| start_idx + offset + 1) + .unwrap_or(source.len()); + let end_idx = source + .find(end_marker) + .expect("missing end marker for extract refactor test"); + let end_line_start = source[..end_idx] + .rfind('\n') + .map(|idx| idx + 1) + .unwrap_or(end_idx); + TextRange::new( + TextSize::try_from(start_line_end).unwrap(), + TextSize::try_from(end_line_start).unwrap(), + ) +} + #[test] fn basic_test() { let report = get_batched_lsp_operations_report_allow_error( @@ -236,3 +282,59 @@ my_export report.trim() ); } + +#[test] +fn extract_function_basic_refactor() { + let code = r#" +def process_data(data_list): + total_sum = 0 + for item in data_list: + # EXTRACT-START + squared_value = item * item + if squared_value > 100: + print(f"Large value detected: {squared_value}") + total_sum += squared_value + # EXTRACT-END + return total_sum + + +if __name__ == "__main__": + data = [1, 5, 12, 8, 15] + result = process_data(data) + print(f"The final sum is: {result}") +"#; + let (handles, state) = + mk_multi_file_state_assert_no_errors(&[("main", code)], Require::Everything); + let handle = handles.get("main").unwrap(); + let transaction = state.transaction(); + let module_info = transaction.get_module_info(handle).unwrap(); + let selection = find_marked_range(module_info.contents()); + let actions = transaction + .extract_function_code_actions(handle, selection) + .unwrap_or_default(); + assert!(!actions.is_empty(), "expected extract refactor action"); + let updated = apply_refactor_edits_for_module(&module_info, &actions[0].edits); + let expected = r#" +def extracted_function(item, total_sum): + squared_value = item * item + if squared_value > 100: + print(f"Large value detected: {squared_value}") + total_sum += squared_value + return total_sum + +def process_data(data_list): + total_sum = 0 + for item in data_list: + # EXTRACT-START + total_sum = extracted_function(item, total_sum) + # EXTRACT-END + return total_sum + + +if __name__ == "__main__": + data = [1, 5, 12, 8, 15] + result = process_data(data) + print(f"The final sum is: {result}") +"#; + assert_eq!(expected.trim(), updated.trim()); +} diff --git a/pyrefly/lib/test/lsp/lsp_interaction/basic.rs b/pyrefly/lib/test/lsp/lsp_interaction/basic.rs index 046af7cc91..1e022736d2 100644 --- a/pyrefly/lib/test/lsp/lsp_interaction/basic.rs +++ b/pyrefly/lib/test/lsp/lsp_interaction/basic.rs @@ -36,7 +36,7 @@ fn test_initialize_basic() { "definitionProvider": true, "typeDefinitionProvider": true, "codeActionProvider": { - "codeActionKinds": ["quickfix"] + "codeActionKinds": ["quickfix", "refactor.extract"] }, "completionProvider": { "triggerCharacters": [".", "'", "\""] From 44f108a8cf85346247ae43642b7c878d3487cccb Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Dec 2025 17:11:44 +0900 Subject: [PATCH 2/3] update by comment --- crates/pyrefly_python/src/docstring.rs | 28 ++ pyrefly/lib/state/lsp.rs | 328 +----------------- .../state/lsp/quick_fixes/extract_function.rs | 322 +++++++++++++++++ pyrefly/lib/state/lsp/quick_fixes/mod.rs | 1 + pyrefly/lib/test/lsp/code_actions.rs | 131 ++++++- 5 files changed, 475 insertions(+), 335 deletions(-) create mode 100644 pyrefly/lib/state/lsp/quick_fixes/extract_function.rs create mode 100644 pyrefly/lib/state/lsp/quick_fixes/mod.rs diff --git a/crates/pyrefly_python/src/docstring.rs b/crates/pyrefly_python/src/docstring.rs index d5ae77d227..e0f37ce4ad 100644 --- a/crates/pyrefly_python/src/docstring.rs +++ b/crates/pyrefly_python/src/docstring.rs @@ -131,6 +131,34 @@ fn dedented_lines_for_parsing(docstring: &str) -> Vec { .collect() } +/// Dedent a block of text while preserving blank lines, similar to how we handle docstrings. +pub fn dedent_block_preserving_layout(text: &str) -> Option { + if text.trim().is_empty() { + return None; + } + + let lines: Vec<&str> = text.lines().collect(); + if lines.is_empty() { + return None; + } + + let min_indent = minimal_indentation(lines.iter().copied()); + let mut dedented = String::new(); + for line in lines { + if line.trim().is_empty() { + dedented.push('\n'); + continue; + } + let start = min_indent.min(line.len()); + dedented.push_str(&line[start..]); + dedented.push('\n'); + } + if !text.ends_with('\n') { + dedented.push('\n'); + } + Some(dedented) +} + fn leading_space_count(line: &str) -> usize { line.as_bytes().iter().take_while(|c| **c == b' ').count() } diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 6bb67b55b8..53ab2d717d 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -7,13 +7,11 @@ use std::cmp::Reverse; use std::collections::BTreeMap; -use std::collections::HashSet; use dupe::Dupe; use fuzzy_matcher::FuzzyMatcher; use fuzzy_matcher::skim::SkimMatcherV2; use itertools::Itertools; -use lsp_types::CodeActionKind; use lsp_types::CompletionItem; use lsp_types::CompletionItemKind; use lsp_types::CompletionItemLabelDetails; @@ -54,12 +52,9 @@ use ruff_python_ast::Identifier; use ruff_python_ast::Keyword; use ruff_python_ast::ModModule; use ruff_python_ast::Number; -use ruff_python_ast::Stmt; use ruff_python_ast::StmtImportFrom; use ruff_python_ast::UnaryOp; use ruff_python_ast::name::Name; -use ruff_python_ast::visitor; -use ruff_python_ast::visitor::Visitor; use ruff_text_size::Ranged; use ruff_text_size::TextRange; use ruff_text_size::TextSize; @@ -87,12 +82,9 @@ use crate::types::callable::Param; use crate::types::module::ModuleType; use crate::types::types::Type; -#[derive(Clone, Debug)] -pub struct LocalRefactorCodeAction { - pub title: String, - pub edits: Vec<(Module, TextRange, String)>, - pub kind: CodeActionKind, -} +mod quick_fixes; + +use self::quick_fixes::extract_function::LocalRefactorCodeAction; fn default_true() -> bool { true @@ -1730,122 +1722,7 @@ impl<'a> Transaction<'a> { handle: &Handle, selection: TextRange, ) -> Option> { - if selection.is_empty() { - return None; - } - let Some(module_info) = self.get_module_info(handle) else { - return None; - }; - let Some(ast) = self.get_ast(handle) else { - return None; - }; - let selection_text = module_info.code_at(selection); - if selection_text.trim().is_empty() { - return None; - } - let module_len = - TextSize::try_from(module_info.contents().len()).unwrap_or(TextSize::new(0)); - let module_stmt_range = - find_enclosing_module_statement_range(ast.as_ref(), selection, module_len); - if selection_contains_disallowed_statements(ast.as_ref(), selection) { - return None; - } - let (load_refs, store_refs) = collect_identifier_refs(ast.as_ref(), selection); - if load_refs.is_empty() && store_refs.is_empty() { - return None; - } - let post_loads = - collect_post_selection_loads(ast.as_ref(), module_stmt_range, selection.end()); - let block_indent = detect_block_indent(selection_text); - let Some(mut dedented_body) = dedent_selection(selection_text) else { - return None; - }; - if dedented_body.trim().is_empty() { - return None; - } - if dedented_body.ends_with('\n') { - dedented_body.pop(); - if dedented_body.ends_with('\r') { - dedented_body.pop(); - } - } - let helper_name = generate_helper_name(module_info.contents()); - let mut params = Vec::new(); - let mut seen_params = HashSet::new(); - for ident in load_refs { - if seen_params.contains(&ident.name) { - continue; - } - if ident.synthetic_load { - seen_params.insert(ident.name.clone()); - params.push(ident.name.clone()); - continue; - } - let defs = self.find_definition(handle, ident.position, FindPreference::default()); - let Some(def) = defs.first() else { - continue; - }; - if def.module.path() != module_info.path() { - continue; - } - if !contains_range(module_stmt_range, def.definition_range) - || contains_range(selection, def.definition_range) - || def.definition_range.start() >= selection.start() - { - continue; - } - seen_params.insert(ident.name.clone()); - params.push(ident.name.clone()); - } - let mut returns = Vec::new(); - let mut seen_returns = HashSet::new(); - for ident in store_refs { - if seen_returns.contains(&ident.name) || !post_loads.contains(&ident.name) { - continue; - } - seen_returns.insert(ident.name.clone()); - returns.push(ident.name.clone()); - } - let helper_params = params.join(", "); - let indented_body = if dedented_body.trim().is_empty() { - " pass\n".to_owned() - } else { - indent_block(&dedented_body, " ") - }; - let mut helper_text = format!("def {helper_name}({helper_params}):\n{indented_body}"); - if !returns.is_empty() { - let return_expr = if returns.len() == 1 { - returns[0].clone() - } else { - returns.join(", ") - }; - helper_text.push_str(&format!(" return {return_expr}\n")); - } - helper_text.push('\n'); - let call_args = params.join(", "); - let call_expr = format!("{helper_name}({call_args})"); - let replacement_line = if returns.is_empty() { - format!("{block_indent}{call_expr}\n") - } else { - let lhs = if returns.len() == 1 { - returns[0].clone() - } else { - returns.join(", ") - }; - format!("{block_indent}{lhs} = {call_expr}\n") - }; - let helper_edit = ( - module_info.dupe(), - TextRange::at(module_stmt_range.start(), TextSize::new(0)), - helper_text, - ); - let call_edit = (module_info.dupe(), selection, replacement_line); - let action = LocalRefactorCodeAction { - title: format!("Extract into helper `{helper_name}`"), - edits: vec![helper_edit, call_edit], - kind: CodeActionKind::REFACTOR_EXTRACT, - }; - Some(vec![action]) + quick_fixes::extract_function::extract_function_code_actions(self, handle, selection) } /// Determines whether a module is a third-party package. @@ -3109,203 +2986,6 @@ impl<'a> Transaction<'a> { } } -#[derive(Clone, Debug)] -struct IdentifierRef { - name: String, - position: TextSize, - synthetic_load: bool, -} - -fn contains_range(outer: TextRange, inner: TextRange) -> bool { - outer.start() <= inner.start() && outer.end() >= inner.end() -} - -fn ranges_overlap(a: TextRange, b: TextRange) -> bool { - a.start() < b.end() && b.start() < a.end() -} - -fn collect_identifier_refs( - ast: &ModModule, - selection: TextRange, -) -> (Vec, Vec) { - struct IdentifierCollector { - selection: TextRange, - loads: Vec, - stores: Vec, - } - - impl<'a> visitor::Visitor<'a> for IdentifierCollector { - fn visit_expr(&mut self, expr: &'a Expr) { - if contains_range(self.selection, expr.range()) - && let Expr::Name(name) = expr - { - let ident = IdentifierRef { - name: name.id.to_string(), - position: name.range.start(), - synthetic_load: false, - }; - match name.ctx { - ExprContext::Load => self.loads.push(ident), - ExprContext::Store => self.stores.push(ident), - ExprContext::Del | ExprContext::Invalid => {} - } - } - visitor::walk_expr(self, expr); - } - - fn visit_stmt(&mut self, stmt: &'a Stmt) { - if contains_range(self.selection, stmt.range()) - && let Stmt::AugAssign(aug) = stmt - && let Expr::Name(name) = aug.target.as_ref() - { - self.loads.push(IdentifierRef { - name: name.id.to_string(), - position: name.range.start(), - synthetic_load: true, - }); - } - visitor::walk_stmt(self, stmt); - } - } - - let mut collector = IdentifierCollector { - selection, - loads: Vec::new(), - stores: Vec::new(), - }; - collector.visit_body(&ast.body); - (collector.loads, collector.stores) -} - -fn selection_contains_disallowed_statements(ast: &ModModule, selection: TextRange) -> bool { - fn visit_stmt(stmt: &Stmt, selection: TextRange, found: &mut bool) { - if *found || !ranges_overlap(stmt.range(), selection) { - return; - } - if contains_range(selection, stmt.range()) { - match stmt { - Stmt::Return(_) - | Stmt::Break(_) - | Stmt::Continue(_) - | Stmt::Raise(_) - | Stmt::FunctionDef(_) - | Stmt::ClassDef(_) => { - *found = true; - return; - } - _ => {} - } - } - stmt.recurse(&mut |child| visit_stmt(child, selection, found)); - } - let mut found = false; - for stmt in &ast.body { - visit_stmt(stmt, selection, &mut found); - if found { - break; - } - } - found -} - -fn find_enclosing_module_statement_range( - ast: &ModModule, - selection: TextRange, - module_len: TextSize, -) -> TextRange { - for stmt in &ast.body { - if contains_range(stmt.range(), selection) { - return stmt.range(); - } - } - TextRange::new(TextSize::new(0), module_len) -} - -fn collect_post_selection_loads( - ast: &ModModule, - module_stmt_range: TextRange, - selection_end: TextSize, -) -> HashSet { - let mut loads = HashSet::new(); - ast.visit(&mut |expr: &Expr| { - if let Expr::Name(name) = expr - && matches!(name.ctx, ExprContext::Load) - && contains_range(module_stmt_range, name.range) - && name.range.start() > selection_end - { - loads.insert(name.id.to_string()); - } - }); - loads -} - -fn detect_block_indent(selection_text: &str) -> String { - for line in selection_text.lines() { - if line.trim().is_empty() { - continue; - } - return line - .chars() - .take_while(|c| c.is_whitespace()) - .collect::(); - } - String::new() -} - -fn dedent_selection(selection_text: &str) -> Option { - let mut min_indent = usize::MAX; - for line in selection_text.lines() { - if line.trim().is_empty() { - continue; - } - let indent = line.chars().take_while(|c| c.is_whitespace()).count(); - min_indent = min_indent.min(indent); - } - if min_indent == usize::MAX { - return None; - } - let mut dedented = String::new(); - for line in selection_text.lines() { - if line.trim().is_empty() { - dedented.push('\n'); - continue; - } - let dedented_line = line.chars().skip(min_indent).collect::(); - dedented.push_str(&dedented_line); - dedented.push('\n'); - } - if !selection_text.ends_with('\n') { - dedented.push('\n'); - } - Some(dedented) -} - -fn indent_block(block: &str, indent: &str) -> String { - let mut result = String::new(); - for line in block.lines() { - result.push_str(indent); - result.push_str(line); - result.push('\n'); - } - result -} - -fn generate_helper_name(source: &str) -> String { - let mut counter = 1; - loop { - let candidate = if counter == 1 { - "extracted_function".to_owned() - } else { - format!("extracted_function_{counter}") - }; - let needle = format!("def {candidate}("); - if !source.contains(&needle) { - return candidate; - } - counter += 1; - } -} - impl<'a> CancellableTransaction<'a> { /// Finds child class implementations of a method definition. /// Returns the ranges of child methods that reimplement the given parent method. diff --git a/pyrefly/lib/state/lsp/quick_fixes/extract_function.rs b/pyrefly/lib/state/lsp/quick_fixes/extract_function.rs new file mode 100644 index 0000000000..5bb33d6d2c --- /dev/null +++ b/pyrefly/lib/state/lsp/quick_fixes/extract_function.rs @@ -0,0 +1,322 @@ +use std::collections::HashSet; + +use dupe::Dupe; +use lsp_types::CodeActionKind; +use pyrefly_build::handle::Handle; +use pyrefly_python::docstring::dedent_block_preserving_layout; +use pyrefly_python::module::Module; +use pyrefly_util::visit::Visit; +use ruff_python_ast::Expr; +use ruff_python_ast::ExprContext; +use ruff_python_ast::ModModule; +use ruff_python_ast::Stmt; +use ruff_python_ast::visitor::Visitor; +use ruff_text_size::Ranged; +use ruff_text_size::TextRange; +use ruff_text_size::TextSize; + +use crate::state::lsp::FindPreference; +use crate::state::lsp::Transaction; + +const HELPER_INDENT: &str = " "; + +/// Description of a refactor edit that stays within the local workspace. +#[derive(Clone, Debug)] +pub struct LocalRefactorCodeAction { + pub title: String, + pub edits: Vec<(Module, TextRange, String)>, + pub kind: CodeActionKind, +} + +/// Builds extract-function quick fix code actions for the supplied selection. +pub(crate) fn extract_function_code_actions( + transaction: &Transaction<'_>, + handle: &Handle, + selection: TextRange, +) -> Option> { + if selection.is_empty() { + return None; + } + let Some(module_info) = transaction.get_module_info(handle) else { + return None; + }; + let Some(ast) = transaction.get_ast(handle) else { + return None; + }; + let selection_text = module_info.code_at(selection); + if selection_text.trim().is_empty() { + return None; + } + let module_len = TextSize::try_from(module_info.contents().len()).unwrap_or(TextSize::new(0)); + let module_stmt_range = + find_enclosing_module_statement_range(ast.as_ref(), selection, module_len); + if selection_contains_disallowed_statements(ast.as_ref(), selection) { + return None; + } + let (load_refs, store_refs) = collect_identifier_refs(ast.as_ref(), selection); + if load_refs.is_empty() && store_refs.is_empty() { + return None; + } + let post_loads = collect_post_selection_loads(ast.as_ref(), module_stmt_range, selection.end()); + let block_indent = detect_block_indent(selection_text); + let Some(mut dedented_body) = dedent_block_preserving_layout(selection_text) else { + return None; + }; + if dedented_body.ends_with('\n') { + dedented_body.pop(); + if dedented_body.ends_with('\r') { + dedented_body.pop(); + } + } + + let helper_name = generate_helper_name(module_info.contents()); + let mut params = Vec::new(); + let mut seen_params = HashSet::new(); + for ident in load_refs { + if seen_params.contains(&ident.name) { + continue; + } + if ident.synthetic_load { + seen_params.insert(ident.name.clone()); + params.push(ident.name.clone()); + continue; + } + let defs = transaction.find_definition(handle, ident.position, FindPreference::default()); + let Some(def) = defs.first() else { + continue; + }; + if def.module.path() != module_info.path() { + continue; + } + if !module_stmt_range.contains_range(def.definition_range) + || selection.contains_range(def.definition_range) + || def.definition_range.start() >= selection.start() + { + continue; + } + seen_params.insert(ident.name.clone()); + params.push(ident.name.clone()); + } + + let mut returns = Vec::new(); + let mut seen_returns = HashSet::new(); + for ident in store_refs { + if seen_returns.contains(&ident.name) || !post_loads.contains(&ident.name) { + continue; + } + seen_returns.insert(ident.name.clone()); + returns.push(ident.name.clone()); + } + + let indented_body = prefix_lines_with(&dedented_body, HELPER_INDENT); + + let mut helper_text = if params.is_empty() { + format!("def {helper_name}():\n{indented_body}") + } else { + let helper_params = params.join(", "); + format!("def {helper_name}({helper_params}):\n{indented_body}") + }; + + if !returns.is_empty() && !returns.iter().all(|name| name.is_empty()) { + let return_expr = if returns.len() == 1 { + returns[0].clone() + } else { + returns.join(", ") + }; + helper_text.push_str(&format!("{HELPER_INDENT}return {return_expr}\n")); + } + helper_text.push('\n'); + + let call_args = params.join(", "); + let call_expr = format!("{helper_name}({call_args})"); + let replacement_line = if returns.is_empty() { + format!("{block_indent}{call_expr}\n") + } else { + let lhs = if returns.len() == 1 { + returns[0].clone() + } else { + returns.join(", ") + }; + format!("{block_indent}{lhs} = {call_expr}\n") + }; + + let helper_edit = ( + module_info.dupe(), + TextRange::at(module_stmt_range.start(), TextSize::new(0)), + helper_text, + ); + let call_edit = (module_info.dupe(), selection, replacement_line); + let action = LocalRefactorCodeAction { + title: format!("Extract into helper `{helper_name}`"), + edits: vec![helper_edit, call_edit], + kind: CodeActionKind::REFACTOR_EXTRACT, + }; + Some(vec![action]) +} + +#[derive(Clone, Debug)] +struct IdentifierRef { + /// Identifier string. + name: String, + /// Byte offset where the identifier was observed. + position: TextSize, + /// True when this "load" came from reading the left-hand side of an augmented assignment. + synthetic_load: bool, +} + +fn collect_identifier_refs( + ast: &ModModule, + selection: TextRange, +) -> (Vec, Vec) { + struct IdentifierCollector { + selection: TextRange, + loads: Vec, + stores: Vec, + } + + impl<'a> ruff_python_ast::visitor::Visitor<'a> for IdentifierCollector { + fn visit_expr(&mut self, expr: &'a Expr) { + if self.selection.contains_range(expr.range()) { + if let Expr::Name(name) = expr { + let ident = IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: false, + }; + match name.ctx { + ExprContext::Load => self.loads.push(ident), + ExprContext::Store => self.stores.push(ident), + ExprContext::Del | ExprContext::Invalid => {} + } + } + } + ruff_python_ast::visitor::walk_expr(self, expr); + } + + fn visit_stmt(&mut self, stmt: &'a Stmt) { + if self.selection.contains_range(stmt.range()) { + if let Stmt::AugAssign(aug) = stmt { + if let Expr::Name(name) = aug.target.as_ref() { + self.loads.push(IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: true, + }); + } + } + } + ruff_python_ast::visitor::walk_stmt(self, stmt); + } + } + + let mut collector = IdentifierCollector { + selection, + loads: Vec::new(), + stores: Vec::new(), + }; + collector.visit_body(&ast.body); + (collector.loads, collector.stores) +} + +fn selection_contains_disallowed_statements(ast: &ModModule, selection: TextRange) -> bool { + fn visit_stmt(stmt: &Stmt, selection: TextRange, found: &mut bool) { + if *found || stmt.range().intersect(selection).is_none() { + return; + } + if selection.contains_range(stmt.range()) { + match stmt { + Stmt::Return(_) + | Stmt::Break(_) + | Stmt::Continue(_) + | Stmt::Raise(_) + | Stmt::FunctionDef(_) + | Stmt::ClassDef(_) => { + *found = true; + return; + } + _ => {} + } + } + stmt.recurse(&mut |child| visit_stmt(child, selection, found)); + } + + let mut found = false; + for stmt in &ast.body { + visit_stmt(stmt, selection, &mut found); + if found { + break; + } + } + found +} + +fn find_enclosing_module_statement_range( + ast: &ModModule, + selection: TextRange, + module_len: TextSize, +) -> TextRange { + for stmt in &ast.body { + if stmt.range().contains_range(selection) { + return stmt.range(); + } + } + TextRange::new(TextSize::new(0), module_len) +} + +fn collect_post_selection_loads( + ast: &ModModule, + module_stmt_range: TextRange, + selection_end: TextSize, +) -> HashSet { + let mut loads = HashSet::new(); + ast.visit(&mut |expr: &Expr| { + if let Expr::Name(name) = expr { + if matches!(name.ctx, ExprContext::Load) + && module_stmt_range.contains_range(name.range) + && name.range.start() > selection_end + { + loads.insert(name.id.to_string()); + } + } + }); + loads +} + +fn detect_block_indent(selection_text: &str) -> String { + for line in selection_text.lines() { + if line.trim().is_empty() { + continue; + } + return line + .chars() + .take_while(|c| c.is_whitespace()) + .collect::(); + } + String::new() +} + +fn prefix_lines_with(block: &str, indent: &str) -> String { + let mut result = String::new(); + for line in block.lines() { + result.push_str(indent); + result.push_str(line); + result.push('\n'); + } + result +} + +fn generate_helper_name(source: &str) -> String { + let mut counter = 1; + loop { + let candidate = if counter == 1 { + "extracted_function".to_string() + } else { + format!("extracted_function_{counter}") + }; + let needle = format!("def {candidate}("); + if !source.contains(&needle) { + return candidate; + } + counter += 1; + } +} diff --git a/pyrefly/lib/state/lsp/quick_fixes/mod.rs b/pyrefly/lib/state/lsp/quick_fixes/mod.rs new file mode 100644 index 0000000000..907c9bc267 --- /dev/null +++ b/pyrefly/lib/state/lsp/quick_fixes/mod.rs @@ -0,0 +1 @@ +pub(crate) mod extract_function; diff --git a/pyrefly/lib/test/lsp/code_actions.rs b/pyrefly/lib/test/lsp/code_actions.rs index 6095bffc4c..4ed51003a8 100644 --- a/pyrefly/lib/test/lsp/code_actions.rs +++ b/pyrefly/lib/test/lsp/code_actions.rs @@ -96,6 +96,43 @@ fn find_marked_range(source: &str) -> TextRange { ) } +fn compute_extract_actions( + code: &str, +) -> ( + ModuleInfo, + Vec>, + Vec, +) { + let (handles, state) = + mk_multi_file_state_assert_no_errors(&[("main", code)], Require::Everything); + let handle = handles.get("main").unwrap(); + let transaction = state.transaction(); + let module_info = transaction.get_module_info(handle).unwrap(); + let selection = find_marked_range(module_info.contents()); + let actions = transaction + .extract_function_code_actions(handle, selection) + .unwrap_or_default(); + let edit_sets: Vec> = + actions.iter().map(|action| action.edits.clone()).collect(); + let titles = actions.iter().map(|action| action.title.clone()).collect(); + (module_info, edit_sets, titles) +} + +fn apply_first_extract_action(code: &str) -> Option { + let (module_info, actions, _) = compute_extract_actions(code); + let edits = actions.first()?; + Some(apply_refactor_edits_for_module(&module_info, edits)) +} + +fn assert_no_extract_action(code: &str) { + let (_, actions, _) = compute_extract_actions(code); + assert!( + actions.is_empty(), + "expected no extract-function actions, found {}", + actions.len() + ); +} + #[test] fn basic_test() { let report = get_batched_lsp_operations_report_allow_error( @@ -303,17 +340,7 @@ if __name__ == "__main__": result = process_data(data) print(f"The final sum is: {result}") "#; - let (handles, state) = - mk_multi_file_state_assert_no_errors(&[("main", code)], Require::Everything); - let handle = handles.get("main").unwrap(); - let transaction = state.transaction(); - let module_info = transaction.get_module_info(handle).unwrap(); - let selection = find_marked_range(module_info.contents()); - let actions = transaction - .extract_function_code_actions(handle, selection) - .unwrap_or_default(); - assert!(!actions.is_empty(), "expected extract refactor action"); - let updated = apply_refactor_edits_for_module(&module_info, &actions[0].edits); + let updated = apply_first_extract_action(code).expect("expected extract refactor action"); let expected = r#" def extracted_function(item, total_sum): squared_value = item * item @@ -338,3 +365,85 @@ if __name__ == "__main__": "#; assert_eq!(expected.trim(), updated.trim()); } + +#[test] +fn extract_function_method_scope_preserves_indent() { + let code = r#" +class Processor: + def consume(self, item): + print(item) + + def process(self, data_list): + for item in data_list: + # EXTRACT-START + squared_value = item * item + if squared_value > 10: + self.consume(squared_value) + # EXTRACT-END + return len(data_list) +"#; + let updated = apply_first_extract_action(code).expect("expected extract refactor action"); + let expected = r#" +def extracted_function(item, self): + squared_value = item * item + if squared_value > 10: + self.consume(squared_value) + +class Processor: + def consume(self, item): + print(item) + + def process(self, data_list): + for item in data_list: + # EXTRACT-START + extracted_function(item, self) + # EXTRACT-END + return len(data_list) +"#; + assert_eq!(expected.trim(), updated.trim()); +} + +#[test] +fn extract_function_rejects_empty_selection() { + let code = r#" +def sink(values): + for value in values: + # EXTRACT-START + # EXTRACT-END + print(value) +"#; + assert!( + apply_first_extract_action(code).is_none(), + "expected no refactor action for empty selection" + ); +} + +#[test] +fn extract_function_rejects_return_statement() { + let code = r#" +def sink(values): + # EXTRACT-START + return values[0] + # EXTRACT-END +"#; + assert_no_extract_action(code); +} + +#[test] +#[ignore = "multiple insertion point choices not yet supported"] +fn extract_function_offers_inner_function_option() { + let code = r#" +def outer(xs): + # EXTRACT-START + running = 0 + for x in xs: + running += x + # EXTRACT-END + return running +"#; + let (_, _, titles) = compute_extract_actions(code); + assert!( + titles.iter().any(|title| title.contains("module scope")), + "expected at least one extract action when control flow is simple" + ); +} From 99c9a6f0fc7da4c499ebfa29cab7b5e0fa27ac93 Mon Sep 17 00:00:00 2001 From: Asuka Minato Date: Tue, 9 Dec 2025 17:22:14 +0900 Subject: [PATCH 3/3] clippy --- crates/pyrefly_bundled/src/lib.rs | 8 +-- .../state/lsp/quick_fixes/extract_function.rs | 70 ++++++++----------- 2 files changed, 35 insertions(+), 43 deletions(-) diff --git a/crates/pyrefly_bundled/src/lib.rs b/crates/pyrefly_bundled/src/lib.rs index 52f798b5a3..b5e2c4d2fa 100644 --- a/crates/pyrefly_bundled/src/lib.rs +++ b/crates/pyrefly_bundled/src/lib.rs @@ -83,10 +83,10 @@ fn extract_pyi_files_from_archive(filter: PathFilter) -> anyhow::Result ruff_python_ast::visitor::Visitor<'a> for IdentifierCollector { fn visit_expr(&mut self, expr: &'a Expr) { - if self.selection.contains_range(expr.range()) { - if let Expr::Name(name) = expr { - let ident = IdentifierRef { - name: name.id.to_string(), - position: name.range.start(), - synthetic_load: false, - }; - match name.ctx { - ExprContext::Load => self.loads.push(ident), - ExprContext::Store => self.stores.push(ident), - ExprContext::Del | ExprContext::Invalid => {} - } + if self.selection.contains_range(expr.range()) + && let Expr::Name(name) = expr + { + let ident = IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: false, + }; + match name.ctx { + ExprContext::Load => self.loads.push(ident), + ExprContext::Store => self.stores.push(ident), + ExprContext::Del | ExprContext::Invalid => {} } } ruff_python_ast::visitor::walk_expr(self, expr); } fn visit_stmt(&mut self, stmt: &'a Stmt) { - if self.selection.contains_range(stmt.range()) { - if let Stmt::AugAssign(aug) = stmt { - if let Expr::Name(name) = aug.target.as_ref() { - self.loads.push(IdentifierRef { - name: name.id.to_string(), - position: name.range.start(), - synthetic_load: true, - }); - } - } + if self.selection.contains_range(stmt.range()) + && let Stmt::AugAssign(aug) = stmt + && let Expr::Name(name) = aug.target.as_ref() + { + self.loads.push(IdentifierRef { + name: name.id.to_string(), + position: name.range.start(), + synthetic_load: true, + }); } ruff_python_ast::visitor::walk_stmt(self, stmt); } @@ -270,13 +263,12 @@ fn collect_post_selection_loads( ) -> HashSet { let mut loads = HashSet::new(); ast.visit(&mut |expr: &Expr| { - if let Expr::Name(name) = expr { - if matches!(name.ctx, ExprContext::Load) - && module_stmt_range.contains_range(name.range) - && name.range.start() > selection_end - { - loads.insert(name.id.to_string()); - } + if let Expr::Name(name) = expr + && matches!(name.ctx, ExprContext::Load) + && module_stmt_range.contains_range(name.range) + && name.range.start() > selection_end + { + loads.insert(name.id.to_string()); } }); loads @@ -309,7 +301,7 @@ fn generate_helper_name(source: &str) -> String { let mut counter = 1; loop { let candidate = if counter == 1 { - "extracted_function".to_string() + "extracted_function".to_owned() } else { format!("extracted_function_{counter}") };