@@ -41,16 +41,16 @@ def _min_norm_2d(vecs, dps):
41
41
if (i ,j ) not in dps :
42
42
dps [(i , j )] = 0.0
43
43
for k in range (len (vecs [i ])):
44
- dps [(i ,j )] += torch .dot (vecs [i ][k ], vecs [j ][k ]).data [ 0 ]
44
+ dps [(i ,j )] += torch .mul (vecs [i ][k ], vecs [j ][k ]).sum (). data . cpu ()
45
45
dps [(j , i )] = dps [(i , j )]
46
46
if (i ,i ) not in dps :
47
47
dps [(i , i )] = 0.0
48
48
for k in range (len (vecs [i ])):
49
- dps [(i ,i )] += torch .dot (vecs [i ][k ], vecs [i ][k ]).data [ 0 ]
49
+ dps [(i ,i )] += torch .mul (vecs [i ][k ], vecs [i ][k ]).sum (). data . cpu ()
50
50
if (j ,j ) not in dps :
51
51
dps [(j , j )] = 0.0
52
52
for k in range (len (vecs [i ])):
53
- dps [(j , j )] += torch .dot (vecs [j ][k ], vecs [j ][k ]).data [ 0 ]
53
+ dps [(j , j )] += torch .mul (vecs [j ][k ], vecs [j ][k ]).sum (). data . cpu ()
54
54
c ,d = MinNormSolver ._min_norm_element_from2 (dps [(i ,i )], dps [(i ,j )], dps [(j ,j )])
55
55
if d < dmin :
56
56
dmin = d
@@ -184,16 +184,16 @@ def gradient_normalizers(grads, losses, normalization_type):
184
184
gn = {}
185
185
if normalization_type == 'l2' :
186
186
for t in grads :
187
- gn [t ] = np .sqrt (np .sum ([gr .pow (2 ).sum ().data [ 0 ] for gr in grads [t ]]))
187
+ gn [t ] = np .sqrt (np .sum ([gr .pow (2 ).sum ().data . cpu () for gr in grads [t ]]))
188
188
elif normalization_type == 'loss' :
189
189
for t in grads :
190
190
gn [t ] = losses [t ]
191
191
elif normalization_type == 'loss+' :
192
192
for t in grads :
193
- gn [t ] = losses [t ] * np .sqrt (np .sum ([gr .pow (2 ).sum ().data [ 0 ] for gr in grads [t ]]))
193
+ gn [t ] = losses [t ] * np .sqrt (np .sum ([gr .pow (2 ).sum ().data . cpu () for gr in grads [t ]]))
194
194
elif normalization_type == 'none' :
195
195
for t in grads :
196
196
gn [t ] = 1.0
197
197
else :
198
198
print ('ERROR: Invalid Normalization Type' )
199
- return gn
199
+ return gn
0 commit comments