Skip to content

Commit c1dd8d1

Browse files
committed
add custom loss
1 parent 311a6f4 commit c1dd8d1

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

smdebug/pytorch/hook.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Third Party
44
import torch
55
import torch.distributed as dist
6+
from torch.nn.modules.loss import _Loss
67

78
# First Party
89
from smdebug.core.collection import DEFAULT_PYTORCH_COLLECTIONS, CollectionKeys
@@ -154,6 +155,9 @@ def forward_hook(self, module, inputs, outputs):
154155
if not self._get_collections_to_save_for_step():
155156
return
156157

158+
if isinstance(module, _Loss):
159+
module._module_name = module._get_name()
160+
157161
module_name = module._module_name
158162
# This overwhelms the logs; turn back on if you really need it
159163
# logger.debug("Processing the global step {0} for module {1}".format(self.step, module_name))

tests/zero_code_change/test_pytorch_integration.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,20 +23,43 @@
2323
from smdebug.core.utils import SagemakerSimulator, ScriptSimulator
2424

2525

26+
class CustomCrossEntropyLoss(nn.modules.loss._WeightedLoss):
27+
__constants__ = ["weight", "ignore_index", "reduction"]
28+
29+
def __init__(
30+
self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean"
31+
):
32+
super(CustomCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction)
33+
self.ignore_index = ignore_index
34+
35+
def forward(self, input, target):
36+
return F.cross_entropy(
37+
input,
38+
target,
39+
weight=self.weight,
40+
ignore_index=self.ignore_index,
41+
reduction=self.reduction,
42+
)
43+
44+
2645
@pytest.mark.skipif(
2746
torch.__version__ == "1.7.0",
2847
reason="Disabling the test temporarily until we root cause the version incompatibility",
2948
)
3049
@pytest.mark.parametrize("script_mode", [False])
3150
@pytest.mark.parametrize("use_loss_module", [True, False])
32-
def test_pytorch(script_mode, use_loss_module):
51+
@pytest.mark.parametrize("use_custom_loss_module", [True, False])
52+
def test_pytorch(script_mode, use_loss_module, use_custom_loss_module):
3353
smd.del_hook()
3454

3555
sim_class = ScriptSimulator if script_mode else SagemakerSimulator
3656
with sim_class() as sim:
3757
trainloader, testloader = get_dataloaders()
3858
net = Net()
39-
criterion = nn.CrossEntropyLoss()
59+
if use_custom_loss_module:
60+
criterion = CustomCrossEntropyLoss()
61+
else:
62+
criterion = nn.CrossEntropyLoss()
4063
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
4164

4265
if script_mode:

0 commit comments

Comments
 (0)