You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
when using the Informermodel and a pred_len == 1, we observe an index error
IndexError: max(): Expected reduction dim 2 to have non-zero size
the reason, we believe, is this:
the forward method takes in x_dec that has shape (b,1,f) and called in the self.decoder module (a DecoderLayer layer)
inside, a self-attention mechanism is performed with x_dec (denoted with x ), which in turn called the self.inner_attention module, which is the forward call of the ProbAttention class.
in this step, keys=x=x_dec, hence the line _, L_K, _, _ = keys.shape sets L_K == 1, which in turns sets U_part to 0 as there is a log taken over L_K in the definition U_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
the last call is the _prob_QK method of the ProbAttention class, that performs M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K), resulting in a IndexError: max(): Expected reduction dim 2 to have non-zero size. (Q_K_sample is an empty tensor)
when using the
Informer
model and apred_len == 1
, we observe an index errorthe reason, we believe, is this:
forward
method takes inx_dec
that has shape(b,1,f)
and called in theself.decoder
module (aDecoderLayer
layer)x_dec
(denoted withx
), which in turn called theself.inner_attention
module, which is theforward
call of theProbAttention
class.keys=x=x_dec
, hence the line_, L_K, _, _ = keys.shape
setsL_K == 1
, which in turns setsU_part
to0
as there is a log taken overL_K
in the definitionU_part = self.factor * np.ceil(np.log(L_K)).astype('int').item() # c*ln(L_k)
_prob_QK
method of theProbAttention
class, that performsM = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
, resulting ina IndexError: max(): Expected reduction dim 2 to have non-zero size
. (Q_K_sample
is an empty tensor)the command we used is:
The text was updated successfully, but these errors were encountered: