Skip to content

Commit ce780c2

Browse files
committed
Merge branch 'main' of https://github.com/moverseai/moai
2 parents df4f771 + 7d47896 commit ce780c2

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

moai/serve/model.py

+4
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,10 @@ def initialize(self, context, extract_files=True):
137137
log.error(f"An error has occured while loading the model:\n{e}")
138138
self.model = self.model.to(self.device)
139139
self.model.eval()
140+
if hasattr(cfg, "archive") and hasattr(cfg.archive, "model_precision"):
141+
if cfg.archive.model_precision == "double":
142+
log.info("Setting model to double precision.")
143+
self.model.double()
140144
self.initialized = True
141145
log.info(
142146
f"Model ({type(self.model.model if hasattr(self.model, 'model') else self.model)}) loaded successfully."

moai/serve/streaming_optimizer.py

+28-11
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,25 @@ def __to_device__(self, x):
8787
y = []
8888
for i in x:
8989
if isinstance(i, torch.Tensor):
90-
y.append(i.to(self.dev))
90+
if i.dtype == torch.float32:
91+
y.append(i.to(self.dev, dtype=self.model.dtype))
92+
else:
93+
y.append(i.to(self.dev))
9194
else:
9295
pass
9396
return y
9497
elif isinstance(x, np.ndarray):
95-
return torch.from_numpy(x).to(self.device)
98+
if x.dtype == np.float32:
99+
return torch.from_numpy(x).to(
100+
device=self.device, dtype=self.model.dtype
101+
)
102+
else:
103+
return torch.from_numpy(x).to(self.device)
96104
elif isinstance(x, torch.Tensor):
97-
return x.to(self.device)
105+
if x.dtype == torch.float32:
106+
return x.to(device=self.device, dtype=self.model.dtype)
107+
else:
108+
return x.to(self.device)
98109

99110
def _get_overrides(self):
100111
overrides = []
@@ -116,6 +127,8 @@ def initialize(self, context):
116127
main_conf = context.manifest["model"]["modelName"].replace("_", "/")
117128
# set model to training true before calling the training step
118129
self.model.train()
130+
# TODO: model ignores the precision set in the config
131+
# self.model.double()
119132
try:
120133
with initialize(
121134
config_path="conf/" + "/".join(main_conf.split("/")[0:-1]),
@@ -174,6 +187,9 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
174187
log.info("Streaming optimization handler called.")
175188
result = defaultdict(list)
176189
self.optimization_step = 0
190+
if hasattr(self.trainer, "serve_context") and self.trainer.serve_context:
191+
self.trainer.serve_context = None
192+
self.trainer.serve_context = context
177193
self.context = context
178194
start_time = time.time()
179195

@@ -189,6 +205,7 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
189205
td = self.preprocess(data)
190206
# get the dataloader for returned dict
191207
dataloader = td["dataloader"]
208+
self.trainer.dataloader_length = len(dataloader)
192209
# iterate over the dataloader
193210
for batch_idx, batch in enumerate(torch.utils.data.DataLoader(dataloader)):
194211
self.model.optimization_step = 0
@@ -209,14 +226,14 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
209226
# batch[key] = (
210227
# f"Running batch_idx {batch_idx} with completion percentage {float((batch_idx + 1)/len(dataloader) * 100):.2f}%."
211228
# )
212-
log.info(batch[key])
213-
send_intermediate_predict_response(
214-
batch,
215-
self.context.request_ids,
216-
"Intermediate response from the model.",
217-
200,
218-
self.context,
219-
)
229+
# log.info(batch[key])
230+
# send_intermediate_predict_response(
231+
# batch,
232+
# self.context.request_ids,
233+
# "Intermediate response from the model.",
234+
# 200,
235+
# self.context,
236+
# )
220237
if self.keys:
221238
for k in self.keys:
222239
result[k].append(batch[k])

0 commit comments

Comments
 (0)