Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
535 changes: 243 additions & 292 deletions passwords/api/Cargo.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions passwords/api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ serde_json = "1.0"
http = "1.0"
tower_governor = "0.8.0"
dashmap = "6"
axum-prometheus = "0.10"
4 changes: 4 additions & 0 deletions passwords/api/fly.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ primary_region = 'ord'
method = 'GET'
path = '/api/v2/generate'

[metrics]
port = 8000
path = "/metrics"

[[vm]]
memory = '256mb'
cpus = 1
Expand Down
78 changes: 62 additions & 16 deletions passwords/api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,13 @@ use axum::{
routing::{get, post},
Json, Router,
};
use axum_prometheus::metrics_exporter_prometheus::PrometheusHandle;
use axum_prometheus::PrometheusMetricLayer;
use db::DbError;
use encrypt::{generate_password, Credentials, CryptoError};
use env::EnvVars;
use serde::Deserialize;
use std::sync::OnceLock;
use tower_governor::governor::GovernorConfigBuilder;
use tower_governor::GovernorLayer;
use tower_http::cors::CorsLayer;
Expand All @@ -30,7 +33,28 @@ const RATE_LIMIT_REPLENISH_PERIOD_MS: u64 = 100;

/// Maximum burst size — the number of requests a client can make
/// before being throttled.
const RATE_LIMIT_BURST_SIZE: u32 = 10;
pub const RATE_LIMIT_BURST_SIZE: u32 = 10;

// ---------------------------------------------------------------------------
// Router configuration
// ---------------------------------------------------------------------------

/// Configuration for building the application router.
///
/// Use [`RouterConfig::default()`] for production settings, or construct
/// manually to override values (e.g. in tests).
pub struct RouterConfig {
/// Maximum number of requests a client can make before being throttled.
pub burst_size: u32,
}

impl Default for RouterConfig {
fn default() -> Self {
Self {
burst_size: RATE_LIMIT_BURST_SIZE,
}
}
}

fn is_valid_key_length(key: &str) -> bool {
key.len() <= MAX_KEY_LENGTH
Expand Down Expand Up @@ -251,15 +275,30 @@ async fn delete_user(creds: Credentials) -> Result<StatusCode, Error> {
Ok(StatusCode::OK)
}

// ---------------------------------------------------------------------------
// Prometheus metrics (initialized at most once per process)
// ---------------------------------------------------------------------------

/// Stores the Prometheus metric layer and handle so that
/// `PrometheusMetricLayer::pair()` (which installs a global recorder) is
/// called at most once. Subsequent calls to `prometheus_pair()` clone the
/// stored values.
static PROMETHEUS: OnceLock<(PrometheusMetricLayer<'static>, PrometheusHandle)> = OnceLock::new();

/// Return a `(layer, handle)` pair, creating it on the first call and
/// cloning the cached values on every subsequent call.
fn prometheus_pair() -> (PrometheusMetricLayer<'static>, PrometheusHandle) {
PROMETHEUS
.get_or_init(PrometheusMetricLayer::pair)
.clone()
}

// ---------------------------------------------------------------------------
// Application builder
// ---------------------------------------------------------------------------

/// Build the application router, using the supplied burst size for the
/// rate limiter. The normal entry point `build_router()` calls this with
/// [`RATE_LIMIT_BURST_SIZE`]. Tests may pass a much larger value to avoid
/// accidental 429s during their busy request sequences.
pub fn build_router_with_burst(burst_size: u32) -> Router {
/// Register all application routes (including conditional test-only routes).
fn app_routes() -> Router {
let app = Router::new()
.route("/api/v2/generate", get(generate))
.route("/api/v2/user", post(create_user).put(update_user))
Expand All @@ -277,6 +316,20 @@ pub fn build_router_with_burst(burst_size: u32) -> Router {
#[cfg(any(test, debug_assertions, feature = "test-helpers"))]
let app = app.route("/api/v2/user", axum::routing::delete(delete_user));

app
}

/// Build the application [`Router`] with middleware configured via [`RouterConfig`].
///
/// Safe to call multiple times — the Prometheus recorder is initialised once
/// and reused.
pub fn build_router(config: RouterConfig) -> Router {
let burst_size = config.burst_size;
let (prometheus_layer, metric_handle) = prometheus_pair();

let app = app_routes()
.route("/metrics", get(|| async move { metric_handle.render() }));

// Build the rate limiter configuration.
let mut rate_limit_builder = GovernorConfigBuilder::default()
.const_per_millisecond(RATE_LIMIT_REPLENISH_PERIOD_MS)
Expand All @@ -285,20 +338,13 @@ pub fn build_router_with_burst(burst_size: u32) -> Router {
.finish()
.expect("invalid rate-limit configuration");

// Layers wrap routes that were registered *before* the .layer() call.
// Order (outermost → innermost): CORS → rate-limit → tracing → handler.
// CORS must be outermost so preflight OPTIONS responses are never blocked
// by the rate limiter.
app.layer(TraceLayer::new_for_http())
// .layer() is last-added = outermost; read bottom-to-top for execution order.
app.layer(prometheus_layer)
.layer(TraceLayer::new_for_http())
.layer(GovernorLayer::new(rate_limit_config))
.layer(cors_layer())
}

/// Convenience wrapper used throughout the production binary.
pub fn build_router() -> Router {
build_router_with_burst(RATE_LIMIT_BURST_SIZE)
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
4 changes: 2 additions & 2 deletions passwords/api/src/main.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use passwords::{build_router, db};
use passwords::{build_router, db, RouterConfig};
use std::net::SocketAddr;
use tokio::net::TcpListener;
use tracing_subscriber::EnvFilter;
Expand All @@ -17,7 +17,7 @@ async fn main() -> Result<(), anyhow::Error> {
}
db::connect().await?;

let app = build_router().into_make_service_with_connect_info::<SocketAddr>();
let app = build_router(RouterConfig::default()).into_make_service_with_connect_info::<SocketAddr>();
let listener = TcpListener::bind("0.0.0.0:8000").await?;
tracing::info!(addr = %listener.local_addr()?, "listening");
axum::serve(listener, app).await?;
Expand Down
4 changes: 2 additions & 2 deletions passwords/api/tests/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use axum::body::Body;
use axum::{Router, middleware::from_fn, extract::ConnectInfo};
use http::Request;
use mongodb::bson::oid::ObjectId;
use passwords::build_router_with_burst;
use passwords::{build_router, RouterConfig};
use passwords::db;
use std::net::SocketAddr;
use std::sync::LazyLock;
Expand All @@ -32,7 +32,7 @@ static APP: LazyLock<Router> = LazyLock::new(|| {
db::connect().await.expect("Failed to connect to test DB");
// use a very large burst so ordinary tests aren't disrupted by our
// rate limiter; stress test will create its own router below.
build_router_with_burst(1_000_000)
build_router(RouterConfig { burst_size: 1_000_000 })
})
});

Expand Down
4 changes: 2 additions & 2 deletions passwords/api/tests/integration_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use axum::body::Body;
use axum::{middleware::from_fn, extract::ConnectInfo};
use common::{app, body_string, parse_json, run, TestUser, WithAuth};
use http::{Request, StatusCode};
use passwords::build_router;
use passwords::{build_router, RouterConfig};
use std::net::SocketAddr;
use std::time::Duration;
use tower::ServiceExt;
Expand Down Expand Up @@ -69,7 +69,7 @@ fn test_rate_limiting() {
// observe throttling. We still need the connect-info middleware that
// `app()` adds, so copy that behaviour.
let addr = SocketAddr::from(([127, 0, 0, 1], 0));
let limited = build_router()
let limited = build_router(RouterConfig { burst_size: 10 })
.layer(from_fn(move |mut req: Request<Body>, next: axum::middleware::Next| async move {
req.extensions_mut().insert(ConnectInfo(addr));
next.run(req).await
Expand Down
49 changes: 49 additions & 0 deletions passwords/api/tests/metrics_tests.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
//! Metrics endpoint tests for the MapoPass API.
//!
//! Verifies that the Prometheus metrics endpoint is working correctly.
//!
//! ## Running
//!
//! ```sh
//! # From passwords/api/:
//! cargo test --test metrics_tests --features test-helpers
//! ```

mod common;

use axum::body::Body;
use common::{app, body_string, run};
use http::{Request, StatusCode};
use tower::ServiceExt;

#[test]
fn test_metrics_endpoint() {
run(async {
// Make a request first so that metrics are recorded.
let warmup = Request::builder()
.method("GET")
.uri("/api/v2/generate")
.body(Body::empty())
.unwrap();
let warmup_res = app().oneshot(warmup).await.unwrap();
assert_eq!(warmup_res.status(), StatusCode::OK);

// Now fetch /metrics and verify the Prometheus exposition format.
let req = Request::builder()
.method("GET")
.uri("/metrics")
.body(Body::empty())
.unwrap();
let res = app().oneshot(req).await.unwrap();
assert_eq!(res.status(), StatusCode::OK);
let body = body_string(res).await;
assert!(
body.contains("axum_http_requests_total"),
"expected axum_http_requests_total counter in /metrics response",
);
assert!(
body.contains("axum_http_requests_duration_seconds"),
"expected axum_http_requests_duration_seconds histogram in /metrics response",
);
});
}
Loading