Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions aws_lambda_powertools/shared/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,6 @@

# Idempotency constants
IDEMPOTENCY_DISABLED_ENV: str = "POWERTOOLS_IDEMPOTENCY_DISABLED"

# Circuit breaker constants
CIRCUIT_BREAKER_DISABLED_ENV: str = "POWERTOOLS_CIRCUIT_BREAKER_DISABLED"
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
"""
Circuit Breaker utility for protecting unhealthy downstream dependencies.

!!! warning "Alpha / experimental"
This utility is published under the `_alpha` namespace while we collect
feedback. The public API may change in a backwards-incompatible way before it
is promoted to GA. Pin your version and follow the tracking discussion before
relying on it in production.
"""

from aws_lambda_powertools.utilities.circuit_breaker_alpha.circuit_breaker import circuit_breaker
from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import (
CircuitBreakerConfigError,
CircuitBreakerError,
CircuitBreakerOpenError,
CircuitBreakerPersistenceError,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import (
CircuitInfo,
CircuitState,
CircuitTransition,
)

__all__ = (
"circuit_breaker",
"CircuitBreakerConfig",
"CircuitInfo",
"CircuitState",
"CircuitTransition",
"CircuitBreakerError",
"CircuitBreakerOpenError",
"CircuitBreakerConfigError",
"CircuitBreakerPersistenceError",
)
203 changes: 203 additions & 0 deletions aws_lambda_powertools/utilities/circuit_breaker_alpha/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""
Orchestrator for the Circuit Breaker utility.

:class:`CircuitBreakerHandler` owns the state machine and the per-environment failure
counter; the persistence layer owns the shared truth. This split keeps the healthy
path write-free: failures are counted locally and only persisted on a state transition.
"""

from __future__ import annotations

import datetime
import logging
import uuid
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.utilities.circuit_breaker_alpha.exceptions import CircuitBreakerOpenError
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitState, CircuitTransition

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.base import (
CircuitBreakerPersistenceLayer,
)
from aws_lambda_powertools.utilities.circuit_breaker_alpha.states import CircuitInfo

logger = logging.getLogger(__name__)

# Per-environment, per-circuit consecutive counters. Module-level so they survive across
# invocations within the same execution environment, the same way idempotency caches do.
_LOCAL_FAILURES: dict[str, int] = {}
_LOCAL_SUCCESSES: dict[str, int] = {}

# Stable per-environment identifier used to claim the half-open probe lock.
_ENVIRONMENT_ID = uuid.uuid4().hex


class CircuitBreakerHandler:
"""
Drive a single protected call through the circuit breaker state machine.

A new handler is created per invocation by the decorator. It reads the shared state,
routes the call (run, short-circuit, or probe), and records the outcome.

Parameters
----------
function : Callable
The protected function.
name : str
Circuit name.
config : CircuitBreakerConfig
Circuit configuration.
persistence_store : CircuitBreakerPersistenceLayer
Shared state store.
on_circuit_open : Callable | None
Callback invoked with the protected call's own ``*args``/``**kwargs`` plus a
trailing ``circuit`` keyword argument when the circuit is open. If ``None``, an
open circuit raises :class:`CircuitBreakerOpenError`.
function_args : tuple
Positional arguments the protected function was called with.
function_kwargs : dict
Keyword arguments the protected function was called with.
"""

def __init__(
self,
function: Callable,
name: str,
config: CircuitBreakerConfig,
persistence_store: CircuitBreakerPersistenceLayer,
on_circuit_open: Callable | None = None,
on_transition: Callable | None = None,
function_args: tuple | None = None,
function_kwargs: dict | None = None,
):
self.function = function
self.name = name
self.config = config
self.on_circuit_open = on_circuit_open
self.on_transition = on_transition
self.fn_args = function_args or ()
self.fn_kwargs = function_kwargs or {}

persistence_store.configure(config=config, circuit_name=name)
self.persistence_store = persistence_store

def handle(self) -> Any:
"""
Evaluate the circuit and route the call.

Returns
-------
Any
The protected function's result when the call runs, or the
``on_circuit_open`` callback's return value when the circuit is open.

Raises
------
CircuitBreakerOpenError
If the circuit is open and no callback is registered.
"""
record = self.persistence_store.get_state(self.name)

if record.state == CircuitState.CLOSED:
return self._call_closed()

if record.state == CircuitState.OPEN:
# ``opened_at`` may legitimately be 0 (epoch); treat only None as missing.
opened_at = record.opened_at if record.opened_at is not None else self._now()
if self._now() >= opened_at + self.config.recovery_timeout:
# Recovery window elapsed: try to become the single prober.
if self.persistence_store.try_acquire_half_open(self.name, _ENVIRONMENT_ID, opened_at):
self._notify(CircuitState.OPEN, CircuitState.HALF_OPEN, opened_at=opened_at)
return self._call_probe()
return self._open_response(record.to_circuit_info())

# HALF_OPEN: only the environment that owns the probe lock runs.
if record.half_open_owner == _ENVIRONMENT_ID:
return self._call_probe()
return self._open_response(record.to_circuit_info())

def _call_closed(self) -> Any:
"""Run the protected call while the circuit is closed, tracking failures."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
failures = _LOCAL_FAILURES.get(self.name, 0) + 1
_LOCAL_FAILURES[self.name] = failures
if failures >= self.config.failure_threshold:
logger.debug("Circuit '%s' tripping CLOSED to OPEN after %d failures.", self.name, failures)
opened_at = self._now()
self.persistence_store.save_open(self.name, failure_count=failures, opened_at=opened_at)
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.CLOSED, CircuitState.OPEN, opened_at=opened_at)
raise
else:
_LOCAL_FAILURES[self.name] = 0
return result

def _call_probe(self) -> Any:
"""Run a probe during half-open, closing or reopening based on the outcome."""
try:
result = self.function(*self.fn_args, **self.fn_kwargs)
except Exception as exc:
if not self.config.counts_as_failure(exc):
raise
logger.debug("Circuit '%s' probe failed; reopening.", self.name)
opened_at = self._now()
self.persistence_store.save_reopen(self.name, opened_at=opened_at)
_LOCAL_SUCCESSES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.OPEN, opened_at=opened_at)
raise
else:
successes = _LOCAL_SUCCESSES.get(self.name, 0) + 1
_LOCAL_SUCCESSES[self.name] = successes
if successes >= self.config.success_threshold:
logger.debug("Circuit '%s' closing after %d probe successes.", self.name, successes)
self.persistence_store.save_closed(self.name)
_LOCAL_SUCCESSES[self.name] = 0
_LOCAL_FAILURES[self.name] = 0
self._notify(CircuitState.HALF_OPEN, CircuitState.CLOSED)
return result

def _open_response(self, circuit: CircuitInfo) -> Any:
"""Produce the response for an open circuit: callback result or raise."""
if self.on_circuit_open is not None:
# Forward the protected call's arguments unchanged: positional stay positional,
# keyword stay keyword. The circuit snapshot is passed as a keyword argument so
# it never collides with positionalized kwargs nor depends on dict ordering.
return self.on_circuit_open(*self.fn_args, **self.fn_kwargs, circuit=circuit)
raise CircuitBreakerOpenError(
f"Circuit '{self.name}' is open.",
circuit=circuit,
)

def _notify(self, from_state: CircuitState, to_state: CircuitState, opened_at: int | None = None) -> None:
"""
Fire the ``on_transition`` hook for a state change.

Called only on real transitions, never on the hot path. Any exception the hook
raises is swallowed and logged: observability must never break the protected call.
"""
if self.on_transition is None:
return
try:
self.on_transition(
CircuitTransition(
circuit_name=self.name,
from_state=from_state,
to_state=to_state,
opened_at=opened_at,
),
)
except Exception:
logger.warning("on_transition hook for circuit '%s' raised; ignoring.", self.name, exc_info=True)

@staticmethod
def _now() -> int:
"""Current unix timestamp in seconds."""
return int(datetime.datetime.now().timestamp())
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""
Primary interface for the Circuit Breaker utility.
"""

from __future__ import annotations

import functools
import logging
import os
import warnings
from typing import TYPE_CHECKING, Any

from aws_lambda_powertools.shared import constants
from aws_lambda_powertools.shared.functions import strtobool
from aws_lambda_powertools.utilities.circuit_breaker_alpha.base import CircuitBreakerHandler
from aws_lambda_powertools.utilities.circuit_breaker_alpha.config import CircuitBreakerConfig
from aws_lambda_powertools.warnings import PowertoolsUserWarning

if TYPE_CHECKING:
from collections.abc import Callable

from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence.base import (
CircuitBreakerPersistenceLayer,
)

logger = logging.getLogger(__name__)


def circuit_breaker(
name: str,
persistence_store: CircuitBreakerPersistenceLayer,
on_circuit_open: Callable | None = None,
on_transition: Callable | None = None,
config: CircuitBreakerConfig | None = None,
) -> Callable:
"""
Protect a function that calls an unhealthy-prone downstream with a circuit breaker.

Wrap the function that makes the downstream call, not the whole Lambda handler, so a
tripped circuit reflects one dependency rather than unrelated handler logic.

When the circuit is open the protected function is not called. Instead, if an
``on_circuit_open`` callback is registered it runs and its return value becomes the
call's result; otherwise :class:`CircuitBreakerOpenError` is raised.

Parameters
----------
name : str
Unique circuit name. Each name is an independent circuit; a function calling
several backends should use one circuit per backend.
persistence_store : CircuitBreakerPersistenceLayer
Shared state store (for example ``CircuitBreakerDynamoDBPersistence``).
on_circuit_open : Callable | None
Called when the circuit is open, with the protected function's own arguments
(positional stay positional, keyword stay keyword) plus a trailing ``circuit``
keyword argument carrying a ``CircuitInfo``. Its return value becomes the call's
result. If ``None``, an open circuit raises ``CircuitBreakerOpenError``.
on_transition : Callable | None
Called with a single ``CircuitTransition`` argument whenever the circuit changes
state (open, probe, close, reopen). Fires only on transitions, never on the
per-invocation hot path, so it is a safe place to emit a CloudWatch metric. Any
exception it raises is swallowed and logged so observability never breaks the
protected call.
config : CircuitBreakerConfig | None
Tunables. Defaults to ``CircuitBreakerConfig()`` when omitted.

Returns
-------
Callable
The decorated function.

Example
-------
**Protect a payment backend, buffering rejected requests**

from aws_lambda_powertools.utilities.circuit_breaker_alpha import circuit_breaker, CircuitInfo
from aws_lambda_powertools.utilities.circuit_breaker_alpha.persistence import (
CircuitBreakerDynamoDBPersistence,
)

persistence = CircuitBreakerDynamoDBPersistence(table_name="CircuitBreakerState")

def buffer(order: dict, circuit: CircuitInfo):
sqs.send_message(QueueUrl=url, MessageBody=json.dumps(order))

@circuit_breaker(name="payment-backend", persistence_store=persistence, on_circuit_open=buffer)
def charge(order: dict) -> dict:
return payment_api.charge(order)
"""
config = config or CircuitBreakerConfig()

def decorator(function: Callable) -> Callable:
@functools.wraps(function)
def wrapper(*args, **kwargs) -> Any:
# Skip the circuit entirely when disabled (development only).
if strtobool(os.getenv(constants.CIRCUIT_BREAKER_DISABLED_ENV, "false")):
warnings.warn(
message="Disabling the circuit breaker is intended for development environments only "
"and should not be used in production.",
category=PowertoolsUserWarning,
stacklevel=2,
)
return function(*args, **kwargs)

handler = CircuitBreakerHandler(
function=function,
name=name,
config=config,
persistence_store=persistence_store,
on_circuit_open=on_circuit_open,
on_transition=on_transition,
function_args=args,
function_kwargs=kwargs,
)
return handler.handle()

return wrapper

return decorator
Loading
Loading