|
1 | 1 | # torchapply
|
2 |
| -Apply a torch model to some datapoints |
| 2 | +Apply a torch model to some datapoints. |
| 3 | + |
| 4 | +Here's an example: |
| 5 | + |
| 6 | +```python |
| 7 | +import torch |
| 8 | +from torch import tensor |
| 9 | + |
| 10 | + |
| 11 | +class Main(torch.nn.Module): |
| 12 | + def __init__(self, model_0, model_1): |
| 13 | + super().__init__() |
| 14 | + self.model_0 = model_0 |
| 15 | + self.model_1 = model_1 |
| 16 | + self.dictionary = {'apple': 0, 'orange': 1, 'pear': 2} |
| 17 | + |
| 18 | + def preprocess(self, arg): |
| 19 | + return [ |
| 20 | + { |
| 21 | + 'a': {'b': self.dictionary[arg[0]['a']['b']]}, |
| 22 | + 'c': self.dictionary[arg[0]['c']] |
| 23 | + }, |
| 24 | + torch.tensor([self.dictionary[x] for x in arg[1]]) |
| 25 | + ] |
| 26 | + |
| 27 | + def forward(self, args): |
| 28 | + return self.model_0(args[0]), self.model_1(args[1]) |
| 29 | + |
| 30 | + def postprocess(self, arg): |
| 31 | + total = [arg[0]['a']['b'].sum(), arg[0]['c'].sum(), arg[1].sum()] |
| 32 | + return {'score': sum(total), 'decision': sum(total) > 0} |
| 33 | + |
| 34 | + |
| 35 | +class ModelA(torch.nn.Module): |
| 36 | + def forward(self, args): |
| 37 | + return {'b': torch.randn(args['b'].shape[0], 10)} |
| 38 | + |
| 39 | + |
| 40 | +class ModelC(torch.nn.Module): |
| 41 | + def forward(self, args): |
| 42 | + return torch.randn(args.shape[0], 10) |
| 43 | + |
| 44 | + |
| 45 | +class Model1(torch.nn.Module): |
| 46 | + def forward(self, args): |
| 47 | + return torch.randn(args.shape[0], 10) |
| 48 | + |
| 49 | + |
| 50 | +class Model0(torch.nn.Module): |
| 51 | + def __init__(self, model_a, model_c): |
| 52 | + super().__init__() |
| 53 | + self.model_a = model_a |
| 54 | + self.model_c = model_c |
| 55 | + |
| 56 | + def forward(self, args): |
| 57 | + return {'a': self.model_a(args['a']), 'c': self.model_c(args['c'])} |
| 58 | + |
| 59 | + |
| 60 | +model = Main( |
| 61 | + model_0=Model0( |
| 62 | + model_a=ModelA(), |
| 63 | + model_c=ModelC() |
| 64 | + ), |
| 65 | + model_1=Model1() |
| 66 | +) |
| 67 | +``` |
| 68 | + |
| 69 | +Apply to a single datapoint: |
| 70 | + |
| 71 | +```python |
| 72 | +from torchapply import apply_model |
| 73 | + |
| 74 | +apply_model( |
| 75 | + model, |
| 76 | + ({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')), |
| 77 | + single=True |
| 78 | +) |
| 79 | +``` |
| 80 | + |
| 81 | +Apply to multiple datapoints: |
| 82 | + |
| 83 | +```python |
| 84 | +from torchapply import apply_model |
| 85 | + |
| 86 | +apply_model( |
| 87 | + model, |
| 88 | + [({'a': {'b': 'orange'}, 'c': 'pear'}, ('apple', 'apple')) for _ in range(10)], |
| 89 | + single=False |
| 90 | +) |
| 91 | +``` |
| 92 | + |
0 commit comments