Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 5 additions & 53 deletions src/context/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -99,61 +99,15 @@ export const getEntitiesInRange = (
}

/**
* Parse import source from an import entity
* Get import source from an import entity
*
* Extracts the source module path from import signatures like:
* - `import { foo } from 'module'`
* - `import foo from 'module'`
* - `import * as foo from 'module'`
* Uses the pre-extracted source from AST parsing (works for all languages).
*
* @param entity - The import entity
* @returns The import source or empty string if not found
*/
const parseImportSource = (entity: ExtractedEntity): string => {
// Try to extract from signature using regex
// Common patterns: from 'source' or from "source"
const fromMatch = entity.signature.match(/from\s+['"]([^'"]+)['"]/)
if (fromMatch?.[1]) {
return fromMatch[1]
}

// For CommonJS style: require('source')
const requireMatch = entity.signature.match(/require\s*\(\s*['"]([^'"]+)['"]/)
if (requireMatch?.[1]) {
return requireMatch[1]
}

return ''
}

/**
* Check if an import is a default import
*
* @param entity - The import entity
* @returns Whether this is a default import
*/
const isDefaultImport = (entity: ExtractedEntity): boolean => {
// Default import patterns:
// import foo from 'module'
// But NOT: import { foo } from 'module'
// And NOT: import * as foo from 'module'
const signature = entity.signature
return (
/^import\s+\w+\s+from/.test(signature) &&
!/^import\s*\{/.test(signature) &&
!/^import\s*\*/.test(signature)
)
}

/**
* Check if an import is a namespace import
*
* @param entity - The import entity
* @returns Whether this is a namespace import
*/
const isNamespaceImport = (entity: ExtractedEntity): boolean => {
// Namespace import pattern: import * as foo from 'module'
return /^import\s*\*\s*as\s+\w+/.test(entity.signature)
const getImportSource = (entity: ExtractedEntity): string => {
return entity.source ?? ''
}

/**
Expand All @@ -178,9 +132,7 @@ export const getRelevantImports = (
// Map import entity to ImportInfo
const mapToImportInfo = (entity: ExtractedEntity): ImportInfo => ({
name: entity.name,
source: parseImportSource(entity),
isDefault: isDefaultImport(entity) || undefined,
isNamespace: isNamespaceImport(entity) || undefined,
source: getImportSource(entity),
})

// If not filtering, return all imports
Expand Down
9 changes: 8 additions & 1 deletion src/extract/fallback.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import type {
SyntaxNode,
} from '../types'
import { extractDocstring } from './docstring'
import { extractName, extractSignature } from './signature'
import { extractImportSource, extractName, extractSignature } from './signature'

/**
* Node types that represent extractable entities by language
Expand Down Expand Up @@ -176,6 +176,12 @@ function walkAndExtract(
// Extract docstring
const docstring = yield* extractDocstring(node, language, code)

// Extract import source for import entities
const source =
entityType === 'import'
? (extractImportSource(node, language) ?? undefined)
: undefined

// Create entity
const entity: ExtractedEntity = {
type: entityType,
Expand All @@ -192,6 +198,7 @@ function walkAndExtract(
},
parent: parentName,
node,
source,
}

entities.push(entity)
Expand Down
11 changes: 9 additions & 2 deletions src/extract/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import {
getEntityType,
} from './fallback'
import { type CompiledQuery, loadQuery, loadQuerySync } from './queries'
import { extractName, extractSignature } from './signature'
import { extractImportSource, extractName, extractSignature } from './signature'

/**
* Error when entity extraction fails
Expand Down Expand Up @@ -168,6 +168,12 @@ function matchesToEntities(
// Find parent entity
const parent = findParentEntityName(itemNode, rootNode, language)

// Extract import source for import entities
const source =
entityType === 'import'
? (extractImportSource(itemNode, language) ?? undefined)
: undefined

const entity: ExtractedEntity = {
type: entityType,
name,
Expand All @@ -183,6 +189,7 @@ function matchesToEntities(
},
parent,
node: itemNode,
source,
}

entities.push(entity)
Expand Down Expand Up @@ -359,4 +366,4 @@ export {
} from './fallback'
export type { CompiledQuery, QueryLoadError } from './queries'
export { clearQueryCache, loadQuery, loadQuerySync } from './queries'
export { extractName, extractSignature } from './signature'
export { extractImportSource, extractName, extractSignature } from './signature'
184 changes: 184 additions & 0 deletions src/extract/signature.ts
Original file line number Diff line number Diff line change
Expand Up @@ -356,3 +356,187 @@ export const extractSignature = (
export const getBodyDelimiter = (language: Language): string => {
return BODY_DELIMITERS[language]
}

/**
* Node types that represent import source/path by language
*/
const IMPORT_SOURCE_NODE_TYPES: readonly string[] = [
'string',
'string_literal',
'interpreted_string_literal', // Go
'source', // Some grammars use this field name
]

/**
* Extract the import source path from an import AST node
*
* Works for all supported languages by looking at the AST structure:
* - JS/TS: import { foo } from 'source' -> string child
* - Python: from source import foo -> 'module_name' field or dotted_name
* - Rust: use crate::module::item -> scoped_identifier or path
* - Go: import "source" -> interpreted_string_literal
* - Java: import package.Class -> scoped_identifier
*
* @param node - The import AST node
* @param language - The programming language
* @returns The import source path, or null if not found
*/
export const extractImportSource = (
node: SyntaxNode,
language: Language,
): string | null => {
// Try the 'source' field first (common in many grammars)
const sourceField = node.childForFieldName('source')
if (sourceField) {
return stripQuotes(sourceField.text)
}

// Language-specific extraction
switch (language) {
case 'typescript':
case 'javascript': {
// Look for string literal child (the 'from "..."' part)
for (const child of node.children) {
if (child.type === 'string') {
return stripQuotes(child.text)
}
}
break
}

case 'python': {
// For 'from X import Y', look for module_name field or dotted_name
const moduleNameField = node.childForFieldName('module_name')
if (moduleNameField) {
return moduleNameField.text
}
// For 'import X' style
const nameField = node.childForFieldName('name')
if (nameField) {
return nameField.text
}
// Fallback: look for dotted_name
for (const child of node.children) {
if (child.type === 'dotted_name') {
return child.text
}
}
break
}

case 'rust': {
// For 'use path::to::item', extract the path
// Look for scoped_identifier, use_wildcard, use_list, or identifier
const argumentField = node.childForFieldName('argument')
if (argumentField) {
// Get the path part (everything except the last segment if it's a use_list)
return extractRustUsePath(argumentField)
}
// Fallback: look for children that could be paths
for (const child of node.children) {
if (
child.type === 'scoped_identifier' ||
child.type === 'identifier' ||
child.type === 'use_wildcard'
) {
return extractRustUsePath(child)
}
}
break
}

case 'go': {
// For 'import "path"', look for import_spec or interpreted_string_literal
for (const child of node.children) {
// Single import: import "fmt" -> has import_spec child
if (child.type === 'import_spec') {
const pathNode = child.childForFieldName('path')
if (pathNode) {
return stripQuotes(pathNode.text)
}
// Fallback: look for string literal in import_spec
for (const specChild of child.children) {
if (specChild.type === 'interpreted_string_literal') {
return stripQuotes(specChild.text)
}
}
}
// Direct string literal (some Go grammars)
if (child.type === 'interpreted_string_literal') {
return stripQuotes(child.text)
}
// For import blocks: import ( "fmt" "os" )
if (child.type === 'import_spec_list') {
for (const spec of child.children) {
if (spec.type === 'import_spec') {
const pathNode = spec.childForFieldName('path')
if (pathNode) {
return stripQuotes(pathNode.text)
}
}
}
}
}
break
}

case 'java': {
// For 'import package.Class', look for scoped_identifier
for (const child of node.children) {
if (child.type === 'scoped_identifier') {
return child.text
}
}
break
}
}

// Fallback: look for any string-like child
for (const child of node.children) {
if (IMPORT_SOURCE_NODE_TYPES.includes(child.type)) {
return stripQuotes(child.text)
}
}

return null
}

/**
* Extract the path from a Rust use declaration
* For 'std::collections::HashMap', returns 'std::collections::HashMap'
* For 'std::collections::{HashMap, HashSet}', returns 'std::collections'
*/
const extractRustUsePath = (node: SyntaxNode): string => {
// If it's a use_list (e.g., {HashMap, HashSet}), get the parent path
if (node.type === 'use_list') {
return ''
}

// For scoped_identifier, check if the last part is a use_list
if (node.type === 'scoped_identifier') {
const lastChild = node.children[node.children.length - 1]
if (lastChild?.type === 'use_list') {
// Return everything except the use_list
const pathChild = node.childForFieldName('path')
if (pathChild) {
return pathChild.text
}
}
}

return node.text
}

/**
* Strip surrounding quotes from a string
*/
const stripQuotes = (str: string): string => {
if (
(str.startsWith('"') && str.endsWith('"')) ||
(str.startsWith("'") && str.endsWith("'")) ||
(str.startsWith('`') && str.endsWith('`'))
) {
return str.slice(1, -1)
}
return str
}
2 changes: 2 additions & 0 deletions src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ export interface ExtractedEntity {
parent: string | null
/** The underlying AST node */
node: SyntaxNode
/** Import source path (only for import entities) */
source?: string
}

/**
Expand Down
Loading