Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- This migration is a data backfill and has no meaningful inverse. Rolling
-- it back would require distinguishing backfilled rows from rows inserted
-- by login, which we cannot do. The `oauth_github` table itself is dropped
-- by the earlier migration 2026-01-20-162913_oauth_github_table/down.sql.
Copy link
Copy Markdown
Member

@carols10cents carols10cents Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah any comment corrections should be made in a separate commit, possibly a separate PR

View changes since the review

-- Intentionally a no-op.
SELECT 1;
26 changes: 26 additions & 0 deletions migrations/2026-04-15-221937-0000_backfill_oauth_github/up.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
-- Backfill oauth_github for users who haven't logged in since the table
-- was created. Batched by id range, idempotent via ON CONFLICT DO NOTHING.

SET LOCAL lock_timeout = '10s';
SET LOCAL statement_timeout = '120s';
Copy link
Copy Markdown
Member

@carols10cents carols10cents Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you elaborate on why this commit is needed? After we merged #12792, we ran the SQL in https://github.com/rust-lang/crates.io/blob/150d165e4d5ecda24b60f146589db4def7fd34bf/migrations/data_oauth_github.sql outside of the migration framework that runs on deploy, so the backfill has already been handled.

View changes since the review


DO $$
DECLARE
lo INT;
hi INT;
pos INT;
BEGIN
SELECT MIN(id), MAX(id) INTO lo, hi FROM users WHERE gh_id > 0;
IF lo IS NULL THEN RETURN; END IF;

pos := lo;
WHILE pos <= hi LOOP
INSERT INTO oauth_github (account_id, user_id, encrypted_token, login, avatar)
SELECT gh_id, id, gh_encrypted_token, gh_login, gh_avatar
FROM users
WHERE gh_id > 0 AND id >= pos AND id < pos + 5000
ON CONFLICT (account_id) DO NOTHING;

pos := pos + 5000;
END LOOP;
END $$;
38 changes: 25 additions & 13 deletions src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,24 @@ type DeadpoolResult = Result<
diesel_async::pooled_connection::deadpool::PoolError,
>;

/// Creates the GitHub OAuth2 BasicClient from server config.
///
/// Centralizes the GitHub OAuth configuration used by both `App.github_oauth`
/// and `GitHubProvider` in the registry, avoiding duplication.
pub fn build_github_oauth_client(
config: &config::Server,
) -> BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet> {
use oauth2::{AuthUrl, TokenUrl};

let auth_url = AuthUrl::new("https://github.com/login/oauth/authorize".into()).unwrap();
let token_url = TokenUrl::new("https://github.com/login/oauth/access_token".into()).unwrap();

BasicClient::new(config.gh_client_id.clone())
.set_client_secret(config.gh_client_secret.clone())
.set_auth_uri(auth_url)
.set_token_uri(token_url)
}

/// The `App` struct holds the main components of the application like
/// the database connection pool and configurations
#[derive(Builder)]
Expand All @@ -40,12 +58,17 @@ pub struct App {
pub replica_database: Option<DeadpoolPool<AsyncPgConnection>>,

/// GitHub API client
pub github: Box<dyn GitHubClient>,
pub github: Arc<dyn GitHubClient>,

/// The GitHub OAuth2 configuration
pub github_oauth:
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>,

/// Registry of OAuth providers (GitHub, etc.) used by the session controller.
///
/// Populated at startup; the session controller resolves providers by name.
pub oauth_providers: crate::oauth::registry::ProviderRegistry,

/// OIDC key stores for "Trusted Publishing"
///
/// This is a map of OIDC key stores, where the key is the issuer URL and
Expand Down Expand Up @@ -82,18 +105,7 @@ impl<S: app_builder::State> AppBuilder<S> {
where
S::GithubOauth: app_builder::IsUnset,
{
use oauth2::{AuthUrl, TokenUrl};

let auth_url = "https://github.com/login/oauth/authorize";
let auth_url = AuthUrl::new(auth_url.into()).unwrap();
let token_url = "https://github.com/login/oauth/access_token";
let token_url = TokenUrl::new(token_url.into()).unwrap();

let github_oauth = BasicClient::new(config.gh_client_id.clone())
.set_client_secret(config.gh_client_secret.clone())
.set_auth_uri(auth_url)
.set_token_uri(token_url);

let github_oauth = build_github_oauth_client(config);
self.github_oauth(github_oauth)
}

Expand Down
32 changes: 29 additions & 3 deletions src/bin/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
extern crate tracing;

use crates_io::middleware::normalize_path::normalize_path;
use crates_io::oauth::preflight::check_oauth_github_backfill;
use crates_io::{App, Emails, metrics::LogEncoder};
use std::{sync::Arc, time::Duration};

Expand Down Expand Up @@ -33,13 +34,26 @@ fn main() -> anyhow::Result<()> {
let user_agent = crates_io_version::user_agent();
let client = Client::builder().user_agent(user_agent).build()?;

let github = RealGitHubClient::new(client);
let github = Box::new(github);
let github: std::sync::Arc<dyn crates_io_github::GitHubClient> =
std::sync::Arc::new(RealGitHubClient::new(client.clone()));
Copy link
Copy Markdown
Member

@carols10cents carols10cents Apr 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this commit should be pulled out into a separate PR from anything having to do with usernames-- we want to introduce the crates.io username first, separately from adding any other oauth provider.

View changes since the review


// Build the ProviderRegistry — currently only GitHub, wired up here so
// Commit 3 can route session::begin/authorize through it.
let github_oauth_for_provider = crates_io::app::build_github_oauth_client(&config);
let mut oauth_providers = crates_io::oauth::registry::ProviderRegistry::new();
oauth_providers.register(std::sync::Arc::new(
crates_io::oauth::github_provider::GitHubProvider::new(
github_oauth_for_provider,
github.clone(),
client.clone(),
),
));

let app = App::builder()
.databases_from_config(&config.db)
.github(github)
.github(github.clone())
.github_oauth_from_config(&config)
.oauth_providers(oauth_providers)
.trustpub_providers(&list("TRUSTPUB_PROVIDERS")?)
.emails(emails)
.storage_from_config(&config.storage)
Expand Down Expand Up @@ -73,6 +87,18 @@ fn main() -> anyhow::Result<()> {

// Block the main thread until the server has shutdown
rt.block_on(async {
// Run startup preflight checks before accepting any connections.
{
let mut conn = app
.db_write()
.await
.map_err(std::io::Error::other)?;
check_oauth_github_backfill(&mut conn)
.await
.map_err(std::io::Error::other)?;
tracing::info!("oauth_github preflight check passed");
}

// Create a `TcpListener` using tokio.
let listener = TcpListener::bind((app.config.ip, app.config.port)).await?;

Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub mod index;
mod licenses;
pub mod metrics;
pub mod middleware;
pub mod oauth;
pub mod openapi;
pub mod rate_limiter;
mod router;
Expand Down
230 changes: 230 additions & 0 deletions src/oauth/github_provider.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
//! [`OAuthProvider`] implementation backed by GitHub OAuth2.
//!
//! This wraps the existing GitHub OAuth2 client ([`BasicClient`]) and the
//! [`GitHubClient`] API client that were already present in [`crate::app::App`].
//! No new HTTP plumbing is introduced here — this is purely a delegation layer
//! that maps GitHub-specific types to the provider-agnostic trait surface.

use std::sync::Arc;

use async_trait::async_trait;
use crates_io_github::{GitHubClient, GitHubError};
use oauth2::basic::{BasicClient, BasicErrorResponseType};
use oauth2::{AccessToken, AuthorizationCode, CsrfToken, EndpointNotSet, EndpointSet, RequestTokenError, Scope, TokenResponse};
use url::Url;

use crate::util::oauth::ReqwestClient;

use super::provider::{OAuthProvider, ProviderError, UserInfo};

/// Type alias matching the field type in [`crate::app::App`].
pub type GithubBasicClient =
BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>;

pub struct GitHubProvider {
oauth: GithubBasicClient,
client: Arc<dyn GitHubClient>,
http: reqwest::Client,
}

impl GitHubProvider {
pub fn new(
oauth: GithubBasicClient,
client: Arc<dyn GitHubClient>,
http: reqwest::Client,
) -> Self {
Self { oauth, client, http }
}
}

#[async_trait]
impl OAuthProvider for GitHubProvider {
fn name(&self) -> &'static str {
"github"
}

fn authorize_url(&self) -> (Url, CsrfToken) {
self.oauth
.authorize_url(CsrfToken::new_random)
.add_scope(Scope::new("read:org".to_string()))
.url()
}

async fn exchange_code(&self, code: &str) -> Result<AccessToken, ProviderError> {
let http = ReqwestClient(self.http.clone());
let token_result = self
.oauth
.exchange_code(AuthorizationCode::new(code.to_string()))
.request_async(&http)
.await;

match token_result {
Ok(response) => Ok(response.access_token().clone()),
Err(RequestTokenError::Request(e)) => Err(ProviderError::Transient {
source: Box::new(e),
}),
Err(RequestTokenError::ServerResponse(resp)) => {
// check if this is an "invalid code" error via direct enum matching
let is_invalid_code = matches!(resp.error(), BasicErrorResponseType::InvalidGrant)
|| matches!(resp.error(), BasicErrorResponseType::Extension(s) if s == "bad_verification_code");

if is_invalid_code {
Err(ProviderError::InvalidCode)
} else {
// Format the server error as a string — `StandardErrorResponse`
// implements `Display` but not `std::error::Error`.
Err(ProviderError::Malformed(format!("{resp}")))
}
}
Err(RequestTokenError::Parse(e, _bytes)) => {
Err(ProviderError::Malformed(e.to_string()))
}
// The `RequestTokenError` enum is non-exhaustive; any future
// variant is treated as a transient infrastructure error.
Err(e) => Err(ProviderError::Malformed(e.to_string())),
}
}

async fn fetch_user_info(&self, token: &AccessToken) -> Result<UserInfo, ProviderError> {
match self.client.current_user(token).await {
Ok(gh) => Ok(UserInfo {
account_id: gh.id.to_string(),
login: gh.login,
name: gh.name,
avatar_url: gh.avatar_url,
email: gh.email,
}),
Err(GitHubError::Unauthorized(_)) => Err(ProviderError::Unauthorized),
Err(e) => Err(ProviderError::Transient {
source: anyhow::Error::from(e).into(),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crates_io_github::{GitHubError, GitHubUser, MockGitHubClient};

fn build_test_oauth_client() -> GithubBasicClient {
use oauth2::{AuthUrl, ClientId, ClientSecret, TokenUrl};
BasicClient::new(ClientId::new("test-id".to_string()))
.set_client_secret(ClientSecret::new("test-secret".to_string()))
.set_auth_uri(
AuthUrl::new("https://github.com/login/oauth/authorize".into()).unwrap(),
)
.set_token_uri(
TokenUrl::new("https://github.com/login/oauth/access_token".into()).unwrap(),
)
}

fn build_test_provider(mock: MockGitHubClient) -> GitHubProvider {
GitHubProvider::new(
build_test_oauth_client(),
Arc::new(mock),
reqwest::Client::new(),
)
}

#[test]
fn name_is_github() {
let provider = build_test_provider(MockGitHubClient::new());
assert_eq!(provider.name(), "github");
}

#[test]
fn authorize_url_contains_client_id_and_read_org_scope() {
let provider = build_test_provider(MockGitHubClient::new());
let (url, _csrf) = provider.authorize_url();
let query = url.query().unwrap_or_default();
assert!(
query.contains("client_id=test-id"),
"expected client_id=test-id in query, got: {query}"
);
assert!(
query.contains("scope=read%3Aorg"),
"expected scope=read%3Aorg in query, got: {query}"
);
}

#[tokio::test]
async fn fetch_user_info_converts_github_user_to_user_info() {
let mut mock = MockGitHubClient::new();
mock.expect_current_user().returning(|_| {
Ok(GitHubUser {
id: 42,
login: "octocat".to_string(),
name: Some("Octo Cat".to_string()),
avatar_url: Some("https://example.com/avatar.png".to_string()),
email: Some("octocat@example.com".to_string()),
})
});

let provider = build_test_provider(mock);
let token = AccessToken::new("test-token".to_string());
let info = provider.fetch_user_info(&token).await.unwrap();

assert_eq!(info.account_id, "42");
assert_eq!(info.login, "octocat");
assert_eq!(info.name, Some("Octo Cat".to_string()));
assert_eq!(
info.avatar_url,
Some("https://example.com/avatar.png".to_string())
);
assert_eq!(info.email, Some("octocat@example.com".to_string()));
}

#[tokio::test]
async fn fetch_user_info_maps_none_optional_fields() {
let mut mock = MockGitHubClient::new();
mock.expect_current_user().returning(|_| {
Ok(GitHubUser {
id: 1,
login: "ghost".to_string(),
name: None,
avatar_url: None,
email: None,
})
});

let provider = build_test_provider(mock);
let token = AccessToken::new("test-token".to_string());
let info = provider.fetch_user_info(&token).await.unwrap();

assert_eq!(info.name, None);
assert_eq!(info.avatar_url, None);
assert_eq!(info.email, None);
}

#[tokio::test]
async fn fetch_user_info_maps_401_to_unauthorized() {
let mut mock = MockGitHubClient::new();
mock.expect_current_user()
.returning(|_| Err(GitHubError::Unauthorized(anyhow::anyhow!("401 Unauthorized"))));

let provider = build_test_provider(mock);
let token = AccessToken::new("bad-token".to_string());
let err = provider.fetch_user_info(&token).await.unwrap_err();

assert!(
matches!(err, ProviderError::Unauthorized),
"expected Unauthorized, got: {err:?}"
);
}

#[tokio::test]
async fn fetch_user_info_maps_other_errors_to_transient() {
let mut mock = MockGitHubClient::new();
mock.expect_current_user()
.returning(|_| Err(GitHubError::Other(anyhow::anyhow!("500 server died"))));

let provider = build_test_provider(mock);
let token = AccessToken::new("test-token".to_string());
let err = provider.fetch_user_info(&token).await.unwrap_err();

assert!(
matches!(err, ProviderError::Transient { .. }),
"expected Transient, got: {err:?}"
);
}
}
4 changes: 4 additions & 0 deletions src/oauth/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod github_provider;
pub mod preflight;
pub mod provider;
pub mod registry;
Loading