Skip to content

Commit c95e938

Browse files
Vectorize search_n for CPUs with SSE4.2 but not AVX2 support; handle AVX2 tail (#5544)
Co-authored-by: Stephan T. Lavavej <[email protected]>
1 parent 395e9a6 commit c95e938

File tree

1 file changed

+99
-0
lines changed

1 file changed

+99
-0
lines changed

stl/src/vector_algorithms.cpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3183,8 +3183,107 @@ namespace {
31833183
_Advance_bytes(_First, 32);
31843184
} while (_First != _Stop_at);
31853185

3186+
if (const size_t _Tail = _Length & 0x1C; _Tail != 0) {
3187+
const __m256i _Tail_mask = _Avx2_tail_mask_32(_Tail);
3188+
const __m256i _Data = _mm256_maskload_epi32(reinterpret_cast<const int*>(_First), _Tail_mask);
3189+
3190+
const __m256i _Cmp = _Traits::_Cmp_avx(_Comparand, _Data);
3191+
const uint32_t _Mask = _mm256_movemask_epi8(_mm256_and_si256(_Cmp, _Tail_mask));
3192+
3193+
const uint64_t _Msk_with_carry = uint64_t{_Carry} | (uint64_t{_Mask} << 32);
3194+
uint64_t _MskX = _Msk_with_carry;
3195+
3196+
_MskX = (_MskX >> sizeof(_Ty)) & _MskX;
3197+
3198+
if constexpr (sizeof(_Ty) == 1) {
3199+
_MskX = __ull_rshift(_MskX, _Sh1) & _MskX;
3200+
}
3201+
3202+
if constexpr (sizeof(_Ty) < 4) {
3203+
_MskX = __ull_rshift(_MskX, _Sh2) & _MskX;
3204+
}
3205+
3206+
if constexpr (sizeof(_Ty) < 8) {
3207+
_MskX = __ull_rshift(_MskX, _Sh3) & _MskX;
3208+
}
3209+
3210+
if (_MskX != 0) {
3211+
#ifdef _M_IX86
3212+
const uint32_t _MskLow = static_cast<uint32_t>(_MskX);
3213+
3214+
const int _Shift = _MskLow != 0
3215+
? static_cast<int>(_tzcnt_u32(_MskLow)) - 32
3216+
: static_cast<int>(_tzcnt_u32(static_cast<uint32_t>(_MskX >> 32)));
3217+
3218+
#elifdef _M_X64
3219+
const long long _Shift = static_cast<long long>(_tzcnt_u64(_MskX)) - 32;
3220+
#else
3221+
#error Unsupported architecture
3222+
#endif
3223+
_Advance_bytes(_First, _Shift);
3224+
return _First;
3225+
}
3226+
3227+
_Carry = static_cast<uint32_t>(__ull_rshift(_Msk_with_carry, static_cast<int>(_Tail)));
3228+
3229+
_Advance_bytes(_First, _Tail);
3230+
}
3231+
31863232
_Mid1 = static_cast<const _Ty*>(_First);
31873233
_Rewind_bytes(_First, _lzcnt_u32(~_Carry));
3234+
} else if constexpr (sizeof(_Ty) < 8) {
3235+
if (_Count <= 8 / sizeof(_Ty) && _Length >= 16 && _Use_sse42()) {
3236+
const int _Bytes_count = static_cast<int>(_Count * sizeof(_Ty));
3237+
const int _Sh1 = sizeof(_Ty) != 1 ? 0 : (_Bytes_count < 4 ? _Bytes_count - 2 : 2);
3238+
const int _Sh2 = sizeof(_Ty) >= 4 ? 0
3239+
: _Bytes_count < 4 ? 0
3240+
: (_Bytes_count < 8 ? _Bytes_count - 4 : 4);
3241+
3242+
const __m128i _Comparand = _Traits::_Set_sse(_Val);
3243+
3244+
const void* _Stop_at = _First;
3245+
_Advance_bytes(_Stop_at, _Length & ~size_t{0xF});
3246+
3247+
uint32_t _Carry = 0;
3248+
do {
3249+
const __m128i _Data = _mm_loadu_si128(reinterpret_cast<const __m128i*>(_First));
3250+
3251+
const __m128i _Cmp = _Traits::_Cmp_sse(_Comparand, _Data);
3252+
const uint32_t _Mask = _mm_movemask_epi8(_Cmp);
3253+
3254+
uint32_t _MskX = _Carry | (_Mask << 16);
3255+
3256+
_MskX = (_MskX >> sizeof(_Ty)) & _MskX;
3257+
3258+
if constexpr (sizeof(_Ty) == 1) {
3259+
_MskX = (_MskX >> _Sh1) & _MskX;
3260+
}
3261+
3262+
if constexpr (sizeof(_Ty) < 4) {
3263+
_MskX = (_MskX >> _Sh2) & _MskX;
3264+
}
3265+
3266+
if (_MskX != 0) {
3267+
unsigned long _Pos;
3268+
// CodeQL [SM02313] _Pos is always initialized: _MskX != 0 was checked right above.
3269+
_BitScanForward(&_Pos, _MskX);
3270+
_Advance_bytes(_First, static_cast<ptrdiff_t>(_Pos) - 16);
3271+
return _First;
3272+
}
3273+
3274+
_Carry = _Mask;
3275+
3276+
_Advance_bytes(_First, 16);
3277+
} while (_First != _Stop_at);
3278+
3279+
_Mid1 = static_cast<const _Ty*>(_First);
3280+
3281+
unsigned long _Carry_pos;
3282+
// Here, _Carry can't be 0xFFFF, because that would have been a match. Therefore:
3283+
// CodeQL [SM02313] _Carry_pos is always initialized: `(_Carry ^ 0xFFFF) != 0` is always true.
3284+
_BitScanReverse(&_Carry_pos, _Carry ^ 0xFFFF);
3285+
_Rewind_bytes(_First, 15 - static_cast<ptrdiff_t>(_Carry_pos));
3286+
}
31883287
}
31893288
#endif // ^^^ !defined(_M_ARM64EC) ^^^
31903289
auto _Match_start = static_cast<const _Ty*>(_First);

0 commit comments

Comments
 (0)