libutil: add async collection mechanism

like kj::joinPromisesFailFast this allows waiting for the results of
multiple promises at once, but unlike it not all input promises must
be complete (or any of them failed) for results to become available.

Change-Id: I0e4a37e7bd90651d56b33d0bc5afbadc56cde70c
This commit is contained in:
eldritch horrors 2024-09-24 00:21:16 +02:00
parent ca9256a789
commit 531d040e8c
4 changed files with 207 additions and 0 deletions

View file

@ -0,0 +1,101 @@
#pragma once
/// @file
#include <kj/async.h>
#include <kj/common.h>
#include <kj/vector.h>
#include <list>
#include <optional>
#include <type_traits>
namespace nix {
template<typename K, typename V>
class AsyncCollect
{
public:
using Item = std::conditional_t<std::is_void_v<V>, K, std::pair<K, V>>;
private:
kj::ForkedPromise<void> allPromises;
std::list<Item> results;
size_t remaining;
kj::ForkedPromise<void> signal;
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> notify;
void oneDone(Item item)
{
results.emplace_back(std::move(item));
remaining -= 1;
KJ_IF_MAYBE (n, notify) {
(*n)->fulfill();
notify = nullptr;
}
}
kj::Promise<void> collectorFor(K key, kj::Promise<V> promise)
{
if constexpr (std::is_void_v<V>) {
return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); });
} else {
return promise.then([this, key{std::move(key)}](V v) {
oneDone(Item{std::move(key), std::move(v)});
});
}
}
kj::ForkedPromise<void> waitForAll(kj::Array<std::pair<K, kj::Promise<V>>> & promises)
{
kj::Vector<kj::Promise<void>> wrappers;
for (auto & [key, promise] : promises) {
wrappers.add(collectorFor(std::move(key), std::move(promise)));
}
return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork();
}
public:
AsyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> && promises)
: allPromises(waitForAll(promises))
, remaining(promises.size())
, signal{nullptr}
{
}
kj::Promise<std::optional<Item>> next()
{
if (remaining == 0 && results.empty()) {
return {std::nullopt};
}
if (!results.empty()) {
auto result = std::move(results.front());
results.pop_front();
return {{std::move(result)}};
}
if (notify == nullptr) {
auto pair = kj::newPromiseAndFulfiller<void>();
notify = std::move(pair.fulfiller);
signal = pair.promise.fork();
}
return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] {
return next();
});
}
};
/**
* Collect the results of a list of promises, in order of completion.
* Once any input promise is rejected all promises that have not been
* resolved or rejected will be cancelled and the exception rethrown.
*/
template<typename K, typename V>
AsyncCollect<K, V> asyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> promises)
{
return AsyncCollect<K, V>(std::move(promises));
}
}

View file

@ -53,6 +53,7 @@ libutil_headers = files(
'archive.hh', 'archive.hh',
'args/root.hh', 'args/root.hh',
'args.hh', 'args.hh',
'async-collect.hh',
'async-semaphore.hh', 'async-semaphore.hh',
'backed-string-view.hh', 'backed-string-view.hh',
'box_ptr.hh', 'box_ptr.hh',

View file

@ -0,0 +1,104 @@
#include "async-collect.hh"
#include <gtest/gtest.h>
#include <kj/array.h>
#include <kj/async.h>
#include <kj/exception.h>
#include <stdexcept>
namespace nix {
TEST(AsyncCollect, void)
{
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto a = kj::newPromiseAndFulfiller<void>();
auto b = kj::newPromiseAndFulfiller<void>();
auto c = kj::newPromiseAndFulfiller<void>();
auto d = kj::newPromiseAndFulfiller<void>();
auto collect = asyncCollect(kj::arr(
std::pair(1, std::move(a.promise)),
std::pair(2, std::move(b.promise)),
std::pair(3, std::move(c.promise)),
std::pair(4, std::move(d.promise))
));
auto p = collect.next();
ASSERT_FALSE(p.poll(waitScope));
// collection is ordered
c.fulfiller->fulfill();
b.fulfiller->fulfill();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_EQ(p.wait(waitScope), 3);
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_EQ(p.wait(waitScope), 2);
p = collect.next();
ASSERT_FALSE(p.poll(waitScope));
// exceptions propagate
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_THROW(p.wait(waitScope), kj::Exception);
// first exception aborts collection
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_THROW(p.wait(waitScope), kj::Exception);
}
TEST(AsyncCollect, nonVoid)
{
kj::EventLoop loop;
kj::WaitScope waitScope(loop);
auto a = kj::newPromiseAndFulfiller<int>();
auto b = kj::newPromiseAndFulfiller<int>();
auto c = kj::newPromiseAndFulfiller<int>();
auto d = kj::newPromiseAndFulfiller<int>();
auto collect = asyncCollect(kj::arr(
std::pair(1, std::move(a.promise)),
std::pair(2, std::move(b.promise)),
std::pair(3, std::move(c.promise)),
std::pair(4, std::move(d.promise))
));
auto p = collect.next();
ASSERT_FALSE(p.poll(waitScope));
// collection is ordered
c.fulfiller->fulfill(1);
b.fulfiller->fulfill(2);
ASSERT_TRUE(p.poll(waitScope));
ASSERT_EQ(p.wait(waitScope), std::pair(3, 1));
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_EQ(p.wait(waitScope), std::pair(2, 2));
p = collect.next();
ASSERT_FALSE(p.poll(waitScope));
// exceptions propagate
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_THROW(p.wait(waitScope), kj::Exception);
// first exception aborts collection
p = collect.next();
ASSERT_TRUE(p.poll(waitScope));
ASSERT_THROW(p.wait(waitScope), kj::Exception);
}
}

View file

@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency(
) )
libutil_tests_sources = files( libutil_tests_sources = files(
'libutil/async-collect.cc',
'libutil/async-semaphore.cc', 'libutil/async-semaphore.cc',
'libutil/canon-path.cc', 'libutil/canon-path.cc',
'libutil/checked-arithmetic.cc', 'libutil/checked-arithmetic.cc',