Skip to content
This repository was archived by the owner on Sep 2, 2024. It is now read-only.

Commit d45eb26

Browse files
authored
Merge pull request #25 from shaanrockz/master
fix dot product for newer version of torch
2 parents 5787487 + f295df0 commit d45eb26

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

multi_task/min_norm_solvers.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -41,16 +41,16 @@ def _min_norm_2d(vecs, dps):
4141
if (i,j) not in dps:
4242
dps[(i, j)] = 0.0
4343
for k in range(len(vecs[i])):
44-
dps[(i,j)] += torch.dot(vecs[i][k], vecs[j][k]).data[0]
44+
dps[(i,j)] += torch.mul(vecs[i][k], vecs[j][k]).sum().data.cpu()
4545
dps[(j, i)] = dps[(i, j)]
4646
if (i,i) not in dps:
4747
dps[(i, i)] = 0.0
4848
for k in range(len(vecs[i])):
49-
dps[(i,i)] += torch.dot(vecs[i][k], vecs[i][k]).data[0]
49+
dps[(i,i)] += torch.mul(vecs[i][k], vecs[i][k]).sum().data.cpu()
5050
if (j,j) not in dps:
5151
dps[(j, j)] = 0.0
5252
for k in range(len(vecs[i])):
53-
dps[(j, j)] += torch.dot(vecs[j][k], vecs[j][k]).data[0]
53+
dps[(j, j)] += torch.mul(vecs[j][k], vecs[j][k]).sum().data.cpu()
5454
c,d = MinNormSolver._min_norm_element_from2(dps[(i,i)], dps[(i,j)], dps[(j,j)])
5555
if d < dmin:
5656
dmin = d
@@ -184,16 +184,16 @@ def gradient_normalizers(grads, losses, normalization_type):
184184
gn = {}
185185
if normalization_type == 'l2':
186186
for t in grads:
187-
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
187+
gn[t] = np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
188188
elif normalization_type == 'loss':
189189
for t in grads:
190190
gn[t] = losses[t]
191191
elif normalization_type == 'loss+':
192192
for t in grads:
193-
gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data[0] for gr in grads[t]]))
193+
gn[t] = losses[t] * np.sqrt(np.sum([gr.pow(2).sum().data.cpu() for gr in grads[t]]))
194194
elif normalization_type == 'none':
195195
for t in grads:
196196
gn[t] = 1.0
197197
else:
198198
print('ERROR: Invalid Normalization Type')
199-
return gn
199+
return gn

0 commit comments

Comments
 (0)