Skip to content

Commit 93e6b94

Browse files
refactor: Remove gradient logging from default callback, add automatic batch size callback to all callback
1 parent abab4e4 commit 93e6b94

File tree

3 files changed

+13
-8
lines changed

3 files changed

+13
-8
lines changed

quadra/configs/callbacks/all.yaml

+13
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,16 @@ progress_bar:
3030
lightning_trainer_setup:
3131
_target_: quadra.callbacks.lightning.LightningTrainerBaseSetup
3232
log_every_n_steps: 1
33+
34+
batch_size_finder:
35+
_target_: quadra.callbacks.lightning.BatchSizeFinder
36+
mode: power
37+
steps_per_trial: 3
38+
init_val: 2
39+
max_trials: 5 # Max 64
40+
batch_arg_name: batch_size
41+
disable: false
42+
find_train_batch_size: true
43+
find_validation_batch_size: false
44+
find_test_batch_size: false
45+
find_predict_batch_size: false

quadra/configs/callbacks/default.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,6 @@ model_checkpoint:
99
filename: "epoch_{epoch:03d}"
1010
auto_insert_metric_name: False
1111

12-
log_gradients:
13-
_target_: quadra.callbacks.mlflow.LogGradients
14-
norm: 2
15-
disable: True
1612
lr_monitor:
1713
_target_: pytorch_lightning.callbacks.LearningRateMonitor
1814
logging_interval: "epoch"

quadra/configs/callbacks/default_anomalib.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,6 @@ upload_ckpts_as_artifact:
4444
upload_best_only: true
4545
delete_after_upload: true
4646
upload: false
47-
log_gradients:
48-
_target_: quadra.callbacks.mlflow.LogGradients
49-
norm: 2
50-
disable: true
5147
lr_monitor:
5248
_target_: pytorch_lightning.callbacks.LearningRateMonitor
5349
logging_interval: "epoch"

0 commit comments

Comments
 (0)