Skip to content

Commit

Permalink
add compare tests
Browse files Browse the repository at this point in the history
  • Loading branch information
liuyubin0 committed Mar 6, 2025
1 parent 811094c commit 8cb4d0d
Show file tree
Hide file tree
Showing 6 changed files with 288 additions and 1 deletion.
59 changes: 59 additions & 0 deletions ntt/include/nncase/ntt/arch/riscv64/primitive_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -1548,6 +1548,65 @@ REGISTER_RVV_WHERE_OP(float)
REGISTER_RVV_KERNEL_4_1(EQUAL_FLOAT32)
REGISTER_RVV_COMPARE_OP(equal, float, equal_float32)

#define NOT_EQUAL_FLOAT32(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t not_equal_float32( \
const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \
const size_t vl) { \
auto mask = __riscv_vmfne_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \
auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \
}

REGISTER_RVV_KERNEL_4_1(NOT_EQUAL_FLOAT32)
REGISTER_RVV_COMPARE_OP(not_equal, float, not_equal_float32)

#define LESS_FLOAT32(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t less_float32(const vfloat32m##lmul1##_t &v1, \
const vfloat32m##lmul1##_t &v2, \
const size_t vl) { \
auto mask = __riscv_vmflt_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \
auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \
}

REGISTER_RVV_KERNEL_4_1(LESS_FLOAT32)
REGISTER_RVV_COMPARE_OP(less, float, less_float32)

#define LESS_OR_EQUAL_FLOAT32(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t less_or_equal_float32( \
const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \
const size_t vl) { \
auto mask = __riscv_vmfle_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \
auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \
}

REGISTER_RVV_KERNEL_4_1(LESS_OR_EQUAL_FLOAT32)
REGISTER_RVV_COMPARE_OP(less_or_equal, float, less_or_equal_float32)

#define GREATER_FLOAT32(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t greater_float32(const vfloat32m##lmul1##_t &v1, \
const vfloat32m##lmul1##_t &v2, \
const size_t vl) { \
auto mask = __riscv_vmfgt_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \
auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \
}

REGISTER_RVV_KERNEL_4_1(GREATER_FLOAT32)
REGISTER_RVV_COMPARE_OP(greater, float, greater_float32)

#define GREATER_OR_EQUAL_FLOAT32(lmul1, lmul2, mlen) \
inline vuint8m##lmul2##_t greater_or_equal_float32( \
const vfloat32m##lmul1##_t &v1, const vfloat32m##lmul1##_t &v2, \
const size_t vl) { \
auto mask = __riscv_vmfge_vv_f32m##lmul1##_b##mlen(v1, v2, vl); \
auto zeros = __riscv_vmv_v_x_u8m##lmul2(0, vl); \
return __riscv_vmerge_vxm_u8m##lmul2(zeros, 0xFF, mask, vl); \
}

REGISTER_RVV_KERNEL_4_1(GREATER_OR_EQUAL_FLOAT32)
REGISTER_RVV_COMPARE_OP(greater_or_equal, float, greater_or_equal_float32)

// scatterND

Expand Down
2 changes: 1 addition & 1 deletion ntt/test/ctest/test_ntt_compare_equal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ template <typename T, size_t vl> void test_vector() {
_TEST_VECTOR(T, 4) \
_TEST_VECTOR(T, 8)

TEST(UnaryTestAdd, vector) {
TEST(CompareTestEqual, vector) {
TEST_VECTOR(float)
// TEST_VECTOR(int32_t)
// TEST_VECTOR(int64_t)
Expand Down
57 changes: 57 additions & 0 deletions ntt/test/ctest/test_ntt_compare_greater.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>
#include <string_view>

using namespace nncase;
using namespace ortki;

template <typename T, size_t vl> void test_vector() {
ntt::vector<T, vl> ntt_lhs, ntt_rhs;
NttTest::init_tensor(ntt_lhs, static_cast<T>(-10), static_cast<T>(10));
NttTest::init_tensor(ntt_rhs, static_cast<T>(-10), static_cast<T>(10));
auto ntt_output1 = ntt::greater(ntt_lhs, ntt_rhs);
auto ort_lhs = NttTest::ntt2ort(ntt_lhs);
auto ort_rhs = NttTest::ntt2ort(ntt_rhs);
auto ort_output = ortki_Greater(ort_lhs, ort_rhs);
ntt::vector<bool, vl> ntt_output2;
NttTest::ort2ntt(ort_output, ntt_output2);
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

#define _TEST_VECTOR(T, lmul) \
test_vector<T, (NTT_VLEN) / (sizeof(T) * 8) * lmul>();

#define TEST_VECTOR(T) \
_TEST_VECTOR(T, 1) \
_TEST_VECTOR(T, 2) \
_TEST_VECTOR(T, 4) \
_TEST_VECTOR(T, 8)

TEST(CompareTestGreater, vector) {
TEST_VECTOR(float)
// TEST_VECTOR(int32_t)
// TEST_VECTOR(int64_t)
}

int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
57 changes: 57 additions & 0 deletions ntt/test/ctest/test_ntt_compare_greater_or_equal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>
#include <string_view>

using namespace nncase;
using namespace ortki;

template <typename T, size_t vl> void test_vector() {
ntt::vector<T, vl> ntt_lhs, ntt_rhs;
NttTest::init_tensor(ntt_lhs, static_cast<T>(-10), static_cast<T>(10));
NttTest::init_tensor(ntt_rhs, static_cast<T>(-10), static_cast<T>(10));
auto ntt_output1 = ntt::greater_or_equal(ntt_lhs, ntt_rhs);
auto ort_lhs = NttTest::ntt2ort(ntt_lhs);
auto ort_rhs = NttTest::ntt2ort(ntt_rhs);
auto ort_output = ortki_GreaterOrEqual(ort_lhs, ort_rhs);
ntt::vector<bool, vl> ntt_output2;
NttTest::ort2ntt(ort_output, ntt_output2);
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

#define _TEST_VECTOR(T, lmul) \
test_vector<T, (NTT_VLEN) / (sizeof(T) * 8) * lmul>();

#define TEST_VECTOR(T) \
_TEST_VECTOR(T, 1) \
_TEST_VECTOR(T, 2) \
_TEST_VECTOR(T, 4) \
_TEST_VECTOR(T, 8)

TEST(CompareTestGreaterOrEqual, vector) {
TEST_VECTOR(float)
// TEST_VECTOR(int32_t)
// TEST_VECTOR(int64_t)
}

int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
57 changes: 57 additions & 0 deletions ntt/test/ctest/test_ntt_compare_less.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>
#include <string_view>

using namespace nncase;
using namespace ortki;

template <typename T, size_t vl> void test_vector() {
ntt::vector<T, vl> ntt_lhs, ntt_rhs;
NttTest::init_tensor(ntt_lhs, static_cast<T>(-10), static_cast<T>(10));
NttTest::init_tensor(ntt_rhs, static_cast<T>(-10), static_cast<T>(10));
auto ntt_output1 = ntt::less(ntt_lhs, ntt_rhs);
auto ort_lhs = NttTest::ntt2ort(ntt_lhs);
auto ort_rhs = NttTest::ntt2ort(ntt_rhs);
auto ort_output = ortki_Less(ort_lhs, ort_rhs);
ntt::vector<bool, vl> ntt_output2;
NttTest::ort2ntt(ort_output, ntt_output2);
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

#define _TEST_VECTOR(T, lmul) \
test_vector<T, (NTT_VLEN) / (sizeof(T) * 8) * lmul>();

#define TEST_VECTOR(T) \
_TEST_VECTOR(T, 1) \
_TEST_VECTOR(T, 2) \
_TEST_VECTOR(T, 4) \
_TEST_VECTOR(T, 8)

TEST(CompareTestLess, vector) {
TEST_VECTOR(float)
// TEST_VECTOR(int32_t)
// TEST_VECTOR(int64_t)
}

int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}
57 changes: 57 additions & 0 deletions ntt/test/ctest/test_ntt_compare_less_or_equal.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright 2019-2024 Canaan Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ntt_test.h"
#include "ortki_helper.h"
#include <gtest/gtest.h>
#include <iostream>
#include <nncase/ntt/ntt.h>
#include <ortki/operators.h>
#include <string_view>

using namespace nncase;
using namespace ortki;

template <typename T, size_t vl> void test_vector() {
ntt::vector<T, vl> ntt_lhs, ntt_rhs;
NttTest::init_tensor(ntt_lhs, static_cast<T>(-10), static_cast<T>(10));
NttTest::init_tensor(ntt_rhs, static_cast<T>(-10), static_cast<T>(10));
auto ntt_output1 = ntt::less_or_equal(ntt_lhs, ntt_rhs);
auto ort_lhs = NttTest::ntt2ort(ntt_lhs);
auto ort_rhs = NttTest::ntt2ort(ntt_rhs);
auto ort_output = ortki_LessOrEqual(ort_lhs, ort_rhs);
ntt::vector<bool, vl> ntt_output2;
NttTest::ort2ntt(ort_output, ntt_output2);
EXPECT_TRUE(NttTest::compare_tensor(ntt_output1, ntt_output2));
}

#define _TEST_VECTOR(T, lmul) \
test_vector<T, (NTT_VLEN) / (sizeof(T) * 8) * lmul>();

#define TEST_VECTOR(T) \
_TEST_VECTOR(T, 1) \
_TEST_VECTOR(T, 2) \
_TEST_VECTOR(T, 4) \
_TEST_VECTOR(T, 8)

TEST(CompareTestLessOrEqual, vector) {
TEST_VECTOR(float)
// TEST_VECTOR(int32_t)
// TEST_VECTOR(int64_t)
}

int main(int argc, char *argv[]) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

0 comments on commit 8cb4d0d

Please sign in to comment.