diff --git a/src/libutil/serialise.cc b/src/libutil/serialise.cc index 640267a13..ba549c214 100644 --- a/src/libutil/serialise.cc +++ b/src/libutil/serialise.cc @@ -149,12 +149,17 @@ void writeLongLong(unsigned long long n, Sink & sink) } +void writeString(const unsigned char * buf, size_t len, Sink & sink) +{ + writeInt(len, sink); + sink(buf, len); + writePadding(len, sink); +} + + void writeString(const string & s, Sink & sink) { - size_t len = s.length(); - writeInt(len, sink); - sink((const unsigned char *) s.c_str(), len); - writePadding(len, sink); + writeString((const unsigned char *) s.c_str(), s.size(), sink); } @@ -208,6 +213,16 @@ unsigned long long readLongLong(Source & source) } +size_t readString(unsigned char * buf, size_t max, Source & source) +{ + size_t len = readInt(source); + if (len > max) throw Error("string is too long"); + source(buf, len); + readPadding(len, source); + return len; +} + + string readString(Source & source) { size_t len = readInt(source); diff --git a/src/libutil/serialise.hh b/src/libutil/serialise.hh index 25398b09d..efd8e2a06 100644 --- a/src/libutil/serialise.hh +++ b/src/libutil/serialise.hh @@ -114,12 +114,14 @@ struct StringSource : Source void writePadding(size_t len, Sink & sink); void writeInt(unsigned int n, Sink & sink); void writeLongLong(unsigned long long n, Sink & sink); +void writeString(const unsigned char * buf, size_t len, Sink & sink); void writeString(const string & s, Sink & sink); void writeStringSet(const StringSet & ss, Sink & sink); void readPadding(size_t len, Source & source); unsigned int readInt(Source & source); unsigned long long readLongLong(Source & source); +size_t readString(unsigned char * buf, size_t max, Source & source); string readString(Source & source); StringSet readStringSet(Source & source); diff --git a/src/nix-worker/nix-worker.cc b/src/nix-worker/nix-worker.cc index 695e4c38d..85e2105b2 100644 --- a/src/nix-worker/nix-worker.cc +++ b/src/nix-worker/nix-worker.cc @@ -56,7 +56,7 @@ static void tunnelStderr(const unsigned char * buf, size_t count) if (canSendStderr && myPid == getpid()) { try { writeInt(STDERR_NEXT, to); - writeString(string((char *) buf, count), to); + writeString(buf, count, to); to.flush(); } catch (...) { /* Write failed; that means that the other side is @@ -205,7 +205,7 @@ struct TunnelSink : Sink virtual void operator () (const unsigned char * data, size_t len) { writeInt(STDERR_WRITE, to); - writeString(string((const char *) data, len), to); + writeString(data, len, to); } }; @@ -224,16 +224,11 @@ struct TunnelSource : BufferedSource writeInt(STDERR_READ, to); writeInt(len, to); to.flush(); - string s = readString(from); // !!! inefficient + size_t n = readString(data, len, from); startWork(); - - if (s.empty()) throw EndOfFile("unexpected end-of-file"); - if (s.size() > len) throw Error("client sent too much data"); - - memcpy(data, (const unsigned char *) s.c_str(), s.size()); - - return s.size(); + if (n == 0) throw EndOfFile("unexpected end-of-file"); + return n; } };