From 16d606d03e415803dadcced92e41057c75b3da1f Mon Sep 17 00:00:00 2001 From: Damian Shaw Date: Fri, 9 Aug 2024 04:12:13 -0400 Subject: [PATCH] feat: Allow provider to filter unsatisfied names, when backtracking (#145) * Allow provider to narrow backtrack selection * formatting * Throw specific error if narrowed_unstatisfied_names is empty * Increase mccabe complexity * Update docs * Add functional tests for narrow_requirement_selection * Add news entry * update docs of `get_preference` --- news/145.feature | 3 + pyproject.toml | 3 + src/resolvelib/providers.py | 59 +++++++++++++++++++ src/resolvelib/resolvers/resolution.py | 32 +++++++++- .../python/test_resolvers_python.py | 42 +++++++++++-- 5 files changed, 131 insertions(+), 8 deletions(-) create mode 100644 news/145.feature diff --git a/news/145.feature b/news/145.feature new file mode 100644 index 0000000..65dcd9e --- /dev/null +++ b/news/145.feature @@ -0,0 +1,3 @@ +New `narrow_requirement_selection` provider method giving option for +providers to reduce the number of times sort key `get_preference` is +called in long running backtrack diff --git a/pyproject.toml b/pyproject.toml index be55ff5..3622199 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,9 @@ exclude = [ "*.pyi" ] +[tool.ruff.lint.mccabe] +max-complexity = 12 + [tool.mypy] warn_unused_configs = true diff --git a/src/resolvelib/providers.py b/src/resolvelib/providers.py index 6d8bc47..524e3d8 100644 --- a/src/resolvelib/providers.py +++ b/src/resolvelib/providers.py @@ -40,6 +40,13 @@ def get_preference( ) -> Preference: """Produce a sort key for given requirement based on preference. + As this is a sort key it will be called O(n) times per backtrack + step, where n is the number of `identifier`s, if you have a check + which is expensive in some sense. E.g. It needs to make O(n) checks + per call or takes significant wall clock time, consider using + `narrow_requirement_selection` to filter the `identifier`s, which + is applied before this sort key is called. + The preference is defined as "I think this requirement should be resolved first". The lower the return value is, the more preferred this group of arguments is. @@ -135,3 +142,55 @@ def get_dependencies(self, candidate: CT) -> Iterable[RT]: specifies as its dependencies. """ raise NotImplementedError + + def narrow_requirement_selection( + self, + identifiers: Iterable[KT], + resolutions: Mapping[KT, CT], + candidates: Mapping[KT, Iterator[CT]], + information: Mapping[KT, Iterator[RequirementInformation[RT, CT]]], + backtrack_causes: Sequence[RequirementInformation[RT, CT]], + ) -> Iterable[KT]: + """ + An optional method to narrow the selection of requirements being + considered during resolution. This method is called O(1) time per + backtrack step. + + :param identifiers: An iterable of `identifiers` as returned by + ``identify()``. These identify all requirements currently being + considered. + :param resolutions: A mapping of candidates currently pinned by the + resolver. Each key is an identifier, and the value is a candidate + that may conflict with requirements from ``information``. + :param candidates: A mapping of each dependency's possible candidates. + Each value is an iterator of candidates. + :param information: A mapping of requirement information for each package. + Each value is an iterator of *requirement information*. + :param backtrack_causes: A sequence of *requirement information* that are + the requirements causing the resolver to most recently + backtrack. + + A *requirement information* instance is a named tuple with two members: + + * ``requirement`` specifies a requirement contributing to the current + list of candidates. + * ``parent`` specifies the candidate that provides (is depended on for) + the requirement, or ``None`` to indicate a root requirement. + + Must return a non-empty subset of `identifiers`, with the default + implementation being to return `identifiers` unchanged. Those `identifiers` + will then be passed to the sort key `get_preference` to pick the most + prefered requirement to attempt to pin, unless `narrow_requirement_selection` + returns only 1 requirement, in which case that will be used without + calling the sort key `get_preference`. + + This method is designed to be used by the provider to optimize the + dependency resolution, e.g. if a check cost is O(m) and it can be done + against all identifiers at once then filtering the requirement selection + here will cost O(m) but making it part of the sort key in `get_preference` + will cost O(m*n), where n is the number of `identifiers`. + + Returns: + Iterable[KT]: A non-empty subset of `identifiers`. + """ + return identifiers diff --git a/src/resolvelib/resolvers/resolution.py b/src/resolvelib/resolvers/resolution.py index 6c0bf50..da3c66e 100644 --- a/src/resolvelib/resolvers/resolution.py +++ b/src/resolvelib/resolvers/resolution.py @@ -411,8 +411,36 @@ def resolve(self, requirements: Iterable[RT], max_rounds: int) -> State[RT, CT, # keep track of satisfied names to calculate diff after pinning satisfied_names = set(self.state.criteria.keys()) - set(unsatisfied_names) - # Choose the most preferred unpinned criterion to try. - name = min(unsatisfied_names, key=self._get_preference) + if len(unsatisfied_names) > 1: + narrowed_unstatisfied_names = list( + self._p.narrow_requirement_selection( + identifiers=unsatisfied_names, + resolutions=self.state.mapping, + candidates=IteratorMapping( + self.state.criteria, + operator.attrgetter("candidates"), + ), + information=IteratorMapping( + self.state.criteria, + operator.attrgetter("information"), + ), + backtrack_causes=self.state.backtrack_causes, + ) + ) + else: + narrowed_unstatisfied_names = unsatisfied_names + + # If there are no unsatisfied names use unsatisfied names + if not narrowed_unstatisfied_names: + raise RuntimeError("narrow_requirement_selection returned 0 names") + + # If there is only 1 unsatisfied name skip calling self._get_preference + if len(narrowed_unstatisfied_names) > 1: + # Choose the most preferred unpinned criterion to try. + name = min(narrowed_unstatisfied_names, key=self._get_preference) + else: + name = narrowed_unstatisfied_names[0] + failure_criterion = self._attempt_to_pin_criterion(name) if failure_criterion: diff --git a/tests/functional/python/test_resolvers_python.py b/tests/functional/python/test_resolvers_python.py index 18c1550..c1e3038 100644 --- a/tests/functional/python/test_resolvers_python.py +++ b/tests/functional/python/test_resolvers_python.py @@ -121,6 +121,24 @@ def get_dependencies(self, candidate): return list(self._iter_dependencies(candidate)) +class PythonInputProviderNarrowRequirements(PythonInputProvider): + def narrow_requirement_selection( + self, identifiers, resolutions, candidates, information, backtrack_causes + ): + # Consider requirements that have 0 candidates (a resolution end point + # that can be backtracked from) or 1 candidate (speeds up situations where + # ever requirement is pinned to 1 specific version) + number_of_candidates = defaultdict(list) + for identifier in identifiers: + number_of_candidates[len(list(candidates[identifier]))].append(identifier) + + min_candidates = min(number_of_candidates.keys()) + if min_candidates in (0, 1): + return number_of_candidates[min_candidates] + + return identifiers + + INPUTS_DIR = os.path.abspath(os.path.join(__file__, "..", "inputs")) CASE_DIR = os.path.join(INPUTS_DIR, "case") @@ -133,20 +151,32 @@ def get_dependencies(self, candidate): } -@pytest.fixture( - params=[ +def create_params(provider_class): + return [ pytest.param( - os.path.join(CASE_DIR, n), + (os.path.join(CASE_DIR, n), provider_class), marks=pytest.mark.xfail(strict=True, reason=XFAIL_CASES[n]), ) if n in XFAIL_CASES - else os.path.join(CASE_DIR, n) + else (os.path.join(CASE_DIR, n), provider_class) + for n in CASE_NAMES + ] + + +@pytest.fixture( + params=[ + *create_params(PythonInputProvider), + *create_params(PythonInputProviderNarrowRequirements), + ], + ids=[ + f"{n[:-5]}-{cls.__name__}" + for cls in [PythonInputProvider, PythonInputProviderNarrowRequirements] for n in CASE_NAMES ], - ids=[n[:-5] for n in CASE_NAMES], ) def provider(request): - return PythonInputProvider(request.param) + path, provider_class = request.param + return provider_class(path) def _format_confliction(exception):