diff --git a/medcat-den/README.md b/medcat-den/README.md index bab9013d..37d00670 100644 --- a/medcat-den/README.md +++ b/medcat-den/README.md @@ -133,5 +133,7 @@ However, there's a set of environmental variables that can be set in order to cu | MEDCAT_DEN_LOCAL_CACHE_EXPIRATION_TIME | int | The expriation time for local cache (in seconds) | The default is 10 days | | MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE | int | The maximum size of the cache in bytes | The default is 100 GB | | MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY | str | The eviction policy for the local cache | The default is LRU | +| MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED | bool | Whether to allow locallly fine tuned model to be pushed to remote dens | Defaults to False | +| MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE | bool | Whether to allow local fine tuning for remote dens | Defaults to False | When creating a den, the resolver will use the explicitly passed values first, and if none are provided, it will default to the ones defined in the environmental variables. diff --git a/medcat-den/src/medcat_den/config.py b/medcat-den/src/medcat_den/config.py index 56927f83..b1d7ebc6 100644 --- a/medcat-den/src/medcat_den/config.py +++ b/medcat-den/src/medcat_den/config.py @@ -17,6 +17,8 @@ class LocalDenConfig(DenConfig): class RemoteDenConfig(DenConfig): host: str credentials: dict + allow_local_fine_tune: bool + allow_push_fine_tuned: bool class LocalCacheConfig(BaseModel): diff --git a/medcat-den/src/medcat_den/den.py b/medcat-den/src/medcat_den/den.py index 4734b9f0..eb91331c 100644 --- a/medcat-den/src/medcat_den/den.py +++ b/medcat-den/src/medcat_den/den.py @@ -1,6 +1,7 @@ -from typing import Protocol, Optional, runtime_checkable +from typing import Protocol, Optional, runtime_checkable, Union from medcat.cat import CAT +from medcat.data.mctexport import MedCATTrainerExport from medcat_den.base import ModelInfo from medcat_den.wrappers import CATWrapper @@ -102,6 +103,55 @@ def delete_model(self, model_info: ModelInfo, """ pass + def finetune_model(self, model_info: ModelInfo, + data: Union[list[str], MedCATTrainerExport] + ) -> ModelInfo: + """Finetune the model on the remote den. + + This is an optional API that is (generally) only available + for remote dens. The idea is that the data is sent to the remote + den and the finetuning is done on the remote. + + If raw data is given, unless already present remotely, it will be + uploaded to the remote den. + + Args: + model_info (ModelInfo): The model info + data (Union[list[str], MedCATTrainerExport]): The list of project + ids (already on remote) or the trainer export to train on. + + Returns: + ModelInfo: The resulting model. + + Raises: + UnsupportedAPIException: If the den does not support this API. + """ + + def evaluate_model(self, model_info: ModelInfo, + data: Union[list[str], MedCATTrainerExport]) -> dict: + """Evaluate model on remote den. + + This is an optional API that is (generally) only available + for remote dens. The idea is that the data is sent to the remote + den and the metrics are gathered on the remote. + + If raw data is given, unless already present remotely, it will be + uploaded to the remote den. + + Args: + model_info (ModelInfo): The model info. + data (Union[list[str], MedCATTrainerExport]): The list of project + ids (already on remote) or the trainer export to train on. + + Returns: + dict: The resulting metrics. + """ + pass + + +class UnsupportedAPIException(ValueError): + pass + def get_default_den( type_: Optional[DenType] = None, @@ -112,6 +162,8 @@ def get_default_den( expiration_time: Optional[int] = None, max_size: Optional[int] = None, eviction_policy: Optional[str] = None, + remote_allow_local_fine_tune: Optional[str] = None, + remote_allow_push_fine_tuned: Optional[str] = None, ) -> Den: """Get the default den. @@ -137,6 +189,10 @@ def get_default_den( Policies avialable: LRU (`least-recently-used`), LRS (`least-recently-stored`), LFU (`least-frequently-used`), and `none` (disables evictions). + remote_allow_local_fine_tune (Optional[str]): Whether to allow local + fine tuning of remote models. + remote_allow_push_fine_tuned (Optional[str]): Whether to allow pushing + of locally fine-tuned models to the remote Returns: Den: The resolved den. @@ -144,7 +200,8 @@ def get_default_den( # NOTE: doing dynamic import to avoid circular imports from medcat_den.resolver import resolve return resolve(type_, location, host, credentials, local_cache_path, - expiration_time, max_size, eviction_policy) + expiration_time, max_size, eviction_policy, + remote_allow_local_fine_tune, remote_allow_push_fine_tuned) def get_default_user_local_den( diff --git a/medcat-den/src/medcat_den/den_impl/file_den.py b/medcat-den/src/medcat_den/den_impl/file_den.py index a7ea7ce8..109ebe6b 100644 --- a/medcat-den/src/medcat_den/den_impl/file_den.py +++ b/medcat-den/src/medcat_den/den_impl/file_den.py @@ -1,4 +1,4 @@ -from typing import Optional, cast, Any +from typing import Optional, cast, Any, Union import json from datetime import datetime @@ -7,8 +7,10 @@ import shutil from medcat.cat import CAT +from medcat.data.mctexport import MedCATTrainerExport -from medcat_den.den import Den, DuplicateModelException +from medcat_den.den import ( + Den, DuplicateModelException, UnsupportedAPIException) from medcat_den.backend import DenType from medcat_den.base import ModelInfo from medcat_den.wrappers import CATWrapper @@ -162,7 +164,8 @@ def fetch_model(self, model_info: ModelInfo) -> CATWrapper: model_path = self._get_model_zip_path(model_info) return cast( CATWrapper, - CATWrapper.load_model_pack(model_path, model_info=model_info)) + CATWrapper.load_model_pack(model_path, model_info=model_info, + den_cnf=self._cnf)) def push_model(self, cat: CAT, description: str) -> None: if isinstance(cat, CATWrapper): @@ -220,3 +223,17 @@ def delete_model(self, model_info: ModelInfo, folder_path = zip_path.removesuffix(".zip") if os.path.exists(folder_path): shutil.rmtree(folder_path) + + def finetune_model(self, model_info: ModelInfo, + data: Union[list[str], MedCATTrainerExport]): + raise UnsupportedAPIException( + "Local den does not support finetuning on the den. " + "Use a remote den instead or perform training locally." + ) + + def evaluate_model(self, model_info: ModelInfo, + data: Union[list[str], MedCATTrainerExport]) -> dict: + raise UnsupportedAPIException( + "Local den does not support evaluation on the den. " + "Use a remote den instead or perform evaluation locally." + ) diff --git a/medcat-den/src/medcat_den/resolver/resolver.py b/medcat-den/src/medcat_den/resolver/resolver.py index 58342f6b..fa97ff82 100644 --- a/medcat-den/src/medcat_den/resolver/resolver.py +++ b/medcat-den/src/medcat_den/resolver/resolver.py @@ -38,6 +38,12 @@ MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE = "MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE" MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY = ( "MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY") +MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE = ( + "MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE") +MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED = ( + "MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED") + +ALLOW_OPTION_LOWERCASE = ("true", "yes", "1", "y") def is_writable(path: str, propgate: bool = True) -> bool: @@ -52,7 +58,10 @@ def _init_den_cnf( type_: Optional[DenType] = None, location: Optional[str] = None, host: Optional[str] = None, - credentials: Optional[dict] = None,) -> DenConfig: + credentials: Optional[dict] = None, + remote_allow_local_fine_tune: Optional[str] = None, + remote_allow_push_fine_tuned: Optional[str] = None, + ) -> DenConfig: # Priority: args > env > defaults type_in = ( type_ @@ -82,13 +91,27 @@ def _init_den_cnf( den_cnf = LocalDenConfig(type=type_final, location=location_final) else: + host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST) if not host: raise ValueError("Need to specify a host for remote den") if not credentials: raise ValueError("Need to specify credentials for remote den") - den_cnf = RemoteDenConfig(type=type_final, - host=host, - credentials=credentials) + # NOTE: these will default to False when nothing is specified + # because "None" is not in ALLOW_OPTION_LOWERCASE + allow_local_fine_tune = str( + remote_allow_local_fine_tune or + os.getenv(MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE) + ).lower() in ALLOW_OPTION_LOWERCASE + allow_push_fine_tuned = str( + remote_allow_push_fine_tuned or + os.getenv(MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED) + ).lower() in ALLOW_OPTION_LOWERCASE + den_cnf = RemoteDenConfig( + type=type_final, + host=host, + credentials=credentials, + allow_local_fine_tune=allow_local_fine_tune, + allow_push_fine_tuned=allow_push_fine_tuned) return den_cnf @@ -101,8 +124,12 @@ def resolve( expiration_time: Optional[int] = None, max_size: Optional[int] = None, eviction_policy: Optional[str] = None, + remote_allow_local_fine_tune: Optional[str] = None, + remote_allow_push_fine_tuned: Optional[str] = None, ) -> Den: - den_cnf = _init_den_cnf(type_, location, host, credentials) + den_cnf = _init_den_cnf(type_, location, host, credentials, + remote_allow_local_fine_tune, + remote_allow_push_fine_tuned) den = resolve_from_config(den_cnf) lc_cnf = _init_lc_cnf( local_cache_path, expiration_time, max_size, eviction_policy) @@ -126,19 +153,13 @@ def _resolve_local(config: LocalDenConfig) -> LocalFileDen: def resolve_from_config(config: DenConfig) -> Den: if isinstance(config, LocalDenConfig): return _resolve_local(config) - # TODO: support remote (e) - # elif type_final == DenType.MEDCATTERY: - # host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST) - # if host is None: - # raise ValueError("Remote DEN requires a host address") - # # later you’d plug in MedcatteryRemoteDen, MLFlowDen, etc. - # return MedCATteryDen(host=host, credentials=credentials) elif has_registered_remote_den(config.type): den_cls = get_registered_remote_den(config.type) den = den_cls(cnf=config) if not isinstance(den, Den): raise ValueError( - f"Registered den class for {config.type} is not a Den") + f"Registered den class for {config.type} is not a Den. " + f"Got {type(den)}: {den}") return den else: raise ValueError( diff --git a/medcat-den/src/medcat_den/wrappers.py b/medcat-den/src/medcat_den/wrappers.py index c3f00804..a46c9318 100644 --- a/medcat-den/src/medcat_den/wrappers.py +++ b/medcat-den/src/medcat_den/wrappers.py @@ -3,8 +3,11 @@ from medcat.cat import CAT from medcat.utils.defaults import DEFAULT_PACK_NAME from medcat.storage.serialisers import AvailableSerialisers +from medcat.trainer import Trainer +from medcat.data.mctexport import MedCATTrainerExport from medcat_den.base import ModelInfo +from medcat_den.config import DenConfig, RemoteDenConfig class CATWrapper(CAT): @@ -20,6 +23,7 @@ class CATWrapper(CAT): """ _model_info: ModelInfo + _den_cnf: DenConfig def save_model_pack( self, target_folder: str, pack_name: str = DEFAULT_PACK_NAME, @@ -54,19 +58,36 @@ def save_model_pack( if not force_save_local and not is_injected_for_save(): raise CannotSaveOnDiskException( f"Cannot save model on disk: {CATWrapper.__doc__}") + if (is_injected_for_save() and isinstance( + self._den_cnf, RemoteDenConfig) and + not self._den_cnf.allow_push_fine_tuned): + # NOTE: should there be a check whether this is a base model? + raise CannotSendToRemoteException( + "Cannot save fine-tuned model onto a remote den." + "In order to make full use of the remote den capabilities, " + "use the den API to fine tune a model directly on the den. " + "See `Den.finetune_model` for details or set the config " + "option of `allow_push_fine_tuned` to True" + ) return super().save_model_pack( target_folder, pack_name, serialiser_type, make_archive, only_archive, add_hash_to_pack_name, change_description) + @property + def trainer(self) -> Trainer: + tr = super().trainer + return WrappedTrainer(self._den_cnf, tr) + @classmethod def load_model_pack(cls, model_pack_path: str, config_dict: Optional[dict] = None, addon_config_dict: Optional[dict[str, dict]] = None, model_info: Optional[ModelInfo] = None, + den_cnf: Optional[DenConfig] = None, ) -> 'CAT': """Load the model pack from file. - This also + This may also disallow model load from disk in certain secnarios. Args: model_pack_path (str): The model pack path. @@ -80,6 +101,9 @@ def load_model_pack(cls, model_pack_path: str, model_inof (Optional[ModelInfo]): The base model info based on which the model was originally fetched. Should not be left None. + den_cnf: (Optional[DenConfig]): The config for the den being + used. Should not be left None. + Raises: ValueError: If the saved data does not represent a model pack. @@ -95,10 +119,45 @@ def load_model_pack(cls, model_pack_path: str, cat.__class__ = CATWrapper if model_info is None: raise CannotWrapModel("Model info must be provided") + if den_cnf is None: + raise CannotWrapModel("den_cnf must be provided") cat._model_info = model_info + cat._den_cnf = den_cnf return cat +class WrappedTrainer(Trainer): + + def __init__(self, den_cnf: DenConfig, delegate: Trainer): + super().__init__(delegate.cdb, delegate.caller, delegate._pipeline) + self._den_cnf = den_cnf + + def train_supervised_raw( + self, data: MedCATTrainerExport, reset_cui_count: bool = False, + nepochs: int = 1, print_stats: int = 0, use_filters: bool = False, + terminate_last: bool = False, use_overlaps: bool = False, + use_cui_doc_limit: bool = False, test_size: float = 0, + devalue_others: bool = False, use_groups: bool = False, + never_terminate: bool = False, + train_from_false_positives: bool = False, + extra_cui_filter: Optional[set[str]] = None, + disable_progress: bool = False, train_addons: bool = False): + if (isinstance(self._den_cnf, RemoteDenConfig) and + not self._den_cnf.allow_local_fine_tune): + raise NotAllowedToFineTuneLocallyException( + "You are not allowed to fine-tune remote models locally. " + "Please use the `Den.finetune_model` method directly to " + "fine tune on the remote den, or if required, set the " + "`allow_local_fine_tune` config value to `True`." + ) + return super().train_supervised_raw( + data, reset_cui_count, nepochs, print_stats, use_filters, + terminate_last, use_overlaps, use_cui_doc_limit, test_size, + devalue_others, use_groups, never_terminate, + train_from_false_positives, extra_cui_filter, disable_progress, + train_addons) + + class CannotWrapModel(ValueError): def __init__(self, *args): @@ -109,3 +168,15 @@ class CannotSaveOnDiskException(ValueError): def __init__(self, *args): super().__init__(*args) + + +class CannotSendToRemoteException(ValueError): + + def __call__(self, *args): + return super().__call__(*args) + + +class NotAllowedToFineTuneLocallyException(ValueError): + + def __call__(self, *args): + return super().__call__(*args) diff --git a/medcat-den/tests/test_backend_registration.py b/medcat-den/tests/test_backend_registration.py index d2e36aa4..e5d5de09 100644 --- a/medcat-den/tests/test_backend_registration.py +++ b/medcat-den/tests/test_backend_registration.py @@ -1,43 +1,15 @@ import pytest +from unittest.mock import MagicMock -from medcat.cat import CAT - -from medcat_den.base import ModelInfo -from medcat_den.wrappers import CATWrapper from medcat_den.backend import DenType, _remote_den_map, register_remote_den from medcat_den.resolver import resolve +from medcat_den.den import Den -class FakeDen: - def __init__(self, **kwargs): - return - - @property - def den_type(self) -> DenType: - return DenType.MEDCATTERY - - def list_available_models(self) -> list[ModelInfo]: - return [] - - def list_available_base_models(self) -> list[ModelInfo]: - return [] - - def list_available_derivative_models(self, model: ModelInfo - ) -> list[ModelInfo]: - return [] - - def fetch_model(self, model_info: ModelInfo) -> CATWrapper: - return - - def push_model(self, cat: CAT, description: str) -> None: - return - - def _push_model_from_file(self, file_path: str, description: str) -> None: - return +class FakeDen(MagicMock): - def delete_model(self, model_info: ModelInfo, - allow_delete_base_models: bool = False) -> None: - return + def __init__(self, *args, **kw): + super().__init__(*args, **kw, spec=Den) @pytest.fixture() diff --git a/medcat-den/tests/test_remote_den_disallows.py b/medcat-den/tests/test_remote_den_disallows.py new file mode 100644 index 00000000..567c0fcb --- /dev/null +++ b/medcat-den/tests/test_remote_den_disallows.py @@ -0,0 +1,149 @@ +from typing import cast +from medcat_den.config import RemoteDenConfig +from medcat_den.backend import DenType +from medcat_den.injection import injected_den +from medcat_den.wrappers import CATWrapper, CannotSendToRemoteException +from medcat_den.wrappers import NotAllowedToFineTuneLocallyException +from medcat_den.den import Den +from medcat_den.den_impl.file_den import LocalFileDen +from medcat_den.base import ModelInfo + +from medcat.cat import CAT + +import pytest + +from .test_file_system_den import def_model_pack, den, UNSUP_TRAIN_EXAMPLE # noqa + + +def get_wrapped_model_pack( + in_model_pack: CAT, cnf: RemoteDenConfig) -> CATWrapper: # noqa + # make it a wrapper + in_model_pack.__class__ = CATWrapper + model_pack = cast(CATWrapper, in_model_pack) + # set required stuff, mostly the config + model_pack._den_cnf = cnf + # set model info + model_pack._model_info = ModelInfo.from_model_pack(model_pack) + return model_pack + + +@pytest.fixture +def cnf_disallow_all(): + return RemoteDenConfig(type=DenType.MEDCATTERY, + host="ABC", + credentials={"A": "B"}, + allow_local_fine_tune=False, + allow_push_fine_tuned=False, + ) + + +@pytest.fixture +def cnf_allow_push_only(): + return RemoteDenConfig(type=DenType.MEDCATTERY, + host="ABC", + credentials={"A": "B"}, + allow_local_fine_tune=False, + allow_push_fine_tuned=True, + ) + + +@pytest.fixture +def cnf_allow_finetune_only(): + return RemoteDenConfig(type=DenType.MEDCATTERY, + host="ABC", + credentials={"A": "B"}, + allow_local_fine_tune=True, + allow_push_fine_tuned=False, + ) + + +@pytest.fixture +def cnf_allow_both(): + return RemoteDenConfig(type=DenType.MEDCATTERY, + host="ABC", + credentials={"A": "B"}, + allow_local_fine_tune=True, + allow_push_fine_tuned=True, + ) + + +@pytest.fixture +def den_disallow_all(den: LocalFileDen, cnf_disallow_all: RemoteDenConfig) -> Den: # noqa + # NOTE: local den with remote config + den._cnf = cnf_disallow_all + return den + + +@pytest.fixture +def den_allow_push_only(den: LocalFileDen, cnf_allow_push_only: RemoteDenConfig) -> Den: # noqa + # NOTE: local den with remote config + den._cnf = cnf_allow_push_only + return den + + +@pytest.fixture +def den_allow_finetune_only(den: LocalFileDen, cnf_allow_finetune_only: RemoteDenConfig) -> Den: # noqa + # NOTE: local den with remote config + den._cnf = cnf_allow_finetune_only + return den + + +def test_can_normally_push(def_model_pack: CAT, den: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den._cnf) + with injected_den(lambda: den, inject_save=True): + # do some training + model_pack.trainer.train_unsupervised(UNSUP_TRAIN_EXAMPLE) + # should be able to just send to den + model_pack.save_model_pack("Did some fine-tuning") + + + +def test_can_disallow_push_all(def_model_pack: CAT, den_disallow_all: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den_disallow_all._cnf) + with injected_den(lambda: den_disallow_all, inject_save=True): + # do some training + model_pack.trainer.train_unsupervised(UNSUP_TRAIN_EXAMPLE) + # attempt to save to den + with pytest.raises(CannotSendToRemoteException): + model_pack.save_model_pack("Did some fine-tuning") + + +def test_can_disallow_save_finetune_only(def_model_pack: CAT, den_allow_finetune_only: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den_allow_finetune_only._cnf) + with injected_den(lambda: den_allow_finetune_only, inject_save=True): + # do some training + model_pack.trainer.train_unsupervised(UNSUP_TRAIN_EXAMPLE) + # attempt to save to den + with pytest.raises(CannotSendToRemoteException): + model_pack.save_model_pack("Did some fine-tuning") + + + +def test_can_normally_fine_tune(def_model_pack: CAT, den: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den._cnf) + with injected_den(lambda: den, inject_save=True): + # should be able to just do some supervised training + model_pack.trainer.train_supervised_raw({"projects": []}) + + +def test_can_disallow_fine_tune_all(def_model_pack: CAT, den_disallow_all: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den_disallow_all._cnf) + with injected_den(lambda: den_disallow_all, inject_save=True): + # attempt to save do supervised training + with pytest.raises(NotAllowedToFineTuneLocallyException): + model_pack.trainer.train_supervised_raw({"projects": []}) + + +def test_can_disallow_fine_tune_push_only(def_model_pack: CAT, den_allow_push_only: LocalFileDen): # noqa + model_pack = get_wrapped_model_pack( + def_model_pack, den_allow_push_only._cnf) + with injected_den(lambda: den_allow_push_only, inject_save=True): + # attempt to save do supervised training + with pytest.raises(NotAllowedToFineTuneLocallyException): + model_pack.trainer.train_supervised_raw({"projects": []}) + model_pack.save_model_pack("Did some fine-tuning")