libutil: add async collection mechanism
like kj::joinPromisesFailFast this allows waiting for the results of
multiple promises at once, but unlike it not all input promises must
be complete (or any of them failed) for results to become available.
Change-Id: I0e4a37e7bd90651d56b33d0bc5afbadc56cde70c
This commit is contained in:
parent
ca9256a789
commit
531d040e8c
101
src/libutil/async-collect.hh
Normal file
101
src/libutil/async-collect.hh
Normal file
|
@ -0,0 +1,101 @@
|
|||
#pragma once
|
||||
/// @file
|
||||
|
||||
#include <kj/async.h>
|
||||
#include <kj/common.h>
|
||||
#include <kj/vector.h>
|
||||
#include <list>
|
||||
#include <optional>
|
||||
#include <type_traits>
|
||||
|
||||
namespace nix {
|
||||
|
||||
template<typename K, typename V>
|
||||
class AsyncCollect
|
||||
{
|
||||
public:
|
||||
using Item = std::conditional_t<std::is_void_v<V>, K, std::pair<K, V>>;
|
||||
|
||||
private:
|
||||
kj::ForkedPromise<void> allPromises;
|
||||
std::list<Item> results;
|
||||
size_t remaining;
|
||||
|
||||
kj::ForkedPromise<void> signal;
|
||||
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> notify;
|
||||
|
||||
void oneDone(Item item)
|
||||
{
|
||||
results.emplace_back(std::move(item));
|
||||
remaining -= 1;
|
||||
KJ_IF_MAYBE (n, notify) {
|
||||
(*n)->fulfill();
|
||||
notify = nullptr;
|
||||
}
|
||||
}
|
||||
|
||||
kj::Promise<void> collectorFor(K key, kj::Promise<V> promise)
|
||||
{
|
||||
if constexpr (std::is_void_v<V>) {
|
||||
return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); });
|
||||
} else {
|
||||
return promise.then([this, key{std::move(key)}](V v) {
|
||||
oneDone(Item{std::move(key), std::move(v)});
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
kj::ForkedPromise<void> waitForAll(kj::Array<std::pair<K, kj::Promise<V>>> & promises)
|
||||
{
|
||||
kj::Vector<kj::Promise<void>> wrappers;
|
||||
for (auto & [key, promise] : promises) {
|
||||
wrappers.add(collectorFor(std::move(key), std::move(promise)));
|
||||
}
|
||||
|
||||
return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork();
|
||||
}
|
||||
|
||||
public:
|
||||
AsyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> && promises)
|
||||
: allPromises(waitForAll(promises))
|
||||
, remaining(promises.size())
|
||||
, signal{nullptr}
|
||||
{
|
||||
}
|
||||
|
||||
kj::Promise<std::optional<Item>> next()
|
||||
{
|
||||
if (remaining == 0 && results.empty()) {
|
||||
return {std::nullopt};
|
||||
}
|
||||
|
||||
if (!results.empty()) {
|
||||
auto result = std::move(results.front());
|
||||
results.pop_front();
|
||||
return {{std::move(result)}};
|
||||
}
|
||||
|
||||
if (notify == nullptr) {
|
||||
auto pair = kj::newPromiseAndFulfiller<void>();
|
||||
notify = std::move(pair.fulfiller);
|
||||
signal = pair.promise.fork();
|
||||
}
|
||||
|
||||
return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] {
|
||||
return next();
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Collect the results of a list of promises, in order of completion.
|
||||
* Once any input promise is rejected all promises that have not been
|
||||
* resolved or rejected will be cancelled and the exception rethrown.
|
||||
*/
|
||||
template<typename K, typename V>
|
||||
AsyncCollect<K, V> asyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> promises)
|
||||
{
|
||||
return AsyncCollect<K, V>(std::move(promises));
|
||||
}
|
||||
|
||||
}
|
|
@ -53,6 +53,7 @@ libutil_headers = files(
|
|||
'archive.hh',
|
||||
'args/root.hh',
|
||||
'args.hh',
|
||||
'async-collect.hh',
|
||||
'async-semaphore.hh',
|
||||
'backed-string-view.hh',
|
||||
'box_ptr.hh',
|
||||
|
|
104
tests/unit/libutil/async-collect.cc
Normal file
104
tests/unit/libutil/async-collect.cc
Normal file
|
@ -0,0 +1,104 @@
|
|||
#include "async-collect.hh"
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <kj/array.h>
|
||||
#include <kj/async.h>
|
||||
#include <kj/exception.h>
|
||||
#include <stdexcept>
|
||||
|
||||
namespace nix {
|
||||
|
||||
TEST(AsyncCollect, void)
|
||||
{
|
||||
kj::EventLoop loop;
|
||||
kj::WaitScope waitScope(loop);
|
||||
|
||||
auto a = kj::newPromiseAndFulfiller<void>();
|
||||
auto b = kj::newPromiseAndFulfiller<void>();
|
||||
auto c = kj::newPromiseAndFulfiller<void>();
|
||||
auto d = kj::newPromiseAndFulfiller<void>();
|
||||
|
||||
auto collect = asyncCollect(kj::arr(
|
||||
std::pair(1, std::move(a.promise)),
|
||||
std::pair(2, std::move(b.promise)),
|
||||
std::pair(3, std::move(c.promise)),
|
||||
std::pair(4, std::move(d.promise))
|
||||
));
|
||||
|
||||
auto p = collect.next();
|
||||
ASSERT_FALSE(p.poll(waitScope));
|
||||
|
||||
// collection is ordered
|
||||
c.fulfiller->fulfill();
|
||||
b.fulfiller->fulfill();
|
||||
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_EQ(p.wait(waitScope), 3);
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_EQ(p.wait(waitScope), 2);
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_FALSE(p.poll(waitScope));
|
||||
|
||||
// exceptions propagate
|
||||
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||
|
||||
// first exception aborts collection
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||
}
|
||||
|
||||
TEST(AsyncCollect, nonVoid)
|
||||
{
|
||||
kj::EventLoop loop;
|
||||
kj::WaitScope waitScope(loop);
|
||||
|
||||
auto a = kj::newPromiseAndFulfiller<int>();
|
||||
auto b = kj::newPromiseAndFulfiller<int>();
|
||||
auto c = kj::newPromiseAndFulfiller<int>();
|
||||
auto d = kj::newPromiseAndFulfiller<int>();
|
||||
|
||||
auto collect = asyncCollect(kj::arr(
|
||||
std::pair(1, std::move(a.promise)),
|
||||
std::pair(2, std::move(b.promise)),
|
||||
std::pair(3, std::move(c.promise)),
|
||||
std::pair(4, std::move(d.promise))
|
||||
));
|
||||
|
||||
auto p = collect.next();
|
||||
ASSERT_FALSE(p.poll(waitScope));
|
||||
|
||||
// collection is ordered
|
||||
c.fulfiller->fulfill(1);
|
||||
b.fulfiller->fulfill(2);
|
||||
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_EQ(p.wait(waitScope), std::pair(3, 1));
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_EQ(p.wait(waitScope), std::pair(2, 2));
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_FALSE(p.poll(waitScope));
|
||||
|
||||
// exceptions propagate
|
||||
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
|
||||
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||
|
||||
// first exception aborts collection
|
||||
p = collect.next();
|
||||
ASSERT_TRUE(p.poll(waitScope));
|
||||
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||
}
|
||||
}
|
|
@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency(
|
|||
)
|
||||
|
||||
libutil_tests_sources = files(
|
||||
'libutil/async-collect.cc',
|
||||
'libutil/async-semaphore.cc',
|
||||
'libutil/canon-path.cc',
|
||||
'libutil/checked-arithmetic.cc',
|
||||
|
|
Loading…
Reference in a new issue