diff --git a/benchmarks/README.md b/benchmarks/README.md new file mode 100644 index 00000000..d5412a98 --- /dev/null +++ b/benchmarks/README.md @@ -0,0 +1,156 @@ +# Benchmarks + +Suite de benchmarks mantenible para Hypercorn. Su funcion es validar cambios de rendimiento con una metodologia reproducible y suficientemente estable como para vivir en `main`. + +## Objetivos + +- medir mejoras y regresiones reales en rutas criticas; +- comparar ramas o refs contra un baseline reproducible; +- separar benchmarks generales de escenarios dirigidos; +- priorizar resultados robustos frente a cifras puntuales. + +## Metodologia por defecto + +Todos los comparadores `benchmarks/compare*.py` siguen estas reglas por defecto: + +- intercalan runs de `current` y `baseline`; +- resumen por medianas de run; +- exponen `mean`, `median` y `p95`; +- dejan `--sequential` solo para diagnostico o compatibilidad historica. + +Esto reduce sesgos por orden de ejecucion, calentamiento, drift del sistema y ruido de la maquina. + +## Estructura + +- `benchmarks/app.py` + App ASGI minima usada como objetivo de los benchmarks de servidor. +- `benchmarks/_runtime.py` + Infraestructura compartida de servidor, TLS, readiness, puertos y percentiles. +- `benchmarks/_compare.py` + Infraestructura compartida de worktrees, intercalado, resumen y salida JSON. +- `benchmarks/run_load.py` + Benchmark dirigido de HoL en HTTP/2. +- `benchmarks/fragmented_body.py` + Benchmark dirigido de request body H2 muy fragmentado. +- `benchmarks/general.py` + Benchmark general para `/fast` y otras rutas HTTP. +- `benchmarks/ws.py` + Benchmark de eco WebSocket. +- `benchmarks/h3.py` + Benchmark real de HTTP/3 sobre QUIC con `aioquic`. +- `benchmarks/task_group.py` + Microbenchmark de `TaskGroup.spawn_app()`. +- `benchmarks/compare*.py` + Comparadores contra otro ref o repo. + +## Escenarios disponibles + +### HoL HTTP/2 + +Una sola conexion HTTP/2, un stream lento y varios streams rapidos multiplexados. Mide si existe bloqueo global a nivel de conexion. + +```bash +python -m benchmarks.run_load +python -m benchmarks.compare --baseline-ref upstream/main +``` + +### Body fragmentado HTTP/2 + +Ejercita `QueuedStream` y el coste de entregar muchos `DATA` pequenos al app ASGI. + +```bash +python -m benchmarks.fragmented_body +python -m benchmarks.compare_fragmented_body --baseline-ref upstream/main +``` + +### Benchmark general HTTP + +Sirve para medir el camino rapido sin mezclarlo con escenarios artificiales. + +```bash +python -m benchmarks.general --http-version 1.1 +python -m benchmarks.general --http-version 1.1 --tls +python -m benchmarks.general --http-version 2 +python -m benchmarks.compare_general --baseline-ref upstream/main +``` + +El comparador general ejecuta: + +- HTTP/1.1 sin TLS +- HTTP/1.1 con TLS +- HTTP/2 con TLS + +### WebSocket echo + +Valida handshake y eco binario con payload configurable. + +```bash +python -m benchmarks.ws --tls +python -m benchmarks.compare_ws --baseline-ref upstream/main --tls +``` + +### HTTP/3 real + +Mide QUIC/H3 real con una conexion H3 multiplexada y un cliente `aioquic`. + +```bash +python -m benchmarks.h3 +python -m benchmarks.compare_h3 --baseline-ref upstream/main +``` + +### TaskGroup + +Microbenchmark de investigacion para separar el coste fijo de `TaskGroup.spawn_app()` del servidor completo. + +```bash +python -m benchmarks.task_group --mode asgi +python -m benchmarks.task_group --mode wsgi +python -m benchmarks.compare_task_group --mode asgi --baseline-ref upstream/main +python -m benchmarks.compare_task_group --mode wsgi --baseline-ref upstream/main +``` + +## Comandos habituales + +Comparar contra `upstream/main`: + +```bash +python -m benchmarks.compare_general --baseline-ref upstream/main --runs 6 +python -m benchmarks.compare --baseline-ref upstream/main --runs 6 +python -m benchmarks.compare_h3 --baseline-ref upstream/main --runs 4 +``` + +Comparar contra un repo ya existente sin crear worktree: + +```bash +python -m benchmarks.compare_general --baseline-path /ruta/a/otro/repo --runs 6 +``` + +Guardar resultados en JSON: + +```bash +python -m benchmarks.compare_general \ + --baseline-ref upstream/main \ + --runs 6 \ + --output-json benchmarks/results/general.json +``` + +## Guia de interpretacion + +- `p95` es el guard-rail principal de cola larga. +- `mean` y `median` ayudan a distinguir mejora general de mejora puntual. +- `req/s` o `messages/s` sirven para leer throughput, pero no deben ocultar empeoramientos claros de latencia. +- Los benchmarks dirigidos como HoL o body fragmentado validan cambios estructurales. +- Los benchmarks generales detectan si una optimizacion local rompe el balance global. + +## Mantenimiento + +- Antes de aceptar un cambio de rendimiento, medir contra un baseline estable. +- No usar una sola pasada secuencial para sacar conclusiones. +- Si un escenario es nuevo, anadir primero un benchmark reproducible y luego el cambio. +- Mantener la app de benchmark y los comparadores pequenos, explicitamente documentados y sin dependencias ocultas. + +## Notas + +- La suite usa `benchmarks/app.py` como app objetivo. +- `tests/assets/cert.pem` y `tests/assets/key.pem` habilitan TLS y ALPN para H2. +- Los resultados son comparaciones relativas entre ramas o refs, no numeros absolutos publicables. diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/benchmarks/__init__.py @@ -0,0 +1 @@ + diff --git a/benchmarks/_compare.py b/benchmarks/_compare.py new file mode 100644 index 00000000..e36488a9 --- /dev/null +++ b/benchmarks/_compare.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import json +import shutil +import subprocess +import tempfile +from dataclasses import replace +from pathlib import Path +from typing import Any, Callable, Sequence, TypeVar + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +INTERLEAVED_METHODOLOGY = "interleaved median-of-runs" +SEQUENTIAL_METHODOLOGY = "sequential median-of-runs" + +T = TypeVar("T") + + +def create_worktree(ref: str, fetch: bool) -> tuple[Callable[[], None], Path]: + if fetch: + subprocess.run(["git", "fetch", "upstream"], cwd=PROJECT_ROOT, check=True) + + tempdir = Path(tempfile.mkdtemp(prefix="hypercorn-bench-")) + subprocess.run( + ["git", "worktree", "add", "--detach", str(tempdir), ref], + cwd=PROJECT_ROOT, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + + def cleanup() -> None: + subprocess.run( + ["git", "worktree", "remove", "--force", str(tempdir)], + cwd=PROJECT_ROOT, + check=True, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + ) + shutil.rmtree(tempdir, ignore_errors=True) + + return cleanup, tempdir + + +def percentage_improvement(current: float, baseline: float) -> float: + if baseline == 0: + return 0.0 + return ((baseline - current) / baseline) * 100 + + +def percentage_growth(current: float, baseline: float) -> float: + if baseline == 0: + return 0.0 + return ((current - baseline) / baseline) * 100 + + +def methodology_name(*, sequential: bool) -> str: + return SEQUENTIAL_METHODOLOGY if sequential else INTERLEAVED_METHODOLOGY + + +async def run_interleaved_async( + runs: int, + current_runner: Callable[[int], Any], + baseline_runner: Callable[[int], Any], + *, + interleave: bool, +) -> tuple[list[Any], list[Any]]: + current_runs: list[Any] = [] + baseline_runs: list[Any] = [] + for index in range(runs): + if interleave and (index % 2 == 1): + baseline_runs.append(await baseline_runner(index)) + current_runs.append(await current_runner(index)) + else: + current_runs.append(await current_runner(index)) + baseline_runs.append(await baseline_runner(index)) + return current_runs, baseline_runs + + +def run_interleaved_sync( + runs: int, + current_runner: Callable[[int], T], + baseline_runner: Callable[[int], T], + *, + interleave: bool, +) -> tuple[list[T], list[T]]: + current_runs: list[T] = [] + baseline_runs: list[T] = [] + for index in range(runs): + if interleave and (index % 2 == 1): + baseline_runs.append(baseline_runner(index)) + current_runs.append(current_runner(index)) + else: + current_runs.append(current_runner(index)) + baseline_runs.append(baseline_runner(index)) + return current_runs, baseline_runs + + +def summarize_dataclass_runs( + label: str, + runs: Sequence[T], + *, + extra_fields: dict[str, Callable[[Sequence[T]], Any]] | None = None, +) -> T: + if not runs: + raise ValueError("Expected at least one benchmark run") + + overrides: dict[str, Any] = { + "target_label": label, + "samples_ms": [sample for run in runs for sample in getattr(run, "samples_ms")], + "mean_ms": _median(getattr(run, "mean_ms") for run in runs), + "median_ms": _median(getattr(run, "median_ms") for run in runs), + "p95_ms": _median(getattr(run, "p95_ms") for run in runs), + "minimum_ms": _median(getattr(run, "minimum_ms") for run in runs), + "maximum_ms": _median(getattr(run, "maximum_ms") for run in runs), + } + if extra_fields is not None: + for field, aggregator in extra_fields.items(): + overrides[field] = aggregator(runs) + return replace(runs[0], **overrides) + + +def build_comparison_result( + current: Any, + baseline: Any, + *, + throughput_field: str | None = None, + throughput_delta_field: str | None = None, + throughput_improvement_field: str | None = None, +) -> dict[str, Any]: + payload = { + "current": current.__dict__, + "baseline": baseline.__dict__, + "delta_mean_ms": current.mean_ms - baseline.mean_ms, + "delta_median_ms": current.median_ms - baseline.median_ms, + "delta_p95_ms": current.p95_ms - baseline.p95_ms, + "improvement_mean_percent": percentage_improvement(current.mean_ms, baseline.mean_ms), + "improvement_median_percent": percentage_improvement(current.median_ms, baseline.median_ms), + "improvement_p95_percent": percentage_improvement(current.p95_ms, baseline.p95_ms), + } + if throughput_field is not None: + current_value = getattr(current, throughput_field) + baseline_value = getattr(baseline, throughput_field) + payload[throughput_delta_field or f"delta_{throughput_field}"] = current_value - baseline_value + payload[throughput_improvement_field or f"improvement_{throughput_field}_percent"] = percentage_growth( + current_value, baseline_value + ) + return payload + + +def write_json_output(payload: dict[str, Any], output_json: str | None) -> None: + encoded = json.dumps(payload, indent=2) + "\n" + print(encoded, end="") + if output_json is not None: + Path(output_json).write_text(encoded) + + +def _median(values) -> float: + ordered = sorted(values) + middle = len(ordered) // 2 + if len(ordered) % 2 == 1: + return float(ordered[middle]) + return (ordered[middle - 1] + ordered[middle]) / 2 diff --git a/benchmarks/_runtime.py b/benchmarks/_runtime.py new file mode 100644 index 00000000..d995ac5b --- /dev/null +++ b/benchmarks/_runtime.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +import asyncio +import os +import socket +import ssl +import subprocess +import sys +import time +from pathlib import Path + +from h2.connection import H2Connection +from h2.events import DataReceived, StreamEnded + +PROJECT_ROOT = Path(__file__).resolve().parent.parent +CERTFILE = PROJECT_ROOT / "tests" / "assets" / "cert.pem" +KEYFILE = PROJECT_ROOT / "tests" / "assets" / "key.pem" + + +class ServerProcess: + def __init__(self, server_repo: Path, *, tls: bool = True) -> None: + self.port = reserve_port() + self.server_repo = server_repo + self.tls = tls + self.process: subprocess.Popen[str] | None = None + + async def __aenter__(self) -> ServerProcess: + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([str(self.server_repo / "src"), str(PROJECT_ROOT)]) + command = [ + sys.executable, + "-m", + "hypercorn", + "--bind", + f"127.0.0.1:{self.port}", + "--workers", + "1", + "--worker-class", + "asyncio", + ] + if self.tls: + command.extend(["--certfile", str(CERTFILE), "--keyfile", str(KEYFILE)]) + command.append("benchmarks.app:app") + self.process = subprocess.Popen( + command, + cwd=PROJECT_ROOT, + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + ) + await wait_for_ready(self.port, tls=self.tls) + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + if self.process is None: + return + self.process.terminate() + try: + await asyncio.wait_for(asyncio.to_thread(self.process.wait), timeout=5) + except asyncio.TimeoutError: + self.process.kill() + await asyncio.to_thread(self.process.wait) + + +async def wait_for_ready(port: int, *, tls: bool) -> None: + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + if not tls: + if await _wait_for_ready_http11(port): + return + await asyncio.sleep(0.05) + continue + + try: + reader, writer = await asyncio.open_connection( + "127.0.0.1", + port, + ssl=build_ssl_context(), + server_hostname="localhost", + ) + except OSError: + await asyncio.sleep(0.05) + continue + + conn = H2Connection() + conn.initiate_connection() + writer.write(conn.data_to_send()) + writer.write(build_ready_request(conn)) + await writer.drain() + + while True: + data = await asyncio.wait_for(reader.read(65535), timeout=1) + if not data: + break + for event in conn.receive_data(data): + if isinstance(event, DataReceived): + conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + elif isinstance(event, StreamEnded) and event.stream_id == 1: + writer.close() + await writer.wait_closed() + return + pending = conn.data_to_send() + if pending: + writer.write(pending) + await writer.drain() + writer.close() + await writer.wait_closed() + await asyncio.sleep(0.05) + raise RuntimeError(f"Timed out waiting for benchmark server on port {port}") + + +async def _wait_for_ready_http11(port: int) -> bool: + try: + reader, writer = await asyncio.open_connection("127.0.0.1", port) + except OSError: + return False + + try: + writer.write(b"GET /ready HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n") + await writer.drain() + status_line = await asyncio.wait_for(reader.readline(), timeout=1) + return status_line.startswith(b"HTTP/1.1 200") + except (OSError, asyncio.TimeoutError): + return False + finally: + writer.close() + await writer.wait_closed() + + +def build_ready_request(conn: H2Connection) -> bytes: + conn.send_headers( + 1, + [ + (":method", "GET"), + (":scheme", "https"), + (":authority", "localhost"), + (":path", "/ready"), + ], + end_stream=True, + ) + return conn.data_to_send() + + +def build_ssl_context(*, alpn_protocols: list[str] | None = None) -> ssl.SSLContext: + context = ssl.create_default_context() + context.check_hostname = False + context.verify_mode = ssl.CERT_NONE + context.set_alpn_protocols(alpn_protocols or ["h2"]) + return context + + +def percentile(values: list[float], ratio: float) -> float: + ordered = sorted(values) + index = max(0, min(len(ordered) - 1, round((len(ordered) - 1) * ratio))) + return ordered[index] + + +def reserve_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) + + +def reserve_udp_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock: + sock.bind(("127.0.0.1", 0)) + return int(sock.getsockname()[1]) diff --git a/benchmarks/app.py b/benchmarks/app.py new file mode 100644 index 00000000..cd0afb8f --- /dev/null +++ b/benchmarks/app.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +import asyncio +from urllib.parse import parse_qs + + +async def app(scope, receive, send) -> None: + if scope["type"] == "websocket": + await _websocket_app(receive, send) + return + if scope["type"] != "http": + return + + path = scope["path"] + query = parse_qs(scope["query_string"].decode("ascii")) + delay_ms = int(query.get("delay_ms", ["0"])[0]) + + if path == "/ready": + await _send_response(send, 200, b"ready") + return + + body = bytearray() + while True: + message = await receive() + if message["type"] != "http.request": + continue + + body.extend(message.get("body", b"")) + if path == "/slow-read" and delay_ms > 0 and message.get("body", b"") != b"": + await asyncio.sleep(delay_ms / 1000) + + if not message.get("more_body", False): + break + + if path == "/fast": + await _send_response(send, 200, b"fast") + elif path == "/echo-body": + await _send_response(send, 200, str(len(body)).encode("ascii")) + elif path == "/slow-read": + await _send_response(send, 200, b"slow") + elif path == "/large-response": + chunks = int(query.get("chunks", ["64"])[0]) + chunk_size = int(query.get("chunk_size", ["16384"])[0]) + await _send_streaming_response(send, 200, chunks, chunk_size) + else: + await _send_response(send, 404, b"not-found") + + +async def _websocket_app(receive, send) -> None: + await send({"type": "websocket.accept"}) + + while True: + message = await receive() + if message["type"] == "websocket.disconnect": + return + if message["type"] != "websocket.receive": + continue + + if message.get("bytes") is not None: + await send({"type": "websocket.send", "bytes": message["bytes"]}) + else: + await send({"type": "websocket.send", "text": message["text"]}) + + +async def _send_response(send, status: int, body: bytes) -> None: + await send( + { + "type": "http.response.start", + "status": status, + "headers": [(b"content-length", str(len(body)).encode("ascii"))], + } + ) + await send({"type": "http.response.body", "body": body, "more_body": False}) + + +async def _send_streaming_response(send, status: int, chunks: int, chunk_size: int) -> None: + total_length = chunks * chunk_size + chunk = b"x" * chunk_size + await send( + { + "type": "http.response.start", + "status": status, + "headers": [(b"content-length", str(total_length).encode("ascii"))], + } + ) + for index in range(chunks): + await send( + { + "type": "http.response.body", + "body": chunk, + "more_body": index < (chunks - 1), + } + ) diff --git a/benchmarks/compare.py b/benchmarks/compare.py new file mode 100644 index 00000000..3cec54bf --- /dev/null +++ b/benchmarks/compare.py @@ -0,0 +1,83 @@ +from __future__ import annotations + +import argparse +import asyncio +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_async, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.run_load import BenchmarkConfig, run_benchmark + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + config = BenchmarkConfig( + warmup_iterations=args.warmup_iterations, + measured_iterations=args.measured_iterations, + fast_streams=args.fast_streams, + slow_chunks=args.slow_chunks, + slow_chunk_size=args.slow_chunk_size, + slow_delay_ms=args.slow_delay_ms, + ) + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + try: + current_runs, baseline_runs = await run_interleaved_async( + args.runs, + lambda index: run_benchmark(PROJECT_ROOT, f"current-run-{index + 1}", config), + lambda index: run_benchmark( + baseline_repo, + f"baseline:{args.baseline_ref}-run-{index + 1}", + config, + ), + interleave=not args.sequential, + ) + finally: + if cleanup is not None: + cleanup() + + current = summarize_dataclass_runs("current", current_runs) + baseline = summarize_dataclass_runs(f"baseline:{args.baseline_ref}", baseline_runs) + payload = { + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + **build_comparison_result(current, baseline), + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare local Hypercorn against upstream for the H2 HoL benchmark.") + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path", help="Optional existing repo path to use as baseline instead of creating a worktree.") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--warmup-iterations", type=int, default=2) + parser.add_argument("--measured-iterations", type=int, default=10) + parser.add_argument("--fast-streams", type=int, default=5) + parser.add_argument("--slow-chunks", type=int, default=12) + parser.add_argument("--slow-chunk-size", type=int, default=4096) + parser.add_argument("--slow-delay-ms", type=int, default=25) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/compare_fragmented_body.py b/benchmarks/compare_fragmented_body.py new file mode 100644 index 00000000..208fb508 --- /dev/null +++ b/benchmarks/compare_fragmented_body.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import argparse +import asyncio +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_async, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.fragmented_body import FragmentedBodyBenchmarkConfig, run_fragmented_body_benchmark + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + config = FragmentedBodyBenchmarkConfig( + warmup_iterations=args.warmup_iterations, + measured_iterations=args.measured_iterations, + chunks=args.chunks, + chunk_size=args.chunk_size, + delay_ms=args.delay_ms, + ) + + try: + current_runs, baseline_runs = await run_interleaved_async( + args.runs, + lambda index: run_fragmented_body_benchmark( + PROJECT_ROOT, + f"current-fragmented-run-{index + 1}", + config, + ), + lambda index: run_fragmented_body_benchmark( + baseline_repo, + f"baseline-{args.baseline_ref}-fragmented-run-{index + 1}", + config, + ), + interleave=not args.sequential, + ) + finally: + if cleanup is not None: + cleanup() + + current = summarize_dataclass_runs("current-fragmented", current_runs) + baseline = summarize_dataclass_runs(f"baseline-{args.baseline_ref}-fragmented", baseline_runs) + payload = { + "baseline_ref": args.baseline_ref, + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + **build_comparison_result(current, baseline), + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Compare local Hypercorn against upstream for fragmented HTTP/2 request bodies." + ) + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--warmup-iterations", type=int, default=5) + parser.add_argument("--measured-iterations", type=int, default=50) + parser.add_argument("--chunks", type=int, default=256) + parser.add_argument("--chunk-size", type=int, default=128) + parser.add_argument("--delay-ms", type=int, default=1) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/compare_general.py b/benchmarks/compare_general.py new file mode 100644 index 00000000..7360c997 --- /dev/null +++ b/benchmarks/compare_general.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +import argparse +import asyncio +import statistics +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_async, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.general import GeneralBenchmarkConfig, run_general_benchmark + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + try: + results = {} + scenarios = [ + ("http1", "1.1", False), + ("http1_tls", "1.1", True), + ("http2", "2", True), + ] + for scenario_name, http_version, tls in scenarios: + config = GeneralBenchmarkConfig( + http_version=http_version, + tls=tls, + path=args.path, + concurrency=args.concurrency, + total_requests=args.total_requests, + warmup_requests=args.warmup_requests, + ) + current_runs, baseline_runs = await run_interleaved_async( + args.runs, + lambda index: run_general_benchmark( + PROJECT_ROOT, + f"current-{scenario_name}-run-{index + 1}", + config, + ), + lambda index: run_general_benchmark( + baseline_repo, + f"baseline-{args.baseline_ref}-{scenario_name}-run-{index + 1}", + config, + ), + interleave=not args.sequential, + ) + current = summarize_dataclass_runs( + f"current-{scenario_name}", + current_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "requests_per_second": lambda runs: statistics.median(run.requests_per_second for run in runs), + }, + ) + baseline = summarize_dataclass_runs( + f"baseline-{args.baseline_ref}-{scenario_name}", + baseline_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "requests_per_second": lambda runs: statistics.median(run.requests_per_second for run in runs), + }, + ) + results[scenario_name] = build_comparison_result( + current, + baseline, + throughput_field="requests_per_second", + throughput_delta_field="delta_rps", + throughput_improvement_field="improvement_rps_percent", + ) + finally: + if cleanup is not None: + cleanup() + + payload = { + "baseline_ref": args.baseline_ref, + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + "concurrency": args.concurrency, + "total_requests": args.total_requests, + "warmup_requests": args.warmup_requests, + "results": results, + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare local Hypercorn against upstream in general HTTP benchmarks.") + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--path", default="/fast") + parser.add_argument("--concurrency", type=int, default=50) + parser.add_argument("--total-requests", type=int, default=500) + parser.add_argument("--warmup-requests", type=int, default=50) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/compare_h3.py b/benchmarks/compare_h3.py new file mode 100644 index 00000000..5715c9bb --- /dev/null +++ b/benchmarks/compare_h3.py @@ -0,0 +1,101 @@ +from __future__ import annotations + +import argparse +import asyncio +import statistics +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_async, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.h3 import H3BenchmarkConfig, run_h3_benchmark + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + config = H3BenchmarkConfig( + path=args.path, + concurrency=args.concurrency, + total_requests=args.total_requests, + warmup_requests=args.warmup_requests, + ) + + try: + current_runs, baseline_runs = await run_interleaved_async( + args.runs, + lambda index: run_h3_benchmark(PROJECT_ROOT, f"current-h3-run-{index + 1}", config), + lambda index: run_h3_benchmark( + baseline_repo, + f"baseline-{args.baseline_ref}-h3-run-{index + 1}", + config, + ), + interleave=not args.sequential, + ) + finally: + if cleanup is not None: + cleanup() + + current = summarize_dataclass_runs( + "current-h3", + current_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "requests_per_second": lambda runs: statistics.median(run.requests_per_second for run in runs), + }, + ) + baseline = summarize_dataclass_runs( + f"baseline-{args.baseline_ref}-h3", + baseline_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "requests_per_second": lambda runs: statistics.median(run.requests_per_second for run in runs), + }, + ) + payload = { + "baseline_ref": args.baseline_ref, + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + **build_comparison_result( + current, + baseline, + throughput_field="requests_per_second", + throughput_delta_field="delta_rps", + throughput_improvement_field="improvement_rps_percent", + ), + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare local Hypercorn against upstream in a real HTTP/3 benchmark.") + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--path", default="/fast") + parser.add_argument("--concurrency", type=int, default=50) + parser.add_argument("--total-requests", type=int, default=500) + parser.add_argument("--warmup-requests", type=int, default=50) + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/compare_task_group.py b/benchmarks/compare_task_group.py new file mode 100644 index 00000000..816ec121 --- /dev/null +++ b/benchmarks/compare_task_group.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import argparse +import json +import os +import subprocess +import sys +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_sync, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.task_group import TaskGroupBenchmarkResult + + +def main() -> int: + parser = build_parser() + args = parser.parse_args() + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + try: + current_runs, baseline_runs = run_interleaved_sync( + args.runs, + lambda index: _run_for_repo(PROJECT_ROOT, f"current-{args.mode}-run-{index + 1}", args), + lambda index: _run_for_repo( + baseline_repo, + f"baseline-{args.baseline_ref}-{args.mode}-run-{index + 1}", + args, + ), + interleave=not args.sequential, + ) + finally: + if cleanup is not None: + cleanup() + + current = summarize_dataclass_runs(f"current-{args.mode}", current_runs) + baseline = summarize_dataclass_runs(f"baseline-{args.baseline_ref}-{args.mode}", baseline_runs) + payload = { + "baseline_ref": args.baseline_ref, + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + "mode": args.mode, + **build_comparison_result(current, baseline), + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare TaskGroup.spawn_app() overhead across repos.") + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--mode", choices=["asgi", "wsgi"], default="asgi") + parser.add_argument("--warmup-iterations", type=int, default=200) + parser.add_argument("--measured-iterations", type=int, default=2000) + parser.add_argument("--runs", type=int, default=3) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +def _run_for_repo(server_repo: Path, label: str, args: argparse.Namespace) -> TaskGroupBenchmarkResult: + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([str(server_repo / "src"), str(PROJECT_ROOT)]) + output = subprocess.check_output( + [ + sys.executable, + "-m", + "benchmarks.task_group", + "--server-repo", + str(server_repo), + "--label", + label, + "--mode", + args.mode, + "--warmup-iterations", + str(args.warmup_iterations), + "--measured-iterations", + str(args.measured_iterations), + ], + cwd=PROJECT_ROOT, + env=env, + text=True, + ) + return TaskGroupBenchmarkResult(**json.loads(output)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/benchmarks/compare_ws.py b/benchmarks/compare_ws.py new file mode 100644 index 00000000..eafbdf2a --- /dev/null +++ b/benchmarks/compare_ws.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import argparse +import asyncio +import statistics +from pathlib import Path + +from benchmarks._compare import ( + PROJECT_ROOT, + build_comparison_result, + create_worktree, + methodology_name, + run_interleaved_async, + summarize_dataclass_runs, + write_json_output, +) +from benchmarks.ws import WebsocketBenchmarkConfig, run_ws_benchmark + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + baseline_repo: Path + cleanup = None + if args.baseline_path is not None: + baseline_repo = Path(args.baseline_path).resolve() + else: + cleanup, baseline_repo = create_worktree(args.baseline_ref, fetch=not args.no_fetch) + + config = WebsocketBenchmarkConfig( + tls=args.tls, + path=args.path, + warmup_messages=args.warmup_messages, + measured_messages=args.measured_messages, + payload_size=args.payload_size, + ) + + try: + current_runs, baseline_runs = await run_interleaved_async( + args.runs, + lambda index: run_ws_benchmark(PROJECT_ROOT, f"current-ws-run-{index + 1}", config), + lambda index: run_ws_benchmark( + baseline_repo, + f"baseline-{args.baseline_ref}-ws-run-{index + 1}", + config, + ), + interleave=not args.sequential, + ) + finally: + if cleanup is not None: + cleanup() + + current = summarize_dataclass_runs( + "current-ws", + current_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "messages_per_second": lambda runs: statistics.median(run.messages_per_second for run in runs), + }, + ) + baseline = summarize_dataclass_runs( + f"baseline-{args.baseline_ref}-ws", + baseline_runs, + extra_fields={ + "total_time_s": lambda runs: statistics.median(run.total_time_s for run in runs), + "messages_per_second": lambda runs: statistics.median(run.messages_per_second for run in runs), + }, + ) + payload = { + "baseline_ref": args.baseline_ref, + "runs": args.runs, + "methodology": methodology_name(sequential=args.sequential), + **build_comparison_result( + current, + baseline, + throughput_field="messages_per_second", + throughput_delta_field="delta_messages_per_second", + throughput_improvement_field="improvement_messages_per_second_percent", + ), + } + write_json_output(payload, args.output_json) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Compare local Hypercorn against upstream in websocket echo benchmarks.") + parser.add_argument("--baseline-ref", default="upstream/main") + parser.add_argument("--baseline-path") + parser.add_argument("--no-fetch", action="store_true") + parser.add_argument("--tls", action="store_true") + parser.add_argument("--path", default="/ws") + parser.add_argument("--warmup-messages", type=int, default=50) + parser.add_argument("--measured-messages", type=int, default=300) + parser.add_argument("--payload-size", type=int, default=65536) + parser.add_argument("--runs", type=int, default=1) + parser.add_argument("--sequential", action="store_true", help="Run all current runs and then all baseline runs.") + parser.add_argument("--output-json") + return parser + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/fragmented_body.py b/benchmarks/fragmented_body.py new file mode 100644 index 00000000..9c7f6393 --- /dev/null +++ b/benchmarks/fragmented_body.py @@ -0,0 +1,158 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +from h2.connection import H2Connection +from h2.events import DataReceived, ResponseReceived, StreamEnded + +from benchmarks._runtime import PROJECT_ROOT, ServerProcess, build_ssl_context, percentile + + +@dataclass +class FragmentedBodyBenchmarkConfig: + warmup_iterations: int + measured_iterations: int + chunks: int + chunk_size: int + delay_ms: int + + +@dataclass +class FragmentedBodyBenchmarkResult: + target_label: str + server_repo: str + warmup_iterations: int + measured_iterations: int + chunks: int + chunk_size: int + delay_ms: int + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + result = await run_fragmented_body_benchmark( + server_repo=Path(args.server_repo).resolve(), + label=args.label, + config=FragmentedBodyBenchmarkConfig( + warmup_iterations=args.warmup_iterations, + measured_iterations=args.measured_iterations, + chunks=args.chunks, + chunk_size=args.chunk_size, + delay_ms=args.delay_ms, + ), + ) + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser( + description="Run an HTTP/2 fragmented request-body benchmark against Hypercorn." + ) + parser.add_argument("--server-repo", default=str(PROJECT_ROOT)) + parser.add_argument("--label", default="local") + parser.add_argument("--warmup-iterations", type=int, default=5) + parser.add_argument("--measured-iterations", type=int, default=50) + parser.add_argument("--chunks", type=int, default=256) + parser.add_argument("--chunk-size", type=int, default=128) + parser.add_argument("--delay-ms", type=int, default=1) + parser.add_argument("--output-json") + return parser + + +async def run_fragmented_body_benchmark( + server_repo: Path, label: str, config: FragmentedBodyBenchmarkConfig +) -> FragmentedBodyBenchmarkResult: + samples_ms: list[float] = [] + async with ServerProcess(server_repo) as server: + for _ in range(config.warmup_iterations): + await run_fragmented_body_iteration(server.port, config) + for _ in range(config.measured_iterations): + samples_ms.append(await run_fragmented_body_iteration(server.port, config)) + + return FragmentedBodyBenchmarkResult( + target_label=label, + server_repo=str(server_repo), + warmup_iterations=config.warmup_iterations, + measured_iterations=config.measured_iterations, + chunks=config.chunks, + chunk_size=config.chunk_size, + delay_ms=config.delay_ms, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +async def run_fragmented_body_iteration( + port: int, config: FragmentedBodyBenchmarkConfig +) -> float: + reader, writer = await asyncio.open_connection( + "127.0.0.1", + port, + ssl=build_ssl_context(), + server_hostname="localhost", + ) + conn = H2Connection() + conn.initiate_connection() + writer.write(conn.data_to_send()) + await writer.drain() + + total_bytes = config.chunks * config.chunk_size + stream_id = 1 + conn.send_headers( + stream_id, + [ + (":method", "POST"), + (":scheme", "https"), + (":authority", "localhost"), + (":path", f"/slow-read?delay_ms={config.delay_ms}"), + ("content-length", str(total_bytes)), + ], + end_stream=False, + ) + chunk = b"x" * config.chunk_size + for index in range(config.chunks): + conn.send_data(stream_id, chunk, end_stream=index == (config.chunks - 1)) + + start = time.perf_counter() + writer.write(conn.data_to_send()) + await writer.drain() + + while True: + data = await asyncio.wait_for(reader.read(65535), timeout=10) + if data == b"": + raise RuntimeError("Fragmented body benchmark connection closed unexpectedly") + for event in conn.receive_data(data): + if isinstance(event, DataReceived): + conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + elif isinstance(event, ResponseReceived): + continue + elif isinstance(event, StreamEnded) and event.stream_id == stream_id: + writer.close() + await writer.wait_closed() + return (time.perf_counter() - start) * 1000 + + pending = conn.data_to_send() + if pending: + writer.write(pending) + await writer.drain() diff --git a/benchmarks/general.py b/benchmarks/general.py new file mode 100644 index 00000000..db279601 --- /dev/null +++ b/benchmarks/general.py @@ -0,0 +1,171 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import httpx + +from benchmarks._runtime import PROJECT_ROOT, ServerProcess, percentile + + +@dataclass +class GeneralBenchmarkConfig: + http_version: str + tls: bool + path: str + concurrency: int + total_requests: int + warmup_requests: int + + +@dataclass +class GeneralBenchmarkResult: + target_label: str + server_repo: str + http_version: str + tls: bool + path: str + concurrency: int + total_requests: int + warmup_requests: int + total_time_s: float + requests_per_second: float + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + result = await run_general_benchmark( + server_repo=Path(args.server_repo).resolve(), + label=args.label, + config=GeneralBenchmarkConfig( + http_version=args.http_version, + tls=(args.tls or args.http_version == "2"), + path=args.path, + concurrency=args.concurrency, + total_requests=args.total_requests, + warmup_requests=args.warmup_requests, + ), + ) + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a general Hypercorn benchmark over /fast.") + parser.add_argument("--server-repo", default=str(PROJECT_ROOT)) + parser.add_argument("--label", default="local") + parser.add_argument("--http-version", choices=["1.1", "2"], default="1.1") + parser.add_argument("--tls", action="store_true", help="Use TLS for HTTP/1.1 runs. HTTP/2 always uses TLS.") + parser.add_argument("--path", default="/fast") + parser.add_argument("--concurrency", type=int, default=50) + parser.add_argument("--total-requests", type=int, default=500) + parser.add_argument("--warmup-requests", type=int, default=50) + parser.add_argument("--output-json") + return parser + + +async def run_general_benchmark( + server_repo: Path, label: str, config: GeneralBenchmarkConfig +) -> GeneralBenchmarkResult: + if config.http_version == "2" and not config.tls: + raise ValueError("HTTP/2 benchmark requires TLS") + + async with ServerProcess(server_repo, tls=config.tls) as server: + await run_requests( + server.port, + config.http_version, + config.tls, + config.path, + min(config.warmup_requests, config.total_requests), + max(1, min(config.concurrency, config.warmup_requests or config.total_requests)), + ) + total_time_s, samples_ms = await run_requests( + server.port, + config.http_version, + config.tls, + config.path, + config.total_requests, + config.concurrency, + ) + + return GeneralBenchmarkResult( + target_label=label, + server_repo=str(server_repo), + http_version=config.http_version, + tls=config.tls, + path=config.path, + concurrency=config.concurrency, + total_requests=config.total_requests, + warmup_requests=config.warmup_requests, + total_time_s=total_time_s, + requests_per_second=(config.total_requests / total_time_s) if total_time_s > 0 else 0.0, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +async def run_requests( + port: int, http_version: str, tls: bool, path: str, total_requests: int, concurrency: int +) -> tuple[float, list[float]]: + if total_requests <= 0: + return 0.0, [] + + scheme = "https" if tls else "http" + url = f"{scheme}://127.0.0.1:{port}{path}" + queue: asyncio.Queue[int | None] = asyncio.Queue() + for index in range(total_requests): + queue.put_nowait(index) + for _ in range(concurrency): + queue.put_nowait(None) + + latencies_ms: list[float] = [] + lock = asyncio.Lock() + limits = httpx.Limits( + max_connections=1 if http_version == "2" else concurrency, + max_keepalive_connections=1 if http_version == "2" else concurrency, + ) + start = time.perf_counter() + async with httpx.AsyncClient( + http2=(http_version == "2"), + verify=False if tls else True, + timeout=10.0, + limits=limits, + ) as client: + async def worker() -> None: + while True: + item = await queue.get() + if item is None: + return + req_start = time.perf_counter() + response = await client.get(url) + response.raise_for_status() + elapsed_ms = (time.perf_counter() - req_start) * 1000 + async with lock: + latencies_ms.append(elapsed_ms) + + await asyncio.gather(*(worker() for _ in range(concurrency))) + total_time_s = time.perf_counter() - start + return total_time_s, latencies_ms + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/h3.py b/benchmarks/h3.py new file mode 100644 index 00000000..6a0c9f7c --- /dev/null +++ b/benchmarks/h3.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import os +import ssl +import statistics +import subprocess +import sys +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +from aioquic.asyncio.client import connect +from aioquic.asyncio.protocol import QuicConnectionProtocol +from aioquic.h3.connection import H3_ALPN, H3Connection +from aioquic.h3.events import DataReceived, HeadersReceived +from aioquic.quic.configuration import QuicConfiguration +from aioquic.quic.events import ConnectionTerminated, QuicEvent + +from benchmarks._runtime import CERTFILE, KEYFILE, PROJECT_ROOT, percentile, reserve_udp_port + + +@dataclass +class H3BenchmarkConfig: + path: str + warmup_requests: int + total_requests: int + concurrency: int + + +@dataclass +class H3BenchmarkResult: + target_label: str + server_repo: str + path: str + warmup_requests: int + total_requests: int + concurrency: int + total_time_s: float + requests_per_second: float + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + result = await run_h3_benchmark( + server_repo=Path(args.server_repo).resolve(), + label=args.label, + config=H3BenchmarkConfig( + path=args.path, + warmup_requests=args.warmup_requests, + total_requests=args.total_requests, + concurrency=args.concurrency, + ), + ) + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a real HTTP/3 benchmark over QUIC.") + parser.add_argument("--server-repo", default=str(PROJECT_ROOT)) + parser.add_argument("--label", default="local") + parser.add_argument("--path", default="/fast") + parser.add_argument("--concurrency", type=int, default=50) + parser.add_argument("--total-requests", type=int, default=500) + parser.add_argument("--warmup-requests", type=int, default=50) + parser.add_argument("--output-json") + return parser + + +async def run_h3_benchmark(server_repo: Path, label: str, config: H3BenchmarkConfig) -> H3BenchmarkResult: + async with QuicServerProcess(server_repo) as server: + await run_requests(server.port, "/ready", min(config.warmup_requests, config.total_requests), max(1, min(config.concurrency, config.warmup_requests or config.total_requests))) + total_time_s, samples_ms = await run_requests( + server.port, config.path, config.total_requests, config.concurrency + ) + + return H3BenchmarkResult( + target_label=label, + server_repo=str(server_repo), + path=config.path, + warmup_requests=config.warmup_requests, + total_requests=config.total_requests, + concurrency=config.concurrency, + total_time_s=total_time_s, + requests_per_second=(config.total_requests / total_time_s) if total_time_s > 0 else 0.0, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +class QuicServerProcess: + def __init__(self, server_repo: Path) -> None: + self.port = reserve_udp_port() + self.server_repo = server_repo + self.process: subprocess.Popen[str] | None = None + + async def __aenter__(self) -> QuicServerProcess: + env = os.environ.copy() + env["PYTHONPATH"] = os.pathsep.join([str(self.server_repo / "src"), str(PROJECT_ROOT)]) + self.process = subprocess.Popen( + [ + sys.executable, + "-m", + "hypercorn", + "--quic-bind", + f"127.0.0.1:{self.port}", + "--workers", + "1", + "--worker-class", + "asyncio", + "--certfile", + str(CERTFILE), + "--keyfile", + str(KEYFILE), + "benchmarks.app:app", + ], + cwd=PROJECT_ROOT, + env=env, + stdout=subprocess.DEVNULL, + stderr=subprocess.DEVNULL, + text=True, + ) + await wait_for_ready(self.port) + return self + + async def __aexit__(self, exc_type, exc, tb) -> None: + if self.process is None: + return + self.process.terminate() + try: + await asyncio.wait_for(asyncio.to_thread(self.process.wait), timeout=5) + except asyncio.TimeoutError: + self.process.kill() + await asyncio.to_thread(self.process.wait) + + +class H3ClientProtocol(QuicConnectionProtocol): + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + self._http = H3Connection(self._quic) + self._responses: dict[int, bytearray] = {} + self._waiters: dict[int, asyncio.Future[tuple[list[tuple[bytes, bytes]], bytes]]] = {} + self._headers: dict[int, list[tuple[bytes, bytes]]] = {} + + async def get(self, path: str) -> tuple[list[tuple[bytes, bytes]], bytes]: + stream_id = self._quic.get_next_available_stream_id() + waiter = asyncio.get_running_loop().create_future() + self._responses[stream_id] = bytearray() + self._waiters[stream_id] = waiter + self._headers[stream_id] = [] + self._http.send_headers( + stream_id, + [ + (b":method", b"GET"), + (b":scheme", b"https"), + (b":authority", b"localhost"), + (b":path", path.encode("ascii")), + ], + end_stream=True, + ) + self.transmit() + return await waiter + + def quic_event_received(self, event: QuicEvent) -> None: + if isinstance(event, ConnectionTerminated): + error = RuntimeError("HTTP/3 connection terminated") + for waiter in list(self._waiters.values()): + if not waiter.done(): + waiter.set_exception(error) + self._waiters.clear() + return + + for http_event in self._http.handle_event(event): + if isinstance(http_event, HeadersReceived): + self._headers[http_event.stream_id] = http_event.headers + if http_event.stream_ended: + self._finish(http_event.stream_id) + elif isinstance(http_event, DataReceived): + self._responses[http_event.stream_id].extend(http_event.data) + if http_event.stream_ended: + self._finish(http_event.stream_id) + + def _finish(self, stream_id: int) -> None: + waiter = self._waiters.pop(stream_id) + headers = self._headers.pop(stream_id) + body = bytes(self._responses.pop(stream_id)) + if not waiter.done(): + waiter.set_result((headers, body)) + + +async def wait_for_ready(port: int) -> None: + deadline = time.monotonic() + 10 + while time.monotonic() < deadline: + try: + headers, _ = await request_once(port, "/ready") + if _status_code(headers) == 200: + return + except Exception: + await asyncio.sleep(0.05) + continue + raise RuntimeError(f"Timed out waiting for HTTP/3 benchmark server on port {port}") + + +async def run_requests(port: int, path: str, total_requests: int, concurrency: int) -> tuple[float, list[float]]: + if total_requests <= 0: + return 0.0, [] + + queue: asyncio.Queue[int | None] = asyncio.Queue() + for index in range(total_requests): + queue.put_nowait(index) + for _ in range(concurrency): + queue.put_nowait(None) + + latencies_ms: list[float] = [] + lock = asyncio.Lock() + async with connect( + "127.0.0.1", + port, + configuration=_build_quic_config(), + create_protocol=H3ClientProtocol, + wait_connected=True, + ) as client: + client = client # typing hint + + async def worker() -> None: + while True: + item = await queue.get() + if item is None: + return + start = time.perf_counter() + headers, _ = await client.get(path) # type: ignore[attr-defined] + if _status_code(headers) != 200: + raise RuntimeError(f"Unexpected HTTP/3 status for {path}: {_status_code(headers)}") + elapsed_ms = (time.perf_counter() - start) * 1000 + async with lock: + latencies_ms.append(elapsed_ms) + + start = time.perf_counter() + await asyncio.gather(*(worker() for _ in range(concurrency))) + total_time_s = time.perf_counter() - start + return total_time_s, latencies_ms + + +async def request_once(port: int, path: str) -> tuple[list[tuple[bytes, bytes]], bytes]: + async with connect( + "127.0.0.1", + port, + configuration=_build_quic_config(), + create_protocol=H3ClientProtocol, + wait_connected=True, + ) as client: + return await client.get(path) # type: ignore[attr-defined] + + +def _build_quic_config() -> QuicConfiguration: + config = QuicConfiguration(alpn_protocols=H3_ALPN, is_client=True, server_name="localhost") + config.verify_mode = ssl.CERT_NONE + return config + + +def _status_code(headers: list[tuple[bytes, bytes]]) -> int: + for name, value in headers: + if name == b":status": + return int(value) + raise RuntimeError("Missing :status header") +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/run_load.py b/benchmarks/run_load.py new file mode 100644 index 00000000..3fd6f8a4 --- /dev/null +++ b/benchmarks/run_load.py @@ -0,0 +1,185 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +from h2.connection import H2Connection +from h2.events import DataReceived, ResponseReceived, StreamEnded + +from benchmarks._runtime import PROJECT_ROOT, ServerProcess, build_ssl_context, percentile + + +@dataclass +class BenchmarkConfig: + warmup_iterations: int + measured_iterations: int + fast_streams: int + slow_chunks: int + slow_chunk_size: int + slow_delay_ms: int + + +@dataclass +class BenchmarkResult: + target_label: str + server_repo: str + warmup_iterations: int + measured_iterations: int + fast_streams: int + slow_chunks: int + slow_chunk_size: int + slow_delay_ms: int + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + + config = BenchmarkConfig( + warmup_iterations=args.warmup_iterations, + measured_iterations=args.measured_iterations, + fast_streams=args.fast_streams, + slow_chunks=args.slow_chunks, + slow_chunk_size=args.slow_chunk_size, + slow_delay_ms=args.slow_delay_ms, + ) + result = await run_benchmark( + server_repo=Path(args.server_repo).resolve(), + label=args.label, + config=config, + ) + + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a small HTTP/2 HoL benchmark against Hypercorn.") + parser.add_argument("--server-repo", default=str(PROJECT_ROOT), help="Path to the Hypercorn repo to benchmark.") + parser.add_argument("--label", default="local", help="Label to include in the output.") + parser.add_argument("--warmup-iterations", type=int, default=2) + parser.add_argument("--measured-iterations", type=int, default=10) + parser.add_argument("--fast-streams", type=int, default=5) + parser.add_argument("--slow-chunks", type=int, default=12) + parser.add_argument("--slow-chunk-size", type=int, default=4096) + parser.add_argument("--slow-delay-ms", type=int, default=25) + parser.add_argument("--output-json", help="Optional path to write the result as JSON.") + return parser + + +async def run_benchmark(server_repo: Path, label: str, config: BenchmarkConfig) -> BenchmarkResult: + samples_ms: list[float] = [] + async with ServerProcess(server_repo) as server: + for _ in range(config.warmup_iterations): + await run_h2_hol_iteration(server.port, config) + for _ in range(config.measured_iterations): + samples_ms.extend(await run_h2_hol_iteration(server.port, config)) + + return BenchmarkResult( + target_label=label, + server_repo=str(server_repo), + warmup_iterations=config.warmup_iterations, + measured_iterations=config.measured_iterations, + fast_streams=config.fast_streams, + slow_chunks=config.slow_chunks, + slow_chunk_size=config.slow_chunk_size, + slow_delay_ms=config.slow_delay_ms, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +async def run_h2_hol_iteration(port: int, config: BenchmarkConfig) -> list[float]: + reader, writer = await asyncio.open_connection( + "127.0.0.1", + port, + ssl=build_ssl_context(), + server_hostname="localhost", + ) + conn = H2Connection() + conn.initiate_connection() + writer.write(conn.data_to_send()) + await writer.drain() + + slow_stream_id = 1 + fast_stream_ids = [3 + (index * 2) for index in range(config.fast_streams)] + send_hol_workload(conn, slow_stream_id, fast_stream_ids, config) + start_times = {stream_id: time.perf_counter() for stream_id in fast_stream_ids} + writer.write(conn.data_to_send()) + await writer.drain() + + latencies_ms: list[float] = [] + while len(latencies_ms) < len(fast_stream_ids): + data = await asyncio.wait_for(reader.read(65535), timeout=5) + if data == b"": + raise RuntimeError("Benchmark connection closed unexpectedly") + for event in conn.receive_data(data): + if isinstance(event, DataReceived): + conn.acknowledge_received_data(event.flow_controlled_length, event.stream_id) + elif isinstance(event, ResponseReceived): + if event.stream_id not in start_times: + continue + elif isinstance(event, StreamEnded) and event.stream_id in start_times: + latencies_ms.append((time.perf_counter() - start_times[event.stream_id]) * 1000) + pending = conn.data_to_send() + if pending: + writer.write(pending) + await writer.drain() + + writer.close() + await writer.wait_closed() + return latencies_ms + + +def send_hol_workload( + conn: H2Connection, slow_stream_id: int, fast_stream_ids: list[int], config: BenchmarkConfig +) -> None: + total_slow_bytes = config.slow_chunks * config.slow_chunk_size + conn.send_headers( + slow_stream_id, + [ + (":method", "POST"), + (":scheme", "https"), + (":authority", "localhost"), + (":path", f"/slow-read?delay_ms={config.slow_delay_ms}"), + ("content-length", str(total_slow_bytes)), + ], + end_stream=False, + ) + for _ in range(config.slow_chunks): + conn.send_data(slow_stream_id, b"x" * config.slow_chunk_size, end_stream=False) + + for stream_id in fast_stream_ids: + conn.send_headers( + stream_id, + [ + (":method", "POST"), + (":scheme", "https"), + (":authority", "localhost"), + (":path", "/fast"), + ("content-length", "2"), + ], + end_stream=False, + ) + conn.send_data(stream_id, b"ok", end_stream=True) + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/task_group.py b/benchmarks/task_group.py new file mode 100644 index 00000000..d4b361f9 --- /dev/null +++ b/benchmarks/task_group.py @@ -0,0 +1,140 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +from benchmarks._runtime import PROJECT_ROOT, percentile +from hypercorn.app_wrappers import ASGIWrapper, WSGIWrapper +from hypercorn.asyncio.task_group import TaskGroup +from hypercorn.config import Config +from tests.wsgi_applications import wsgi_app_simple + +HTTP_SCOPE = { + "type": "http", + "asgi": {"spec_version": "2.1", "version": "3.0"}, + "http_version": "1.1", + "method": "GET", + "scheme": "http", + "path": "/", + "raw_path": b"/", + "query_string": b"", + "root_path": "", + "headers": [(b"host", b"localhost")], + "client": ("127.0.0.1", 1234), + "server": ("127.0.0.1", 8000), + "state": {}, +} + + +@dataclass +class TaskGroupBenchmarkConfig: + mode: str + warmup_iterations: int + measured_iterations: int + + +@dataclass +class TaskGroupBenchmarkResult: + target_label: str + server_repo: str + mode: str + warmup_iterations: int + measured_iterations: int + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + result = await run_task_group_benchmark( + Path(args.server_repo).resolve(), + args.label, + TaskGroupBenchmarkConfig( + mode=args.mode, + warmup_iterations=args.warmup_iterations, + measured_iterations=args.measured_iterations, + ), + ) + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Measure TaskGroup.spawn_app() overhead.") + parser.add_argument("--server-repo", default=str(PROJECT_ROOT)) + parser.add_argument("--label", default="local") + parser.add_argument("--mode", choices=["asgi", "wsgi"], default="asgi") + parser.add_argument("--warmup-iterations", type=int, default=200) + parser.add_argument("--measured-iterations", type=int, default=2000) + parser.add_argument("--output-json") + return parser + + +async def run_task_group_benchmark( + server_repo: Path, + label: str, + config: TaskGroupBenchmarkConfig, +) -> TaskGroupBenchmarkResult: + loop = asyncio.get_running_loop() + benchmark = _TaskGroupBenchmark(config.mode) + samples_ms = [] + async with TaskGroup(loop) as task_group: + for _ in range(config.warmup_iterations): + await benchmark.run_iteration(task_group) + for _ in range(config.measured_iterations): + start = time.perf_counter() + await benchmark.run_iteration(task_group) + samples_ms.append((time.perf_counter() - start) * 1000) + + return TaskGroupBenchmarkResult( + target_label=label, + server_repo=str(server_repo), + mode=config.mode, + warmup_iterations=config.warmup_iterations, + measured_iterations=config.measured_iterations, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +class _TaskGroupBenchmark: + def __init__(self, mode: str) -> None: + self._config = Config() + if mode == "asgi": + self._app = ASGIWrapper(self._asgi_app) + self._message = {"type": "http.disconnect"} + else: + self._app = WSGIWrapper(wsgi_app_simple, 2**16) + self._message = {"type": "http.request", "body": b"", "more_body": False} + + async def run_iteration(self, task_group: TaskGroup) -> None: + send_queue: asyncio.Queue = asyncio.Queue() + app_put = await task_group.spawn_app(self._app, self._config, HTTP_SCOPE, send_queue.put) + await app_put(self._message) + while True: + message = await send_queue.get() + if message is None: + return + + async def _asgi_app(self, scope, receive, send) -> None: + await receive() + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/benchmarks/ws.py b/benchmarks/ws.py new file mode 100644 index 00000000..17c4d8f0 --- /dev/null +++ b/benchmarks/ws.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import argparse +import asyncio +import json +import statistics +import time +from dataclasses import asdict, dataclass +from pathlib import Path + +import wsproto +import wsproto.events + +from benchmarks._runtime import PROJECT_ROOT, ServerProcess, build_ssl_context, percentile + + +@dataclass +class WebsocketBenchmarkConfig: + tls: bool + path: str + warmup_messages: int + measured_messages: int + payload_size: int + + +@dataclass +class WebsocketBenchmarkResult: + target_label: str + server_repo: str + tls: bool + path: str + warmup_messages: int + measured_messages: int + payload_size: int + total_time_s: float + messages_per_second: float + samples_ms: list[float] + mean_ms: float + median_ms: float + p95_ms: float + minimum_ms: float + maximum_ms: float + + +async def main() -> int: + parser = build_parser() + args = parser.parse_args() + result = await run_ws_benchmark( + server_repo=Path(args.server_repo).resolve(), + label=args.label, + config=WebsocketBenchmarkConfig( + tls=args.tls, + path=args.path, + warmup_messages=args.warmup_messages, + measured_messages=args.measured_messages, + payload_size=args.payload_size, + ), + ) + payload = asdict(result) + if args.output_json is not None: + Path(args.output_json).write_text(json.dumps(payload, indent=2) + "\n") + print(json.dumps(payload, indent=2)) + return 0 + + +def build_parser() -> argparse.ArgumentParser: + parser = argparse.ArgumentParser(description="Run a small websocket echo benchmark against Hypercorn.") + parser.add_argument("--server-repo", default=str(PROJECT_ROOT)) + parser.add_argument("--label", default="local") + parser.add_argument("--tls", action="store_true") + parser.add_argument("--path", default="/ws") + parser.add_argument("--warmup-messages", type=int, default=50) + parser.add_argument("--measured-messages", type=int, default=300) + parser.add_argument("--payload-size", type=int, default=65536) + parser.add_argument("--output-json") + return parser + + +async def run_ws_benchmark( + server_repo: Path, label: str, config: WebsocketBenchmarkConfig +) -> WebsocketBenchmarkResult: + async with ServerProcess(server_repo, tls=config.tls) as server: + reader, writer = await asyncio.open_connection( + "127.0.0.1", + server.port, + ssl=build_ssl_context(alpn_protocols=["http/1.1"]) if config.tls else None, + server_hostname="localhost" if config.tls else None, + ) + client = wsproto.WSConnection(wsproto.ConnectionType.CLIENT) + writer.write(client.send(wsproto.events.Request(host="localhost", target=config.path))) + await writer.drain() + await _wait_for_accept(client, reader) + + payload = b"x" * config.payload_size + for _ in range(config.warmup_messages): + await _round_trip(client, reader, writer, payload) + + start = time.perf_counter() + samples_ms = [] + for _ in range(config.measured_messages): + samples_ms.append(await _round_trip(client, reader, writer, payload)) + total_time_s = time.perf_counter() - start + + writer.write(client.send(wsproto.events.CloseConnection(code=1000))) + await writer.drain() + writer.close() + await writer.wait_closed() + + return WebsocketBenchmarkResult( + target_label=label, + server_repo=str(server_repo), + tls=config.tls, + path=config.path, + warmup_messages=config.warmup_messages, + measured_messages=config.measured_messages, + payload_size=config.payload_size, + total_time_s=total_time_s, + messages_per_second=(config.measured_messages / total_time_s) if total_time_s > 0 else 0.0, + samples_ms=samples_ms, + mean_ms=statistics.fmean(samples_ms), + median_ms=statistics.median(samples_ms), + p95_ms=percentile(samples_ms, 0.95), + minimum_ms=min(samples_ms), + maximum_ms=max(samples_ms), + ) + + +async def _wait_for_accept(client: wsproto.WSConnection, reader: asyncio.StreamReader) -> None: + while True: + data = await asyncio.wait_for(reader.read(65535), timeout=5) + if data == b"": + raise RuntimeError("Websocket benchmark connection closed during handshake") + client.receive_data(data) + for event in client.events(): + if isinstance(event, wsproto.events.AcceptConnection): + return + if isinstance(event, wsproto.events.RejectConnection): + raise RuntimeError("Websocket benchmark connection was rejected") + + +async def _round_trip( + client: wsproto.WSConnection, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + payload: bytes, +) -> float: + start = time.perf_counter() + writer.write(client.send(wsproto.events.BytesMessage(data=payload))) + await writer.drain() + + while True: + data = await asyncio.wait_for(reader.read(65535), timeout=5) + if data == b"": + raise RuntimeError("Websocket benchmark connection closed unexpectedly") + client.receive_data(data) + for event in client.events(): + if isinstance(event, wsproto.events.BytesMessage): + return (time.perf_counter() - start) * 1000 + + +if __name__ == "__main__": + raise SystemExit(asyncio.run(main())) diff --git a/pyproject.toml b/pyproject.toml index b72f468f..141c2614 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -48,7 +48,6 @@ uvloop = ["uvloop"] dev = [ "httpx", "hypothesis", - "mock", "pytest", "pytest-asyncio", "pytest-trio", diff --git a/src/hypercorn/__main__.py b/src/hypercorn/__main__.py index b3980941..f710a2d9 100644 --- a/src/hypercorn/__main__.py +++ b/src/hypercorn/__main__.py @@ -276,7 +276,7 @@ def _convert_verify_mode(value: str) -> ssl.VerifyMode: if args.max_requests is not sentinel: config.max_requests = args.max_requests if args.max_requests_jitter is not sentinel: - config.max_requests_jitter = args.max_requests + config.max_requests_jitter = args.max_requests_jitter if args.pid is not sentinel: config.pid_path = args.pid if args.root_path is not sentinel: diff --git a/src/hypercorn/app_wrappers.py b/src/hypercorn/app_wrappers.py index 2f856e90..13d2d8ce 100644 --- a/src/hypercorn/app_wrappers.py +++ b/src/hypercorn/app_wrappers.py @@ -84,8 +84,9 @@ async def handle_http( await send({"type": "http.response.body", "body": b"", "more_body": False}) def run_app(self, environ: dict, send: Callable) -> None: - headers: list[tuple[bytes, bytes]] + headers: list[tuple[bytes, bytes]] = [] response_started = False + headers_sent = False status_code: int | None = None def start_response( @@ -93,7 +94,21 @@ def start_response( response_headers: list[tuple[str, str]], exc_info: Exception | None = None, ) -> None: - nonlocal headers, response_started, status_code + nonlocal headers, response_started, status_code, headers_sent + + if response_started and exc_info is None: + raise RuntimeError( + "start_response cannot be called again without the exc_info parameter" + ) + elif exc_info is not None: + try: + if headers_sent: + # The headers have already been sent and we can no longer change + # the status_code and headers. reraise this exception in accordance + # with the WSGI specification. + raise exc_info[1].with_traceback(exc_info[2]) + finally: + exc_info = None # Delete reference to exc_info to avoid circular references raw, _ = status.split(" ", 1) status_code = int(raw) @@ -106,16 +121,35 @@ def start_response( response_body = self.app(environ, start_response) try: - first_chunk = True for output in response_body: - if first_chunk: + # Per the WSGI specification in PEP-3333, the start_response callable + # must not actually transmit the response headers. Instead, it must + # store them for the server to transmit only after the first iteration + # of the application return value that yields a non-empty bytestring. + # + # We therefore delay sending the http.response.start event until after + # we receive a non-empty byte string from the application return value. + if output and not headers_sent: if not response_started: raise RuntimeError("WSGI app did not call start_response") + # Send the http.response.start event with the status and headers, flagging + # that this was completed so they aren't sent twice. send({"type": "http.response.start", "status": status_code, "headers": headers}) - first_chunk = False + headers_sent = True send({"type": "http.response.body", "body": output, "more_body": True}) + + # If we still haven't sent the headers by this point, then we received no + # non-empty byte strings from the application return value. This can happen when + # handling certain HTTP methods that don't include a response body like HEAD. + # In those cases we still need to send the http.response.start event with the + # status code and headers, but we need to ensure they haven't been sent previously. + if not headers_sent: + if not response_started: + raise RuntimeError("WSGI app did not call start_response") + + send({"type": "http.response.start", "status": status_code, "headers": headers}) finally: if hasattr(response_body, "close"): response_body.close() diff --git a/src/hypercorn/asyncio/tcp_server.py b/src/hypercorn/asyncio/tcp_server.py index 5612214c..7f88ca30 100644 --- a/src/hypercorn/asyncio/tcp_server.py +++ b/src/hypercorn/asyncio/tcp_server.py @@ -95,7 +95,10 @@ async def protocol_send(self, event: Event) -> None: async def _read_data(self) -> None: while not self.reader.at_eof(): try: - data = await asyncio.wait_for(self.reader.read(MAX_RECV), self.config.read_timeout) + if self.config.read_timeout is None: + data = await self.reader.read(MAX_RECV) + else: + data = await asyncio.wait_for(self.reader.read(MAX_RECV), self.config.read_timeout) except ( ConnectionError, OSError, @@ -124,7 +127,12 @@ async def _close(self) -> None: ConnectionResetError, RuntimeError, asyncio.CancelledError, - ): + TimeoutError, + ) as exc: + if isinstance(exc, TimeoutError): + transport = getattr(self.writer, "transport", None) + if transport is not None: + transport.abort() pass # Already closed finally: await self.idle_task.stop() diff --git a/src/hypercorn/asyncio/udp_server.py b/src/hypercorn/asyncio/udp_server.py index 2ef6f94a..6f328bce 100644 --- a/src/hypercorn/asyncio/udp_server.py +++ b/src/hypercorn/asyncio/udp_server.py @@ -29,7 +29,7 @@ def __init__( self.context = context self.loop = loop self.protocol: QuicProtocol - self.protocol_queue: asyncio.Queue = asyncio.Queue(10) + self.protocol_queue: asyncio.Queue = asyncio.Queue(config.quic_receive_queue_size) self.transport: asyncio.DatagramTransport | None = None self.state = state diff --git a/src/hypercorn/config.py b/src/hypercorn/config.py index 2911aff1..b451c1f0 100644 --- a/src/hypercorn/config.py +++ b/src/hypercorn/config.py @@ -94,9 +94,11 @@ class Config: logger_class = Logger loglevel: str = "INFO" max_app_queue_size: int = 10 + max_app_queue_bytes: int = 1024 * 1024 * BYTES max_requests: int | None = None max_requests_jitter: int = 0 pid_path: str | None = None + quic_receive_queue_size: int = 128 server_names: list[str] = [] shutdown_timeout = 60 * SECONDS ssl_handshake_timeout = 60 * SECONDS @@ -113,6 +115,10 @@ class Config: worker_class = "asyncio" workers = 1 wsgi_max_body_size = 16 * 1024 * 1024 * BYTES + _date_header_cache: tuple[int, tuple[bytes, bytes]] | None = None + _response_headers_cache: dict[ + tuple[str, bool, tuple[str, ...], tuple[tuple, ...]], tuple[tuple[bytes, bytes], ...] + ] | None = None def set_cert_reqs(self, value: int) -> None: warnings.warn("Please use verify_mode instead", Warning) @@ -285,20 +291,49 @@ def _create_sockets( def response_headers(self, protocol: str) -> list[tuple[bytes, bytes]]: headers = [] if self.include_date_header: - headers.append((b"date", format_date_time(time()).encode("ascii"))) - if self.include_server_header: - headers.append((b"server", f"hypercorn-{protocol}".encode("ascii"))) - - for alt_svc_header in self.alt_svc_headers: - headers.append((b"alt-svc", alt_svc_header.encode())) - if len(self.alt_svc_headers) == 0 and self._quic_addresses: - from aioquic.h3.connection import H3_ALPN - - for version in H3_ALPN: - for addr in self._quic_addresses: - port = addr[1] - headers.append((b"alt-svc", b'%s=":%d"; ma=3600' % (version.encode(), port))) + headers.append(self._cached_date_header_value()) + headers.extend(self._cached_static_response_headers(protocol)) + return headers + def _cached_date_header_value(self) -> tuple[bytes, bytes]: + current_second = int(time()) + if self._date_header_cache is None or self._date_header_cache[0] != current_second: + self._date_header_cache = ( + current_second, + (b"date", format_date_time(current_second).encode("ascii")), + ) + return self._date_header_cache[1] + + def _cached_static_response_headers(self, protocol: str) -> tuple[tuple[bytes, bytes], ...]: + cache_key = ( + protocol, + self.include_server_header, + tuple(self.alt_svc_headers), + tuple(self._quic_addresses), + ) + if self._response_headers_cache is None: + self._response_headers_cache = {} + + headers = self._response_headers_cache.get(cache_key) + if headers is None: + headers_list = [] + if self.include_server_header: + headers_list.append((b"server", f"hypercorn-{protocol}".encode("ascii"))) + + for alt_svc_header in self.alt_svc_headers: + headers_list.append((b"alt-svc", alt_svc_header.encode())) + if len(self.alt_svc_headers) == 0 and self._quic_addresses: + from aioquic.h3.connection import H3_ALPN + + for version in H3_ALPN: + for addr in self._quic_addresses: + port = addr[1] + headers_list.append( + (b"alt-svc", b'%s=":%d"; ma=3600' % (version.encode(), port)) + ) + + headers = tuple(headers_list) + self._response_headers_cache[cache_key] = headers return headers def set_statsd_logger_class(self, statsd_logger: type[Logger]) -> None: diff --git a/src/hypercorn/events.py b/src/hypercorn/events.py index 2fa91c36..6477bd8a 100644 --- a/src/hypercorn/events.py +++ b/src/hypercorn/events.py @@ -8,17 +8,17 @@ class Event(ABC): pass -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class RawData(Event): data: bytes address: tuple[str, int] | None = None -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Closed(Event): pass -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Updated(Event): idle: bool diff --git a/src/hypercorn/logging.py b/src/hypercorn/logging.py index 04f84ab4..70b0647f 100644 --- a/src/hypercorn/logging.py +++ b/src/hypercorn/logging.py @@ -1,6 +1,7 @@ from __future__ import annotations import json +import re import logging import os import sys @@ -53,6 +54,7 @@ def _create_logger( class Logger: def __init__(self, config: Config) -> None: self.access_log_format = config.access_log_format + self.access_log_atoms = frozenset(re.findall(r"%\(([^)]+)\)s", self.access_log_format)) self.access_logger = _create_logger( "hypercorn.access", @@ -126,78 +128,158 @@ def atoms( This can be overidden and customised if desired. It should return a mapping between an access log format key and a value. """ - return AccessLogAtoms(request, response, request_time) + return AccessLogAtoms(request, response, request_time, self.access_log_atoms) def __getattr__(self, name: str) -> Any: return getattr(self.error_logger, name) class AccessLogAtoms(dict): + _HEADER_ATOM_RE = re.compile(r"^\{(.+)\}([ioe])$") + def __init__( - self, request: WWWScope, response: ResponseSummary | None, request_time: float + self, + request: WWWScope, + response: ResponseSummary | None, + request_time: float, + required_atoms: frozenset[str] | None = None, ) -> None: - for name, value in request["headers"]: - self[f"{{{name.decode('latin1').lower()}}}i"] = value.decode("latin1") - for name, value in os.environ.items(): - self[f"{{{name.lower()}}}e"] = value - protocol = request.get("http_version", "ws") - client = request.get("client") - if client is None: - remote_addr = None - elif len(client) == 2: - remote_addr = f"{client[0]}:{client[1]}" - elif len(client) == 1: - remote_addr = client[0] - else: # make sure not to throw UnboundLocalError - remote_addr = f"" - if request["type"] == "http": - method = request["method"] - else: - method = "GET" - query_string = request["query_string"].decode() - path_with_qs = request["path"] + ("?" + query_string if query_string else "") - - status_code = "-" - status_phrase = "-" - if response is not None: - for name, value in response.get("headers", []): # type: ignore - self[f"{{{name.decode('latin1').lower()}}}o"] = value.decode("latin1") # type: ignore # noqa: E501 - status_code = str(response["status"]) - try: - status_phrase = HTTPStatus(response["status"]).phrase - except ValueError: - status_phrase = f"" - self.update( - { - "h": remote_addr, - "l": "-", - "t": time.strftime("[%d/%b/%Y:%H:%M:%S %z]"), - "r": f"{method} {request['path']} {protocol}", - "R": f"{method} {path_with_qs} {protocol}", - "s": status_code, - "st": status_phrase, - "S": request["scheme"], - "m": method, - "U": request["path"], - "Uq": path_with_qs, - "q": query_string, - "H": protocol, - "b": self["{Content-Length}o"], - "B": self["{Content-Length}o"], - "f": self["{Referer}i"], - "a": self["{User-Agent}i"], - "T": int(request_time), - "D": int(request_time * 1_000_000), - "L": f"{request_time:.6f}", - "p": f"<{os.getpid()}>", - } - ) + self._request = request + self._response = response + self._request_time = request_time + self._request_header_cache: dict[str, str] | None = None + self._response_header_cache: dict[str, str] | None = None + self._environ_cache: dict[str, str] | None = None + self._query_string: str | None = None + self._path_with_qs: str | None = None + self._status_code: str | None = None + self._status_phrase: str | None = None + if required_atoms is not None: + for atom in required_atoms: + value = self._compute(atom) + if value != "-": + self[atom.lower() if atom.startswith("{") else atom] = value def __getitem__(self, key: str) -> str: + normalized_key = key.lower() if key.startswith("{") else key try: - if key.startswith("{"): - return super().__getitem__(key.lower()) - else: - return super().__getitem__(key) + return super().__getitem__(normalized_key) except KeyError: + value = self._compute(normalized_key) + if value == "-": + return "-" + self[normalized_key] = value + return value + + def _compute(self, key: str) -> str: + if key == "h": + client = self._request.get("client") + if client is None: + return "-" + if len(client) == 2: + return f"{client[0]}:{client[1]}" + if len(client) == 1: + return client[0] + return f"" + if key == "l": + return "-" + if key == "t": + return time.strftime("[%d/%b/%Y:%H:%M:%S %z]") + if key == "s": + return self._get_status_code() + if key == "st": + return self._get_status_phrase() + if key == "S": + return self._request["scheme"] + if key == "m": + return self._request["method"] if self._request["type"] == "http" else "GET" + if key == "U": + return self._request["path"] + if key == "q": + return self._get_query_string() + if key == "H": + return self._request.get("http_version", "ws") + if key == "r": + return f"{self['m']} {self._request['path']} {self['H']}" + if key == "R": + return f"{self['m']} {self._get_path_with_qs()} {self['H']}" + if key == "Uq": + return self._get_path_with_qs() + if key == "b" or key == "B": + return self["{Content-Length}o"] + if key == "f": + return self["{Referer}i"] + if key == "a": + return self["{User-Agent}i"] + if key == "T": + return str(int(self._request_time)) + if key == "D": + return str(int(self._request_time * 1_000_000)) + if key == "L": + return f"{self._request_time:.6f}" + if key == "p": + return f"<{os.getpid()}>" + + match = self._HEADER_ATOM_RE.match(key) + if match is None: return "-" + + header_name, atom_type = match.groups() + if atom_type == "i": + return self._get_request_header(header_name) + if atom_type == "o": + return self._get_response_header(header_name) + if atom_type == "e": + return self._get_environ(header_name) + return "-" + + def _get_query_string(self) -> str: + if self._query_string is None: + self._query_string = self._request["query_string"].decode() + return self._query_string + + def _get_path_with_qs(self) -> str: + if self._path_with_qs is None: + query_string = self._get_query_string() + self._path_with_qs = self._request["path"] + ("?" + query_string if query_string else "") + return self._path_with_qs + + def _get_status_code(self) -> str: + if self._status_code is None: + self._status_code = "-" if self._response is None else str(self._response["status"]) + return self._status_code + + def _get_status_phrase(self) -> str: + if self._status_phrase is None: + status_code = self._get_status_code() + if status_code == "-": + self._status_phrase = "-" + else: + try: + self._status_phrase = HTTPStatus(int(status_code)).phrase + except ValueError: + self._status_phrase = f"" + return self._status_phrase + + def _get_request_header(self, name: str) -> str: + if self._request_header_cache is None: + self._request_header_cache = { + header_name.decode("latin1").lower(): value.decode("latin1") + for header_name, value in self._request["headers"] + } + return self._request_header_cache.get(name, "-") + + def _get_response_header(self, name: str) -> str: + if self._response is None: + return "-" + if self._response_header_cache is None: + self._response_header_cache = { + header_name.decode("latin1").lower(): value.decode("latin1") + for header_name, value in self._response.get("headers", []) # type: ignore[arg-type] + } + return self._response_header_cache.get(name, "-") + + def _get_environ(self, name: str) -> str: + if self._environ_cache is None: + self._environ_cache = {env_name.lower(): value for env_name, value in os.environ.items()} + return self._environ_cache.get(name, "-") diff --git a/src/hypercorn/middleware/dispatcher.py b/src/hypercorn/middleware/dispatcher.py index a070f21e..8b917282 100644 --- a/src/hypercorn/middleware/dispatcher.py +++ b/src/hypercorn/middleware/dispatcher.py @@ -10,6 +10,15 @@ MAX_QUEUE_SIZE = 10 +def _path_matches(path: str, mount_path: str) -> bool: + if path == mount_path: + return True + if mount_path == "/": + return path.startswith("/") + + return path.startswith(f"{mount_path.rstrip('/')}/") + + class _DispatcherMiddleware: def __init__(self, mounts: dict[str, ASGIFramework]) -> None: self.mounts = mounts @@ -19,7 +28,7 @@ async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> Non await self._handle_lifespan(scope, receive, send) else: for path, app in self.mounts.items(): - if scope["path"].startswith(path): + if _path_matches(scope["path"], path): local_scope = scope.copy() local_scope["root_path"] += path return await app(local_scope, receive, send) diff --git a/src/hypercorn/protocol/__init__.py b/src/hypercorn/protocol/__init__.py index 98f62dcc..97d38cd3 100644 --- a/src/hypercorn/protocol/__init__.py +++ b/src/hypercorn/protocol/__init__.py @@ -34,17 +34,7 @@ def __init__( self.state = state self.protocol: H11Protocol | H2Protocol if alpn_protocol == "h2": - self.protocol = H2Protocol( - self.app, - self.config, - self.context, - self.task_group, - self.state, - self.ssl, - self.client, - self.server, - self.send, - ) + self.protocol = self._create_h2() else: self.protocol = H11Protocol( self.app, @@ -65,32 +55,25 @@ async def handle(self, event: Event) -> None: try: return await self.protocol.handle(event) except H2ProtocolAssumedError as error: - self.protocol = H2Protocol( - self.app, - self.config, - self.context, - self.task_group, - self.state, - self.ssl, - self.client, - self.server, - self.send, - ) + self.protocol = self._create_h2() await self.protocol.initiate() if error.data != b"": return await self.protocol.handle(RawData(data=error.data)) except H2CProtocolRequiredError as error: - self.protocol = H2Protocol( - self.app, - self.config, - self.context, - self.task_group, - self.state, - self.ssl, - self.client, - self.server, - self.send, - ) + self.protocol = self._create_h2() await self.protocol.initiate(error.headers, error.settings) if error.data != b"": return await self.protocol.handle(RawData(data=error.data)) + + def _create_h2(self) -> H2Protocol: + return H2Protocol( + self.app, + self.config, + self.context, + self.task_group, + self.state, + self.ssl, + self.client, + self.server, + self.send, + ) diff --git a/src/hypercorn/protocol/events.py b/src/hypercorn/protocol/events.py index 71fb6727..3ed1b8c9 100644 --- a/src/hypercorn/protocol/events.py +++ b/src/hypercorn/protocol/events.py @@ -5,12 +5,12 @@ from hypercorn.typing import ConnectionState -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Event: stream_id: int -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Request(Event): headers: list[tuple[bytes, bytes]] http_version: str @@ -19,38 +19,38 @@ class Request(Event): state: ConnectionState -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Body(Event): data: bytes -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class EndBody(Event): pass -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Trailers(Event): headers: list[tuple[bytes, bytes]] -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Data(Event): data: bytes -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class EndData(Event): pass -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class Response(Event): headers: list[tuple[bytes, bytes]] status_code: int -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class InformationalResponse(Event): headers: list[tuple[bytes, bytes]] status_code: int @@ -60,6 +60,6 @@ def __post_init__(self) -> None: raise ValueError(f"Status code must be 1XX not {self.status_code}") -@dataclass(frozen=True) +@dataclass(frozen=True, slots=True) class StreamClosed(Event): pass diff --git a/src/hypercorn/protocol/h2.py b/src/hypercorn/protocol/h2.py index d50e0f82..cec72bf2 100644 --- a/src/hypercorn/protocol/h2.py +++ b/src/hypercorn/protocol/h2.py @@ -6,8 +6,6 @@ import h2.connection import h2.events import h2.exceptions -import priority - from .events import ( Body, Data, @@ -21,62 +19,14 @@ Trailers, ) from .http_stream import HTTPStream +from .h2_send import BUFFER_HIGH_WATER, BufferCompleteError, H2SendScheduler, StreamBuffer +from .queued_stream import QueuedStream from .ws_stream import WSStream from ..config import Config from ..events import Closed, Event, RawData, Updated -from ..typing import AppWrapper, ConnectionState, Event as IOEvent, TaskGroup, WorkerContext +from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext from ..utils import filter_pseudo_headers -BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth) -BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2 - - -class BufferCompleteError(Exception): - pass - - -class StreamBuffer: - def __init__(self, event_class: type[IOEvent]) -> None: - self.buffer = bytearray() - self._complete = False - self._is_empty = event_class() - self._paused = event_class() - - async def drain(self) -> None: - await self._is_empty.wait() - - def set_complete(self) -> None: - self._complete = True - - async def close(self) -> None: - self._complete = True - self.buffer = bytearray() - await self._is_empty.set() - await self._paused.set() - - @property - def complete(self) -> bool: - return self._complete and len(self.buffer) == 0 - - async def push(self, data: bytes) -> None: - if self._complete: - raise BufferCompleteError() - self.buffer.extend(data) - await self._is_empty.clear() - if len(self.buffer) >= BUFFER_HIGH_WATER: - await self._paused.wait() - await self._paused.clear() - - async def pop(self, max_length: int) -> bytes: - length = min(len(self.buffer), max_length) - data = bytes(self.buffer[:length]) - del self.buffer[:length] - if len(data) < BUFFER_LOW_WATER: - await self._paused.set() - if len(self.buffer) == 0: - await self._is_empty.set() - return data - class H2Protocol: def __init__( @@ -116,15 +66,14 @@ def __init__( self.send = send self.server = server self.ssl = ssl - self.streams: dict[int, HTTPStream | WSStream] = {} - # The below are used by the sending task - self.has_data = self.context.event_class() - self.priority = priority.PriorityTree() - self.stream_buffers: dict[int, StreamBuffer] = {} + self.streams: dict[int, QueuedStream] = {} + self.sender = H2SendScheduler( + self.connection, self.context.event_class, self._flush + ) @property def idle(self) -> bool: - return len(self.streams) == 0 or all(stream.idle for stream in self.streams.values()) + return len(self.streams) == 0 async def initiate( self, headers: list[tuple[bytes, bytes]] | None = None, settings: bytes | None = None @@ -138,46 +87,7 @@ async def initiate( event = h2.events.RequestReceived(stream_id=1, headers=headers) await self._create_stream(event) await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) - self.task_group.spawn(self.send_task) - - async def send_task(self) -> None: - # This should be run in a separate task to the rest of this - # class. This allows it separately choose when to send, - # crucially in what order. - while not self.closed: - try: - stream_id = next(self.priority) - except priority.DeadlockError: - await self.has_data.wait() - await self.has_data.clear() - else: - await self._send_data(stream_id) - - async def _send_data(self, stream_id: int) -> None: - try: - chunk_size = min( - self.connection.local_flow_control_window(stream_id), - self.connection.max_outbound_frame_size, - ) - chunk_size = max(0, chunk_size) - data = await self.stream_buffers[stream_id].pop(chunk_size) - if data: - self.connection.send_data(stream_id, data) - await self._flush() - else: - self.priority.block(stream_id) - - if self.stream_buffers[stream_id].complete: - self.connection.end_stream(stream_id) - await self._flush() - del self.stream_buffers[stream_id] - self.priority.remove_stream(stream_id) - except (h2.exceptions.StreamClosedError, KeyError, h2.exceptions.ProtocolError): - # Stream or connection has closed whilst waiting to send - # data, not a problem - just force close it. - await self.stream_buffers[stream_id].close() - del self.stream_buffers[stream_id] - self.priority.remove_stream(stream_id) + self.task_group.spawn(self.sender.run, lambda: self.closed) async def handle(self, event: Event) -> None: if isinstance(event, RawData): @@ -193,7 +103,7 @@ async def handle(self, event: Event) -> None: stream_ids = list(self.streams.keys()) for stream_id in stream_ids: await self._close_stream(stream_id) - await self.has_data.set() + await self.sender.wake() async def stream_send(self, event: StreamEvent) -> None: try: @@ -206,25 +116,14 @@ async def stream_send(self, event: StreamEvent) -> None: ) await self._flush() elif isinstance(event, (Body, Data)): - self.priority.unblock(event.stream_id) - await self.has_data.set() - await self.stream_buffers[event.stream_id].push(event.data) + await self.sender.buffer(event.stream_id, event.data) elif isinstance(event, (EndBody, EndData)): - self.stream_buffers[event.stream_id].set_complete() - self.priority.unblock(event.stream_id) - await self.has_data.set() - await self.stream_buffers[event.stream_id].drain() + await self.sender.end(event.stream_id) elif isinstance(event, Trailers): - self.priority.unblock(event.stream_id) - await self.has_data.set() - await self.stream_buffers[event.stream_id].drain() - self.connection.send_headers(event.stream_id, event.headers, end_stream=True) - await self._flush() + await self.sender.trailers(event.stream_id, event.headers) elif isinstance(event, StreamClosed): await self._close_stream(event.stream_id) - idle = len(self.streams) == 0 or all( - stream.idle for stream in self.streams.values() - ) + idle = len(self.streams) == 0 if idle and self.context.terminated.is_set(): self.connection.close_connection() await self._flush() @@ -234,7 +133,6 @@ async def stream_send(self, event: StreamEvent) -> None: except ( BufferCompleteError, KeyError, - priority.MissingStreamError, h2.exceptions.ProtocolError, ): # Connection has closed whilst blocked on flow control or @@ -256,12 +154,16 @@ async def _handle_events(self, events: list[h2.events.Event]) -> None: if self.keep_alive_requests > self.config.keep_alive_max_requests: self.connection.close_connection() elif isinstance(event, h2.events.DataReceived): - await self.streams[event.stream_id].handle( - Body(stream_id=event.stream_id, data=event.data) - ) - self.connection.acknowledge_received_data( - event.flow_controlled_length, event.stream_id - ) + try: + await self.streams[event.stream_id].handle( + Body(stream_id=event.stream_id, data=event.data), + lambda length=event.flow_controlled_length, stream_id=event.stream_id: self._acknowledge_data( + length, stream_id + ), + ) + except KeyError: + # Data received while already closed, nothing to do. + pass elif isinstance(event, h2.events.StreamEnded): try: await self.streams[event.stream_id].handle(EndBody(stream_id=event.stream_id)) @@ -271,14 +173,14 @@ async def _handle_events(self, events: list[h2.events.Event]) -> None: pass elif isinstance(event, h2.events.StreamReset): await self._close_stream(event.stream_id) - await self._window_updated(event.stream_id) + await self.sender.window_updated(event.stream_id) elif isinstance(event, h2.events.WindowUpdated): - await self._window_updated(event.stream_id) + await self.sender.window_updated(event.stream_id) elif isinstance(event, h2.events.PriorityUpdated): - await self._priority_updated(event) + await self.sender.priority_updated(event) elif isinstance(event, h2.events.RemoteSettingsChanged): if h2.settings.SettingCodes.INITIAL_WINDOW_SIZE in event.changed_settings: - await self._window_updated(None) + await self.sender.window_updated(None) elif isinstance(event, h2.events.ConnectionTerminated): await self.send(Closed()) await self._flush() @@ -288,34 +190,6 @@ async def _flush(self) -> None: if data != b"": await self.send(RawData(data=data)) - async def _window_updated(self, stream_id: int | None) -> None: - if stream_id is None or stream_id == 0: - # Unblock all streams - for stream_id in list(self.stream_buffers.keys()): - self.priority.unblock(stream_id) - elif stream_id is not None and stream_id in self.stream_buffers: - self.priority.unblock(stream_id) - await self.has_data.set() - - async def _priority_updated(self, event: h2.events.PriorityUpdated) -> None: - try: - self.priority.reprioritize( - stream_id=event.stream_id, - depends_on=event.depends_on or None, - weight=event.weight, - exclusive=event.exclusive, - ) - except priority.MissingStreamError: - # Received PRIORITY frame before HEADERS frame - self.priority.insert_stream( - stream_id=event.stream_id, - depends_on=event.depends_on or None, - weight=event.weight, - exclusive=event.exclusive, - ) - self.priority.block(event.stream_id) - await self.has_data.set() - async def _create_stream(self, request: h2.events.RequestReceived) -> None: for name, value in request.headers: if name == b":method": @@ -324,7 +198,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: raw_path = value if method == "CONNECT": - self.streams[request.stream_id] = WSStream( + stream = WSStream( self.app, self.config, self.context, @@ -336,7 +210,7 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: request.stream_id, ) else: - self.streams[request.stream_id] = HTTPStream( + stream = HTTPStream( self.app, self.config, self.context, @@ -347,14 +221,14 @@ async def _create_stream(self, request: h2.events.RequestReceived) -> None: self.stream_send, request.stream_id, ) - self.stream_buffers[request.stream_id] = StreamBuffer(self.context.event_class) - try: - self.priority.insert_stream(request.stream_id) - except priority.DuplicateStreamError: - # Received PRIORITY frame before HEADERS frame - pass - else: - self.priority.block(request.stream_id) + self.streams[request.stream_id] = QueuedStream( + stream, + self.task_group, + self.context, + self.config.max_app_queue_size, + self.config.max_app_queue_bytes, + ) + self.sender.register_stream(request.stream_id) await self.streams[request.stream_id].handle( Request( @@ -397,4 +271,8 @@ async def _close_stream(self, stream_id: int) -> None: if stream_id in self.streams: stream = self.streams.pop(stream_id) await stream.handle(StreamClosed(stream_id=stream_id)) - await self.has_data.set() + await self.sender.close_stream(stream_id) + + async def _acknowledge_data(self, length: int, stream_id: int) -> None: + self.connection.acknowledge_received_data(length, stream_id) + await self._flush() diff --git a/src/hypercorn/protocol/h2_send.py b/src/hypercorn/protocol/h2_send.py new file mode 100644 index 00000000..68d7c563 --- /dev/null +++ b/src/hypercorn/protocol/h2_send.py @@ -0,0 +1,222 @@ +from __future__ import annotations + +from collections import deque +from collections.abc import Awaitable, Callable + +import h2.events +import h2.exceptions +import priority + +from ..typing import Event as IOEvent + +BUFFER_HIGH_WATER = 2 * 2**14 # Twice the default max frame size (two frames worth) +BUFFER_LOW_WATER = BUFFER_HIGH_WATER / 2 +MAX_BATCHED_SENDS = 16 + + +class BufferCompleteError(Exception): + pass + + +class StreamBuffer: + __slots__ = ("_chunks", "_head_offset", "_size", "_complete", "_is_empty", "_paused") + + def __init__(self, event_class: type[IOEvent]) -> None: + self._chunks: deque[memoryview] = deque() + self._head_offset = 0 + self._size = 0 + self._complete = False + self._is_empty = event_class() + self._paused = event_class() + + async def drain(self) -> None: + await self._is_empty.wait() + + def set_complete(self) -> None: + self._complete = True + + async def close(self) -> None: + self._complete = True + self._chunks = deque() + self._head_offset = 0 + self._size = 0 + await self._is_empty.set() + await self._paused.set() + + @property + def complete(self) -> bool: + return self._complete and self._size == 0 + + async def push(self, data: bytes) -> None: + if self._complete: + raise BufferCompleteError() + chunk = memoryview(data) + if len(chunk) > 0: + self._chunks.append(chunk) + self._size += len(chunk) + await self._is_empty.clear() + if self._size >= BUFFER_HIGH_WATER: + await self._paused.wait() + await self._paused.clear() + + async def pop(self, max_length: int) -> bytes | memoryview: + length = min(self._size, max_length) + if length == 0: + await self._is_empty.set() + return b"" + + remaining = length + parts: list[memoryview] = [] + while remaining > 0: + chunk = self._chunks[0] + available = len(chunk) - self._head_offset + take = min(available, remaining) + parts.append(chunk[self._head_offset : self._head_offset + take]) + self._head_offset += take + self._size -= take + remaining -= take + + if self._head_offset == len(chunk): + self._chunks.popleft() + self._head_offset = 0 + + if self._size <= BUFFER_LOW_WATER: + await self._paused.set() + if self._size == 0: + await self._is_empty.set() + if len(parts) == 1: + return parts[0] + return b"".join(part.tobytes() for part in parts) + + +class H2SendScheduler: + def __init__( + self, + connection: object, + event_class: type[IOEvent], + flush: Callable[[], Awaitable[None]], + ) -> None: + self.connection = connection + self.flush = flush + self.has_data = event_class() + self.priority = priority.PriorityTree() + self.stream_buffers: dict[int, StreamBuffer] = {} + self._event_class = event_class + + async def run(self, should_stop: Callable[[], bool]) -> None: + while not should_stop(): + try: + stream_id = next(self.priority) + except priority.DeadlockError: + await self.has_data.wait() + await self.has_data.clear() + else: + await self._send_ready_batch(stream_id) + + def register_stream(self, stream_id: int) -> None: + self.stream_buffers[stream_id] = StreamBuffer(self._event_class) + try: + self.priority.insert_stream(stream_id) + except priority.DuplicateStreamError: + # Received PRIORITY frame before HEADERS frame + pass + else: + self.priority.block(stream_id) + + async def buffer(self, stream_id: int, data: bytes) -> None: + self.priority.unblock(stream_id) + await self.has_data.set() + await self.stream_buffers[stream_id].push(data) + + async def end(self, stream_id: int) -> None: + self.stream_buffers[stream_id].set_complete() + self.priority.unblock(stream_id) + await self.has_data.set() + await self.stream_buffers[stream_id].drain() + + async def trailers(self, stream_id: int, headers: list[tuple[bytes, bytes]]) -> None: + self.priority.unblock(stream_id) + await self.has_data.set() + await self.stream_buffers[stream_id].drain() + self.connection.send_headers(stream_id, headers, end_stream=True) + await self.flush() + + async def wake(self) -> None: + await self.has_data.set() + + async def window_updated(self, stream_id: int | None) -> None: + if stream_id is None or stream_id == 0: + for pending_stream_id in list(self.stream_buffers.keys()): + self.priority.unblock(pending_stream_id) + elif stream_id in self.stream_buffers: + self.priority.unblock(stream_id) + await self.has_data.set() + + async def priority_updated(self, event: h2.events.PriorityUpdated) -> None: + try: + self.priority.reprioritize( + stream_id=event.stream_id, + depends_on=event.depends_on or None, + weight=event.weight, + exclusive=event.exclusive, + ) + except priority.MissingStreamError: + self.priority.insert_stream( + stream_id=event.stream_id, + depends_on=event.depends_on or None, + weight=event.weight, + exclusive=event.exclusive, + ) + self.priority.block(event.stream_id) + await self.has_data.set() + + async def close_stream(self, stream_id: int) -> None: + if stream_id not in self.stream_buffers: + return + + await self.stream_buffers[stream_id].close() + del self.stream_buffers[stream_id] + self.priority.remove_stream(stream_id) + await self.has_data.set() + + async def _send_ready_batch(self, stream_id: int) -> None: + needs_flush = False + sends = 0 + + while sends < MAX_BATCHED_SENDS: + needs_flush |= await self._send_data(stream_id) + sends += 1 + + try: + stream_id = next(self.priority) + except priority.DeadlockError: + break + + if needs_flush: + await self.flush() + + async def _send_data(self, stream_id: int) -> bool: + needs_flush = False + try: + chunk_size = min( + self.connection.local_flow_control_window(stream_id), + self.connection.max_outbound_frame_size, + ) + chunk_size = max(0, chunk_size) + data = await self.stream_buffers[stream_id].pop(chunk_size) + if data: + self.connection.send_data(stream_id, data) + needs_flush = True + else: + self.priority.block(stream_id) + + if self.stream_buffers[stream_id].complete: + self.connection.end_stream(stream_id) + needs_flush = True + del self.stream_buffers[stream_id] + self.priority.remove_stream(stream_id) + except (h2.exceptions.StreamClosedError, KeyError, h2.exceptions.ProtocolError): + await self.close_stream(stream_id) + return False + + return needs_flush diff --git a/src/hypercorn/protocol/h3.py b/src/hypercorn/protocol/h3.py index 1bcf584e..a005346d 100644 --- a/src/hypercorn/protocol/h3.py +++ b/src/hypercorn/protocol/h3.py @@ -21,6 +21,8 @@ Trailers, ) from .http_stream import HTTPStream +from .h3_send import H3SendScheduler +from .queued_stream import QueuedStream from .ws_stream import WSStream from ..config import Config from ..typing import AppWrapper, ConnectionState, TaskGroup, WorkerContext @@ -45,11 +47,14 @@ def __init__( self.config = config self.context = context self.connection = H3Connection(quic) + self.closed = False self.send = send self.server = server - self.streams: dict[int, HTTPStream | WSStream] = {} + self.streams: dict[int, QueuedStream] = {} self.task_group = task_group self.state = state + self.sender = H3SendScheduler(self.connection, self.context.event_class, self.send) + self.task_group.spawn(self.sender.run, lambda: self.closed) async def handle(self, quic_event: QuicEvent) -> None: for event in self.connection.handle_event(quic_event): @@ -69,27 +74,27 @@ async def handle(self, quic_event: QuicEvent) -> None: async def stream_send(self, event: StreamEvent) -> None: if isinstance(event, (InformationalResponse, Response)): - self.connection.send_headers( + await self.sender.headers( event.stream_id, [(b":status", b"%d" % event.status_code)] + event.headers + self.config.response_headers("h3"), ) - await self.send() elif isinstance(event, (Body, Data)): - self.connection.send_data(event.stream_id, event.data, False) - await self.send() + await self.sender.data(event.stream_id, event.data) elif isinstance(event, (EndBody, EndData)): - self.connection.send_data(event.stream_id, b"", True) - await self.send() + await self.sender.data(event.stream_id, b"", end_stream=True) elif isinstance(event, Trailers): - self.connection.send_headers(event.stream_id, event.headers) - await self.send() + await self.sender.headers(event.stream_id, event.headers) elif isinstance(event, StreamClosed): self.streams.pop(event.stream_id, None) elif isinstance(event, Request): await self._create_server_push(event.stream_id, event.raw_path, event.headers) + async def close(self) -> None: + self.closed = True + await self.sender.close() + async def _create_stream(self, request: HeadersReceived) -> None: for name, value in request.headers: if name == b":method": @@ -98,7 +103,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: raw_path = value if method == "CONNECT": - self.streams[request.stream_id] = WSStream( + stream = WSStream( self.app, self.config, self.context, @@ -110,7 +115,7 @@ async def _create_stream(self, request: HeadersReceived) -> None: request.stream_id, ) else: - self.streams[request.stream_id] = HTTPStream( + stream = HTTPStream( self.app, self.config, self.context, @@ -121,6 +126,13 @@ async def _create_stream(self, request: HeadersReceived) -> None: self.stream_send, request.stream_id, ) + self.streams[request.stream_id] = QueuedStream( + stream, + self.task_group, + self.context, + self.config.max_app_queue_size, + self.config.max_app_queue_bytes, + ) await self.streams[request.stream_id].handle( Request( diff --git a/src/hypercorn/protocol/h3_send.py b/src/hypercorn/protocol/h3_send.py new file mode 100644 index 00000000..201a3895 --- /dev/null +++ b/src/hypercorn/protocol/h3_send.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +from collections import deque +from collections.abc import Awaitable, Callable +from dataclasses import dataclass + +from ..typing import Event as IOEvent + +MAX_BATCHED_SENDS = 32 + + +class H3SendClosedError(RuntimeError): + pass + + +@dataclass +class _QueuedOperation: + apply: Callable[[], bool] + complete: IOEvent + error: Exception | None = None + + +class H3SendScheduler: + def __init__( + self, + connection: object, + event_class: type[IOEvent], + flush: Callable[[], Awaitable[None]], + ) -> None: + self.connection = connection + self.flush = flush + self.has_data = event_class() + self._event_class = event_class + self._closed = False + self._queue: deque[_QueuedOperation] = deque() + + async def run(self, should_stop: Callable[[], bool]) -> None: + while True: + if self._closed or should_stop(): + await self._fail_pending(H3SendClosedError("H3 send scheduler is closed")) + break + if not self._queue: + await self.has_data.wait() + await self.has_data.clear() + if self._closed or should_stop(): + await self._fail_pending(H3SendClosedError("H3 send scheduler is closed")) + break + + await self._send_ready_batch() + + async def headers( + self, stream_id: int, headers: list[tuple[bytes, bytes]], end_stream: bool = False + ) -> None: + await self._enqueue( + lambda: self._send_headers(stream_id, headers, end_stream=end_stream) + ) + + async def data(self, stream_id: int, data: bytes, end_stream: bool = False) -> None: + await self._enqueue(lambda: self._send_data(stream_id, data, end_stream=end_stream)) + + async def wake(self) -> None: + await self.has_data.set() + + async def close(self) -> None: + self._closed = True + await self._fail_pending(H3SendClosedError("H3 send scheduler is closed")) + await self.has_data.set() + + async def _enqueue(self, apply: Callable[[], bool]) -> None: + if self._closed: + raise H3SendClosedError("H3 send scheduler is closed") + operation = _QueuedOperation(apply=apply, complete=self._event_class()) + self._queue.append(operation) + await self.has_data.set() + await operation.complete.wait() + if operation.error is not None: + raise operation.error + + async def _fail_pending(self, error: Exception) -> None: + while self._queue: + operation = self._queue.popleft() + operation.error = error + await operation.complete.set() + + async def _send_ready_batch(self) -> None: + needs_flush = False + processed: list[_QueuedOperation] = [] + + try: + while self._queue and len(processed) < MAX_BATCHED_SENDS: + operation = self._queue.popleft() + processed.append(operation) + needs_flush |= operation.apply() + + if needs_flush: + await self.flush() + except Exception as error: + for operation in processed: + operation.error = error + finally: + for operation in processed: + await operation.complete.set() + + if self._queue: + await self.has_data.set() + + def _send_headers( + self, stream_id: int, headers: list[tuple[bytes, bytes]], end_stream: bool = False + ) -> bool: + self.connection.send_headers(stream_id, headers, end_stream=end_stream) + return True + + def _send_data(self, stream_id: int, data: bytes, end_stream: bool = False) -> bool: + self.connection.send_data(stream_id, data, end_stream) + return True diff --git a/src/hypercorn/protocol/http_stream.py b/src/hypercorn/protocol/http_stream.py index 206ad6d6..136f1f9d 100644 --- a/src/hypercorn/protocol/http_stream.py +++ b/src/hypercorn/protocol/http_stream.py @@ -60,16 +60,17 @@ def __init__( stream_id: int, ) -> None: self.app = app + self.app_put: Callable[[dict], Awaitable[None]] | None = None self.client = client self.closed = False self.config = config self.context = context self.response: HTTPResponseStartEvent - self.scope: HTTPScope + self.scope: HTTPScope | None = None self.send = send self.scheme = "https" if ssl else "http" self.server = server - self.start_time: float + self.start_time: float | None = None self.state = ASGIHTTPState.REQUEST self.stream_id = stream_id self.task_group = task_group @@ -119,14 +120,19 @@ async def handle(self, event: Event) -> None: self.closed = True elif isinstance(event, Body): - await self.app_put( - {"type": "http.request", "body": bytes(event.data), "more_body": True} - ) + if self.app_put is not None: + body = event.data if isinstance(event.data, bytes) else bytes(event.data) + await self.app_put({"type": "http.request", "body": body, "more_body": True}) elif isinstance(event, EndBody): - await self.app_put({"type": "http.request", "body": b"", "more_body": False}) + if self.app_put is not None: + await self.app_put({"type": "http.request", "body": b"", "more_body": False}) elif isinstance(event, StreamClosed): self.closed = True - if self.state != ASGIHTTPState.CLOSED: + if ( + self.state != ASGIHTTPState.CLOSED + and self.scope is not None + and self.start_time is not None + ): await self.config.log.access(self.scope, None, time() - self.start_time) if self.app_put is not None: await self.app_put({"type": "http.disconnect"}) @@ -189,8 +195,11 @@ async def app_send(self, message: ASGISendEvent | None) -> None: not suppress_body(self.scope["method"], int(self.response["status"])) and message.get("body", b"") != b"" ): + body = message.get("body", b"") + if not isinstance(body, bytes): + body = bytes(body) await self.send( - Body(stream_id=self.stream_id, data=bytes(message.get("body", b""))) + Body(stream_id=self.stream_id, data=body) ) if not message.get("more_body", False): diff --git a/src/hypercorn/protocol/queued_stream.py b/src/hypercorn/protocol/queued_stream.py new file mode 100644 index 00000000..e1ff6104 --- /dev/null +++ b/src/hypercorn/protocol/queued_stream.py @@ -0,0 +1,141 @@ +from __future__ import annotations + +from collections import deque +from collections.abc import Awaitable, Callable +from dataclasses import dataclass +from typing import Protocol + +from .events import Body, Data, Event, StreamClosed +from ..typing import TaskGroup, WorkerContext + + +class Stream(Protocol): + @property + def idle(self) -> bool: ... + + async def handle(self, event: Event) -> None: ... + + +@dataclass +class _QueuedEvent: + event: Event + callbacks: list[Callable[[], Awaitable[None]]] + chunks: list[bytes] | None = None + byte_size: int = 0 + + @classmethod + def create( + cls, event: Event, callback: Callable[[], Awaitable[None]] | None = None + ) -> _QueuedEvent: + callbacks = [] if callback is None else [callback] + if isinstance(event, (Body, Data)): + data = event.data if isinstance(event.data, bytes) else bytes(event.data) + return cls(event=event, callbacks=callbacks, chunks=[data], byte_size=len(data)) + return cls(event=event, callbacks=callbacks) + + def append(self, other: _QueuedEvent) -> None: + if self.chunks is None or other.chunks is None: + raise TypeError("Only data-carrying events can be merged") + self.chunks.extend(other.chunks) + self.callbacks.extend(other.callbacks) + self.byte_size += other.byte_size + + def materialize(self) -> Event: + if self.chunks is None: + return self.event + + data = self.chunks[0] if len(self.chunks) == 1 else b"".join(self.chunks) + if isinstance(self.event, Body): + return Body(stream_id=self.event.stream_id, data=data) + return Data(stream_id=self.event.stream_id, data=data) + + @property + def size_bytes(self) -> int: + return self.byte_size + + +class QueuedStream: + def __init__( + self, + stream: Stream, + task_group: TaskGroup, + context: WorkerContext, + max_queue_size: int = 0, + max_queue_bytes: int = 0, + ) -> None: + self._closed = False + self._has_events = context.event_class() + self._has_space = context.event_class() + self._queued_bytes = 0 + self._max_queue_bytes = max_queue_bytes + self._max_queue_size = max_queue_size + self._queue: deque[_QueuedEvent] = deque() + self._stream = stream + task_group.spawn(self._handle) + + @property + def idle(self) -> bool: + return len(self._queue) == 0 and self._stream.idle + + async def handle( + self, event: Event, callback: Callable[[], Awaitable[None]] | None = None + ) -> None: + queued_event = _QueuedEvent.create(event, callback) + while True: + queue_empty = len(self._queue) == 0 and self._queued_bytes == 0 + has_count_space = self._max_queue_size == 0 or len(self._queue) < self._max_queue_size + has_byte_space = ( + self._max_queue_bytes == 0 + or queue_empty + or (self._queued_bytes + queued_event.size_bytes) <= self._max_queue_bytes + ) + + if self._queue: + if has_byte_space and _merge_queued_events(self._queue[-1], queued_event): + self._queued_bytes += queued_event.size_bytes + await self._has_events.set() + return + + if has_count_space and has_byte_space: + break + + await self._has_space.wait() + await self._has_space.clear() + + self._queue.append(queued_event) + self._queued_bytes += queued_event.size_bytes + await self._has_events.set() + + async def _handle(self) -> None: + while True: + await self._has_events.wait() + while self._queue: + queued = self._queue.popleft() + self._queued_bytes -= queued.size_bytes + if ( + (self._max_queue_size > 0 and len(self._queue) < self._max_queue_size) + or (self._max_queue_bytes > 0 and self._queued_bytes < self._max_queue_bytes) + ): + await self._has_space.set() + if len(self._queue) == 0: + await self._has_events.clear() + + await self._stream.handle(queued.materialize()) + for callback in queued.callbacks: + await callback() + + if isinstance(queued.event, StreamClosed): + self._closed = True + + if self._closed and len(self._queue) == 0: + return + + +def _merge_queued_events(first: _QueuedEvent, second: _QueuedEvent) -> bool: + if isinstance(first.event, Body) and isinstance(second.event, Body): + first.append(second) + return True + if isinstance(first.event, Data) and isinstance(second.event, Data): + first.append(second) + return True + return False diff --git a/src/hypercorn/protocol/quic.py b/src/hypercorn/protocol/quic.py index 40819420..eda61160 100644 --- a/src/hypercorn/protocol/quic.py +++ b/src/hypercorn/protocol/quic.py @@ -120,6 +120,8 @@ async def _handle_events( event = connection.quic.next_event() while event is not None: if isinstance(event, ConnectionTerminated): + if connection.h3 is not None: + await connection.h3.close() await connection.task.stop() for cid in connection.cids: del self.connections[cid] diff --git a/src/hypercorn/protocol/ws_stream.py b/src/hypercorn/protocol/ws_stream.py index 136345f5..99f76ca5 100644 --- a/src/hypercorn/protocol/ws_stream.py +++ b/src/hypercorn/protocol/ws_stream.py @@ -184,17 +184,17 @@ def __init__( self.context = context self.task_group = task_group self.response: WebsocketResponseStartEvent - self.scope: WebsocketScope + self.scope: WebsocketScope | None = None self.send = send # RFC 8441 for HTTP/2 says use http or https, ASGI says ws or wss self.scheme = "wss" if ssl else "ws" self.server = server - self.start_time: float + self.start_time: float | None = None self.state = ASGIWebsocketState.HANDSHAKE self.stream_id = stream_id - self.connection: Connection - self.handshake: Handshake + self.connection: Connection | None = None + self.handshake: Handshake | None = None @property def idle(self) -> bool: @@ -235,10 +235,13 @@ async def handle(self, event: Event) -> None: self.app, self.config, self.scope, self.app_send ) await self.app_put({"type": "websocket.connect"}) - elif isinstance(event, (Body, Data)) and not self.handshake.accepted: + elif isinstance(event, (Body, Data)) and ( + self.handshake is None or not self.handshake.accepted + ): await self._send_error_response(400) self.closed = True elif isinstance(event, (Body, Data)): + assert self.connection is not None self.connection.receive_data(event.data) await self._handle_events() elif isinstance(event, StreamClosed): @@ -281,7 +284,10 @@ async def app_send(self, message: ASGISendEvent | None) -> None: elif message["type"] == "websocket.send" and self.state == ASGIWebsocketState.CONNECTED: event: WSProtoEvent if message.get("bytes") is not None: - event = BytesMessage(data=bytes(message["bytes"])) + data = message["bytes"] + if not isinstance(data, bytes): + data = bytes(data) + event = BytesMessage(data=data) elif not isinstance(message["text"], str): raise TypeError(f"{message['text']} should be a str") else: @@ -323,6 +329,7 @@ async def _handle_events(self) -> None: elif isinstance(event, CloseConnection): if self.connection.state == ConnectionState.REMOTE_CLOSING: await self._send_wsproto_event(event.response()) + await self.send(EndData(stream_id=self.stream_id)) await self.send(StreamClosed(stream_id=self.stream_id)) async def _send_error_response(self, status_code: int) -> None: @@ -334,9 +341,10 @@ async def _send_error_response(self, status_code: int) -> None: ) ) await self.send(EndBody(stream_id=self.stream_id)) - await self.config.log.access( - self.scope, {"status": status_code, "headers": []}, time() - self.start_time - ) + if self.scope is not None and self.start_time is not None: + await self.config.log.access( + self.scope, {"status": status_code, "headers": []}, time() - self.start_time + ) async def _send_wsproto_event(self, event: WSProtoEvent) -> None: try: @@ -373,7 +381,10 @@ async def _send_rejection(self, message: WebsocketResponseBodyEvent) -> None: ) self.state = ASGIWebsocketState.RESPONSE if not body_suppressed: - await self.send(Body(stream_id=self.stream_id, data=bytes(message.get("body", b"")))) + body = message.get("body", b"") + if not isinstance(body, bytes): + body = bytes(body) + await self.send(Body(stream_id=self.stream_id, data=body)) if not message.get("more_body", False): self.state = ASGIWebsocketState.HTTPCLOSED await self.send(EndBody(stream_id=self.stream_id)) diff --git a/src/hypercorn/trio/lifespan.py b/src/hypercorn/trio/lifespan.py index 087aa839..62ff1d70 100644 --- a/src/hypercorn/trio/lifespan.py +++ b/src/hypercorn/trio/lifespan.py @@ -90,7 +90,7 @@ async def wait_for_shutdown(self) -> None: with trio.fail_after(self.config.shutdown_timeout): await self.shutdown.wait() except trio.TooSlowError as error: - raise LifespanTimeoutError("startup") from error + raise LifespanTimeoutError("shutdown") from error async def asgi_receive(self) -> ASGIReceiveEvent: return await self.app_receive_channel.receive() diff --git a/src/hypercorn/trio/tcp_server.py b/src/hypercorn/trio/tcp_server.py index 5b890f08..835cca76 100644 --- a/src/hypercorn/trio/tcp_server.py +++ b/src/hypercorn/trio/tcp_server.py @@ -109,7 +109,8 @@ async def _read_data(self) -> None: ): break else: - await self.protocol.handle(RawData(bytes(data))) + payload = data if isinstance(data, bytes) else bytes(data) + await self.protocol.handle(RawData(payload)) if data == b"": break await self.protocol.handle(Closed()) diff --git a/src/hypercorn/trio/worker_context.py b/src/hypercorn/trio/worker_context.py index f9e08f17..7d359ae7 100644 --- a/src/hypercorn/trio/worker_context.py +++ b/src/hypercorn/trio/worker_context.py @@ -41,19 +41,25 @@ async def stop(self) -> None: class EventWrapper: def __init__(self) -> None: - self._event = trio.Event() + self._condition = trio.Condition() + self._is_set = False async def clear(self) -> None: - self._event = trio.Event() + async with self._condition: + self._is_set = False async def wait(self) -> None: - await self._event.wait() + async with self._condition: + while not self._is_set: + await self._condition.wait() async def set(self) -> None: - self._event.set() + async with self._condition: + self._is_set = True + self._condition.notify_all() def is_set(self) -> bool: - return self._event.is_set() + return self._is_set class WorkerContext: diff --git a/src/hypercorn/utils.py b/src/hypercorn/utils.py index 6c3b0c45..9ae9b1c4 100644 --- a/src/hypercorn/utils.py +++ b/src/hypercorn/utils.py @@ -78,6 +78,16 @@ def filter_pseudo_headers(headers: list[tuple[bytes, bytes]]) -> list[tuple[byte return filtered_headers +def _resolve_application(module: Any, app_name: str, path: str) -> Any: + app = module + for attribute in app_name.split("."): + try: + app = getattr(app, attribute) + except AttributeError as error: + raise NoAppError(f"Cannot load application from '{path}', application not found.") from error + return app + + def load_application(path: str, wsgi_max_body_size: int) -> AppWrapper: mode: Literal["asgi", "wsgi"] | None = None if ":" not in path: @@ -102,12 +112,8 @@ def load_application(path: str, wsgi_max_body_size: int) -> AppWrapper: raise NoAppError(f"Cannot load application from '{path}', module not found.") else: raise - try: - app = eval(app_name, vars(module)) - except NameError: - raise NoAppError(f"Cannot load application from '{path}', application not found.") - else: - return wrap_app(app, wsgi_max_body_size, mode) + app = _resolve_application(module, app_name, path) + return wrap_app(app, wsgi_max_body_size, mode) def wrap_app( diff --git a/tests/assets/load_apps.py b/tests/assets/load_apps.py new file mode 100644 index 00000000..0850c97d --- /dev/null +++ b/tests/assets/load_apps.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from collections.abc import Callable + +from hypercorn.typing import Scope + + +async def app(scope: Scope, receive: Callable, send: Callable) -> None: + pass + + +class Container: + def __init__(self) -> None: + self.app = app + + +nested = Container() diff --git a/tests/asyncio/test_tcp_server.py b/tests/asyncio/test_tcp_server.py index 1aa2898b..761b366b 100644 --- a/tests/asyncio/test_tcp_server.py +++ b/tests/asyncio/test_tcp_server.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +from unittest.mock import AsyncMock, Mock import pytest @@ -8,10 +9,20 @@ from hypercorn.asyncio.tcp_server import TCPServer from hypercorn.asyncio.worker_context import WorkerContext from hypercorn.config import Config +from hypercorn.events import Closed, RawData from .helpers import MemoryReader, MemoryWriter from ..helpers import echo_framework +class TimeoutWriter(MemoryWriter): + def __init__(self) -> None: + super().__init__() + self.transport = Mock() + + async def wait_closed(self) -> None: + raise TimeoutError() + + @pytest.mark.asyncio async def test_completes_on_closed() -> None: event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() @@ -53,3 +64,58 @@ async def test_complets_on_half_close() -> None: data == b"HTTP/1.1 200 \r\ncontent-length: 348\r\ndate: Thu, 01 Jan 1970 01:23:20 GMT\r\nserver: hypercorn-h11\r\n\r\n" # noqa: E501 ) + + +@pytest.mark.asyncio +async def test_close_aborts_transport_on_wait_closed_timeout() -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + writer = TimeoutWriter() + + server = TCPServer( + ASGIWrapper(echo_framework), + event_loop, + Config(), + WorkerContext(None), + {}, + MemoryReader(), # type: ignore + writer, # type: ignore + ) + + await server._close() + + writer.transport.abort.assert_called_once_with() + + +@pytest.mark.asyncio +async def test_read_data_without_timeout_does_not_use_wait_for(monkeypatch) -> None: + event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() + config = Config() + config.read_timeout = None + reader = MemoryReader() + + server = TCPServer( + ASGIWrapper(echo_framework), + event_loop, + config, + WorkerContext(None), + {}, + reader, # type: ignore[arg-type] + MemoryWriter(), # type: ignore[arg-type] + ) + server.protocol = Mock() + server.protocol.handle = AsyncMock() + + async def forbidden_wait_for(*args, **kwargs): + raise AssertionError("asyncio.wait_for should not be used when read_timeout is None") + + monkeypatch.setattr(asyncio, "wait_for", forbidden_wait_for) + + await reader.send(b"body") + reader.close() + await server._read_data() + + assert server.protocol.handle.await_args_list == [ + ((RawData(data=b"body"),),), + ((RawData(data=b""),),), + ((Closed(),),), + ] diff --git a/tests/asyncio/test_udp_server.py b/tests/asyncio/test_udp_server.py new file mode 100644 index 00000000..a56edf1d --- /dev/null +++ b/tests/asyncio/test_udp_server.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +import asyncio + +from hypercorn.asyncio.udp_server import UDPServer +from hypercorn.asyncio.worker_context import WorkerContext +from hypercorn.config import Config +from hypercorn.events import RawData + + +def test_udp_server_uses_configured_queue_size() -> None: + config = Config() + config.quic_receive_queue_size = 64 + + loop = asyncio.new_event_loop() + try: + server = UDPServer(None, loop, config, WorkerContext(None), {}) # type: ignore[arg-type] + + assert server.protocol_queue.maxsize == 64 + finally: + loop.close() + + +def test_udp_server_drops_datagrams_when_queue_is_full() -> None: + config = Config() + config.quic_receive_queue_size = 1 + + loop = asyncio.new_event_loop() + try: + server = UDPServer(None, loop, config, WorkerContext(None), {}) # type: ignore[arg-type] + + server.datagram_received(b"one", ("127.0.0.1", 4433)) + server.datagram_received(b"two", ("127.0.0.1", 4433)) + + assert server.protocol_queue.qsize() == 1 + queued = server.protocol_queue.get_nowait() + assert queued == RawData(data=b"one", address=("127.0.0.1", 4433)) + finally: + loop.close() diff --git a/tests/conftest.py b/tests/conftest.py index be84f59a..6c132abe 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,17 @@ from __future__ import annotations +import sys +from pathlib import Path + import pytest from _pytest.monkeypatch import MonkeyPatch +PROJECT_ROOT = Path(__file__).resolve().parent.parent +SRC_PATH = PROJECT_ROOT / "src" + +if str(SRC_PATH) not in sys.path: + sys.path.insert(0, str(SRC_PATH)) + import hypercorn.config from hypercorn.typing import ConnectionState, HTTPScope diff --git a/tests/middleware/test_dispatcher.py b/tests/middleware/test_dispatcher.py index 2f57d46a..12c1151e 100644 --- a/tests/middleware/test_dispatcher.py +++ b/tests/middleware/test_dispatcher.py @@ -49,6 +49,33 @@ async def send(message: dict) -> None: ] +@pytest.mark.asyncio +async def test_dispatcher_middleware_matches_segments(http_scope: HTTPScope) -> None: + class EchoFramework: + async def __call__(self, scope: Scope, receive: Callable, send: Callable) -> None: + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [(b"content-length", b"2")], + } + ) + await send({"type": "http.response.body", "body": b"ok"}) + + app = AsyncioDispatcherMiddleware({"/api": EchoFramework()}) + + sent_events = [] + + async def send(message: dict) -> None: + sent_events.append(message) + + await app({**http_scope, **{"path": "/apiary"}}, None, send) # type: ignore + assert sent_events == [ + {"type": "http.response.start", "status": 404, "headers": [(b"content-length", b"0")]}, + {"type": "http.response.body"}, + ] + + class ScopeFramework: def __init__(self, name: str) -> None: self.name = name diff --git a/tests/protocol/test_h2.py b/tests/protocol/test_h2.py index b549496d..15e201f9 100644 --- a/tests/protocol/test_h2.py +++ b/tests/protocol/test_h2.py @@ -5,15 +5,35 @@ import pytest from h2.connection import H2Connection -from h2.events import ConnectionTerminated +from h2.events import ConnectionTerminated, DataReceived from hypercorn.asyncio.worker_context import EventWrapper, WorkerContext from hypercorn.config import Config from hypercorn.events import Closed, RawData +from hypercorn.protocol.events import Body, StreamClosed from hypercorn.protocol.h2 import BUFFER_HIGH_WATER, BufferCompleteError, H2Protocol, StreamBuffer +from hypercorn.protocol.h2_send import H2SendScheduler +from hypercorn.protocol.queued_stream import QueuedStream from hypercorn.typing import ConnectionState +class DummyTaskGroup: + def __init__(self) -> None: + self.tasks: list[asyncio.Task] = [] + + def spawn(self, func, *args) -> None: + self.tasks.append(asyncio.create_task(func(*args))) + + async def spawn_app(self, *args, **kwargs) -> AsyncMock: + return AsyncMock() + + async def aclose(self) -> None: + for task in self.tasks: + task.cancel() + if self.tasks: + await asyncio.gather(*self.tasks, return_exceptions=True) + + @pytest.mark.asyncio async def test_stream_buffer_push_and_pop() -> None: event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() @@ -25,10 +45,11 @@ async def _push_over_limit() -> bool: return True task = event_loop.create_task(_push_over_limit()) + await asyncio.sleep(0) assert not task.done() # Blocked as over high water await stream_buffer.pop(BUFFER_HIGH_WATER // 4) assert not task.done() # Blocked as over low water - await stream_buffer.pop(BUFFER_HIGH_WATER // 4) + await stream_buffer.pop((BUFFER_HIGH_WATER // 4) + 1) assert (await task) is True @@ -71,13 +92,66 @@ async def test_stream_buffer_complete() -> None: assert stream_buffer.complete +@pytest.mark.asyncio +async def test_stream_buffer_pop_across_chunks() -> None: + stream_buffer = StreamBuffer(EventWrapper) + await stream_buffer.push(b"abcd") + await stream_buffer.push(b"efgh") + + assert bytes(await stream_buffer.pop(6)) == b"abcdef" + assert bytes(await stream_buffer.pop(6)) == b"gh" + + +@pytest.mark.asyncio +async def test_send_scheduler_sends_data_and_closes_stream() -> None: + connection = Mock() + connection.local_flow_control_window.return_value = 5 + connection.max_outbound_frame_size = 5 + flush = AsyncMock() + scheduler = H2SendScheduler(connection, EventWrapper, flush) + + scheduler.register_stream(1) + await scheduler.buffer(1, b"hello") + scheduler.stream_buffers[1].set_complete() + await scheduler._send_ready_batch(1) + + connection.send_data.assert_called_once() + assert bytes(connection.send_data.call_args.args[1]) == b"hello" + connection.end_stream.assert_called_once_with(1) + flush.assert_awaited_once() + assert 1 not in scheduler.stream_buffers + + +@pytest.mark.asyncio +async def test_send_scheduler_batches_flush_across_ready_streams() -> None: + connection = Mock() + connection.local_flow_control_window.return_value = 5 + connection.max_outbound_frame_size = 5 + flush = AsyncMock() + scheduler = H2SendScheduler(connection, EventWrapper, flush) + + scheduler.register_stream(1) + scheduler.register_stream(3) + await scheduler.buffer(1, b"hello") + await scheduler.buffer(3, b"world") + scheduler.stream_buffers[1].set_complete() + scheduler.stream_buffers[3].set_complete() + + await scheduler._send_ready_batch(1) + + assert connection.send_data.call_count == 2 + assert connection.end_stream.call_count == 2 + flush.assert_awaited_once() + + @pytest.mark.asyncio async def test_protocol_handle_protocol_error() -> None: + task_group = DummyTaskGroup() protocol = H2Protocol( Mock(), Config(), WorkerContext(None), - AsyncMock(), + task_group, # type: ignore[arg-type] ConnectionState({}), False, None, @@ -87,15 +161,17 @@ async def test_protocol_handle_protocol_error() -> None: await protocol.handle(RawData(data=b"broken nonsense\r\n\r\n")) protocol.send.assert_awaited() # type: ignore assert protocol.send.call_args_list == [call(Closed())] # type: ignore + await task_group.aclose() @pytest.mark.asyncio async def test_protocol_keep_alive_max_requests() -> None: + task_group = DummyTaskGroup() protocol = H2Protocol( Mock(), Config(), WorkerContext(None), - AsyncMock(), + task_group, # type: ignore[arg-type] ConnectionState({}), False, None, @@ -116,3 +192,121 @@ async def test_protocol_keep_alive_max_requests() -> None: protocol.send.assert_awaited() # type: ignore events = client.receive_data(protocol.send.call_args_list[1].args[0].data) # type: ignore assert isinstance(events[-1], ConnectionTerminated) + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_protocol_ignores_data_received_for_closed_stream() -> None: + task_group = DummyTaskGroup() + protocol = H2Protocol( + Mock(), + Config(), + WorkerContext(None), + task_group, # type: ignore[arg-type] + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) + protocol.connection.acknowledge_received_data = Mock() # type: ignore[method-assign] + protocol._flush = AsyncMock() # type: ignore[method-assign] + + await protocol._handle_events( + [DataReceived(stream_id=1, data=b"body", flow_controlled_length=4)] + ) + + protocol.connection.acknowledge_received_data.assert_not_called() + protocol._flush.assert_awaited_once() + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_protocol_idle_requires_no_registered_streams() -> None: + task_group = DummyTaskGroup() + protocol = H2Protocol( + Mock(), + Config(), + WorkerContext(None), + task_group, # type: ignore[arg-type] + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) + stream = Mock(spec=QueuedStream) + stream.idle = True + + assert protocol.idle is True + protocol.streams[1] = stream + assert protocol.idle is False + + del protocol.streams[1] + assert protocol.idle is True + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_protocol_does_not_block_other_streams_on_slow_stream() -> None: + class BlockingStream: + idle = False + + def __init__(self) -> None: + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + if isinstance(event, Body): + self.started.set() + await self.release.wait() + + class RecordingStream: + idle = False + + def __init__(self) -> None: + self.body_received = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + if isinstance(event, Body): + self.body_received.set() + + task_group = DummyTaskGroup() + protocol = H2Protocol( + Mock(), + Config(), + WorkerContext(None), + task_group, # type: ignore[arg-type] + ConnectionState({}), + False, + None, + None, + AsyncMock(), + ) + protocol.connection.acknowledge_received_data = Mock() # type: ignore[method-assign] + protocol._flush = AsyncMock() # type: ignore[method-assign] + + slow_stream = BlockingStream() + fast_stream = RecordingStream() + protocol.streams[1] = QueuedStream(slow_stream, task_group, protocol.context) # type: ignore[arg-type] + protocol.streams[3] = QueuedStream(fast_stream, task_group, protocol.context) # type: ignore[arg-type] + + await protocol._handle_events( + [ + DataReceived(stream_id=1, data=b"a", flow_controlled_length=1), + DataReceived(stream_id=3, data=b"b", flow_controlled_length=1), + ] + ) + + await asyncio.wait_for(slow_stream.started.wait(), timeout=0.1) + await asyncio.wait_for(fast_stream.body_received.wait(), timeout=0.1) + protocol.connection.acknowledge_received_data.assert_called_once_with(1, 3) + + slow_stream.release.set() + await asyncio.wait_for(asyncio.sleep(0), timeout=0.1) + await asyncio.wait_for(asyncio.sleep(0), timeout=0.1) + assert protocol.connection.acknowledge_received_data.call_args_list == [call(1, 3), call(1, 1)] + + await protocol.streams[1].handle(StreamClosed(stream_id=1)) + await protocol.streams[3].handle(StreamClosed(stream_id=3)) + await asyncio.gather(*task_group.tasks) diff --git a/tests/protocol/test_h3.py b/tests/protocol/test_h3.py new file mode 100644 index 00000000..01564f11 --- /dev/null +++ b/tests/protocol/test_h3.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +import asyncio +from unittest.mock import AsyncMock, Mock + +import pytest + +from hypercorn.asyncio.worker_context import EventWrapper +from hypercorn.protocol.h3_send import H3SendClosedError, H3SendScheduler + + +class DummyTaskGroup: + def __init__(self) -> None: + self.tasks: list[asyncio.Task] = [] + + def spawn(self, func, *args) -> None: + self.tasks.append(asyncio.create_task(func(*args))) + + async def aclose(self) -> None: + for task in self.tasks: + task.cancel() + if self.tasks: + await asyncio.gather(*self.tasks, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_send_scheduler_batches_flush_across_ready_events() -> None: + connection = Mock() + flush = AsyncMock() + scheduler = H3SendScheduler(connection, EventWrapper, flush) + task_group = DummyTaskGroup() + closed = False + task_group.spawn(scheduler.run, lambda: closed) + + tasks = [ + asyncio.create_task(scheduler.headers(1, [(b":status", b"200")])), + asyncio.create_task(scheduler.data(1, b"hello")), + asyncio.create_task(scheduler.data(1, b"", end_stream=True)), + ] + await asyncio.gather(*tasks) + + assert connection.send_headers.call_count == 1 + assert connection.send_data.call_count == 2 + flush.assert_awaited_once() + + closed = True + await scheduler.wake() + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_send_scheduler_propagates_flush_errors() -> None: + connection = Mock() + flush = AsyncMock(side_effect=RuntimeError("boom")) + scheduler = H3SendScheduler(connection, EventWrapper, flush) + task_group = DummyTaskGroup() + closed = False + task_group.spawn(scheduler.run, lambda: closed) + + with pytest.raises(RuntimeError, match="boom"): + await scheduler.data(1, b"hello") + + closed = True + await scheduler.wake() + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_send_scheduler_rejects_operations_after_close() -> None: + connection = Mock() + flush = AsyncMock() + scheduler = H3SendScheduler(connection, EventWrapper, flush) + + await scheduler.close() + + with pytest.raises(H3SendClosedError, match="closed"): + await scheduler.data(1, b"hello") diff --git a/tests/protocol/test_http_stream.py b/tests/protocol/test_http_stream.py index 3f82a02b..ed52e1a4 100644 --- a/tests/protocol/test_http_stream.py +++ b/tests/protocol/test_http_stream.py @@ -117,6 +117,28 @@ async def test_handle_body(stream: HTTPStream) -> None: ] +@pytest.mark.asyncio +async def test_handle_body_before_request_is_ignored() -> None: + stream = HTTPStream( + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 + ) + stream.config._log = AsyncMock(spec=Logger) + + await stream.handle(Body(stream_id=1, data=b"data")) + + stream.send.assert_not_called() # type: ignore[attr-defined] + stream.config._log.access.assert_not_called() + + +@pytest.mark.asyncio +async def test_handle_body_memoryview(stream: HTTPStream) -> None: + await stream.handle(Body(stream_id=1, data=memoryview(b"data"))) # type: ignore[arg-type] + stream.app_put.assert_called() # type: ignore + assert stream.app_put.call_args_list == [ # type: ignore + call({"type": "http.request", "body": b"data", "more_body": True}) + ] + + @pytest.mark.asyncio async def test_handle_end_body(stream: HTTPStream) -> None: stream.app_put = AsyncMock() @@ -127,6 +149,20 @@ async def test_handle_end_body(stream: HTTPStream) -> None: ] +@pytest.mark.asyncio +async def test_handle_closed_before_request_does_not_log_or_disconnect() -> None: + stream = HTTPStream( + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 + ) + stream.config._log = AsyncMock(spec=Logger) + + await stream.handle(StreamClosed(stream_id=1)) + + assert stream.closed + stream.send.assert_not_called() # type: ignore[attr-defined] + stream.config._log.access.assert_not_called() + + @pytest.mark.asyncio async def test_handle_closed(stream: HTTPStream) -> None: await stream.handle( @@ -174,6 +210,55 @@ async def test_send_response(stream: HTTPStream) -> None: stream.config._log.access.assert_called() +@pytest.mark.asyncio +async def test_send_response_reuses_bytes_body(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send( + cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) + ) + body = b"Body" + await stream.app_send( + cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": body}) + ) + + sent_body = stream.send.call_args_list[1].args[0] + assert isinstance(sent_body, Body) + assert sent_body.data is body + + +@pytest.mark.asyncio +async def test_send_response_copies_non_bytes_body(stream: HTTPStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[], + raw_path=b"/?a=b", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send( + cast(HTTPResponseStartEvent, {"type": "http.response.start", "status": 200, "headers": []}) + ) + await stream.app_send( + cast(HTTPResponseBodyEvent, {"type": "http.response.body", "body": memoryview(b"Body")}) # type: ignore[arg-type] + ) + + sent_body = stream.send.call_args_list[1].args[0] + assert isinstance(sent_body, Body) + assert sent_body.data == b"Body" + + @pytest.mark.asyncio async def test_invalid_server_name(stream: HTTPStream) -> None: stream.config.server_names = ["hypercorn"] diff --git a/tests/protocol/test_queued_stream.py b/tests/protocol/test_queued_stream.py new file mode 100644 index 00000000..b78f9ab4 --- /dev/null +++ b/tests/protocol/test_queued_stream.py @@ -0,0 +1,209 @@ +from __future__ import annotations + +import asyncio + +import pytest + +from hypercorn.asyncio.worker_context import EventWrapper, WorkerContext +from hypercorn.protocol.events import Body, EndBody, StreamClosed +from hypercorn.protocol.queued_stream import QueuedStream, _QueuedEvent, _merge_queued_events + + +class DummyTaskGroup: + def __init__(self) -> None: + self.tasks: list[asyncio.Task] = [] + + def spawn(self, func, *args) -> None: + self.tasks.append(asyncio.create_task(func(*args))) + + async def aclose(self) -> None: + for task in self.tasks: + task.cancel() + if self.tasks: + await asyncio.gather(*self.tasks, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_queued_stream_backpressure_blocks_until_queue_has_space() -> None: + class BlockingStream: + idle = False + + def __init__(self) -> None: + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def handle(self, event: Body | EndBody | StreamClosed) -> None: + if isinstance(event, Body): + self.started.set() + await self.release.wait() + + task_group = DummyTaskGroup() + stream = BlockingStream() + queued = QueuedStream(stream, task_group, WorkerContext(None), max_queue_size=1) + + await queued.handle(Body(stream_id=1, data=b"one")) + await asyncio.wait_for(stream.started.wait(), timeout=0.1) + + await queued.handle(EndBody(stream_id=1)) + task = asyncio.create_task(queued.handle(EndBody(stream_id=1))) + await asyncio.sleep(0) + assert not task.done() + + stream.release.set() + await asyncio.wait_for(task, timeout=0.1) + + await queued.handle(StreamClosed(stream_id=1)) + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_queued_stream_zero_max_queue_size_is_unbounded() -> None: + class RecordingStream: + idle = False + + def __init__(self) -> None: + self.events: list[Body | StreamClosed] = [] + self.release = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + self.events.append(event) + if isinstance(event, Body): + await self.release.wait() + + task_group = DummyTaskGroup() + stream = RecordingStream() + queued = QueuedStream(stream, task_group, WorkerContext(None), max_queue_size=0) + + await queued.handle(Body(stream_id=1, data=b"one")) + await queued.handle(Body(stream_id=1, data=b"two")) + await queued.handle(Body(stream_id=1, data=b"three")) + + stream.release.set() + await queued.handle(StreamClosed(stream_id=1)) + await asyncio.gather(*task_group.tasks, return_exceptions=True) + + +@pytest.mark.asyncio +async def test_queued_stream_allows_oversized_first_event() -> None: + class BlockingStream: + idle = False + + def __init__(self) -> None: + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + if isinstance(event, Body): + self.started.set() + await self.release.wait() + + task_group = DummyTaskGroup() + stream = BlockingStream() + queued = QueuedStream( + stream, + task_group, + WorkerContext(None), + max_queue_size=1, + max_queue_bytes=4, + ) + + await asyncio.wait_for(queued.handle(Body(stream_id=1, data=b"abcdef")), timeout=0.1) + await asyncio.wait_for(stream.started.wait(), timeout=0.1) + + stream.release.set() + await queued.handle(StreamClosed(stream_id=1)) + await task_group.aclose() + + +@pytest.mark.asyncio +async def test_queued_stream_coalesces_consecutive_body_events() -> None: + class RecordingStream: + idle = False + + def __init__(self) -> None: + self.events: list[Body | StreamClosed] = [] + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + self.events.append(event) + if isinstance(event, Body) and event.data == b"one": + self.started.set() + await self.release.wait() + + callbacks: list[str] = [] + + async def callback(value: str) -> None: + callbacks.append(value) + + task_group = DummyTaskGroup() + stream = RecordingStream() + queued = QueuedStream(stream, task_group, WorkerContext(None), max_queue_size=1) + + await queued.handle(Body(stream_id=1, data=b"one")) + await asyncio.wait_for(stream.started.wait(), timeout=0.1) + + await queued.handle(Body(stream_id=1, data=b"two"), lambda: callback("two")) + task = asyncio.create_task(queued.handle(Body(stream_id=1, data=b"three"), lambda: callback("three"))) + await asyncio.sleep(0) + assert task.done() + + stream.release.set() + await queued.handle(StreamClosed(stream_id=1)) + await asyncio.gather(*task_group.tasks, return_exceptions=True) + + assert stream.events == [ + Body(stream_id=1, data=b"one"), + Body(stream_id=1, data=b"twothree"), + StreamClosed(stream_id=1), + ] + assert callbacks == ["two", "three"] + + +@pytest.mark.asyncio +async def test_queued_stream_retains_chunks_until_dispatch() -> None: + class RecordingStream: + idle = False + + def __init__(self) -> None: + self.events: list[Body | StreamClosed] = [] + self.started = asyncio.Event() + self.release = asyncio.Event() + + async def handle(self, event: Body | StreamClosed) -> None: + self.events.append(event) + if isinstance(event, Body) and event.data == b"one": + self.started.set() + await self.release.wait() + + task_group = DummyTaskGroup() + stream = RecordingStream() + queued = QueuedStream(stream, task_group, WorkerContext(None), max_queue_size=1) + + await queued.handle(Body(stream_id=1, data=b"one")) + await asyncio.wait_for(stream.started.wait(), timeout=0.1) + + await queued.handle(Body(stream_id=1, data=b"two")) + await queued.handle(Body(stream_id=1, data=b"three")) + + assert queued._queue[-1].chunks == [b"two", b"three"] + + stream.release.set() + await queued.handle(StreamClosed(stream_id=1)) + await asyncio.gather(*task_group.tasks, return_exceptions=True) + + assert stream.events[1] == Body(stream_id=1, data=b"twothree") + + +def test_queued_stream_chunk_size_accounting() -> None: + first_event = Body(stream_id=1, data=b"aa") + second_event = Body(stream_id=1, data=b"bbb") + + first_queued_event = _QueuedEvent.create(first_event) + second_queued_event = _QueuedEvent.create(second_event) + + assert first_queued_event.size_bytes == 2 + assert second_queued_event.size_bytes == 3 + assert _merge_queued_events(first_queued_event, second_queued_event) is True + assert first_queued_event.size_bytes == 5 + assert first_queued_event.chunks == [b"aa", b"bbb"] diff --git a/tests/protocol/test_ws_stream.py b/tests/protocol/test_ws_stream.py index 6b656440..8f181273 100644 --- a/tests/protocol/test_ws_stream.py +++ b/tests/protocol/test_ws_stream.py @@ -6,6 +6,7 @@ import pytest import pytest_asyncio +import wsproto.connection from wsproto.events import BytesMessage, TextMessage from hypercorn.asyncio.task_group import TaskGroup @@ -230,6 +231,29 @@ async def test_handle_data_before_acceptance(stream: WSStream) -> None: ] +@pytest.mark.asyncio +async def test_handle_data_before_request_sends_error_without_crashing() -> None: + stream = WSStream( + AsyncMock(), Config(), WorkerContext(None), AsyncMock(), False, None, None, AsyncMock(), 1 + ) + stream.config._log = AsyncMock(spec=Logger) + + await stream.handle(Data(stream_id=1, data=b"X")) + + assert stream.closed + assert stream.send.call_args_list == [ # type: ignore[attr-defined] + call( + Response( + stream_id=1, + headers=[(b"content-length", b"0"), (b"connection", b"close")], + status_code=400, + ) + ), + call(EndBody(stream_id=1)), + ] + stream.config._log.access.assert_not_called() + + @pytest.mark.asyncio async def test_handle_connection(stream: WSStream) -> None: await stream.handle( @@ -251,6 +275,31 @@ async def test_handle_connection(stream: WSStream) -> None: ] +@pytest.mark.asyncio +async def test_handle_remote_close_sends_end_data_before_stream_closed(stream: WSStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[(b"sec-websocket-version", b"13")], + raw_path=b"/", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) + stream.send.reset_mock() # type: ignore[attr-defined] + + client = wsproto.connection.Connection(wsproto.ConnectionType.CLIENT) + await stream.handle(Data(stream_id=1, data=client.send(wsproto.events.CloseConnection(code=1000)))) + + assert stream.send.call_args_list == [ # type: ignore + call(Data(stream_id=1, data=b"\x88\x02\x03\xe8")), + call(EndData(stream_id=1)), + call(StreamClosed(stream_id=1)), + ] + + @pytest.mark.asyncio async def test_handle_closed(stream: WSStream) -> None: await stream.handle(StreamClosed(stream_id=1)) @@ -339,6 +388,32 @@ async def test_send_reject(stream: WSStream) -> None: stream.config._log.access.assert_called() +@pytest.mark.asyncio +async def test_send_reject_reuses_bytes_body(stream: WSStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[(b"sec-websocket-version", b"13")], + raw_path=b"/", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send( + cast( + WebsocketResponseStartEvent, + {"type": "websocket.http.response.start", "status": 200, "headers": []}, + ), + ) + body = b"Body" + await stream.app_send( + cast(WebsocketResponseBodyEvent, {"type": "websocket.http.response.body", "body": body}) + ) + + assert stream.send.call_args_list[1].args[0].data is body # type: ignore[attr-defined] + + @pytest.mark.asyncio async def test_invalid_server_name(stream: WSStream) -> None: stream.config.server_names = ["hypercorn"] @@ -441,6 +516,29 @@ async def test_send_connection(stream: WSStream) -> None: ] +@pytest.mark.asyncio +async def test_send_connection_reuses_bytes_payload(stream: WSStream) -> None: + await stream.handle( + Request( + stream_id=1, + http_version="2", + headers=[(b"sec-websocket-version", b"13")], + raw_path=b"/", + method="GET", + state=ConnectionState({}), + ) + ) + await stream.app_send(cast(WebsocketAcceptEvent, {"type": "websocket.accept"})) + stream._send_wsproto_event = AsyncMock() # type: ignore[method-assign] + payload = b"abc" + + await stream.app_send(cast(WebsocketSendEvent, {"type": "websocket.send", "bytes": payload})) + + event = stream._send_wsproto_event.await_args.args[0] # type: ignore[attr-defined] + assert isinstance(event, BytesMessage) + assert event.data is payload + + @pytest.mark.asyncio async def test_pings(stream: WSStream) -> None: event_loop: asyncio.AbstractEventLoop = asyncio.get_running_loop() diff --git a/tests/test___main__.py b/tests/test___main__.py index f24b3462..362e73a3 100644 --- a/tests/test___main__.py +++ b/tests/test___main__.py @@ -47,6 +47,7 @@ def test_load_config(monkeypatch: MonkeyPatch) -> None: ("--worker-class", "trio", "worker_class"), ("--keep-alive", 20, "keep_alive_timeout"), ("--keyfile", "/path", "keyfile"), + ("--max-requests-jitter", 7, "max_requests_jitter"), ("--pid", "/path", "pid_path"), ("--root-path", "/path", "root_path"), ("--workers", 2, "workers"), diff --git a/tests/test_app_wrappers.py b/tests/test_app_wrappers.py index ca96a7d1..a0ffd33d 100644 --- a/tests/test_app_wrappers.py +++ b/tests/test_app_wrappers.py @@ -10,22 +10,22 @@ from hypercorn.app_wrappers import _build_environ, InvalidPathError, WSGIWrapper from hypercorn.typing import ASGIReceiveEvent, ASGISendEvent, ConnectionState, HTTPScope - - -def echo_body(environ: dict, start_response: Callable) -> list[bytes]: - status = "200 OK" - output = environ["wsgi.input"].read() - headers = [ - ("Content-Type", "text/plain; charset=utf-8"), - ("Content-Length", str(len(output))), - ] - start_response(status, headers) - return [output] +from .wsgi_applications import ( + wsgi_app_echo_body, + wsgi_app_generator, + wsgi_app_generator_delayed_start_response, + wsgi_app_generator_multiple_start_response_after_body, + wsgi_app_generator_no_body, + wsgi_app_multiple_start_response_no_exc_info, + wsgi_app_no_body, + wsgi_app_no_start_response, + wsgi_app_simple, +) @pytest.mark.trio async def test_wsgi_trio() -> None: - app = WSGIWrapper(echo_body, 2**16) + app = WSGIWrapper(wsgi_app_echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -52,12 +52,12 @@ async def _send(message: ASGISendEvent) -> None: await app(scope, receive_channel.receive, _send, trio.to_thread.run_sync, trio.from_thread.run) assert messages == [ + {"body": bytearray(b""), "type": "http.response.body", "more_body": True}, { "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], "status": 200, "type": "http.response.start", }, - {"body": bytearray(b""), "type": "http.response.body", "more_body": True}, {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, ] @@ -83,7 +83,7 @@ def _call_soon(func: Callable, *args: Any) -> Any: @pytest.mark.asyncio async def test_wsgi_asyncio() -> None: - app = WSGIWrapper(echo_body, 2**16) + app = WSGIWrapper(wsgi_app_echo_body, 2**16) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -100,21 +100,24 @@ async def test_wsgi_asyncio() -> None: "extensions": {}, "state": ConnectionState({}), } - messages = await _run_app(app, scope) + messages = await _run_app(app, scope, b"Hello, world!") assert messages == [ { - "headers": [(b"content-type", b"text/plain; charset=utf-8"), (b"content-length", b"0")], + "headers": [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"content-length", b"13"), + ], "status": 200, "type": "http.response.start", }, - {"body": bytearray(b""), "type": "http.response.body", "more_body": True}, - {"body": bytearray(b""), "type": "http.response.body", "more_body": False}, + {"body": b"Hello, world!", "type": "http.response.body", "more_body": True}, + {"body": b"", "type": "http.response.body", "more_body": False}, ] @pytest.mark.asyncio async def test_max_body_size() -> None: - app = WSGIWrapper(echo_body, 4) + app = WSGIWrapper(wsgi_app_echo_body, 4) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -138,13 +141,9 @@ async def test_max_body_size() -> None: ] -def no_start_response(environ: dict, start_response: Callable) -> list[bytes]: - return [b"result"] - - @pytest.mark.asyncio async def test_no_start_response() -> None: - app = WSGIWrapper(no_start_response, 2**16) + app = WSGIWrapper(wsgi_app_no_start_response, 2**16) scope: HTTPScope = { "http_version": "1.1", "asgi": {}, @@ -206,3 +205,151 @@ def test_build_environ_root_path() -> None: } with pytest.raises(InvalidPathError): _build_environ(scope, b"") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("wsgi_app", [wsgi_app_simple, wsgi_app_generator]) +async def test_wsgi_protocol(wsgi_app: Callable) -> None: + app = WSGIWrapper(wsgi_app, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + + messages = await _run_app(app, scope) + assert messages == [ + { + "headers": [(b"x-test-header", b"Test-Value")], + "status": 200, + "type": "http.response.start", + }, + {"body": b"Hello, ", "type": "http.response.body", "more_body": True}, + {"body": b"world!", "type": "http.response.body", "more_body": True}, + {"body": b"", "type": "http.response.body", "more_body": False}, + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("wsgi_app", [wsgi_app_no_body, wsgi_app_generator_no_body]) +async def test_wsgi_protocol_no_body(wsgi_app: Callable) -> None: + app = WSGIWrapper(wsgi_app, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + + messages = await _run_app(app, scope) + assert messages == [ + { + "headers": [(b"x-test-header", b"Test-Value")], + "status": 200, + "type": "http.response.start", + }, + {"body": b"", "type": "http.response.body", "more_body": False}, + ] + + +@pytest.mark.asyncio +async def test_wsgi_protocol_overwrite_start_response() -> None: + app = WSGIWrapper(wsgi_app_generator_delayed_start_response, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + + messages = await _run_app(app, scope) + assert messages == [ + {"body": b"", "type": "http.response.body", "more_body": True}, + { + "headers": [(b"x-test-header", b"New-Value")], + "status": 500, + "type": "http.response.start", + }, + {"body": b"Hello, ", "type": "http.response.body", "more_body": True}, + {"body": b"world!", "type": "http.response.body", "more_body": True}, + {"body": b"", "type": "http.response.body", "more_body": False}, + ] + + +@pytest.mark.asyncio +async def test_wsgi_protocol_multiple_start_response_no_exc_info() -> None: + app = WSGIWrapper(wsgi_app_multiple_start_response_no_exc_info, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + + with pytest.raises(RuntimeError): + await _run_app(app, scope) + + +@pytest.mark.asyncio +async def test_wsgi_protocol_multiple_start_response_after_body() -> None: + app = WSGIWrapper(wsgi_app_generator_multiple_start_response_after_body, 2**16) + scope: HTTPScope = { + "http_version": "1.1", + "asgi": {}, + "method": "GET", + "headers": [], + "path": "/", + "root_path": "/", + "query_string": b"a=b", + "raw_path": b"/", + "scheme": "http", + "type": "http", + "client": ("localhost", 80), + "server": None, + "extensions": {}, + "state": ConnectionState({}), + } + + with pytest.raises(ValueError): + await _run_app(app, scope) diff --git a/tests/test_config.py b/tests/test_config.py index 83fa2db6..1dacc335 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,7 +4,7 @@ import socket import ssl import sys -from unittest.mock import Mock, NonCallableMock +from unittest.mock import Mock, NonCallableMock, call import pytest from _pytest.monkeypatch import MonkeyPatch @@ -150,5 +150,23 @@ def test_response_headers(monkeypatch: MonkeyPatch) -> None: (b"date", b"Sat, 02 Dec 2017 15:43:15 GMT"), (b"server", b"hypercorn-test"), ] + assert config.response_headers("other") == [ + (b"date", b"Sat, 02 Dec 2017 15:43:15 GMT"), + (b"server", b"hypercorn-other"), + ] config.include_server_header = False assert config.response_headers("test") == [(b"date", b"Sat, 02 Dec 2017 15:43:15 GMT")] + + +def test_response_headers_cache_date_per_second(monkeypatch: MonkeyPatch) -> None: + timestamps = iter([1_512_229_395.1, 1_512_229_395.9, 1_512_229_396.0]) + formatter = Mock(side_effect=["Sat, 02 Dec 2017 15:43:15 GMT", "Sat, 02 Dec 2017 15:43:16 GMT"]) + monkeypatch.setattr(hypercorn.config, "time", lambda: next(timestamps)) + monkeypatch.setattr(hypercorn.config, "format_date_time", formatter) + + config = Config() + + assert config.response_headers("test")[0] == (b"date", b"Sat, 02 Dec 2017 15:43:15 GMT") + assert config.response_headers("test")[0] == (b"date", b"Sat, 02 Dec 2017 15:43:15 GMT") + assert config.response_headers("test")[0] == (b"date", b"Sat, 02 Dec 2017 15:43:16 GMT") + assert formatter.call_args_list == [call(1_512_229_395), call(1_512_229_396)] diff --git a/tests/test_import_path.py b/tests/test_import_path.py new file mode 100644 index 00000000..7e9c738d --- /dev/null +++ b/tests/test_import_path.py @@ -0,0 +1,11 @@ +from __future__ import annotations + +from pathlib import Path + +import hypercorn + + +def test_imports_use_local_src_tree() -> None: + package_path = Path(hypercorn.__file__).resolve() + assert package_path.parents[2] == Path(__file__).resolve().parent.parent + assert package_path.parts[-3:-1] == ("src", "hypercorn") diff --git a/tests/test_logging.py b/tests/test_logging.py index bb6f32b8..1b5903bb 100644 --- a/tests/test_logging.py +++ b/tests/test_logging.py @@ -2,6 +2,7 @@ import logging import os +import re import time import pytest @@ -42,6 +43,8 @@ def test_access_logger_init( else: assert isinstance(logger.access_logger.handlers[0], expected_handler_type) + assert logger.access_log_atoms == frozenset(re.findall(r"%\(([^)]+)\)s", logger.access_log_format)) + @pytest.mark.parametrize( "level, expected", @@ -110,6 +113,14 @@ def test_access_log_environ_atoms(http_scope: HTTPScope, response: ResponseSumma assert atoms["{random}e"] == "Environ" +def test_access_log_required_atoms_precompute(http_scope: HTTPScope, response: ResponseSummary) -> None: + os.environ["Random"] = "Environ" + atoms = AccessLogAtoms(http_scope, response, 0, frozenset({"h", "{random}e"})) + assert atoms["h"] == "127.0.0.1:80" + assert atoms["{random}e"] == "Environ" + assert atoms["{x-hypercorn}i"] == "Hypercorn" + + def test_nonstandard_status_code(http_scope: HTTPScope) -> None: atoms = AccessLogAtoms(http_scope, {"status": 441, "headers": []}, 0) assert atoms["st"] == "" diff --git a/tests/test_utils.py b/tests/test_utils.py index f9365012..520d7a56 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -5,11 +5,14 @@ import pytest +from hypercorn.app_wrappers import ASGIWrapper from hypercorn.typing import Scope from hypercorn.utils import ( + NoAppError, build_and_validate_headers, filter_pseudo_headers, is_asgi, + load_application, suppress_body, ) @@ -80,3 +83,13 @@ def test_filter_pseudo_headers_no_authority() -> None: [(b"host", b"quart"), (b":path", b"/"), (b"user-agent", b"something")] ) assert result == [(b"host", b"quart"), (b"user-agent", b"something")] + + +def test_load_application_resolves_nested_attributes() -> None: + wrapper = load_application("tests.assets.load_apps:nested.app", 1024) + assert isinstance(wrapper, ASGIWrapper) + + +def test_load_application_missing_nested_attribute() -> None: + with pytest.raises(NoAppError): + load_application("tests.assets.load_apps:nested.missing", 1024) diff --git a/tests/trio/test_lifespan.py b/tests/trio/test_lifespan.py index 1dbc0086..4348a216 100644 --- a/tests/trio/test_lifespan.py +++ b/tests/trio/test_lifespan.py @@ -38,6 +38,17 @@ async def _lifespan_failure( break +async def _slow_shutdown( + scope: Scope, receive: ASGIReceiveCallable, send: ASGISendCallable +) -> None: + while True: + message = await receive() + if message["type"] == "lifespan.startup": + await send({"type": "lifespan.startup.complete"}) + elif message["type"] == "lifespan.shutdown": + await trio.sleep(0.02) + + @pytest.mark.trio async def test_startup_failure() -> None: lifespan = Lifespan(ASGIWrapper(_lifespan_failure), Config(), {}) @@ -47,3 +58,15 @@ async def test_startup_failure() -> None: await lifespan.wait_for_startup() except ExceptionGroup as error: assert error.subgroup(LifespanFailureError) is not None + + +@pytest.mark.trio +async def test_shutdown_timeout_error(nursery: trio._core._run.Nursery) -> None: + config = Config() + config.shutdown_timeout = 0.01 + lifespan = Lifespan(ASGIWrapper(_slow_shutdown), config, {}) + nursery.start_soon(lifespan.handle_lifespan) + await lifespan.wait_for_startup() + with pytest.raises(LifespanTimeoutError) as exc_info: + await lifespan.wait_for_shutdown() + assert str(exc_info.value).startswith("Timeout whilst awaiting shutdown") diff --git a/tests/trio/test_worker_context.py b/tests/trio/test_worker_context.py new file mode 100644 index 00000000..07ef4ecf --- /dev/null +++ b/tests/trio/test_worker_context.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +import pytest +import trio + +from hypercorn.trio.worker_context import EventWrapper + + +@pytest.mark.trio +async def test_event_wrapper_waiters_survive_clear() -> None: + event = EventWrapper() + waiter_released = trio.Event() + + async def waiter() -> None: + await event.wait() + waiter_released.set() + + async with trio.open_nursery() as nursery: + nursery.start_soon(waiter) + await trio.sleep(0) + await event.clear() + assert not waiter_released.is_set() + await event.set() + with trio.fail_after(1): + await waiter_released.wait() + nursery.cancel_scope.cancel() + + +@pytest.mark.trio +async def test_event_wrapper_clear_resets_is_set() -> None: + event = EventWrapper() + + await event.set() + assert event.is_set() + + await event.clear() + + assert not event.is_set() diff --git a/tests/wsgi_applications.py b/tests/wsgi_applications.py new file mode 100644 index 00000000..a030552b --- /dev/null +++ b/tests/wsgi_applications.py @@ -0,0 +1,193 @@ +import sys +from collections.abc import Callable +from typing import Generator + + +def wsgi_app_echo_body(environ: dict, start_response: Callable) -> list[bytes]: + """Simple WSGI application which returns the request body as the response body.""" + status = "200 OK" + output = environ["wsgi.input"].read() + headers = [ + ("Content-Type", "text/plain; charset=utf-8"), + ("Content-Length", str(len(output))), + ] + start_response(status, headers) + return [output] + + +def wsgi_app_no_start_response(environ: dict, start_response: Callable) -> list[bytes]: + """Invalid WSGI application which fails to call start_response""" + return [b"result"] + + +def wsgi_app_simple(environ: dict, start_response: Callable) -> list[bytes]: + """ + A basic WSGI Application. + + It is valid to send multiple chunks of data, but the status code and headers + must be sent before the first non-empty chunk of body data is sent. + + Therefore, the headers must be sent immediately after the first non-empty + byte string is returned, but before continuing to iterate further. While sending + the headers before begining iteration would technically work in this case, + this violates the WSGI spec and further examples prove that this behavior + is actually invalid. + """ + start_response("200 OK", [("X-Test-Header", "Test-Value")]) + return [b"Hello, ", b"world!"] + + +def wsgi_app_generator(environ: dict, start_response: Callable) -> Generator[bytes, None, None]: + """ + A synchronous generator usable as a valid WSGI Application. + + Notably, the WSGI specification ensures only that start_response() is called + before the first item is returned from the iterator. It does not have to + be immediately called when app(environ, start_response) is called. + + Using a generator for a WSGI app will delay calling start_response() until after + something begins iterating on it, so only invoking the app and not iterating on + the returned iterable will not be sufficient to get the status code and headers. + + It is also valid to send multiple chunks of data, but the status code and headers + must be sent before the first non-empty chunk of body data is sent. + + Therefore it is not valid to send the status code and headers before iterating on + the returned generator. It is only valid to send status code and headers during + iteration of the generator, immediately after the first non-empty byte + string is returned, but before continuing to iterate further. + """ + start_response("200 OK", [("X-Test-Header", "Test-Value")]) + yield b"Hello, " + yield b"world!" + + +def wsgi_app_no_body(environ: dict, start_response: Callable) -> list[bytes]: + """ + A WSGI Application that does not yield up any body chunks when iterated on. + + This is most common when supporting HTTP methods such as HEAD, which is identical + to GET except that the server MUST NOT return a message body in the response. + + The iterable returned by this app will have no contents, immediately exiting + any for loops attempting to iterate on it. Even though no body was returned + from the application, this is still a valid HTTP request and MUST send the + status code and headers as the response. Failing to do so violates the + WSGI, ASGI, and HTTP specifications. + + Therefore, the status code and headers must be sent after the iteration completes, + as it is not valid to send them only during iteration. If headers are only sent + within the body of the for loop, this application will cause the server to fail + to send this information at all. However, care must be taken to check + whether the status code and headers were already sent during the iteration process, + as they may have been sent during the iteration process for applications with + non-empty bodies. If this isn't accounted for they will be sent twice in error. + """ + start_response("200 OK", [("X-Test-Header", "Test-Value")]) + return [] + + +def wsgi_app_generator_no_body( + environ: dict, start_response: Callable +) -> Generator[bytes, None, None]: + """ + A synchronous generator usable as a valid WSGI Application, which + does not yield up any body chunks when iterated on. + + This is a very complicated edge case. It is most commonly found when building a + generator based WSGI app with support for HTTP methods such as HEAD, which is + identical to GET except that the server MUST NOT return a message body in the response. + + 1. The application is subject to the same delay in calling start_response until + after the server has begun iterating on the returned generator object. + + 2. The status code and headers are also not available during iteration, as the + empty generator will immediately end any for loops that attempt to iterate on it. + + 3. Even though no body was returned from the application, this is still a valid + HTTP request and MUST send the status code and headers as the response. Failing + to do so violates the WSGI, ASGI, and HTTP specifications. + + Therefore, the status code and headers must be sent after the iteration completes, + as it is not valid to send them only during iteration. If headers are only sent + within the body of the for loop, this application will cause the server to fail + to send this information at all. However, care must be taken to check + whether the status code and headers were already sent during the iteration process, + as they may have been sent during the iteration process for applications with + non-empty bodies. If this isn't accounted for they will be sent twice in error. + """ + start_response("200 OK", [("X-Test-Header", "Test-Value")]) + if False: + yield b"" # Unreachable yield makes this an empty generator # noqa + + +def wsgi_app_generator_delayed_start_response( + environ: dict, start_response: Callable +) -> Generator[bytes, None, None]: + """ + A synchronous generator usable as a valid WSGI Application, which calls start_response + a second time after yielding up empty chunks of body. + + This application exercises the ability for WSGI apps to change their status code + right up until the last possible second before the first non-empty chunk of body is + sent. The status code and headers must be buffered until the first non-empty chunk of body + is yielded by this generator, and should be overwritable until that time. + """ + # Initial 200 OK status that will be overwritten before any non-empty chunks of body are sent + start_response("200 OK", [("X-Test-Header", "Old-Value")]) + yield b"" + + try: + raise ValueError + except ValueError: + # start_response may be called more than once before the first non-empty byte string + # is yielded by this generator. However, it is a fatal error to call start_response() + # a second time without passing an exception tuple in via the exc_info argument. + start_response( + "500 Internal Server Error", [("X-Test-Header", "New-Value")], exc_info=sys.exc_info() + ) + + yield b"Hello, " + yield b"world!" + + +def wsgi_app_multiple_start_response_no_exc_info( + environ: dict, start_response: Callable +) -> list[bytes]: + """ + An invalid WSGI Application, which calls start_response a second time + without passing an exception tuple in via the exc_info argument. + + This is considered a fatal error in the WSGI specification and should raise an exception. + """ + + # Calling start_response multiple times without exc_info should raise an error + start_response("200 OK", []) + start_response("202 Accepted", []) + return [] + + +def wsgi_app_generator_multiple_start_response_after_body( + environ: dict, start_response: Callable +) -> Generator[bytes, None, None]: + """ + An invalid WSGI Application, which calls start_response a second time + after the first non-empty byte string is returned. This should reraise the exception + as the headers and status code have already been sent. + + This is considered a fatal error in the WSGI specification and should raise an exception. + """ + + # Calling start_response multiple times without exc_info should raise an error + start_response("200 OK", []) + yield b"Hello, world!" + + try: + raise ValueError + except ValueError: + # start_response may not be called again after the first non-empty byte string is returned + # + # It is a fatal error to call start_response() a second time without passing an exception + # tuple in via the exc_info argument, so ensure we do that to avoid raising the wrong + # exception. + start_response("500 Internal Server Error", [], exc_info=sys.exc_info())