Skip to content

Commit

Permalink
refactor: extract repodata retrieval || add: package matching on more…
Browse files Browse the repository at this point in the history
… complicated patterns
  • Loading branch information
YYYasin19 committed Jun 29, 2023
1 parent da14c25 commit d050205
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 72 deletions.
22 changes: 16 additions & 6 deletions quetz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ def _alembic_config(db_url: str) -> AlembicConfig:
script_location = "quetz:migrations"

migration_modules = [
f"{ep.module}:versions" for ep in entry_points().select(group='quetz.migrations')
f"{ep.module}:versions"
for ep in entry_points().select(group='quetz.migrations')
]
migration_modules.append("quetz:migrations/versions")

Expand Down Expand Up @@ -127,7 +128,9 @@ def _make_migrations(
found = True

if plugin_name != "quetz" and not found:
raise Exception(f"models entrypoint (quetz.models) for plugin {plugin_name} not registered")
raise Exception(
f"models entrypoint (quetz.models) for plugin {plugin_name} not registered"
)

logger.info('Making DB migrations on %r for %r', db_url, plugin_name)
if not alembic_config and db_url:
Expand All @@ -137,7 +140,9 @@ def _make_migrations(
if plugin_name == "quetz":
version_path = None # Path(quetz.__file__).parent / 'migrations' / 'versions'
else:
entry_point = tuple(entry_points().select(group='quetz.migrations', name=plugin_name))[0]
entry_point = tuple(
entry_points().select(group='quetz.migrations', name=plugin_name)
)[0]
module = entry_point.load()
version_path = str(Path(module.__file__).parent / "versions")

Expand Down Expand Up @@ -198,7 +203,9 @@ def _set_user_roles(db: Session, config: Config):
f"with identity from provider '{provider}'"
)
elif user.role is not None and user.role != default_role:
logger.warning(f"user has already role {user.role} not assigning a new role")
logger.warning(
f"user has already role {user.role} not assigning a new role"
)
else:
user.role = role

Expand Down Expand Up @@ -429,7 +436,8 @@ def create(
if _is_deployment(deployment_folder):
if exists_ok:
logger.info(
f'Quetz deployment already exists at {deployment_folder}.\n' f'Skipping creation.'
f'Quetz deployment already exists at {deployment_folder}.\n'
f'Skipping creation.'
)
return
if delete and (copy_conf or create_conf):
Expand Down Expand Up @@ -725,7 +733,9 @@ def plugin(
# Try to install pip if it's missing
if conda_exe_path is not None:
print("pip is missing, installing...")
subprocess.call([conda_exe_path, 'install', '--channel', 'conda-forge', 'pip'])
subprocess.call(
[conda_exe_path, 'install', '--channel', 'conda-forge', 'pip']
)
pip_exe_path = find_executable('pip')

if pip_exe_path is None:
Expand Down
4 changes: 3 additions & 1 deletion quetz/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,4 +96,6 @@ def register(app):
frontend_dir = f"{sys.prefix}/share/quetz/frontend/"
else:
logger.info("Using basic fallback frontend")
frontend_dir = os.path.join(os.path.dirname(os.path.realpath(__file__)), "basic_frontend")
frontend_dir = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "basic_frontend"
)
2 changes: 1 addition & 1 deletion quetz/rest_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ class ChannelActionEnum(str, Enum):


class ChannelMetadata(BaseModel):
includelist: Optional[List[str]] = Field(
includelist: Optional[Union[List[str], Dict[str, List]]] = Field(
None,
title="list of packages to include while creating a channel",
nullable=True,
Expand Down
82 changes: 49 additions & 33 deletions quetz/tasks/mirror.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from concurrent.futures import ThreadPoolExecutor
from http.client import IncompleteRead
from tempfile import SpooledTemporaryFile
from typing import List
from typing import List, Union

import requests
from fastapi import HTTPException, status
Expand All @@ -23,7 +23,7 @@
from quetz.errors import DBError
from quetz.pkgstores import PackageStore
from quetz.tasks import indexing
from quetz.utils import TicToc, add_static_file, check_package_membership_pattern
from quetz.utils import TicToc, add_static_file, check_package_membership

# copy common subdirs from conda:
# https://github.com/conda/conda/blob/a78a2387f26a188991d771967fc33aa1fb5bb810/conda/base/constants.py#L63
Expand Down Expand Up @@ -278,36 +278,22 @@ def _upload_package(file, channel_name, subdir):
file.file.close()


def initial_sync_mirror(
channel_name: str,
remote_repository: RemoteRepository,
arch: str,
dao: Dao,
pkgstore: PackageStore,
auth: authorization.Rules,
includelist: List[str] = None,
excludelist: List[str] = None,
skip_errors: bool = True,
use_repodata: bool = False,
):
force = True # needed for updating packages
logger.info(
f"Running channel mirroring {channel_name}/{arch} from {remote_repository.host}"
)

def get_remote_repodata(
channel_name: str, arch: str, remote_repository: RemoteRepository
) -> Union[dict, None]:
repodata = {}
for repodata_fn in ["repodata_from_packages.json", "repodata.json"]:
try:
repo_file = remote_repository.open(os.path.join(arch, repodata_fn))
repodata = json.load(repo_file.file)
break
return repodata
except RemoteServerError:
logger.error(
f"can not get {repodata_fn} for channel {arch}/{channel_name}."
)
if repodata_fn == "repodata.json":
logger.error(f"Giving up for {channel_name}/{arch}.")
return
return None
else:
logger.error("Trying next filename.")
continue
Expand All @@ -317,18 +303,41 @@ def initial_sync_mirror(
f"in channel {channel_name}"
)
if repodata_fn == "repodata.json":
return
return None

return {}

channel = dao.get_channel(channel_name)

def initial_sync_mirror(
channel_name: str,
remote_repository: RemoteRepository,
arch: str,
dao: Dao,
pkgstore: PackageStore,
auth: authorization.Rules,
includelist: List[str] = None,
excludelist: List[str] = None,
skip_errors: bool = True,
use_repodata: bool = False,
):
force = True # needed for updating packages
logger.info(
f"Running channel mirroring {channel_name}/{arch} from {remote_repository.host}"
)

repodata = get_remote_repodata(channel_name, arch, remote_repository)
if not repodata:
return # quit; error has already been logged.

packages = repodata.get("packages", {})

channel = dao.get_channel(channel_name)
if not channel:
logger.error(f"channel {channel_name} not found")
return

from quetz.main import handle_package_files

packages = repodata.get("packages", {})

version_methods = [
_check_checksum(dao, channel_name, arch, "sha256"),
_check_checksum(dao, channel_name, arch, "md5"),
Expand Down Expand Up @@ -408,10 +417,13 @@ def handle_batch(update_batch):
# TODO: also remove all packages that are not in the remote repository anymore
# practically re-write the complete sync mechanism?

# SYNC: Remote -> Local
# for each package in the remote repository:
# validate if it should be downloaded to this channel
# also: remove packages if they are not supposed to in this channel anymore
for repo_package_name, metadata in packages.items():
# if check_package_membership(repo_package_name, includelist, excludelist):
if check_package_membership_pattern(
repo_package_name, includelist, excludelist
if check_package_membership(
channel, repo_package_name, metadata, remote_host=remote_repository.host
):
path = os.path.join(arch, repo_package_name)

Expand Down Expand Up @@ -447,24 +459,28 @@ def handle_batch(update_batch):
# handle final batch
any_updated |= handle_batch(update_batch)

# SYNC: Local checks Remote
# Validate if all packages in this channel are still
# also present in the remote channel
# if not: add them to the remove batch as well
# TODO

if remove_batch:
logger.debug(f"Removing {len(remove_batch)} packages: {remove_batch}")
package_specs_remove = set([p[1].split("-")[0] for p in remove_batch])
# TODO: reuse route [DELETE /api/channels/{channel_name}/packages] logic
for package_specs in package_specs_remove:
# TODO: only remove if it already exists of course
dao.remove_package(channel_name, package_name=package_specs)
# TODO: is this needed every time?
dao.cleanup_channel_db(channel_name, package_name=package_specs)
# only remove if exists
if pkgstore.file_exists(channel.name, package_specs):
pkgstore.delete_file(channel.name, destination=package_specs)

any_updated |= True

if any_updated:
indexing.update_indexes(
dao, pkgstore, channel_name, subdirs=[arch]
) # build repodata
# build local repodata
indexing.update_indexes(dao, pkgstore, channel_name, subdirs=[arch])


def create_packages_from_channeldata(
Expand Down
100 changes: 69 additions & 31 deletions quetz/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,47 +26,85 @@
from .db_models import Channel, Package, PackageVersion, User


def check_package_membership(package_name, includelist, excludelist):
if includelist:
for each_package in includelist:
if package_name.startswith(each_package):
return True
return False
elif excludelist:
for each_package in excludelist:
if package_name.startswith(each_package):
return False
return True
return True


def _parse_package_spec(package_spec: str) -> tuple[str, str, str]:
def _parse_package_spec(package_name: str, package_metadata) -> tuple[str, str, str]:
"""Given a package name and metadata, return the package spec.
Args:
package_name (str): The package name in file format,
e.g. "numpy-1.23.4-py39hefdcf20_0.tar.bz2"
package_metadata (_type_): Metadata of the package,
e.g. from repodata.json
Returns:
tuple[str, str, str]: (name, version, build-string)
"""

# spec = _parse_spec_str(package_spec)
# return spec.get("name", ""), spec.get("version", ""), spec.get("build", "")
# TODO: the package spec here looks like "numpy-1.23.4-py39hefdcf20_0.tar.bz2"
# and does not have "="
spec = package_spec.split("-")
spec = package_name.split("-")
return spec[0], spec[1] if len(spec) > 1 else "", spec[2] if len(spec) > 2 else ""


def check_package_membership_pattern(
package_spec, include_pattern_list=[], exclude_pattern_list=[]
):
# TODO: validate performance, can we save the MatchSpec instances between calls?
# might be okay for <100 packages to check against, but what about 1000s?
# TODO: matchspec vs package spec and build string matching with *
name, version, build = _parse_package_spec(package_spec)
for include_pattern in include_pattern_list:
# TODO: how do we get the build number?
include = MatchSpec(include_pattern).match(
def _check_package_match(
package_spec: tuple[str, str, str],
include_or_exclude_list: list[str],
) -> bool:
"""
Check if the given package specification matches
with the given include or exclude list.
Returns true if a match is found.
"""
name, version, build = package_spec
for pattern in include_or_exclude_list:
# TODO: validate if this matches with our current implementation
if MatchSpec(pattern).match(
{"name": name, "version": version, "build": build, "build_number": 0}
)
exclude = False # TODO
if include and not exclude:
):
return True

else:
return False
return False


def check_package_membership(
channel: Channel,
package_name: str,
package_metadata: dict,
remote_host: str,
):
"""
Check if a package should be in a channel according
to the rules defined in the channel metadata.
Args:
channel (Channel): Channel object returned from the database
package_name (str): name of the package in file format,
e.g. "numpy-1.23.4-py39hefdcf20_0.tar.bz2"
package_metadata (dict): package metadata,
information that can be found in repodata.json for example
includelist (Union[list[str], dict, None], optional):
excludelist (Union[list[str], dict, None], optional):
Returns:
bool: if the package should be in this channel or not according to the rules.
"""
package_spec = _parse_package_spec(package_name, package_metadata)
metadata = channel.load_channel_metadata()
if (includelist := metadata['includelist']) is not None:
# Example: { "main": ["numpy", "pandas"], "r": ["r-base"]}
if isinstance(includelist, dict):
if channel.name not in includelist:
include_package = False
channel_includelist = includelist[remote_host.split("/")[-1]]
include_package = _check_package_match(package_spec, channel_includelist)
# Example: ["numpy", "pandas", "r-base"]
elif isinstance(includelist, list):
include_package = _check_package_match(package_spec, includelist)

# TODO: implement excludelist

return include_package


def add_static_file(contents, channel_name, subdir, fname, pkgstore, file_index=None):
Expand Down

0 comments on commit d050205

Please sign in to comment.