Skip to content

Commit ab45708

Browse files
authored
Merge pull request #1086 from JanFSchulte/softmaxfix_torch
Fix softmax parsing in pytorch and add test
2 parents f9a2412 + a306e3f commit ab45708

File tree

3 files changed

+26
-3
lines changed

3 files changed

+26
-3
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,13 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
6262
layer['activation'] = 'ThresholdedReLU'
6363
if layer['activ_param'] < 0:
6464
raise Exception('negative threshold values not supported')
65-
66-
if hasattr(node, 'dim'):
65+
if hasattr(class_object, 'dim'):
6766
layer['axis'] = class_object.dim
67+
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
68+
layer['axis'] = -1
69+
if 'IOType' in config:
70+
if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
71+
raise Exception('dim needs to be -1 for io_stream')
6872
else:
6973
if layer['class_name'] in ['ReLU', 'Sigmoid', 'Tanh']:
7074
layer['class_name'] = 'Activation'
@@ -80,6 +84,11 @@ def parse_activation_layer(operation, layer_name, input_names, input_shapes, nod
8084
layer['activation'] = 'ThresholdedReLU'
8185
if 'dim' in node.kwargs:
8286
layer['axis'] = node.kwargs['dim']
87+
if layer['class_name'] == 'Softmax' and layer['axis'] is None:
88+
layer['axis'] = -1
89+
if 'IOType' in config:
90+
if layer['class_name'] == 'Softmax' and config['IOType'] == 'io_stream' and layer['axis'] != -1:
91+
raise Exception('dim needs to be -1 for io_stream')
8392

8493
output_shape = input_shapes[0]
8594
return layer, output_shape

hls4ml/model/optimizer/passes/convert_to_channels_last.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,11 @@ def transform(self, model, node):
9494
node.add_output_variable(shape, dims)
9595

9696
# Have to transpose back before flattening to get correct order of elements in the flattened tensor
97-
if isinstance(node, Reshape) and len(node.attributes['target_shape']) == 1:
97+
if (
98+
isinstance(node, Reshape)
99+
and len(node.attributes['target_shape']) == 1
100+
and not model.config.config['HLSConfig']['Model']['ChannelsLastConversion'] == "internal"
101+
):
98102
previous_node = node.get_input_node(node.inputs[0])
99103
input = previous_node.name
100104
outshape = previous_node.get_output_variable().shape

test/pytest/test_pytorch_api.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ def test_linear(backend, io_type):
6363
@pytest.mark.parametrize(
6464
"activation_function",
6565
[
66+
nn.Softmax(dim=-1),
6667
nn.ReLU(),
6768
nn.Tanh(),
6869
nn.LeakyReLU(negative_slope=1.0),
@@ -119,6 +120,14 @@ def forward(self, x):
119120
return nn.functional.relu(x)
120121

121122

123+
class SoftmaxModel(nn.Module):
124+
def __init__(self):
125+
super().__init__()
126+
127+
def forward(self, x):
128+
return nn.functional.softmax(x, dim=-1)
129+
130+
122131
class TanHModel(nn.Module):
123132
def __init__(self):
124133
super().__init__()
@@ -162,6 +171,7 @@ def forward(self, x):
162171
@pytest.mark.parametrize(
163172
"activation_function",
164173
[
174+
SoftmaxModel(),
165175
ReLuModel(),
166176
TanHModel(),
167177
LeakyReLuModel(),

0 commit comments

Comments
 (0)