Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/bindings/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ list(TRANSFORM BINDING_SOURCES PREPEND "${CMAKE_CURRENT_SOURCE_DIR}/")
set(HIERARCHICAL_SRC ${CMAKE_SOURCE_DIR}/src/common/hierarchical)

set(HIERARCHICAL_SOURCES
${HIERARCHICAL_SRC}/dist_chip_bootstrap_channel.cpp
${HIERARCHICAL_SRC}/types.cpp
${HIERARCHICAL_SRC}/tensormap.cpp
${HIERARCHICAL_SRC}/ring.cpp
Expand Down
11 changes: 10 additions & 1 deletion python/bindings/task_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -618,7 +618,16 @@ NB_MODULE(_task_interface, m) {
.def("malloc", &ChipWorker::malloc, nb::arg("size"))
.def("free", &ChipWorker::free, nb::arg("ptr"))
.def("copy_to", &ChipWorker::copy_to, nb::arg("dst"), nb::arg("src"), nb::arg("size"))
.def("copy_from", &ChipWorker::copy_from, nb::arg("dst"), nb::arg("src"), nb::arg("size"));
.def("copy_from", &ChipWorker::copy_from, nb::arg("dst"), nb::arg("src"), nb::arg("size"))
.def(
"comm_init", &ChipWorker::comm_init, nb::arg("rank"), nb::arg("nranks"), nb::arg("device_id"),
nb::arg("rootinfo_path")
)
.def("comm_alloc_windows", &ChipWorker::comm_alloc_windows, nb::arg("comm_handle"), nb::arg("win_size"))
.def("comm_get_local_window_base", &ChipWorker::comm_get_local_window_base, nb::arg("comm_handle"))
.def("comm_get_window_size", &ChipWorker::comm_get_window_size, nb::arg("comm_handle"))
.def("comm_barrier", &ChipWorker::comm_barrier, nb::arg("comm_handle"))
.def("comm_destroy", &ChipWorker::comm_destroy, nb::arg("comm_handle"));

// --- Standalone blob helpers ---
m.def(
Expand Down
72 changes: 72 additions & 0 deletions python/bindings/worker_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,15 +32,55 @@

#include "chip_worker.h"
#include "ring.h"
#include "dist_chip_bootstrap_channel.h"
#include "orchestrator.h"
#include "types.h"
#include "worker.h"
#include "worker_manager.h"

namespace nb = nanobind;

inline int32_t load_i32_acquire(const volatile int32_t *ptr) {
int32_t value;
#if defined(__aarch64__)
__asm__ volatile("ldar %w0, [%1]" : "=r"(value) : "r"(ptr) : "memory");
#elif defined(__x86_64__)
value = *ptr;
__asm__ volatile("" ::: "memory");
#else
__atomic_load(ptr, &value, __ATOMIC_ACQUIRE);
#endif
return value;
}

inline void store_i32_release(volatile int32_t *ptr, int32_t value) {
#if defined(__aarch64__)
__asm__ volatile("stlr %w0, [%1]" : : "r"(value), "r"(ptr) : "memory");
#elif defined(__x86_64__)
__asm__ volatile("" ::: "memory");
*ptr = value;
#else
__atomic_store(ptr, &value, __ATOMIC_RELEASE);
#endif
}

inline void bind_worker(nb::module_ &m) {
// --- WorkerType ---
m.def(
"_mailbox_load_i32",
[](uint64_t addr) -> int32_t {
return load_i32_acquire(reinterpret_cast<const volatile int32_t *>(addr));
},
nb::arg("addr"), "Internal: acquire-load an int32 mailbox field."
);
m.def(
"_mailbox_store_i32",
[](uint64_t addr, int32_t value) {
store_i32_release(reinterpret_cast<volatile int32_t *>(addr), value);
},
nb::arg("addr"), nb::arg("value"), "Internal: release-store an int32 mailbox field."
);

nb::enum_<WorkerType>(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB);

// --- TaskState ---
Expand All @@ -61,6 +101,38 @@ inline void bind_worker(nb::module_ &m) {
// Bound as `_Orchestrator` because the Python user-facing `Orchestrator`
// wrapper (simpler.orchestrator.Orchestrator) holds a borrowed reference
// to this C++ type.
m.attr("DIST_MAILBOX_SIZE") = static_cast<int>(MAILBOX_SIZE);
m.attr("DIST_SUB_MAILBOX_SIZE") = static_cast<int>(MAILBOX_SIZE);
m.attr("DIST_CHIP_MAILBOX_SIZE") = static_cast<int>(MAILBOX_SIZE);
m.attr("DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE") = static_cast<int>(DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE);

nb::enum_<ChipBootstrapMailboxState>(m, "ChipBootstrapMailboxState")
.value("IDLE", ChipBootstrapMailboxState::IDLE)
.value("SUCCESS", ChipBootstrapMailboxState::SUCCESS)
.value("ERROR", ChipBootstrapMailboxState::ERROR);

nb::class_<DistChipBootstrapChannel>(m, "DistChipBootstrapChannel")
.def(
"__init__",
[](DistChipBootstrapChannel *self, uint64_t mailbox_ptr, size_t max_buffer_count) {
new (self) DistChipBootstrapChannel(reinterpret_cast<void *>(mailbox_ptr), max_buffer_count);
},
nb::arg("mailbox_ptr"), nb::arg("max_buffer_count")
)
.def("reset", &DistChipBootstrapChannel::reset)
.def(
"write_success", &DistChipBootstrapChannel::write_success, nb::arg("device_ctx"),
nb::arg("local_window_base"), nb::arg("actual_window_size"), nb::arg("buffer_ptrs")
)
.def("write_error", &DistChipBootstrapChannel::write_error, nb::arg("error_code"), nb::arg("message"))
.def_prop_ro("state", &DistChipBootstrapChannel::state)
.def_prop_ro("error_code", &DistChipBootstrapChannel::error_code)
.def_prop_ro("device_ctx", &DistChipBootstrapChannel::device_ctx)
.def_prop_ro("local_window_base", &DistChipBootstrapChannel::local_window_base)
.def_prop_ro("actual_window_size", &DistChipBootstrapChannel::actual_window_size)
.def_prop_ro("buffer_ptrs", &DistChipBootstrapChannel::buffer_ptrs)
.def_prop_ro("error_message", &DistChipBootstrapChannel::error_message);

nb::class_<Orchestrator>(m, "_Orchestrator")
.def(
"submit_next_level",
Expand Down
236 changes: 236 additions & 0 deletions python/simpler/task_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,32 @@
from task_interface import DataType, ContinuousTensor, ChipStorageTaskArgs, make_tensor_arg
"""

from dataclasses import dataclass
from multiprocessing.shared_memory import SharedMemory

from _task_interface import ( # pyright: ignore[reportMissingImports]
CONTINUOUS_TENSOR_MAX_DIMS,
DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE,
DIST_CHIP_MAILBOX_SIZE,
DIST_MAILBOX_SIZE,
DIST_SUB_MAILBOX_SIZE,
MAILBOX_SIZE,
ArgDirection,
ChipBootstrapMailboxState,
ChipCallable,
ChipCallConfig,
ChipStorageTaskArgs,
ContinuousTensor,
CoreCallable,
DataType,
DistChipBootstrapChannel,
SubmitResult,
TaskArgs,
TaskState,
TensorArgType,
WorkerType,
_mailbox_load_i32,
_mailbox_store_i32,
_ChipWorker,
_Orchestrator,
_Worker,
Expand Down Expand Up @@ -65,10 +76,30 @@
"SubmitResult",
"_Worker",
"MAILBOX_SIZE",
"DistChipBootstrapChannel",
"_mailbox_load_i32",
"_mailbox_store_i32",
"ChipBootstrapMailboxState",
"DIST_MAILBOX_SIZE",
"read_args_from_blob",
"DIST_SUB_MAILBOX_SIZE",
"DIST_CHIP_MAILBOX_SIZE",
"DIST_CHIP_BOOTSTRAP_MAILBOX_SIZE",
"ChipBootstrapResult",
]


@dataclass
class ChipBootstrapResult:
"""Parent-visible reply from per-chip bootstrap."""

comm_handle: int
device_ctx: int
local_window_base: int
actual_window_size: int
buffer_ptrs: list[int]


# Lazy-loaded torch dtype → DataType map (avoids importing torch at module load)
_TORCH_DTYPE_MAP = None

Expand Down Expand Up @@ -225,6 +256,211 @@ def copy_from(self, dst, src, size):
"""Copy *size* bytes from worker *src* to host *dst*."""
self._impl.copy_from(int(dst), int(src), int(size))

def run_raw(self, callable, args, *, block_dim=1, aicpu_thread_num=3, enable_profiling=False):
"""Run a callable using raw pointer arguments."""
self._impl.run_raw(int(callable), int(args), int(block_dim), int(aicpu_thread_num), bool(enable_profiling))

def run_from_blob(self, callable, blob_ptr, block_dim=1, aicpu_thread_num=3, enable_profiling=False):
"""Run a callable from a length-prefixed TaskArgs blob."""
self._impl.run_from_blob(
int(callable), int(blob_ptr), int(block_dim), int(aicpu_thread_num), bool(enable_profiling)
)

def _copy_to_device(self, dev_ptr, host_ptr, size):
"""Internal helper for bootstrap staging into an existing device/window pointer."""
self._impl.copy_to(int(dev_ptr), int(host_ptr), int(size))

def _copy_from_device(self, host_ptr, dev_ptr, size):
"""Internal helper for bootstrap staging out of an existing device/window pointer."""
self._impl.copy_from(int(host_ptr), int(dev_ptr), int(size))

def comm_init(self, rank, nranks, device_id, rootinfo_path):
"""Create a communicator in the current chip child."""
return int(self._impl.comm_init(int(rank), int(nranks), int(device_id), str(rootinfo_path)))

def comm_alloc_windows(self, comm_handle, win_size):
"""Allocate the communicator-owned window and return the device context."""
return int(self._impl.comm_alloc_windows(int(comm_handle), int(win_size)))

def comm_get_local_window_base(self, comm_handle):
"""Return the local base address of the communicator window."""
return int(self._impl.comm_get_local_window_base(int(comm_handle)))

def comm_get_window_size(self, comm_handle):
"""Return the actual communicator window size."""
return int(self._impl.comm_get_window_size(int(comm_handle)))

def comm_barrier(self, comm_handle):
"""Synchronize all ranks in the current communicator."""
self._impl.comm_barrier(int(comm_handle))

def comm_destroy(self, comm_handle):
"""Destroy a communicator previously created by ``comm_init()``."""
self._impl.comm_destroy(int(comm_handle))

def bootstrap(
self,
device_id,
*,
comm_rank=-1,
comm_nranks=0,
rootinfo_path="",
window_size=0,
win_sync_prefix=0,
buffer_sizes,
buffer_placements,
input_blobs,
):
"""Bootstrap per-chip runtime state before the first task submission."""
buffer_sizes = [int(size) for size in buffer_sizes]
buffer_placements = [str(placement) for placement in buffer_placements]
input_blobs = list(input_blobs)

if len(buffer_sizes) != len(buffer_placements):
raise ValueError("buffer_sizes and buffer_placements must have the same length")
if len(buffer_sizes) != len(input_blobs):
raise ValueError("input_blobs length must match buffer_sizes")

enable_comm = int(comm_rank) >= 0
comm_handle = 0
device_ctx = 0
local_window_base = 0
actual_window_size = 0
buffer_ptrs: list[int] = []

try:
if enable_comm:
if int(comm_nranks) <= 0:
raise ValueError("comm_nranks must be positive when comm bootstrap is enabled")
if not str(rootinfo_path):
raise ValueError("rootinfo_path is required when comm bootstrap is enabled")
comm_handle = self.comm_init(comm_rank, comm_nranks, device_id, rootinfo_path)

if not self.device_set:
self.set_device(int(device_id))
elif self.device_id != int(device_id):
raise ValueError("ChipWorker already bound to a different device")

if enable_comm:
device_ctx = self.comm_alloc_windows(comm_handle, window_size)
local_window_base = self.comm_get_local_window_base(comm_handle)
actual_window_size = self.comm_get_window_size(comm_handle)

win_offset = int(win_sync_prefix)
for size, placement, blob in zip(buffer_sizes, buffer_placements, input_blobs, strict=True):
if placement != "window":
raise ValueError(f"Unsupported bootstrap buffer placement: {placement}; only 'window' is allowed")
if not enable_comm:
raise ValueError("window placement requires comm bootstrap")
if win_offset + size > actual_window_size:
raise ValueError("window bootstrap buffers exceed the allocated communicator window size")

ptr = local_window_base + win_offset
win_offset += size
buffer_ptrs.append(ptr)

if blob is not None:
if not isinstance(blob, bytes):
raise ValueError("input blobs must be bytes or None")
if len(blob) != size:
raise ValueError("input blob size must match buffer size")
if size > 0:
import ctypes as _ct
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The import ctypes as _ct statement is inside a loop. While this works, it's inefficient to re-import the module on every iteration. For better performance and code style, please move this import to the top of the bootstrap method (e.g., on line 304).


host_buf = _ct.create_string_buffer(blob, size)
self._copy_to_device(ptr, _ct.addressof(host_buf), size)

if enable_comm:
self.comm_barrier(comm_handle)
except Exception:
if comm_handle != 0:
try:
self.comm_destroy(comm_handle)
except Exception:
pass
raise

return {
"comm_handle": comm_handle,
"device_ctx": device_ctx,
"local_window_base": local_window_base,
"actual_window_size": actual_window_size,
"buffer_ptrs": buffer_ptrs,
}

def shutdown_bootstrap(self, *, comm_handle=0, buffer_ptrs, buffer_placements):
"""Release per-chip runtime state previously created by ``bootstrap()``."""
buffer_ptrs = [int(ptr) for ptr in buffer_ptrs]
buffer_placements = [str(placement) for placement in buffer_placements]
if len(buffer_ptrs) != len(buffer_placements):
raise ValueError("buffer_ptrs and buffer_placements must have the same length")
for _ptr, placement in zip(buffer_ptrs, buffer_placements, strict=True):
if placement != "window":
raise ValueError(
f"Unsupported bootstrap buffer placement during shutdown: {placement}; only 'window' is allowed"
)
if int(comm_handle) != 0:
self.comm_destroy(int(comm_handle))

@staticmethod
def _read_bootstrap_input_bytes(shm_name: str, size: int) -> bytes:
shm = SharedMemory(name=shm_name)
try:
if size == 0:
return b""
assert shm.buf is not None
return bytes(shm.buf[:size])
finally:
shm.close()

def bootstrap_context(self, device_id, chip_bootstrap_config) -> ChipBootstrapResult:
"""Bootstrap a chip child from a typed bootstrap config."""
comm_cfg = getattr(chip_bootstrap_config, "comm", None)
input_blobs = []
for buf in chip_bootstrap_config.buffers:
if buf.load_from_host:
staged = chip_bootstrap_config.input_staging(buf.name)
input_blobs.append(self._read_bootstrap_input_bytes(staged.shm_name, staged.size))
else:
input_blobs.append(None)
reply = self.bootstrap(
device_id,
comm_rank=comm_cfg.rank if comm_cfg is not None else -1,
comm_nranks=comm_cfg.nranks if comm_cfg is not None else 0,
rootinfo_path=comm_cfg.rootinfo_path if comm_cfg is not None else "",
window_size=comm_cfg.window_size if comm_cfg is not None else 0,
win_sync_prefix=comm_cfg.win_sync_prefix if comm_cfg is not None else 0,
buffer_sizes=[buf.nbytes for buf in chip_bootstrap_config.buffers],
buffer_placements=[buf.placement for buf in chip_bootstrap_config.buffers],
input_blobs=input_blobs,
)
return ChipBootstrapResult(
comm_handle=int(reply["comm_handle"]),
device_ctx=int(reply["device_ctx"]),
local_window_base=int(reply["local_window_base"]),
actual_window_size=int(reply["actual_window_size"]),
buffer_ptrs=[int(ptr) for ptr in reply["buffer_ptrs"]],
)

def shutdown_bootstrap_context(self, chip_bootstrap_config, *, comm_handle=0, buffer_ptrs):
"""Release resources created by ``bootstrap_context``."""
self.shutdown_bootstrap(
comm_handle=comm_handle,
buffer_ptrs=buffer_ptrs,
buffer_placements=[buf.placement for buf in chip_bootstrap_config.buffers],
)

def copy_device_to_bytes(self, dev_ptr, size) -> bytes:
"""Copy a device buffer into a Python bytes object."""
size = int(size)
if size == 0:
return b""
import ctypes as _ct

host_buf = _ct.create_string_buffer(size)
self._copy_from_device(_ct.addressof(host_buf), int(dev_ptr), size)
return host_buf.raw[:size]

@property
def device_id(self):
return self._impl.device_id
Expand Down
Loading