diff --git a/CMakeLists.txt b/CMakeLists.txt index 54057727..d80e87f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -11,6 +11,8 @@ if(NOT ENABLE_WEBSOCKET) return() endif() +option(ENABLE_WEBSOCKET_CLIENT_TLS "Enable TLS (wss://) for obs-websocket client mode" ON) + # Find Qt find_package(Qt6 REQUIRED Core Widgets Svg Network) @@ -28,6 +30,11 @@ find_package(Websocketpp 0.8 REQUIRED) # Find Asio find_package(Asio 1.12.1 REQUIRED) +if(ENABLE_WEBSOCKET_CLIENT_TLS) + # Find OpenSSL (required for TLS client mode) + find_package(OpenSSL REQUIRED) +endif() + add_library(obs-websocket MODULE) add_library(OBS::websocket ALIAS obs-websocket) @@ -49,12 +56,15 @@ target_sources( target_sources( obs-websocket PRIVATE # cmake-format: sortable + src/websocketclient/WebSocketClient.cpp + src/websocketclient/WebSocketClient.h src/websocketserver/rpc/WebSocketSession.h src/websocketserver/types/WebSocketCloseCode.h src/websocketserver/types/WebSocketOpCode.h + src/websocketserver/WebSocketProtocol.cpp + src/websocketserver/WebSocketProtocol.h src/websocketserver/WebSocketServer.cpp - src/websocketserver/WebSocketServer.h - src/websocketserver/WebSocketServer_Protocol.cpp) + src/websocketserver/WebSocketServer.h) target_sources( obs-websocket @@ -131,7 +141,8 @@ target_sources(obs-websocket PRIVATE plugin-macros.generated.h) target_compile_definitions( obs-websocket PRIVATE ASIO_STANDALONE $<$:PLUGIN_TESTS> - $<$:_WEBSOCKETPP_CPP11_STL_> $<$:_WIN32_WINNT=0x0603>) + $<$:_WEBSOCKETPP_CPP11_STL_> $<$:_WIN32_WINNT=0x0603> + OBS_WEBSOCKET_CLIENT_TLS=$) target_compile_options( obs-websocket @@ -165,6 +176,10 @@ target_link_libraries( Asio::Asio qrcodegencpp::qrcodegencpp) +if(ENABLE_WEBSOCKET_CLIENT_TLS) + target_link_libraries(obs-websocket PRIVATE OpenSSL::SSL OpenSSL::Crypto) +endif() + target_link_options(obs-websocket PRIVATE $<$:/IGNORE:4099>) set_target_properties_obs( diff --git a/README.md b/README.md index 2e5e8142..66d89c47 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ Binaries **for OBS Studio < 28.0.0** on Windows, MacOS, and Linux are available It is **highly recommended** to keep obs-websocket protected with a password against unauthorized control. obs-websocket generates a password for you automatically when you load it for the first time. To change this, open the "obs-websocket Settings" dialog under OBS' "Tools" menu. In the settings dialog, you can enable or disable authentication and set a password for it. -(Psst. You can use `--websocket_port`(value), `--websocket_password`(value), `--websocket_debug`(flag) and `--websocket_ipv4_only`(flag) on the command line to override the configured values.) +(Psst. You can use `--websocket_port`(value), `--websocket_password`(value), `--websocket_debug`(flag) and `--websocket_ipv4_only`(flag) on the command line to override configured **server** values. For outbound client mode, you can use `--websocket_client_enabled`(flag), `--websocket_client_url`(value), `--websocket_client_password`(value), `--websocket_client_allow_insecure`(flag), and `--websocket_client_allow_invalid_cert`(flag).) ### Possible use cases @@ -61,6 +61,14 @@ Here's a list of available language APIs for obs-websocket: The 5.x server is a typical WebSocket server running by default on port 4455 (the port number can be changed in the Settings dialog under `Tools`). The protocol we use is documented in [PROTOCOL.md](docs/generated/protocol.md). +### Client mode (outbound) + +obs-websocket can optionally initiate an outbound WebSocket connection to a remote controller while preserving the existing protocol semantics (OBS still sends `Hello`, the remote controller `Identifies`). Client mode is disabled by default and lives in the obs-websocket Settings dialog. Enter a full WebSocket connection URL using `ws://` or `wss://` (for example `wss://controller.example.com:4455`). + +Client connection URLs must not include credentials, a path, a query string, or a fragment. If a port is provided, it must be in the range `1-65534`. + +TLS (`wss://`) is recommended, while unencrypted `ws://` connections are gated behind an explicit unsafe toggle. TLS support requires OpenSSL and can be toggled at build time with `-DENABLE_WEBSOCKET_CLIENT_TLS=ON|OFF` (default ON). Server mode and outbound client mode can be enabled independently and run at the same time. + We'd like to know what you're building with obs-websocket! If you do something in this fashion, feel free to drop a message in `#project-showoff` in the [discord server!](https://discord.gg/WBaSQ3A) ## Contributors diff --git a/data/locale/en-US.ini b/data/locale/en-US.ini index ea566aa9..f1c6df9c 100644 --- a/data/locale/en-US.ini +++ b/data/locale/en-US.ini @@ -1,6 +1,6 @@ OBSWebSocket.Plugin.Description="Remote-control of OBS Studio through WebSocket" -OBSWebSocket.Settings.DialogTitle="WebSocket Server Settings" +OBSWebSocket.Settings.DialogTitle="WebSocket Settings" OBSWebSocket.Settings.PluginSettingsTitle="Plugin Settings" OBSWebSocket.Settings.ServerEnable="Enable WebSocket server" @@ -23,7 +23,49 @@ OBSWebSocket.Settings.Save.UserPasswordWarningInfoText="Are you sure you want to OBSWebSocket.Settings.Save.PasswordInvalidErrorTitle="Error: Invalid Configuration" OBSWebSocket.Settings.Save.PasswordInvalidErrorMessage="You must use a password that is 6 or more characters." +OBSWebSocket.Settings.ClientSettingsTitle="Client Settings (Outbound)" +OBSWebSocket.Settings.ClientEnable="Enable WebSocket client mode (outbound)" +OBSWebSocket.Settings.ClientHost="Client Connection" +OBSWebSocket.Settings.ClientHostPlaceholder="wss://controller.example.com:4455" +OBSWebSocket.Settings.ClientPort="Client Port" +OBSWebSocket.Settings.ClientUseTls="Use TLS (wss://)" +OBSWebSocket.Settings.ClientAllowInsecure="Allow unencrypted ws:// connections (unsafe)" +OBSWebSocket.Settings.ClientAllowInvalidCert="Allow invalid TLS certificates (unsafe)" +OBSWebSocket.Settings.ClientAuthRequired="Require Authentication" +OBSWebSocket.Settings.ClientPassword="Client Password" +OBSWebSocket.Settings.Show="Show" +OBSWebSocket.Settings.Hide="Hide" +OBSWebSocket.Settings.ClientGeneratePassword="Generate Password" +OBSWebSocket.Settings.ClientStatus="Client Status" +OBSWebSocket.Settings.ClientStatusUnavailable="Unavailable" +OBSWebSocket.Settings.ClientStatusDisabled="Disabled" +OBSWebSocket.Settings.ClientStatusConnecting="Connecting" +OBSWebSocket.Settings.ClientStatusConnected="Connected" +OBSWebSocket.Settings.ClientStatusDisconnected="Disconnected" +OBSWebSocket.Settings.ClientStatusError="Error" +OBSWebSocket.Settings.ClientWarningText="Client mode initiates an outbound connection to a remote controller. Only enable for trusted endpoints." +OBSWebSocket.Settings.ClientHostInvalidTitle="Error: Invalid Client Connection" +OBSWebSocket.Settings.ClientHostInvalidMessage="Client connection is required when client mode is enabled." +OBSWebSocket.Settings.ClientConnectionInvalidMessage="Client connection must be a valid URL (for example: wss://controller.example.com:4455)." +OBSWebSocket.Settings.ClientConnectionInvalidSchemeMessage="Client connection must use ws:// or wss://." +OBSWebSocket.Settings.ClientConnectionCredentialsMessage="Client connection must not include username or password." +OBSWebSocket.Settings.ClientConnectionPathMessage="Client connection must not include a path, query, or fragment." +OBSWebSocket.Settings.ClientConnectionPortMessage="Client connection port must be between 1 and 65534." +OBSWebSocket.Settings.ClientInsecureBlockedTitle="Error: Insecure Client Mode" +OBSWebSocket.Settings.ClientInsecureBlockedMessage="Unencrypted ws:// is disabled. Enable the unsafe toggle to allow ws://." +OBSWebSocket.Settings.ClientTlsDisabledMessage="TLS is disabled in this build. Rebuild with ENABLE_WEBSOCKET_CLIENT_TLS=ON to allow wss://." +OBSWebSocket.Settings.Save.ClientPasswordWarningTitle="Warning: Potential Security Issue" +OBSWebSocket.Settings.Save.ClientPasswordWarningMessage="obs-websocket stores the client password as plain text. Using a password generated by obs-websocket is highly recommended." +OBSWebSocket.Settings.Save.ClientPasswordWarningInfoText="Are you sure you want to use your own password?" +OBSWebSocket.Settings.ClientEnableWarningTitle="Warning: Client Mode" +OBSWebSocket.Settings.ClientEnableWarningMessage="Client mode allows remote control of OBS by connecting to a remote server." +OBSWebSocket.Settings.ClientEnableWarningInfoText="Only enable this for endpoints you trust." +OBSWebSocket.Settings.ClientInsecureWarningTitle="Warning: Insecure Client Settings" +OBSWebSocket.Settings.ClientInsecureWarningMessage="You are allowing insecure client settings (unencrypted or invalid certificates)." +OBSWebSocket.Settings.ClientInsecureWarningInfoText="Only proceed if you understand the risks." + OBSWebSocket.SessionTable.Title="Connected WebSocket Sessions" +OBSWebSocket.SessionTable.ServerOnlyNote="Shows inbound connections to this OBS instance only. Outbound client mode appears above." OBSWebSocket.SessionTable.RemoteAddressColumnTitle="Remote Address" OBSWebSocket.SessionTable.SessionDurationColumnTitle="Session Duration" OBSWebSocket.SessionTable.MessagesInOutColumnTitle="Messages In/Out" diff --git a/src/Config.cpp b/src/Config.cpp index 34b3a18a..41ba21d9 100644 --- a/src/Config.cpp +++ b/src/Config.cpp @@ -19,6 +19,7 @@ with this program. If not, see #include +#include #include #include "Config.h" @@ -41,11 +42,89 @@ with this program. If not, see #define PARAM_ALERTS "alerts_enabled" #define PARAM_AUTHREQUIRED "auth_required" #define PARAM_PASSWORD "server_password" +#define PARAM_CLIENT_ENABLED "client_enabled" +#define PARAM_CLIENT_HOST "client_host" +#define PARAM_CLIENT_PORT "client_port" +#define PARAM_CLIENT_USE_TLS "client_use_tls" +#define PARAM_CLIENT_ALLOW_INSECURE "client_allow_insecure" +#define PARAM_CLIENT_ALLOW_INVALID_CERT "client_allow_invalid_cert" +#define PARAM_CLIENT_AUTH_REQUIRED "client_auth_required" +#define PARAM_CLIENT_PASSWORD "client_password" #define CMDLINE_WEBSOCKET_PORT "websocket_port" #define CMDLINE_WEBSOCKET_IPV4_ONLY "websocket_ipv4_only" #define CMDLINE_WEBSOCKET_PASSWORD "websocket_password" #define CMDLINE_WEBSOCKET_DEBUG "websocket_debug" +#define CMDLINE_WEBSOCKET_CLIENT_ENABLED "websocket_client_enabled" +#define CMDLINE_WEBSOCKET_CLIENT_URL "websocket_client_url" +#define CMDLINE_WEBSOCKET_CLIENT_PASSWORD "websocket_client_password" +#define CMDLINE_WEBSOCKET_CLIENT_ALLOW_INSECURE "websocket_client_allow_insecure" +#define CMDLINE_WEBSOCKET_CLIENT_ALLOW_INVALID_CERT "websocket_client_allow_invalid_cert" + +namespace { + struct ParsedClientEndpoint { + bool valid = false; + bool useTls = false; + std::string host; + bool hasPort = false; + uint16_t port = 0; + std::string error; + }; + + ParsedClientEndpoint ParseClientEndpoint(QString endpointInput) + { + ParsedClientEndpoint parsed; + QString trimmed = endpointInput.trimmed(); + if (trimmed.isEmpty()) { + parsed.error = "value is empty"; + return parsed; + } + + QUrl endpoint(trimmed, QUrl::StrictMode); + if (!endpoint.isValid()) { + parsed.error = "value is not a valid URL"; + return parsed; + } + + QString scheme = endpoint.scheme().toLower(); + if (scheme != "ws" && scheme != "wss") { + parsed.error = "URL scheme must be ws:// or wss://"; + return parsed; + } + + if (endpoint.host().isEmpty()) { + parsed.error = "URL host is empty"; + return parsed; + } + + if (!endpoint.userName().isEmpty() || !endpoint.password().isEmpty()) { + parsed.error = "URL must not include credentials"; + return parsed; + } + + QString path = endpoint.path(); + if ((path.size() > 1) || endpoint.hasQuery() || endpoint.hasFragment()) { + parsed.error = "URL must not include a path, query, or fragment"; + return parsed; + } + + int parsedPort = endpoint.port(-1); + if (parsedPort != -1) { + if (parsedPort < 1 || parsedPort > 65534) { + parsed.error = "URL port must be between 1 and 65534"; + return parsed; + } + + parsed.hasPort = true; + parsed.port = static_cast(parsedPort); + } + + parsed.valid = true; + parsed.useTls = (scheme == "wss"); + parsed.host = endpoint.host().toStdString(); + return parsed; + } +} void Config::Load(json config) { @@ -72,6 +151,22 @@ void Config::Load(json config) AuthRequired = config[PARAM_AUTHREQUIRED]; if (config.contains(PARAM_PASSWORD) && config[PARAM_PASSWORD].is_string()) ServerPassword = config[PARAM_PASSWORD]; + if (config.contains(PARAM_CLIENT_ENABLED) && config[PARAM_CLIENT_ENABLED].is_boolean()) + ClientEnabled = config[PARAM_CLIENT_ENABLED]; + if (config.contains(PARAM_CLIENT_HOST) && config[PARAM_CLIENT_HOST].is_string()) + ClientHost = config[PARAM_CLIENT_HOST]; + if (config.contains(PARAM_CLIENT_PORT) && config[PARAM_CLIENT_PORT].is_number_unsigned()) + ClientPort = config[PARAM_CLIENT_PORT]; + if (config.contains(PARAM_CLIENT_USE_TLS) && config[PARAM_CLIENT_USE_TLS].is_boolean()) + ClientUseTls = config[PARAM_CLIENT_USE_TLS]; + if (config.contains(PARAM_CLIENT_ALLOW_INSECURE) && config[PARAM_CLIENT_ALLOW_INSECURE].is_boolean()) + ClientAllowInsecure = config[PARAM_CLIENT_ALLOW_INSECURE]; + if (config.contains(PARAM_CLIENT_ALLOW_INVALID_CERT) && config[PARAM_CLIENT_ALLOW_INVALID_CERT].is_boolean()) + ClientAllowInvalidCert = config[PARAM_CLIENT_ALLOW_INVALID_CERT]; + if (config.contains(PARAM_CLIENT_AUTH_REQUIRED) && config[PARAM_CLIENT_AUTH_REQUIRED].is_boolean()) + ClientAuthRequired = config[PARAM_CLIENT_AUTH_REQUIRED]; + if (config.contains(PARAM_CLIENT_PASSWORD) && config[PARAM_CLIENT_PASSWORD].is_string()) + ClientPassword = config[PARAM_CLIENT_PASSWORD]; // Set server password and save it to the config before processing overrides, // so that there is always a true configured password regardless of if @@ -87,6 +182,11 @@ void Config::Load(json config) Save(); } + if (ClientHost.empty()) + ClientHost = "127.0.0.1"; + if (ClientPassword.empty()) + ClientPassword = ServerPassword; + // If there are migrated settings, write them to disk before processing arguments. if (!config.empty()) Save(); @@ -126,6 +226,49 @@ void Config::Load(json config) blog(LOG_INFO, "[Config::Load] --websocket_debug passed. Enabling debug logging."); DebugEnabled = true; } + + // Process `--websocket_client_enabled` override + if (Utils::Platform::GetCommandLineFlagSet(CMDLINE_WEBSOCKET_CLIENT_ENABLED)) { + blog(LOG_INFO, "[Config::Load] --websocket_client_enabled passed. Enabling WebSocket client mode."); + ClientEnabled = true; + } + + // Process `--websocket_client_url` override + QString clientUrlArgument = Utils::Platform::GetCommandLineArgument(CMDLINE_WEBSOCKET_CLIENT_URL); + if (clientUrlArgument != "") { + auto parsedEndpoint = ParseClientEndpoint(clientUrlArgument); + if (parsedEndpoint.valid) { + blog(LOG_INFO, "[Config::Load] --websocket_client_url passed. Overriding client endpoint."); + ClientHost = parsedEndpoint.host; + ClientUseTls = parsedEndpoint.useTls; + if (parsedEndpoint.hasPort) + ClientPort = parsedEndpoint.port; + } else { + blog(LOG_WARNING, "[Config::Load] Ignoring --websocket_client_url override: %s", + parsedEndpoint.error.c_str()); + } + } + + // Process `--websocket_client_password` override + QString clientPasswordArgument = Utils::Platform::GetCommandLineArgument(CMDLINE_WEBSOCKET_CLIENT_PASSWORD); + if (clientPasswordArgument != "") { + blog(LOG_INFO, "[Config::Load] --websocket_client_password passed. Overriding WebSocket client password."); + ClientAuthRequired = true; + ClientPassword = clientPasswordArgument.toStdString(); + } + + // Process `--websocket_client_allow_insecure` override + if (Utils::Platform::GetCommandLineFlagSet(CMDLINE_WEBSOCKET_CLIENT_ALLOW_INSECURE)) { + blog(LOG_INFO, "[Config::Load] --websocket_client_allow_insecure passed. Enabling insecure ws:// client mode."); + ClientAllowInsecure = true; + } + + // Process `--websocket_client_allow_invalid_cert` override + if (Utils::Platform::GetCommandLineFlagSet(CMDLINE_WEBSOCKET_CLIENT_ALLOW_INVALID_CERT)) { + blog(LOG_INFO, + "[Config::Load] --websocket_client_allow_invalid_cert passed. Allowing invalid TLS certs for client mode."); + ClientAllowInvalidCert = true; + } } void Config::Save() @@ -144,6 +287,14 @@ void Config::Save() config[PARAM_AUTHREQUIRED] = AuthRequired.load(); config[PARAM_PASSWORD] = ServerPassword; } + config[PARAM_CLIENT_ENABLED] = ClientEnabled.load(); + config[PARAM_CLIENT_HOST] = ClientHost; + config[PARAM_CLIENT_PORT] = ClientPort.load(); + config[PARAM_CLIENT_USE_TLS] = ClientUseTls.load(); + config[PARAM_CLIENT_ALLOW_INSECURE] = ClientAllowInsecure.load(); + config[PARAM_CLIENT_ALLOW_INVALID_CERT] = ClientAllowInvalidCert.load(); + config[PARAM_CLIENT_AUTH_REQUIRED] = ClientAuthRequired.load(); + config[PARAM_CLIENT_PASSWORD] = ClientPassword; if (Utils::Json::SetJsonFileContent(configFilePath, config)) blog(LOG_DEBUG, "[Config::Save] Saved config."); diff --git a/src/Config.h b/src/Config.h index f0b52bd7..e43c89ef 100644 --- a/src/Config.h +++ b/src/Config.h @@ -41,6 +41,15 @@ struct Config { std::atomic AlertsEnabled = false; std::atomic AuthRequired = true; std::string ServerPassword; + + std::atomic ClientEnabled = false; + std::string ClientHost = "127.0.0.1"; + std::atomic ClientPort = 4455; + std::atomic ClientUseTls = true; + std::atomic ClientAllowInsecure = false; + std::atomic ClientAllowInvalidCert = false; + std::atomic ClientAuthRequired = true; + std::string ClientPassword; }; json MigrateGlobalConfigData(); diff --git a/src/forms/SettingsDialog.cpp b/src/forms/SettingsDialog.cpp index eaf2027e..d9b2b63d 100644 --- a/src/forms/SettingsDialog.cpp +++ b/src/forms/SettingsDialog.cpp @@ -19,7 +19,11 @@ with this program. If not, see #include #include +#include #include +#include + +#include #include #include @@ -27,6 +31,7 @@ with this program. If not, see #include "../obs-websocket.h" #include "../Config.h" #include "../websocketserver/WebSocketServer.h" +#include "../websocketclient/WebSocketClient.h" #include "../utils/Crypto.h" QString GetToolTipIconHtml() @@ -37,35 +42,135 @@ QString GetToolTipIconHtml() return iconTemplate.arg(iconFile); } +struct ParsedClientConnection { + bool valid = false; + bool useTls = false; + std::string host; + std::optional port; + const char *errorTitleKey = "OBSWebSocket.Settings.ClientHostInvalidTitle"; + const char *errorMessageKey = "OBSWebSocket.Settings.ClientHostInvalidMessage"; +}; + +QString BuildClientConnectionInput(const std::string &host, uint16_t port, bool useTls) +{ + if (host.empty()) + return {}; + + QString hostValue = QString::fromStdString(host); + if (hostValue.contains(':') && !hostValue.startsWith('[') && !hostValue.endsWith(']')) + hostValue = QString("[%1]").arg(hostValue); + + return QString("%1://%2:%3").arg(useTls ? "wss" : "ws", hostValue).arg(port); +} + +ParsedClientConnection ParseClientConnectionInput(const QString &connectionInput) +{ + ParsedClientConnection parsed; + QString trimmed = connectionInput.trimmed(); + if (trimmed.isEmpty()) + return parsed; + + QUrl endpoint(trimmed, QUrl::StrictMode); + if (!endpoint.isValid()) { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientConnectionInvalidMessage"; + return parsed; + } + + QString scheme = endpoint.scheme().toLower(); + if (scheme != "ws" && scheme != "wss") { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientConnectionInvalidSchemeMessage"; + return parsed; + } + + if (endpoint.host().isEmpty()) { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientHostInvalidMessage"; + return parsed; + } + + if (!endpoint.userName().isEmpty() || !endpoint.password().isEmpty()) { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientConnectionCredentialsMessage"; + return parsed; + } + + QString path = endpoint.path(); + if ((path.size() > 1) || endpoint.hasQuery() || endpoint.hasFragment()) { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientConnectionPathMessage"; + return parsed; + } + + int parsedPort = endpoint.port(-1); + if (parsedPort != -1) { + if (parsedPort < 1 || parsedPort > 65534) { + parsed.errorMessageKey = "OBSWebSocket.Settings.ClientConnectionPortMessage"; + return parsed; + } + parsed.port = static_cast(parsedPort); + } + + parsed.valid = true; + parsed.useTls = (scheme == "wss"); + parsed.host = endpoint.host().toStdString(); + return parsed; +} + +void SetClientPasswordVisible(Ui::SettingsDialog *ui, bool visible) +{ + ui->clientPasswordLineEdit->setEchoMode(visible ? QLineEdit::Normal : QLineEdit::Password); + ui->showClientPasswordButton->setText( + obs_module_text(visible ? "OBSWebSocket.Settings.Hide" : "OBSWebSocket.Settings.Show")); +} + SettingsDialog::SettingsDialog(QWidget *parent) : QDialog(parent, Qt::Dialog), ui(new Ui::SettingsDialog), connectInfo(new ConnectInfo), sessionTableTimer(new QTimer), - passwordManuallyEdited(false) + passwordManuallyEdited(false), + clientPasswordManuallyEdited(false) { ui->setupUi(this); ui->websocketSessionTable->horizontalHeader()->resizeSection(3, 100); // Resize Session Table column widths ui->websocketSessionTable->horizontalHeader()->resizeSection(4, 100); + // Ensure the window cannot be resized below its full content height. + const int contentMinHeight = sizeHint().height(); + if (minimumHeight() < contentMinHeight) + setMinimumHeight(contentMinHeight); + if (height() < contentMinHeight) + resize(width(), contentMinHeight); + // Remove the ? button on dialogs on Windows setWindowFlags(windowFlags() & ~Qt::WindowContextHelpButtonHint); // Set the appropriate tooltip icon for the theme + ui->clientWarningToolTipLabel->setText(GetToolTipIconHtml()); ui->enableDebugLoggingToolTipLabel->setText(GetToolTipIconHtml()); connect(sessionTableTimer, &QTimer::timeout, this, &SettingsDialog::FillSessionTable); connect(ui->buttonBox, &QDialogButtonBox::clicked, this, &SettingsDialog::DialogButtonClicked); #if QT_VERSION >= QT_VERSION_CHECK(6, 7, 0) + connect(ui->enableWebSocketServerCheckBox, &QCheckBox::checkStateChanged, this, &SettingsDialog::UpdateServerUiState); connect(ui->enableAuthenticationCheckBox, &QCheckBox::checkStateChanged, this, &SettingsDialog::EnableAuthenticationCheckBoxChanged); + connect(ui->enableWebSocketClientCheckBox, &QCheckBox::checkStateChanged, this, &SettingsDialog::UpdateClientUiState); + connect(ui->clientAuthRequiredCheckBox, &QCheckBox::checkStateChanged, this, &SettingsDialog::UpdateClientUiState); + connect(ui->clientAllowInsecureCheckBox, &QCheckBox::checkStateChanged, this, &SettingsDialog::UpdateClientUiState); #else + connect(ui->enableWebSocketServerCheckBox, &QCheckBox::stateChanged, this, &SettingsDialog::UpdateServerUiState); connect(ui->enableAuthenticationCheckBox, &QCheckBox::stateChanged, this, &SettingsDialog::EnableAuthenticationCheckBoxChanged); + connect(ui->enableWebSocketClientCheckBox, &QCheckBox::stateChanged, this, &SettingsDialog::UpdateClientUiState); + connect(ui->clientAuthRequiredCheckBox, &QCheckBox::stateChanged, this, &SettingsDialog::UpdateClientUiState); + connect(ui->clientAllowInsecureCheckBox, &QCheckBox::stateChanged, this, &SettingsDialog::UpdateClientUiState); #endif + connect(ui->clientHostLineEdit, &QLineEdit::textChanged, this, &SettingsDialog::UpdateClientUiState); connect(ui->generatePasswordButton, &QPushButton::clicked, this, &SettingsDialog::GeneratePasswordButtonClicked); + connect(ui->generateClientPasswordButton, &QPushButton::clicked, this, + &SettingsDialog::GenerateClientPasswordButtonClicked); + connect(ui->showClientPasswordButton, &QPushButton::clicked, this, &SettingsDialog::ToggleClientPasswordVisibility); connect(ui->showConnectInfoButton, &QPushButton::clicked, this, &SettingsDialog::ShowConnectInfoButtonClicked); connect(ui->serverPasswordLineEdit, &QLineEdit::textEdited, this, &SettingsDialog::PasswordEdited); + connect(ui->clientPasswordLineEdit, &QLineEdit::textEdited, this, &SettingsDialog::ClientPasswordEdited); } SettingsDialog::~SettingsDialog() @@ -93,6 +198,7 @@ void SettingsDialog::showEvent(QShowEvent *) } passwordManuallyEdited = false; + clientPasswordManuallyEdited = false; RefreshData(); @@ -133,6 +239,19 @@ void SettingsDialog::RefreshData() ui->serverPasswordLineEdit->setEnabled(conf->AuthRequired); ui->generatePasswordButton->setEnabled(conf->AuthRequired); + ui->enableWebSocketClientCheckBox->setChecked(conf->ClientEnabled); + ui->clientHostLineEdit->setText( + BuildClientConnectionInput(conf->ClientHost, conf->ClientPort.load(), conf->ClientUseTls.load())); + ui->clientAllowInsecureCheckBox->setChecked(conf->ClientAllowInsecure); + ui->clientAllowInvalidCertCheckBox->setChecked(conf->ClientAllowInvalidCert); + ui->clientAuthRequiredCheckBox->setChecked(conf->ClientAuthRequired); + ui->clientPasswordLineEdit->setText(QString::fromStdString(conf->ClientPassword)); + SetClientPasswordVisible(ui, false); + + UpdateServerUiState(); + UpdateClientUiState(); + UpdateClientStatus(); + FillSessionTable(); } @@ -182,10 +301,116 @@ void SettingsDialog::SaveFormData() } } - bool needsRestart = (conf->ServerEnabled != ui->enableWebSocketServerCheckBox->isChecked()) || - (conf->ServerPort != ui->serverPortSpinBox->value()) || - (ui->enableAuthenticationCheckBox->isChecked() && - conf->ServerPassword != ui->serverPasswordLineEdit->text().toStdString()); + const bool clientEnabled = ui->enableWebSocketClientCheckBox->isChecked(); + const bool clientAuthRequired = ui->clientAuthRequiredCheckBox->isChecked(); + const ParsedClientConnection parsedClientConnection = ParseClientConnectionInput(ui->clientHostLineEdit->text()); + + if (clientEnabled && !parsedClientConnection.valid) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text(parsedClientConnection.errorTitleKey)); + msgBox.setText(obs_module_text(parsedClientConnection.errorMessageKey)); + msgBox.setStandardButtons(QMessageBox::Ok); + msgBox.exec(); + return; + } + + const bool clientUseTls = parsedClientConnection.valid ? parsedClientConnection.useTls : conf->ClientUseTls.load(); + const std::string clientHost = parsedClientConnection.valid ? parsedClientConnection.host : conf->ClientHost; + const uint16_t clientPort = parsedClientConnection.valid && parsedClientConnection.port.has_value() + ? *parsedClientConnection.port + : conf->ClientPort.load(); + +#if !OBS_WEBSOCKET_CLIENT_TLS + if (clientEnabled && clientUseTls) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.ClientHostInvalidTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.ClientTlsDisabledMessage")); + msgBox.setStandardButtons(QMessageBox::Ok); + msgBox.exec(); + return; + } +#endif + + if (clientEnabled && !clientUseTls && !ui->clientAllowInsecureCheckBox->isChecked()) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.ClientInsecureBlockedTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.ClientInsecureBlockedMessage")); + msgBox.setStandardButtons(QMessageBox::Ok); + msgBox.exec(); + return; + } + + if (clientEnabled && clientAuthRequired && ui->clientPasswordLineEdit->text().length() < 6) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.Save.PasswordInvalidErrorTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.Save.PasswordInvalidErrorMessage")); + msgBox.setStandardButtons(QMessageBox::Ok); + msgBox.exec(); + return; + } + + if (clientEnabled && clientAuthRequired && clientPasswordManuallyEdited && + (conf->ClientPassword != ui->clientPasswordLineEdit->text().toStdString())) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.Save.ClientPasswordWarningTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.Save.ClientPasswordWarningMessage")); + msgBox.setInformativeText(obs_module_text("OBSWebSocket.Settings.Save.ClientPasswordWarningInfoText")); + msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::No); + msgBox.setDefaultButton(QMessageBox::No); + int ret = msgBox.exec(); + + switch (ret) { + case QMessageBox::Yes: + break; + case QMessageBox::No: + default: + ui->clientPasswordLineEdit->setText(QString::fromStdString(conf->ClientPassword)); + return; + } + } + + if (!conf->ClientEnabled && clientEnabled) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.ClientEnableWarningTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.ClientEnableWarningMessage")); + msgBox.setInformativeText(obs_module_text("OBSWebSocket.Settings.ClientEnableWarningInfoText")); + msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::No); + msgBox.setDefaultButton(QMessageBox::No); + int ret = msgBox.exec(); + if (ret != QMessageBox::Yes) { + RefreshData(); + return; + } + } + + const bool insecureChanged = (ui->clientAllowInsecureCheckBox->isChecked() && !conf->ClientAllowInsecure) || + (!clientUseTls && conf->ClientUseTls) || + (ui->clientAllowInvalidCertCheckBox->isChecked() && !conf->ClientAllowInvalidCert); + if (clientEnabled && insecureChanged) { + QMessageBox msgBox; + msgBox.setWindowTitle(obs_module_text("OBSWebSocket.Settings.ClientInsecureWarningTitle")); + msgBox.setText(obs_module_text("OBSWebSocket.Settings.ClientInsecureWarningMessage")); + msgBox.setInformativeText(obs_module_text("OBSWebSocket.Settings.ClientInsecureWarningInfoText")); + msgBox.setStandardButtons(QMessageBox::Yes | QMessageBox::No); + msgBox.setDefaultButton(QMessageBox::No); + int ret = msgBox.exec(); + if (ret != QMessageBox::Yes) { + RefreshData(); + return; + } + } + + bool serverNeedsRestart = (conf->ServerEnabled != ui->enableWebSocketServerCheckBox->isChecked()) || + (conf->ServerPort != ui->serverPortSpinBox->value()) || + (ui->enableAuthenticationCheckBox->isChecked() && + conf->ServerPassword != ui->serverPasswordLineEdit->text().toStdString()); + + bool clientNeedsRestart = (conf->ClientEnabled != clientEnabled) || (conf->ClientHost != clientHost) || + (conf->ClientPort != clientPort) || (conf->ClientUseTls != clientUseTls) || + (conf->ClientAllowInsecure != ui->clientAllowInsecureCheckBox->isChecked()) || + (conf->ClientAllowInvalidCert != ui->clientAllowInvalidCertCheckBox->isChecked()) || + (conf->ClientAuthRequired != clientAuthRequired) || + (clientAuthRequired && conf->ClientPassword != ui->clientPasswordLineEdit->text().toStdString()); conf->ServerEnabled = ui->enableWebSocketServerCheckBox->isChecked(); conf->AlertsEnabled = ui->enableSystemTrayAlertsCheckBox->isChecked(); @@ -194,12 +419,21 @@ void SettingsDialog::SaveFormData() conf->AuthRequired = ui->enableAuthenticationCheckBox->isChecked(); conf->ServerPassword = ui->serverPasswordLineEdit->text().toStdString(); + conf->ClientEnabled = clientEnabled; + conf->ClientHost = clientHost; + conf->ClientPort = clientPort; + conf->ClientUseTls = clientUseTls; + conf->ClientAllowInsecure = ui->clientAllowInsecureCheckBox->isChecked(); + conf->ClientAllowInvalidCert = ui->clientAllowInvalidCertCheckBox->isChecked(); + conf->ClientAuthRequired = clientAuthRequired; + conf->ClientPassword = ui->clientPasswordLineEdit->text().toStdString(); + conf->Save(); RefreshData(); connectInfo->RefreshData(); - if (needsRestart) { + if (serverNeedsRestart) { blog(LOG_INFO, "[SettingsDialog::SaveFormData] A setting was changed which requires a server restart."); auto server = GetWebSocketServer(); server->Stop(); @@ -207,6 +441,17 @@ void SettingsDialog::SaveFormData() server->Start(); } } + + if (clientNeedsRestart) { + blog(LOG_INFO, "[SettingsDialog::SaveFormData] A setting was changed which requires a client reconnect."); + auto client = GetWebSocketClient(); + if (client) { + client->Stop(); + if (conf->ClientEnabled) { + client->Start(); + } + } + } } void SettingsDialog::FillSessionTable() @@ -214,6 +459,7 @@ void SettingsDialog::FillSessionTable() auto webSocketServer = GetWebSocketServer(); if (!webSocketServer) { blog(LOG_ERROR, "[SettingsDialog::FillSessionTable] Unable to fetch websocket server instance!"); + UpdateClientStatus(); return; } @@ -264,10 +510,18 @@ void SettingsDialog::FillSessionTable() i++; } + + UpdateClientStatus(); } void SettingsDialog::EnableAuthenticationCheckBoxChanged() { + if (!ui->enableWebSocketServerCheckBox->isChecked()) { + ui->serverPasswordLineEdit->setEnabled(false); + ui->generatePasswordButton->setEnabled(false); + return; + } + if (ui->enableAuthenticationCheckBox->isChecked()) { ui->serverPasswordLineEdit->setEnabled(true); ui->generatePasswordButton->setEnabled(true); @@ -277,6 +531,98 @@ void SettingsDialog::EnableAuthenticationCheckBoxChanged() } } +void SettingsDialog::UpdateServerUiState() +{ + auto conf = GetConfig(); + bool serverEnabled = ui->enableWebSocketServerCheckBox->isChecked(); + + ui->serverSettingsGroupBox->setEnabled(serverEnabled); + if (!serverEnabled) { + ui->serverPasswordLineEdit->setEnabled(false); + ui->generatePasswordButton->setEnabled(false); + return; + } + + if (conf && conf->PortOverridden) + ui->serverPortSpinBox->setEnabled(false); + + if (conf && conf->PasswordOverridden) { + ui->enableAuthenticationCheckBox->setEnabled(false); + ui->serverPasswordLineEdit->setEnabled(false); + ui->generatePasswordButton->setEnabled(false); + } else { + ui->enableAuthenticationCheckBox->setEnabled(true); + EnableAuthenticationCheckBoxChanged(); + } +} + +void SettingsDialog::UpdateClientUiState() +{ + QSignalBlocker blockClientAllowInvalid(ui->clientAllowInvalidCertCheckBox); + QSignalBlocker blockClientAllowInsecure(ui->clientAllowInsecureCheckBox); + + bool clientEnabled = ui->enableWebSocketClientCheckBox->isChecked(); + bool authRequired = ui->clientAuthRequiredCheckBox->isChecked(); + + ui->clientSettingsGroupBox->setEnabled(clientEnabled); + + ui->clientHostLineEdit->setEnabled(clientEnabled); + ui->clientAllowInsecureCheckBox->setEnabled(clientEnabled); + ui->clientAuthRequiredCheckBox->setEnabled(clientEnabled); + +#if !OBS_WEBSOCKET_CLIENT_TLS + ui->clientAllowInvalidCertCheckBox->setChecked(false); + ui->clientAllowInvalidCertCheckBox->setEnabled(false); + ui->clientAllowInsecureCheckBox->setChecked(true); + ui->clientAllowInsecureCheckBox->setEnabled(false); +#else + bool connectionUsesTls = ui->clientHostLineEdit->text().trimmed().startsWith("wss://", Qt::CaseInsensitive); + if (!connectionUsesTls) + ui->clientAllowInvalidCertCheckBox->setChecked(false); + ui->clientAllowInvalidCertCheckBox->setEnabled(clientEnabled && connectionUsesTls); +#endif + + ui->clientPasswordLineEdit->setEnabled(clientEnabled && authRequired); + ui->showClientPasswordButton->setEnabled(clientEnabled && authRequired); + ui->generateClientPasswordButton->setEnabled(clientEnabled && authRequired); +} + +void SettingsDialog::UpdateClientStatus() +{ + auto client = GetWebSocketClient(); + if (!client) { + ui->clientStatusValueLabel->setText(obs_module_text("OBSWebSocket.Settings.ClientStatusUnavailable")); + return; + } + + auto status = client->GetStatus(); + QString statusText; + switch (status.state) { + case WebSocketClient::State::Disabled: + statusText = obs_module_text("OBSWebSocket.Settings.ClientStatusDisabled"); + break; + case WebSocketClient::State::Connecting: + statusText = obs_module_text("OBSWebSocket.Settings.ClientStatusConnecting"); + break; + case WebSocketClient::State::Connected: + statusText = obs_module_text("OBSWebSocket.Settings.ClientStatusConnected"); + break; + case WebSocketClient::State::Disconnected: + statusText = obs_module_text("OBSWebSocket.Settings.ClientStatusDisconnected"); + break; + case WebSocketClient::State::Error: + statusText = obs_module_text("OBSWebSocket.Settings.ClientStatusError"); + break; + } + + if (!status.endpoint.empty()) + statusText += QString(" (%1)").arg(QString::fromStdString(status.endpoint)); + if (!status.lastError.empty()) + statusText += QString(" - %1").arg(QString::fromStdString(status.lastError)); + + ui->clientStatusValueLabel->setText(statusText); +} + void SettingsDialog::GeneratePasswordButtonClicked() { QString newPassword = QString::fromStdString(Utils::Crypto::GeneratePassword()); @@ -285,6 +631,20 @@ void SettingsDialog::GeneratePasswordButtonClicked() passwordManuallyEdited = false; } +void SettingsDialog::GenerateClientPasswordButtonClicked() +{ + QString newPassword = QString::fromStdString(Utils::Crypto::GeneratePassword()); + ui->clientPasswordLineEdit->setText(newPassword); + ui->clientPasswordLineEdit->selectAll(); + clientPasswordManuallyEdited = false; +} + +void SettingsDialog::ToggleClientPasswordVisibility() +{ + bool showing = ui->clientPasswordLineEdit->echoMode() != QLineEdit::Password; + SetClientPasswordVisible(ui, !showing); +} + void SettingsDialog::ShowConnectInfoButtonClicked() { if (obs_video_active()) { @@ -315,3 +675,8 @@ void SettingsDialog::PasswordEdited() { passwordManuallyEdited = true; } + +void SettingsDialog::ClientPasswordEdited() +{ + clientPasswordManuallyEdited = true; +} diff --git a/src/forms/SettingsDialog.h b/src/forms/SettingsDialog.h index 64f90c9b..43c0dad6 100644 --- a/src/forms/SettingsDialog.h +++ b/src/forms/SettingsDialog.h @@ -43,13 +43,20 @@ private Q_SLOTS: void SaveFormData(); void FillSessionTable(); void EnableAuthenticationCheckBoxChanged(); + void UpdateServerUiState(); + void UpdateClientUiState(); + void UpdateClientStatus(); void GeneratePasswordButtonClicked(); + void GenerateClientPasswordButtonClicked(); + void ToggleClientPasswordVisibility(); void ShowConnectInfoButtonClicked(); void PasswordEdited(); + void ClientPasswordEdited(); private: Ui::SettingsDialog *ui; ConnectInfo *connectInfo; QTimer *sessionTableTimer; bool passwordManuallyEdited; + bool clientPasswordManuallyEdited; }; diff --git a/src/forms/SettingsDialog.ui b/src/forms/SettingsDialog.ui index 6eb40824..6c48fcaa 100644 --- a/src/forms/SettingsDialog.ui +++ b/src/forms/SettingsDialog.ui @@ -7,13 +7,13 @@ 0 0 675 - 565 + 700 675 - 0 + 700 @@ -70,14 +70,73 @@ + + + + Qt::Horizontal + + + QSizePolicy::Fixed + + + + 150 + 20 + + + + + + + 2 + + + + + OBSWebSocket.Settings.ClientWarningText + + + OBSWebSocket.Settings.ClientEnable + + + + + + + OBSWebSocket.Settings.ClientWarningText + + + true + + + Qt::AlignCenter + + + + + + + Qt::Horizontal + + + + 40 + 20 + + + + + + + OBSWebSocket.Settings.AlertsEnable - + 2 @@ -124,7 +183,7 @@ - + 0 @@ -217,10 +276,112 @@ - - - - + + + + + + + 0 + 0 + + + + OBSWebSocket.Settings.ClientSettingsTitle + + + + Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter + + + 2 + + + + + OBSWebSocket.Settings.ClientHost + + + + + + + OBSWebSocket.Settings.ClientHostPlaceholder + + + + + + + OBSWebSocket.Settings.ClientAllowInsecure + + + + + + + OBSWebSocket.Settings.ClientAllowInvalidCert + + + + + + + OBSWebSocket.Settings.ClientAuthRequired + + + + + + + OBSWebSocket.Settings.ClientPassword + + + + + + + + + QLineEdit::Password + + + + + + + OBSWebSocket.Settings.Show + + + + + + + OBSWebSocket.Settings.ClientGeneratePassword + + + + + + + + + OBSWebSocket.Settings.ClientStatus + + + + + + + OBSWebSocket.Settings.ClientStatusDisabled + + + + + + + + 0 @@ -230,7 +391,7 @@ OBSWebSocket.SessionTable.Title - + @@ -281,9 +442,19 @@ + + + + OBSWebSocket.SessionTable.ServerOnlyNote + + + true + + + - + diff --git a/src/obs-websocket.cpp b/src/obs-websocket.cpp index 7c9071ee..84336048 100644 --- a/src/obs-websocket.cpp +++ b/src/obs-websocket.cpp @@ -26,6 +26,7 @@ with this program. If not, see #include "Config.h" #include "WebSocketApi.h" #include "websocketserver/WebSocketServer.h" +#include "websocketclient/WebSocketClient.h" #include "eventhandler/EventHandler.h" #include "forms/SettingsDialog.h" @@ -46,6 +47,7 @@ ConfigPtr _config; EventHandlerPtr _eventHandler; WebSocketApiPtr _webSocketApi; WebSocketServerPtr _webSocketServer; +WebSocketClientPtr _webSocketClient; SettingsDialog *_settingsDialog = nullptr; void OnWebSocketApiVendorEvent(std::string vendorName, std::string eventType, obs_data_t *obsEventData); @@ -87,6 +89,11 @@ bool obs_module_load(void) _webSocketServer->SetClientSubscriptionCallback(std::bind(&EventHandler::ProcessSubscriptionChange, _eventHandler.get(), std::placeholders::_1, std::placeholders::_2)); + // Initialize the WebSocket client (outbound) + _webSocketClient = std::make_shared(); + _webSocketClient->SetClientSubscriptionCallback(std::bind(&EventHandler::ProcessSubscriptionChange, _eventHandler.get(), + std::placeholders::_1, std::placeholders::_2)); + // Initialize the settings dialog obs_frontend_push_ui_translation(obs_module_get_string); QMainWindow *mainWindow = static_cast(obs_frontend_get_main_window()); @@ -121,6 +128,11 @@ void obs_module_post_load(void) blog(LOG_INFO, "[obs_module_post_load] WebSocket server is enabled, starting..."); _webSocketServer->Start(); } + + if (_config->ClientEnabled) { + blog(LOG_INFO, "[obs_module_post_load] WebSocket client mode is enabled, starting..."); + _webSocketClient->Start(); + } } void obs_module_unload(void) @@ -133,10 +145,22 @@ void obs_module_unload(void) _webSocketServer->Stop(); } + // Shutdown the WebSocket client if it is running + if (_webSocketClient) { + blog_debug("[obs_module_unload] WebSocket client is running. Stopping..."); + _webSocketClient->Stop(); + } + // Release the WebSocket server _webSocketServer->SetClientSubscriptionCallback(nullptr); _webSocketServer = nullptr; + // Release the WebSocket client + if (_webSocketClient) { + _webSocketClient->SetClientSubscriptionCallback(nullptr); + _webSocketClient = nullptr; + } + // Release the plugin/script api _webSocketApi = nullptr; @@ -179,6 +203,11 @@ WebSocketServerPtr GetWebSocketServer() return _webSocketServer; } +WebSocketClientPtr GetWebSocketClient() +{ + return _webSocketClient; +} + bool IsDebugEnabled() { return !_config || _config->DebugEnabled; @@ -212,6 +241,8 @@ void OnWebSocketApiVendorEvent(std::string vendorName, std::string eventType, ob broadcastEventData["eventData"] = eventData; _webSocketServer->BroadcastEvent(EventSubscription::Vendors, "VendorEvent", broadcastEventData); + if (_webSocketClient) + _webSocketClient->BroadcastEvent(EventSubscription::Vendors, "VendorEvent", broadcastEventData); } // Sent from: EventHandler @@ -219,6 +250,8 @@ void OnEvent(uint64_t requiredIntent, std::string eventType, json eventData, uin { if (_webSocketServer) _webSocketServer->BroadcastEvent(requiredIntent, eventType, eventData, rpcVersion); + if (_webSocketClient) + _webSocketClient->BroadcastEvent(requiredIntent, eventType, eventData, rpcVersion); if (_webSocketApi) _webSocketApi->BroadcastEvent(requiredIntent, eventType, eventData, rpcVersion); } @@ -228,6 +261,8 @@ void OnObsReady(bool ready) { if (_webSocketServer) _webSocketServer->SetObsReady(ready); + if (_webSocketClient) + _webSocketClient->SetObsReady(ready); if (_webSocketApi) _webSocketApi->SetObsReady(ready); } diff --git a/src/obs-websocket.h b/src/obs-websocket.h index c6c0f349..57568a64 100644 --- a/src/obs-websocket.h +++ b/src/obs-websocket.h @@ -40,6 +40,9 @@ typedef std::shared_ptr WebSocketApiPtr; class WebSocketServer; typedef std::shared_ptr WebSocketServerPtr; +class WebSocketClient; +typedef std::shared_ptr WebSocketClientPtr; + os_cpu_usage_info_t *GetCpuUsageInfo(); ConfigPtr GetConfig(); @@ -50,4 +53,6 @@ WebSocketApiPtr GetWebSocketApi(); WebSocketServerPtr GetWebSocketServer(); +WebSocketClientPtr GetWebSocketClient(); + bool IsDebugEnabled(); diff --git a/src/websocketclient/WebSocketClient.cpp b/src/websocketclient/WebSocketClient.cpp new file mode 100644 index 00000000..4f033265 --- /dev/null +++ b/src/websocketclient/WebSocketClient.cpp @@ -0,0 +1,1076 @@ +/* +obs-websocket +Copyright (C) 2016-2021 Stephane Lepin +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#include "WebSocketClient.h" + +#include +#include +#include +#include +#include + +#include +#include + +#include +#if OBS_WEBSOCKET_CLIENT_TLS +#include +#include +#include +#else +#include +#endif + +#include "../Config.h" +#include "../obs-websocket.h" +#include "../utils/Compat.h" +#include "../utils/Crypto.h" + +namespace { + constexpr int kHeartbeatSeconds = 30; + constexpr int kReconnectBaseSeconds = 1; + constexpr int kReconnectMaxSeconds = 30; + constexpr uint8_t kEncodingJson = 0; + constexpr uint8_t kEncodingMsgPack = 1; + + uint64_t NowSeconds() + { + return static_cast(QDateTime::currentSecsSinceEpoch()); + } + + int BackoffSeconds(uint32_t attempt) + { + uint32_t shift = std::min(attempt, 5); + int delay = kReconnectBaseSeconds << shift; + return std::min(delay, kReconnectMaxSeconds); + } + + std::string NormalizeHostForUri(const std::string &host) + { + if (host.empty()) + return host; + if (host.find(':') != std::string::npos && host.front() != '[' && host.find(']') == std::string::npos) + return "[" + host + "]"; + return host; + } + + std::string TrimWhitespace(const std::string &value) + { + const auto start = value.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) + return {}; + const auto end = value.find_last_not_of(" \t\r\n"); + return value.substr(start, end - start + 1); + } + + bool ContainsWhitespace(const std::string &value) + { + return std::any_of(value.begin(), value.end(), [](unsigned char ch) { return std::isspace(ch) != 0; }); + } + + bool LooksLikeHostWithPort(const std::string &host) + { + if (host.empty()) + return false; + + const auto schemePos = host.find("://"); + if (schemePos != std::string::npos) + return false; + + const auto bracketPos = host.find(']'); + if (host.front() == '[' && bracketPos != std::string::npos) { + if (bracketPos + 1 < host.size() && host[bracketPos + 1] == ':') { + const std::string portPart = host.substr(bracketPos + 2); + return !portPart.empty() && std::all_of(portPart.begin(), portPart.end(), + [](unsigned char ch) { return std::isdigit(ch) != 0; }); + } + return false; + } + + const auto colonPos = host.find(':'); + if (colonPos == std::string::npos) + return false; + if (host.find(':', colonPos + 1) != std::string::npos) + return false; + const std::string portPart = host.substr(colonPos + 1); + return !portPart.empty() && + std::all_of(portPart.begin(), portPart.end(), [](unsigned char ch) { return std::isdigit(ch) != 0; }); + } + +} // namespace + +class WebSocketClientTransport { +public: + using OpenHandler = std::function; + using CloseHandler = std::function; + using FailHandler = std::function; + using MessageHandler = + std::function; +#if OBS_WEBSOCKET_CLIENT_TLS + using TlsInitHandler = std::function(websocketpp::connection_hdl)>; +#else + using TlsInitHandler = std::function; +#endif + + virtual ~WebSocketClientTransport() = default; + virtual void Init(bool debug) = 0; + virtual void SetHandlers(OpenHandler open, CloseHandler close, FailHandler fail, MessageHandler message) = 0; + virtual void SetTlsInitHandler(TlsInitHandler handler) = 0; + virtual websocketpp::lib::error_code Connect(const std::string &uri, websocketpp::connection_hdl &outHdl, + const std::vector &subprotocols) = 0; + virtual void Run() = 0; + virtual void Stop() = 0; + virtual void Reset() = 0; + virtual void Send(const websocketpp::connection_hdl &hdl, const std::string &payload, + websocketpp::frame::opcode::value opcode, websocketpp::lib::error_code &ec) = 0; + virtual void Ping(const websocketpp::connection_hdl &hdl, const std::string &payload, websocketpp::lib::error_code &ec) = 0; + virtual void Close(const websocketpp::connection_hdl &hdl, uint16_t code, const std::string &reason, + websocketpp::lib::error_code &ec) = 0; + virtual std::string GetRemoteEndpoint(const websocketpp::connection_hdl &hdl) = 0; + virtual std::string GetSelectedSubprotocol(const websocketpp::connection_hdl &hdl) = 0; + virtual std::string GetLocalCloseReason(const websocketpp::connection_hdl &hdl) = 0; + virtual uint16_t GetLocalCloseCode(const websocketpp::connection_hdl &hdl) = 0; + virtual std::string GetFailReason(const websocketpp::connection_hdl &hdl) = 0; + virtual asio::io_service &GetIoService() = 0; + virtual void Post(std::function fn) = 0; +}; + +class WebSocketClientTransportPlain : public WebSocketClientTransport { +public: + using Client = websocketpp::client; + + void Init(bool debug) override + { + _client.get_alog().clear_channels(websocketpp::log::alevel::all); + _client.get_elog().clear_channels(websocketpp::log::elevel::all); + _client.init_asio(); + + if (debug) { + _client.get_alog().set_channels(websocketpp::log::alevel::all); + _client.get_alog().clear_channels(websocketpp::log::alevel::frame_header | + websocketpp::log::alevel::frame_payload | + websocketpp::log::alevel::control); + _client.get_elog().set_channels(websocketpp::log::elevel::all); + _client.get_alog().clear_channels(websocketpp::log::elevel::devel | websocketpp::log::elevel::library); + } else { + _client.get_alog().clear_channels(websocketpp::log::alevel::all); + _client.get_elog().clear_channels(websocketpp::log::elevel::all); + } + } + + void SetHandlers(OpenHandler open, CloseHandler close, FailHandler fail, MessageHandler message) override + { + _client.set_open_handler(open); + _client.set_close_handler(close); + _client.set_fail_handler(fail); + _client.set_message_handler([message](websocketpp::connection_hdl hdl, Client::message_ptr msg) { + message(hdl, msg->get_opcode(), msg->get_payload()); + }); + } + + void SetTlsInitHandler(TlsInitHandler) override {} + + websocketpp::lib::error_code Connect(const std::string &uri, websocketpp::connection_hdl &outHdl, + const std::vector &subprotocols) override + { + websocketpp::lib::error_code ec; + Client::connection_ptr connection = _client.get_connection(uri, ec); + if (ec) + return ec; + for (const auto &subprotocol : subprotocols) + connection->add_subprotocol(subprotocol); + outHdl = connection->get_handle(); + _client.connect(connection); + return ec; + } + + void Run() override { _client.run(); } + void Stop() override { _client.stop(); } + void Reset() override { _client.reset(); } + + void Send(const websocketpp::connection_hdl &hdl, const std::string &payload, websocketpp::frame::opcode::value opcode, + websocketpp::lib::error_code &ec) override + { + _client.send(hdl, payload, opcode, ec); + } + + void Ping(const websocketpp::connection_hdl &hdl, const std::string &payload, websocketpp::lib::error_code &ec) override + { + _client.ping(hdl, payload, ec); + } + + void Close(const websocketpp::connection_hdl &hdl, uint16_t code, const std::string &reason, + websocketpp::lib::error_code &ec) override + { + _client.close(hdl, code, reason, ec); + } + + std::string GetRemoteEndpoint(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_remote_endpoint(); + } + + std::string GetSelectedSubprotocol(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_subprotocol(); + } + + std::string GetLocalCloseReason(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_local_close_reason(); + } + + uint16_t GetLocalCloseCode(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_local_close_code(); + } + + std::string GetFailReason(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_ec().message(); + } + + asio::io_service &GetIoService() override { return _client.get_io_service(); } + + void Post(std::function fn) override { _client.get_io_service().post(std::move(fn)); } + +private: + Client _client; +}; + +#if OBS_WEBSOCKET_CLIENT_TLS +class WebSocketClientTransportTls : public WebSocketClientTransport { +public: + using Client = websocketpp::client; + + void Init(bool debug) override + { + _client.get_alog().clear_channels(websocketpp::log::alevel::all); + _client.get_elog().clear_channels(websocketpp::log::elevel::all); + _client.init_asio(); + + if (debug) { + _client.get_alog().set_channels(websocketpp::log::alevel::all); + _client.get_alog().clear_channels(websocketpp::log::alevel::frame_header | + websocketpp::log::alevel::frame_payload | + websocketpp::log::alevel::control); + _client.get_elog().set_channels(websocketpp::log::elevel::all); + _client.get_alog().clear_channels(websocketpp::log::elevel::devel | websocketpp::log::elevel::library); + } else { + _client.get_alog().clear_channels(websocketpp::log::alevel::all); + _client.get_elog().clear_channels(websocketpp::log::elevel::all); + } + } + + void SetHandlers(OpenHandler open, CloseHandler close, FailHandler fail, MessageHandler message) override + { + _client.set_open_handler(open); + _client.set_close_handler(close); + _client.set_fail_handler(fail); + _client.set_message_handler([message](websocketpp::connection_hdl hdl, Client::message_ptr msg) { + message(hdl, msg->get_opcode(), msg->get_payload()); + }); + } + + void SetTlsInitHandler(TlsInitHandler handler) override { _client.set_tls_init_handler(handler); } + + websocketpp::lib::error_code Connect(const std::string &uri, websocketpp::connection_hdl &outHdl, + const std::vector &subprotocols) override + { + websocketpp::lib::error_code ec; + Client::connection_ptr connection = _client.get_connection(uri, ec); + if (ec) + return ec; + for (const auto &subprotocol : subprotocols) + connection->add_subprotocol(subprotocol); + outHdl = connection->get_handle(); + _client.connect(connection); + return ec; + } + + void Run() override { _client.run(); } + void Stop() override { _client.stop(); } + void Reset() override { _client.reset(); } + + void Send(const websocketpp::connection_hdl &hdl, const std::string &payload, websocketpp::frame::opcode::value opcode, + websocketpp::lib::error_code &ec) override + { + _client.send(hdl, payload, opcode, ec); + } + + void Ping(const websocketpp::connection_hdl &hdl, const std::string &payload, websocketpp::lib::error_code &ec) override + { + _client.ping(hdl, payload, ec); + } + + void Close(const websocketpp::connection_hdl &hdl, uint16_t code, const std::string &reason, + websocketpp::lib::error_code &ec) override + { + _client.close(hdl, code, reason, ec); + } + + std::string GetRemoteEndpoint(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_remote_endpoint(); + } + + std::string GetSelectedSubprotocol(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_subprotocol(); + } + + std::string GetLocalCloseReason(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_local_close_reason(); + } + + uint16_t GetLocalCloseCode(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_local_close_code(); + } + + std::string GetFailReason(const websocketpp::connection_hdl &hdl) override + { + auto connection = _client.get_con_from_hdl(hdl); + return connection->get_ec().message(); + } + + asio::io_service &GetIoService() override { return _client.get_io_service(); } + + void Post(std::function fn) override { _client.get_io_service().post(std::move(fn)); } + +private: + Client _client; +}; +#endif + +WebSocketClient::WebSocketClient() : _protocol() +{ + UpdateStatus(State::Disabled); +} + +WebSocketClient::~WebSocketClient() +{ + Stop(); +} + +void WebSocketClient::Start() +{ + if (_shouldRun.load()) + return; + + ClientConfigSnapshot config = GetConfigSnapshot(); + if (!config.enabled) { + UpdateStatus(State::Disabled, FormatEndpointForUi(config)); + return; + } + + _reconnectAttempt = 0; + UpdateReconnectAttempt(0); + + _shouldRun = true; + _clientThread = std::thread(&WebSocketClient::ClientRunner, this); +} + +void WebSocketClient::Stop() +{ + StopInternal(true); +} + +void WebSocketClient::Restart() +{ + StopInternal(true); + Start(); +} + +void WebSocketClient::StopInternal(bool joinThread) +{ + if (!_shouldRun.load() && !joinThread) + return; + + _shouldRun = false; + _reconnectCv.notify_all(); + _connected = false; + _connecting = false; + + { + std::lock_guard lock(_transportMutex); + if (_transport) + _transport->Stop(); + } + + if (joinThread && _clientThread.joinable()) + _clientThread.join(); + + StopHeartbeat(); + _reconnectAttempt = 0; + UpdateReconnectAttempt(0); + + ClientConfigSnapshot config = GetConfigSnapshot(); + UpdateStatus(config.enabled ? State::Disconnected : State::Disabled, FormatEndpointForUi(config)); +} + +void WebSocketClient::SetObsReady(bool ready) +{ + _protocol.SetObsReady(ready); +} + +void WebSocketClient::SetClientSubscriptionCallback(WebSocketProtocol::ClientSubscriptionCallback cb) +{ + { + std::lock_guard lock(_subscriptionMutex); + _clientSubscriptionCallback = cb; + } + + _protocol.SetClientSubscriptionCallback([this](bool type, uint64_t subs) { + WebSocketProtocol::ClientSubscriptionCallback callback; + { + std::lock_guard lock(_subscriptionMutex); + _subscriptionActive = type; + callback = _clientSubscriptionCallback; + } + if (callback) + callback(type, subs); + }); +} + +WebSocketClient::Status WebSocketClient::GetStatus() +{ + std::lock_guard lock(_statusMutex); + return _status; +} + +void WebSocketClient::BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData, + uint8_t rpcVersion) +{ + if (!_connected.load() || !_protocol.IsObsReady()) + return; + + _protocol.GetThreadPool()->start(Utils::Compat::CreateFunctionRunnable([=]() { + SessionPtr session; + { + std::lock_guard lock(_sessionMutex); + session = _session; + } + if (!session || !session->IsIdentified()) + return; + if (rpcVersion && session->RpcVersion() != rpcVersion) + return; + if ((session->EventSubscriptions() & requiredIntent) == 0) + return; + + json eventMessage; + eventMessage["op"] = 5; + eventMessage["d"]["eventType"] = eventType; + eventMessage["d"]["eventIntent"] = requiredIntent; + if (eventData.is_object()) + eventMessage["d"]["eventData"] = eventData; + + SendPayload(eventMessage); + if (IsDebugEnabled() && (EventSubscription::All & requiredIntent) != 0) + blog(LOG_INFO, "[WebSocketClient::BroadcastEvent] Outgoing event:\n%s", eventMessage.dump(2).c_str()); + })); +} + +void WebSocketClient::UpdateStatus(State state, const std::string &endpoint, const std::string &error) +{ + std::lock_guard lock(_statusMutex); + _status.state = state; + if (!endpoint.empty()) + _status.endpoint = endpoint; + if (!error.empty()) + _status.lastError = error; + else if (state != State::Error) + _status.lastError.clear(); + _status.lastStateChange = NowSeconds(); +} + +void WebSocketClient::UpdateReconnectAttempt(uint32_t attempt) +{ + std::lock_guard lock(_statusMutex); + _status.reconnectAttempt = attempt; +} + +WebSocketClient::ClientConfigSnapshot WebSocketClient::GetConfigSnapshot() +{ + ClientConfigSnapshot snapshot; + auto conf = GetConfig(); + if (!conf) + return snapshot; + + snapshot.enabled = conf->ClientEnabled.load(); + snapshot.host = conf->ClientHost; + snapshot.port = conf->ClientPort.load(); + snapshot.useTls = conf->ClientUseTls.load(); + snapshot.allowInsecure = conf->ClientAllowInsecure.load(); + snapshot.allowInvalidCert = conf->ClientAllowInvalidCert.load(); + snapshot.authRequired = conf->ClientAuthRequired.load(); + snapshot.password = conf->ClientPassword; + snapshot.debug = conf->DebugEnabled.load(); + return snapshot; +} + +std::string WebSocketClient::BuildEndpoint(const ClientConfigSnapshot &config, std::string &error) +{ + std::string hostInput = TrimWhitespace(config.host); + if (hostInput.empty()) { + error = "Client host is empty."; + return {}; + } + if (hostInput.find("://") != std::string::npos) { + error = "Client host should not include a scheme (ws:// or wss://)."; + return {}; + } + if (hostInput.find('/') != std::string::npos) { + error = "Client host should not include a path."; + return {}; + } + if (ContainsWhitespace(hostInput)) { + error = "Client host must not contain whitespace."; + return {}; + } + if (LooksLikeHostWithPort(hostInput)) { + error = "Client host should not include a port; use the Client Port field."; + return {}; + } +#if !OBS_WEBSOCKET_CLIENT_TLS + if (config.useTls) { + error = obs_module_text("OBSWebSocket.Settings.ClientTlsDisabledMessage"); + return {}; + } +#endif + if (!config.useTls && !config.allowInsecure) { + error = "Unencrypted ws:// is disabled. Enable the unsafe toggle to allow it."; + return {}; + } + + std::string host = NormalizeHostForUri(hostInput); + std::string scheme = config.useTls ? "wss" : "ws"; + return scheme + "://" + host + ":" + std::to_string(config.port); +} + +std::string WebSocketClient::FormatEndpointForUi(const ClientConfigSnapshot &config) +{ + std::string hostInput = TrimWhitespace(config.host); + if (hostInput.empty()) + return {}; + if (hostInput.find("://") != std::string::npos || hostInput.find('/') != std::string::npos || + ContainsWhitespace(hostInput) || LooksLikeHostWithPort(hostInput)) { + return hostInput; + } + std::string host = NormalizeHostForUri(hostInput); + std::string scheme = config.useTls ? "wss" : "ws"; + return scheme + "://" + host + ":" + std::to_string(config.port); +} + +void WebSocketClient::ClientRunner() +{ + _activeConfig = GetConfigSnapshot(); + + while (_shouldRun.load()) { + _activeConfig = GetConfigSnapshot(); + if (!_activeConfig.enabled) { + UpdateStatus(State::Disabled, FormatEndpointForUi(_activeConfig)); + break; + } + + std::string error; + std::string endpoint = BuildEndpoint(_activeConfig, error); + if (endpoint.empty()) { + UpdateStatus(State::Error, FormatEndpointForUi(_activeConfig), error); + uint32_t attempt = _reconnectAttempt.fetch_add(1) + 1; + UpdateReconnectAttempt(attempt); + int delay = BackoffSeconds(attempt); + std::unique_lock lock(_reconnectMutex); + _reconnectCv.wait_for(lock, std::chrono::seconds(delay), [this]() { return !_shouldRun.load(); }); + continue; + } + + UpdateStatus(State::Connecting, endpoint); + _connecting = true; + + { + std::lock_guard lock(_transportMutex); +#if OBS_WEBSOCKET_CLIENT_TLS + if (_activeConfig.useTls) + _transport = std::make_shared(); + else + _transport = std::make_shared(); +#else + _transport = std::make_shared(); +#endif + + _transport->Init(_activeConfig.debug); + _transport->SetHandlers([this](websocketpp::connection_hdl hdl) { HandleOpen(hdl); }, + [this](websocketpp::connection_hdl hdl) { HandleClose(hdl); }, + [this](websocketpp::connection_hdl hdl) { HandleFail(hdl); }, + [this](websocketpp::connection_hdl hdl, websocketpp::frame::opcode::value opCode, + const std::string &payload) { HandleMessage(hdl, opCode, payload); }); + +#if OBS_WEBSOCKET_CLIENT_TLS + if (_activeConfig.useTls) { + std::string hostForVerify = TrimWhitespace(_activeConfig.host); + bool allowInvalid = _activeConfig.allowInvalidCert; + _transport->SetTlsInitHandler([hostForVerify, allowInvalid](websocketpp::connection_hdl) { + auto ctx = std::make_shared(asio::ssl::context::tls_client); + ctx->set_options(asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 | + asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use); + if (allowInvalid) { + ctx->set_verify_mode(asio::ssl::verify_none); + } else { + ctx->set_verify_mode(asio::ssl::verify_peer); + ctx->set_default_verify_paths(); + ctx->set_verify_callback(asio::ssl::rfc2818_verification(hostForVerify)); + } + return ctx; + }); + } +#endif + } + + websocketpp::connection_hdl hdl; + websocketpp::lib::error_code ec; + { + std::lock_guard lock(_transportMutex); + ec = _transport->Connect(endpoint, hdl, {"obswebsocket.json", "obswebsocket.msgpack"}); + } + if (ec) { + UpdateStatus(State::Error, endpoint, ec.message()); + _connecting = false; + uint32_t attempt = _reconnectAttempt.fetch_add(1) + 1; + UpdateReconnectAttempt(attempt); + int delay = BackoffSeconds(attempt); + std::unique_lock lock(_reconnectMutex); + _reconnectCv.wait_for(lock, std::chrono::seconds(delay), [this]() { return !_shouldRun.load(); }); + continue; + } + + { + std::lock_guard lock(_sessionMutex); + _connection = hdl; + _hasConnection = true; + } + + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (transportSnapshot) + transportSnapshot->Run(); + + StopHeartbeat(); + SessionPtr session; + { + std::lock_guard lock(_sessionMutex); + session = _session; + } + NotifySubscriptionStopIfActive(session); + { + std::lock_guard lock(_sessionMutex); + _session.reset(); + _hasConnection = false; + } + + _connecting = false; + _connected = false; + + { + std::lock_guard lock(_transportMutex); + if (_transport) { + _transport->Reset(); + _transport.reset(); + } + } + + if (!_shouldRun.load()) + break; + + uint32_t attempt = _reconnectAttempt.fetch_add(1) + 1; + UpdateReconnectAttempt(attempt); + UpdateStatus(State::Disconnected, endpoint); + + int delay = BackoffSeconds(attempt); + std::unique_lock lock(_reconnectMutex); + _reconnectCv.wait_for(lock, std::chrono::seconds(delay), [this]() { return !_shouldRun.load(); }); + } +} + +void WebSocketClient::HandleOpen(websocketpp::connection_hdl hdl) +{ + auto conf = GetConfig(); + if (!conf) { + HandleFail(hdl); + return; + } + + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + SessionPtr session = std::make_shared(); + session->SetRemoteAddress(transportSnapshot->GetRemoteEndpoint(hdl)); + session->SetConnectedAt(NowSeconds()); + session->SetAuthenticationRequired(_activeConfig.authRequired); + + std::string selectedSubprotocol = transportSnapshot->GetSelectedSubprotocol(hdl); + if (selectedSubprotocol == "obswebsocket.msgpack") + session->SetEncoding(kEncodingMsgPack); + else + session->SetEncoding(kEncodingJson); + + if (session->AuthenticationRequired()) { + _authenticationSalt = Utils::Crypto::GenerateSalt(); + _authenticationSecret = Utils::Crypto::GenerateSecret(_activeConfig.password, _authenticationSalt); + session->SetSecret(_authenticationSecret); + session->SetChallenge(Utils::Crypto::GenerateSalt()); + } + + { + std::lock_guard lock(_sessionMutex); + _session = session; + } + + json helloMessageData; + helloMessageData["obsStudioVersion"] = obs_get_version_string(); + helloMessageData["obsWebSocketVersion"] = OBS_WEBSOCKET_VERSION; + helloMessageData["rpcVersion"] = OBS_WEBSOCKET_RPC_VERSION; + if (session->AuthenticationRequired()) { + helloMessageData["authentication"] = json::object(); + helloMessageData["authentication"]["challenge"] = session->Challenge(); + helloMessageData["authentication"]["salt"] = _authenticationSalt; + } + json helloMessage; + helloMessage["op"] = 0; + helloMessage["d"] = helloMessageData; + + blog(LOG_INFO, "[WebSocketClient::HandleOpen] Connected to %s", session->RemoteAddress().c_str()); + if (IsDebugEnabled()) + blog_debug("[WebSocketClient::HandleOpen] Sending Op 0 (Hello) message:\n%s", helloMessage.dump(2).c_str()); + + SendPayload(helloMessage); + + UpdateStatus(State::Connected, FormatEndpointForUi(_activeConfig)); + _connected = true; + _connecting = false; + _reconnectAttempt = 0; + UpdateReconnectAttempt(0); + StartHeartbeat(); +} + +void WebSocketClient::HandleClose(websocketpp::connection_hdl hdl) +{ + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + SessionPtr session; + { + std::lock_guard lock(_sessionMutex); + session = _session; + } + + NotifySubscriptionStopIfActive(session); + + uint16_t closeCode = transportSnapshot->GetLocalCloseCode(hdl); + std::string closeReason = transportSnapshot->GetLocalCloseReason(hdl); + if (!closeReason.empty()) + UpdateStatus(State::Disconnected, FormatEndpointForUi(_activeConfig), closeReason); + else + UpdateStatus(State::Disconnected, FormatEndpointForUi(_activeConfig)); + + _connected = false; + _connecting = false; + StopHeartbeat(); + + blog(LOG_INFO, "[WebSocketClient::HandleClose] WebSocket client disconnected (code %u): %s", closeCode, + closeReason.c_str()); +} + +void WebSocketClient::HandleFail(websocketpp::connection_hdl hdl) +{ + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + SessionPtr session; + { + std::lock_guard lock(_sessionMutex); + session = _session; + } + NotifySubscriptionStopIfActive(session); + + std::string error = transportSnapshot->GetFailReason(hdl); + UpdateStatus(State::Error, FormatEndpointForUi(_activeConfig), error); + _connected = false; + _connecting = false; + StopHeartbeat(); + + blog(LOG_WARNING, "[WebSocketClient::HandleFail] WebSocket client connection failed: %s", error.c_str()); +} + +void WebSocketClient::HandleMessage(websocketpp::connection_hdl hdl, websocketpp::frame::opcode::value opCode, + const std::string &payload) +{ + UNUSED_PARAMETER(hdl); + _protocol.GetThreadPool()->start(Utils::Compat::CreateFunctionRunnable([=]() { + SessionPtr session; + { + std::lock_guard lock(_sessionMutex); + session = _session; + } + if (!session) + return; + + session->IncrementIncomingMessages(); + + json incomingMessage; + uint8_t sessionEncoding = session->Encoding(); + + if (sessionEncoding == kEncodingJson) { + if (opCode != websocketpp::frame::opcode::text) { + CloseWithCode(WebSocketCloseCode::MessageDecodeError, + "Your session encoding is set to Json, but a binary message was received."); + return; + } + try { + incomingMessage = json::parse(payload); + } catch (json::parse_error &e) { + CloseWithCode(WebSocketCloseCode::MessageDecodeError, + std::string("Unable to decode Json: ") + e.what()); + return; + } + } else if (sessionEncoding == kEncodingMsgPack) { + if (opCode != websocketpp::frame::opcode::binary) { + CloseWithCode(WebSocketCloseCode::MessageDecodeError, + "Your session encoding is set to MsgPack, but a text message was received."); + return; + } + try { + incomingMessage = json::from_msgpack(payload); + } catch (json::parse_error &e) { + CloseWithCode(WebSocketCloseCode::MessageDecodeError, + std::string("Unable to decode MsgPack: ") + e.what()); + return; + } + } + + blog_debug("[WebSocketClient::HandleMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str()); + + WebSocketProtocol::ProcessResult ret; + + if (!incomingMessage.is_object()) { + ret.closeCode = WebSocketCloseCode::MessageDecodeError; + ret.closeReason = "You sent a non-object payload."; + goto skipProcessing; + } + + if (!session->IsIdentified() && incomingMessage.contains("request-type")) { + ret.closeCode = WebSocketCloseCode::UnsupportedRpcVersion; + ret.closeReason = + "You appear to be attempting to connect with the pre-5.0.0 plugin protocol. Check to make sure your client is updated."; + goto skipProcessing; + } + + if (!incomingMessage.contains("op")) { + ret.closeCode = WebSocketCloseCode::UnknownOpCode; + ret.closeReason = "Your request is missing an `op`."; + goto skipProcessing; + } + + if (!incomingMessage["op"].is_number()) { + ret.closeCode = WebSocketCloseCode::UnknownOpCode; + ret.closeReason = "Your `op` is not a number."; + goto skipProcessing; + } + + _protocol.ProcessMessage(session, ret, incomingMessage["op"], incomingMessage["d"]); + + skipProcessing: + if (ret.closeCode != WebSocketCloseCode::DontClose) { + CloseWithCode(ret.closeCode, ret.closeReason); + return; + } + + if (!ret.result.is_null()) { + SendPayload(ret.result); + } + })); +} + +void WebSocketClient::NotifySubscriptionStopIfActive(const SessionPtr &session) +{ + if (!session || !session->IsIdentified()) + return; + + bool active = false; + { + std::lock_guard lock(_subscriptionMutex); + active = _subscriptionActive; + } + if (!active) + return; + + _protocol.NotifyClientSubscriptionChange(false, session->EventSubscriptions()); +} + +void WebSocketClient::SendPayload(const json &payload) +{ + SessionPtr session; + websocketpp::connection_hdl hdl; + { + std::lock_guard lock(_sessionMutex); + if (!_hasConnection || !_session) + return; + session = _session; + hdl = _connection; + } + + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + uint8_t encoding = session->Encoding(); + std::string payloadData; + websocketpp::frame::opcode::value opcode; + + if (encoding == kEncodingJson) { + payloadData = payload.dump(); + opcode = websocketpp::frame::opcode::text; + } else { + auto msgPackData = json::to_msgpack(payload); + payloadData = std::string(msgPackData.begin(), msgPackData.end()); + opcode = websocketpp::frame::opcode::binary; + } + + session->IncrementOutgoingMessages(); + transportSnapshot->Post([transportSnapshot, hdl, payloadData, opcode]() { + websocketpp::lib::error_code ec; + transportSnapshot->Send(hdl, payloadData, opcode, ec); + if (ec) + blog(LOG_WARNING, "[WebSocketClient::SendPayload] Sending message failed: %s", ec.message().c_str()); + }); +} + +void WebSocketClient::CloseWithCode(WebSocketCloseCode::WebSocketCloseCode code, const std::string &reason) +{ + std::shared_ptr transportSnapshot; + websocketpp::connection_hdl hdl; + { + std::lock_guard lock(_sessionMutex); + if (!_hasConnection) + return; + hdl = _connection; + } + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + transportSnapshot->Post([transportSnapshot, hdl, code, reason]() { + websocketpp::lib::error_code ec; + transportSnapshot->Close(hdl, code, reason, ec); + }); +} + +void WebSocketClient::StartHeartbeat() +{ + std::shared_ptr transportSnapshot; + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + _heartbeatTimer = std::make_unique(transportSnapshot->GetIoService()); + ScheduleHeartbeat(); +} + +void WebSocketClient::StopHeartbeat() +{ + if (_heartbeatTimer) { + std::error_code ec; + _heartbeatTimer->cancel(ec); + _heartbeatTimer.reset(); + } +} + +void WebSocketClient::ScheduleHeartbeat() +{ + if (!_heartbeatTimer) + return; + + _heartbeatTimer->expires_after(std::chrono::seconds(kHeartbeatSeconds)); + _heartbeatTimer->async_wait([this](const std::error_code &ec) { + if (ec || !_connected.load()) + return; + + std::shared_ptr transportSnapshot; + websocketpp::connection_hdl hdl; + { + std::lock_guard lock(_sessionMutex); + if (!_hasConnection) + return; + hdl = _connection; + } + { + std::lock_guard lock(_transportMutex); + transportSnapshot = _transport; + } + if (!transportSnapshot) + return; + + transportSnapshot->Post([transportSnapshot, hdl]() { + websocketpp::lib::error_code ec; + transportSnapshot->Ping(hdl, "ping", ec); + if (ec) + blog(LOG_WARNING, "[WebSocketClient::Heartbeat] Ping failed: %s", ec.message().c_str()); + }); + + ScheduleHeartbeat(); + }); +} diff --git a/src/websocketclient/WebSocketClient.h b/src/websocketclient/WebSocketClient.h new file mode 100644 index 00000000..1952177d --- /dev/null +++ b/src/websocketclient/WebSocketClient.h @@ -0,0 +1,135 @@ +/* +obs-websocket +Copyright (C) 2016-2021 Stephane Lepin +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "../utils/Json.h" +#include "../websocketserver/WebSocketProtocol.h" +#include "../websocketserver/rpc/WebSocketSession.h" + +class WebSocketClientTransport; + +class WebSocketClient { +public: + enum class State { Disabled, Connecting, Connected, Disconnected, Error }; + + struct Status { + State state = State::Disabled; + std::string endpoint; + std::string lastError; + uint64_t lastStateChange = 0; + uint32_t reconnectAttempt = 0; + }; + + WebSocketClient(); + ~WebSocketClient(); + + void Start(); + void Stop(); + void Restart(); + + void SetObsReady(bool ready); + void SetClientSubscriptionCallback(WebSocketProtocol::ClientSubscriptionCallback cb); + + void BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData = nullptr, + uint8_t rpcVersion = 0); + + Status GetStatus(); + +private: + void ClientRunner(); + void StopInternal(bool joinThread); + + void UpdateStatus(State state, const std::string &endpoint = {}, const std::string &error = {}); + void UpdateReconnectAttempt(uint32_t attempt); + + struct ClientConfigSnapshot { + bool enabled = false; + std::string host; + uint16_t port = 4455; + bool useTls = true; + bool allowInsecure = false; + bool allowInvalidCert = false; + bool authRequired = true; + std::string password; + bool debug = false; + }; + + ClientConfigSnapshot GetConfigSnapshot(); + std::string BuildEndpoint(const ClientConfigSnapshot &config, std::string &error); + std::string FormatEndpointForUi(const ClientConfigSnapshot &config); + + void HandleOpen(websocketpp::connection_hdl hdl); + void HandleClose(websocketpp::connection_hdl hdl); + void HandleFail(websocketpp::connection_hdl hdl); + void HandleMessage(websocketpp::connection_hdl hdl, websocketpp::frame::opcode::value opCode, const std::string &payload); + + void SendPayload(const json &payload); + void CloseWithCode(WebSocketCloseCode::WebSocketCloseCode code, const std::string &reason); + + void NotifySubscriptionStopIfActive(const SessionPtr &session); + + void StartHeartbeat(); + void StopHeartbeat(); + void ScheduleHeartbeat(); + + std::atomic _shouldRun = false; + std::atomic _connected = false; + std::atomic _connecting = false; + std::atomic _reconnectAttempt = 0; + + std::mutex _statusMutex; + Status _status; + + std::mutex _sessionMutex; + SessionPtr _session; + websocketpp::connection_hdl _connection; + bool _hasConnection = false; + + std::string _authenticationSecret; + std::string _authenticationSalt; + + std::mutex _transportMutex; + std::shared_ptr _transport; + + std::unique_ptr _heartbeatTimer; + + std::thread _clientThread; + std::mutex _reconnectMutex; + std::condition_variable _reconnectCv; + + std::mutex _subscriptionMutex; + WebSocketProtocol::ClientSubscriptionCallback _clientSubscriptionCallback; + bool _subscriptionActive = false; + + WebSocketProtocol _protocol; + ClientConfigSnapshot _activeConfig; +}; diff --git a/src/websocketserver/WebSocketServer_Protocol.cpp b/src/websocketserver/WebSocketProtocol.cpp similarity index 82% rename from src/websocketserver/WebSocketServer_Protocol.cpp rename to src/websocketserver/WebSocketProtocol.cpp index 686a98f6..6eae2ac0 100644 --- a/src/websocketserver/WebSocketServer_Protocol.cpp +++ b/src/websocketserver/WebSocketProtocol.cpp @@ -20,14 +20,13 @@ with this program. If not, see #include #include -#include "WebSocketServer.h" +#include "WebSocketProtocol.h" #include "../requesthandler/RequestHandler.h" #include "../requesthandler/RequestBatchHandler.h" #include "../obs-websocket.h" #include "../Config.h" #include "../utils/Crypto.h" #include "../utils/Platform.h" -#include "../utils/Compat.h" static bool IsSupportedRpcVersion(uint8_t requestedVersion) { @@ -54,7 +53,35 @@ static json ConstructRequestResult(RequestResult requestResult, const json &requ return ret; } -void WebSocketServer::SetSessionParameters(SessionPtr session, ProcessResult &ret, const json &payloadData) +WebSocketProtocol::WebSocketProtocol() : _threadPool() {} + +void WebSocketProtocol::SetObsReady(bool ready) +{ + _obsReady = ready; +} + +bool WebSocketProtocol::IsObsReady() const +{ + return _obsReady.load(); +} + +void WebSocketProtocol::SetClientSubscriptionCallback(ClientSubscriptionCallback cb) +{ + _clientSubscriptionCallback = cb; +} + +void WebSocketProtocol::NotifyClientSubscriptionChange(bool type, uint64_t eventSubscriptions) +{ + if (_clientSubscriptionCallback) + _clientSubscriptionCallback(type, eventSubscriptions); +} + +QThreadPool *WebSocketProtocol::GetThreadPool() +{ + return &_threadPool; +} + +void WebSocketProtocol::SetSessionParameters(SessionPtr session, ProcessResult &ret, const json &payloadData) { if (payloadData.contains("eventSubscriptions")) { if (!payloadData["eventSubscriptions"].is_number_unsigned()) { @@ -66,8 +93,8 @@ void WebSocketServer::SetSessionParameters(SessionPtr session, ProcessResult &re } } -void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::ProcessResult &ret, - WebSocketOpCode::WebSocketOpCode opCode, json &payloadData) +void WebSocketProtocol::ProcessMessage(SessionPtr session, WebSocketProtocol::ProcessResult &ret, + WebSocketOpCode::WebSocketOpCode opCode, json &payloadData) { if (!payloadData.is_object()) { if (payloadData.is_null()) { @@ -351,61 +378,3 @@ void WebSocketServer::ProcessMessage(SessionPtr session, WebSocketServer::Proces return; } } - -// It isn't consistent to directly call the WebSocketServer from the events system, but it would also be dumb to make it unnecessarily complicated. -void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData, - uint8_t rpcVersion) -{ - if (!_server.is_listening() || !_obsReady) - return; - - _threadPool.start(Utils::Compat::CreateFunctionRunnable([=]() { - // Populate message object - json eventMessage; - eventMessage["op"] = 5; - eventMessage["d"]["eventType"] = eventType; - eventMessage["d"]["eventIntent"] = requiredIntent; - if (eventData.is_object()) - eventMessage["d"]["eventData"] = eventData; - - // Initialize objects. The broadcast process only dumps the data when its needed. - std::string messageJson; - std::string messageMsgPack; - - // Recurse connected sessions and send the event to suitable sessions. - std::unique_lock lock(_sessionMutex); - for (auto &it : _sessions) { - if (!it.second->IsIdentified()) - continue; - if (rpcVersion && it.second->RpcVersion() != rpcVersion) - continue; - if ((it.second->EventSubscriptions() & requiredIntent) != 0) { - websocketpp::lib::error_code errorCode; - switch (it.second->Encoding()) { - case WebSocketEncoding::Json: - if (messageJson.empty()) - messageJson = eventMessage.dump(); - _server.send((websocketpp::connection_hdl)it.first, messageJson, - websocketpp::frame::opcode::text, errorCode); - it.second->IncrementOutgoingMessages(); - break; - case WebSocketEncoding::MsgPack: - if (messageMsgPack.empty()) { - auto msgPackData = json::to_msgpack(eventMessage); - messageMsgPack = std::string(msgPackData.begin(), msgPackData.end()); - } - _server.send((websocketpp::connection_hdl)it.first, messageMsgPack, - websocketpp::frame::opcode::binary, errorCode); - it.second->IncrementOutgoingMessages(); - break; - } - if (errorCode) - blog(LOG_ERROR, "[WebSocketServer::BroadcastEvent] Error sending event message: %s", - errorCode.message().c_str()); - } - } - lock.unlock(); - if (IsDebugEnabled() && (EventSubscription::All & requiredIntent) != 0) // Don't log high volume events - blog(LOG_INFO, "[WebSocketServer::BroadcastEvent] Outgoing event:\n%s", eventMessage.dump(2).c_str()); - })); -} diff --git a/src/websocketserver/WebSocketProtocol.h b/src/websocketserver/WebSocketProtocol.h new file mode 100644 index 00000000..e1a4033b --- /dev/null +++ b/src/websocketserver/WebSocketProtocol.h @@ -0,0 +1,59 @@ +/* +obs-websocket +Copyright (C) 2016-2021 Stephane Lepin +Copyright (C) 2020-2021 Kyle Manning + +This program is free software; you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation; either version 2 of the License, or +(at your option) any later version. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License along +with this program. If not, see +*/ + +#pragma once + +#include +#include +#include + +#include + +#include "rpc/WebSocketSession.h" +#include "types/WebSocketCloseCode.h" +#include "types/WebSocketOpCode.h" +#include "../utils/Json.h" + +class WebSocketProtocol { +public: + using ClientSubscriptionCallback = std::function; + + struct ProcessResult { + WebSocketCloseCode::WebSocketCloseCode closeCode = WebSocketCloseCode::DontClose; + std::string closeReason; + json result; + }; + + WebSocketProtocol(); + + void SetObsReady(bool ready); + bool IsObsReady() const; + void SetClientSubscriptionCallback(ClientSubscriptionCallback cb); + void NotifyClientSubscriptionChange(bool type, uint64_t eventSubscriptions); + QThreadPool *GetThreadPool(); + + void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, json &payloadData); + +private: + void SetSessionParameters(SessionPtr session, ProcessResult &ret, const json &payloadData); + + QThreadPool _threadPool; + std::atomic _obsReady = false; + ClientSubscriptionCallback _clientSubscriptionCallback; +}; diff --git a/src/websocketserver/WebSocketServer.cpp b/src/websocketserver/WebSocketServer.cpp index d2c26742..8692eb59 100644 --- a/src/websocketserver/WebSocketServer.cpp +++ b/src/websocketserver/WebSocketServer.cpp @@ -152,7 +152,7 @@ void WebSocketServer::Stop() } lock.unlock(); - _threadPool.waitForDone(); + _protocol.GetThreadPool()->waitForDone(); // This can delay the thread that it is running on. Bad but kinda required. while (_sessions.size() > 0) @@ -310,8 +310,8 @@ void WebSocketServer::onClose(websocketpp::connection_hdl hdl) lock.unlock(); // If client was identified, announce unsubscription - if (isIdentified && _clientSubscriptionCallback) - _clientSubscriptionCallback(false, eventSubscriptions); + if (isIdentified) + _protocol.NotifyClientSubscriptionChange(false, eventSubscriptions); // Build SessionState object for signal WebSocketSessionState state; @@ -349,7 +349,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, { auto opCode = message->get_opcode(); std::string payload = message->get_payload(); - _threadPool.start(Utils::Compat::CreateFunctionRunnable([=]() { + _protocol.GetThreadPool()->start(Utils::Compat::CreateFunctionRunnable([=]() { std::unique_lock lock(_sessionMutex); SessionPtr session; try { @@ -401,7 +401,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, blog_debug("[WebSocketServer::onMessage] Incoming message (decoded):\n%s", incomingMessage.dump(2).c_str()); - ProcessResult ret; + WebSocketProtocol::ProcessResult ret; // Verify incoming message is an object if (!incomingMessage.is_object()) { @@ -433,7 +433,7 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, goto skipProcessing; } - ProcessMessage(session, ret, incomingMessage["op"], incomingMessage["d"]); + _protocol.ProcessMessage(session, ret, incomingMessage["op"], incomingMessage["d"]); skipProcessing: if (ret.closeCode != WebSocketCloseCode::DontClose) { @@ -462,3 +462,61 @@ void WebSocketServer::onMessage(websocketpp::connection_hdl hdl, } })); } + +// It isn't consistent to directly call the WebSocketServer from the events system, but it would also be dumb to make it unnecessarily complicated. +void WebSocketServer::BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData, + uint8_t rpcVersion) +{ + if (!_server.is_listening() || !_protocol.IsObsReady()) + return; + + _protocol.GetThreadPool()->start(Utils::Compat::CreateFunctionRunnable([=]() { + // Populate message object + json eventMessage; + eventMessage["op"] = 5; + eventMessage["d"]["eventType"] = eventType; + eventMessage["d"]["eventIntent"] = requiredIntent; + if (eventData.is_object()) + eventMessage["d"]["eventData"] = eventData; + + // Initialize objects. The broadcast process only dumps the data when its needed. + std::string messageJson; + std::string messageMsgPack; + + // Recurse connected sessions and send the event to suitable sessions. + std::unique_lock lock(_sessionMutex); + for (auto &it : _sessions) { + if (!it.second->IsIdentified()) + continue; + if (rpcVersion && it.second->RpcVersion() != rpcVersion) + continue; + if ((it.second->EventSubscriptions() & requiredIntent) != 0) { + websocketpp::lib::error_code errorCode; + switch (it.second->Encoding()) { + case WebSocketEncoding::Json: + if (messageJson.empty()) + messageJson = eventMessage.dump(); + _server.send((websocketpp::connection_hdl)it.first, messageJson, + websocketpp::frame::opcode::text, errorCode); + it.second->IncrementOutgoingMessages(); + break; + case WebSocketEncoding::MsgPack: + if (messageMsgPack.empty()) { + auto msgPackData = json::to_msgpack(eventMessage); + messageMsgPack = std::string(msgPackData.begin(), msgPackData.end()); + } + _server.send((websocketpp::connection_hdl)it.first, messageMsgPack, + websocketpp::frame::opcode::binary, errorCode); + it.second->IncrementOutgoingMessages(); + break; + } + if (errorCode) + blog(LOG_ERROR, "[WebSocketServer::BroadcastEvent] Error sending event message: %s", + errorCode.message().c_str()); + } + } + lock.unlock(); + if (IsDebugEnabled() && (EventSubscription::All & requiredIntent) != 0) // Don't log high volume events + blog(LOG_INFO, "[WebSocketServer::BroadcastEvent] Outgoing event:\n%s", eventMessage.dump(2).c_str()); + })); +} diff --git a/src/websocketserver/WebSocketServer.h b/src/websocketserver/WebSocketServer.h index 6bb406b5..1511f110 100644 --- a/src/websocketserver/WebSocketServer.h +++ b/src/websocketserver/WebSocketServer.h @@ -21,12 +21,12 @@ with this program. If not, see #include #include -#include #include #include #include #include +#include "WebSocketProtocol.h" #include "rpc/WebSocketSession.h" #include "types/WebSocketCloseCode.h" #include "types/WebSocketOpCode.h" @@ -57,26 +57,20 @@ class WebSocketServer : QObject { void InvalidateSession(websocketpp::connection_hdl hdl); void BroadcastEvent(uint64_t requiredIntent, const std::string &eventType, const json &eventData = nullptr, uint8_t rpcVersion = 0); - inline void SetObsReady(bool ready) { _obsReady = ready; } + inline void SetObsReady(bool ready) { _protocol.SetObsReady(ready); } inline bool IsListening() { return _server.is_listening(); } std::vector GetWebSocketSessions(); - inline QThreadPool *GetThreadPool() { return &_threadPool; } + inline QThreadPool *GetThreadPool() { return _protocol.GetThreadPool(); } // Callback for when a client subscribes or unsubscribes. `true` for sub, `false` for unsub - typedef std::function ClientSubscriptionCallback; // bool type, uint64_t eventSubscriptions - inline void SetClientSubscriptionCallback(ClientSubscriptionCallback cb) { _clientSubscriptionCallback = cb; } + using ClientSubscriptionCallback = WebSocketProtocol::ClientSubscriptionCallback; // bool type, uint64_t eventSubscriptions + inline void SetClientSubscriptionCallback(ClientSubscriptionCallback cb) { _protocol.SetClientSubscriptionCallback(cb); } signals: void ClientConnected(WebSocketSessionState state); void ClientDisconnected(WebSocketSessionState state, uint16_t closeCode); private: - struct ProcessResult { - WebSocketCloseCode::WebSocketCloseCode closeCode = WebSocketCloseCode::DontClose; - std::string closeReason; - json result; - }; - void ServerRunner(); bool onValidate(websocketpp::connection_hdl hdl); @@ -84,21 +78,13 @@ class WebSocketServer : QObject { void onClose(websocketpp::connection_hdl hdl); void onMessage(websocketpp::connection_hdl hdl, websocketpp::server::message_ptr message); - static void SetSessionParameters(SessionPtr session, WebSocketServer::ProcessResult &ret, const json &payloadData); - void ProcessMessage(SessionPtr session, ProcessResult &ret, WebSocketOpCode::WebSocketOpCode opCode, json &payloadData); - - QThreadPool _threadPool; - std::thread _serverThread; websocketpp::server _server; + WebSocketProtocol _protocol; std::string _authenticationSecret; std::string _authenticationSalt; std::mutex _sessionMutex; std::map> _sessions; - - std::atomic _obsReady = false; - - ClientSubscriptionCallback _clientSubscriptionCallback; };