@@ -370,8 +370,10 @@ def __init__(
370370 raise ValueError ("Error, the number of output channels has to be an integer multiple of the group size" )
371371 self .groupsize_in = in_channels // self .groups
372372 self .groupsize_out = out_channels // self .groups
373+ # keep this for backward compatibility
374+ self .groupsize = self .groupsize_in
373375 scale = math .sqrt (1.0 / self .groupsize_in / self .kernel_size )
374- self .weight = nn .Parameter (scale * torch .randn (self .groups , self .groupsize_out , self .groupsize_in , self .kernel_size ))
376+ self .weight = nn .Parameter (scale * torch .randn (self .groups * self .groupsize_out , self .groupsize_in , self .kernel_size ))
375377
376378 if bias :
377379 self .bias = nn .Parameter (torch .zeros (out_channels ))
@@ -496,7 +498,7 @@ def __init__(
496498 self .psi = _get_psi (self .kernel_size , self .psi_idx , self .psi_vals , self .nlat_in , self .nlon_in , self .nlat_out , self .nlon_out )
497499
498500 def extra_repr (self ):
499- return f"in_shape={ (self .nlat_in , self .nlon_in )} , out_shape={ (self .nlat_out , self .nlon_out )} , in_chans={ self .groupsize * self .groups } , out_chans={ self .weight .shape [0 ]} , filter_basis={ self .filter_basis } , kernel_shape={ self .kernel_shape } , groups={ self .groups } "
501+ return f"in_shape={ (self .nlat_in , self .nlon_in )} , out_shape={ (self .nlat_out , self .nlon_out )} , in_chans={ self .groupsize_in * self .groups } , out_chans={ self .weight .shape [0 ]} , filter_basis={ self .filter_basis } , kernel_shape={ self .kernel_shape } , groups={ self .groups } "
500502
501503 @property
502504 def psi_idx (self ):
@@ -524,7 +526,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
524526 # weight multiplication
525527 B , H , W , _ , K = xpc .shape
526528 xpc = xpc .reshape (B , H , W , self .groups , self .groupsize_in , K )
527- outp = torch .einsum ("bxygck,gock->bxygo" , xpc , self .weight )
529+ outp = torch .einsum ("bxygck,gock->bxygo" , xpc , self .weight . reshape ( self . groups , self . groupsize_out , self . groupsize_in , self . kernel_size ) )
528530 outp = outp .reshape (B , H , W , - 1 ).contiguous ()
529531
530532 # permute output
@@ -538,7 +540,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
538540 x = x .reshape (B , self .groups , self .groupsize_in , K , H , W )
539541
540542 # weight multiplication
541- out = torch .einsum ("bgckxy,gock->bgoxy" , x , self .weight )
543+ out = torch .einsum ("bgckxy,gock->bgoxy" , x , self .weight . reshape ( self . groups , self . groupsize_out , self . groupsize_in , self . kernel_size ) )
542544 out = out .reshape (B , - 1 , H , W ).contiguous ()
543545
544546 if self .bias is not None :
@@ -657,7 +659,7 @@ def __init__(
657659 self .psi_st = _get_psi (self .kernel_size , self .psi_idx , self .psi_vals , self .nlat_in , self .nlon_in , self .nlat_out , self .nlon_out , semi_transposed = True )
658660
659661 def extra_repr (self ):
660- return f"in_shape={ (self .nlat_in , self .nlon_in )} , out_shape={ (self .nlat_out , self .nlon_out )} , in_chans={ self .groupsize * self .groups } , out_chans={ self .weight .shape [0 ]} , filter_basis={ self .filter_basis } , kernel_shape={ self .kernel_shape } , groups={ self .groups } "
662+ return f"in_shape={ (self .nlat_in , self .nlon_in )} , out_shape={ (self .nlat_out , self .nlon_out )} , in_chans={ self .groupsize_in * self .groups } , out_chans={ self .weight .shape [0 ]} , filter_basis={ self .filter_basis } , kernel_shape={ self .kernel_shape } , groups={ self .groups } "
661663
662664 @property
663665 def psi_idx (self ):
@@ -674,7 +676,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
674676
675677 # weight multiplication
676678 xp = xp .reshape (B , H , W , self .groups , self .groupsize_in )
677- xpc = torch .einsum ("bxygc,gock->bxygok" , xp , self .weight )
679+ xpc = torch .einsum ("bxygc,gock->bxygok" , xp , self .weight . reshape ( self . groups , self . groupsize_out , self . groupsize_in , self . kernel_size ) )
678680 xpc = xpc .reshape (B , H , W , - 1 , self .kernel_size ).contiguous ()
679681
680682 # disco contraction
@@ -695,7 +697,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
695697 else :
696698 # weight multiplication
697699 x = x .reshape (B , self .groups , self .groupsize_in , H , W )
698- xc = torch .einsum ("bgcxy,gock->bgokxy" , x , self .weight )
700+ xc = torch .einsum ("bgcxy,gock->bgokxy" , x , self .weight . reshape ( self . groups , self . groupsize_out , self . groupsize_in , self . kernel_size ) )
699701 xc = xc .reshape (B , self .groups * self .groupsize_out , - 1 , H , W ).contiguous ()
700702
701703 # disco contraction
0 commit comments