Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions python/bindings/worker_bind.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <nanobind/stl/string.h>
#include <nanobind/stl/vector.h>

#include <cstdint>
#include <stdexcept>

#include "chip_bootstrap_channel.h"
Expand All @@ -40,6 +41,40 @@

namespace nb = nanobind;

// ---------------------------------------------------------------------------
// Mailbox acquire/release helpers (exposed to Python as _mailbox_load_i32 /
// _mailbox_store_i32). Mirror WorkerThread::read_mailbox_state /
// write_mailbox_state in worker_manager.cpp so the Python side of the mailbox
// handshake uses the same memory order as the C++ side. Without these, a
// plain struct.pack_into("i", ...) on the Python child followed by the parent
// C++ acquire-load on aarch64 can observe the state flip before the
// preceding error-field writes are visible.
inline int32_t mailbox_load_i32(uint64_t addr) {
volatile int32_t *ptr = reinterpret_cast<volatile int32_t *>(addr);
int32_t v;
#if defined(__aarch64__)
__asm__ volatile("ldar %w0, [%1]" : "=r"(v) : "r"(ptr) : "memory");
#elif defined(__x86_64__)
v = *ptr;
__asm__ volatile("" ::: "memory");
#else
__atomic_load(ptr, &v, __ATOMIC_ACQUIRE);
#endif
return v;
}

inline void mailbox_store_i32(uint64_t addr, int32_t v) {
volatile int32_t *ptr = reinterpret_cast<volatile int32_t *>(addr);
#if defined(__aarch64__)
__asm__ volatile("stlr %w0, [%1]" : : "r"(v), "r"(ptr) : "memory");
#elif defined(__x86_64__)
__asm__ volatile("" ::: "memory");
*ptr = v;
#else
__atomic_store(ptr, &v, __ATOMIC_RELEASE);
#endif
}

inline void bind_worker(nb::module_ &m) {
// --- WorkerType ---
nb::enum_<WorkerType>(m, "WorkerType").value("NEXT_LEVEL", WorkerType::NEXT_LEVEL).value("SUB", WorkerType::SUB);
Expand Down Expand Up @@ -279,4 +314,22 @@ inline void bind_worker(nb::module_ &m) {
.def_prop_ro("actual_window_size", &ChipBootstrapChannel::actual_window_size)
.def_prop_ro("buffer_ptrs", &ChipBootstrapChannel::buffer_ptrs)
.def_prop_ro("error_message", &ChipBootstrapChannel::error_message);

// Private mailbox acquire/release helpers — only for simpler.worker. The
// underscore prefix keeps them out of the public surface; they do not
// appear in task_interface.__all__.
m.def(
"_mailbox_load_i32",
[](uint64_t addr) -> int32_t {
return mailbox_load_i32(addr);
},
nb::arg("addr"), "Acquire-load a 32-bit mailbox word at `addr`."
);
m.def(
"_mailbox_store_i32",
[](uint64_t addr, int32_t value) {
mailbox_store_i32(addr, value);
},
nb::arg("addr"), nb::arg("value"), "Release-store a 32-bit mailbox word at `addr`."
);
}
53 changes: 36 additions & 17 deletions python/simpler/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ def my_l4_orch(orch, args, config):
from multiprocessing.shared_memory import SharedMemory
from typing import Any, Callable, Optional

from _task_interface import ( # pyright: ignore[reportMissingImports]
_mailbox_load_i32,
_mailbox_store_i32,
)

from .orchestrator import Orchestrator
from .task_interface import (
MAILBOX_ERROR_MSG_SIZE,
Expand Down Expand Up @@ -117,6 +122,16 @@ def _mailbox_addr(shm: SharedMemory) -> int:
return ctypes.addressof(ctypes.c_char.from_buffer(buf))


def _buffer_field_addr(buf, offset: int) -> int:
"""Absolute address of a field inside a shared-memory buffer.

Used to feed `_mailbox_load_i32` / `_mailbox_store_i32`, which operate on
raw pointers so the acquire/release semantics match the C++ side
(worker_manager.cpp::read_mailbox_state / write_mailbox_state).
"""
return ctypes.addressof(ctypes.c_char.from_buffer(buf)) + offset


def _write_error(buf, code: int, msg: str = "") -> None:
"""Write an (error code, message) tuple into the mailbox error region.

Expand Down Expand Up @@ -185,8 +200,9 @@ def _sub_worker_loop(buf, registry: dict) -> None:
error-message region; the parent's ``WorkerThread::dispatch_process``
rethrows it as ``std::runtime_error``.
"""
state_addr = _buffer_field_addr(buf, _OFF_STATE)
while True:
state = struct.unpack_from("i", buf, _OFF_STATE)[0]
state = _mailbox_load_i32(state_addr)
if state == _TASK_READY:
cid = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0]
fn = registry.get(int(cid))
Expand All @@ -203,7 +219,7 @@ def _sub_worker_loop(buf, registry: dict) -> None:
code = 1
msg = _format_exc("sub_worker", e)
_write_error(buf, code, msg)
struct.pack_into("i", buf, _OFF_STATE, _TASK_DONE)
_mailbox_store_i32(state_addr, _TASK_DONE)
elif state == _SHUTDOWN:
break

Expand Down Expand Up @@ -237,12 +253,13 @@ def _chip_process_loop(
return

mailbox_addr = ctypes.addressof(ctypes.c_char.from_buffer(buf))
state_addr = mailbox_addr + _OFF_STATE
args_ptr = mailbox_addr + _OFF_ARGS
sys.stderr.write(f"[chip_process pid={os.getpid()} dev={device_id}] ready\n")
sys.stderr.flush()

while True:
state = struct.unpack_from("i", buf, _OFF_STATE)[0]
state = _mailbox_load_i32(state_addr)
if state == _TASK_READY:
callable_ptr = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0]
block_dim = struct.unpack_from("i", buf, _OFF_BLOCK_DIM)[0]
Expand All @@ -257,7 +274,7 @@ def _chip_process_loop(
code = 1
msg = _format_exc(f"chip_process dev={device_id}", e)
_write_error(buf, code, msg)
struct.pack_into("i", buf, _OFF_STATE, _TASK_DONE)
_mailbox_store_i32(state_addr, _TASK_DONE)
elif state == _CONTROL_REQUEST:
sub_cmd = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0]
code = 0
Expand All @@ -284,7 +301,7 @@ def _chip_process_loop(
code = 1
msg = _format_exc(f"chip_process dev={device_id} ctrl={int(sub_cmd)}", e)
_write_error(buf, code, msg)
struct.pack_into("i", buf, _OFF_STATE, _CONTROL_DONE)
_mailbox_store_i32(state_addr, _CONTROL_DONE)
elif state == _SHUTDOWN:
cw.finalize()
break
Expand Down Expand Up @@ -312,8 +329,9 @@ def _child_worker_loop(
``inner_worker.run(orch_fn, args, cfg)`` which opens its own scope,
runs the orch function, and drains.
"""
state_addr = _buffer_field_addr(buf, _OFF_STATE)
while True:
state = struct.unpack_from("i", buf, _OFF_STATE)[0]
state = _mailbox_load_i32(state_addr)
if state == _TASK_READY:
cid = struct.unpack_from("Q", buf, _OFF_CALLABLE)[0]
orch_fn = registry.get(int(cid))
Expand All @@ -331,7 +349,7 @@ def _child_worker_loop(
code = 1
msg = _format_exc(f"child_worker level={inner_worker.level}", e)
_write_error(buf, code, msg)
struct.pack_into("i", buf, _OFF_STATE, _TASK_DONE)
_mailbox_store_i32(state_addr, _TASK_DONE)
elif state == _SHUTDOWN:
inner_worker.close()
break
Expand Down Expand Up @@ -449,7 +467,7 @@ def _init_hierarchical(self) -> None:
for _ in range(n_sub):
shm = SharedMemory(create=True, size=MAILBOX_SIZE)
assert shm.buf is not None
struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE)
_mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE)
self._sub_shms.append(shm)

# 2. Prepare chip-worker config (L3 only — L4+ has Worker children instead)
Expand All @@ -472,14 +490,14 @@ def _init_hierarchical(self) -> None:
for _ in device_ids:
shm = SharedMemory(create=True, size=MAILBOX_SIZE)
assert shm.buf is not None
struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE)
_mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE)
self._chip_shms.append(shm)

# 3. Allocate next-level Worker child mailboxes (L4+ only).
for _ in self._next_level_workers:
shm = SharedMemory(create=True, size=MAILBOX_SIZE)
assert shm.buf is not None
struct.pack_into("i", shm.buf, _OFF_STATE, _IDLE)
_mailbox_store_i32(_buffer_field_addr(shm.buf, _OFF_STATE), _IDLE)
self._next_level_shms.append(shm)

# 4. Construct the _Worker *before* fork so the HeapRing mmap
Expand Down Expand Up @@ -584,21 +602,22 @@ def _chip_control(self, worker_id: int, sub_cmd: int, arg0: int = 0, arg1: int =
shm = self._chip_shms[worker_id]
buf = shm.buf
assert buf is not None
state_addr = _buffer_field_addr(buf, _OFF_STATE)
_write_error(buf, 0, "")
struct.pack_into("Q", buf, _OFF_CALLABLE, sub_cmd)
struct.pack_into("Q", buf, _CTRL_OFF_ARG0, arg0)
struct.pack_into("Q", buf, _CTRL_OFF_ARG1, arg1)
struct.pack_into("Q", buf, _CTRL_OFF_ARG2, arg2)
struct.pack_into("i", buf, _OFF_STATE, _CONTROL_REQUEST)
while struct.unpack_from("i", buf, _OFF_STATE)[0] != _CONTROL_DONE:
_mailbox_store_i32(state_addr, _CONTROL_REQUEST)
while _mailbox_load_i32(state_addr) != _CONTROL_DONE:
pass
error = struct.unpack_from("i", buf, _OFF_ERROR)[0]
if error != 0:
err_msg = _read_error_msg(buf)
struct.pack_into("i", buf, _OFF_STATE, _IDLE)
_mailbox_store_i32(state_addr, _IDLE)
raise RuntimeError(f"chip control command {sub_cmd} failed on worker {worker_id}: {err_msg}")
result = struct.unpack_from("Q", buf, _CTRL_OFF_RESULT)[0]
struct.pack_into("i", buf, _OFF_STATE, _IDLE)
_mailbox_store_i32(state_addr, _IDLE)
return result

def malloc(self, size: int, worker_id: int = 0) -> int:
Expand Down Expand Up @@ -705,7 +724,7 @@ def close(self) -> None:
for shm in self._sub_shms:
buf = shm.buf
assert buf is not None
struct.pack_into("i", buf, _OFF_STATE, _SHUTDOWN)
_mailbox_store_i32(_buffer_field_addr(buf, _OFF_STATE), _SHUTDOWN)
for pid in self._sub_pids:
os.waitpid(pid, 0)
for shm in self._sub_shms:
Expand All @@ -716,7 +735,7 @@ def close(self) -> None:
for shm in self._chip_shms:
buf = shm.buf
assert buf is not None
struct.pack_into("i", buf, _OFF_STATE, _SHUTDOWN)
_mailbox_store_i32(_buffer_field_addr(buf, _OFF_STATE), _SHUTDOWN)
for pid in self._chip_pids:
os.waitpid(pid, 0)
for shm in self._chip_shms:
Expand All @@ -728,7 +747,7 @@ def close(self) -> None:
for shm in self._next_level_shms:
buf = shm.buf
assert buf is not None
struct.pack_into("i", buf, _OFF_STATE, _SHUTDOWN)
_mailbox_store_i32(_buffer_field_addr(buf, _OFF_STATE), _SHUTDOWN)
for pid in self._next_level_pids:
os.waitpid(pid, 0)
for shm in self._next_level_shms:
Expand Down
Loading
Loading