From 8cb4d0d4d6b15666fbb43df8cf8b4e07bd1508e6 Mon Sep 17 00:00:00 2001 From: liuyubin0 Date: Thu, 6 Mar 2025 10:56:04 +0800 Subject: [PATCH] add compare tests --- .../nncase/ntt/arch/riscv64/primitive_ops.h | 59 +++++++++++++++++++ ntt/test/ctest/test_ntt_compare_equal.cpp | 2 +- ntt/test/ctest/test_ntt_compare_greater.cpp | 57 ++++++++++++++++++ .../test_ntt_compare_greater_or_equal.cpp | 57 ++++++++++++++++++ ntt/test/ctest/test_ntt_compare_less.cpp | 57 ++++++++++++++++++ .../ctest/test_ntt_compare_less_or_equal.cpp | 57 ++++++++++++++++++ 6 files changed, 288 insertions(+), 1 deletion(-) create mode 100644 ntt/test/ctest/test_ntt_compare_greater.cpp create mode 100644 ntt/test/ctest/test_ntt_compare_greater_or_equal.cpp create mode 100644 ntt/test/ctest/test_ntt_compare_less.cpp create mode 100644 ntt/test/ctest/test_ntt_compare_less_or_equal.cpp diff --git a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h index 381b3a20d..2dde915f9 100644 --- a/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h +++ b/ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h @@ -1548,6 +1548,65 @@ REGISTER_RVV_WHERE_OP(float) REGISTER_RVV_KERNEL_4_1(EQUAL_FLOAT32) REGISTER_RVV_COMPARE_OP(equal, float, equal_float32) +#define NOT_EQUAL_FLOAT32(lmul1, lmul2, mlen) \ + inline vuint8m##lmul2##_t not_equal_float32( \ + const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \ + const size_t vl) { \ + auto mask = __riscv_vmfne_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \ + auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \ + return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \ + } + +REGISTER_RVV_KERNEL_4_1(NOT_EQUAL_FLOAT32) +REGISTER_RVV_COMPARE_OP(not_equal, float, not_equal_float32) + +#define LESS_FLOAT32(lmul1, lmul2, mlen) \ + inline vuint8m##lmul2##_t less_float32(const vfloat32m##lmul1##_t &v1, \ + const vfloat32m##lmul1##_t &v2, \ + const size_t vl) { \ + auto mask = __riscv_vmflt_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \ + auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \ + return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \ + } + +REGISTER_RVV_KERNEL_4_1(LESS_FLOAT32) +REGISTER_RVV_COMPARE_OP(less, float, less_float32) + +#define LESS_OR_EQUAL_FLOAT32(lmul1, lmul2, mlen) \ + inline vuint8m##lmul2##_t less_or_equal_float32( \ + const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \ + const size_t vl) { \ + auto mask = __riscv_vmfle_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \ + auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \ + return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \ + } + +REGISTER_RVV_KERNEL_4_1(LESS_OR_EQUAL_FLOAT32) +REGISTER_RVV_COMPARE_OP(less_or_equal, float, less_or_equal_float32) + +#define GREATER_FLOAT32(lmul1, lmul2, mlen) \ + inline vuint8m##lmul2##_t greater_float32(const vfloat32m##lmul1##_t &v1, \ + const vfloat32m##lmul1##_t &v2, \ + const size_t vl) { \ + auto mask = __riscv_vmfgt_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \ + auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \ + return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \ + } + +REGISTER_RVV_KERNEL_4_1(GREATER_FLOAT32) +REGISTER_RVV_COMPARE_OP(greater, float, greater_float32) + +#define GREATER_OR_EQUAL_FLOAT32(lmul1, lmul2, mlen) \ + inline vuint8m##lmul2##_t greater_or_equal_float32( \ + const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \ + const size_t vl) { \ + auto mask = __riscv_vmfge_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \ + auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \ + return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \ + } + +REGISTER_RVV_KERNEL_4_1(GREATER_OR_EQUAL_FLOAT32) +REGISTER_RVV_COMPARE_OP(greater_or_equal, float, greater_or_equal_float32) // scatterND diff --git a/ntt/test/ctest/test_ntt_compare_equal.cpp b/ntt/test/ctest/test_ntt_compare_equal.cpp index 9e17fd498..66b6faee5 100644 --- a/ntt/test/ctest/test_ntt_compare_equal.cpp +++ b/ntt/test/ctest/test_ntt_compare_equal.cpp @@ -45,7 +45,7 @@ template void test_vector() { _TEST_VECTOR(T, 4) \ _TEST_VECTOR(T, 8) -TEST(UnaryTestAdd, vector) { +TEST(CompareTestEqual, vector) { TEST_VECTOR(float) // TEST_VECTOR(int32_t) // TEST_VECTOR(int64_t) diff --git a/ntt/test/ctest/test_ntt_compare_greater.cpp b/ntt/test/ctest/test_ntt_compare_greater.cpp new file mode 100644 index 000000000..52a062615 --- /dev/null +++ b/ntt/test/ctest/test_ntt_compare_greater.cpp @@ -0,0 +1,57 @@ +/* Copyright 2019-2024 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ntt_test.h" +#include "ortki_helper.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace ortki; + +template void test_vector() { + ntt::vector ntt_lhs, ntt_rhs; + NttTest::init_tensor(ntt_lhs, static_cast(-10), static_cast(10)); + NttTest::init_tensor(ntt_rhs, static_cast(-10), static_cast(10)); + auto ntt_output1 = ntt::greater(ntt_lhs, ntt_rhs); + auto ort_lhs = NttTest::ntt2ort(ntt_lhs); + auto ort_rhs = NttTest::ntt2ort(ntt_rhs); + auto ort_output = ortki_Greater(ort_lhs, ort_rhs); + ntt::vector ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + +#define _TEST_VECTOR(T, lmul) \ + test_vector(); + +#define TEST_VECTOR(T) \ + _TEST_VECTOR(T, 1) \ + _TEST_VECTOR(T, 2) \ + _TEST_VECTOR(T, 4) \ + _TEST_VECTOR(T, 8) + +TEST(CompareTestGreater, vector) { + TEST_VECTOR(float) + // TEST_VECTOR(int32_t) + // TEST_VECTOR(int64_t) +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/ntt/test/ctest/test_ntt_compare_greater_or_equal.cpp b/ntt/test/ctest/test_ntt_compare_greater_or_equal.cpp new file mode 100644 index 000000000..fe2173b3c --- /dev/null +++ b/ntt/test/ctest/test_ntt_compare_greater_or_equal.cpp @@ -0,0 +1,57 @@ +/* Copyright 2019-2024 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ntt_test.h" +#include "ortki_helper.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace ortki; + +template void test_vector() { + ntt::vector ntt_lhs, ntt_rhs; + NttTest::init_tensor(ntt_lhs, static_cast(-10), static_cast(10)); + NttTest::init_tensor(ntt_rhs, static_cast(-10), static_cast(10)); + auto ntt_output1 = ntt::greater_or_equal(ntt_lhs, ntt_rhs); + auto ort_lhs = NttTest::ntt2ort(ntt_lhs); + auto ort_rhs = NttTest::ntt2ort(ntt_rhs); + auto ort_output = ortki_GreaterOrEqual(ort_lhs, ort_rhs); + ntt::vector ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + +#define _TEST_VECTOR(T, lmul) \ + test_vector(); + +#define TEST_VECTOR(T) \ + _TEST_VECTOR(T, 1) \ + _TEST_VECTOR(T, 2) \ + _TEST_VECTOR(T, 4) \ + _TEST_VECTOR(T, 8) + +TEST(CompareTestGreaterOrEqual, vector) { + TEST_VECTOR(float) + // TEST_VECTOR(int32_t) + // TEST_VECTOR(int64_t) +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/ntt/test/ctest/test_ntt_compare_less.cpp b/ntt/test/ctest/test_ntt_compare_less.cpp new file mode 100644 index 000000000..af107b67a --- /dev/null +++ b/ntt/test/ctest/test_ntt_compare_less.cpp @@ -0,0 +1,57 @@ +/* Copyright 2019-2024 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ntt_test.h" +#include "ortki_helper.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace ortki; + +template void test_vector() { + ntt::vector ntt_lhs, ntt_rhs; + NttTest::init_tensor(ntt_lhs, static_cast(-10), static_cast(10)); + NttTest::init_tensor(ntt_rhs, static_cast(-10), static_cast(10)); + auto ntt_output1 = ntt::less(ntt_lhs, ntt_rhs); + auto ort_lhs = NttTest::ntt2ort(ntt_lhs); + auto ort_rhs = NttTest::ntt2ort(ntt_rhs); + auto ort_output = ortki_Less(ort_lhs, ort_rhs); + ntt::vector ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + +#define _TEST_VECTOR(T, lmul) \ + test_vector(); + +#define TEST_VECTOR(T) \ + _TEST_VECTOR(T, 1) \ + _TEST_VECTOR(T, 2) \ + _TEST_VECTOR(T, 4) \ + _TEST_VECTOR(T, 8) + +TEST(CompareTestLess, vector) { + TEST_VECTOR(float) + // TEST_VECTOR(int32_t) + // TEST_VECTOR(int64_t) +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/ntt/test/ctest/test_ntt_compare_less_or_equal.cpp b/ntt/test/ctest/test_ntt_compare_less_or_equal.cpp new file mode 100644 index 000000000..9f668db3a --- /dev/null +++ b/ntt/test/ctest/test_ntt_compare_less_or_equal.cpp @@ -0,0 +1,57 @@ +/* Copyright 2019-2024 Canaan Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "ntt_test.h" +#include "ortki_helper.h" +#include +#include +#include +#include +#include + +using namespace nncase; +using namespace ortki; + +template void test_vector() { + ntt::vector ntt_lhs, ntt_rhs; + NttTest::init_tensor(ntt_lhs, static_cast(-10), static_cast(10)); + NttTest::init_tensor(ntt_rhs, static_cast(-10), static_cast(10)); + auto ntt_output1 = ntt::less_or_equal(ntt_lhs, ntt_rhs); + auto ort_lhs = NttTest::ntt2ort(ntt_lhs); + auto ort_rhs = NttTest::ntt2ort(ntt_rhs); + auto ort_output = ortki_LessOrEqual(ort_lhs, ort_rhs); + ntt::vector ntt_output2; + NttTest::ort2ntt(ort_output, ntt_output2); + EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2)); +} + +#define _TEST_VECTOR(T, lmul) \ + test_vector(); + +#define TEST_VECTOR(T) \ + _TEST_VECTOR(T, 1) \ + _TEST_VECTOR(T, 2) \ + _TEST_VECTOR(T, 4) \ + _TEST_VECTOR(T, 8) + +TEST(CompareTestLessOrEqual, vector) { + TEST_VECTOR(float) + // TEST_VECTOR(int32_t) + // TEST_VECTOR(int64_t) +} + +int main(int argc, char *argv[]) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file