Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

<generator>: Make nested types of generator ADL-proof #4464

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
132 changes: 74 additions & 58 deletions stl/inc/generator
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,18 @@ using _Gen_reference_t = conditional_t<is_void_v<_Vty>, _Rty&&, _Rty>;
template <class _Ref>
using _Gen_yield_t = conditional_t<is_reference_v<_Ref>, _Ref, const _Ref&>;

template <class>
struct _Gen_promise_base_provider {
class _Base;
};

template <class, class>
struct _Gen_iter_provider {
class _Iterator;
};

template <class _Yielded>
class _Gen_promise_base {
class _Gen_promise_base_provider<_Yielded>::_Base {
public:
_STL_INTERNAL_STATIC_ASSERT(is_reference_v<_Yielded>);

Expand Down Expand Up @@ -220,20 +230,21 @@ public:
template <class _Rty, class _Vty, class _Alloc, class _Unused>
requires same_as<_Gen_yield_t<_Gen_reference_t<_Rty, _Vty>>, _Yielded>
_NODISCARD auto yield_value(_RANGES elements_of<generator<_Rty, _Vty, _Alloc>&&, _Unused> _Elem) noexcept {
return _Nested_awaitable<_Rty, _Vty, _Alloc>{std::move(_Elem.range)};
using _Nested_awaitable = _Nested_awaitable_provider<_Rty, _Vty, _Alloc>::_Awaitable;
return _Nested_awaitable{std::move(_Elem.range)};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should use _STD in product code.

}

template <_RANGES input_range _Rng, class _Alloc>
requires convertible_to<_RANGES range_reference_t<_Rng>, _Yielded>
_NODISCARD auto yield_value(_RANGES elements_of<_Rng, _Alloc> _Elem) {
using _Vty = _RANGES range_value_t<_Rng>;
return _Nested_awaitable<_Yielded, _Vty, _Alloc>{
[](allocator_arg_t, _Alloc, _RANGES iterator_t<_Rng> _It,
const _RANGES sentinel_t<_Rng> _Se) -> generator<_Yielded, _Vty, _Alloc> {
for (; _It != _Se; ++_It) {
co_yield static_cast<_Yielded>(*_It);
}
}(allocator_arg, _Elem.allocator, _RANGES begin(_Elem.range), _RANGES end(_Elem.range))};
using _Vty = _RANGES range_value_t<_Rng>;
using _Nested_awaitable = _Nested_awaitable_provider<_Yielded, _Vty, _Alloc>::_Awaitable;
return _Nested_awaitable{[](allocator_arg_t, _Alloc, _RANGES iterator_t<_Rng> _It,
const _RANGES sentinel_t<_Rng> _Se) -> generator<_Yielded, _Vty, _Alloc> {
for (; _It != _Se; ++_It) {
co_yield static_cast<_Yielded>(*_It);
}
}(allocator_arg, _Elem.allocator, _RANGES begin(_Elem.range), _RANGES end(_Elem.range))};
}

void await_transform() = delete;
Expand All @@ -259,20 +270,20 @@ private:
template <class _Promise>
constexpr void await_suspend(coroutine_handle<_Promise> _Handle) noexcept {
#ifdef __cpp_lib_is_pointer_interconvertible // TRANSITION, LLVM-48860
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Gen_promise_base, _Promise>);
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Base, _Promise>);
#endif // ^^^ no workaround ^^^

_Gen_promise_base& _Current = _Handle.promise();
_Current._Ptr = _STD addressof(_Val);
_Base& _Current = _Handle.promise();
_Current._Ptr = _STD addressof(_Val);
}

constexpr void await_resume() const noexcept {}
};

struct _Nest_info {
exception_ptr _Except;
coroutine_handle<_Gen_promise_base> _Parent;
coroutine_handle<_Gen_promise_base> _Root;
coroutine_handle<_Base> _Parent;
coroutine_handle<_Base> _Root;
};

struct _Final_awaiter {
Expand All @@ -283,81 +294,86 @@ private:
template <class _Promise>
_NODISCARD coroutine_handle<> await_suspend(coroutine_handle<_Promise> _Handle) noexcept {
#ifdef __cpp_lib_is_pointer_interconvertible // TRANSITION, LLVM-48860
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Gen_promise_base, _Promise>);
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Base, _Promise>);
#endif // ^^^ no workaround ^^^

_Gen_promise_base& _Current = _Handle.promise();
_Base& _Current = _Handle.promise();
if (!_Current._Info) {
return _STD noop_coroutine();
}

coroutine_handle<_Gen_promise_base> _Cont = _Current._Info->_Parent;
_Current._Info->_Root.promise()._Top = _Cont;
_Current._Info = nullptr;
coroutine_handle<_Base> _Cont = _Current._Info->_Parent;
_Current._Info->_Root.promise()._Top = _Cont;
_Current._Info = nullptr;
return _Cont;
}

void await_resume() noexcept {}
};

template <class _Rty, class _Vty, class _Alloc>
struct _Nested_awaitable {
_STL_INTERNAL_STATIC_ASSERT(same_as<_Gen_yield_t<_Gen_reference_t<_Rty, _Vty>>, _Yielded>);
struct _Nested_awaitable_provider {
struct _Awaitable {
_STL_INTERNAL_STATIC_ASSERT(same_as<_Gen_yield_t<_Gen_reference_t<_Rty, _Vty>>, _Yielded>);

_Nest_info _Nested;
generator<_Rty, _Vty, _Alloc> _Gen;
_Nest_info _Nested;
generator<_Rty, _Vty, _Alloc> _Gen;

explicit _Nested_awaitable(generator<_Rty, _Vty, _Alloc>&& _Gen_) noexcept : _Gen(_STD move(_Gen_)) {}
explicit _Awaitable(generator<_Rty, _Vty, _Alloc>&& _Gen_) noexcept : _Gen(_STD move(_Gen_)) {}

_NODISCARD bool await_ready() noexcept {
return !_Gen._Coro;
}
_NODISCARD bool await_ready() noexcept {
return !_Gen._Coro;
}

template <class _Promise>
_NODISCARD coroutine_handle<_Gen_promise_base> await_suspend(coroutine_handle<_Promise> _Current) noexcept {
template <class _Promise>
_NODISCARD coroutine_handle<_Base> await_suspend(coroutine_handle<_Promise> _Current) noexcept {
#ifdef __cpp_lib_is_pointer_interconvertible // TRANSITION, LLVM-48860
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Gen_promise_base, _Promise>);
_STL_INTERNAL_STATIC_ASSERT(is_pointer_interconvertible_base_of_v<_Base, _Promise>);
#endif // ^^^ no workaround ^^^
auto _Target = coroutine_handle<_Gen_promise_base>::from_address(_Gen._Coro.address());
_Nested._Parent = coroutine_handle<_Gen_promise_base>::from_address(_Current.address());
_Gen_promise_base& _Parent_promise = _Nested._Parent.promise();
if (_Parent_promise._Info) {
_Nested._Root = _Parent_promise._Info->_Root;
} else {
_Nested._Root = _Nested._Parent;
auto _Target = coroutine_handle<_Base>::from_address(_Gen._Coro.address());
_Nested._Parent = coroutine_handle<_Base>::from_address(_Current.address());
_Base& _Parent_promise = _Nested._Parent.promise();
if (_Parent_promise._Info) {
_Nested._Root = _Parent_promise._Info->_Root;
} else {
_Nested._Root = _Nested._Parent;
}
_Nested._Root.promise()._Top = _Target;
_Target.promise()._Info = _STD addressof(_Nested);
return _Target;
}
_Nested._Root.promise()._Top = _Target;
_Target.promise()._Info = _STD addressof(_Nested);
return _Target;
}

void await_resume() {
if (_Nested._Except) {
_STD rethrow_exception(_STD move(_Nested._Except));
void await_resume() {
if (_Nested._Except) {
_STD rethrow_exception(_STD move(_Nested._Except));
}
}
}
};
};

template <class, class>
friend class _Gen_iter;
friend struct _Gen_iter_provider;

// _Top and _Info are mutually exclusive, and could potentially be merged.
coroutine_handle<_Gen_promise_base> _Top = coroutine_handle<_Gen_promise_base>::from_promise(*this);
add_pointer_t<_Yielded> _Ptr = nullptr;
_Nest_info* _Info = nullptr;
coroutine_handle<_Base> _Top = coroutine_handle<_Base>::from_promise(*this);
add_pointer_t<_Yielded> _Ptr = nullptr;
_Nest_info* _Info = nullptr;
};

template <class _Yielded>
using _Gen_promise_base = _Gen_promise_base_provider<_Yielded>::_Base;

struct _Gen_secret_tag {};

template <class _Value, class _Ref>
class _Gen_iter {
class _Gen_iter_provider<_Value, _Ref>::_Iterator {
public:
using value_type = _Value;
using difference_type = ptrdiff_t;

_Gen_iter(_Gen_iter&& _That) noexcept : _Coro{_STD exchange(_That._Coro, {})} {}
_Iterator(_Iterator&& _That) noexcept : _Coro{_STD exchange(_That._Coro, {})} {}

_Gen_iter& operator=(_Gen_iter&& _That) noexcept {
_Iterator& operator=(_Iterator&& _That) noexcept {
_Coro = _STD exchange(_That._Coro, {});
return *this;
}
Expand All @@ -367,7 +383,7 @@ public:
return static_cast<_Ref>(*_Coro.promise()._Top.promise()._Ptr);
}

_Gen_iter& operator++() {
_Iterator& operator++() {
_STL_ASSERT(!_Coro.done(), "Can't increment generator end iterator");
_Coro.promise()._Top.resume();
return *this;
Expand All @@ -377,7 +393,7 @@ public:
++*this;
}

_NODISCARD_FRIEND bool operator==(const _Gen_iter& _It, default_sentinel_t) noexcept /* strengthened */
_NODISCARD_FRIEND bool operator==(const _Iterator& _It, default_sentinel_t) noexcept /* strengthened */
{
return _It._Coro.done();
}
Expand All @@ -386,7 +402,7 @@ private:
template <class, class, class>
friend class generator;

explicit _Gen_iter(_Gen_secret_tag, coroutine_handle<_Gen_promise_base<_Gen_yield_t<_Ref>>> _Coro_) noexcept
explicit _Iterator(_Gen_secret_tag, coroutine_handle<_Gen_promise_base<_Gen_yield_t<_Ref>>> _Coro_) noexcept
: _Coro{_Coro_} {}

coroutine_handle<_Gen_promise_base<_Gen_yield_t<_Ref>>> _Coro;
Expand Down Expand Up @@ -440,11 +456,11 @@ public:
return *this;
}

_NODISCARD _Gen_iter<_Value, _Ref> begin() {
_NODISCARD _Gen_iter_provider<_Value, _Ref>::_Iterator begin() {
// Pre: _Coro is suspended at its initial suspend point
_STL_ASSERT(_Coro, "Can't call begin on moved-from generator");
_Coro.resume();
return _Gen_iter<_Value, _Ref>{
return typename _Gen_iter_provider<_Value, _Ref>::_Iterator{
_Gen_secret_tag{}, coroutine_handle<_Gen_promise_base<_Gen_yield_t<_Ref>>>::from_address(_Coro.address())};
}

Expand Down
19 changes: 11 additions & 8 deletions stl/inc/ranges
Original file line number Diff line number Diff line change
Expand Up @@ -86,20 +86,23 @@ namespace ranges {
requires (_Extent != dynamic_extent)
inline constexpr auto _Compile_time_max_size<const span<_Ty, _Extent>> = _Extent;

#ifdef __cpp_lib_byte
using _Elements_alloc_type = byte;
#else
using _Elements_alloc_type = char;
#endif

_EXPORT_STD template <range _Rng, class _Alloc = allocator<_Elements_alloc_type>>
#if defined(__cpp_lib_byte)
_EXPORT_STD template <range _Rng, class _Alloc = allocator<byte>>
#else // ^^^ defined(__cpp_lib_byte) / !defined(__cpp_lib_byte) vvv
_EXPORT_STD template <range _Rng, class _Alloc>
#endif // ^^^ !defined(__cpp_lib_byte) ^^^
struct elements_of {
/* [[no_unique_address]] */ _Rng range;
/* [[no_unique_address]] */ _Alloc allocator{};
};

template <class _Rng, class _Alloc = allocator<_Elements_alloc_type>>
#if defined(__cpp_lib_byte)
template <class _Rng, class _Alloc = allocator<byte>>
elements_of(_Rng&&, _Alloc = {}) -> elements_of<_Rng&&, _Alloc>;
#else // ^^^ defined(__cpp_lib_byte) / !defined(__cpp_lib_byte) vvv
template <class _Rng, class _Alloc>
elements_of(_Rng&&, _Alloc) -> elements_of<_Rng&&, _Alloc>;
#endif // ^^^ !defined(__cpp_lib_byte) ^^^
#endif // _HAS_CXX23

// clang-format off
Expand Down
39 changes: 39 additions & 0 deletions tests/std/tests/P2502R2_generator/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#include <random>
#include <ranges>
#include <sstream>
#include <type_traits>
#include <utility>
#include <vector>

namespace ranges = std::ranges;
Expand Down Expand Up @@ -281,6 +283,39 @@ void arbitrary_range_test() {

assert(ranges::equal(yield_arbitrary_ranges(), std::array{40, 30, 20, 10, 0, 1, 2, 3, 500, 400, 300}));
}

#ifndef _M_CEE // TRANSITION, VSO-1659496
template <class T>
struct holder {
T t;
};

struct incomplete;

void adl_proof_test() {
using validator = holder<incomplete>*;
auto yield_range = []() -> std::generator<validator> {
co_yield ranges::elements_of(
ranges::views::repeat(nullptr, 42) | ranges::views::transform([](std::nullptr_t) { return validator{}; }));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should include <cstddef> for std::nullptr_t.

};

using R = decltype(yield_range());
static_assert(ranges::input_range<R>);

using It = ranges::iterator_t<R>;
static_assert(std::is_same_v<decltype(&std::declval<It&>()), It*>);

using Promise = R::promise_type;
static_assert(std::is_same_v<decltype(&std::declval<Promise&>()), Promise*>);

std::size_t i = 0;
for (const auto elem : yield_range()) {
++i;
assert(elem == nullptr);
}
assert(i == 42);
}
#endif // ^^^ no workaround ^^^
#endif // ^^^ no workaround ^^^

int main() {
Expand Down Expand Up @@ -339,5 +374,9 @@ int main() {
#if !(defined(__clang__) && defined(_M_IX86)) // TRANSITION, LLVM-56507
recursive_test();
arbitrary_range_test();

#ifndef _M_CEE // TRANSITION, VSO-1659496
adl_proof_test();
#endif // ^^^ no workaround ^^^
#endif // ^^^ no workaround ^^^
}