Skip to content

Commit

Permalink
Merge pull request #89 from orobix/feature/upgrade-timm
Browse files Browse the repository at this point in the history
Upgrade timm
  • Loading branch information
AlessandroPolidori authored Jan 11, 2024
2 parents b3f73a0 + 3de40de commit a1fd7d8
Show file tree
Hide file tree
Showing 32 changed files with 405 additions and 63 deletions.
19 changes: 19 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,25 @@
# Changelog
All notable changes to this project will be documented in this file.

### [1.4.0]

#### Added

- Add new backbones for classification
- Add parameter to save a model summary for sklearn based classification tasks
- Add results csv file for anomaly detection task
- Add a way to freeze backbone layers by index for the finetuning task

#### Updated

- Update timm requirements to 0.9.12

#### Fixed

- Fix ModelSignatureWrapper not returing the correct instance when cpu, to and half functions are called
- Fix failure in model logging on mlflow whe half precision is used


### [1.3.8]

#### Updated
Expand Down
2 changes: 2 additions & 0 deletions docs/tutorials/examples/sklearn_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ task:
automatic_batch_size:
starting_batch_size: 1024
disable: true
save_model_summary: false
output:
folder: classification_experiment
report: true
Expand All @@ -161,6 +162,7 @@ This will train a logistic regression classifier using a resnet18 backbone, resi
The backbone (in torchscript and pytorch format) will be saved along with the classifier. `test_full_data` is used to specify if a final test should be performed on all the data (after training on the training and validation datasets).

Optionally it's possible to enable the automatic batch size finder by setting `automatic_batch_size.disable` to `false`. This will try to find the maximum batch size that can be used on the given device without running out of memory. The `starting_batch_size` parameter is used to specify the starting batch size to use for the search, the algorithm will start from this value and will try to divide it by two until it doesn't run out of memory.
Finally, the `save_model_summary` parameter can be used to save the backbone information in a text file called `model_summary.txt` located in the root of the output folder.

### Run

Expand Down
12 changes: 7 additions & 5 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "quadra"
version = "1.3.8"
version = "1.4.0"
description = "Deep Learning experiment orchestration library"
authors = [
{ name = "Alessandro Polidori", email = "[email protected]" },
Expand Down Expand Up @@ -64,10 +64,12 @@ dependencies = [
"scikit-multilearn==0.2.*",
"tripy==1.0.*",
"h5py==3.8.*",
"timm==0.6.12", # required by smp
"segmentation-models-pytorch==0.3.2",
"anomalib@git+https://github.com/orobix/[email protected]+obx.1.2.7",
"timm==0.9.12",
# Currently the only version of smp supporting timm 0.9.12 is the following
"segmentation-models-pytorch@git+https://github.com/qubvel/segmentation_models.pytorch@7b381f899ed472a477a89d381689caf535b5d0a6",
"anomalib@git+https://github.com/orobix/[email protected]+obx.1.2.9",
"xxhash==3.2.*",
"torchinfo==1.8.*",
]

[project.optional-dependencies]
Expand Down Expand Up @@ -121,7 +123,7 @@ repository = "https://github.com/orobix/quadra"

# Adapted from https://realpython.com/pypi-publish-python-package/#version-your-package
[tool.bumpver]
current_version = "1.3.8"
current_version = "1.4.0"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "build: Bump version {old_version} -> {new_version}"
commit = true
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.8"
__version__ = "1.4.0"


def get_version():
Expand Down
8 changes: 8 additions & 0 deletions quadra/configs/backbone/caformer_m36.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: caformer_m36.sail_in22k_ft_in1k
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 576
8 changes: 8 additions & 0 deletions quadra/configs/backbone/caformer_s36.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: caformer_s36.sail_in22k_ft_in1k
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 512
8 changes: 8 additions & 0 deletions quadra/configs/backbone/convnextv2_base.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: convnextv2_base.fcmae
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 1024
8 changes: 8 additions & 0 deletions quadra/configs/backbone/convnextv2_femto.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: convnextv2_femto.fcmae
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 384
8 changes: 8 additions & 0 deletions quadra/configs/backbone/convnextv2_tiny.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: convnextv2_tiny.fcmae
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 768
12 changes: 12 additions & 0 deletions quadra/configs/backbone/dinov2_vitb14.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model:
_target_: quadra.models.classification.TorchHubNetworkBuilder
repo_or_dir: facebookresearch/dinov2
model_name: dinov2_vitb14
pretrained: true
freeze: false
hyperspherical: false
metadata:
input_size: 224
output_dim: 768
patch_size: 14
nb_heads: 12
12 changes: 12 additions & 0 deletions quadra/configs/backbone/dinov2_vits14.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
model:
_target_: quadra.models.classification.TorchHubNetworkBuilder
repo_or_dir: facebookresearch/dinov2
model_name: dinov2_vits14
pretrained: true
freeze: false
hyperspherical: false
metadata:
input_size: 224
output_dim: 384
patch_size: 14
nb_heads: 6
8 changes: 8 additions & 0 deletions quadra/configs/backbone/efficientnet_b0.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: tf_efficientnet_b0.ns_jft_in1k
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 1280
8 changes: 8 additions & 0 deletions quadra/configs/backbone/efficientnet_b1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: tf_efficientnet_b1.ns_jft_in1k
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 1280
8 changes: 8 additions & 0 deletions quadra/configs/backbone/efficientnet_b2.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: tf_efficientnet_b2.ns_jft_in1k
pretrained: true
freeze: false
metadata:
input_size: 260
output_dim: 1408
8 changes: 8 additions & 0 deletions quadra/configs/backbone/efficientnet_b3.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: tf_efficientnet_b3.ns_jft_in1k
pretrained: true
freeze: false
metadata:
input_size: 300
output_dim: 1536
2 changes: 1 addition & 1 deletion quadra/configs/backbone/resnet18.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: resnet18
model_name: resnet18.tv_in1k # Use torchvision weights
pretrained: true
freeze: false
metadata:
Expand Down
9 changes: 9 additions & 0 deletions quadra/configs/backbone/tiny_vit_21m_224.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
model:
_target_: quadra.models.classification.TimmNetworkBuilder
model_name: tiny_vit_21m_224.dist_in22k_ft_in1k
pretrained: true
freeze: false
metadata:
input_size: 224
output_dim: 576
num_heads: 18 # Is it correct?
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ backbone:
freeze_parameters_name:

trainer:
precision: 32
max_epochs: 200
check_val_every_n_epoch: 1
log_every_n_steps: 1
Expand Down
2 changes: 1 addition & 1 deletion quadra/configs/model/anomalib/padim.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dataset:
model:
name: padim
input_size: [224, 224]
backbone: resnet18
backbone: resnet18.tv_in1k
layers:
- layer1
- layer2
Expand Down
2 changes: 1 addition & 1 deletion quadra/configs/model/anomalib/patchcore.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ dataset:

model:
name: patchcore
backbone: wide_resnet50_2
backbone: resnet18.tv_in1k
layers:
- layer2
- layer3
Expand Down
1 change: 1 addition & 0 deletions quadra/configs/task/sklearn_classification.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ device: cuda:0
automatic_batch_size:
starting_batch_size: 1024
disable: true
save_model_summary: false
output:
folder: classification_experiment
report: true
Expand Down
2 changes: 1 addition & 1 deletion quadra/datamodules/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from sklearn.model_selection import train_test_split
from skmultilearn.model_selection import iterative_train_test_split
from timm.data.parsers.parser_image_folder import find_images_and_targets
from timm.data.readers.reader_image_folder import find_images_and_targets
from torch.utils.data import DataLoader

from quadra.datamodules.base import BaseDataModule
Expand Down
4 changes: 3 additions & 1 deletion quadra/datasets/anomaly.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def make_anomaly_dataset(
samples_list = [
(str(path),) + filename.parts[-3:]
for filename in path.glob("**/*")
if filename.is_file() and os.path.splitext(filename)[-1].lower() in IMAGE_EXTENSIONS
if filename.is_file()
and os.path.splitext(filename)[-1].lower() in IMAGE_EXTENSIONS
and ".ipynb_checkpoints" not in str(filename)
]

if len(samples_list) == 0:
Expand Down
42 changes: 34 additions & 8 deletions quadra/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,15 @@
import torch
from torch import nn

from quadra.utils.logger import get_logger

log = get_logger(__name__)


class ModelSignatureWrapper(nn.Module):
"""Model wrapper used to retrieve input shape. It can be used as a decorator of nn.Module, the first call to the
forward method will retrieve the input shape and store it in the input_shapes attribute.
It will also save the model summary in a file called model_summary.txt in the current working directory.
"""

def __init__(self, model: nn.Module):
Expand All @@ -24,16 +29,13 @@ def __init__(self, model: nn.Module):
self.instance = self.instance.instance

def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:
"""Retrieve the input shape and forward the model."""
"""Retrieve the input shape and forward the model, if the input shape is already retrieved it will just forward
the model.
"""
if self.input_shapes is None and not self.disable:
try:
self.input_shapes = self._get_input_shapes(*args, **kwargs)
except Exception:
# Avoid circular import
# pylint: disable=import-outside-toplevel
from quadra.utils.utils import get_logger # noqa

log = get_logger(__name__)
log.warning(
"Failed to retrieve input shapes after forward! To export the model you'll need to "
"provide the input shapes manually setting the config.export.input_shapes parameter! "
Expand All @@ -46,7 +48,21 @@ def forward(self, *args: Any, **kwargs: Any) -> torch.Tensor:

def to(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
return ModelSignatureWrapper(self.instance.to(*args, **kwargs))
self.instance = self.instance.to(*args, **kwargs)

return self

def half(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
self.instance = self.instance.half(*args, **kwargs)

return self

def cpu(self, *args, **kwargs):
"""Handle calls to to method returning the underlying model."""
self.instance = self.instance.cpu(*args, **kwargs)

return self

def _get_input_shapes(self, *args: Any, **kwargs: Any) -> list[Any]:
"""Retrieve the input shapes from the input. Inputs will be in the same order as the forward method
Expand Down Expand Up @@ -115,7 +131,17 @@ def __setattr__(self, name: str, value: torch.Tensor | nn.Module) -> None:
setattr(self.instance, name, value)

def __getattribute__(self, __name: str) -> Any:
if __name in ["instance", "input_shapes", "__dict__", "forward", "_get_input_shapes", "_get_input_shape", "to"]:
if __name in [
"instance",
"input_shapes",
"__dict__",
"forward",
"_get_input_shapes",
"_get_input_shape",
"to",
"half",
"cpu",
]:
return super().__getattribute__(__name)

return getattr(self.instance, __name)
Loading

0 comments on commit a1fd7d8

Please sign in to comment.