diff --git a/CHANGELOG.md b/CHANGELOG.md index 7be4e6fa..3778ff93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ All versions prior to 0.0.9 are untracked. ### Fixed +* Fixed bug with the `--fix` flag where new requirements were sometimes being + appended to requirement files instead of patching the existing requirement + ([#577](https://github.com/pypa/pip-audit/pull/577)) + * Fixed a crash caused by auditing requirements files that refer to other requirements files ([#568](https://github.com/pypa/pip-audit/pull/568)) diff --git a/pip_audit/_dependency_source/requirement.py b/pip_audit/_dependency_source/requirement.py index 63d0b1c5..33c9993f 100644 --- a/pip_audit/_dependency_source/requirement.py +++ b/pip_audit/_dependency_source/requirement.py @@ -14,6 +14,7 @@ from typing import IO, Iterator from packaging.specifiers import SpecifierSet +from packaging.utils import canonicalize_name from pip_requirements_parser import InstallRequirement, InvalidRequirementLine, RequirementsFile from pip_audit._dependency_source import DependencyFixError, DependencySource, DependencySourceError @@ -203,7 +204,10 @@ def _fix_file(self, filename: Path, fix_version: ResolvedFixVersion) -> None: with filename.open("w") as f: found = False for req in reqs: - if isinstance(req, InstallRequirement) and req.name == fix_version.dep.name: + if ( + isinstance(req, InstallRequirement) + and canonicalize_name(req.name) == fix_version.dep.canonical_name + ): found = True if req.specifier.contains( fix_version.dep.version diff --git a/test/dependency_source/test_requirement.py b/test/dependency_source/test_requirement.py index 1eeee699..a5ac69ad 100644 --- a/test/dependency_source/test_requirement.py +++ b/test/dependency_source/test_requirement.py @@ -198,6 +198,52 @@ def test_requirement_source_fix(req_file): ) +def test_requirement_source_fix_roundtrip(req_file): + req_path = req_file() + with open(req_path, "w") as f: + f.write("flask==0.5") + + source = requirement.RequirementSource([req_path]) + specs = list(source.collect()) + + flask_dep: ResolvedDependency | None = None + for spec in specs: + if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask": + flask_dep = spec + break + assert flask_dep is not None + assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5")) + + flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0")) + source.fix(flask_fix) + + with open(req_path) as f: + assert f.read().strip() == "flask==1.0" + + +def test_requirement_source_fix_roundtrip_non_canonical_name(req_file): + req_path = req_file() + with open(req_path, "w") as f: + f.write("Flask==0.5") + + source = requirement.RequirementSource([req_path]) + specs = list(source.collect()) + + flask_dep: ResolvedDependency | None = None + for spec in specs: + if isinstance(spec, ResolvedDependency) and spec.canonical_name == "flask": + flask_dep = spec + break + assert flask_dep is not None + assert flask_dep == ResolvedDependency(name="Flask", version=Version("0.5")) + + flask_fix = ResolvedFixVersion(dep=flask_dep, version=Version("1.0")) + source.fix(flask_fix) + + with open(req_path) as f: + assert f.read().strip() == "Flask==1.0" + + def test_requirement_source_fix_multiple_files(req_file): _check_fixes( ["flask==0.5", "requests==2.0\nflask==0.5"],