diff --git a/README.md b/README.md index 8ef586a..3dede46 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,6 @@ Seer is an original, strong UCI chess engine. Seer relies on a neural network es The latest network can be found [here](https://github.com/connormcmonigle/seer-training/releases) ``` cd build -wget -O eval.bin https://github.com/connormcmonigle/seer-training/releases/download/0x35ddef41/q0x35ddef41.bin +wget -O eval.bin https://github.com/connormcmonigle/seer-training/releases/download/0x2291e0ff/q0x2291e0ff.bin make pgo EVALFILE=eval.bin ``` diff --git a/build/makefile b/build/makefile index c6b878c..175d2ec 100644 --- a/build/makefile +++ b/build/makefile @@ -2,7 +2,7 @@ EXE = seer CXX = g++ CXXSTANDARD = 17 -EVALFILE = weights/q0x35ddef41.bin +EVALFILE = weights/q0x2291e0ff.bin OPSLIMIT = 1000000000 CXXSRC += $(wildcard ../src/*.cc ) diff --git a/include/nnue/dense_relu_affine_layer.h b/include/nnue/dense_relu_affine_layer.h index a4da958..12bceea 100644 --- a/include/nnue/dense_relu_affine_layer.h +++ b/include/nnue/dense_relu_affine_layer.h @@ -22,49 +22,63 @@ #include #include +#include #include namespace nnue { -template +template > struct dense_relu_affine_layer { static constexpr std::size_t W_numel = dim0 * dim1; static constexpr std::size_t b_numel = dim1; alignas(simd::alignment) T W[W_numel]; - alignas(simd::alignment) dot_type b[b_numel]; + alignas(simd::alignment) O b[b_numel]; [[nodiscard]] constexpr std::size_t num_parameters() const noexcept { return W_numel + b_numel; } - [[nodiscard]] inline aligned_vector, dim1> forward(const aligned_vector& x) const noexcept { - auto result = aligned_vector, dim1>::from(b); + [[nodiscard]] inline aligned_vector forward_relu(const aligned_vector& x) const noexcept { + auto result = aligned_vector::from(b); simd::relu_matrix_vector_product(W, x.data, result.data); return result; } - [[nodiscard]] inline aligned_vector, dim1> forward(const aligned_slice& x) const noexcept { - auto result = aligned_vector, dim1>::from(b); + [[nodiscard]] inline aligned_vector forward_relu(const aligned_slice& x) const noexcept { + auto result = aligned_vector::from(b); simd::relu_matrix_vector_product(W, x.data, result.data); return result; } + [[nodiscard]] inline aligned_vector forward_crelu255(const aligned_vector& x) const noexcept { + auto result = aligned_vector::from(b); + simd::crelu255_matrix_vector_product(W, x.data, result.data); + return result; + } + + [[nodiscard]] inline aligned_vector forward_crelu255(const aligned_slice& x) const noexcept { + auto result = aligned_vector::from(b); + simd::crelu255_matrix_vector_product(W, x.data, result.data); + return result; + } + template - [[maybe_unused]] dense_relu_affine_layer& load_(streamer_type& streamer) noexcept { - streamer.template stream(W, W_numel).template stream>(b, b_numel); + [[maybe_unused]] dense_relu_affine_layer& load_(streamer_type& streamer) noexcept { + streamer.template stream(W, W_numel).template stream(b, b_numel); return *this; } template - [[maybe_unused]] const dense_relu_affine_layer& write_(exporter_type& exporter) const noexcept { - exporter.template write(W, W_numel).template write>(b, b_numel); + [[maybe_unused]] const dense_relu_affine_layer& write_(exporter_type& exporter) const noexcept { + exporter.template write(W, W_numel).template write(b, b_numel); return *this; } - [[nodiscard]] dense_relu_affine_layer half_input_flipped() const noexcept { + [[nodiscard]] dense_relu_affine_layer half_input_flipped() const noexcept { static_assert(dim0 % 2 == 0); constexpr std::size_t half_dim0 = dim0 / 2; - dense_relu_affine_layer result = *this; + dense_relu_affine_layer result = *this; + for (std::size_t i(0); i < W_numel; i += dim0) { for (std::size_t j(0); j < half_dim0; ++j) { std::iter_swap(result.W + i + j, result.W + half_dim0 + i + j); } } @@ -72,13 +86,23 @@ struct dense_relu_affine_layer { return result; } - template - [[nodiscard]] dense_relu_affine_layer quantized(const T& weight_scale, const T& bias_scale) const noexcept { - static_assert(std::is_floating_point_v && std::is_integral_v); - dense_relu_affine_layer result{}; -#pragma omp simd - for (std::size_t i = 0; i < W_numel; ++i) { result.W[i] = static_cast(std::round(weight_scale * W[i])); } - for (std::size_t i = 0; i < b_numel; ++i) { result.b[i] = static_cast>(std::round(bias_scale * b[i])); } + template > + [[nodiscard]] dense_relu_affine_layer quantized(const T& weight_scale, const T& bias_scale) const noexcept { + static_assert(std::is_floating_point_v && std::is_integral_v && std::is_integral_v && std::is_integral_v); + dense_relu_affine_layer result{}; + + for (std::size_t i = 0; i < W_numel; ++i) { + const float lower_limit = static_cast(std::numeric_limits::min()); + const float upper_limit = static_cast(std::numeric_limits::max()); + result.W[i] = static_cast(std::clamp(std::round(weight_scale * W[i]), lower_limit, upper_limit)); + } + + for (std::size_t i = 0; i < b_numel; ++i) { + const float lower_limit = static_cast(std::numeric_limits::min()); + const float upper_limit = static_cast(std::numeric_limits::max()); + result.b[i] = static_cast(std::clamp(std::round(bias_scale * b[i]), lower_limit, upper_limit)); + } + return result; } }; diff --git a/include/nnue/eval.h b/include/nnue/eval.h index e429c9e..be7c1c1 100644 --- a/include/nnue/eval.h +++ b/include/nnue/eval.h @@ -90,10 +90,10 @@ struct eval : public chess::sided [[nodiscard]] inline propagate_data> propagate(const bool pov, F&& final_output_encoder) const noexcept { - const auto x1 = (pov ? weights_->white_fc0 : weights_->black_fc0).forward(base_).dequantized(weights::dequantization_scale); - const auto x2 = concat(x1, weights_->fc1.forward(x1)); - const auto x3 = concat(x2, weights_->fc2.forward(x2)); - return propagate_data(final_output_encoder(x3), weights_->fc3.forward(x3).item()); + const auto x1 = (pov ? weights_->white_fc0 : weights_->black_fc0).forward_crelu255(base_).dequantized(weights::dequantization_scale); + const auto x2 = concat(x1, weights_->fc1.forward_relu(x1)); + const auto x3 = concat(x2, weights_->fc2.forward_relu(x2)); + return propagate_data(final_output_encoder(x3), weights_->fc3.forward_relu(x3).item()); } template diff --git a/include/nnue/simd.h b/include/nnue/simd.h index a2ad3d4..a82164c 100644 --- a/include/nnue/simd.h +++ b/include/nnue/simd.h @@ -43,15 +43,6 @@ inline void aligned_free(void* ptr) { #endif } -#if defined(__AVX512BW__) -struct vector_512 { - using integral_type = __m512i; - using float_type = __m512; - static_assert(sizeof(integral_type) == sizeof(float_type)); - static constexpr std::size_t size = sizeof(integral_type); -}; -#endif - #if defined(__AVX2__) struct vector_256 { using integral_type = __m256i; @@ -70,9 +61,7 @@ struct vector_128 { }; #endif -#if defined(__AVX512BW__) -constexpr std::size_t alignment = vector_512::size; -#elif defined(__AVX2__) +#if defined(__AVX2__) constexpr std::size_t alignment = vector_256::size; #elif defined(__SSSE3__) constexpr std::size_t alignment = vector_128::size; @@ -142,299 +131,15 @@ inline void relu_matrix_vector_product(const T0* matrix, const T0* input, T1* ou } } -#if defined(__AVX512BW__) -template -struct int16_add_x128 { - static constexpr std::size_t num_units = 4; - static constexpr bool available = divides>; - - static inline void f(std::int16_t* a, const std::int16_t* b) noexcept { - for (std::size_t i(0); i < dim; i += num_units * per_unit) { - __m512i* a_0 = (__m512i*)(a + i + 0 * per_unit); - *a_0 = _mm512_add_epi16(*a_0, _mm512_load_si512((__m512i*)(b + i + 0 * per_unit))); - - __m512i* a_1 = (__m512i*)(a + i + 1 * per_unit); - *a_1 = _mm512_add_epi16(*a_1, _mm512_load_si512((__m512i*)(b + i + 1 * per_unit))); - - __m512i* a_2 = (__m512i*)(a + i + 2 * per_unit); - *a_2 = _mm512_add_epi16(*a_2, _mm512_load_si512((__m512i*)(b + i + 2 * per_unit))); - - __m512i* a_3 = (__m512i*)(a + i + 3 * per_unit); - *a_3 = _mm512_add_epi16(*a_3, _mm512_load_si512((__m512i*)(b + i + 3 * per_unit))); - } - } -}; - -template -inline void add(std::int16_t* a, const std::int16_t* b) noexcept { - return overload_set>::f(a, b); -} - -template -struct int16_sub_x128 { - static constexpr std::size_t num_units = 4; - static constexpr bool available = divides>; - - static inline void f(std::int16_t* a, const std::int16_t* b) noexcept { - for (std::size_t i(0); i < dim; i += num_units * per_unit) { - __m512i* a_0 = (__m512i*)(a + i + 0 * per_unit); - *a_0 = _mm512_sub_epi16(*a_0, _mm512_load_si512((__m512i*)(b + i + 0 * per_unit))); - - __m512i* a_1 = (__m512i*)(a + i + 1 * per_unit); - *a_1 = _mm512_sub_epi16(*a_1, _mm512_load_si512((__m512i*)(b + i + 1 * per_unit))); - - __m512i* a_2 = (__m512i*)(a + i + 2 * per_unit); - *a_2 = _mm512_sub_epi16(*a_2, _mm512_load_si512((__m512i*)(b + i + 2 * per_unit))); - - __m512i* a_3 = (__m512i*)(a + i + 3 * per_unit); - *a_3 = _mm512_sub_epi16(*a_3, _mm512_load_si512((__m512i*)(b + i + 3 * per_unit))); - } - } -}; - -template -inline void sub(std::int16_t* a, const std::int16_t* b) { - return overload_set>::f(a, b); -} - -template -struct int16_add_add_sub_x128 { - static constexpr std::size_t num_units = 4; - static constexpr bool available = divides>; - - static inline void f(const std::int16_t* a_0, const std::int16_t* a_1, const std::int16_t* s_0, std::int16_t* out) noexcept { - for (std::size_t i(0); i < dim; i += num_units * per_unit) { - { - const __m512i a_0_0 = _mm512_load_si512((__m512i*)(a_0 + i + 0 * per_unit)); - const __m512i a_1_0 = _mm512_load_si512((__m512i*)(a_1 + i + 0 * per_unit)); - const __m512i s_0_0 = _mm512_load_si512((__m512i*)(s_0 + i + 0 * per_unit)); - __m512i* out_0 = (__m512i*)(out + i + 0 * per_unit); - *out_0 = _mm512_add_epi16(a_0_0, _mm512_sub_epi16(a_1_0, s_0_0)); - } - - { - const __m512i a_0_1 = _mm512_load_si512((__m512i*)(a_0 + i + 1 * per_unit)); - const __m512i a_1_1 = _mm512_load_si512((__m512i*)(a_1 + i + 1 * per_unit)); - const __m512i s_0_1 = _mm512_load_si512((__m512i*)(s_0 + i + 1 * per_unit)); - __m512i* out_1 = (__m512i*)(out + i + 1 * per_unit); - *out_1 = _mm512_add_epi16(a_0_1, _mm512_sub_epi16(a_1_1, s_0_1)); - } - - { - const __m512i a_0_2 = _mm512_load_si512((__m512i*)(a_0 + i + 2 * per_unit)); - const __m512i a_1_2 = _mm512_load_si512((__m512i*)(a_1 + i + 2 * per_unit)); - const __m512i s_0_2 = _mm512_load_si512((__m512i*)(s_0 + i + 2 * per_unit)); - __m512i* out_2 = (__m512i*)(out + i + 2 * per_unit); - *out_2 = _mm512_add_epi16(a_0_2, _mm512_sub_epi16(a_1_2, s_0_2)); - } - - { - const __m512i a_0_3 = _mm512_load_si512((__m512i*)(a_0 + i + 3 * per_unit)); - const __m512i a_1_3 = _mm512_load_si512((__m512i*)(a_1 + i + 3 * per_unit)); - const __m512i s_0_3 = _mm512_load_si512((__m512i*)(s_0 + i + 3 * per_unit)); - __m512i* out_3 = (__m512i*)(out + i + 3 * per_unit); - *out_3 = _mm512_add_epi16(a_0_3, _mm512_sub_epi16(a_1_3, s_0_3)); - } - } - } -}; - -template -inline void add_add_sub(const std::int16_t* a_0, const std::int16_t* a_1, const std::int16_t* s_0, std::int16_t* out) noexcept { - return overload_set>::f(a_0, a_1, s_0, out); -} - -template -struct int16_add_add_sub_sub_x128 { - static constexpr std::size_t num_units = 4; - static constexpr bool available = divides>; - - static inline void - f(const std::int16_t* a_0, const std::int16_t* a_1, const std::int16_t* s_0, const std::int16_t* s_1, std::int16_t* out) noexcept { - for (std::size_t i(0); i < dim; i += num_units * per_unit) { - { - const __m512i a_0_0 = _mm512_load_si512((__m512i*)(a_0 + i + 0 * per_unit)); - const __m512i a_1_0 = _mm512_load_si512((__m512i*)(a_1 + i + 0 * per_unit)); - const __m512i s_0_0 = _mm512_load_si512((__m512i*)(s_0 + i + 0 * per_unit)); - const __m512i s_1_0 = _mm512_load_si512((__m512i*)(s_1 + i + 0 * per_unit)); - __m512i* out_0 = (__m512i*)(out + i + 0 * per_unit); - *out_0 = _mm512_add_epi16(_mm512_sub_epi16(a_0_0, s_0_0), _mm512_sub_epi16(a_1_0, s_1_0)); - } - - { - const __m512i a_0_1 = _mm512_load_si512((__m512i*)(a_0 + i + 1 * per_unit)); - const __m512i a_1_1 = _mm512_load_si512((__m512i*)(a_1 + i + 1 * per_unit)); - const __m512i s_0_1 = _mm512_load_si512((__m512i*)(s_0 + i + 1 * per_unit)); - const __m512i s_1_1 = _mm512_load_si512((__m512i*)(s_1 + i + 1 * per_unit)); - __m512i* out_1 = (__m512i*)(out + i + 1 * per_unit); - *out_1 = _mm512_add_epi16(_mm512_sub_epi16(a_0_1, s_0_1), _mm512_sub_epi16(a_1_1, s_1_1)); - } - - { - const __m512i a_0_2 = _mm512_load_si512((__m512i*)(a_0 + i + 2 * per_unit)); - const __m512i a_1_2 = _mm512_load_si512((__m512i*)(a_1 + i + 2 * per_unit)); - const __m512i s_0_2 = _mm512_load_si512((__m512i*)(s_0 + i + 2 * per_unit)); - const __m512i s_1_2 = _mm512_load_si512((__m512i*)(s_1 + i + 2 * per_unit)); - __m512i* out_2 = (__m512i*)(out + i + 2 * per_unit); - *out_2 = _mm512_add_epi16(_mm512_sub_epi16(a_0_2, s_0_2), _mm512_sub_epi16(a_1_2, s_1_2)); - } - - { - const __m512i a_0_3 = _mm512_load_si512((__m512i*)(a_0 + i + 3 * per_unit)); - const __m512i a_1_3 = _mm512_load_si512((__m512i*)(a_1 + i + 3 * per_unit)); - const __m512i s_0_3 = _mm512_load_si512((__m512i*)(s_0 + i + 3 * per_unit)); - const __m512i s_1_3 = _mm512_load_si512((__m512i*)(s_1 + i + 3 * per_unit)); - __m512i* out_3 = (__m512i*)(out + i + 3 * per_unit); - *out_3 = _mm512_add_epi16(_mm512_sub_epi16(a_0_3, s_0_3), _mm512_sub_epi16(a_1_3, s_1_3)); - } - } - } -}; - -template -inline void -add_add_sub_sub(const std::int16_t* a_0, const std::int16_t* a_1, const std::int16_t* s_0, const std::int16_t* s_1, std::int16_t* out) noexcept { - return overload_set>::f(a_0, a_1, s_0, s_1, out); -} - -template -struct float_relu_matrix_vector_product_x8_x1 { - static constexpr bool available = divides>; - - static inline void f(const float* matrix, const float* input, float* output) { - const __m256 zero = _mm256_setzero_ps(); - for (std::size_t i(0); i < dim1; ++i) { - __m256 sum = _mm256_setzero_ps(); - - for (std::size_t j(0); j < dim0; j += per_unit) { - const __m256 input_region = _mm256_max_ps(zero, _mm256_load_ps(input + j)); - sum = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + i * dim0 + j), input_region), sum); - } - - const __m128 reduced_4 = _mm_add_ps(_mm256_castps256_ps128(sum), _mm256_extractf128_ps(sum, 0x1)); - const __m128 reduced_2 = _mm_add_ps(reduced_4, _mm_movehl_ps(reduced_4, reduced_4)); - const __m128 reduced_1 = _mm_add_ss(reduced_2, _mm_shuffle_ps(reduced_2, reduced_2, 0x1)); - - output[i] += _mm_cvtss_f32(reduced_1); - } - } -}; - -template -struct float_relu_matrix_vector_product_x8_x8 { - static constexpr std::size_t num_units = 8; - static constexpr bool available = divides && divides>; - - static inline void f(const float* matrix, const float* input, float* output) noexcept { - const __m256 zero = _mm256_setzero_ps(); - __m256* v_output = (__m256*)output; - constexpr std::size_t output_step = num_units / per_unit; - for (std::size_t i(0); i < dim1; i += num_units, v_output += output_step) { - __m256 sum_0 = _mm256_setzero_ps(); - __m256 sum_1 = _mm256_setzero_ps(); - __m256 sum_2 = _mm256_setzero_ps(); - __m256 sum_3 = _mm256_setzero_ps(); - __m256 sum_4 = _mm256_setzero_ps(); - __m256 sum_5 = _mm256_setzero_ps(); - __m256 sum_6 = _mm256_setzero_ps(); - __m256 sum_7 = _mm256_setzero_ps(); - - for (std::size_t j(0); j < dim0; j += per_unit) { - const __m256 input_region = _mm256_max_ps(zero, _mm256_load_ps(input + j)); - sum_0 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 0) * dim0 + j), input_region), sum_0); - sum_1 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 1) * dim0 + j), input_region), sum_1); - sum_2 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 2) * dim0 + j), input_region), sum_2); - sum_3 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 3) * dim0 + j), input_region), sum_3); - sum_4 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 4) * dim0 + j), input_region), sum_4); - sum_5 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 5) * dim0 + j), input_region), sum_5); - sum_6 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 6) * dim0 + j), input_region), sum_6); - sum_7 = _mm256_add_ps(_mm256_mul_ps(_mm256_load_ps(matrix + (i + 7) * dim0 + j), input_region), sum_7); - } - - const __m256 sum_01 = _mm256_hadd_ps(sum_0, sum_1); - const __m256 sum_23 = _mm256_hadd_ps(sum_2, sum_3); - const __m256 sum_45 = _mm256_hadd_ps(sum_4, sum_5); - const __m256 sum_67 = _mm256_hadd_ps(sum_6, sum_7); - - const __m256 sum_0123 = _mm256_hadd_ps(sum_01, sum_23); - const __m256 sum_4567 = _mm256_hadd_ps(sum_45, sum_67); - - const __m256 sum_01234567 = _mm256_add_ps(_mm256_permute2f128_ps(sum_0123, sum_4567, 0x20), _mm256_permute2f128_ps(sum_0123, sum_4567, 0x31)); - - *v_output = _mm256_add_ps(*v_output, sum_01234567); - } - } -}; - -template -struct int16_relu_matrix_vector_product_x32_x8 { - static constexpr std::size_t num_units = 8; - static constexpr bool available = divides && divides>; - - static inline void f(const std::int16_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { - const __m512i zero = _mm512_setzero_si512(); - - const __m512i mm512_unpacklo_epi128_permutationx2var = - _mm512_set_epi32(0x17, 0x16, 0x15, 0x14, 0x07, 0x06, 0x05, 0x04, 0x13, 0x12, 0x11, 0x10, 0x03, 0x02, 0x01, 0x00); - - const __m512i mm512_unpackhi_epi128_permutationx2var = - _mm512_set_epi32(0x1f, 0x1e, 0x1d, 0x1c, 0x0f, 0x0e, 0x0d, 0x0c, 0x1b, 0x1a, 0x19, 0x18, 0x0b, 0x0a, 0x09, 0x08); - - __m256i* v_output = (__m256i*)output; - constexpr std::size_t output_step = num_units / per_unit; - for (std::size_t i(0); i < dim1; i += num_units, v_output += output_step) { - __m512i sum_0 = _mm512_setzero_si512(); - __m512i sum_1 = _mm512_setzero_si512(); - __m512i sum_2 = _mm512_setzero_si512(); - __m512i sum_3 = _mm512_setzero_si512(); - __m512i sum_4 = _mm512_setzero_si512(); - __m512i sum_5 = _mm512_setzero_si512(); - __m512i sum_6 = _mm512_setzero_si512(); - __m512i sum_7 = _mm512_setzero_si512(); - - for (std::size_t j(0); j < dim0; j += per_unit) { - const __m512i input_region = _mm512_max_epi16(zero, _mm512_load_si512((__m512i*)(input + j))); - sum_0 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 0) * dim0 + j)), input_region), sum_0); - sum_1 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 1) * dim0 + j)), input_region), sum_1); - sum_2 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 2) * dim0 + j)), input_region), sum_2); - sum_3 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 3) * dim0 + j)), input_region), sum_3); - sum_4 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 4) * dim0 + j)), input_region), sum_4); - sum_5 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 5) * dim0 + j)), input_region), sum_5); - sum_6 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 6) * dim0 + j)), input_region), sum_6); - sum_7 = _mm512_add_epi32(_mm512_madd_epi16(_mm512_load_si512((__m512i*)(matrix + (i + 7) * dim0 + j)), input_region), sum_7); - } - - const __m512i sum_01 = _mm512_add_epi32(_mm512_unpacklo_epi32(sum_0, sum_1), _mm512_unpackhi_epi32(sum_0, sum_1)); - const __m512i sum_23 = _mm512_add_epi32(_mm512_unpacklo_epi32(sum_2, sum_3), _mm512_unpackhi_epi32(sum_2, sum_3)); - const __m512i sum_45 = _mm512_add_epi32(_mm512_unpacklo_epi32(sum_4, sum_5), _mm512_unpackhi_epi32(sum_4, sum_5)); - const __m512i sum_67 = _mm512_add_epi32(_mm512_unpacklo_epi32(sum_6, sum_7), _mm512_unpackhi_epi32(sum_6, sum_7)); - - const __m512i sum_0123 = _mm512_add_epi32(_mm512_unpacklo_epi64(sum_01, sum_23), _mm512_unpackhi_epi64(sum_01, sum_23)); - const __m512i sum_4567 = _mm512_add_epi32(_mm512_unpacklo_epi64(sum_45, sum_67), _mm512_unpackhi_epi64(sum_45, sum_67)); - - const __m512i sum_512_01234567 = _mm512_add_epi32( - _mm512_permutex2var_epi32(sum_0123, mm512_unpacklo_epi128_permutationx2var, sum_4567), - _mm512_permutex2var_epi32(sum_0123, mm512_unpackhi_epi128_permutationx2var, sum_4567)); - - const __m256i sum_256_01234567 = _mm256_add_epi32(_mm512_castsi512_si256(sum_512_01234567), _mm512_extracti64x4_epi64(sum_512_01234567, 0x1)); - - *v_output = _mm256_add_epi32(*v_output, sum_256_01234567); - } +template +inline void crelu255_matrix_vector_product(const T0* matrix, const T1* input, T2* output) noexcept { +#pragma omp simd + for (std::size_t i = 0; i < dim1; ++i) { + for (std::size_t j = 0; j < dim0; ++j) { output[i] += static_cast(std::min(std::max(input[j], T1{0}), T1{255})) * static_cast((matrix + i * dim0)[j]); } } -}; - -template -inline void relu_matrix_vector_product(const float* matrix, const float* input, float* output) noexcept { - return overload_set, float_relu_matrix_vector_product_x8_x1>::f( - matrix, input, output); -} - -template -inline void relu_matrix_vector_product(const std::int16_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { - return overload_set>::f(matrix, input, output); } -#elif defined(__AVX2__) +#if defined(__AVX2__) template struct int16_add_x64 { static constexpr std::size_t num_units = 4; @@ -659,12 +364,15 @@ struct float_relu_matrix_vector_product_x8_x8 { }; template -struct int16_relu_matrix_vector_product_x16_x8 { +struct int16_crelu255_matrix_vector_product_x32_x8 { static constexpr std::size_t num_units = 8; - static constexpr bool available = divides && divides>; + static constexpr bool available = divides && divides>; - static inline void f(const std::int16_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { - const __m256i zero = _mm256_setzero_si256(); + static inline __m256i mm256_maddubs_epi16_coalesced(const __m256i& a, const __m256i& b) { + return _mm256_madd_epi16(_mm256_set1_epi16(1), _mm256_maddubs_epi16(a, b)); + } + + static inline void f(const std::int8_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { __m256i* v_output = (__m256i*)output; constexpr std::size_t output_step = num_units / per_unit; for (std::size_t i(0); i < dim1; i += num_units, v_output += output_step) { @@ -677,16 +385,19 @@ struct int16_relu_matrix_vector_product_x16_x8 { __m256i sum_6 = _mm256_setzero_si256(); __m256i sum_7 = _mm256_setzero_si256(); - for (std::size_t j(0); j < dim0; j += per_unit) { - const __m256i input_region = _mm256_max_epi16(zero, _mm256_load_si256((__m256i*)(input + j))); - sum_0 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 0) * dim0 + j)), input_region), sum_0); - sum_1 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 1) * dim0 + j)), input_region), sum_1); - sum_2 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 2) * dim0 + j)), input_region), sum_2); - sum_3 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 3) * dim0 + j)), input_region), sum_3); - sum_4 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 4) * dim0 + j)), input_region), sum_4); - sum_5 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 5) * dim0 + j)), input_region), sum_5); - sum_6 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 6) * dim0 + j)), input_region), sum_6); - sum_7 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_load_si256((__m256i*)(matrix + (i + 7) * dim0 + j)), input_region), sum_7); + for (std::size_t j(0); j < dim0; j += per_unit) { + const __m256i input_region_0 = _mm256_load_si256((__m256i*)(input + j + 0 * per_unit)); + const __m256i input_region_1 = _mm256_load_si256((__m256i*)(input + j + 1 * per_unit)); + const __m256i input_region = _mm256_permute4x64_epi64(_mm256_packus_epi16(input_region_0, input_region_1), 0b11011000); + + sum_0 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 0) * dim0 + j))), sum_0); + sum_1 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 1) * dim0 + j))), sum_1); + sum_2 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 2) * dim0 + j))), sum_2); + sum_3 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 3) * dim0 + j))), sum_3); + sum_4 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 4) * dim0 + j))), sum_4); + sum_5 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 5) * dim0 + j))), sum_5); + sum_6 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 6) * dim0 + j))), sum_6); + sum_7 = _mm256_add_epi32(mm256_maddubs_epi16_coalesced(input_region, _mm256_load_si256((__m256i*)(matrix + (i + 7) * dim0 + j))), sum_7); } const __m256i sum_01 = _mm256_hadd_epi32(sum_0, sum_1); @@ -712,8 +423,8 @@ inline void relu_matrix_vector_product(const float* matrix, const float* input, } template -inline void relu_matrix_vector_product(const std::int16_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { - return overload_set>::f(matrix, input, output); +inline void crelu255_matrix_vector_product(const std::int8_t* matrix, const std::int16_t* input, std::int32_t* output) noexcept { + return overload_set>::f(matrix, input, output); } #elif defined(__SSSE3__) diff --git a/include/nnue/weights.h b/include/nnue/weights.h index 963e970..9e69f19 100644 --- a/include/nnue/weights.h +++ b/include/nnue/weights.h @@ -36,8 +36,9 @@ namespace nnue { struct weights { using parameter_type = float; using quantized_parameter_type = std::int16_t; + using half_quantized_parameter_type = std::int8_t; - static constexpr std::size_t base_dim = 768; + static constexpr std::size_t base_dim = 1024; static constexpr parameter_type shared_quantization_scale = static_cast(512); static constexpr parameter_type fc0_weight_quantization_scale = static_cast(1024); @@ -47,11 +48,11 @@ struct weights { weights_streamer::signature_type signature_{0}; sparse_affine_layer shared{}; - dense_relu_affine_layer fc0{}; + dense_relu_affine_layer<2 * base_dim, 8, parameter_type> fc0{}; - dense_relu_affine_layer fc1{}; - dense_relu_affine_layer fc2{}; - dense_relu_affine_layer fc3{}; + dense_relu_affine_layer<8, 8, parameter_type> fc1{}; + dense_relu_affine_layer<16, 8, parameter_type> fc2{}; + dense_relu_affine_layer<24, 1, parameter_type> fc3{}; [[nodiscard]] constexpr const weights_streamer::signature_type& signature() const noexcept { return signature_; } @@ -65,7 +66,9 @@ struct weights { quantized.signature_ = signature_; quantized.shared = shared.quantized(shared_quantization_scale); - quantized.fc0 = fc0.quantized(fc0_weight_quantization_scale, fc0_bias_quantization_scale); + + quantized.fc0 = + fc0.quantized(fc0_weight_quantization_scale, fc0_bias_quantization_scale); quantized.white_fc0 = quantized.fc0; quantized.black_fc0 = quantized.white_fc0.half_input_flipped(); @@ -97,20 +100,21 @@ struct weights { struct quantized_weights { using parameter_type = weights::parameter_type; using quantized_parameter_type = weights::quantized_parameter_type; + using half_quantized_parameter_type = weights::half_quantized_parameter_type; - static constexpr std::size_t base_dim = 768; + static constexpr std::size_t base_dim = 1024; weights_streamer::signature_type signature_{0}; sparse_affine_layer shared{}; - dense_relu_affine_layer fc0{}; - dense_relu_affine_layer white_fc0{}; - dense_relu_affine_layer black_fc0{}; + dense_relu_affine_layer<2 * base_dim, 8, half_quantized_parameter_type, quantized_parameter_type> fc0{}; + dense_relu_affine_layer<2 * base_dim, 8, half_quantized_parameter_type, quantized_parameter_type> white_fc0{}; + dense_relu_affine_layer<2 * base_dim, 8, half_quantized_parameter_type, quantized_parameter_type> black_fc0{}; - dense_relu_affine_layer fc1{}; - dense_relu_affine_layer fc2{}; - dense_relu_affine_layer fc3{}; + dense_relu_affine_layer<8, 8, parameter_type> fc1{}; + dense_relu_affine_layer<16, 8, parameter_type> fc2{}; + dense_relu_affine_layer<24, 1, parameter_type> fc3{}; [[nodiscard]] constexpr const weights_streamer::signature_type& signature() const noexcept { return signature_; }