From 9c40b48c6a25b5c27814f360a71e76bb35738a87 Mon Sep 17 00:00:00 2001 From: "Stephan T. Lavavej" Date: Thu, 7 Mar 2024 21:29:04 -0800 Subject: [PATCH] Fix vectorized `ranges::find` with `unreachable_sentinel` to properly mask the beginning and handle unaligned pointers (#4450) --- stl/src/vector_algorithms.cpp | 32 ++++------- .../VSO_0000000_vector_algorithms/test.cpp | 54 +++++++++++++++++++ 2 files changed, 65 insertions(+), 21 deletions(-) diff --git a/stl/src/vector_algorithms.cpp b/stl/src/vector_algorithms.cpp index f3a364884b..bf9de5a308 100644 --- a/stl/src/vector_algorithms.cpp +++ b/stl/src/vector_algorithms.cpp @@ -1844,7 +1844,10 @@ namespace { template const void* __stdcall __std_find_trivial_unsized_impl(const void* _First, const _Ty _Val) noexcept { #ifndef _M_ARM64EC - if (_Use_avx2()) { + if ((reinterpret_cast(_First) & (sizeof(_Ty) - 1)) != 0) { + // _First isn't aligned to sizeof(_Ty), so we need to use the scalar fallback below. + // This can happen with 8-byte elements on x86's 4-aligned stack. It can also happen with packed structs. + } else if (_Use_avx2()) { _Zeroupper_on_exit _Guard; // TRANSITION, DevCom-10331414 // We read by vector-sized pieces, and we align pointers to vector-sized boundary. @@ -1862,16 +1865,8 @@ namespace { unsigned int _Bingo = static_cast(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand))); _Bingo &= _Mask; - if (_Bingo != 0) { - unsigned long _Offset = _tzcnt_u32(_Bingo); - _Advance_bytes(_First, _Offset); - return _First; - } for (;;) { - _Data = _mm256_load_si256(static_cast(_First)); - _Bingo = static_cast(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand))); - if (_Bingo != 0) { unsigned long _Offset = _tzcnt_u32(_Bingo); _Advance_bytes(_First, _Offset); @@ -1879,10 +1874,11 @@ namespace { } _Advance_bytes(_First, 32); - } - } - if (_Traits::_Sse_available()) { + _Data = _mm256_load_si256(static_cast(_First)); + _Bingo = static_cast(_mm256_movemask_epi8(_Traits::_Cmp_avx(_Data, _Comparand))); + } + } else if (_Traits::_Sse_available()) { // We read by vector-sized pieces, and we align pointers to vector-sized boundary. // From start partial piece we mask out matches that don't belong to the range. // This makes sure we never cross page boundary, thus we read 'as if' sequentially. @@ -1898,17 +1894,8 @@ namespace { unsigned int _Bingo = static_cast(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand))); _Bingo &= _Mask; - if (_Bingo != 0) { - unsigned long _Offset; - _BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable] - _Advance_bytes(_First, _Offset); - return _First; - } for (;;) { - _Data = _mm_load_si128(static_cast(_First)); - _Bingo = static_cast(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand))); - if (_Bingo != 0) { unsigned long _Offset; _BitScanForward(&_Offset, _Bingo); // lgtm [cpp/conditionallyuninitializedvariable] @@ -1917,6 +1904,9 @@ namespace { } _Advance_bytes(_First, 16); + + _Data = _mm_load_si128(static_cast(_First)); + _Bingo = static_cast(_mm_movemask_epi8(_Traits::_Cmp_sse(_Data, _Comparand))); } } #endif // !_M_ARM64EC diff --git a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp index 087220a63a..510c7b9935 100644 --- a/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp +++ b/tests/std/tests/VSO_0000000_vector_algorithms/test.cpp @@ -140,6 +140,53 @@ void test_find(mt19937_64& gen) { } } +#if _HAS_CXX20 +template +struct NormalArrayWrapper { + T m_arr[N]; +}; + +// Also test GH-4454 "vector_algorithms.cpp: __std_find_trivial_unsized_impl assumes N-byte elements are N-aligned" +#pragma pack(push, 1) +template +struct PackedArrayWrapper { + uint8_t m_ignored; // to misalign the following array + T m_arr[N]; +}; +#pragma pack(pop) + +// GH-4449 : ranges::find with unreachable_sentinel / __std_find_trivial_unsized_1 gives wrong result +template class ArrayWrapper> +void test_gh_4449_impl() { + constexpr T desired_val{11}; + constexpr T unwanted_val{22}; + + ArrayWrapper wrapper; + auto& arr = wrapper.m_arr; + + constexpr int mid1 = 64; + constexpr int mid2 = 192; + + ranges::fill(arr, arr + mid1, desired_val); + ranges::fill(arr + mid1, arr + mid2, unwanted_val); + ranges::fill(arr + mid2, end(arr), desired_val); + + for (int idx = mid1; idx <= mid2; ++idx) { // when idx == mid2, the value is immediately found + const auto where = ranges::find(arr + idx, unreachable_sentinel, desired_val); + + assert(where == arr + mid2); + + arr[idx] = desired_val; // get ready for the next iteration + } +} + +template +void test_gh_4449() { + test_gh_4449_impl(); + test_gh_4449_impl(); +} +#endif // _HAS_CXX20 + #if _HAS_CXX23 template void test_case_find_last(const vector& input, T v) { @@ -371,6 +418,13 @@ void test_vector_algorithms(mt19937_64& gen) { test_find(gen); test_find(gen); +#if _HAS_CXX20 + test_gh_4449(); + test_gh_4449(); + test_gh_4449(); + test_gh_4449(); +#endif // _HAS_CXX20 + #if _HAS_CXX23 test_find_last(gen); test_find_last(gen);