From 531d040e8c2d211408c84ae23421aaa45b3b5a7a Mon Sep 17 00:00:00 2001 From: eldritch horrors Date: Tue, 24 Sep 2024 00:21:16 +0200 Subject: [PATCH] libutil: add async collection mechanism like kj::joinPromisesFailFast this allows waiting for the results of multiple promises at once, but unlike it not all input promises must be complete (or any of them failed) for results to become available. Change-Id: I0e4a37e7bd90651d56b33d0bc5afbadc56cde70c --- src/libutil/async-collect.hh | 101 +++++++++++++++++++++++++++ src/libutil/meson.build | 1 + tests/unit/libutil/async-collect.cc | 104 ++++++++++++++++++++++++++++ tests/unit/meson.build | 1 + 4 files changed, 207 insertions(+) create mode 100644 src/libutil/async-collect.hh create mode 100644 tests/unit/libutil/async-collect.cc diff --git a/src/libutil/async-collect.hh b/src/libutil/async-collect.hh new file mode 100644 index 000000000..9e0b8bad9 --- /dev/null +++ b/src/libutil/async-collect.hh @@ -0,0 +1,101 @@ +#pragma once +/// @file + +#include +#include +#include +#include +#include +#include + +namespace nix { + +template +class AsyncCollect +{ +public: + using Item = std::conditional_t, K, std::pair>; + +private: + kj::ForkedPromise allPromises; + std::list results; + size_t remaining; + + kj::ForkedPromise signal; + kj::Maybe>> notify; + + void oneDone(Item item) + { + results.emplace_back(std::move(item)); + remaining -= 1; + KJ_IF_MAYBE (n, notify) { + (*n)->fulfill(); + notify = nullptr; + } + } + + kj::Promise collectorFor(K key, kj::Promise promise) + { + if constexpr (std::is_void_v) { + return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); }); + } else { + return promise.then([this, key{std::move(key)}](V v) { + oneDone(Item{std::move(key), std::move(v)}); + }); + } + } + + kj::ForkedPromise waitForAll(kj::Array>> & promises) + { + kj::Vector> wrappers; + for (auto & [key, promise] : promises) { + wrappers.add(collectorFor(std::move(key), std::move(promise))); + } + + return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork(); + } + +public: + AsyncCollect(kj::Array>> && promises) + : allPromises(waitForAll(promises)) + , remaining(promises.size()) + , signal{nullptr} + { + } + + kj::Promise> next() + { + if (remaining == 0 && results.empty()) { + return {std::nullopt}; + } + + if (!results.empty()) { + auto result = std::move(results.front()); + results.pop_front(); + return {{std::move(result)}}; + } + + if (notify == nullptr) { + auto pair = kj::newPromiseAndFulfiller(); + notify = std::move(pair.fulfiller); + signal = pair.promise.fork(); + } + + return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] { + return next(); + }); + } +}; + +/** + * Collect the results of a list of promises, in order of completion. + * Once any input promise is rejected all promises that have not been + * resolved or rejected will be cancelled and the exception rethrown. + */ +template +AsyncCollect asyncCollect(kj::Array>> promises) +{ + return AsyncCollect(std::move(promises)); +} + +} diff --git a/src/libutil/meson.build b/src/libutil/meson.build index 89eeed133..afca4e021 100644 --- a/src/libutil/meson.build +++ b/src/libutil/meson.build @@ -53,6 +53,7 @@ libutil_headers = files( 'archive.hh', 'args/root.hh', 'args.hh', + 'async-collect.hh', 'async-semaphore.hh', 'backed-string-view.hh', 'box_ptr.hh', diff --git a/tests/unit/libutil/async-collect.cc b/tests/unit/libutil/async-collect.cc new file mode 100644 index 000000000..770374d21 --- /dev/null +++ b/tests/unit/libutil/async-collect.cc @@ -0,0 +1,104 @@ +#include "async-collect.hh" + +#include +#include +#include +#include +#include + +namespace nix { + +TEST(AsyncCollect, void) +{ + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto a = kj::newPromiseAndFulfiller(); + auto b = kj::newPromiseAndFulfiller(); + auto c = kj::newPromiseAndFulfiller(); + auto d = kj::newPromiseAndFulfiller(); + + auto collect = asyncCollect(kj::arr( + std::pair(1, std::move(a.promise)), + std::pair(2, std::move(b.promise)), + std::pair(3, std::move(c.promise)), + std::pair(4, std::move(d.promise)) + )); + + auto p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // collection is ordered + c.fulfiller->fulfill(); + b.fulfiller->fulfill(); + + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), 3); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), 2); + + p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // exceptions propagate + a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); }); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); + + // first exception aborts collection + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); +} + +TEST(AsyncCollect, nonVoid) +{ + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto a = kj::newPromiseAndFulfiller(); + auto b = kj::newPromiseAndFulfiller(); + auto c = kj::newPromiseAndFulfiller(); + auto d = kj::newPromiseAndFulfiller(); + + auto collect = asyncCollect(kj::arr( + std::pair(1, std::move(a.promise)), + std::pair(2, std::move(b.promise)), + std::pair(3, std::move(c.promise)), + std::pair(4, std::move(d.promise)) + )); + + auto p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // collection is ordered + c.fulfiller->fulfill(1); + b.fulfiller->fulfill(2); + + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), std::pair(3, 1)); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), std::pair(2, 2)); + + p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // exceptions propagate + a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); }); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); + + // first exception aborts collection + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); +} +} diff --git a/tests/unit/meson.build b/tests/unit/meson.build index 5baf10de2..525fdea0b 100644 --- a/tests/unit/meson.build +++ b/tests/unit/meson.build @@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency( ) libutil_tests_sources = files( + 'libutil/async-collect.cc', 'libutil/async-semaphore.cc', 'libutil/canon-path.cc', 'libutil/checked-arithmetic.cc',