diff --git a/src/libutil/compression.cc b/src/libutil/compression.cc index 678557a58..e78d76500 100644 --- a/src/libutil/compression.cc +++ b/src/libutil/compression.cc @@ -137,53 +137,55 @@ struct NoneSink : CompressionSink void writeUnbuffered(std::string_view data) override { nextSink(data); } }; -struct BrotliDecompressionSink : ChunkedCompressionSink +struct BrotliDecompressionSource : Source { - Sink & nextSink; - BrotliDecoderState * state; - bool finished = false; + static constexpr size_t BUF_SIZE = 32 * 1024; + std::unique_ptr buf; + size_t avail_in = 0; + const uint8_t * next_in; - BrotliDecompressionSink(Sink & nextSink) : nextSink(nextSink) + Source * inner; + std::unique_ptr state; + + BrotliDecompressionSource(Source & inner) + : buf(std::make_unique(BUF_SIZE)) + , inner(&inner) + , state{ + BrotliDecoderCreateInstance(nullptr, nullptr, nullptr), BrotliDecoderDestroyInstance} { - state = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); - if (!state) + if (!state) { throw CompressionError("unable to initialize brotli decoder"); + } } - ~BrotliDecompressionSink() + size_t read(char * data, size_t len) override { - BrotliDecoderDestroyInstance(state); - } + uint8_t * out = (uint8_t *) data; + const auto * begin = out; - void finish() override - { - flush(); - writeInternal({}); - } + try { + while (len && !BrotliDecoderIsFinished(state.get())) { + checkInterrupt(); - void writeInternal(std::string_view data) override - { - auto next_in = (const uint8_t *) data.data(); - size_t avail_in = data.size(); - uint8_t * next_out = outbuf; - size_t avail_out = sizeof(outbuf); + while (avail_in == 0) { + avail_in = inner->read(buf.get(), BUF_SIZE); + next_in = (const uint8_t *) buf.get(); + } - while (!finished && (!data.data() || avail_in)) { - checkInterrupt(); - - if (!BrotliDecoderDecompressStream(state, - &avail_in, &next_in, - &avail_out, &next_out, - nullptr)) - throw CompressionError("error while decompressing brotli file"); - - if (avail_out < sizeof(outbuf) || avail_in == 0) { - nextSink({(char *) outbuf, sizeof(outbuf) - avail_out}); - next_out = outbuf; - avail_out = sizeof(outbuf); + if (!BrotliDecoderDecompressStream( + state.get(), &avail_in, &next_in, &len, &out, nullptr + )) + { + throw CompressionError("error while decompressing brotli file"); + } } + } catch (EndOfFile &) { + } - finished = BrotliDecoderIsFinished(state); + if (begin != out) { + return out - begin; + } else { + throw EndOfFile("brotli stream exhausted"); } } }; @@ -202,7 +204,19 @@ std::unique_ptr makeDecompressionSink(const std::string & method, Si if (method == "none" || method == "") return std::make_unique(nextSink); else if (method == "br") - return std::make_unique(nextSink); + return sourceToSink([&](Source & source) { + BrotliDecompressionSource wrapped{source}; + wrapped.drainInto(nextSink); + // special handling because sourceToSink is screwy: try + // to read the source one final time and fail when that + // succeeds (to reject trailing garbage in input data). + try { + char buf; + source(&buf, 1); + throw Error("garbage at end of brotli stream detected"); + } catch (EndOfFile &) { + } + }); else return sourceToSink([&](Source & source) { auto decompressionSource = std::make_unique(source);