Skip to content

Commit d678093

Browse files
neginraooffacebook-github-bot
authored andcommitted
[ONNX] Extend op registration to next opsets (pytorch#32943)
Summary: Currently, custom ops are registered for a specific opset version. For example, all torchvision custom ops are registered for opset 11, and cannot be exported into higher opset versions. This PR extends op registration to higher opset versions. Pull Request resolved: pytorch#32943 Reviewed By: hl475 Differential Revision: D19739406 Pulled By: houseroad fbshipit-source-id: dd8b616de3a69a529d135fdd02608a17a8e421bc
1 parent 3b2f267 commit d678093

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

Diff for: test/onnx/test_custom_ops.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def symbolic_custom_add(g, self, other):
4444
y = torch.randn(2, 3, 4, requires_grad=False)
4545

4646
model = CustomAddModel()
47-
onnxir, _ = do_export(model, (x, y))
47+
onnxir, _ = do_export(model, (x, y), opset_version=11)
4848
onnx_model = onnx.ModelProto.FromString(onnxir)
4949
prepared = c2.prepare(onnx_model)
5050
caffe2_out = prepared.run(inputs=[x.cpu().numpy(), y.cpu().numpy()])

Diff for: test/onnx/test_pytorch_onnx_onnxruntime.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from torch.nn.utils import rnn as rnn_utils
1616
from model_defs.lstm_flattening_result import LstmFlatteningResult
1717
from model_defs.rnn_model_with_packed_sequence import RnnModelWithPackedSequence
18-
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfUnsupportedOpsetVersion, skipIfNoLapack, enableScriptTest
18+
from test_pytorch_common import skipIfUnsupportedMinOpsetVersion, skipIfNoLapack, enableScriptTest
1919
from test_pytorch_common import BATCH_SIZE
2020
from test_pytorch_common import RNN_BATCH_SIZE, RNN_SEQUENCE_LENGTH, RNN_INPUT_SIZE, RNN_HIDDEN_SIZE
2121
import model_defs.word_language_model as word_language_model
@@ -233,7 +233,6 @@ def run_word_language_model(self, model_name):
233233
# Only support CPU version, since tracer is not working in GPU RNN.
234234
self.run_test(model, (x, model.hidden))
235235

236-
@skipIfUnsupportedOpsetVersion([12])
237236
@skipIfUnsupportedMinOpsetVersion(11)
238237
def test_faster_rcnn(self):
239238
model = torchvision.models.detection.faster_rcnn.fasterrcnn_resnet50_fpn(pretrained=True, min_size=200,
@@ -273,15 +272,13 @@ def get_test_images(self):
273272
images = [image]
274273
return images
275274

276-
@skipIfUnsupportedOpsetVersion([12])
277275
@skipIfUnsupportedMinOpsetVersion(11)
278276
def test_mask_rcnn(self):
279277
model = torchvision.models.detection.mask_rcnn.maskrcnn_resnet50_fpn(pretrained=True, min_size=200,
280278
max_size=300)
281279
images = self.get_test_images()
282280
self.run_test(model, (images,), rtol=1e-3, atol=1e-5)
283281

284-
@skipIfUnsupportedOpsetVersion([12])
285282
@skipIfUnsupportedMinOpsetVersion(11)
286283
def test_keypoint_rcnn(self):
287284
class KeyPointRCNN(torch.nn.Module):

Diff for: torch/onnx/utils.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -860,7 +860,11 @@ def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version):
860860
raise RuntimeError("Failed to register operator {}. The domain {} is already a used domain."
861861
.format(symbolic_name, ns))
862862
import torch.onnx.symbolic_registry as sym_registry
863-
sym_registry.register_op(op_name, symbolic_fn, ns, opset_version)
863+
from torch.onnx.symbolic_helper import _onnx_stable_opsets
864+
865+
for version in _onnx_stable_opsets:
866+
if version >= opset_version:
867+
sym_registry.register_op(op_name, symbolic_fn, ns, version)
864868

865869
# This helper function ensures dynamic axes argument is following the expected format
866870
def _validate_dynamic_axes(dynamic_axes, model, input_names, output_names):

0 commit comments

Comments
 (0)