Skip to content

Commit 4e8335f

Browse files
authored
example: sine function model prediction with litdata & pytorch-lightning (#517)
* example: sine function model prediction with litdata & pytorch-lightning * update * update * update * update * clear jupyter notebook output
1 parent ede445d commit 4e8335f

File tree

6 files changed

+300
-0
lines changed

6 files changed

+300
-0
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,3 +116,6 @@ lightning_logs
116116

117117
# status.json file
118118
status.json
119+
120+
# use the below name for your optimize dataset directory for examples
121+
example_optimize_dataset

.pre-commit-config.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@ repos:
4444
hooks:
4545
- id: codespell
4646
additional_dependencies: [tomli]
47+
exclude: >
48+
(?x)^(
49+
.*\.ipynb
50+
)$
4751
#args: ["--write-changes"] # uncomment if you want to get automatic fixing
4852

4953
- repo: https://github.com/astral-sh/ruff-pre-commit
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
3+
import litdata as ld
4+
5+
6+
def sine_function(x: int):
7+
# You can use any key:value pairs. Note that their types must not change between samples, and Python lists must
8+
# always contain the same number of elements with the same types.
9+
data = {"x": x, "sine": np.sin(x)}
10+
11+
return data # noqa: RET504
12+
13+
14+
if __name__ == "__main__":
15+
# The optimize function writes data in an optimized format.
16+
ld.optimize(
17+
fn=sine_function, # the function applied to each input
18+
inputs=list(np.linspace(-5, 5, 1000)), # the inputs to the function (here it's a list of numbers)
19+
output_dir="example_optimize_dataset", # optimized data is stored here
20+
num_workers=4, # The number of workers on the same machine
21+
chunk_size=50, # number of items in each chunk (1000/50 = 20 chunks should be made)
22+
mode="overwrite", # if optimized dataset already exists in dir, overwrite it.
23+
)
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# ruff: noqa: RET504
2+
import lightning as L
3+
import torch
4+
import torch.nn.functional as F
5+
from torch import nn
6+
7+
import litdata as ld
8+
9+
10+
class SineModule(L.LightningModule):
11+
def __init__(self):
12+
super().__init__()
13+
self.fc1 = nn.Linear(1, 32)
14+
self.fc2 = nn.Linear(32, 32)
15+
self.fc3 = nn.Linear(32, 8)
16+
self.fc4 = nn.Linear(8, 1)
17+
18+
def forward(self, x):
19+
x = F.relu(self.fc1(x))
20+
x = F.relu(self.fc2(x))
21+
x = F.relu(self.fc3(x))
22+
x = F.tanh(self.fc4(x)) # for output to be in -1 to 1
23+
return x
24+
25+
def training_step(self, batch, batch_idx):
26+
# training_step defines the train loop.
27+
x, y = batch["x"], batch["sine"]
28+
x = x.view(x.size(0), -1)
29+
x = self.forward(x)
30+
31+
loss = F.mse_loss(x.squeeze(), y)
32+
return loss
33+
34+
def test_step(self, batch, batch_idx):
35+
# this is the test loop
36+
x, y = batch["x"], batch["sine"]
37+
x = x.view(x.size(0), -1)
38+
x = self.forward(x)
39+
40+
test_loss = F.mse_loss(x.squeeze(), y)
41+
self.log("test_loss", test_loss)
42+
43+
def validation_step(self, batch, batch_idx):
44+
# this is the validation loop
45+
x, y = batch["x"], batch["sine"]
46+
x = x.view(x.size(0), -1)
47+
x = self.forward(x)
48+
49+
val_loss = F.mse_loss(x.squeeze(), y)
50+
self.log("val_loss", val_loss)
51+
52+
def configure_optimizers(self):
53+
optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
54+
return optimizer
55+
56+
57+
class SineDataModule(L.LightningDataModule):
58+
def __init__(self, data_dir: str, batch_size: int = 4):
59+
super().__init__()
60+
self.data_dir = data_dir
61+
self.batch_size = batch_size
62+
63+
def setup(self, stage: str):
64+
dataset = ld.StreamingDataset(self.data_dir)
65+
self.train_dataset, self.val_dataset, self.test_dataset = ld.train_test_split(dataset, splits=[0.7, 0.1, 0.1])
66+
67+
def train_dataloader(self):
68+
return ld.StreamingDataLoader(
69+
self.train_dataset, batch_size=self.batch_size, num_workers=7, persistent_workers=True
70+
)
71+
72+
def val_dataloader(self):
73+
return ld.StreamingDataLoader(
74+
self.val_dataset, batch_size=self.batch_size, num_workers=7, persistent_workers=True
75+
)
76+
77+
def test_dataloader(self):
78+
return ld.StreamingDataLoader(
79+
self.test_dataset, batch_size=self.batch_size, num_workers=7, persistent_workers=True
80+
)
81+
82+
83+
# ======================================================
84+
85+
86+
if __name__ == "__main__":
87+
model = SineModule()
88+
data = SineDataModule("example_optimize_dataset")
89+
90+
trainer = L.Trainer(max_epochs=100, accelerator="cpu", precision="64-true")
91+
trainer.fit(model, data)
92+
trainer.test(model, data)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Sine function model prediction with `LitData` & `PyTorch Lightning`
2+
3+
<a target="_blank" href="https://lightning.ai/deependu/studios/sine-function-model-prediction-with-litdata-and-pytorch-lightning"><img src="https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/app-2/studio-badge.svg" alt="Open in Studio"/>
4+
</a>
5+
6+
- Checkout this example in [Lightning Studio](https://lightning.ai/deependu/studios/sine-function-model-prediction-with-litdata-and-pytorch-lightning)
7+
8+
---
9+
10+
## Steps
11+
12+
- Prepare Optimize dataset. [Check optimize.py file](./01-optimize.py)
13+
14+
- Train Model with LitData Streaming Dataset & Dataloader + PyTorch Lightning & Datamodule. [check model training code](./02-model_training.py)
15+
16+
- Visualize the prediction. [Check jupyter notebook](./main.ipynb)
17+
18+
![visualize prediction](https://storage.googleapis.com/lightning-avatars/litpages/01jphhqptdw8t8sbrdxgdbj3np/5e809ecf-6781-4089-9f48-654519db7c34.png)
Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"metadata": {},
7+
"outputs": [],
8+
"source": [
9+
"import lightning as L\n",
10+
"import torch\n",
11+
"import torch.nn.functional as F\n",
12+
"from torch import nn\n",
13+
"\n",
14+
"\n",
15+
"# ruff: noqa: RET504\n",
16+
"class SineModule(L.LightningModule):\n",
17+
" def __init__(self):\n",
18+
" super().__init__()\n",
19+
" self.fc1 = nn.Linear(1, 32)\n",
20+
" self.fc2 = nn.Linear(32, 32)\n",
21+
" self.fc3 = nn.Linear(32, 8)\n",
22+
" self.fc4 = nn.Linear(8, 1)\n",
23+
"\n",
24+
" def forward(self, x):\n",
25+
" x = F.relu(self.fc1(x))\n",
26+
" x = F.relu(self.fc2(x))\n",
27+
" x = F.relu(self.fc3(x))\n",
28+
" x = F.tanh(self.fc4(x)) # for output to be in -1 to 1\n",
29+
" return x\n",
30+
"\n",
31+
" def training_step(self, batch, batch_idx):\n",
32+
" # training_step defines the train loop.\n",
33+
" x, y = batch[\"x\"], batch[\"sine\"]\n",
34+
" x = x.view(x.size(0), -1)\n",
35+
" x = F.relu(self.fc1(x))\n",
36+
" x = F.relu(self.fc2(x))\n",
37+
" x = F.relu(self.fc3(x))\n",
38+
" x = F.tanh(self.fc4(x)) # for output to be in -1 to 1\n",
39+
"\n",
40+
" loss = F.mse_loss(x.squeeze(), y)\n",
41+
" return loss\n",
42+
"\n",
43+
" def test_step(self, batch, batch_idx):\n",
44+
" # this is the test loop\n",
45+
" x, y = batch[\"x\"], batch[\"sine\"]\n",
46+
" x = x.view(x.size(0), -1)\n",
47+
" x = F.relu(self.fc1(x))\n",
48+
" x = F.relu(self.fc2(x))\n",
49+
" x = F.relu(self.fc3(x))\n",
50+
" x = F.tanh(self.fc4(x)) # for output to be in -1 to 1\n",
51+
"\n",
52+
" test_loss = F.mse_loss(x.squeeze(), y)\n",
53+
" self.log(\"test_loss\", test_loss)\n",
54+
"\n",
55+
" def validation_step(self, batch, batch_idx):\n",
56+
" # this is the validation loop\n",
57+
" x, y = batch[\"x\"], batch[\"sine\"]\n",
58+
" x = x.view(x.size(0), -1)\n",
59+
" x = F.relu(self.fc1(x))\n",
60+
" x = F.relu(self.fc2(x))\n",
61+
" x = F.relu(self.fc3(x))\n",
62+
" x = F.tanh(self.fc4(x)) # for output to be in -1 to 1\n",
63+
"\n",
64+
" val_loss = F.mse_loss(x.squeeze(), y)\n",
65+
" self.log(\"val_loss\", val_loss)\n",
66+
"\n",
67+
" def configure_optimizers(self):\n",
68+
" optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)\n",
69+
" return optimizer"
70+
]
71+
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"metadata": {},
76+
"outputs": [],
77+
"source": [
78+
"model = SineModule.load_from_checkpoint(\"lightning_logs/version_0/checkpoints/epoch=99-step=17500.ckpt\")"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": null,
84+
"metadata": {},
85+
"outputs": [],
86+
"source": [
87+
"model"
88+
]
89+
},
90+
{
91+
"cell_type": "code",
92+
"execution_count": null,
93+
"metadata": {},
94+
"outputs": [],
95+
"source": [
96+
"import numpy as np\n",
97+
"\n",
98+
"model.eval()\n",
99+
"\n",
100+
"x = np.linspace(-5, 5, 100)\n",
101+
"original_sine = np.sin(x)\n",
102+
"\n",
103+
"y = []\n",
104+
"\n",
105+
"with torch.no_grad():\n",
106+
" for _x in x:\n",
107+
" _x = torch.Tensor([_x])\n",
108+
" y_hat = model(_x)\n",
109+
" y.append(y_hat)"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": null,
115+
"metadata": {},
116+
"outputs": [],
117+
"source": [
118+
"import matplotlib.pyplot as plt"
119+
]
120+
},
121+
{
122+
"cell_type": "code",
123+
"execution_count": null,
124+
"metadata": {},
125+
"outputs": [],
126+
"source": [
127+
"plt.plot(x, y, color=\"red\", label=\"predicted\") # Red color for y\n",
128+
"# Blue color for original_sine\n",
129+
"plt.plot(x, original_sine, color=\"blue\", label=\"original sine\")\n",
130+
"\n",
131+
"plt.legend() # Show labels in the plot\n",
132+
"plt.xlabel(\"X-axis\")\n",
133+
"plt.ylabel(\"Y-axis\")\n",
134+
"plt.title(\"Comparison of y and original_sine\")\n",
135+
"plt.show()"
136+
]
137+
}
138+
],
139+
"metadata": {
140+
"kernelspec": {
141+
"display_name": "litdata",
142+
"language": "python",
143+
"name": "python3"
144+
},
145+
"language_info": {
146+
"codemirror_mode": {
147+
"name": "ipython",
148+
"version": 3
149+
},
150+
"file_extension": ".py",
151+
"mimetype": "text/x-python",
152+
"name": "python",
153+
"nbconvert_exporter": "python",
154+
"pygments_lexer": "ipython3",
155+
"version": "3.10.15"
156+
}
157+
},
158+
"nbformat": 4,
159+
"nbformat_minor": 2
160+
}

0 commit comments

Comments
 (0)