Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions include/tvm/relax/attrs/vision.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*!
* \file tvm/relax/attrs/vision.h
* \brief Auxiliary attributes for vision operators.
*/
#ifndef TVM_RELAX_ATTRS_VISION_H_
#define TVM_RELAX_ATTRS_VISION_H_

#include <tvm/relax/expr.h>

namespace tvm {
namespace relax {

/*! \brief Attributes used in AllClassNonMaximumSuppression operator */
struct AllClassNonMaximumSuppressionAttrs
: public tvm::AttrsNode<AllClassNonMaximumSuppressionAttrs> {
String output_format;

TVM_DECLARE_ATTRS(AllClassNonMaximumSuppressionAttrs,
"relax.attrs.AllClassNonMaximumSuppressionAttrs") {
TVM_ATTR_FIELD(output_format)
.set_default("onnx")
.describe(
"Output format, onnx or tensorflow. Returns outputs in a way that can be easily "
"consumed by each frontend.");
}
}; // struct AllClassNonMaximumSuppressionAttrs

} // namespace relax
} // namespace tvm

#endif // TVM_RELAX_ATTRS_VISION_H_
51 changes: 50 additions & 1 deletion python/tvm/relax/frontend/onnx/onnx_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,6 +3109,55 @@ def _impl_v9(cls, bb, inputs, attr, params):
)


class NonMaxSuppression(OnnxOpConverter):
"""Converts an onnx NonMaxSuppression node into an equivalent Relax expression."""

@classmethod
def _impl_v10(cls, bb, inputs, attr, params):
# Get parameter values
boxes = inputs[0]
scores = inputs[1]
max_output_boxes_per_class = inputs[2] if len(inputs) >= 3 else relax.const([0], "int64")
iou_threshold = inputs[3] if len(inputs) >= 4 else relax.const([0.0], "float32")
score_threshold = inputs[4] if len(inputs) >= 5 else relax.const([0.0], "float32")

boxes_dtype = boxes.struct_info.dtype
if attr.get("center_point_box", 0) != 0:
xc, yc, w, h = relax.op.split(boxes, 4, axis=2)
half_w = w / relax.const(2.0, boxes_dtype)
half_h = h / relax.const(2.0, boxes_dtype)
x1 = xc - half_w
x2 = xc + half_w
y1 = yc - half_h
y2 = yc + half_h
boxes = relax.op.concat([y1, x1, y2, x2], axis=2)

def conditionally_squeeze_scalar(x):
rank = x.struct_info.ndim
assert rank <= 1, "nms thresholds must be scalars"
return relax.op.squeeze(x, [0]) if rank == 1 else x

max_output_boxes_per_class = conditionally_squeeze_scalar(max_output_boxes_per_class)
iou_threshold = conditionally_squeeze_scalar(iou_threshold)
score_threshold = conditionally_squeeze_scalar(score_threshold)

nms_out = bb.normalize(
relax.op.vision.all_class_non_max_suppression(
boxes,
scores,
max_output_boxes_per_class,
iou_threshold,
score_threshold,
)
)
return relax.op.dynamic_strided_slice(
nms_out[0],
begin=relax.const([0, 0], dtype="int64"),
end=relax.op.concat([nms_out[1], relax.const([3], dtype="int64")], axis=0),
strides=relax.const([1, 1], dtype="int64"),
)


class HardSigmoid(OnnxOpConverter):
"""Converts an onnx HardSigmoid node into an equivalent Relax expression."""

Expand Down Expand Up @@ -3499,7 +3548,7 @@ def _get_convert_map():
# "LRN": LRN,
# "MaxRoiPool": MaxRoiPool,
# "RoiAlign": RoiAlign,
# "NonMaxSuppression": NonMaxSuppression,
"NonMaxSuppression": NonMaxSuppression,
# "GridSample": GridSample,
# "Upsample": Upsample,
# others
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@
tanh,
trunc,
)
from .vision import all_class_non_max_suppression


def _register_op_make():
Expand Down
18 changes: 18 additions & 0 deletions python/tvm/relax/op/vision/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""VISION operators."""
from .nms import *
19 changes: 19 additions & 0 deletions python/tvm/relax/op/vision/_ffi_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
"""Constructor APIs"""
import tvm._ffi

tvm._ffi._init_api("relax.op.vision", __name__)
Loading
Loading