diff --git a/.gitignore b/.gitignore index 4046166..67eceef 100644 --- a/.gitignore +++ b/.gitignore @@ -170,3 +170,5 @@ cython_debug/ *.mp4 *.mkv + +/cccv/cache_models/ diff --git a/README.md b/README.md index f63bc63..bc48edd 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ pip install cccv ### Start -#### cv2 +#### Load a registered model in cccv a simple example to use the SISR (Single Image Super-Resolution) model to process an image @@ -37,6 +37,22 @@ img = model.inference_image(img) cv2.imwrite("test_out.jpg", img) ``` +#### Load a custom model from remote repository or local path + +a simple example to use [remote repository](https://github.com/EutropicAI/cccv_demo_remote_model) or local path, auto register the model then load + +```python +import cv2 +import numpy as np + +from cccv import AutoModel, SRBaseModel + +# remote repo +model: SRBaseModel = AutoModel.from_pretrained("https://github.com/EutropicAI/cccv_demo_remote_model") +# local path +model: SRBaseModel = AutoModel.from_pretrained("/path/to/cccv_demo_model") +``` + #### VapourSynth a simple example to use the VapourSynth to process a video diff --git a/cccv/__init__.py b/cccv/__init__.py index 7a602b2..91d0965 100644 --- a/cccv/__init__.py +++ b/cccv/__init__.py @@ -26,6 +26,6 @@ from cccv.arch import ARCH_REGISTRY from cccv.auto import AutoConfig, AutoModel -from cccv.config import CONFIG_REGISTRY, BaseConfig, SRBaseConfig, VFIBaseConfig, VSRBaseConfig +from cccv.config import CONFIG_REGISTRY, AutoBaseConfig, BaseConfig, SRBaseConfig, VFIBaseConfig, VSRBaseConfig from cccv.model import MODEL_REGISTRY, AuxiliaryBaseModel, CCBaseModel, SRBaseModel, VFIBaseModel, VSRBaseModel from cccv.type import ArchType, BaseModelInterface, ConfigType, ModelType diff --git a/cccv/arch/sr/dat_arch.py b/cccv/arch/sr/dat_arch.py index ed65e2b..5eec59f 100644 --- a/cccv/arch/sr/dat_arch.py +++ b/cccv/arch/sr/dat_arch.py @@ -365,8 +365,7 @@ def __init__( elif idx == 1: W_sp, H_sp = self.split_size[0], self.split_size[1] else: - print("ERROR MODE", idx) - exit(0) + raise ValueError(f"[CCCV] ERROR MODE: invalid idx {idx}, expected 0 or 1") self.H_sp = H_sp self.W_sp = W_sp diff --git a/cccv/arch/sr/upcunet_arch.py b/cccv/arch/sr/upcunet_arch.py index 56aa39e..af702aa 100644 --- a/cccv/arch/sr/upcunet_arch.py +++ b/cccv/arch/sr/upcunet_arch.py @@ -376,8 +376,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro): t2 = tile_mode * 2 crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode) else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] @@ -526,8 +525,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro): t2 = tile_mode * 2 crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode) else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (18, 18 + pw - w0, 18, 18 + ph - h0), "reflect") @@ -767,8 +765,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro): t4 = tile_mode * 4 crop_size = (((h0 - 1) // t4 * t4 + t4) // tile_mode, ((w0 - 1) // t4 * t4 + t4) // tile_mode) else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect") @@ -916,8 +913,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro): t4 = tile_mode * 4 crop_size = (((h0 - 1) // t4 * t4 + t4) // tile_mode, ((w0 - 1) // t4 * t4 + t4) // tile_mode) else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (14, 14 + pw - w0, 14, 14 + ph - h0), "reflect") @@ -1162,8 +1158,7 @@ def forward(self, x, tile_mode, cache_mode, alpha, pro): t2 = tile_mode * 2 crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode) else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect") @@ -1323,8 +1318,7 @@ def forward_gap_sync(self, x, tile_mode, alpha, pro): t2 = tile_mode * 2 crop_size = (((h0 - 1) // t2 * t2 + t2) // tile_mode, ((w0 - 1) // t2 * t2 + t2) // tile_mode) # 5.6G else: - print("tile_mode config error") - os._exit(233) + raise ValueError("[CCCV] tile_mode config error: invalid tile_mode value") ph = ((h0 - 1) // crop_size[0] + 1) * crop_size[0] pw = ((w0 - 1) // crop_size[1] + 1) * crop_size[1] x = F.pad(x, (19, 19 + pw - w0, 19, 19 + ph - h0), "reflect") diff --git a/cccv/arch/vfi/drba_arch.py b/cccv/arch/vfi/drba_arch.py index a4c27ee..fa57fb2 100644 --- a/cccv/arch/vfi/drba_arch.py +++ b/cccv/arch/vfi/drba_arch.py @@ -1,4 +1,6 @@ # type: ignore +import warnings + import numpy as np import torch import torch.nn as nn @@ -61,7 +63,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i] ) if ensemble: - print("warning: ensemble is not supported since RIFEv4.21") + warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2) else: wf0 = warp(f0, flow[:, :2]) wf1 = warp(f1, flow[:, 2:4]) @@ -71,7 +73,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa scale=scale_list[i], ) if ensemble: - print("warning: ensemble is not supported since RIFEv4.21") + warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2) else: mask = m0 flow = flow + fd @@ -83,7 +85,7 @@ def inference(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fa mask = torch.sigmoid(mask) merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) if not fastmode: - print("contextnet is removed") + warnings.warn("[CCCV] contextnet is removed", stacklevel=2) """ c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) diff --git a/cccv/arch/vfi/ifnet_arch.py b/cccv/arch/vfi/ifnet_arch.py index 3d7ab51..40087a3 100644 --- a/cccv/arch/vfi/ifnet_arch.py +++ b/cccv/arch/vfi/ifnet_arch.py @@ -1,4 +1,6 @@ # type: ignore +import warnings + import torch import torch.nn as nn import torch.nn.functional as F @@ -43,7 +45,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals torch.cat((img0[:, :3], img1[:, :3], f0, f1, timestep), 1), None, scale=scale_list[i] ) if ensemble: - print("warning: ensemble is not supported since RIFEv4.21") + warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2) else: wf0 = warp(f0, flow[:, :2]) wf1 = warp(f1, flow[:, 2:4]) @@ -53,7 +55,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals scale=scale_list[i], ) if ensemble: - print("warning: ensemble is not supported since RIFEv4.21") + warnings.warn("[CCCV] ensemble is not supported since RIFEv4.21", stacklevel=2) else: mask = m0 flow = flow + fd @@ -65,7 +67,7 @@ def forward(self, x, timestep=0.5, scale_list=None, fastmode=True, ensemble=Fals mask = torch.sigmoid(mask) merged[4] = warped_img0 * mask + warped_img1 * (1 - mask) if not fastmode: - print("contextnet is removed") + warnings.warn("[CCCV] contextnet is removed", stacklevel=2) """ c0 = self.contextnet(img0, flow[:, :2]) c1 = self.contextnet(img1, flow[:, 2:4]) diff --git a/cccv/arch/vfi/vfi_utils/softsplat.py b/cccv/arch/vfi/vfi_utils/softsplat.py index 5f82e98..823d523 100644 --- a/cccv/arch/vfi/vfi_utils/softsplat.py +++ b/cccv/arch/vfi/vfi_utils/softsplat.py @@ -61,7 +61,7 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): strKey += str(objValue.stride()) elif True: - print(strVariable, type(objValue)) + print(f"[CCCV] {strVariable}, {type(objValue)}") # end # end @@ -106,10 +106,10 @@ def cuda_kernel(strFunction: str, strKernel: str, objVariables: typing.Dict): strKernel = strKernel.replace("{{type}}", "long") elif isinstance(objValue, torch.Tensor): - print(strVariable, objValue.dtype) + print(f"[CCCV] {strVariable}, {objValue.dtype}") elif True: - print(strVariable, type(objValue)) + print(f"[CCCV] {strVariable}, {type(objValue)}") # end # end diff --git a/cccv/auto/config.py b/cccv/auto/config.py index e625835..c7e055b 100644 --- a/cccv/auto/config.py +++ b/cccv/auto/config.py @@ -1,32 +1,87 @@ +import importlib.util +import json +import warnings +from pathlib import Path from typing import Any, Optional, Union -from cccv.config import CONFIG_REGISTRY, BaseConfig +from cccv.config import CONFIG_REGISTRY, AutoBaseConfig from cccv.type import ConfigType +from cccv.util.remote import git_clone class AutoConfig: @staticmethod def from_pretrained( - pretrained_model_name: Union[ConfigType, str], + pretrained_model_name_or_path: Union[ConfigType, str, Path], + *, + model_dir: Optional[Union[Path, str]] = None, **kwargs: Any, ) -> Any: """ - Get a config instance of a pretrained model configuration. + Get a config instance of a pretrained model configuration, can be a registered config name or a local path or a git url. - :param pretrained_model_name: The name of the pretrained model configuration + :param pretrained_model_name_or_path: + :param model_dir: The path to cache the downloaded model configuration. Should be a full path. If None, use default cache path. :return: """ - return CONFIG_REGISTRY.get(pretrained_model_name) + if "pretrained_model_name" in kwargs: + warnings.warn( + "[CCCV] 'pretrained_model_name' is deprecated, please use 'pretrained_model_name_or_path' instead.", + DeprecationWarning, + stacklevel=2, + ) + pretrained_model_name_or_path = kwargs.pop("pretrained_model_name") - @staticmethod - def register(config: Union[BaseConfig, Any], name: Optional[str] = None) -> None: - """ - Register the given config class instance under the name BaseConfig.name or the given name. - Can be used as a function call. See docstring of this class for usage. + # 1. check if it's a registered config name, early return if found + if isinstance(pretrained_model_name_or_path, ConfigType): + pretrained_model_name_or_path = pretrained_model_name_or_path.value + if str(pretrained_model_name_or_path) in CONFIG_REGISTRY: + return CONFIG_REGISTRY.get(str(pretrained_model_name_or_path)) - :param config: The config class instance to register - :param name: The name to register the config class instance under. If None, use BaseConfig.name - :return: - """ - # used as a function call - CONFIG_REGISTRY.register(obj=config, name=name) + # 2. check is a url or not, if it's a url, git clone it to model_dir then replace pretrained_model_name_or_path with the local path (Path) + if str(pretrained_model_name_or_path).startswith("http"): + pretrained_model_name_or_path = git_clone( + git_url=str(pretrained_model_name_or_path), + model_dir=model_dir, + **kwargs, + ) + + # 3. check if it's a real path + dir_path = Path(str(pretrained_model_name_or_path)) + + if not dir_path.exists() or not dir_path.is_dir(): + raise ValueError(f"[CCCV] model configuration '{dir_path}' is not a valid config name or path") + + # load config,json from the directory + config_path = dir_path / "config.json" + # check if config.json exists + if not config_path.exists(): + raise FileNotFoundError(f"[CCCV] no valid config.json not found in {dir_path}") + + with open(config_path, "r", encoding="utf-8") as f: + config_dict = json.load(f) + + for k in ["arch", "model", "name"]: + if k not in config_dict: + raise KeyError( + f"[CCCV] no key '{k}' in config.json in {dir_path}, you should provide a valid config.json contain a key '{k}'" + ) + + # auto import all .py files in the directory to register the arch, model and config + try: + for py_file in dir_path.glob("*.py"): + spec = importlib.util.spec_from_file_location(py_file.stem, py_file) + if spec is None or spec.loader is None: + continue + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + except Exception as e: + raise ImportError(f"[CCCV] failed register model from {dir_path}, error: {e}, please check your .py files") + + if "path" not in config_dict or config_dict["path"] is None or config_dict["path"] == "": + # add the path to the config_dict + config_dict["path"] = str(dir_path / config_dict["name"]) + + # convert config_dict to pydantic model + cfg = AutoBaseConfig.model_validate(config_dict) + return cfg diff --git a/cccv/auto/model.py b/cccv/auto/model.py index b6d0885..6fd5f4e 100644 --- a/cccv/auto/model.py +++ b/cccv/auto/model.py @@ -1,8 +1,10 @@ +from pathlib import Path from typing import Any, Optional, Tuple, Union import torch -from cccv.config import CONFIG_REGISTRY, BaseConfig +from cccv.auto.config import AutoConfig +from cccv.config import BaseConfig from cccv.model import MODEL_REGISTRY from cccv.type import ConfigType @@ -10,7 +12,8 @@ class AutoModel: @staticmethod def from_pretrained( - pretrained_model_name: Union[ConfigType, str], + pretrained_model_name_or_path: Union[ConfigType, str, Path], + *, device: Optional[torch.device] = None, fp16: bool = True, compile: bool = False, @@ -18,14 +21,14 @@ def from_pretrained( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, - model_dir: Optional[str] = None, + model_dir: Optional[Union[Path, str]] = None, gh_proxy: Optional[str] = None, **kwargs: Any, ) -> Any: """ - Get a model instance from a pretrained model name. + Get a model instance from a registered config name or a local path or a git url. - :param pretrained_model_name: The name of the pretrained model. It should be registered in CONFIG_REGISTRY. + :param pretrained_model_name_or_path: :param device: inference device :param fp16: use fp16 precision or not :param compile: use torch.compile or not @@ -37,8 +40,8 @@ def from_pretrained( :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ :return: """ + config = AutoConfig.from_pretrained(pretrained_model_name_or_path, model_dir=model_dir, **kwargs) - config = CONFIG_REGISTRY.get(pretrained_model_name) return AutoModel.from_config( config=config, device=device, @@ -56,6 +59,7 @@ def from_pretrained( @staticmethod def from_config( config: Union[BaseConfig, Any], + *, device: Optional[torch.device] = None, fp16: bool = True, compile: bool = False, @@ -63,14 +67,14 @@ def from_config( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, - model_dir: Optional[str] = None, + model_dir: Optional[Union[Path, str]] = None, gh_proxy: Optional[str] = None, **kwargs: Any, ) -> Any: """ Get a model instance from a config. - :param config: The config object. It should be registered in CONFIG_REGISTRY. + :param config: The config object. We suggest use cccv.BaseConfig or its subclass. :param device: inference device :param fp16: use fp16 precision or not :param compile: use torch.compile or not @@ -99,29 +103,3 @@ def from_config( ) return model - - @staticmethod - def register(obj: Optional[Any] = None, name: Optional[str] = None) -> Any: - """ - Register the given object under the name `obj.__name__` or the given name. - Can be used as either a decorator or not. See docstring of this class for usage. - - :param obj: The object to register. If None, this is being used as a decorator. - :param name: The name to register the object under. If None, use `obj.__name__`. - :return: - """ - if obj is None: - # used as a decorator - def deco(func_or_class: Any) -> Any: - _name = name - if _name is None: - _name = func_or_class.__name__ - MODEL_REGISTRY.register(obj=func_or_class, name=_name) - return func_or_class - - return deco - - # used as a function call - if name is None: - name = obj.__name__ - MODEL_REGISTRY.register(obj=obj, name=name) diff --git a/cccv/cache_models/__init__.py b/cccv/cache_models/__init__.py deleted file mode 100644 index f05df63..0000000 --- a/cccv/cache_models/__init__.py +++ /dev/null @@ -1,118 +0,0 @@ -import hashlib -import os -import sys -from pathlib import Path -from typing import Any, Optional - -from tenacity import retry, stop_after_attempt, stop_after_delay, wait_random -from torch.hub import download_url_to_file - -from cccv.config import BaseConfig - -if getattr(sys, "frozen", False): - # frozen - _IS_FROZEN_ = True - CACHE_PATH = Path(sys.executable).parent.absolute() / "cache_models" - if not CACHE_PATH.exists(): - os.makedirs(CACHE_PATH) -else: - # unfrozen - _IS_FROZEN_ = False - CACHE_PATH = Path(__file__).resolve().parent.absolute() - - -def get_file_sha256(file_path: str, blocksize: int = 1 << 20) -> str: - sha256 = hashlib.sha256() - with open(file_path, "rb") as f: - while True: - data = f.read(blocksize) - if not data: - break - sha256.update(data) - return sha256.hexdigest() - - -def load_file_from_url( - config: BaseConfig, - force_download: bool = False, - progress: bool = True, - model_dir: Optional[str] = None, - gh_proxy: Optional[str] = None, - **kwargs: Any, -) -> str: - """ - Load file form http url, will download models if necessary. - - Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py - - :param config: The config object. - :param force_download: Whether to force download the file. - :param progress: Whether to show the download progress. - :param model_dir: The path to save the downloaded model. Should be a full path. If None, use default cache path. - :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ - :return: - """ - - CCCV_REMOTE_MODEL_ZOO = os.environ.get( - "CCCV_REMOTE_MODEL_ZOO", "https://github.com/EutropicAI/cccv/releases/download/model_zoo/" - ) - CCCV_CACHE_MODEL_DIR = os.environ.get("CCCV_CACHE_MODEL_DIR", str(CACHE_PATH)) - - if model_dir is None: - model_dir = str(CCCV_CACHE_MODEL_DIR) - print( - f"[CCCV] Using default cache model path {model_dir}, override it by setting environment variable CCCV_CACHE_MODEL_DIR" - ) - - cached_file_path = os.path.abspath(os.path.join(model_dir, config.name)) - - if config.url is not None: - _url: str = str(config.url) - else: - print( - f"[CCCV] Fetching models from {CCCV_REMOTE_MODEL_ZOO}, override it by setting environment variable CCCV_REMOTE_MODEL_ZOO" - ) - if not CCCV_REMOTE_MODEL_ZOO.endswith("/"): - CCCV_REMOTE_MODEL_ZOO += "/" - _url = CCCV_REMOTE_MODEL_ZOO + config.name - - _gh_proxy = gh_proxy - if _gh_proxy is not None and _url.startswith("https://github.com"): - if not _gh_proxy.endswith("/"): - _gh_proxy += "/" - _url = _gh_proxy + _url - - if not os.path.exists(cached_file_path) or force_download: - if _gh_proxy is not None: - print(f"[CCCV] Using github proxy: {_gh_proxy}") - print(f"[CCCV] Downloading: {_url} to {cached_file_path}\n") - - @retry(wait=wait_random(min=3, max=5), stop=stop_after_delay(10) | stop_after_attempt(30)) - def _download() -> None: - try: - download_url_to_file(url=_url, dst=cached_file_path, hash_prefix=None, progress=progress) - except Exception as e: - print(f"[CCCV] Download failed: {e}, retrying...") - raise e - - _download() - - if config.hash is not None: - get_hash = get_file_sha256(cached_file_path) - if get_hash != config.hash: - raise ValueError( - f"[CCCV] File {cached_file_path} hash mismatched with config hash {config.hash}, compare with {get_hash}" - ) - - return cached_file_path - - -if __name__ == "__main__": - # get all model files sha256 - for root, _, files in os.walk(CACHE_PATH): - for file in files: - if not file.endswith(".pth") and not file.endswith(".pt"): - continue - file_path = os.path.join(root, file) - name = os.path.basename(file_path) - print(f"[CCCV] {name}: {get_file_sha256(file_path)}") diff --git a/cccv/config/__init__.py b/cccv/config/__init__.py index 283a196..4c43937 100644 --- a/cccv/config/__init__.py +++ b/cccv/config/__init__.py @@ -3,7 +3,7 @@ CONFIG_REGISTRY: RegistryConfigInstance = RegistryConfigInstance("CONFIG") -from cccv.config.base_config import BaseConfig, SRBaseConfig, VSRBaseConfig, VFIBaseConfig +from cccv.config.base_config import BaseConfig, SRBaseConfig, VSRBaseConfig, VFIBaseConfig, AutoBaseConfig # Auxiliary Network diff --git a/cccv/config/base_config.py b/cccv/config/base_config.py index e34b75f..03c0d02 100644 --- a/cccv/config/base_config.py +++ b/cccv/config/base_config.py @@ -1,6 +1,6 @@ from typing import Optional, Union -from pydantic import BaseModel, FilePath, HttpUrl +from pydantic import BaseModel, ConfigDict, FilePath, HttpUrl from cccv.type.arch import ArchType from cccv.type.model import ModelType @@ -15,6 +15,10 @@ class BaseConfig(BaseModel): model: Union[ModelType, str] +class AutoBaseConfig(BaseConfig): + model_config = ConfigDict(extra="allow") + + class AuxiliaryBaseConfig(BaseConfig): pass diff --git a/cccv/config/sr/srcnn_config.py b/cccv/config/sr/srcnn_config.py index 7d04f11..1a6403b 100644 --- a/cccv/config/sr/srcnn_config.py +++ b/cccv/config/sr/srcnn_config.py @@ -6,7 +6,7 @@ class SRCNNConfig(SRBaseConfig): - arch: ArchType = ArchType.SRCNN + arch: Union[ArchType, str] = ArchType.SRCNN model: Union[ModelType, str] = ModelType.SRBaseModel scale: int = 2 num_channels: int = 1 diff --git a/cccv/model/base_model.py b/cccv/model/base_model.py index 7c6b333..9c879eb 100644 --- a/cccv/model/base_model.py +++ b/cccv/model/base_model.py @@ -1,14 +1,16 @@ import sys +import warnings from inspect import signature -from typing import Any, Optional, Tuple +from pathlib import Path +from typing import Any, Optional, Tuple, Union import torch from cccv.arch import ARCH_REGISTRY -from cccv.cache_models import load_file_from_url from cccv.config import BaseConfig from cccv.type import BaseModelInterface from cccv.util.device import DEFAULT_DEVICE +from cccv.util.remote import load_file_from_url class CCBaseModel(BaseModelInterface): @@ -37,7 +39,7 @@ def __init__( tile: Optional[Tuple[int, int]] = (128, 128), tile_pad: int = 8, pad_img: Optional[Tuple[int, int]] = None, - model_dir: Optional[str] = None, + model_dir: Optional[Union[Path, str]] = None, gh_proxy: Optional[str] = None, **kwargs: Any, ) -> None: @@ -56,7 +58,7 @@ def __init__( self.tile: Optional[Tuple[int, int]] = tile self.tile_pad: int = tile_pad self.pad_img: Optional[Tuple[int, int]] = pad_img - self.model_dir: Optional[str] = model_dir + self.model_dir: Optional[Union[Path, str]] = model_dir self.gh_proxy: Optional[str] = gh_proxy # post-hook: edit parameters here if needed @@ -72,7 +74,7 @@ def __init__( try: self.model = self.model.half() except Exception as e: - print(f"[CCCV] Warning: {e}. \nfp16 is not supported on this model, fallback to fp32.") + warnings.warn(f"[CCCV] {e}. fp16 is not supported on this model, fallback to fp32.", stacklevel=2) self.fp16 = False self.model = self.load_model() @@ -86,7 +88,7 @@ def __init__( self.compile_backend = "inductor" self.model = torch.compile(self.model, backend=self.compile_backend) except Exception as e: - print(f"[CCCV] Error: {e}, compile is not supported on this model.") + warnings.warn(f"[CCCV] {e}, compile is not supported on this model.", stacklevel=2) def post_init_hook(self) -> None: """ @@ -105,14 +107,14 @@ def get_state_dict(self) -> Any: cfg: BaseConfig = self.config if cfg.path is not None: - state_dict_path = str(cfg.path) + state_dict_path = cfg.path else: try: state_dict_path = load_file_from_url( config=cfg, force_download=False, model_dir=self.model_dir, gh_proxy=self.gh_proxy ) except Exception as e: - print(f"[CCCV] Error: {e}, try force download the model...") + warnings.warn(f"[CCCV] Error: {e}, try force download the model...", stacklevel=2) state_dict_path = load_file_from_url( config=cfg, force_download=True, model_dir=self.model_dir, gh_proxy=self.gh_proxy ) diff --git a/cccv/util/device.py b/cccv/util/device.py index e1532a5..92894ae 100644 --- a/cccv/util/device.py +++ b/cccv/util/device.py @@ -1,4 +1,5 @@ import sys +import warnings import torch @@ -10,7 +11,7 @@ def default_device() -> torch.device: try: return torch.device("mps" if torch.backends.mps.is_available() else "cpu") except Exception as e: - print(f"[CCCV] Error: {e}, MPS is not available, use CPU instead.") + warnings.warn(f"[CCCV] {e}, MPS is not available, use CPU instead.", stacklevel=2) return torch.device("cpu") diff --git a/cccv/util/registry.py b/cccv/util/registry.py index c160600..ff6c3a9 100644 --- a/cccv/util/registry.py +++ b/cccv/util/registry.py @@ -41,7 +41,7 @@ def __init__(self, name: str) -> None: def _do_register(self, name: str, obj: Any) -> None: if name in self._obj_map: - print("[CCCV] An object named '{}' was already registered in '{}' registry!".format(name, self._name)) + raise KeyError(f"[CCCV] An object named '{name}' was already registered in '{self._name}' registry!") else: self._obj_map[name] = obj @@ -69,7 +69,7 @@ def deco(func_or_class: Any) -> Any: def get(self, name: str) -> Any: ret = self._obj_map.get(name) if ret is None: - raise KeyError("[CCCV] No object named '{}' found in '{}' registry!".format(name, self._name)) + raise KeyError(f"[CCCV] No object named '{name}' found in '{self._name}' registry!") return ret def __contains__(self, name: str) -> bool: diff --git a/cccv/util/remote.py b/cccv/util/remote.py new file mode 100644 index 0000000..1598726 --- /dev/null +++ b/cccv/util/remote.py @@ -0,0 +1,173 @@ +import hashlib +import os +import shutil +import subprocess +import sys +import warnings +from pathlib import Path +from typing import Any, Optional, Union + +from tenacity import retry, stop_after_attempt, stop_after_delay, wait_random +from torch.hub import download_url_to_file + +from cccv.config import BaseConfig + +if getattr(sys, "frozen", False): + # frozen + _IS_FROZEN_ = True + CACHE_PATH = Path(sys.executable).parent.absolute() / "cache_models" +else: + # unfrozen + _IS_FROZEN_ = False + CACHE_PATH = Path(__file__).resolve().parent.parent.absolute() / "cache_models" + + +CCCV_CACHE_MODEL_DIR = os.environ.get("CCCV_CACHE_MODEL_DIR", str(CACHE_PATH)) + +CCCV_REMOTE_MODEL_ZOO = os.environ.get( + "CCCV_REMOTE_MODEL_ZOO", "https://github.com/EutropicAI/cccv/releases/download/model_zoo/" +) + + +def get_cache_dir(model_dir: Optional[Union[Path, str]] = None) -> Path: + if model_dir is None or str(model_dir) == "": + model_dir = str(CCCV_CACHE_MODEL_DIR) + print( + f"[CCCV] Using default cache model path {model_dir}, override it by setting environment variable CCCV_CACHE_MODEL_DIR" + ) + if not os.path.exists(model_dir): + os.makedirs(model_dir) + return Path(model_dir) + + +def get_file_sha256(file_path: Union[Path, str], blocksize: int = 1 << 20) -> str: + sha256 = hashlib.sha256() + with open(file_path, "rb") as f: + while True: + data = f.read(blocksize) + if not data: + break + sha256.update(data) + return sha256.hexdigest() + + +def load_file_from_url( + config: BaseConfig, + force_download: bool = False, + progress: bool = True, + model_dir: Optional[Union[Path, str]] = None, + gh_proxy: Optional[str] = None, + **kwargs: Any, +) -> Path: + """ + Load file form http url, will download models if necessary. + + Reference: https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + + :param config: The config object. + :param force_download: Whether to force download the file. + :param progress: Whether to show the download progress. + :param model_dir: The path to save the downloaded model. Should be a full path. If None, use default cache path. + :param gh_proxy: The proxy for downloading from github release. Example: https://github.abskoop.workers.dev/ + :return: + """ + model_dir = get_cache_dir(model_dir) + cached_file_path = model_dir / config.name + + if config.url is not None: + _url: str = str(config.url) + else: + remote_zoo = CCCV_REMOTE_MODEL_ZOO + print( + f"[CCCV] Fetching models from {remote_zoo}, override it by setting environment variable CCCV_REMOTE_MODEL_ZOO" + ) + if not remote_zoo.endswith("/"): + remote_zoo += "/" + _url = remote_zoo + config.name + + _gh_proxy = gh_proxy + if _gh_proxy is not None and _url.startswith("https://github.com"): + if not _gh_proxy.endswith("/"): + _gh_proxy += "/" + _url = _gh_proxy + _url + + if not cached_file_path.exists() or force_download: + if _gh_proxy is not None: + print(f"[CCCV] Using github proxy: {_gh_proxy}") + print(f"[CCCV] Downloading: {_url} to {cached_file_path}\n") + + @retry(wait=wait_random(min=3, max=5), stop=stop_after_delay(10) | stop_after_attempt(30)) + def _download() -> None: + try: + download_url_to_file(url=_url, dst=str(cached_file_path), hash_prefix=None, progress=progress) + except Exception as e: + warnings.warn(f"[CCCV] Download failed: {e}, retrying...", stacklevel=2) + raise e + + _download() + + if config.hash is not None: + get_hash = get_file_sha256(cached_file_path) + if get_hash != config.hash: + raise ValueError( + f"[CCCV] File {cached_file_path} hash mismatched with config hash {config.hash}, compare with {get_hash}" + ) + + return cached_file_path + + +def git_clone(git_url: str, model_dir: Optional[Union[Path, str]] = None, **kwargs: Any) -> Path: + """ + Clone or update a git repository. We suggest use HuggingFace repo instead of GitHub repo for larger models. + + :param git_url: GitHub repository URL + :param model_dir: Directory to clone into + :param **kwargs: Additional git options (branch, commit_hash, etc.) + :return: Path to the cloned repository + """ + if not shutil.which("git"): + warnings.warn( + "[CCCV] git is not installed or not in the system's PATH. " + "Please install git to use models from remote git repositories.", + stacklevel=2, + ) + + model_dir = get_cache_dir(model_dir) + # get repo name from url + repo_name = git_url.split("/")[-1].replace(".git", "") + clone_dir = model_dir / repo_name + + if clone_dir.exists() and (clone_dir / ".git").exists(): + print(f"[CCCV] Repository exists, updating: {clone_dir}") + subprocess.run(["git", "-C", str(clone_dir), "pull"], check=True) + + # if branch or commit_hash is specified, checkout to that + if "branch" in kwargs: + subprocess.run(["git", "-C", str(clone_dir), "checkout", kwargs["branch"]], check=True) + if "commit_hash" in kwargs: + subprocess.run(["git", "-C", str(clone_dir), "reset", "--hard", kwargs["commit_hash"]], check=True) + else: + # clone the repo if not exists + print(f"[CCCV] Cloning repository: {git_url} -> {clone_dir}") + command = ["git", "clone", git_url, str(clone_dir)] + + if "branch" in kwargs: + command.extend(["--branch", kwargs["branch"]]) + + subprocess.run(command, check=True) + + if "commit_hash" in kwargs: + subprocess.run(["git", "-C", str(clone_dir), "reset", "--hard", kwargs["commit_hash"]], check=True) + + return clone_dir + + +if __name__ == "__main__": + # get all model files sha256 + for root, _, files in os.walk(get_cache_dir()): + for file in files: + if not file.endswith(".pth") and not file.endswith(".pt"): + continue + file_path = os.path.join(root, file) + name = os.path.basename(file_path) + print(f"[CCCV] {name}: {get_file_sha256(file_path)}") diff --git a/example/register.py b/example/register.py index 7cb8245..1b11aec 100644 --- a/example/register.py +++ b/example/register.py @@ -1,6 +1,6 @@ from typing import Any -from cccv import ArchType, AutoConfig, AutoModel, SRBaseModel +from cccv import CONFIG_REGISTRY, MODEL_REGISTRY, ArchType, AutoModel, SRBaseModel from cccv.config import RealESRGANConfig # define your own config name and model name @@ -17,11 +17,11 @@ scale=2, ) -AutoConfig.register(cfg) +CONFIG_REGISTRY.register(cfg) # extend from cccv.SRBaseModel then implement your own model -@AutoModel.register(name=model_name) +@MODEL_REGISTRY.register(name=model_name) class TESTMODEL(SRBaseModel): def load_model(self) -> Any: print("Override load_model function here") diff --git a/example/remote.py b/example/remote.py new file mode 100644 index 0000000..10fb070 --- /dev/null +++ b/example/remote.py @@ -0,0 +1,10 @@ +import cv2 +import numpy as np + +from cccv import AutoModel, SRBaseModel + +model: SRBaseModel = AutoModel.from_pretrained("https://github.com/EutropicAI/cccv_demo_remote_model") + +img = cv2.imdecode(np.fromfile("../assets/test.jpg", dtype=np.uint8), cv2.IMREAD_COLOR) +img = model.inference_image(img) +cv2.imwrite("../assets/test_remote_example_out.jpg", img) diff --git a/tests/test_auto_class.py b/tests/test_auto_class.py index 326c255..a63bfcc 100644 --- a/tests/test_auto_class.py +++ b/tests/test_auto_class.py @@ -1,8 +1,18 @@ from typing import Any -from cccv import ArchType, AutoConfig, AutoModel +import cv2 + +from cccv import CONFIG_REGISTRY, MODEL_REGISTRY, ArchType, AutoModel from cccv.config import RealESRGANConfig from cccv.model import SRBaseModel +from tests.util import ( + ASSETS_PATH, + CCCV_DEVICE, + CCCV_FP16, + CCCV_TILE, + calculate_image_similarity, + load_image, +) def test_auto_class_register() -> None: @@ -17,12 +27,26 @@ def test_auto_class_register() -> None: scale=2, ) - AutoConfig.register(cfg) + CONFIG_REGISTRY.register(cfg) - @AutoModel.register(name=model_name) + @MODEL_REGISTRY.register(name=model_name) class TESTMODEL(SRBaseModel): def get_cfg(self) -> Any: return self.config model: TESTMODEL = AutoModel.from_pretrained(cfg_name) assert model.get_cfg() == cfg + + +class Test_AutoModel: + def test_model_from_remote_repo(self) -> None: + model: SRBaseModel = AutoModel.from_pretrained( + "https://github.com/EutropicAI/cccv_demo_remote_model", device=CCCV_DEVICE, fp16=CCCV_FP16, tile=CCCV_TILE + ) + + img1 = load_image() + img2 = model.inference_image(img1) + + cv2.imwrite(str(ASSETS_PATH / "test_remote_repo_test_out.jpg"), img2) + + assert calculate_image_similarity(img1, img2) diff --git a/tests/test_cache_models.py b/tests/test_remote.py similarity index 71% rename from tests/test_cache_models.py rename to tests/test_remote.py index 37c0be0..49c7ca4 100644 --- a/tests/test_cache_models.py +++ b/tests/test_remote.py @@ -1,5 +1,5 @@ from cccv import CONFIG_REGISTRY, ConfigType -from cccv.cache_models import load_file_from_url +from cccv.util.remote import get_cache_dir, git_clone, load_file_from_url def test_cache_models() -> None: @@ -17,3 +17,8 @@ def test_cache_models_with_gh_proxy() -> None: force_download=True, gh_proxy="https://github.abskoop.workers.dev", ) + + +def test_git_clone() -> None: + clone_dir = git_clone("https://github.com/EutropicAI/cccv_demo_remote_model") + assert clone_dir == get_cache_dir() / "cccv_demo_remote_model" diff --git a/tests/test_tile.py b/tests/test_tile.py index e0e4b6f..0fa37d1 100644 --- a/tests/test_tile.py +++ b/tests/test_tile.py @@ -6,8 +6,7 @@ from cccv import AutoModel, ConfigType from cccv.model import SRBaseModel, tile_sr - -from .util import ASSETS_PATH, CCCV_DEVICE, calculate_image_similarity, compare_image_size, load_image +from tests.util import ASSETS_PATH, CCCV_DEVICE, calculate_image_similarity, compare_image_size, load_image def test_tile_sr() -> None: diff --git a/tests/test_util.py b/tests/test_util.py index a767cef..fa9ff81 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -15,8 +15,7 @@ resize, ssim_matlab, ) - -from .util import calculate_image_similarity, load_image +from tests.util import calculate_image_similarity, load_image def test_device() -> None: