|  | 
| 23 | 23 | from smdebug.core.utils import SagemakerSimulator, ScriptSimulator | 
| 24 | 24 | 
 | 
| 25 | 25 | 
 | 
|  | 26 | +class CustomCrossEntropyLoss(nn.modules.loss._WeightedLoss): | 
|  | 27 | +    __constants__ = ["weight", "ignore_index", "reduction"] | 
|  | 28 | + | 
|  | 29 | +    def __init__( | 
|  | 30 | +        self, weight=None, size_average=None, ignore_index=-100, reduce=None, reduction="mean" | 
|  | 31 | +    ): | 
|  | 32 | +        super(CustomCrossEntropyLoss, self).__init__(weight, size_average, reduce, reduction) | 
|  | 33 | +        self.ignore_index = ignore_index | 
|  | 34 | + | 
|  | 35 | +    def forward(self, input, target): | 
|  | 36 | +        return F.cross_entropy( | 
|  | 37 | +            input, | 
|  | 38 | +            target, | 
|  | 39 | +            weight=self.weight, | 
|  | 40 | +            ignore_index=self.ignore_index, | 
|  | 41 | +            reduction=self.reduction, | 
|  | 42 | +        ) | 
|  | 43 | + | 
|  | 44 | + | 
| 26 | 45 | @pytest.mark.skipif( | 
| 27 | 46 |     torch.__version__ == "1.7.0", | 
| 28 | 47 |     reason="Disabling the test temporarily until we root cause the version incompatibility", | 
| 29 | 48 | ) | 
| 30 | 49 | @pytest.mark.parametrize("script_mode", [False]) | 
| 31 | 50 | @pytest.mark.parametrize("use_loss_module", [True, False]) | 
| 32 |  | -def test_pytorch(script_mode, use_loss_module): | 
|  | 51 | +@pytest.mark.parametrize("use_custom_loss_module", [True, False]) | 
|  | 52 | +def test_pytorch(script_mode, use_loss_module, use_custom_loss_module): | 
| 33 | 53 |     smd.del_hook() | 
| 34 | 54 | 
 | 
| 35 | 55 |     sim_class = ScriptSimulator if script_mode else SagemakerSimulator | 
| 36 | 56 |     with sim_class() as sim: | 
| 37 | 57 |         trainloader, testloader = get_dataloaders() | 
| 38 | 58 |         net = Net() | 
| 39 |  | -        criterion = nn.CrossEntropyLoss() | 
|  | 59 | +        if use_custom_loss_module: | 
|  | 60 | +            criterion = CustomCrossEntropyLoss() | 
|  | 61 | +        else: | 
|  | 62 | +            criterion = nn.CrossEntropyLoss() | 
| 40 | 63 |         optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9) | 
| 41 | 64 | 
 | 
| 42 | 65 |         if script_mode: | 
|  | 
0 commit comments