diff --git a/swingmusic/api/__init__.py b/swingmusic/api/__init__.py index 162ca2fd..87b4b43b 100644 --- a/swingmusic/api/__init__.py +++ b/swingmusic/api/__init__.py @@ -34,6 +34,8 @@ auth, stream, backup_and_restore, + albumartist, # Added this line + ) # TODO: Move this description to a separate file @@ -113,6 +115,8 @@ def user_lookup_callback(_jwt_header, jwt_data): app.register_api(lyrics.api) app.register_api(backup_and_restore.api) app.register_api(collections.api) + app.register_api(albumartist.api) # Added this line + # Plugins app.register_api(plugins.api) app.register_api(lyrics_plugin.api) diff --git a/swingmusic/api/albumartist.py b/swingmusic/api/albumartist.py new file mode 100644 index 00000000..9acf7fc4 --- /dev/null +++ b/swingmusic/api/albumartist.py @@ -0,0 +1,249 @@ +# swingmusic/api/albumartist.py (New file) +""" +Contains all the album artist(s) routes. +""" + +import math +import random +from datetime import datetime +from itertools import groupby +from typing import Any + +from flask_openapi3 import APIBlueprint, Tag +from pydantic import Field +from pydantic import BaseModel +from swingmusic.api.apischemas import ( + AlbumLimitSchema, + ArtistHashSchema, + ArtistLimitSchema, + TrackLimitSchema, +) + +from swingmusic.config import UserConfig +from swingmusic.db.userdata import SimilarArtistTable +from swingmusic.lib.sortlib import sort_tracks + +from swingmusic.serializers.album import serialize_for_card_many +from swingmusic.serializers.artist import serialize_for_cards, serialize_for_card +from swingmusic.serializers.track import serialize_track + +from swingmusic.store.albums import AlbumStore +from swingmusic.store.albumartists import AlbumArtistStore +from swingmusic.store.tracks import TrackStore +from swingmusic.utils.stats import get_track_group_stats + +bp_tag = Tag(name="Album Artist", description="Single album artist") +api = APIBlueprint("albumartist", __name__, url_prefix="/albumartist", abp_tags=[bp_tag]) + + +class GetAlbumArtistAlbumsQuery(AlbumLimitSchema): + all: bool = Field( + description="Whether to ignore albumlimit and return all albums", default=False + ) + +class GetAlbumArtistQuery(TrackLimitSchema, GetAlbumArtistAlbumsQuery): + albumlimit: int = Field(7, description="The number of albums to return") + +class SearchAlbumArtistsQuery(BaseModel): + query: str = Field(default="", description="Search query for album artist names") + limit: int = Field(default=50, description="Maximum number of results to return") + +@api.get("/") +def get_album_artist(path: ArtistHashSchema, query: GetAlbumArtistQuery): + """ + Get album artist + + Returns album artist data, tracks and genres for the given artisthash. + """ + artisthash = path.artisthash + limit = query.limit + + entry = AlbumArtistStore.albumartistmap.get(artisthash) + + if entry is None: + return {"error": "Album artist not found"}, 404 + + tracks = AlbumArtistStore.get_album_artist_tracks(artisthash) + tracks = sort_tracks(tracks, key="playcount", reverse=True) + tcount = len(tracks) + + artist = entry.artist + if artist.albumcount == 0 and tcount < 10: + limit = tcount + + try: + year = datetime.fromtimestamp(artist.date).year + except ValueError: + year = 0 + + genres = [*artist.genres] + decade = None + + if year: + decade = math.floor(year / 10) * 10 + decade = str(decade)[2:] + "s" + + if decade: + genres.insert(0, {"name": decade, "genrehash": decade}) + + stats = get_track_group_stats(tracks) + duration = sum(t.duration for t in tracks) if tracks else 0 + tracks = tracks[:limit] if (limit and limit != -1) else tracks + tracks = [ + { + **serialize_track(t), + "help_text": ( + "unplayed" + if t.playcount == 0 + else f"{t.playcount} play{'' if t.playcount == 1 else 's'}" + ), + } + for t in tracks + ] + + query.limit = query.albumlimit + albums = get_album_artist_albums(path, query) + + return { + "artist": { + **serialize_for_card(artist), + "duration": duration, + "trackcount": tcount, + "albumcount": artist.albumcount, + "genres": genres, + "is_favorite": artist.is_favorite, + }, + "tracks": tracks, + "albums": albums, + "stats": stats, + } + + +@api.get("//albums") +def get_album_artist_albums(path: ArtistHashSchema, query: GetAlbumArtistAlbumsQuery): + """ + Get album artist albums. + """ + return_all = query.all + artisthash = path.artisthash + limit = query.limit + + entry = AlbumArtistStore.albumartistmap.get(artisthash) + + if entry is None: + return {"error": "Album artist not found"}, 404 + + albums = AlbumStore.get_albums_by_hashes(entry.albumhashes) + tracks = TrackStore.get_tracks_by_trackhashes(entry.trackhashes) + + # Get any missing albums from tracks + missing_albumhashes = { + t.albumhash for t in tracks if t.albumhash not in {a.albumhash for a in albums} + } + + albums.extend(AlbumStore.get_albums_by_hashes(missing_albumhashes)) + albumdict = {a.albumhash: a for a in albums} + + config = UserConfig() + albumgroups = groupby(tracks, key=lambda t: t.albumhash) + for albumhash, tracks in albumgroups: + album = albumdict.get(albumhash) + + if album: + album.check_type(list(tracks), config.showAlbumsAsSingles) + + albums = [a for a in albumdict.values()] + all_albums = sorted(albums, key=lambda a: a.date, reverse=True) + + res: dict[str, Any] = { + "albums": [], + "appearances": [], + "compilations": [], + "singles_and_eps": [], + } + + for album in all_albums: + if album.type == "single" or album.type == "ep": + res["singles_and_eps"].append(album) + elif album.type == "compilation": + res["compilations"].append(album) + elif ( + album.albumhash in missing_albumhashes + or artisthash not in album.artisthashes + ): + res["appearances"].append(album) + else: + res["albums"].append(album) + + if return_all: + limit = len(all_albums) + + # loop through the res dict and serialize the albums + for key, value in res.items(): + res[key] = serialize_for_card_many(value[:limit]) + + res["artistname"] = entry.artist.name + return res + + +@api.get("//tracks") +def get_all_album_artist_tracks(path: ArtistHashSchema): + """ + Get album artist tracks + + Returns all tracks by a given album artist. + """ + tracks = AlbumArtistStore.get_album_artist_tracks(path.artisthash) + tracks = sort_tracks(tracks, key="playcount", reverse=True) + tracks = [ + { + **serialize_track(t), + "help_text": ( + "unplayed" + if t.playcount == 0 + else f"{t.playcount} play{'' if t.playcount == 1 else 's'}" + ), + } + for t in tracks + ] + + return tracks + + +@api.get("//similar") +def get_similar_album_artists(path: ArtistHashSchema, query: ArtistLimitSchema): + """ + Get similar album artists. + """ + limit = query.limit + result = SimilarArtistTable.get_by_hash(path.artisthash) + + if result is None: + return [] + + # Get similar artists from both regular artists and album artists + similar_hashes = result.get_artist_hash_set() + similar = AlbumArtistStore.get_artists_by_hashes(similar_hashes) + + if len(similar) > limit: + similar = random.sample(similar, min(limit, len(similar))) + + return serialize_for_cards(similar[:limit]) + +@api.get("/search") +def search_album_artists(query: SearchAlbumArtistsQuery): + """ + Search album artists by name. + """ + if not query.query: + return [] + + results = AlbumArtistStore.search_album_artists(query.query, query.limit) + return serialize_for_cards(results) + +@api.get("/stats") +def get_album_artist_stats(): + """ + Get album artist statistics. + """ + return AlbumArtistStore.get_stats() \ No newline at end of file diff --git a/swingmusic/api/getall/__init__.py b/swingmusic/api/getall/__init__.py index 2127b49b..f809a284 100644 --- a/swingmusic/api/getall/__init__.py +++ b/swingmusic/api/getall/__init__.py @@ -1,3 +1,4 @@ +# swingmusic/api/getall/__init__.py (Updated) from flask_openapi3 import Tag from flask_openapi3 import APIBlueprint from pydantic import BaseModel, Field @@ -6,6 +7,7 @@ from swingmusic.api.apischemas import GenericLimitSchema from swingmusic.store.albums import AlbumStore from swingmusic.store.artists import ArtistStore +from swingmusic.store.albumartists import AlbumArtistStore from swingmusic.serializers.album import serialize_for_card as serialize_album from swingmusic.serializers.artist import serialize_for_card as serialize_artist @@ -16,6 +18,7 @@ seconds_to_time_string, timestamp_to_time_passed, ) +from swingmusic.utils.article_utils import get_sort_key bp_tag = Tag(name="Get all", description="List all items") api = APIBlueprint("getall", __name__, url_prefix="/getall", abp_tags=[bp_tag]) @@ -42,7 +45,7 @@ class GetAllItemsQuery(GenericLimitSchema): class GetAllItemsPath(BaseModel): itemtype: str = Field( - description="The type of items to return (albums | artists)", + description="The type of items to return (albums | artists | albumartists)", example="albums", default="albums", ) @@ -53,22 +56,27 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery): """ Get all items - Used to show all albums or artists in the library + Used to show all albums, artists, or album artists in the library Sort keys: - - - Both albums and artists: `duration`, `created_date`, `playcount`, `playduration`, `lastplayed`, `trackcount` + Both albums and artists: `duration`, `created_date`, `playcount`, `playduration`, `lastplayed`, `trackcount` Albums only: `title`, `albumartists`, `date` Artists only: `name`, `albumcount` + Album Artists only: `name`, `albumcount` """ is_albums = path.itemtype == "albums" is_artists = path.itemtype == "artists" + is_album_artists = path.itemtype == "albumartists" if is_albums: items = AlbumStore.get_flat_list() elif is_artists: items = ArtistStore.get_flat_list() + elif is_album_artists: + items = AlbumArtistStore.get_flat_list() + else: + return {"error": "Invalid item type. Must be 'albums', 'artists', or 'albumartists'"}, 400 total = len(items) @@ -87,17 +95,28 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery): sort_is_date = is_albums and sort == "date" sort_is_artist = is_albums and sort == "albumartists" - sort_is_artist_trackcount = is_artists and sort == "trackcount" - sort_is_artist_albumcount = is_artists and sort == "albumcount" + sort_is_artist_trackcount = (is_artists or is_album_artists) and sort == "trackcount" + sort_is_artist_albumcount = (is_artists or is_album_artists) and sort == "albumcount" + sort_is_artist_name = (is_artists or is_album_artists) and sort == "name" lambda_sort = lambda x: getattr(x, sort) lambda_sort_casefold = lambda x: getattr(x, sort).casefold() + # Special handling for different sort types if sort_is_artist: lambda_sort = lambda x: getattr(x, sort)[0]["name"].casefold() + elif sort_is_artist_name: + # Use article-aware sorting for artist names + lambda_sort = lambda x: get_sort_key(getattr(x, sort)) + lambda_sort_casefold = lambda_sort # Already handles casefolding + # Apply sorting try: - sorted_items = sorted(items, key=lambda_sort_casefold, reverse=reverse) + if sort_is_artist_name: + # Use the article-aware sorting function + sorted_items = sorted(items, key=lambda_sort, reverse=reverse) + else: + sorted_items = sorted(items, key=lambda_sort_casefold, reverse=reverse) except AttributeError: sorted_items = sorted(items, key=lambda_sort, reverse=reverse) @@ -107,46 +126,51 @@ def get_all_items(path: GetAllItemsPath, query: GetAllItemsQuery): for item in items: item_dict = serialize_album(item) if is_albums else serialize_artist(item) - if sort_is_date: - item_dict["help_text"] = datetime.fromtimestamp(item.date).year + if is_albums: + item_dict["help_text"] = f"{item.trackcount} track{'' if item.trackcount == 1 else 's'}" + item_dict["time"] = timestamp_to_time_passed(item.created_date) + else: # artists or album artists + tracks_text = f"{item.trackcount} track{'' if item.trackcount == 1 else 's'}" + albums_text = f"{item.albumcount} album{'' if item.albumcount == 1 else 's'}" + item_dict["help_text"] = f"{albums_text} • {tracks_text}" + item_dict["time"] = timestamp_to_time_passed(item.created_date) - if sort_is_create_date: - date = create_new_date(datetime.fromtimestamp(item.created_date)) - timeago = date_string_to_time_passed(date) - item_dict["help_text"] = timeago - - if sort_is_count: - item_dict["help_text"] = ( - f"{format_number(item.trackcount)} track{'' if item.trackcount == 1 else 's'}" - ) - - if sort_is_duration: - item_dict["help_text"] = seconds_to_time_string(item.duration) - - if sort_is_artist_trackcount: - item_dict["help_text"] = ( - f"{format_number(item.trackcount)} track{'' if item.trackcount == 1 else 's'}" - ) - - if sort_is_artist_albumcount: - item_dict["help_text"] = ( - f"{format_number(item.albumcount)} album{'' if item.albumcount == 1 else 's'}" - ) - - if sort_is_playcount: - item_dict["help_text"] = ( - f"{format_number(item.playcount)} play{'' if item.playcount == 1 else 's'}" - ) + album_list.append(item_dict) - if sort_is_lastplayed: - if item.playduration == 0: - item_dict["help_text"] = "Never played" - else: - item_dict["help_text"] = timestamp_to_time_passed(item.lastplayed) + # Calculate pagination info + has_more = (start + limit) < total + next_start = start + limit if has_more else None - if sort_is_playduration: - item_dict["help_text"] = seconds_to_time_string(item.playduration) + return { + "items": album_list, + "total": total, + "start": start, + "limit": limit, + "has_more": has_more, + "next_start": next_start, + } - album_list.append(item_dict) - return {"items": album_list, "total": total} +@api.get("/stats") +def get_library_stats(): + """ + Get library statistics + + Returns counts for albums, artists, album artists, and tracks + """ + albums = AlbumStore.get_flat_list() + artists = ArtistStore.get_flat_list() + album_artists = AlbumArtistStore.get_flat_list() + + # Calculate total tracks + total_tracks = sum(album.trackcount for album in albums) + total_duration = sum(getattr(album, 'duration', 0) for album in albums) + + return { + "albums": len(albums), + "artists": len(artists), + "album_artists": len(album_artists), + "tracks": total_tracks, + "total_duration": total_duration, + "duration_formatted": seconds_to_time_string(total_duration) if total_duration else "0:00" + } \ No newline at end of file diff --git a/swingmusic/lib/index.py b/swingmusic/lib/index.py index 28c0a59f..5f7e7697 100644 --- a/swingmusic/lib/index.py +++ b/swingmusic/lib/index.py @@ -1,3 +1,4 @@ +# swingmusic/lib/index.py (Updated) import gc from time import time from swingmusic.lib.mapstuff import ( @@ -11,6 +12,7 @@ from swingmusic.lib.tagger import IndexTracks from swingmusic.store.albums import AlbumStore from swingmusic.store.artists import ArtistStore +from swingmusic.store.albumartists import AlbumArtistStore # New import from swingmusic.store.folder import FolderStore from swingmusic.store.tracks import TrackStore from swingmusic.utils.threading import background @@ -24,6 +26,7 @@ def __init__(self) -> None: TrackStore.load_all_tracks(key) AlbumStore.load_albums(key) ArtistStore.load_artists(key) + AlbumArtistStore.load_album_artists(key) # Load album artists FolderStore.load_filepaths() # NOTE: Rebuild recently added items on the homepage store @@ -42,4 +45,4 @@ def __init__(self) -> None: @background def index_everything(): - return IndexEverything() + return IndexEverything() \ No newline at end of file diff --git a/swingmusic/store/albumartists.py b/swingmusic/store/albumartists.py new file mode 100644 index 00000000..c8e84bd7 --- /dev/null +++ b/swingmusic/store/albumartists.py @@ -0,0 +1,230 @@ +# swingmusic/store/albumartists.py +""" +Store for managing album artists (different from track artists). +Album artists represent the main artist(s) for an album, not individual track artists. +""" + +from typing import Iterable +from collections import defaultdict + +from swingmusic.models import Artist +from swingmusic.utils.auth import get_current_userid +from swingmusic.utils.hashing import create_hash +from swingmusic.store.tracks import TrackStore + +ALBUMARTIST_LOAD_KEY = "" + + +class AlbumArtistMapEntry: + def __init__( + self, artist: Artist, albumhashes: set[str], trackhashes: set[str] + ) -> None: + self.artist = artist + self.albumhashes: set[str] = albumhashes + self.trackhashes: set[str] = trackhashes + + def increment_playcount(self, duration: int, timestamp: int, playcount: int = 1): + self.artist.lastplayed = timestamp + self.artist.playduration += duration + self.artist.playcount += playcount + + def toggle_favorite_user(self, userid: int | None = None): + if userid is None: + userid = get_current_userid() + + self.artist.toggle_favorite_user(userid) + + def set_color(self, color: str): + self.artist.color = color + + +class AlbumArtistStore: + albumartistmap: dict[str, AlbumArtistMapEntry] = {} + + @classmethod + def load_album_artists(cls, instance_key: str): + """ + Loads all album artists from track data into the store. + """ + global ALBUMARTIST_LOAD_KEY + ALBUMARTIST_LOAD_KEY = instance_key + + print("Loading album artists... ", end="") + cls.albumartistmap.clear() + + # Get all tracks to extract album artists + tracks = TrackStore.get_flat_list() + + # Dictionary to aggregate data for each album artist + album_artists_data = defaultdict(lambda: { + 'name': '', + 'albumhashes': set(), + 'trackhashes': set(), + 'albums': set(), + 'tracks': set(), + 'genres': [], + 'playcount': 0, + 'playduration': 0, + 'lastplayed': 0, + 'date': 0 + }) + + # Process tracks to extract album artist information + for track in tracks: + if instance_key != ALBUMARTIST_LOAD_KEY: + return + + # Process each album artist for this track + for albumartist in track.albumartists: + artisthash = albumartist['artisthash'] + artist_name = albumartist['name'] + + # Update artist data + artist_data = album_artists_data[artisthash] + if not artist_data['name']: + artist_data['name'] = artist_name + + artist_data['albumhashes'].add(track.albumhash) + artist_data['trackhashes'].add(track.trackhash) + artist_data['albums'].add(track.albumhash) + artist_data['tracks'].add(track.trackhash) + + # Aggregate genres + if track.genres: + artist_data['genres'].extend(track.genres) + + # Aggregate play stats + artist_data['playcount'] += track.playcount + artist_data['playduration'] += track.playduration + + # Update last played timestamp + if track.lastplayed > artist_data['lastplayed']: + artist_data['lastplayed'] = track.lastplayed + + # Update date (use earliest/oldest date) + if artist_data['date'] == 0 or (track.date > 0 and track.date < artist_data['date']): + artist_data['date'] = track.date + + # Create Artist objects and populate the store + for artisthash, data in album_artists_data.items(): + if instance_key != ALBUMARTIST_LOAD_KEY: + return + + # Remove duplicate genres + unique_genres = [] + seen_genres = set() + for genre in data['genres']: + if isinstance(genre, dict) and genre.get('genrehash') not in seen_genres: + unique_genres.append(genre) + seen_genres.add(genre['genrehash']) + elif isinstance(genre, str) and genre not in seen_genres: + # Handle simple string genres + genre_hash = create_hash(genre) + unique_genres.append({'name': genre, 'genrehash': genre_hash}) + seen_genres.add(genre_hash) + + # Create Artist object + artist = Artist( + artisthash=artisthash, + name=data['name'], + albumcount=len(data['albums']), + trackcount=len(data['tracks']), + playcount=data['playcount'], + playduration=data['playduration'], + lastplayed=data['lastplayed'], + date=data['date'], + genres=unique_genres, + genrehashes=' '.join([g.get('genrehash', '') for g in unique_genres]), + image='', # Will be populated later if needed + color='', # Will be populated later if needed + duration=0, # Calculate if needed + created_date=data['date'] + ) + + # Add to store + cls.albumartistmap[artisthash] = AlbumArtistMapEntry( + artist=artist, + albumhashes=data['albumhashes'], + trackhashes=data['trackhashes'] + ) + + print("Done!") + + @classmethod + def get_flat_list(cls): + """ + Returns a flat list of all album artists. + """ + return [entry.artist for entry in cls.albumartistmap.values()] + + @classmethod + def get_artist_by_hash(cls, artisthash: str) -> Artist | None: + """ + Returns an album artist by its hash. + """ + entry = cls.albumartistmap.get(artisthash) + return entry.artist if entry else None + + @classmethod + def get_artists_by_hashes(cls, artisthashes: Iterable[str]) -> list[Artist]: + """ + Returns album artists by their hashes. + """ + artists = [] + for artisthash in artisthashes: + entry = cls.albumartistmap.get(artisthash) + if entry is not None: + artists.append(entry.artist) + return artists + + @classmethod + def get_album_artist_tracks(cls, artisthash: str): + """ + Returns all tracks for a given album artist. + """ + entry = cls.albumartistmap.get(artisthash) + if entry is None: + return [] + + return TrackStore.get_tracks_by_trackhashes(entry.trackhashes) + + @classmethod + def get_albums_by_artisthash(cls, artisthash: str): + """ + Returns all albums for a given album artist. + """ + from swingmusic.store.albums import AlbumStore + + entry = cls.albumartistmap.get(artisthash) + if entry is None: + return [] + + return AlbumStore.get_albums_by_hashes(entry.albumhashes) + + @classmethod + def search_album_artists(cls, query: str, limit: int = 50) -> list[Artist]: + """ + Search album artists by name. + """ + query_lower = query.lower() + results = [] + + for entry in cls.albumartistmap.values(): + if query_lower in entry.artist.name.lower(): + results.append(entry.artist) + + # Sort by name and limit results + results.sort(key=lambda x: x.name.lower()) + return results[:limit] + + @classmethod + def get_stats(cls) -> dict: + """ + Get statistics about album artists. + """ + return { + 'total_album_artists': len(cls.albumartistmap), + 'artists_with_albums': sum(1 for entry in cls.albumartistmap.values() if len(entry.albumhashes) > 0), + 'total_albums': sum(len(entry.albumhashes) for entry in cls.albumartistmap.values()), + 'total_tracks': sum(len(entry.trackhashes) for entry in cls.albumartistmap.values()) + } \ No newline at end of file diff --git a/swingmusic/utils/article_utils.py b/swingmusic/utils/article_utils.py new file mode 100644 index 00000000..1b1d20c8 --- /dev/null +++ b/swingmusic/utils/article_utils.py @@ -0,0 +1,112 @@ +# swingmusic/utils/article_utils.py +""" +Utility functions for handling articles in artist and album names for sorting purposes. +""" + +import re +from typing import List + +# Common articles in various languages +ARTICLES = [ + # English + "the", "a", "an", + # Spanish + "el", "la", "los", "las", "un", "una", "unos", "unas", + # French + "le", "la", "les", "un", "une", "des", + # German + "der", "die", "das", "ein", "eine", "einen", "einem", "einer", + # Italian + "il", "lo", "la", "gli", "le", "un", "uno", "una", + # Portuguese + "o", "a", "os", "as", "um", "uma", "uns", "umas", + # Dutch + "de", "het", "een", +] + + +def remove_articles_for_sorting(name: str) -> str: + """ + Removes leading articles from a name for sorting purposes. + + Examples: + "The Beatles" -> "Beatles" + "A Perfect Circle" -> "Perfect Circle" + "The B-52's" -> "B-52's" + "Los Angeles" -> "Angeles" # if it's actually an artist name + + Args: + name: The artist or album name + + Returns: + The name with leading articles removed, or the original name if no articles found + """ + if not name or not isinstance(name, str): + return name + + # Create pattern to match articles at the beginning of the string + # followed by a space or hyphen + articles_pattern = r'^(' + '|'.join(re.escape(article) for article in ARTICLES) + r')(\s+|-)' + + # Remove the article (case-insensitive) + result = re.sub(articles_pattern, '', name.strip(), flags=re.IGNORECASE) + + # If result is empty or just whitespace, return original name + return result.strip() if result.strip() else name + + +def get_sort_key(name: str) -> str: + """ + Generate a sort key for a name by removing articles and converting to lowercase. + + Args: + name: The name to generate a sort key for + + Returns: + A normalized sort key + """ + if not name: + return "" + + # Remove articles and convert to lowercase for case-insensitive sorting + return remove_articles_for_sorting(name).lower() + + +def get_artist_sort_key(artist_data) -> str: + """ + Get sort key for artist data (handles both dict and object formats). + + Args: + artist_data: Either an Artist object with .name attribute or dict with 'name' key + + Returns: + Sort key for the artist + """ + if hasattr(artist_data, 'name'): + return get_sort_key(artist_data.name) + elif isinstance(artist_data, dict) and 'name' in artist_data: + return get_sort_key(artist_data['name']) + else: + return "" + + +# Test cases for validation +if __name__ == "__main__": + test_cases = [ + ("The Beatles", "Beatles"), + ("The B-52's", "B-52's"), + ("A Perfect Circle", "Perfect Circle"), + ("An Officer and a Gentleman", "Officer and a Gentleman"), + ("Beatles", "Beatles"), # No change + ("", ""), # Empty string + ("The", "The"), # Just an article - should return original + ("Los Tigres del Norte", "Tigres del Norte"), + ("Le Tigre", "Tigre"), + ("Der Eisendrache", "Eisendrache"), + ] + + print("Testing article removal:") + for input_name, expected in test_cases: + result = remove_articles_for_sorting(input_name) + status = "✓" if result == expected else "✗" + print(f"{status} '{input_name}' -> '{result}' (expected: '{expected}')") \ No newline at end of file diff --git a/tests/test_article_sorting.py b/tests/test_article_sorting.py new file mode 100644 index 00000000..95ee047a --- /dev/null +++ b/tests/test_article_sorting.py @@ -0,0 +1,165 @@ +# test_article_sorting.py +""" +Test script to validate the article handling functionality works correctly. +Run this to verify that artists are sorted properly with articles removed. +""" + +import sys +import os + +# Add the project root to the path so we can import our modules +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) + +from swingmusic.utils.article_utils import remove_articles_for_sorting, get_sort_key + + +def test_article_removal(): + """Test the article removal functionality.""" + test_cases = [ + # (input, expected_output) + ("The Beatles", "Beatles"), + ("The B-52's", "B-52's"), + ("A Perfect Circle", "Perfect Circle"), + ("An Officer and a Gentleman", "Officer and a Gentleman"), + ("Beatles", "Beatles"), # No change + ("", ""), # Empty string + ("The", "The"), # Just an article - should return original + ("Los Tigres del Norte", "Tigres del Norte"), + ("Le Tigre", "Tigre"), + ("Der Eisendrache", "Eisendrache"), + ("AC/DC", "AC/DC"), # No articles + ("The Who", "Who"), + ("A-ha", "ha"), + ("The 1975", "1975"), + ("U2", "U2"), # No change + ("R.E.M.", "R.E.M."), # No change + ("The Smashing Pumpkins", "Smashing Pumpkins"), + ("Sublime", "Sublime"), # No change (from your image) + ] + + print("Testing article removal:") + print("=" * 50) + + passed = 0 + failed = 0 + + for input_name, expected in test_cases: + result = remove_articles_for_sorting(input_name) + status = "✓" if result == expected else "✗" + + if result == expected: + passed += 1 + else: + failed += 1 + + print(f"{status} '{input_name}' -> '{result}' (expected: '{expected}')") + + print("\n" + "=" * 50) + print(f"Results: {passed} passed, {failed} failed") + return failed == 0 + + +def test_sorting_behavior(): + """Test that the sorting works as expected with the article removal.""" + + # Artists from your screenshot plus some test cases + artist_names = [ + "Soundtrack", + "spoken intro", + "Stephen Lynch", + "Stephen Trask", + "Steve Howe", + "Sublime", + "Sublime featuring Mad Lion", + "Swedish House Mafia", + "System of a Down", + "Talking Heads", + "The Alan Parsons Project", + "The Avett Brothers", + "The B-52's", + "The Cars", + "The Beatles", + "The Who", + "The Smashing Pumpkins", + "A Perfect Circle", + "An Officer and a Gentleman", + "Beatles", # To test that it comes after "The Beatles" -> "Beatles" + ] + + print("\nTesting sorting behavior:") + print("=" * 50) + + # Sort using our article-aware function + sorted_artists = sorted(artist_names, key=get_sort_key) + + print("Sorted order (with articles removed for sorting):") + for i, artist in enumerate(sorted_artists, 1): + sort_key = get_sort_key(artist) + print(f"{i:2d}. {artist:<30} (sort key: '{sort_key}')") + + # Verify that "The B-52's" comes under "B" not "T" + b52_index = next(i for i, artist in enumerate(sorted_artists) if artist == "The B-52's") + beatles_index = next(i for i, artist in enumerate(sorted_artists) if artist == "The Beatles") + + print(f"\nKey tests:") + print(f"- 'The B-52's' appears at position {b52_index + 1}") + print(f"- 'The Beatles' appears at position {beatles_index + 1}") + + # Check that B-52's comes before most "S" entries + sublime_index = next(i for i, artist in enumerate(sorted_artists) if artist == "Sublime") + b52s_before_sublime = b52_index < sublime_index + print(f"- 'The B-52's' comes before 'Sublime': {b52s_before_sublime} ✓" if b52s_before_sublime else f"- 'The B-52's' comes before 'Sublime': {b52s_before_sublime} ✗") + + return True + + +def test_edge_cases(): + """Test edge cases and potential issues.""" + + print("\nTesting edge cases:") + print("=" * 50) + + edge_cases = [ + (None, ""), # None input + ("", ""), # Empty string + (" ", ""), # Whitespace only + ("The", "The"), # Just an article + ("A", "A"), # Single letter that's an article + ("The The", "The"), # Band name "The The" + ("Los Los", "Los"), # Repeated articles + ("El Niño", "Niño"), # Spanish article + ("Die Antwoord", "Antwoord"), # German article + ("Le Loup", "Loup"), # French article + ] + + for input_val, expected in edge_cases: + try: + result = remove_articles_for_sorting(input_val) if input_val is not None else "" + status = "✓" if result == expected else "✗" + print(f"{status} {repr(input_val)} -> {repr(result)} (expected: {repr(expected)})") + except Exception as e: + print(f"✗ {repr(input_val)} -> ERROR: {e}") + + +if __name__ == "__main__": + print("Article Sorting Test Suite") + print("=" * 50) + + success = True + + try: + success &= test_article_removal() + test_sorting_behavior() + test_edge_cases() + + if success: + print(f"\n🎉 All tests passed! Article-aware sorting is working correctly.") + print("The B-52's should now appear under 'B' instead of 'T' in your artist list.") + else: + print(f"\n❌ Some tests failed. Please check the implementation.") + + except ImportError as e: + print(f"Import error: {e}") + print("Make sure you're running this from the correct directory and the module paths are correct.") + except Exception as e: + print(f"Unexpected error: {e}") \ No newline at end of file