Skip to content

Commit

Permalink
Fix vectorized ranges::find with unreachable_sentinel to properly…
Browse files Browse the repository at this point in the history
… mask the beginning and handle unaligned pointers (#4450)
  • Loading branch information
StephanTLavavej committed Mar 8, 2024
1 parent 0407db6 commit 9c40b48
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 21 deletions.
32 changes: 11 additions & 21 deletions stl/src/vector_algorithms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1844,7 +1844,10 @@ namespace {
template <class _Traits, class _Ty>
const void* __stdcall __std_find_trivial_unsized_impl(const void* _First, const _Ty _Val) noexcept {
#ifndef _M_ARM64EC
if (_Use_avx2()) {
if ((reinterpret_cast<uintptr_t>(_First) & (sizeof(_Ty) - 1)) != 0) {
// _First isn't aligned to sizeof(_Ty), so we need to use the scalar fallback below.
// This can happen with 8-byte elements on x86's 4-aligned stack. It can also happen with packed structs.
} else if (_Use_avx2()) {
_Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414

// We read by vector-sized pieces, and we align pointers to vector-sized boundary.
Expand All @@ -1862,27 +1865,20 @@ namespace {
unsigned int _Bingo = static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand)));

_Bingo &= _Mask;
if (_Bingo != 0) {
unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First, _Offset);
return _First;
}

for (;;) {
_Data = _mm256_load_si256(static_cast<const __m256i*>(_First));
_Bingo = static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand)));

if (_Bingo != 0) {
unsigned long _Offset = _tzcnt_u32(_Bingo);
_Advance_bytes(_First, _Offset);
return _First;
}

_Advance_bytes(_First, 32);
}
}

if (_Traits::_Sse_available()) {
_Data = _mm256_load_si256(static_cast<const __m256i*>(_First));
_Bingo = static_cast<unsigned int>(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand)));
}
} else if (_Traits::_Sse_available()) {
// We read by vector-sized pieces, and we align pointers to vector-sized boundary.
// From start partial piece we mask out matches that don't belong to the range.
// This makes sure we never cross page boundary, thus we read 'as if' sequentially.
Expand All @@ -1898,17 +1894,8 @@ namespace {
unsigned int _Bingo = static_cast<unsigned int>(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand)));

_Bingo &= _Mask;
if (_Bingo != 0) {
unsigned long _Offset;
_BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
_Advance_bytes(_First, _Offset);
return _First;
}

for (;;) {
_Data = _mm_load_si128(static_cast<const __m128i*>(_First));
_Bingo = static_cast<unsigned int>(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand)));

if (_Bingo != 0) {
unsigned long _Offset;
_BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable]
Expand All @@ -1917,6 +1904,9 @@ namespace {
}

_Advance_bytes(_First, 16);

_Data = _mm_load_si128(static_cast<const __m128i*>(_First));
_Bingo = static_cast<unsigned int>(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand)));
}
}
#endif // !_M_ARM64EC
Expand Down
54 changes: 54 additions & 0 deletions tests/std/tests/VSO_0000000_vector_algorithms/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,53 @@ void test_find(mt19937_64& gen) {
}
}

#if _HAS_CXX20
template <class T, size_t N>
struct NormalArrayWrapper {
T m_arr[N];
};

// Also test GH-4454 "vector_algorithms.cpp: __std_find_trivial_unsized_impl assumes N-byte elements are N-aligned"
#pragma pack(push, 1)
template <class T, size_t N>
struct PackedArrayWrapper {
uint8_t m_ignored; // to misalign the following array
T m_arr[N];
};
#pragma pack(pop)

// GH-4449 <xutility>: ranges::find with unreachable_sentinel / __std_find_trivial_unsized_1 gives wrong result
template <class T, template <class, size_t> class ArrayWrapper>
void test_gh_4449_impl() {
constexpr T desired_val{11};
constexpr T unwanted_val{22};

ArrayWrapper<T, 256> wrapper;
auto& arr = wrapper.m_arr;

constexpr int mid1 = 64;
constexpr int mid2 = 192;

ranges::fill(arr, arr + mid1, desired_val);
ranges::fill(arr + mid1, arr + mid2, unwanted_val);
ranges::fill(arr + mid2, end(arr), desired_val);

for (int idx = mid1; idx <= mid2; ++idx) { // when idx == mid2, the value is immediately found
const auto where = ranges::find(arr + idx, unreachable_sentinel, desired_val);

assert(where == arr + mid2);

arr[idx] = desired_val; // get ready for the next iteration
}
}

template <class T>
void test_gh_4449() {
test_gh_4449_impl<T, NormalArrayWrapper>();
test_gh_4449_impl<T, PackedArrayWrapper>();
}
#endif // _HAS_CXX20

#if _HAS_CXX23
template <class T>
void test_case_find_last(const vector<T>& input, T v) {
Expand Down Expand Up @@ -371,6 +418,13 @@ void test_vector_algorithms(mt19937_64& gen) {
test_find<long long>(gen);
test_find<unsigned long long>(gen);

#if _HAS_CXX20
test_gh_4449<uint8_t>();
test_gh_4449<uint16_t>();
test_gh_4449<uint32_t>();
test_gh_4449<uint64_t>();
#endif // _HAS_CXX20

#if _HAS_CXX23
test_find_last<char>(gen);
test_find_last<signed char>(gen);
Expand Down

0 comments on commit 9c40b48

Please sign in to comment.