Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion python/pypto/ir/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1919,6 +1919,7 @@ def slice(
shape: Sequence[int | Expr] | _ir_core.MakeTuple,
offset: Sequence[int | Expr] | _ir_core.MakeTuple,
valid_shape: Sequence[int | Expr] | _ir_core.MakeTuple | None = None,
pad_value: PadValue | int | float | None = None,
span: Span | None = None,
) -> Call:
"""Create a slice of a tile with static shape and optional valid shape.
Expand All @@ -1929,6 +1930,13 @@ def slice(
offset: Offset dimensions for the slice, or a MakeTuple
valid_shape: Valid shape dimensions, or a MakeTuple. When omitted, shape
is reused as the valid shape.
pad_value: Optional padding mode for out-of-valid-shape elements.
Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or
the literal sugars ``0``, ``math.inf``, ``-math.inf`` (normalized
via :func:`normalize_pad_value`). ``PadValue.null`` is passed
through unchanged and means "no padding". When omitted (``None``),
the kwarg is not forwarded — the deducer defaults to
``PadValue.null``.
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Expand All @@ -1951,7 +1959,15 @@ def slice(
)
args.append(valid_shape_tuple)

return _ir_core.create_op_call("tile.slice", args, {}, actual_span)
kwargs: dict[str, Any] = {}
if pad_value is not None:
# PadValue.null is a legal "no padding" signal for slice (unlike
# fillpad, which requires a real padding mode). Pass it through;
# normalize the rest via the shared helper so numeric sugar and
# validation match tile.fillpad exactly.
kwargs["pad_value"] = pad_value if pad_value is PadValue.null else normalize_pad_value(pad_value)

return _ir_core.create_op_call("tile.slice", args, kwargs, actual_span)


def reshape(
Expand Down
19 changes: 19 additions & 0 deletions python/pypto/language/op/tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Accessed as ``pl.tile.*``
"""

import warnings
from collections.abc import Sequence
from typing import overload

Expand Down Expand Up @@ -1214,6 +1215,7 @@ def slice(
shape: Sequence[IntLike],
offset: Sequence[IntLike],
valid_shape: Sequence[IntLike] | None = None,
pad_value: PadValue | int | float | None = None,
) -> Tile:
"""Create a slice of a tile with static shape and optional valid shape.

Expand All @@ -1223,17 +1225,34 @@ def slice(
offset: Offset dimensions for the slice
valid_shape: Valid shape dimensions. When omitted, shape is reused as the
logical valid shape.
pad_value: Optional padding mode for out-of-valid-shape elements.
``None`` or ``PadValue.null`` means no padding (the default).
Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or
the literal sugars ``0``, ``math.inf``, ``-math.inf`` (same
spelling as :func:`tile.fillpad`). Only meaningful when
``valid_shape`` is smaller than ``shape``.

Returns:
Tile wrapping the slice operation
"""
if pad_value is not None and pad_value is not PadValue.null and valid_shape is None:
warnings.warn(
f"tile.slice received pad_value={pad_value!r} but no valid_shape. "
f"pad_value has no effect unless valid_shape is smaller than shape. "
f"If you intend to narrow the valid region later via "
f"tile.set_validshape, you can ignore this warning; otherwise "
f"pass valid_shape=... to tile.slice.",
stacklevel=2,
)

tile_expr = tile.unwrap()
normalized_valid_shape = None if valid_shape is None else _normalize_intlike(valid_shape)
call_expr = _ir_ops.slice(
tile_expr,
_normalize_intlike(shape),
_normalize_intlike(offset),
normalized_valid_shape,
pad_value=pad_value,
)
return Tile(expr=call_expr)

Expand Down
15 changes: 15 additions & 0 deletions src/ir/op/tile_ops/transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,20 @@ TypePtr DeduceTileSliceType(const std::vector<ExprPtr>& args,

tile_view.blayout = InferTileLayoutFromShape(new_shape);

// Read optional pad_value kwarg (default PadValue::null = no padding).
PadValue pad_value = PadValue::null;
for (const auto& [k, v] : kwargs) {
if (k != "pad_value") continue;
CHECK(v.type() == typeid(PadValue))
<< "tile.slice pad_value must be a PadValue enum, got " << v.type().name();
pad_value = std::any_cast<PadValue>(v);
CHECK(pad_value == PadValue::null || pad_value == PadValue::zero || pad_value == PadValue::max ||
pad_value == PadValue::min)
<< "tile.slice pad_value has invalid enum value: " << static_cast<int>(pad_value);
break;
}
tile_view.pad = pad_value;

return std::make_shared<TileType>(new_shape, tile_type->dtype_, std::nullopt, tile_view);
}

Expand Down Expand Up @@ -298,6 +312,7 @@ REGISTER_OP("tile.slice")
.add_argument("offset", "Offset dimensions (TupleType of ScalarType(INT64/UINT64/INDEX))")
.add_argument("valid_shape", "Optional logical valid shape (TupleType of ScalarType(INT64/UINT64/INDEX))")
.set_output_memory_inherit_input()
.set_attr<PadValue>("pad_value")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTileSliceType(args, kwargs);
Expand Down
6 changes: 6 additions & 0 deletions tests/ut/ir/operators/test_op_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,12 @@ def test_fillpad_kwarg_schema():
assert tile_fillpad_inplace_op.has_attr("pad_value")


def test_tile_slice_pad_value_kwarg_schema():
"""Test that tile.slice declares pad_value in its kwarg schema."""
tile_slice_op = ir.get_op("tile.slice")
assert tile_slice_op.has_attr("pad_value")


class TestOpMemorySpecRegistry:
"""Test that op memory specs are correctly registered and queryable."""

Expand Down
111 changes: 111 additions & 0 deletions tests/ut/ir/operators/test_tile_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

"""Unit tests for tile operations."""

import math

import pypto.language as pl
import pytest
from pypto import DataType, ir
Expand Down Expand Up @@ -1017,6 +1019,115 @@ def test_tile_slice_rejects_dynamic_shape(self):
with pytest.raises(ValueError, match="compile-time constant"):
tile.slice(tile_var, [8, valid_n], [0, 0])

@staticmethod
def _make_slice_tile_var():
"""Build a [16, 32] FP16 tile Var for slice pad_value tests."""
span = ir.Span.unknown()
dim16 = ir.ConstInt(16, DataType.INT32, span)
dim32 = ir.ConstInt(32, DataType.INT32, span)
tile_type = ir.TileType([dim16, dim32], DataType.FP16)
return ir.Var("tile", tile_type, span)

def test_tile_slice_with_pad_value_zero(self):
"""tile.slice writes pad_value=zero to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.zero)

assert isinstance(call, ir.Call)
assert call.op.name == "tile.slice"
result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.zero
assert len(result_type.tile_view.valid_shape) == 2
assert isinstance(result_type.tile_view.valid_shape[0], ir.ConstInt)
assert result_type.tile_view.valid_shape[0].value == 8
assert isinstance(result_type.tile_view.valid_shape[1], ir.ConstInt)
assert result_type.tile_view.valid_shape[1].value == 4

def test_tile_slice_with_pad_value_min(self):
"""tile.slice writes pad_value=min to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.min)

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.min

def test_tile_slice_with_pad_value_max(self):
"""tile.slice writes pad_value=max to the output tile_view.pad."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.max)

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.max

def test_tile_slice_default_pad_is_null(self):
"""tile.slice without pad_value defaults to PadValue.null (backward compat)."""
tile_var = self._make_slice_tile_var()
call = tile.slice(tile_var, [8, 16], [0, 0])

result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == ir.PadValue.null

def test_tile_slice_rejects_bad_pad_value(self):
"""tile.slice rejects a non-PadValue pad_value kwarg via registry validation."""
tile_var = self._make_slice_tile_var()
span = tile_var.span
shape_tuple = ir.MakeTuple(
[ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(16, DataType.INT32, span)], span
)
offset_tuple = ir.MakeTuple(
[ir.ConstInt(0, DataType.INT32, span), ir.ConstInt(0, DataType.INT32, span)], span
)
valid_shape_tuple = ir.MakeTuple(
[ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(4, DataType.INT32, span)], span
)
with pytest.raises(TypeError, match="'pad_value'.*incompatible type"):
ir.create_op_call(
"tile.slice",
[tile_var, shape_tuple, offset_tuple, valid_shape_tuple],
{"pad_value": 5},
span,
)

def test_tile_slice_accepts_numeric_sugar_pad_value(self):
"""tile.slice maps 0 / math.inf / -math.inf onto PadValue zero/max/min."""
tile_var = self._make_slice_tile_var()
for literal, expected_pad in [
(0, ir.PadValue.zero),
(math.inf, ir.PadValue.max),
(-math.inf, ir.PadValue.min),
]:
call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=literal)
result_type = call.type
assert isinstance(result_type, ir.TileType)
assert result_type.tile_view is not None
assert result_type.tile_view.pad == expected_pad

def test_tile_slice_rejects_bad_numeric_pad_value_at_python_boundary(self):
"""Non-sugar numeric values are rejected at the Python API boundary."""
tile_var = self._make_slice_tile_var()
with pytest.raises(ValueError, match="fillpad pad_value"):
tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=5)

def test_tile_slice_pad_without_valid_shape_warns(self):
"""DSL emits a UserWarning when pad_value is set but valid_shape is None."""
span = ir.Span.unknown()
dim16 = ir.ConstInt(16, DataType.INT32, span)
dim32 = ir.ConstInt(32, DataType.INT32, span)
tile_type = ir.TileType([dim16, dim32], DataType.FP16)
tile_var = ir.Var("tile", tile_type, span)

tile_arg = pl.Tile(expr=tile_var)
with pytest.warns(UserWarning, match="pad_value has no effect"):
pl.tile.slice(tile_arg, [8, 16], [0, 0], pad_value=pl.PadValue.zero)

def test_tile_reshape(self):
"""Test tile.reshape operation."""
span = ir.Span.unknown()
Expand Down
Loading