-
Notifications
You must be signed in to change notification settings - Fork 58
feat(op): implement tensor.scatter_ element-level scatter #898
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -940,3 +940,74 @@ def scatter_update( | |
| op_args: list[Expr] = [input, index, src] | ||
| kwargs: dict[str, Any] = {"dim": dim_val} | ||
| return _ir_core.create_op_call("tensor.scatter_update", op_args, kwargs, actual_span) | ||
|
|
||
|
|
||
| def scatter_( | ||
| input: Expr, | ||
| *args: Expr | int | float, | ||
| dim: int | Expr | None = None, | ||
| index: Expr | None = None, | ||
| src: Expr | float | int | None = None, | ||
| reduce: str | None = None, | ||
| span: Span | None = None, | ||
| ) -> Call: | ||
| """Element-level scatter into tensor along a dimension. | ||
|
|
||
| For each position (i₀,…,iₙ) in index, sets: | ||
| input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ] | ||
|
|
||
| Follows PyTorch ``torch.Tensor.scatter_`` semantics. | ||
|
|
||
| Accepts call forms: | ||
| - scatter_(input, dim, index, src) | ||
| - scatter_(input, dim, index, src=1.0) | ||
|
|
||
| Args: | ||
| input: Destination tensor (N-D). | ||
| dim: Dimension along which to scatter. | ||
| index: Index tensor (same rank as input, integer dtype). | ||
| src: Source tensor (same shape as index) or scalar value. | ||
| span: Optional source span for debugging (auto-captured if not provided). | ||
|
|
||
| Returns: | ||
| Call expression returning the updated input tensor. | ||
| """ | ||
| if len(args) == 3 and dim is None and index is None and src is None: | ||
| dim, index, src = args | ||
| elif len(args) == 2 and dim is not None and index is None and src is None: | ||
| index, src = args | ||
| elif len(args) == 1 and dim is None and index is not None and src is not None: | ||
| dim = args[0] | ||
| elif len(args) != 0: | ||
| raise TypeError( | ||
| "scatter_ expects (input, dim, index, src), " | ||
| "(input, index, src, dim=...), or (input, dim, index=..., src=...)" | ||
| ) | ||
|
|
||
| if dim is None or index is None or src is None: | ||
| raise TypeError("scatter_ requires input, dim, index, and src") | ||
|
|
||
| actual_span = _get_span_or_capture(span) | ||
| if isinstance(dim, ConstInt): | ||
| dim_val = int(dim.value) | ||
| elif isinstance(dim, int): | ||
| dim_val = dim | ||
| else: | ||
| raise TypeError(f"dim must be int or ConstInt, got {type(dim)}") | ||
|
|
||
| if not isinstance(index, Expr): | ||
| raise TypeError(f"index must be Expr, got {type(index)}") | ||
|
|
||
| # src can be Expr or scalar (int → ConstInt, float → ConstFloat) | ||
| if isinstance(src, int): | ||
| src = ConstInt(src, DataType.INT32, actual_span) | ||
| elif isinstance(src, float): | ||
| src = ConstFloat(src, DataType.FP32, actual_span) | ||
| elif not isinstance(src, Expr): | ||
| raise TypeError(f"src must be Expr or scalar, got {type(src)}") | ||
|
|
||
| op_args: list[Expr] = [input, index, src] | ||
| kwargs: dict[str, Any] = {"dim": dim_val} | ||
| if reduce is not None: | ||
| kwargs["reduce"] = reduce | ||
| return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span) | ||
|
Comment on lines
+945
to
+1013
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The def scatter_(
input: Expr,
*args: Expr | int | float,
dim: int | Expr | None = None,
index: Expr | None = None,
src: Expr | float | int | None = None,
reduce: str | None = None,
span: Span | None = None,
) -> Call:
"""Element-level scatter into tensor along a dimension.
For each position (i₀,…,iₙ) in index, sets:
input[i₀]…[i_{d-1}][ index[i₀…iₙ] ][i_{d+1}]…[iₙ] = src[i₀…iₙ]
Follows PyTorch ``torch.Tensor.scatter_`` semantics.
Accepts call forms:
- scatter_(input, dim, index, src)
- scatter_(input, dim, index, src=1.0)
Args:
input: Destination tensor (N-D).
dim: Dimension along which to scatter.
index: Index tensor (same rank as input, integer dtype).
src: Source tensor (same shape as index) or scalar value.
reduce: Optional reduction mode ("add" or "multiply").
span: Optional source span for debugging (auto-captured if not provided).
Returns:
Call expression returning the updated input tensor.
"""
if len(args) == 3 and dim is None and index is None and src is None:
dim, index, src = args
elif len(args) == 2 and dim is not None and index is None and src is None:
index, src = args
elif len(args) == 1 and dim is None and index is not None and src is not None:
dim = args[0]
elif len(args) != 0:
raise TypeError(
"scatter_ expects (input, dim, index, src), "
"(input, index, src, dim=...), or (input, dim, index=..., src=...)"
)
if dim is None or index is None or src is None:
raise TypeError("scatter_ requires input, dim, index, and src")
actual_span = _get_span_or_capture(span)
if isinstance(dim, ConstInt):
dim_val = int(dim.value)
elif isinstance(dim, int):
dim_val = dim
else:
raise TypeError(f"dim must be int or ConstInt, got {type(dim)}")
if not isinstance(index, Expr):
raise TypeError(f"index must be Expr, got {type(index)}")
# src can be Expr or scalar (int/float → ConstFloat)
if isinstance(src, (int, float)):
src = ConstFloat(float(src), DataType.FP32, actual_span)
elif not isinstance(src, Expr):
raise TypeError(f"src must be Expr or scalar, got {type(src)}")
op_args: list[Expr] = [input, index, src]
kwargs: dict[str, Any] = {"dim": dim_val}
if reduce is not None:
kwargs["reduce"] = reduce
return _ir_core.create_op_call("tensor.scatter_", op_args, kwargs, actual_span)References
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Fixed: pl.scatter_() now passes reduce kwarg through both the IR and language layers. |
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -59,6 +59,7 @@ | |||||
| "reshape", | ||||||
| "transpose", | ||||||
| "scatter_update", | ||||||
| "scatter_", | ||||||
| ] | ||||||
|
|
||||||
| from pypto.ir.op import tensor_ops as _ir_ops | ||||||
|
|
@@ -779,3 +780,38 @@ def scatter_update( | |||||
| """ | ||||||
| call_expr = _ir_ops.scatter_update(input.unwrap(), dim, index.unwrap(), src.unwrap()) | ||||||
| return Tensor(expr=call_expr) | ||||||
|
|
||||||
|
|
||||||
| def scatter_( | ||||||
| input: Tensor, | ||||||
| dim: int, | ||||||
| index: Tensor, | ||||||
| src: float | int | Tensor, | ||||||
| *, | ||||||
| reduce: str | None = None, | ||||||
| ) -> Tensor: | ||||||
| """Element-level scatter: write src values into input at positions given by index along dim. | ||||||
|
|
||||||
| For each element position (i0,...,in) in index, sets: | ||||||
| input[i0]...[i_{d-1}][index[i0...in]][i_{d+1}]...[in] = src[i0...in] | ||||||
|
|
||||||
| Supports arbitrary rank and any valid dim in [-rank, rank). | ||||||
| src can be a tensor (same shape as index) or a scalar value. | ||||||
|
|
||||||
| Args: | ||||||
| input: Destination tensor (N-D) | ||||||
| dim: Dimension along which to scatter | ||||||
| index: Index tensor (N-D, same rank as input) of integer dtype | ||||||
| src: Source tensor (same shape as index) or scalar value | ||||||
| reduce: Optional reduce mode ("add" or "multiply") | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chain🏁 Script executed: #!/bin/bash
# Check if reduce validation exists anywhere in the scatter implementation chain
rg -n "reduce" --type cpp src/ir/op/tensor_ops/scatter.cpp | head -20
rg -n "reduce.*add\|reduce.*multiply" --type cpp src/ir/transforms/op_conversion_registry.cpp | head -10Repository: hw-native-sys/pypto Length of output: 103 🏁 Script executed: # Check the context around line 806 in tensor_ops.py
cd python/pypto/language/op && sed -n '795,820p' tensor_ops.pyRepository: hw-native-sys/pypto Length of output: 957 🏁 Script executed: # Look at the full scatter function implementation
cd python/pypto/language/op && grep -n "def scatter" tensor_ops.pyRepository: hw-native-sys/pypto Length of output: 104 🏁 Script executed: # Check the C++ scatter implementation for reduce handling
cat src/ir/op/tensor_ops/scatter.cpp | head -150Repository: hw-native-sys/pypto Length of output: 5645 🏁 Script executed: # Search for where the reduce parameter is used in C++ code
rg -B5 -A10 "reduce" --type cpp src/ir/op/tensor_ops/scatter.cppRepository: hw-native-sys/pypto Length of output: 661 🏁 Script executed: # Search for reduce parameter usage in tests or examples
rg -n "reduce" --type python python/pypto | grep -E "(test|example)" | head -20Repository: hw-native-sys/pypto Length of output: 93 🏁 Script executed: # Search for scatter tests
fd -e py "test.*scatter" python/Repository: hw-native-sys/pypto Length of output: 45 🏁 Script executed: # Look for where scatter_ is actually called with reduce parameter
rg -B2 -A2 'scatter_.*reduce' --type python python/Repository: hw-native-sys/pypto Length of output: 93 🏁 Script executed: # Search for scatter tests and reduce usage
fd -type f -name "*scatter*" python/Repository: hw-native-sys/pypto Length of output: 233 🏁 Script executed: # Search for reduce parameter usage in Python files
rg "reduce\s*=" --type py python/ | head -20Repository: hw-native-sys/pypto Length of output: 396 🏁 Script executed: # Look at the actual backend to understand what reduce modes are supported
rg -n "reduce" src/ | grep -E "(add|multiply)" | head -20Repository: hw-native-sys/pypto Length of output: 541 🏁 Script executed: # Check if there are any comments or documentation about valid reduce modes
rg -B3 -A3 "reduce.*add.*multiply" --type py --type cpp --type txtRepository: hw-native-sys/pypto Length of output: 1343 Update docstring to include all valid The Python docstring lists only - reduce: Optional reduce mode ("add" or "multiply")
+ reduce: Optional reduce mode ("none", "add", or "multiply")📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||
|
|
||||||
| Returns: | ||||||
| Tensor wrapping the scatter_ operation | ||||||
| """ | ||||||
| src_expr: float | int | Expr | ||||||
| if isinstance(src, (Tensor, Scalar)): | ||||||
| src_expr = src.unwrap() | ||||||
| else: | ||||||
| src_expr = src | ||||||
| call_expr = _ir_ops.scatter_(input.unwrap(), dim, index.unwrap(), src_expr, reduce=reduce) | ||||||
| return Tensor(expr=call_expr) | ||||||
|
coderabbitai[bot] marked this conversation as resolved.
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -717,22 +717,46 @@ static std::string MakeTileAllocCodegenPTO(const CallPtr& op, codegen::CodegenBa | |
| return ""; // No MLIR emission - pto.alloc_tile generated from MemRefs in TileTypes | ||
| } | ||
|
|
||
| // Compute a row-major flat offset string from a MakeTuple of indices and the shape of the container. | ||
| // Compute a row-major flat offset from a MakeTuple of indices and the shape, | ||
| // emitting proper arith.muli / arith.addi SSA operations. | ||
| // Returns the SSA name of the final flat-offset value (index type). | ||
| static std::string ComputeFlatOffsetPTO(const ir::MakeTuplePtr& indices_tuple, | ||
| const std::vector<ir::ExprPtr>& shape, codegen::PTOCodegen& codegen) { | ||
| const auto& indices = indices_tuple->elements_; | ||
| INTERNAL_CHECK(indices.size() == shape.size()) | ||
| << "Index count (" << indices.size() << ") must match shape rank (" << shape.size() << ")"; | ||
|
|
||
| std::ostringstream idx_oss; | ||
| // Helper: ensure an index element SSA value has `index` type. | ||
| // If the expression is a non-index integer (e.g. i32 from tile.read on an | ||
| // INT32 tile), emit arith.index_cast to convert it. | ||
| auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string { | ||
| if (auto var = ir::As<ir::Var>(expr)) { | ||
| return codegen.EmitCastToIndex(var, ssa); | ||
| } | ||
| return ssa; | ||
| }; | ||
|
|
||
| // For each dimension i, compute: index[i] * (shape[i+1] * shape[i+2] * ... * shape[rank-1]) | ||
| // then sum all terms with arith.addi. | ||
| std::string accumulator; | ||
| for (size_t i = 0; i < indices.size(); ++i) { | ||
| if (i > 0) idx_oss << " + "; | ||
| idx_oss << codegen.GetExprAsCode(indices[i]); | ||
| std::string term = ensure_index(indices[i], codegen.GetExprAsCode(indices[i])); | ||
| // Multiply by each trailing dimension size | ||
| for (size_t j = i + 1; j < shape.size(); ++j) { | ||
| idx_oss << " * " << codegen.GetExprAsCode(shape[j]); | ||
| std::string dim = codegen.GetExprAsCode(shape[j]); | ||
| std::string tmp = codegen.NewTemp(); | ||
| codegen.Emit(tmp + " = arith.muli " + term + ", " + dim + " : index"); | ||
|
Comment on lines
+729
to
+748
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Cast all non- Line 733 only normalizes plain 🔧 Suggested fix- auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
- if (auto var = ir::As<ir::Var>(expr)) {
- return codegen.EmitCastToIndex(var, ssa);
- }
- return ssa;
- };
+ auto ensure_index = [&](const ir::ExprPtr& expr, const std::string& ssa) -> std::string {
+ auto scalar_type = As<ScalarType>(expr->GetType());
+ INTERNAL_CHECK(scalar_type) << "flat index expression must be scalar";
+ if (scalar_type->dtype_ == DataType::INDEX) {
+ return ssa;
+ }
+ CHECK(!scalar_type->dtype_.IsFloat()) << "flat index expression must be integer/index typed";
+ std::string idx_ssa = codegen.NewTemp();
+ codegen.Emit(idx_ssa + " = arith.index_cast " + ssa + " : " +
+ codegen.GetTypeString(scalar_type->dtype_) + " to index");
+ return idx_ssa;
+ };🤖 Prompt for AI Agents |
||
| term = tmp; | ||
| } | ||
| if (accumulator.empty()) { | ||
| accumulator = term; | ||
| } else { | ||
| std::string tmp = codegen.NewTemp(); | ||
| codegen.Emit(tmp + " = arith.addi " + accumulator + ", " + term + " : index"); | ||
| accumulator = tmp; | ||
| } | ||
| } | ||
| return idx_oss.str(); | ||
| return accumulator; | ||
| } | ||
|
|
||
| // Get or emit a flat offset SSA value for a MakeTuple of indices and shape. | ||
|
|
@@ -932,6 +956,15 @@ static std::string MakeTensorDimCodegenPTO(const CallPtr& op, codegen::CodegenBa | |
| return ""; | ||
| } | ||
|
|
||
| static std::string MakeSystemBarrierCodegenPTO(const std::string& pipe_name, const CallPtr& op, | ||
| codegen::CodegenBase& codegen_base) { | ||
| CHECK(op->args_.empty()) << "system.barrier_" << pipe_name << " expects 0 arguments, got " | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要插入barrier,这个工作应该是ptoas做的 |
||
| << op->args_.size(); | ||
| auto& codegen = dynamic_cast<codegen::PTOCodegen&>(codegen_base); | ||
| codegen.Emit("pto.barrier #pto.pipe<" + pipe_name + ">"); | ||
| return ""; | ||
| } | ||
|
|
||
| // ============================================================================ | ||
| // Cross-Core Communication Operations (TPUSH/TPOP) | ||
| // ============================================================================ | ||
|
|
@@ -1334,6 +1367,15 @@ void RegisterPTOOps(Backend& backend, const std::unordered_set<std::string>& exc | |
| reg("tile.cast", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||
| return MakeTileCvtCodegenPTO("pto.tcvt", op, codegen); | ||
| }); | ||
| reg("system.bar_v", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||
| return MakeSystemBarrierCodegenPTO("PIPE_V", op, codegen); | ||
| }); | ||
| reg("system.bar_m", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||
| return MakeSystemBarrierCodegenPTO("PIPE_M", op, codegen); | ||
| }); | ||
| reg("system.bar_all", [](const ir::CallPtr& op, codegen::CodegenBase& codegen) { | ||
| return MakeSystemBarrierCodegenPTO("PIPE_ALL", op, codegen); | ||
| }); | ||
| // tile.full (TEXPANDS): output is row_major per ISA | ||
| if (exclude_ops.count("tile.full") == 0) { | ||
| backend.RegisterOp("tile.full") | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.