Skip to content

Commit 7836c0b

Browse files
author
Duncan Blythe
committed
added example to the readme
1 parent 249be45 commit 7836c0b

File tree

2 files changed

+92
-2
lines changed

2 files changed

+92
-2
lines changed

README.md

+91-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,92 @@
11
# torchapply
2-
Apply a torch model to some datapoints
2+
Apply a torch model to some datapoints.
3+
4+
Here's an example:
5+
6+
```python
7+
import torch
8+
from torch import tensor
9+
10+
11+
class Main(torch.nn.Module):
12+
def __init__(self, model_0, model_1):
13+
super().__init__()
14+
self.model_0 = model_0
15+
self.model_1 = model_1
16+
self.dictionary = {'apple': 0, 'orange': 1, 'pear': 2}
17+
18+
def preprocess(self, arg):
19+
return [
20+
{
21+
'a': {'b': self.dictionary[arg[0]['a']['b']]},
22+
'c': self.dictionary[arg[0]['c']]
23+
},
24+
torch.tensor([self.dictionary[x] for x in arg[1]])
25+
]
26+
27+
def forward(self, args):
28+
return self.model_0(args[0]), self.model_1(args[1])
29+
30+
def postprocess(self, arg):
31+
total = [arg[0]['a']['b'].sum(), arg[0]['c'].sum(), arg[1].sum()]
32+
return {'score': sum(total), 'decision': sum(total) > 0}
33+
34+
35+
class ModelA(torch.nn.Module):
36+
def forward(self, args):
37+
return {'b': torch.randn(args['b'].shape[0], 10)}
38+
39+
40+
class ModelC(torch.nn.Module):
41+
def forward(self, args):
42+
return torch.randn(args.shape[0], 10)
43+
44+
45+
class Model1(torch.nn.Module):
46+
def forward(self, args):
47+
return torch.randn(args.shape[0], 10)
48+
49+
50+
class Model0(torch.nn.Module):
51+
def __init__(self, model_a, model_c):
52+
super().__init__()
53+
self.model_a = model_a
54+
self.model_c = model_c
55+
56+
def forward(self, args):
57+
return {'a': self.model_a(args['a']), 'c': self.model_c(args['c'])}
58+
59+
60+
model = Main(
61+
model_0=Model0(
62+
model_a=ModelA(),
63+
model_c=ModelC()
64+
),
65+
model_1=Model1()
66+
)
67+
```
68+
69+
Apply to a single datapoint:
70+
71+
```python
72+
from torchapply import apply_model
73+
74+
apply_model(
75+
model,
76+
({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')),
77+
single=True
78+
)
79+
```
80+
81+
Apply to multiple datapoints:
82+
83+
```python
84+
from torchapply import apply_model
85+
86+
apply_model(
87+
model,
88+
[({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')) for _ in range(10)],
89+
single=False
90+
)
91+
```
92+

torchapply/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "0.0.2"
1+
__version__ = "0.0.3"

0 commit comments

Comments
 (0)