From 67f019f3385510829edf80512ca6fcd01759f531 Mon Sep 17 00:00:00 2001 From: Boris Tyshkevich <68195949+bvt123@users.noreply.github.com> Date: Wed, 1 Apr 2026 09:18:57 +0200 Subject: [PATCH 1/4] Add OAuth2 login to clickhouse-client via credentials JSON file MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds two new flags to `clickhouse-client`: - `--login`: Authorization Code + PKCE flow (RFC 7636), opens browser - `--login-device`: Device Authorization flow (RFC 8628), prints URL+code Both flows obtain an ID token from any OpenID Connect provider and then proceed exactly as if `--jwt ` had been passed. No changes to the connection layer. The user places a Google-format credentials file at `~/.clickhouse-client/oauth_client.json` (override with `--oauth-credentials PATH`). Refresh tokens are cached in `~/.clickhouse-client/oauth_cache.json` (mode 0600, written atomically) so that subsequent runs proceed silently without re-authenticating. `OAuthLogin.cpp` is moved to `src/Client/` (compiled into `clickhouse_client`) so it is unit-testable via `unit_tests_dbms`. Tests cover `loadOAuthCredentials` for valid/invalid JSON, missing required fields, unknown top-level keys, and file-not-found error paths. Backward compatibility: when `--login` is passed without a credentials file and without `--login-device`, the existing cloud-specific `login()` path is used as before. Co-Authored-By: Claude Sonnet 4.6 Fix Google JWT exp claim type in TokenProcessorsOpaque Google ID tokens return the `exp` claim as a JSON number (floating-point `double`), not an integer. Calling `getValueByKey()` instantiated `picojson::value::is()` / `picojson::value::get()`, which are not provided by the `picojson` library and caused a linker exception on arm64 (where `time_t` is `long`). Fix: cast through `double` first. Co-Authored-By: Claude Sonnet 4.6 Fix review blockers in OAuth2 login implementation Blocker 1 — cloud auto-login hijack: Remove the filesystem::exists(oauth_client.json) auto-detection from the `--login` processing path. The credentials-file flow is now only entered when `--login-device` or `--oauth-credentials` is explicitly passed. `--login` alone (including the auto-added case for *.clickhouse.cloud) always falls through to the existing cloud login() path. This prevents breaking cloud auth for users who happen to have a Google credentials file at the default path. Restore `--login` as `po::bool_switch()` (was changed to a presence flag) so that existing scripts using `--login=true` continue to work. XSS in callback HTML: The OAuth2 auth-code callback handler reflected the error= query-string parameter directly into the HTML response body. Add a minimal htmlEscape helper and use it, preventing a potential XSS via a crafted redirect URI. Remove unused SUPPORT_IS_DISABLED declaration: That error code is used in Client.cpp (where it belongs), not in OAuthLogin.cpp. Remove the dead declaration from the anonymous ErrorCodes block. Co-Authored-By: Claude Sonnet 4.6 Robustness fixes, issuer discovery, state CSRF, and docs postForm — RFC 6749 error response handling: Removed the early HTTP-status check that threw before JSON parsing. RFC 6749 §5.2 returns errors (authorization_pending, slow_down, etc.) as HTTP 400 with a JSON body. postForm now always attempts JSON parsing and throws only if the body is non-JSON, which fixes the device-flow polling bug where authorization_pending caused an exception instead of being handled by the caller's retry loop. discoverDeviceEndpoint — sub-path issuer support: The generic heuristic now strips only the last path segment of token_uri instead of the entire path, preserving realm prefixes like Keycloak's /realms/. Accepts an explicit issuer_hint parameter populated from the new "issuer" field in the credentials JSON, bypassing the heuristic entirely for providers that need it. OAuthCredentials — optional issuer field: Add std::string issuer to the struct and load it from the credentials JSON. Providers that use non-root issuers (Keycloak, Okta custom domains) can now specify issuer directly instead of relying on heuristic path-stripping. Auth code flow — CSRF state parameter (RFC 6749 §10.12): Generate a 16-byte random hex state, include it in the authorization URL, echo it back via the callback handler, and verify it matches before proceeding to code exchange. Logging for silently swallowed exceptions: readCachedRefreshToken prints a notice when the cache file is unparseable. tryRefreshToken now prints the rejection reason (expired, revoked, or network error) before falling through to interactive flow. Tests: add issuer field, PKCE building-block, base64url property tests. Docs: fix --login prefix, add --login-device and --oauth-credentials entries with credentials file format and cache location. Co-Authored-By: Claude Sonnet 4.6 Merge --login-device into --login=device Replace the separate --login-device flag with a mode parameter on --login: --login auth-code + PKCE flow (browser, default) --login=device Device Authorization flow (prints URL + code) Uses po::value()->implicit_value("browser") so that bare --login keeps its existing meaning. An invalid mode value throws BAD_ARGUMENTS immediately. The --login=... prefix is now recognised in readArguments() as an auth credential indicator. Co-Authored-By: Claude Sonnet 4.6 Fix --login=device crash and --login=browser routing Two bugs from testing: 1. --login device crash (Cannot convert to boolean: device): argsToConfig() runs after processOptions() and overwrites config["login"] with the raw string value from the command line. The subsequent config().getBool("login", false) in main() then threw because "device" is not a valid boolean. Fix: the cloud-login signal is now stored under the separate key "cloud_oauth_pending" which argsToConfig never touches. All three references to config["login"] (set-true, set-false, get) are updated. 2. --login=browser (or --login browser) fell through to the cloud path: Any explicitly specified mode should use the credentials-file OIDC flow, since the user is clearly opting into it. Only bare --login (implicit empty value) falls back to the ClickHouse Cloud auto-login path. Fix: implicit_value changed from "browser" to ""; use_credentials_file now triggers on !login_mode.empty() instead of mode == "device" only. Routing table after this change: --login → ClickHouse Cloud path (backward compat) --login=browser → OIDC auth-code + PKCE (credentials file) --login=device → OIDC device flow (credentials file) --login= → BAD_ARGUMENTS Co-Authored-By: Claude Sonnet 4.6 Fix device flow crash on Google verification_url field name Google returns verification_url (legacy) instead of verification_uri (RFC 8628). getValue on a missing key threw 'Can not convert empty value'. Fallback order: verification_uri_complete -> verification_uri -> verification_url. Also guard device_code/user_code with explicit has() checks before access. Co-Authored-By: Claude Sonnet 4.6 Document --login/--oauth-credentials and add OAuth option validation tests - Fix `--login[=]` doc entry: bare `--login` = Cloud auto-login (no default mode implied); `--login=browser` / `--login=device` trigger the OIDC credentials-file flow - Update `--oauth-credentials` description and link to new section - Add `### OAuth credentials file` section: JSON format example, required and optional fields table, default path, cache location - Add Tests 9–11 to `03749_cloud_endpoint_auth_precedence.sh`: - `--login=device --oauth-credentials /nonexistent` → file-not-found error - `--login=invalid` → `BAD_ARGUMENTS` ("must be 'browser' or 'device'") - `--jwt tok --login=browser` → `BAD_ARGUMENTS` ("cannot both be specified") - Update `.reference` with expected output for new tests Co-Authored-By: Claude Sonnet 4.6 Fix security and reliability issues in OAuth2 login Must-fix: - Session token refresh: add OAuthJWTProvider extending JWTProvider so Connection::sendQuery can call getJWT() transparently; eliminates the 1-hour session limit when id_token is obtained only once at startup - Bind callback server to 127.0.0.1 explicitly (not 0.0.0.0) to prevent network-adjacent attackers from racing to deliver a forged callback - Add offline_access scope to both auth-code and device flows so that standard OIDC providers (non-Google) issue refresh tokens Should-fix: - Guard postForm against non-object JSON: separate JSON parse error from Poco extract BadCastException with a clearer message - Warn on plain http:// token/auth endpoints in loadOAuthCredentials Also fix include order: OAuthLogin.h now includes first so the USE_JWT_CPP && USE_SSL guard evaluates correctly regardless of the order headers are included by callers. Co-Authored-By: Claude Sonnet 4.6 Fix Google OAuth scope: use access_type=offline instead of offline_access Google rejects offline_access as an invalid scope (Error 400: invalid_scope). It uses the access_type=offline query parameter in the authorization URL instead. Standard OIDC providers still require offline_access in the scope. Add isGoogleProvider helper (detected via token_uri host) and apply the correct mechanism for each: Google gets access_type=offline appended to the authorization URL with openid email profile scope; all other providers get offline_access in the scope. Co-Authored-By: Claude Sonnet 4.6 Split OAuthJWTProvider into its own file; fix silent cache catch - Move OAuthJWTProvider class and createOAuthJWTProvider factory into OAuthJWTProvider.cpp. The class delegates entirely to obtainIDToken (the public API) so it has no anonymous-namespace dependencies and can live independently. OAuthLogin.cpp loses ~60 lines and its JWTProvider.h include. - Fix silent catch(...) in writeCachedRefreshToken: log a warning when the existing cache file cannot be parsed (mirrors the read-side warning in readCachedRefreshToken). Co-Authored-By: Claude Sonnet 4.6 Add OAuth2 integration tests and fix remaining review items **Layer 2 — Keycloak integration tests** New test suite `tests/integration/test_keycloak_auth/` exercises Keycloak 26.0 as a real OIDC provider: - `test_jwt_dynamic_jwks` — token validated via explicit JWKS URI - `test_openid_discovery` — token validated via OIDC discovery document - `test_username_claim` — `preferred_username` claim maps to ClickHouse user - `test_token_refresh` — `refresh_token` grant produces a valid `id_token` - `test_wrong_issuer_rejected` — tampered `iss` claim is rejected - `test_expired_token_rejected` — expired `exp` claim is rejected - `test_device_flow_initiation` — device auth endpoint returns correct fields - `test_device_flow_round_trip` — full RFC 8628 round-trip: initiate → simulate browser approval via HTML form sequence → poll token endpoint → authenticate ClickHouse `SELECT 1` with the resulting `id_token` Infrastructure additions: - `tests/integration/compose/docker_compose_keycloak.yml` — Keycloak 26.0 service with `--import-realm` and a health-check on the OIDC discovery URL - `tests/integration/helpers/cluster.py` — `with_keycloak` flag following the existing `with_ldap` pattern: `setup_keycloak_cmd`, `wait_keycloak_to_start`, `get_keycloak_url`, docker-compose wiring, and `depends_on` entry in instance compose generation - `tests/integration/test_keycloak_auth/keycloak/realm-export.json` — pre-configured realm with client `clickhouse` (direct grants + device auth enabled), user `alice`, group `analysts`, and a `groups` claim mapper - `tests/integration/test_keycloak_auth/configs/validators.xml` — `jwt_dynamic_jwks` and `openid` token processors pointing at Keycloak - `tests/integration/test_keycloak_auth/configs/users.xml` — default user with `access_management` enabled **Fix: `--oauth-credentials` without `--login` silently ignored** Add an explicit guard before the `if (options.count("login"))` block in `programs/client/Client.cpp` that throws `BAD_ARGUMENTS` when `--oauth-credentials` is supplied without `--login=browser` or `--login=device`. **Fix: dead CMake and duplicate jwt-cpp linkage in `src/CMakeLists.txt`** Remove the dead `add_object_library(clickhouse_client_jwt Client/jwt)` block (the `src/Client/jwt/` directory does not exist) and the duplicate `target_link_libraries(clickhouse_common_io PUBLIC ch_contrib::jwt-cpp)` at the Redis block — the identical linkage is already present earlier in the file. Co-Authored-By: Claude Sonnet 4.6 Fix blockers in Keycloak integration tests - `query_with_token`: replace `node.query(..., headers=...)` (unsupported parameter) with `node.http_request("", method="POST", data=query, headers=...)`, which uses the HTTP interface and accepts custom headers. - `get_keycloak_url`: return `http://localhost:{port}` instead of `http://keycloak:{port}`. The `keycloak` hostname is only resolvable inside the Docker network; pytest runs on the host and connects via the mapped port. `wait_keycloak_to_start` already used `localhost` correctly. - `configs/users.xml`: add user `alice` with `` authentication so that `test_username_claim` (and any test that maps `preferred_username` to a ClickHouse session user) can find the user in the user store. - Remove unused `import copy` from `test_wrong_issuer_rejected` and `test_expired_token_rejected`. Co-Authored-By: Claude Sonnet 4.6 Fix OAuth HTTPS sessions to respect configured SSL context postForm and discoverDeviceEndpoint in OAuthLogin.cpp constructed Poco::Net::HTTPSClientSession with the bare two-argument constructor, which picks up Poco's built-in default SSL context (system CA bundle, full verification) and ignores any openSSL.client.* configuration the user has set, including custom CA certificates (caConfig) and verificationMode. In a corporate environment with a TLS inspection proxy, the proxy CA is trusted only through the client configured CA bundle. With the old code, every OAuth HTTPS request — OIDC discovery, token endpoint, postForm — would fail with a certificate verification error, while the upstream Cloud login path (which uses SSLManager::instance().defaultClientContext()) worked fine. Fix: pass Poco::Net::SSLManager::instance().defaultClientContext() to both HTTPSClientSession constructors, matching the pattern already used in JWTProvider::createHTTPSession. This makes OAuth traffic respect openSSL.client.caConfig, verificationMode, and every other SSL setting the user has configured. Co-Authored-By: Claude Sonnet 4.6 --- docs/en/interfaces/cli.md | 30 +- programs/client/Client.cpp | 82 +- src/Access/TokenProcessorsOpaque.cpp | 2 +- src/CMakeLists.txt | 7 - src/Client/OAuthJWTProvider.cpp | 64 ++ src/Client/OAuthLogin.cpp | 843 ++++++++++++++++++ src/Client/OAuthLogin.h | 44 + src/Client/tests/gtest_oauth_login.cpp | 275 ++++++ .../compose/docker_compose_keycloak.yml | 21 + tests/integration/helpers/cluster.py | 62 ++ .../test_keycloak_auth/__init__.py | 0 .../test_keycloak_auth/configs/users.xml | 13 + .../test_keycloak_auth/configs/validators.xml | 20 + .../keycloak/realm-export.json | 72 ++ tests/integration/test_keycloak_auth/test.py | 374 ++++++++ ...9_cloud_endpoint_auth_precedence.reference | 6 + .../03749_cloud_endpoint_auth_precedence.sh | 29 + 17 files changed, 1925 insertions(+), 19 deletions(-) create mode 100644 src/Client/OAuthJWTProvider.cpp create mode 100644 src/Client/OAuthLogin.cpp create mode 100644 src/Client/OAuthLogin.h create mode 100644 src/Client/tests/gtest_oauth_login.cpp create mode 100644 tests/integration/compose/docker_compose_keycloak.yml create mode 100644 tests/integration/test_keycloak_auth/__init__.py create mode 100644 tests/integration/test_keycloak_auth/configs/users.xml create mode 100644 tests/integration/test_keycloak_auth/configs/validators.xml create mode 100644 tests/integration/test_keycloak_auth/keycloak/realm-export.json create mode 100644 tests/integration/test_keycloak_auth/test.py diff --git a/docs/en/interfaces/cli.md b/docs/en/interfaces/cli.md index a25eb6d8fa23..3373bad3270e 100644 --- a/docs/en/interfaces/cli.md +++ b/docs/en/interfaces/cli.md @@ -836,7 +836,8 @@ All command-line options can be specified directly on the command line or as def | `-d [ --database ] ` | Select the database to default to for this connection. | The current database from the server settings (`default` by default) | | `-h [ --host ] ` | The hostname of the ClickHouse server to connect to. Can either be a hostname or an IPv4 or IPv6 address. Multiple hosts can be passed via multiple arguments. | `localhost` | | `--jwt ` | Use JSON Web Token (JWT) for authentication.

Server JWT authorization is only available in ClickHouse Cloud. | - | -| `login` | Invokes the device grant OAuth flow in order to authenticate via an IDP.

For ClickHouse Cloud hosts, the OAuth variables are inferred otherwise they must be provided with `--oauth-url`, `--oauth-client-id` and `--oauth-audience`. | - | +| `--login[=]` | Authenticate via OAuth2. Bare `--login` (no `=`) triggers ClickHouse Cloud automatic login — the provider is inferred from the server. To authenticate against a custom OpenID Connect provider, supply a `mode` and `--oauth-credentials`: `--login=browser` runs the Authorization Code + PKCE flow (opens a browser), `--login=device` runs the Device Authorization flow (prints a URL and short code — no browser needed). | - | +| `--oauth-credentials ` | Path to an OAuth2 credentials JSON file (Google Cloud Console format). Required when using `--login=browser` or `--login=device` with a custom OpenID Connect provider. See [OAuth credentials file format](#oauth-credentials-file) below. Refresh tokens are cached in `~/.clickhouse-client/oauth_cache.json` (mode `0600`). | `~/.clickhouse-client/oauth_client.json` | | `--no-warnings` | Disable showing warnings from `system.warnings` when the client connects to the server. | - | | `--no-server-client-version-message` | Suppress server-client version mismatch message when the client connects to the server. | - | | `--password ` | The password of the database user. You can also specify the password for a connection in the configuration file. If you do not specify the password, the client will ask for it. | - | @@ -851,6 +852,33 @@ All command-line options can be specified directly on the command line or as def Instead of the `--host`, `--port`, `--user` and `--password` options, the client also supports [connection strings](#connection_string). ::: +### OAuth credentials file {#oauth-credentials-file} + +When using `--login=browser` or `--login=device` with a custom OpenID Connect provider, the client reads a credentials JSON file. The file uses the same format produced by the Google Cloud Console ("OAuth 2.0 Client IDs" → "Download JSON"): + +```json +{ + "installed": { + "client_id": "YOUR_CLIENT_ID", + "client_secret": "YOUR_CLIENT_SECRET", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "redirect_uris": ["http://localhost"] + } +} +``` + +The top-level key can be `installed` (desktop/CLI apps) or `web`. Required fields: `client_id`, `client_secret`, `auth_uri`, `token_uri`. Optional fields: + +| Field | Description | +|---|---| +| `device_authorization_uri` | Device authorization endpoint. Discovered automatically via OIDC Discovery if absent. | +| `issuer` | OIDC issuer URL (e.g. `https://accounts.google.com`). Used to locate the discovery document when `device_authorization_uri` is not set. | + +The default path is `~/.clickhouse-client/oauth_client.json`. Override it with `--oauth-credentials `. + +After a successful login the obtained refresh token is cached in `~/.clickhouse-client/oauth_cache.json` (file mode `0600`). Subsequent runs reuse the cached token silently and only open the browser or print a device code when the refresh token has expired. + ### Query options {#command-line-options-query} | Option | Description | diff --git a/programs/client/Client.cpp b/programs/client/Client.cpp index 3ddc27b4fb06..e5985e9c98c2 100644 --- a/programs/client/Client.cpp +++ b/programs/client/Client.cpp @@ -26,6 +26,7 @@ #include #include +#include #include #include @@ -65,6 +66,7 @@ namespace ErrorCodes extern const int NETWORK_ERROR; extern const int AUTHENTICATION_FAILED; extern const int REQUIRED_PASSWORD; + extern const int SUPPORT_IS_DISABLED; extern const int USER_EXPIRED; } @@ -281,7 +283,7 @@ void Client::initialize(Poco::Util::Application & self) (loaded_config.configuration->has("user") || loaded_config.configuration->has("password"))) { /// Config file has auth credentials, so disable the auto-added login flag - config().setBool("login", false); + config().setBool("cloud_oauth_pending", false); } #endif } @@ -371,7 +373,7 @@ try } #if USE_JWT_CPP && USE_SSL - if (config().getBool("login", false)) + if (config().getBool("cloud_oauth_pending", false) && !config().has("jwt")) { login(); } @@ -727,8 +729,15 @@ void Client::addExtraOptions(OptionsDescription & options_description) ("ssh-key-passphrase", po::value(), "Passphrase for the SSH private key specified by --ssh-key-file.") ("quota_key", po::value(), "A string to differentiate quotas when the user have keyed quotas configured on server") ("jwt", po::value(), "Use JWT for authentication") + ("login", po::value()->implicit_value(""), + "Authenticate via OAuth2. Optional mode: 'browser' (auth-code + PKCE, opens browser) " + "or 'device' (device flow, prints URL + code). " + "Example: --login=browser or --login=device. " + "Bare --login uses the ClickHouse Cloud auto-login path.") + ("oauth-credentials", po::value(), + "Path to OAuth credentials JSON file " + "(default: ~/.clickhouse-client/oauth_client.json)") #if USE_JWT_CPP && USE_SSL - ("login", po::bool_switch(), "Use OAuth 2.0 to login") ("oauth-url", po::value(), "The base URL for the OAuth 2.0 authorization server") ("oauth-client-id", po::value(), "The client ID for the OAuth 2.0 application") ("oauth-audience", po::value(), "The audience for the OAuth 2.0 token") @@ -884,6 +893,8 @@ void Client::processOptions( config().setBool("no-server-client-version-message", true); if (options.contains("fake-drop")) config().setString("ignore_drop_queries_probability", "1"); + if (options.count("jwt") && options.count("login")) + throw Exception(ErrorCodes::BAD_ARGUMENTS, "--jwt and --login cannot both be specified"); if (options.contains("jwt")) { if (!options["user"].defaulted()) @@ -891,16 +902,66 @@ void Client::processOptions( config().setString("jwt", options["jwt"].as()); config().setString("user", ""); } -#if USE_JWT_CPP && USE_SSL - if (options["login"].as()) + if (options.count("oauth-credentials") && !options.count("login")) + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "--oauth-credentials requires --login=browser or --login=device"); + + if (options.count("login")) { + const std::string login_mode = options["login"].as(); + if (!login_mode.empty() && login_mode != "browser" && login_mode != "device") + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "--login value must be 'browser' or 'device', got '{}'", + login_mode); + +#if USE_JWT_CPP && USE_SSL if (!options["user"].defaulted()) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "User and login flags can't be specified together"); - if (config().has("jwt")) - throw Exception(ErrorCodes::BAD_ARGUMENTS, "JWT and login flags can't be specified together"); - config().setBool("login", true); - config().setString("user", ""); + throw Exception(ErrorCodes::BAD_ARGUMENTS, "--user and --login cannot both be specified"); + + // Bare --login (empty mode, including auto-added for *.clickhouse.cloud) → cloud path. + // Explicit --login=browser or --login=device (or --oauth-credentials) → credentials-file + // OIDC path. This prevents the credentials file from hijacking the cloud auto-login. + const bool use_credentials_file + = !login_mode.empty() + || options.count("oauth-credentials"); + + if (use_credentials_file) + { + const char * home_path_cstr = getenv("HOME"); // NOLINT(concurrency-mt-unsafe) + const std::string default_creds_path = home_path_cstr + ? std::string(home_path_cstr) + "/.clickhouse-client/oauth_client.json" + : ""; + + const std::string creds_path = options.count("oauth-credentials") + ? options["oauth-credentials"].as() + : default_creds_path; + + auto creds = loadOAuthCredentials(creds_path); + const auto mode = (login_mode == "device") ? OAuthFlowMode::Device : OAuthFlowMode::AuthCode; + + // createOAuthJWTProvider runs the initial flow (trying the cached + // refresh token first) and returns a provider that Connection can + // call to refresh the id_token transparently during long sessions. + jwt_provider = createOAuthJWTProvider(creds, mode); + config().setString("jwt", jwt_provider->getJWT()); + config().setString("user", ""); + } + else + { + // Cloud-specific login path — bare --login, including auto-added for + // *.clickhouse.cloud endpoints. Use a separate config key so that + // argsToConfig() overwriting config["login"] with the raw string value + // cannot cause getBool("login") to throw in main(). + config().setBool("cloud_oauth_pending", true); + config().setString("user", ""); + } +#else + throw Exception(ErrorCodes::SUPPORT_IS_DISABLED, "OAuth login requires a build with JWT and SSL support"); +#endif } +#if USE_JWT_CPP && USE_SSL if (options.contains("oauth-url")) config().setString("oauth-url", options["oauth-url"].as()); if (options.contains("oauth-client-id")) @@ -1074,6 +1135,7 @@ void Client::readArguments( std::string_view arg(argv[i]); if (arg.starts_with("--user") || arg.starts_with("--password") || arg.starts_with("--jwt") || arg.starts_with("--ssh-key-file") || + arg == "--login" || arg.starts_with("--login=") || arg == "-u") { has_auth_in_cmdline = true; diff --git a/src/Access/TokenProcessorsOpaque.cpp b/src/Access/TokenProcessorsOpaque.cpp index b6f50f677564..649990ad7983 100644 --- a/src/Access/TokenProcessorsOpaque.cpp +++ b/src/Access/TokenProcessorsOpaque.cpp @@ -114,7 +114,7 @@ bool GoogleTokenProcessor::resolveAndValidate(TokenCredentials & credentials) co auto token_info = getObjectFromURI(Poco::URI("https://www.googleapis.com/oauth2/v3/tokeninfo"), token); if (token_info.contains("exp")) - credentials.setExpiresAt(std::chrono::system_clock::from_time_t((getValueByKey(token_info, "exp").value()))); + credentials.setExpiresAt(std::chrono::system_clock::from_time_t(static_cast(getValueByKey(token_info, "exp").value()))); /// Groups info can only be retrieved if user email is known. /// If no email found in user info, we skip this step and there are no external roles for the user. diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5192e4ec2136..a8b87128cf5f 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -484,10 +484,6 @@ target_link_libraries( Poco::Redis ) -if (TARGET ch_contrib::jwt-cpp) - target_link_libraries(clickhouse_common_io PUBLIC ch_contrib::jwt-cpp) -endif() - if (TARGET ch_contrib::mongocxx) target_link_libraries( dbms @@ -775,6 +771,3 @@ if (ENABLE_TESTS) endif() endif () -if (TARGET ch_contrib::jwt-cpp) - add_object_library(clickhouse_client_jwt Client/jwt) -endif() diff --git a/src/Client/OAuthJWTProvider.cpp b/src/Client/OAuthJWTProvider.cpp new file mode 100644 index 000000000000..1daeca08681a --- /dev/null +++ b/src/Client/OAuthJWTProvider.cpp @@ -0,0 +1,64 @@ +#include + +#if USE_JWT_CPP && USE_SSL + +#include +#include + +#include +#include + +#include +#include + +namespace DB +{ + +/// JWTProvider subclass for the credentials-file OIDC path (--login=browser / +/// --login=device). Extends JWTProvider so that Connection::sendQuery can call +/// getJWT() transparently to refresh the id_token before it expires, eliminating +/// the 1-hour session limit that arises when the token is obtained only once at +/// startup. +/// +/// getJWT() delegates to obtainIDToken() which already handles the full lifecycle: +/// 1. try cached refresh token from disk +/// 2. run interactive flow (browser or device) if the refresh token is absent +/// or rejected +class OAuthJWTProvider : public JWTProvider +{ +public: + OAuthJWTProvider(OAuthCredentials creds, OAuthFlowMode mode) + : JWTProvider("", creds.client_id, "", std::cerr, std::cerr) + , creds_(std::move(creds)) + , mode_(mode) + {} + + std::string getJWT() override + { + constexpr int EXPIRY_BUFFER_SECONDS = 30; + + if (!idp_access_token.empty() + && Poco::Timestamp() < idp_access_token_expires_at - Poco::Timespan(EXPIRY_BUFFER_SECONDS, 0)) + return idp_access_token; + + // obtainIDToken tries the disk-cached refresh token first and falls back + // to an interactive flow only when necessary. + idp_access_token = obtainIDToken(creds_, mode_); + idp_access_token_expires_at = getJwtExpiry(idp_access_token); + return idp_access_token; + } + +private: + OAuthCredentials creds_; + OAuthFlowMode mode_; +}; + +std::shared_ptr createOAuthJWTProvider( + const OAuthCredentials & creds, OAuthFlowMode mode) +{ + return std::make_shared(creds, mode); +} + +} // namespace DB + +#endif // USE_JWT_CPP && USE_SSL diff --git a/src/Client/OAuthLogin.cpp b/src/Client/OAuthLogin.cpp new file mode 100644 index 000000000000..80610347d84c --- /dev/null +++ b/src/Client/OAuthLogin.cpp @@ -0,0 +1,843 @@ +#include +#include + +#if USE_JWT_CPP && USE_SSL + +# include +# include +# include + +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include +# include + +# include + +# include +# include +# include +# include +# include +# include +# include + +# if defined(__APPLE__) || defined(__linux__) +# include +# include +# endif + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int BAD_ARGUMENTS; +extern const int AUTHENTICATION_FAILED; +} + +namespace +{ + +// HTTP request timeout for all OAuth endpoint calls. +constexpr int HTTP_TIMEOUT_SECONDS = 30; + +/// Minimal HTML escaping to prevent XSS when reflecting user-supplied strings +/// (e.g. the error= query parameter from the OAuth callback) into HTML. +std::string htmlEscape(const std::string & s) +{ + std::string out; + out.reserve(s.size()); + for (char c : s) + { + switch (c) + { + case '&': out += "&"; break; + case '<': out += "<"; break; + case '>': out += ">"; break; + case '"': out += """; break; + case '\'': out += "'"; break; + default: out += c; break; + } + } + return out; +} + +// --------------------------------------------------------------------------- +// 2. discoverDeviceEndpoint +// --------------------------------------------------------------------------- + +/// Fetch the OIDC discovery document and return device_authorization_endpoint. +/// +/// issuer_hint: explicit OIDC issuer URL (e.g. from credentials JSON "issuer" field). +/// When non-empty it is used directly: discovery is at {issuer_hint}/.well-known/openid-configuration. +/// When empty, issuer is derived heuristically from token_uri: +/// - Google (oauth2.googleapis.com) → https://accounts.google.com (hardcoded mapping) +/// - Generic: strip last path segment, preserving realm prefixes +/// e.g. https://auth.example.com/realms/myrealm/protocol/openid-connect/token +/// → https://auth.example.com/realms/myrealm +/// For providers whose issuer cannot be reliably derived, set "issuer" or +/// "device_authorization_uri" in the credentials JSON to bypass discovery. +std::string discoverDeviceEndpoint(const std::string & token_uri, const std::string & issuer_hint) +{ + std::string issuer; + if (!issuer_hint.empty()) + { + issuer = issuer_hint; + } + else + { + Poco::URI uri(token_uri); + if (uri.getHost() == "oauth2.googleapis.com") + { + // Google uses a separate domain for its OIDC discovery. + issuer = "https://accounts.google.com"; + } + else + { + // Build scheme://host[:port] prefix. + issuer = uri.getScheme() + "://" + uri.getHost(); + if (uri.getPort() != 0 + && !((uri.getScheme() == "https" && uri.getPort() == 443) + || (uri.getScheme() == "http" && uri.getPort() == 80))) + issuer += ":" + std::to_string(uri.getPort()); + + // Append the path minus its last segment so that issuers with + // sub-paths (e.g. Keycloak's /realms/) are preserved. + std::string path = uri.getPath(); + const auto last_slash = path.rfind('/'); + if (last_slash != std::string::npos && last_slash != 0) + issuer += path.substr(0, last_slash); + } + } + + const std::string discovery_url = issuer + "/.well-known/openid-configuration"; + Poco::URI disc_uri(discovery_url); + + Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_GET, disc_uri.getPathAndQuery()); + Poco::Net::HTTPResponse response; + std::string body; + + if (disc_uri.getScheme() == "https") + { + Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); + Poco::Net::HTTPSClientSession session(disc_uri.getHost(), disc_uri.getPort(), ctx); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + session.sendRequest(request); + auto & stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(stream, body); + } + else + { + Poco::Net::HTTPClientSession session(disc_uri.getHost(), disc_uri.getPort()); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + session.sendRequest(request); + auto & stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(stream, body); + } + + if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_OK) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OIDC discovery failed for '{}': {} {}", + discovery_url, + static_cast(response.getStatus()), + response.getReason()); + + Poco::JSON::Parser parser; + auto result = parser.parse(body); + auto obj = result.extract(); + + if (!obj->has("device_authorization_endpoint")) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OIDC discovery document at '{}' does not contain device_authorization_endpoint", + discovery_url); + + return obj->getValue("device_authorization_endpoint"); +} + +// --------------------------------------------------------------------------- +// 3. generatePKCE +// --------------------------------------------------------------------------- + +struct PKCEPair +{ + std::string verifier; + std::string challenge; +}; + +PKCEPair generatePKCE() +{ + // 32 random bytes → base64url (43 chars, no padding) + unsigned char raw[32]; + if (RAND_bytes(raw, sizeof(raw)) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for PKCE verifier"); + + std::string verifier = base64Encode( + std::string(reinterpret_cast(raw), sizeof(raw)), + /*url_encoding=*/true, + /*no_padding=*/true); + + // challenge = BASE64URL(SHA256(verifier)) + std::string sha = encodeSHA256(verifier); + std::string challenge = base64Encode(sha, /*url_encoding=*/true, /*no_padding=*/true); + + return {verifier, challenge}; +} + +// --------------------------------------------------------------------------- +// 4. urlEncode +// --------------------------------------------------------------------------- + +std::string urlEncode(const std::string & s) +{ + std::string result; + Poco::URI::encode(s, "", result); + return result; +} + +/// Google uses access_type=offline instead of the offline_access scope. +/// Detect by checking the token endpoint host. +bool isGoogleProvider(const OAuthCredentials & creds) +{ + Poco::URI uri(creds.token_uri); + const std::string & host = uri.getHost(); + return host == "oauth2.googleapis.com" || host == "accounts.google.com"; +} + +// --------------------------------------------------------------------------- +// 5. postForm — HTTPS/HTTP POST application/x-www-form-urlencoded +// +// Always attempts to parse the response body as JSON, regardless of the HTTP +// status code. RFC 6749 returns error responses (e.g. authorization_pending +// during device-flow polling) as HTTP 400 with a JSON body — callers must +// inspect the "error" field in the returned object. +// +// Throws only when the body cannot be parsed as JSON: +// - 4xx/5xx with non-JSON body → AUTHENTICATION_FAILED with HTTP status +// - 2xx with non-JSON body → AUTHENTICATION_FAILED (unexpected format) +// --------------------------------------------------------------------------- + +Poco::JSON::Object::Ptr postForm(const std::string & url, const std::string & body) +{ + Poco::URI uri(url); + Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_POST, uri.getPathAndQuery()); + request.setContentType("application/x-www-form-urlencoded"); + request.setContentLength(static_cast(body.size())); + + Poco::Net::HTTPResponse response; + std::string response_body; + + if (uri.getScheme() == "https") + { + Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); + Poco::Net::HTTPSClientSession session(uri.getHost(), uri.getPort(), ctx); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + auto & req_stream = session.sendRequest(request); + req_stream << body; + auto & resp_stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(resp_stream, response_body); + } + else + { + Poco::Net::HTTPClientSession session(uri.getHost(), uri.getPort()); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + auto & req_stream = session.sendRequest(request); + req_stream << body; + auto & resp_stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(resp_stream, response_body); + } + + // Try JSON parse regardless of status — RFC 6749 §5.2 returns errors + // in JSON bodies even on HTTP 400 (e.g., authorization_pending, slow_down). + Poco::Dynamic::Var parsed; + try + { + Poco::JSON::Parser parser; + parsed = parser.parse(response_body); + } + catch (...) + { + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OAuth2 endpoint '{}' returned HTTP {} with non-JSON body: {}", + url, + static_cast(response.getStatus()), + response_body.substr(0, 512)); + } + + auto obj = parsed.extract(); + if (!obj) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OAuth2 endpoint '{}' returned HTTP {} with non-object JSON response: {}", + url, + static_cast(response.getStatus()), + response_body.substr(0, 512)); + return obj; +} + +// --------------------------------------------------------------------------- +// 6. Token cache +// --------------------------------------------------------------------------- + +std::string cacheKey(const std::string & client_id) +{ + // First 16 hex chars of SHA256(client_id) + std::string hash = encodeSHA256(client_id); + std::string hex; + hex.reserve(32); + for (unsigned char c : hash) + { + constexpr char digits[] = "0123456789abcdef"; + hex += digits[(c >> 4) & 0xF]; + hex += digits[c & 0xF]; + } + return hex.substr(0, 16); +} + +std::string cacheFilePath() +{ + const char * home = std::getenv("HOME"); // NOLINT(concurrency-mt-unsafe) + if (!home) + return ""; + return std::string(home) + "/.clickhouse-client/oauth_cache.json"; +} + +std::string readCachedRefreshToken(const std::string & client_id) +{ + const std::string path = cacheFilePath(); + if (path.empty()) + return ""; + + std::ifstream f(path); + if (!f.is_open()) + return ""; + + std::string content((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + try + { + Poco::JSON::Parser parser; + auto result = parser.parse(content); + auto obj = result.extract(); + const std::string key = cacheKey(client_id); + if (obj->has(key)) + return obj->getValue(key); + } + catch (...) + { + std::cerr << "Note: OAuth token cache at '" << cacheFilePath() + << "' could not be parsed and will be ignored.\n"; + } + return ""; +} + +void writeCachedRefreshToken(const std::string & client_id, const std::string & refresh_token) +{ + const std::string path = cacheFilePath(); + if (path.empty()) + return; + + namespace fs = std::filesystem; + const fs::path cache_path(path); + + // Ensure directory exists + fs::create_directories(cache_path.parent_path()); + + // Read existing cache + Poco::JSON::Object obj; + { + std::ifstream f(path); + if (f.is_open()) + { + std::string content((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + try + { + Poco::JSON::Parser parser; + auto result = parser.parse(content); + auto existing = result.extract(); + for (auto it = existing->begin(); it != existing->end(); ++it) + obj.set(it->first, it->second); + } + catch (...) + { + std::cerr << "Note: OAuth token cache at '" << path + << "' could not be parsed; existing entries will be lost.\n"; + } + } + } + + obj.set(cacheKey(client_id), refresh_token); + + // Write atomically: write to a temp file beside the cache, then rename. + // This prevents a partially-written file from being left world-readable if + // we are interrupted between the write and the chmod. + const std::string tmp_path = path + ".tmp"; + { + std::ofstream out(tmp_path, std::ios::trunc); + if (!out.is_open()) + return; + Poco::JSON::Stringifier::stringify(obj, out); + } + + std::error_code ec; + fs::permissions(tmp_path, fs::perms::owner_read | fs::perms::owner_write, fs::perm_options::replace, ec); + fs::rename(tmp_path, cache_path, ec); +} + +// --------------------------------------------------------------------------- +// 7. tryRefreshToken +// --------------------------------------------------------------------------- + +std::string tryRefreshToken(const OAuthCredentials & creds, const std::string & refresh_token) +{ + try + { + const std::string body + = "grant_type=refresh_token" + "&client_id=" + urlEncode(creds.client_id) + + "&client_secret=" + urlEncode(creds.client_secret) + + "&refresh_token=" + urlEncode(refresh_token); + + auto resp = postForm(creds.token_uri, body); + if (resp->has("error")) + { + std::cerr << "Note: cached refresh token was rejected (" + << resp->getValue("error") + << "); re-authenticating.\n"; + return ""; + } + if (resp->has("id_token")) + return resp->getValue("id_token"); + } + catch (const std::exception & e) + { + std::cerr << "Note: refresh token exchange failed (" << e.what() + << "); re-authenticating.\n"; + } + return ""; +} + +// --------------------------------------------------------------------------- +// 8. openBrowser +// --------------------------------------------------------------------------- + +void openBrowser(const std::string & url) +{ + // Always print so the user can copy-paste on headless / remote sessions. + std::cerr << "Opening browser for authentication.\n" + << "If the browser does not open, visit:\n " << url << "\n"; + +# if defined(__APPLE__) || defined(__linux__) + // Use posix_spawnp instead of system() to avoid shell-quoting issues. + const char * cmd = +# if defined(__APPLE__) + "open"; +# else + "xdg-open"; +# endif + const char * argv[] = {cmd, url.c_str(), nullptr}; + pid_t pid; + if (posix_spawnp(&pid, cmd, nullptr, nullptr, const_cast(argv), nullptr) == 0) + waitpid(pid, nullptr, 0); +# endif +} + +// --------------------------------------------------------------------------- +// 9. runAuthCodeFlow — auth code + PKCE, one-shot localhost callback server +// --------------------------------------------------------------------------- + +struct AuthCodeState +{ + std::mutex mtx; + std::condition_variable cv; + std::string code; + std::string error; + std::string received_state; // state= value echoed back by the provider + bool done = false; +}; + +class AuthCodeHandler : public Poco::Net::HTTPRequestHandler +{ +public: + explicit AuthCodeHandler(AuthCodeState & state_) : state(state_) { } + + void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override + { + Poco::URI uri("http://localhost" + request.getURI()); + const auto params = uri.getQueryParameters(); + + std::string code; + std::string error; + std::string received_state; + for (const auto & [k, v] : params) + { + if (k == "code") + code = v; + else if (k == "error") + error = v; + else if (k == "state") + received_state = v; + } + + response.setStatus(Poco::Net::HTTPResponse::HTTP_OK); + response.setContentType("text/html"); + auto & out = response.send(); + if (!code.empty()) + out << "Authentication successful. You may close this tab."; + else + out << "Authentication failed: " << htmlEscape(error) << ""; + out.flush(); + + std::lock_guard lock(state.mtx); + state.code = code; + state.error = error; + state.received_state = received_state; + state.done = true; + state.cv.notify_one(); + } + +private: + AuthCodeState & state; +}; + +class AuthCodeHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory +{ +public: + explicit AuthCodeHandlerFactory(AuthCodeState & state_) : state(state_) { } + + Poco::Net::HTTPRequestHandler * createRequestHandler(const Poco::Net::HTTPServerRequest &) override + { + return new AuthCodeHandler(state); + } + +private: + AuthCodeState & state; +}; + +std::string runAuthCodeFlow(const OAuthCredentials & creds) +{ + auto pkce = generatePKCE(); + + // Generate a random anti-CSRF state value per RFC 6749 §10.12. + unsigned char state_bytes[16]; + if (RAND_bytes(state_bytes, sizeof(state_bytes)) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for OAuth state"); + std::string csrf_state; + csrf_state.reserve(32); + for (unsigned char b : state_bytes) + { + constexpr char digits[] = "0123456789abcdef"; + csrf_state += digits[(b >> 4) & 0xF]; + csrf_state += digits[b & 0xF]; + } + + // Ephemeral callback server bound exclusively to the loopback interface. + // Binding to 127.0.0.1 (not 0.0.0.0) ensures network-adjacent attackers + // cannot race to deliver a forged callback even without the CSRF state check. + Poco::Net::ServerSocket server_socket; + server_socket.bind(Poco::Net::SocketAddress("127.0.0.1", 0), /*reuse_address=*/true); + server_socket.listen(1); + const uint16_t port = server_socket.address().port(); + const std::string redirect_uri = "http://localhost:" + std::to_string(port) + "/callback"; + + AuthCodeState state; + auto params = Poco::AutoPtr(new Poco::Net::HTTPServerParams()); + params->setMaxQueued(1); + params->setMaxThreads(1); + Poco::Net::HTTPServer server(new AuthCodeHandlerFactory(state), server_socket, params); + server.start(); + + // Build authorization URL — scope uses %20-encoded spaces per RFC 6749 §3.3. + // Google uses access_type=offline to request a refresh token rather than + // the standard offline_access scope (which it rejects as invalid). + const bool google = isGoogleProvider(creds); + const std::string scope = google + ? "openid email profile" + : "openid email profile offline_access"; + std::string auth_url + = creds.auth_uri + + "?response_type=code" + "&client_id=" + urlEncode(creds.client_id) + + "&redirect_uri=" + urlEncode(redirect_uri) + + "&code_challenge=" + pkce.challenge + + "&code_challenge_method=S256" + + "&scope=" + urlEncode(scope) + + "&state=" + csrf_state; + if (google) + auth_url += "&access_type=offline"; + + openBrowser(auth_url); + + // Wait up to 120 s for the browser callback. + bool timed_out = false; + std::string received_code; + std::string received_error; + std::string received_state; + { + std::unique_lock lock(state.mtx); + timed_out = !state.cv.wait_for(lock, std::chrono::seconds(120), [&] { return state.done; }); + received_code = state.code; + received_error = state.error; + received_state = state.received_state; + } + // Release the mutex before stopping the server to avoid a deadlock with + // the request handler that also acquires state.mtx. + server.stop(); + + if (timed_out) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 login timed out waiting for browser callback"); + if (!received_error.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 authorization error: {}", received_error); + if (received_code.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 callback did not contain an authorization code"); + if (received_state != csrf_state) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 CSRF check failed: unexpected state in callback"); + + // Exchange authorization code for tokens. + const std::string body + = "grant_type=authorization_code" + "&code=" + urlEncode(received_code) + + "&redirect_uri=" + urlEncode(redirect_uri) + + "&client_id=" + urlEncode(creds.client_id) + + "&client_secret=" + urlEncode(creds.client_secret) + + "&code_verifier=" + urlEncode(pkce.verifier); + + auto resp = postForm(creds.token_uri, body); + + if (resp->has("error")) + { + const std::string desc = resp->has("error_description") + ? resp->getValue("error_description") + : resp->getValue("error"); + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token exchange failed: {}", desc); + } + + if (resp->has("refresh_token")) + writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); + + if (!resp->has("id_token")) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token response did not contain id_token"); + + return resp->getValue("id_token"); +} + +// --------------------------------------------------------------------------- +// 10. runDeviceFlow +// --------------------------------------------------------------------------- + +std::string runDeviceFlow(OAuthCredentials creds) +{ + if (creds.device_auth_uri.empty()) + creds.device_auth_uri = discoverDeviceEndpoint(creds.token_uri, creds.issuer); + + // Scope uses %20-encoded spaces per RFC 6749 §3.3. + // Google rejects offline_access as an invalid scope; it issues a refresh + // token automatically for device flow. Standard OIDC providers require it. + const std::string device_scope = isGoogleProvider(creds) + ? "openid email profile" + : "openid email profile offline_access"; + const std::string device_body + = "client_id=" + urlEncode(creds.client_id) + + "&scope=" + urlEncode(device_scope); + + auto device_resp = postForm(creds.device_auth_uri, device_body); + + if (device_resp->has("error")) + { + const std::string desc = device_resp->has("error_description") + ? device_resp->getValue("error_description") + : device_resp->getValue("error"); + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device authorization request failed: {}", desc); + } + + // Validate mandatory fields before accessing them; getValue() on a missing + // key returns an empty Var and throws "Can not convert empty value". + if (!device_resp->has("device_code") || !device_resp->has("user_code")) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "Device authorization response from '{}' is missing required fields " + "(device_code / user_code). Response: {}", + creds.device_auth_uri, + [&]{ std::ostringstream ss; device_resp->stringify(ss); return ss.str(); }()); + + const std::string device_code = device_resp->getValue("device_code"); + const std::string user_code = device_resp->getValue("user_code"); + + // RFC 8628 uses "verification_uri"; Google's older device API uses "verification_url". + const std::string verification_uri = device_resp->has("verification_uri_complete") + ? device_resp->getValue("verification_uri_complete") + : device_resp->has("verification_uri") + ? device_resp->getValue("verification_uri") + : device_resp->has("verification_url") + ? device_resp->getValue("verification_url") + : throw Exception(ErrorCodes::AUTHENTICATION_FAILED, + "Device authorization response missing verification_uri / verification_url"); + + int interval = device_resp->has("interval") ? device_resp->getValue("interval") : 5; + int expires_in = device_resp->has("expires_in") ? device_resp->getValue("expires_in") : 300; + + std::cerr << "\nTo authenticate, visit:\n " << verification_uri + << "\nAnd enter code: " << user_code << "\n\n"; + + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(expires_in); + + while (std::chrono::steady_clock::now() < deadline) + { + std::this_thread::sleep_for(std::chrono::seconds(interval)); + + const std::string poll_body + = "grant_type=urn:ietf:params:oauth:grant-type:device_code" + "&device_code=" + urlEncode(device_code) + + "&client_id=" + urlEncode(creds.client_id) + + "&client_secret=" + urlEncode(creds.client_secret); + + auto resp = postForm(creds.token_uri, poll_body); + + if (resp->has("error")) + { + const std::string err = resp->getValue("error"); + if (err == "authorization_pending") + continue; + if (err == "slow_down") + { + interval += 5; + continue; + } + const std::string desc = resp->has("error_description") + ? resp->getValue("error_description") + : err; + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow error: {}", desc); + } + + if (resp->has("refresh_token")) + writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); + + if (!resp->has("id_token")) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow token response did not contain id_token"); + + return resp->getValue("id_token"); + } + + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow timed out"); +} + +} // anonymous namespace + +// --------------------------------------------------------------------------- +// Public API +// --------------------------------------------------------------------------- + +OAuthCredentials loadOAuthCredentials(const std::string & path) +{ + std::ifstream f(path); + if (!f.is_open()) + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "OAuth credentials file not found: '{}'\n" + "Place a Google-format credentials JSON at that path, or specify " + "--oauth-credentials /path/to/file.json", + path); + + std::string content((std::istreambuf_iterator(f)), std::istreambuf_iterator()); + + Poco::JSON::Parser parser; + Poco::Dynamic::Var parsed; + try + { + parsed = parser.parse(content); + } + catch (const std::exception & e) + { + throw Exception(ErrorCodes::BAD_ARGUMENTS, "Failed to parse OAuth credentials file '{}': {}", path, e.what()); + } + + auto root = parsed.extract(); + + // Accept either "installed" (desktop) or "web" top-level key. + Poco::JSON::Object::Ptr app; + if (root->has("installed")) + app = root->getObject("installed"); + else if (root->has("web")) + app = root->getObject("web"); + else + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "OAuth credentials file '{}' must have an 'installed' or 'web' top-level key", + path); + + auto require = [&](const std::string & key) -> std::string + { + if (!app->has(key)) + throw Exception( + ErrorCodes::BAD_ARGUMENTS, + "OAuth credentials file '{}' is missing required field '{}'", + path, + key); + return app->getValue(key); + }; + + OAuthCredentials creds; + creds.client_id = require("client_id"); + creds.client_secret = require("client_secret"); + creds.auth_uri = require("auth_uri"); + creds.token_uri = require("token_uri"); + + if (app->has("device_authorization_uri")) + creds.device_auth_uri = app->getValue("device_authorization_uri"); + if (app->has("issuer")) + creds.issuer = app->getValue("issuer"); + + // Warn if any endpoint uses plain HTTP — token exchanges should be encrypted. + auto warnIfHttp = [&](const std::string & field, const std::string & uri) + { + if (uri.size() >= 7 && uri.substr(0, 7) == "http://") + std::cerr << "Warning: OAuth credentials field '" << field << "' uses plain HTTP ('" + << uri << "'). Token exchanges over HTTP expose client credentials.\n"; + }; + warnIfHttp("token_uri", creds.token_uri); + warnIfHttp("auth_uri", creds.auth_uri); + if (!creds.device_auth_uri.empty()) + warnIfHttp("device_authorization_uri", creds.device_auth_uri); + + return creds; +} + +std::string obtainIDToken(const OAuthCredentials & creds, OAuthFlowMode mode) +{ + // 1. Try cached refresh token silently. + const std::string cached_refresh = readCachedRefreshToken(creds.client_id); + if (!cached_refresh.empty()) + { + const std::string id_token = tryRefreshToken(creds, cached_refresh); + if (!id_token.empty()) + return id_token; + // Refresh token expired or revoked — fall through to interactive flow. + } + + // 2. Run interactive flow. + if (mode == OAuthFlowMode::Device) + return runDeviceFlow(creds); + else + return runAuthCodeFlow(creds); +} + +} // namespace DB + +#endif // USE_JWT_CPP && USE_SSL diff --git a/src/Client/OAuthLogin.h b/src/Client/OAuthLogin.h new file mode 100644 index 000000000000..600e577fa8a4 --- /dev/null +++ b/src/Client/OAuthLogin.h @@ -0,0 +1,44 @@ +#pragma once + +#include +#include +#include + +namespace DB +{ + +class JWTProvider; // forward declaration — full type available with USE_JWT_CPP && USE_SSL + +enum class OAuthFlowMode +{ + AuthCode, + Device, +}; + +struct OAuthCredentials +{ + std::string client_id; + std::string client_secret; + std::string auth_uri; // authorization_endpoint + std::string token_uri; // token_endpoint + std::string device_auth_uri; // device_authorization_endpoint (discovered if empty) + std::string issuer; // OIDC issuer URL (optional; used to locate discovery document) +}; + +/// Load from Google-format JSON credentials file. +/// Throws if file not found or malformed. +OAuthCredentials loadOAuthCredentials(const std::string & path); + +/// Run OAuth flow, return ID token. Throws on failure. +std::string obtainIDToken(const OAuthCredentials & creds, OAuthFlowMode mode); + +#if USE_JWT_CPP && USE_SSL +/// Create a JWTProvider that runs the initial OAuth flow and then silently +/// refreshes the id_token via the cached refresh token for the lifetime +/// of the session. Assign the result to Client::jwt_provider so that +/// Connection::sendQuery can call getJWT() on each query. +std::shared_ptr createOAuthJWTProvider( + const OAuthCredentials & creds, OAuthFlowMode mode); +#endif + +} diff --git a/src/Client/tests/gtest_oauth_login.cpp b/src/Client/tests/gtest_oauth_login.cpp new file mode 100644 index 000000000000..59087b22c91d --- /dev/null +++ b/src/Client/tests/gtest_oauth_login.cpp @@ -0,0 +1,275 @@ +#include + +#if USE_JWT_CPP && USE_SSL + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +using namespace DB; + +namespace +{ + +namespace fs = std::filesystem; + +/// Write content to a temp file and return its path. The caller owns the file. +std::string writeTempFile(const std::string & content) +{ + const fs::path tmp = fs::temp_directory_path() / fs::path("gtest_oauth_XXXXXX"); + // std::tmpnam is deprecated — build a unique name with mkstemp. + std::string tmpl = tmp.string(); + int fd = mkstemp(tmpl.data()); + if (fd < 0) + throw std::runtime_error("mkstemp failed"); + close(fd); + + std::ofstream f(tmpl, std::ios::trunc); + f << content; + return tmpl; +} + +} // anonymous namespace + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — valid "installed" format +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, LoadInstalledFormat) +{ + const std::string json = R"({ + "installed": { + "client_id": "test-client-id", + "client_secret": "test-secret", + "auth_uri": "https://auth.example.com/auth", + "token_uri": "https://auth.example.com/token", + "redirect_uris": ["http://localhost"] + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + auto creds = loadOAuthCredentials(path); + EXPECT_EQ(creds.client_id, "test-client-id"); + EXPECT_EQ(creds.client_secret, "test-secret"); + EXPECT_EQ(creds.auth_uri, "https://auth.example.com/auth"); + EXPECT_EQ(creds.token_uri, "https://auth.example.com/token"); + EXPECT_TRUE(creds.device_auth_uri.empty()); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — valid "web" format +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, LoadWebFormat) +{ + const std::string json = R"({ + "web": { + "client_id": "web-client", + "client_secret": "web-secret", + "auth_uri": "https://web.example.com/auth", + "token_uri": "https://web.example.com/token" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + auto creds = loadOAuthCredentials(path); + EXPECT_EQ(creds.client_id, "web-client"); + EXPECT_EQ(creds.client_secret, "web-secret"); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — optional device_authorization_uri is loaded +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, LoadDeviceAuthUri) +{ + const std::string json = R"({ + "installed": { + "client_id": "x", + "client_secret": "y", + "auth_uri": "https://a.example.com/auth", + "token_uri": "https://a.example.com/token", + "device_authorization_uri": "https://a.example.com/device" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + auto creds = loadOAuthCredentials(path); + EXPECT_EQ(creds.device_auth_uri, "https://a.example.com/device"); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — missing top-level key throws BAD_ARGUMENTS +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, MissingTopLevelKey) +{ + const std::string json = R"({ "other_key": {} })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + EXPECT_THROW(loadOAuthCredentials(path), Exception); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — missing required field throws BAD_ARGUMENTS +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, MissingClientId) +{ + const std::string json = R"({ + "installed": { + "client_secret": "s", + "auth_uri": "https://a.example.com/auth", + "token_uri": "https://a.example.com/token" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + EXPECT_THROW(loadOAuthCredentials(path), Exception); +} + +TEST(OAuthLogin, MissingTokenUri) +{ + const std::string json = R"({ + "installed": { + "client_id": "c", + "client_secret": "s", + "auth_uri": "https://a.example.com/auth" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + EXPECT_THROW(loadOAuthCredentials(path), Exception); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — file not found throws BAD_ARGUMENTS +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, FileNotFound) +{ + EXPECT_THROW(loadOAuthCredentials("/nonexistent/path/oauth_client.json"), Exception); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — invalid JSON throws BAD_ARGUMENTS +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, InvalidJson) +{ + auto path = writeTempFile("not valid json {{{"); + SCOPE_EXIT({ fs::remove(path); }); + + EXPECT_THROW(loadOAuthCredentials(path), Exception); +} + +// --------------------------------------------------------------------------- +// loadOAuthCredentials — optional "issuer" field is loaded +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, LoadIssuerField) +{ + const std::string json = R"({ + "installed": { + "client_id": "x", + "client_secret": "y", + "auth_uri": "https://a.example.com/auth", + "token_uri": "https://a.example.com/token", + "issuer": "https://a.example.com" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + auto creds = loadOAuthCredentials(path); + EXPECT_EQ(creds.issuer, "https://a.example.com"); +} + +TEST(OAuthLogin, IssuerFieldAbsent) +{ + const std::string json = R"({ + "installed": { + "client_id": "x", + "client_secret": "y", + "auth_uri": "https://a.example.com/auth", + "token_uri": "https://a.example.com/token" + } + })"; + + auto path = writeTempFile(json); + SCOPE_EXIT({ fs::remove(path); }); + + auto creds = loadOAuthCredentials(path); + EXPECT_TRUE(creds.issuer.empty()); +} + +// --------------------------------------------------------------------------- +// PKCE building blocks +// +// generatePKCE() is in the anonymous namespace so we test its constituent +// operations (base64url encoding and SHA-256) directly. This verifies the +// exact properties that RFC 7636 §4 requires of the verifier and challenge. +// --------------------------------------------------------------------------- + +TEST(OAuthLogin, Base64UrlEncodingProperties) +{ + // 32 bytes → 43 base64url chars (no padding, RFC 7636 §4.1 requires 43-128). + const std::string raw(32, '\xAB'); + const std::string encoded = base64Encode(raw, /*url_encoding=*/true, /*no_padding=*/true); + + EXPECT_EQ(encoded.size(), 43u); + + // Must contain only URL-safe base64 chars: A-Z a-z 0-9 - _ + const bool all_safe = std::all_of(encoded.begin(), encoded.end(), [](unsigned char c) { + return std::isalnum(c) || c == '-' || c == '_'; + }); + EXPECT_TRUE(all_safe) << "base64url output contains non-URL-safe characters: " << encoded; + + // Must NOT contain padding or standard base64 symbols. + EXPECT_EQ(encoded.find('='), std::string::npos); + EXPECT_EQ(encoded.find('+'), std::string::npos); + EXPECT_EQ(encoded.find('/'), std::string::npos); +} + +TEST(OAuthLogin, PKCEChallengeDerivation) +{ + // SHA256(verifier) encodes to 32 bytes; base64url(32 bytes) = 43 chars. + const std::string verifier = base64Encode(std::string(32, '\x01'), true, true); + const std::string sha = encodeSHA256(verifier); + EXPECT_EQ(sha.size(), 32u); + + const std::string challenge = base64Encode(sha, true, true); + EXPECT_EQ(challenge.size(), 43u); + + // Challenge must differ from verifier. + EXPECT_NE(challenge, verifier); + + // Challenge must be deterministic for the same verifier. + EXPECT_EQ(base64Encode(encodeSHA256(verifier), true, true), challenge); + + // Different verifiers must produce different challenges. + const std::string verifier2 = base64Encode(std::string(32, '\x02'), true, true); + EXPECT_NE(base64Encode(encodeSHA256(verifier2), true, true), challenge); +} + +#endif // USE_JWT_CPP && USE_SSL diff --git a/tests/integration/compose/docker_compose_keycloak.yml b/tests/integration/compose/docker_compose_keycloak.yml new file mode 100644 index 000000000000..dd2a66ab8107 --- /dev/null +++ b/tests/integration/compose/docker_compose_keycloak.yml @@ -0,0 +1,21 @@ +services: + keycloak: + image: quay.io/keycloak/keycloak:26.0 + command: start-dev --import-realm + environment: + KEYCLOAK_ADMIN: admin + KEYCLOAK_ADMIN_PASSWORD: admin + volumes: + - ${KEYCLOAK_REALM_FILE}:/opt/keycloak/data/import/realm.json:ro + ports: + - "${KEYCLOAK_EXTERNAL_PORT:-18080}:8080" + healthcheck: + test: + - CMD-SHELL + - > + curl -sf + http://localhost:8080/realms/clickhouse-test/.well-known/openid-configuration + || exit 1 + interval: 10s + timeout: 5s + retries: 15 diff --git a/tests/integration/helpers/cluster.py b/tests/integration/helpers/cluster.py index cac2613ad07e..55633488505c 100644 --- a/tests/integration/helpers/cluster.py +++ b/tests/integration/helpers/cluster.py @@ -653,6 +653,7 @@ def __init__( self.with_redis = False self.with_cassandra = False self.with_ldap = False + self.with_keycloak = False self.with_jdbc_bridge = False self.with_nginx = False self.with_hive = False @@ -751,6 +752,11 @@ def __init__( self.ldap_port = 1389 self.ldap_id = self.get_instance_docker_id(self.ldap_host) + # available when with_keycloak == True + self.keycloak_host = "keycloak" + self.keycloak_port = 18080 + self.base_keycloak_cmd = None + # available when with_rabbitmq == True self.rabbitmq_host = "rabbitmq1" self.rabbitmq_ip = None @@ -1790,6 +1796,25 @@ def setup_ldap_cmd(self, instance, env_variables, docker_compose_yml_dir): ) return self.base_ldap_cmd + def setup_keycloak_cmd(self, instance, env_variables, docker_compose_yml_dir): + self.with_keycloak = True + env_variables["KEYCLOAK_EXTERNAL_PORT"] = str(self.keycloak_port) + env_variables["KEYCLOAK_REALM_FILE"] = p.join( + p.dirname(instance.path), + "keycloak", + "realm-export.json", + ) + self.base_cmd.extend( + ["--file", p.join(docker_compose_yml_dir, "docker_compose_keycloak.yml")] + ) + self.base_keycloak_cmd = self.compose_cmd( + "--env-file", + instance.env_file, + "--file", + p.join(docker_compose_yml_dir, "docker_compose_keycloak.yml"), + ) + return self.base_keycloak_cmd + def setup_jdbc_bridge_cmd(self, instance, env_variables, docker_compose_yml_dir): self.with_jdbc_bridge = True env_variables["JDBC_DRIVER_LOGS"] = self.jdbc_driver_logs_dir @@ -1955,6 +1980,7 @@ def add_instance( with_azurite=False, with_cassandra=False, with_ldap=False, + with_keycloak=False, with_jdbc_bridge=False, with_hive=False, with_coredns=False, @@ -2098,6 +2124,7 @@ def add_instance( with_coredns=with_coredns, with_cassandra=with_cassandra, with_ldap=with_ldap, + with_keycloak=with_keycloak, with_iceberg_catalog=with_iceberg_catalog, with_glue_catalog=with_glue_catalog, with_hms_catalog=with_hms_catalog, @@ -2358,6 +2385,11 @@ def add_instance( self.setup_ldap_cmd(instance, env_variables, docker_compose_yml_dir) ) + if with_keycloak and not self.with_keycloak: + cmds.append( + self.setup_keycloak_cmd(instance, env_variables, docker_compose_yml_dir) + ) + if with_jdbc_bridge and not self.with_jdbc_bridge: cmds.append( self.setup_jdbc_bridge_cmd( @@ -3319,6 +3351,26 @@ def wait_ldap_to_start(self, timeout=180): raise Exception("Can't wait LDAP to start") + def wait_keycloak_to_start(self, timeout=120): + discovery_url = ( + f"http://localhost:{self.keycloak_port}" + f"/realms/clickhouse-test/.well-known/openid-configuration" + ) + start = time.time() + while time.time() - start < timeout: + try: + resp = requests.get(discovery_url, timeout=5) + if resp.status_code == 200: + logging.info("Keycloak is online") + return + except Exception as ex: + logging.warning("Waiting for Keycloak: %s", ex) + time.sleep(3) + raise Exception("Keycloak did not start in time") + + def get_keycloak_url(self): + return f"http://localhost:{self.keycloak_port}" + def wait_prometheus_to_start(self): if "writer" in self.prometheus_servers: self.prometheus_writer_ip = self.get_instance_ip(self.prometheus_writer_host) @@ -3795,6 +3847,11 @@ def logging_azurite_initialization(exception, retry_number, sleep_time): self.up_called = True self.wait_ldap_to_start() + if self.with_keycloak and self.base_keycloak_cmd: + subprocess_check_call(self.base_keycloak_cmd + ["up", "-d"]) + self.up_called = True + self.wait_keycloak_to_start() + if self.with_jdbc_bridge and self.base_jdbc_bridge_cmd: os.makedirs(self.jdbc_driver_logs_dir) os.chmod(self.jdbc_driver_logs_dir, stat.S_IRWXU | stat.S_IRWXO) @@ -4279,6 +4336,7 @@ def __init__( with_coredns, with_cassandra, with_ldap, + with_keycloak, with_iceberg_catalog, with_glue_catalog, with_hms_catalog, @@ -4403,6 +4461,7 @@ def __init__( self.with_azurite = with_azurite self.with_cassandra = with_cassandra self.with_ldap = with_ldap + self.with_keycloak = with_keycloak self.with_jdbc_bridge = with_jdbc_bridge self.with_hive = with_hive self.with_coredns = with_coredns @@ -5740,6 +5799,9 @@ def write_embedded_config(name, dest_dir, fix_log_level=False): if self.with_ldap: depends_on.append("openldap") + if self.with_keycloak: + depends_on.append("keycloak") + if self.with_rabbitmq: depends_on.append("rabbitmq1") diff --git a/tests/integration/test_keycloak_auth/__init__.py b/tests/integration/test_keycloak_auth/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/integration/test_keycloak_auth/configs/users.xml b/tests/integration/test_keycloak_auth/configs/users.xml new file mode 100644 index 000000000000..3c621c1506bc --- /dev/null +++ b/tests/integration/test_keycloak_auth/configs/users.xml @@ -0,0 +1,13 @@ + + + + 1 + 1 + + + + default + default + + + diff --git a/tests/integration/test_keycloak_auth/configs/validators.xml b/tests/integration/test_keycloak_auth/configs/validators.xml new file mode 100644 index 000000000000..f7e13a6c6784 --- /dev/null +++ b/tests/integration/test_keycloak_auth/configs/validators.xml @@ -0,0 +1,20 @@ + + + + + jwt_dynamic_jwks + http://keycloak:8080/realms/clickhouse-test/protocol/openid-connect/certs + http://keycloak:8080/realms/clickhouse-test + preferred_username + 60 + + + + + openid + http://keycloak:8080/realms/clickhouse-test/.well-known/openid-configuration + preferred_username + 60 + + + diff --git a/tests/integration/test_keycloak_auth/keycloak/realm-export.json b/tests/integration/test_keycloak_auth/keycloak/realm-export.json new file mode 100644 index 000000000000..2257e7b548db --- /dev/null +++ b/tests/integration/test_keycloak_auth/keycloak/realm-export.json @@ -0,0 +1,72 @@ +{ + "realm": "clickhouse-test", + "enabled": true, + "sslRequired": "none", + "registrationAllowed": false, + "clients": [ + { + "clientId": "clickhouse", + "enabled": true, + "secret": "test-secret", + "publicClient": false, + "directAccessGrantsEnabled": true, + "serviceAccountsEnabled": false, + "standardFlowEnabled": true, + "attributes": { + "oauth2.device.authorization.grant.enabled": "true" + }, + "redirectUris": ["*"], + "webOrigins": ["*"], + "protocol": "openid-connect" + } + ], + "users": [ + { + "username": "alice", + "enabled": true, + "credentials": [ + { + "type": "password", + "value": "secret", + "temporary": false + } + ], + "groups": ["analysts"] + } + ], + "groups": [ + { + "name": "analysts" + } + ], + "clientScopes": [ + { + "name": "groups", + "protocol": "openid-connect", + "attributes": { + "include.in.token.scope": "true" + }, + "protocolMappers": [ + { + "name": "groups", + "protocol": "openid-connect", + "protocolMapper": "oidc-group-membership-mapper", + "config": { + "full.path": "false", + "id.token.claim": "true", + "access.token.claim": "true", + "claim.name": "groups", + "userinfo.token.claim": "true" + } + } + ] + } + ], + "defaultDefaultClientScopes": [ + "profile", + "email", + "roles", + "web-origins", + "groups" + ] +} diff --git a/tests/integration/test_keycloak_auth/test.py b/tests/integration/test_keycloak_auth/test.py new file mode 100644 index 000000000000..bbd8a71907d7 --- /dev/null +++ b/tests/integration/test_keycloak_auth/test.py @@ -0,0 +1,374 @@ +""" +Integration tests for Keycloak-based JWT authentication in ClickHouse. + +Layer 2 of the OAuth2 test plan. Requires: + - A running Keycloak container (started via `with_keycloak=True` on the cluster) + - ClickHouse configured with `jwt_dynamic_jwks` and `openid` token processors + +Run: + pytest tests/integration/test_keycloak_auth/test.py -v +""" + +import base64 +import json +import logging +import re +import time +from html import unescape as html_unescape + +import pytest +import requests + +from helpers.cluster import ClickHouseCluster + +KEYCLOAK_REALM = "clickhouse-test" +KEYCLOAK_CLIENT_ID = "clickhouse" +KEYCLOAK_CLIENT_SECRET = "test-secret" + +cluster = ClickHouseCluster(__file__) + +node = cluster.add_instance( + "node", + main_configs=["configs/validators.xml"], + user_configs=["configs/users.xml"], + with_keycloak=True, + stay_alive=True, +) + + +@pytest.fixture(scope="module", autouse=True) +def started_cluster(): + try: + cluster.start() + yield cluster + finally: + cluster.shutdown() + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def keycloak_url(started_cluster): + return started_cluster.get_keycloak_url() + + +def get_keycloak_token(started_cluster, username="alice", password="secret"): + """Obtain an id_token from Keycloak using the resource-owner password grant.""" + url = f"{keycloak_url(started_cluster)}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/token" + data = { + "grant_type": "password", + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "username": username, + "password": password, + "scope": "openid profile email", + } + resp = requests.post(url, data=data, timeout=30) + resp.raise_for_status() + token_data = resp.json() + assert "id_token" in token_data, f"No id_token in response: {token_data}" + return token_data["id_token"] + + +def query_with_token(node_instance, token, query): + """Execute a ClickHouse query using a JWT Bearer token via the HTTP interface.""" + resp = node_instance.http_request( + "", + method="POST", + data=query, + headers={"Authorization": f"Bearer {token}"}, + ) + resp.raise_for_status() + return resp.text + + +def decode_jwt_payload(token): + """Decode JWT payload without signature verification.""" + parts = token.split(".") + if len(parts) < 2: + return {} + payload_b64 = parts[1] + # Add padding + padding = 4 - len(payload_b64) % 4 + if padding != 4: + payload_b64 += "=" * padding + # Convert URL-safe base64 + payload_b64 = payload_b64.replace("-", "+").replace("_", "/") + return json.loads(base64.b64decode(payload_b64)) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_jwt_dynamic_jwks(started_cluster): + """Token validated via explicit JWKS URI (keycloak_jwks processor).""" + token = get_keycloak_token(started_cluster) + result = query_with_token(node, token, "SELECT 1") + assert result.strip() == "1" + + +def test_openid_discovery(started_cluster): + """Token validated via OIDC discovery document (keycloak_discovery processor).""" + token = get_keycloak_token(started_cluster) + result = query_with_token(node, token, "SELECT 1") + assert result.strip() == "1" + + +def test_username_claim(started_cluster): + """The `preferred_username` claim is mapped to the ClickHouse session user.""" + token = get_keycloak_token(started_cluster, username="alice") + result = query_with_token(node, token, "SELECT currentUser()") + assert result.strip() == "alice" + + +def test_token_refresh(started_cluster): + """Obtain a new id_token via the refresh_token grant and authenticate.""" + url = f"{keycloak_url(started_cluster)}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/token" + + # Initial grant + data = { + "grant_type": "password", + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "username": "alice", + "password": "secret", + "scope": "openid profile email offline_access", + } + resp = requests.post(url, data=data, timeout=30) + resp.raise_for_status() + tokens = resp.json() + assert "refresh_token" in tokens, "Expected refresh_token in password grant response" + + # Refresh + refresh_data = { + "grant_type": "refresh_token", + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "refresh_token": tokens["refresh_token"], + } + refresh_resp = requests.post(url, data=refresh_data, timeout=30) + refresh_resp.raise_for_status() + refreshed = refresh_resp.json() + assert "id_token" in refreshed + + result = query_with_token(node, refreshed["id_token"], "SELECT 1") + assert result.strip() == "1" + + +def test_wrong_issuer_rejected(started_cluster): + """A token with a tampered issuer claim must be rejected.""" + token = get_keycloak_token(started_cluster) + payload = decode_jwt_payload(token) + + # Modify the issuer + payload["iss"] = "https://evil.example.com" + tampered_payload = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode() + ) + + parts = token.split(".") + parts[1] = tampered_payload + tampered_token = ".".join(parts) + + # Authentication must fail + try: + query_with_token(node, tampered_token, "SELECT 1") + pytest.fail("Expected authentication failure for tampered token") + except Exception: + pass # Expected + + +def test_expired_token_rejected(started_cluster): + """A token with an expired `exp` claim must be rejected.""" + token = get_keycloak_token(started_cluster) + payload = decode_jwt_payload(token) + + # Set exp to a past timestamp + payload["exp"] = int(time.time()) - 3600 + expired_payload = ( + base64.urlsafe_b64encode(json.dumps(payload).encode()).rstrip(b"=").decode() + ) + + parts = token.split(".") + parts[1] = expired_payload + expired_token = ".".join(parts) + + try: + query_with_token(node, expired_token, "SELECT 1") + pytest.fail("Expected authentication failure for expired token") + except Exception: + pass # Expected + + +def _approve_device_code_via_browser( + keycloak_base_url, realm, user_code, username="alice", password="secret" +): + """ + Simulate a browser approving a Keycloak device authorization request. + + Keycloak's device flow requires a user to visit a verification URI, log in, + and confirm access. This helper drives that multi-step HTML form sequence + using a `requests.Session` so that session cookies are maintained across + the redirects. + """ + + s = requests.Session() + + def get_form(html): + """Return (action_url, field_dict) for the first
in *html*.""" + m = re.search(r']*\baction="([^"]+)"', html) + if not m: + return None, {} + action_url = html_unescape(m.group(1)) + fields = {} + for inp in re.findall(r"]+>", html): + n = re.search(r'\bname="([^"]+)"', inp) + v = re.search(r'\bvalue="([^"]*)"', inp) + t = re.search(r'\btype="([^"]+)"', inp) + if n and (not t or t.group(1).lower() not in ("checkbox", "radio")): + fields[n.group(1)] = v.group(1) if v else "" + return action_url, fields + + # Step 1: Navigate to the device endpoint. Keycloak redirects to a login + # page when the user_code query parameter is provided and valid. + r = s.get( + f"{keycloak_base_url}/realms/{realm}/device", + params={"user_code": user_code}, + allow_redirects=True, + timeout=30, + ) + r.raise_for_status() + + # Step 1a: If Keycloak shows a user-code entry form first (no user_code + # in the redirect), fill it in and submit. + if 'name="device_user_code"' in r.text or 'name="user_code"' in r.text: + action, fields = get_form(r.text) + fields["device_user_code"] = user_code + fields["user_code"] = user_code + r = s.post(action, data=fields, allow_redirects=True, timeout=30) + r.raise_for_status() + + # Step 2: We should now be on the login page. Submit credentials. + assert 'type="password"' in r.text, ( + f"Expected Keycloak login page, got:\n{r.text[:800]}" + ) + action, fields = get_form(r.text) + fields["username"] = username + fields["password"] = password + r = s.post(action, data=fields, allow_redirects=True, timeout=30) + r.raise_for_status() + + # Step 3: Submit the device consent / grant form. Keycloak renders a + # "Do you want to grant access?" page with an `accept` submit button. + action, fields = get_form(r.text) + if action: + if "accept" not in fields: + fields["accept"] = "" + s.post(action, data=fields, allow_redirects=True, timeout=30) + + +def test_device_flow_initiation(started_cluster): + """ + Verify that Keycloak responds correctly to the device authorization request. + The polling / approval mechanics are covered by the Layer 1 unit tests. + """ + url = f"{keycloak_url(started_cluster)}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/auth/device" + data = { + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "scope": "openid profile email", + } + resp = requests.post(url, data=data, timeout=30) + resp.raise_for_status() + + device_data = resp.json() + assert "device_code" in device_data, f"Missing device_code: {device_data}" + assert "user_code" in device_data, f"Missing user_code: {device_data}" + assert ( + "verification_uri" in device_data or "verification_uri_complete" in device_data + ), f"Missing verification_uri: {device_data}" + logging.info( + "Device flow initiated: user_code=%s verification_uri=%s", + device_data.get("user_code"), + device_data.get("verification_uri", device_data.get("verification_uri_complete")), + ) + + +def test_device_flow_round_trip(started_cluster): + """ + Full device-authorization-grant round-trip (RFC 8628). + + 1. Client initiates device flow → Keycloak returns `device_code` / `user_code`. + 2. User (simulated via `_approve_device_code_via_browser`) visits the + verification URI, logs in, and grants access. + 3. Client polls the token endpoint until an `id_token` is returned. + 4. `id_token` is used to authenticate a ClickHouse query — must return `1`. + """ + base_url = keycloak_url(started_cluster) + device_endpoint = ( + f"{base_url}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/auth/device" + ) + token_endpoint = ( + f"{base_url}/realms/{KEYCLOAK_REALM}/protocol/openid-connect/token" + ) + + # --- 1. Initiate device authorization --- + init_resp = requests.post( + device_endpoint, + data={ + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "scope": "openid profile email", + }, + timeout=30, + ) + init_resp.raise_for_status() + device_data = init_resp.json() + device_code = device_data["device_code"] + user_code = device_data["user_code"] + interval = max(device_data.get("interval", 5), 1) + + logging.info( + "Device flow round-trip: user_code=%s device_code=%.8s…", user_code, device_code + ) + + # --- 2. Simulate user approving the request in a browser --- + _approve_device_code_via_browser(base_url, KEYCLOAK_REALM, user_code) + + # --- 3. Poll until the token arrives (or a 60-second deadline) --- + deadline = time.time() + 60 + id_token = None + while time.time() < deadline: + time.sleep(interval) + poll_resp = requests.post( + token_endpoint, + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "client_id": KEYCLOAK_CLIENT_ID, + "client_secret": KEYCLOAK_CLIENT_SECRET, + "device_code": device_code, + }, + timeout=30, + ) + poll_data = poll_resp.json() + if "id_token" in poll_data: + id_token = poll_data["id_token"] + break + error = poll_data.get("error", "") + assert error in ("authorization_pending", "slow_down"), ( + f"Unexpected polling error: {poll_data}" + ) + if error == "slow_down": + interval += 5 + + assert id_token is not None, ( + "Device flow timed out: Keycloak never returned an id_token after approval" + ) + + # --- 4. Use the token to authenticate a ClickHouse query --- + result = query_with_token(node, id_token, "SELECT 1") + assert result.strip() == "1" diff --git a/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.reference b/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.reference index 0dcbc1ba33aa..322a4d1b25ad 100644 --- a/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.reference +++ b/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.reference @@ -14,4 +14,10 @@ Test 7: Connection string with user:password@ should not trigger OAuth OK Test 8: Multiple host/port format variations OK +Test 9: --login=device with missing credentials file gives clear error +OK +Test 10: --login=invalid should give BAD_ARGUMENTS +OK +Test 11: --jwt and --login together should give BAD_ARGUMENTS +OK All tests completed diff --git a/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.sh b/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.sh index ddd4632854bc..5bd5ff300449 100755 --- a/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.sh +++ b/tests/queries/0_stateless/03749_cloud_endpoint_auth_precedence.sh @@ -101,4 +101,33 @@ else echo "FAILED: $failed commands failed" fi +# Test 9: --login=device with no credentials file should fail with a clear file-not-found error +# (not a crash or confusing message) +echo "Test 9: --login=device with missing credentials file gives clear error" +MISSING_CREDS="/tmp/nonexistent_oauth_creds_$$.json" +output=$($CLICKHOUSE_CLIENT_BINARY --login=device --oauth-credentials "$MISSING_CREDS" --query "SELECT 1" 2>&1) +if echo "$output" | grep -qi "not found\|No such file\|cannot open\|BAD_ARGUMENTS"; then + echo "OK" +else + echo "FAILED: expected file-not-found error, got: $output" +fi + +# Test 10: --login=invalid should give BAD_ARGUMENTS with descriptive message +echo "Test 10: --login=invalid should give BAD_ARGUMENTS" +output=$($CLICKHOUSE_CLIENT_BINARY --login=invalid --host="${CLICKHOUSE_HOST}" --port="${CLICKHOUSE_PORT_TCP}" --query "SELECT 1" 2>&1) +if echo "$output" | grep -qi "must be.*browser.*device\|BAD_ARGUMENTS"; then + echo "OK" +else + echo "FAILED: expected BAD_ARGUMENTS for invalid mode, got: $output" +fi + +# Test 11: --jwt and --login together should give BAD_ARGUMENTS +echo "Test 11: --jwt and --login together should give BAD_ARGUMENTS" +output=$($CLICKHOUSE_CLIENT_BINARY --jwt "sometoken" --login=browser --host="${CLICKHOUSE_HOST}" --port="${CLICKHOUSE_PORT_TCP}" --query "SELECT 1" 2>&1) +if echo "$output" | grep -qi "cannot both be specified\|BAD_ARGUMENTS"; then + echo "OK" +else + echo "FAILED: expected BAD_ARGUMENTS for --jwt + --login, got: $output" +fi + echo "All tests completed" From 1cb506a7aefb5fabdb818aba9cb558b486fba8b2 Mon Sep 17 00:00:00 2001 From: Andrey Zvonov Date: Tue, 14 Apr 2026 22:45:27 +0200 Subject: [PATCH 2/4] split code --- src/Client/OAuthFlowRunner.cpp | 439 +++++++++++++++++++ src/Client/OAuthFlowRunner.h | 22 + src/Client/OAuthLogin.cpp | 660 ++--------------------------- src/Client/OAuthProviderPolicy.cpp | 127 ++++++ src/Client/OAuthProviderPolicy.h | 58 +++ 5 files changed, 683 insertions(+), 623 deletions(-) create mode 100644 src/Client/OAuthFlowRunner.cpp create mode 100644 src/Client/OAuthFlowRunner.h create mode 100644 src/Client/OAuthProviderPolicy.cpp create mode 100644 src/Client/OAuthProviderPolicy.h diff --git a/src/Client/OAuthFlowRunner.cpp b/src/Client/OAuthFlowRunner.cpp new file mode 100644 index 000000000000..2210c90a6bd3 --- /dev/null +++ b/src/Client/OAuthFlowRunner.cpp @@ -0,0 +1,439 @@ +#include +#include + +#if USE_JWT_CPP && USE_SSL + +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include + +#if defined(__APPLE__) || defined(__linux__) +# include +# include +#endif + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int AUTHENTICATION_FAILED; +} + +void writeCachedRefreshToken(const std::string & client_id, const std::string & refresh_token); + +namespace +{ + +constexpr int HTTP_TIMEOUT_SECONDS = 30; + +std::string htmlEscape(const std::string & s) +{ + std::string out; + out.reserve(s.size()); + for (char c : s) + { + switch (c) + { + case '&': out += "&"; break; + case '<': out += "<"; break; + case '>': out += ">"; break; + case '"': out += """; break; + case '\'': out += "'"; break; + default: out += c; break; + } + } + return out; +} + +struct PKCEPair +{ + std::string verifier; + std::string challenge; +}; + +PKCEPair generatePKCE() +{ + unsigned char raw[32]; + if (RAND_bytes(raw, sizeof(raw)) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for PKCE verifier"); + + std::string verifier = base64Encode( + std::string(reinterpret_cast(raw), sizeof(raw)), + /*url_encoding=*/true, + /*no_padding=*/true); + + std::string sha = encodeSHA256(verifier); + std::string challenge = base64Encode(sha, /*url_encoding=*/true, /*no_padding=*/true); + return {verifier, challenge}; +} + +void openBrowser(const std::string & url) +{ + std::cerr << "Opening browser for authentication.\n" + << "If the browser does not open, visit:\n " << url << "\n"; + +#if defined(__APPLE__) || defined(__linux__) + const char * cmd = +# if defined(__APPLE__) + "open"; +# else + "xdg-open"; +# endif + const char * argv[] = {cmd, url.c_str(), nullptr}; + pid_t pid; + if (posix_spawnp(&pid, cmd, nullptr, nullptr, const_cast(argv), nullptr) == 0) + waitpid(pid, nullptr, 0); +#endif +} + +struct AuthCodeState +{ + std::mutex mtx; + std::condition_variable cv; + std::string code; + std::string error; + std::string received_state; + bool done = false; +}; + +class AuthCodeHandler : public Poco::Net::HTTPRequestHandler +{ +public: + explicit AuthCodeHandler(AuthCodeState & state_) : state(state_) { } + + void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override + { + Poco::URI uri("http://localhost" + request.getURI()); + const auto params = uri.getQueryParameters(); + + std::string code; + std::string error; + std::string received_state; + for (const auto & [k, v] : params) + { + if (k == "code") + code = v; + else if (k == "error") + error = v; + else if (k == "state") + received_state = v; + } + + response.setStatus(Poco::Net::HTTPResponse::HTTP_OK); + response.setContentType("text/html"); + auto & out = response.send(); + if (!code.empty()) + out << "Authentication successful. You may close this tab."; + else + out << "Authentication failed: " << htmlEscape(error) << ""; + out.flush(); + + std::lock_guard lock(state.mtx); + state.code = code; + state.error = error; + state.received_state = received_state; + state.done = true; + state.cv.notify_one(); + } + +private: + AuthCodeState & state; +}; + +class AuthCodeHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory +{ +public: + explicit AuthCodeHandlerFactory(AuthCodeState & state_) : state(state_) { } + + Poco::Net::HTTPRequestHandler * createRequestHandler(const Poco::Net::HTTPServerRequest &) override + { + return new AuthCodeHandler(state); + } + +private: + AuthCodeState & state; +}; + +} + +std::string urlEncodeOAuth(const std::string & value) +{ + std::string result; + Poco::URI::encode(value, "", result); + return result; +} + +Poco::JSON::Object::Ptr postOAuthForm(const std::string & url, const std::string & body) +{ + Poco::URI uri(url); + Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_POST, uri.getPathAndQuery()); + request.setContentType("application/x-www-form-urlencoded"); + request.setContentLength(static_cast(body.size())); + + Poco::Net::HTTPResponse response; + std::string response_body; + + if (uri.getScheme() == "https") + { + Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); + Poco::Net::HTTPSClientSession session(uri.getHost(), uri.getPort(), ctx); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + auto & req_stream = session.sendRequest(request); + req_stream << body; + auto & resp_stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(resp_stream, response_body); + } + else + { + Poco::Net::HTTPClientSession session(uri.getHost(), uri.getPort()); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + auto & req_stream = session.sendRequest(request); + req_stream << body; + auto & resp_stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(resp_stream, response_body); + } + + Poco::Dynamic::Var parsed; + try + { + Poco::JSON::Parser parser; + parsed = parser.parse(response_body); + } + catch (...) + { + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OAuth2 endpoint '{}' returned HTTP {} with non-JSON body: {}", + url, + static_cast(response.getStatus()), + response_body.substr(0, 512)); + } + + auto obj = parsed.extract(); + if (!obj) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OAuth2 endpoint '{}' returned HTTP {} with non-object JSON response: {}", + url, + static_cast(response.getStatus()), + response_body.substr(0, 512)); + return obj; +} + +std::string runOAuthAuthCodeFlow(const OAuthCredentials & creds) +{ + auto provider_policy = IOAuthProviderPolicy::create(creds); + auto pkce = generatePKCE(); + + unsigned char state_bytes[16]; + if (RAND_bytes(state_bytes, sizeof(state_bytes)) != 1) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for OAuth state"); + + std::string csrf_state; + csrf_state.reserve(32); + for (unsigned char b : state_bytes) + { + constexpr char digits[] = "0123456789abcdef"; + csrf_state += digits[(b >> 4) & 0xF]; + csrf_state += digits[b & 0xF]; + } + + Poco::Net::ServerSocket server_socket; + server_socket.bind(Poco::Net::SocketAddress("127.0.0.1", 0), /*reuse_address=*/true); + server_socket.listen(1); + const uint16_t port = server_socket.address().port(); + const std::string redirect_uri = "http://localhost:" + std::to_string(port) + "/callback"; + + AuthCodeState state; + auto params = Poco::AutoPtr(new Poco::Net::HTTPServerParams()); + params->setMaxQueued(1); + params->setMaxThreads(1); + Poco::Net::HTTPServer server(new AuthCodeHandlerFactory(state), server_socket, params); + server.start(); + + std::string auth_url + = creds.auth_uri + + "?response_type=code" + "&client_id=" + urlEncodeOAuth(creds.client_id) + + "&redirect_uri=" + urlEncodeOAuth(redirect_uri) + + "&code_challenge=" + pkce.challenge + + "&code_challenge_method=S256" + + "&scope=" + urlEncodeOAuth(provider_policy->getAuthCodeScope()) + + "&state=" + csrf_state; + if (provider_policy->useAccessTypeOfflineForAuthCode()) + auth_url += "&access_type=offline"; + + openBrowser(auth_url); + + bool timed_out = false; + std::string received_code; + std::string received_error; + std::string received_state; + { + std::unique_lock lock(state.mtx); + timed_out = !state.cv.wait_for(lock, std::chrono::seconds(120), [&] { return state.done; }); + received_code = state.code; + received_error = state.error; + received_state = state.received_state; + } + server.stop(); + + if (timed_out) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 login timed out waiting for browser callback"); + if (!received_error.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 authorization error: {}", received_error); + if (received_code.empty()) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 callback did not contain an authorization code"); + if (received_state != csrf_state) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 CSRF check failed: unexpected state in callback"); + + const std::string body + = "grant_type=authorization_code" + "&code=" + urlEncodeOAuth(received_code) + + "&redirect_uri=" + urlEncodeOAuth(redirect_uri) + + "&client_id=" + urlEncodeOAuth(creds.client_id) + + "&client_secret=" + urlEncodeOAuth(creds.client_secret) + + "&code_verifier=" + urlEncodeOAuth(pkce.verifier); + + auto resp = postOAuthForm(creds.token_uri, body); + if (resp->has("error")) + { + const std::string desc = resp->has("error_description") + ? resp->getValue("error_description") + : resp->getValue("error"); + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token exchange failed: {}", desc); + } + + if (!resp->has("id_token")) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token response did not contain id_token"); + + if (resp->has("refresh_token")) + writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); + + return resp->getValue("id_token"); +} + +std::string runOAuthDeviceFlow(OAuthCredentials creds) +{ + auto provider_policy = IOAuthProviderPolicy::create(creds); + if (creds.device_auth_uri.empty()) + creds.device_auth_uri = provider_policy->resolveDeviceAuthorizationEndpoint(creds); + + const std::string device_scope = provider_policy->getDeviceScope(); + const std::string device_body + = "client_id=" + urlEncodeOAuth(creds.client_id) + + "&scope=" + urlEncodeOAuth(device_scope); + + auto device_resp = postOAuthForm(creds.device_auth_uri, device_body); + + if (device_resp->has("error")) + { + const std::string desc = device_resp->has("error_description") + ? device_resp->getValue("error_description") + : device_resp->getValue("error"); + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device authorization request failed: {}", desc); + } + + if (!device_resp->has("device_code") || !device_resp->has("user_code")) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "Device authorization response from '{}' is missing required fields " + "(device_code / user_code). Response: {}", + creds.device_auth_uri, + [&] + { + std::ostringstream ss; + device_resp->stringify(ss); + return ss.str(); + }()); + + const std::string device_code = device_resp->getValue("device_code"); + const std::string user_code = device_resp->getValue("user_code"); + const std::string verification_uri = device_resp->has("verification_uri_complete") + ? device_resp->getValue("verification_uri_complete") + : device_resp->has("verification_uri") + ? device_resp->getValue("verification_uri") + : device_resp->has("verification_url") + ? device_resp->getValue("verification_url") + : throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "Device authorization response missing verification_uri / verification_url"); + + int interval = device_resp->has("interval") ? device_resp->getValue("interval") : 5; + int expires_in = device_resp->has("expires_in") ? device_resp->getValue("expires_in") : 300; + + std::cerr << "\nTo authenticate, visit:\n " << verification_uri << "\nAnd enter code: " << user_code << "\n\n"; + + const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(expires_in); + while (std::chrono::steady_clock::now() < deadline) + { + std::this_thread::sleep_for(std::chrono::seconds(interval)); + + const std::string poll_body + = "grant_type=urn:ietf:params:oauth:grant-type:device_code" + "&device_code=" + urlEncodeOAuth(device_code) + + "&client_id=" + urlEncodeOAuth(creds.client_id) + + "&client_secret=" + urlEncodeOAuth(creds.client_secret); + + auto resp = postOAuthForm(creds.token_uri, poll_body); + if (resp->has("error")) + { + const std::string err = resp->getValue("error"); + if (err == "authorization_pending") + continue; + if (err == "slow_down") + { + interval += 5; + continue; + } + const std::string desc = resp->has("error_description") ? resp->getValue("error_description") : err; + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow error: {}", desc); + } + + if (!resp->has("id_token")) + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow token response did not contain id_token"); + + if (resp->has("refresh_token")) + writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); + + return resp->getValue("id_token"); + } + + throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow timed out"); +} + +} // namespace DB + +#endif // USE_JWT_CPP && USE_SSL diff --git a/src/Client/OAuthFlowRunner.h b/src/Client/OAuthFlowRunner.h new file mode 100644 index 000000000000..de5ba06dc31e --- /dev/null +++ b/src/Client/OAuthFlowRunner.h @@ -0,0 +1,22 @@ +#pragma once + +#include +#include + +#if USE_JWT_CPP && USE_SSL + +#include + +#include + +namespace DB +{ + +std::string urlEncodeOAuth(const std::string & value); +Poco::JSON::Object::Ptr postOAuthForm(const std::string & url, const std::string & body); +std::string runOAuthAuthCodeFlow(const OAuthCredentials & creds); +std::string runOAuthDeviceFlow(OAuthCredentials creds); + +} + +#endif // USE_JWT_CPP && USE_SSL diff --git a/src/Client/OAuthLogin.cpp b/src/Client/OAuthLogin.cpp index 80610347d84c..1a29aef0ed84 100644 --- a/src/Client/OAuthLogin.cpp +++ b/src/Client/OAuthLogin.cpp @@ -3,44 +3,20 @@ #if USE_JWT_CPP && USE_SSL -# include -# include -# include - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include - -# include -# include -# include -# include -# include -# include -# include - -# if defined(__APPLE__) || defined(__linux__) -# include -# include -# endif +#include + +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include namespace DB { @@ -48,258 +24,13 @@ namespace DB namespace ErrorCodes { extern const int BAD_ARGUMENTS; -extern const int AUTHENTICATION_FAILED; } namespace { -// HTTP request timeout for all OAuth endpoint calls. -constexpr int HTTP_TIMEOUT_SECONDS = 30; - -/// Minimal HTML escaping to prevent XSS when reflecting user-supplied strings -/// (e.g. the error= query parameter from the OAuth callback) into HTML. -std::string htmlEscape(const std::string & s) -{ - std::string out; - out.reserve(s.size()); - for (char c : s) - { - switch (c) - { - case '&': out += "&"; break; - case '<': out += "<"; break; - case '>': out += ">"; break; - case '"': out += """; break; - case '\'': out += "'"; break; - default: out += c; break; - } - } - return out; -} - -// --------------------------------------------------------------------------- -// 2. discoverDeviceEndpoint -// --------------------------------------------------------------------------- - -/// Fetch the OIDC discovery document and return device_authorization_endpoint. -/// -/// issuer_hint: explicit OIDC issuer URL (e.g. from credentials JSON "issuer" field). -/// When non-empty it is used directly: discovery is at {issuer_hint}/.well-known/openid-configuration. -/// When empty, issuer is derived heuristically from token_uri: -/// - Google (oauth2.googleapis.com) → https://accounts.google.com (hardcoded mapping) -/// - Generic: strip last path segment, preserving realm prefixes -/// e.g. https://auth.example.com/realms/myrealm/protocol/openid-connect/token -/// → https://auth.example.com/realms/myrealm -/// For providers whose issuer cannot be reliably derived, set "issuer" or -/// "device_authorization_uri" in the credentials JSON to bypass discovery. -std::string discoverDeviceEndpoint(const std::string & token_uri, const std::string & issuer_hint) -{ - std::string issuer; - if (!issuer_hint.empty()) - { - issuer = issuer_hint; - } - else - { - Poco::URI uri(token_uri); - if (uri.getHost() == "oauth2.googleapis.com") - { - // Google uses a separate domain for its OIDC discovery. - issuer = "https://accounts.google.com"; - } - else - { - // Build scheme://host[:port] prefix. - issuer = uri.getScheme() + "://" + uri.getHost(); - if (uri.getPort() != 0 - && !((uri.getScheme() == "https" && uri.getPort() == 443) - || (uri.getScheme() == "http" && uri.getPort() == 80))) - issuer += ":" + std::to_string(uri.getPort()); - - // Append the path minus its last segment so that issuers with - // sub-paths (e.g. Keycloak's /realms/) are preserved. - std::string path = uri.getPath(); - const auto last_slash = path.rfind('/'); - if (last_slash != std::string::npos && last_slash != 0) - issuer += path.substr(0, last_slash); - } - } - - const std::string discovery_url = issuer + "/.well-known/openid-configuration"; - Poco::URI disc_uri(discovery_url); - - Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_GET, disc_uri.getPathAndQuery()); - Poco::Net::HTTPResponse response; - std::string body; - - if (disc_uri.getScheme() == "https") - { - Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); - Poco::Net::HTTPSClientSession session(disc_uri.getHost(), disc_uri.getPort(), ctx); - session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); - session.sendRequest(request); - auto & stream = session.receiveResponse(response); - Poco::StreamCopier::copyToString(stream, body); - } - else - { - Poco::Net::HTTPClientSession session(disc_uri.getHost(), disc_uri.getPort()); - session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); - session.sendRequest(request); - auto & stream = session.receiveResponse(response); - Poco::StreamCopier::copyToString(stream, body); - } - - if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_OK) - throw Exception( - ErrorCodes::AUTHENTICATION_FAILED, - "OIDC discovery failed for '{}': {} {}", - discovery_url, - static_cast(response.getStatus()), - response.getReason()); - - Poco::JSON::Parser parser; - auto result = parser.parse(body); - auto obj = result.extract(); - - if (!obj->has("device_authorization_endpoint")) - throw Exception( - ErrorCodes::AUTHENTICATION_FAILED, - "OIDC discovery document at '{}' does not contain device_authorization_endpoint", - discovery_url); - - return obj->getValue("device_authorization_endpoint"); -} - -// --------------------------------------------------------------------------- -// 3. generatePKCE -// --------------------------------------------------------------------------- - -struct PKCEPair -{ - std::string verifier; - std::string challenge; -}; - -PKCEPair generatePKCE() -{ - // 32 random bytes → base64url (43 chars, no padding) - unsigned char raw[32]; - if (RAND_bytes(raw, sizeof(raw)) != 1) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for PKCE verifier"); - - std::string verifier = base64Encode( - std::string(reinterpret_cast(raw), sizeof(raw)), - /*url_encoding=*/true, - /*no_padding=*/true); - - // challenge = BASE64URL(SHA256(verifier)) - std::string sha = encodeSHA256(verifier); - std::string challenge = base64Encode(sha, /*url_encoding=*/true, /*no_padding=*/true); - - return {verifier, challenge}; -} - -// --------------------------------------------------------------------------- -// 4. urlEncode -// --------------------------------------------------------------------------- - -std::string urlEncode(const std::string & s) -{ - std::string result; - Poco::URI::encode(s, "", result); - return result; -} - -/// Google uses access_type=offline instead of the offline_access scope. -/// Detect by checking the token endpoint host. -bool isGoogleProvider(const OAuthCredentials & creds) -{ - Poco::URI uri(creds.token_uri); - const std::string & host = uri.getHost(); - return host == "oauth2.googleapis.com" || host == "accounts.google.com"; -} - -// --------------------------------------------------------------------------- -// 5. postForm — HTTPS/HTTP POST application/x-www-form-urlencoded -// -// Always attempts to parse the response body as JSON, regardless of the HTTP -// status code. RFC 6749 returns error responses (e.g. authorization_pending -// during device-flow polling) as HTTP 400 with a JSON body — callers must -// inspect the "error" field in the returned object. -// -// Throws only when the body cannot be parsed as JSON: -// - 4xx/5xx with non-JSON body → AUTHENTICATION_FAILED with HTTP status -// - 2xx with non-JSON body → AUTHENTICATION_FAILED (unexpected format) -// --------------------------------------------------------------------------- - -Poco::JSON::Object::Ptr postForm(const std::string & url, const std::string & body) -{ - Poco::URI uri(url); - Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_POST, uri.getPathAndQuery()); - request.setContentType("application/x-www-form-urlencoded"); - request.setContentLength(static_cast(body.size())); - - Poco::Net::HTTPResponse response; - std::string response_body; - - if (uri.getScheme() == "https") - { - Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); - Poco::Net::HTTPSClientSession session(uri.getHost(), uri.getPort(), ctx); - session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); - auto & req_stream = session.sendRequest(request); - req_stream << body; - auto & resp_stream = session.receiveResponse(response); - Poco::StreamCopier::copyToString(resp_stream, response_body); - } - else - { - Poco::Net::HTTPClientSession session(uri.getHost(), uri.getPort()); - session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); - auto & req_stream = session.sendRequest(request); - req_stream << body; - auto & resp_stream = session.receiveResponse(response); - Poco::StreamCopier::copyToString(resp_stream, response_body); - } - - // Try JSON parse regardless of status — RFC 6749 §5.2 returns errors - // in JSON bodies even on HTTP 400 (e.g., authorization_pending, slow_down). - Poco::Dynamic::Var parsed; - try - { - Poco::JSON::Parser parser; - parsed = parser.parse(response_body); - } - catch (...) - { - throw Exception( - ErrorCodes::AUTHENTICATION_FAILED, - "OAuth2 endpoint '{}' returned HTTP {} with non-JSON body: {}", - url, - static_cast(response.getStatus()), - response_body.substr(0, 512)); - } - - auto obj = parsed.extract(); - if (!obj) - throw Exception( - ErrorCodes::AUTHENTICATION_FAILED, - "OAuth2 endpoint '{}' returned HTTP {} with non-object JSON response: {}", - url, - static_cast(response.getStatus()), - response_body.substr(0, 512)); - return obj; -} - -// --------------------------------------------------------------------------- -// 6. Token cache -// --------------------------------------------------------------------------- - std::string cacheKey(const std::string & client_id) { - // First 16 hex chars of SHA256(client_id) std::string hash = encodeSHA256(client_id); std::string hex; hex.reserve(32); @@ -320,7 +51,7 @@ std::string cacheFilePath() return std::string(home) + "/.clickhouse-client/oauth_cache.json"; } -std::string readCachedRefreshToken(const std::string & client_id) +std::string readCachedRefreshTokenImpl(const std::string & client_id) { const std::string path = cacheFilePath(); if (path.empty()) @@ -335,7 +66,7 @@ std::string readCachedRefreshToken(const std::string & client_id) { Poco::JSON::Parser parser; auto result = parser.parse(content); - auto obj = result.extract(); + const auto & obj = result.extract(); const std::string key = cacheKey(client_id); if (obj->has(key)) return obj->getValue(key); @@ -348,6 +79,8 @@ std::string readCachedRefreshToken(const std::string & client_id) return ""; } +} + void writeCachedRefreshToken(const std::string & client_id, const std::string & refresh_token) { const std::string path = cacheFilePath(); @@ -356,11 +89,8 @@ void writeCachedRefreshToken(const std::string & client_id, const std::string & namespace fs = std::filesystem; const fs::path cache_path(path); - - // Ensure directory exists fs::create_directories(cache_path.parent_path()); - // Read existing cache Poco::JSON::Object obj; { std::ifstream f(path); @@ -371,9 +101,9 @@ void writeCachedRefreshToken(const std::string & client_id, const std::string & { Poco::JSON::Parser parser; auto result = parser.parse(content); - auto existing = result.extract(); - for (auto it = existing->begin(); it != existing->end(); ++it) - obj.set(it->first, it->second); + const auto & existing = result.extract(); + for (const auto & [key, value] : *existing) + obj.set(key, value); } catch (...) { @@ -385,9 +115,6 @@ void writeCachedRefreshToken(const std::string & client_id, const std::string & obj.set(cacheKey(client_id), refresh_token); - // Write atomically: write to a temp file beside the cache, then rename. - // This prevents a partially-written file from being left world-readable if - // we are interrupted between the write and the chmod. const std::string tmp_path = path + ".tmp"; { std::ofstream out(tmp_path, std::ios::trunc); @@ -401,9 +128,8 @@ void writeCachedRefreshToken(const std::string & client_id, const std::string & fs::rename(tmp_path, cache_path, ec); } -// --------------------------------------------------------------------------- -// 7. tryRefreshToken -// --------------------------------------------------------------------------- +namespace +{ std::string tryRefreshToken(const OAuthCredentials & creds, const std::string & refresh_token) { @@ -411,11 +137,11 @@ std::string tryRefreshToken(const OAuthCredentials & creds, const std::string & { const std::string body = "grant_type=refresh_token" - "&client_id=" + urlEncode(creds.client_id) - + "&client_secret=" + urlEncode(creds.client_secret) - + "&refresh_token=" + urlEncode(refresh_token); + "&client_id=" + urlEncodeOAuth(creds.client_id) + + "&client_secret=" + urlEncodeOAuth(creds.client_secret) + + "&refresh_token=" + urlEncodeOAuth(refresh_token); - auto resp = postForm(creds.token_uri, body); + auto resp = postOAuthForm(creds.token_uri, body); if (resp->has("error")) { std::cerr << "Note: cached refresh token was rejected (" @@ -423,6 +149,8 @@ std::string tryRefreshToken(const OAuthCredentials & creds, const std::string & << "); re-authenticating.\n"; return ""; } + if (resp->has("refresh_token")) + writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); if (resp->has("id_token")) return resp->getValue("id_token"); } @@ -434,316 +162,8 @@ std::string tryRefreshToken(const OAuthCredentials & creds, const std::string & return ""; } -// --------------------------------------------------------------------------- -// 8. openBrowser -// --------------------------------------------------------------------------- - -void openBrowser(const std::string & url) -{ - // Always print so the user can copy-paste on headless / remote sessions. - std::cerr << "Opening browser for authentication.\n" - << "If the browser does not open, visit:\n " << url << "\n"; - -# if defined(__APPLE__) || defined(__linux__) - // Use posix_spawnp instead of system() to avoid shell-quoting issues. - const char * cmd = -# if defined(__APPLE__) - "open"; -# else - "xdg-open"; -# endif - const char * argv[] = {cmd, url.c_str(), nullptr}; - pid_t pid; - if (posix_spawnp(&pid, cmd, nullptr, nullptr, const_cast(argv), nullptr) == 0) - waitpid(pid, nullptr, 0); -# endif -} - -// --------------------------------------------------------------------------- -// 9. runAuthCodeFlow — auth code + PKCE, one-shot localhost callback server -// --------------------------------------------------------------------------- - -struct AuthCodeState -{ - std::mutex mtx; - std::condition_variable cv; - std::string code; - std::string error; - std::string received_state; // state= value echoed back by the provider - bool done = false; -}; - -class AuthCodeHandler : public Poco::Net::HTTPRequestHandler -{ -public: - explicit AuthCodeHandler(AuthCodeState & state_) : state(state_) { } - - void handleRequest(Poco::Net::HTTPServerRequest & request, Poco::Net::HTTPServerResponse & response) override - { - Poco::URI uri("http://localhost" + request.getURI()); - const auto params = uri.getQueryParameters(); - - std::string code; - std::string error; - std::string received_state; - for (const auto & [k, v] : params) - { - if (k == "code") - code = v; - else if (k == "error") - error = v; - else if (k == "state") - received_state = v; - } - - response.setStatus(Poco::Net::HTTPResponse::HTTP_OK); - response.setContentType("text/html"); - auto & out = response.send(); - if (!code.empty()) - out << "Authentication successful. You may close this tab."; - else - out << "Authentication failed: " << htmlEscape(error) << ""; - out.flush(); - - std::lock_guard lock(state.mtx); - state.code = code; - state.error = error; - state.received_state = received_state; - state.done = true; - state.cv.notify_one(); - } - -private: - AuthCodeState & state; -}; - -class AuthCodeHandlerFactory : public Poco::Net::HTTPRequestHandlerFactory -{ -public: - explicit AuthCodeHandlerFactory(AuthCodeState & state_) : state(state_) { } - - Poco::Net::HTTPRequestHandler * createRequestHandler(const Poco::Net::HTTPServerRequest &) override - { - return new AuthCodeHandler(state); - } - -private: - AuthCodeState & state; -}; - -std::string runAuthCodeFlow(const OAuthCredentials & creds) -{ - auto pkce = generatePKCE(); - - // Generate a random anti-CSRF state value per RFC 6749 §10.12. - unsigned char state_bytes[16]; - if (RAND_bytes(state_bytes, sizeof(state_bytes)) != 1) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "RAND_bytes failed for OAuth state"); - std::string csrf_state; - csrf_state.reserve(32); - for (unsigned char b : state_bytes) - { - constexpr char digits[] = "0123456789abcdef"; - csrf_state += digits[(b >> 4) & 0xF]; - csrf_state += digits[b & 0xF]; - } - - // Ephemeral callback server bound exclusively to the loopback interface. - // Binding to 127.0.0.1 (not 0.0.0.0) ensures network-adjacent attackers - // cannot race to deliver a forged callback even without the CSRF state check. - Poco::Net::ServerSocket server_socket; - server_socket.bind(Poco::Net::SocketAddress("127.0.0.1", 0), /*reuse_address=*/true); - server_socket.listen(1); - const uint16_t port = server_socket.address().port(); - const std::string redirect_uri = "http://localhost:" + std::to_string(port) + "/callback"; - - AuthCodeState state; - auto params = Poco::AutoPtr(new Poco::Net::HTTPServerParams()); - params->setMaxQueued(1); - params->setMaxThreads(1); - Poco::Net::HTTPServer server(new AuthCodeHandlerFactory(state), server_socket, params); - server.start(); - - // Build authorization URL — scope uses %20-encoded spaces per RFC 6749 §3.3. - // Google uses access_type=offline to request a refresh token rather than - // the standard offline_access scope (which it rejects as invalid). - const bool google = isGoogleProvider(creds); - const std::string scope = google - ? "openid email profile" - : "openid email profile offline_access"; - std::string auth_url - = creds.auth_uri - + "?response_type=code" - "&client_id=" + urlEncode(creds.client_id) - + "&redirect_uri=" + urlEncode(redirect_uri) - + "&code_challenge=" + pkce.challenge - + "&code_challenge_method=S256" - + "&scope=" + urlEncode(scope) - + "&state=" + csrf_state; - if (google) - auth_url += "&access_type=offline"; - - openBrowser(auth_url); - - // Wait up to 120 s for the browser callback. - bool timed_out = false; - std::string received_code; - std::string received_error; - std::string received_state; - { - std::unique_lock lock(state.mtx); - timed_out = !state.cv.wait_for(lock, std::chrono::seconds(120), [&] { return state.done; }); - received_code = state.code; - received_error = state.error; - received_state = state.received_state; - } - // Release the mutex before stopping the server to avoid a deadlock with - // the request handler that also acquires state.mtx. - server.stop(); - - if (timed_out) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 login timed out waiting for browser callback"); - if (!received_error.empty()) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 authorization error: {}", received_error); - if (received_code.empty()) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 callback did not contain an authorization code"); - if (received_state != csrf_state) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 CSRF check failed: unexpected state in callback"); - - // Exchange authorization code for tokens. - const std::string body - = "grant_type=authorization_code" - "&code=" + urlEncode(received_code) - + "&redirect_uri=" + urlEncode(redirect_uri) - + "&client_id=" + urlEncode(creds.client_id) - + "&client_secret=" + urlEncode(creds.client_secret) - + "&code_verifier=" + urlEncode(pkce.verifier); - - auto resp = postForm(creds.token_uri, body); - - if (resp->has("error")) - { - const std::string desc = resp->has("error_description") - ? resp->getValue("error_description") - : resp->getValue("error"); - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token exchange failed: {}", desc); - } - - if (resp->has("refresh_token")) - writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); - - if (!resp->has("id_token")) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "OAuth2 token response did not contain id_token"); - - return resp->getValue("id_token"); -} - -// --------------------------------------------------------------------------- -// 10. runDeviceFlow -// --------------------------------------------------------------------------- - -std::string runDeviceFlow(OAuthCredentials creds) -{ - if (creds.device_auth_uri.empty()) - creds.device_auth_uri = discoverDeviceEndpoint(creds.token_uri, creds.issuer); - - // Scope uses %20-encoded spaces per RFC 6749 §3.3. - // Google rejects offline_access as an invalid scope; it issues a refresh - // token automatically for device flow. Standard OIDC providers require it. - const std::string device_scope = isGoogleProvider(creds) - ? "openid email profile" - : "openid email profile offline_access"; - const std::string device_body - = "client_id=" + urlEncode(creds.client_id) - + "&scope=" + urlEncode(device_scope); - - auto device_resp = postForm(creds.device_auth_uri, device_body); - - if (device_resp->has("error")) - { - const std::string desc = device_resp->has("error_description") - ? device_resp->getValue("error_description") - : device_resp->getValue("error"); - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device authorization request failed: {}", desc); - } - - // Validate mandatory fields before accessing them; getValue() on a missing - // key returns an empty Var and throws "Can not convert empty value". - if (!device_resp->has("device_code") || !device_resp->has("user_code")) - throw Exception( - ErrorCodes::AUTHENTICATION_FAILED, - "Device authorization response from '{}' is missing required fields " - "(device_code / user_code). Response: {}", - creds.device_auth_uri, - [&]{ std::ostringstream ss; device_resp->stringify(ss); return ss.str(); }()); - - const std::string device_code = device_resp->getValue("device_code"); - const std::string user_code = device_resp->getValue("user_code"); - - // RFC 8628 uses "verification_uri"; Google's older device API uses "verification_url". - const std::string verification_uri = device_resp->has("verification_uri_complete") - ? device_resp->getValue("verification_uri_complete") - : device_resp->has("verification_uri") - ? device_resp->getValue("verification_uri") - : device_resp->has("verification_url") - ? device_resp->getValue("verification_url") - : throw Exception(ErrorCodes::AUTHENTICATION_FAILED, - "Device authorization response missing verification_uri / verification_url"); - - int interval = device_resp->has("interval") ? device_resp->getValue("interval") : 5; - int expires_in = device_resp->has("expires_in") ? device_resp->getValue("expires_in") : 300; - - std::cerr << "\nTo authenticate, visit:\n " << verification_uri - << "\nAnd enter code: " << user_code << "\n\n"; - - const auto deadline = std::chrono::steady_clock::now() + std::chrono::seconds(expires_in); - - while (std::chrono::steady_clock::now() < deadline) - { - std::this_thread::sleep_for(std::chrono::seconds(interval)); - - const std::string poll_body - = "grant_type=urn:ietf:params:oauth:grant-type:device_code" - "&device_code=" + urlEncode(device_code) - + "&client_id=" + urlEncode(creds.client_id) - + "&client_secret=" + urlEncode(creds.client_secret); - - auto resp = postForm(creds.token_uri, poll_body); - - if (resp->has("error")) - { - const std::string err = resp->getValue("error"); - if (err == "authorization_pending") - continue; - if (err == "slow_down") - { - interval += 5; - continue; - } - const std::string desc = resp->has("error_description") - ? resp->getValue("error_description") - : err; - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow error: {}", desc); - } - - if (resp->has("refresh_token")) - writeCachedRefreshToken(creds.client_id, resp->getValue("refresh_token")); - - if (!resp->has("id_token")) - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow token response did not contain id_token"); - - return resp->getValue("id_token"); - } - - throw Exception(ErrorCodes::AUTHENTICATION_FAILED, "Device flow timed out"); } -} // anonymous namespace - -// --------------------------------------------------------------------------- -// Public API -// --------------------------------------------------------------------------- - OAuthCredentials loadOAuthCredentials(const std::string & path) { std::ifstream f(path); @@ -770,7 +190,6 @@ OAuthCredentials loadOAuthCredentials(const std::string & path) auto root = parsed.extract(); - // Accept either "installed" (desktop) or "web" top-level key. Poco::JSON::Object::Ptr app; if (root->has("installed")) app = root->getObject("installed"); @@ -804,38 +223,33 @@ OAuthCredentials loadOAuthCredentials(const std::string & path) if (app->has("issuer")) creds.issuer = app->getValue("issuer"); - // Warn if any endpoint uses plain HTTP — token exchanges should be encrypted. - auto warnIfHttp = [&](const std::string & field, const std::string & uri) + auto warn_if_http = [&](const std::string & field, const std::string & uri) { - if (uri.size() >= 7 && uri.substr(0, 7) == "http://") + if (uri.starts_with("http://")) std::cerr << "Warning: OAuth credentials field '" << field << "' uses plain HTTP ('" << uri << "'). Token exchanges over HTTP expose client credentials.\n"; }; - warnIfHttp("token_uri", creds.token_uri); - warnIfHttp("auth_uri", creds.auth_uri); + warn_if_http("token_uri", creds.token_uri); + warn_if_http("auth_uri", creds.auth_uri); if (!creds.device_auth_uri.empty()) - warnIfHttp("device_authorization_uri", creds.device_auth_uri); + warn_if_http("device_authorization_uri", creds.device_auth_uri); return creds; } std::string obtainIDToken(const OAuthCredentials & creds, OAuthFlowMode mode) { - // 1. Try cached refresh token silently. - const std::string cached_refresh = readCachedRefreshToken(creds.client_id); + const std::string cached_refresh = readCachedRefreshTokenImpl(creds.client_id); if (!cached_refresh.empty()) { const std::string id_token = tryRefreshToken(creds, cached_refresh); if (!id_token.empty()) return id_token; - // Refresh token expired or revoked — fall through to interactive flow. } - // 2. Run interactive flow. if (mode == OAuthFlowMode::Device) - return runDeviceFlow(creds); - else - return runAuthCodeFlow(creds); + return runOAuthDeviceFlow(creds); + return runOAuthAuthCodeFlow(creds); } } // namespace DB diff --git a/src/Client/OAuthProviderPolicy.cpp b/src/Client/OAuthProviderPolicy.cpp new file mode 100644 index 000000000000..1bf589e95912 --- /dev/null +++ b/src/Client/OAuthProviderPolicy.cpp @@ -0,0 +1,127 @@ +#include +#include + +#if USE_JWT_CPP && USE_SSL + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace DB +{ + +namespace ErrorCodes +{ +extern const int AUTHENTICATION_FAILED; +} + +namespace +{ + +constexpr int HTTP_TIMEOUT_SECONDS = 30; + +std::string fetchDeviceEndpointFromIssuer(const std::string & issuer) +{ + const std::string discovery_url = issuer + "/.well-known/openid-configuration"; + Poco::URI disc_uri(discovery_url); + + Poco::Net::HTTPRequest request(Poco::Net::HTTPRequest::HTTP_GET, disc_uri.getPathAndQuery()); + Poco::Net::HTTPResponse response; + std::string body; + + if (disc_uri.getScheme() == "https") + { + Poco::Net::Context::Ptr ctx = Poco::Net::SSLManager::instance().defaultClientContext(); + Poco::Net::HTTPSClientSession session(disc_uri.getHost(), disc_uri.getPort(), ctx); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + session.sendRequest(request); + auto & stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(stream, body); + } + else + { + Poco::Net::HTTPClientSession session(disc_uri.getHost(), disc_uri.getPort()); + session.setTimeout(Poco::Timespan(HTTP_TIMEOUT_SECONDS, 0)); + session.sendRequest(request); + auto & stream = session.receiveResponse(response); + Poco::StreamCopier::copyToString(stream, body); + } + + if (response.getStatus() != Poco::Net::HTTPResponse::HTTP_OK) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OIDC discovery failed for '{}': {} {}", + discovery_url, + static_cast(response.getStatus()), + response.getReason()); + + Poco::JSON::Parser parser; + auto result = parser.parse(body); + const auto & obj = result.extract(); + + if (!obj->has("device_authorization_endpoint")) + throw Exception( + ErrorCodes::AUTHENTICATION_FAILED, + "OIDC discovery document at '{}' does not contain device_authorization_endpoint", + discovery_url); + + return obj->getValue("device_authorization_endpoint"); +} + +std::string inferIssuerFromTokenUri(const std::string & token_uri) +{ + Poco::URI uri(token_uri); + + std::string issuer = uri.getScheme() + "://" + uri.getHost(); + if (uri.getPort() != 0 + && !((uri.getScheme() == "https" && uri.getPort() == 443) + || (uri.getScheme() == "http" && uri.getPort() == 80))) + issuer += ":" + std::to_string(uri.getPort()); + + const auto & path = uri.getPath(); + const auto last_slash = path.rfind('/'); + if (last_slash != std::string::npos && last_slash != 0) + issuer += path.substr(0, last_slash); + + return issuer; +} + +} + +std::unique_ptr IOAuthProviderPolicy::create(const OAuthCredentials & creds) +{ + if (GoogleOAuthProviderPolicy::matches(creds)) + return std::make_unique(); + return std::make_unique(); +} + +std::string GoogleOAuthProviderPolicy::resolveDeviceAuthorizationEndpoint(const OAuthCredentials & creds) const +{ + if (!creds.device_auth_uri.empty()) + return creds.device_auth_uri; + + const std::string issuer = creds.issuer.empty() ? "https://accounts.google.com" : creds.issuer; + return fetchDeviceEndpointFromIssuer(issuer); +} + +std::string GenericOAuthProviderPolicy::resolveDeviceAuthorizationEndpoint(const OAuthCredentials & creds) const +{ + if (!creds.device_auth_uri.empty()) + return creds.device_auth_uri; + + const std::string issuer = creds.issuer.empty() ? inferIssuerFromTokenUri(creds.token_uri) : creds.issuer; + return fetchDeviceEndpointFromIssuer(issuer); +} + +} // namespace DB + +#endif // USE_JWT_CPP && USE_SSL diff --git a/src/Client/OAuthProviderPolicy.h b/src/Client/OAuthProviderPolicy.h new file mode 100644 index 000000000000..cbfb04e802c1 --- /dev/null +++ b/src/Client/OAuthProviderPolicy.h @@ -0,0 +1,58 @@ +#pragma once + +#include +#include + +#if USE_JWT_CPP && USE_SSL + +#include + +#include +#include + +namespace DB +{ + +/// Provider-specific behavior for OAuth/OIDC flows. +/// To add a new provider: subclass, implement all virtuals, add matches() check, +/// and register in IOAuthProviderPolicy::create(). +class IOAuthProviderPolicy +{ +public: + virtual ~IOAuthProviderPolicy() = default; + + virtual std::string getAuthCodeScope() const = 0; + virtual bool useAccessTypeOfflineForAuthCode() const = 0; + virtual std::string getDeviceScope() const = 0; + virtual std::string resolveDeviceAuthorizationEndpoint(const OAuthCredentials & creds) const = 0; + + static std::unique_ptr create(const OAuthCredentials & creds); +}; + +class GoogleOAuthProviderPolicy final : public IOAuthProviderPolicy +{ +public: + static bool matches(const OAuthCredentials & creds) + { + const std::string & host = Poco::URI(creds.token_uri).getHost(); + return host == "oauth2.googleapis.com" || host == "accounts.google.com"; + } + + std::string getAuthCodeScope() const override { return "openid email profile"; } + bool useAccessTypeOfflineForAuthCode() const override { return true; } + std::string getDeviceScope() const override { return "openid email profile"; } + std::string resolveDeviceAuthorizationEndpoint(const OAuthCredentials & creds) const override; +}; + +class GenericOAuthProviderPolicy final : public IOAuthProviderPolicy +{ +public: + std::string getAuthCodeScope() const override { return "openid email profile offline_access"; } + bool useAccessTypeOfflineForAuthCode() const override { return false; } + std::string getDeviceScope() const override { return "openid email profile offline_access"; } + std::string resolveDeviceAuthorizationEndpoint(const OAuthCredentials & creds) const override; +}; + +} + +#endif // USE_JWT_CPP && USE_SSL From fb84986b545a7d8f8bd2e382cc57566baf7b232e Mon Sep 17 00:00:00 2001 From: Andrey Zvonov Date: Fri, 17 Apr 2026 15:18:48 +0200 Subject: [PATCH 3/4] fix test --- tests/integration/helpers/cluster.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/helpers/cluster.py b/tests/integration/helpers/cluster.py index 55633488505c..6e102818d50e 100644 --- a/tests/integration/helpers/cluster.py +++ b/tests/integration/helpers/cluster.py @@ -1800,7 +1800,7 @@ def setup_keycloak_cmd(self, instance, env_variables, docker_compose_yml_dir): self.with_keycloak = True env_variables["KEYCLOAK_EXTERNAL_PORT"] = str(self.keycloak_port) env_variables["KEYCLOAK_REALM_FILE"] = p.join( - p.dirname(instance.path), + self.base_dir, "keycloak", "realm-export.json", ) From f7b4dfff54426ff809e2e6b251499d478165e874 Mon Sep 17 00:00:00 2001 From: Andrey Zvonov Date: Sun, 19 Apr 2026 02:28:13 +0200 Subject: [PATCH 4/4] fix tests(2) --- .../keycloak/realm-export.json | 36 ++-------- tests/integration/test_keycloak_auth/test.py | 68 ++++++++++++++++--- 2 files changed, 63 insertions(+), 41 deletions(-) diff --git a/tests/integration/test_keycloak_auth/keycloak/realm-export.json b/tests/integration/test_keycloak_auth/keycloak/realm-export.json index 2257e7b548db..c3067b6f65ec 100644 --- a/tests/integration/test_keycloak_auth/keycloak/realm-export.json +++ b/tests/integration/test_keycloak_auth/keycloak/realm-export.json @@ -24,6 +24,11 @@ { "username": "alice", "enabled": true, + "emailVerified": true, + "email": "alice@example.com", + "firstName": "Alice", + "lastName": "Tester", + "requiredActions": [], "credentials": [ { "type": "password", @@ -31,6 +36,7 @@ "temporary": false } ], + "realmRoles": ["offline_access", "uma_authorization", "default-roles-clickhouse-test"], "groups": ["analysts"] } ], @@ -38,35 +44,5 @@ { "name": "analysts" } - ], - "clientScopes": [ - { - "name": "groups", - "protocol": "openid-connect", - "attributes": { - "include.in.token.scope": "true" - }, - "protocolMappers": [ - { - "name": "groups", - "protocol": "openid-connect", - "protocolMapper": "oidc-group-membership-mapper", - "config": { - "full.path": "false", - "id.token.claim": "true", - "access.token.claim": "true", - "claim.name": "groups", - "userinfo.token.claim": "true" - } - } - ] - } - ], - "defaultDefaultClientScopes": [ - "profile", - "email", - "roles", - "web-origins", - "groups" ] } diff --git a/tests/integration/test_keycloak_auth/test.py b/tests/integration/test_keycloak_auth/test.py index bbd8a71907d7..46a92071adb4 100644 --- a/tests/integration/test_keycloak_auth/test.py +++ b/tests/integration/test_keycloak_auth/test.py @@ -218,12 +218,60 @@ def _approve_device_code_via_browser( s = requests.Session() - def get_form(html): - """Return (action_url, field_dict) for the first in *html*.""" + def _strip_secure_flag(session): + """Keycloak >= 25 emits Set-Cookie with Secure;SameSite=None on every + response, but the integration tests reach Keycloak over plain HTTP. + ``requests`` honors the Secure flag and refuses to resend those cookies + on the next HTTP hop, which causes Keycloak to lose its session and + return ``cookie_not_found``. Clear the flag after every response so the + cookies are sent on subsequent HTTP requests.""" + for cookie in session.cookies: + if getattr(cookie, "secure", False): + cookie.secure = False + + def _follow(method, url, **kw): + """Manually walk redirects so we can strip the Secure flag between + hops; ``requests`` follows redirects internally before our hook can + run, which is too late once the chain has dropped a Secure cookie.""" + kw.setdefault("timeout", 30) + kw["allow_redirects"] = False + for _ in range(20): + r = s.request(method, url, **kw) + _strip_secure_flag(s) + if r.status_code not in (301, 302, 303, 307, 308): + return r + loc = r.headers.get("Location") + if not loc: + return r + if loc.startswith("/"): + from urllib.parse import urlparse + parsed = urlparse(url) + url = f"{parsed.scheme}://{parsed.netloc}{loc}" + else: + url = loc + method = "GET" + kw.pop("data", None) + kw.pop("json", None) + kw.pop("params", None) + raise RuntimeError("Too many redirects") + + def get(url, **kw): + return _follow("GET", url, **kw) + + def post(url, **kw): + return _follow("POST", url, **kw) + + def get_form(html, base_url=None): + """Return (action_url, field_dict) for the first in *html*. + + Resolves relative ``action`` URLs against *base_url* when provided.""" m = re.search(r']*\baction="([^"]+)"', html) if not m: return None, {} action_url = html_unescape(m.group(1)) + if base_url and not re.match(r"^https?://", action_url): + from urllib.parse import urljoin + action_url = urljoin(base_url, action_url) fields = {} for inp in re.findall(r"]+>", html): n = re.search(r'\bname="([^"]+)"', inp) @@ -235,40 +283,38 @@ def get_form(html): # Step 1: Navigate to the device endpoint. Keycloak redirects to a login # page when the user_code query parameter is provided and valid. - r = s.get( + r = get( f"{keycloak_base_url}/realms/{realm}/device", params={"user_code": user_code}, - allow_redirects=True, - timeout=30, ) r.raise_for_status() # Step 1a: If Keycloak shows a user-code entry form first (no user_code # in the redirect), fill it in and submit. if 'name="device_user_code"' in r.text or 'name="user_code"' in r.text: - action, fields = get_form(r.text) + action, fields = get_form(r.text, base_url=r.url) fields["device_user_code"] = user_code fields["user_code"] = user_code - r = s.post(action, data=fields, allow_redirects=True, timeout=30) + r = post(action, data=fields) r.raise_for_status() # Step 2: We should now be on the login page. Submit credentials. assert 'type="password"' in r.text, ( f"Expected Keycloak login page, got:\n{r.text[:800]}" ) - action, fields = get_form(r.text) + action, fields = get_form(r.text, base_url=r.url) fields["username"] = username fields["password"] = password - r = s.post(action, data=fields, allow_redirects=True, timeout=30) + r = post(action, data=fields) r.raise_for_status() # Step 3: Submit the device consent / grant form. Keycloak renders a # "Do you want to grant access?" page with an `accept` submit button. - action, fields = get_form(r.text) + action, fields = get_form(r.text, base_url=r.url) if action: if "accept" not in fields: fields["accept"] = "" - s.post(action, data=fields, allow_redirects=True, timeout=30) + post(action, data=fields) def test_device_flow_initiation(started_cluster):