Skip to content

Commit

Permalink
Fix TBPTT example (#20528)
Browse files Browse the repository at this point in the history
* Fix TBPTT example

* Make example self-contained

* Update imports

* Add test
  • Loading branch information
lantiga authored Jan 6, 2025
1 parent ee7fa43 commit 76f0c54
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 22 deletions.
85 changes: 64 additions & 21 deletions docs/source-pytorch/common/tbptt.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,48 +12,91 @@ hidden states should be kept in-between each time-dimension split.
.. code-block:: python
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pytorch_lightning as pl
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset, DataLoader
class LitModel(LightningModule):
import lightning as L
class AverageDataset(Dataset):
def __init__(self, dataset_len=300, sequence_len=100):
self.dataset_len = dataset_len
self.sequence_len = sequence_len
self.input_seq = torch.randn(dataset_len, sequence_len, 10)
top, bottom = self.input_seq.chunk(2, -1)
self.output_seq = top + bottom.roll(shifts=1, dims=-1)
def __len__(self):
return self.dataset_len
def __getitem__(self, item):
return self.input_seq[item], self.output_seq[item]
class LitModel(L.LightningModule):
def __init__(self):
super().__init__()
self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20
# 1. Switch to manual optimization
self.automatic_optimization = False
self.truncated_bptt_steps = 10
self.my_rnn = ParityModuleRNN() # Define RNN model using ParityModuleRNN
self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)
def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs
# 2. Remove the `hiddens` argument
def training_step(self, batch, batch_idx):
# 3. Split the batch in chunks along the time dimension
split_batches = split_batch(batch, self.truncated_bptt_steps)
batch_size = 10
hidden_dim = 20
hiddens = torch.zeros(1, batch_size, hidden_dim, device=self.device)
for split_batch in range(split_batches):
# 4. Perform the optimization in a loop
loss, hiddens = self.my_rnn(split_batch, hiddens)
self.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1)
]
hiddens = None
optimizer = self.optimizers()
losses = []
# 4. Perform the optimization in a loop
for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)
optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()
# 5. "Truncate"
hiddens = hiddens.detach()
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())
avg_loss = sum(losses) / len(losses)
self.log("train_loss", avg_loss, prog_bar=True)
# 6. Remove the return of `hiddens`
# Returning loss in manual optimization is not needed
return None
def configure_optimizers(self):
return optim.Adam(self.my_rnn.parameters(), lr=0.001)
return optim.Adam(self.parameters(), lr=0.001)
def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
if __name__ == "__main__":
model = LitModel()
trainer = pl.Trainer(max_epochs=5)
trainer.fit(model, train_dataloader) # Define your own dataloader
trainer = L.Trainer(max_epochs=5)
trainer.fit(model)
51 changes: 51 additions & 0 deletions tests/tests_pytorch/helpers/advanced_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,3 +219,54 @@ def configure_optimizers(self):

def train_dataloader(self):
return DataLoader(MNIST(root=_PATH_DATASETS, train=True, download=True), batch_size=128, num_workers=1)


class TBPTTModule(LightningModule):
def __init__(self):
super().__init__()

self.batch_size = 10
self.in_features = 10
self.out_features = 5
self.hidden_dim = 20

self.automatic_optimization = False
self.truncated_bptt_steps = 10

self.rnn = nn.LSTM(self.in_features, self.hidden_dim, batch_first=True)
self.linear_out = nn.Linear(in_features=self.hidden_dim, out_features=self.out_features)

def forward(self, x, hs):
seq, hs = self.rnn(x, hs)
return self.linear_out(seq), hs

def training_step(self, batch, batch_idx):
x, y = batch
split_x, split_y = [
x.tensor_split(self.truncated_bptt_steps, dim=1),
y.tensor_split(self.truncated_bptt_steps, dim=1),
]

hiddens = None
optimizer = self.optimizers()
losses = []

for x, y in zip(split_x, split_y):
y_pred, hiddens = self(x, hiddens)
loss = F.mse_loss(y_pred, y)

optimizer.zero_grad()
self.manual_backward(loss)
optimizer.step()

# "Truncate"
hiddens = [h.detach() for h in hiddens]
losses.append(loss.detach())

return

def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=0.001)

def train_dataloader(self):
return DataLoader(AverageDataset(), batch_size=self.batch_size)
9 changes: 8 additions & 1 deletion tests/tests_pytorch/helpers/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from lightning.pytorch import Trainer
from lightning.pytorch.demos.boring_classes import BoringModel

from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN
from tests_pytorch.helpers.advanced_models import BasicGAN, ParityModuleMNIST, ParityModuleRNN, TBPTTModule
from tests_pytorch.helpers.datamodules import ClassifDataModule, RegressDataModule
from tests_pytorch.helpers.runif import RunIf
from tests_pytorch.helpers.simple_models import ClassificationModel, RegressionModel
Expand Down Expand Up @@ -49,3 +49,10 @@ def test_models(tmp_path, data_class, model_class):
model.to_torchscript()
if data_class:
model.to_onnx(os.path.join(tmp_path, "my-model.onnx"), input_sample=dm.sample)


def test_tbptt(tmp_path):
model = TBPTTModule()

trainer = Trainer(default_root_dir=tmp_path, max_epochs=1)
trainer.fit(model)

0 comments on commit 76f0c54

Please sign in to comment.