Merge "libutil: add async collection mechanism" into main
This commit is contained in:
commit
619a93bd54
101
src/libutil/async-collect.hh
Normal file
101
src/libutil/async-collect.hh
Normal file
|
@ -0,0 +1,101 @@
|
||||||
|
#pragma once
|
||||||
|
/// @file
|
||||||
|
|
||||||
|
#include <kj/async.h>
|
||||||
|
#include <kj/common.h>
|
||||||
|
#include <kj/vector.h>
|
||||||
|
#include <list>
|
||||||
|
#include <optional>
|
||||||
|
#include <type_traits>
|
||||||
|
|
||||||
|
namespace nix {
|
||||||
|
|
||||||
|
template<typename K, typename V>
|
||||||
|
class AsyncCollect
|
||||||
|
{
|
||||||
|
public:
|
||||||
|
using Item = std::conditional_t<std::is_void_v<V>, K, std::pair<K, V>>;
|
||||||
|
|
||||||
|
private:
|
||||||
|
kj::ForkedPromise<void> allPromises;
|
||||||
|
std::list<Item> results;
|
||||||
|
size_t remaining;
|
||||||
|
|
||||||
|
kj::ForkedPromise<void> signal;
|
||||||
|
kj::Maybe<kj::Own<kj::PromiseFulfiller<void>>> notify;
|
||||||
|
|
||||||
|
void oneDone(Item item)
|
||||||
|
{
|
||||||
|
results.emplace_back(std::move(item));
|
||||||
|
remaining -= 1;
|
||||||
|
KJ_IF_MAYBE (n, notify) {
|
||||||
|
(*n)->fulfill();
|
||||||
|
notify = nullptr;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kj::Promise<void> collectorFor(K key, kj::Promise<V> promise)
|
||||||
|
{
|
||||||
|
if constexpr (std::is_void_v<V>) {
|
||||||
|
return promise.then([this, key{std::move(key)}] { oneDone(std::move(key)); });
|
||||||
|
} else {
|
||||||
|
return promise.then([this, key{std::move(key)}](V v) {
|
||||||
|
oneDone(Item{std::move(key), std::move(v)});
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kj::ForkedPromise<void> waitForAll(kj::Array<std::pair<K, kj::Promise<V>>> & promises)
|
||||||
|
{
|
||||||
|
kj::Vector<kj::Promise<void>> wrappers;
|
||||||
|
for (auto & [key, promise] : promises) {
|
||||||
|
wrappers.add(collectorFor(std::move(key), std::move(promise)));
|
||||||
|
}
|
||||||
|
|
||||||
|
return kj::joinPromisesFailFast(wrappers.releaseAsArray()).fork();
|
||||||
|
}
|
||||||
|
|
||||||
|
public:
|
||||||
|
AsyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> && promises)
|
||||||
|
: allPromises(waitForAll(promises))
|
||||||
|
, remaining(promises.size())
|
||||||
|
, signal{nullptr}
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
kj::Promise<std::optional<Item>> next()
|
||||||
|
{
|
||||||
|
if (remaining == 0 && results.empty()) {
|
||||||
|
return {std::nullopt};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!results.empty()) {
|
||||||
|
auto result = std::move(results.front());
|
||||||
|
results.pop_front();
|
||||||
|
return {{std::move(result)}};
|
||||||
|
}
|
||||||
|
|
||||||
|
if (notify == nullptr) {
|
||||||
|
auto pair = kj::newPromiseAndFulfiller<void>();
|
||||||
|
notify = std::move(pair.fulfiller);
|
||||||
|
signal = pair.promise.fork();
|
||||||
|
}
|
||||||
|
|
||||||
|
return signal.addBranch().exclusiveJoin(allPromises.addBranch()).then([this] {
|
||||||
|
return next();
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collect the results of a list of promises, in order of completion.
|
||||||
|
* Once any input promise is rejected all promises that have not been
|
||||||
|
* resolved or rejected will be cancelled and the exception rethrown.
|
||||||
|
*/
|
||||||
|
template<typename K, typename V>
|
||||||
|
AsyncCollect<K, V> asyncCollect(kj::Array<std::pair<K, kj::Promise<V>>> promises)
|
||||||
|
{
|
||||||
|
return AsyncCollect<K, V>(std::move(promises));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
|
@ -53,6 +53,7 @@ libutil_headers = files(
|
||||||
'archive.hh',
|
'archive.hh',
|
||||||
'args/root.hh',
|
'args/root.hh',
|
||||||
'args.hh',
|
'args.hh',
|
||||||
|
'async-collect.hh',
|
||||||
'async-semaphore.hh',
|
'async-semaphore.hh',
|
||||||
'backed-string-view.hh',
|
'backed-string-view.hh',
|
||||||
'box_ptr.hh',
|
'box_ptr.hh',
|
||||||
|
|
104
tests/unit/libutil/async-collect.cc
Normal file
104
tests/unit/libutil/async-collect.cc
Normal file
|
@ -0,0 +1,104 @@
|
||||||
|
#include "async-collect.hh"
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <kj/array.h>
|
||||||
|
#include <kj/async.h>
|
||||||
|
#include <kj/exception.h>
|
||||||
|
#include <stdexcept>
|
||||||
|
|
||||||
|
namespace nix {
|
||||||
|
|
||||||
|
TEST(AsyncCollect, void)
|
||||||
|
{
|
||||||
|
kj::EventLoop loop;
|
||||||
|
kj::WaitScope waitScope(loop);
|
||||||
|
|
||||||
|
auto a = kj::newPromiseAndFulfiller<void>();
|
||||||
|
auto b = kj::newPromiseAndFulfiller<void>();
|
||||||
|
auto c = kj::newPromiseAndFulfiller<void>();
|
||||||
|
auto d = kj::newPromiseAndFulfiller<void>();
|
||||||
|
|
||||||
|
auto collect = asyncCollect(kj::arr(
|
||||||
|
std::pair(1, std::move(a.promise)),
|
||||||
|
std::pair(2, std::move(b.promise)),
|
||||||
|
std::pair(3, std::move(c.promise)),
|
||||||
|
std::pair(4, std::move(d.promise))
|
||||||
|
));
|
||||||
|
|
||||||
|
auto p = collect.next();
|
||||||
|
ASSERT_FALSE(p.poll(waitScope));
|
||||||
|
|
||||||
|
// collection is ordered
|
||||||
|
c.fulfiller->fulfill();
|
||||||
|
b.fulfiller->fulfill();
|
||||||
|
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_EQ(p.wait(waitScope), 3);
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_EQ(p.wait(waitScope), 2);
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_FALSE(p.poll(waitScope));
|
||||||
|
|
||||||
|
// exceptions propagate
|
||||||
|
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||||
|
|
||||||
|
// first exception aborts collection
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(AsyncCollect, nonVoid)
|
||||||
|
{
|
||||||
|
kj::EventLoop loop;
|
||||||
|
kj::WaitScope waitScope(loop);
|
||||||
|
|
||||||
|
auto a = kj::newPromiseAndFulfiller<int>();
|
||||||
|
auto b = kj::newPromiseAndFulfiller<int>();
|
||||||
|
auto c = kj::newPromiseAndFulfiller<int>();
|
||||||
|
auto d = kj::newPromiseAndFulfiller<int>();
|
||||||
|
|
||||||
|
auto collect = asyncCollect(kj::arr(
|
||||||
|
std::pair(1, std::move(a.promise)),
|
||||||
|
std::pair(2, std::move(b.promise)),
|
||||||
|
std::pair(3, std::move(c.promise)),
|
||||||
|
std::pair(4, std::move(d.promise))
|
||||||
|
));
|
||||||
|
|
||||||
|
auto p = collect.next();
|
||||||
|
ASSERT_FALSE(p.poll(waitScope));
|
||||||
|
|
||||||
|
// collection is ordered
|
||||||
|
c.fulfiller->fulfill(1);
|
||||||
|
b.fulfiller->fulfill(2);
|
||||||
|
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_EQ(p.wait(waitScope), std::pair(3, 1));
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_EQ(p.wait(waitScope), std::pair(2, 2));
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_FALSE(p.poll(waitScope));
|
||||||
|
|
||||||
|
// exceptions propagate
|
||||||
|
a.fulfiller->rejectIfThrows([] { throw std::runtime_error("test"); });
|
||||||
|
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||||
|
|
||||||
|
// first exception aborts collection
|
||||||
|
p = collect.next();
|
||||||
|
ASSERT_TRUE(p.poll(waitScope));
|
||||||
|
ASSERT_THROW(p.wait(waitScope), kj::Exception);
|
||||||
|
}
|
||||||
|
}
|
|
@ -39,6 +39,7 @@ liblixutil_test_support = declare_dependency(
|
||||||
)
|
)
|
||||||
|
|
||||||
libutil_tests_sources = files(
|
libutil_tests_sources = files(
|
||||||
|
'libutil/async-collect.cc',
|
||||||
'libutil/async-semaphore.cc',
|
'libutil/async-semaphore.cc',
|
||||||
'libutil/canon-path.cc',
|
'libutil/canon-path.cc',
|
||||||
'libutil/checked-arithmetic.cc',
|
'libutil/checked-arithmetic.cc',
|
||||||
|
|
Loading…
Reference in a new issue