diff --git a/examples/workers/l3/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp b/examples/workers/l3/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp new file mode 100644 index 000000000..3289f9bfb --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/kernels/aic/kernel_local_linear.cpp @@ -0,0 +1,113 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * FFN local linear kernel (AIC) — partial_local = x_shard @ w_shard. + * + * Stage 1 of the 2-stage FFN tensor-parallel demo: each rank computes its + * local matmul into per-rank device memory (partial_local). Stage 2 + * (kernel_allreduce_sum.cpp) then sums partial_local across ranks into y. + * + * args layout (ChipStorageTaskArgs — see ffn_local_orch.cpp): + * tensor(0) = x_shard INPUT (M x K, host-backed) + * tensor(1) = w_shard INPUT (K x N, host-backed) + * tensor(2) = partial_local OUTPUT_EXISTING (M x N, per-rank device mem) + */ + +#include + +#include +#include +#include + +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +template +AICORE constexpr inline T CeilAlign(T num_1, T num_2) { + if (num_2 == 0) { + return 0; + } + return (num_1 + num_2 - 1) / num_2 * num_2; +} + +static __aicore__ void local_linear_impl(__gm__ Tensor *x_tensor, __gm__ Tensor *w_tensor, __gm__ Tensor *out_tensor) { + __gm__ float *x_ptr = reinterpret_cast<__gm__ float *>(x_tensor->buffer.addr) + x_tensor->start_offset; + __gm__ float *w_ptr = reinterpret_cast<__gm__ float *>(w_tensor->buffer.addr) + w_tensor->start_offset; + __gm__ float *out_ptr = reinterpret_cast<__gm__ float *>(out_tensor->buffer.addr) + out_tensor->start_offset; + + constexpr int TILE = 64; + constexpr int block_align = C0_SIZE_BYTE / sizeof(float); + constexpr int M = CeilAlign(TILE, 16); + constexpr int K = CeilAlign(TILE, block_align); + constexpr int N = CeilAlign(TILE, block_align); + + using GlobalData = + GlobalTensor, Stride<1 * TILE * TILE, 1 * TILE * TILE, TILE * TILE, TILE, 1>>; + using TileMatA = Tile; + using TileMatB = Tile; + using LeftTile = TileLeft; + using RightTile = TileRight; + using AccTile = TileAcc; + + GlobalData x_global(x_ptr); + GlobalData w_global(w_ptr); + GlobalData out_global(out_ptr); + + TileMatA x_mat; + TileMatB w_mat; + TASSIGN(x_mat, 0x0); + TASSIGN(w_mat, 0x20000); + + LeftTile x_tile; + RightTile w_tile; + AccTile out_tile; + TASSIGN(x_tile, 0x0); + TASSIGN(w_tile, 0x0); + TASSIGN(out_tile, 0x0); + + TLOAD(x_mat, x_global); + TLOAD(w_mat, w_global); + + set_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + wait_flag(PIPE_MTE2, PIPE_MTE1, EVENT_ID0); + + TMOV(x_tile, x_mat); + TMOV(w_tile, w_mat); + + set_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + wait_flag(PIPE_MTE1, PIPE_M, EVENT_ID0); + + TMATMUL(out_tile, x_tile, w_tile); + + set_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + wait_flag(PIPE_M, PIPE_FIX, EVENT_ID0); + + TSTORE(out_global, out_tile); + + set_flag(PIPE_FIX, PIPE_S, EVENT_ID7); + wait_flag(PIPE_FIX, PIPE_S, EVENT_ID7); +} + +extern "C" __aicore__ void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *x_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *w_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *out_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + local_linear_impl(x_tensor, w_tensor, out_tensor); +} diff --git a/examples/workers/l3/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp b/examples/workers/l3/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp new file mode 100644 index 000000000..f16fce1b4 --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/kernels/aiv/kernel_allreduce_sum.cpp @@ -0,0 +1,150 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * AllReduce-sum kernel — publish/notify/wait/accumulate over partial_local. + * + * Stage 2 of the FFN tensor-parallel demo. Each rank holds a per-rank + * ``partial_local`` (= x_shard @ w_shard, written by the AIC matmul kernel + * in the previous stage); we sum it across ranks into ``y``. Cross-rank + * exchange goes through ``scratch``, which is laid out as: + * + * [ mailbox: nranks * M*N floats | signal tail: nranks int32 slots ] + * + * Each rank publishes its partial_local into peer's mailbox slot + * mailbox[my_rank], notifies the peer's signal[my_rank], waits until its + * own signal tail has been bumped by every peer, then sums local + + * mailbox[peer] for each peer into sum_tile and stores into y. + * + * args layout (ChipStorageTaskArgs — see allreduce_sum_orch.cpp): + * tensor(0) = partial_local INPUT (M x N, per-rank device mem) + * tensor(1) = y OUTPUT_EXISTING (M x N, host-backed) + * tensor(2) = scratch INOUT (HCCL-window slot, cross-rank) + * scalar(0) = nranks + * scalar(1) = CommContext device pointer + */ + +#include + +#include +#include +#include +#include + +#include "platform_comm/comm_context.h" +#include "tensor.h" + +using namespace pto; + +#ifndef __gm__ +#define __gm__ +#endif + +#ifndef __aicore__ +#define __aicore__ [aicore] +#endif + +static constexpr int kRows = 64; +static constexpr int kCols = 64; +static constexpr int kElemsPerPartial = kRows * kCols; +static constexpr int kMaxSupportedRanks = 16; + +template +AICORE inline __gm__ T *CommRemotePtr(__gm__ CommContext *ctx, __gm__ T *local_ptr, int peer_rank) { + uint64_t local_base = ctx->windowsIn[ctx->rankId]; + uint64_t offset = reinterpret_cast(local_ptr) - local_base; + return reinterpret_cast<__gm__ T *>(ctx->windowsIn[peer_rank] + offset); +} + +extern "C" __aicore__ __attribute__((always_inline)) void kernel_entry(__gm__ int64_t *args) { + __gm__ Tensor *partial_local_tensor = reinterpret_cast<__gm__ Tensor *>(args[0]); + __gm__ Tensor *y_tensor = reinterpret_cast<__gm__ Tensor *>(args[1]); + __gm__ Tensor *scratch_tensor = reinterpret_cast<__gm__ Tensor *>(args[2]); + int nranks = static_cast(args[3]); + __gm__ CommContext *comm_ctx = reinterpret_cast<__gm__ CommContext *>(args[4]); + + __gm__ float *partial_local_ptr = + reinterpret_cast<__gm__ float *>(partial_local_tensor->buffer.addr) + partial_local_tensor->start_offset; + __gm__ float *y_ptr = reinterpret_cast<__gm__ float *>(y_tensor->buffer.addr) + y_tensor->start_offset; + __gm__ float *mailbox_ptr = + reinterpret_cast<__gm__ float *>(scratch_tensor->buffer.addr) + scratch_tensor->start_offset; + // Signal slots sit at the tail of the scratch buffer, after nranks * M*N floats. + __gm__ int32_t *signal_base = reinterpret_cast<__gm__ int32_t *>(mailbox_ptr + nranks * kElemsPerPartial); + + using MatrixGlobal = GlobalTensor, Stride<1, 1, 1, kCols, 1>>; + using MatrixTile = Tile; + + int my_rank = static_cast(comm_ctx->rankId); + + if (nranks <= 0 || nranks > kMaxSupportedRanks) { + pipe_barrier(PIPE_ALL); + return; + } + + MatrixGlobal partial_local_global(partial_local_ptr); + + MatrixTile sum_tile(kRows, kCols); + MatrixTile tmp_tile(kRows, kCols); + MatrixTile staging_tile(kRows, kCols); + TASSIGN(sum_tile, 0x0); + TASSIGN(tmp_tile, 0x10000); + TASSIGN(staging_tile, 0x20000); + + TLOAD(sum_tile, partial_local_global); + pipe_barrier(PIPE_ALL); + + // Phase 1: publish my partial_local into every peer's mailbox slot mailbox[my_rank]. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ float *remote_mailbox_base = CommRemotePtr(comm_ctx, mailbox_ptr, peer); + __gm__ float *remote_slot_ptr = remote_mailbox_base + my_rank * kElemsPerPartial; + MatrixGlobal remote_slot(remote_slot_ptr); + pto::comm::TPUT(remote_slot, partial_local_global, staging_tile); + } + pipe_barrier(PIPE_ALL); + + // Phase 2: notify peer's signal[my_rank] slot, then wait for every peer to notify ours. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ int32_t *remote_signal_slot = CommRemotePtr(comm_ctx, signal_base + my_rank, peer); + pto::comm::Signal remote_sig(remote_signal_slot); + pto::comm::TNOTIFY(remote_sig, (int32_t)1, pto::comm::NotifyOp::AtomicAdd); + } + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + pto::comm::Signal local_sig(signal_base + peer); + pto::comm::TWAIT(local_sig, (int32_t)1, pto::comm::WaitCmp::GE); + } + pipe_barrier(PIPE_ALL); + + // Phase 3: accumulate every peer's mailbox slot into sum_tile. + for (int peer = 0; peer < nranks; ++peer) { + if (peer == my_rank) { + continue; + } + __gm__ float *mailbox_slot_ptr = mailbox_ptr + peer * kElemsPerPartial; + MatrixGlobal mailbox_slot(mailbox_slot_ptr); + TLOAD(tmp_tile, mailbox_slot); + pipe_barrier(PIPE_ALL); + TADD(sum_tile, sum_tile, tmp_tile); + pipe_barrier(PIPE_ALL); + } + + // Phase 4: store sum_tile into y (per-rank device output). + MatrixGlobal y_global(y_ptr); + TSTORE(y_global, sum_tile); + pipe_barrier(PIPE_ALL); +} diff --git a/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp b/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp new file mode 100644 index 000000000..4c94ae2cf --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/allreduce_sum_orch.cpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * AllReduce-sum orchestration — AIV publish/notify/wait/accumulate shim. + * + * tensor(0) partial_local INPUT (per-rank device mem; producer = ffn_local kernel) + * tensor(1) y OUTPUT_EXISTING (per-rank host-backed) + * tensor(2) scratch INOUT (HCCL-window slot — mailbox + signal tail) + * scalar(0) nranks + * scalar(1) CommContext device pointer + */ + +#include + +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +allreduce_sum_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{.expected_arg_count = 5}; // 3 tensors + 2 scalars +} + +__attribute__((visibility("default"))) void allreduce_sum_orchestration(const ChipStorageTaskArgs &orch_args) { + Tensor partial_local = from_tensor_arg(orch_args.tensor(0)); + Tensor y = from_tensor_arg(orch_args.tensor(1)); + Tensor scratch = from_tensor_arg(orch_args.tensor(2)); + + Arg params; + params.add_input(partial_local); + params.add_output(y); + params.add_inout(scratch); + params.add_scalar(orch_args.scalar(0)); // nranks + params.add_scalar(orch_args.scalar(1)); // CommContext + pto2_rt_submit_aiv_task(1, params); +} + +} // extern "C" diff --git a/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp b/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp new file mode 100644 index 000000000..b8b4859f8 --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/kernels/orchestration/ffn_local_orch.cpp @@ -0,0 +1,43 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ +/** + * FFN local linear orchestration — AIC matmul shim. + * + * tensor(0) x_shard INPUT + * tensor(1) w_shard INPUT + * tensor(2) partial_local OUTPUT_EXISTING + */ + +#include + +#include "pto_orchestration_api.h" + +extern "C" { + +__attribute__((visibility("default"))) PTO2OrchestrationConfig +ffn_local_orchestration_config(const ChipStorageTaskArgs &orch_args) { + (void)orch_args; + return PTO2OrchestrationConfig{.expected_arg_count = 3}; +} + +__attribute__((visibility("default"))) void ffn_local_orchestration(const ChipStorageTaskArgs &orch_args) { + Tensor x_shard = from_tensor_arg(orch_args.tensor(0)); + Tensor w_shard = from_tensor_arg(orch_args.tensor(1)); + Tensor partial_local = from_tensor_arg(orch_args.tensor(2)); + + Arg params; + params.add_input(x_shard); + params.add_input(w_shard); + params.add_output(partial_local); + pto2_rt_submit_aic_task(0, params); +} + +} // extern "C" diff --git a/examples/workers/l3/ffn_tp_parallel/main.py b/examples/workers/l3/ffn_tp_parallel/main.py new file mode 100644 index 000000000..2ab2a0020 --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/main.py @@ -0,0 +1,301 @@ +#!/usr/bin/env python3 +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""End-to-end FFN tensor-parallel demo — two-stage orchestration. + +Per rank, in one orch_fn: + + Stage 1 (AIC matmul): partial_local = x_shard @ w_shard + Stage 2 (AIV reduce): y = sum_over_ranks(partial_local) + +partial_local is a per-rank torch.share_memory_() tensor; it is the OUTPUT of +stage 1 and the INPUT of stage 2. Because both submits see the same +``buffer.addr``, the framework's TensorMap discovers the producer/consumer +edge automatically — no manual barriers in Python. Cross-rank exchange in +stage 2 still goes through a per-chip ``scratch`` HCCL-window buffer (laid +out as ``[mailbox: nranks * M*N floats | signal tail: nranks int32 slots]``). + +Hardware only. Run: + python examples/workers/l3/ffn_tp_parallel/main.py -d 0-1 +""" + +from __future__ import annotations + +import argparse +import os +import sys + +# Workaround for the duplicate-libomp abort when homebrew numpy and pip torch +# coexist in one macOS process. Harmless on Linux. Must be set before +# ``import torch``. See docs/macos-libomp-collision.md. +os.environ.setdefault("KMP_DUPLICATE_LIB_OK", "TRUE") + +import torch # noqa: E402 +from simpler.task_interface import ( # noqa: E402 + ArgDirection, + ChipBootstrapConfig, + ChipBufferSpec, + ChipCallable, + ChipCallConfig, + ChipCommBootstrapConfig, + ChipContext, + ContinuousTensor, + CoreCallable, + DataType, + TaskArgs, + TensorArgType, +) +from simpler.worker import Worker # noqa: E402 + +from simpler_setup.elf_parser import extract_text_section # noqa: E402 +from simpler_setup.kernel_compiler import KernelCompiler # noqa: E402 +from simpler_setup.pto_isa import ensure_pto_isa_root # noqa: E402 +from simpler_setup.torch_interop import make_tensor_arg # noqa: E402 + +HERE = os.path.dirname(os.path.abspath(__file__)) + +# Must match TILE / kRows / kCols in the AIC and AIV kernels. +M = 64 +K = 64 +N = 64 +DTYPE_NBYTES = 4 # float32 +PARTIAL_NBYTES = M * N * DTYPE_NBYTES + + +def parse_device_range(spec: str) -> list[int]: + if "-" in spec: + lo, hi = (int(x) for x in spec.split("-")) + ids = list(range(lo, hi + 1)) + else: + ids = [int(spec)] + if len(ids) != 2: + raise ValueError(f"ffn_tp_parallel needs exactly 2 devices, got {ids}") + return ids + + +def _kernel_compiler(platform: str) -> tuple[KernelCompiler, str, list[str], list[str]]: + kc = KernelCompiler(platform=platform) + runtime = "tensormap_and_ringbuffer" + pto_isa_root = ensure_pto_isa_root(clone_protocol="https") + include_dirs = kc.get_orchestration_include_dirs(runtime) + # The allreduce_sum kernel resolves CommContext from + # "platform_comm/comm_context.h" under src/common/. + kernel_include_dirs = list(include_dirs) + [str(kc.project_root / "src" / "common")] + return kc, pto_isa_root, list(include_dirs), kernel_include_dirs + + +def build_ffn_local_callable(platform: str) -> ChipCallable: + """AIC matmul: x_shard @ w_shard -> partial_local.""" + kc, pto_isa_root, _, kernel_include_dirs = _kernel_compiler(platform) + runtime = "tensormap_and_ringbuffer" + + kernel_bytes = kc.compile_incore( + source_path=os.path.join(HERE, "kernels/aic/kernel_local_linear.cpp"), + core_type="aic", + pto_isa_root=pto_isa_root, + extra_include_dirs=kernel_include_dirs, + ) + kernel_bytes = extract_text_section(kernel_bytes) + + orch_bytes = kc.compile_orchestration( + runtime_name=runtime, + source_path=os.path.join(HERE, "kernels/orchestration/ffn_local_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.IN, ArgDirection.OUT], + func_name="ffn_local_orchestration", + binary=orch_bytes, + children=[(0, core_callable)], + ) + + +def build_allreduce_sum_callable(platform: str) -> ChipCallable: + """AIV cross-rank sum (4-phase publish/notify/wait/accumulate).""" + kc, pto_isa_root, _, kernel_include_dirs = _kernel_compiler(platform) + runtime = "tensormap_and_ringbuffer" + + kernel_bytes = kc.compile_incore( + source_path=os.path.join(HERE, "kernels/aiv/kernel_allreduce_sum.cpp"), + core_type="aiv", + pto_isa_root=pto_isa_root, + extra_include_dirs=kernel_include_dirs, + ) + kernel_bytes = extract_text_section(kernel_bytes) + + orch_bytes = kc.compile_orchestration( + runtime_name=runtime, + source_path=os.path.join(HERE, "kernels/orchestration/allreduce_sum_orch.cpp"), + ) + core_callable = CoreCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + binary=kernel_bytes, + ) + return ChipCallable.build( + signature=[ArgDirection.IN, ArgDirection.OUT, ArgDirection.INOUT], + func_name="allreduce_sum_orchestration", + binary=orch_bytes, + children=[(1, core_callable)], + ) + + +def make_rank_inputs(rank: int) -> tuple[torch.Tensor, torch.Tensor]: + """Match golden formula from PR #522 (golden.py).""" + x = (torch.arange(M * K, dtype=torch.float32).reshape(M, K) + float(rank) * 0.25) / 32.0 + w = (torch.arange(K * N, dtype=torch.float32).reshape(K, N) + float(rank + 1) * 0.5) / 48.0 + return x, w + + +def run(device_ids: list[int]) -> int: + nranks = len(device_ids) + # scratch = mailbox(nranks * M*N floats) + signal tail (nranks int32). + scratch_count = nranks * M * N + scratch_nbytes = scratch_count * DTYPE_NBYTES + nranks * 4 + window_size = max(scratch_nbytes, 4 * 1024) + + rootinfo_path = f"/tmp/pto_ffn_tp_parallel_rootinfo_{os.getpid()}.bin" + try: + os.unlink(rootinfo_path) + except FileNotFoundError: + pass + + print(f"[ffn_tp_parallel] devices={device_ids} nranks={nranks} M={M} K={K} N={N}") + + # Per-rank host tensors via torch.share_memory_(): inputs, partial_local + # (stage1 output / stage2 input), and final y (stage2 output). + host_x_shards = [make_rank_inputs(r)[0].share_memory_() for r in range(nranks)] + host_w_shards = [make_rank_inputs(r)[1].share_memory_() for r in range(nranks)] + host_partial = [torch.zeros(M, N, dtype=torch.float32).share_memory_() for _ in range(nranks)] + host_y = [torch.zeros(M, N, dtype=torch.float32).share_memory_() for _ in range(nranks)] + + cfgs = [ + ChipBootstrapConfig( + comm=ChipCommBootstrapConfig( + rank=rank, + nranks=nranks, + rootinfo_path=rootinfo_path, + window_size=window_size, + ), + buffers=[ + ChipBufferSpec( + name="scratch", + dtype="float32", + count=scratch_count, + nbytes=scratch_nbytes, + ), + ], + ) + for rank in range(nranks) + ] + + print("[ffn_tp_parallel] compiling kernels...") + ffn_local_cc = build_ffn_local_callable("a2a3") + allreduce_cc = build_allreduce_sum_callable("a2a3") + + worker = Worker( + level=3, + platform="a2a3", + runtime="tensormap_and_ringbuffer", + device_ids=device_ids, + num_sub_workers=0, + chip_bootstrap_configs=cfgs, + ) + + try: + print("[ffn_tp_parallel] init worker (forks chip children + bootstraps HCCL)...") + worker.init() + + contexts: list[ChipContext] = worker.chip_contexts + assert len(contexts) == nranks + for i, ctx in enumerate(contexts): + print( + f"[ffn_tp_parallel] chip {i}: device={ctx.device_id} rank={ctx.rank}/{ctx.nranks} " + f"window=[0x{ctx.local_window_base:x} +{ctx.actual_window_size}B] " + f"scratch=0x{ctx.buffer_ptrs['scratch']:x}" + ) + + def orch_fn(orch, _args, cfg): + for i, ctx in enumerate(contexts): + # Stage 1: AIC matmul. partial_local is OUTPUT_EXISTING here; + # the framework records its buffer.addr as a producer. + a1 = TaskArgs() + a1.add_tensor(make_tensor_arg(host_x_shards[i]), TensorArgType.INPUT) + a1.add_tensor(make_tensor_arg(host_w_shards[i]), TensorArgType.INPUT) + a1.add_tensor(make_tensor_arg(host_partial[i]), TensorArgType.OUTPUT_EXISTING) + orch.submit_next_level(ffn_local_cc, a1, cfg, worker=i) + + # Stage 2: AIV cross-rank sum. Tagging partial_local INPUT + # with the same buffer.addr makes TensorMap auto-link this + # task as a consumer of stage 1, no explicit barrier needed. + a2 = TaskArgs() + a2.add_tensor(make_tensor_arg(host_partial[i]), TensorArgType.INPUT) + a2.add_tensor(make_tensor_arg(host_y[i]), TensorArgType.OUTPUT_EXISTING) + a2.add_tensor( + ContinuousTensor.make( + data=ctx.buffer_ptrs["scratch"], + shapes=(scratch_count,), + dtype=DataType.FLOAT32, + child_memory=True, + ), + TensorArgType.INOUT, + ) + a2.add_scalar(ctx.nranks) + a2.add_scalar(ctx.device_ctx) + orch.submit_next_level(allreduce_cc, a2, cfg, worker=i) + + print("[ffn_tp_parallel] running 2-chip 2-stage DAG...") + worker.run(orch_fn, args=None, config=ChipCallConfig()) + + # Golden: every rank's y should equal sum over r of x_shard[r] @ w_shard[r]. + expected = torch.zeros(M, N, dtype=torch.float32) + for r in range(nranks): + x, w = make_rank_inputs(r) + expected += x @ w + + # Match scene_test's _compare_outputs: torch.allclose(rtol, atol), + # which evaluates |a-e| <= atol + rtol*|e|. #522's golden.py uses + # rtol=atol=1e-4. + rtol, atol = 1e-4, 1e-4 + ok = True + for i in range(nranks): + diff = torch.abs(host_y[i] - expected) + rel = diff / torch.clamp(torch.abs(expected), min=1e-12) + print(f"[ffn_tp_parallel] chip {i}: max|y-exp|={float(diff.max()):.3e} max_rel={float(rel.max()):.3e}") + if not torch.allclose(host_y[i], expected, rtol=rtol, atol=atol): + ok = False + for j in range(min(4, M * N)): + flat_y = host_y[i].flatten() + flat_e = expected.flatten() + print(f" y[{j}]={float(flat_y[j])!r} expected={float(flat_e[j])!r}") + + if not ok: + print("[ffn_tp_parallel] golden check FAILED") + return 1 + print("[ffn_tp_parallel] all ranks matched golden ✅") + return 0 + finally: + worker.close() + try: + os.unlink(rootinfo_path) + except FileNotFoundError: + pass + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter) + parser.add_argument("-d", "--device", default="0-1", help="Device range, e.g. '0-1'. Two chips required.") + cli = parser.parse_args() + return run(parse_device_range(cli.device)) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/examples/workers/l3/ffn_tp_parallel/test_ffn_tp_parallel.py b/examples/workers/l3/ffn_tp_parallel/test_ffn_tp_parallel.py new file mode 100644 index 000000000..6e66f1867 --- /dev/null +++ b/examples/workers/l3/ffn_tp_parallel/test_ffn_tp_parallel.py @@ -0,0 +1,26 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- +"""Hardware ST for examples/workers/l3/ffn_tp_parallel.""" + +import os +from importlib.machinery import SourceFileLoader + +import pytest + +_main = SourceFileLoader("ffn_tp_parallel_main", os.path.join(os.path.dirname(__file__), "main.py")).load_module() +run = _main.run + + +@pytest.mark.requires_hardware +@pytest.mark.platforms(["a2a3"]) +@pytest.mark.runtime("tensormap_and_ringbuffer") +@pytest.mark.device_count(2) +def test_ffn_tp_parallel(st_device_ids): + rc = run([int(d) for d in st_device_ids]) + assert rc == 0