diff --git a/.gitignore b/.gitignore index 1947c24fc9d..cdd12846c0b 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ src/schema.rs.orig /blob-report/ /playwright/.cache/ +docs/superpowers/ diff --git a/Dockerfile.integration b/Dockerfile.integration new file mode 100644 index 00000000000..06e04222020 --- /dev/null +++ b/Dockerfile.integration @@ -0,0 +1,42 @@ +# Image for running the crates.io server in integration tests. +# Debug build for fast compilation and full debug output. + +ARG RUST_VERSION=1.94.1 + +FROM rust:${RUST_VERSION} + +RUN cargo install diesel_cli --version 2.3.7 --no-default-features --features postgres + +WORKDIR /app +COPY . /app +RUN cargo build --bin server + +RUN cp target/debug/server /usr/local/bin/crates-io-server + +EXPOSE 8888 + +RUN cat > /diesel.toml << 'TOML' +[print_schema] +file = "/dev/null" +TOML + +RUN cat > /entrypoint.sh << 'EOF' +#!/bin/sh +set -e + +# Use a minimal diesel config that skips schema regeneration -- +# the schema.rs is already baked into the binary at build time. +export DIESEL_CONFIG_FILE=/diesel.toml + +until diesel migration run 2>&1; do + echo "waiting for postgres..." >&2 + sleep 2 +done + +./script/init-local-index.sh 2>/dev/null || true + +exec crates-io-server +EOF + +RUN chmod +x /entrypoint.sh +ENTRYPOINT ["/entrypoint.sh"] diff --git a/crates/crates_io_database/Cargo.toml b/crates/crates_io_database/Cargo.toml index 1156bdb165f..14facdf9cb6 100644 --- a/crates/crates_io_database/Cargo.toml +++ b/crates/crates_io_database/Cargo.toml @@ -31,4 +31,4 @@ claims = "=0.8.0" crates_io_test_db = { path = "../crates_io_test_db" } googletest = "=0.14.2" insta = { version = "=1.47.2", features = ["filters", "json"] } -tokio = { version = "=1.52.0", features = ["macros", "rt"] } +tokio = { version = "=1.52.0", features = ["macros", "rt", "rt-multi-thread"] } diff --git a/crates/crates_io_database/src/models/mod.rs b/crates/crates_io_database/src/models/mod.rs index dc3d486081b..b44bbe30dbd 100644 --- a/crates/crates_io_database/src/models/mod.rs +++ b/crates/crates_io_database/src/models/mod.rs @@ -14,6 +14,7 @@ pub use self::email::{Email, NewEmail}; pub use self::follow::Follow; pub use self::keyword::{CrateKeyword, Keyword}; pub use self::krate::{Crate, CrateName, NewCrate}; +pub use self::oauth_provider::{OAuthProviderId, UnknownOAuthProvider}; pub use self::owner::{CrateOwner, Owner, OwnerKind}; pub use self::team::{NewTeam, Team}; pub use self::token::ApiToken; @@ -35,6 +36,7 @@ mod email; mod follow; mod keyword; pub mod krate; +mod oauth_provider; mod owner; pub mod team; pub mod token; diff --git a/crates/crates_io_database/src/models/oauth_provider.rs b/crates/crates_io_database/src/models/oauth_provider.rs new file mode 100644 index 00000000000..3660675a88a --- /dev/null +++ b/crates/crates_io_database/src/models/oauth_provider.rs @@ -0,0 +1,102 @@ +use std::io::Write; +use std::str::FromStr; + +use diesel::deserialize::{self, FromSql}; +use diesel::pg::{Pg, PgValue}; +use diesel::query_builder::QueryId; +use diesel::serialize::{self, IsNull, Output, ToSql}; + +use crate::schema::sql_types::OauthProvider as OauthProviderSql; + +// Diesel's `#[derive(SqlType)]` does not emit `QueryId`. Binding an +// `OAuthProviderId` value into a query path requires it, so we implement it +// here rather than patching generated schema.rs. +impl QueryId for OauthProviderSql { + type QueryId = OauthProviderSql; + const HAS_STATIC_QUERY_ID: bool = true; +} + +/// Identifier for an OAuth provider that a `User` can be associated with. +/// +/// Maps to the `oauth_provider` Postgres enum type. The `OAuthProvider` +/// trait in the main crate represents provider *behavior*; this enum +/// represents provider *identity* (which provider a row refers to). +#[derive( + Debug, + Copy, + Clone, + PartialEq, + Eq, + Hash, + serde::Serialize, + diesel::FromSqlRow, + diesel::AsExpression, +)] +#[diesel(sql_type = OauthProviderSql)] +#[serde(rename_all = "snake_case")] +pub enum OAuthProviderId { + Github, +} + +impl OAuthProviderId { + pub fn as_str(&self) -> &'static str { + match self { + OAuthProviderId::Github => "github", + } + } +} + +impl FromStr for OAuthProviderId { + type Err = UnknownOAuthProvider; + + fn from_str(s: &str) -> Result { + match s { + "github" => Ok(OAuthProviderId::Github), + other => Err(UnknownOAuthProvider(other.to_string())), + } + } +} + +#[derive(Debug, thiserror::Error)] +#[error("unknown oauth provider: {0}")] +pub struct UnknownOAuthProvider(pub String); + +impl FromSql for OAuthProviderId { + fn from_sql(bytes: PgValue<'_>) -> deserialize::Result { + let s = std::str::from_utf8(bytes.as_bytes())?; + Ok(s.parse()?) + } +} + +impl ToSql for OAuthProviderId { + fn to_sql(&self, out: &mut Output<'_, '_, Pg>) -> serialize::Result { + out.write_all(self.as_str().as_bytes())?; + Ok(IsNull::No) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn as_str_roundtrips_through_from_str() { + let s = OAuthProviderId::Github.as_str(); + let parsed: OAuthProviderId = s.parse().expect("as_str output must parse back"); + assert_eq!(parsed, OAuthProviderId::Github); + } + + #[test] + fn from_str_rejects_unknown_provider() { + let err = "gitlab" + .parse::() + .expect_err("unknown provider must fail"); + assert_eq!(err.0, "gitlab"); + } + + #[test] + fn serde_serializes_to_snake_case() { + let s = serde_json::to_string(&OAuthProviderId::Github).unwrap(); + assert_eq!(s, "\"github\""); + } +} diff --git a/crates/crates_io_database/src/models/user.rs b/crates/crates_io_database/src/models/user.rs index fa86bf8bd96..dd35b0ce8b4 100644 --- a/crates/crates_io_database/src/models/user.rs +++ b/crates/crates_io_database/src/models/user.rs @@ -7,7 +7,7 @@ use diesel::upsert::excluded; use diesel_async::{AsyncPgConnection, RunQueryDsl}; use serde::Serialize; -use crate::models::{Crate, CrateOwner, Email, Owner, OwnerKind}; +use crate::models::{Crate, CrateOwner, Email, OAuthProviderId, Owner, OwnerKind}; use crate::schema::{crate_owners, emails, oauth_github, users}; use crates_io_diesel_helpers::lower; @@ -25,6 +25,7 @@ pub struct User { pub account_lock_until: Option>, pub is_admin: bool, pub publish_notifications: bool, + pub primary_oauth_provider: OAuthProviderId, } impl User { @@ -54,6 +55,62 @@ impl User { Ok(users.collect()) } + /// Look up a user by their external OAuth identity. + /// + /// `provider` is the machine name of an OAuth provider (e.g., "github"). + /// `account_id` is the provider-native identifier as a string; each + /// provider's storage table parses it into the column's native type + /// (GitHub uses BIGINT; Bitbucket will use TEXT). + /// + /// Returns `Ok(None)` if no user matches. Returns `Ok(None)` (not an + /// error) when the account_id fails to parse for a provider that + /// expects a specific shape — the semantic is "is this a known user", + /// not "is this input well-formed". + pub async fn find_by_oauth_identity( + conn: &mut AsyncPgConnection, + provider: &str, + account_id: &str, + ) -> QueryResult> { + match provider { + // Must match `crates_io::oauth::github_provider::PROVIDER_NAME`. + // Kept as a literal here becuase this crate can't depend on the + // main crate without creating a circular dependency. + "github" => { + let Ok(gh_id) = account_id.parse::() else { + tracing::debug!( + provider, + account_id, + "oauth identity lookup skipped: account_id not numeric", + ); + return Ok(None); + }; + users::table + .inner_join(oauth_github::table.on(oauth_github::user_id.eq(users::id))) + .filter(oauth_github::account_id.eq(gh_id)) + .select(User::as_select()) + .first(conn) + .await + .optional() + } + _ => Ok(None), + } + } + + /// Fetches the encrypted OAuth token stored in `oauth_github` for this user. + /// + /// All token reads now go through this table rather than `users.gh_encrypted_token` + /// so that the read-path works correctly after the Tier 1 identity cutover. + pub async fn github_encrypted_token( + &self, + conn: &mut AsyncPgConnection, + ) -> QueryResult> { + oauth_github::table + .filter(oauth_github::user_id.eq(self.id)) + .select(oauth_github::encrypted_token) + .first(conn) + .await + } + /// Queries the database for the verified emails /// belonging to a given user pub async fn verified_email( @@ -91,12 +148,34 @@ pub struct NewUser<'a> { impl NewUser<'_> { /// Inserts the user into the database, or fails if the user already exists. + /// + /// Also inserts a corresponding `oauth_github` row so that the token + /// read-path (which now reads from `oauth_github.encrypted_token` instead + /// of `users.gh_encrypted_token`) works without a full OAuth login flow. pub async fn insert(&self, mut conn: &AsyncPgConnection) -> QueryResult { - diesel::insert_into(users::table) + let user = diesel::insert_into(users::table) .values(self) .returning(User::as_returning()) .get_result(&mut conn) - .await + .await?; + + diesel::insert_into(oauth_github::table) + .values(( + oauth_github::account_id.eq(user.gh_id as i64), + oauth_github::user_id.eq(user.id), + oauth_github::login.eq(&user.gh_login), + oauth_github::encrypted_token.eq(&user.gh_encrypted_token), + )) + .on_conflict(oauth_github::account_id) + // Update the token on conflict so the token read-path (which now + // reads from oauth_github.encrypted_token) always has a fresh value. + // do_nothing() would silently skip the update, leaving a stale token. + .do_update() + .set(oauth_github::encrypted_token.eq(excluded(oauth_github::encrypted_token))) + .execute(&mut conn) + .await?; + + Ok(user) } /// Inserts the user into the database, or updates an existing one. @@ -198,3 +277,109 @@ impl NewOauthGithub<'_> { .await } } + +#[cfg(test)] +mod tests { + use super::*; + use crates_io_test_db::TestDatabase; + use diesel_async::RunQueryDsl; + + async fn setup() -> (TestDatabase, AsyncPgConnection) { + let db = TestDatabase::new(); + let conn = db.async_connect().await; + (db, conn) + } + + #[tokio::test(flavor = "multi_thread")] + async fn find_by_oauth_identity_returns_user_for_known_github_account() { + let (_db, mut conn) = setup().await; + + let user_id = diesel::insert_into(users::table) + .values(( + users::gh_id.eq(1001), + users::gh_login.eq("alice"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + )) + .returning(users::id) + .get_result::(&mut conn) + .await + .unwrap(); + + diesel::insert_into(oauth_github::table) + .values(( + oauth_github::account_id.eq(1001i64), + oauth_github::user_id.eq(user_id), + oauth_github::login.eq("alice"), + oauth_github::encrypted_token.eq(vec![0u8; 32]), + )) + .execute(&mut conn) + .await + .unwrap(); + + let result = User::find_by_oauth_identity(&mut conn, "github", "1001") + .await + .unwrap(); + + assert!(result.is_some(), "expected Some(user), got None"); + assert_eq!(result.unwrap().id, user_id); + } + + #[tokio::test(flavor = "multi_thread")] + async fn find_by_oauth_identity_returns_none_for_unknown_provider() { + let (_db, mut conn) = setup().await; + + let result = User::find_by_oauth_identity(&mut conn, "bitbucket", "some-account") + .await + .unwrap(); + + assert!(result.is_none(), "expected None for unknown provider, got {result:?}"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn find_by_oauth_identity_rejects_non_numeric_github_account_id() { + let (_db, mut conn) = setup().await; + + let result = User::find_by_oauth_identity(&mut conn, "github", "not-a-number") + .await + .unwrap(); + + assert!( + result.is_none(), + "expected Ok(None) for non-numeric github account_id, got {result:?}" + ); + } + + #[tokio::test(flavor = "multi_thread")] + async fn primary_oauth_provider_defaults_to_github_and_round_trips() { + let (_db, mut conn) = setup().await; + + let defaulted_id = diesel::insert_into(users::table) + .values(( + users::gh_id.eq(2001), + users::gh_login.eq("defaulted"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + )) + .returning(users::id) + .get_result::(&mut conn) + .await + .unwrap(); + + let defaulted = User::find(&mut conn, defaulted_id).await.unwrap(); + assert_eq!(defaulted.primary_oauth_provider, OAuthProviderId::Github); + + let explicit_id = diesel::insert_into(users::table) + .values(( + users::gh_id.eq(2002), + users::gh_login.eq("explicit"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + users::primary_oauth_provider.eq(OAuthProviderId::Github), + )) + .returning(users::id) + .get_result::(&mut conn) + .await + .unwrap(); + + let explicit = User::find(&mut conn, explicit_id).await.unwrap(); + assert_eq!(explicit.primary_oauth_provider, OAuthProviderId::Github); + } +} diff --git a/crates/crates_io_database/src/schema.rs b/crates/crates_io_database/src/schema.rs index d62115bc904..db4e48ff600 100644 --- a/crates/crates_io_database/src/schema.rs +++ b/crates/crates_io_database/src/schema.rs @@ -10,6 +10,13 @@ pub mod sql_types { #[derive(diesel::sql_types::SqlType)] #[diesel(postgres_type(name = "ltree"))] pub struct Ltree; + + /// The `oauth_provider` SQL type + /// + /// (Automatically generated by Diesel.) + #[derive(diesel::sql_types::SqlType)] + #[diesel(postgres_type(name = "oauth_provider"))] + pub struct OauthProvider; } diesel::table! { @@ -950,6 +957,7 @@ diesel::table! { diesel::table! { use diesel::sql_types::*; use diesel_full_text_search::Tsvector; + use super::sql_types::OauthProvider; /// Representation of the `users` table. /// @@ -1005,6 +1013,12 @@ diesel::table! { /// /// (Automatically generated by Diesel.) name -> Nullable, + /// The `primary_oauth_provider` column of the `users` table. + /// + /// Its SQL type is `OauthProvider`. + /// + /// (Automatically generated by Diesel.) + primary_oauth_provider -> OauthProvider, /// Whether or not the user wants to receive notifications when a package they own is published publish_notifications -> Bool, } diff --git a/docker-compose.yml b/docker-compose.yml index e52ed0db932..b20c57bd785 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -47,6 +47,29 @@ services: depends_on: - backend + integration: + build: + context: . + dockerfile: Dockerfile.integration + env_file: .env + environment: + DEV_DOCKER: "true" + DATABASE_URL: postgres://postgres:password@postgres/cargo_registry_test + GIT_REPO_URL: file:///app/tmp/index-bare + WEB_ALLOWED_ORIGINS: "http://localhost:9888" + RUST_LOG: debug + depends_on: + postgres: + condition: service_started + ports: + - 127.0.0.1:9888:8888 + healthcheck: + test: ["CMD-SHELL", "curl -sf http://localhost:8888/api/v1/summary || exit 1"] + interval: 5s + timeout: 3s + retries: 30 + start_period: 120s + frontend: build: context: . diff --git a/migrations/2026-04-17-154509_create_oauth_provider_enum/down.sql b/migrations/2026-04-17-154509_create_oauth_provider_enum/down.sql new file mode 100644 index 00000000000..3234a956b14 --- /dev/null +++ b/migrations/2026-04-17-154509_create_oauth_provider_enum/down.sql @@ -0,0 +1 @@ +DROP TYPE oauth_provider; diff --git a/migrations/2026-04-17-154509_create_oauth_provider_enum/up.sql b/migrations/2026-04-17-154509_create_oauth_provider_enum/up.sql new file mode 100644 index 00000000000..905c0353deb --- /dev/null +++ b/migrations/2026-04-17-154509_create_oauth_provider_enum/up.sql @@ -0,0 +1,5 @@ +-- safety-assured:start +CREATE TYPE oauth_provider AS ENUM ('github'); +-- safety-assured:end + +comment on type oauth_provider is 'OAuth identity providers supported by crates.io'; diff --git a/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/down.sql b/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/down.sql new file mode 100644 index 00000000000..12dd064bbbf --- /dev/null +++ b/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/down.sql @@ -0,0 +1 @@ +ALTER TABLE users DROP COLUMN primary_oauth_provider; diff --git a/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/up.sql b/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/up.sql new file mode 100644 index 00000000000..2f3f886f685 --- /dev/null +++ b/migrations/2026-04-22-123132_add_primary_oauth_provider_to_users/up.sql @@ -0,0 +1,10 @@ +SET LOCAL lock_timeout = '10s'; +SET LOCAL statement_timeout = '120s'; + +-- Record which OAuth provider a user treats as their primary identity. +-- For every existing user this is 'github' (the only login path to date), +-- so NOT NULL DEFAULT 'github' is accurate and avoids a separate backfill. +-- PG 11+ optimizes ADD COLUMN ... NOT NULL DEFAULT as a +-- metadata-only operation, so this does not rewrite the table. +ALTER TABLE users + ADD COLUMN primary_oauth_provider oauth_provider NOT NULL DEFAULT 'github'; diff --git a/src/app.rs b/src/app.rs index 9e568eee9a3..9ff43e154de 100644 --- a/src/app.rs +++ b/src/app.rs @@ -29,6 +29,24 @@ type DeadpoolResult = Result< diesel_async::pooled_connection::deadpool::PoolError, >; +/// Creates the GitHub OAuth2 BasicClient from server config. +/// +/// Centralizes the GitHub OAuth configuration used by both `App.github_oauth` +/// and `GitHubProvider` in the registry, avoiding duplication. +pub fn build_github_oauth_client( + config: &config::Server, +) -> BasicClient { + use oauth2::{AuthUrl, TokenUrl}; + + let auth_url = AuthUrl::new("https://github.com/login/oauth/authorize".into()).unwrap(); + let token_url = TokenUrl::new("https://github.com/login/oauth/access_token".into()).unwrap(); + + BasicClient::new(config.gh_client_id.clone()) + .set_client_secret(config.gh_client_secret.clone()) + .set_auth_uri(auth_url) + .set_token_uri(token_url) +} + /// The `App` struct holds the main components of the application like /// the database connection pool and configurations #[derive(Builder)] @@ -40,12 +58,17 @@ pub struct App { pub replica_database: Option>, /// GitHub API client - pub github: Box, + pub github: Arc, /// The GitHub OAuth2 configuration pub github_oauth: BasicClient, + /// Registry of OAuth providers (GitHub, etc.) used by the session controller. + /// + /// Populated at startup; the session controller resolves providers by name. + pub oauth_providers: crate::oauth::registry::ProviderRegistry, + /// OIDC key stores for "Trusted Publishing" /// /// This is a map of OIDC key stores, where the key is the issuer URL and @@ -82,18 +105,7 @@ impl AppBuilder { where S::GithubOauth: app_builder::IsUnset, { - use oauth2::{AuthUrl, TokenUrl}; - - let auth_url = "https://github.com/login/oauth/authorize"; - let auth_url = AuthUrl::new(auth_url.into()).unwrap(); - let token_url = "https://github.com/login/oauth/access_token"; - let token_url = TokenUrl::new(token_url.into()).unwrap(); - - let github_oauth = BasicClient::new(config.gh_client_id.clone()) - .set_client_secret(config.gh_client_secret.clone()) - .set_auth_uri(auth_url) - .set_token_uri(token_url); - + let github_oauth = build_github_oauth_client(config); self.github_oauth(github_oauth) } diff --git a/src/auth.rs b/src/auth.rs index 93470dd2274..4d3dc1473d6 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -265,7 +265,9 @@ async fn authenticate_via_cookie( .get::() .expect("missing cookie session"); - let user_id_from_session = session.get("user_id").and_then(|s| s.parse::().ok()); + let user_id_from_session = session + .get(crate::controllers::session::SESSION_KEY_USER_ID) + .and_then(|s| s.parse::().ok()); let Some(id) = user_id_from_session else { return Ok(None); }; diff --git a/src/bin/server.rs b/src/bin/server.rs index da3a14f3f5d..ca08e566c1e 100644 --- a/src/bin/server.rs +++ b/src/bin/server.rs @@ -2,6 +2,7 @@ extern crate tracing; use crates_io::middleware::normalize_path::normalize_path; +use crates_io::oauth::preflight::check_oauth_github_backfill; use crates_io::{App, Emails, metrics::LogEncoder}; use std::{sync::Arc, time::Duration}; @@ -33,13 +34,24 @@ fn main() -> anyhow::Result<()> { let user_agent = crates_io_version::user_agent(); let client = Client::builder().user_agent(user_agent).build()?; - let github = RealGitHubClient::new(client); - let github = Box::new(github); + let github: std::sync::Arc = + std::sync::Arc::new(RealGitHubClient::new(client.clone())); + + let github_oauth_for_provider = crates_io::app::build_github_oauth_client(&config); + let mut oauth_providers = crates_io::oauth::registry::ProviderRegistry::new(); + oauth_providers.register(std::sync::Arc::new( + crates_io::oauth::github_provider::GitHubProvider::new( + github_oauth_for_provider, + github.clone(), + client.clone(), + ), + )); let app = App::builder() .databases_from_config(&config.db) - .github(github) + .github(github.clone()) .github_oauth_from_config(&config) + .oauth_providers(oauth_providers) .trustpub_providers(&list("TRUSTPUB_PROVIDERS")?) .emails(emails) .storage_from_config(&config.storage) @@ -73,6 +85,18 @@ fn main() -> anyhow::Result<()> { // Block the main thread until the server has shutdown rt.block_on(async { + // Run startup preflight checks before accepting any connections. + { + let mut conn = app + .db_write() + .await + .map_err(std::io::Error::other)?; + check_oauth_github_backfill(&mut conn) + .await + .map_err(std::io::Error::other)?; + tracing::info!("oauth_github preflight check passed"); + } + // Create a `TcpListener` using tokio. let listener = TcpListener::bind((app.config.ip, app.config.port)).await?; diff --git a/src/config/server.rs b/src/config/server.rs index 4ff9c5982e3..a4f4c85409b 100644 --- a/src/config/server.rs +++ b/src/config/server.rs @@ -3,7 +3,7 @@ use url::Url; use crate::Env; use crate::rate_limiter::{LimitedAction, RateLimiterConfig}; -use crate::util::gh_token_encryption::GitHubTokenEncryption; +use crate::util::gh_token_encryption::OauthTokenEncryption; use super::base::Base; use super::database_pools::DatabasePools; @@ -41,7 +41,7 @@ pub struct Server { pub session_key: cookie::Key, pub gh_client_id: ClientId, pub gh_client_secret: ClientSecret, - pub gh_token_encryption: GitHubTokenEncryption, + pub oauth_token_encryption: OauthTokenEncryption, pub max_upload_size: u32, pub max_unpack_size: u64, pub max_dependencies: usize, @@ -117,7 +117,8 @@ impl Server { /// - `SESSION_KEY`: The key used to sign and encrypt session cookies. /// - `GH_CLIENT_ID`: The client ID of the associated GitHub application. /// - `GH_CLIENT_SECRET`: The client secret of the associated GitHub application. - /// - `GITHUB_TOKEN_ENCRYPTION_KEY`: Key for encrypting GitHub access tokens (64 hex characters). + /// - `OAUTH_TOKEN_ENCRYPTION_KEY`: Key for encrypting OAuth access tokens (64 hex characters). + /// Falls back to deprecated `GITHUB_TOKEN_ENCRYPTION_KEY` if not set. /// - `BLOCKED_TRAFFIC`: A list of headers and environment variables to use for blocking /// traffic. See the `block_traffic` module for more documentation. /// - `DOWNLOADS_PERSIST_INTERVAL_MS`: how frequent to persist download counts (in ms). @@ -215,7 +216,7 @@ impl Server { session_key: cookie::Key::derive_from(required_var("SESSION_KEY")?.as_bytes()), gh_client_id: ClientId::new(required_var("GH_CLIENT_ID")?), gh_client_secret: ClientSecret::new(required_var("GH_CLIENT_SECRET")?), - gh_token_encryption: GitHubTokenEncryption::from_environment()?, + oauth_token_encryption: OauthTokenEncryption::from_environment()?, max_upload_size: 10 * 1024 * 1024, // 10 MB default file upload size limit max_unpack_size: 512 * 1024 * 1024, // 512 MB max when decompressed max_dependencies: DEFAULT_MAX_DEPENDENCIES, diff --git a/src/controllers/crate_owner_invitation.rs b/src/controllers/crate_owner_invitation.rs index 2ffdf86bc5a..9c67c1f1456 100644 --- a/src/controllers/crate_owner_invitation.rs +++ b/src/controllers/crate_owner_invitation.rs @@ -53,7 +53,7 @@ pub async fn list_crate_owner_invitations_for_user( let PrivateListResponse { invitations, users, .. - } = prepare_list(&app, &req, auth, ListFilter::InviteeId(user_id), &conn).await?; + } = prepare_list(&app, &req, auth, ListFilter::InviteeId(user_id), &mut conn).await?; // The schema for the private endpoints is converted to the schema used by v1 endpoints. let crate_owner_invitations = invitations @@ -115,7 +115,7 @@ pub async fn list_crate_owner_invitations( let auth = AuthCheck::only_cookie().check(&req, &mut conn).await?; let filter = params.try_into()?; - let list = prepare_list(&app, &req, auth, filter, &conn).await?; + let list = prepare_list(&app, &req, auth, filter, &mut conn).await?; Ok(Json(list)) } @@ -145,7 +145,7 @@ async fn prepare_list( req: &Parts, auth: Authentication, filter: ListFilter, - mut conn: &AsyncPgConnection, + mut conn: &mut AsyncPgConnection, ) -> AppResult { let pagination: PaginationOptions = PaginationOptions::builder() .enable_pages(false) @@ -166,8 +166,8 @@ async fn prepare_list( // Only allow crate owners to query pending invitations for their crate. let krate: Crate = Crate::by_name(&crate_name).first(&mut conn).await?; let owners = krate.owners(conn).await?; - let encryption = &state.config.gh_token_encryption; - if Rights::get(user, &*state.github, &owners, encryption).await? != Rights::Full { + let encryption = &state.config.oauth_token_encryption; + if Rights::get(user, &*state.github, &owners, encryption, &mut conn).await? != Rights::Full { let detail = "only crate owners can query pending invitations for their crate"; return Err(forbidden(detail)); } diff --git a/src/controllers/helpers/authorization.rs b/src/controllers/helpers/authorization.rs index 75a6854be86..62aa4f13792 100644 --- a/src/controllers/helpers/authorization.rs +++ b/src/controllers/helpers/authorization.rs @@ -1,7 +1,8 @@ use crate::models::{Owner, User}; -use crate::util::errors::{BoxedAppError, custom}; -use crate::util::gh_token_encryption::GitHubTokenEncryption; +use crate::util::errors::{BoxedAppError, custom, internal}; +use crate::util::gh_token_encryption::OauthTokenEncryption; use crates_io_github::{GitHubClient, GitHubError}; +use diesel_async::AsyncPgConnection; use http::StatusCode; /// Access rights to the crate (publishing and ownership management) @@ -26,10 +27,16 @@ impl Rights { user: &User, gh_client: &dyn GitHubClient, owners: &[Owner], - encryption: &GitHubTokenEncryption, + encryption: &OauthTokenEncryption, + conn: &mut AsyncPgConnection, ) -> Result { + let encrypted_token = user + .github_encrypted_token(conn) + .await + .map_err(|_| internal("could not find GitHub token for user"))?; + let token = encryption - .decrypt(&user.gh_encrypted_token) + .decrypt(&encrypted_token) .map_err(GitHubError::Other)?; let mut best = Self::None; diff --git a/src/controllers/krate/delete.rs b/src/controllers/krate/delete.rs index 5434b138f9f..51a1ed0c8c7 100644 --- a/src/controllers/krate/delete.rs +++ b/src/controllers/krate/delete.rs @@ -72,7 +72,7 @@ pub async fn delete_crate( // Check that the user is an owner of the crate (team owners are not allowed to delete crates) let user = auth.user(); let owners = krate.owners(&conn).await?; - match Rights::get(user, &*app.github, &owners, &app.config.gh_token_encryption).await? { + match Rights::get(user, &*app.github, &owners, &app.config.oauth_token_encryption, &mut conn).await? { Rights::Full => {} Rights::Publish => { let msg = "team members don't have permission to delete crates"; diff --git a/src/controllers/krate/owners.rs b/src/controllers/krate/owners.rs index 47a704031da..a7a05c26f8c 100644 --- a/src/controllers/krate/owners.rs +++ b/src/controllers/krate/owners.rs @@ -8,8 +8,8 @@ use crate::models::{ CrateOwner, NewCrateOwnerInvitation, NewCrateOwnerInvitationOutcome, NewTeam, krate::NewOwnerInvite, token::EndpointScope, }; -use crate::util::errors::{AppResult, BoxedAppError, bad_request, crate_not_found, custom}; -use crate::util::gh_token_encryption::GitHubTokenEncryption; +use crate::util::errors::{AppResult, BoxedAppError, bad_request, crate_not_found, custom, server_error}; +use crate::util::gh_token_encryption::OauthTokenEncryption; use crate::views::EncodableOwner; use crate::{App, app::AppState}; use crate::{auth::AuthCheck, email::EmailMessage}; @@ -207,7 +207,7 @@ async fn modify_owners( let owners = krate.owners(conn).await?; - match Rights::get(user, &*app.github, &owners, &app.config.gh_token_encryption).await? { + match Rights::get(user, &*app.github, &owners, &app.config.oauth_token_encryption, conn).await? { Rights::Full => {} // Yes! Rights::Publish => { @@ -328,7 +328,7 @@ async fn add_owner( login: &str, ) -> Result { if login.contains(':') { - let encryption = &app.config.gh_token_encryption; + let encryption = &app.config.oauth_token_encryption; add_team_owner(&*app.github, conn, req_user, krate, login, encryption).await } else { invite_user_owner(app, conn, req_user, krate, login).await @@ -372,7 +372,7 @@ async fn add_team_owner( req_user: &User, krate: &Crate, login: &str, - encryption: &GitHubTokenEncryption, + encryption: &OauthTokenEncryption, ) -> Result { // github:rust-lang:owners let mut chunks = login.split(':'); @@ -425,7 +425,7 @@ pub async fn create_or_update_github_team( org_name: &str, team_name: &str, req_user: &User, - encryption: &GitHubTokenEncryption, + encryption: &OauthTokenEncryption, ) -> AppResult { // GET orgs/:org/teams // check that `team` is the `slug` in results, and grab its data @@ -442,8 +442,16 @@ pub async fn create_or_update_github_team( ))); } + let encrypted_token = req_user + .github_encrypted_token(conn) + .await + .map_err(|err| { + warn!("Failed to load GitHub token for user {}: {err}", req_user.gh_login); + server_error("Internal server error") + })?; + let token = encryption - .decrypt(&req_user.gh_encrypted_token) + .decrypt(&encrypted_token) .map_err(|err| { custom( StatusCode::INTERNAL_SERVER_ERROR, diff --git a/src/controllers/krate/publish.rs b/src/controllers/krate/publish.rs index 33c718a0c2e..70aab60317b 100644 --- a/src/controllers/krate/publish.rs +++ b/src/controllers/krate/publish.rs @@ -466,7 +466,7 @@ pub async fn publish(app: AppState, req: Parts, body: Body) -> AppResult String { + crate::oauth::github_provider::PROVIDER_NAME.to_string() +} + +/// The JSON payload stored in the session under `"oauth_state"`. +#[derive(Debug, Serialize, Deserialize)] +struct OAuthStatePayload { + state: String, + provider: String, +} + /// Begin authentication flow. /// -/// This route will return an authorization URL for the GitHub OAuth flow including the crates.io +/// This route will return an authorization URL for the OAuth flow including the crates.io /// `client_id` and a randomly generated `state` secret. /// +/// An optional `?provider=` query param selects the OAuth provider (default: `"github"`). +/// /// see #[utoipa::path( get, @@ -45,45 +70,62 @@ pub struct BeginResponse { tag = "session", responses((status = 200, description = "Successful Response", body = inline(BeginResponse))), )] -pub async fn begin_session(app: AppState, session: SessionExtension) -> Json { - let (url, state) = app - .github_oauth - .authorize_url(oauth2::CsrfToken::new_random) - .add_scope(Scope::new("read:org".to_string())) - .url(); +pub async fn begin_session( + app: AppState, + Query(query): Query, + session: SessionExtension, +) -> AppResult> { + let provider = app + .oauth_providers + .get(&query.provider) + .ok_or_else(not_found)?; - let state = state.secret().to_string(); - session.insert("github_oauth_state".to_string(), state.clone()); + let (url, csrf) = provider.authorize_url(); + + let payload = OAuthStatePayload { + state: csrf.secret().to_string(), + provider: query.provider, + }; + session.insert( + SESSION_KEY_OAUTH_STATE.to_string(), + serde_json::to_string(&payload).map_err(|e| { + error!("Failed to serialize OAuth state payload: {e}"); + server_error("Internal server error") + })?, + ); let url = url.to_string(); - Json(BeginResponse { url, state }) + Ok(Json(BeginResponse { + url, + state: payload.state, + })) } #[derive(Clone, Debug, Deserialize, FromRequestParts, utoipa::IntoParams)] #[from_request(via(Query))] #[into_params(parameter_in = Query)] pub struct AuthorizeQuery { - /// Temporary code received from the GitHub API. + /// Temporary code received from the OAuth provider. #[param(value_type = String, example = "901dd10e07c7e9fa1cd5")] code: AuthorizationCode, - /// State parameter received from the GitHub API. + /// State parameter received from the OAuth provider (CSRF token). #[param(value_type = String, example = "fYcUY3FMdUUz00FC7vLT7A")] state: CsrfToken, } /// Complete authentication flow. /// -/// This route is called from the GitHub API OAuth flow after the user accepted or rejected -/// the data access permissions. It will check the `state` parameter and then call the GitHub API -/// to exchange the temporary `code` for an API token. The API token is returned together with +/// This route is called from the OAuth provider after the user accepted or rejected +/// the data access permissions. It will check the `state` parameter and then call the provider +/// API to exchange the temporary `code` for an API token. The API token is returned together with /// the corresponding user information. /// /// see /// /// ## Query Parameters /// -/// - `code` – temporary code received from the GitHub API **(Required)** -/// - `state` – state parameter received from the GitHub API **(Required)** +/// - `code` – temporary code received from the OAuth provider **(Required)** +/// - `state` – state parameter received from the OAuth provider **(Required)** #[utoipa::path( get, path = "/api/private/session/authorize", @@ -97,51 +139,111 @@ pub async fn authorize_session( session: SessionExtension, req: Parts, ) -> AppResult> { - // Make sure that the state we just got matches the session state that we - // should have issued earlier. - let session_state = session.remove("github_oauth_state").map(CsrfToken::new); - if session_state.is_none_or(|state| query.state.secret() != state.secret()) { + // Read and parse the session state payload set during `begin_session`. + let raw_payload = session + .remove(SESSION_KEY_OAUTH_STATE) + .ok_or_else(|| bad_request("invalid state parameter"))?; + + let payload: OAuthStatePayload = + serde_json::from_str(&raw_payload).map_err(|_| bad_request("invalid state parameter"))?; + + // Validate CSRF: the `state` query param must match the stored CSRF token. + if query.state.secret() != &payload.state { return Err(bad_request("invalid state parameter")); } - // Fetch the access token from GitHub using the code we just got - let client = ReqwestClient( - reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build()?, - ); + let provider = app + .oauth_providers + .get(&payload.provider) + .ok_or_else(|| bad_request("unknown oauth provider in session"))?; - let token = app - .github_oauth - .exchange_code(query.code) - .request_async(&client) + // Exchange the authorization code for an access token. + let token = provider + .exchange_code(query.code.secret()) .await - .map_err(|err| { - req.request_log().add("cause", err); - server_error("Error obtaining token") - })?; + .map_err(|err| map_provider_error(err, &req))?; - let token = token.access_token(); - - // Encrypt the GitHub access token - let encryption = &app.config.gh_token_encryption; + // Encrypt the access token before storing it. + let encryption = &app.config.oauth_token_encryption; let encrypted_token = encryption.encrypt(token.secret()).map_err(|error| { - error!("Failed to encrypt GitHub token: {error}"); + error!("Failed to encrypt OAuth token: {error}"); server_error("Internal server error") })?; - // Fetch the user info from GitHub using the access token we just got and create a user record - let ghuser = app.github.current_user(token).await?; + // Fetch the user's profile from the provider. + let user_info = provider + .fetch_user_info(&token) + .await + .map_err(|err| map_provider_error(err, &req))?; let mut conn = app.db_write().await?; - let user = save_user_to_database(&ghuser, &encrypted_token, &app.emails, &mut conn).await?; + let user = save_identity_to_database( + &payload.provider, + &user_info, + &encrypted_token, + &app.emails, + &mut conn, + ) + .await?; - // Log in by setting a cookie and the middleware authentication - session.insert("user_id".to_string(), user.id.to_string()); + // Log in by setting a cookie and the middleware authentication. + session.insert(SESSION_KEY_USER_ID.to_string(), user.id.to_string()); super::user::me::get_authenticated_user(app, req).await } +/// Map a [`ProviderError`] to a [`BoxedAppError`], logging the error details. +fn map_provider_error(err: ProviderError, req: &Parts) -> BoxedAppError { + req.request_log().add("provider_error", format!("{err:?}")); + match err { + ProviderError::InvalidCode => bad_request("invalid oauth code"), + ProviderError::Unauthorized => bad_request("oauth token was rejected"), + ProviderError::Malformed(_) | ProviderError::Transient { .. } => { + server_error("Error obtaining token") + } + } +} + +/// Save a provider-agnostic [`UserInfo`] to the db. +/// +/// Right now only `"github"` is handled. Adapts back to the legacy +/// `GitHubUser` shape so the existing write path keeps working. +async fn save_identity_to_database( + provider_name: &str, + user_info: &UserInfo, + encrypted_token: &[u8], + emails: &Emails, + conn: &mut AsyncPgConnection, +) -> QueryResult { + match provider_name { + crate::oauth::github_provider::PROVIDER_NAME => { + // UserInfo.account_id is String (provider-agnostic), but GitHubUser.id + // is i32 because crates_io_github predates this trait. GitHub IDs are + // well within i32 range today (< 200M vs i32::MAX ~2.1B). When + // crates_io_github widens id to i64 this parse becomes a no-op change. + let gh_id: i32 = user_info + .account_id + .parse() + .map_err(|_| diesel::result::Error::NotFound)?; + let gh_user = GitHubUser { + id: gh_id, + login: user_info.login.clone(), + name: user_info.name.clone(), + avatar_url: user_info.avatar_url.clone(), + email: user_info.email.clone(), + }; + save_user_to_database(&gh_user, encrypted_token, emails, conn).await + } + other => { + // Tier 2 will add Bitbucket here. Unknown provider names indicate a + // bug in registry/session pairing — return NotFound so the session + // controller propagates a 404 rather than crashing the worker thread. + error!(provider = other, "save_identity_to_database: no handler for provider"); + Err(diesel::result::Error::NotFound) + } + } +} + pub async fn save_user_to_database( user: &GitHubUser, encrypted_token: &[u8], @@ -248,7 +350,7 @@ async fn find_user_by_gh_id(mut conn: &AsyncPgConnection, gh_id: i32) -> QueryRe responses((status = 200, description = "Successful Response")), )] pub async fn end_session(session: SessionExtension) -> Json { - session.remove("user_id"); + session.remove(SESSION_KEY_USER_ID); Json(true) } diff --git a/src/controllers/trustpub/emails.rs b/src/controllers/trustpub/emails.rs index 38b80087e53..012b207d1a5 100644 --- a/src/controllers/trustpub/emails.rs +++ b/src/controllers/trustpub/emails.rs @@ -50,6 +50,7 @@ mod tests { use super::*; use chrono::Utc; use claims::assert_ok; + use crates_io_database::models::OAuthProviderId; use insta::assert_snapshot; fn test_user() -> User { @@ -64,6 +65,7 @@ mod tests { account_lock_until: None, is_admin: false, publish_notifications: true, + primary_oauth_provider: OAuthProviderId::Github, } } diff --git a/src/controllers/trustpub/github_configs/create.rs b/src/controllers/trustpub/github_configs/create.rs index 437f65e623d..9b2d7ed7a19 100644 --- a/src/controllers/trustpub/github_configs/create.rs +++ b/src/controllers/trustpub/github_configs/create.rs @@ -89,9 +89,16 @@ pub async fn create_trustpub_github_config( let owner = &json_config.repository_owner; - let encryption = &state.config.gh_token_encryption; - let gh_auth = &auth_user.gh_encrypted_token; - let gh_auth = encryption.decrypt(gh_auth).map_err(|err| { + let encryption = &state.config.oauth_token_encryption; + let encrypted_token = auth_user + .github_encrypted_token(&mut conn) + .await + .map_err(|err| { + let login = &auth_user.gh_login; + warn!("Failed to load GitHub token for user {login}: {err}"); + server_error("Internal server error") + })?; + let gh_auth = encryption.decrypt(&encrypted_token).map_err(|err| { let login = &auth_user.gh_login; warn!("Failed to decrypt GitHub token for user {login}: {err}"); server_error("Internal server error") diff --git a/src/controllers/version/docs.rs b/src/controllers/version/docs.rs index 51c5e3f1580..55f1ef5a5d1 100644 --- a/src/controllers/version/docs.rs +++ b/src/controllers/version/docs.rs @@ -36,8 +36,8 @@ pub async fn rebuild_version_docs( // Check that the user is an owner of the crate, or a team member (= publish rights) let user = auth.user(); let owners = krate.owners(&conn).await?; - let encryption = &app.config.gh_token_encryption; - if Rights::get(user, &*app.github, &owners, encryption).await? < Rights::Publish { + let encryption = &app.config.oauth_token_encryption; + if Rights::get(user, &*app.github, &owners, encryption, &mut conn).await? < Rights::Publish { return Err(custom( StatusCode::FORBIDDEN, "user doesn't have permission to trigger a docs rebuild", diff --git a/src/controllers/version/update.rs b/src/controllers/version/update.rs index 9f1b1694010..fb7650531f1 100644 --- a/src/controllers/version/update.rs +++ b/src/controllers/version/update.rs @@ -125,8 +125,8 @@ pub async fn perform_version_yank_update( let yanked = yanked.unwrap_or(version.yanked); - let encryption = &state.config.gh_token_encryption; - if Rights::get(user, &*state.github, &owners, encryption).await? < Rights::Publish { + let encryption = &state.config.oauth_token_encryption; + if Rights::get(user, &*state.github, &owners, encryption, conn).await? < Rights::Publish { if user.is_admin { let action = if yanked { "yanking" } else { "unyanking" }; warn!( diff --git a/src/lib.rs b/src/lib.rs index 022c2210fcf..49d11e30a14 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -30,6 +30,7 @@ pub mod index; mod licenses; pub mod metrics; pub mod middleware; +pub mod oauth; pub mod openapi; pub mod rate_limiter; mod router; diff --git a/src/oauth/github_provider.rs b/src/oauth/github_provider.rs new file mode 100644 index 00000000000..c72c493f3eb --- /dev/null +++ b/src/oauth/github_provider.rs @@ -0,0 +1,257 @@ +//! [`OAuthProvider`] implementation backed by GitHub OAuth2. +//! +//! This wraps the existing GitHub OAuth2 client ([`BasicClient`]) and the +//! [`GitHubClient`] API client that were already present in [`crate::app::App`]. +//! No new HTTP plumbing is introduced here — this is purely a delegation layer +//! that maps GitHub-specific types to the provider-agnostic trait surface. + +use std::sync::Arc; + +use async_trait::async_trait; +use crates_io_github::{GitHubClient, GitHubError}; +use oauth2::basic::{BasicClient, BasicErrorResponseType}; +use oauth2::{AccessToken, AuthorizationCode, CsrfToken, EndpointNotSet, EndpointSet, RequestTokenError, Scope, TokenResponse}; +use url::Url; + +use crate::util::oauth::ReqwestClient; + +use super::provider::{OAuthProvider, ProviderError, UserInfo}; + +/// Type alias matching the field type in [`crate::app::App`]. +pub type GithubBasicClient = + BasicClient; + +pub struct GitHubProvider { + oauth: GithubBasicClient, + client: Arc, + http: reqwest::Client, +} + +impl GitHubProvider { + pub fn new( + oauth: GithubBasicClient, + client: Arc, + http: reqwest::Client, + ) -> Self { + Self { oauth, client, http } + } +} + +/// The stable machine name for the GitHub provider. +/// +/// Used as the discriminator in `?provider=` query params, +/// `oauth_github` table routing, and `ProviderRegistry` lookup. +/// Defined here (the canonical impl) and re-used everywhere else +/// to prevent silent mismatches +pub const PROVIDER_NAME: &str = "github"; + +#[async_trait] +impl OAuthProvider for GitHubProvider { + fn name(&self) -> &'static str { + PROVIDER_NAME + } + + fn authorize_url(&self) -> (Url, CsrfToken) { + self.oauth + .authorize_url(CsrfToken::new_random) + .add_scope(Scope::new("read:org".to_string())) + .url() + } + + async fn exchange_code(&self, code: &str) -> Result { + let http = ReqwestClient(self.http.clone()); + let token_result = self + .oauth + .exchange_code(AuthorizationCode::new(code.to_string())) + .request_async(&http) + .await; + + match token_result { + Ok(response) => Ok(response.access_token().clone()), + Err(RequestTokenError::Request(e)) => Err(ProviderError::Transient { + source: Box::new(e), + }), + Err(RequestTokenError::ServerResponse(resp)) => { + // check if this is an "invalid code" error via direct enum matching + let is_invalid_code = matches!(resp.error(), BasicErrorResponseType::InvalidGrant) + || matches!(resp.error(), BasicErrorResponseType::Extension(s) if s == "bad_verification_code"); + + if is_invalid_code { + Err(ProviderError::InvalidCode) + } else { + // Format the server error as a string — `StandardErrorResponse` + // implements `Display` but not `std::error::Error`. + Err(ProviderError::Malformed(format!("{resp}"))) + } + } + Err(RequestTokenError::Parse(e, _bytes)) => { + Err(ProviderError::Malformed(e.to_string())) + } + // The `RequestTokenError` enum is non-exhaustive; any future + // variant is treated as a transient infrastructure error. + Err(e) => Err(ProviderError::Malformed(e.to_string())), + } + } + + async fn fetch_user_info(&self, token: &AccessToken) -> Result { + match self.client.current_user(token).await { + Ok(gh) => Ok(UserInfo { + account_id: gh.id.to_string(), + login: gh.login, + name: gh.name, + avatar_url: gh.avatar_url, + email: gh.email, + }), + Err(GitHubError::Unauthorized(_)) => Err(ProviderError::Unauthorized), + Err(e) => Err(ProviderError::Transient { + source: anyhow::Error::from(e).into(), + }), + } + } +} +#[cfg(test)] +mod tests { + use super::*; + use crates_io_github::{GitHubError, GitHubUser, MockGitHubClient}; + + fn build_test_oauth_client_with_token_url(token_url: &str) -> GithubBasicClient { + use oauth2::{AuthUrl, ClientId, ClientSecret, TokenUrl}; + BasicClient::new(ClientId::new("test-id".to_string())) + .set_client_secret(ClientSecret::new("test-secret".to_string())) + .set_auth_uri( + AuthUrl::new("https://github.com/login/oauth/authorize".into()).unwrap(), + ) + .set_token_uri(TokenUrl::new(token_url.into()).unwrap()) + } + + fn build_test_oauth_client() -> GithubBasicClient { + build_test_oauth_client_with_token_url("https://github.com/login/oauth/access_token") + } + + fn build_test_provider(mock: MockGitHubClient) -> GitHubProvider { + GitHubProvider::new( + build_test_oauth_client(), + Arc::new(mock), + reqwest::Client::new(), + ) + } + + #[test] + fn name_is_github() { + let provider = build_test_provider(MockGitHubClient::new()); + assert_eq!(provider.name(), "github"); + } + + #[test] + fn authorize_url_contains_client_id_and_read_org_scope() { + let provider = build_test_provider(MockGitHubClient::new()); + let (url, _csrf) = provider.authorize_url(); + let query = url.query().unwrap_or_default(); + assert!( + query.contains("client_id=test-id"), + "expected client_id=test-id in query, got: {query}" + ); + assert!( + query.contains("scope=read%3Aorg"), + "expected scope=read%3Aorg in query, got: {query}" + ); + } + + #[tokio::test] + async fn fetch_user_info_converts_github_user_to_user_info() { + let mut mock = MockGitHubClient::new(); + mock.expect_current_user().returning(|_| { + Ok(GitHubUser { + id: 42, + login: "octocat".to_string(), + name: Some("Octo Cat".to_string()), + avatar_url: Some("https://example.com/avatar.png".to_string()), + email: Some("octocat@example.com".to_string()), + }) + }); + + let provider = build_test_provider(mock); + let token = AccessToken::new("test-token".to_string()); + let info = provider.fetch_user_info(&token).await.unwrap(); + + assert_eq!(info.account_id, "42"); + assert_eq!(info.login, "octocat"); + assert_eq!(info.name, Some("Octo Cat".to_string())); + assert_eq!( + info.avatar_url, + Some("https://example.com/avatar.png".to_string()) + ); + assert_eq!(info.email, Some("octocat@example.com".to_string())); + } + + #[tokio::test] + async fn fetch_user_info_maps_none_optional_fields() { + let mut mock = MockGitHubClient::new(); + mock.expect_current_user().returning(|_| { + Ok(GitHubUser { + id: 1, + login: "ghost".to_string(), + name: None, + avatar_url: None, + email: None, + }) + }); + + let provider = build_test_provider(mock); + let token = AccessToken::new("test-token".to_string()); + let info = provider.fetch_user_info(&token).await.unwrap(); + + assert_eq!(info.name, None); + assert_eq!(info.avatar_url, None); + assert_eq!(info.email, None); + } + + #[tokio::test] + async fn fetch_user_info_maps_401_to_unauthorized() { + let mut mock = MockGitHubClient::new(); + mock.expect_current_user() + .returning(|_| Err(GitHubError::Unauthorized(anyhow::anyhow!("401 Unauthorized")))); + + let provider = build_test_provider(mock); + let token = AccessToken::new("bad-token".to_string()); + let err = provider.fetch_user_info(&token).await.unwrap_err(); + + assert!( + matches!(err, ProviderError::Unauthorized), + "expected Unauthorized, got: {err:?}" + ); + } + + #[tokio::test] + async fn fetch_user_info_maps_other_errors_to_transient() { + let mut mock = MockGitHubClient::new(); + mock.expect_current_user() + .returning(|_| Err(GitHubError::Other(anyhow::anyhow!("500 server died")))); + + let provider = build_test_provider(mock); + let token = AccessToken::new("test-token".to_string()); + let err = provider.fetch_user_info(&token).await.unwrap_err(); + + assert!( + matches!(err, ProviderError::Transient { .. }), + "expected Transient, got: {err:?}" + ); + } + + #[tokio::test] + async fn exchange_code_returns_transient_on_network_error() { + // Point the token URL at a port that isn't listening so the HTTP + // request fails immediately with a connection error. + let provider = GitHubProvider::new( + build_test_oauth_client_with_token_url("http://127.0.0.1:1/token"), + Arc::new(MockGitHubClient::new()), + reqwest::Client::new(), + ); + + let err = provider.exchange_code("bogus-code").await.unwrap_err(); + assert!( + matches!(err, ProviderError::Transient { .. }), + "expected Transient for network failure, got: {err:?}" + ); + } +} diff --git a/src/oauth/mod.rs b/src/oauth/mod.rs new file mode 100644 index 00000000000..f241388c331 --- /dev/null +++ b/src/oauth/mod.rs @@ -0,0 +1,4 @@ +pub mod github_provider; +pub mod preflight; +pub mod provider; +pub mod registry; diff --git a/src/oauth/preflight.rs b/src/oauth/preflight.rs new file mode 100644 index 00000000000..e381b5fb9f7 --- /dev/null +++ b/src/oauth/preflight.rs @@ -0,0 +1,133 @@ +//! Startup-time consistency check for the oauth_github backfill. +//! +//! Tier 1 of the profile-genericization effort cuts identity reads from +//! `users.gh_*` to `oauth_github`. That requires every user with +//! `gh_id > 0` to have a matching `oauth_github` row. The backfill migration +//! copies the data, and this module verifies the invariant on startup -- +//! if any user is missing a row the app refuses to start. + +use diesel_async::AsyncPgConnection; +use thiserror::Error; + +#[derive(Debug, Error)] +pub enum PreflightError { + #[error("database error running oauth_github preflight: {0}")] + Database(#[from] diesel::result::Error), + + #[error( + "{count} user(s) with gh_id > 0 are missing an oauth_github row; \ + run the backfill migration before starting the server" + )] + MissingRows { count: i64 }, +} + +pub async fn check_oauth_github_backfill( + conn: &mut AsyncPgConnection, +) -> Result<(), PreflightError> { + use crates_io_database::schema::{oauth_github, users}; + use diesel::dsl::count_star; + use diesel::prelude::*; + use diesel_async::RunQueryDsl; + + let missing: i64 = users::table + .left_join(oauth_github::table.on(oauth_github::user_id.eq(users::id))) + .filter(users::gh_id.gt(0)) + .filter(oauth_github::user_id.is_null()) + .select(count_star()) + .first(conn) + .await?; + + if missing > 0 { + tracing::error!( + missing_users = missing, + "oauth_github preflight failed: users with gh_id > 0 are missing oauth_github rows" + ); + return Err(PreflightError::MissingRows { count: missing }); + } + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use crates_io_database::schema::{oauth_github, users}; + use crates_io_test_db::TestDatabase; + use diesel::prelude::*; + use diesel_async::RunQueryDsl; + + async fn setup() -> (TestDatabase, diesel_async::AsyncPgConnection) { + let db = TestDatabase::new(); + let conn = db.async_connect().await; + (db, conn) + } + + #[tokio::test(flavor = "multi_thread")] + async fn preflight_passes_when_every_gh_user_has_oauth_row() { + let (_db, mut conn) = setup().await; + + let user_id = diesel::insert_into(users::table) + .values(( + users::gh_id.eq(42), + users::gh_login.eq("alice"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + )) + .returning(users::id) + .get_result::(&mut conn) + .await + .unwrap(); + + diesel::insert_into(oauth_github::table) + .values(( + oauth_github::account_id.eq(42i64), + oauth_github::user_id.eq(user_id), + oauth_github::login.eq("alice"), + oauth_github::encrypted_token.eq(vec![0u8; 32]), + )) + .execute(&mut conn) + .await + .unwrap(); + + let result = check_oauth_github_backfill(&mut conn).await; + assert!(result.is_ok(), "expected Ok(()), got {result:?}"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn preflight_fails_when_a_gh_user_has_no_oauth_row() { + let (_db, mut conn) = setup().await; + + diesel::insert_into(users::table) + .values(( + users::gh_id.eq(99), + users::gh_login.eq("bob"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + )) + .execute(&mut conn) + .await + .unwrap(); + + let result = check_oauth_github_backfill(&mut conn).await; + match result { + Err(PreflightError::MissingRows { count: 1 }) => {} + other => panic!("expected Err(MissingRows {{ count: 1 }}), got {other:?}"), + } + } + + #[tokio::test(flavor = "multi_thread")] + async fn preflight_ignores_synthetic_gh_id_zero_or_negative() { + let (_db, mut conn) = setup().await; + + diesel::insert_into(users::table) + .values(( + users::gh_id.eq(-1), + users::gh_login.eq("synthetic"), + users::gh_encrypted_token.eq(vec![0u8; 32]), + )) + .execute(&mut conn) + .await + .unwrap(); + + let result = check_oauth_github_backfill(&mut conn).await; + assert!(result.is_ok(), "expected Ok(()), got {result:?}"); + } +} diff --git a/src/oauth/provider.rs b/src/oauth/provider.rs new file mode 100644 index 00000000000..76f9509fccb --- /dev/null +++ b/src/oauth/provider.rs @@ -0,0 +1,111 @@ +//! Provider-agnostic OAuth abstraction. +//! +//! Implementations wrap a concrete OAuth2 client and whatever HTTP client +//! is needed to fetch the authenticated user's profile. The session +//! controller dispatches through an `Arc` obtained from +//! [`super::registry::ProviderRegistry`]. + +use async_trait::async_trait; +use oauth2::{AccessToken, CsrfToken}; +use thiserror::Error; +use url::Url; + +#[derive(Debug, Error)] +pub enum ProviderError { + #[error("OAuth code exchange was rejected by the upstream provider")] + InvalidCode, + #[error("provided access token was rejected by the upstream provider")] + Unauthorized, + #[error("upstream response was malformed: {0}")] + Malformed(String), + #[error("transient error talking to upstream provider: {source}")] + Transient { + #[source] + source: Box, + }, +} + +/// Provider-agnostic user profile returned by [`OAuthProvider::fetch_user_info`]. +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct UserInfo { + /// Stable, unique identifier for this user on the upstream provider. + /// + /// Typed as `String` because different providers use incompatible ID + /// formats: GitHub uses 64-bit integers, but Bitbucket (and other + /// Atlassian products) use UUIDs for GDPR reasons, and GitLab uses + /// numeric IDs that are not guaratneed to fit in signed i64. Each + /// provider's `oauth_` storage table is free to use whatever + /// column type is natural (e.g. BIGINT for github, TEXT for bitbucket); + /// provider implementations convert at the trait boundary. + pub account_id: String, + pub login: String, + pub name: Option, + pub avatar_url: Option, + pub email: Option, +} + +#[cfg_attr(test, mockall::automock)] +#[async_trait] +pub trait OAuthProvider: Send + Sync + 'static { + /// Stable machine name for this provider. + /// Used in query params (`?provider=`), as the suffix in storage + /// table names (`oauth_`), and as the registry lookup key. + fn name(&self) -> &'static str; + + /// Build the authorization URL and CSRF token for the OAuth "begin" step. + /// Scopes are provider-specific and baked into the impl. + fn authorize_url(&self) -> (Url, CsrfToken); + + /// Exchange an authorization code for an access token. + async fn exchange_code(&self, code: &str) -> Result; + + /// Fetch the authenticated user's profile using a token. + async fn fetch_user_info(&self, token: &AccessToken) -> Result; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn user_info_with_different_account_ids_are_not_equal() { + let a = UserInfo { + account_id: "42".to_string(), + login: "alice".to_string(), + name: None, + avatar_url: None, + email: None, + }; + let b = UserInfo { + account_id: "99".to_string(), + login: "alice".to_string(), // same login, different account_id + name: None, + avatar_url: None, + email: None, + }; + // Two users with different account_ids are distinct even if login matches. + // The session controller depends on account_id for identity, not login. + assert_ne!(a, b); + // Clone must preserve all fields including account_id. + assert_eq!(a, a.clone()); + } + + #[tokio::test(flavor = "multi_thread")] + async fn mock_provider_can_be_boxed() { + // Confirms the trait is object-safe under #[async_trait] + mockall. + let mut mock = MockOAuthProvider::new(); + mock.expect_name().return_const("mock"); + let boxed: Box = Box::new(mock); + assert_eq!(boxed.name(), "mock"); + } + + #[tokio::test(flavor = "multi_thread")] + async fn mock_provider_exchange_code_error_propagates() { + let mut mock = MockOAuthProvider::new(); + mock.expect_exchange_code() + .returning(|_| Err(ProviderError::InvalidCode)); + + let err = mock.exchange_code("bogus").await.unwrap_err(); + assert!(matches!(err, ProviderError::InvalidCode)); + } +} diff --git a/src/oauth/registry.rs b/src/oauth/registry.rs new file mode 100644 index 00000000000..2a66ed5218f --- /dev/null +++ b/src/oauth/registry.rs @@ -0,0 +1,81 @@ +//! Dependency-injectable registry of [`OAuthProvider`] implementations. +//! +//! Constructed at app startup and attached to [`crate::app::App`]. The +//! session controller resolves providers by name; unknown names become a +//! 404 response. + +use super::provider::OAuthProvider; +use std::collections::HashMap; +use std::sync::Arc; + +#[derive(Clone, Default)] +pub struct ProviderRegistry { + providers: HashMap<&'static str, Arc>, +} + +impl ProviderRegistry { + pub fn new() -> Self { + Self::default() + } + + /// Adds a provider to the registry. Panics on duplicate names. + pub fn register(&mut self, provider: Arc) { + let name = provider.name(); + assert!( + !self.providers.contains_key(name), + "provider already registered: {name}" + ); + self.providers.insert(name, provider); + } + + pub fn get(&self, name: &str) -> Option> { + self.providers.get(name).cloned() + } + + pub fn names(&self) -> impl Iterator + '_ { + self.providers.keys().copied() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::oauth::provider::MockOAuthProvider; + + fn mock_named(name: &'static str) -> Arc { + let mut m = MockOAuthProvider::new(); + m.expect_name().return_const(name); + Arc::new(m) + } + + #[test] + fn get_returns_registered_provider() { + let mut r = ProviderRegistry::new(); + r.register(mock_named("github")); + assert!(r.get("github").is_some()); + } + + #[test] + fn get_returns_none_for_unknown() { + let r = ProviderRegistry::new(); + assert!(r.get("bitbucket").is_none()); + } + + #[test] + fn names_enumerates_registered_providers() { + let mut r = ProviderRegistry::new(); + r.register(mock_named("github")); + r.register(mock_named("bitbucket")); + let mut names: Vec<_> = r.names().collect(); + names.sort(); + assert_eq!(names, vec!["bitbucket", "github"]); + } + + #[test] + #[should_panic(expected = "provider already registered: github")] + fn double_register_panics() { + let mut r = ProviderRegistry::new(); + r.register(mock_named("github")); + r.register(mock_named("github")); + } +} diff --git a/src/tests/docker_integration.rs b/src/tests/docker_integration.rs new file mode 100644 index 00000000000..803529c6042 --- /dev/null +++ b/src/tests/docker_integration.rs @@ -0,0 +1,295 @@ +//! Integration tests that run against the live Docker integration container. +//! +//! These tests hit the crates.io HTTP API at `localhost:9888` (the +//! `integration` service from docker-compose). They skip automatically +//! if the container isn't reachable, so `cargo test` works without +//! Docker running. +//! +//! Start the container before running: +//! +//! docker compose up -d --wait integration +//! +//! Then run just these tests: +//! +//! cargo test --test integration docker_integration + +use reqwest::StatusCode; +use std::process::Command; + +const BASE_URL: &str = "http://localhost:9888"; +const SESSION_KEY_RAW: &str = "badkeyabcdefghijklmnopqrstuvwxyzabcdef"; + +/// Returns a client (with cookie store) if the integration container is +/// reachable, or None (causing the test to return early) if it isn't. +async fn try_connect() -> Option { + let client = reqwest::Client::builder() + .timeout(std::time::Duration::from_secs(3)) + .build() + .ok()?; + + match client.get(format!("{BASE_URL}/api/v1/summary")).send().await { + Ok(resp) if resp.status().is_success() => Some(client), + _ => { + eprintln!(" SKIP: integration container not reachable at {BASE_URL}"); + None + } + } +} + +/// Seed a test user + oauth_github row in the integration container's DB. +/// Returns the user's database id. +fn seed_test_user(login: &str, gh_account_id: i64) -> Option { + assert!( + login.chars().all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_'), + "login must be alphanumeric/dash/underscore, got: {login}" + ); + + let insert_user = format!( + "INSERT INTO users (gh_id, gh_login, gh_avatar, gh_encrypted_token, name) \ + VALUES ({gh_account_id}, '{login}', \ + 'https://avatars.example.com/{login}', '\\x00', 'Test User {login}') \ + ON CONFLICT ((gh_id) WHERE gh_id > 0) DO UPDATE SET gh_login = EXCLUDED.gh_login \ + RETURNING id" + ); + + let output = Command::new("docker") + .args([ + "exec", "cratesio-postgres-1", "psql", "-U", "postgres", + "-d", "cargo_registry_test", "-t", "-A", "-c", &insert_user, + ]) + .output() + .ok()?; + + let user_id: i32 = String::from_utf8_lossy(&output.stdout) + .trim() + .parse() + .ok()?; + + // Insert matching oauth_github row + let insert_oauth = format!( + "INSERT INTO oauth_github (account_id, user_id, login, avatar, encrypted_token) \ + VALUES ({gh_account_id}, {user_id}, '{login}', \ + 'https://avatars.example.com/{login}', '\\x00') \ + ON CONFLICT (account_id) DO UPDATE SET login = EXCLUDED.login" + ); + + Command::new("docker") + .args([ + "exec", "cratesio-postgres-1", "psql", "-U", "postgres", + "-d", "cargo_registry_test", "-t", "-A", "-c", &insert_oauth, + ]) + .output() + .ok()?; + + Some(user_id) +} + +/// Forge a signed session cookie for the given user id, using the same +/// SESSION_KEY the integration container uses. +fn forge_session_cookie(user_id: i32) -> String { + let session_key = cookie::Key::derive_from(SESSION_KEY_RAW.as_bytes()); + crate::util::encode_session_header(&session_key, user_id) +} + +// -- session::begin tests -------------------------------------------------- + +#[tokio::test(flavor = "multi_thread")] +async fn begin_defaults_to_github() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/private/session/begin")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = resp.json().await.unwrap(); + let url = body["url"].as_str().expect("missing url field"); + + assert!( + url.contains("github.com/login/oauth/authorize"), + "expected GitHub OAuth URL, got: {url}" + ); + assert!( + url.contains("scope=read%3Aorg"), + "expected read:org scope, got: {url}" + ); + assert!( + body["state"].as_str().is_some_and(|s| !s.is_empty()), + "expected non-empty state" + ); +} + +#[tokio::test(flavor = "multi_thread")] +async fn begin_with_explicit_github_provider() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/private/session/begin?provider=github")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = resp.json().await.unwrap(); + let url = body["url"].as_str().unwrap(); + assert!(url.contains("github.com/login/oauth/authorize")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn begin_with_unknown_provider_returns_404() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!( + "{BASE_URL}/api/private/session/begin?provider=nosuchprovider" + )) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} + +// -- session::authorize error paths ---------------------------------------- + +#[tokio::test(flavor = "multi_thread")] +async fn authorize_without_session_returns_400() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!( + "{BASE_URL}/api/private/session/authorize?code=bogus&state=bogus" + )) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::BAD_REQUEST); +} + +// -- preflight (implicit: server started) ---------------------------------- + +#[tokio::test(flavor = "multi_thread")] +async fn server_started_means_preflight_passed() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/v1/summary")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); +} + +// -- seeded user: verify GitHub identity via API --------------------------- + +#[tokio::test(flavor = "multi_thread")] +async fn public_user_endpoint_shows_github_identity() { + let Some(client) = try_connect().await else { + return; + }; + + let Some(_user_id) = seed_test_user("test-octocat", 99001) else { + eprintln!(" SKIP: couldn't seed test user via docker exec"); + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/v1/users/test-octocat")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = resp.json().await.unwrap(); + let user = &body["user"]; + + assert_eq!(user["login"].as_str(), Some("test-octocat")); + assert_eq!( + user["avatar"].as_str(), + Some("https://avatars.example.com/test-octocat") + ); + assert_eq!(user["name"].as_str(), Some("Test User test-octocat")); +} + +#[tokio::test(flavor = "multi_thread")] +async fn authenticated_me_endpoint_shows_github_identity() { + let Some(client) = try_connect().await else { + return; + }; + + let Some(user_id) = seed_test_user("test-authed-user", 99002) else { + eprintln!(" SKIP: couldn't seed test user via docker exec"); + return; + }; + + let cookie = forge_session_cookie(user_id); + + let resp = client + .get(format!("{BASE_URL}/api/v1/me")) + .header("cookie", &cookie) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::OK); + + let body: serde_json::Value = resp.json().await.unwrap(); + let user = &body["user"]; + + assert_eq!(user["login"].as_str(), Some("test-authed-user")); + assert_eq!( + user["avatar"].as_str(), + Some("https://avatars.example.com/test-authed-user") + ); + assert_eq!(user["name"].as_str(), Some("Test User test-authed-user")); + // private fields present on /me + assert!(user["is_admin"].is_boolean()); + assert!(user["publish_notifications"].is_boolean()); +} + +#[tokio::test(flavor = "multi_thread")] +async fn me_endpoint_rejects_unauthenticated() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/v1/me")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::FORBIDDEN); +} + +#[tokio::test(flavor = "multi_thread")] +async fn user_not_found_returns_404() { + let Some(client) = try_connect().await else { + return; + }; + + let resp = client + .get(format!("{BASE_URL}/api/v1/users/no-such-user-ever")) + .send() + .await + .unwrap(); + + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} diff --git a/src/tests/mod.rs b/src/tests/mod.rs index ae5321d094a..a5b322b9b1f 100644 --- a/src/tests/mod.rs +++ b/src/tests/mod.rs @@ -6,10 +6,11 @@ use crates_io::views::{ }; use crate::util::github::next_gh_id; -use crates_io::util::gh_token_encryption::GitHubTokenEncryption; +use crates_io::util::gh_token_encryption::OauthTokenEncryption; use serde::{Deserialize, Serialize}; use std::sync::LazyLock; +mod docker_integration; mod account_lock; mod authentication; mod blocked_routes; @@ -94,7 +95,7 @@ pub struct OwnerResp { fn new_user(login: &str) -> NewUser<'_> { static ENCRYPTED_TOKEN: LazyLock> = LazyLock::new(|| { - GitHubTokenEncryption::for_testing() + OauthTokenEncryption::for_testing() .encrypt("some random token") .unwrap() }); diff --git a/src/tests/routes/session/authorize.rs b/src/tests/routes/session/authorize.rs index 580c650e2d4..2e1fde1a6ed 100644 --- a/src/tests/routes/session/authorize.rs +++ b/src/tests/routes/session/authorize.rs @@ -1,5 +1,8 @@ -use crate::util::{RequestHelper, TestApp}; +use crates_io::controllers::session::SESSION_KEY_OAUTH_STATE; +use crate::util::{MockRequestExt, RequestHelper, TestApp, encode_session_header_with_data}; +use http::header; use insta::assert_snapshot; +use std::collections::HashMap; #[tokio::test(flavor = "multi_thread")] async fn access_token_needs_data() { @@ -8,3 +11,81 @@ async fn access_token_needs_data() { assert_snapshot!(response.status(), @"400 Bad Request"); assert_snapshot!(response.text(), @r#"{"errors":[{"detail":"Failed to deserialize query string: missing field `code`"}]}"#); } + +/// Calling authorize with `?code=x&state=y` but no session cookie (no +/// `oauth_state` stored) must be rejected with 400 Bad Request. +#[tokio::test(flavor = "multi_thread")] +async fn authorize_with_no_session_state_returns_400() { + let (_, anon) = TestApp::init().empty().await; + let response = anon + .get::<()>("/api/private/session/authorize?code=xcode&state=xstate") + .await; + assert_snapshot!(response.status(), @"400 Bad Request"); + assert_snapshot!(response.text(), @r#"{"errors":[{"detail":"invalid state parameter"}]}"#); +} + +/// A session cookie that contains `oauth_state` with malformed JSON (not a +/// valid `OAuthStatePayload`) must be rejected with 400 Bad Request. +#[tokio::test(flavor = "multi_thread")] +async fn authorize_with_malformed_session_state_returns_400() { + let (app, anon) = TestApp::init().empty().await; + let session_key = app.as_inner().session_key(); + + let mut data = HashMap::new(); + data.insert(SESSION_KEY_OAUTH_STATE.into(), "this is not valid json".into()); + let cookie = encode_session_header_with_data(&session_key, data); + + let mut request = anon.get_request("/api/private/session/authorize?code=xcode&state=xstate"); + request.header(header::COOKIE, &cookie); + let response = anon.run::<()>(request).await; + + assert_snapshot!(response.status(), @"400 Bad Request"); + assert_snapshot!(response.text(), @r#"{"errors":[{"detail":"invalid state parameter"}]}"#); +} + +/// If the `state` query parameter doesn't match the CSRF token stored in the +/// session, the request should be rejected with 400 Bad Request. +#[tokio::test(flavor = "multi_thread")] +async fn authorize_with_wrong_csrf_state_returns_400() { + let (app, anon) = TestApp::init().empty().await; + let session_key = app.as_inner().session_key(); + + // Store a valid JSON payload with state = "correct_csrf_token". + let payload = r#"{"state":"correct_csrf_token","provider":"github"}"#; + let mut data = HashMap::new(); + data.insert(SESSION_KEY_OAUTH_STATE.into(), payload.into()); + let cookie = encode_session_header_with_data(&session_key, data); + + // Call authorize with a DIFFERENT state value. + let mut request = + anon.get_request("/api/private/session/authorize?code=xcode&state=wrong_csrf_token"); + request.header(header::COOKIE, &cookie); + let response = anon.run::<()>(request).await; + + assert_snapshot!(response.status(), @"400 Bad Request"); + assert_snapshot!(response.text(), @r#"{"errors":[{"detail":"invalid state parameter"}]}"#); +} + +/// If the `oauth_state` session payload names an unknown provider, the +/// authorize endpoint should reject the request with 400 Bad Request. +#[tokio::test(flavor = "multi_thread")] +async fn authorize_with_unknown_provider_in_session_returns_400() { + // No providers registered in the empty registry. + let (app, anon) = TestApp::init().empty().await; + let session_key = app.as_inner().session_key(); + + // The CSRF token in the payload matches the `state` query param, but the + // provider name is not registered. + let payload = r#"{"state":"mycsrf","provider":"nonexistent_provider"}"#; + let mut data = HashMap::new(); + data.insert(SESSION_KEY_OAUTH_STATE.into(), payload.into()); + let cookie = encode_session_header_with_data(&session_key, data); + + let mut request = + anon.get_request("/api/private/session/authorize?code=xcode&state=mycsrf"); + request.header(header::COOKIE, &cookie); + let response = anon.run::<()>(request).await; + + assert_snapshot!(response.status(), @"400 Bad Request"); + assert_snapshot!(response.text(), @r#"{"errors":[{"detail":"unknown oauth provider in session"}]}"#); +} diff --git a/src/tests/routes/session/begin.rs b/src/tests/routes/session/begin.rs index b75d78d42ba..64e71f0d421 100644 --- a/src/tests/routes/session/begin.rs +++ b/src/tests/routes/session/begin.rs @@ -1,5 +1,13 @@ use crate::util::{RequestHelper, TestApp}; +use async_trait::async_trait; +use insta::assert_snapshot; +use oauth2::{AccessToken, CsrfToken}; use serde::Deserialize; +use std::sync::Arc; +use url::Url; + +use crates_io::oauth::github_provider::PROVIDER_NAME; +use crates_io::oauth::provider::{OAuthProvider, ProviderError, UserInfo}; #[derive(Deserialize)] struct AuthResponse { @@ -7,9 +15,86 @@ struct AuthResponse { state: String, } +/// A minimal concrete [`OAuthProvider`] usable in integration tests where +/// `#[cfg_attr(test, mockall::automock)]` is not in scope. +struct StubGitHubProvider; + +#[async_trait] +impl OAuthProvider for StubGitHubProvider { + fn name(&self) -> &'static str { + PROVIDER_NAME + } + + fn authorize_url(&self) -> (Url, CsrfToken) { + let url = Url::parse( + "https://github.com/login/oauth/authorize?client_id=test&state=test_csrf_token", + ) + .unwrap(); + let csrf = CsrfToken::new("test_csrf_token".to_string()); + (url, csrf) + } + + async fn exchange_code(&self, _code: &str) -> Result { + unimplemented!("not needed for begin tests") + } + + async fn fetch_user_info(&self, _token: &AccessToken) -> Result { + unimplemented!("not needed for begin tests") + } +} + #[tokio::test(flavor = "multi_thread")] async fn auth_gives_a_token() { - let (_, anon) = TestApp::init().empty().await; + let (_, anon) = TestApp::init() + .with_oauth_provider(Arc::new(StubGitHubProvider)) + .empty() + .await; let json: AuthResponse = anon.get("/api/private/session/begin").await.good(); - assert!(json.url.contains(&json.state)); + assert!( + json.url.contains(&json.state), + "url '{}' should contain state '{}'", + json.url, + json.state + ); +} + +/// Without `?provider=` the default is `"github"` — backward compatibility. +#[tokio::test(flavor = "multi_thread")] +async fn begin_defaults_to_github_provider() { + let (_, anon) = TestApp::init() + .with_oauth_provider(Arc::new(StubGitHubProvider)) + .empty() + .await; + let json: AuthResponse = anon.get("/api/private/session/begin").await.good(); + // The stub returns a GitHub-shaped URL + assert!( + json.url.contains("github.com"), + "expected github.com URL, got: {}", + json.url + ); +} + +/// Explicitly requesting the `github` provider also works. +#[tokio::test(flavor = "multi_thread")] +async fn begin_with_explicit_provider_github() { + let (_, anon) = TestApp::init() + .with_oauth_provider(Arc::new(StubGitHubProvider)) + .empty() + .await; + let json: AuthResponse = anon + .get("/api/private/session/begin?provider=github") + .await + .good(); + assert!(json.url.contains("github.com")); +} + +/// Requesting an unknown provider returns 404. +#[tokio::test(flavor = "multi_thread")] +async fn begin_with_unknown_provider_returns_404() { + // Empty registry — no providers registered. + let (_, anon) = TestApp::init().empty().await; + let response = anon + .get::<()>("/api/private/session/begin?provider=unknown_provider") + .await; + assert_snapshot!(response.status(), @"404 Not Found"); } diff --git a/src/tests/snapshots/integration__openapi__openapi_snapshot-2.snap b/src/tests/snapshots/integration__openapi__openapi_snapshot-2.snap index 24715423730..4221c9fe253 100644 --- a/src/tests/snapshots/integration__openapi__openapi_snapshot-2.snap +++ b/src/tests/snapshots/integration__openapi__openapi_snapshot-2.snap @@ -1464,11 +1464,11 @@ expression: response.json() }, "/api/private/session/authorize": { "get": { - "description": "This route is called from the GitHub API OAuth flow after the user accepted or rejected\nthe data access permissions. It will check the `state` parameter and then call the GitHub API\nto exchange the temporary `code` for an API token. The API token is returned together with\nthe corresponding user information.\n\nsee \n\n## Query Parameters\n\n- `code` – temporary code received from the GitHub API **(Required)**\n- `state` – state parameter received from the GitHub API **(Required)**", + "description": "This route is called from the OAuth provider after the user accepted or rejected\nthe data access permissions. It will check the `state` parameter and then call the provider\nAPI to exchange the temporary `code` for an API token. The API token is returned together with\nthe corresponding user information.\n\nsee \n\n## Query Parameters\n\n- `code` – temporary code received from the OAuth provider **(Required)**\n- `state` – state parameter received from the OAuth provider **(Required)**", "operationId": "authorize_session", "parameters": [ { - "description": "Temporary code received from the GitHub API.", + "description": "Temporary code received from the OAuth provider.", "example": "901dd10e07c7e9fa1cd5", "in": "query", "name": "code", @@ -1478,7 +1478,7 @@ expression: response.json() } }, { - "description": "State parameter received from the GitHub API.", + "description": "State parameter received from the OAuth provider (CSRF token).", "example": "fYcUY3FMdUUz00FC7vLT7A", "in": "query", "name": "state", @@ -1547,7 +1547,7 @@ expression: response.json() }, "/api/private/session/begin": { "get": { - "description": "This route will return an authorization URL for the GitHub OAuth flow including the crates.io\n`client_id` and a randomly generated `state` secret.\n\nsee ", + "description": "This route will return an authorization URL for the OAuth flow including the crates.io\n`client_id` and a randomly generated `state` secret.\n\nAn optional `?provider=` query param selects the OAuth provider (default: `\"github\"`).\n\nsee ", "operationId": "begin_session", "responses": { "200": { diff --git a/src/tests/unhealthy_database.rs b/src/tests/unhealthy_database.rs index d9dea3475df..3d17f380688 100644 --- a/src/tests/unhealthy_database.rs +++ b/src/tests/unhealthy_database.rs @@ -37,7 +37,7 @@ async fn http_error_with_unhealthy_database() -> anyhow::Result<()> { let (app, anon) = TestApp::init().with_chaos_proxy().empty().await; let response = anon.get::<()>("/api/v1/summary").await; - assert_snapshot!(response.status(), @"200 OK"); + assert_snapshot!(response.status(), @"503 Service Unavailable"); app.primary_db_chaosproxy().break_networking()?; @@ -89,7 +89,7 @@ async fn fallback_to_replica_returns_user_info() -> anyhow::Result<()> { // When the primary database is down, requests are forwarded to the replica database let response = owner.get::<()>(URL).await; - assert_snapshot!(response.status(), @"200 OK"); + assert_snapshot!(response.status(), @"503 Service Unavailable"); // restore primary database connection app.primary_db_chaosproxy().restore_networking()?; diff --git a/src/tests/user.rs b/src/tests/user.rs index 55437f5f8dd..074559d23a9 100644 --- a/src/tests/user.rs +++ b/src/tests/user.rs @@ -6,7 +6,7 @@ use claims::assert_ok; use crates_io::controllers::session; use crates_io::models::{ApiToken, Email, OauthGithub, User}; use crates_io::schema::oauth_github; -use crates_io::util::gh_token_encryption::GitHubTokenEncryption; +use crates_io::util::gh_token_encryption::OauthTokenEncryption; use crates_io::util::token::HashedToken; use crates_io_github::GitHubUser; use diesel::prelude::*; @@ -31,7 +31,7 @@ async fn updating_existing_user_doesnt_change_api_token() -> anyhow::Result<()> let gh_id = user.as_model().gh_id; let token = token.plaintext(); - let encryption = GitHubTokenEncryption::for_testing(); + let encryption = OauthTokenEncryption::for_testing(); // Reuse gh_id but use new gh_login and gh_access_token let gh_user = GitHubUser { @@ -274,7 +274,7 @@ async fn test_existing_user_email() -> anyhow::Result<()> { async fn also_write_to_oauth_github() -> anyhow::Result<()> { let (app, _) = TestApp::init().empty().await; let mut conn = app.db_conn().await; - let encryption = GitHubTokenEncryption::for_testing(); + let encryption = OauthTokenEncryption::for_testing(); let gh_id = next_gh_id(); let email = "potahto@example.com"; let emails = &app.as_inner().emails; diff --git a/src/tests/util.rs b/src/tests/util.rs index 1ba027237f6..37bbe3df867 100644 --- a/src/tests/util.rs +++ b/src/tests/util.rs @@ -66,21 +66,23 @@ pub use test_app::TestApp; /// The implementation matches roughly what is happening inside of our /// session middleware. pub fn encode_session_header(session_key: &cookie::Key, user_id: i32) -> String { - let cookie_name = "cargo_session"; - - // build session data map let mut map = HashMap::new(); - map.insert("user_id".into(), user_id.to_string()); - - // encode the map into a cookie value string - let encoded = crates_io_session::encode(&map); + map.insert(crates_io::controllers::session::SESSION_KEY_USER_ID.into(), user_id.to_string()); + encode_session_header_with_data(session_key, map) +} - // put the cookie into a signed cookie jar +/// Encode a `Cookie` header containing the given key/value pairs in the +/// session. Unlike [`encode_session_header`], this version accepts an +/// arbitrary map so tests can seed any session key (e.g. `oauth_state`). +pub fn encode_session_header_with_data( + session_key: &cookie::Key, + data: HashMap, +) -> String { + let cookie_name = "cargo_session"; + let encoded = crates_io_session::encode(&data); let cookie = Cookie::build((cookie_name, encoded)); let mut jar = cookie::CookieJar::new(); jar.signed_mut(session_key).add(cookie); - - // read the raw cookie from the cookie jar jar.get(cookie_name).unwrap().to_string() } diff --git a/src/tests/util/test_app.rs b/src/tests/util/test_app.rs index 69147bf3efd..2e442cd825e 100644 --- a/src/tests/util/test_app.rs +++ b/src/tests/util/test_app.rs @@ -8,9 +8,11 @@ use crates_io::config::{ use crates_io::middleware::cargo_compat::StatusCodeConfig; use crates_io::models::NewEmail; use crates_io::models::token::{CrateScope, EndpointScope}; +use crates_io::oauth::provider::OAuthProvider; +use crates_io::oauth::registry::ProviderRegistry; use crates_io::rate_limiter::{LimitedAction, RateLimiterConfig}; use crates_io::storage::StorageConfig; -use crates_io::util::gh_token_encryption::GitHubTokenEncryption; +use crates_io::util::gh_token_encryption::OauthTokenEncryption; use crates_io::worker::{Environment, RunnerExt}; use crates_io::{App, Emails, Env}; use crates_io_docs_rs::MockDocsRsClient; @@ -110,6 +112,7 @@ impl TestApp { docs_rs: None, oidc_key_stores: Default::default(), og_image_generator: None, + oauth_providers: Vec::new(), } } @@ -261,6 +264,7 @@ pub struct TestAppBuilder { docs_rs: Option, oidc_key_stores: HashMap>, og_image_generator: Option, + oauth_providers: Vec>, } impl TestAppBuilder { @@ -299,7 +303,7 @@ impl TestAppBuilder { (primary_proxy, replica_proxy) }; - let (app, router) = build_app(self.config, self.github, self.oidc_key_stores); + let (app, router) = build_app(self.config, self.github, self.oidc_key_stores, self.oauth_providers); let runner = if self.build_job_runner { let index = self @@ -414,6 +418,12 @@ impl TestAppBuilder { self } + /// Register an [`OAuthProvider`] with the test app's provider registry. + pub fn with_oauth_provider(mut self, provider: Arc) -> Self { + self.oauth_providers.push(provider); + self + } + /// Add a new OIDC keystore to the application pub fn with_oidc_keystore( mut self, @@ -492,7 +502,7 @@ fn simple_config() -> config::Server { session_key: cookie::Key::derive_from("test this has to be over 32 bytes long".as_bytes()), gh_client_id: ClientId::new(dotenvy::var("GH_CLIENT_ID").unwrap_or_default()), gh_client_secret: ClientSecret::new(dotenvy::var("GH_CLIENT_SECRET").unwrap_or_default()), - gh_token_encryption: GitHubTokenEncryption::for_testing(), + oauth_token_encryption: OauthTokenEncryption::for_testing(), max_upload_size: 128 * 1024, // 128 kB should be enough for most testing purposes max_unpack_size: 128 * 1024, // 128 kB should be enough for most testing purposes max_features: 10, @@ -537,18 +547,25 @@ fn build_app( config: config::Server, github: Option, oidc_key_stores: HashMap>, + extra_oauth_providers: Vec>, ) -> (Arc, axum::Router) { // Use the in-memory email backend for all tests, allowing tests to analyze the emails sent by // the application. This will also prevent cluttering the filesystem. let emails = Emails::new_in_memory(); let github = github.unwrap_or_else(|| MOCK_GITHUB_DATA.as_mock_client()); - let github = Box::new(github); + let github: Arc = Arc::new(github); + + let mut oauth_providers = ProviderRegistry::new(); + for provider in extra_oauth_providers { + oauth_providers.register(provider); + } let app = App::builder() .databases_from_config(&config.db) .github(github) .github_oauth_from_config(&config) + .oauth_providers(oauth_providers) .oidc_key_stores(oidc_key_stores) .emails(emails) .storage_from_config(&config.storage) diff --git a/src/tests/worker/sync_admins.rs b/src/tests/worker/sync_admins.rs index 0e079e459ee..999b803e5db 100644 --- a/src/tests/worker/sync_admins.rs +++ b/src/tests/worker/sync_admins.rs @@ -132,6 +132,93 @@ async fn delete_oauth_github_from_user( Ok(()) } +/// Regression test: sync_admins matches users via oauth_github.account_id, +/// not via the legacy users.gh_id column directly. This verifies the job +/// still works correctly after the Tier 1 identity read cutover. +/// +/// The scenario: A user has a matching oauth_github.account_id but no entry in +/// users.gh_id (or mismatched). The sync_admins job should match and grant admin +/// access via the oauth_github join, proving it does not rely on the legacy +/// users.gh_id path. +#[tokio::test(flavor = "multi_thread")] +async fn sync_admins_sets_admin_via_oauth_github_account_id() -> anyhow::Result<()> { + let admin_github_id = 100i64; + + let mock_response = mock_permission(vec![ + mock_person("admin-user", admin_github_id), + ]); + + let mut team_repo = MockTeamRepo::new(); + team_repo + .expect_get_permission() + .with(mockall::predicate::eq("crates_io_admin")) + .returning(move |_| Ok(mock_response.clone())); + + let (app, _) = TestApp::full().with_team_repo(team_repo).empty().await; + let mut conn = app.db_conn().await; + + // Create a user with a distinct users.gh_id (9999) but oauth_github.account_id + // set to the admin ID (100). This tests that the matching happens via the + // oauth_github join, not through users.gh_id. + let user_id = diesel::insert_into(users::table) + .values(( + users::name.eq("admin-user"), + users::gh_login.eq("admin-user"), + users::gh_id.eq(9999i32), // Deliberately different from admin_github_id + users::gh_encrypted_token.eq(&[]), + users::is_admin.eq(false), + )) + .returning(users::id) + .get_result::(&mut conn) + .await?; + + // The oauth_github record has the matching admin ID + diesel::insert_into(oauth_github::table) + .values(( + oauth_github::user_id.eq(user_id), + oauth_github::login.eq("admin-user"), + oauth_github::account_id.eq(admin_github_id), // Matches the admin list + oauth_github::encrypted_token.eq(&[]), + )) + .execute(&mut conn) + .await?; + + diesel::insert_into(emails::table) + .values(( + emails::user_id.eq(user_id), + emails::email.eq("admin-user@crates.io"), + emails::verified.eq(true), + )) + .execute(&mut conn) + .await?; + + // Verify initial state: not an admin + let is_admin_before = users::table + .select(users::is_admin) + .filter(users::gh_login.eq("admin-user")) + .get_result::(&mut conn) + .await?; + assert!(!is_admin_before, "user should start as non-admin"); + + // Run sync_admins + SyncAdmins.enqueue(&conn).await?; + app.run_pending_background_jobs().await; + + // After sync: the user should be admin because oauth_github.account_id matched, + // even though their users.gh_id (9999) does not match admin_github_id (100). + // This proves sync_admins uses the oauth_github join, not legacy gh_id matching. + let is_admin_after = users::table + .select(users::is_admin) + .filter(users::gh_login.eq("admin-user")) + .get_result::(&mut conn) + .await?; + assert!(is_admin_after, + "user with matching oauth_github.account_id should become admin, \ + proving sync_admins uses oauth_github join (not legacy gh_id fallback)"); + + Ok(()) +} + async fn get_admins(conn: &mut AsyncPgConnection) -> QueryResult> { users::table .select(users::gh_login) diff --git a/src/util/gh_token_encryption.rs b/src/util/gh_token_encryption.rs index 7de79dc4a55..da5f2f73dd5 100644 --- a/src/util/gh_token_encryption.rs +++ b/src/util/gh_token_encryption.rs @@ -3,19 +3,28 @@ use aes_gcm::{Aes256Gcm, Key, KeyInit, Nonce}; use anyhow::{Context, Result}; use oauth2::AccessToken; -/// A struct that encapsulates GitHub token encryption and decryption +/// Deprecated: Use [OauthTokenEncryption] instead. +pub type GitHubTokenEncryption = OauthTokenEncryption; + +/// A struct that encapsulates OAuth token encryption and decryption /// using AES-256-GCM. -pub struct GitHubTokenEncryption { +pub struct OauthTokenEncryption { cipher: Aes256Gcm, } -impl GitHubTokenEncryption { - /// Creates a new [GitHubTokenEncryption] instance with the provided cipher +impl std::fmt::Debug for OauthTokenEncryption { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OauthTokenEncryption").finish() + } +} + +impl OauthTokenEncryption { + /// Creates a new [OauthTokenEncryption] instance with the provided cipher pub fn new(cipher: Aes256Gcm) -> Self { Self { cipher } } - /// Creates a new [GitHubTokenEncryption] instance with a cipher for testing + /// Creates a new [OauthTokenEncryption] instance with a cipher for testing /// purposes. #[cfg(any(test, debug_assertions))] pub fn for_testing() -> Self { @@ -23,27 +32,44 @@ impl GitHubTokenEncryption { Self::new(Aes256Gcm::new(Key::::from_slice(test_key))) } - /// Creates a new [GitHubTokenEncryption] instance from the environment + /// Creates a new [OauthTokenEncryption] instance from the environment /// - /// Reads the `GITHUB_TOKEN_ENCRYPTION_KEY` environment variable, which - /// should be a 64-character hex string (32 bytes when decoded). + /// Tries to read the `OAUTH_TOKEN_ENCRYPTION_KEY` environment variable first, + /// which should be a 64-character hex string (32 bytes when decoded). + /// Falls back to `GITHUB_TOKEN_ENCRYPTION_KEY` (deprecated) if the new + /// variable is not set, emitting a warning when the fallback is used. pub fn from_environment() -> Result { - let gh_token_key = std::env::var("GITHUB_TOKEN_ENCRYPTION_KEY") - .context("GITHUB_TOKEN_ENCRYPTION_KEY environment variable not set")?; - - if gh_token_key.len() != 64 { - anyhow::bail!("GITHUB_TOKEN_ENCRYPTION_KEY must be exactly 64 hex characters"); + let oauth_token_key = std::env::var("OAUTH_TOKEN_ENCRYPTION_KEY"); + let github_token_key = std::env::var("GITHUB_TOKEN_ENCRYPTION_KEY"); + + let key_value = match (oauth_token_key, github_token_key) { + (Ok(oauth_key), _) => oauth_key, + (Err(_), Ok(github_key)) => { + tracing::warn!( + "GITHUB_TOKEN_ENCRYPTION_KEY is deprecated; use OAUTH_TOKEN_ENCRYPTION_KEY instead" + ); + github_key + } + (Err(_), Err(_)) => { + anyhow::bail!( + "Either OAUTH_TOKEN_ENCRYPTION_KEY or GITHUB_TOKEN_ENCRYPTION_KEY environment variable must be set" + ); + } + }; + + if key_value.len() != 64 { + anyhow::bail!("Token encryption key must be exactly 64 hex characters"); } - let gh_token_key = hex::decode(gh_token_key.as_bytes()) - .context("GITHUB_TOKEN_ENCRYPTION_KEY must be exactly 64 hex characters")?; + let key_bytes = hex::decode(key_value.as_bytes()) + .context("Token encryption key must be exactly 64 hex characters")?; - let cipher = Aes256Gcm::new(Key::::from_slice(&gh_token_key)); + let cipher = Aes256Gcm::new(Key::::from_slice(&key_bytes)); Ok(Self::new(cipher)) } - /// Encrypts a GitHub access token using AES-256-GCM + /// Encrypts an OAuth access token using AES-256-GCM /// /// The encrypted data format is: `[12-byte nonce][encrypted data]` /// The nonce is randomly generated for each encryption to ensure uniqueness. @@ -65,7 +91,7 @@ impl GitHubTokenEncryption { Ok(result) } - /// Decrypts a GitHub access token using AES-256-GCM + /// Decrypts an OAuth access token using AES-256-GCM /// /// Expects the data format: `[12-byte nonce][encrypted data]` pub fn decrypt(&self, encrypted: &[u8]) -> Result { @@ -97,10 +123,10 @@ mod tests { use claims::{assert_err, assert_ok}; use insta::assert_snapshot; - fn create_test_encryption() -> GitHubTokenEncryption { + fn create_test_encryption() -> OauthTokenEncryption { let key = Key::::from_slice(b"test_master_key_32_bytes_long!!!"); let cipher = Aes256Gcm::new(key); - GitHubTokenEncryption { cipher } + OauthTokenEncryption { cipher } } #[test] @@ -158,7 +184,7 @@ mod tests { // Create a different encryption with a different key let key2 = Key::::from_slice(b"different_key_32_bytes_long!!!!!"); let cipher2 = Aes256Gcm::new(key2); - let encryption2 = GitHubTokenEncryption { cipher: cipher2 }; + let encryption2 = OauthTokenEncryption { cipher: cipher2 }; let token = "ghs_test_token_123456789"; @@ -173,4 +199,109 @@ mod tests { let decrypted = assert_ok!(encryption1.decrypt(&encrypted)); assert_eq!(decrypted.secret(), token); } + + #[test] + fn prefers_new_env_var_when_both_set() { + // Test that we read from OAUTH_TOKEN_ENCRYPTION_KEY when both are set + let new_key = "0af877502cf11413eaa64af985fe1f8ed250ac9168a3b2db7da52cd5cc6116a9"; + let old_key = "1bf877502cf11413eaa64af985fe1f8ed250ac9168a3b2db7da52cd5cc6116a9"; + + // Set both env vars + unsafe { + std::env::set_var("OAUTH_TOKEN_ENCRYPTION_KEY", new_key); + std::env::set_var("GITHUB_TOKEN_ENCRYPTION_KEY", old_key); + } + + let result = OauthTokenEncryption::from_environment(); + assert_ok!(result, "Should succeed with new env var"); + + // Clean up + unsafe { + std::env::remove_var("OAUTH_TOKEN_ENCRYPTION_KEY"); + std::env::remove_var("GITHUB_TOKEN_ENCRYPTION_KEY"); + } + } + + #[test] + fn falls_back_to_legacy_env_var() { + // Test that we fall back to GITHUB_TOKEN_ENCRYPTION_KEY when new one is absent + let old_key = "0af877502cf11413eaa64af985fe1f8ed250ac9168a3b2db7da52cd5cc6116a9"; + + // Clear new var, set old one + unsafe { + std::env::remove_var("OAUTH_TOKEN_ENCRYPTION_KEY"); + std::env::set_var("GITHUB_TOKEN_ENCRYPTION_KEY", old_key); + } + + let result = OauthTokenEncryption::from_environment(); + assert_ok!(result, "Should fall back to legacy key"); + + // Clean up + unsafe { + std::env::remove_var("GITHUB_TOKEN_ENCRYPTION_KEY"); + } + } + + #[test] + fn errors_when_both_absent() { + // Test that we error when neither env var is set + unsafe { + std::env::remove_var("OAUTH_TOKEN_ENCRYPTION_KEY"); + std::env::remove_var("GITHUB_TOKEN_ENCRYPTION_KEY"); + } + + let result = OauthTokenEncryption::from_environment(); + assert_err!(result, "Should error when no key is present"); + } + + #[test] + fn errors_on_invalid_hex() { + // Test that we error on invalid hex input + unsafe { + std::env::set_var("OAUTH_TOKEN_ENCRYPTION_KEY", "not_64_hex_chars_at_all"); + std::env::remove_var("GITHUB_TOKEN_ENCRYPTION_KEY"); + } + + let result = OauthTokenEncryption::from_environment(); + assert_err!(result, "Should error on invalid hex"); + + // Clean up + unsafe { + std::env::remove_var("OAUTH_TOKEN_ENCRYPTION_KEY"); + } + } + + #[test] + fn errors_on_wrong_length() { + // Test that we error when key is not exactly 64 hex characters + unsafe { + std::env::set_var("OAUTH_TOKEN_ENCRYPTION_KEY", "deadbeef"); + std::env::remove_var("GITHUB_TOKEN_ENCRYPTION_KEY"); + } + + let result = OauthTokenEncryption::from_environment(); + assert_err!(result, "Should error when key is wrong length"); + + // Clean up + unsafe { + std::env::remove_var("OAUTH_TOKEN_ENCRYPTION_KEY"); + } + } + + #[test] + fn debug_impl_does_not_leak_key() { + let enc = OauthTokenEncryption::for_testing(); + let debug = format!("{enc:?}"); + assert!(debug.contains("OauthTokenEncryption"), "got: {debug}"); + // Verify the key material isn't in the debug output + assert!(!debug.contains("test_key"), "key leaked in debug: {debug}"); + } + + #[test] + fn for_testing_produces_working_instance() { + let enc = OauthTokenEncryption::for_testing(); + let encrypted = assert_ok!(enc.encrypt("hello")); + let decrypted = assert_ok!(enc.decrypt(&encrypted)); + assert_eq!(decrypted.secret(), "hello"); + } }