util.{hh,cc}: Split out unix-domain-socket.{hh,cc}

Change-Id: I3f9a628e0f8998b6146f5caa8ae9842361a66b8b
This commit is contained in:
Tom Hubrecht 2024-05-28 14:41:48 +02:00
parent e81ed5f12d
commit 5b5a75979a
10 changed files with 142 additions and 114 deletions

View file

@ -15,6 +15,7 @@
#include "personality.hh" #include "personality.hh"
#include "namespaces.hh" #include "namespaces.hh"
#include "child.hh" #include "child.hh"
#include "unix-domain-socket.hh"
#include <regex> #include <regex>
#include <queue> #include <queue>

View file

@ -4,6 +4,7 @@
#include "processes.hh" #include "processes.hh"
#include "signals.hh" #include "signals.hh"
#include "finally.hh" #include "finally.hh"
#include "unix-domain-socket.hh"
#include <functional> #include <functional>
#include <queue> #include <queue>

View file

@ -1,4 +1,5 @@
#include "uds-remote-store.hh" #include "uds-remote-store.hh"
#include "unix-domain-socket.hh"
#include "worker-protocol.hh" #include "worker-protocol.hh"
#include <sys/types.h> #include <sys/types.h>

View file

@ -77,13 +77,4 @@ void closeOnExec(int fd);
MakeError(EndOfFile, Error); MakeError(EndOfFile, Error);
/**
* Create a Unix domain socket.
*/
AutoCloseFD createUnixDomainSocket();
/**
* Create a Unix domain socket in listen mode.
*/
AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode);
} }

View file

@ -35,6 +35,7 @@ libutil_sources = files(
'tarfile.cc', 'tarfile.cc',
'terminal.cc', 'terminal.cc',
'thread-pool.cc', 'thread-pool.cc',
'unix-domain-socket.cc',
'url.cc', 'url.cc',
'url-name.cc', 'url-name.cc',
'util.cc', 'util.cc',
@ -102,6 +103,7 @@ libutil_headers = files(
'thread-pool.hh', 'thread-pool.hh',
'topo-sort.hh', 'topo-sort.hh',
'types.hh', 'types.hh',
'unix-domain-socket.hh',
'url-parts.hh', 'url-parts.hh',
'url-name.hh', 'url-name.hh',
'url.hh', 'url.hh',

View file

@ -0,0 +1,105 @@
#include "file-system.hh"
#include "processes.hh"
#include "unix-domain-socket.hh"
#include "util.hh"
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>
namespace nix {
AutoCloseFD createUnixDomainSocket()
{
AutoCloseFD fdSocket{socket(PF_UNIX, SOCK_STREAM
#ifdef SOCK_CLOEXEC
| SOCK_CLOEXEC
#endif
, 0)};
if (!fdSocket)
throw SysError("cannot create Unix domain socket");
closeOnExec(fdSocket.get());
return fdSocket;
}
AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
{
auto fdSocket = nix::createUnixDomainSocket();
bind(fdSocket.get(), path);
chmodPath(path.c_str(), mode);
if (listen(fdSocket.get(), 100) == -1)
throw SysError("cannot listen on socket '%1%'", path);
return fdSocket;
}
static void bindConnectProcHelper(
std::string_view operationName, auto && operation,
int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
// Casting between types like these legacy C library interfaces
// require is forbidden in C++. To maintain backwards
// compatibility, the implementation of the bind/connect functions
// contains some hints to the compiler that allow for this
// special case.
auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr);
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pipe pipe;
pipe.create();
Pid pid = startProcess([&] {
try {
pipe.readSide.close();
Path dir = dirOf(path);
if (chdir(dir.c_str()) == -1)
throw SysError("chdir to '%s' failed", dir);
std::string base(baseNameOf(path));
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
writeFull(pipe.writeSide.get(), "0\n");
} catch (SysError & e) {
writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo));
} catch (...) {
writeFull(pipe.writeSide.get(), "-1\n");
}
});
pipe.writeSide.close();
auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get())));
if (!errNo || *errNo == -1)
throw Error("cannot %s to socket at '%s'", operationName, path);
else if (*errNo > 0) {
errno = *errNo;
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
}
void bind(int fd, const std::string & path)
{
unlink(path.c_str());
bindConnectProcHelper("bind", ::bind, fd, path);
}
void connect(int fd, const std::string & path)
{
bindConnectProcHelper("connect", ::connect, fd, path);
}
}

View file

@ -0,0 +1,31 @@
#pragma once
///@file
#include "file-descriptor.hh"
#include "types.hh"
#include <unistd.h>
namespace nix {
/**
* Create a Unix domain socket.
*/
AutoCloseFD createUnixDomainSocket();
/**
* Create a Unix domain socket in listen mode.
*/
AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode);
/**
* Bind a Unix domain socket to a path.
*/
void bind(int fd, const std::string & path);
/**
* Connect to a Unix domain socket.
*/
void connect(int fd, const std::string & path);
}

View file

@ -459,100 +459,6 @@ void unshareFilesystem()
#endif #endif
} }
AutoCloseFD createUnixDomainSocket()
{
AutoCloseFD fdSocket{socket(PF_UNIX, SOCK_STREAM
#ifdef SOCK_CLOEXEC
| SOCK_CLOEXEC
#endif
, 0)};
if (!fdSocket)
throw SysError("cannot create Unix domain socket");
closeOnExec(fdSocket.get());
return fdSocket;
}
AutoCloseFD createUnixDomainSocket(const Path & path, mode_t mode)
{
auto fdSocket = nix::createUnixDomainSocket();
bind(fdSocket.get(), path);
chmodPath(path.c_str(), mode);
if (listen(fdSocket.get(), 100) == -1)
throw SysError("cannot listen on socket '%1%'", path);
return fdSocket;
}
static void bindConnectProcHelper(
std::string_view operationName, auto && operation,
int fd, const std::string & path)
{
struct sockaddr_un addr;
addr.sun_family = AF_UNIX;
// Casting between types like these legacy C library interfaces
// require is forbidden in C++. To maintain backwards
// compatibility, the implementation of the bind/connect functions
// contains some hints to the compiler that allow for this
// special case.
auto * psaddr = reinterpret_cast<struct sockaddr *>(&addr);
if (path.size() + 1 >= sizeof(addr.sun_path)) {
Pipe pipe;
pipe.create();
Pid pid = startProcess([&] {
try {
pipe.readSide.close();
Path dir = dirOf(path);
if (chdir(dir.c_str()) == -1)
throw SysError("chdir to '%s' failed", dir);
std::string base(baseNameOf(path));
if (base.size() + 1 >= sizeof(addr.sun_path))
throw Error("socket path '%s' is too long", base);
memcpy(addr.sun_path, base.c_str(), base.size() + 1);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
writeFull(pipe.writeSide.get(), "0\n");
} catch (SysError & e) {
writeFull(pipe.writeSide.get(), fmt("%d\n", e.errNo));
} catch (...) {
writeFull(pipe.writeSide.get(), "-1\n");
}
});
pipe.writeSide.close();
auto errNo = string2Int<int>(chomp(drainFD(pipe.readSide.get())));
if (!errNo || *errNo == -1)
throw Error("cannot %s to socket at '%s'", operationName, path);
else if (*errNo > 0) {
errno = *errNo;
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
} else {
memcpy(addr.sun_path, path.c_str(), path.size() + 1);
if (operation(fd, psaddr, sizeof(addr)) == -1)
throw SysError("cannot %s to socket at '%s'", operationName, path);
}
}
void bind(int fd, const std::string & path)
{
unlink(path.c_str());
bindConnectProcHelper("bind", ::bind, fd, path);
}
void connect(int fd, const std::string & path)
{
bindConnectProcHelper("connect", ::connect, fd, path);
}
std::string showBytes(uint64_t bytes) std::string showBytes(uint64_t bytes)
{ {

View file

@ -406,17 +406,6 @@ struct MaintainCount
}; };
/**
* Bind a Unix domain socket to a path.
*/
void bind(int fd, const std::string & path);
/**
* Connect to a Unix domain socket.
*/
void connect(int fd, const std::string & path);
/** /**
* A Rust/Python-like enumerate() iterator adapter. * A Rust/Python-like enumerate() iterator adapter.
* *

View file

@ -14,6 +14,7 @@
#include "legacy.hh" #include "legacy.hh"
#include "signals.hh" #include "signals.hh"
#include "daemon.hh" #include "daemon.hh"
#include "unix-domain-socket.hh"
#include <algorithm> #include <algorithm>
#include <climits> #include <climits>