diff --git a/CMakeLists.txt b/CMakeLists.txt index b5ef6a3cd..3c70fb744 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,6 +112,7 @@ set(PYPTO_SOURCES src/ir/op/tensor_ops/elementwise.cpp src/ir/op/tensor_ops/matmul.cpp src/ir/op/tensor_ops/memory.cpp + src/ir/op/tensor_ops/scatter.cpp src/ir/op/tensor_ops/scatter_update.cpp src/ir/op/tensor_ops/reduction.cpp src/ir/op/tensor_ops/transform.cpp diff --git a/docs/en/user/02-operation_reference.md b/docs/en/user/02-operation_reference.md index 0b4872056..9b3e9e26e 100644 --- a/docs/en/user/02-operation_reference.md +++ b/docs/en/user/02-operation_reference.md @@ -43,6 +43,7 @@ Operate on `Tensor` objects (DDR memory). | `transpose` | `(tensor: Tensor, axis1: int, axis2: int) -> Tensor` | Swap two axes | | `assemble` | `(target: Tensor, source: Tensor, offset: Sequence[IntLike]) -> Tensor` | Write source into target at offset | | `scatter_update` | `(input: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor` | Update rows of `input` at sparse positions given by `index` with values from `src`. `input`/`src`: 2D `[rows, d]` or 4D `[B, S, 1, d]`; `index`: 2D `[b, s]` integer. Only `dim=-2` is supported | +| `scatter_` | `(input: Tensor, dim: int, index: Tensor, src: Tensor \| float \| int) -> Tensor` | Element-level scatter: write `src` values into `input` at positions given by `index` along `dim`. Follows PyTorch `scatter_` semantics. Supports arbitrary rank and any valid `dim` in `[-rank, rank)`. `src` can be a tensor or a scalar | | `add` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise add | | `sub` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise subtract | | `mul` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise multiply | diff --git a/docs/zh-cn/user/02-operation_reference.md b/docs/zh-cn/user/02-operation_reference.md index 3835aa967..31b0ab5d8 100644 --- a/docs/zh-cn/user/02-operation_reference.md +++ b/docs/zh-cn/user/02-operation_reference.md @@ -40,6 +40,7 @@ | `transpose` | `(tensor: Tensor, axis1: int, axis2: int) -> Tensor` | 交换两个轴 | | `assemble` | `(target: Tensor, source: Tensor, offset: Sequence[IntLike]) -> Tensor` | 将 source 写入 target 的指定偏移 | | `scatter_update` | `(input: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor` | 按 `index` 指定的稀疏行位置,将 `src` 的行数据写入 `input`。`input`/`src`:2D `[rows, d]` 或 4D `[B, S, 1, d]`;`index`:2D `[b, s]` 整型。当前仅支持 `dim=-2` | +| `scatter_` | `(input: Tensor, dim: int, index: Tensor, src: Tensor \| float \| int) -> Tensor` | 逐元素散射:沿 `dim` 维度按 `index` 指定的位置将 `src` 写入 `input`。语义与 PyTorch `scatter_` 一致。支持任意维度和 `[-rank, rank)` 范围内的 `dim`。`src` 可以是张量或标量 | | `add` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素加法 | | `sub` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素减法 | | `mul` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素乘法 | diff --git a/python/pypto/debug/torch_codegen.py b/python/pypto/debug/torch_codegen.py index 3dd73fa85..106910598 100644 --- a/python/pypto/debug/torch_codegen.py +++ b/python/pypto/debug/torch_codegen.py @@ -210,6 +210,14 @@ def _handle_slice(a: list[str], _kw: dict[str, Any]) -> str: _OP_MAP: dict[str, OpHandler] = {} +def _scatter_handler(a: list[str], kw: dict[str, str]) -> str: + dim = kw.get("dim", "0") + reduce = kw.get("reduce", None) + if reduce and reduce != "none": + return f"{a[0]}.scatter_({dim}, {a[1]}, {a[2]}, reduce='{reduce}')" + return f"{a[0]}.scatter_({dim}, {a[1]}, {a[2]})" + + def _register_ops() -> None: m = _OP_MAP @@ -261,6 +269,9 @@ def _register_ops() -> None: # scatter_update m[f"{prefix}.scatter_update"] = lambda a, kw: f"{a[0]}.scatter_(-2, {a[1]}.expand_as({a[2]}), {a[2]})" + # scatter_ (element-level scatter) + m[f"{prefix}.scatter_"] = _scatter_handler + # broadcast ops - torch broadcasting handles these naturally m[f"{prefix}.row_expand_add"] = _binop("+") m[f"{prefix}.row_expand_sub"] = _binop("-") diff --git a/python/pypto/ir/op/tensor_ops.py b/python/pypto/ir/op/tensor_ops.py index 5b8066a07..84f6224b7 100644 --- a/python/pypto/ir/op/tensor_ops.py +++ b/python/pypto/ir/op/tensor_ops.py @@ -940,3 +940,74 @@ def scatter_update( op_args: list[Expr] = [input, index, src] kwargs: dict[str, Any] = {"dim": dim_val} return _ir_core.create_op_call("tensor.scatter_update", op_args, kwargs, actual_span) + + +def scatter_( + input: Expr, + *args: Expr | int | float, + dim: int | Expr | None = None, + index: Expr | None = None, + src: Expr | float | int | None = None, + reduce: str | None = None, + span: Span | None = None, +) -> Call: + """Element-level scatter into tensor along a dimension. + + For each position (i₀,…,iₙ) in index, sets: + input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ] + + Follows PyTorch ``torch.Tensor.scatter_`` semantics. + + Accepts call forms: + - scatter_(input, dim, index, src) + - scatter_(input, dim, index, src=1.0) + + Args: + input: Destination tensor (N-D). + dim: Dimension along which to scatter. + index: Index tensor (same rank as input, integer dtype). + src: Source tensor (same shape as index) or scalar value. + span: Optional source span for debugging (auto-captured if not provided). + + Returns: + Call expression returning the updated input tensor. + """ + if len(args) == 3 and dim is None and index is None and src is None: + dim, index, src = args + elif len(args) == 2 and dim is not None and index is None and src is None: + index, src = args + elif len(args) == 1 and dim is None and index is not None and src is not None: + dim = args[0] + elif len(args) != 0: + raise TypeError( + "scatter_ expects (input, dim, index, src), " + "(input, index, src, dim=...), or (input, dim, index=..., src=...)" + ) + + if dim is None or index is None or src is None: + raise TypeError("scatter_ requires input, dim, index, and src") + + actual_span = _get_span_or_capture(span) + if isinstance(dim, ConstInt): + dim_val = int(dim.value) + elif isinstance(dim, int): + dim_val = dim + else: + raise TypeError(f"dim must be int or ConstInt, got {type(dim)}") + + if not isinstance(index, Expr): + raise TypeError(f"index must be Expr, got {type(index)}") + + # src can be Expr or scalar (int → ConstInt, float → ConstFloat) + if isinstance(src, int): + src = ConstInt(src, DataType.INT32, actual_span) + elif isinstance(src, float): + src = ConstFloat(src, DataType.FP32, actual_span) + elif not isinstance(src, Expr): + raise TypeError(f"src must be Expr or scalar, got {type(src)}") + + op_args: list[Expr] = [input, index, src] + kwargs: dict[str, Any] = {"dim": dim_val} + if reduce is not None: + kwargs["reduce"] = reduce + return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span) diff --git a/python/pypto/language/__init__.py b/python/pypto/language/__init__.py index 035f7b4f9..f4ab1aed7 100644 --- a/python/pypto/language/__init__.py +++ b/python/pypto/language/__init__.py @@ -85,7 +85,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: tpush_to_aic, tpush_to_aiv, ) -from .op.tensor_ops import assemble, create_tensor, dim, full, scatter_update +from .op.tensor_ops import assemble, create_tensor, dim, full, scatter_, scatter_update from .op.tile_ops import ( MemRefType, abs, @@ -326,6 +326,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]: "dim", "full", "scatter_update", + "scatter_", "FunctionType", "ForKind", "Level", diff --git a/python/pypto/language/op/__init__.py b/python/pypto/language/op/__init__.py index b305a1126..5ea56493a 100644 --- a/python/pypto/language/op/__init__.py +++ b/python/pypto/language/op/__init__.py @@ -27,7 +27,7 @@ from . import tile_ops as tile # Promoted tensor-only ops (accessible as pl.create_tensor, etc.) -from .tensor_ops import assemble, dim, scatter_update +from .tensor_ops import assemble, dim, scatter_, scatter_update from .tensor_ops import create as create_tensor # Promoted tile-only ops (accessible as pl.load, etc.) @@ -191,4 +191,5 @@ "assemble", "dim", "scatter_update", + "scatter_", ] diff --git a/python/pypto/language/op/tensor_ops.py b/python/pypto/language/op/tensor_ops.py index 602bc8951..eae32df9a 100644 --- a/python/pypto/language/op/tensor_ops.py +++ b/python/pypto/language/op/tensor_ops.py @@ -59,6 +59,7 @@ "reshape", "transpose", "scatter_update", + "scatter_", ] from pypto.ir.op import tensor_ops as _ir_ops @@ -779,3 +780,38 @@ def scatter_update( """ call_expr = _ir_ops.scatter_update(input.unwrap(), dim, index.unwrap(), src.unwrap()) return Tensor(expr=call_expr) + + +def scatter_( + input: Tensor, + dim: int, + index: Tensor, + src: float | int | Tensor, + *, + reduce: str | None = None, +) -> Tensor: + """Element-level scatter: write src values into input at positions given by index along dim. + + For each element position (i0,...,in) in index, sets: + input[i0]...[i_{d-1}][index[i0...in]][i_{d+1}]...[in] = src[i0...in] + + Supports arbitrary rank and any valid dim in [-rank, rank). + src can be a tensor (same shape as index) or a scalar value. + + Args: + input: Destination tensor (N-D) + dim: Dimension along which to scatter + index: Index tensor (N-D, same rank as input) of integer dtype + src: Source tensor (same shape as index) or scalar value + reduce: Optional reduce mode ("add" or "multiply") + + Returns: + Tensor wrapping the scatter_ operation + """ + src_expr: float | int | Expr + if isinstance(src, (Tensor, Scalar)): + src_expr = src.unwrap() + else: + src_expr = src + call_expr = _ir_ops.scatter_(input.unwrap(), dim, index.unwrap(), src_expr, reduce=reduce) + return Tensor(expr=call_expr) diff --git a/src/backend/common/pto_ops_common.cpp b/src/backend/common/pto_ops_common.cpp index c4c051ec0..f521f4bdd 100644 --- a/src/backend/common/pto_ops_common.cpp +++ b/src/backend/common/pto_ops_common.cpp @@ -717,22 +717,46 @@ static std::string MakeTileAllocCodegenPTO(const CallPtr& op, codegen::CodegenBa return ""; // No MLIR emission - pto.alloc_tile generated from MemRefs in TileTypes } -// Compute a row-major flat offset string from a MakeTuple of indices and the shape of the container. +// Compute a row-major flat offset from a MakeTuple of indices and the shape, +// emitting proper arith.muli / arith.addi SSA operations. +// Returns the SSA name of the final flat-offset value (index type). static std::string ComputeFlatOffsetPTO(const ir::MakeTuplePtr& indices_tuple, const std::vector& shape, codegen::PTOCodegen& codegen) { const auto& indices = indices_tuple->elements_; INTERNAL_CHECK(indices.size() == shape.size()) << "Index count (" << indices.size() << ") must match shape rank (" << shape.size() << ")"; - std::ostringstream idx_oss; + // Helper: ensure an index element SSA value has `index` type. + // If the expression is a non-index integer (e.g. i32 from tile.read on an + // INT32 tile), emit arith.index_cast to convert it. + auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string { + if (auto var = ir::As(expr)) { + return codegen.EmitCastToIndex(var, ssa); + } + return ssa; + }; + + // For each dimension i, compute: index[i] * (shape[i+1] * shape[i+2] * ... * shape[rank-1]) + // then sum all terms with arith.addi. + std::string accumulator; for (size_t i = 0; i < indices.size(); ++i) { - if (i > 0) idx_oss << " + "; - idx_oss << codegen.GetExprAsCode(indices[i]); + std::string term = ensure_index(indices[i], codegen.GetExprAsCode(indices[i])); + // Multiply by each trailing dimension size for (size_t j = i + 1; j < shape.size(); ++j) { - idx_oss << " * " << codegen.GetExprAsCode(shape[j]); + std::string dim = codegen.GetExprAsCode(shape[j]); + std::string tmp = codegen.NewTemp(); + codegen.Emit(tmp + " = arith.muli " + term + ", " + dim + " : index"); + term = tmp; + } + if (accumulator.empty()) { + accumulator = term; + } else { + std::string tmp = codegen.NewTemp(); + codegen.Emit(tmp + " = arith.addi " + accumulator + ", " + term + " : index"); + accumulator = tmp; } } - return idx_oss.str(); + return accumulator; } // Get or emit a flat offset SSA value for a MakeTuple of indices and shape. @@ -932,6 +956,15 @@ static std::string MakeTensorDimCodegenPTO(const CallPtr& op, codegen::CodegenBa return ""; } +static std::string MakeSystemBarrierCodegenPTO(const std::string& pipe_name, const CallPtr& op, + codegen::CodegenBase& codegen_base) { + CHECK(op->args_.empty()) << "system.barrier_" << pipe_name << " expects 0 arguments, got " + << op->args_.size(); + auto& codegen = dynamic_cast(codegen_base); + codegen.Emit("pto.barrier #pto.pipe<" + pipe_name + ">"); + return ""; +} + // ============================================================================ // Cross-Core Communication Operations (TPUSH/TPOP) // ============================================================================ @@ -1334,6 +1367,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set& exc reg("tile.cast", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { return MakeTileCvtCodegenPTO("pto.tcvt", op, codegen); }); + reg("system.bar_v", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeSystemBarrierCodegenPTO("PIPE_V", op, codegen); + }); + reg("system.bar_m", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeSystemBarrierCodegenPTO("PIPE_M", op, codegen); + }); + reg("system.bar_all", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { + return MakeSystemBarrierCodegenPTO("PIPE_ALL", op, codegen); + }); // tile.full (TEXPANDS): output is row_major per ISA if (exclude_ops.count("tile.full") == 0) { backend.RegisterOp("tile.full") diff --git a/src/codegen/pto/pto_codegen.cpp b/src/codegen/pto/pto_codegen.cpp index df8ed1a6b..46fefe500 100644 --- a/src/codegen/pto/pto_codegen.cpp +++ b/src/codegen/pto/pto_codegen.cpp @@ -907,6 +907,10 @@ std::string PTOCodegen::GetExprAsCode(const ExprPtr& expr) { return GetVarName(var); } if (auto const_int = As(expr)) { + DataType dtype = const_int->dtype(); + if (dtype == DataType::INT32) { + return GetOrEmitI32Constant(static_cast(const_int->value_)); + } return GetIndexConstant(const_int->value_); } if (auto const_float = As(expr)) { @@ -1148,7 +1152,7 @@ std::string PTOCodegen::GetExprTypeAnnotation(const ir::ExprPtr& expr) { return "f32"; } if (auto const_int = As(expr)) { - return "index"; + return GetTypeString(const_int->dtype()); } return ""; } diff --git a/src/codegen/pto/pto_scalar_expr_codegen.cpp b/src/codegen/pto/pto_scalar_expr_codegen.cpp index 9f4517613..1acc50e16 100644 --- a/src/codegen/pto/pto_scalar_expr_codegen.cpp +++ b/src/codegen/pto/pto_scalar_expr_codegen.cpp @@ -171,7 +171,17 @@ void PTOCodegen::VisitExpr_(const ir::IterArgPtr& op) { } void PTOCodegen::VisitExpr_(const ir::ConstIntPtr& op) { - fs_.current_expr_value = GetOrEmitIndexConstant(op->value_); + DataType dtype = op->dtype(); + if (dtype == DataType::INDEX) { + fs_.current_expr_value = GetOrEmitIndexConstant(op->value_); + } else if (dtype == DataType::INT32) { + fs_.current_expr_value = GetOrEmitI32Constant(static_cast(op->value_)); + } else { + std::string result = NewTemp(); + std::string type_str = GetTypeString(dtype); + Emit(result + " = arith.constant " + std::to_string(op->value_) + " : " + type_str); + fs_.current_expr_value = result; + } } void PTOCodegen::VisitExpr_(const ir::ConstFloatPtr& op) { diff --git a/src/ir/op/tensor_ops/scatter.cpp b/src/ir/op/tensor_ops/scatter.cpp new file mode 100644 index 000000000..76f9fb80c --- /dev/null +++ b/src/ir/op/tensor_ops/scatter.cpp @@ -0,0 +1,118 @@ +/* + * Copyright (c) PyPTO Contributors. + * This program is free software, you can redistribute it and/or modify it under the terms and conditions of + * CANN Open Software License Agreement Version 2.0 (the "License"). + * Please refer to the License for details. You may not use this file except in compliance with the License. + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. + * See LICENSE in the root of the software repository for the full text of the License. + * ----------------------------------------------------------------------------------------------------------- + */ + +/** + * @file scatter.cpp + * @brief Element-level scatter tensor operation + * + * Implements tensor.scatter_, which writes values from a source tensor (or scalar) + * into an input tensor at positions specified by a per-element index tensor along + * a given dimension. Follows PyTorch torch.Tensor.scatter_ semantics: + * + * self[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ] + */ + +#include +#include +#include +#include +#include +#include + +#include "pypto/core/any_cast.h" +#include "pypto/core/logging.h" +#include "pypto/ir/expr.h" +#include "pypto/ir/kind_traits.h" +#include "pypto/ir/op_registry.h" +#include "pypto/ir/scalar_expr.h" +#include "pypto/ir/type.h" + +namespace pypto { +namespace ir { + +TypePtr DeduceTensorScatterType(const std::vector& args, + const std::vector>& kwargs) { + // tensor.scatter_(input, index, src) -> TensorType same as input + // input: N-D tensor + // index: N-D tensor of integer dtype (same rank as input) + // src: N-D tensor (same shape as index, dtype matches input) OR scalar + CHECK(args.size() == 3) << "tensor.scatter_ requires exactly 3 arguments (input, index, src), got " + << args.size(); + + auto input_type = As(args[0]->GetType()); + CHECK(input_type) << "tensor.scatter_: input must be TensorType, got " << args[0]->GetType()->TypeName(); + const size_t rank = input_type->shape_.size(); + CHECK(rank >= 1) << "tensor.scatter_: input must be at least 1D, got rank 0"; + + auto index_type = As(args[1]->GetType()); + CHECK(index_type) << "tensor.scatter_: index must be TensorType, got " << args[1]->GetType()->TypeName(); + CHECK(index_type->shape_.size() == rank) << "tensor.scatter_: index rank (" << index_type->shape_.size() + << ") must match input rank (" << rank << ")"; + CHECK(index_type->dtype_.IsInt()) << "tensor.scatter_: index dtype must be integer, got " + << index_type->dtype_.ToString(); + + // src can be TensorType or scalar (ConstFloat / ConstInt) + bool src_is_scalar = As(args[2]) || As(args[2]); + if (!src_is_scalar) { + auto src_type = As(args[2]->GetType()); + CHECK(src_type) << "tensor.scatter_: src must be TensorType or scalar, got " + << args[2]->GetType()->TypeName(); + CHECK(src_type->shape_.size() == rank) << "tensor.scatter_: src rank (" << src_type->shape_.size() + << ") must match input rank (" << rank << ")"; + CHECK(src_type->dtype_ == input_type->dtype_) + << "tensor.scatter_: src dtype (" << src_type->dtype_.ToString() << ") must match input dtype (" + << input_type->dtype_.ToString() << ")"; + } else { + // Validate scalar dtype compatibility with input dtype + if (As(args[2])) { + CHECK(input_type->dtype_.IsFloat()) + << "tensor.scatter_: float scalar src requires float input dtype, got " + << input_type->dtype_.ToString(); + } else if (As(args[2])) { + CHECK(input_type->dtype_.IsInt()) + << "tensor.scatter_: integer scalar src requires integer input dtype, got " + << input_type->dtype_.ToString(); + } + } + + // Validate dim kwarg + for (const auto& [key, val] : kwargs) { + if (key == "dim") { + int dim_val = AnyCast(val, "kwarg key: dim"); + int irank = static_cast(rank); + CHECK(dim_val >= -irank && dim_val < irank) << "tensor.scatter_: dim must be in [" << -irank << ", " + << irank << ") for " << rank << "D input, got " << dim_val; + } + } + + return std::make_shared(input_type->shape_, input_type->dtype_); +} + +REGISTER_OP("tensor.scatter_") + .set_op_category("TensorOp") + .set_description( + "Element-level scatter: write src values into input at positions given by index along dim. " + "For each element position (i₀,…,iₙ) in index, sets " + "input[i₀]…[i_{d-1}][index[i₀…iₙ]][i_{d+1}]…[iₙ] = src[i₀…iₙ]. " + "Supports arbitrary rank and any valid dim ∈ [-rank, rank). " + "src can be a tensor (same shape as index) or a scalar value.") + .add_argument("input", "Destination tensor (N-D)") + .add_argument("index", "Index tensor (N-D, same rank as input) of integer dtype") + .add_argument("src", "Source tensor (same shape as index) or scalar value") + .set_attr("dim") + .set_attr("reduce") + .f_deduce_type([](const std::vector& args, + const std::vector>& kwargs) { + return DeduceTensorScatterType(args, kwargs); + }); + +} // namespace ir +} // namespace pypto diff --git a/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp b/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp index f4781321e..12e3b64bd 100644 --- a/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp +++ b/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp @@ -720,6 +720,22 @@ class TensorToTileMutator : public TypePropagatingMutator { stmts.push_back(std::make_shared(tile_var, new_result, op->span_)); var_remap_[op->var_.get()] = tile_var; + // If the conversion result is an existing Var (e.g. in-place ops like scatter_ + // that return the modified input), record a direct alias instead of emitting a + // redundant assignment. This prevents PTO codegen from allocating a separate + // tile buffer for the result, which would break SSA data-flow (the writes to + // the original tile would become dead and the result tile would be uninitialised). + if (auto alias_var = As(new_result)) { + var_remap_[op->var_.get()] = alias_var; + // Remove the redundant assignment we just added and return only prologue stmts + stmts.pop_back(); + if (stmts.empty()) return std::make_shared(std::vector{}, op->span_); + } else if (auto iter_arg = As(new_result)) { + var_remap_[op->var_.get()] = iter_arg; + stmts.pop_back(); + if (stmts.empty()) return std::make_shared(std::vector{}, op->span_); + } + return SeqStmts::Flatten(std::move(stmts), op->span_); } @@ -1508,6 +1524,31 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func, size_t num_added_outputs = 0; std::unordered_set merged_return_indices; + // Collect InOut tensor params by base name for matching with return values. + // IterArg InOut returns are stored back to the existing InOut param instead + // of creating a new Out param. + std::unordered_map inout_param_by_base; + // Collect existing Out tensor params that are not yet referenced by any tile.store + // in the body. These can be reused for auto-inserted tile.store when the DSL + // cannot express pl.store() (e.g. scatter_ returns TensorType at DSL level). + std::vector unused_out_params; + for (size_t i = 0; i < func->params_.size(); ++i) { + if (func->param_directions_[i] == ParamDirection::InOut && As(func->params_[i]->GetType())) { + inout_param_by_base[auto_name::GetBaseName(func->params_[i]->name_hint_)] = func->params_[i]; + } else if (func->param_directions_[i] == ParamDirection::Out && + As(func->params_[i]->GetType())) { + // Check if this Out param is already used by a tile.store in the body + VarUseVisitor use_checker(func->params_[i].get()); + for (const auto& s : new_stmts) { + use_checker.CheckStmt(s); + if (use_checker.Found()) break; + } + if (!use_checker.Found()) { + unused_out_params.push_back(func->params_[i]); + } + } + } + if (return_stmt) { std::vector new_return_exprs; @@ -1708,11 +1749,52 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func, continue; } - // Add output tensor parameter - std::string out_name = MakeOutParamName(num_added_outputs); - auto out_param = std::make_shared(out_name, orig_tensor_type, span); - new_params.push_back(out_param); - new_param_directions.push_back(ParamDirection::Out); + // Determine target param: reuse existing InOut/Out param or create new Out param. + VarPtr out_param; + bool is_existing_param = false; + auto orig_ret_var = As(return_stmt->value_[i]); + if (orig_ret_var) { + std::string ret_base = auto_name::GetBaseName(orig_ret_var->name_hint_); + auto inout_it = inout_param_by_base.find(ret_base); + if (inout_it != inout_param_by_base.end()) { + out_param = inout_it->second; + is_existing_param = true; + inout_param_by_base.erase(inout_it); + } + } + // Try reusing an unused Out param with compatible tensor type (e.g. scatter_ + // where the DSL has pl.Out but cannot use pl.store because the result is + // still TensorType at parse time). + if (!is_existing_param && !unused_out_params.empty()) { + for (auto out_it = unused_out_params.begin(); out_it != unused_out_params.end(); ++out_it) { + auto out_type = As((*out_it)->GetType()); + if (out_type && out_type->dtype_ == orig_tensor_type->dtype_ && + out_type->shape_.size() == orig_tensor_type->shape_.size()) { + bool shapes_match = true; + for (size_t k = 0; k < out_type->shape_.size(); ++k) { + auto out_dim = As(out_type->shape_[k]); + auto orig_dim = As(orig_tensor_type->shape_[k]); + if (!out_dim || !orig_dim || out_dim->value_ != orig_dim->value_) { + shapes_match = false; + break; + } + } + if (shapes_match) { + out_param = *out_it; + is_existing_param = true; + unused_out_params.erase(out_it); + break; + } + } + } + } + if (!is_existing_param) { + // Add new output tensor parameter + std::string out_name = MakeOutParamName(num_added_outputs); + out_param = std::make_shared(out_name, orig_tensor_type, span); + new_params.push_back(out_param); + new_param_directions.push_back(ParamDirection::Out); + } if (auto loop_rewrite = RewriteReturnedAssembleLoopToStore(new_stmts, ret_expr, out_param, orig_tensor_type, op_registry)) { @@ -1723,7 +1805,7 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func, } new_return_types.push_back(orig_tensor_type); new_return_exprs.push_back(loop_rewrite->new_return_var); - ++num_added_outputs; + if (!is_existing_param) ++num_added_outputs; continue; } @@ -1737,7 +1819,7 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func, new_return_types.push_back(store_call->GetType()); new_return_exprs.push_back(store_var); - ++num_added_outputs; + if (!is_existing_param) ++num_added_outputs; } else { // Non-tile return values pass through new_return_types.push_back(ret_expr->GetType()); diff --git a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp index 999c0a61b..8b738b395 100644 --- a/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp +++ b/src/ir/transforms/flatten_tile_nd_to_2d_pass.cpp @@ -55,6 +55,28 @@ namespace { */ bool IsNdTile(const TileTypePtr& tile_type) { return tile_type && tile_type->shape_.size() > 2; } +/** + * @brief Flatten ND index tuple to 2D (merged_row, col). + * + * Given an ND index tuple (i0, i1, ..., i_{n-1}) and the tile shape [d0, d1, ..., d_{n-1}], + * compute merged_row = i0 * d1 * d2 * ... * d_{n-2} + i1 * d2 * ... + i_{n-2} and col = i_{n-1}. + * Each index element is substituted via var_map before use. + */ +ExprPtr FlattenNdIndicesToTwoD(const MakeTuplePtr& idx_tuple, const std::vector& nd_shape, + const std::unordered_map& var_map, const Span& span) { + const size_t rank = nd_shape.size(); + ExprPtr merged_row; + for (size_t k = 0; k + 1 < rank; ++k) { + ExprPtr term = Substitute(idx_tuple->elements_[k], var_map); + for (size_t j = k + 1; j + 1 < rank; ++j) { + term = MakeMul(term, nd_shape[j], span); + } + merged_row = merged_row ? MakeAdd(merged_row, term, span) : term; + } + ExprPtr col = Substitute(idx_tuple->elements_[rank - 1], var_map); + return std::make_shared(std::vector{merged_row, col}, span); +} + /** * @brief Extract a static int64_t from a ConstInt expression. * @@ -118,7 +140,7 @@ std::vector Make2DShapeExprs(int64_t merged, int64_t last, const Span& * Checks: * 1. All tile shapes are static (ConstInt) * 2. All tile reduce ops (tile.sum/max/min) on >2D tiles reduce the last axis - * 3. No tile.read/tile.write/tile.slice on >2D tiles + * 3. No tile.slice on >2D tiles */ class PreconditionChecker : public IRVisitor { public: @@ -162,8 +184,8 @@ class PreconditionChecker : public IRVisitor { } CheckStaticShape(As(call->GetType()), name); - // Disallow tile.read/tile.write/tile.slice on >2D tiles - if (name == "tile.read" || name == "tile.write" || name == "tile.slice") { + // Disallow tile.slice on >2D tiles (tile.read/tile.write are handled by the pass) + if (name == "tile.slice") { if (!call->args_.empty()) { auto input_tile = As(call->args_[0]->GetType()); CHECK(!IsNdTile(input_tile)) << "FlattenTileNdTo2D: " << name << " is not supported on >2D tiles"; @@ -407,6 +429,26 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon // EvalStmt: substitute variables in the expression if (auto eval = As(stmt)) { + // tile.write in EvalStmt on >2D tiles: flatten ND indices to 2D + if (auto call = As(eval->expr_)) { + if (call->op_ && call->op_->name_ == "tile.write") { + auto orig_tile_type = As(call->args_[0]->GetType()); + if (orig_tile_type && IsNdTile(orig_tile_type)) { + std::vector new_args; + new_args.reserve(call->args_.size()); + new_args.push_back(Substitute(call->args_[0], ctx.var_map)); + auto idx_tuple = As(call->args_[1]); + INTERNAL_CHECK(idx_tuple) << "tile.write indices must be MakeTuple"; + new_args.push_back(FlattenNdIndicesToTwoD(idx_tuple, orig_tile_type->shape_, ctx.var_map, span)); + for (size_t i = 2; i < call->args_.size(); ++i) { + new_args.push_back(Substitute(call->args_[i], ctx.var_map)); + } + auto new_call = op_registry.Create("tile.write", new_args, call->kwargs_, span); + result.push_back(std::make_shared(new_call, eval->span_)); + continue; + } + } + } auto new_expr = Substitute(eval->expr_, ctx.var_map); if (new_expr != eval->expr_) { // Re-create tile ops via OpRegistry for proper type deduction @@ -529,6 +571,21 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon shapes.push_back(dim); } new_args.push_back(std::make_shared(shapes, span)); + + // Ensure offsets rank matches shapes rank (DeduceTileStoreType requirement). + auto offsets_tuple = As(new_args[1]); + if (offsets_tuple && offsets_tuple->elements_.size() < tensor_rank) { + std::vector padded_offsets; + padded_offsets.reserve(tensor_rank); + size_t pad = tensor_rank - offsets_tuple->elements_.size(); + for (size_t i = 0; i < pad; ++i) { + padded_offsets.push_back(std::make_shared(0, DataType::INDEX, span)); + } + for (const auto& off : offsets_tuple->elements_) { + padded_offsets.push_back(off); + } + new_args[1] = std::make_shared(padded_offsets, span); + } } // Construct call directly: store result type = output tensor type (args[2]) @@ -601,6 +658,44 @@ std::vector TransformBody(const std::vector& stmts, FlattenCon } } + // ---- tile.read/tile.write on >2D tiles: flatten ND indices to 2D ---- + // tile.read(tile, (i0, i1, ..., in)) → tile.read(tile_2d, (merged_row, col)) + // where merged_row = i0 * d1 * d2 * ... * d_{n-2} + i1 * d2 * ... + i_{n-2} + // and col = i_{n-1} + // tile.write(tile, (i0, ..., in), val) → tile.write(tile_2d, (merged_row, col), val) + if (op_name == "tile.read" || op_name == "tile.write") { + auto orig_tile_type = As(call->args_[0]->GetType()); + if (orig_tile_type && IsNdTile(orig_tile_type)) { + std::vector new_args; + new_args.reserve(call->args_.size()); + // args[0]: tile (substitute) + new_args.push_back(Substitute(call->args_[0], ctx.var_map)); + // args[1]: indices tuple — flatten from ND to 2D + auto idx_tuple = As(call->args_[1]); + INTERNAL_CHECK(idx_tuple) << "tile.read/tile.write indices must be MakeTuple"; + new_args.push_back(FlattenNdIndicesToTwoD(idx_tuple, orig_tile_type->shape_, ctx.var_map, span)); + // Remaining args (e.g., value for tile.write) + for (size_t i = 2; i < call->args_.size(); ++i) { + new_args.push_back(Substitute(call->args_[i], ctx.var_map)); + } + auto new_call = op_registry.Create(op_name, new_args, call->kwargs_, span); + if (op_name == "tile.read") { + // tile.read returns scalar — assign to var + auto new_var = + std::make_shared(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); + result.push_back(std::make_shared(new_var, new_call, assign->span_)); + ctx.Insert(assign->var_, new_var); + } else { + // tile.write returns tile — assign to var and update mapping + auto new_var = + std::make_shared(assign->var_->name_hint_, new_call->GetType(), assign->var_->span_); + result.push_back(std::make_shared(new_var, new_call, assign->span_)); + ctx.Insert(assign->var_, new_var); + } + continue; + } + } + // ---- All other tile ops (including tile.reshape) and non-tile ops: substitute args ---- { std::vector new_args; diff --git a/src/ir/transforms/op_conversion_registry.cpp b/src/ir/transforms/op_conversion_registry.cpp index f39568c35..15df2888b 100644 --- a/src/ir/transforms/op_conversion_registry.cpp +++ b/src/ir/transforms/op_conversion_registry.cpp @@ -14,6 +14,7 @@ #include #include #include +#include #include #include #include @@ -494,6 +495,205 @@ OpConversionRegistry::OpConversionRegistry() { return ConversionResult{std::move(prologue), scatter_call}; }); + // ──────────────────────────────────────────────────────────────────────── + // tensor.scatter_ → nested ForStmt with tile.read + tile.write + // + // For every element position (i0, ..., in) in the index tensor: + // idx = tile.read(index, [i0, ..., in]) + // src_val = tile.read(src, [i0, ..., in]) (or scalar src directly) + // Build write_indices: same as (i0, ..., in) but replace dim-th with idx + // tile.write(input, write_indices, src_val) + // + // For reduce="add": old = tile.read(input, write_indices); write old + src_val + // For reduce="multiply": old = tile.read(input, write_indices); write old * src_val + // + // Uses nested ForStmt loops (scf.for) which PTOAS compiles to C++ for loops. + // PTOAS eliminates single-trip scf.for loops, so real multi-iteration loops + // are required for tgetval/tsetval to compile. + // + // When input is a TensorType (global memory), keep the op unchanged for + // orchestration codegen. + // ──────────────────────────────────────────────────────────────────────── + + RegisterCustom( + "tensor.scatter_", + [](const std::vector& args, const std::vector>& kwargs, + const Span& span) -> ConversionResult { + CHECK(args.size() == 3) << "tensor.scatter_ conversion expects 3 args (input, index, src)"; + auto& op_reg = OpRegistry::GetInstance(); + + const auto& input = args[0]; + const auto& index = args[1]; + const auto& src = args[2]; + + // Global tensor input → keep as tensor.scatter_ (orchestration codegen handles it) + if (As(input->GetType())) { + return ConversionResult{op_reg.Create("tensor.scatter_", args, kwargs, span)}; + } + + auto input_tile_type = As(input->GetType()); + CHECK(input_tile_type) << "tensor.scatter_: unexpected input type: " << input->GetType()->TypeName(); + + const size_t rank = input_tile_type->shape_.size(); + + // Extract kwargs + int dim = GetKwargOr(kwargs, "dim", 0); + std::string reduce = GetKwargOr(kwargs, "reduce", std::string("none")); + // Normalize negative dim + if (dim < 0) dim += static_cast(rank); + CHECK(dim >= 0 && dim < static_cast(rank)) + << "tensor.scatter_: dim " << dim << " is out of range for rank " << rank; + CHECK(reduce == "none" || reduce == "add" || reduce == "multiply") + << "tensor.scatter_: unsupported reduce mode '" << reduce << "'"; + + std::vector prologue; + + // Load index to Vec tile if it is still a global tensor + ExprPtr index_tile = index; + auto index_tile_type = As(index->GetType()); + if (auto index_tensor_type = As(index->GetType())) { + auto offsets = MakeZeroOffsetsTuple(index_tensor_type->shape_.size(), span); + auto shapes = MakeShapesTuple(index_tensor_type->shape_, span); + std::vector> load_kw = {{"target_memory", MemorySpace::Vec}, + {"transpose", false}}; + auto load = op_reg.Create("tile.load", {index, offsets, shapes, shapes}, load_kw, span); + auto idx_var = std::make_shared("scatter_idx", load->GetType(), span); + prologue.push_back(std::make_shared(idx_var, load, span)); + index_tile = idx_var; + index_tile_type = As(load->GetType()); + } + + // Determine whether src is scalar + bool src_is_scalar = static_cast(As(src->GetType())); + + // Load src to Vec tile if it is a global tensor + ExprPtr src_tile = src; + if (!src_is_scalar) { + if (auto src_tensor_type = As(src->GetType())) { + auto offsets = MakeZeroOffsetsTuple(src_tensor_type->shape_.size(), span); + auto shapes = MakeShapesTuple(src_tensor_type->shape_, span); + std::vector> load_kw = {{"target_memory", MemorySpace::Vec}, + {"transpose", false}}; + auto load = op_reg.Create("tile.load", {src, offsets, shapes, shapes}, load_kw, span); + auto src_var = std::make_shared("scatter_src", load->GetType(), span); + prologue.push_back(std::make_shared(src_var, load, span)); + src_tile = src_var; + } + } + + CHECK(index_tile_type) << "tensor.scatter_: index must be a tile at this point"; + const DataType value_dtype = input_tile_type->dtype_; + + // Build nested ForStmt loops over each dimension of the index tile. + // Recursion: build_loop(d, loop_vars) creates ForStmt for dimension d, + // passing accumulated loop variables down to the body. + std::function&)> build_loop; + build_loop = [&](size_t d, std::vector& loop_vars) -> StmtPtr { + if (d == rank) { + // Innermost body: generate tile.read + tile.write statements + std::vector body_stmts; + + auto bind = [&](const std::string& name, const ExprPtr& expr) -> ExprPtr { + auto var = std::make_shared(name, expr->GetType(), span); + body_stmts.push_back(std::make_shared(var, expr, span)); + return var; + }; + + // 1. Read index value: idx_val = tile.read(index, [loop_vars...]) + std::vector idx_elems; + idx_elems.reserve(rank); + for (size_t k = 0; k < rank; ++k) { + idx_elems.push_back(loop_vars[k]); + } + auto read_indices = std::make_shared(idx_elems, span); + auto idx_val = + bind("scatter_idx_val", op_reg.Create("tile.read", {index_tile, read_indices}, {}, span)); + + // 2. Read or get src value + ExprPtr src_val; + if (src_is_scalar) { + src_val = src; + } else { + auto src_indices = std::make_shared(idx_elems, span); + src_val = + bind("scatter_src_val", op_reg.Create("tile.read", {src_tile, src_indices}, {}, span)); + } + + // Cast src_val to input dtype if needed + if (As(src_val->GetType())) { + DataType src_dtype = GetScalarDtype(src_val); + if (src_dtype != value_dtype) { + src_val = bind("scatter_src_cast", MakeCast(src_val, value_dtype, span)); + } + } + + // 3. Build write indices: same as loop_vars but dim-th replaced with idx_val + std::vector write_idx_elems; + write_idx_elems.reserve(rank); + for (size_t k = 0; k < rank; ++k) { + if (static_cast(k) == dim) { + write_idx_elems.push_back(idx_val); + } else { + write_idx_elems.push_back(loop_vars[k]); + } + } + auto write_indices = std::make_shared(write_idx_elems, span); + + // 4. Write (with optional reduce) + if (reduce == "none") { + body_stmts.push_back(std::make_shared( + op_reg.Create("tile.write", {input, write_indices, src_val}, {}, span), span)); + } else { + auto old_val = + bind("scatter_old_val", op_reg.Create("tile.read", {input, write_indices}, {}, span)); + ExprPtr new_val; + if (reduce == "add") { + new_val = bind("scatter_new_val", MakeAdd(old_val, src_val, span)); + } else { + new_val = bind("scatter_new_val", MakeMul(old_val, src_val, span)); + } + body_stmts.push_back(std::make_shared( + op_reg.Create("tile.write", {input, write_indices, new_val}, {}, span), span)); + } + + // 5. Vector barrier after tile.write (required for tsetval hardware sync) + body_stmts.push_back( + std::make_shared(op_reg.Create("system.bar_v", {}, {}, span), span)); + + return SeqStmts::Flatten(std::move(body_stmts), span); + } + + // Create ForStmt for dimension d + auto loop_var = std::make_shared("scatter_d" + std::to_string(d), + std::make_shared(DataType::INDEX), span); + loop_vars.push_back(loop_var); + auto body = build_loop(d + 1, loop_vars); + loop_vars.pop_back(); + + auto zero = std::make_shared(0, DataType::INDEX, span); + auto step = std::make_shared(1, DataType::INDEX, span); + auto extent = index_tile_type->shape_[d]; + + return std::make_shared(loop_var, zero, extent, step, + /*iter_args=*/std::vector{}, body, + /*return_vars=*/std::vector{}, span, ForKind::Sequential); + }; + + std::vector loop_vars; + loop_vars.reserve(rank); + auto nested_for = build_loop(0, loop_vars); + + // Emit bar_all before and after the scatter loops to ensure: + // - Pre-loop: TLOAD (index) and TEXPANDS (input fill) have completed + // - Post-loop: All tsetval writes are committed before tstore + prologue.push_back(std::make_shared(op_reg.Create("system.bar_all", {}, {}, span), span)); + prologue.push_back(nested_for); + prologue.push_back(std::make_shared(op_reg.Create("system.bar_all", {}, {}, span), span)); + + // The input tile is modified in-place via tile.write; return it as the result. + return ConversionResult{std::move(prologue), input}; + }); + // ──────────────────────────────────────────────────────────────────────── // tensor.create → tile.create // diff --git a/tests/st/runtime/test_scatter.py b/tests/st/runtime/test_scatter.py new file mode 100644 index 000000000..bfd7021b4 --- /dev/null +++ b/tests/st/runtime/test_scatter.py @@ -0,0 +1,486 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +""" +Runtime tests for tensor.scatter_ (element-level scatter). + +Validates that scatter_ correctly places values from a source into a destination +tensor at positions specified by an index tensor along a given dimension. +Follows PyTorch torch.Tensor.scatter_ semantics: + + input[i0]...[i_{d-1}][ index[i0...in] ][i_{d+1}]...[in] = src[i0...in] + +Test matrix: + - 2D dim=1: index selects destination columns + - 2D dim=0: index selects destination rows + - 2D scalar src: fill constant at scattered positions + - 3D dim=2: scatter along last dimension of a 3D tensor + +All tests use PTOAS strategy (BackendType.Ascend910B + OptimizationStrategy.Default). +scatter_ decomposes to scalar tgetval/tsetval loops so shapes are kept small. + +Design notes: + scatter_ operates on a local buffer (from pl.full) inside an InCore function. + The ConvertTensorToTileOps pass decomposes the tensor-level scatter_ into nested + for-loops with tile.read / tile.write. Because scatter_ returns a TensorType at + DSL level (not a TileType), the kernel cannot call pl.store() directly. Instead, + the kernel declares a pl.Out parameter and returns the scatter result; the pass + detects the unused Out param and auto-inserts tile.store to write the result back. +""" + +from typing import Any + +import pypto.language as pl +import pytest +import torch +from harness.core.harness import DataType, PTOTestCase, TensorSpec +from pypto.backend import BackendType +from pypto.ir.pass_manager import OptimizationStrategy + +# ── Deterministic index tensors ────────────────────────────────────────────── +# Pre-built so that golden_writer can inline them (numel <= 100). +# +# Tile alignment constraint: for RowMajor+NoneBox tiles, Cols*sizeof(DType) must +# be 32-byte aligned. For int32/float32 (4 bytes), Cols must be a multiple of 8. +# Shapes are chosen accordingly. + +# 2D dim=1: index [8, 8] with values in [0, 16) +_IDX_2D_DIM1 = torch.tensor( + [ + [0, 3, 7, 15, 1, 5, 9, 12], + [2, 6, 10, 14, 0, 4, 8, 11], + [3, 7, 13, 15, 1, 2, 6, 10], + [5, 9, 11, 14, 0, 4, 8, 12], + [1, 3, 7, 15, 2, 6, 10, 14], + [0, 4, 8, 11, 3, 7, 13, 15], + [1, 2, 6, 10, 5, 9, 11, 14], + [0, 4, 8, 12, 1, 5, 9, 13], + ], + dtype=torch.int32, +) + +# 2D dim=0: index [4, 8] with values in [0, 16) +_IDX_2D_DIM0 = torch.tensor( + [ + [0, 3, 7, 15, 1, 5, 9, 12], + [2, 6, 10, 14, 0, 4, 8, 11], + [3, 7, 13, 15, 1, 2, 6, 10], + [5, 9, 11, 14, 0, 4, 8, 12], + ], + dtype=torch.int32, +) + +# 3D dim=2: index [2, 4, 8] with values in [0, 8) +_IDX_3D_DIM2 = torch.tensor( + [ + [ + [0, 3, 7, 1, 5, 6, 2, 4], + [7, 0, 3, 5, 1, 2, 6, 4], + [2, 7, 3, 5, 0, 4, 6, 1], + [3, 0, 4, 7, 5, 2, 6, 1], + ], + [ + [1, 2, 6, 3, 4, 7, 0, 5], + [6, 1, 0, 7, 3, 4, 5, 2], + [0, 5, 6, 3, 1, 4, 2, 7], + [1, 2, 7, 3, 0, 5, 4, 6], + ], + ], + dtype=torch.int32, +) + + +# 2D small: index [2, 8] with values in [0, 8) — full permutation per row +_IDX_2D_SMALL = torch.tensor( + [ + [3, 0, 7, 1, 5, 2, 6, 4], + [7, 4, 1, 6, 0, 3, 2, 5], + ], + dtype=torch.int32, +) + +_SRC_2D_SMALL = torch.tensor( + [ + [10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0], + [11.0, 22.0, 33.0, 44.0, 55.0, 66.0, 77.0, 88.0], + ], +) + + +# --------------------------------------------------------------------------- +# 2D dim=1: scatter along columns +# --------------------------------------------------------------------------- + + +@pl.program +class Scatter2dDim1Program: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + index: pl.Tensor[[8, 8], pl.INT32], + src: pl.Tensor[[8, 8], pl.FP32], + out: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + buf: pl.Tensor[[8, 16], pl.FP32] = pl.full([8, 16], dtype=pl.FP32, value=1.0) + result: pl.Tensor[[8, 16], pl.FP32] = pl.scatter_(buf, dim=1, index=index, src=src) + return result + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + index: pl.Tensor[[8, 8], pl.INT32], + src: pl.Tensor[[8, 8], pl.FP32], + out: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + out = self.kernel(index, src, out) + return out + + +class Scatter2dDim1TestCase(PTOTestCase): + """2D scatter along dim=1: index selects destination columns.""" + + __test__ = False + + def get_name(self) -> str: + return "scatter_2d_dim1" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("index", [8, 8], DataType.INT32, init_value=_IDX_2D_DIM1), + TensorSpec("src", [8, 8], DataType.FP32, init_value=2.0), + TensorSpec("out", [8, 16], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return Scatter2dDim1Program + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + expected = torch.full([8, 16], 1.0) + expected.scatter_(1, tensors["index"].long(), tensors["src"].float()) + tensors["out"][:] = expected + + +# --------------------------------------------------------------------------- +# 2D dim=0: scatter along rows +# --------------------------------------------------------------------------- + + +@pl.program +class Scatter2dDim0Program: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + index: pl.Tensor[[4, 8], pl.INT32], + src: pl.Tensor[[4, 8], pl.FP32], + out: pl.Out[pl.Tensor[[16, 8], pl.FP32]], + ) -> pl.Tensor[[16, 8], pl.FP32]: + buf: pl.Tensor[[16, 8], pl.FP32] = pl.full([16, 8], dtype=pl.FP32, value=1.0) + result: pl.Tensor[[16, 8], pl.FP32] = pl.scatter_(buf, dim=0, index=index, src=src) + return result + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + index: pl.Tensor[[4, 8], pl.INT32], + src: pl.Tensor[[4, 8], pl.FP32], + out: pl.Out[pl.Tensor[[16, 8], pl.FP32]], + ) -> pl.Tensor[[16, 8], pl.FP32]: + out = self.kernel(index, src, out) + return out + + +class Scatter2dDim0TestCase(PTOTestCase): + """2D scatter along dim=0: index selects destination rows.""" + + __test__ = False + + def get_name(self) -> str: + return "scatter_2d_dim0" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("index", [4, 8], DataType.INT32, init_value=_IDX_2D_DIM0), + TensorSpec("src", [4, 8], DataType.FP32, init_value=2.0), + TensorSpec("out", [16, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return Scatter2dDim0Program + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + expected = torch.full([16, 8], 1.0) + expected.scatter_(0, tensors["index"].long(), tensors["src"].float()) + tensors["out"][:] = expected + + +# --------------------------------------------------------------------------- +# 2D dim=1 with scalar src: fill a constant at scattered positions +# --------------------------------------------------------------------------- + + +@pl.program +class Scatter2dScalarProgram: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + index: pl.Tensor[[8, 8], pl.INT32], + out: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + buf: pl.Tensor[[8, 16], pl.FP32] = pl.full([8, 16], dtype=pl.FP32, value=1.0) + result: pl.Tensor[[8, 16], pl.FP32] = pl.scatter_(buf, dim=1, index=index, src=99.0) + return result + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + index: pl.Tensor[[8, 8], pl.INT32], + out: pl.Out[pl.Tensor[[8, 16], pl.FP32]], + ) -> pl.Tensor[[8, 16], pl.FP32]: + out = self.kernel(index, out) + return out + + +class Scatter2dScalarTestCase(PTOTestCase): + """2D scatter with scalar src (99.0) along dim=1.""" + + __test__ = False + + def get_name(self) -> str: + return "scatter_2d_scalar" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("index", [8, 8], DataType.INT32, init_value=_IDX_2D_DIM1), + TensorSpec("out", [8, 16], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return Scatter2dScalarProgram + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + expected = torch.full([8, 16], 1.0) + expected.scatter_(1, tensors["index"].long(), 99.0) + tensors["out"][:] = expected + + +# --------------------------------------------------------------------------- +# 3D dim=2: scatter along last dimension of a 3D tensor +# --------------------------------------------------------------------------- + + +@pl.program +class Scatter3dDim2Program: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + index: pl.Tensor[[2, 4, 8], pl.INT32], + src: pl.Tensor[[2, 4, 8], pl.FP32], + out: pl.Out[pl.Tensor[[2, 4, 8], pl.FP32]], + ) -> pl.Tensor[[2, 4, 8], pl.FP32]: + buf: pl.Tensor[[2, 4, 8], pl.FP32] = pl.full([2, 4, 8], dtype=pl.FP32, value=1.0) + result: pl.Tensor[[2, 4, 8], pl.FP32] = pl.scatter_(buf, dim=2, index=index, src=src) + return result + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + index: pl.Tensor[[2, 4, 8], pl.INT32], + src: pl.Tensor[[2, 4, 8], pl.FP32], + out: pl.Out[pl.Tensor[[2, 4, 8], pl.FP32]], + ) -> pl.Tensor[[2, 4, 8], pl.FP32]: + out = self.kernel(index, src, out) + return out + + +class Scatter3dDim2TestCase(PTOTestCase): + """3D scatter along dim=2: index selects positions in last dimension.""" + + __test__ = False + + def get_name(self) -> str: + return "scatter_3d_dim2" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("index", [2, 4, 8], DataType.INT32, init_value=_IDX_3D_DIM2), + TensorSpec("src", [2, 4, 8], DataType.FP32, init_value=2.0), + TensorSpec("out", [2, 4, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return Scatter3dDim2Program + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + expected = torch.full([2, 4, 8], 1.0) + expected.scatter_(2, tensors["index"].long(), tensors["src"].float()) + tensors["out"][:] = expected + + +# --------------------------------------------------------------------------- +# 2D small dim=1: minimal scatter with unique values for tracing +# --------------------------------------------------------------------------- + + +@pl.program +class Scatter2dSmallProgram: + @pl.function(type=pl.FunctionType.InCore) + def kernel( + self, + index: pl.Tensor[[2, 8], pl.INT32], + src: pl.Tensor[[2, 8], pl.FP32], + out: pl.Out[pl.Tensor[[2, 8], pl.FP32]], + ) -> pl.Tensor[[2, 8], pl.FP32]: + buf: pl.Tensor[[2, 8], pl.FP32] = pl.full([2, 8], dtype=pl.FP32, value=0.0) + result: pl.Tensor[[2, 8], pl.FP32] = pl.scatter_(buf, dim=1, index=index, src=src) + return result + + @pl.function(type=pl.FunctionType.Orchestration) + def orchestrator( + self, + index: pl.Tensor[[2, 8], pl.INT32], + src: pl.Tensor[[2, 8], pl.FP32], + out: pl.Out[pl.Tensor[[2, 8], pl.FP32]], + ) -> pl.Tensor[[2, 8], pl.FP32]: + out = self.kernel(index, src, out) + return out + + +class Scatter2dSmallTestCase(PTOTestCase): + """2D small scatter along dim=1: unique values for easy tracing.""" + + __test__ = False + + def get_name(self) -> str: + return "scatter_2d_small" + + def define_tensors(self) -> list[TensorSpec]: + return [ + TensorSpec("index", [2, 8], DataType.INT32, init_value=_IDX_2D_SMALL), + TensorSpec("src", [2, 8], DataType.FP32, init_value=_SRC_2D_SMALL), + TensorSpec("out", [2, 8], DataType.FP32, is_output=True), + ] + + def get_program(self) -> Any: + return Scatter2dSmallProgram + + def get_strategy(self) -> OptimizationStrategy: + return OptimizationStrategy.Default + + def get_backend_type(self) -> BackendType: + return BackendType.Ascend910B + + def compute_expected(self, tensors, params=None): + expected = torch.full([2, 8], 0.0) + expected.scatter_(1, tensors["index"].long(), tensors["src"].float()) + tensors["out"][:] = expected + + +# --------------------------------------------------------------------------- +# Test suites +# --------------------------------------------------------------------------- + + +class TestScatterOperations: + """Test suite for tensor.scatter_ element-level scatter.""" + + def test_scatter_2d_small(self, test_runner): + """2D small scatter with unique values for tracing.""" + tc = Scatter2dSmallTestCase() + print(f"\n=== {tc.get_name()} ===") + print("dim=1") + print(f"index=\n{_IDX_2D_SMALL}") + print(f"src=\n{_SRC_2D_SMALL}") + expected = torch.full([2, 8], 0.0) + expected.scatter_(1, _IDX_2D_SMALL.long(), _SRC_2D_SMALL) + print(f"expected=\n{expected}") + result = test_runner.run(tc) + assert result.passed, f"Test failed: {result.error}" + + def test_scatter_2d_dim1(self, test_runner): + """2D scatter along dim=1: index selects destination columns.""" + tc = Scatter2dDim1TestCase() + print(f"\n=== {tc.get_name()} ===") + print("dim=1") + print(f"index=\n{_IDX_2D_DIM1}") + src = torch.full([8, 8], 2.0) + print(f"src=\n{src}") + expected = torch.full([8, 16], 1.0) + expected.scatter_(1, _IDX_2D_DIM1.long(), src) + print(f"expected=\n{expected}") + result = test_runner.run(tc) + assert result.passed, f"Test failed: {result.error}" + + def test_scatter_2d_dim0(self, test_runner): + """2D scatter along dim=0: index selects destination rows.""" + tc = Scatter2dDim0TestCase() + print(f"\n=== {tc.get_name()} ===") + print("dim=0") + print(f"index=\n{_IDX_2D_DIM0}") + src = torch.full([4, 8], 2.0) + print(f"src=\n{src}") + expected = torch.full([16, 8], 1.0) + expected.scatter_(0, _IDX_2D_DIM0.long(), src) + print(f"expected=\n{expected}") + result = test_runner.run(tc) + assert result.passed, f"Test failed: {result.error}" + + def test_scatter_2d_scalar(self, test_runner): + """2D scatter with scalar src (99.0) along dim=1.""" + tc = Scatter2dScalarTestCase() + print(f"\n=== {tc.get_name()} ===") + print("dim=1") + print(f"index=\n{_IDX_2D_DIM1}") + print("src=99.0") + expected = torch.full([8, 16], 1.0) + expected.scatter_(1, _IDX_2D_DIM1.long(), 99.0) + print(f"expected=\n{expected}") + result = test_runner.run(tc) + assert result.passed, f"Test failed: {result.error}" + + def test_scatter_3d_dim2(self, test_runner): + """3D scatter along dim=2: index selects positions in last dimension.""" + tc = Scatter3dDim2TestCase() + print(f"\n=== {tc.get_name()} ===") + print("dim=2") + print(f"index=\n{_IDX_3D_DIM2}") + src = torch.full([2, 4, 8], 2.0) + print(f"src=\n{src}") + expected = torch.full([2, 4, 8], 1.0) + expected.scatter_(2, _IDX_3D_DIM2.long(), src) + print(f"expected=\n{expected}") + result = test_runner.run(tc) + assert result.passed, f"Test failed: {result.error}" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/operators/test_tensor_ops.py b/tests/ut/ir/operators/test_tensor_ops.py index 5e45a12c3..266ae337c 100644 --- a/tests/ut/ir/operators/test_tensor_ops.py +++ b/tests/ut/ir/operators/test_tensor_ops.py @@ -1655,5 +1655,147 @@ def test_tensor_add_symbolic_shape_mismatch_shows_var_names(self): ir.op.tensor.add(tensor_a, tensor_b) +# ────────────────────────────────────────────────────────────────────────────── +# tensor.scatter_ tests +# ────────────────────────────────────────────────────────────────────────────── + + +def test_tensor_scatter_2d_dim1(): + """Test tensor.scatter_ with 2D input, dim=1, tensor src.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([3, 4], DataType.FP32) + index_type = ir.TensorType([3, 2], DataType.INT32) + src_type = ir.TensorType([3, 2], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + call = ir.op.tensor.scatter_(input_var, 1, index_var, src_var) + + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.scatter_" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.dtype == DataType.FP32 + assert len(result_type.shape) == 2 + + +def test_tensor_scatter_2d_dim0(): + """Test tensor.scatter_ with 2D input, dim=0, tensor src.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([4, 3], DataType.FP32) + index_type = ir.TensorType([2, 3], DataType.INT32) + src_type = ir.TensorType([2, 3], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + call = ir.op.tensor.scatter_(input_var, 0, index_var, src_var) + + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.scatter_" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert len(result_type.shape) == 2 + + +def test_tensor_scatter_scalar_src(): + """Test tensor.scatter_ with scalar src.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([3, 4], DataType.FP32) + index_type = ir.TensorType([3, 2], DataType.INT32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + + call = ir.op.tensor.scatter_(input_var, 1, index_var, 1.0) + + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.scatter_" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.dtype == DataType.FP32 + + +def test_tensor_scatter_negative_dim(): + """Test tensor.scatter_ with negative dim.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([3, 4], DataType.FP32) + index_type = ir.TensorType([3, 2], DataType.INT32) + src_type = ir.TensorType([3, 2], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + # dim=-1 should be valid (equivalent to dim=1 for 2D) + call = ir.op.tensor.scatter_(input_var, -1, index_var, src_var) + assert call.op.name == "tensor.scatter_" + + # dim=-2 should also be valid (equivalent to dim=0 for 2D) + call2 = ir.op.tensor.scatter_(input_var, -2, index_var, src_var) + assert call2.op.name == "tensor.scatter_" + + +def test_tensor_scatter_3d(): + """Test tensor.scatter_ with 3D input.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([2, 3, 4], DataType.FP32) + index_type = ir.TensorType([2, 3, 2], DataType.INT32) + src_type = ir.TensorType([2, 3, 2], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + call = ir.op.tensor.scatter_(input_var, 2, index_var, src_var) + assert call.op.name == "tensor.scatter_" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert len(result_type.shape) == 3 + + +def test_tensor_scatter_invalid_dim(): + """Test tensor.scatter_ rejects out-of-range dim.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([3, 4], DataType.FP32) + index_type = ir.TensorType([3, 2], DataType.INT32) + src_type = ir.TensorType([3, 2], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + with pytest.raises(ValueError, match="dim must be in"): + ir.op.tensor.scatter_(input_var, 2, index_var, src_var) + + with pytest.raises(ValueError, match="dim must be in"): + ir.op.tensor.scatter_(input_var, -3, index_var, src_var) + + +def test_tensor_scatter_index_rank_mismatch(): + """Test tensor.scatter_ rejects index with different rank than input.""" + span = ir.Span.unknown() + + input_type = ir.TensorType([3, 4], DataType.FP32) + index_type = ir.TensorType([3, 2, 1], DataType.INT32) # 3D vs 2D input + src_type = ir.TensorType([3, 2, 1], DataType.FP32) + + input_var = ir.Var("input", input_type, span) + index_var = ir.Var("index", index_type, span) + src_var = ir.Var("src", src_type, span) + + with pytest.raises(ValueError, match="index rank"): + ir.op.tensor.scatter_(input_var, 0, index_var, src_var) + + if __name__ == "__main__": pytest.main([__file__, "-v"]) diff --git a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py index fb8849529..bdc4be359 100644 --- a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py +++ b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py @@ -2445,5 +2445,111 @@ def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: assert "tensor.full" not in ir_str +class TestInOutParamHandling: + """Tests for InOut parameter handling in ConvertTensorToTileOps.""" + + def test_iter_arg_loop_adds_out_param(self): + """IterArg loop in InCore: pass adds Out param for the returned tile. + + The orchestration-level ForStmt iter_arg_mapping handles reuse at + the call-site level (AnalyzeIterArgMappings). Inside the InCore + function itself, the return tile gets a new Out param + tile.store. + """ + + @pl.program + class Input: + @pl.function + def main(self, x: pl.Tensor[[64], pl.FP32]) -> pl.Tensor[[64], pl.FP32]: + with pl.incore(): + for i in pl.range(10): + x = pl.add(x, x) + return x + + ctx = passes.PassContext([], passes.VerificationLevel.NONE) + with ctx: + prog = passes.convert_to_ssa()(Input) + prog = passes.outline_incore_scopes()(prog) + prog = passes.ctrl_flow_transform()(prog) + After = passes.convert_tensor_to_tile_ops()(prog) + + incore_func = After.get_function("main_incore_0") + assert incore_func is not None + + in_count = sum(1 for d in incore_func.param_directions if d == ir.ParamDirection.In) + out_count = sum(1 for d in incore_func.param_directions if d == ir.ParamDirection.Out) + assert in_count == 1, f"Expected 1 In param, got {in_count}" + assert out_count == 1, f"Expected 1 Out param (for return tile), got {out_count}" + + +class TestScatterConversion: + """Tests for tensor.scatter_ → for-loop tile.read/tile.write conversion.""" + + def test_scatter_local_tile_dim1_scalar_src(self): + """tensor.scatter_ with scalar src lowers to for-loop + tile.read/tile.write.""" + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + index: pl.Tensor[[3, 2], pl.INT32], + out: pl.Out[pl.Tensor[[3, 4], pl.FP32]], + ) -> pl.Tensor[[3, 4], pl.FP32]: + buf: pl.Tensor[[3, 4], pl.FP32] = pl.create_tensor([3, 4], dtype=pl.FP32) + result: pl.Tensor[[3, 4], pl.FP32] = pl.scatter_(buf, dim=1, index=index, src=1.0) + return result + + @pl.function + def main( + self, + index: pl.Tensor[[3, 2], pl.INT32], + out: pl.Out[pl.Tensor[[3, 4], pl.FP32]], + ) -> pl.Tensor[[3, 4], pl.FP32]: + return self.main_incore_0(index, out) + + After = passes.convert_tensor_to_tile_ops()(Before) + after_str = str(After) + + # tensor.scatter_ should be decomposed into element-wise for-loop + assert "tensor.scatter_" not in after_str + # New approach uses tile.read (for index) and tile.write (for scatter) + assert "tile.read" in after_str + assert "tile.write" in after_str + + def test_scatter_local_tile_dim0_tensor_src(self): + """tensor.scatter_ with dim=0 and tensor src lowers to for-loop + tile.read/tile.write.""" + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + index: pl.Tensor[[2, 3], pl.INT32], + src: pl.Tensor[[2, 3], pl.FP32], + out: pl.Out[pl.Tensor[[4, 3], pl.FP32]], + ) -> pl.Tensor[[4, 3], pl.FP32]: + buf: pl.Tensor[[4, 3], pl.FP32] = pl.create_tensor([4, 3], dtype=pl.FP32) + result: pl.Tensor[[4, 3], pl.FP32] = pl.scatter_(buf, dim=0, index=index, src=src) + return result + + @pl.function + def main( + self, + index: pl.Tensor[[2, 3], pl.INT32], + src: pl.Tensor[[2, 3], pl.FP32], + out: pl.Out[pl.Tensor[[4, 3], pl.FP32]], + ) -> pl.Tensor[[4, 3], pl.FP32]: + return self.main_incore_0(index, src, out) + + After = passes.convert_tensor_to_tile_ops()(Before) + after_str = str(After) + + # tensor.scatter_ should be decomposed into element-wise for-loop + assert "tensor.scatter_" not in after_str + # New approach uses tile.read (for index + src) and tile.write + assert "tile.read" in after_str + assert "tile.write" in after_str + + if __name__ == "__main__": pytest.main([__file__, "-v"])