Skip to content

Proposal: Outsource installation of packages (models, data plugins) - WARNING - breaking API change #346

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scivision/io/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
from .reader import load_pretrained_model, load_dataset, _parse_url, _get_model_configs
from .installer import install_package, package_from_config
from .checker import check_package, package_from_config
from .wrapper import PretrainedModel, Datasource
24 changes: 3 additions & 21 deletions scivision/io/installer.py → scivision/io/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
# -*- coding: utf-8 -*-

import importlib
import subprocess
import sys


def _package_exists(config: dict) -> bool:
Expand All @@ -25,31 +23,15 @@ def package_from_config(config: dict, branch: str = "main") -> str:
return f"git+{install_str}@{install_branch}#egg={config['import']}"


def _install(package, pip_install_args=None):
"""Install a package using pip."""

if pip_install_args is None:
pip_install_args = []

subprocess.check_call(
[sys.executable, "-m", "pip", "install", *pip_install_args, package]
)


def install_package(
def check_package(
config: dict,
allow_install=False, # allowed values: True, False, or the string "force"
branch: str = "main",
):
"""Install the python package if it doesn't exist."""
"""Check if the python package exists."""
package = package_from_config(config, branch)
exists = _package_exists(config)

if allow_install == "force" or (allow_install and not exists):
# if a package is not already installed, there is little harm
# to passing the extra arguments, so these cases are combined
_install(package, pip_install_args=["--force-reinstall", "--no-cache-dir"])
elif not exists:
if not exists:
raise Exception(
"Package does not exist. Try installing it with: \n"
f"`!pip install -e {package}`"
Expand Down
13 changes: 5 additions & 8 deletions scivision/io/reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import yaml

from ..koala import koala
from .installer import install_package
from .checker import check_package
from .wrapper import PretrainedModel, Datasource

import warnings
Expand Down Expand Up @@ -120,7 +120,6 @@ def _get_model_configs(
def load_pretrained_model(
path: os.PathLike,
branch: str = "main",
allow_install: bool = False,
model_selection: str = "default",
load_multiple: bool = False,
*args,
Expand All @@ -134,8 +133,6 @@ def load_pretrained_model(
The filename, path or URL of a pretrained model description.
branch : str, default = main
Specify the name of a github branch if loading from github.
allow_install : bool, default = False
Allow installation of remote package via pip.
model_selection : str, default = default
Specify the name of the model if there is > 1 in the model repo package.
load_multiple : bool, default = False
Expand Down Expand Up @@ -170,8 +167,8 @@ def load_pretrained_model(
# make sure a model at least has an input to the function
assert "X" in config["prediction_fn"]["args"].keys()

# try to install the package if necessary
install_package(config, allow_install=allow_install, branch=branch)
# Raise exception if package not installed and suggest how
check_package(config, branch=branch)

loaded_models.append(PretrainedModel(config))
if load_multiple:
Expand Down Expand Up @@ -245,7 +242,7 @@ def load_data_from_plugin(
The dataset to be visualised, loaded via xarray.
"""

# install the package
install_package(config, allow_install=True, branch=branch)
# Raise exception if package not installed and suggest how
check_package(config, branch=branch)

return Datasource(config)
17 changes: 15 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from scivision.io import install_package
from scivision.io import package_from_config
from scivision.io import PretrainedModel

import subprocess
import sys
import fsspec
import yaml
import pytest
Expand Down Expand Up @@ -31,4 +33,15 @@ def KOALA(request):


# Install the model package so it can be used in tests
install_package(imagenet_model_config, allow_install=True)
subprocess.check_call([sys.executable, "-m", "pip", "install", package_from_config(imagenet_model_config)])


# Set up some global vars for tests that require an example data plugin
file = fsspec.open('tests/test_data_plugin.yml')
with file as config_file:
stream = config_file.read()
data_plugin_config = yaml.safe_load(stream)


# Install the data plugin package so it can be used in tests
subprocess.check_call([sys.executable, "-m", "pip", "install", package_from_config(data_plugin_config)])
8 changes: 8 additions & 0 deletions tests/test_data_plugin.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
name: scivision_sentinel2_stac
url: https://github.com/alan-turing-institute/scivision_sentinel2_stac.git
import: scivision_sentinel2_stac
class: scivision_sentinel2_stac
args: None
func:
call: get_images
args: None