diff --git a/CMakeLists.txt b/CMakeLists.txt index eb66d8155..0b900c4b8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -104,6 +104,7 @@ set(PYPTO_SOURCES src/ir/op/tile_ops/transform.cpp src/ir/op/tile_ops/unary.cpp src/ir/op/tile_ops/cross_core.cpp + src/ir/op/tile_ops/utility.cpp src/ir/op/sync_ops/sync.cpp src/ir/op/sync_ops/cross_core.cpp src/ir/op/tensor_ops/broadcast.cpp @@ -114,6 +115,7 @@ set(PYPTO_SOURCES src/ir/op/tensor_ops/reduction.cpp src/ir/op/tensor_ops/transform.cpp src/ir/op/tensor_ops/unary.cpp + src/ir/op/tensor_ops/utility.cpp src/ir/op/testing.cpp src/ir/op/type_inference.cpp diff --git a/python/pypto/backend/pto_backend.py b/python/pypto/backend/pto_backend.py index 443b09491..51e620e8d 100644 --- a/python/pypto/backend/pto_backend.py +++ b/python/pypto/backend/pto_backend.py @@ -330,6 +330,21 @@ def _generate_kernel_wrapper(func: _ir_core.Function, ptoas_code: str) -> str: 3. ``kernel_entry`` wrapper with arg unpacking and forward call """ header = _KERNEL_HEADER.format(func_name=func.name) + # TPRINT is guarded by #ifdef _DEBUG in pto-isa headers. Defining + # _DEBUG globally is too broad (it enables cce::printf calls that don't + # compile on simulation). Instead, provide a no-op fallback so the + # generated code compiles in all environments. + if "TPRINT" in ptoas_code: + header = header.replace( + "using namespace pto;", + "using namespace pto;\n\n" + "#ifndef _DEBUG\n" + "namespace pto {\n" + "template \n" + "PTO_INST void TPRINT(T& /*src*/) {}\n" + "} // namespace pto\n" + "#endif", + ) ptoas_body = _preprocess_ptoas_output(ptoas_code) unpacking_code, var_names = _generate_arg_unpacking(func) call_args = ", ".join(var_names) diff --git a/python/pypto/ir/op/tensor_ops.py b/python/pypto/ir/op/tensor_ops.py index 5b8066a07..f9b98e384 100644 --- a/python/pypto/ir/op/tensor_ops.py +++ b/python/pypto/ir/op/tensor_ops.py @@ -940,3 +940,19 @@ def scatter_update( op_args: list[Expr] = [input, index, src] kwargs: dict[str, Any] = {"dim": dim_val} return _ir_core.create_op_call("tensor.scatter_update", op_args, kwargs, actual_span) + + +def runtime_print(tensor: Expr, span: Span | None = None) -> Call: + """Print tensor contents at runtime for debugging. + + Generates a pto.tprint instruction in the compiled output. + + Args: + tensor: Input tensor expression (TensorType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression (type is pass-through TensorType) + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("tensor.runtime_print", [tensor], {}, actual_span) diff --git a/python/pypto/ir/op/tile_ops.py b/python/pypto/ir/op/tile_ops.py index c044edaf8..a259b2fbb 100644 --- a/python/pypto/ir/op/tile_ops.py +++ b/python/pypto/ir/op/tile_ops.py @@ -1950,3 +1950,19 @@ def tpop_from_aiv( op = _ir_core.get_op("tile.tpop_from_aiv") return _ir_core.Call(op, [], {"split": split}, resolved_type, actual_span) return _ir_core.create_op_call("tile.tpop_from_aiv", [], {"split": split}, actual_span) + + +def runtime_print(tile: Expr, span: Span | None = None) -> Call: + """Print tile contents at runtime for debugging. + + Generates a pto.tprint instruction in the compiled output. + + Args: + tile: Input tile expression (TileType) + span: Optional source span for debugging (auto-captured if not provided) + + Returns: + Call expression (type is pass-through TileType) + """ + actual_span = _get_span_or_capture(span) + return _ir_core.create_op_call("tile.runtime_print", [tile], {}, actual_span) diff --git a/python/pypto/language/__init__.py b/python/pypto/language/__init__.py index 2afad33e2..470c869f8 100644 --- a/python/pypto/language/__init__.py +++ b/python/pypto/language/__init__.py @@ -159,6 +159,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: row_min, row_sum, rsqrt, + runtime_print, slice, sqrt, sub, @@ -264,6 +265,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: "recip", "read", "write", + "runtime_print", # Promoted tile-only "create_tile", "fillpad", diff --git a/python/pypto/language/op/tensor_ops.py b/python/pypto/language/op/tensor_ops.py index 602bc8951..8e9dde619 100644 --- a/python/pypto/language/op/tensor_ops.py +++ b/python/pypto/language/op/tensor_ops.py @@ -59,6 +59,7 @@ "reshape", "transpose", "scatter_update", + "runtime_print", ] from pypto.ir.op import tensor_ops as _ir_ops @@ -779,3 +780,15 @@ def scatter_update( """ call_expr = _ir_ops.scatter_update(input.unwrap(), dim, index.unwrap(), src.unwrap()) return Tensor(expr=call_expr) + + +def runtime_print(tensor: Tensor) -> None: + """Print tensor contents at runtime for debugging. + + Generates a pto.tprint instruction in the compiled output. + This is a statement-only operation — no value is returned. + + Args: + tensor: Input tensor to print + """ + _ir_ops.runtime_print(tensor.unwrap()) diff --git a/python/pypto/language/op/tile_ops.py b/python/pypto/language/op/tile_ops.py index 44b062981..747a65c26 100644 --- a/python/pypto/language/op/tile_ops.py +++ b/python/pypto/language/op/tile_ops.py @@ -108,6 +108,7 @@ "tpush_to_aic", "tpop_from_aic", "tpop_from_aiv", + "runtime_print", ] from pypto.ir.op import tile_ops as _ir_ops @@ -1547,3 +1548,15 @@ def sels(lhs: Tile, rhs: Tile, select_mode: int | float | Expr | Scalar) -> Tile select_mode_expr = select_mode.unwrap() if isinstance(select_mode, Scalar) else select_mode call_expr = _ir_ops.sels(lhs.unwrap(), rhs.unwrap(), select_mode_expr) return Tile(expr=call_expr) + + +def runtime_print(tile: Tile) -> None: + """Print tile contents at runtime for debugging. + + Generates a pto.tprint instruction in the compiled output. + This is a statement-only operation — no value is returned. + + Args: + tile: Input tile to print + """ + _ir_ops.runtime_print(tile.unwrap()) diff --git a/python/pypto/language/op/unified_ops.py b/python/pypto/language/op/unified_ops.py index a0f3165fe..da6bd06c0 100644 --- a/python/pypto/language/op/unified_ops.py +++ b/python/pypto/language/op/unified_ops.py @@ -53,6 +53,7 @@ "create_tile", "read", "write", + "runtime_print", ] from pypto.ir.utils import resolve_cast_mode @@ -552,3 +553,19 @@ def write(dst: Tensor | Tile, offset: IntLike | Sequence[IntLike], value: Scalar if isinstance(dst, Tile): return _tile.write(dst, offset, value) raise TypeError(f"write: expected Tensor or Tile, got {type(dst).__name__}") + + +def runtime_print(src: Tensor | Tile) -> None: + """Print tensor or tile contents at runtime for debugging. + + Generates a pto.tprint instruction in the compiled output. + This is a statement-only operation — no value is returned. + + Args: + src: Tensor or tile to print + """ + if isinstance(src, Tensor): + return _tensor.runtime_print(src) + if isinstance(src, Tile): + return _tile.runtime_print(src) + raise TypeError(f"runtime_print: expected Tensor or Tile, got {type(src).__name__}") diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index 520146936..9b926c029 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -364,7 +364,13 @@ static std::string MakePrintCodegenPTO(const std::string& pto_op_name, const Cal CHECK(op->args_.size() == 1) << "Operation:" << pto_op_name << "] requires 1 argument, but got " << op->args_.size(); std::string src = codegen.GetExprAsCode(op->args_[0]); - codegen.Emit(pto_op_name + " ins(" + src + " | !pto.partition_tensor_view)"); + std::string src_type = codegen.GetExprTypeAnnotation(op->args_[0]); + std::string line = pto_op_name + " ins(" + src; + if (!src_type.empty()) { + line += " : " + src_type; + } + line += ")"; + codegen.Emit(line); return ""; } @@ -1234,9 +1240,11 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc reg("tile.mrgsort", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeMrgSortCodegenPTO("pto.tmrgsort", op, codegen); }); - reg("tile.print", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + auto make_tprint = [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakePrintCodegenPTO("pto.tprint", op, codegen); - }); + }; + reg("tile.runtime_print", make_tprint); + reg("tensor.runtime_print", make_tprint); // In-place accumulation ops (matmul_acc, gemv_acc): ptoas expects the // accumulator in ins() to be the same SSA value as outs(). InitMemRef diff --git a/src/ir/op/tensor_ops/utility.cpp b/src/ir/op/tensor_ops/utility.cpp new file mode 100644 index 000000000..dc4fe9b13 --- /dev/null +++ b/src/ir/op/tensor_ops/utility.cpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +/** + * @file utility.cpp + * @brief Utility tensor operations (print) + * + * This file implements utility/debugging operations for tensor-level programming. + */ + +#include +#include +#include +#include + +#include "pypto/core/logging.h" +#include "pypto/ir/kind_traits.h" +#include "pypto/ir/op_registry.h" +#include "pypto/ir/type.h" + +namespace pypto { +namespace ir { + +TypePtr DeduceTensorPrintType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 1) << "The operator " << op_name << " requires 1 argument (tensor), but got " + << args.size(); + auto tensor_type = As(args[0]->GetType()); + CHECK(tensor_type) << "The operator " << op_name << " requires argument to be a TensorType, but got " + << args[0]->GetType()->TypeName(); + // Pass-through: returns the input tensor type (print is a side-effect operation) + return tensor_type; +} + +REGISTER_OP("tensor.runtime_print") + .set_op_category("TensorOp") + .set_description("Print tensor contents for debugging (generates pto.tprint)") + .add_argument("tensor", "Input tensor to print (TensorType)") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceTensorPrintType(args, kwargs, "tensor.runtime_print"); + }); + +} // namespace ir +} // namespace pypto diff --git a/src/ir/op/tile_ops/utility.cpp b/src/ir/op/tile_ops/utility.cpp new file mode 100644 index 000000000..169d66824 --- /dev/null +++ b/src/ir/op/tile_ops/utility.cpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +/** + * @file utility.cpp + * @brief Utility tile operations (print) + * + * This file implements utility/debugging operations for tile-level programming. + */ + +#include +#include +#include +#include + +#include "pypto/core/logging.h" +#include "pypto/ir/kind_traits.h" +#include "pypto/ir/op_registry.h" +#include "pypto/ir/type.h" + +namespace pypto { +namespace ir { + +TypePtr DeduceTilePrintType(const std::vector& args, + const std::vector>& kwargs, + const std::string& op_name) { + CHECK(args.size() == 1) << "The operator " << op_name << " requires 1 argument (tile), but got " + << args.size(); + auto tile_type = As(args[0]->GetType()); + CHECK(tile_type) << "The operator " << op_name << " requires argument to be a TileType, but got " + << args[0]->GetType()->TypeName(); + // Pass-through: returns the input tile type (print is a side-effect operation) + return tile_type; +} + +REGISTER_OP("tile.runtime_print") + .set_op_category("TileOp") + .set_description("Print tile contents for debugging (generates pto.tprint)") + .add_argument("tile", "Input tile to print (TileType)") + .no_memory_spec() + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceTilePrintType(args, kwargs, "tile.runtime_print"); + }); + +} // namespace ir +} // namespace pypto diff --git a/src/ir/transforms/op_conversion_registry.cpp b/src/ir/transforms/op_conversion_registry.cpp index e388784e2..0b3d62678 100644 --- a/src/ir/transforms/op_conversion_registry.cpp +++ b/src/ir/transforms/op_conversion_registry.cpp @@ -158,6 +158,42 @@ OpConversionRegistry::OpConversionRegistry() { // Memory creation ops RegisterSimple("tensor.full", "tile.full"); + // Utility ops — runtime_print needs a custom converter because the + // argument may still be a TensorType (e.g. printing a function parameter + // before any explicit tile.load). In that case we insert a tile.load + // prologue to materialise the tile, matching the tensor.fillpad pattern. + RegisterCustom( + "tensor.runtime_print", + [](const std::vector& args, const std::vector>& kwargs, + const Span& span) -> ConversionResult { + CHECK(args.size() == 1) << "tensor.runtime_print conversion expects 1 arg (input)"; + auto& op_reg = OpRegistry::GetInstance(); + const auto& input = args[0]; + + // Already a tile — pass through. + if (As(input->GetType())) { + return ConversionResult{op_reg.Create("tile.runtime_print", {input}, span)}; + } + + auto tensor_type = As(input->GetType()); + CHECK(tensor_type) << "tensor.runtime_print conversion: input must be TensorType or TileType, got " + << input->GetType()->TypeName(); + + auto offsets = MakeZeroOffsetsTuple(tensor_type->shape_.size(), span); + auto shapes = MakeShapesTuple(tensor_type->shape_, span); + + std::vector> load_kwargs = {{"target_memory", MemorySpace::Vec}, + {"transpose", false}}; + auto load_call = op_reg.Create("tile.load", {input, offsets, shapes, shapes}, load_kwargs, span); + auto load_var = std::make_shared("runtime_print_src", load_call->GetType(), span); + + std::vector prologue; + prologue.push_back(std::make_shared(load_var, load_call, span)); + + auto print_call = op_reg.Create("tile.runtime_print", {load_var}, span); + return ConversionResult{std::move(prologue), print_call}; + }); + // ──────────────────────────────────────────────────────────────────────── // Broadcast-aware elementwise binary ops // diff --git a/tests/st/runtime/test_runtime_print.py b/tests/st/runtime/test_runtime_print.py new file mode 100644 index 000000000..08eb20625 --- /dev/null +++ b/tests/st/runtime/test_runtime_print.py @@ -0,0 +1,161 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +""" +Runtime tests for pl.runtime_print() — verifies that pto.tprint codegen works +end-to-end without breaking kernel correctness. + +runtime_print is a debugging utility that emits pto.tprint instructions. +These tests verify that inserting runtime_print into a kernel does NOT +affect the computed results (it is a pure side-effect operation). +""" + +from typing import Any + +import pypto.language as pl +import pytest +from harness.core.harness import DataType, PTOTestCase, TensorSpec + +# ============================================================================= +# Program definitions +# ============================================================================= + + +@pl.program +class RuntimePrintTileProgram: + """Element-wise add with a runtime_print of the intermediate tile.""" + + @pl.function(type=pl.FunctionType.InCore) + def tile_add_with_print( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + c: pl.Out[pl.Tensor[[128, 128], pl.FP32]], + ) -> pl.Tensor[[128, 128], pl.FP32]: + tile_a: pl.Tile[[128, 128], pl.FP32] = pl.load(a, [0, 0], [128, 128]) + tile_b: pl.Tile[[128, 128], pl.FP32] = pl.load(b, [0, 0], [128, 128]) + tile_c: pl.Tile[[128, 128], pl.FP32] = pl.add(tile_a, tile_b) + pl.runtime_print(tile_c) + out_c = pl.store(tile_c, [0, 0], c) + return out_c + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + out_c: pl.Out[pl.Tensor[[128, 128], pl.FP32]], + ) -> pl.Tensor[[128, 128], pl.FP32]: + out_c = self.tile_add_with_print(a, b, out_c) + return out_c + + +@pl.program +class RuntimePrintTensorProgram: + """Element-wise add with a runtime_print of the input tensor.""" + + @pl.function(type=pl.FunctionType.InCore) + def tile_add_with_tensor_print( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + c: pl.Out[pl.Tensor[[128, 128], pl.FP32]], + ) -> pl.Tensor[[128, 128], pl.FP32]: + pl.runtime_print(a) + tile_a: pl.Tile[[128, 128], pl.FP32] = pl.load(a, [0, 0], [128, 128]) + tile_b: pl.Tile[[128, 128], pl.FP32] = pl.load(b, [0, 0], [128, 128]) + tile_c: pl.Tile[[128, 128], pl.FP32] = pl.add(tile_a, tile_b) + out_c = pl.store(tile_c, [0, 0], c) + return out_c + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + a: pl.Tensor[[128, 128], pl.FP32], + b: pl.Tensor[[128, 128], pl.FP32], + out_c: pl.Out[pl.Tensor[[128, 128], pl.FP32]], + ) -> pl.Tensor[[128, 128], pl.FP32]: + out_c = self.tile_add_with_tensor_print(a, b, out_c) + return out_c + + +# ============================================================================= +# Test cases +# ============================================================================= + + +class RuntimePrintTileTestCase(PTOTestCase): + """Test: element-wise add with pl.runtime_print(tile) after computation.""" + + __test__ = False + + def get_name(self) -> str: + return "runtime_print_tile_128x128" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("a", [128, 128], DataType.FP32, init_value=2.0), + TensorSpec("b", [128, 128], DataType.FP32, init_value=3.0), + TensorSpec("c", [128, 128], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return RuntimePrintTileProgram + + def compute_expected(self, tensors, params=None): + """Expected: c = a + b. runtime_print does not affect the result.""" + tensors["c"][:] = tensors["a"] + tensors["b"] + + +class RuntimePrintTensorTestCase(PTOTestCase): + """Test: element-wise add with pl.runtime_print(tensor) on input.""" + + __test__ = False + + def get_name(self) -> str: + return "runtime_print_tensor_128x128" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("a", [128, 128], DataType.FP32, init_value=2.0), + TensorSpec("b", [128, 128], DataType.FP32, init_value=3.0), + TensorSpec("c", [128, 128], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return RuntimePrintTensorProgram + + def compute_expected(self, tensors, params=None): + """Expected: c = a + b. runtime_print does not affect the result.""" + tensors["c"][:] = tensors["a"] + tensors["b"] + + +# ============================================================================= +# pytest test functions +# ============================================================================= + + +class TestRuntimePrint: + """Test suite for runtime_print — verifies codegen and correctness.""" + + def test_runtime_print_tile(self, test_runner): + """runtime_print(tile) should compile and run without affecting results.""" + test_case = RuntimePrintTileTestCase() + result = test_runner.run(test_case) + assert result.passed, f"runtime_print tile test failed: {result.error}" + + def test_runtime_print_tensor(self, test_runner): + """runtime_print(tensor) should compile and run without affecting results.""" + test_case = RuntimePrintTensorTestCase() + result = test_runner.run(test_case) + assert result.passed, f"runtime_print tensor test failed: {result.error}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/language/parser/test_runtime_print.py b/tests/ut/language/parser/test_runtime_print.py new file mode 100644 index 000000000..8ac1d5763 --- /dev/null +++ b/tests/ut/language/parser/test_runtime_print.py @@ -0,0 +1,197 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""Unit tests for pl.runtime_print() — runtime tile/tensor printing.""" + +import pypto.language as pl +import pytest +from pypto import ir +from pypto.language.parser.diagnostics.exceptions import InvalidOperationError + + +class TestRuntimePrintTile: + """Tests for pl.runtime_print() with tile arguments.""" + + def test_runtime_print_tile_creates_eval_stmt(self): + """runtime_print(tile) should create an EvalStmt wrapping a tile.runtime_print Call.""" + + @pl.function + def func(x: pl.Tensor[[32, 32], pl.FP32]) -> pl.Tensor[[32, 32], pl.FP32]: + tile: pl.Tile[[32, 32], pl.FP32] = pl.load(x, [0, 0], [32, 32]) + pl.runtime_print(tile) + return pl.store(tile, [0, 0], x) + + body = func.body + assert isinstance(body, ir.SeqStmts) + # AssignStmt (load) + EvalStmt (runtime_print) + ReturnStmt + assert len(body.stmts) == 3 + eval_stmt = body.stmts[1] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert call.op.name == "tile.runtime_print" + assert len(call.args) == 1 + + def test_tile_namespace_runtime_print(self): + """pl.tile.runtime_print(tile) should produce the same IR.""" + + @pl.function + def func(x: pl.Tensor[[32, 32], pl.FP32]) -> pl.Tensor[[32, 32], pl.FP32]: + tile: pl.Tile[[32, 32], pl.FP32] = pl.load(x, [0, 0], [32, 32]) + pl.tile.runtime_print(tile) + return pl.store(tile, [0, 0], x) + + body = func.body + assert isinstance(body, ir.SeqStmts) + eval_stmt = body.stmts[1] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert call.op.name == "tile.runtime_print" + + def test_runtime_print_tile_does_not_affect_data_flow(self): + """runtime_print should not change data flow — only adds an EvalStmt.""" + + @pl.function + def with_print(x: pl.Tensor[[32, 32], pl.FP32]) -> pl.Tensor[[32, 32], pl.FP32]: + tile: pl.Tile[[32, 32], pl.FP32] = pl.load(x, [0, 0], [32, 32]) + pl.runtime_print(tile) + return pl.store(tile, [0, 0], x) + + @pl.function + def without_print(x: pl.Tensor[[32, 32], pl.FP32]) -> pl.Tensor[[32, 32], pl.FP32]: + tile: pl.Tile[[32, 32], pl.FP32] = pl.load(x, [0, 0], [32, 32]) + return pl.store(tile, [0, 0], x) + + # Should NOT be structurally equal (with_print has an extra EvalStmt) + with_body = with_print.body + without_body = without_print.body + assert isinstance(with_body, ir.SeqStmts) + assert isinstance(without_body, ir.SeqStmts) + assert len(with_body.stmts) == len(without_body.stmts) + 1 + + def test_runtime_print_tile_roundtrip(self): + """Print IR → reparse → should produce structurally equal IR.""" + + @pl.program + class Before: + @pl.function + def main(self, x: pl.Tensor[[32, 32], pl.FP32]) -> pl.Tensor[[32, 32], pl.FP32]: + tile: pl.Tile[[32, 32], pl.FP32] = pl.load(x, [0, 0], [32, 32]) + pl.runtime_print(tile) + return pl.store(tile, [0, 0], x) + + printed = Before.as_python() + assert "pl.tile.runtime_print(" in printed + reparsed = pl.parse_program(printed) + ir.assert_structural_equal(Before, reparsed) + + def test_runtime_print_tile_type_preserved(self): + """The Call expression should have TileType matching the input tile.""" + + @pl.function + def func(x: pl.Tensor[[32, 32], pl.FP16]) -> pl.Tensor[[32, 32], pl.FP16]: + tile: pl.Tile[[32, 32], pl.FP16] = pl.load(x, [0, 0], [32, 32]) + pl.runtime_print(tile) + return pl.store(tile, [0, 0], x) + + body = func.body + assert isinstance(body, ir.SeqStmts) + eval_stmt = body.stmts[1] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert isinstance(call.type, ir.TileType) + + +class TestRuntimePrintTensor: + """Tests for pl.runtime_print() with tensor arguments.""" + + def test_runtime_print_tensor_creates_eval_stmt(self): + """runtime_print(tensor) should create an EvalStmt wrapping a tensor.runtime_print Call.""" + + @pl.function + def func(x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + pl.runtime_print(x) + return x + + body = func.body + assert isinstance(body, ir.SeqStmts) + # EvalStmt (runtime_print) + ReturnStmt + assert len(body.stmts) == 2 + eval_stmt = body.stmts[0] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.runtime_print" + assert len(call.args) == 1 + + def test_tensor_namespace_runtime_print(self): + """pl.tensor.runtime_print(tensor) should produce the same IR.""" + + @pl.function + def func(x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + pl.tensor.runtime_print(x) + return x + + body = func.body + assert isinstance(body, ir.SeqStmts) + eval_stmt = body.stmts[0] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.runtime_print" + + def test_runtime_print_tensor_roundtrip(self): + """Print IR → reparse → should produce structurally equal IR.""" + + @pl.program + class Before: + @pl.function + def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + pl.runtime_print(x) + return x + + printed = Before.as_python() + assert "pl.tensor.runtime_print(" in printed + reparsed = pl.parse_program(printed) + ir.assert_structural_equal(Before, reparsed) + + def test_runtime_print_tensor_type_preserved(self): + """The Call expression should have TensorType matching the input tensor.""" + + @pl.function + def func(x: pl.Tensor[[64, 128], pl.FP16]) -> pl.Tensor[[64, 128], pl.FP16]: + pl.runtime_print(x) + return x + + body = func.body + assert isinstance(body, ir.SeqStmts) + eval_stmt = body.stmts[0] + assert isinstance(eval_stmt, ir.EvalStmt) + call = eval_stmt.expr + assert isinstance(call, ir.Call) + assert isinstance(call.type, ir.TensorType) + + +class TestRuntimePrintErrors: + """Tests for error cases.""" + + def test_runtime_print_requires_tile_or_tensor(self): + """runtime_print with scalar should raise an error.""" + with pytest.raises(InvalidOperationError): + + @pl.function + def func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: + pl.runtime_print(x) # type: ignore + return x + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])