Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_lambda_powertools/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,6 @@

# Idempotency constants
IDEMPOTENCY_DISABLED_ENV: str = "POWERTOOLS_IDEMPOTENCY_DISABLED"

# Circuit breaker constants
CIRCUIT_BREAKER_DISABLED_ENV: str = "POWERTOOLS_CIRCUIT_BREAKER_DISABLED"
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Circuit Breaker utility for protecting unhealthy downstream dependencies.

!!! warning "Alpha / experimental"
This utility is published under the `_alpha` namespace while we collect
feedback. The public API may change in a backwards-incompatible way before it
is promoted to GA. Pin your version and follow the tracking discussion before
relying on it in production.
"""

from aws_lambda_powertools.utilities.circuit_breaker_alpha.circuit_breaker import circuit_breaker
from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import (
CircuitBreakerConfigError,
CircuitBreakerError,
CircuitBreakerOpenError,
CircuitBreakerPersistenceError,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import (
CircuitInfo,
CircuitState,
CircuitTransition,
)

__all__ = (
"circuit_breaker",
"CircuitBreakerConfig",
"CircuitInfo",
"CircuitState",
"CircuitTransition",
"CircuitBreakerError",
"CircuitBreakerOpenError",
"CircuitBreakerConfigError",
"CircuitBreakerPersistenceError",
)
203 changes: 203 additions & 0 deletions aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Orchestrator for the Circuit Breaker utility.

:class:`CircuitBreakerHandler` owns the state machine and the per-environment failure
counter; the persistence layer owns the shared truth. This split keeps the healthy
path write-free: failures are counted locally and only persisted on a state transition.
"""

from __future__ import annotations

import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import CircuitBreakerOpenError
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitState, CircuitTransition

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.base import (
CircuitBreakerPersistenceLayer,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitInfo

logger = logging.getLogger(__name__)

# Per-environment, per-circuit consecutive counters. Module-level so they survive across
# invocations within the same execution environment, the same way idempotency caches do.
_LOCAL_FAILURES: dict[str, int] = {}
_LOCAL_SUCCESSES: dict[str, int] = {}

# Stable per-environment identifier used to claim the half-open probe lock.
_ENVIRONMENT_ID = uuid.uuid4().hex


class CircuitBreakerHandler:
"""
Drive a single protected call through the circuit breaker state machine.

A new handler is created per invocation by the decorator. It reads the shared state,
routes the call (run, short-circuit, or probe), and records the outcome.

Parameters
----------
function : Callable
The protected function.
name : str
Circuit name.
config : CircuitBreakerConfig
Circuit configuration.
persistence_store : CircuitBreakerPersistenceLayer
Shared state store.
on_circuit_open : Callable | None
Callback invoked with the protected call's own ``*args``/``**kwargs`` plus a
trailing ``circuit`` keyword argument when the circuit is open. If ``None``, an
open circuit raises :class:`CircuitBreakerOpenError`.
function_args : tuple
Positional arguments the protected function was called with.
function_kwargs : dict
Keyword arguments the protected function was called with.
"""

def __init__(
self,
function: Callable,
name: str,
config: CircuitBreakerConfig,
persistence_store: CircuitBreakerPersistenceLayer,
on_circuit_open: Callable | None = None,
on_transition: Callable | None = None,
function_args: tuple | None = None,
function_kwargs: dict | None = None,
):
self.function = function
self.name = name
self.config = config
self.on_circuit_open = on_circuit_open
self.on_transition = on_transition
self.fn_args = function_args or ()
self.fn_kwargs = function_kwargs or {}

persistence_store.configure(config=config, circuit_name=name)
self.persistence_store = persistence_store

def handle(self) -> Any:
"""
Evaluate the circuit and route the call.

Returns
-------
Any
The protected function's result when the call runs, or the
``on_circuit_open`` callback's return value when the circuit is open.

Raises
------
CircuitBreakerOpenError
If the circuit is open and no callback is registered.
"""
record = self.persistence_store.get_state(self.name)

if record.state == CircuitState.CLOSED:
return self._call_closed()

if record.state == CircuitState.OPEN:
# ``opened_at`` may legitimately be 0 (epoch); treat only None as missing.
opened_at = record.opened_at if record.opened_at is not None else self._now()
if self._now() >= opened_at + self.config.recovery_timeout:
# Recovery window elapsed: try to become the single prober.
if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, opened_at):
self._notify(CircuitState.OPEN, CircuitState.HALF_OPEN, opened_at=opened_at)
return self._call_probe()
return self._open_response(record.to_circuit_info())

# HALF_OPEN: only the environment that owns the probe lock runs.
if record.half_open_owner == _ENVIRONMENT_ID:
return self._call_probe()
return self._open_response(record.to_circuit_info())

def _call_closed(self) -> Any:
"""Run the protected call while the circuit is closed, tracking failures."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
failures = _LOCAL_FAILURES.get(self.name, 0) + 1
_LOCAL_FAILURES[self.name] = failures
if failures >= self.config.failure_threshold:
logger.debug("Circuit '%s' tripping CLOSED to OPEN after %d failures.", self.name, failures)
opened_at = self._now()
self.persistence_store.save_open(self.name, failure_count=failures, opened_at=opened_at)
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.CLOSED, CircuitState.OPEN, opened_at=opened_at)
raise
else:
_LOCAL_FAILURES[self.name] = 0
return result

def _call_probe(self) -> Any:
"""Run a probe during half-open, closing or reopening based on the outcome."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
logger.debug("Circuit '%s' probe failed; reopening.", self.name)
opened_at = self._now()
self.persistence_store.save_reopen(self.name, opened_at=opened_at)
_LOCAL_SUCCESSES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.OPEN, opened_at=opened_at)
raise
else:
successes = _LOCAL_SUCCESSES.get(self.name, 0) + 1
_LOCAL_SUCCESSES[self.name] = successes
if successes >= self.config.success_threshold:
logger.debug("Circuit '%s' closing after %d probe successes.", self.name, successes)
self.persistence_store.save_closed(self.name)
_LOCAL_SUCCESSES[self.name] = 0
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.CLOSED)
return result

def _open_response(self, circuit: CircuitInfo) -> Any:
"""Produce the response for an open circuit: callback result or raise."""
if self.on_circuit_open is not None:
# Forward the protected call's arguments unchanged: positional stay positional,
# keyword stay keyword. The circuit snapshot is passed as a keyword argument so
# it never collides with positionalized kwargs nor depends on dict ordering.
return self.on_circuit_open(*self.fn_args, **self.fn_kwargs, circuit=circuit)
raise CircuitBreakerOpenError(
f"Circuit '{self.name}' is open.",
circuit=circuit,
)

def _notify(self, from_state: CircuitState, to_state: CircuitState, opened_at: int | None = None) -> None:
"""
Fire the ``on_transition`` hook for a state change.

Called only on real transitions, never on the hot path. Any exception the hook
raises is swallowed and logged: observability must never break the protected call.
"""
if self.on_transition is None:
return
try:
self.on_transition(
CircuitTransition(
circuit_name=self.name,
from_state=from_state,
to_state=to_state,
opened_at=opened_at,
),
)
except Exception:
logger.warning("on_transition hook for circuit '%s' raised; ignoring.", self.name, exc_info=True)

@staticmethod
def _now() -> int:
"""Current unix timestamp in seconds."""
return int(datetime.datetime.now().timestamp())
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Primary interface for the Circuit Breaker utility.
"""

from __future__ import annotations

import functools
import logging
import os
import warnings
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import strtobool
from aws_lambda_powertools.utilities.circuit_breaker_alpha.base import CircuitBreakerHandler
from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.warnings import PowertoolsUserWarning

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.base import (
CircuitBreakerPersistenceLayer,
)

logger = logging.getLogger(__name__)


def circuit_breaker(
name: str,
persistence_store: CircuitBreakerPersistenceLayer,
on_circuit_open: Callable | None = None,
on_transition: Callable | None = None,
config: CircuitBreakerConfig | None = None,
) -> Callable:
"""
Protect a function that calls an unhealthy-prone downstream with a circuit breaker.

Wrap the function that makes the downstream call, not the whole Lambda handler, so a
tripped circuit reflects one dependency rather than unrelated handler logic.

When the circuit is open the protected function is not called. Instead, if an
``on_circuit_open`` callback is registered it runs and its return value becomes the
call's result; otherwise :class:`CircuitBreakerOpenError` is raised.

Parameters
----------
name : str
Unique circuit name. Each name is an independent circuit; a function calling
several backends should use one circuit per backend.
persistence_store : CircuitBreakerPersistenceLayer
Shared state store (for example ``CircuitBreakerDynamoDBPersistence``).
on_circuit_open : Callable | None
Called when the circuit is open, with the protected function's own arguments
(positional stay positional, keyword stay keyword) plus a trailing ``circuit``
keyword argument carrying a ``CircuitInfo``. Its return value becomes the call's
result. If ``None``, an open circuit raises ``CircuitBreakerOpenError``.
on_transition : Callable | None
Called with a single ``CircuitTransition`` argument whenever the circuit changes
state (open, probe, close, reopen). Fires only on transitions, never on the
per-invocation hot path, so it is a safe place to emit a CloudWatch metric. Any
exception it raises is swallowed and logged so observability never breaks the
protected call.
config : CircuitBreakerConfig | None
Tunables. Defaults to ``CircuitBreakerConfig()`` when omitted.

Returns
-------
Callable
The decorated function.

Example
-------
**Protect a payment backend, buffering rejected requests**

from aws_lambda_powertools.utilities.circuit_breaker_alpha import circuit_breaker, CircuitInfo
from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence import (
CircuitBreakerDynamoDBPersistence,
)

persistence = CircuitBreakerDynamoDBPersistence(table_name="CircuitBreakerState")

def buffer(order: dict, circuit: CircuitInfo):
sqs.send_message(QueueUrl=url, MessageBody=json.dumps(order))

@circuit_breaker(name="payment-backend", persistence_store=persistence, on_circuit_open=buffer)
def charge(order: dict) -> dict:
return payment_api.charge(order)
"""
config = config or CircuitBreakerConfig()

def decorator(function: Callable) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs) -> Any:
# Skip the circuit entirely when disabled (development only).
if strtobool(os.getenv(constants.CIRCUIT_BREAKER_DISABLED_ENV, "false")):
warnings.warn(
message="Disabling the circuit breaker is intended for development environments only "
"and should not be used in production.",
category=PowertoolsUserWarning,
stacklevel=2,
)
return function(*args, **kwargs)

handler = CircuitBreakerHandler(
function=function,
name=name,
config=config,
persistence_store=persistence_store,
on_circuit_open=on_circuit_open,
on_transition=on_transition,
function_args=args,
function_kwargs=kwargs,
)
return handler.handle()

return wrapper

return decorator
Loading