diff --git a/.gitignore b/.gitignore index ae3d542..facac88 100644 --- a/.gitignore +++ b/.gitignore @@ -6,11 +6,13 @@ dist/ build/ images/ outputs/ +output/ +DATASETS/ multirun/ exp/ handling/ tests/ -scripts/ +wandb/ *.code-workspace configs/train/* diff --git a/animaloc/data/transforms.py b/animaloc/data/transforms.py index 881c466..89e7d36 100644 --- a/animaloc/data/transforms.py +++ b/animaloc/data/transforms.py @@ -191,6 +191,59 @@ def __call__( return image, target +@TRANSFORMS.register() +class AnimalDensity: + ''' Compute animal density per tile ''' + + def __init__( + self, + max_animals: float = 100.0, + anno_type: str = 'binary' + ) -> None: + ''' + Args: + anno_type (str, optional): choose between 'binary' for bounding box or 'density' + for points. Defaults to 'binary' + ''' + + assert anno_type in ['binary', 'density'], \ + f'Annotations type must be \'binary\' or \'density\', got \'{anno_type}\'' + self.max_animals= max_animals + self.anno_type = anno_type + + def __call__( + self, + image: Union[PIL.Image.Image, torch.Tensor], + target: Dict[str, torch.Tensor] + ) -> Dict[str, torch.Tensor]: + ''' + Args:. + image (PIL.Image.Image or torch.Tensor): image of reference [C,H,W], only for + pipeline convenience, original size is kept + target (dict): target containing at least 'boxes' (or 'points') and 'labels' + keys, with torch.Tensor as value. Labels must be integers! + + Returns: + Dict[str, torch.Tensor] + the down-sampled target + ''' + + if isinstance(image, PIL.Image.Image): + image = torchvision.transforms.ToTensor()(image) + + if self.anno_type == 'binary': # binary case (empty or not empty image) + if len(target['labels']): # we have annotations = not empty + target['labels'] = torch.as_tensor([1], dtype=torch.int64) + else: # we don't have annotations = empty + target['labels'] = torch.as_tensor([0], dtype=torch.int64) + elif self.anno_type == 'density': #TODO: complete for density case + if len(target['labels']): # we have annotations = not empty + target['labels'] = torch.as_tensor(len(target['labels'])/self.max_animals, dtype=torch.int64) + else: # we don't have annotations = empty + target['labels'] = torch.as_tensor([0], dtype=torch.int64) + + return image, target['labels'].float() + @TRANSFORMS.register() class PointsToMask: ''' Convert points annotations to mask with a buffer option ''' diff --git a/animaloc/datasets/folder.py b/animaloc/datasets/folder.py index 4f1dd2c..edc2c99 100644 --- a/animaloc/datasets/folder.py +++ b/animaloc/datasets/folder.py @@ -77,13 +77,13 @@ def __init__( self.folder_images = [i for i in os.listdir(self.root_dir) if i.endswith(('.JPG','.jpg','.JPEG','.jpeg'))] - self._img_names = self.folder_images + self._img_names = self.folder_images self.anno_keys = self.data.columns - self.data['from_folder'] = 0 - + self.data['from_folder'] = 0 # all images in the folder + folder_only_images = numpy.setdiff1d(self.folder_images, self.data['images'].unique().tolist()) folder_df = pandas.DataFrame(data=dict(images = folder_only_images)) - folder_df['from_folder'] = 1 + folder_df['from_folder'] = 1 # some have annotations self.data = pandas.concat([self.data, folder_df], ignore_index=True).convert_dtypes() diff --git a/animaloc/eval/evaluators.py b/animaloc/eval/evaluators.py index 4f6fe75..80b7d86 100644 --- a/animaloc/eval/evaluators.py +++ b/animaloc/eval/evaluators.py @@ -20,6 +20,7 @@ import numpy import wandb import matplotlib +from itertools import chain matplotlib.use('Agg') @@ -198,11 +199,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool if i % self.print_freq == 0 or i == len(self.dataloader) - 1: fig = self._vizual(image = images, target = targets, output = output) wandb.log({'validation_vizuals': fig}) - - output = self.prepare_feeding(targets, output) - - iter_metrics.feed(**output) - iter_metrics.aggregate() + for b in range(images.shape[0]): + batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0))) + iter_metrics.feed(**batch_output) + iter_metrics.aggregate() if log_meters: logger.add_meter('n', sum(iter_metrics.tp) + sum(iter_metrics.fn)) logger.add_meter('recall', round(iter_metrics.recall(),2)) @@ -224,8 +224,10 @@ def evaluate(self, returns: str = 'recall', wandb_flag: bool = False, viz: bool }) iter_metrics.flush() - - self.metrics.feed(**output) + for b in range(images.shape[0]): + batch_output = self.prepare_feeding(dict(labels= targets['labels'][b], points= targets['points'][b]), (output[0][b].unsqueeze(0), output[1][b].unsqueeze(0))) + self.metrics.feed(**batch_output) + #self.metrics.feed(**output) self._stored_metrics = self.metrics.copy() @@ -345,14 +347,16 @@ def post_stitcher(self, output: torch.Tensor) -> Any: def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.Tensor]) -> dict: - gt_coords = [p[::-1] for p in targets['points'].squeeze(0).tolist()] - gt_labels = targets['labels'].squeeze(0).tolist() - + gt_coords = [p[::-1] for p in targets['points'].tolist()] + gt_labels = targets['labels'].tolist() + + ndim= numpy.array(gt_coords).ndim gt = dict( loc = gt_coords, labels = gt_labels ) + up = True if self.stitcher is not None: up = False @@ -363,8 +367,8 @@ def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: List[torch.T preds = dict( loc = locs[0], labels = labels[0], - scores = scores[0], - dscores = dscores[0] + scores = scores[0], # class scores + dscores = dscores[0] # heatmap scores ) return dict(gt = gt, preds = preds, est_count = counts[0]) @@ -390,6 +394,27 @@ def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor return dict(gt = gt, preds = preds, est_count = est_counts) +@EVALUATORS.register() +class DensityMapEvaluator(Evaluator): + + def prepare_data(self, images: Any, targets: Any) -> tuple: + return images.to(self.device), targets + + def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor) -> dict: + + gt_coords = [p[::-1] for p in targets['points'].squeeze(0).tolist()] + gt_labels = targets['labels'].squeeze(0).tolist() + + gt = dict(loc = gt_coords, labels = gt_labels) + preds = dict(loc = [], labels = [], scores = []) + + _, idx = torch.max(output, dim=1) + masks = F.one_hot(idx, num_classes=output.shape[1]).permute(0,3,1,2) + output = (output * masks) + est_counts = output[0].sum(2).sum(1).tolist() + + return dict(gt = gt, preds = preds, est_count = est_counts) + @EVALUATORS.register() class FasterRCNNEvaluator(Evaluator): @@ -420,4 +445,23 @@ def prepare_feeding(self, targets: List[dict], output: List[dict]) -> dict: num_classes = self.metrics.num_classes - 1 counts = [preds['labels'].count(i+1) for i in range(num_classes)] - return dict(gt = gt, preds = preds, est_count = counts) \ No newline at end of file + return dict(gt = gt, preds = preds, est_count = counts) + +@EVALUATORS.register() +class TileEvaluator(Evaluator): + + def prepare_data(self, images: Any, targets: Any) -> tuple: + return images.to(self.device), targets + + def prepare_feeding(self, targets: Dict[str, torch.Tensor], output: torch.Tensor) -> dict: + + + gt_labels = list(chain.from_iterable(targets[0].tolist())) + gt_labels = [int(l+1) for l in gt_labels] + gt = dict(loc = [], labels = gt_labels) + preds = dict(loc = [], labels = [], scores = []) + + scores= list(chain.from_iterable(output.tolist())) + labels= [2 if s>0 else 1 for s in scores] + preds = dict(loc = [], labels = labels, scores = scores) + return dict(gt = gt_labels, preds = labels) \ No newline at end of file diff --git a/animaloc/eval/metrics.py b/animaloc/eval/metrics.py index c039e5d..ef7a2d5 100644 --- a/animaloc/eval/metrics.py +++ b/animaloc/eval/metrics.py @@ -636,7 +636,7 @@ def __init__(self, num_classes: int = 2) -> None: num_classes = num_classes + 1 # for convenience super().__init__(0, num_classes) - def feed(self, gt: int, pred: int) -> tuple: + def feed(self, gt: int, preds: int) -> tuple: ''' Args: gt (int): numeric ground truth label @@ -644,19 +644,44 @@ def feed(self, gt: int, pred: int) -> tuple: ''' gt = dict(labels=[gt], loc=[(0,0)]) - preds = dict(labels=[pred], loc=[(0,0)]) + preds = dict(labels=[preds], loc=[(0,0)]) super().feed(gt, preds) def matching(self, gt: dict, pred: dict) -> None: gt_lab = gt['labels'][0] p_lab = pred['labels'][0] + for g, p in zip(gt_lab, p_lab): #TODO: To be confirmed + if g == p: + self.tp[g-1] += 1 + else: + self.fp[p-1] += 1 + self.fn[g-1] += 1 + + self._confusion_matrix += confusion_matrix(gt_lab, p_lab, labels=list(range(1, self.num_classes))) + +@METRICS.register() +class RegressionMetrics(Metrics): + ''' Metrics class for regression type tasks ''' - if gt_lab == p_lab: - self.tp[gt_lab-1] += 1 - else: - self.fp[p_lab-1] += 1 - self.fn[gt_lab-1] += 1 + def __init__(self, num_classes: int = 2) -> None: + num_classes = num_classes + 1 # for convenience + super().__init__(0, num_classes) + + def feed(self, gt: float, pred: float) -> tuple: + ''' + Args: + gt (float): numeric ground truth value + pred (float): numeric predicted value + ''' + + gt = dict(labels=[gt], loc=[(0,0)]) + preds = dict(labels=[pred], loc=[(0,0)]) - self._confusion_matrix += confusion_matrix( - [gt_lab], [p_lab], labels=list(range(self.num_classes-1))) \ No newline at end of file + super().feed(gt, preds) + + def matching(self, gt: dict, pred: dict) -> None: + gt_lab = gt['labels'][0] + p_lab = pred['labels'][0] + + diff= math.abs(gt_lab-p_lab) # L1-loss diff --git a/animaloc/models/__init__.py b/animaloc/models/__init__.py index 725a01b..59d753c 100644 --- a/animaloc/models/__init__.py +++ b/animaloc/models/__init__.py @@ -21,5 +21,6 @@ from .herdnet import * from .utils import * from .ss_dla import * +from .dla_backbone import * __all__ = ['MODELS', *MODELS.registry_names] \ No newline at end of file diff --git a/animaloc/models/dla_backbone.py b/animaloc/models/dla_backbone.py new file mode 100644 index 0000000..3ff61c6 --- /dev/null +++ b/animaloc/models/dla_backbone.py @@ -0,0 +1,93 @@ +__copyright__ = \ + """ + Copyright (C) 2022 University of Liège, Gembloux Agro-Bio Tech, Forest Is Life + All rights reserved. + + This source code is under the CC BY-NC-SA-4.0 license (https://creativecommons.org/licenses/by-nc-sa/4.0/). + It is to be used for academic research purposes only, no commercial use is permitted. + + Please contact the author Alexandre Delplanque (alexandre.delplanque@uliege.be) for any questions. + + Last modification: March 29, 2023 + """ +__author__ = "Alexandre Delplanque" +__license__ = "CC BY-NC-SA 4.0" +__version__ = "0.2.0" + + +import torch + +import torch.nn as nn +import numpy as np +import torchvision.transforms as T + +from typing import Optional + +from .register import MODELS + +from . import dla as dla_modules + + +@MODELS.register() +class DLAEncoder(nn.Module): + ''' DLA encoder architecture ''' + + def __init__( + self, + num_layers: int = 34, + num_classes: int = 2, + pretrained: bool = True, + ): + ''' + Args: + num_layers (int, optional): number of layers of DLA. Defaults to 34. + num_classes (int, optional): number of output classes, background included. + Defaults to 2. + pretrained (bool, optional): set False to disable pretrained DLA encoder parameters + from ImageNet. Defaults to True. + ''' + + super(DLAEncoder, self).__init__() + + base_name = 'dla{}'.format(num_layers) + + self.num_classes = num_classes + + # backbone + base = dla_modules.__dict__[base_name](pretrained=pretrained, return_levels=True) + setattr(self, 'base_0', base) + setattr(self, 'channels_0', base.channels) + + channels = self.channels_0 + + + # bottleneck conv + self.bottleneck_conv = nn.Conv2d( + channels[-1], channels[-1], + kernel_size=1, stride=1, + padding=0, bias=True + ) + self.pooling= nn.AvgPool2d(kernel_size= 16, stride=1, padding=0) # we take the average of each filter + self.cls_head = nn.Linear(512, 1) # binary head + + def forward(self, input: torch.Tensor): + + encode = self.base_0(input) # Nx512x16x16 + bottleneck = self.bottleneck_conv(encode[-1]) + bottleneck = self.pooling(bottleneck) + bottleneck= torch.reshape(bottleneck, (bottleneck.size()[0],-1)) # keeping the first dimension (samples) + encode[-1] = bottleneck # Nx512 + cls = self.cls_head(encode[-1]) + + #cls = nn.functional.sigmoid(cls) + return cls + + def freeze(self, layers: list) -> None: + ''' Freeze all layers mentioned in the input list ''' + for layer in layers: + self._freeze_layer(layer) + + def _freeze_layer(self, layer_name: str) -> None: + for param in getattr(self, layer_name).parameters(): + param.requires_grad = False + diff --git a/docker/Dockerfile b/docker/Dockerfile new file mode 100644 index 0000000..3e20f0e --- /dev/null +++ b/docker/Dockerfile @@ -0,0 +1,38 @@ + +ARG PYTORCH="1.11.0" +ARG CUDA="11.3" +ARG CUDNN="8" + +FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel AS base +ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX 9.0" +ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all" + +RUN conda update conda && conda install pip && conda clean -afy +WORKDIR /herdnet +COPY environment-dev.yml ./ +RUN conda env update -f environment-dev.yml -n base && conda clean -afy + +FROM base AS dev +# Install make for the docs build +# solves a weired problem with NVIDIA with https://github.com/NVIDIA/nvidia-container-toolkit/issues/258#issuecomment-1903945418 +RUN \ + # Update nvidia GPG key + rm /etc/apt/sources.list.d/cuda.list && \ + rm /etc/apt/sources.list.d/nvidia-ml.list && \ + apt-key del 7fa2af80 && \ + apt-get update && apt-get install -y --no-install-recommends wget && \ + wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/cuda-keyring_1.0-1_all.deb && \ + dpkg -i cuda-keyring_1.0-1_all.deb && \ + apt-get update + +RUN apt-get update \ + && apt-get install -y --no-install-recommends python3-pyqt5 python3-pyqt5.qtwebengine unzip git \ + && rm -rf /var/lib/apt/lists/* + +COPY requirements-dev.txt ./ +RUN pip install -r requirements-dev.txt +COPY . ./ +# pre-commit run --all-files fails w/o this line +RUN git init +RUN pip install -e . + diff --git a/docker/build b/docker/build new file mode 100755 index 0000000..f66da1f --- /dev/null +++ b/docker/build @@ -0,0 +1,63 @@ +#!/bin/bash + +set -e + +if [[ -n "${HERDNET_DEBUG}" ]]; then + set -x +fi + +source $(dirname "$0")/env + +function usage() { + + echo -n \ + "Usage: $(basename "$0") +Build hedrnet container. Must be run from the repository root + +Options: +--no-cache Do not use cache when building the images +--pull Attempt to pull upstream images before building +" +} + +while [[ $# -gt 0 ]] +do + key="$1" + case $key in + --help) + usage + exit 0 + shift + ;; + + --no-cache) + NO_CACHE="--no-cache" + shift + ;; + + --pull) + PULL="--pull" + shift + ;; + esac +done + +if [ "${BASH_SOURCE[0]}" = "${0}" ]; then + if [ "${1:-}" = "--help" ]; then + usage + else + docker build $PULL $NO_CACHE --target dev -t herdnet:local -f docker/Dockerfile . + fi + + # Wipe out the debug container if it still exists. It is now obsolete. + if [ "$(docker ps -aqf name=$DOCKER_CONTAINER_NAME)" ]; then + # Container exists + if [ "$(docker ps -qf name=$DOCKER_CONTAINER_NAME)" ]; then + # Container is runnking + echo "Kill debug container" + docker kill $DOCKER_DEBUG_NAME + fi + echo "Remove debug container" + docker rm -f $DOCKER_DEBUG_NAME + fi +fi diff --git a/docker/env b/docker/env new file mode 100644 index 0000000..49f4a96 --- /dev/null +++ b/docker/env @@ -0,0 +1,2 @@ +DOCKER_TAG_DEV=local +DOCKER_CONTAINER_NAME=herdnet \ No newline at end of file diff --git a/docker/jupyter b/docker/jupyter new file mode 100644 index 0000000..45642b2 --- /dev/null +++ b/docker/jupyter @@ -0,0 +1,30 @@ +#!/bin/bash + +set -e + +if [[ -n "${HERDNET_DEBUG}" ]]; then + set -x +fi + +source $(dirname "$0")/env + +function usage() { + echo -n \ + "Usage: $(basename "$0") +Launches a Jupyter notebook in a docker container with all prerequisites installed. +" +} + +if [ "${BASH_SOURCE[0]}" = "${0}" ]; then + docker run --rm -it \ + -v `pwd`:/herdnet \ + --entrypoint jupyter \ + -p 8888:8888 \ + herdnet:local \ + lab \ + --ip=0.0.0.0 \ + --port=8888 \ + --no-browser \ + --allow-root \ + --notebook-dir=/herdnet +fi diff --git a/docker/run b/docker/run new file mode 100755 index 0000000..edc7826 --- /dev/null +++ b/docker/run @@ -0,0 +1,26 @@ +#!/bin/bash + +set -e + +if [[ -n "${HERDNET_DEBUG}" ]]; then + set -x +fi + +source $(dirname "$0")/env + +function usage() { + echo -n \ + "Usage: $(basename "$0") +Run a console in a docker container with all prerequisites installed. +" +} + +if [ "${BASH_SOURCE[0]}" = "${0}" ]; then + docker run --rm -it --gpus all\ + -v /home/fous3401/DATASETS:/herdnet/DATASETS \ + --ipc=host \ + -v `pwd`:/herdnet \ + -p 8000:8000 \ + --entrypoint /bin/bash \ + herdnet:local +fi \ No newline at end of file diff --git a/environment-dev.yml b/environment-dev.yml new file mode 100644 index 0000000..27b6d83 --- /dev/null +++ b/environment-dev.yml @@ -0,0 +1,7 @@ +name: herdnet +channels: + - pytorch + - defaults +dependencies: + - pip + - pip: diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..efb630b --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,15 @@ +black +codespell +flake8 +importlib-metadata +ipywidgets +ipython +jupyterlab +matplotlib +albumentations +pandas +#opencv-python +#opencv-python-headless +hydra-core +wandb +gdown diff --git a/scripts/base_herdnet_finetuning_stratified_v1.py b/scripts/base_herdnet_finetuning_stratified_v1.py new file mode 100644 index 0000000..7c10160 --- /dev/null +++ b/scripts/base_herdnet_finetuning_stratified_v1.py @@ -0,0 +1,351 @@ +# -*- coding: utf-8 -*- +"""Base_Herdnet_Finetuning_stratified_v1 + +Automatically generated by Colaboratory. + +Original file is located at + https://colab.research.google.com/drive/1h9HchpvlPruHShl2Oo4JGA1-O6-6d_3D + +# DEMO - Training and testing HerdNet on nadir aerial images + +## Installations +""" + +# Check GPU +""" !nvidia-smi + +# Install the dependencies +!pip install h5py +!pip install typing-extensions +!pip install wheel +!pip install albumentations>=1.0.3 +!pip install fiftyone>=0.14.3 +!pip install hydra-core>=1.1.0 +!pip install opencv-python>=4.5.1.48 +!pip install pandas>=1.2.3 +!pip install pillow>=8.2.0 +!pip install scikit-image>=0.18.1 +!pip install scikit-learn>=1.0.2 +!pip install scipy>=1.6.2 +!pip install wandb>=0.10.33 +!pip install numpy>=1.20.0 """ + +# Download and install the code +import sys + + +#!wandb login +import wandb +import random + +"""## Create datasets""" + +# Set the seed +from animaloc.utils.seed import set_seed + +set_seed(9292) + +#### Downloading and unziping the files +#zip file download (destination link) +# %cd /content/drive/MyDrive/ +#!pip install --upgrade --no-cache-dir gdown +# Download the Train zip file + +#!gdown https://drive.google.com/uc?id=1mI6Ve5v3sAj9h502g75GD1lZSYy4-FR4 -O /herdnet/DATASETS/Train_patches_stratified.zip +# Unzip the file to the specified directory +#!unzip -oq /herdnet/DATASETS/Train_patches_stratified.zip -d /herdnet/DATASETS/Train_patches_stratified + +# Download the val zip file +#!gdown https://drive.google.com/uc?id=1-1lGSZVk-ts0TMo0n-sbwlBGKHhgh9O9 -O /herdnet/DATASETS/val_patches_stratified.zip +# Unzip the file to the specified directory +#!unzip -oq /herdnet/DATASETS/val_patches_stratified.zip -d /herdnet/DATASETS/val_patches_stratified + +# Download the test zip file +#!gdown https://drive.google.com/uc?id=1-1r9sQlC-NxgcSvKKl0WPEmOpkzRV4KB -O /herdnet/DATASETS/test_patches_stratified.zip + +# Unzip the file to the specified directory +#!unzip -oq /herdnet/DATASETS/test_patches_stratified.zip -d /herdnet/DATASETS/test_patches_stratified +# %% +# Commented out IPython magic to ensure Python compatibility. +# %matplotlib inline +# Showing some samples of patches and the annotations +import matplotlib.pyplot as plt +from animaloc.datasets import CSVDataset +from animaloc.data.batch_utils import show_batch, collate_fn +from torch.utils.data import DataLoader +import torch +import albumentations as A +batch_size = 8 +NUM_WORKERS= 8 +csv_path = '/herdnet/DATASETS/Train_patches_stratified/gt.csv' +image_path = '/herdnet/DATASETS/Train_patches_stratified' +dataset = CSVDataset(csv_path, image_path, [A.Normalize()]) +dataloader = DataLoader(dataset, batch_size=batch_size, collate_fn=collate_fn, shuffle=True, num_workers= NUM_WORKERS) + +sample_batch = next(iter(dataloader)) +for i in range(len(sample_batch[1])): + points = sample_batch[1][i]['points'].numpy() + bbox= [] + for pt in points: + bbox.append([pt[0]-2,pt[1]-2,pt[0]+2,pt[1]+2]) + print(len(sample_batch[1][i]['labels'])) + sample_batch[1][i]['annotations']=torch.tensor(bbox) +plt.figure(figsize=(16,2)) +show_batch(sample_batch) +plt.savefig('/herdnet/show_patch.pdf') + +# %% +# Training, validation and test datasets +import albumentations as A + +from animaloc.datasets import CSVDataset +from animaloc.data.transforms import MultiTransformsWrapper, DownSample, PointsToMask, FIDT + +patch_size = 512 +num_classes = 2 +down_ratio = 2 + +train_dataset = CSVDataset( + csv_file = '/herdnet/DATASETS/Train_patches_stratified/gt.csv', + root_dir = '/herdnet/DATASETS/Train_patches_stratified', + albu_transforms = [ + A.VerticalFlip(p=0.5), + A.HorizontalFlip(p=0.5), + A.RandomRotate90(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.2), + A.Blur(blur_limit=15, p=0.2), + A.Normalize(p=1.0) + ], + end_transforms = [MultiTransformsWrapper([ + FIDT(num_classes=num_classes, down_ratio=down_ratio), + PointsToMask(radius=2, num_classes=num_classes, squeeze=True, down_ratio=int(patch_size//16)) + ])] + ) + +val_dataset = CSVDataset( + csv_file = '/herdnet/DATASETS/val_patches_stratified/gt.csv', + root_dir = '/herdnet/DATASETS/val_patches_stratified', + albu_transforms = [A.Normalize(p=1.0)], + end_transforms = [DownSample(down_ratio=down_ratio, anno_type='point')] + ) + +test_dataset = CSVDataset( + csv_file = '/herdnet/DATASETS/test_patches_stratified/gt.csv', + root_dir = '/herdnet/DATASETS/test_patches_stratified', + albu_transforms = [A.Normalize(p=1.0)], + end_transforms = [DownSample(down_ratio=down_ratio, anno_type='point')] + ) + +# Dataloaders +from torch.utils.data import DataLoader + +train_dataloader = DataLoader(dataset = train_dataset, batch_size = 4, shuffle = True, num_workers= NUM_WORKERS) + +val_dataloader = DataLoader(dataset = val_dataset, batch_size = 1, shuffle = False, num_workers= NUM_WORKERS) + +test_dataloader = DataLoader(dataset = test_dataset, batch_size = 1, shuffle = False, num_workers= NUM_WORKERS) + +"""## Define HerdNet for training""" + +# Path to your .pth file (initial pth file) +import gdown +import torch +pth_path = None #'/herdnet/output/best_model.pth' +from pathlib import Path + +dir_path = Path('/herdnet/output') +dir_path.mkdir(parents=True, exist_ok=True) +pth_path= '/herdnet/DATASETS/20220413_herdnet_model.pth' +if not pth_path: + gdown.download( + 'https://drive.google.com/uc?export=download&id=1-WUnBC4BJMVkNvRqalF_HzA1_pRkQTI_', + '/herdnet/output/20220413_herdnet_model.pth' + ) + pth_path = '/herdnet/output/20220413_herdnet_model.pth' + +from animaloc.models import HerdNet +from torch import Tensor +from animaloc.models import LossWrapper +from animaloc.train.losses import FocalLoss +from torch.nn import CrossEntropyLoss +pretrained= False + +herdnet = HerdNet(pretrained= False, num_classes=num_classes, down_ratio=down_ratio).cuda() +if not pretrained: + pretrained_dict = torch.load(pth_path)['model_state_dict'] + #herdnet_dict = herdnet.state_dict() + #pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in herdnet_dict} + #herdnet.load_state_dict(pretrained_dict, strict=False) + +losses = [ + {'loss': FocalLoss(reduction='mean'), 'idx': 0, 'idy': 0, 'lambda': 1.0, 'name': 'focal_loss'}, + {'loss': CrossEntropyLoss(reduction='mean'), 'idx': 1, 'idy': 1, 'lambda': 1.0, 'name': 'ce_loss'} + ] + +herdnet = LossWrapper(herdnet, losses=losses) + +#############Get model layers ########################### +def get_parameter_names(model): # getting the model layers + param_dict= dict() + for l, (name,param) in enumerate(model.named_parameters()): + #print(l,":\t",name,type(param),param.requires_grad) + param_dict[name]= l + return param_dict +result = get_parameter_names(herdnet) +print(result) + +"""# Freeze the alyers (different options) +1. half of a layer and other layers +""" + +#Freeze half of a specified layer +def freeze_parts(model, get_parameter_names, layers_to_freeze, freeze_layer_half=None, lr=0.0001, unfreeze=False): + params_to_update = [] + + for l, (name, param) in enumerate(model.named_parameters()): + res = any(ele in name for ele in layers_to_freeze) + param.requires_grad = unfreeze if res else not unfreeze + + # Check if the current layer is the specified layer to freeze half of its parameters + if freeze_layer_half is not None and freeze_layer_half in name: + total_params = param.numel() + half_params = total_params // 2 + param.requires_grad = unfreeze if l < half_params else not unfreeze + + if param.requires_grad: + params_to_update.append({ + "params": param, + "lr": lr, + }) + + # Print parameters to update + if param.requires_grad: + print(f"Trainable parameter: {name}") + else: + print(f"Frozen parameter: {name}") + + return params_to_update + +#freezing half of one lyer+ other layers +params_to_update = freeze_parts(herdnet.model, get_parameter_names, layers_to_freeze=['base_layer','level0','level1','level2','level3'], freeze_layer_half='level_4', lr=0.0001, unfreeze=False) + +"""# Freeze a complete layer""" + +#Freeze the layers +def freeze_parts(model, get_parameter_names, layers_to_freeze, lr, unfreeze=False): + params_to_update = [] + + for l, (name, param) in enumerate(model.named_parameters()): + res = any(ele in name for ele in layers_to_freeze) + param.requires_grad = unfreeze if res else not unfreeze + + if param.requires_grad == True: + params_to_update.append({ + "params": param, + "lr": lr, + }) + + # Print parameters to update + if param.requires_grad: + print(f"Trainable parameter: {name}") + else: + print(f"Frozen parameter: {name}") + + return params_to_update + +"""## Create the Trainer""" + +from torch.optim import Adam +from animaloc.train import Trainer +from animaloc.eval import PointsMetrics, HerdNetStitcher, HerdNetEvaluator +from animaloc.utils.useful_funcs import mkdir + +work_dir = '/herdnet/output' +mkdir(work_dir) + +lr = 1e-4 +weight_decay = 1e-3 +epochs = 100 + +optimizer = Adam(params_to_update, lr=lr, weight_decay=weight_decay) + +metrics = PointsMetrics(radius=20, num_classes=num_classes) + +stitcher = HerdNetStitcher( + model=herdnet, + size=(patch_size,patch_size), + overlap=160, + down_ratio=down_ratio, + reduction='mean' + ) + +evaluator = HerdNetEvaluator( + model=herdnet, + dataloader=val_dataloader, + metrics=metrics, + stitcher= None, # stitcher, + work_dir=work_dir, + header='validation' + ) + +trainer = Trainer( + model=herdnet, + train_dataloader=train_dataloader, + optimizer=optimizer, + num_epochs=epochs, + evaluator=evaluator, + # val_dataloader= val_dataloader #loss evaluation + work_dir=work_dir + ) + +"""## Start training""" + +import wandb +if wandb.run is not None: + wandb.finish() +wandb.init(project="herdnet-finetuning") + +trainer.start(warmup_iters=100, checkpoints='best', select='max', validate_on='f1_score', wandb_flag =True) + +"""## Test the model""" + +#save and load finetunned parameters +#herdnet = HerdNet() +#torch.save(herdnet.state_dict(), 'fine_tuned_base.pth') +#herdnet.load_state_dict(torch.load('fine_tuned_base.pth')) +pth_path = '/herdnet/output/fine_tuned_base.pth' +torch.save(herdnet.state_dict(), pth_path) + +if 0: + herdnet = HerdNet() + herdnet.load_state_dict(torch.load(pth_path)) +# Load trained parameters +if 0: + from animaloc.models import load_model + checkpoint = torch.load(pth_path, map_location=map_location) + herdnet.load_state_dict(checkpoint['model_state_dict']) + herdnet = load_model(herdnet, pth_path=pth_path) + +# Create output folder +test_dir = '/herdnet/test_output' +mkdir(test_dir) + +# Create an Evaluator +test_evaluator = HerdNetEvaluator( + model=herdnet, + dataloader=test_dataloader, + metrics=metrics, + stitcher=stitcher, + work_dir=test_dir, + header='test' + ) + +# Start testing +test_f1_score = test_evaluator.evaluate(returns='f1_score') + +# Print global F1 score (%) +print(f"F1 score = {test_f1_score * 100:0.0f}%") + +# Get the detections +detections = test_evaluator.results +detections \ No newline at end of file