Skip to content

Commit b523c48

Browse files
authored
Merge pull request #25 from simple-repository/feature/refined-search
Refine the search implementation
2 parents 057e290 + 27452f3 commit b523c48

File tree

3 files changed

+665
-190
lines changed

3 files changed

+665
-190
lines changed

simple_repository_browser/_search.py

Lines changed: 287 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from __future__ import annotations
2+
13
import dataclasses
24
from enum import Enum
35
import re
@@ -45,72 +47,289 @@ def normalise_name(name: str) -> str:
4547
return re.sub(r"[-_.]+", "-", name).lower()
4648

4749

48-
# A safe SQL statement must not use *any* user-defined input in the resulting first argument (the SQL query),
49-
# rather any user input MUST be provided as part of the arguments (second part of the value), which will be passed
50-
# to SQLITE to deal with.
51-
SafeSQLStmt = typing.Tuple[str, typing.Tuple[typing.Any, ...]]
52-
53-
54-
def prepare_name(term: Filter) -> SafeSQLStmt:
55-
if term.value.startswith('"'):
56-
# Match the phase precisely.
57-
value = term.value[1:-1]
58-
else:
59-
value = normalise_name(term.value)
60-
value = value.replace("*", "%")
61-
return "canonical_name LIKE ?", (f"%{value}%",)
62-
63-
64-
def prepare_summary(term: Filter) -> SafeSQLStmt:
65-
if term.value.startswith('"'):
66-
# Match the phase precisely.
67-
value = term.value[1:-1]
68-
else:
69-
value = term.value
70-
value = value.replace("*", "%")
71-
return "summary LIKE ?", (f"%{value}%",)
72-
73-
74-
def build_sql(term: typing.Union[Term, typing.Tuple[Term, ...]]) -> SafeSQLStmt:
75-
# Return query and params to be used in SQL. query MUST not be produced using untrusted input, as is vulnerable to SQL injection.
76-
# Instead, any user input must be in the parameters, which undergoes sqllite built-in cleaning.
77-
if isinstance(term, tuple):
78-
if len(term) == 0:
79-
return "", ()
80-
81-
# No known query can produce a multi-value term
82-
assert len(term) == 1
83-
return build_sql(term[0])
84-
85-
if isinstance(term, Filter):
86-
if term.filter_on == FilterOn.name_or_summary:
87-
sql1, terms1 = prepare_name(term)
88-
sql2, terms2 = prepare_summary(term)
89-
return f"({sql1} OR {sql2})", terms1 + terms2
90-
elif term.filter_on == FilterOn.name:
91-
return prepare_name(term)
92-
elif term.filter_on == FilterOn.summary:
93-
return prepare_summary(term)
50+
@dataclasses.dataclass(frozen=True)
51+
class SQLBuilder:
52+
"""Immutable SQL WHERE and ORDER BY clauses with parameters."""
53+
54+
where_clause: str
55+
where_params: tuple[typing.Any, ...]
56+
order_clause: str
57+
order_params: tuple[typing.Any, ...]
58+
search_context: SearchContext
59+
60+
def build_complete_query(
61+
self,
62+
base_select: str,
63+
limit: int,
64+
offset: int,
65+
) -> tuple[str, tuple[typing.Any, ...]]:
66+
"""Build complete query with LIMIT/OFFSET"""
67+
where_part = f"WHERE {self.where_clause}" if self.where_clause else ""
68+
query = f"{base_select} {where_part} {self.order_clause} LIMIT ? OFFSET ?"
69+
return query, self.where_params + self.order_params + (limit, offset)
70+
71+
def with_where(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder:
72+
"""Return new SQLBuilder with updated WHERE clause"""
73+
return dataclasses.replace(self, where_clause=clause, where_params=params)
74+
75+
def with_order(self, clause: str, params: tuple[typing.Any, ...]) -> SQLBuilder:
76+
"""Return new SQLBuilder with updated ORDER BY clause"""
77+
return dataclasses.replace(self, order_clause=clause, order_params=params)
78+
79+
80+
@dataclasses.dataclass(frozen=True)
81+
class SearchContext:
82+
"""Context collected during WHERE clause building."""
83+
84+
exact_names: tuple[str, ...] = ()
85+
fuzzy_patterns: tuple[str, ...] = ()
86+
87+
def with_exact_name(self, name: str) -> SearchContext:
88+
"""Add an exact name match."""
89+
if name in self.exact_names:
90+
return self
91+
else:
92+
return dataclasses.replace(self, exact_names=self.exact_names + (name,))
93+
94+
def with_fuzzy_pattern(self, pattern: str) -> SearchContext:
95+
"""Add a fuzzy search pattern."""
96+
if pattern in self.fuzzy_patterns:
97+
return self
98+
else:
99+
return dataclasses.replace(
100+
self, fuzzy_patterns=self.fuzzy_patterns + (pattern,)
101+
)
102+
103+
def merge(self, other: SearchContext) -> SearchContext:
104+
"""Merge contexts from multiple terms (for OR/AND)."""
105+
names = self.exact_names + tuple(
106+
name for name in other.exact_names if name not in self.exact_names
107+
)
108+
patterns = self.fuzzy_patterns + tuple(
109+
pattern
110+
for pattern in other.fuzzy_patterns
111+
if pattern not in self.fuzzy_patterns
112+
)
113+
114+
return dataclasses.replace(self, exact_names=names, fuzzy_patterns=patterns)
115+
116+
117+
class SearchCompiler:
118+
"""Extensible visitor-pattern compiler for search terms to SQL.
119+
120+
Uses AST-style method dispatch: visit_TermName maps to handle_term_TermName.
121+
Subclasses can override specific handlers for customisation.
122+
"""
123+
124+
@classmethod
125+
def compile(cls, term: Term | None) -> SQLBuilder:
126+
"""Compile search terms into SQL WHERE and ORDER BY clauses."""
127+
if term is None:
128+
return SQLBuilder(
129+
where_clause="",
130+
where_params=(),
131+
order_clause="",
132+
order_params=(),
133+
search_context=SearchContext(),
134+
)
135+
136+
# Build WHERE clause and collect context
137+
context = SearchContext()
138+
where_clause, where_params, final_context = cls._visit_term(term, context)
139+
140+
# Build ORDER BY clause based on collected context
141+
order_clause, order_params = cls._build_ordering_from_context(final_context)
142+
143+
return SQLBuilder(
144+
where_clause=where_clause,
145+
where_params=where_params,
146+
order_clause=order_clause,
147+
order_params=order_params,
148+
search_context=final_context,
149+
)
150+
151+
@classmethod
152+
def _visit_term(
153+
cls, term: Term, context: SearchContext
154+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
155+
"""Dispatch to appropriate handler using AST-style method naming."""
156+
method_name = f"handle_term_{type(term).__name__}"
157+
handler = getattr(cls, method_name, None)
158+
if handler is None:
159+
raise ValueError(f"No handler for term type {type(term).__name__}")
160+
return handler(term, context)
161+
162+
@classmethod
163+
def handle_term_Filter(
164+
cls, term: Filter, context: SearchContext
165+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
166+
"""Dispatch to field-specific filter handler."""
167+
match term.filter_on:
168+
case FilterOn.name_or_summary:
169+
return cls.handle_filter_name_or_summary(term, context)
170+
case FilterOn.name:
171+
return cls.handle_filter_name(term, context)
172+
case FilterOn.summary:
173+
return cls.handle_filter_summary(term, context)
174+
case _:
175+
raise ValueError(f"Unhandled filter on {term.filter_on}")
176+
177+
@classmethod
178+
def handle_term_And(
179+
cls, term: And, context: SearchContext
180+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
181+
lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context)
182+
rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context)
183+
184+
merged_context = lhs_context.merge(rhs_context)
185+
return f"({lhs_sql} AND {rhs_sql})", lhs_params + rhs_params, merged_context
186+
187+
@classmethod
188+
def handle_term_Or(
189+
cls, term: Or, context: SearchContext
190+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
191+
lhs_sql, lhs_params, lhs_context = cls._visit_term(term.lhs, context)
192+
rhs_sql, rhs_params, rhs_context = cls._visit_term(term.rhs, context)
193+
194+
merged_context = lhs_context.merge(rhs_context)
195+
return f"({lhs_sql} OR {rhs_sql})", lhs_params + rhs_params, merged_context
196+
197+
@classmethod
198+
def handle_term_Not(
199+
cls, term: Not, context: SearchContext
200+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
201+
inner_sql, inner_params, _ = cls._visit_term(term.term, context)
202+
return f"(NOT {inner_sql})", inner_params, context
203+
204+
@classmethod
205+
def handle_filter_name(
206+
cls, term: Filter, context: SearchContext
207+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
208+
if term.value.startswith('"'):
209+
# Exact quoted match
210+
value = term.value[1:-1]
211+
normalised = normalise_name(value)
212+
new_context = context.with_exact_name(normalised)
213+
return "canonical_name = ?", (normalised,), new_context
94214
else:
95-
raise ValueError(f"Unhandled filter on {term.filter_on}")
96-
elif isinstance(term, And):
97-
sql1, terms1 = build_sql(term.lhs)
98-
sql2, terms2 = build_sql(term.rhs)
99-
return f"({sql1} AND {sql2})", terms1 + terms2
100-
elif isinstance(term, Or):
101-
sql1, terms1 = build_sql(term.lhs)
102-
sql2, terms2 = build_sql(term.rhs)
103-
return f"({sql1} OR {sql2})", terms1 + terms2
104-
elif isinstance(term, Not):
105-
sql1, terms1 = build_sql(term.term)
106-
return f"(Not {sql1})", terms1
107-
else:
108-
raise ValueError(f"unknown term type {type(term)}")
109-
110-
111-
def query_to_sql(query) -> SafeSQLStmt:
112-
terms = parse(query)
113-
return build_sql(terms)
215+
normalised = normalise_name(term.value)
216+
if "*" in term.value:
217+
# Fuzzy wildcard search - respect wildcard position
218+
# "numpy*" > "numpy%", "*numpy" > "%numpy", "*numpy*" > "%numpy%"
219+
pattern = normalised.replace("*", "%")
220+
new_context = context.with_fuzzy_pattern(pattern)
221+
return "canonical_name LIKE ?", (pattern,), new_context
222+
else:
223+
# Simple name search
224+
new_context = context.with_exact_name(normalised)
225+
return "canonical_name LIKE ?", (f"%{normalised}%",), new_context
226+
227+
@classmethod
228+
def handle_filter_summary(
229+
cls, term: Filter, context: SearchContext
230+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
231+
if term.value.startswith('"'):
232+
value = term.value[1:-1]
233+
else:
234+
value = term.value
235+
value = value.replace("*", "%")
236+
return "summary LIKE ?", (f"%{value}%",), context
237+
238+
@classmethod
239+
def handle_filter_name_or_summary(
240+
cls, term: Filter, context: SearchContext
241+
) -> tuple[str, tuple[typing.Any, ...], SearchContext]:
242+
"""Handle filtering across both name and summary fields."""
243+
name_sql, name_params, name_context = cls.handle_filter_name(term, context)
244+
summary_sql, summary_params, _ = cls.handle_filter_summary(term, context)
245+
246+
combined_sql = f"({name_sql} OR {summary_sql})"
247+
combined_params = name_params + summary_params
248+
return combined_sql, combined_params, name_context
249+
250+
@classmethod
251+
def _build_ordering_from_context(
252+
cls, context: SearchContext
253+
) -> tuple[str, tuple[typing.Any, ...]]:
254+
"""Build mixed ordering for exact names and fuzzy patterns."""
255+
256+
exact_names, fuzzy_patterns = context.exact_names, context.fuzzy_patterns
257+
order_parts = []
258+
all_params = []
259+
260+
# Build single comprehensive CASE statement for priority
261+
case_conditions = []
262+
263+
# Add exact match conditions (priority 0)
264+
for name in exact_names:
265+
case_conditions.append(f"WHEN canonical_name = ? THEN 0")
266+
all_params.append(name)
267+
268+
# Add fuzzy pattern conditions (priority 1)
269+
for pattern in fuzzy_patterns:
270+
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 1")
271+
all_params.append(f"%{pattern}%")
272+
273+
# Add exact-related conditions (priority 2)
274+
for name in exact_names:
275+
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # prefix
276+
case_conditions.append(f"WHEN canonical_name LIKE ? THEN 2") # suffix
277+
all_params.extend([f"{name}%", f"%{name}"])
278+
279+
if case_conditions:
280+
cond = "\n".join(case_conditions)
281+
priority_expr = f"""
282+
CASE
283+
{cond}
284+
ELSE 3
285+
END
286+
"""
287+
order_parts.append(priority_expr)
288+
289+
# Length-based ordering for fuzzy matches (reuse same pattern logic)
290+
if fuzzy_patterns:
291+
length_conditions = []
292+
for pattern in fuzzy_patterns:
293+
length_conditions.append(f"canonical_name LIKE ?")
294+
all_params.append(f"%{pattern}%")
295+
296+
length_expr = f"CASE WHEN ({' OR '.join(length_conditions)}) THEN LENGTH(canonical_name) ELSE 999999 END"
297+
order_parts.append(length_expr)
298+
299+
# Prefix distance for exact names
300+
if exact_names:
301+
distance_conditions = []
302+
for name in exact_names:
303+
distance_conditions.append(
304+
f"WHEN INSTR(canonical_name, ?) > 0 THEN (INSTR(canonical_name, ?) - 1)"
305+
)
306+
all_params.extend([name, name])
307+
308+
if distance_conditions:
309+
cond = "\n".join(distance_conditions)
310+
distance_expr = f"""
311+
CASE
312+
{cond}
313+
ELSE 999999
314+
END
315+
"""
316+
order_parts.append(distance_expr)
317+
318+
# Alphabetical fallback
319+
order_parts.append("canonical_name")
320+
321+
order_clause = f"ORDER BY {', '.join(order_parts)}"
322+
return order_clause, tuple(all_params)
323+
324+
325+
def build_sql(term: Term | None) -> SQLBuilder:
326+
"""Build SQL WHERE and ORDER BY clauses from search terms."""
327+
return SearchCompiler.compile(term)
328+
329+
330+
def query_to_sql(query) -> SQLBuilder:
331+
term = parse(query)
332+
return build_sql(term)
114333

115334

116335
grammar = parsley.makeGrammar(
@@ -141,8 +360,8 @@ def query_to_sql(query) -> SafeSQLStmt:
141360
|filter:filter -> filter
142361
|'-' filters:filters -> Not(filters)
143362
)
144-
search_terms = (filters+:filters -> tuple(filters)
145-
| -> ())
363+
search_terms = (filters:filters -> filters
364+
| -> None)
146365
"""),
147366
{
148367
"And": And,
@@ -154,21 +373,8 @@ def query_to_sql(query) -> SafeSQLStmt:
154373
)
155374

156375

157-
def parse(query: str) -> typing.Tuple[Term, ...]:
376+
def parse(query: str) -> Term | None:
158377
return grammar(query.strip()).search_terms()
159378

160379

161380
ParseError = parsley.ParseError
162-
163-
164-
def simple_name_from_query(terms: typing.Tuple[Term, ...]) -> typing.Optional[str]:
165-
"""If possible, give a simple (normalized) package name which represents the query terms provided"""
166-
for term in terms:
167-
if isinstance(term, Filter):
168-
if term.filter_on in [FilterOn.name_or_summary, FilterOn.name]:
169-
if "*" in term.value or '"' in term.value:
170-
break
171-
return normalise_name(term.value)
172-
else:
173-
break
174-
return None

0 commit comments

Comments
 (0)