|
27 | 27 | from gluonts.transform import ( |
28 | 28 | AddObservedValuesIndicator, |
29 | 29 | AsNumpyArray, |
30 | | - CausalMeanValueImputation, |
31 | 30 | ExpandDimArray, |
32 | 31 | TestSplitSampler, |
33 | 32 | Transformation, |
@@ -82,7 +81,6 @@ def __init__( |
82 | 81 | module: Optional[MoiraiModule] = None, |
83 | 82 | patch_size: int | str = "auto", |
84 | 83 | num_samples: int = 100, |
85 | | - mode: str = "direct", |
86 | 84 | ): |
87 | 85 | assert (module is not None) or ( |
88 | 86 | module_kwargs is not None |
@@ -334,139 +332,22 @@ def forward( |
334 | 332 | idx = val_loss.argmin(dim=0) |
335 | 333 | return preds[idx, torch.arange(len(idx), device=idx.device)] |
336 | 334 | else: |
337 | | - if self.hparams.mode == "direct": |
338 | | - distr = self._get_distr( |
339 | | - self.hparams.patch_size, |
340 | | - past_target, |
341 | | - past_observed_target, |
342 | | - past_is_pad, |
343 | | - feat_dynamic_real, |
344 | | - observed_feat_dynamic_real, |
345 | | - past_feat_dynamic_real, |
346 | | - past_observed_feat_dynamic_real, |
347 | | - ) |
348 | | - preds = distr.sample( |
349 | | - torch.Size((num_samples or self.hparams.num_samples,)) |
350 | | - ) |
351 | | - return self._format_preds( |
352 | | - self.hparams.patch_size, preds, past_target.shape[-1] |
353 | | - ) |
354 | | - |
355 | | - elif self.hparams.mode == "autoregressive": |
356 | | - context_step = self.context_token_length(self.hparams.patch_size) |
357 | | - context_token = self.hparams.target_dim * context_step |
358 | | - predict_step = self.prediction_token_length(self.hparams.patch_size) |
359 | | - predict_token = self.hparams.target_dim * predict_step |
360 | | - |
361 | | - ( |
362 | | - target, |
363 | | - observed_mask, |
364 | | - sample_id, |
365 | | - time_id, |
366 | | - variate_id, |
367 | | - prediction_mask, |
368 | | - ) = self._convert( |
369 | | - self.hparams.patch_size, |
370 | | - past_target, |
371 | | - past_observed_target, |
372 | | - past_is_pad, |
373 | | - feat_dynamic_real=feat_dynamic_real, |
374 | | - observed_feat_dynamic_real=observed_feat_dynamic_real, |
375 | | - past_feat_dynamic_real=past_feat_dynamic_real, |
376 | | - past_observed_feat_dynamic_real=past_observed_feat_dynamic_real, |
377 | | - ) |
378 | | - patch_size = ( |
379 | | - torch.ones_like(time_id, dtype=torch.long) * self.hparams.patch_size |
380 | | - ) |
381 | | - |
382 | | - pred_index = torch.arange( |
383 | | - start=context_step - 1, end=context_token, step=context_step |
384 | | - ) |
385 | | - assign_index = torch.arange( |
386 | | - start=context_token, |
387 | | - end=context_token + predict_token, |
388 | | - step=predict_step, |
389 | | - ) |
390 | | - |
391 | | - if predict_step == 1: |
392 | | - distr = self.module( |
393 | | - target, |
394 | | - observed_mask, |
395 | | - sample_id, |
396 | | - time_id, |
397 | | - variate_id, |
398 | | - prediction_mask, |
399 | | - patch_size, |
400 | | - ) |
401 | | - preds = distr.sample( |
402 | | - torch.Size((num_samples or self.hparams.num_samples,)) |
403 | | - ) |
404 | | - preds[..., assign_index, :] = preds[..., pred_index, :] |
405 | | - return self._format_preds( |
406 | | - self.hparams.patch_size, preds, self.hparams.target_dim |
407 | | - ) |
408 | | - else: |
409 | | - distr = self.module( |
410 | | - target, |
411 | | - observed_mask, |
412 | | - sample_id, |
413 | | - time_id, |
414 | | - variate_id, |
415 | | - prediction_mask, |
416 | | - patch_size, |
417 | | - ) |
418 | | - preds = distr.sample(torch.Size((self.hparams.num_samples,))) |
419 | | - |
420 | | - expand_target = target.unsqueeze(0).repeat( |
421 | | - self.hparams.num_samples, 1, 1, 1 |
422 | | - ) |
423 | | - expand_prediction_mask = prediction_mask.unsqueeze(0).repeat( |
424 | | - self.hparams.num_samples, 1, 1 |
425 | | - ) |
426 | | - expand_observed_mask = observed_mask.unsqueeze(0).expand( |
427 | | - self.hparams.num_samples, -1, -1, -1 |
428 | | - ) |
429 | | - expand_sample_id = sample_id.unsqueeze(0).expand( |
430 | | - self.hparams.num_samples, -1, -1 |
431 | | - ) |
432 | | - expand_time_id = time_id.unsqueeze(0).expand( |
433 | | - self.hparams.num_samples, -1, -1 |
434 | | - ) |
435 | | - expand_variate_id = variate_id.unsqueeze(0).expand( |
436 | | - self.hparams.num_samples, -1, -1 |
437 | | - ) |
438 | | - expand_patch_size = patch_size.unsqueeze(0).expand( |
439 | | - self.hparams.num_samples, -1, -1 |
440 | | - ) |
441 | | - |
442 | | - expand_target[..., assign_index, :] = preds[..., pred_index, :] |
443 | | - expand_prediction_mask[..., assign_index] = False |
444 | | - |
445 | | - remain_step = predict_step - 1 |
446 | | - while remain_step > 0: |
447 | | - distr = self.module( |
448 | | - expand_target, |
449 | | - expand_observed_mask, |
450 | | - expand_sample_id, |
451 | | - expand_time_id, |
452 | | - expand_variate_id, |
453 | | - expand_prediction_mask, |
454 | | - expand_patch_size, |
455 | | - ) |
456 | | - preds = distr.sample(torch.Size((1,))) |
457 | | - _, _, bs, token, ps = preds.shape |
458 | | - preds = preds.view(-1, bs, token, ps) |
459 | | - |
460 | | - pred_index = assign_index |
461 | | - assign_index = assign_index + 1 |
462 | | - expand_target[..., assign_index, :] = preds[..., pred_index, :] |
463 | | - expand_prediction_mask[..., assign_index] = False |
464 | | - |
465 | | - remain_step -= 1 |
466 | | - |
467 | | - return self._format_preds( |
468 | | - self.hparams.patch_size, expand_target, self.hparams.target_dim |
469 | | - ) |
| 335 | + distr = self._get_distr( |
| 336 | + self.hparams.patch_size, |
| 337 | + past_target, |
| 338 | + past_observed_target, |
| 339 | + past_is_pad, |
| 340 | + feat_dynamic_real, |
| 341 | + observed_feat_dynamic_real, |
| 342 | + past_feat_dynamic_real, |
| 343 | + past_observed_feat_dynamic_real, |
| 344 | + ) |
| 345 | + preds = distr.sample( |
| 346 | + torch.Size((num_samples or self.hparams.num_samples,)) |
| 347 | + ) |
| 348 | + return self._format_preds( |
| 349 | + self.hparams.patch_size, preds, past_target.shape[-1] |
| 350 | + ) |
470 | 351 |
|
471 | 352 | def _val_loss( |
472 | 353 | self, |
@@ -1066,20 +947,12 @@ def get_default_transform(self) -> Transformation: |
1066 | 947 | dtype=np.float32, |
1067 | 948 | ) |
1068 | 949 | if self.hparams.target_dim == 1: |
1069 | | - transform += AddObservedValuesIndicator( |
1070 | | - target_field="target", |
1071 | | - output_field="observed_target", |
1072 | | - imputation_method=CausalMeanValueImputation(), |
1073 | | - dtype=bool, |
1074 | | - ) |
1075 | 950 | transform += ExpandDimArray(field="target", axis=0) |
1076 | | - transform += ExpandDimArray(field="observed_target", axis=0) |
1077 | | - else: |
1078 | | - transform += AddObservedValuesIndicator( |
1079 | | - target_field="target", |
1080 | | - output_field="observed_target", |
1081 | | - dtype=bool, |
1082 | | - ) |
| 951 | + transform += AddObservedValuesIndicator( |
| 952 | + target_field="target", |
| 953 | + output_field="observed_target", |
| 954 | + dtype=bool, |
| 955 | + ) |
1083 | 956 |
|
1084 | 957 | if self.hparams.feat_dynamic_real_dim > 0: |
1085 | 958 | transform += AsNumpyArray( |
|
0 commit comments