diff --git a/python/pypto/ir/op/tile_ops.py b/python/pypto/ir/op/tile_ops.py index 9c7ce7355..a1fe6d241 100644 --- a/python/pypto/ir/op/tile_ops.py +++ b/python/pypto/ir/op/tile_ops.py @@ -1919,6 +1919,7 @@ def slice( shape: Sequence[int | Expr] | _ir_core.MakeTuple, offset: Sequence[int | Expr] | _ir_core.MakeTuple, valid_shape: Sequence[int | Expr] | _ir_core.MakeTuple | None = None, + pad_value: PadValue | int | float | None = None, span: Span | None = None, ) -> Call: """Create a slice of a tile with static shape and optional valid shape. @@ -1929,6 +1930,13 @@ def slice( offset: Offset dimensions for the slice, or a MakeTuple valid_shape: Valid shape dimensions, or a MakeTuple. When omitted, shape is reused as the valid shape. + pad_value: Optional padding mode for out-of-valid-shape elements. + Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or + the literal sugars ``0``, ``math.inf``, ``-math.inf`` (normalized + via :func:`normalize_pad_value`). ``PadValue.null`` is passed + through unchanged and means "no padding". When omitted (``None``), + the kwarg is not forwarded — the deducer defaults to + ``PadValue.null``. span: Optional source span for debugging (auto-captured if not provided) Returns: @@ -1951,7 +1959,15 @@ def slice( ) args.append(valid_shape_tuple) - return _ir_core.create_op_call("tile.slice", args, {}, actual_span) + kwargs: dict[str, Any] = {} + if pad_value is not None: + # PadValue.null is a legal "no padding" signal for slice (unlike + # fillpad, which requires a real padding mode). Pass it through; + # normalize the rest via the shared helper so numeric sugar and + # validation match tile.fillpad exactly. + kwargs["pad_value"] = pad_value if pad_value is PadValue.null else normalize_pad_value(pad_value) + + return _ir_core.create_op_call("tile.slice", args, kwargs, actual_span) def reshape( diff --git a/python/pypto/language/op/tile_ops.py b/python/pypto/language/op/tile_ops.py index 031ce7024..298bde554 100644 --- a/python/pypto/language/op/tile_ops.py +++ b/python/pypto/language/op/tile_ops.py @@ -15,6 +15,7 @@ Accessed as ``pl.tile.*`` """ +import warnings from collections.abc import Sequence from typing import overload @@ -1214,6 +1215,7 @@ def slice( shape: Sequence[IntLike], offset: Sequence[IntLike], valid_shape: Sequence[IntLike] | None = None, + pad_value: PadValue | int | float | None = None, ) -> Tile: """Create a slice of a tile with static shape and optional valid shape. @@ -1223,10 +1225,26 @@ def slice( offset: Offset dimensions for the slice valid_shape: Valid shape dimensions. When omitted, shape is reused as the logical valid shape. + pad_value: Optional padding mode for out-of-valid-shape elements. + ``None`` or ``PadValue.null`` means no padding (the default). + Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or + the literal sugars ``0``, ``math.inf``, ``-math.inf`` (same + spelling as :func:`tile.fillpad`). Only meaningful when + ``valid_shape`` is smaller than ``shape``. Returns: Tile wrapping the slice operation """ + if pad_value is not None and pad_value is not PadValue.null and valid_shape is None: + warnings.warn( + f"tile.slice received pad_value={pad_value!r} but no valid_shape. " + f"pad_value has no effect unless valid_shape is smaller than shape. " + f"If you intend to narrow the valid region later via " + f"tile.set_validshape, you can ignore this warning; otherwise " + f"pass valid_shape=... to tile.slice.", + stacklevel=2, + ) + tile_expr = tile.unwrap() normalized_valid_shape = None if valid_shape is None else _normalize_intlike(valid_shape) call_expr = _ir_ops.slice( @@ -1234,6 +1252,7 @@ def slice( _normalize_intlike(shape), _normalize_intlike(offset), normalized_valid_shape, + pad_value=pad_value, ) return Tile(expr=call_expr) diff --git a/src/ir/op/tile_ops/transform.cpp b/src/ir/op/tile_ops/transform.cpp index a330d5b88..118e9c37b 100644 --- a/src/ir/op/tile_ops/transform.cpp +++ b/src/ir/op/tile_ops/transform.cpp @@ -188,6 +188,20 @@ TypePtr DeduceTileSliceType(const std::vector& args, tile_view.blayout = InferTileLayoutFromShape(new_shape); + // Read optional pad_value kwarg (default PadValue::null = no padding). + PadValue pad_value = PadValue::null; + for (const auto& [k, v] : kwargs) { + if (k != "pad_value") continue; + CHECK(v.type() == typeid(PadValue)) + << "tile.slice pad_value must be a PadValue enum, got " << v.type().name(); + pad_value = std::any_cast(v); + CHECK(pad_value == PadValue::null || pad_value == PadValue::zero || pad_value == PadValue::max || + pad_value == PadValue::min) + << "tile.slice pad_value has invalid enum value: " << static_cast(pad_value); + break; + } + tile_view.pad = pad_value; + return std::make_shared(new_shape, tile_type->dtype_, std::nullopt, tile_view); } @@ -298,6 +312,7 @@ REGISTER_OP("tile.slice") .add_argument("offset", "Offset dimensions (TupleType of ScalarType(INT64/UINT64/INDEX))") .add_argument("valid_shape", "Optional logical valid shape (TupleType of ScalarType(INT64/UINT64/INDEX))") .set_output_memory_inherit_input() + .set_attr("pad_value") .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { return DeduceTileSliceType(args, kwargs); diff --git a/tests/ut/ir/operators/test_op_registry.py b/tests/ut/ir/operators/test_op_registry.py index 220944c56..b09a4d784 100644 --- a/tests/ut/ir/operators/test_op_registry.py +++ b/tests/ut/ir/operators/test_op_registry.py @@ -485,6 +485,12 @@ def test_fillpad_kwarg_schema(): assert tile_fillpad_inplace_op.has_attr("pad_value") +def test_tile_slice_pad_value_kwarg_schema(): + """Test that tile.slice declares pad_value in its kwarg schema.""" + tile_slice_op = ir.get_op("tile.slice") + assert tile_slice_op.has_attr("pad_value") + + class TestOpMemorySpecRegistry: """Test that op memory specs are correctly registered and queryable.""" diff --git a/tests/ut/ir/operators/test_tile_ops.py b/tests/ut/ir/operators/test_tile_ops.py index 1a73a7a8e..5a207ffce 100644 --- a/tests/ut/ir/operators/test_tile_ops.py +++ b/tests/ut/ir/operators/test_tile_ops.py @@ -9,6 +9,8 @@ """Unit tests for tile operations.""" +import math + import pypto.language as pl import pytest from pypto import DataType, ir @@ -1017,6 +1019,115 @@ def test_tile_slice_rejects_dynamic_shape(self): with pytest.raises(ValueError, match="compile-time constant"): tile.slice(tile_var, [8, valid_n], [0, 0]) + @staticmethod + def _make_slice_tile_var(): + """Build a [16, 32] FP16 tile Var for slice pad_value tests.""" + span = ir.Span.unknown() + dim16 = ir.ConstInt(16, DataType.INT32, span) + dim32 = ir.ConstInt(32, DataType.INT32, span) + tile_type = ir.TileType([dim16, dim32], DataType.FP16) + return ir.Var("tile", tile_type, span) + + def test_tile_slice_with_pad_value_zero(self): + """tile.slice writes pad_value=zero to the output tile_view.pad.""" + tile_var = self._make_slice_tile_var() + call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.zero) + + assert isinstance(call, ir.Call) + assert call.op.name == "tile.slice" + result_type = call.type + assert isinstance(result_type, ir.TileType) + assert result_type.tile_view is not None + assert result_type.tile_view.pad == ir.PadValue.zero + assert len(result_type.tile_view.valid_shape) == 2 + assert isinstance(result_type.tile_view.valid_shape[0], ir.ConstInt) + assert result_type.tile_view.valid_shape[0].value == 8 + assert isinstance(result_type.tile_view.valid_shape[1], ir.ConstInt) + assert result_type.tile_view.valid_shape[1].value == 4 + + def test_tile_slice_with_pad_value_min(self): + """tile.slice writes pad_value=min to the output tile_view.pad.""" + tile_var = self._make_slice_tile_var() + call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.min) + + result_type = call.type + assert isinstance(result_type, ir.TileType) + assert result_type.tile_view is not None + assert result_type.tile_view.pad == ir.PadValue.min + + def test_tile_slice_with_pad_value_max(self): + """tile.slice writes pad_value=max to the output tile_view.pad.""" + tile_var = self._make_slice_tile_var() + call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.max) + + result_type = call.type + assert isinstance(result_type, ir.TileType) + assert result_type.tile_view is not None + assert result_type.tile_view.pad == ir.PadValue.max + + def test_tile_slice_default_pad_is_null(self): + """tile.slice without pad_value defaults to PadValue.null (backward compat).""" + tile_var = self._make_slice_tile_var() + call = tile.slice(tile_var, [8, 16], [0, 0]) + + result_type = call.type + assert isinstance(result_type, ir.TileType) + assert result_type.tile_view is not None + assert result_type.tile_view.pad == ir.PadValue.null + + def test_tile_slice_rejects_bad_pad_value(self): + """tile.slice rejects a non-PadValue pad_value kwarg via registry validation.""" + tile_var = self._make_slice_tile_var() + span = tile_var.span + shape_tuple = ir.MakeTuple( + [ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(16, DataType.INT32, span)], span + ) + offset_tuple = ir.MakeTuple( + [ir.ConstInt(0, DataType.INT32, span), ir.ConstInt(0, DataType.INT32, span)], span + ) + valid_shape_tuple = ir.MakeTuple( + [ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(4, DataType.INT32, span)], span + ) + with pytest.raises(TypeError, match="'pad_value'.*incompatible type"): + ir.create_op_call( + "tile.slice", + [tile_var, shape_tuple, offset_tuple, valid_shape_tuple], + {"pad_value": 5}, + span, + ) + + def test_tile_slice_accepts_numeric_sugar_pad_value(self): + """tile.slice maps 0 / math.inf / -math.inf onto PadValue zero/max/min.""" + tile_var = self._make_slice_tile_var() + for literal, expected_pad in [ + (0, ir.PadValue.zero), + (math.inf, ir.PadValue.max), + (-math.inf, ir.PadValue.min), + ]: + call = tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=literal) + result_type = call.type + assert isinstance(result_type, ir.TileType) + assert result_type.tile_view is not None + assert result_type.tile_view.pad == expected_pad + + def test_tile_slice_rejects_bad_numeric_pad_value_at_python_boundary(self): + """Non-sugar numeric values are rejected at the Python API boundary.""" + tile_var = self._make_slice_tile_var() + with pytest.raises(ValueError, match="fillpad pad_value"): + tile.slice(tile_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=5) + + def test_tile_slice_pad_without_valid_shape_warns(self): + """DSL emits a UserWarning when pad_value is set but valid_shape is None.""" + span = ir.Span.unknown() + dim16 = ir.ConstInt(16, DataType.INT32, span) + dim32 = ir.ConstInt(32, DataType.INT32, span) + tile_type = ir.TileType([dim16, dim32], DataType.FP16) + tile_var = ir.Var("tile", tile_type, span) + + tile_arg = pl.Tile(expr=tile_var) + with pytest.warns(UserWarning, match="pad_value has no effect"): + pl.tile.slice(tile_arg, [8, 16], [0, 0], pad_value=pl.PadValue.zero) + def test_tile_reshape(self): """Test tile.reshape operation.""" span = ir.Span.unknown()