Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion python/pypto/ir/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,7 @@ def slice(
shape: Sequence[int | Expr] | _ir_core.MakeTuple,
offset: Sequence[int | Expr] | _ir_core.MakeTuple,
valid_shape: Sequence[int | Expr] | _ir_core.MakeTuple | None = None,
pad_value: PadValue | None = None,
Comment thread
Hzfengsy marked this conversation as resolved.
Outdated
span: Span | None = None,
) -> Call:
"""Create a slice of a tile with static shape and optional valid shape.
Expand All @@ -1929,6 +1930,9 @@ def slice(
offset: Offset dimensions for the slice, or a MakeTuple
valid_shape: Valid shape dimensions, or a MakeTuple. When omitted, shape
is reused as the valid shape.
pad_value: Optional padding mode for out-of-valid-shape elements
(PadValue.zero, PadValue.max, or PadValue.min). When omitted, no
padding is applied (equivalent to PadValue.null).
Comment thread
Hzfengsy marked this conversation as resolved.
Outdated
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Expand All @@ -1951,7 +1955,11 @@ def slice(
)
args.append(valid_shape_tuple)

return _ir_core.create_op_call("tile.slice", args, {}, actual_span)
kwargs: dict[str, Any] = {}
if pad_value is not None:
kwargs["pad_value"] = pad_value
Comment thread
Hzfengsy marked this conversation as resolved.
Outdated

return _ir_core.create_op_call("tile.slice", args, kwargs, actual_span)


def reshape(
Expand Down
16 changes: 16 additions & 0 deletions python/pypto/language/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Accessed as ``pl.tile.*``
"""

import warnings
from collections.abc import Sequence
from typing import overload

Expand Down Expand Up @@ -1214,6 +1215,7 @@ def slice(
shape: Sequence[IntLike],
offset: Sequence[IntLike],
valid_shape: Sequence[IntLike] | None = None,
pad_value: PadValue | None = None,
Comment thread
Hzfengsy marked this conversation as resolved.
Outdated
) -> Tile:
"""Create a slice of a tile with static shape and optional valid shape.

Expand All @@ -1223,17 +1225,31 @@ def slice(
offset: Offset dimensions for the slice
valid_shape: Valid shape dimensions. When omitted, shape is reused as the
logical valid shape.
pad_value: Optional padding mode (PadValue.zero, PadValue.max, or
PadValue.min) applied to out-of-valid-shape elements. Only
meaningful when ``valid_shape`` is smaller than ``shape``.
Comment thread
Hzfengsy marked this conversation as resolved.
Outdated

Returns:
Tile wrapping the slice operation
"""
if pad_value is not None and pad_value != PadValue.null and valid_shape is None:
warnings.warn(
f"tile.slice received pad_value={pad_value!r} but no valid_shape. "
f"pad_value has no effect unless valid_shape is smaller than shape. "
f"If you intend to narrow the valid region later via "
f"tile.set_validshape, you can ignore this warning; otherwise "
f"pass valid_shape=... to tile.slice.",
stacklevel=2,
)

tile_expr = tile.unwrap()
normalized_valid_shape = None if valid_shape is None else _normalize_intlike(valid_shape)
call_expr = _ir_ops.slice(
tile_expr,
_normalize_intlike(shape),
_normalize_intlike(offset),
normalized_valid_shape,
pad_value=pad_value,
)
return Tile(expr=call_expr)

Expand Down
15 changes: 15 additions & 0 deletions src/ir/op/tile_ops/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ TypePtr DeduceTileSliceType(const std::vector<ExprPtr>& args,

tile_view.blayout = InferTileLayoutFromShape(new_shape);

// Read optional pad_value kwarg (default PadValue::null = no padding).
PadValue pad_value = PadValue::null;
for (const auto& [k, v] : kwargs) {
if (k != "pad_value") continue;
CHECK(v.type() == typeid(PadValue))
<< "tile.slice pad_value must be a PadValue enum, got " << v.type().name();
pad_value = std::any_cast<PadValue>(v);
CHECK(pad_value == PadValue::null || pad_value == PadValue::zero || pad_value == PadValue::max ||
pad_value == PadValue::min)
<< "tile.slice pad_value has invalid enum value: " << static_cast<int>(pad_value);
break;
}
tile_view.pad = pad_value;

return std::make_shared<TileType>(new_shape, tile_type->dtype_, std::nullopt, tile_view);
}

Expand Down Expand Up @@ -298,6 +312,7 @@ REGISTER_OP("tile.slice")
.add_argument("offset", "Offset dimensions (TupleType of ScalarType(INT64/UINT64/INDEX))")
.add_argument("valid_shape", "Optional logical valid shape (TupleType of ScalarType(INT64/UINT64/INDEX))")
.set_output_memory_inherit_input()
.set_attr<PadValue>("pad_value")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTileSliceType(args, kwargs);
Expand Down
6 changes: 6 additions & 0 deletions tests/ut/ir/operators/test_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ def test_fillpad_kwarg_schema():
assert tile_fillpad_inplace_op.has_attr("pad_value")


def test_tile_slice_pad_value_kwarg_schema():
"""Test that tile.slice declares pad_value in its kwarg schema."""
tile_slice_op = ir.get_op("tile.slice")
assert tile_slice_op.has_attr("pad_value")


class TestOpMemorySpecRegistry:
"""Test that op memory specs are correctly registered and queryable."""

Expand Down
89 changes: 89 additions & 0 deletions tests/ut/ir/operators/test_tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,95 @@ def test_tile_slice_rejects_dynamic_shape(self):
with pytest.raises(ValueError, match="compile-time constant"):
tile.slice(tile_var, [8, valid_n], [0, 0])

@staticmethod
def _make_slice_tile_var():
"""Build a [16, 32] FP16 tile Var for slice pad_value tests."""
span = ir.Span.unknown()
dim16 = ir.ConstInt(16, DataType.INT32, span)
dim32 = ir.ConstInt(32, DataType.INT32, span)
tile_type = ir.TileType([dim16, dim32], DataType.FP16)
return ir.Var("tile", tile_type, span)

def test_tile_slice_with_pad_value_zero(self):
"""tile.slice writes pad_value=zero to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.zero)

assert isinstance(call, ir.Call)
assert call.op.name == "tile.slice"
result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.zero
assert len(result_type.tile_view.valid_shape) == 2
assert isinstance(result_type.tile_view.valid_shape[0], ir.ConstInt)
assert result_type.tile_view.valid_shape[0].value == 8
assert isinstance(result_type.tile_view.valid_shape[1], ir.ConstInt)
assert result_type.tile_view.valid_shape[1].value == 4

def test_tile_slice_with_pad_value_min(self):
"""tile.slice writes pad_value=min to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.min)

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.min

def test_tile_slice_with_pad_value_max(self):
"""tile.slice writes pad_value=max to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.max)

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.max

def test_tile_slice_default_pad_is_null(self):
"""tile.slice without pad_value defaults to PadValue.null (backward compat)."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0])

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.null

def test_tile_slice_rejects_bad_pad_value(self):
"""tile.slice rejects a non-PadValue pad_value kwarg via registry validation."""
tile_var = self._make_slice_tile_var()
span = tile_var.span
shape_tuple = ir.MakeTuple(
[ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(16, DataType.INT32, span)], span
)
offset_tuple = ir.MakeTuple(
[ir.ConstInt(0, DataType.INT32, span), ir.ConstInt(0, DataType.INT32, span)], span
)
valid_shape_tuple = ir.MakeTuple(
[ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(4, DataType.INT32, span)], span
)
with pytest.raises(TypeError, match="'pad_value'.*incompatible type"):
ir.create_op_call(
"tile.slice",
[tile_var, shape_tuple, offset_tuple, valid_shape_tuple],
{"pad_value": 5},
span,
)

def test_tile_slice_pad_without_valid_shape_warns(self):
"""DSL emits a UserWarning when pad_value is set but valid_shape is None."""
span = ir.Span.unknown()
dim16 = ir.ConstInt(16, DataType.INT32, span)
dim32 = ir.ConstInt(32, DataType.INT32, span)
tile_type = ir.TileType([dim16, dim32], DataType.FP16)
tile_var = ir.Var("tile", tile_type, span)

tile_arg = pl.Tile(expr=tile_var)
with pytest.warns(UserWarning, match="pad_value has no effect"):
pl.tile.slice(tile_arg, [8, 16], [0, 0], pad_value=pl.PadValue.zero)

def test_tile_reshape(self):
"""Test tile.reshape operation."""
span = ir.Span.unknown()
Expand Down
Loading