diff --git a/src/i3dpt.py b/src/i3dpt.py index 99d7487..cb07ba9 100644 --- a/src/i3dpt.py +++ b/src/i3dpt.py @@ -421,8 +421,8 @@ def load_conv3d(state_dict, name_pt, sess, name_tf, bias=False, bn=True): out_planes = conv_weights_rs.shape[0] state_dict[name_pt + '.batch3d.weight'] = torch.ones(out_planes) - state_dict[name_pt + - '.batch3d.bias'] = torch.from_numpy(beta.squeeze()) + state_dict[name_pt + '.batch3d.bias'] = torch.from_numpy(beta.squeeze()) + state_dict[name_pt + '.batch3d.running_mean'] = torch.from_numpy(moving_mean.squeeze()) state_dict[name_pt