libstore: de-callback-ify FileTransfer

also add a few more tests for exception propagation behavior. using
packaged_tasks and futures (which only allow a single call to a few
of their methods) introduces error paths that weren't there before.

Change-Id: I42ca5236f156fefec17df972f6e9be45989cf805
This commit is contained in:
eldritch horrors 2024-04-28 02:04:08 +02:00
parent 28a98d152c
commit b66451ae7f
3 changed files with 164 additions and 49 deletions

View file

@ -48,7 +48,7 @@ struct curlFileTransfer : public FileTransfer
FileTransferResult result; FileTransferResult result;
Activity act; Activity act;
bool done = false; // whether either the success or failure function has been called bool done = false; // whether either the success or failure function has been called
Callback<FileTransferResult> callback; std::packaged_task<FileTransferResult(std::exception_ptr, FileTransferResult)> callback;
std::function<void(TransferItem &, std::string_view data)> dataCallback; std::function<void(TransferItem &, std::string_view data)> dataCallback;
CURL * req = 0; CURL * req = 0;
bool active = false; // whether the handle has been added to the multi object bool active = false; // whether the handle has been added to the multi object
@ -83,14 +83,17 @@ struct curlFileTransfer : public FileTransfer
TransferItem(curlFileTransfer & fileTransfer, TransferItem(curlFileTransfer & fileTransfer,
const FileTransferRequest & request, const FileTransferRequest & request,
Callback<FileTransferResult> && callback, std::invocable<std::exception_ptr> auto callback,
std::function<void(TransferItem &, std::string_view data)> dataCallback) std::function<void(TransferItem &, std::string_view data)> dataCallback)
: fileTransfer(fileTransfer) : fileTransfer(fileTransfer)
, request(request) , request(request)
, act(*logger, lvlTalkative, actFileTransfer, , act(*logger, lvlTalkative, actFileTransfer,
fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri), fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri),
{request.uri}, request.parentAct) {request.uri}, request.parentAct)
, callback(std::move(callback)) , callback([cb{std::move(callback)}] (std::exception_ptr ex, FileTransferResult r) {
cb(ex);
return r;
})
, dataCallback(std::move(dataCallback)) , dataCallback(std::move(dataCallback))
{ {
requestHeaders = curl_slist_append(requestHeaders, "Accept-Encoding: zstd, br, gzip, deflate, bzip2, xz"); requestHeaders = curl_slist_append(requestHeaders, "Accept-Encoding: zstd, br, gzip, deflate, bzip2, xz");
@ -123,7 +126,7 @@ struct curlFileTransfer : public FileTransfer
{ {
assert(!done); assert(!done);
done = true; done = true;
callback.rethrow(ex); callback(ex, std::move(result));
} }
template<class T> template<class T>
@ -369,7 +372,7 @@ struct curlFileTransfer : public FileTransfer
result.cached = httpStatus == 304; result.cached = httpStatus == 304;
act.progress(result.bodySize, result.bodySize); act.progress(result.bodySize, result.bodySize);
done = true; done = true;
callback(std::move(result)); callback(nullptr, std::move(result));
} }
else { else {
@ -623,7 +626,7 @@ struct curlFileTransfer : public FileTransfer
} }
} }
void enqueueItem(std::shared_ptr<TransferItem> item) std::shared_ptr<TransferItem> enqueueItem(std::shared_ptr<TransferItem> item)
{ {
if (item->request.data if (item->request.data
&& !item->request.uri.starts_with("http://") && !item->request.uri.starts_with("http://")
@ -637,10 +640,11 @@ struct curlFileTransfer : public FileTransfer
state->incoming.push(item); state->incoming.push(item);
} }
wakeup(); wakeup();
return item;
} }
#if ENABLE_S3 #if ENABLE_S3
std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri) static std::tuple<std::string, std::string, Store::Params> parseS3Uri(std::string uri)
{ {
auto [path, params] = splitUriAndParams(uri); auto [path, params] = splitUriAndParams(uri);
@ -655,22 +659,29 @@ struct curlFileTransfer : public FileTransfer
} }
#endif #endif
void enqueueFileTransfer(const FileTransferRequest & request, std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) override
Callback<FileTransferResult> callback) override
{ {
enqueueFileTransfer(request, std::move(callback), {}); return enqueueFileTransfer(
request,
[](std::exception_ptr ex) {
if (ex) {
std::rethrow_exception(ex);
}
},
{}
);
} }
void enqueueFileTransfer(const FileTransferRequest & request, std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request,
Callback<FileTransferResult> callback, std::invocable<std::exception_ptr> auto callback,
std::function<void(TransferItem &, std::string_view data)> dataCallback) std::function<void(TransferItem &, std::string_view data)> dataCallback)
{ {
/* Ugly hack to support s3:// URIs. */ /* Ugly hack to support s3:// URIs. */
if (request.uri.starts_with("s3://")) { if (request.uri.starts_with("s3://")) {
// FIXME: do this on a worker thread // FIXME: do this on a worker thread
try { return std::async(std::launch::deferred, [uri{request.uri}] {
#if ENABLE_S3 #if ENABLE_S3
auto [bucketName, key, params] = parseS3Uri(request.uri); auto [bucketName, key, params] = parseS3Uri(uri);
std::string profile = getOr(params, "profile", ""); std::string profile = getOr(params, "profile", "");
std::string region = getOr(params, "region", Aws::Region::US_EAST_1); std::string region = getOr(params, "region", Aws::Region::US_EAST_1);
@ -683,19 +694,19 @@ struct curlFileTransfer : public FileTransfer
auto s3Res = s3Helper.getObject(bucketName, key); auto s3Res = s3Helper.getObject(bucketName, key);
FileTransferResult res; FileTransferResult res;
if (!s3Res.data) if (!s3Res.data)
throw FileTransferError(NotFound, "S3 object '%s' does not exist", request.uri); throw FileTransferError(NotFound, "S3 object '%s' does not exist", uri);
res.data = std::move(*s3Res.data); res.data = std::move(*s3Res.data);
callback(std::move(res)); return res;
#else #else
throw nix::Error("cannot download '%s' because Lix is not built with S3 support", request.uri); throw nix::Error("cannot download '%s' because Lix is not built with S3 support", uri);
#endif #endif
} catch (...) { callback.rethrow(); } });
return;
} }
enqueueItem(std::make_shared<TransferItem>( return enqueueItem(std::make_shared<TransferItem>(
*this, request, std::move(callback), std::move(dataCallback) *this, request, std::move(callback), std::move(dataCallback)
)); ))
->callback.get_future();
} }
void download(FileTransferRequest && request, Sink & sink) override void download(FileTransferRequest && request, Sink & sink) override
@ -724,18 +735,15 @@ struct curlFileTransfer : public FileTransfer
state->request.notify_one(); state->request.notify_one();
}); });
enqueueFileTransfer(request, enqueueFileTransfer(
{[_state](std::future<FileTransferResult> fut) { request,
[_state](std::exception_ptr ex) {
auto state(_state->lock()); auto state(_state->lock());
state->done = true; state->done = true;
try { state->exc = ex;
fut.get();
} catch (...) {
state->exc = std::current_exception();
}
state->avail.notify_one(); state->avail.notify_one();
state->request.notify_one(); state->request.notify_one();
}}, },
[_state](TransferItem & transfer, std::string_view data) { [_state](TransferItem & transfer, std::string_view data) {
auto state(_state->lock()); auto state(_state->lock());
@ -758,7 +766,8 @@ struct curlFileTransfer : public FileTransfer
thread. */ thread. */
state->data.append(data); state->data.append(data);
state->avail.notify_one(); state->avail.notify_one();
}); }
);
std::unique_ptr<FinishSink> decompressor; std::unique_ptr<FinishSink> decompressor;
@ -827,20 +836,6 @@ ref<FileTransfer> makeFileTransfer()
return makeCurlFileTransfer(); return makeCurlFileTransfer();
} }
std::future<FileTransferResult> FileTransfer::enqueueFileTransfer(const FileTransferRequest & request)
{
auto promise = std::make_shared<std::promise<FileTransferResult>>();
enqueueFileTransfer(request,
{[promise](std::future<FileTransferResult> fut) {
try {
promise->set_value(fut.get());
} catch (...) {
promise->set_exception(std::current_exception());
}
}});
return promise->get_future();
}
FileTransferResult FileTransfer::download(const FileTransferRequest & request) FileTransferResult FileTransfer::download(const FileTransferRequest & request)
{ {
return enqueueFileTransfer(request).get(); return enqueueFileTransfer(request).get();

View file

@ -95,10 +95,7 @@ struct FileTransfer
* the download. The future may throw a FileTransferError * the download. The future may throw a FileTransferError
* exception. * exception.
*/ */
virtual void enqueueFileTransfer(const FileTransferRequest & request, virtual std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request) = 0;
Callback<FileTransferResult> callback) = 0;
std::future<FileTransferResult> enqueueFileTransfer(const FileTransferRequest & request);
/** /**
* Synchronously download a file. * Synchronously download a file.

View file

@ -1,12 +1,114 @@
#include "filetransfer.hh" #include "filetransfer.hh"
#include <cstdint>
#include <exception>
#include <future> #include <future>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <netinet/in.h>
#include <stdexcept>
#include <string_view>
#include <sys/poll.h>
#include <sys/socket.h>
#include <thread>
#include <unistd.h>
// local server tests don't work on darwin without some incantations
// the horrors do not want to look up. contributions welcome though!
#if __APPLE__
#define NOT_ON_DARWIN(n) DISABLED_##n
#else
#define NOT_ON_DARWIN(n) n
#endif
using namespace std::chrono_literals; using namespace std::chrono_literals;
namespace nix { namespace nix {
static std::tuple<uint16_t, AutoCloseFD>
serveHTTP(std::string_view status, std::string_view headers, std::function<std::string_view()> content)
{
AutoCloseFD listener(::socket(AF_INET6, SOCK_STREAM, 0));
if (!listener) {
throw SysError(errno, "socket() failed");
}
Pipe trigger;
trigger.create();
sockaddr_in6 addr = {
.sin6_family = AF_INET6,
.sin6_addr = IN6ADDR_LOOPBACK_INIT,
};
socklen_t len = sizeof(addr);
if (::bind(listener.get(), reinterpret_cast<const sockaddr *>(&addr), sizeof(addr)) < 0) {
throw SysError(errno, "bind() failed");
}
if (::getsockname(listener.get(), reinterpret_cast<sockaddr *>(&addr), &len) < 0) {
throw SysError(errno, "getsockname() failed");
}
if (::listen(listener.get(), 1) < 0) {
throw SysError(errno, "listen() failed");
}
std::thread(
[status, headers, content](AutoCloseFD socket, AutoCloseFD trigger) {
while (true) {
pollfd pfds[2] = {
{
.fd = socket.get(),
.events = POLLIN,
},
{
.fd = trigger.get(),
.events = POLLHUP,
},
};
if (::poll(pfds, 2, -1) <= 0) {
throw SysError(errno, "poll() failed");
}
if (pfds[1].revents & POLLHUP) {
return;
}
if (!(pfds[0].revents & POLLIN)) {
continue;
}
AutoCloseFD conn(::accept(socket.get(), nullptr, nullptr));
if (!conn) {
throw SysError(errno, "accept() failed");
}
auto send = [&](std::string_view bit) {
while (!bit.empty()) {
auto written = ::write(conn.get(), bit.data(), bit.size());
if (written < 0) {
throw SysError(errno, "write() failed");
}
bit.remove_prefix(written);
}
};
send("HTTP/1.1 ");
send(status);
send("\r\n");
send(headers);
send("\r\n");
send(content());
::shutdown(conn.get(), SHUT_RDWR);
}
},
std::move(listener),
std::move(trigger.readSide)
)
.detach();
return {
ntohs(addr.sin6_port),
std::move(trigger.writeSide),
};
}
TEST(FileTransfer, exceptionAbortsDownload) TEST(FileTransfer, exceptionAbortsDownload)
{ {
struct Done struct Done
@ -29,4 +131,25 @@ TEST(FileTransfer, exceptionAbortsDownload)
(void) new auto(std::move(reset)); (void) new auto(std::move(reset));
} }
} }
TEST(FileTransfer, NOT_ON_DARWIN(reportsSetupErrors))
{
auto [port, srv] = serveHTTP("404 not found", "", [] { return ""; });
auto ft = makeFileTransfer();
ASSERT_THROW(
ft->download(FileTransferRequest(fmt("http://[::1]:%d/index", port))),
FileTransferError);
}
TEST(FileTransfer, NOT_ON_DARWIN(reportsTransferError))
{
auto [port, srv] = serveHTTP("200 ok", "content-length: 100\r\n", [] {
std::this_thread::sleep_for(10ms);
return "";
});
auto ft = makeFileTransfer();
FileTransferRequest req(fmt("http://[::1]:%d/index", port));
req.baseRetryTimeMs = 0;
ASSERT_THROW(ft->download(req), FileTransferError);
}
} }