@@ -87,14 +87,25 @@ def __to_device__(self, x):
87
87
y = []
88
88
for i in x :
89
89
if isinstance (i , torch .Tensor ):
90
- y .append (i .to (self .dev ))
90
+ if i .dtype == torch .float32 :
91
+ y .append (i .to (self .dev , dtype = self .model .dtype ))
92
+ else :
93
+ y .append (i .to (self .dev ))
91
94
else :
92
95
pass
93
96
return y
94
97
elif isinstance (x , np .ndarray ):
95
- return torch .from_numpy (x ).to (self .device )
98
+ if x .dtype == np .float32 :
99
+ return torch .from_numpy (x ).to (
100
+ device = self .device , dtype = self .model .dtype
101
+ )
102
+ else :
103
+ return torch .from_numpy (x ).to (self .device )
96
104
elif isinstance (x , torch .Tensor ):
97
- return x .to (self .device )
105
+ if x .dtype == torch .float32 :
106
+ return x .to (device = self .device , dtype = self .model .dtype )
107
+ else :
108
+ return x .to (self .device )
98
109
99
110
def _get_overrides (self ):
100
111
overrides = []
@@ -116,6 +127,8 @@ def initialize(self, context):
116
127
main_conf = context .manifest ["model" ]["modelName" ].replace ("_" , "/" )
117
128
# set model to training true before calling the training step
118
129
self .model .train ()
130
+ # TODO: model ignores the precision set in the config
131
+ # self.model.double()
119
132
try :
120
133
with initialize (
121
134
config_path = "conf/" + "/" .join (main_conf .split ("/" )[0 :- 1 ]),
@@ -174,6 +187,9 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
174
187
log .info ("Streaming optimization handler called." )
175
188
result = defaultdict (list )
176
189
self .optimization_step = 0
190
+ if hasattr (self .trainer , "serve_context" ) and self .trainer .serve_context :
191
+ self .trainer .serve_context = None
192
+ self .trainer .serve_context = context
177
193
self .context = context
178
194
start_time = time .time ()
179
195
@@ -189,6 +205,7 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
189
205
td = self .preprocess (data )
190
206
# get the dataloader for returned dict
191
207
dataloader = td ["dataloader" ]
208
+ self .trainer .dataloader_length = len (dataloader )
192
209
# iterate over the dataloader
193
210
for batch_idx , batch in enumerate (torch .utils .data .DataLoader (dataloader )):
194
211
self .model .optimization_step = 0
@@ -209,14 +226,14 @@ def handle(self, data: typing.Mapping[str, typing.Any], context: typing.Any):
209
226
# batch[key] = (
210
227
# f"Running batch_idx {batch_idx} with completion percentage {float((batch_idx + 1)/len(dataloader) * 100):.2f}%."
211
228
# )
212
- log .info (batch [key ])
213
- send_intermediate_predict_response (
214
- batch ,
215
- self .context .request_ids ,
216
- "Intermediate response from the model." ,
217
- 200 ,
218
- self .context ,
219
- )
229
+ # log.info(batch[key])
230
+ # send_intermediate_predict_response(
231
+ # batch,
232
+ # self.context.request_ids,
233
+ # "Intermediate response from the model.",
234
+ # 200,
235
+ # self.context,
236
+ # )
220
237
if self .keys :
221
238
for k in self .keys :
222
239
result [k ].append (batch [k ])
0 commit comments