Skip to content

Commit 5444ff1

Browse files
committed
Incorporate code review feedback on PR #694
1 parent 888e435 commit 5444ff1

File tree

4 files changed

+13
-23
lines changed

4 files changed

+13
-23
lines changed

opacus/grad_sample/embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
from typing import Dict, List
1717

18-
from opacus.grad_sample import embedding_norm_sample
1918
import torch
2019
import torch.nn as nn
2120

21+
from opacus.grad_sample import embedding_norm_sample
2222
from .utils import register_grad_sampler, register_norm_sampler
2323

2424

opacus/grad_sample/embedding_norm_sample.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,10 @@ def compute_embedding_norm_sample(
4646
activations: [tensor([[1, 1],
4747
[2, 0],
4848
[2, 0]])]
49-
backprops: tensor([[0.2000],
50-
[0.2000],
51-
[0.3000],
52-
[0.1000],
53-
[0.3000],
54-
[0.1000]])
49+
backprops: tensor([[[0.2], [0.2]],
50+
[[0.3], [0.1]],
51+
[[0.3], [0.1]]])
52+
backprops.shape: torch.Size([3, 2, 1])
5553
5654
Intermediate values:
5755
input_ids: tensor([[1, 1],

opacus/tests/grad_sample_module_fast_gradient_clipping_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ def test_norm_calculation(self):
351351
diff = flat_norms_normal - flat_norms_gc
352352

353353
logging.info(f"Diff = {diff}")
354-
msg = "Fail: Gradients from vanilla DP-SGD and from fast gradient clipping are different"
354+
msg = "Fail: Gradient norms from vanilla DP-SGD and from fast gradient clipping are different"
355355
assert torch.allclose(flat_norms_normal, flat_norms_gc, atol=1e-3), msg
356356

357357
def test_gradient_calculation(self):

opacus/tests/grad_samples/embedding_norm_sample_test.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@
1515

1616
import unittest
1717

18-
from opacus.grad_sample import embedding_norm_sample
1918
import torch
2019
import torch.nn as nn
20+
from opacus.grad_sample import embedding_norm_sample
2121

2222

2323
class TestComputeEmbeddingNormSample(unittest.TestCase):
@@ -36,15 +36,11 @@ def test_compute_embedding_norm_sample(self):
3636
# Example input ids (activations). Shape: [3, 2]
3737
input_ids = torch.tensor([[1, 1], [2, 0], [2, 0]], dtype=torch.long)
3838

39-
# Example gradients with respect to the embedding output (backprops).
40-
# Shape: [6, 1]
41-
grad_values = torch.tensor(
42-
[[0.2], [0.2], [0.3], [0.1], [0.3], [0.1]], dtype=torch.float32
39+
# Example backprops. Shape: [3, 2, 1]
40+
backprops = torch.tensor(
41+
[[[0.2], [0.2]], [[0.3], [0.1]], [[0.3], [0.1]]], dtype=torch.float32
4342
)
4443

45-
# Simulate backprop through embedding layer
46-
backprops = grad_values
47-
4844
# Wrap input_ids in a list as expected by the norm sample function
4945
activations = [input_ids]
5046

@@ -70,17 +66,17 @@ def test_compute_embedding_norm_sample_with_non_one_embedding_dim(self):
7066

7167
# Manually set weights for the embedding layer for testing
7268
embedding_layer.weight = nn.Parameter(
73-
torch.tensor([[0.1], [0.2], [0.3]], dtype=torch.float32)
69+
torch.tensor([[0.1, 0.1], [0.2, 0.2], [0.3, 0.3]], dtype=torch.float32)
7470
)
7571

7672
# Example input ids (activations). Shape: [6, 1, 1].
7773
input_ids = torch.tensor(
7874
[[[1]], [[1]], [[2]], [[0]], [[2]], [[0]]], dtype=torch.long
7975
)
8076

81-
# Example gradients per input id, with embedding_dim=2.
77+
# Example backprops per input id, with embedding_dim=2.
8278
# Shape: [6, 1, 1, 2]
83-
grad_values = torch.tensor(
79+
backprops = torch.tensor(
8480
[
8581
[[[0.2, 0.2]]],
8682
[[[0.2, 0.2]]],
@@ -92,9 +88,6 @@ def test_compute_embedding_norm_sample_with_non_one_embedding_dim(self):
9288
dtype=torch.float32,
9389
)
9490

95-
# Simulate backprop through embedding layer
96-
backprops = grad_values
97-
9891
# Wrap input_ids in a list as expected by the grad norm function
9992
activations = [input_ids]
10093

@@ -211,7 +204,6 @@ def test_compute_embedding_norm_sample_with_extra_activations_per_example(self):
211204
expected_norms = torch.tensor(
212205
[0.0150, 0.0071, 0.0005, 0.0081, 0.0039], dtype=torch.float32
213206
)
214-
print("expected_norms: ", expected_norms)
215207
computed_norms = result[embedding_layer.weight]
216208

217209
# Verify the computed norms match the expected norms

0 commit comments

Comments
 (0)