-
Notifications
You must be signed in to change notification settings - Fork 10
Expand file tree
/
Copy pathmodel_factory.py
More file actions
32 lines (24 loc) · 861 Bytes
/
model_factory.py
File metadata and controls
32 lines (24 loc) · 861 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
"""Python file to instantite the model and the transform that goes with it."""
from data import data_transforms
from model import Net
class ModelFactory:
def __init__(self, model_name: str):
self.model_name = model_name
self.model = self.init_model()
self.transform = self.init_transform()
def init_model(self):
if self.model_name == "basic_cnn":
return Net()
else:
raise NotImplementedError("Model not implemented")
def init_transform(self):
if self.model_name == "basic_cnn":
return data_transforms
else:
raise NotImplementedError("Transform not implemented")
def get_model(self):
return self.model
def get_transform(self):
return self.transform
def get_all(self):
return self.model, self.transform