Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 106 additions & 21 deletions src/benchmark_service/sandbox/daytona.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from __future__ import annotations

import asyncio
import logging
import shlex
import uuid
from collections.abc import AsyncGenerator, Awaitable, Callable, Mapping
from contextlib import suppress
from typing import Any, Literal
from typing import Any, Literal, ParamSpec, TypeVar, cast

from aiohttp import ClientConnectionError, ClientResponseError
from daytona import (
Expand Down Expand Up @@ -47,6 +48,11 @@
SnapshotSource,
)

_P = ParamSpec("_P")
_R = TypeVar("_R")

logger = logging.getLogger(__name__)

_PTY_STATUS_CHECK_ATTEMPTS = 30
_STATUS_DIR = "/tmp/.sandbox-provider"
_REMOVED_SANDBOX_STATES = (SandboxState.DESTROYING, SandboxState.DESTROYED)
Expand All @@ -55,6 +61,8 @@
_SANDBOX_OPERATION_ERRORS = (DaytonaError, ClientResponseError)
_TRANSIENT_DAYTONA_ERRORS = (DaytonaConnectionError, DaytonaRateLimitError, DaytonaTimeoutError)
_RETRY_AFTER_PREFIX = "retry-after-"
_RATE_LIMIT_REMAINING_PREFIX = "x-ratelimit-remaining-"
_RATE_LIMIT_RESET_PREFIX = "x-ratelimit-reset-"
_KNOWN_THROTTLERS = ("sandbox-create", "sandbox-lifecycle", "authenticated", "anonymous")
_DELETE_CONFLICT_MESSAGES = ("state change in progress", "modified by another operation")
_REMOVED_SANDBOX_CLIENT_STATUSES = (404, 502)
Expand Down Expand Up @@ -116,6 +124,62 @@ def _provider_retry_wait(retry_state: RetryCallState) -> float:
return _RATE_LIMIT_WAIT(retry_state)


def _provider_retry_callback(op: str) -> Callable[[RetryCallState], None]:
def _hook(state: RetryCallState) -> None:
exc = state.outcome.exception() if state.outcome else None
error_class = type(exc).__name__ if exc else "unknown"
cause = exc.__cause__ if exc else None
root_cause = _root_cause(exc) if exc else None
sleep_seconds = state.next_action.sleep if state.next_action else None
logger.warning(
"daytona.retry_before_sleep",
extra={
"op": op,
"fn": state.fn.__name__ if state.fn else None,
"attempt": state.attempt_number,
"idle_for": state.idle_for,
"sleep_seconds": sleep_seconds,
"error_class": error_class,
"cause_error_class": type(cause).__name__ if cause else None,
"root_cause_error_class": type(root_cause).__name__ if root_cause else None,
},
)

rate_limit_error = _rate_limit_error(exc) if exc else None
if rate_limit_error is None:
return

throttler = _daytona_rate_limit_throttler(rate_limit_error)
logger.warning(
"daytona.rate_limit_retry",
extra={
"op": op,
"throttler": throttler,
"attempt": state.attempt_number,
"sleep_seconds": sleep_seconds,
"rate_limit_remaining": _daytona_rate_limit_header(
rate_limit_error, _RATE_LIMIT_REMAINING_PREFIX, throttler
),
"rate_limit_reset": _daytona_rate_limit_header(rate_limit_error, _RATE_LIMIT_RESET_PREFIX, throttler),
},
)

return _hook


def _provider_retry(op: str) -> Callable[[Callable[_P, _R]], Callable[_P, _R]]:
return cast(
Callable[[Callable[_P, _R]], Callable[_P, _R]],
retry(
retry=retry_if_exception_type(SandboxConnectionError),
stop=stop_after_attempt(3),
wait=_provider_retry_wait,
before_sleep=_provider_retry_callback(op),
reraise=True,
),
)


def _rate_limit_error(exc: BaseException) -> DaytonaRateLimitError | None:
if isinstance(exc, DaytonaRateLimitError):
return exc
Expand All @@ -124,6 +188,17 @@ def _rate_limit_error(exc: BaseException) -> DaytonaRateLimitError | None:
return None


def _root_cause(exc: BaseException) -> BaseException | None:
seen: set[int] = set()
cause = exc.__cause__ or exc.__context__
root = cause
while cause is not None and id(cause) not in seen:
root = cause
seen.add(id(cause))
cause = cause.__cause__ or cause.__context__
return root


def _message_contains(exc: BaseException, messages: tuple[str, ...]) -> bool:
error = str(exc).lower()
return any(message in error for message in messages)
Expand Down Expand Up @@ -164,6 +239,25 @@ def _is_transient_daytona_error(exc: DaytonaError | ClientResponseError) -> bool
return _message_contains(exc, _TRANSPORT_ERROR_MESSAGES)


def _daytona_rate_limit_throttler(exc: DaytonaRateLimitError) -> str:
for key in exc.headers:
lower_key = str(key).lower()
if lower_key.startswith(_RETRY_AFTER_PREFIX):
throttler = lower_key.removeprefix(_RETRY_AFTER_PREFIX)
return throttler if throttler in _KNOWN_THROTTLERS else "unknown"
if lower_key.startswith(_RATE_LIMIT_REMAINING_PREFIX):
throttler = lower_key.removeprefix(_RATE_LIMIT_REMAINING_PREFIX)
return throttler if throttler in _KNOWN_THROTTLERS else "unknown"

return "unknown"


def _daytona_rate_limit_header(exc: DaytonaRateLimitError, prefix: str, throttler: str) -> object | None:
if throttler == "unknown":
return None
return _get_header(exc.headers, f"{prefix}{throttler}")


def _parse_retry_after_seconds(value: object) -> float | None:
try:
seconds = float(str(value))
Expand All @@ -176,7 +270,7 @@ def _parse_retry_after_seconds(value: object) -> float | None:
return seconds


def _get_header(headers: dict[str, Any], header_name: str) -> object | None:
def _get_header(headers: Mapping[str, Any], header_name: str) -> object | None:
header_name = header_name.lower()
for key, value in headers.items():
if str(key).lower() == header_name:
Expand All @@ -202,15 +296,6 @@ def daytona_retry_after_seconds(exc: DaytonaRateLimitError) -> float | None:

return None


_PROVIDER_RETRY = retry(
retry=retry_if_exception_type(SandboxConnectionError),
stop=stop_after_attempt(3),
wait=_provider_retry_wait,
reraise=True,
)


class DaytonaSandbox(Sandbox):
def __init__(self, sandbox: AsyncSandbox) -> None:
self._sandbox = sandbox
Expand Down Expand Up @@ -241,7 +326,7 @@ def _sandbox_error(self, exc: DaytonaError | ClientResponseError) -> SandboxErro
return SandboxConnectionError(f"Sandbox connection error for {self._sandbox_ref}: {exc}")
return SandboxError(f"Sandbox operation failed for {self._sandbox_ref}: {exc}")

@_PROVIDER_RETRY
@_provider_retry("sandbox.exec")
async def exec(
self,
command: str,
Expand Down Expand Up @@ -286,14 +371,14 @@ async def command(
with suppress(asyncio.CancelledError):
await exec_task

@_PROVIDER_RETRY
@_provider_retry("sandbox.upload_file")
async def upload_file(self, remote_path: str, content: bytes) -> None:
try:
await self._sandbox.fs.upload_file(content, remote_path)
except _SANDBOX_OPERATION_ERRORS as exc:
raise self._sandbox_error(exc) from exc

@_PROVIDER_RETRY
@_provider_retry("sandbox.download_file")
async def download_file(self, remote_path: str) -> bytes:
try:
stream = await self._sandbox.fs.download_file_stream(remote_path)
Expand Down Expand Up @@ -350,7 +435,7 @@ async def on_data(data: bytes) -> None:
with suppress(Exception):
await self.exec(f"rm -f {shlex.quote(status_path)}")

@_PROVIDER_RETRY
@_provider_retry("pty.create")
async def _create_pty_session(
self,
session_id: str,
Expand All @@ -366,7 +451,7 @@ async def _create_pty_session(
await self._check_sandbox_alive()
raise self._sandbox_error(exc) from exc

@_PROVIDER_RETRY
@_provider_retry("pty.reconnect")
async def _reconnect_pty(
self,
session_id: str,
Expand All @@ -393,7 +478,7 @@ async def _reconnect_pty(
await self._check_sandbox_alive()
raise self._sandbox_error(exc) from exc

@_PROVIDER_RETRY
@_provider_retry("sandbox.health_check")
async def _check_sandbox_alive(self) -> None:
try:
await self._sandbox.refresh_data()
Expand Down Expand Up @@ -424,7 +509,7 @@ def _sandbox_error(self, exc: DaytonaError) -> SandboxError:
return SandboxConnectionError(f"Daytona sandbox provider connection error: {exc}")
return SandboxError(f"Daytona sandbox provider error: {exc}")

@_PROVIDER_RETRY
@_provider_retry("sandbox.create")
async def create_sandbox(self, request: SandboxCreateRequest) -> DaytonaSandbox:
existing = await self._find_reusable_sandbox(request.name)
if existing is not None:
Expand Down Expand Up @@ -483,7 +568,7 @@ async def _find_reusable_sandbox(self, name: str) -> AsyncSandbox | None:
except DaytonaError as exc:
raise self._sandbox_error(exc) from exc

@_PROVIDER_RETRY
@_provider_retry("sandbox.get")
async def get_sandbox(self, instance_id: str) -> DaytonaSandbox:
try:
return DaytonaSandbox(await self._daytona.get(instance_id))
Expand All @@ -492,7 +577,7 @@ async def get_sandbox(self, instance_id: str) -> DaytonaSandbox:
except DaytonaError as exc:
raise self._sandbox_error(exc) from exc

@_PROVIDER_RETRY
@_provider_retry("sandbox.delete")
async def delete_sandbox(self, instance_id: str) -> None:
try:
sandbox = await self._daytona.get(instance_id)
Expand All @@ -519,7 +604,7 @@ async def list_sandboxes(self, query: SandboxQuery) -> AsyncGenerator[DaytonaSan
continue
yield DaytonaSandbox(sandbox)

@_PROVIDER_RETRY
@_provider_retry("sandbox.list")
async def _list_sandboxes(self, query: SandboxQuery) -> list[AsyncSandbox]:
try:
daytona_query = ListSandboxesQuery(labels=query.labels, limit=query.page_size)
Expand Down
40 changes: 36 additions & 4 deletions tests/test_sandbox.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import logging
from types import SimpleNamespace
from typing import Any, Awaitable, Callable, cast

Expand Down Expand Up @@ -26,6 +27,12 @@
from benchmark_service.sandbox.daytona import DaytonaSandbox, DaytonaSandboxProvider, daytona_retry_after_seconds


def _log_record(records: list[logging.LogRecord], message: str) -> logging.LogRecord:
matches = [record for record in records if record.getMessage() == message]
assert len(matches) == 1
return matches[0]


def _client_response_error(status: int, message: str) -> ClientResponseError:
url = URL("https://daytona.example.test")
headers: CIMultiDict[str] = CIMultiDict()
Expand Down Expand Up @@ -434,16 +441,32 @@ def test_daytona_retry_after_uses_any_retry_after_header() -> None:
assert daytona_retry_after_seconds(exc) == 5


async def test_daytona_exec_retries_rate_limits() -> None:
async def test_daytona_exec_retries_rate_limits(caplog: pytest.LogCaptureFixture) -> None:
inner = InnerSandbox()
process = RateLimitedProcess()
inner.process = process
sandbox = DaytonaSandbox(cast(Any, inner))

await sandbox.exec("pytest")
with caplog.at_level(logging.WARNING, logger="benchmark_service.sandbox.daytona"):
await sandbox.exec("pytest")

assert process.attempts == 2

retry_log = _log_record(caplog.records, "daytona.retry_before_sleep")
assert getattr(retry_log, "op") == "sandbox.exec"
assert getattr(retry_log, "fn") == "exec"
assert getattr(retry_log, "attempt") == 1
assert getattr(retry_log, "sleep_seconds") == 0
assert getattr(retry_log, "error_class") == "SandboxConnectionError"
assert getattr(retry_log, "cause_error_class") == "DaytonaRateLimitError"
assert getattr(retry_log, "root_cause_error_class") == "DaytonaRateLimitError"

rate_limit_log = _log_record(caplog.records, "daytona.rate_limit_retry")
assert getattr(rate_limit_log, "op") == "sandbox.exec"
assert getattr(rate_limit_log, "throttler") == "sandbox-create"
assert getattr(rate_limit_log, "attempt") == 1
assert getattr(rate_limit_log, "sleep_seconds") == 0


async def test_daytona_exec_retries_failed_execute_command_errors() -> None:
"""Blank Daytona exec failures should be retried because they are usually transient.
Expand All @@ -463,17 +486,26 @@ async def test_daytona_exec_retries_failed_execute_command_errors() -> None:
assert process.attempts == 2


async def test_daytona_exec_retries_wrapped_connection_errors() -> None:
async def test_daytona_exec_retries_wrapped_connection_errors(caplog: pytest.LogCaptureFixture) -> None:
inner = InnerSandbox()
process = WrappedConnectionErrorProcess()
inner.process = process
sandbox = DaytonaSandbox(cast(Any, inner))

result = await sandbox.exec("pytest")
with caplog.at_level(logging.WARNING, logger="benchmark_service.sandbox.daytona"):
result = await sandbox.exec("pytest")

assert result.exit_code == 0
assert process.attempts == 2

retry_log = _log_record(caplog.records, "daytona.retry_before_sleep")
assert getattr(retry_log, "op") == "sandbox.exec"
assert getattr(retry_log, "attempt") == 1
assert getattr(retry_log, "sleep_seconds") == 2
assert getattr(retry_log, "error_class") == "SandboxConnectionError"
assert getattr(retry_log, "cause_error_class") == "DaytonaError"
assert getattr(retry_log, "root_cause_error_class") == "ClientConnectionError"


async def test_daytona_exec_retries_misclassified_transport_errors() -> None:
inner = InnerSandbox()
Expand Down
Loading