Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions llvm/lib/Support/raw_socket_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,16 @@ static std::error_code getLastSocketErrorCode() {
#endif
}

static sockaddr_un setSocketAddr(StringRef SocketPath) {
static Expected<sockaddr_un> setSocketAddr(StringRef SocketPath) {
struct sockaddr_un Addr;
memset(&Addr, 0, sizeof(Addr));
Addr.sun_family = AF_UNIX;

if (sizeof(sockaddr_un::sun_path) <= SocketPath.size())
return make_error<StringError>(
std::make_error_code(std::errc::filename_too_long),
"Socket path exceeds sockaddr_un::sun_path size limit");

strncpy(Addr.sun_path, SocketPath.str().c_str(), sizeof(Addr.sun_path) - 1);
return Addr;
}
Expand All @@ -90,8 +96,11 @@ static Expected<int> getSocketFD(StringRef SocketPath) {
// off the handshake (and SO_PEERCRED/getpeereid support).
setsockopt(Socket, SOL_SOCKET, SO_PEERCRED, NULL, 0);
#endif
struct sockaddr_un Addr = setSocketAddr(SocketPath);
if (::connect(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
Expected<struct sockaddr_un> Addr = setSocketAddr(SocketPath);
if (!Addr)
return Addr.takeError();

if (::connect(Socket, (struct sockaddr *)&*Addr, sizeof(*Addr)) == -1) {
::close(Socket);
return llvm::make_error<StringError>(getLastSocketErrorCode(),
"Connect socket failed");
Expand Down Expand Up @@ -167,8 +176,11 @@ Expected<ListeningSocket> ListeningSocket::createUnix(StringRef SocketPath,
// off the handshake (and SO_PEERCRED/getpeereid support).
setsockopt(Socket, SOL_SOCKET, SO_PEERCRED, NULL, 0);
#endif
struct sockaddr_un Addr = setSocketAddr(SocketPath);
if (::bind(Socket, (struct sockaddr *)&Addr, sizeof(Addr)) == -1) {
Expected<struct sockaddr_un> Addr = setSocketAddr(SocketPath);
if (!Addr)
return Addr.takeError();

if (::bind(Socket, (struct sockaddr *)&*Addr, sizeof(*Addr)) == -1) {
// Grab error code from call to ::bind before calling ::close
std::error_code EC = getLastSocketErrorCode();
::close(Socket);
Expand Down
103 changes: 35 additions & 68 deletions llvm/unittests/Support/raw_socket_stream_test.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,8 @@
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/SmallString.h"
#include "llvm/Config/llvm-config.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/Support/FileUtilities.h"
#include "llvm/Support/raw_socket_stream.h"
#include "llvm/Testing/Support/Error.h"
#include "gtest/gtest.h"
#include <future>
#include <stdlib.h>
#include <thread>

Expand All @@ -28,29 +23,40 @@ bool hasUnixSocketSupport() {
return true;
}

TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

struct raw_socket_streamTest : ::testing::Test {
SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("client_server_comms-%%%%%%.sock", SocketPath,
true);
auto Cleanup = llvm::scope_exit([&] { std::remove(SocketPath.c_str()); });

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());

ListeningSocket ServerListener = std::move(*MaybeServerListener);

std::optional<ListeningSocket> ServerListener;

void SetUp() override {
if (!hasUnixSocketSupport())
GTEST_SKIP();

llvm::sys::fs::createUniquePath("llvm-%%%%%%%%.sock", SocketPath, true);
Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
if (!MaybeServerListener) {
std::error_code EC = errorToErrorCode(MaybeServerListener.takeError());
if (EC == std::errc::filename_too_long)
GTEST_SKIP() << EC.message() << ": " << SocketPath;
FAIL() << EC.message();
return;
}

ServerListener.emplace(std::move(*MaybeServerListener));
}

void TearDown() override { std::remove(SocketPath.c_str()); }
};

TEST_F(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
raw_socket_stream::createConnectedUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());

raw_socket_stream &Client = **MaybeClient;

Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept();
ServerListener->accept();
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());

raw_socket_stream &Server = **MaybeServer;
Expand All @@ -61,33 +67,20 @@ TEST(raw_socket_streamTest, CLIENT_TO_SERVER_AND_SERVER_TO_CLIENT) {
char Bytes[8];
ssize_t BytesRead = Server.read(Bytes, 8);

std::string string(Bytes, 8);
std::string Str(Bytes, 8);
ASSERT_EQ(Server.has_error(), false);

ASSERT_EQ(8, BytesRead);
ASSERT_EQ("01234567", string);
ASSERT_EQ("01234567", Str);
}

TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("read_with_timeout-%%%%%%.sock", SocketPath,
true);
auto Cleanup = llvm::scope_exit([&] { std::remove(SocketPath.c_str()); });

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);

TEST_F(raw_socket_streamTest, READ_WITH_TIMEOUT) {
Expected<std::unique_ptr<raw_socket_stream>> MaybeClient =
raw_socket_stream::createConnectedUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeClient, llvm::Succeeded());

Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept();
ServerListener->accept();
ASSERT_THAT_EXPECTED(MaybeServer, llvm::Succeeded());
raw_socket_stream &Server = **MaybeServer;

Expand All @@ -99,49 +92,23 @@ TEST(raw_socket_streamTest, READ_WITH_TIMEOUT) {
Server.clear_error();
}

TEST(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("accept_with_timeout-%%%%%%.sock", SocketPath,
true);
auto Cleanup = llvm::scope_exit([&] { std::remove(SocketPath.c_str()); });

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);

TEST_F(raw_socket_streamTest, ACCEPT_WITH_TIMEOUT) {
Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept(std::chrono::milliseconds(100));
ServerListener->accept(std::chrono::milliseconds(100));
ASSERT_EQ(llvm::errorToErrorCode(MaybeServer.takeError()),
std::errc::timed_out);
}

TEST(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
if (!hasUnixSocketSupport())
GTEST_SKIP();

SmallString<100> SocketPath;
llvm::sys::fs::createUniquePath("accept_with_shutdown-%%%%%%.sock",
SocketPath, true);
auto Cleanup = llvm::scope_exit([&] { std::remove(SocketPath.c_str()); });

Expected<ListeningSocket> MaybeServerListener =
ListeningSocket::createUnix(SocketPath);
ASSERT_THAT_EXPECTED(MaybeServerListener, llvm::Succeeded());
ListeningSocket ServerListener = std::move(*MaybeServerListener);

TEST_F(raw_socket_streamTest, ACCEPT_WITH_SHUTDOWN) {
// Create a separate thread to close the socket after a delay. Simulates a
// signal handler calling ServerListener::shutdown
std::thread CloseThread([&]() {
std::this_thread::sleep_for(std::chrono::milliseconds(500));
ServerListener.shutdown();
ServerListener->shutdown();
});

Expected<std::unique_ptr<raw_socket_stream>> MaybeServer =
ServerListener.accept();
ServerListener->accept();

// Wait for the CloseThread to finish
CloseThread.join();
Expand Down