Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,12 @@ jobs:
run: pip install -v ./runtime

- name: Test system tests
run: pytest tests/st -v --device=$DEVICE_ID --forked --precompile-workers=128 --pto-isa-commit=6eecebc5
run: pytest tests/st -v --device=$DEVICE_ID --forked --precompile-workers=128 --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352

- name: Test swimlane output
run: |
pytest tests/st/runtime/test_perf_swimlane.py \
-v --device=$DEVICE_ID --platform=a2a3 --runtime-profiling --forked --pto-isa-commit=6eecebc5
-v --device=$DEVICE_ID --platform=a2a3 --runtime-profiling --forked --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352

system-tests-a5sim:
runs-on: ubuntu-latest
Expand Down Expand Up @@ -185,13 +185,13 @@ jobs:
run: pip install -v ./runtime

- name: Test A5 system tests (simulator)
run: pytest tests/st/runtime/test_assemble.py tests/st/runtime/test_mscatter.py tests/st/runtime/test_qwen3_decode_scope3_mixed.py tests/st/runtime/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=6eecebc5
run: pytest tests/st/runtime/test_assemble.py tests/st/runtime/test_mscatter.py tests/st/runtime/test_qwen3_decode_scope3_mixed.py tests/st/runtime/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352

- name: Test A5 cross-core system tests (simulator)
run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=6eecebc5
run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352

- name: Test A2A3 cross-core system tests (simulator)
run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=6eecebc5
run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352

fuzz-tests-sim:
runs-on: ubuntu-latest
Expand Down
4 changes: 2 additions & 2 deletions docs/en/user/02-operation_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ Auto-selects between tensor and tile implementation based on input type.
| `matmul_acc` | `(acc: T, lhs: T, rhs: T, a_trans=False, b_trans=False) -> T` | Matrix multiply with accumulation: `acc += lhs @ rhs` |
| `row_max` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Row-wise max (tile path requires `tmp_tile`) |
| `row_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Row-wise sum (tile path requires `tmp_tile`) |
| `col_sum` | `(input: Tile, tmp_tile: Tile) -> Tile` | Column-wise sum (tile-only, requires `tmp_tile`) |
| `col_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Column-wise sum (tile-only). Passing `tmp_tile` activates binary-tree reduction; omitting it uses sequential reduction. |
| `col_max` | `(input: Tile) -> Tile` | Column-wise max (tile-only) |
| `col_min` | `(input: Tile) -> Tile` | Column-wise min (tile-only) |
| `rsqrt` | `(input: T, high_precision: bool = False) -> T` | Reciprocal square root; `high_precision=True` selects the high-precision path (tensor input only — tile callers must use `pl.tile.rsqrt(src, tmp=...)`) |
Expand Down Expand Up @@ -134,7 +134,7 @@ Transfer data between memory hierarchy levels.
| `row_max` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise max (requires tmp buffer) |
| `row_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise sum (requires tmp buffer) |
| `row_min` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise min (requires tmp buffer) |
| `col_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Column-wise sum (requires tmp buffer) |
| `col_sum` | `(tile: Tile, tmp_tile: Tile \| None = None) -> Tile` | Column-wise sum. Passing `tmp_tile` activates binary-tree reduction; omitting it uses sequential reduction. |
| `col_max` | `(tile: Tile) -> Tile` | Column-wise max |
| `col_min` | `(tile: Tile) -> Tile` | Column-wise min |
| `sum` | `(tile: Tile, axis: int, keepdim: bool = False) -> Tile` | Sum along axis |
Expand Down
4 changes: 2 additions & 2 deletions docs/zh-cn/user/02-operation_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
| `matmul_acc` | `(acc: T, lhs: T, rhs: T, a_trans=False, b_trans=False) -> T` | 带累加的矩阵乘法:`acc += lhs @ rhs` |
| `row_max` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 行最大值(tile 路径需要 `tmp_tile`) |
| `row_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 行求和(tile 路径需要 `tmp_tile`) |
| `col_sum` | `(input: Tile, tmp_tile: Tile) -> Tile` | 列求和(仅 tile,需要 `tmp_tile` |
| `col_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 列求和(仅 tile);传入 `tmp_tile` 启用二叉树归约,省略时使用顺序归约 |
| `col_max` | `(input: Tile) -> Tile` | 列最大值(仅 tile) |
| `col_min` | `(input: Tile) -> Tile` | 列最小值(仅 tile) |
| `rsqrt` | `(input: T, high_precision: bool = False) -> T` | 倒数平方根;`high_precision=True` 选择高精度路径(仅对 Tensor 输入生效,Tile 路径需要改用 `pl.tile.rsqrt(src, tmp=...)`) |
Expand Down Expand Up @@ -129,7 +129,7 @@
| `row_max` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行最大值(需要临时缓冲区) |
| `row_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行求和(需要临时缓冲区) |
| `row_min` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行最小值(需要临时缓冲区) |
| `col_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 列求和(需要临时缓冲区) |
| `col_sum` | `(tile: Tile, tmp_tile: Tile \| None = None) -> Tile` | 列求和;传入 `tmp_tile` 启用二叉树归约,省略时使用顺序归约 |
| `col_max` | `(tile: Tile) -> Tile` | 列最大值 |
| `col_min` | `(tile: Tile) -> Tile` | 列最小值 |
| `sum` | `(tile: Tile, axis: int, keepdim: bool = False) -> Tile` | 沿轴求和 |
Expand Down
11 changes: 8 additions & 3 deletions python/pypto/ir/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1806,21 +1806,26 @@ def row_min(tile: Expr, tmp_tile: Expr, span: Span | None = None) -> Call:
return _ir_core.create_op_call("tile.row_min", [tile, tmp_tile], {}, actual_span)


def col_sum(tile: Expr, tmp_tile: Expr, span: Span | None = None) -> Call:
def col_sum(tile: Expr, tmp_tile: Expr | None = None, span: Span | None = None) -> Call:
"""Column-wise sum reduction of a tile (reduces along axis=0, maps to TCOLSUM).

Output shape is [1, N] for an [M, N] input.

Passing ``tmp_tile`` activates the binary-tree reduction path (O(log M) depth,
better precision). Omitting ``tmp_tile`` emits the sequential reduction path.

Args:
tile: Input tile (TileType [M, N])
tmp_tile: Temporary tile (TileType, same shape as input)
tmp_tile: Optional scratch tile (TileType, same shape/dtype as input) that
activates binary-tree reduction.
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Call expression for column-wise sum reduction (TileType [1, N])
"""
actual_span = _get_span_or_capture(span)
return _ir_core.create_op_call("tile.col_sum", [tile, tmp_tile], {}, actual_span)
args = [tile] if tmp_tile is None else [tile, tmp_tile]
return _ir_core.create_op_call("tile.col_sum", args, {}, actual_span)


def col_max(tile: Expr, span: Span | None = None) -> Call:
Expand Down
11 changes: 8 additions & 3 deletions python/pypto/language/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,17 +883,22 @@ def row_min(tile: Tile, tmp_tile: Tile) -> Tile:
return Tile(expr=call_expr)


def col_sum(tile: Tile, tmp_tile: Tile) -> Tile:
def col_sum(tile: Tile, tmp_tile: Tile | None = None) -> Tile:
"""Column-wise sum reduction.

Passing ``tmp_tile`` activates the binary-tree reduction path; omitting it
uses the sequential path.

Args:
tile: Input tile
tmp_tile: Temporary tile (same shape as input)
tmp_tile: Optional scratch tile (same shape/dtype as input) that selects
the binary-tree reduction path.

Returns:
Tile wrapping the col_sum operation
"""
call_expr = _ir_ops.col_sum(tile.unwrap(), tmp_tile.unwrap())
tmp_expr = None if tmp_tile is None else tmp_tile.unwrap()
call_expr = _ir_ops.col_sum(tile.unwrap(), tmp_expr)
return Tile(expr=call_expr)


Expand Down
5 changes: 2 additions & 3 deletions python/pypto/language/op/unified_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -491,11 +491,10 @@ def row_min(input: T, tmp_tile: Tile | None = None) -> T:
def col_sum(input: T, tmp_tile: Tile | None = None) -> T:
"""Column-wise sum reduction, dispatched by input type.

For Tile inputs, tmp_tile is required as a temporary buffer.
For Tile inputs, passing ``tmp_tile`` activates the binary-tree reduction
path; omitting it uses the sequential path.
"""
if isinstance(input, Tile):
if tmp_tile is None:
raise ValueError("col_sum on Tile requires tmp_tile argument")
return _tile.col_sum(input, tmp_tile)
_raise_type_dispatch_error("col_sum", input)

Expand Down
11 changes: 5 additions & 6 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1561,16 +1561,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc
.set_input_layout(1, ir::TileLayout::row_major)
.set_output_layout(ir::TileLayout::row_major);
}
// tile.col_sum (TCOLSUM): registered separately because PTOAS requires an isBinary attribute.
// isBinary is hardcoded to true (binary-tree reduction) — the sequential path (false) offers
// no advantage in precision or performance.
// tile.col_sum (TCOLSUM): accepts 1 arg (sequential) or 2 args (tile + tmp for binary-tree).
// PTOAS pairs tmp operand with isBinary attribute; both present or both absent.
if (exclude_ops.count("tile.col_sum") == 0) {
backend.RegisterOp("tile.col_sum")
.f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen_base) {
auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base);
CHECK(op->args_.size() == 2)
<< "tile.col_sum requires 2 arguments (tile, tmp_tile), but got " << op->args_.size();
std::string config_attr = " {isBinary = true}";
CHECK(op->args_.size() == 1 || op->args_.size() == 2)
<< "tile.col_sum requires 1 or 2 arguments, but got " << op->args_.size();
std::string config_attr = op->args_.size() == 2 ? " {isBinary = true}" : "";
codegen.Emit("pto.tcolsum " + GenerateInsOutsClause(op, codegen, config_attr));
return std::string("");
});
Expand Down
28 changes: 17 additions & 11 deletions src/ir/op/tile_ops/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ TypePtr DeduceTileRowReductionType(const std::vector<ExprPtr>& args,
}

// Type deduction for column reduction operations (reduces along first axis with keepdim=True)
// col_sum requires 2 arguments (tile + tmp_tile); col_max and col_min require 1 argument
// col_sum accepts 1 arg (sequential) or 2 args (tile + tmp_tile for binary-tree reduction).
// col_max and col_min require 1 argument.
TypePtr DeduceTileColReductionType(const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs,
const std::string& op_name) {
Expand Down Expand Up @@ -272,23 +273,28 @@ REGISTER_OP("tile.row_min")

REGISTER_OP("tile.col_sum")
.set_op_category("TileOp")
.set_description("Column-wise sum reduction (reduces along axis=0, maps to TCOLSUM)")
.set_description(
"Column-wise sum reduction (reduces along axis=0, maps to TCOLSUM). "
"Passing an optional second tmp_tile activates the binary-tree reduction path.")
.add_argument("tile", "Input tile (TileType)")
.add_argument("tmp_tile", "Temporary tile (TileType)")
.add_argument("tmp_tile", "Optional scratch tile for binary-tree reduction (TileType)")
.set_input_memory(0, MemorySpace::Vec)
.set_input_memory(1, MemorySpace::Vec)
.set_output_memory(MemorySpace::Vec)
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
CHECK(args.size() == 2) << "The operator tile.col_sum requires 2 arguments (tile, tmp_tile), but got "
<< args.size();
// Validate tmp_tile: must be TileType with matching dtype
CHECK(args.size() == 1 || args.size() == 2)
<< "The operator tile.col_sum requires 1 or 2 arguments, but got " << args.size();
auto tile_type = As<TileType>(args[0]->GetType());
auto tmp_type = As<TileType>(args[1]->GetType());
CHECK(tmp_type) << "The operator tile.col_sum requires tmp_tile to be a TileType, but got "
<< args[1]->GetType()->TypeName();
CHECK(tmp_type->dtype_ == tile_type->dtype_)
<< "The operator tile.col_sum requires tmp_tile dtype to match input dtype";
CHECK(tile_type) << "The operator tile.col_sum requires first argument to be a TileType, but got "
<< args[0]->GetType()->TypeName();
if (args.size() == 2) {
auto tmp_type = As<TileType>(args[1]->GetType());
CHECK(tmp_type) << "The operator tile.col_sum requires tmp_tile to be a TileType, but got "
<< args[1]->GetType()->TypeName();
CHECK(tmp_type->dtype_ == tile_type->dtype_)
<< "The operator tile.col_sum requires tmp_tile dtype to match input dtype";
}
return DeduceTileColReductionType(args, kwargs, "tile.col_sum");
});

Expand Down
Loading
Loading