Skip to content

Commit

Permalink
hotfix: Batch size finder calling wrong super class functions (#85)
Browse files Browse the repository at this point in the history
* fix: Fix batch size finder callback not calling proper super functions

* fix: Fix hardcoded value

* build: Bump version 1.3.6 -> 1.3.7

* docs: Update changelog

Approved By: @rcmalli
  • Loading branch information
lorenzomammana authored Dec 5, 2023
1 parent 4a10a8c commit 8147e2d
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,13 @@
# Changelog
All notable changes to this project will be documented in this file.

### [1.3.7]

#### Fixed

- Fix BatchSizeFinder calling wrong super functions
- Fix ModelManager get_latest_version calling an hardcoded model

### [1.3.6]

#### Fixed
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "quadra"
version = "1.3.6"
version = "1.3.7"
description = "Deep Learning experiment orchestration library"
authors = [
{ name = "Alessandro Polidori", email = "[email protected]" },
Expand Down Expand Up @@ -121,7 +121,7 @@ repository = "https://github.com/orobix/quadra"

# Adapted from https://realpython.com/pypi-publish-python-package/#version-your-package
[tool.bumpver]
current_version = "1.3.6"
current_version = "1.3.7"
version_pattern = "MAJOR.MINOR.PATCH"
commit_message = "build: Bump version {old_version} -> {new_version}"
commit = true
Expand Down
2 changes: 1 addition & 1 deletion quadra/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.3.6"
__version__ = "1.3.7"


def get_version():
Expand Down
20 changes: 14 additions & 6 deletions quadra/callbacks/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,19 @@ def __init__(
self.find_test_batch_size = find_test_batch_size
self.find_predict_batch_size = find_predict_batch_size

def on_train_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.find_train_batch_size:
def on_fit_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.find_train_batch_size or trainer.state.stage is None:
# If called during validation skip it as it will be triggered during on_validation_start
return None

if trainer.state.stage.value != "train":
return None

if not isinstance(pl_module.model, nn.Module):
raise ValueError("The model must be a nn.Module")
pl_module.model.train()
return super().on_train_epoch_start(trainer, pl_module)

return super().on_fit_start(trainer, pl_module)

def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.find_validation_batch_size:
Expand All @@ -92,7 +97,8 @@ def on_validation_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule
if not isinstance(pl_module.model, nn.Module):
raise ValueError("The model must be a nn.Module")
pl_module.model.eval()
return super().on_validation_epoch_start(trainer, pl_module)

return super().on_validation_start(trainer, pl_module)

def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.find_test_batch_size:
Expand All @@ -101,7 +107,8 @@ def on_test_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> N
if not isinstance(pl_module.model, nn.Module):
raise ValueError("The model must be a nn.Module")
pl_module.model.eval()
return super().on_test_epoch_start(trainer, pl_module)

return super().on_test_start(trainer, pl_module)

def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -> None:
if not self.find_predict_batch_size:
Expand All @@ -110,4 +117,5 @@ def on_predict_start(self, trainer: pl.Trainer, pl_module: pl.LightningModule) -
if not isinstance(pl_module.model, nn.Module):
raise ValueError("The model must be a nn.Module")
pl_module.model.eval()
return super().on_predict_epoch_start(trainer, pl_module)

return super().on_predict_start(trainer, pl_module)
2 changes: 1 addition & 1 deletion quadra/utils/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def get_latest_version(self, model_name: str) -> ModelVersion:
Returns:
The model version
"""
latest_version = max(int(x.version) for x in self.client.get_latest_versions("manager_model"))
latest_version = max(int(x.version) for x in self.client.get_latest_versions(model_name))
model_version = self.client.get_model_version(model_name, latest_version)

return model_version
Expand Down

0 comments on commit 8147e2d

Please sign in to comment.