Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions medcat-den/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -133,5 +133,7 @@ However, there's a set of environmental variables that can be set in order to cu
| MEDCAT_DEN_LOCAL_CACHE_EXPIRATION_TIME | int | The expriation time for local cache (in seconds) | The default is 10 days |
| MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE | int | The maximum size of the cache in bytes | The default is 100 GB |
| MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY | str | The eviction policy for the local cache | The default is LRU |
| MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED | bool | Whether to allow locallly fine tuned model to be pushed to remote dens | Defaults to False |
| MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE | bool | Whether to allow local fine tuning for remote dens | Defaults to False |

When creating a den, the resolver will use the explicitly passed values first, and if none are provided, it will default to the ones defined in the environmental variables.
2 changes: 2 additions & 0 deletions medcat-den/src/medcat_den/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ class LocalDenConfig(DenConfig):
class RemoteDenConfig(DenConfig):
host: str
credentials: dict
allow_local_fine_tune: bool
allow_push_fine_tuned: bool


class LocalCacheConfig(BaseModel):
Expand Down
61 changes: 59 additions & 2 deletions medcat-den/src/medcat_den/den.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Protocol, Optional, runtime_checkable
from typing import Protocol, Optional, runtime_checkable, Union

from medcat.cat import CAT
from medcat.data.mctexport import MedCATTrainerExport

from medcat_den.base import ModelInfo
from medcat_den.wrappers import CATWrapper
Expand Down Expand Up @@ -102,6 +103,55 @@ def delete_model(self, model_info: ModelInfo,
"""
pass

def finetune_model(self, model_info: ModelInfo,
data: Union[list[str], MedCATTrainerExport]
) -> ModelInfo:
"""Finetune the model on the remote den.

This is an optional API that is (generally) only available
for remote dens. The idea is that the data is sent to the remote
den and the finetuning is done on the remote.

If raw data is given, unless already present remotely, it will be
uploaded to the remote den.

Args:
model_info (ModelInfo): The model info
data (Union[list[str], MedCATTrainerExport]): The list of project
ids (already on remote) or the trainer export to train on.

Returns:
ModelInfo: The resulting model.

Raises:
UnsupportedAPIException: If the den does not support this API.
"""

def evaluate_model(self, model_info: ModelInfo,
data: Union[list[str], MedCATTrainerExport]) -> dict:
"""Evaluate model on remote den.

This is an optional API that is (generally) only available
for remote dens. The idea is that the data is sent to the remote
den and the metrics are gathered on the remote.

If raw data is given, unless already present remotely, it will be
uploaded to the remote den.

Args:
model_info (ModelInfo): The model info.
data (Union[list[str], MedCATTrainerExport]): The list of project
ids (already on remote) or the trainer export to train on.

Returns:
dict: The resulting metrics.
"""
pass


class UnsupportedAPIException(ValueError):
pass


def get_default_den(
type_: Optional[DenType] = None,
Expand All @@ -112,6 +162,8 @@ def get_default_den(
expiration_time: Optional[int] = None,
max_size: Optional[int] = None,
eviction_policy: Optional[str] = None,
remote_allow_local_fine_tune: Optional[str] = None,
remote_allow_push_fine_tuned: Optional[str] = None,
) -> Den:
"""Get the default den.

Expand All @@ -137,14 +189,19 @@ def get_default_den(
Policies avialable: LRU (`least-recently-used`),
LRS (`least-recently-stored`), LFU (`least-frequently-used`),
and `none` (disables evictions).
remote_allow_local_fine_tune (Optional[str]): Whether to allow local
fine tuning of remote models.
remote_allow_push_fine_tuned (Optional[str]): Whether to allow pushing
of locally fine-tuned models to the remote

Returns:
Den: The resolved den.
"""
# NOTE: doing dynamic import to avoid circular imports
from medcat_den.resolver import resolve
return resolve(type_, location, host, credentials, local_cache_path,
expiration_time, max_size, eviction_policy)
expiration_time, max_size, eviction_policy,
remote_allow_local_fine_tune, remote_allow_push_fine_tuned)


def get_default_user_local_den(
Expand Down
23 changes: 20 additions & 3 deletions medcat-den/src/medcat_den/den_impl/file_den.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, cast, Any
from typing import Optional, cast, Any, Union

import json
from datetime import datetime
Expand All @@ -7,8 +7,10 @@
import shutil

from medcat.cat import CAT
from medcat.data.mctexport import MedCATTrainerExport

from medcat_den.den import Den, DuplicateModelException
from medcat_den.den import (
Den, DuplicateModelException, UnsupportedAPIException)
from medcat_den.backend import DenType
from medcat_den.base import ModelInfo
from medcat_den.wrappers import CATWrapper
Expand Down Expand Up @@ -162,7 +164,8 @@ def fetch_model(self, model_info: ModelInfo) -> CATWrapper:
model_path = self._get_model_zip_path(model_info)
return cast(
CATWrapper,
CATWrapper.load_model_pack(model_path, model_info=model_info))
CATWrapper.load_model_pack(model_path, model_info=model_info,
den_cnf=self._cnf))

def push_model(self, cat: CAT, description: str) -> None:
if isinstance(cat, CATWrapper):
Expand Down Expand Up @@ -220,3 +223,17 @@ def delete_model(self, model_info: ModelInfo,
folder_path = zip_path.removesuffix(".zip")
if os.path.exists(folder_path):
shutil.rmtree(folder_path)

def finetune_model(self, model_info: ModelInfo,
data: Union[list[str], MedCATTrainerExport]):
raise UnsupportedAPIException(
"Local den does not support finetuning on the den. "
"Use a remote den instead or perform training locally."
)

def evaluate_model(self, model_info: ModelInfo,
data: Union[list[str], MedCATTrainerExport]) -> dict:
raise UnsupportedAPIException(
"Local den does not support evaluation on the den. "
"Use a remote den instead or perform evaluation locally."
)
47 changes: 34 additions & 13 deletions medcat-den/src/medcat_den/resolver/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@
MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE = "MEDCAT_DEN_LOCAL_CACHE_MAX_SIZE"
MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY = (
"MEDCAT_DEN_LOCAL_CACHE_EVICTION_POLICY")
MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE = (
"MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE")
MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED = (
"MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED")

ALLOW_OPTION_LOWERCASE = ("true", "yes", "1", "y")


def is_writable(path: str, propgate: bool = True) -> bool:
Expand All @@ -52,7 +58,10 @@ def _init_den_cnf(
type_: Optional[DenType] = None,
location: Optional[str] = None,
host: Optional[str] = None,
credentials: Optional[dict] = None,) -> DenConfig:
credentials: Optional[dict] = None,
remote_allow_local_fine_tune: Optional[str] = None,
remote_allow_push_fine_tuned: Optional[str] = None,
) -> DenConfig:
# Priority: args > env > defaults
type_in = (
type_
Expand Down Expand Up @@ -82,13 +91,27 @@ def _init_den_cnf(
den_cnf = LocalDenConfig(type=type_final,
location=location_final)
else:
host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST)
if not host:
raise ValueError("Need to specify a host for remote den")
if not credentials:
raise ValueError("Need to specify credentials for remote den")
den_cnf = RemoteDenConfig(type=type_final,
host=host,
credentials=credentials)
# NOTE: these will default to False when nothing is specified
# because "None" is not in ALLOW_OPTION_LOWERCASE
allow_local_fine_tune = str(
remote_allow_local_fine_tune or
os.getenv(MEDCAT_DEN_REMOTE_ALLOW_LOCAL_FINE_TUNE)
).lower() in ALLOW_OPTION_LOWERCASE
allow_push_fine_tuned = str(
remote_allow_push_fine_tuned or
os.getenv(MEDCAT_DEN_REMOTE_ALLOW_PUSH_FINETUNED)
).lower() in ALLOW_OPTION_LOWERCASE
den_cnf = RemoteDenConfig(
type=type_final,
host=host,
credentials=credentials,
allow_local_fine_tune=allow_local_fine_tune,
allow_push_fine_tuned=allow_push_fine_tuned)
return den_cnf


Expand All @@ -101,8 +124,12 @@ def resolve(
expiration_time: Optional[int] = None,
max_size: Optional[int] = None,
eviction_policy: Optional[str] = None,
remote_allow_local_fine_tune: Optional[str] = None,
remote_allow_push_fine_tuned: Optional[str] = None,
) -> Den:
den_cnf = _init_den_cnf(type_, location, host, credentials)
den_cnf = _init_den_cnf(type_, location, host, credentials,
remote_allow_local_fine_tune,
remote_allow_push_fine_tuned)
den = resolve_from_config(den_cnf)
lc_cnf = _init_lc_cnf(
local_cache_path, expiration_time, max_size, eviction_policy)
Expand All @@ -126,19 +153,13 @@ def _resolve_local(config: LocalDenConfig) -> LocalFileDen:
def resolve_from_config(config: DenConfig) -> Den:
if isinstance(config, LocalDenConfig):
return _resolve_local(config)
# TODO: support remote (e)
# elif type_final == DenType.MEDCATTERY:
# host = host or os.getenv(MEDCAT_DEN_REMOTE_HOST)
# if host is None:
# raise ValueError("Remote DEN requires a host address")
# # later you’d plug in MedcatteryRemoteDen, MLFlowDen, etc.
# return MedCATteryDen(host=host, credentials=credentials)
elif has_registered_remote_den(config.type):
den_cls = get_registered_remote_den(config.type)
den = den_cls(cnf=config)
if not isinstance(den, Den):
raise ValueError(
f"Registered den class for {config.type} is not a Den")
f"Registered den class for {config.type} is not a Den. "
f"Got {type(den)}: {den}")
return den
else:
raise ValueError(
Expand Down
73 changes: 72 additions & 1 deletion medcat-den/src/medcat_den/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from medcat.cat import CAT
from medcat.utils.defaults import DEFAULT_PACK_NAME
from medcat.storage.serialisers import AvailableSerialisers
from medcat.trainer import Trainer
from medcat.data.mctexport import MedCATTrainerExport
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more a note to self - this export model should be used in medcat-trainer

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There should be some consolidation on that end probably. But I'm not sure this will be too useful. I've only really gathered the parts of the schema that are useful for the library. The trainer adds a bunch of other stuff that's not really used during superivsed training, so they don't appear in the schema for the TypedDict that is MedCATTrainerExport .


from medcat_den.base import ModelInfo
from medcat_den.config import DenConfig, RemoteDenConfig


class CATWrapper(CAT):
Expand All @@ -20,6 +23,7 @@ class CATWrapper(CAT):
"""

_model_info: ModelInfo
_den_cnf: DenConfig

def save_model_pack(
self, target_folder: str, pack_name: str = DEFAULT_PACK_NAME,
Expand Down Expand Up @@ -54,19 +58,36 @@ def save_model_pack(
if not force_save_local and not is_injected_for_save():
raise CannotSaveOnDiskException(
f"Cannot save model on disk: {CATWrapper.__doc__}")
if (is_injected_for_save() and isinstance(
self._den_cnf, RemoteDenConfig) and
not self._den_cnf.allow_push_fine_tuned):
# NOTE: should there be a check whether this is a base model?
raise CannotSendToRemoteException(
"Cannot save fine-tuned model onto a remote den."
"In order to make full use of the remote den capabilities, "
"use the den API to fine tune a model directly on the den. "
"See `Den.finetune_model` for details or set the config "
"option of `allow_push_fine_tuned` to True"
)
return super().save_model_pack(
target_folder, pack_name, serialiser_type, make_archive,
only_archive, add_hash_to_pack_name, change_description)

@property
def trainer(self) -> Trainer:
tr = super().trainer
return WrappedTrainer(self._den_cnf, tr)

@classmethod
def load_model_pack(cls, model_pack_path: str,
config_dict: Optional[dict] = None,
addon_config_dict: Optional[dict[str, dict]] = None,
model_info: Optional[ModelInfo] = None,
den_cnf: Optional[DenConfig] = None,
) -> 'CAT':
"""Load the model pack from file.

This also
This may also disallow model load from disk in certain secnarios.

Args:
model_pack_path (str): The model pack path.
Expand All @@ -80,6 +101,9 @@ def load_model_pack(cls, model_pack_path: str,
model_inof (Optional[ModelInfo]): The base model info based on
which the model was originally fetched. Should not be
left None.
den_cnf: (Optional[DenConfig]): The config for the den being
used. Should not be left None.


Raises:
ValueError: If the saved data does not represent a model pack.
Expand All @@ -95,10 +119,45 @@ def load_model_pack(cls, model_pack_path: str,
cat.__class__ = CATWrapper
if model_info is None:
raise CannotWrapModel("Model info must be provided")
if den_cnf is None:
raise CannotWrapModel("den_cnf must be provided")
cat._model_info = model_info
cat._den_cnf = den_cnf
return cat


class WrappedTrainer(Trainer):

def __init__(self, den_cnf: DenConfig, delegate: Trainer):
super().__init__(delegate.cdb, delegate.caller, delegate._pipeline)
self._den_cnf = den_cnf

def train_supervised_raw(
self, data: MedCATTrainerExport, reset_cui_count: bool = False,
nepochs: int = 1, print_stats: int = 0, use_filters: bool = False,
terminate_last: bool = False, use_overlaps: bool = False,
use_cui_doc_limit: bool = False, test_size: float = 0,
devalue_others: bool = False, use_groups: bool = False,
never_terminate: bool = False,
train_from_false_positives: bool = False,
extra_cui_filter: Optional[set[str]] = None,
disable_progress: bool = False, train_addons: bool = False):
if (isinstance(self._den_cnf, RemoteDenConfig) and
not self._den_cnf.allow_local_fine_tune):
raise NotAllowedToFineTuneLocallyException(
"You are not allowed to fine-tune remote models locally. "
"Please use the `Den.finetune_model` method directly to "
"fine tune on the remote den, or if required, set the "
"`allow_local_fine_tune` config value to `True`."
)
return super().train_supervised_raw(
data, reset_cui_count, nepochs, print_stats, use_filters,
terminate_last, use_overlaps, use_cui_doc_limit, test_size,
devalue_others, use_groups, never_terminate,
train_from_false_positives, extra_cui_filter, disable_progress,
train_addons)


class CannotWrapModel(ValueError):

def __init__(self, *args):
Expand All @@ -109,3 +168,15 @@ class CannotSaveOnDiskException(ValueError):

def __init__(self, *args):
super().__init__(*args)


class CannotSendToRemoteException(ValueError):

def __call__(self, *args):
return super().__call__(*args)


class NotAllowedToFineTuneLocallyException(ValueError):

def __call__(self, *args):
return super().__call__(*args)
Loading
Loading