Skip to content
This repository was archived by the owner on Dec 20, 2024. It is now read-only.

[Feature] Model Freezing ❄️ #189

Open
wants to merge 59 commits into
base: develop
Choose a base branch
from
Open
Changes from 1 commit
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
7e238e4
Introduced resume flag and checkpoint loading for transfer learning, …
icedoom888 Oct 9, 2024
08671d7
Added len of dataset computed dynamically
icedoom888 Oct 10, 2024
e2bd868
debugging validation
icedoom888 Oct 22, 2024
557a9f0
Merge branch 'develop' into feature/transfer-learning
icedoom888 Oct 22, 2024
544dddc
Small changes
icedoom888 Oct 24, 2024
a85619d
Removed prints
icedoom888 Oct 25, 2024
b87f795
Merge branch 'develop' into feature/transfer-learning
icedoom888 Oct 25, 2024
c8ce0b0
Not working
icedoom888 Nov 18, 2024
135eac5
small changes
icedoom888 Nov 18, 2024
3cebf18
Rebased on develop
icedoom888 Nov 26, 2024
db2a14f
Imputer changes
icedoom888 Nov 26, 2024
57f9026
Added sanification of checkpoint, effective batch size, git pre commit
icedoom888 Nov 26, 2024
039c16f
gpc
icedoom888 Nov 26, 2024
463c6a9
gpc
icedoom888 Nov 26, 2024
062f552
Merge branch 'develop' into feature/transfer-learning
icedoom888 Nov 26, 2024
2f4dd65
rebased on develop
icedoom888 Nov 26, 2024
c6d7519
New implementation: do not store modified checkpoint, load it directl…
icedoom888 Nov 27, 2024
bca0355
Added logging
icedoom888 Nov 28, 2024
aa6f207
Rebased on develop
icedoom888 Nov 28, 2024
7894cc0
Transfer learning working: implemented checkpoint cleaning with large…
icedoom888 Nov 29, 2024
eff4539
Reverted some changes concerning imputer issues
icedoom888 Dec 3, 2024
c1f854f
Reverted some changes concerning imputer issues
icedoom888 Dec 3, 2024
338387d
Cleaned code for final review
icedoom888 Dec 3, 2024
f739bf4
Changed changelog and assigned TODO correctly
icedoom888 Dec 3, 2024
7fd9a92
Changed changelog and assigned TODO correctly
icedoom888 Dec 3, 2024
315d59d
Merge branch 'develop' into feature/transfer-learning
icedoom888 Dec 3, 2024
1ac34d8
Addressed review: copy checkpoint before removing metadata file
icedoom888 Dec 3, 2024
b7697a1
Merge remote-tracking branch 'refs/remotes/origin/feature/transfer-le…
icedoom888 Dec 3, 2024
22ddeab
Merge branch 'develop' into feature/transfer-learning
icedoom888 Dec 3, 2024
0d4fa51
gpc passed
icedoom888 Dec 3, 2024
3265892
Removed logger in debugging mode
icedoom888 Dec 4, 2024
c325a9e
removed dataset lenght due to checkpointing issues
icedoom888 Dec 5, 2024
4709d46
Reintroduced correct config on graphtansformer
icedoom888 Dec 5, 2024
b0023f9
gpc passed
icedoom888 Dec 5, 2024
17d02f7
Merge branch 'develop' into feature/transfer-learning
icedoom888 Dec 5, 2024
6a8ac97
Removed patched for issue #57, code expects patched checkpoint already
icedoom888 Dec 5, 2024
4cd24bd
Merge branch 'develop' into feature/transfer-learning
icedoom888 Dec 5, 2024
355cca1
Removed new path name for patched checkpoint (ignoring fully issue #5…
icedoom888 Dec 5, 2024
b875ea0
Adapted changelog
icedoom888 Dec 5, 2024
b9b611b
Added Freezing functionality
icedoom888 Dec 5, 2024
0f0dff0
Added Freezing functionality
icedoom888 Dec 5, 2024
03c4adb
Tested ✅ waiting for transfer learning merge to happen
icedoom888 Dec 6, 2024
7d51c75
Switched logging to info from debug
icedoom888 Dec 6, 2024
7063407
Merge branch 'feature/transfer-learning' into feature/model_freezing
icedoom888 Dec 6, 2024
37f6090
Rebased on transfer learning develop
icedoom888 Dec 6, 2024
8c7d54c
GPC passed
icedoom888 Dec 6, 2024
4bce6f1
Changelog updated
icedoom888 Dec 6, 2024
bd32096
Completed Merge and code check
icedoom888 Dec 6, 2024
da5fffb
Rebased on latest changes
icedoom888 Dec 11, 2024
6aac548
gpc
icedoom888 Dec 11, 2024
a7ab588
Merge branch 'develop' into feature/model_freezing
icedoom888 Dec 17, 2024
8478689
Changed docstring and pytorch lightnening freeze
icedoom888 Dec 17, 2024
2eb2140
Addressed review
icedoom888 Dec 17, 2024
742a7a8
Changes for review
icedoom888 Dec 17, 2024
0b8a407
Refactor CHANGELOG
icedoom888 Dec 17, 2024
498a792
Merge branch 'develop' into feature/model_freezing
icedoom888 Dec 18, 2024
8797fb3
Rebased on develop
icedoom888 Dec 18, 2024
7705a7e
Added documentation
icedoom888 Dec 18, 2024
463bec4
Added documentation
icedoom888 Dec 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
debugging validation
icedoom888 committed Oct 22, 2024
commit e2bd86804aec2501e627240613904cd791bb2464
5 changes: 3 additions & 2 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
@@ -162,10 +162,11 @@ def _get_dataset(
rollout: int = 1,
label: str = "generic",
) -> NativeGridDataset:

r = max(rollout, self.rollout)

# Compute effective batch size
effective_bs = self.config.dataloader.batch_size[label] *\
# Compute effective batch size
effective_bs = self.config.dataloader.batch_size['training'] *\
self.config.hardware.num_gpus_per_node *\
self.config.hardware.num_nodes //\
self.config.hardware.num_gpus_per_model
9 changes: 8 additions & 1 deletion src/anemoi/training/train/forecaster.py
Original file line number Diff line number Diff line change
@@ -75,6 +75,8 @@ def __init__(
config=DotDict(map_config_to_primitives(OmegaConf.to_container(config, resolve=True))),
)

self.model = torch.compile(self.model)

self.data_indices = data_indices

self.save_hyperparameters()
@@ -321,8 +323,11 @@ def on_train_epoch_end(self) -> None:
self.rollout = min(self.rollout, self.rollout_max)

def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
print('I am doing validation!!!')
with torch.no_grad():
val_loss, metrics, y_preds = self._step(batch, batch_idx, validation_mode=True)
print('Done step..')
print('Logging..')
self.log(
"val_wmse",
val_loss,
@@ -333,7 +338,8 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
batch_size=batch.shape[0],
sync_dist=True,
)
for mname, mvalue in metrics.items():
for i, (mname, mvalue) in enumerate(metrics.items()):
print(i)
self.log(
"val_" + mname,
mvalue,
@@ -344,6 +350,7 @@ def validation_step(self, batch: torch.Tensor, batch_idx: int) -> None:
batch_size=batch.shape[0],
sync_dist=True,
)
print('Done')
return val_loss, y_preds

def configure_optimizers(self) -> tuple[list[torch.optim.Optimizer], list[dict]]: