diff --git a/src/tagstudio/core/library/alchemy/library.py b/src/tagstudio/core/library/alchemy/library.py index a25231e95..66ead4e6e 100644 --- a/src/tagstudio/core/library/alchemy/library.py +++ b/src/tagstudio/core/library/alchemy/library.py @@ -9,6 +9,7 @@ import re import shutil +import sys import time import unicodedata from collections.abc import Iterable, Iterator @@ -75,7 +76,6 @@ DB_VERSION_LEGACY_KEY, JSON_FILENAME, SQL_FILENAME, - TAG_CHILDREN_QUERY, ) from tagstudio.core.library.alchemy.db import make_tables from tagstudio.core.library.alchemy.enums import ( @@ -555,6 +555,20 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus: # Convert file extension list to ts_ignore file, if a .ts_ignore file does not exist self.migrate_sql_to_ts_ignore(library_dir) + session.execute( + text("CREATE INDEX IF NOT EXISTS idx_tags_name_shorthand ON tags (name, shorthand)") + ) + session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_tag_parents_child_id ON tag_parents (child_id)" + ) + ) + session.execute( + text( + "CREATE INDEX IF NOT EXISTS idx_tag_entries_entry_id ON tag_entries (entry_id)" + ) + ) + # Update DB_VERSION if loaded_db_version < DB_VERSION: self.set_version(DB_VERSION_CURRENT_KEY, DB_VERSION) @@ -1054,55 +1068,58 @@ def search_library( return res - def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]: + def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]: """Return a list of Tag records matching the query.""" + name = name or "" + name = name.lower() + + def sort_key(text: str): + priority = text.startswith(name) + p_ordering = len(text) if priority else sys.maxsize + return (not priority, p_ordering, text) + with Session(self.engine) as session: - query = select(Tag).outerjoin(TagAlias).order_by(func.lower(Tag.name)) - query = query.options( - selectinload(Tag.parent_tags), - selectinload(Tag.aliases), - ) - if limit > 0: - query = query.limit(limit) + query = select(Tag.id, Tag.name) + + if limit > 0 and not name: + query = query.order_by(Tag.name).limit(limit) if name: query = query.where( or_( Tag.name.icontains(name), Tag.shorthand.icontains(name), - TagAlias.name.icontains(name), ) ) - direct_tags = set(session.scalars(query)) - ancestor_tag_ids: list[Tag] = [] - for tag in direct_tags: - ancestor_tag_ids.extend( - list(session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag.id})) - ) + tags = list(session.execute(query)) - ancestor_tags = session.scalars( - select(Tag) - .where(Tag.id.in_(ancestor_tag_ids)) - .options(selectinload(Tag.parent_tags), selectinload(Tag.aliases)) - ) + if name: + query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name)) + tags.extend(session.execute(query)) - res = [ - direct_tags, - {at for at in ancestor_tags if at not in direct_tags}, - ] + tags.sort(key=lambda t: sort_key(t[1])) + # Use order from Tag.name or TagAlias.name depending on which comes first for each tag. + # Value=0 to avoid unnecessary copying of tag names. + tag_ids = list(dict((id, 0) for id, _ in tags).keys()) logger.info( "searching tags", search=name, limit=limit, statement=str(query), - results=len(res), + results=len(tag_ids), ) - session.expunge_all() + if limit <= 0: + limit = len(tag_ids) + tag_ids = tag_ids[:limit] - return res + hierarchy = self.get_tag_hierarchy(tag_ids) + direct_tags = [hierarchy.pop(id) for id in tag_ids] + ancestor_tags = list(hierarchy.values()) + ancestor_tags.sort(key=lambda t: sort_key(t.name)) + return direct_tags, ancestor_tags def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool: """Set the path field of an entry. diff --git a/src/tagstudio/qt/mixed/tag_search.py b/src/tagstudio/qt/mixed/tag_search.py index 53990c378..a0f2b0402 100644 --- a/src/tagstudio/qt/mixed/tag_search.py +++ b/src/tagstudio/qt/mixed/tag_search.py @@ -218,32 +218,13 @@ def update_tags(self, query: str | None = None): self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater() self.create_button_in_layout = False - # Get results for the search query - query_lower = "" if not query else query.lower() # Only use the tag limit if it's an actual number (aka not "All Tags") tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1 - tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit) - if self.exclude: - tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude} - tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude} - - # Sort and prioritize the results - results_0 = list(tag_results[0]) - results_0.sort(key=lambda tag: tag.name.lower()) - results_1 = list(tag_results[1]) - results_1.sort(key=lambda tag: tag.name.lower()) - raw_results = list(results_0 + results_1) - priority_results: set[Tag] = set() - all_results: list[Tag] = [] + direct_tags, ancestor_tags = self.lib.search_tags(name=query, limit=tag_limit) - if query and query.strip(): - for tag in raw_results: - if tag.name.lower().startswith(query_lower): - priority_results.add(tag) + all_results = [t for t in direct_tags if t.id not in self.exclude] + all_results.extend(t for t in ancestor_tags if t.id not in self.exclude) - all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [ - r for r in raw_results if r not in priority_results - ] if tag_limit > 0: all_results = all_results[:tag_limit] diff --git a/tests/test_library.py b/tests/test_library.py index 447344512..111e36116 100644 --- a/tests/test_library.py +++ b/tests/test_library.py @@ -131,10 +131,10 @@ def test_library_search(library: Library, entry_full: Entry): def test_tag_search(library: Library): tag = library.tags[0] - assert library.search_tags(tag.name.lower()) - assert library.search_tags(tag.name.upper()) - assert library.search_tags(tag.name[2:-2]) - assert library.search_tags(tag.name * 2) == [set(), set()] + assert library.search_tags(tag.name.lower())[0] + assert library.search_tags(tag.name.upper())[0] + assert library.search_tags(tag.name[2:-2])[0] + assert library.search_tags(tag.name * 2) == ([], []) def test_get_entry(library: Library, entry_min: Entry):