diff --git a/src/libstore/filetransfer.cc b/src/libstore/filetransfer.cc index 9010aa382..aba8b7fbb 100644 --- a/src/libstore/filetransfer.cc +++ b/src/libstore/filetransfer.cc @@ -64,7 +64,8 @@ struct curlFileTransfer : public FileTransfer } phase = initialSetup; std::promise metadataPromise; std::packaged_task doneCallback; - std::function dataCallback; + // return false from dataCallback to pause the transfer without consuming data + std::function dataCallback; CURL * req; // must never be nullptr std::string statusMsg; @@ -104,7 +105,7 @@ struct curlFileTransfer : public FileTransfer const Headers & headers, ActivityId parentAct, std::invocable auto doneCallback, - std::function dataCallback, + std::function dataCallback, std::optional uploadData, bool noBody ) @@ -198,15 +199,16 @@ struct curlFileTransfer : public FileTransfer try { maybeFinishSetup(); - bodySize += realSize; - if (successfulStatuses.count(getHTTPStatus()) && this->dataCallback) { + if (!dataCallback({static_cast(contents), realSize})) { + return CURL_WRITEFUNC_PAUSE; + } writtenToSink += realSize; - dataCallback({static_cast(contents), realSize}); } else { this->downloadData.append(static_cast(contents), realSize); } + bodySize += realSize; return realSize; } catch (...) { callbackException = std::current_exception(); @@ -501,6 +503,13 @@ struct curlFileTransfer : public FileTransfer fail(std::move(exc)); } } + + void unpause() + { + auto lock = fileTransfer.state_.lock(); + lock->unpause.push_back(shared_from_this()); + fileTransfer.wakeup(); + } }; struct State @@ -512,6 +521,7 @@ struct curlFileTransfer : public FileTransfer }; bool quit = false; std::priority_queue, std::vector>, EmbargoComparator> incoming; + std::vector> unpause; }; Sync state_; @@ -627,6 +637,10 @@ struct curlFileTransfer : public FileTransfer { auto state(state_.lock()); + for (auto & item : state->unpause) { + curl_easy_pause(item->req, CURLPAUSE_CONT); + } + state->unpause.clear(); while (!state->incoming.empty()) { auto item = state->incoming.top(); if (item->embargo <= now) { @@ -826,14 +840,14 @@ struct curlFileTransfer : public FileTransfer download thread. (Hopefully sleeping will throttle the sender.) */ if (state->data.size() > 1024 * 1024) { - debug("download buffer is full; going to sleep"); - state.wait_for(state->request, std::chrono::seconds(10)); + return false; } /* Append data to the buffer and wake up the calling thread. */ state->data.append(data); state->avail.notify_one(); + return true; }, std::move(data), noBody @@ -842,10 +856,17 @@ struct curlFileTransfer : public FileTransfer struct TransferSource : Source { const std::shared_ptr> _state; + std::shared_ptr transfer; std::string chunk; std::string_view buffered; - explicit TransferSource(const std::shared_ptr> & state) : _state(state) {} + explicit TransferSource( + const std::shared_ptr> & state, std::shared_ptr transfer + ) + : _state(state) + , transfer(std::move(transfer)) + { + } ~TransferSource() { @@ -868,6 +889,7 @@ struct curlFileTransfer : public FileTransfer return; } + transfer->unpause(); state.wait(state->avail); } @@ -910,7 +932,7 @@ struct curlFileTransfer : public FileTransfer }; auto metadata = item->metadataPromise.get_future().get(); - auto source = make_box_ptr(_state); + auto source = make_box_ptr(_state, item); auto lock(_state->lock()); source->awaitData(lock); return {std::move(metadata), std::move(source)}; diff --git a/tests/unit/libstore/filetransfer.cc b/tests/unit/libstore/filetransfer.cc index 5885a8059..6690c4a94 100644 --- a/tests/unit/libstore/filetransfer.cc +++ b/tests/unit/libstore/filetransfer.cc @@ -27,9 +27,28 @@ namespace { struct Reply { std::string status, headers; - std::function content; -}; + std::function(int)> content; + Reply( + std::string_view status, std::string_view headers, std::function content + ) + : Reply(status, headers, [content](int round) { + return round == 0 ? std::optional(content()) : std::nullopt; + }) + { + } + + Reply( + std::string_view status, + std::string_view headers, + std::function(int)> content + ) + : status(status) + , headers(headers) + , content(content) + { + } +}; } namespace nix { @@ -89,25 +108,44 @@ serveHTTP(std::vector replies) throw SysError(errno, "accept() failed"); } - auto send = [&](std::string_view bit) { - while (!bit.empty()) { - auto written = ::write(conn.get(), bit.data(), bit.size()); - if (written < 0) { - throw SysError(errno, "write() failed"); - } - bit.remove_prefix(written); - } - }; - const auto & reply = replies[at++ % replies.size()]; - send("HTTP/1.1 "); - send(reply.status); - send("\r\n"); - send(reply.headers); - send("\r\n"); - send(reply.content()); - ::shutdown(conn.get(), SHUT_RDWR); + std::thread([=, conn{std::move(conn)}] { + auto send = [&](std::string_view bit) { + while (!bit.empty()) { + auto written = ::write(conn.get(), bit.data(), bit.size()); + if (written < 0) { + throw SysError(errno, "write() failed"); + } + bit.remove_prefix(written); + } + }; + + send("HTTP/1.1 "); + send(reply.status); + send("\r\n"); + send(reply.headers); + send("\r\n"); + for (int round = 0; ; round++) { + if (auto content = reply.content(round); content.has_value()) { + send(*content); + } else { + break; + } + } + ::shutdown(conn.get(), SHUT_WR); + for (;;) { + char buf[1]; + switch (read(conn.get(), buf, 1)) { + case 0: + return; // remote closed + case 1: + continue; // connection still held open by remote + default: + throw SysError(errno, "read() failed"); + } + } + }).detach(); } }, std::move(listener), @@ -219,4 +257,35 @@ TEST(FileTransfer, usesIntermediateLinkHeaders) ASSERT_EQ(result.immutableUrl, "http://foo"); } +TEST(FileTransfer, stalledReaderDoesntBlockOthers) +{ + auto [port, srv] = serveHTTP({ + {"200 ok", + "content-length: 100000000\r\n", + [](int round) mutable { + return round < 100 ? std::optional(std::string(1'000'000, ' ')) : std::nullopt; + }}, + }); + auto ft = makeFileTransfer(0); + auto [_result1, data1] = ft->download(fmt("http://[::1]:%d", port)); + auto [_result2, data2] = ft->download(fmt("http://[::1]:%d", port)); + auto drop = [](Source & source, size_t size) { + char buf[1000]; + while (size > 0) { + auto round = std::min(size, sizeof(buf)); + source(buf, round); + size -= round; + } + }; + // read 10M of each of the 100M, then the rest. neither reader should + // block the other, nor should it take that long to copy 200MB total. + drop(*data1, 10'000'000); + drop(*data2, 10'000'000); + drop(*data1, 90'000'000); + drop(*data2, 90'000'000); + + ASSERT_THROW(drop(*data1, 1), EndOfFile); + ASSERT_THROW(drop(*data2, 1), EndOfFile); +} + }