-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Stepwise LR scheduler #20211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Stepwise LR scheduler #20211
Changes from all commits
5dba6f9
3d8b2bf
7a55c5c
f4b01e5
4424d70
607363e
3099586
935a9c1
c70ef61
ebfedf6
3285d7a
4b7b719
5be642f
fc01630
7f748cf
06f0a0a
3c48c9e
63cd1f0
48a7c8e
64ed819
29af194
09bc52b
e96c474
2391336
a273722
eb98dce
e45a8f9
7adad14
7bb9697
4c77cb3
15052fb
e30a504
9dbbc8d
27047bf
ac5afed
c61fd46
e17fd6f
337c1c2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,13 +12,14 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from unittest import mock | ||
from unittest.mock import call | ||
from unittest.mock import call, patch | ||
|
||
import pytest | ||
import torch | ||
from torch import optim | ||
from torch.utils.data import DataLoader, TensorDataset | ||
|
||
from lightning.pytorch import Trainer | ||
from lightning.pytorch import LightningModule, Trainer | ||
from lightning.pytorch.callbacks import ModelCheckpoint | ||
from lightning.pytorch.core.optimizer import ( | ||
_configure_optimizers, | ||
|
@@ -657,3 +658,66 @@ def lr_scheduler_step(*_): ... | |
else: | ||
with pytest.raises(MisconfigurationException, match="CustomScheduler` doesn't follow"): | ||
_init_optimizers_and_lr_schedulers(model) | ||
|
||
|
||
@patch("torch.optim.lr_scheduler.StepLR.step") | ||
def test_lr_scheduler_step_across_epoch_boundaries(mocked_sched, tmp_path): | ||
class StepAcrossEpochsModel(LightningModule): | ||
def __init__(self): | ||
super().__init__() | ||
self.layer = torch.nn.Linear(32, 2) | ||
|
||
def forward(self, x): | ||
return self.layer(x) | ||
|
||
def training_step(self, batch, batch_idx): | ||
# Add print statement to track batch index and global step | ||
if hasattr(self, "trainer"): | ||
print(f"Batch idx: {batch_idx}, Global step: {self.trainer.global_step}") | ||
return {"loss": torch.tensor(0.1, requires_grad=True)} | ||
|
||
def train_dataloader(self): | ||
x = torch.randn(21, 32) | ||
y = torch.randn(21, 2) | ||
return DataLoader(TensorDataset(x, y), batch_size=3) | ||
|
||
def configure_optimizers(self): | ||
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) | ||
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) | ||
return { | ||
"optimizer": optimizer, | ||
"lr_scheduler": { | ||
"scheduler": scheduler, | ||
"interval": "step", | ||
"frequency": 5, # Scheduler steps every 5 iterations | ||
}, | ||
} | ||
|
||
model = StepAcrossEpochsModel() | ||
|
||
# Trainer configuration for cross-epoch testing | ||
trainer = Trainer( | ||
default_root_dir=tmp_path, | ||
limit_train_batches=7, # More than `frequency` iterations per epoch | ||
max_epochs=3, # Test across multiple epochs | ||
) | ||
|
||
# Fit the model | ||
trainer.fit(model) | ||
|
||
# Debug print statements | ||
print(f"Mocked scheduler step calls: {mocked_sched.call_count}") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove the debug statements, I'd just convert them to asserts that compare the values with expected ones. |
||
print(f"Mocked scheduler call history: {mocked_sched.call_args_list}") | ||
|
||
# Calculate the total number of steps (iterations) and expected scheduler calls | ||
total_steps = 7 * 3 # Total iterations (7 batches per epoch * 3 epochs) | ||
expected_steps = (total_steps - 1) // 5 # Scheduler steps every 5 iterations | ||
|
||
print(f"Total steps: {total_steps}") | ||
print(f"Expected steps: {expected_steps}") | ||
|
||
# Assert that the scheduler was called the expected number of times | ||
# Allow for a small difference due to environment or rounding discrepancies | ||
assert abs(mocked_sched.call_count - expected_steps) <= 1, ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not sure why there should be rounding discrepancies. Shouldn't this be fully deterministic? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually the test was passing in my local environment but not in the CI / CD pipeline for some reason. I forgot to change it later. Let me correct it asap. |
||
f"Scheduler was called {mocked_sched.call_count} times, but expected {expected_steps} calls." | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Print statements in tests are not super helpful, just use
assert
s so the test will break if we don't get the expected value here.