From f19e7f697470a9ed4a02c68f65f55ee2add998b4 Mon Sep 17 00:00:00 2001 From: eldritch horrors Date: Tue, 19 Mar 2024 18:02:22 +0100 Subject: [PATCH] libutil: basic generator type with mapping Change-Id: I2cebcefa0148b631fb30df4c8cfa92167a407e34 --- src/libutil/generator.hh | 181 ++++++++++++++++++++++++++++++++ tests/unit/libutil/generator.cc | 141 +++++++++++++++++++++++++ 2 files changed, 322 insertions(+) create mode 100644 src/libutil/generator.hh create mode 100644 tests/unit/libutil/generator.cc diff --git a/src/libutil/generator.hh b/src/libutil/generator.hh new file mode 100644 index 000000000..ff524fe5d --- /dev/null +++ b/src/libutil/generator.hh @@ -0,0 +1,181 @@ +#pragma once + +#include "overloaded.hh" + +#include +#include +#include +#include + +namespace nix { + +template +struct Generator : private Generator +{ + struct promise_type; + using handle_type = std::coroutine_handle; + + explicit Generator(handle_type h) : Generator{h, h.promise().state} {} + + using Generator::operator bool; + using Generator::operator(); + + operator Generator &() & + { + return *this; + } + operator Generator() && + { + return std::move(*this); + } +}; + +template +struct Generator +{ + template + friend struct Generator::promise_type; + + struct promise_state; + + struct _link + { + std::coroutine_handle<> handle{}; + promise_state * state{}; + }; + + struct promise_state + { + std::variant<_link, T> value{}; + std::exception_ptr exception{}; + _link parent{}; + }; + + // NOTE coroutine handles are LiteralType, own a memory resource (that may + // itself own unique resources), and are "typically TriviallyCopyable". we + // need to take special care to wrap this into a less footgunny interface, + // which mostly means move-only. + Generator(Generator && other) + { + swap(other); + } + + Generator & operator=(Generator && other) + { + Generator(std::move(other)).swap(*this); + return *this; + } + + ~Generator() + { + if (h) { + h.destroy(); + } + } + + explicit operator bool() + { + return ensure(); + } + + T operator()() + { + ensure(); + auto result = std::move(*current); + current = nullptr; + return result; + } + +protected: + std::coroutine_handle<> h{}; + _link active{}; + T * current{}; + + Generator(std::coroutine_handle<> h, promise_state & state) : h(h), active(h, &state) {} + + void swap(Generator & other) + { + std::swap(h, other.h); + std::swap(active, other.active); + std::swap(current, other.current); + } + + bool ensure() + { + while (!current && active.handle) { + active.handle.resume(); + auto & p = *active.state; + if (p.exception) { + std::rethrow_exception(p.exception); + } else if (active.handle.done()) { + active = p.parent; + } else { + std::visit( + overloaded{ + [&](_link & inner) { + auto base = inner.state; + while (base->parent.handle) { + base = base->parent.state; + } + base->parent = active; + active = inner; + }, + [&](T & value) { current = &value; }, + }, + p.value + ); + } + } + return current; + } +}; + +template +struct Generator::promise_type +{ + Generator::promise_state state; + Transform convert; + std::optional> inner; + + Generator get_return_object() + { + return Generator(handle_type::from_promise(*this)); + } + std::suspend_always initial_suspend() + { + return {}; + } + std::suspend_always final_suspend() noexcept + { + return {}; + } + void unhandled_exception() + { + state.exception = std::current_exception(); + } + + template + requires requires(Transform t, From && f) { + { + t(std::forward(f)) + } -> std::convertible_to; + } + std::suspend_always yield_value(From && from) + { + state.value = convert(std::forward(from)); + return {}; + } + + template + requires requires(Transform t, From f) { static_cast>(t(std::move(f))); } + std::suspend_always yield_value(From from) + { + inner = static_cast>(convert(std::move(from))); + state.value = inner->active; + return {}; + } + + void return_void() {} +}; + +} diff --git a/tests/unit/libutil/generator.cc b/tests/unit/libutil/generator.cc new file mode 100644 index 000000000..291a7c24c --- /dev/null +++ b/tests/unit/libutil/generator.cc @@ -0,0 +1,141 @@ +#include "generator.hh" + +#include +#include +#include + +namespace nix { + +TEST(Generator, yields) +{ + auto g = []() -> Generator { + co_yield 1; + co_yield 2; + }(); + + ASSERT_TRUE(bool(g)); + ASSERT_EQ(g(), 1); + ASSERT_EQ(g(), 2); + ASSERT_FALSE(bool(g)); +} + +TEST(Generator, nests) +{ + auto g = []() -> Generator { + co_yield 1; + co_yield []() -> Generator { + co_yield 9; + co_yield []() -> Generator { + co_yield 99; + co_yield 100; + }(); + }(); + + auto g2 = []() -> Generator { + co_yield []() -> Generator { + co_yield 2000; + co_yield 2001; + }(); + co_yield 1001; + }(); + + co_yield g2(); + co_yield std::move(g2); + co_yield 2; + }(); + + ASSERT_TRUE(bool(g)); + ASSERT_EQ(g(), 1); + ASSERT_EQ(g(), 9); + ASSERT_EQ(g(), 99); + ASSERT_EQ(g(), 100); + ASSERT_EQ(g(), 2000); + ASSERT_EQ(g(), 2001); + ASSERT_EQ(g(), 1001); + ASSERT_EQ(g(), 2); + ASSERT_FALSE(bool(g)); +} + +TEST(Generator, nestsExceptions) +{ + auto g = []() -> Generator { + co_yield 1; + co_yield []() -> Generator { + co_yield 9; + throw 1; + co_yield 10; + }(); + co_yield 2; + }(); + + ASSERT_TRUE(bool(g)); + ASSERT_EQ(g(), 1); + ASSERT_EQ(g(), 9); + ASSERT_THROW(g(), int); +} + +TEST(Generator, exception) +{ + { + auto g = []() -> Generator { + throw 1; + co_return; + }(); + + ASSERT_THROW(void(bool(g)), int); + } + { + auto g = []() -> Generator { + throw 1; + co_return; + }(); + + ASSERT_THROW(g(), int); + } +} + +namespace { +struct Transform +{ + int state = 0; + + std::pair operator()(std::integral auto x) + { + return {x, state++}; + } + + Generator, Transform> operator()(const char *) + { + co_yield 9; + co_yield 19; + } + + Generator, Transform> operator()(Generator && inner) + { + return [](auto g) mutable -> Generator, Transform> { + while (g) { + co_yield g(); + } + }(std::move(inner)); + } +}; +} + +TEST(Generator, transform) +{ + auto g = []() -> Generator, Transform> { + co_yield int32_t(-1); + co_yield ""; + std::cerr << "1\n"; + co_yield []() -> Generator { co_yield 7; }(); + co_yield 20; + }(); + + ASSERT_EQ(g(), (std::pair{4294967295, 0})); + ASSERT_EQ(g(), (std::pair{9, 0})); + ASSERT_EQ(g(), (std::pair{19, 1})); + ASSERT_EQ(g(), (std::pair{7, 0})); + ASSERT_EQ(g(), (std::pair{20, 1})); +} + +}