-
Notifications
You must be signed in to change notification settings - Fork 6
Description
请问SDGL在PEMS04数据集上是不是不需要lag这个参数。
# add time window
x_tra, y_tra = Add_Window_Horizon(data_train, args.lag, args.horizon, single)
x_val, y_val = Add_Window_Horizon(data_val, args.lag, args.horizon, single)
x_test, y_test = Add_Window_Horizon(data_test, args.lag, args.horizon, single)
在去掉args.lag参数以后,会出现下面的错误:
Traceback (most recent call last):
File "/root/autodl-tmp/SDGL-main/SDGL/Pems4/train_pems.py", line 218, in
main()
File "/root/autodl-tmp/SDGL-main/SDGL/Pems4/train_pems.py", line 116, in main
metrics = engine.train(trainx, trainy, pred_time_embed=None, iter=iter)
File "/root/autodl-tmp/SDGL-main/SDGL/Pems4/../../SDGL/Pems4/engine.py", line 49, in train
loss = torch.mean(gl_loss) * self.gc_order + self.loss_usual(predict, real) # self.loss(predict, real, 0.0)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/modules/loss.py", line 933, in forward
return F.smooth_l1_loss(input, target, reduction=self.reduction, beta=self.beta)
File "/root/miniconda3/lib/python3.8/site-packages/torch/nn/functional.py", line 3238, in smooth_l1_loss
expanded_input, expanded_target = torch.broadcast_tensors(input, target)
File "/root/miniconda3/lib/python3.8/site-packages/torch/functional.py", line 76, in broadcast_tensors
return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined]
RuntimeError: The size of tensor a (12) must match the size of tensor b (0) at non-singleton dimension 3
发现predic和real的维度不匹配
if gl_loss is None:
loss = self.loss(predict, real, 0.0)
else:
loss = torch.mean(gl_loss) * self.gc_order + self.loss_usual(predict, real) # self.loss(predict, real, 0.0)
# predict.shape(64, 1, 307, 12) real(64, 1, 307, 0)
请问问题是出在哪里呢,谢谢