diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 79f2a7c07..703dc678f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -137,12 +137,12 @@ jobs: run: pip install -v ./runtime - name: Test system tests - run: pytest tests/st -v --device=$DEVICE_ID --forked --precompile-workers=128 --pto-isa-commit=6eecebc5 + run: pytest tests/st -v --device=$DEVICE_ID --forked --precompile-workers=128 --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352 - name: Test swimlane output run: | pytest tests/st/runtime/test_perf_swimlane.py \ - -v --device=$DEVICE_ID --platform=a2a3 --runtime-profiling --forked --pto-isa-commit=6eecebc5 + -v --device=$DEVICE_ID --platform=a2a3 --runtime-profiling --forked --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352 system-tests-a5sim: runs-on: ubuntu-latest @@ -185,13 +185,13 @@ jobs: run: pip install -v ./runtime - name: Test A5 system tests (simulator) - run: pytest tests/st/runtime/test_assemble.py tests/st/runtime/test_mscatter.py tests/st/runtime/test_qwen3_decode_scope3_mixed.py tests/st/runtime/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=6eecebc5 + run: pytest tests/st/runtime/test_assemble.py tests/st/runtime/test_mscatter.py tests/st/runtime/test_qwen3_decode_scope3_mixed.py tests/st/runtime/test_dyn_orch_shape.py::TestDynOrchShapeOperations::test_dyn_orch_paged_attention -v --platform=a5sim --forked --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352 - name: Test A5 cross-core system tests (simulator) - run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=6eecebc5 + run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a5sim --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352 - name: Test A2A3 cross-core system tests (simulator) - run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=6eecebc5 + run: pytest tests/st/runtime/test_cross_core.py -v --forked --platform=a2a3sim --pto-isa-commit=5779238f2461460c1b2e87d0a6741bf10dfe4352 fuzz-tests-sim: runs-on: ubuntu-latest diff --git a/docs/en/user/02-operation_reference.md b/docs/en/user/02-operation_reference.md index da526f6af..7291b0472 100644 --- a/docs/en/user/02-operation_reference.md +++ b/docs/en/user/02-operation_reference.md @@ -24,7 +24,7 @@ Auto-selects between tensor and tile implementation based on input type. | `matmul_acc` | `(acc: T, lhs: T, rhs: T, a_trans=False, b_trans=False) -> T` | Matrix multiply with accumulation: `acc += lhs @ rhs` | | `row_max` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Row-wise max (tile path requires `tmp_tile`) | | `row_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Row-wise sum (tile path requires `tmp_tile`) | -| `col_sum` | `(input: Tile, tmp_tile: Tile) -> Tile` | Column-wise sum (tile-only, requires `tmp_tile`) | +| `col_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | Column-wise sum (tile-only). Passing `tmp_tile` activates binary-tree reduction; omitting it uses sequential reduction. | | `col_max` | `(input: Tile) -> Tile` | Column-wise max (tile-only) | | `col_min` | `(input: Tile) -> Tile` | Column-wise min (tile-only) | | `rsqrt` | `(input: T, high_precision: bool = False) -> T` | Reciprocal square root; `high_precision=True` selects the high-precision path (tensor input only — tile callers must use `pl.tile.rsqrt(src, tmp=...)`) | @@ -134,7 +134,7 @@ Transfer data between memory hierarchy levels. | `row_max` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise max (requires tmp buffer) | | `row_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise sum (requires tmp buffer) | | `row_min` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Row-wise min (requires tmp buffer) | -| `col_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | Column-wise sum (requires tmp buffer) | +| `col_sum` | `(tile: Tile, tmp_tile: Tile \| None = None) -> Tile` | Column-wise sum. Passing `tmp_tile` activates binary-tree reduction; omitting it uses sequential reduction. | | `col_max` | `(tile: Tile) -> Tile` | Column-wise max | | `col_min` | `(tile: Tile) -> Tile` | Column-wise min | | `sum` | `(tile: Tile, axis: int, keepdim: bool = False) -> Tile` | Sum along axis | diff --git a/docs/zh-cn/user/02-operation_reference.md b/docs/zh-cn/user/02-operation_reference.md index f2266f98b..6f85e0741 100644 --- a/docs/zh-cn/user/02-operation_reference.md +++ b/docs/zh-cn/user/02-operation_reference.md @@ -24,7 +24,7 @@ | `matmul_acc` | `(acc: T, lhs: T, rhs: T, a_trans=False, b_trans=False) -> T` | 带累加的矩阵乘法:`acc += lhs @ rhs` | | `row_max` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 行最大值(tile 路径需要 `tmp_tile`) | | `row_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 行求和(tile 路径需要 `tmp_tile`) | -| `col_sum` | `(input: Tile, tmp_tile: Tile) -> Tile` | 列求和(仅 tile,需要 `tmp_tile`) | +| `col_sum` | `(input: T, tmp_tile: Tile \| None = None) -> T` | 列求和(仅 tile);传入 `tmp_tile` 启用二叉树归约,省略时使用顺序归约 | | `col_max` | `(input: Tile) -> Tile` | 列最大值(仅 tile) | | `col_min` | `(input: Tile) -> Tile` | 列最小值(仅 tile) | | `rsqrt` | `(input: T, high_precision: bool = False) -> T` | 倒数平方根;`high_precision=True` 选择高精度路径(仅对 Tensor 输入生效,Tile 路径需要改用 `pl.tile.rsqrt(src, tmp=...)`) | @@ -129,7 +129,7 @@ | `row_max` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行最大值(需要临时缓冲区) | | `row_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行求和(需要临时缓冲区) | | `row_min` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 行最小值(需要临时缓冲区) | -| `col_sum` | `(tile: Tile, tmp_tile: Tile) -> Tile` | 列求和(需要临时缓冲区) | +| `col_sum` | `(tile: Tile, tmp_tile: Tile \| None = None) -> Tile` | 列求和;传入 `tmp_tile` 启用二叉树归约,省略时使用顺序归约 | | `col_max` | `(tile: Tile) -> Tile` | 列最大值 | | `col_min` | `(tile: Tile) -> Tile` | 列最小值 | | `sum` | `(tile: Tile, axis: int, keepdim: bool = False) -> Tile` | 沿轴求和 | diff --git a/python/pypto/ir/op/tile_ops.py b/python/pypto/ir/op/tile_ops.py index 9c7ce7355..588a34c26 100644 --- a/python/pypto/ir/op/tile_ops.py +++ b/python/pypto/ir/op/tile_ops.py @@ -1806,21 +1806,26 @@ def row_min(tile: Expr, tmp_tile: Expr, span: Span | None = None) -> Call: return _ir_core.create_op_call("tile.row_min", [tile, tmp_tile], {}, actual_span) -def col_sum(tile: Expr, tmp_tile: Expr, span: Span | None = None) -> Call: +def col_sum(tile: Expr, tmp_tile: Expr | None = None, span: Span | None = None) -> Call: """Column-wise sum reduction of a tile (reduces along axis=0, maps to TCOLSUM). Output shape is [1, N] for an [M, N] input. + Passing ``tmp_tile`` activates the binary-tree reduction path (O(log M) depth, + better precision). Omitting ``tmp_tile`` emits the sequential reduction path. + Args: tile: Input tile (TileType [M, N]) - tmp_tile: Temporary tile (TileType, same shape as input) + tmp_tile: Optional scratch tile (TileType, same shape/dtype as input) that + activates binary-tree reduction. span: Optional source span for debugging (auto-captured if not provided) Returns: Call expression for column-wise sum reduction (TileType [1, N]) """ actual_span = _get_span_or_capture(span) - return _ir_core.create_op_call("tile.col_sum", [tile, tmp_tile], {}, actual_span) + args = [tile] if tmp_tile is None else [tile, tmp_tile] + return _ir_core.create_op_call("tile.col_sum", args, {}, actual_span) def col_max(tile: Expr, span: Span | None = None) -> Call: diff --git a/python/pypto/language/op/tile_ops.py b/python/pypto/language/op/tile_ops.py index 031ce7024..a461ef601 100644 --- a/python/pypto/language/op/tile_ops.py +++ b/python/pypto/language/op/tile_ops.py @@ -883,17 +883,22 @@ def row_min(tile: Tile, tmp_tile: Tile) -> Tile: return Tile(expr=call_expr) -def col_sum(tile: Tile, tmp_tile: Tile) -> Tile: +def col_sum(tile: Tile, tmp_tile: Tile | None = None) -> Tile: """Column-wise sum reduction. + Passing ``tmp_tile`` activates the binary-tree reduction path; omitting it + uses the sequential path. + Args: tile: Input tile - tmp_tile: Temporary tile (same shape as input) + tmp_tile: Optional scratch tile (same shape/dtype as input) that selects + the binary-tree reduction path. Returns: Tile wrapping the col_sum operation """ - call_expr = _ir_ops.col_sum(tile.unwrap(), tmp_tile.unwrap()) + tmp_expr = None if tmp_tile is None else tmp_tile.unwrap() + call_expr = _ir_ops.col_sum(tile.unwrap(), tmp_expr) return Tile(expr=call_expr) diff --git a/python/pypto/language/op/unified_ops.py b/python/pypto/language/op/unified_ops.py index 34a9602e4..6c6816d6e 100644 --- a/python/pypto/language/op/unified_ops.py +++ b/python/pypto/language/op/unified_ops.py @@ -491,11 +491,10 @@ def row_min(input: T, tmp_tile: Tile | None = None) -> T: def col_sum(input: T, tmp_tile: Tile | None = None) -> T: """Column-wise sum reduction, dispatched by input type. - For Tile inputs, tmp_tile is required as a temporary buffer. + For Tile inputs, passing ``tmp_tile`` activates the binary-tree reduction + path; omitting it uses the sequential path. """ if isinstance(input, Tile): - if tmp_tile is None: - raise ValueError("col_sum on Tile requires tmp_tile argument") return _tile.col_sum(input, tmp_tile) _raise_type_dispatch_error("col_sum", input) diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index 30f49b39e..37c87a5a3 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -1561,16 +1561,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc .set_input_layout(1, ir::TileLayout::row_major) .set_output_layout(ir::TileLayout::row_major); } - // tile.col_sum (TCOLSUM): registered separately because PTOAS requires an isBinary attribute. - // isBinary is hardcoded to true (binary-tree reduction) — the sequential path (false) offers - // no advantage in precision or performance. + // tile.col_sum (TCOLSUM): accepts 1 arg (sequential) or 2 args (tile + tmp for binary-tree). + // PTOAS pairs tmp operand with isBinary attribute; both present or both absent. if (exclude_ops.count("tile.col_sum") == 0) { backend.RegisterOp("tile.col_sum") .f_codegen([](const ir::CallPtr& op, codegen::CodegenBase& codegen_base) { auto& codegen = dynamic_cast(codegen_base); - CHECK(op->args_.size() == 2) - << "tile.col_sum requires 2 arguments (tile, tmp_tile), but got " << op->args_.size(); - std::string config_attr = " {isBinary = true}"; + CHECK(op->args_.size() == 1 || op->args_.size() == 2) + << "tile.col_sum requires 1 or 2 arguments, but got " << op->args_.size(); + std::string config_attr = op->args_.size() == 2 ? " {isBinary = true}" : ""; codegen.Emit("pto.tcolsum " + GenerateInsOutsClause(op, codegen, config_attr)); return std::string(""); }); diff --git a/src/ir/op/tile_ops/reduction.cpp b/src/ir/op/tile_ops/reduction.cpp index 63fbe1e71..ca0f1a6fe 100644 --- a/src/ir/op/tile_ops/reduction.cpp +++ b/src/ir/op/tile_ops/reduction.cpp @@ -155,7 +155,8 @@ TypePtr DeduceTileRowReductionType(const std::vector& args, } // Type deduction for column reduction operations (reduces along first axis with keepdim=True) -// col_sum requires 2 arguments (tile + tmp_tile); col_max and col_min require 1 argument +// col_sum accepts 1 arg (sequential) or 2 args (tile + tmp_tile for binary-tree reduction). +// col_max and col_min require 1 argument. TypePtr DeduceTileColReductionType(const std::vector& args, const std::vector>& kwargs, const std::string& op_name) { @@ -272,23 +273,28 @@ REGISTER_OP("tile.row_min") REGISTER_OP("tile.col_sum") .set_op_category("TileOp") - .set_description("Column-wise sum reduction (reduces along axis=0, maps to TCOLSUM)") + .set_description( + "Column-wise sum reduction (reduces along axis=0, maps to TCOLSUM). " + "Passing an optional second tmp_tile activates the binary-tree reduction path.") .add_argument("tile", "Input tile (TileType)") - .add_argument("tmp_tile", "Temporary tile (TileType)") + .add_argument("tmp_tile", "Optional scratch tile for binary-tree reduction (TileType)") .set_input_memory(0, MemorySpace::Vec) .set_input_memory(1, MemorySpace::Vec) .set_output_memory(MemorySpace::Vec) .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { - CHECK(args.size() == 2) << "The operator tile.col_sum requires 2 arguments (tile, tmp_tile), but got " - << args.size(); - // Validate tmp_tile: must be TileType with matching dtype + CHECK(args.size() == 1 || args.size() == 2) + << "The operator tile.col_sum requires 1 or 2 arguments, but got " << args.size(); auto tile_type = As(args[0]->GetType()); - auto tmp_type = As(args[1]->GetType()); - CHECK(tmp_type) << "The operator tile.col_sum requires tmp_tile to be a TileType, but got " - << args[1]->GetType()->TypeName(); - CHECK(tmp_type->dtype_ == tile_type->dtype_) - << "The operator tile.col_sum requires tmp_tile dtype to match input dtype"; + CHECK(tile_type) << "The operator tile.col_sum requires first argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + if (args.size() == 2) { + auto tmp_type = As(args[1]->GetType()); + CHECK(tmp_type) << "The operator tile.col_sum requires tmp_tile to be a TileType, but got " + << args[1]->GetType()->TypeName(); + CHECK(tmp_type->dtype_ == tile_type->dtype_) + << "The operator tile.col_sum requires tmp_tile dtype to match input dtype"; + } return DeduceTileColReductionType(args, kwargs, "tile.col_sum"); }); diff --git a/tests/st/runtime/test_col_reduction.py b/tests/st/runtime/test_col_reduction.py index 29e75101e..d56010912 100644 --- a/tests/st/runtime/test_col_reduction.py +++ b/tests/st/runtime/test_col_reduction.py @@ -14,8 +14,9 @@ - Shapes: [32, 64] (tall), [16, 16] (square), [8, 128] (wide) - Dtypes: FP32, FP16 -col_sum requires a tmp_tile argument (same shape as input) because -TCOLSUM on a2a3 uses a 4-arg IMPL(dst, src, tmp, isBinary). +col_sum accepts an optional tmp_tile argument. Passing tmp_tile activates +the binary-tree reduction path (TCOLSUM 4-arg form); omitting it uses the +sequential reduction path (TCOLSUM 2-arg form). """ from typing import Any @@ -28,7 +29,7 @@ from pypto.ir.pass_manager import OptimizationStrategy # ============================================================================= -# Programs — col_sum (requires tmp_tile) +# Programs — col_sum (tmp_tile optional; provide it for binary-tree reduction) # ============================================================================= @@ -57,6 +58,28 @@ def orchestrator( return output +@pl.program +class ColSum_32x64_FP32_Sequential: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + input_tensor: pl.Tensor[[32, 64], pl.FP32], + output: pl.Out[pl.Tensor[[1, 64], pl.FP32]], + ) -> pl.Tensor[[1, 64], pl.FP32]: + tile: pl.Tile[[32, 64], pl.FP32] = pl.load(input_tensor, [0, 0], [32, 64]) + result: pl.Tile[[1, 64], pl.FP32] = pl.tile.col_sum(tile) + return pl.store(result, [0, 0], output) + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + input_tensor: pl.Tensor[[32, 64], pl.FP32], + output: pl.Out[pl.Tensor[[1, 64], pl.FP32]], + ) -> pl.Tensor[[1, 64], pl.FP32]: + output = self.kernel(input_tensor, output) + return output + + @pl.program class ColSum_16x16_FP32: @pl.function(type=pl.FunctionType.InCore) @@ -132,6 +155,72 @@ def orchestrator( return output +@pl.program +class ColSum_16x16_FP32_Sequential: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + input_tensor: pl.Tensor[[16, 16], pl.FP32], + output: pl.Out[pl.Tensor[[1, 16], pl.FP32]], + ) -> pl.Tensor[[1, 16], pl.FP32]: + tile: pl.Tile[[16, 16], pl.FP32] = pl.load(input_tensor, [0, 0], [16, 16]) + result: pl.Tile[[1, 16], pl.FP32] = pl.tile.col_sum(tile) + return pl.store(result, [0, 0], output) + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + input_tensor: pl.Tensor[[16, 16], pl.FP32], + output: pl.Out[pl.Tensor[[1, 16], pl.FP32]], + ) -> pl.Tensor[[1, 16], pl.FP32]: + output = self.kernel(input_tensor, output) + return output + + +@pl.program +class ColSum_8x128_FP32_Sequential: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + input_tensor: pl.Tensor[[8, 128], pl.FP32], + output: pl.Out[pl.Tensor[[1, 128], pl.FP32]], + ) -> pl.Tensor[[1, 128], pl.FP32]: + tile: pl.Tile[[8, 128], pl.FP32] = pl.load(input_tensor, [0, 0], [8, 128]) + result: pl.Tile[[1, 128], pl.FP32] = pl.tile.col_sum(tile) + return pl.store(result, [0, 0], output) + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + input_tensor: pl.Tensor[[8, 128], pl.FP32], + output: pl.Out[pl.Tensor[[1, 128], pl.FP32]], + ) -> pl.Tensor[[1, 128], pl.FP32]: + output = self.kernel(input_tensor, output) + return output + + +@pl.program +class ColSum_32x64_FP16_Sequential: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + input_tensor: pl.Tensor[[32, 64], pl.FP16], + output: pl.Out[pl.Tensor[[1, 64], pl.FP16]], + ) -> pl.Tensor[[1, 64], pl.FP16]: + tile: pl.Tile[[32, 64], pl.FP16] = pl.load(input_tensor, [0, 0], [32, 64]) + result: pl.Tile[[1, 64], pl.FP16] = pl.tile.col_sum(tile) + return pl.store(result, [0, 0], output) + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + input_tensor: pl.Tensor[[32, 64], pl.FP16], + output: pl.Out[pl.Tensor[[1, 64], pl.FP16]], + ) -> pl.Tensor[[1, 64], pl.FP16]: + output = self.kernel(input_tensor, output) + return output + + # ============================================================================= # Programs — col_max # ============================================================================= @@ -346,6 +435,29 @@ def compute_expected(self, tensors, params=None): tensors["output"][:] = torch.sum(tensors["input_tensor"], dim=0, keepdim=True) +class ColSum32x64FP32Sequential(PTOTestCase): + def get_name(self) -> str: + return "col_sum_32x64_fp32_sequential" + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("input_tensor", [32, 64], DataType.FP32, init_value=torch.randn), + TensorSpec("output", [1, 64], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return ColSum_32x64_FP32_Sequential + + def compute_expected(self, tensors, params=None): + tensors["output"][:] = torch.sum(tensors["input_tensor"], dim=0, keepdim=True) + + class ColSum16x16FP32(PTOTestCase): def get_name(self) -> str: return "col_sum_16x16_fp32" @@ -426,6 +538,80 @@ def compute_expected(self, tensors, params=None): tensors["output"][:] = buf[0:1] +class ColSum16x16FP32Sequential(PTOTestCase): + def get_name(self) -> str: + return "col_sum_16x16_fp32_sequential" + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("input_tensor", [16, 16], DataType.FP32, init_value=torch.randn), + TensorSpec("output", [1, 16], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return ColSum_16x16_FP32_Sequential + + def compute_expected(self, tensors, params=None): + tensors["output"][:] = torch.sum(tensors["input_tensor"], dim=0, keepdim=True) + + +class ColSum8x128FP32Sequential(PTOTestCase): + def get_name(self) -> str: + return "col_sum_8x128_fp32_sequential" + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("input_tensor", [8, 128], DataType.FP32, init_value=torch.randn), + TensorSpec("output", [1, 128], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return ColSum_8x128_FP32_Sequential + + def compute_expected(self, tensors, params=None): + tensors["output"][:] = torch.sum(tensors["input_tensor"], dim=0, keepdim=True) + + +class ColSum32x64FP16Sequential(PTOTestCase): + def get_name(self) -> str: + return "col_sum_32x64_fp16_sequential" + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("input_tensor", [32, 64], DataType.FP16, init_value=torch.randn), + TensorSpec("output", [1, 64], DataType.FP16, is_output=True), + ] + + def get_program(self) -> Any: + return ColSum_32x64_FP16_Sequential + + def compute_expected(self, tensors, params=None): + # Sequential reduction in FP16: accumulate rows one by one + inp = tensors["input_tensor"] + acc = inp[0].clone() + for i in range(1, inp.shape[0]): + acc = (acc + inp[i]).half() + tensors["output"][:] = acc.unsqueeze(0) + + # ============================================================================= # Test Cases — col_max # ============================================================================= @@ -644,6 +830,22 @@ def test_32x64_fp16(self, test_runner): result = test_runner.run(ColSum32x64FP16()) assert result.passed, f"Test failed: {result.error}" + def test_32x64_fp32_sequential(self, test_runner): + result = test_runner.run(ColSum32x64FP32Sequential()) + assert result.passed, f"Test failed: {result.error}" + + def test_16x16_fp32_sequential(self, test_runner): + result = test_runner.run(ColSum16x16FP32Sequential()) + assert result.passed, f"Test failed: {result.error}" + + def test_8x128_fp32_sequential(self, test_runner): + result = test_runner.run(ColSum8x128FP32Sequential()) + assert result.passed, f"Test failed: {result.error}" + + def test_32x64_fp16_sequential(self, test_runner): + result = test_runner.run(ColSum32x64FP16Sequential()) + assert result.passed, f"Test failed: {result.error}" + class TestColMax: """col_max: column-wise maximum across different shapes and dtypes.""" diff --git a/tests/ut/codegen/test_pto_codegen_ops.py b/tests/ut/codegen/test_pto_codegen_ops.py index 82949616d..aeb345ccc 100644 --- a/tests/ut/codegen/test_pto_codegen_ops.py +++ b/tests/ut/codegen/test_pto_codegen_ops.py @@ -1066,7 +1066,26 @@ def _generate_mlir(self, program_cls) -> str: return codegen_instance.generate(single) def test_col_sum_codegen(self): - """tile.col_sum emits pto.tcolsum with isBinary attribute.""" + """tile.col_sum without tmp_tile emits pto.tcolsum with no isBinary attribute.""" + + @pl.program + class Prog: + @pl.function(type=pl.FunctionType.InCore) + def main( + self, + input: pl.Tensor[[16, 16], pl.FP32], + output: pl.Tensor[[1, 16], pl.FP32], + ) -> pl.Tensor[[1, 16], pl.FP32]: + tile_in: pl.Tile[[16, 16], pl.FP32] = pl.load(input, [0, 0], [16, 16]) + result: pl.Tile[[1, 16], pl.FP32] = pl.tile.col_sum(tile_in) + return pl.store(result, [0, 0], output) + + mlir = self._generate_mlir(Prog) + assert "pto.tcolsum" in mlir, f"Expected pto.tcolsum in codegen output:\n{mlir}" + assert "isBinary" not in mlir, f"Expected no isBinary attribute in codegen output:\n{mlir}" + + def test_col_sum_codegen_binary(self): + """tile.col_sum with tmp_tile emits isBinary = true.""" @pl.program class Prog: @@ -1085,7 +1104,7 @@ def main( mlir = self._generate_mlir(Prog) assert "pto.tcolsum" in mlir, f"Expected pto.tcolsum in codegen output:\n{mlir}" - assert "isBinary = true" in mlir, f"Expected isBinary attribute in codegen output:\n{mlir}" + assert "isBinary = true" in mlir, f"Expected isBinary = true in codegen output:\n{mlir}" def test_col_max_codegen(self): """tile.col_max emits pto.tcolmax."""