-
Notifications
You must be signed in to change notification settings - Fork 710
Genericize OAuth profiles #13477
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Genericize OAuth profiles #13477
Changes from 2 commits
319fa39
633d64d
6662fb9
239b17c
12b4ca8
22b3133
9c6eadc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| -- This migration is a data backfill and has no meaningful inverse. Rolling | ||
| -- it back would require distinguishing backfilled rows from rows inserted | ||
| -- by login, which we cannot do. The `oauth_github` table itself is dropped | ||
| -- by the earlier migration 2026-01-20-162913_oauth_github_table/down.sql. | ||
| -- Intentionally a no-op. | ||
| SELECT 1; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,26 @@ | ||
| -- Backfill oauth_github for users who haven't logged in since the table | ||
| -- was created. Batched by id range, idempotent via ON CONFLICT DO NOTHING. | ||
|
|
||
| SET LOCAL lock_timeout = '10s'; | ||
| SET LOCAL statement_timeout = '120s'; | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you elaborate on why this commit is needed? After we merged #12792, we ran the SQL in https://github.com/rust-lang/crates.io/blob/150d165e4d5ecda24b60f146589db4def7fd34bf/migrations/data_oauth_github.sql outside of the migration framework that runs on deploy, so the backfill has already been handled. |
||
|
|
||
| DO $$ | ||
| DECLARE | ||
| lo INT; | ||
| hi INT; | ||
| pos INT; | ||
| BEGIN | ||
| SELECT MIN(id), MAX(id) INTO lo, hi FROM users WHERE gh_id > 0; | ||
| IF lo IS NULL THEN RETURN; END IF; | ||
|
|
||
| pos := lo; | ||
| WHILE pos <= hi LOOP | ||
| INSERT INTO oauth_github (account_id, user_id, encrypted_token, login, avatar) | ||
| SELECT gh_id, id, gh_encrypted_token, gh_login, gh_avatar | ||
| FROM users | ||
| WHERE gh_id > 0 AND id >= pos AND id < pos + 5000 | ||
| ON CONFLICT (account_id) DO NOTHING; | ||
|
|
||
| pos := pos + 5000; | ||
| END LOOP; | ||
| END $$; | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
| extern crate tracing; | ||
|
|
||
| use crates_io::middleware::normalize_path::normalize_path; | ||
| use crates_io::oauth::preflight::check_oauth_github_backfill; | ||
| use crates_io::{App, Emails, metrics::LogEncoder}; | ||
| use std::{sync::Arc, time::Duration}; | ||
|
|
||
|
|
@@ -33,13 +34,26 @@ fn main() -> anyhow::Result<()> { | |
| let user_agent = crates_io_version::user_agent(); | ||
| let client = Client::builder().user_agent(user_agent).build()?; | ||
|
|
||
| let github = RealGitHubClient::new(client); | ||
| let github = Box::new(github); | ||
| let github: std::sync::Arc<dyn crates_io_github::GitHubClient> = | ||
| std::sync::Arc::new(RealGitHubClient::new(client.clone())); | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IMO this commit should be pulled out into a separate PR from anything having to do with usernames-- we want to introduce the crates.io username first, separately from adding any other oauth provider. |
||
|
|
||
| // Build the ProviderRegistry — currently only GitHub, wired up here so | ||
| // Commit 3 can route session::begin/authorize through it. | ||
| let github_oauth_for_provider = crates_io::app::build_github_oauth_client(&config); | ||
| let mut oauth_providers = crates_io::oauth::registry::ProviderRegistry::new(); | ||
| oauth_providers.register(std::sync::Arc::new( | ||
| crates_io::oauth::github_provider::GitHubProvider::new( | ||
| github_oauth_for_provider, | ||
| github.clone(), | ||
| client.clone(), | ||
| ), | ||
| )); | ||
|
|
||
| let app = App::builder() | ||
| .databases_from_config(&config.db) | ||
| .github(github) | ||
| .github(github.clone()) | ||
| .github_oauth_from_config(&config) | ||
| .oauth_providers(oauth_providers) | ||
| .trustpub_providers(&list("TRUSTPUB_PROVIDERS")?) | ||
| .emails(emails) | ||
| .storage_from_config(&config.storage) | ||
|
|
@@ -73,6 +87,18 @@ fn main() -> anyhow::Result<()> { | |
|
|
||
| // Block the main thread until the server has shutdown | ||
| rt.block_on(async { | ||
| // Run startup preflight checks before accepting any connections. | ||
| { | ||
| let mut conn = app | ||
| .db_write() | ||
| .await | ||
| .map_err(std::io::Error::other)?; | ||
| check_oauth_github_backfill(&mut conn) | ||
| .await | ||
| .map_err(std::io::Error::other)?; | ||
| tracing::info!("oauth_github preflight check passed"); | ||
| } | ||
|
|
||
| // Create a `TcpListener` using tokio. | ||
| let listener = TcpListener::bind((app.config.ip, app.config.port)).await?; | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,230 @@ | ||
| //! [`OAuthProvider`] implementation backed by GitHub OAuth2. | ||
| //! | ||
| //! This wraps the existing GitHub OAuth2 client ([`BasicClient`]) and the | ||
| //! [`GitHubClient`] API client that were already present in [`crate::app::App`]. | ||
| //! No new HTTP plumbing is introduced here — this is purely a delegation layer | ||
| //! that maps GitHub-specific types to the provider-agnostic trait surface. | ||
|
|
||
| use std::sync::Arc; | ||
|
|
||
| use async_trait::async_trait; | ||
| use crates_io_github::{GitHubClient, GitHubError}; | ||
| use oauth2::basic::{BasicClient, BasicErrorResponseType}; | ||
| use oauth2::{AccessToken, AuthorizationCode, CsrfToken, EndpointNotSet, EndpointSet, RequestTokenError, Scope, TokenResponse}; | ||
| use url::Url; | ||
|
|
||
| use crate::util::oauth::ReqwestClient; | ||
|
|
||
| use super::provider::{OAuthProvider, ProviderError, UserInfo}; | ||
|
|
||
| /// Type alias matching the field type in [`crate::app::App`]. | ||
| pub type GithubBasicClient = | ||
| BasicClient<EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet>; | ||
|
|
||
| pub struct GitHubProvider { | ||
| oauth: GithubBasicClient, | ||
| client: Arc<dyn GitHubClient>, | ||
| http: reqwest::Client, | ||
| } | ||
|
|
||
| impl GitHubProvider { | ||
| pub fn new( | ||
| oauth: GithubBasicClient, | ||
| client: Arc<dyn GitHubClient>, | ||
| http: reqwest::Client, | ||
| ) -> Self { | ||
| Self { oauth, client, http } | ||
| } | ||
| } | ||
|
|
||
| #[async_trait] | ||
| impl OAuthProvider for GitHubProvider { | ||
| fn name(&self) -> &'static str { | ||
| "github" | ||
| } | ||
|
|
||
| fn authorize_url(&self) -> (Url, CsrfToken) { | ||
| self.oauth | ||
| .authorize_url(CsrfToken::new_random) | ||
| .add_scope(Scope::new("read:org".to_string())) | ||
| .url() | ||
| } | ||
|
|
||
| async fn exchange_code(&self, code: &str) -> Result<AccessToken, ProviderError> { | ||
| let http = ReqwestClient(self.http.clone()); | ||
| let token_result = self | ||
| .oauth | ||
| .exchange_code(AuthorizationCode::new(code.to_string())) | ||
| .request_async(&http) | ||
| .await; | ||
|
|
||
| match token_result { | ||
| Ok(response) => Ok(response.access_token().clone()), | ||
| Err(RequestTokenError::Request(e)) => Err(ProviderError::Transient { | ||
| source: Box::new(e), | ||
| }), | ||
| Err(RequestTokenError::ServerResponse(resp)) => { | ||
| // check if this is an "invalid code" error via direct enum matching | ||
| let is_invalid_code = matches!(resp.error(), BasicErrorResponseType::InvalidGrant) | ||
| || matches!(resp.error(), BasicErrorResponseType::Extension(s) if s == "bad_verification_code"); | ||
|
|
||
| if is_invalid_code { | ||
| Err(ProviderError::InvalidCode) | ||
| } else { | ||
| // Format the server error as a string — `StandardErrorResponse` | ||
| // implements `Display` but not `std::error::Error`. | ||
| Err(ProviderError::Malformed(format!("{resp}"))) | ||
| } | ||
| } | ||
| Err(RequestTokenError::Parse(e, _bytes)) => { | ||
| Err(ProviderError::Malformed(e.to_string())) | ||
| } | ||
| // The `RequestTokenError` enum is non-exhaustive; any future | ||
| // variant is treated as a transient infrastructure error. | ||
| Err(e) => Err(ProviderError::Malformed(e.to_string())), | ||
| } | ||
| } | ||
|
|
||
| async fn fetch_user_info(&self, token: &AccessToken) -> Result<UserInfo, ProviderError> { | ||
| match self.client.current_user(token).await { | ||
| Ok(gh) => Ok(UserInfo { | ||
| account_id: gh.id.to_string(), | ||
| login: gh.login, | ||
| name: gh.name, | ||
| avatar_url: gh.avatar_url, | ||
| email: gh.email, | ||
| }), | ||
| Err(GitHubError::Unauthorized(_)) => Err(ProviderError::Unauthorized), | ||
| Err(e) => Err(ProviderError::Transient { | ||
| source: anyhow::Error::from(e).into(), | ||
| }), | ||
| } | ||
| } | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
| use crates_io_github::{GitHubError, GitHubUser, MockGitHubClient}; | ||
|
|
||
| fn build_test_oauth_client() -> GithubBasicClient { | ||
| use oauth2::{AuthUrl, ClientId, ClientSecret, TokenUrl}; | ||
| BasicClient::new(ClientId::new("test-id".to_string())) | ||
| .set_client_secret(ClientSecret::new("test-secret".to_string())) | ||
| .set_auth_uri( | ||
| AuthUrl::new("https://github.com/login/oauth/authorize".into()).unwrap(), | ||
| ) | ||
| .set_token_uri( | ||
| TokenUrl::new("https://github.com/login/oauth/access_token".into()).unwrap(), | ||
| ) | ||
| } | ||
|
|
||
| fn build_test_provider(mock: MockGitHubClient) -> GitHubProvider { | ||
| GitHubProvider::new( | ||
| build_test_oauth_client(), | ||
| Arc::new(mock), | ||
| reqwest::Client::new(), | ||
| ) | ||
| } | ||
|
|
||
| #[test] | ||
| fn name_is_github() { | ||
| let provider = build_test_provider(MockGitHubClient::new()); | ||
| assert_eq!(provider.name(), "github"); | ||
| } | ||
|
|
||
| #[test] | ||
| fn authorize_url_contains_client_id_and_read_org_scope() { | ||
| let provider = build_test_provider(MockGitHubClient::new()); | ||
| let (url, _csrf) = provider.authorize_url(); | ||
| let query = url.query().unwrap_or_default(); | ||
| assert!( | ||
| query.contains("client_id=test-id"), | ||
| "expected client_id=test-id in query, got: {query}" | ||
| ); | ||
| assert!( | ||
| query.contains("scope=read%3Aorg"), | ||
| "expected scope=read%3Aorg in query, got: {query}" | ||
| ); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn fetch_user_info_converts_github_user_to_user_info() { | ||
| let mut mock = MockGitHubClient::new(); | ||
| mock.expect_current_user().returning(|_| { | ||
| Ok(GitHubUser { | ||
| id: 42, | ||
| login: "octocat".to_string(), | ||
| name: Some("Octo Cat".to_string()), | ||
| avatar_url: Some("https://example.com/avatar.png".to_string()), | ||
| email: Some("octocat@example.com".to_string()), | ||
| }) | ||
| }); | ||
|
|
||
| let provider = build_test_provider(mock); | ||
| let token = AccessToken::new("test-token".to_string()); | ||
| let info = provider.fetch_user_info(&token).await.unwrap(); | ||
|
|
||
| assert_eq!(info.account_id, "42"); | ||
| assert_eq!(info.login, "octocat"); | ||
| assert_eq!(info.name, Some("Octo Cat".to_string())); | ||
| assert_eq!( | ||
| info.avatar_url, | ||
| Some("https://example.com/avatar.png".to_string()) | ||
| ); | ||
| assert_eq!(info.email, Some("octocat@example.com".to_string())); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn fetch_user_info_maps_none_optional_fields() { | ||
| let mut mock = MockGitHubClient::new(); | ||
| mock.expect_current_user().returning(|_| { | ||
| Ok(GitHubUser { | ||
| id: 1, | ||
| login: "ghost".to_string(), | ||
| name: None, | ||
| avatar_url: None, | ||
| email: None, | ||
| }) | ||
| }); | ||
|
|
||
| let provider = build_test_provider(mock); | ||
| let token = AccessToken::new("test-token".to_string()); | ||
| let info = provider.fetch_user_info(&token).await.unwrap(); | ||
|
|
||
| assert_eq!(info.name, None); | ||
| assert_eq!(info.avatar_url, None); | ||
| assert_eq!(info.email, None); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn fetch_user_info_maps_401_to_unauthorized() { | ||
| let mut mock = MockGitHubClient::new(); | ||
| mock.expect_current_user() | ||
| .returning(|_| Err(GitHubError::Unauthorized(anyhow::anyhow!("401 Unauthorized")))); | ||
|
|
||
| let provider = build_test_provider(mock); | ||
| let token = AccessToken::new("bad-token".to_string()); | ||
| let err = provider.fetch_user_info(&token).await.unwrap_err(); | ||
|
|
||
| assert!( | ||
| matches!(err, ProviderError::Unauthorized), | ||
| "expected Unauthorized, got: {err:?}" | ||
| ); | ||
| } | ||
|
|
||
| #[tokio::test] | ||
| async fn fetch_user_info_maps_other_errors_to_transient() { | ||
| let mut mock = MockGitHubClient::new(); | ||
| mock.expect_current_user() | ||
| .returning(|_| Err(GitHubError::Other(anyhow::anyhow!("500 server died")))); | ||
|
|
||
| let provider = build_test_provider(mock); | ||
| let token = AccessToken::new("test-token".to_string()); | ||
| let err = provider.fetch_user_info(&token).await.unwrap_err(); | ||
|
|
||
| assert!( | ||
| matches!(err, ProviderError::Transient { .. }), | ||
| "expected Transient, got: {err:?}" | ||
| ); | ||
| } | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| pub mod github_provider; | ||
| pub mod preflight; | ||
| pub mod provider; | ||
| pub mod registry; |
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah any comment corrections should be made in a separate commit, possibly a separate PR
View changes since the review