diff --git a/src/libutil/async-collect.hh b/src/libutil/async-collect.hh new file mode 100644 index 000000000..9e0b8bad9 --- /dev/null +++ b/src/libutil/async-collect.hh @@ -0,0 +1,101 @@ +#pragma once +/// @file + +#include +#include +#include +#include +#include +#include + +namespace nix { + +template +class AsyncCollect +{ +public: + using Item = std::conditional_t, K, std::pair>; + +private: + kj::ForkedPromise allPromises; + std::list results; + size_t remaining; + + kj::ForkedPromise signal; + kj::Maybe>> notify; + + void oneDone(Item item) + { + results.emplace_back(std::move(item)); + remaining -= 1; + KJ_IF_MAYBE (n, notify) { + (*n)->fulfill(); + notify = nullptr; + } + } + + kj::Promise collectorFor(K key, kj::Promise promise) + { + if constexpr (std::is_void_v) { + return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); }); + } else { + return promise.then([this, key{std::move(key)}](V v) { + oneDone(Item{std::move(key), std::move(v)}); + }); + } + } + + kj::ForkedPromise waitForAll(kj::Array>> & promises) + { + kj::Vector> wrappers; + for (auto & [key, promise] : promises) { + wrappers.add(collectorFor(std::move(key), std::move(promise))); + } + + return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork(); + } + +public: + AsyncCollect(kj::Array>> && promises) + : allPromises(waitForAll(promises)) + , remaining(promises.size()) + , signal{nullptr} + { + } + + kj::Promise> next() + { + if (remaining == 0 && results.empty()) { + return {std::nullopt}; + } + + if (!results.empty()) { + auto result = std::move(results.front()); + results.pop_front(); + return {{std::move(result)}}; + } + + if (notify == nullptr) { + auto pair = kj::newPromiseAndFulfiller(); + notify = std::move(pair.fulfiller); + signal = pair.promise.fork(); + } + + return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] { + return next(); + }); + } +}; + +/** + * Collect the results of a list of promises, in order of completion. + * Once any input promise is rejected all promises that have not been + * resolved or rejected will be cancelled and the exception rethrown. + */ +template +AsyncCollect asyncCollect(kj::Array>> promises) +{ + return AsyncCollect(std::move(promises)); +} + +} diff --git a/src/libutil/meson.build b/src/libutil/meson.build index 89eeed133..afca4e021 100644 --- a/src/libutil/meson.build +++ b/src/libutil/meson.build @@ -53,6 +53,7 @@ libutil_headers = files( 'archive.hh', 'args/root.hh', 'args.hh', + 'async-collect.hh', 'async-semaphore.hh', 'backed-string-view.hh', 'box_ptr.hh', diff --git a/tests/unit/libutil/async-collect.cc b/tests/unit/libutil/async-collect.cc new file mode 100644 index 000000000..770374d21 --- /dev/null +++ b/tests/unit/libutil/async-collect.cc @@ -0,0 +1,104 @@ +#include "async-collect.hh" + +#include +#include +#include +#include +#include + +namespace nix { + +TEST(AsyncCollect, void) +{ + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto a = kj::newPromiseAndFulfiller(); + auto b = kj::newPromiseAndFulfiller(); + auto c = kj::newPromiseAndFulfiller(); + auto d = kj::newPromiseAndFulfiller(); + + auto collect = asyncCollect(kj::arr( + std::pair(1, std::move(a.promise)), + std::pair(2, std::move(b.promise)), + std::pair(3, std::move(c.promise)), + std::pair(4, std::move(d.promise)) + )); + + auto p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // collection is ordered + c.fulfiller->fulfill(); + b.fulfiller->fulfill(); + + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), 3); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), 2); + + p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // exceptions propagate + a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); }); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); + + // first exception aborts collection + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); +} + +TEST(AsyncCollect, nonVoid) +{ + kj::EventLoop loop; + kj::WaitScope waitScope(loop); + + auto a = kj::newPromiseAndFulfiller(); + auto b = kj::newPromiseAndFulfiller(); + auto c = kj::newPromiseAndFulfiller(); + auto d = kj::newPromiseAndFulfiller(); + + auto collect = asyncCollect(kj::arr( + std::pair(1, std::move(a.promise)), + std::pair(2, std::move(b.promise)), + std::pair(3, std::move(c.promise)), + std::pair(4, std::move(d.promise)) + )); + + auto p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // collection is ordered + c.fulfiller->fulfill(1); + b.fulfiller->fulfill(2); + + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), std::pair(3, 1)); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_EQ(p.wait(waitScope), std::pair(2, 2)); + + p = collect.next(); + ASSERT_FALSE(p.poll(waitScope)); + + // exceptions propagate + a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); }); + + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); + + // first exception aborts collection + p = collect.next(); + ASSERT_TRUE(p.poll(waitScope)); + ASSERT_THROW(p.wait(waitScope), kj::Exception); +} +} diff --git a/tests/unit/meson.build b/tests/unit/meson.build index 5baf10de2..525fdea0b 100644 --- a/tests/unit/meson.build +++ b/tests/unit/meson.build @@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency( ) libutil_tests_sources = files( + 'libutil/async-collect.cc', 'libutil/async-semaphore.cc', 'libutil/canon-path.cc', 'libutil/checked-arithmetic.cc',