Merge pull request #9279 from tfc/util-improv

Util improvements
This commit is contained in:
John Ericson 2024-01-16 15:05:28 -05:00 committed by GitHub
commit 799e662cbe
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 51 additions and 37 deletions

View file

@ -96,7 +96,7 @@ void drainFD(int fd, Sink & sink, bool block)
throw SysError("making file descriptor non-blocking");
}
Finally finally([&]() {
Finally finally([&] {
if (!block) {
if (fcntl(fd, F_SETFL, saved) == -1)
throw SysError("making file descriptor blocking");
@ -114,7 +114,7 @@ void drainFD(int fd, Sink & sink, bool block)
throw SysError("reading from file");
}
else if (rd == 0) break;
else sink({(char *) buf.data(), (size_t) rd});
else sink({reinterpret_cast<char *>(buf.data()), size_t(rd)});
}
}

View file

@ -90,7 +90,7 @@ Path canonPath(PathView path, bool resolveSymlinks)
/* Normal component; copy it. */
else {
s += '/';
if (const auto slash = path.find('/'); slash == std::string::npos) {
if (const auto slash = path.find('/'); slash == path.npos) {
s += path;
path = {};
} else {
@ -116,14 +116,18 @@ Path canonPath(PathView path, bool resolveSymlinks)
}
}
return s.empty() ? "/" : std::move(s);
if (s.empty()) {
s = "/";
}
return s;
}
Path dirOf(const PathView path)
{
Path::size_type pos = path.rfind('/');
if (pos == std::string::npos)
if (pos == path.npos)
return ".";
return pos == 0 ? "/" : Path(path, 0, pos);
}
@ -139,7 +143,7 @@ std::string_view baseNameOf(std::string_view path)
last -= 1;
auto pos = path.rfind('/', last);
if (pos == std::string::npos)
if (pos == path.npos)
pos = 0;
else
pos += 1;

View file

@ -131,7 +131,7 @@ void killUser(uid_t uid)
users to which the current process can send signals. So we
fork a process, switch to uid, and send a mass kill. */
Pid pid = startProcess([&]() {
Pid pid = startProcess([&] {
if (setuid(uid) == -1)
throw SysError("setting uid");
@ -168,11 +168,12 @@ void killUser(uid_t uid)
//////////////////////////////////////////////////////////////////////
using ChildWrapperFunction = std::function<void()>;
/* Wrapper around vfork to prevent the child process from clobbering
the caller's stack frame in the parent. */
static pid_t doFork(bool allowVfork, std::function<void()> fun) __attribute__((noinline));
static pid_t doFork(bool allowVfork, std::function<void()> fun)
static pid_t doFork(bool allowVfork, ChildWrapperFunction & fun) __attribute__((noinline));
static pid_t doFork(bool allowVfork, ChildWrapperFunction & fun)
{
#ifdef __linux__
pid_t pid = allowVfork ? vfork() : fork();
@ -188,8 +189,8 @@ static pid_t doFork(bool allowVfork, std::function<void()> fun)
#if __linux__
static int childEntry(void * arg)
{
auto main = (std::function<void()> *) arg;
(*main)();
auto & fun = *reinterpret_cast<ChildWrapperFunction*>(arg);
fun();
return 1;
}
#endif
@ -197,7 +198,7 @@ static int childEntry(void * arg)
pid_t startProcess(std::function<void()> fun, const ProcessOptions & options)
{
std::function<void()> wrapper = [&]() {
ChildWrapperFunction wrapper = [&] {
if (!options.allowVfork)
logger = makeSimpleLogger();
try {
@ -225,11 +226,11 @@ pid_t startProcess(std::function<void()> fun, const ProcessOptions & options)
assert(!(options.cloneFlags & CLONE_VM));
size_t stackSize = 1 * 1024 * 1024;
auto stack = (char *) mmap(0, stackSize,
PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0);
auto stack = static_cast<char *>(mmap(0, stackSize,
PROT_WRITE | PROT_READ, MAP_PRIVATE | MAP_ANONYMOUS | MAP_STACK, -1, 0));
if (stack == MAP_FAILED) throw SysError("allocating stack");
Finally freeStack([&]() { munmap(stack, stackSize); });
Finally freeStack([&] { munmap(stack, stackSize); });
pid = clone(childEntry, stack + stackSize, options.cloneFlags | SIGCHLD, &wrapper);
#else
@ -308,7 +309,7 @@ void runProgram2(const RunOptions & options)
}
/* Fork. */
Pid pid = startProcess([&]() {
Pid pid = startProcess([&] {
if (options.environment)
replaceEnv(*options.environment);
if (options.standardOut && dup2(out.writeSide.get(), STDOUT_FILENO) == -1)
@ -350,7 +351,7 @@ void runProgram2(const RunOptions & options)
std::promise<void> promise;
Finally doJoin([&]() {
Finally doJoin([&] {
if (writerThread.joinable())
writerThread.join();
});
@ -358,7 +359,7 @@ void runProgram2(const RunOptions & options)
if (source) {
in.readSide.close();
writerThread = std::thread([&]() {
writerThread = std::thread([&] {
try {
std::vector<char> buf(8 * 1024);
while (true) {

View file

@ -179,7 +179,7 @@ std::unique_ptr<InterruptCallback> createInterruptCallback(std::function<void()>
auto token = interruptCallbacks->nextToken++;
interruptCallbacks->callbacks.emplace(token, callback);
auto res = std::make_unique<InterruptCallbackImpl>();
std::unique_ptr<InterruptCallbackImpl> res {new InterruptCallbackImpl{}};
res->token = token;
return std::unique_ptr<InterruptCallback>(res.release());

View file

@ -38,6 +38,14 @@ AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
return fdSocket;
}
static struct sockaddr* safeSockAddrPointerCast(struct sockaddr_un *addr) {
// Casting between types like these legacy C library interfaces require
// is forbidden in C++.
// To maintain backwards compatibility, the implementation of the
// bind function contains some hints to the compiler that allow for this
// special case.
return reinterpret_cast<struct sockaddr *>(addr);
}
void bind(int fd, const std::string & path)
{
@ -45,9 +53,10 @@ void bind(int fd, const std::string & path)
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pid pid = startProcess([&]() {
Pid pid = startProcess([&] {
Path dir = dirOf(path);
if (chdir(dir.c_str()) == -1)
throw SysError("chdir to '%s' failed", dir);
@ -55,7 +64,7 @@ void bind(int fd, const std::string & path)
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (bind(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1)
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
_exit(0);
});
@ -64,7 +73,7 @@ void bind(int fd, const std::string & path)
throw Error("cannot bind to socket '%s'", path);
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (bind(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1)
if (bind(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot bind to socket '%s'", path);
}
}
@ -74,11 +83,12 @@ void connect(int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
auto psaddr {safeSockAddrPointerCast(&addr)};
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pipe pipe;
pipe.create();
Pid pid = startProcess([&]() {
Pid pid = startProcess([&] {
try {
pipe.readSide.close();
Path dir = dirOf(path);
@ -88,7 +98,7 @@ void connect(int fd, const std::string & path)
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1)
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
writeFull(pipe.writeSide.get(), "0\n");
} catch (SysError & e) {
@ -107,7 +117,7 @@ void connect(int fd, const std::string & path)
}
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (connect(fd, (struct sockaddr *) &addr, sizeof(addr)) == -1)
if (connect(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot connect to socket at '%s'", path);
}
}

View file

@ -52,9 +52,9 @@ template<class C> C tokenizeString(std::string_view s, std::string_view separato
{
C result;
auto pos = s.find_first_not_of(separators, 0);
while (pos != std::string_view::npos) {
while (pos != s.npos) {
auto end = s.find_first_of(separators, pos + 1);
if (end == std::string_view::npos) end = s.size();
if (end == s.npos) end = s.size();
result.insert(result.end(), std::string(s, pos, end - pos));
pos = s.find_first_not_of(separators, end);
}
@ -69,7 +69,7 @@ template std::vector<std::string> tokenizeString(std::string_view s, std::string
std::string chomp(std::string_view s)
{
size_t i = s.find_last_not_of(" \n\r\t");
return i == std::string_view::npos ? "" : std::string(s, 0, i + 1);
return i == s.npos ? "" : std::string(s, 0, i + 1);
}
@ -89,7 +89,7 @@ std::string replaceStrings(
{
if (from.empty()) return res;
size_t pos = 0;
while ((pos = res.find(from, pos)) != std::string::npos) {
while ((pos = res.find(from, pos)) != res.npos) {
res.replace(pos, from.size(), to);
pos += to.size();
}
@ -102,7 +102,7 @@ std::string rewriteStrings(std::string s, const StringMap & rewrites)
for (auto & i : rewrites) {
if (i.first == i.second) continue;
size_t j = 0;
while ((j = s.find(i.first, j)) != std::string::npos)
while ((j = s.find(i.first, j)) != s.npos)
s.replace(j, i.first.size(), i.second);
}
return s;
@ -122,12 +122,11 @@ bool hasSuffix(std::string_view s, std::string_view suffix)
}
std::string toLower(const std::string & s)
std::string toLower(std::string s)
{
std::string r(s);
for (auto & c : r)
for (auto & c : s)
c = std::tolower(c);
return r;
return s;
}
@ -135,7 +134,7 @@ std::string shellEscape(const std::string_view s)
{
std::string r;
r.reserve(s.size() + 2);
r += "'";
r += '\'';
for (auto & i : s)
if (i == '\'') r += "'\\''"; else r += i;
r += '\'';
@ -184,7 +183,7 @@ std::string base64Encode(std::string_view s)
std::string base64Decode(std::string_view s)
{
constexpr char npos = -1;
constexpr std::array<char, 256> base64DecodeChars = [&]() {
constexpr std::array<char, 256> base64DecodeChars = [&] {
std::array<char, 256> result{};
for (auto& c : result)
c = npos;

View file

@ -180,7 +180,7 @@ bool hasSuffix(std::string_view s, std::string_view suffix);
/**
* Convert a string to lower case.
*/
std::string toLower(const std::string & s);
std::string toLower(std::string s);
/**