diff --git a/crates/pyrefly_types/src/display.rs b/crates/pyrefly_types/src/display.rs index b2a574fee1..583706fed2 100644 --- a/crates/pyrefly_types/src/display.rs +++ b/crates/pyrefly_types/src/display.rs @@ -7,9 +7,11 @@ //! Display a type. The complexity comes from if we have two classes with the same name, //! we want to display disambiguating information (e.g. module name or location). +use std::cell::RefCell; use std::fmt; use std::fmt::Display; +use dupe::Dupe; use pyrefly_python::module_name::ModuleName; use pyrefly_python::qname::QName; use pyrefly_util::display::Fmt; @@ -89,6 +91,8 @@ pub struct TypeDisplayContext<'a> { /// Should we display for IDE Hover? This makes type names more readable but less precise. hover: bool, always_display_module_name: bool, + /// Modules encountered while formatting, used downstream (e.g. to decide which imports are required). + modules: RefCell>, } impl<'a> TypeDisplayContext<'a> { @@ -234,6 +238,9 @@ impl<'a> TypeDisplayContext<'a> { name: &str, output: &mut impl TypeOutput, ) -> fmt::Result { + self.modules + .borrow_mut() + .insert(ModuleName::from_str(module)); if self.always_display_module_name { // write!(f, "{module}.{name}") output.write_str(&format!("{}.{}", module, name)) @@ -243,6 +250,16 @@ impl<'a> TypeDisplayContext<'a> { } } + pub fn referenced_modules(&self) -> SmallSet { + let mut modules = self.modules.borrow().clone(); + for info in self.qnames.values() { + for module in info.info.keys() { + modules.insert(module.dupe()); + } + } + modules + } + fn fmt_helper_generic( &self, t: &Type, diff --git a/pyrefly/lib/lsp/non_wasm/server.rs b/pyrefly/lib/lsp/non_wasm/server.rs index 2939242ba1..fb249aa8f7 100644 --- a/pyrefly/lib/lsp/non_wasm/server.rs +++ b/pyrefly/lib/lsp/non_wasm/server.rs @@ -2621,22 +2621,31 @@ impl Server { )?; let res = t .into_iter() - .filter_map(|(text_size, label_text, _locations)| { + .filter_map(|hint| { // If the url is a notebook cell, filter out inlay hints for other cells - if info.to_cell_for_lsp(text_size) != maybe_cell_idx { + if info.to_cell_for_lsp(hint.position) != maybe_cell_idx { return None; } - let position = info.to_lsp_position(text_size); + let position = info.to_lsp_position(hint.position); // The range is half-open, so the end position is exclusive according to the spec. if position >= range.start && position < range.end { + let mut text_edits = Vec::with_capacity(1 + hint.import_edits.len()); + text_edits.push(TextEdit { + range: Range::new(position, position), + new_text: hint.label.clone(), + }); + for (offset, import_text) in hint.import_edits { + let insert_position = info.to_lsp_position(offset); + text_edits.push(TextEdit { + range: Range::new(insert_position, insert_position), + new_text: import_text, + }); + } Some(InlayHint { position, - label: InlayHintLabel::String(label_text.clone()), + label: InlayHintLabel::String(hint.label), kind: None, - text_edits: Some(vec![TextEdit { - range: Range::new(position, position), - new_text: label_text, - }]), + text_edits: Some(text_edits), tooltip: None, padding_left: None, padding_right: None, diff --git a/pyrefly/lib/playground.rs b/pyrefly/lib/playground.rs index eb28564fc8..3d3614f6fc 100644 --- a/pyrefly/lib/playground.rs +++ b/pyrefly/lib/playground.rs @@ -521,9 +521,12 @@ impl Playground { .get_module_info(handle) .zip(transaction.inlay_hints(handle, Default::default())) .map(|(info, hints)| { - hints.into_map(|(position, label, _locations)| { - let position = Position::from_display_pos(info.display_pos(position)); - InlayHint { label, position } + hints.into_map(|hint| { + let position = Position::from_display_pos(info.display_pos(hint.position)); + InlayHint { + label: hint.label, + position, + } }) }) .unwrap_or_default() diff --git a/pyrefly/lib/state/import_tracker.rs b/pyrefly/lib/state/import_tracker.rs new file mode 100644 index 0000000000..c6bab25a19 --- /dev/null +++ b/pyrefly/lib/state/import_tracker.rs @@ -0,0 +1,212 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * This source code is licensed under the MIT license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Helpers for harvesting imports and formatting type strings for inlay hints. + +use std::cmp::Reverse; + +use dupe::Dupe; +use pyrefly_python::module_name::ModuleName; +use ruff_python_ast::ModModule; +use ruff_python_ast::Stmt; +use ruff_python_ast::StmtImport; +use starlark_map::small_set::SmallSet; + +use crate::types::display::TypeDisplayContext; +use crate::types::types::Type; + +/// Tracks imports already present in a module and can determine which modules are still missing +/// for a given set of referenced modules. Also supports alias-aware replacement when displaying +/// type strings. +#[derive(Default)] +pub struct ImportTracker { + canonical_modules: SmallSet, + alias_modules: Vec<(ModuleName, String)>, +} + +impl ImportTracker { + /// Build an import tracker from the top-level `import ...` statements in a module. + pub fn from_ast(ast: &ModModule) -> Self { + let mut tracker = Self::default(); + for stmt in &ast.body { + if let Stmt::Import(stmt_import) = stmt { + tracker.record_import(stmt_import); + } + } + tracker + .alias_modules + .sort_by_key(|(module, _)| Reverse(module.as_str().len())); + tracker + } + + /// Record an `import ...` statement into the tracker. + pub fn record_import(&mut self, stmt_import: &StmtImport) { + for alias in &stmt_import.names { + let module_name = ModuleName::from_str(alias.name.as_str()); + if let Some(asname) = &alias.asname { + self.alias_modules + .push((module_name, asname.id.to_string())); + } else { + self.canonical_modules.insert(module_name); + } + } + } + + /// Replace any module prefixes that have been imported under an alias (e.g. `import typing as t`). + pub fn apply_aliases(&self, text: &str) -> String { + if self.alias_modules.is_empty() { + return text.to_owned(); + } + let bytes = text.as_bytes(); + let mut result = String::with_capacity(text.len()); + let mut i = 0; + while i < bytes.len() { + let mut replaced = false; + for (module, alias) in &self.alias_modules { + let module_str = module.as_str(); + if module_str.is_empty() { + continue; + } + let module_bytes = module_str.as_bytes(); + if i + module_bytes.len() <= bytes.len() + && &bytes[i..i + module_bytes.len()] == module_bytes + && Self::is_boundary(bytes, i, i + module_bytes.len()) + { + result.push_str(alias); + i += module_bytes.len(); + replaced = true; + break; + } + } + if !replaced { + result.push(bytes[i] as char); + i += 1; + } + } + result + } + + /// Modules that are referenced in the type string but not yet imported (excluding builtins/current). + pub fn missing_modules( + &self, + modules: &SmallSet, + current_module: ModuleName, + ) -> SmallSet { + let mut missing = SmallSet::new(); + for module in modules.iter() { + let module = module.dupe(); + if module.as_str().is_empty() + || module == current_module + || module == ModuleName::builtins() + || module == ModuleName::extra_builtins() + { + continue; + } + if self.module_is_imported(module) { + continue; + } + missing.insert(module); + } + missing + } + + fn module_is_imported(&self, module: ModuleName) -> bool { + self.alias_for(module).is_some() || self.has_canonical(module) + } + + fn alias_for(&self, module: ModuleName) -> Option { + let target = module.as_str(); + for (alias_module, alias_name) in &self.alias_modules { + let alias_module_str = alias_module.as_str(); + if alias_module_str.is_empty() { + continue; + } + if target == alias_module_str { + return Some(alias_name.clone()); + } + if target.len() > alias_module_str.len() + && target.starts_with(alias_module_str) + && target.as_bytes()[alias_module_str.len()] == b'.' + { + let remainder = &target[alias_module_str.len()..]; + return Some(format!("{alias_name}{remainder}")); + } + } + None + } + + fn has_canonical(&self, module: ModuleName) -> bool { + let target = module.as_str(); + self.canonical_modules.iter().any(|imported| { + let imported_str = imported.as_str(); + imported_str == target + || (target.len() > imported_str.len() + && target.starts_with(imported_str) + && target.as_bytes()[imported_str.len()] == b'.') + }) + } + + fn is_boundary(bytes: &[u8], start: usize, end: usize) -> bool { + (start == 0 || !Self::is_ident(bytes[start - 1])) + && (end == bytes.len() || !Self::is_ident(bytes[end])) + } + + fn is_ident(byte: u8) -> bool { + matches!(byte, b'a'..=b'z' | b'A'..=b'Z' | b'0'..=b'9' | b'_') + } +} + +/// Produce a user-facing type string (without module qualifiers) together with all referenced modules +/// (captured with module qualification) so callers can insert the necessary imports. +pub fn format_type_for_annotation(ty: &Type) -> (String, SmallSet) { + // First pass: force module names so referenced_modules collects everything, but ignore the text. + let mut module_ctx = TypeDisplayContext::new(&[ty]); + module_ctx.always_display_module_name_except_builtins(); + let _ = module_ctx.display(ty).to_string(); + let modules = module_ctx.referenced_modules(); + + // Second pass: produce a concise label without module qualifiers. + let display_ctx = TypeDisplayContext::new(&[ty]); + let text = display_ctx.display(ty).to_string(); + (text, modules) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn aliases_are_applied_at_boundaries_only() { + let module = ModuleName::from_str("typing"); + let mut tracker = ImportTracker::default(); + tracker.alias_modules.push((module, "t".to_owned())); + assert_eq!(tracker.apply_aliases("typing.Literal"), "t.Literal"); + // Do not replace inside longer identifiers + assert_eq!(tracker.apply_aliases("mytyping"), "mytyping"); + } + + #[test] + fn missing_modules_skips_builtin_and_current() { + let tracker = ImportTracker::default(); + let mut modules = SmallSet::new(); + let current = ModuleName::from_str("pkg.mod"); + modules.insert(current.dupe()); + modules.insert(ModuleName::builtins()); + modules.insert(ModuleName::from_str("typing")); + let missing = tracker.missing_modules(&modules, current); + assert!(missing.contains(&ModuleName::from_str("typing"))); + assert_eq!(missing.len(), 1); + } + + #[test] + fn format_type_collects_modules_but_returns_short_label() { + let ty = Type::LiteralString; + let (text, modules) = format_type_for_annotation(&ty); + assert_eq!(text, "LiteralString"); + assert!(modules.contains(&ModuleName::from_str("typing"))); + } +} diff --git a/pyrefly/lib/state/lsp.rs b/pyrefly/lib/state/lsp.rs index 3dc2c6596b..de90d17349 100644 --- a/pyrefly/lib/state/lsp.rs +++ b/pyrefly/lib/state/lsp.rs @@ -67,6 +67,8 @@ use crate::state::ide::IntermediateDefinition; use crate::state::ide::import_regular_import_edit; use crate::state::ide::insert_import_edit; use crate::state::ide::key_to_intermediate_definition; +use crate::state::import_tracker::ImportTracker; +use crate::state::import_tracker::format_type_for_annotation; use crate::state::lsp_attributes::AttributeContext; use crate::state::require::Require; use crate::state::state::CancellableTransaction; @@ -306,6 +308,18 @@ pub enum AnnotationKind { Variable, } +#[derive(Clone, Debug)] +pub struct InlayHintWithEdits { + pub position: TextSize, + pub label: String, + pub import_edits: Vec<(TextSize, String)>, +} + +struct RenderedTypeHint { + text: String, + import_edits: Vec<(TextSize, String)>, +} + impl IdentifierWithContext { fn from_stmt_import(id: &Identifier, alias: &Alias) -> Self { let identifier = id.clone(); @@ -2648,6 +2662,202 @@ impl<'a> Transaction<'a> { res.sort_by_key(|(score, _, _, _)| Reverse(*score)); res.into_map(|(_, handle, name, export)| (handle, name, export)) } + + pub fn inlay_hints( + &self, + handle: &Handle, + inlay_hint_config: InlayHintConfig, + ) -> Option> { + let is_interesting = |e: &Expr, ty: &Type, class_name: Option<&Name>| { + !ty.is_any() + && match e { + Expr::Tuple(tuple) => { + !tuple.elts.is_empty() && tuple.elts.iter().all(|x| !Ast::is_literal(x)) + } + Expr::Call(ExprCall { func, .. }) => { + if let Expr::Name(name) = &**func + && let Some(class_name) = class_name + { + *name.id() != *class_name + } else if let Expr::Attribute(attr) = &**func + && let Some(class_name) = class_name + { + *attr.attr.id() != *class_name + } else { + true + } + } + _ => !Ast::is_literal(e), + } + }; + let bindings = self.get_bindings(handle)?; + let ast = self.get_ast(handle); + let import_tracker = ast.as_deref().map(ImportTracker::from_ast); + let mut res: Vec = Vec::new(); + for idx in bindings.keys::() { + match bindings.idx_to_key(idx) { + key @ Key::ReturnType(id) => { + if inlay_hint_config.function_return_types { + match bindings.get(bindings.key_to_idx(&Key::Definition(*id))) { + Binding::Function(x, _pred, _class_meta) => { + if matches!(&bindings.get(idx), Binding::ReturnType(ret) if !ret.kind.has_return_annotation()) + && let Some(mut ty) = self.get_type(handle, key) + && !ty.is_any() + { + let fun = bindings.get(bindings.get(*x).undecorated_idx); + if fun.def.is_async + && let Some(Some((_, _, return_ty))) = self + .ad_hoc_solve(handle, |solver| { + solver.unwrap_coroutine(&ty) + }) + { + ty = return_ty; + } + let rendered = self.render_type_hint( + &ty, + handle, + import_tracker.as_ref(), + ast.as_deref(), + ); + res.push(InlayHintWithEdits { + position: fun.def.parameters.range.end(), + label: format!(" -> {}", rendered.text), + import_edits: rendered.import_edits, + }); + } + } + _ => {} + } + } + } + key @ Key::Definition(_) + if inlay_hint_config.variable_types + && let Some(ty) = self.get_type(handle, key) => + { + let e = match bindings.get(idx) { + Binding::NameAssign(_, None, e, _) => Some(&**e), + Binding::Expr(None, e) => Some(e), + _ => None, + }; + // If the inferred type is a class type w/ no type arguments and the + // RHS is a call to a function that's the same name as the inferred class, + // we assume it's a constructor and do not display an inlay hint + let class_name = if let Type::ClassType(cls) = &ty + && cls.targs().is_empty() + { + Some(cls.name()) + } else { + None + }; + if let Some(e) = e + && is_interesting(e, &ty, class_name) + { + let rendered = self.render_type_hint( + &ty, + handle, + import_tracker.as_ref(), + ast.as_deref(), + ); + res.push(InlayHintWithEdits { + position: key.range().end(), + label: format!(": {}", rendered.text), + import_edits: rendered.import_edits, + }); + } + } + _ => {} + } + } + + if inlay_hint_config.call_argument_names != AllOffPartial::Off { + res.extend(self.add_inlay_hints_for_positional_function_args(handle)); + } + + Some(res) + } + + fn add_inlay_hints_for_positional_function_args( + &self, + handle: &Handle, + ) -> Vec { + let mut param_hints: Vec = Vec::new(); + + if let Some(mod_module) = self.get_ast(handle) { + let function_calls = Self::collect_function_calls_from_ast(mod_module); + + for call in function_calls { + if let Some(answers) = self.get_answers(handle) { + let callee_type = if let Some((overloads, chosen_idx)) = + answers.get_all_overload_trace(call.arguments.range) + { + // If we have overload information, use the chosen overload + overloads + .get(chosen_idx.unwrap_or_default()) + .map(|c| Type::Callable(Box::new(c.clone()))) + } else { + // Otherwise, try to get the type of the callee directly + answers.get_type_trace(call.func.range()) + }; + + if let Some(params) = + callee_type.and_then(Self::normalize_singleton_function_type_into_params) + { + for (arg_idx, arg) in call.arguments.args.iter().enumerate() { + // Skip keyword arguments - they already show their parameter name + let is_keyword_arg = call + .arguments + .keywords + .iter() + .any(|kw| kw.value.range() == arg.range()); + + if !is_keyword_arg + && let Some( + Param::Pos(name, _, _) + | Param::PosOnly(Some(name), _, _) + | Param::KwOnly(name, _, _), + ) = params.get(arg_idx) + && name.as_str() != "self" + && name.as_str() != "cls" + { + param_hints.push(InlayHintWithEdits { + position: arg.range().start(), + label: format!("{}= ", name.as_str()), + import_edits: Vec::new(), + }); + } + } + } + } + } + } + + param_hints.sort_by_key(|hint| hint.position); + param_hints + } + + fn render_type_hint( + &self, + ty: &Type, + handle: &Handle, + tracker: Option<&ImportTracker>, + ast: Option<&ModModule>, + ) -> RenderedTypeHint { + let (mut text, modules) = format_type_for_annotation(ty); + let mut import_edits = Vec::new(); + if let (Some(tracker), Some(ast)) = (tracker, ast) { + text = tracker.apply_aliases(&text); + for module in tracker + .missing_modules(&modules, handle.module()) + .into_iter() + { + if let Some(handle_to_import) = self.import_handle(handle, module, None).finding() { + let (position, insert_text) = import_regular_import_edit(ast, handle_to_import); + import_edits.push((position, insert_text)); + } + } + } + RenderedTypeHint { text, import_edits } + } } impl<'a> CancellableTransaction<'a> { diff --git a/pyrefly/lib/state/mod.rs b/pyrefly/lib/state/mod.rs index 358473b832..f6bc86516c 100644 --- a/pyrefly/lib/state/mod.rs +++ b/pyrefly/lib/state/mod.rs @@ -9,6 +9,7 @@ pub mod dirty; pub mod epoch; pub mod errors; pub mod ide; +pub mod import_tracker; pub mod load; pub mod loader; pub mod lsp; diff --git a/pyrefly/lib/test/lsp/inlay_hint.rs b/pyrefly/lib/test/lsp/inlay_hint.rs index 223241482f..a926c80b94 100644 --- a/pyrefly/lib/test/lsp/inlay_hint.rs +++ b/pyrefly/lib/test/lsp/inlay_hint.rs @@ -22,14 +22,14 @@ fn generate_inlay_hint_report(code: &str, hint_config: InlayHintConfig) -> Strin report.push_str(name); report.push_str(".py\n"); let handle = handles.get(name).unwrap(); - for (pos, hint, _) in state + for hint in state .transaction() .inlay_hints(handle, hint_config) .unwrap() { - report.push_str(&code_frame_of_source_at_position(code, pos)); + report.push_str(&code_frame_of_source_at_position(code, hint.position)); report.push_str(" inlay-hint: `"); - report.push_str(&hint); + report.push_str(&hint.label); report.push_str("`\n\n"); } report.push('\n'); diff --git a/pyrefly/lib/test/lsp/lsp_interaction/inlay_hint.rs b/pyrefly/lib/test/lsp/lsp_interaction/inlay_hint.rs index beda108aed..141ec1d4cc 100644 --- a/pyrefly/lib/test/lsp/lsp_interaction/inlay_hint.rs +++ b/pyrefly/lib/test/lsp/lsp_interaction/inlay_hint.rs @@ -36,6 +36,10 @@ fn test_inlay_hint_default_config() { "textEdits":[{ "newText":" -> tuple[Literal[1], Literal[2]]", "range":{"end":{"character":21,"line":6},"start":{"character":21,"line":6}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] }, { @@ -44,6 +48,10 @@ fn test_inlay_hint_default_config() { "textEdits":[{ "newText":": tuple[Literal[1], Literal[2]]", "range":{"end":{"character":6,"line":11},"start":{"character":6,"line":11}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] }, { @@ -52,6 +60,10 @@ fn test_inlay_hint_default_config() { "textEdits":[{ "newText":" -> Literal[0]", "range":{"end":{"character":15,"line":14},"start":{"character":15,"line":14}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] } ]), @@ -154,6 +166,10 @@ fn test_inlay_hint_disable_variables() { "textEdits":[{ "newText":" -> tuple[Literal[1], Literal[2]]", "range":{"end":{"character":21,"line":6},"start":{"character":21,"line":6}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] }, { @@ -162,6 +178,10 @@ fn test_inlay_hint_disable_variables() { "textEdits":[{ "newText":" -> Literal[0]", "range":{"end":{"character":15,"line":14},"start":{"character":15,"line":14}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] }]), ); @@ -199,6 +219,10 @@ fn test_inlay_hint_disable_returns() { "textEdits":[{ "newText":": tuple[Literal[1], Literal[2]]", "range":{"end":{"character":6,"line":11},"start":{"character":6,"line":11}} + }, + { + "newText":"import typing\n", + "range":{"end":{"character":0,"line":6},"start":{"character":0,"line":6}} }] }]), ); diff --git a/pyrefly/lib/test/lsp/lsp_interaction/notebook_inlay_hint.rs b/pyrefly/lib/test/lsp/lsp_interaction/notebook_inlay_hint.rs index 5edbf1a4d6..dc33fbf514 100644 --- a/pyrefly/lib/test/lsp/lsp_interaction/notebook_inlay_hint.rs +++ b/pyrefly/lib/test/lsp/lsp_interaction/notebook_inlay_hint.rs @@ -39,6 +39,9 @@ fn test_inlay_hints() { "textEdits": [{ "newText": " -> tuple[Literal[1], Literal[2]]", "range": {"end": {"character": 21, "line": 0}, "start": {"character": 21, "line": 0}} + }, { + "newText": "import typing\n", + "range": {"end": {"character": 0, "line": 0}, "start": {"character": 0, "line": 0}} }] }]), ); @@ -52,6 +55,9 @@ fn test_inlay_hints() { "textEdits": [{ "newText": ": tuple[Literal[1], Literal[2]]", "range": {"end": {"character": 6, "line": 0}, "start": {"character": 6, "line": 0}} + }, { + "newText": "import typing\n", + "range": {"end": {"character": 0, "line": 0}, "start": {"character": 0, "line": 0}} }] }]), ); @@ -65,6 +71,9 @@ fn test_inlay_hints() { "textEdits": [{ "newText": " -> Literal[0]", "range": {"end": {"character": 15, "line": 0}, "start": {"character": 15, "line": 0}} + }, { + "newText": "import typing\n", + "range": {"end": {"character": 0, "line": 0}, "start": {"character": 0, "line": 0}} }] }]), );