From 25c25dabcb8263f13c4dca80b2c8638c17d68e2e Mon Sep 17 00:00:00 2001 From: Zhaoqi Zhu Date: Wed, 3 Feb 2021 06:43:59 -0800 Subject: [PATCH] port & update pr16744 numpy gcd (#19547) * numpy-compatible gcd operator * use BinaryScalarRTCCompute * Update _op.py * Update np_elemwise_broadcast_op_extended.cc * fix * Update operator_tune.cc * fix kernel * add large tensor test * add gcd interoperability workload * Update test_numpy_interoperability.py * Update np_elemwise_broadcast_op_extended.cc * Update np_elemwise_broadcast_op_extended.cc * avoid ci linspce issue Co-authored-by: Hao Jin --- ci/docker/install/requirements | 2 +- python/mxnet/amp/lists/symbol_fp16.py | 2 + python/mxnet/ndarray/numpy/_op.py | 42 ++++++++++++++++- python/mxnet/numpy/multiarray.py | 40 +++++++++++++++- python/mxnet/numpy_dispatch_protocol.py | 1 + python/mxnet/symbol/numpy/_symbol.py | 33 ++++++++++++- .../numpy/np_elemwise_broadcast_op.cc | 8 ++++ src/common/cuda/rtc/forward_functions-inl.h | 43 +++++++++++++++++ src/operator/mshadow_op.h | 46 +++++++++++++++++++ .../np_elemwise_broadcast_op_extended.cc | 35 +++++++++++++- .../np_elemwise_broadcast_op_extended.cu | 6 +++ src/operator/operator_tune.cc | 1 + tests/nightly/test_np_large_array.py | 16 +++++++ .../unittest/test_numpy_interoperability.py | 7 +++ tests/python/unittest/test_numpy_op.py | 1 + 15 files changed, 278 insertions(+), 5 deletions(-) diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements index 8741e333feea..e3d90d994463 100644 --- a/ci/docker/install/requirements +++ b/ci/docker/install/requirements @@ -19,7 +19,7 @@ # the whole docker cache for the image # Required dependencies -numpy<1.20.0 +numpy>=1.17,<1.20.0 requests>=2.20.0,<3 graphviz<0.9.0,>=0.8.1 contextvars;python_version<"3.7" diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 7242a702bb20..d58237b9e8de 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -245,6 +245,8 @@ '_npi_logistic', '_npi_lcm', '_npi_lcm_scalar', + '_npi_gcd', + '_npi_gcd_scalar', '_npi_linspace', '_npi_logical_not', '_npi_logical_and_scalar', diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py index 68eea624a1ed..a1f4bcf6cc7e 100644 --- a/python/mxnet/ndarray/numpy/_op.py +++ b/python/mxnet/ndarray/numpy/_op.py @@ -44,7 +44,7 @@ 'max', 'min', 'amax', 'amin', 'logical_and', 'logical_or', 'logical_xor', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', - 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', + 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd', 'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp', @@ -2081,6 +2081,46 @@ def expand_dims(a, axis): return _api_internal.expand_dims(a, axis) +@set_module('mxnet.ndarray.numpy') +@wrap_np_binary_func +def gcd(x1, x2, out=None, **kwargs): + """ + Returns the greatest common divisor of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing greatest common divisor. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The greatest common divisor of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + lcm : The lowest common multiple + + Examples + -------- + >>> np.gcd(12, 20) + 4 + >>> np.gcd(np.arange(6, dtype=int), 20) + array([20, 1, 2, 1, 4, 5], dtype=int64) + """ + if isinstance(x1, numeric_types) and isinstance(x2, numeric_types): + return _np.gcd(x1, x2, out=out) + return _api_internal.gcd(x1, x2, out) + + @set_module('mxnet.ndarray.numpy') @wrap_np_binary_func def lcm(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index e260a41849ad..0f55c318d950 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -74,7 +74,7 @@ 'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot', 'triu_indices_from', 'triu_indices', 'tri', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', - 'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', + 'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron', 'equal', 'not_equal', 'interp', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul', @@ -3620,6 +3620,44 @@ def power(x1, x2, out=None, **kwargs): return _mx_nd_np.power(x1, x2, out=out) +@set_module('mxnet.numpy') +@wrap_np_binary_func +def gcd(x1, x2, out=None, **kwargs): + """ + Returns the greatest common divisor of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing greatest common divisor. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The greatest common divisor of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + gcd : The lowest common multiple + + Examples + -------- + >>> np.gcd(12, 20) + 4 + >>> np.gcd(np.arange(6, dtype=int), 20) + array([20, 1, 2, 1, 4, 5], dtype=int64) + """ + return _mx_nd_np.gcd(x1, x2, out=out) + + @set_module('mxnet.numpy') @wrap_np_binary_func def lcm(x1, x2, out=None, **kwargs): diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 01f2b7257473..f047076e8193 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -249,6 +249,7 @@ def _register_array_function(): 'degrees', 'hypot', 'lcm', + 'gcd', # 'ldexp', 'subtract', 'multiply', diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py index c1df9729ecbf..821b6fa42d41 100644 --- a/python/mxnet/symbol/numpy/_symbol.py +++ b/python/mxnet/symbol/numpy/_symbol.py @@ -49,7 +49,7 @@ 'flatnonzero', 'tril_indices', 'amax', 'amin', 'max', 'min', 'logical_and', 'logical_or', 'logical_xor', 'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index', 'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr', - 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp', + 'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd', 'interp', 'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', @@ -1678,6 +1678,37 @@ def power(x1, x2, out=None, **kwargs): return _ufunc_helper(x1, x2, _npi.power, _np.power, _npi.power_scalar, _npi.rpower_scalar, out) +@set_module('mxnet.symbol.numpy') +@wrap_np_binary_func +def gcd(x1, x2, out=None, **kwargs): + """ + Returns the greatest common divisor of ``|x1|`` and ``|x2|`` + + Parameters + ---------- + x1, x2 : ndarrays or scalar values + The arrays for computing greatest common divisor. If x1.shape != x2.shape, + they must be broadcastable to a common shape (which may be the shape of + one or the other). + + out : ndarray or None, optional + A location into which the result is stored. If provided, it must have a shape + that the inputs broadcast to. If not provided or None, a freshly-allocated array + is returned. + + Returns + ------- + y : ndarray or scalar + The greatest common divisor of the absolute value of the inputs + This is a scalar if both `x1` and `x2` are scalars. + + See Also + -------- + lcm : The lowest common multiple + """ + return _ufunc_helper(x1, x2, _npi.gcd, _np.gcd, _npi.gcd_scalar, None, out) + + @set_module('mxnet.symbol.numpy') @wrap_np_binary_func def matmul(a, b, out=None, **kwargs): diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc index f55f82c30a49..a411b067f1c0 100644 --- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc +++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc @@ -88,6 +88,14 @@ MXNET_REGISTER_API("_npi.lcm") UFuncHelper(args, ret, op, op_scalar, nullptr); }); +MXNET_REGISTER_API("_npi.gcd") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_gcd"); + const nnvm::Op* op_scalar = Op::Get("_npi_gcd_scalar"); + UFuncHelper(args, ret, op, op_scalar, nullptr); +}); + MXNET_REGISTER_API("_npi.logical_and") .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { using namespace runtime; diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h index 6568bae58318..f85916f1ef96 100644 --- a/src/common/cuda/rtc/forward_functions-inl.h +++ b/src/common/cuda/rtc/forward_functions-inl.h @@ -541,6 +541,49 @@ lcm(const DType a, const DType2 b) { } } +template +__device__ inline typename type_util::mixed_type::type +gcd(const DType a, const DType2 b) { + if (type_util::is_integral::value && + type_util::is_integral::value) { + DType A = a; + DType2 B = b; + // minus cases. + if (a < 0) { + A = -a; + } + if (b < 0) { + B = -b; + } + // handle zero-valued cases. + DType c; + if (a == 0 && b != 0) { + c = B; + } else if (b == 0 && a != 0) { + c = A; + } else if (a == 0 && b == 0) { + c = 0; + } else { + DType tmp; + if (A < B) { + tmp = A; + A = B; + B = tmp; + } + while (A % B != 0) { + A = A % B; + tmp = A; + A = B; + B = tmp; + } + c = B; + } + return c; + } else { + return 0; + } +} + template __device__ inline typename type_util::mixed_type::type bitwise_xor(const DType a, const DType2 b) { diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 82170cdf8763..7c7c18f39c3c 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -1704,6 +1704,52 @@ struct nanprod_grad : public mxnet_op::tunable { #pragma GCC diagnostic ignored "-Wint-in-bool-context" #pragma GCC diagnostic ignored "-Wbool-compare" #endif + +/*! \brief used for computing binary greatest common divisor */ +struct gcd : public mxnet_op::tunable { + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + // minus cases. + if (a < 0) { + a = -a; + } + if (b < 0) { + b = -b; + } + // handle zero-valued cases. + DType c; + if (a == 0 && b != 0) { + c = b; + } else if (b == 0 && a != 0) { + c = a; + } else if (a == 0 && b == 0) { + c = 0; + } else { + DType tmp; + if (a < b) { + tmp = a; + a = b; + b = tmp; + } + while (a % b != 0) { + a = a % b; + tmp = a; + a = b; + b = tmp; + } + c = b; + } + return c; + } + + template + MSHADOW_XINLINE static typename enable_if::value, DType>::type + Map(DType a, DType b) { + return DType(0.0f); + } +}; + /*! \brief used for computing binary lowest common multiple */ struct lcm : public mxnet_op::tunable { template diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc index 90a48d4aee9f..188b6c853dbb 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc @@ -63,6 +63,39 @@ NNVM_REGISTER_OP(_backward_npi_copysign) .set_attr("FCompute", BinaryBroadcastBackwardUseIn); +NNVM_REGISTER_OP(_npi_gcd) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FListInputNames", +[](const NodeAttrs& attrs) { + return std::vector{"lhs", "rhs"}; +}) +.set_attr("FInferShape", BinaryBroadcastShape) +.set_attr("FInferType", ElemwiseIntType<2, 1>) +.set_attr("FInplaceOption", +[](const NodeAttrs& attrs){ + return std::vector >{{0, 0}, {1, 0}}; +}) +.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FCompute", BinaryBroadcastIntCompute) +.add_argument("lhs", "NDArray-or-Symbol", "First input to the function") +.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function"); + +NNVM_REGISTER_OP(_npi_gcd_scalar) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", ElemwiseShape<1, 1>) +.set_attr("FInferType", ElemwiseIntType<1, 1>) +.set_attr("FInplaceOption", + [](const NodeAttrs& attrs){ + return std::vector >{{0, 0}}; + }) +.set_attr("FGradient", MakeZeroGradNodes) +.add_argument("data", "NDArray-or-Symbol", "source input") +.add_arguments(NumpyBinaryScalarParam::__FIELDS__()) +.set_attr("FCompute", BinaryScalarOp::ComputeInt); + NNVM_REGISTER_OP(_npi_lcm) .set_num_inputs(2) .set_num_outputs(1) @@ -94,7 +127,7 @@ NNVM_REGISTER_OP(_npi_lcm_scalar) .set_attr("FGradient", MakeZeroGradNodes) .add_argument("data", "NDArray-or-Symbol", "source input") .add_arguments(NumpyBinaryScalarParam::__FIELDS__()) -.set_attr("FCompute", BinaryScalarOp::Compute); +.set_attr("FCompute", BinaryScalarOp::ComputeInt); NNVM_REGISTER_OP(_npi_bitwise_and) .set_num_inputs(2) diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu index b1d7e71bf17d..ff1cedff53db 100644 --- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu +++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu @@ -31,6 +31,9 @@ namespace op { NNVM_REGISTER_OP(_npi_copysign) .set_attr("FCompute", BinaryBroadcastRTCCompute{"copysign"}); +NNVM_REGISTER_OP(_npi_gcd) +.set_attr("FCompute", BinaryBroadcastRTCCompute{"gcd"}); + NNVM_REGISTER_OP(_npi_lcm) .set_attr("FCompute", BinaryBroadcastRTCCompute{"lcm"}); @@ -82,6 +85,9 @@ NNVM_REGISTER_OP(_npi_rarctan2_scalar) NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar) .set_attr("FCompute", BinaryScalarRTCBackward{"rarctan2_grad"}); +NNVM_REGISTER_OP(_npi_gcd_scalar) +.set_attr("FCompute", BinaryScalarRTCCompute{"gcd"}); + NNVM_REGISTER_OP(_npi_lcm_scalar) .set_attr("FCompute", BinaryScalarRTCCompute{"lcm"}); diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc index 9af336499d5c..557338e2b408 100644 --- a/src/operator/operator_tune.cc +++ b/src/operator/operator_tune.cc @@ -417,6 +417,7 @@ IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_xor); // IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient); // NOLINT() +IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::gcd); // NOLINT() IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>); // NOLINT() IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>); // NOLINT() diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py index a42325377482..fa103fdce699 100644 --- a/tests/nightly/test_np_large_array.py +++ b/tests/nightly/test_np_large_array.py @@ -906,6 +906,22 @@ def test_lcm(): assert inp1.grad[-1, -1] == 0 +@use_np +def test_gcd(): + inp1 = np.ones((2, INT_OVERFLOW), dtype='int32') + inp2 = np.ones((2, INT_OVERFLOW), dtype='int32') + inp1[-1, -1] = 12 + inp2[-1, -1] = 20 + inp1.attach_grad() + with mx.autograd.record(): + out = np.gcd(inp1, inp2) + out.backward() + assert out.shape == inp1.shape + assert out[-1, -1] == 4 + assert inp1.grad.shape == inp1.shape + assert inp1.grad[-1, -1] == 0 + + @use_np def test_log_family(): def batch_check(funcs, exp): diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index a05bc79b9042..1fa7d5284399 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -1488,6 +1488,12 @@ def _add_workload_lcm(): OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) +def _add_workload_gcd(): + OpArgMngr.add_workload('gcd', np.array([24, 30], dtype=np.int8), np.array([20, 75], dtype=np.int8)) + OpArgMngr.add_workload('gcd', np.array([24, 30], dtype=np.uint8), np.array([20, 75], dtype=np.uint8)) + OpArgMngr.add_workload('gcd', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32)) + + def _add_workload_bitwise_or(): OpArgMngr.add_workload('bitwise_or', np.array([False, False, True, True], dtype=np.bool), np.array([False, True, False, True], dtype=np.bool)) @@ -3071,6 +3077,7 @@ def _prepare_workloads(): _add_workload_interp() _add_workload_hypot() _add_workload_lcm() + _add_workload_gcd() _add_workload_bitwise_and() _add_workload_bitwise_xor() _add_workload_bitwise_or() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 650c420941e3..6bea5109b4c6 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3021,6 +3021,7 @@ def hybrid_forward(self, F, a, b, *args, **kwargs): [[_np.float16, _np.float32, _np.float64], [_np.int32]]), 'power': (1.0, 3.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2], [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]), + 'gcd': (-100, 100, [None], None, [[_np.int32]]), 'lcm': (-100, 100, [None], None, [[_np.int32]]), 'bitwise_and': (-100, 100, [None], None, [[_np.int32]]), 'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),