|
9 | 9 | from __future__ import print_function |
10 | 10 | from __future__ import unicode_literals |
11 | 11 |
|
| 12 | +import sys |
12 | 13 | import logging |
13 | 14 |
|
14 | 15 | import numpy as np |
@@ -174,12 +175,10 @@ def version_4(cls, ctx, node, **kwargs): |
174 | 175 | if perm.is_const(): |
175 | 176 | # perms is passed as const |
176 | 177 | dims = perm.get_tensor_value() |
| 178 | + ctx.remove_input(node, node.input[1]) |
| 179 | + node.set_attr("perm", dims) |
177 | 180 | else: |
178 | | - # calculate perms from shape |
179 | | - shape = ctx.get_shape(node.input[1]) |
180 | | - dims = [i for i in range(len(shape) - 1, -1)] |
181 | | - ctx.remove_input(node, node.input[1]) |
182 | | - node.set_attr("perm", dims) |
| 181 | + utils.make_sure(False, "perm can't be dynamic in ONNX") |
183 | 182 | else: |
184 | 183 | # graph rewrite moved perm to attribute |
185 | 184 | pass |
@@ -356,7 +355,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt |
356 | 355 | # reshape indices into [sum(indices[:-1]), indices[-1]] |
357 | 356 | indices_shape = ctx.make_node("Shape", [indices], dtypes=[TensorProto.INT64]) |
358 | 357 | indices_size = ctx.make_node("Size", [indices]) |
359 | | - attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [-1]} |
| 358 | + attr = {"axes": [0], "ends": [sys.maxsize], "starts": [-1]} |
360 | 359 | inputs_map = {"data": indices_shape.output[0], **attr} |
361 | 360 | inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64]) |
362 | 361 | outter_shape = ctx.make_node("Div", |
@@ -414,7 +413,7 @@ def make_gathernd(ctx, params, indices, output, scope_name, t_params, shapes, dt |
414 | 413 | [inner_loop_shape.output[0], one_const.output[0]], |
415 | 414 | attr={"axis": 0}, |
416 | 415 | dtypes=[TensorProto.INT64]) |
417 | | - attr = {"axes": [0], "ends": [utils.get_max_value(np.int64)], "starts": [1]} |
| 416 | + attr = {"axes": [0], "ends": [sys.maxsize], "starts": [1]} |
418 | 417 | inputs_map = {"data": inner_loop_shape_.output[0], **attr} |
419 | 418 | output_inner_shape = GraphBuilder(ctx).make_slice(inputs_map, dtypes=[TensorProto.INT64]) |
420 | 419 | attr = {"axes": [0], "ends": [-1], "starts": [0]} |
|
0 commit comments