Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions python/simpler/task_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,20 @@
from task_interface import DataType, ContinuousTensor, ChipStorageTaskArgs, make_tensor_arg
"""

import ctypes
from dataclasses import dataclass, field
from multiprocessing.shared_memory import SharedMemory
from typing import Optional

from _task_interface import ( # pyright: ignore[reportMissingImports]
CHIP_BOOTSTRAP_MAILBOX_SIZE,
CONTINUOUS_TENSOR_MAX_DIMS,
MAILBOX_ERROR_MSG_SIZE,
MAILBOX_OFF_ERROR_MSG,
MAILBOX_SIZE,
ArgDirection,
ChipBootstrapChannel,
ChipBootstrapMailboxState,
ChipCallable,
ChipCallConfig,
ChipStorageTaskArgs,
Expand Down Expand Up @@ -70,6 +78,15 @@
"MAILBOX_OFF_ERROR_MSG",
"MAILBOX_ERROR_MSG_SIZE",
"read_args_from_blob",
# Chip bootstrap (L5)
"CHIP_BOOTSTRAP_MAILBOX_SIZE",
"ChipBootstrapChannel",
"ChipBootstrapMailboxState",
"ChipCommBootstrapConfig",
"ChipBufferSpec",
"HostBufferStaging",
"ChipBootstrapConfig",
"ChipBootstrapResult",
]


Expand Down Expand Up @@ -143,6 +160,95 @@ def scalar_to_uint64(value) -> int:
return int(value) & 0xFFFFFFFFFFFFFFFF


@dataclass
class ChipCommBootstrapConfig:
"""Per-chip communicator bring-up knobs consumed by `ChipWorker.bootstrap_context`.

A ``ChipBootstrapConfig`` with ``comm=None`` skips the communicator step
entirely; in that mode ``cfg.buffers`` must be empty because
``placement="window"`` is the only supported placement in L5 and the
window only exists once a communicator has been brought up. Comm-less
configs are used by validation / error-path tests that need to trip
``bootstrap_context`` before it reaches any communicator call.
"""

rank: int
nranks: int
rootinfo_path: str
window_size: int
"""Requested per-rank window size in bytes. HCCL may round this up — the
actual allocation is reported back via
``ChipBootstrapResult.actual_window_size`` and must be what callers use
when slicing the window."""


@dataclass
class ChipBufferSpec:
"""A named slice of the per-rank communicator window.

Buffers are placed sequentially inside the window in declaration order —
``ChipBootstrapResult.buffer_ptrs`` is 1:1 aligned with the ``buffers``
list so downstream code (L6's ``ChipContext``) can build a ``name → ptr``
dict by zipping the two.
"""

name: str
dtype: str
count: int
placement: str
nbytes: int
load_from_host: bool = False
store_to_host: bool = False


@dataclass
class HostBufferStaging:
"""A POSIX shared-memory region staged by the parent for one named buffer.

The parent creates the ``SharedMemory`` object and fills it with the input
bytes *before* forking; the child attaches read-only via
``SharedMemory(name=shm_name)`` and does not unlink it.
"""

name: str
shm_name: str
size: int


@dataclass
class ChipBootstrapConfig:
"""Inputs to `ChipWorker.bootstrap_context` for one chip child."""

comm: Optional[ChipCommBootstrapConfig] = None
buffers: list[ChipBufferSpec] = field(default_factory=list)
host_inputs: list[HostBufferStaging] = field(default_factory=list)
host_outputs: list[HostBufferStaging] = field(default_factory=list)

def input_staging(self, buffer_name: str) -> HostBufferStaging:
for s in self.host_inputs:
if s.name == buffer_name:
return s
raise KeyError(buffer_name)

def output_staging(self, buffer_name: str) -> HostBufferStaging:
for s in self.host_outputs:
if s.name == buffer_name:
return s
raise KeyError(buffer_name)


@dataclass
class ChipBootstrapResult:
"""Return value of `ChipWorker.bootstrap_context` — and the tuple the
`ChipBootstrapChannel` publishes to the parent on success.
"""

device_ctx: int
local_window_base: int
actual_window_size: int
buffer_ptrs: list[int]


class ChipWorker:
"""Unified execution interface wrapping the host runtime C API.

Expand Down Expand Up @@ -267,6 +373,110 @@ def comm_destroy(self, comm_handle: int) -> None:
"""Destroy the communicator and release its resources."""
self._impl.comm_destroy(int(comm_handle))

def bootstrap_context(
self,
device_id: int,
cfg: ChipBootstrapConfig,
channel: Optional[ChipBootstrapChannel] = None,
) -> ChipBootstrapResult:
"""One-shot per-chip bootstrap: set device, build communicator, slice window,
stage inputs from host shared memory, and (optionally) publish the result.

Runs inside a forked chip child. If ``channel`` is provided (the L6
integration path), the result is written as SUCCESS or — on any
exception — as ERROR (code=1, ``"<ExceptionType>: <message>"``) before
the exception is re-raised. Standalone callers can pass
``channel=None`` and consume the return value directly.

The HCCL comm handle produced by ``comm_init`` is stashed on
``self._comm_handle`` so ``shutdown_bootstrap()`` can release it later;
``finalize()`` is intentionally *not* wired to this handle — teardown
ordering is the caller's (L6's) responsibility.
"""
try:
self.set_device(device_id)

device_ctx = 0
local_base = 0
actual_size = 0
if cfg.comm is not None:
handle = self.comm_init(cfg.comm.rank, cfg.comm.nranks, cfg.comm.rootinfo_path)
if handle == 0:
raise RuntimeError(f"comm_init returned 0 handle (rank={cfg.comm.rank}, nranks={cfg.comm.nranks})")
self._comm_handle = handle
device_ctx = self.comm_alloc_windows(handle, cfg.comm.window_size)
if device_ctx == 0:
raise RuntimeError("comm_alloc_windows returned null device_ctx")
local_base = self.comm_get_local_window_base(handle)
actual_size = self.comm_get_window_size(handle)

offset = 0
buffer_ptrs: list[int] = []
for spec in cfg.buffers:
if spec.placement != "window":
raise ValueError(f"ChipBufferSpec.placement={spec.placement!r}; only 'window' is supported")
if cfg.comm is None:
raise ValueError("ChipBufferSpec requires comm; cfg.comm is None")
if offset + spec.nbytes > actual_size:
raise ValueError(
f"buffer '{spec.name}' (nbytes={spec.nbytes}) at offset={offset} "
f"overflows window size {actual_size}"
)
buffer_ptrs.append(local_base + offset)
offset += spec.nbytes

for spec, ptr in zip(cfg.buffers, buffer_ptrs):
if not spec.load_from_host:
continue
staging = cfg.input_staging(spec.name)
if staging.size != spec.nbytes:
raise ValueError(f"host_inputs[{spec.name!r}].size={staging.size} != buffer.nbytes={spec.nbytes}")
if staging.size == 0:
continue
shm = SharedMemory(name=staging.shm_name)
try:
buf = shm.buf
assert buf is not None
host_ptr = ctypes.addressof(ctypes.c_char.from_buffer(buf))
self.copy_to(ptr, host_ptr, staging.size)
finally:
shm.close()
Comment thread
ChaoWao marked this conversation as resolved.

result = ChipBootstrapResult(
device_ctx=device_ctx,
local_window_base=local_base,
actual_window_size=actual_size,
buffer_ptrs=buffer_ptrs,
)
if channel is not None:
channel.write_success(
result.device_ctx,
result.local_window_base,
result.actual_window_size,
result.buffer_ptrs,
)
return result
except Exception as e:
if channel is not None:
channel.write_error(1, f"{type(e).__name__}: {e}")
raise

def shutdown_bootstrap(self) -> None:
"""Release the communicator handle stashed by ``bootstrap_context``.

Idempotent — safe to call multiple times, and safe to call if
``bootstrap_context`` was never invoked. ``finalize()`` does *not*
chain into this method, so L6 must call ``shutdown_bootstrap()``
before ``finalize()`` (or after, if the comm handle was already
destroyed — the zero-handle guard makes a second call a no-op).
"""
handle = getattr(self, "_comm_handle", 0)
if handle != 0:
try:
self.comm_destroy(handle)
finally:
self._comm_handle = 0

@property
def device_id(self):
return self._impl.device_id
Expand Down
33 changes: 31 additions & 2 deletions src/common/platform_comm/comm_sim.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

#include <cerrno>
#include <chrono>
#include <cstdint>
#include <cstdio>
#include <cstdlib>
#include <cstring>
Expand All @@ -56,6 +57,16 @@ constexpr int FTRUNCATE_POLL_INTERVAL_US = 1000;
constexpr int BARRIER_POLL_INTERVAL_US = 50;
constexpr int DESTROY_POLL_INTERVAL_US = 1000;

// macOS's PSHMNAMLEN is 31 (name length excluding the null terminator). Linux
// accepts up to NAME_MAX (255), but we pick the tighter value so the same
// backend runs on both. The name layout below is fully constant-width so we
// can static_assert on it at compile time.
constexpr size_t SHM_NAME_MAX_LEN = 31;
constexpr size_t SHM_NAME_PREFIX_LEN = 9; // "/simpler_"
constexpr size_t SHM_NAME_HEX_FIELD = 8; // %08x: exactly 8 hex chars
constexpr size_t SHM_NAME_LEN = SHM_NAME_PREFIX_LEN + SHM_NAME_HEX_FIELD + 1 /*underscore*/ + SHM_NAME_HEX_FIELD;
static_assert(SHM_NAME_LEN <= SHM_NAME_MAX_LEN, "shm name exceeds macOS PSHMNAMLEN");

struct SharedHeader {
volatile int nranks;
volatile int alloc_done;
Expand All @@ -80,10 +91,28 @@ struct SharedHeader {
// parent PID and therefore a fresh name. Cross-node / cross-parent launches
// on sim are out of scope; callers relying on those topologies must use the
// HCCL backend.
//
// Name layout is fixed-width `"/simpler_%08x_%08x"` = 26 bytes (plus NUL), well
// under macOS's PSHMNAMLEN=31. The width is constant-propagated into
// SHM_NAME_LEN above so a future format-string change gets caught by the
// static_assert at compile time rather than by an EFILENAMEMAXEXCEEDED at
// runtime on macOS. PID is truncated to its low 32 bits (pid_t is int32_t on
// every target we support) and the 64-bit rootinfo-path hash is xor-folded to
// 32 bits; both are still collision-resistant for the canonical
// "one driver spawns N ranks" launch pattern.
std::string make_shm_name(const char *rootinfo_path) {
size_t h = std::hash<std::string>{}(rootinfo_path ? rootinfo_path : "default");
char buf[96];
std::snprintf(buf, sizeof(buf), "/simpler_comm_%d_%zx", static_cast<int>(getppid()), h);
uint32_t h32 = static_cast<uint32_t>(h ^ (h >> 32));
char buf[SHM_NAME_LEN + 1];
int written = std::snprintf(buf, sizeof(buf), "/simpler_%08x_%08x", static_cast<uint32_t>(getppid()), h32);
// Defensive runtime check: snprintf returns -1 only on I/O / encoding
// errors, and the static_assert above already pins the upper bound of a
// successful write, so this is really an "impossible path" guard for the
// libc-misbehaving edge case.
if (written < 0 || static_cast<size_t>(written) != SHM_NAME_LEN) {
std::fprintf(stderr, "[comm_sim] snprintf produced unexpected length %d for shm name\n", written);
return {};
}
return {buf};
}

Expand Down
Loading
Loading