Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -78,9 +78,8 @@ __attribute__((visibility("default"))) void aicpu_orchestration_entry(const Chip

// A/B layout: [num_groups, grid_k, incore_loop, tile_size, tile_size]
// C layout: [incore_loop * num_groups, tile_size, tile_size]
for (int group_idx = 0; group_idx < num_groups; group_idx++) {
PTO2_PARALLEL_FOR(group_idx, num_groups) {
PTO2_SCOPE_GUARD();

uint32_t c_elem_offset = static_cast<uint32_t>(static_cast<uint64_t>(group_idx) * group_tile_elems);
uint32_t c_view_offsets[1] = {c_elem_offset};
Tensor C_view = ext_C.view(group_shapes, c_view_offsets);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ typedef struct PTO2RuntimeOps {
PTO2Runtime *rt, const Tensor &tensor, uint32_t ndims, const uint32_t indices[], uint64_t value
);
TaskOutputTensors (*alloc_tensors)(PTO2Runtime *rt, const Arg &args);

// Parallel for iteration isolation
void (*parallel_for_begin)(PTO2Runtime *rt);
void (*parallel_iter_begin)(PTO2Runtime *rt);
void (*parallel_for_end)(PTO2Runtime *rt);
} PTO2RuntimeOps;

/**
Expand Down Expand Up @@ -255,6 +260,21 @@ static inline void pto2_rt_scope_end() {
rt->ops->scope_end(rt);
}

static inline void pto2_rt_parallel_for_begin() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_for_begin(rt);
}

static inline void pto2_rt_parallel_iter_begin() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_iter_begin(rt);
}

static inline void pto2_rt_parallel_for_end() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_for_end(rt);
}

static inline void pto2_rt_orchestration_done() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->orchestration_done(rt);
Expand Down Expand Up @@ -381,6 +401,41 @@ class PTO2ScopeGuard {
*/
#define PTO2_SCOPE() if (PTO2_SCOPE_GUARD(); true)

/**
* Combined RAII guard + loop controller for PTO2_PARALLEL_FOR.
* Construction calls parallel_for_begin; destruction calls parallel_for_end.
* next() drives per-iteration parallel_iter_begin bookkeeping.
*/
class PTO2ParallelForLoop {
public: // NOLINT(whitespace/indent)
explicit PTO2ParallelForLoop(int count) :
rt_(pto2_current_runtime()),
count_(count) {
rt_->ops->parallel_for_begin(rt_);
}
~PTO2ParallelForLoop() { rt_->ops->parallel_for_end(rt_); }
bool next(int var) {
if (var >= count_) return false;
rt_->ops->parallel_iter_begin(rt_);
return true;
}

private: // NOLINT(whitespace/indent)
PTO2Runtime *rt_;
int count_;
};

/**
* Parallel for loop with automatic iteration isolation:
* PTO2_PARALLEL_FOR(i, N) {
* submit_iter_tasks(i);
* }
* Body is a genuine for-loop body; break/continue work naturally.
*/
#define PTO2_PARALLEL_FOR(var, count) \
if (PTO2ParallelForLoop _pfl_##var(count); true) \
for (int var = 0; _pfl_##var.next(var); ++var)

// =============================================================================
// Orchestration Config
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,30 @@ void pto2_scope_end(PTO2OrchestratorState *orch) {
#endif
}

// =============================================================================
// Parallel For Iteration Isolation
// =============================================================================

void pto2_parallel_for_begin(PTO2OrchestratorState *orch) {
if (orch->fatal) return;
orch->tensor_map.push_iter_frame(orch->current_ring_id());
}

void pto2_parallel_iter_begin(PTO2OrchestratorState *orch) {
if (orch->fatal) return;
auto &tm = orch->tensor_map;
// If stack overflowed, skip filtering — run as a plain for loop.
if (tm.iter_stack_top < 0 || tm.iter_stack_top >= PTO2_MAX_PARALLEL_DEPTH) return;
uint8_t ring_id = orch->current_ring_id();
int32_t next_id = orch->rings[ring_id].task_allocator.next_local_id();
tm.set_iter_start(next_id);
}

void pto2_parallel_for_end(PTO2OrchestratorState *orch) {
if (orch->fatal) return;
orch->tensor_map.pop_iter_frame();
}

// =============================================================================
// Task Submission
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,29 @@ void pto2_scope_begin(PTO2OrchestratorState *orch);
*/
void pto2_scope_end(PTO2OrchestratorState *orch);

// =============================================================================
// Parallel For Iteration Isolation
// =============================================================================

/**
* Begin a parallel for region.
* Pushes an iteration frame onto the iter_stack.
*/
void pto2_parallel_for_begin(PTO2OrchestratorState *orch);

/**
* Begin a parallel for iteration.
* Records the current ring's next local_id as the iteration boundary.
* Does NOT create a scope — scope management is fully explicit.
*/
void pto2_parallel_iter_begin(PTO2OrchestratorState *orch);

/**
* End a parallel for region.
* Pops the iteration frame from the iter_stack.
*/
void pto2_parallel_for_end(PTO2OrchestratorState *orch);

// =============================================================================
// Task Submission
// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@ class PTO2TaskAllocator {

uint64_t heap_top() const { return heap_top_; }
uint64_t heap_capacity() const { return heap_size_; }
int32_t next_local_id() const { return local_task_id_; }

private:
// --- Task Ring ---
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ void pto2_rt_scope_begin(PTO2Runtime *rt) { pto2_scope_begin(&rt->orchestrator);

void pto2_rt_scope_end(PTO2Runtime *rt) { pto2_scope_end(&rt->orchestrator); }

static void pto2_rt_parallel_for_begin(PTO2Runtime *rt) { pto2_parallel_for_begin(&rt->orchestrator); }

static void pto2_rt_parallel_iter_begin(PTO2Runtime *rt) { pto2_parallel_iter_begin(&rt->orchestrator); }

static void pto2_rt_parallel_for_end(PTO2Runtime *rt) { pto2_parallel_for_end(&rt->orchestrator); }

void pto2_rt_orchestration_done(PTO2Runtime *rt) { pto2_orchestrator_done(&rt->orchestrator); }

static bool is_fatal_impl(PTO2Runtime *rt) { return rt->orchestrator.fatal; }
Expand Down Expand Up @@ -224,6 +230,9 @@ static const PTO2RuntimeOps s_runtime_ops = {
.get_tensor_data = pto2_get_tensor_data,
.set_tensor_data = pto2_set_tensor_data,
.alloc_tensors = alloc_tensors_impl,
.parallel_for_begin = pto2_rt_parallel_for_begin,
.parallel_iter_begin = pto2_rt_parallel_iter_begin,
.parallel_for_end = pto2_rt_parallel_for_end,
};

// =============================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,11 @@ struct PTO2RuntimeOps {
PTO2Runtime *rt, const Tensor &tensor, uint32_t ndims, const uint32_t indices[], uint64_t value
);
TaskOutputTensors (*alloc_tensors)(PTO2Runtime *rt, const Arg &args);

// Parallel for iteration isolation
void (*parallel_for_begin)(PTO2Runtime *rt);
void (*parallel_iter_begin)(PTO2Runtime *rt);
void (*parallel_for_end)(PTO2Runtime *rt);
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@
#define PTO2_MAX_SCOPE_DEPTH 64 // Maximum nesting depth
#define PTO2_SCOPE_TASKS_INIT_CAP 65536 // Initial capacity for scope task buffer

// Parallel for iteration isolation
#define PTO2_MAX_PARALLEL_DEPTH 8 // Max nesting depth for iteration filtering; deeper levels degrade gracefully

// Ready queue
#define PTO2_READY_QUEUE_SIZE 65536 // Per-shape queue size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,10 @@ bool PTO2TensorMap::init(
for (int r = 0; r < PTO2_MAX_RING_DEPTH; r++) {
last_task_alives[r] = 0;
last_cleanup[r] = 0;
active_iter_start[r] = -1;
}
iter_stack_top = -1;
active_filter_mask = 0;

return true;
}
Expand Down
78 changes: 73 additions & 5 deletions src/a2a3/runtime/tensormap_and_ringbuffer/runtime/pto_tensormap.h
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,68 @@ struct PTO2TensorMap {
// Per-ring validity threshold (for lazy invalidation)
int32_t last_task_alives[PTO2_MAX_RING_DEPTH]; // Cached from shared memory per ring

// Per-ring active iteration threshold (lookup hot-path cache).
// active_iter_start[r] >= 0 : entries on ring r with local_id < this are filtered.
// active_iter_start[r] == -1 : no active filter on ring r.
// active_filter_mask bit r mirrors (active_iter_start[r] >= 0) for a single branch test.
int32_t active_iter_start[PTO2_MAX_RING_DEPTH]{};
uint32_t active_filter_mask{0};

// Parallel for iteration isolation stack.
// Each PTO2_PARALLEL_FOR pushes a frame; each iteration updates the frame's
// iter_start. Lookup filters entries whose local_id < iter_start on the
// matching ring. Nesting beyond PTO2_MAX_PARALLEL_DEPTH degrades gracefully
// (no filtering for the overflow level, full dependency visibility).
//
// The stack itself is the source of truth for nesting/pop semantics; lookup,
// however, consumes a denormalized per-ring cache (active_iter_start +
// active_filter_mask) so the hot path is O(1) regardless of stack depth.
// For same-ring nesting, the inner frame's threshold dominates (it is always
// >= the outer's since next_local_id() is monotonic), so the cache simply
// tracks the innermost frame per ring; on pop we restore the saved outer.
struct PTO2IterFrame {
int32_t iter_start_local_id; // -1 = before first iter; >= 0 = boundary
int32_t saved_prev_iter_start; // value of active_iter_start[ring_id] before this frame
uint8_t ring_id; // ring this parallel for operates on
};
PTO2IterFrame iter_stack[PTO2_MAX_PARALLEL_DEPTH];
int32_t iter_stack_top{-1}; // -1 = no active parallel for

// =============================================================================
// Iter-stack helpers (maintain frames + per-ring cache atomically)
// =============================================================================

// Push a frame on parallel_for_begin. New frame has no active threshold yet
// (iter_start_local_id == -1); active_iter_start[ring] is unchanged. The
// previous value is saved in the frame so pop can restore it.
void push_iter_frame(uint8_t ring_id) {
int32_t top = ++iter_stack_top;
if (top >= PTO2_MAX_PARALLEL_DEPTH) return; // overflow: see class comment
iter_stack[top] = {-1, active_iter_start[ring_id], ring_id};
}

// Update the top frame's iter_start on parallel_iter_begin.
void set_iter_start(int32_t iter_start_local_id) {
int32_t top = iter_stack_top;
if (top < 0 || top >= PTO2_MAX_PARALLEL_DEPTH) return;
uint8_t ring_id = iter_stack[top].ring_id;
iter_stack[top].iter_start_local_id = iter_start_local_id;
active_iter_start[ring_id] = iter_start_local_id;
active_filter_mask |= (1u << ring_id);
}

// Pop a frame on parallel_for_end, restoring the outer threshold.
void pop_iter_frame() {
int32_t top = iter_stack_top--;
if (top < 0 || top >= PTO2_MAX_PARALLEL_DEPTH) return;
const PTO2IterFrame &frame = iter_stack[top];
uint8_t ring_id = frame.ring_id;
active_iter_start[ring_id] = frame.saved_prev_iter_start;
if (frame.saved_prev_iter_start < 0) {
active_filter_mask &= ~(1u << ring_id);
}
}

// Per-ring cleanup progress (for periodic cleanup_retired)
int32_t last_cleanup[PTO2_MAX_RING_DEPTH]{};

Expand Down Expand Up @@ -320,9 +382,9 @@ struct PTO2TensorMap {
#if PTO2_TENSORMAP_PROFILING
chain_len++;
#endif
// Skip stale entries (no chain truncation — entries from different
// rings can be interleaved, so a stale entry from one ring does NOT
// imply subsequent entries from other rings are also stale)
// Skip entries that are either stale (producer retired) or from prior
// iterations of the current parallel-for. Both checks are unified in
// entry_valid() to avoid extracting ring/local twice.
if (!entry_valid(*cur_entry)) {
cur_entry = next_entry;
continue;
Expand Down Expand Up @@ -450,10 +512,16 @@ struct PTO2TensorMap {
}

/**
* Check if entry is valid (producer has not retired)
* Check if entry is visible in the current execution context:
* 1. Producer has not retired (not stale).
* 2. Not from a prior iteration of the active parallel-for on the same ring.
*/
bool entry_valid(const PTO2TensorMapEntry &entry) const {
return static_cast<int32_t>(entry.producer_task_id.local()) >= last_task_alives[entry.producer_task_id.ring()];
uint8_t ring = entry.producer_task_id.ring();
int32_t local = static_cast<int32_t>(entry.producer_task_id.local());
if (local < last_task_alives[ring]) return false;
if (active_filter_mask && ((active_filter_mask >> ring) & 1u) && local < active_iter_start[ring]) return false;
return true;
}

void remove_entry(PTO2TensorMapEntry &entry) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,11 @@ typedef struct PTO2RuntimeOps {
PTO2Runtime *rt, const Tensor &tensor, uint32_t ndims, const uint32_t indices[], uint64_t value
);
TaskOutputTensors (*alloc_tensors)(PTO2Runtime *rt, const Arg &args);

// Parallel for iteration isolation
void (*parallel_for_begin)(PTO2Runtime *rt);
void (*parallel_iter_begin)(PTO2Runtime *rt);
void (*parallel_for_end)(PTO2Runtime *rt);
} PTO2RuntimeOps;

/**
Expand Down Expand Up @@ -255,6 +260,21 @@ static inline void pto2_rt_scope_end() {
rt->ops->scope_end(rt);
}

static inline void pto2_rt_parallel_for_begin() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_for_begin(rt);
}

static inline void pto2_rt_parallel_iter_begin() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_iter_begin(rt);
}

static inline void pto2_rt_parallel_for_end() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->parallel_for_end(rt);
}

static inline void pto2_rt_orchestration_done() {
PTO2Runtime *rt = pto2_current_runtime();
rt->ops->orchestration_done(rt);
Expand Down Expand Up @@ -381,6 +401,41 @@ class PTO2ScopeGuard {
*/
#define PTO2_SCOPE() if (PTO2_SCOPE_GUARD(); true)

/**
* Combined RAII guard + loop controller for PTO2_PARALLEL_FOR.
* Construction calls parallel_for_begin; destruction calls parallel_for_end.
* next() drives per-iteration parallel_iter_begin bookkeeping.
*/
class PTO2ParallelForLoop {
public: // NOLINT(whitespace/indent)
explicit PTO2ParallelForLoop(int count) :
rt_(pto2_current_runtime()),
count_(count) {
rt_->ops->parallel_for_begin(rt_);
}
~PTO2ParallelForLoop() { rt_->ops->parallel_for_end(rt_); }
bool next(int var) {
if (var >= count_) return false;
rt_->ops->parallel_iter_begin(rt_);
return true;
}

private: // NOLINT(whitespace/indent)
PTO2Runtime *rt_;
int count_;
};

/**
* Parallel for loop with automatic iteration isolation:
* PTO2_PARALLEL_FOR(i, N) {
* submit_iter_tasks(i);
* }
* Body is a genuine for-loop body; break/continue work naturally.
*/
#define PTO2_PARALLEL_FOR(var, count) \
if (PTO2ParallelForLoop _pfl_##var(count); true) \
for (int var = 0; _pfl_##var.next(var); ++var)

// =============================================================================
// Orchestration Config
// =============================================================================
Expand Down
Loading
Loading