Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 45 additions & 28 deletions src/tagstudio/core/library/alchemy/library.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import re
import shutil
import sys
import time
import unicodedata
from collections.abc import Iterable, Iterator
Expand Down Expand Up @@ -75,7 +76,6 @@
DB_VERSION_LEGACY_KEY,
JSON_FILENAME,
SQL_FILENAME,
TAG_CHILDREN_QUERY,
)
from tagstudio.core.library.alchemy.db import make_tables
from tagstudio.core.library.alchemy.enums import (
Expand Down Expand Up @@ -555,6 +555,20 @@ def open_sqlite_library(self, library_dir: Path, is_new: bool) -> LibraryStatus:
# Convert file extension list to ts_ignore file, if a .ts_ignore file does not exist
self.migrate_sql_to_ts_ignore(library_dir)

session.execute(
text("CREATE INDEX IF NOT EXISTS idx_tags_name_shorthand ON tags (name, shorthand)")
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_parents_child_id ON tag_parents (child_id)"
)
)
session.execute(
text(
"CREATE INDEX IF NOT EXISTS idx_tag_entries_entry_id ON tag_entries (entry_id)"
)
)

# Update DB_VERSION
if loaded_db_version < DB_VERSION:
self.set_version(DB_VERSION_CURRENT_KEY, DB_VERSION)
Expand Down Expand Up @@ -1054,55 +1068,58 @@ def search_library(

return res

def search_tags(self, name: str | None, limit: int = 100) -> list[set[Tag]]:
def search_tags(self, name: str | None, limit: int = 100) -> tuple[list[Tag], list[Tag]]:
"""Return a list of Tag records matching the query."""
name = name or ""
name = name.lower()

def sort_key(text: str):
priority = text.startswith(name)
p_ordering = len(text) if priority else sys.maxsize
return (not priority, p_ordering, text)

with Session(self.engine) as session:
query = select(Tag).outerjoin(TagAlias).order_by(func.lower(Tag.name))
query = query.options(
selectinload(Tag.parent_tags),
selectinload(Tag.aliases),
)
if limit > 0:
query = query.limit(limit)
query = select(Tag.id, Tag.name)

if limit > 0 and not name:
query = query.order_by(Tag.name).limit(limit)

if name:
query = query.where(
or_(
Tag.name.icontains(name),
Tag.shorthand.icontains(name),
TagAlias.name.icontains(name),
)
)

direct_tags = set(session.scalars(query))
ancestor_tag_ids: list[Tag] = []
for tag in direct_tags:
ancestor_tag_ids.extend(
list(session.scalars(TAG_CHILDREN_QUERY, {"tag_id": tag.id}))
)
tags = list(session.execute(query))

ancestor_tags = session.scalars(
select(Tag)
.where(Tag.id.in_(ancestor_tag_ids))
.options(selectinload(Tag.parent_tags), selectinload(Tag.aliases))
)
if name:
query = select(TagAlias.tag_id, TagAlias.name).where(TagAlias.name.icontains(name))
tags.extend(session.execute(query))

res = [
direct_tags,
{at for at in ancestor_tags if at not in direct_tags},
]
tags.sort(key=lambda t: sort_key(t[1]))
# Use order from Tag.name or TagAlias.name depending on which comes first for each tag.
# Value=0 to avoid unnecessary copying of tag names.
tag_ids = list(dict((id, 0) for id, _ in tags).keys())

logger.info(
"searching tags",
search=name,
limit=limit,
statement=str(query),
results=len(res),
results=len(tag_ids),
)

session.expunge_all()
if limit <= 0:
limit = len(tag_ids)
tag_ids = tag_ids[:limit]

return res
hierarchy = self.get_tag_hierarchy(tag_ids)
direct_tags = [hierarchy.pop(id) for id in tag_ids]
ancestor_tags = list(hierarchy.values())
ancestor_tags.sort(key=lambda t: sort_key(t.name))
return direct_tags, ancestor_tags

def update_entry_path(self, entry_id: int | Entry, path: Path) -> bool:
"""Set the path field of an entry.
Expand Down
25 changes: 3 additions & 22 deletions src/tagstudio/qt/mixed/tag_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,32 +218,13 @@ def update_tags(self, query: str | None = None):
self.scroll_layout.takeAt(self.scroll_layout.count() - 1).widget().deleteLater()
self.create_button_in_layout = False

# Get results for the search query
query_lower = "" if not query else query.lower()
# Only use the tag limit if it's an actual number (aka not "All Tags")
tag_limit = TagSearchPanel.tag_limit if isinstance(TagSearchPanel.tag_limit, int) else -1
tag_results: list[set[Tag]] = self.lib.search_tags(name=query, limit=tag_limit)
if self.exclude:
tag_results[0] = {t for t in tag_results[0] if t.id not in self.exclude}
tag_results[1] = {t for t in tag_results[1] if t.id not in self.exclude}

# Sort and prioritize the results
results_0 = list(tag_results[0])
results_0.sort(key=lambda tag: tag.name.lower())
results_1 = list(tag_results[1])
results_1.sort(key=lambda tag: tag.name.lower())
raw_results = list(results_0 + results_1)
priority_results: set[Tag] = set()
all_results: list[Tag] = []
direct_tags, ancestor_tags = self.lib.search_tags(name=query, limit=tag_limit)

if query and query.strip():
for tag in raw_results:
if tag.name.lower().startswith(query_lower):
priority_results.add(tag)
all_results = [t for t in direct_tags if t.id not in self.exclude]
all_results.extend(t for t in ancestor_tags if t.id not in self.exclude)

all_results = sorted(list(priority_results), key=lambda tag: len(tag.name)) + [
r for r in raw_results if r not in priority_results
]
if tag_limit > 0:
all_results = all_results[:tag_limit]

Expand Down
8 changes: 4 additions & 4 deletions tests/test_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,10 @@ def test_library_search(library: Library, entry_full: Entry):
def test_tag_search(library: Library):
tag = library.tags[0]

assert library.search_tags(tag.name.lower())
assert library.search_tags(tag.name.upper())
assert library.search_tags(tag.name[2:-2])
assert library.search_tags(tag.name * 2) == [set(), set()]
assert library.search_tags(tag.name.lower())[0]
assert library.search_tags(tag.name.upper())[0]
assert library.search_tags(tag.name[2:-2])[0]
assert library.search_tags(tag.name * 2) == ([], [])


def test_get_entry(library: Library, entry_min: Entry):
Expand Down