Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ set(PYPTO_SOURCES
src/ir/op/tensor_ops/elementwise.cpp
src/ir/op/tensor_ops/matmul.cpp
src/ir/op/tensor_ops/memory.cpp
src/ir/op/tensor_ops/scatter.cpp
src/ir/op/tensor_ops/scatter_update.cpp
src/ir/op/tensor_ops/reduction.cpp
src/ir/op/tensor_ops/transform.cpp
Expand Down
1 change: 1 addition & 0 deletions docs/en/user/02-operation_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ Operate on `Tensor` objects (DDR memory).
| `transpose` | `(tensor: Tensor, axis1: int, axis2: int) -> Tensor` | Swap two axes |
| `assemble` | `(target: Tensor, source: Tensor, offset: Sequence[IntLike]) -> Tensor` | Write source into target at offset |
| `scatter_update` | `(input: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor` | Update rows of `input` at sparse positions given by `index` with values from `src`. `input`/`src`: 2D `[rows, d]` or 4D `[B, S, 1, d]`; `index`: 2D `[b, s]` integer. Only `dim=-2` is supported |
| `scatter_` | `(input: Tensor, dim: int, index: Tensor, src: Tensor \| float \| int) -> Tensor` | Element-level scatter: write `src` values into `input` at positions given by `index` along `dim`. Follows PyTorch `scatter_` semantics. Supports arbitrary rank and any valid `dim` in `[-rank, rank)`. `src` can be a tensor or a scalar |
| `add` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise add |
| `sub` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise subtract |
| `mul` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | Element-wise multiply |
Expand Down
1 change: 1 addition & 0 deletions docs/zh-cn/user/02-operation_reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
| `transpose` | `(tensor: Tensor, axis1: int, axis2: int) -> Tensor` | 交换两个轴 |
| `assemble` | `(target: Tensor, source: Tensor, offset: Sequence[IntLike]) -> Tensor` | 将 source 写入 target 的指定偏移 |
| `scatter_update` | `(input: Tensor, dim: int, index: Tensor, src: Tensor) -> Tensor` | 按 `index` 指定的稀疏行位置,将 `src` 的行数据写入 `input`。`input`/`src`:2D `[rows, d]` 或 4D `[B, S, 1, d]`;`index`:2D `[b, s]` 整型。当前仅支持 `dim=-2` |
| `scatter_` | `(input: Tensor, dim: int, index: Tensor, src: Tensor \| float \| int) -> Tensor` | 逐元素散射:沿 `dim` 维度按 `index` 指定的位置将 `src` 写入 `input`。语义与 PyTorch `scatter_` 一致。支持任意维度和 `[-rank, rank)` 范围内的 `dim`。`src` 可以是张量或标量 |
| `add` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素加法 |
| `sub` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素减法 |
| `mul` | `(lhs: Tensor, rhs: Tensor \| int \| float \| Scalar) -> Tensor` | 逐元素乘法 |
Expand Down
11 changes: 11 additions & 0 deletions python/pypto/debug/torch_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,14 @@ def _handle_slice(a: list[str], _kw: dict[str, Any]) -> str:
_OP_MAP: dict[str, OpHandler] = {}


def _scatter_handler(a: list[str], kw: dict[str, str]) -> str:
dim = kw.get("dim", "0")
reduce = kw.get("reduce", None)
if reduce and reduce != "none":
return f"{a[0]}.scatter_({dim}, {a[1]}, {a[2]}, reduce='{reduce}')"
return f"{a[0]}.scatter_({dim}, {a[1]}, {a[2]})"


def _register_ops() -> None:
m = _OP_MAP

Expand Down Expand Up @@ -261,6 +269,9 @@ def _register_ops() -> None:
# scatter_update
m[f"{prefix}.scatter_update"] = lambda a, kw: f"{a[0]}.scatter_(-2, {a[1]}.expand_as({a[2]}), {a[2]})"

# scatter_ (element-level scatter)
m[f"{prefix}.scatter_"] = _scatter_handler

# broadcast ops - torch broadcasting handles these naturally
m[f"{prefix}.row_expand_add"] = _binop("+")
m[f"{prefix}.row_expand_sub"] = _binop("-")
Expand Down
71 changes: 71 additions & 0 deletions python/pypto/ir/op/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -940,3 +940,74 @@ def scatter_update(
op_args: list[Expr] = [input, index, src]
kwargs: dict[str, Any] = {"dim": dim_val}
return _ir_core.create_op_call("tensor.scatter_update", op_args, kwargs, actual_span)


def scatter_(
input: Expr,
*args: Expr | int | float,
dim: int | Expr | None = None,
index: Expr | None = None,
src: Expr | float | int | None = None,
reduce: str | None = None,
span: Span | None = None,
) -> Call:
Comment thread
coderabbitai[bot] marked this conversation as resolved.
"""Element-level scatter into tensor along a dimension.

For each position (i₀,…,iₙ) in index, sets:
input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]

Follows PyTorch ``torch.Tensor.scatter_`` semantics.

Accepts call forms:
- scatter_(input, dim, index, src)
- scatter_(input, dim, index, src=1.0)

Args:
input: Destination tensor (N-D).
dim: Dimension along which to scatter.
index: Index tensor (same rank as input, integer dtype).
src: Source tensor (same shape as index) or scalar value.
span: Optional source span for debugging (auto-captured if not provided).

Returns:
Call expression returning the updated input tensor.
"""
if len(args) == 3 and dim is None and index is None and src is None:
dim, index, src = args
elif len(args) == 2 and dim is not None and index is None and src is None:
index, src = args
elif len(args) == 1 and dim is None and index is not None and src is not None:
dim = args[0]
elif len(args) != 0:
raise TypeError(
"scatter_ expects (input, dim, index, src), "
"(input, index, src, dim=...), or (input, dim, index=..., src=...)"
)

if dim is None or index is None or src is None:
raise TypeError("scatter_ requires input, dim, index, and src")

actual_span = _get_span_or_capture(span)
if isinstance(dim, ConstInt):
dim_val = int(dim.value)
elif isinstance(dim, int):
dim_val = dim
else:
raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")

if not isinstance(index, Expr):
raise TypeError(f"index must be Expr, got {type(index)}")

# src can be Expr or scalar (int → ConstInt, float → ConstFloat)
if isinstance(src, int):
src = ConstInt(src, DataType.INT32, actual_span)
elif isinstance(src, float):
src = ConstFloat(src, DataType.FP32, actual_span)
elif not isinstance(src, Expr):
raise TypeError(f"src must be Expr or scalar, got {type(src)}")

op_args: list[Expr] = [input, index, src]
kwargs: dict[str, Any] = {"dim": dim_val}
if reduce is not None:
kwargs["reduce"] = reduce
return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)
Comment on lines +945 to +1013
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The language layer pl.scatter_ calls this function with a reduce keyword argument, but this function's signature doesn't accept it. This will lead to a TypeError at runtime. The reduce argument should be added to the signature and passed to create_op_call. Additionally, ensure that user-provided arguments are validated at this level to provide clear error messages.

def scatter_(
    input: Expr,
    *args: Expr | int | float,
    dim: int | Expr | None = None,
    index: Expr | None = None,
    src: Expr | float | int | None = None,
    reduce: str | None = None,
    span: Span | None = None,
) -> Call:
    """Element-level scatter into tensor along a dimension.

    For each position (i₀,…,iₙ) in index, sets:
      input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]

    Follows PyTorch ``torch.Tensor.scatter_`` semantics.

    Accepts call forms:
    - scatter_(input, dim, index, src)
    - scatter_(input, dim, index, src=1.0)

    Args:
        input: Destination tensor (N-D).
        dim: Dimension along which to scatter.
        index: Index tensor (same rank as input, integer dtype).
        src: Source tensor (same shape as index) or scalar value.
        reduce: Optional reduction mode ("add" or "multiply").
        span: Optional source span for debugging (auto-captured if not provided).

    Returns:
        Call expression returning the updated input tensor.
    """
    if len(args) == 3 and dim is None and index is None and src is None:
        dim, index, src = args
    elif len(args) == 2 and dim is not None and index is None and src is None:
        index, src = args
    elif len(args) == 1 and dim is None and index is not None and src is not None:
        dim = args[0]
    elif len(args) != 0:
        raise TypeError(
            "scatter_ expects (input, dim, index, src), "
            "(input, index, src, dim=...), or (input, dim, index=..., src=...)"
        )

    if dim is None or index is None or src is None:
        raise TypeError("scatter_ requires input, dim, index, and src")

    actual_span = _get_span_or_capture(span)
    if isinstance(dim, ConstInt):
        dim_val = int(dim.value)
    elif isinstance(dim, int):
        dim_val = dim
    else:
        raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")

    if not isinstance(index, Expr):
        raise TypeError(f"index must be Expr, got {type(index)}")

    # src can be Expr or scalar (int/float → ConstFloat)
    if isinstance(src, (int, float)):
        src = ConstFloat(float(src), DataType.FP32, actual_span)
    elif not isinstance(src, Expr):
        raise TypeError(f"src must be Expr or scalar, got {type(src)}")

    op_args: list[Expr] = [input, index, src]
    kwargs: dict[str, Any] = {"dim": dim_val}
    if reduce is not None:
        kwargs["reduce"] = reduce
    return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)
References
  1. Validate user-provided arguments for DSL functions at the parser level to provide early and clear error messages, rather than relying solely on backend C++ validation.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed: pl.scatter_() now passes reduce kwarg through both the IR and language layers.

3 changes: 2 additions & 1 deletion python/pypto/language/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]:
tpush_to_aic,
tpush_to_aiv,
)
from .op.tensor_ops import assemble, create_tensor, dim, full, scatter_update
from .op.tensor_ops import assemble, create_tensor, dim, full, scatter_, scatter_update
from .op.tile_ops import (
MemRefType,
abs,
Expand Down Expand Up @@ -326,6 +326,7 @@ def scalar_func(x: pl.Scalar[pl.FP32]) -> pl.Scalar[pl.FP32]:
"dim",
"full",
"scatter_update",
"scatter_",
"FunctionType",
"ForKind",
"Level",
Expand Down
3 changes: 2 additions & 1 deletion python/pypto/language/op/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from . import tile_ops as tile

# Promoted tensor-only ops (accessible as pl.create_tensor, etc.)
from .tensor_ops import assemble, dim, scatter_update
from .tensor_ops import assemble, dim, scatter_, scatter_update
from .tensor_ops import create as create_tensor

# Promoted tile-only ops (accessible as pl.load, etc.)
Expand Down Expand Up @@ -191,4 +191,5 @@
"assemble",
"dim",
"scatter_update",
"scatter_",
]
36 changes: 36 additions & 0 deletions python/pypto/language/op/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"reshape",
"transpose",
"scatter_update",
"scatter_",
]

from pypto.ir.op import tensor_ops as _ir_ops
Expand Down Expand Up @@ -779,3 +780,38 @@ def scatter_update(
"""
call_expr = _ir_ops.scatter_update(input.unwrap(), dim, index.unwrap(), src.unwrap())
return Tensor(expr=call_expr)


def scatter_(
input: Tensor,
dim: int,
index: Tensor,
src: float | int | Tensor,
*,
reduce: str | None = None,
) -> Tensor:
"""Element-level scatter: write src values into input at positions given by index along dim.

For each element position (i0,...,in) in index, sets:
input[i0]...[i_{d-1}][index[i0...in]][i_{d+1}]...[in] = src[i0...in]

Supports arbitrary rank and any valid dim in [-rank, rank).
src can be a tensor (same shape as index) or a scalar value.

Args:
input: Destination tensor (N-D)
dim: Dimension along which to scatter
index: Index tensor (N-D, same rank as input) of integer dtype
src: Source tensor (same shape as index) or scalar value
reduce: Optional reduce mode ("add" or "multiply")
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

#!/bin/bash
# Check if reduce validation exists anywhere in the scatter implementation chain
rg -n "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp | head -20
rg -n "reduce.*add\|reduce.*multiply" --type cpp src/ir/transforms/op_conversion_registry.cpp | head -10

Repository: hw-native-sys/pypto

Length of output: 103


🏁 Script executed:

# Check the context around line 806 in tensor_ops.py
cd python/pypto/language/op && sed -n '795,820p' tensor_ops.py

Repository: hw-native-sys/pypto

Length of output: 957


🏁 Script executed:

# Look at the full scatter function implementation
cd python/pypto/language/op && grep -n "def scatter" tensor_ops.py

Repository: hw-native-sys/pypto

Length of output: 104


🏁 Script executed:

# Check the C++ scatter implementation for reduce handling
cat src/ir/op/tensor_ops/scatter.cpp | head -150

Repository: hw-native-sys/pypto

Length of output: 5645


🏁 Script executed:

# Search for where the reduce parameter is used in C++ code
rg -B5 -A10 "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp

Repository: hw-native-sys/pypto

Length of output: 661


🏁 Script executed:

# Search for reduce parameter usage in tests or examples
rg -n "reduce" --type python python/pypto | grep -E "(test|example)" | head -20

Repository: hw-native-sys/pypto

Length of output: 93


🏁 Script executed:

# Search for scatter tests
fd -e py "test.*scatter" python/

Repository: hw-native-sys/pypto

Length of output: 45


🏁 Script executed:

# Look for where scatter_ is actually called with reduce parameter
rg -B2 -A2 'scatter_.*reduce' --type python python/

Repository: hw-native-sys/pypto

Length of output: 93


🏁 Script executed:

# Search for scatter tests and reduce usage
fd -type f -name "*scatter*" python/

Repository: hw-native-sys/pypto

Length of output: 233


🏁 Script executed:

# Search for reduce parameter usage in Python files
rg "reduce\s*=" --type py python/ | head -20

Repository: hw-native-sys/pypto

Length of output: 396


🏁 Script executed:

# Look at the actual backend to understand what reduce modes are supported
rg -n "reduce" src/ | grep -E "(add|multiply)" | head -20

Repository: hw-native-sys/pypto

Length of output: 541


🏁 Script executed:

# Check if there are any comments or documentation about valid reduce modes
rg -B3 -A3 "reduce.*add.*multiply" --type py --type cpp --type txt

Repository: hw-native-sys/pypto

Length of output: 1343


Update docstring to include all valid reduce values.

The Python docstring lists only "add" and "multiply", but the C++ backend (src/ir/transforms/op_conversion_registry.cpp:546) validates against three modes: "none", "add", and "multiply". Update the docstring to reflect the complete set of valid values.

-        reduce: Optional reduce mode ("add" or "multiply")
+        reduce: Optional reduce mode ("none", "add", or "multiply")
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
reduce: Optional reduce mode ("add" or "multiply")
reduce: Optional reduce mode ("none", "add", or "multiply")
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@python/pypto/language/op/tensor_ops.py` at line 806, Update the docstring for
the parameter named reduce in python/pypto/language/op/tensor_ops.py to list all
valid values used by the backend — "none", "add", and "multiply" — so it matches
the validation in the C++ op conversion registry; locate the docstring block
that documents reduce and replace the current quoted values ("add", "multiply")
with the full set ("none", "add", "multiply") and ensure the description clearly
states what each mode means if present.


Returns:
Tensor wrapping the scatter_ operation
"""
src_expr: float | int | Expr
if isinstance(src, (Tensor, Scalar)):
src_expr = src.unwrap()
else:
src_expr = src
call_expr = _ir_ops.scatter_(input.unwrap(), dim, index.unwrap(), src_expr, reduce=reduce)
return Tensor(expr=call_expr)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
54 changes: 48 additions & 6 deletions src/backend/common/pto_ops_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -717,22 +717,46 @@ static std::string MakeTileAllocCodegenPTO(const CallPtr& op, codegen::CodegenBa
return ""; // No MLIR emission - pto.alloc_tile generated from MemRefs in TileTypes
}

// Compute a row-major flat offset string from a MakeTuple of indices and the shape of the container.
// Compute a row-major flat offset from a MakeTuple of indices and the shape,
// emitting proper arith.muli / arith.addi SSA operations.
// Returns the SSA name of the final flat-offset value (index type).
static std::string ComputeFlatOffsetPTO(const ir::MakeTuplePtr& indices_tuple,
const std::vector<ir::ExprPtr>& shape, codegen::PTOCodegen& codegen) {
const auto& indices = indices_tuple->elements_;
INTERNAL_CHECK(indices.size() == shape.size())
<< "Index count (" << indices.size() << ") must match shape rank (" << shape.size() << ")";

std::ostringstream idx_oss;
// Helper: ensure an index element SSA value has `index` type.
// If the expression is a non-index integer (e.g. i32 from tile.read on an
// INT32 tile), emit arith.index_cast to convert it.
auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
if (auto var = ir::As<ir::Var>(expr)) {
return codegen.EmitCastToIndex(var, ssa);
}
return ssa;
};

// For each dimension i, compute: index[i] * (shape[i+1] * shape[i+2] * ... * shape[rank-1])
// then sum all terms with arith.addi.
std::string accumulator;
for (size_t i = 0; i < indices.size(); ++i) {
if (i > 0) idx_oss << " + ";
idx_oss << codegen.GetExprAsCode(indices[i]);
std::string term = ensure_index(indices[i], codegen.GetExprAsCode(indices[i]));
// Multiply by each trailing dimension size
for (size_t j = i + 1; j < shape.size(); ++j) {
idx_oss << " * " << codegen.GetExprAsCode(shape[j]);
std::string dim = codegen.GetExprAsCode(shape[j]);
std::string tmp = codegen.NewTemp();
codegen.Emit(tmp + " = arith.muli " + term + ", " + dim + " : index");
Comment on lines +729 to +748
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Cast all non-index scalar index expressions before the flat-offset math.

Line 733 only normalizes plain ir::Var. If a tuple element is an IterArg, an integer ConstInt, or a computed scalar like idx32 + 1, term stays i32/i64 and the emitted arith.muli / arith.addi ... : index becomes ill-typed.

🔧 Suggested fix
-  auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
-    if (auto var = ir::As<ir::Var>(expr)) {
-      return codegen.EmitCastToIndex(var, ssa);
-    }
-    return ssa;
-  };
+  auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
+    auto scalar_type = As<ScalarType>(expr->GetType());
+    INTERNAL_CHECK(scalar_type) << "flat index expression must be scalar";
+    if (scalar_type->dtype_ == DataType::INDEX) {
+      return ssa;
+    }
+    CHECK(!scalar_type->dtype_.IsFloat()) << "flat index expression must be integer/index typed";
+    std::string idx_ssa = codegen.NewTemp();
+    codegen.Emit(idx_ssa + " = arith.index_cast " + ssa + " : " +
+                 codegen.GetTypeString(scalar_type->dtype_) + " to index");
+    return idx_ssa;
+  };
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@src/backend/common/pto_ops_common.cpp` around lines 729 - 748, The code only
casts ir::Var in the ensure_index lambda, so non-Var index expressions
(ConstInt, IterArg, or compound Expr like add) keep their integer scalar type
and produce ill-typed arith.muli/arith.addi ... : index ops; update ensure_index
(used above the loop that computes term and in the flat-offset math) to inspect
the expression's type or kind and call codegen.EmitCastToIndex for any
expression that is not already of index type (not just ir::Var), ensuring term
is always the casted SSA string; use the same pattern around
codegen.GetExprAsCode(indices[i]) so all terms used in the subsequent
arith.muli/arith.addi emissions are index-typed.

term = tmp;
}
if (accumulator.empty()) {
accumulator = term;
} else {
std::string tmp = codegen.NewTemp();
codegen.Emit(tmp + " = arith.addi " + accumulator + ", " + term + " : index");
accumulator = tmp;
}
}
return idx_oss.str();
return accumulator;
}

// Get or emit a flat offset SSA value for a MakeTuple of indices and shape.
Expand Down Expand Up @@ -932,6 +956,15 @@ static std::string MakeTensorDimCodegenPTO(const CallPtr& op, codegen::CodegenBa
return "";
}

static std::string MakeSystemBarrierCodegenPTO(const std::string& pipe_name, const CallPtr& op,
codegen::CodegenBase& codegen_base) {
CHECK(op->args_.empty()) << "system.barrier_" << pipe_name << " expects 0 arguments, got "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为什么要插入barrier,这个工作应该是ptoas做的

<< op->args_.size();
auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base);
codegen.Emit("pto.barrier #pto.pipe<" + pipe_name + ">");
return "";
}

// ============================================================================
// Cross-Core Communication Operations (TPUSH/TPOP)
// ============================================================================
Expand Down Expand Up @@ -1334,6 +1367,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc
reg("tile.cast", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
return MakeTileCvtCodegenPTO("pto.tcvt", op, codegen);
});
reg("system.bar_v", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
return MakeSystemBarrierCodegenPTO("PIPE_V", op, codegen);
});
reg("system.bar_m", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
return MakeSystemBarrierCodegenPTO("PIPE_M", op, codegen);
});
reg("system.bar_all", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) {
return MakeSystemBarrierCodegenPTO("PIPE_ALL", op, codegen);
});
// tile.full (TEXPANDS): output is row_major per ISA
if (exclude_ops.count("tile.full") == 0) {
backend.RegisterOp("tile.full")
Expand Down
6 changes: 5 additions & 1 deletion src/codegen/pto/pto_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -907,6 +907,10 @@ std::string PTOCodegen::GetExprAsCode(const ExprPtr& expr) {
return GetVarName(var);
}
if (auto const_int = As<ir::ConstInt>(expr)) {
DataType dtype = const_int->dtype();
if (dtype == DataType::INT32) {
return GetOrEmitI32Constant(static_cast<int32_t>(const_int->value_));
}
return GetIndexConstant(const_int->value_);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.
if (auto const_float = As<ir::ConstFloat>(expr)) {
Expand Down Expand Up @@ -1148,7 +1152,7 @@ std::string PTOCodegen::GetExprTypeAnnotation(const ir::ExprPtr& expr) {
return "f32";
}
if (auto const_int = As<ir::ConstInt>(expr)) {
return "index";
return GetTypeString(const_int->dtype());
}
return "";
}
Expand Down
12 changes: 11 additions & 1 deletion src/codegen/pto/pto_scalar_expr_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,17 @@ void PTOCodegen::VisitExpr_(const ir::IterArgPtr& op) {
}

void PTOCodegen::VisitExpr_(const ir::ConstIntPtr& op) {
fs_.current_expr_value = GetOrEmitIndexConstant(op->value_);
DataType dtype = op->dtype();
if (dtype == DataType::INDEX) {
fs_.current_expr_value = GetOrEmitIndexConstant(op->value_);
} else if (dtype == DataType::INT32) {
fs_.current_expr_value = GetOrEmitI32Constant(static_cast<int32_t>(op->value_));
} else {
std::string result = NewTemp();
std::string type_str = GetTypeString(dtype);
Emit(result + " = arith.constant " + std::to_string(op->value_) + " : " + type_str);
fs_.current_expr_value = result;
}
}

void PTOCodegen::VisitExpr_(const ir::ConstFloatPtr& op) {
Expand Down
Loading
Loading