Skip to content

Commit

Permalink
Replace munkres with lapjv for track assignment (#5564)
Browse files Browse the repository at this point in the history
Fixes #5207.

This PR replaces the `munkres` library with `lap` (Linear Assignment
Problem solver) for computing optimal track assignments during the
auto-tagging process. The main changes are:

- Remove `munkres` dependency and add `lap` and `numpy` dependencies
- Refactor the track assignment code to use `lap.lapjv()` instead of
`Munkres().compute()`
- Simplify cost matrix construction using list comprehension
- Move config value reading outside of `track_distance` function to
improve performance

The motivation for this change comes from benchmark comparisons showing
that LAPJV (implemented in the `lap` library) significantly outperforms
the Munkres/Hungarian algorithm for the linear assignment problem. See
detailed benchmarks at: https://github.com/berhane/LAP-solvers

The change should provide better performance for track matching,
especially with larger albums, while maintaining the same assignment
results.

## Testing Notes
- All existing tests pass without modification
- Track assignments produce identical results 
- No behavioral changes in auto-tagging
  • Loading branch information
snejus authored Dec 27, 2024
2 parents 2277e2a + 4c8d75f commit faf7529
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 39 deletions.
52 changes: 27 additions & 25 deletions beets/autotag/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import re
from collections.abc import Iterable, Sequence
from enum import IntEnum
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, Union, cast
from functools import cache
from typing import TYPE_CHECKING, Any, NamedTuple, TypeVar, cast

from munkres import Munkres
import lap
import numpy as np

from beets import config, logging, plugins
from beets.autotag import (
Expand Down Expand Up @@ -126,21 +128,15 @@ def assign_items(
of objects of the two types.
"""
# Construct the cost matrix.
costs: list[list[Distance]] = []
for item in items:
row = []
for track in tracks:
row.append(track_distance(item, track))
costs.append(row)

costs = [[float(track_distance(i, t)) for t in tracks] for i in items]
# Find a minimum-cost bipartite matching.
log.debug("Computing track assignment...")
matching = Munkres().compute(costs)
cost, _, assigned_idxs = lap.lapjv(np.array(costs), extend_cost=True)
log.debug("...done.")

# Produce the output matching.
mapping = {items[i]: tracks[j] for (i, j) in matching}
extra_items = list(set(items) - set(mapping.keys()))
mapping = {items[i]: tracks[t] for (t, i) in enumerate(assigned_idxs)}
extra_items = list(set(items) - mapping.keys())
extra_items.sort(key=lambda i: (i.disc, i.track, i.title))
extra_tracks = list(set(tracks) - set(mapping.values()))
extra_tracks.sort(key=lambda t: (t.index, t.title))
Expand All @@ -154,6 +150,18 @@ def track_index_changed(item: Item, track_info: TrackInfo) -> bool:
return item.track not in (track_info.medium_index, track_info.index)


@cache
def get_track_length_grace() -> float:
"""Get cached grace period for track length matching."""
return config["match"]["track_length_grace"].as_number()


@cache
def get_track_length_max() -> float:
"""Get cached maximum track length for track length matching."""
return config["match"]["track_length_max"].as_number()


def track_distance(
item: Item,
track_info: TrackInfo,
Expand All @@ -162,23 +170,17 @@ def track_distance(
"""Determines the significance of a track metadata change. Returns a
Distance object. `incl_artist` indicates that a distance component should
be included for the track artist (i.e., for various-artist releases).
``track_length_grace`` and ``track_length_max`` configuration options are
cached because this function is called many times during the matching
process and their access comes with a performance overhead.
"""
dist = hooks.Distance()

# Length.
if track_info.length:
item_length = cast(float, item.length)
track_length_grace = cast(
Union[float, int],
config["match"]["track_length_grace"].as_number(),
)
track_length_max = cast(
Union[float, int],
config["match"]["track_length_max"].as_number(),
)

diff = abs(item_length - track_info.length) - track_length_grace
dist.add_ratio("track_length", diff, track_length_max)
if info_length := track_info.length:
diff = abs(item.length - info_length) - get_track_length_grace()
dist.add_ratio("track_length", diff, get_track_length_max())

# Title.
dist.add_string("track_title", item.title, track_info.title)
Expand Down
3 changes: 3 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ Bug fixes:
:bug:`5265`
:bug:`5371`
:bug:`4715`
* :ref:`import-cmd`: Fix ``MemoryError`` and improve performance tagging large
albums by replacing ``munkres`` library with ``lap.lapjv``.
:bug:`5207`

For packagers:

Expand Down
81 changes: 68 additions & 13 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,10 @@ python = ">=3.9,<4"
colorama = { version = "*", markers = "sys_platform == 'win32'" }
confuse = ">=1.5.0"
jellyfish = "*"
lap = ">=0.5.12"
mediafile = ">=0.12.0"
munkres = ">=1.0.0"
musicbrainzngs = ">=0.4"
numpy = ">=1.24.4"
platformdirs = ">=3.5.0"
pyyaml = "*"
typing_extensions = { version = "*", python = "<=3.10" }
Expand Down

0 comments on commit faf7529

Please sign in to comment.