diff --git a/docs/en/dev/ir/02-types.md b/docs/en/dev/ir/02-types.md index 27f2e50bd..a6c7af5ef 100644 --- a/docs/en/dev/ir/02-types.md +++ b/docs/en/dev/ir/02-types.md @@ -50,6 +50,11 @@ tensor_with_view = ir.TensorType([128, 256], DataType.FP32, memref=None, tensor_ # With valid_shape tensor_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128]) +# With pad mode for out-of-valid-shape accesses (symmetric with TileView) +tensor_view = ir.TensorView( + stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128], pad=ir.PadValue.zero +) + # Different layouts nd_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND) # ND layout dn_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.DN) # DN layout @@ -70,6 +75,15 @@ tensor_with_both = ir.TensorType([128, 256], DataType.FP16, memref=memref, tenso - `DN`: DN layout - `NZ`: NZ layout +**TensorView fields:** + +- `stride`: stride for each dimension +- `layout`: `TensorLayout.ND` / `DN` / `NZ` +- `valid_shape`: optional valid-region dimensions (empty means use full shape) +- `pad`: `PadValue.null` (default) / `zero` / `max` / `min` — padding mode used + when loads/slices read outside the `valid_shape`. Peer of `TileView.pad`; + `tensor.slice(..., pad_value=PadValue.zero)` writes this field. + ### TileType Specialized tensor with optional memory and view information for hardware-optimized operations. diff --git a/docs/zh-cn/dev/ir/02-types.md b/docs/zh-cn/dev/ir/02-types.md index 72e4dd4d4..2e57432d9 100644 --- a/docs/zh-cn/dev/ir/02-types.md +++ b/docs/zh-cn/dev/ir/02-types.md @@ -50,6 +50,11 @@ tensor_with_view = ir.TensorType([128, 256], DataType.FP32, memref=None, tensor_ # With valid_shape tensor_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128]) +# With pad mode for out-of-valid-shape accesses (symmetric with TileView) +tensor_view = ir.TensorView( + stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128], pad=ir.PadValue.zero +) + # Different layouts nd_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND) # ND layout dn_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.DN) # DN layout @@ -70,6 +75,15 @@ tensor_with_both = ir.TensorType([128, 256], DataType.FP16, memref=memref, tenso - `DN`:DN 布局 - `NZ`:NZ 布局 +**TensorView 字段:** + +- `stride`:每个维度的步长 +- `layout`:`TensorLayout.ND` / `DN` / `NZ` +- `valid_shape`:可选的有效区域维度(为空表示使用完整 shape) +- `pad`:`PadValue.null`(默认)/ `zero` / `max` / `min`,用于访问超出 + `valid_shape` 部分时的填充模式。与 `TileView.pad` 对称; + `tensor.slice(..., pad_value=PadValue.zero)` 会写入该字段。 + ### TileType 专用张量类型,带可选内存和视图信息,用于硬件优化操作。 diff --git a/include/pypto/ir/transforms/utils/memref_utils.h b/include/pypto/ir/transforms/utils/memref_utils.h index 8fdadaf04..2288e3a58 100644 --- a/include/pypto/ir/transforms/utils/memref_utils.h +++ b/include/pypto/ir/transforms/utils/memref_utils.h @@ -92,7 +92,7 @@ inline std::optional RemapTensorViewExprs(const std::optionallayout, std::move(new_valid_shape)); + return TensorView(std::move(new_stride), tensor_view->layout, std::move(new_valid_shape), tensor_view->pad); } template diff --git a/include/pypto/ir/type.h b/include/pypto/ir/type.h index 725f1c4b9..bec08729c 100644 --- a/include/pypto/ir/type.h +++ b/include/pypto/ir/type.h @@ -153,18 +153,35 @@ std::string TensorLayoutToString(TensorLayout layout); */ TensorLayout StringToTensorLayout(const std::string& str); +/** + * @brief Pad mode enumeration (shared by TileView and TensorView) + * + * Defines the padding mode applied when a tile/tensor view access falls + * outside `valid_shape` but still within the physical shape: + * - null: No padding + * - zero: Pad with zero + * - max: Pad with maximum value of the element type + * - min: Pad with minimum value of the element type + */ +enum class PadValue { + null, ///< No padding + zero, ///< Zero padding + max, ///< Max value padding + min ///< Min value padding +}; + /** * @brief Tensor view representation * - * Represents the view information for a tensor, including stride and layout. - * The shape is stored in TensorType itself, so TensorView only needs - * stride and layout information. + * Represents the view information for a tensor, including stride, layout, + * valid_shape, and pad mode. The shape is stored in TensorType itself. */ struct TensorView { std::vector stride; ///< Stride for each dimension TensorLayout layout; ///< Tensor layout type std::vector - valid_shape; ///< Valid shape for each dimension (optional, empty means use full shape) + valid_shape; ///< Valid shape for each dimension (optional, empty means use full shape) + PadValue pad = PadValue::null; ///< Pad mode for accesses outside valid_shape but within shape /** * @brief Default constructor with ND layout and empty stride/valid_shape @@ -177,9 +194,11 @@ struct TensorView { * @param stride Stride for each dimension * @param layout Tensor layout type * @param valid_shape Valid shape for each dimension (optional, defaults to empty) + * @param pad Pad mode (optional, defaults to PadValue::null) */ - TensorView(std::vector stride, TensorLayout layout, std::vector valid_shape = {}) - : stride(std::move(stride)), layout(layout), valid_shape(std::move(valid_shape)) {} + TensorView(std::vector stride, TensorLayout layout, std::vector valid_shape = {}, + PadValue pad = PadValue::null) + : stride(std::move(stride)), layout(layout), valid_shape(std::move(valid_shape)), pad(pad) {} /** * @brief Constructor with integer stride and valid_shape (auto-converted to ConstInt) @@ -187,9 +206,10 @@ struct TensorView { * @param stride Stride for each dimension (int64, converted to ConstInt with INDEX dtype) * @param layout Tensor layout type * @param valid_shape Valid shape for each dimension (int64, defaults to empty) + * @param pad Pad mode (optional, defaults to PadValue::null) */ TensorView(const std::vector& stride, TensorLayout layout, - const std::vector& valid_shape = {}); + const std::vector& valid_shape = {}, PadValue pad = PadValue::null); /** * @brief Get field descriptors for reflection-based visitation @@ -199,7 +219,8 @@ struct TensorView { static constexpr auto GetFieldDescriptors() { return std::make_tuple(reflection::UsualField(&TensorView::stride, "stride"), reflection::UsualField(&TensorView::layout, "layout"), - reflection::UsualField(&TensorView::valid_shape, "valid_shape")); + reflection::UsualField(&TensorView::valid_shape, "valid_shape"), + reflection::UsualField(&TensorView::pad, "pad")); } }; @@ -227,22 +248,6 @@ std::string TileLayoutToString(TileLayout layout); */ TileLayout StringToTileLayout(const std::string& str); -/** - * @brief Tile pad enumeration - * - * Defines the padding mode for out-of-bound tile accesses: - * - null: No padding - * - zero: Pad with zero - * - max: Pad with maximum value of the element type - * - min: Pad with minimum value of the element type - */ -enum class PadValue { - null, ///< No padding - zero, ///< Zero padding - max, ///< Max value padding - min ///< Min value padding -}; - /** * @brief Tile view representation * diff --git a/python/bindings/modules/ir.cpp b/python/bindings/modules/ir.cpp index 5f877d5fe..ae3ee96f8 100644 --- a/python/bindings/modules/ir.cpp +++ b/python/bindings/modules/ir.cpp @@ -291,18 +291,30 @@ void BindIR(nb::module_& m) { .value("NZ", TensorLayout::NZ, "NZ layout") .export_values(); + // PadValue enum - must be before both TensorView and TileView since both carry it + nb::enum_(ir, "PadValue", "Pad mode enumeration for tile/tensor views") + .value("null", PadValue::null, "No padding") + .value("zero", PadValue::zero, "Zero padding") + .value("max", PadValue::max, "Max value padding") + .value("min", PadValue::min, "Min value padding") + .export_values(); + // TensorView - struct for tensor view information - must be before TensorType - nb::class_(ir, "TensorView", "Tensor view representation with stride, layout and valid shape") + nb::class_(ir, "TensorView", + "Tensor view representation with stride, layout, valid shape, and pad mode") .def(nb::init<>(), "Create an empty tensor view") - .def(nb::init&, TensorLayout, const std::vector&>(), + .def(nb::init&, TensorLayout, const std::vector&, PadValue>(), nb::arg("stride"), nb::arg("layout"), nb::arg("valid_shape") = std::vector{}, - "Create a tensor view with stride, layout and optional valid shape") - .def(nb::init&, TensorLayout, const std::vector&>(), + nb::arg("pad") = PadValue::null, + "Create a tensor view with stride, layout, optional valid shape, and optional pad") + .def(nb::init&, TensorLayout, const std::vector&, PadValue>(), nb::arg("stride"), nb::arg("layout"), nb::arg("valid_shape") = std::vector{}, - "Create a tensor view with integer stride, layout and optional integer valid shape") + nb::arg("pad") = PadValue::null, + "Create a tensor view with integer stride, layout, optional integer valid shape, and optional pad") .def_rw("stride", &TensorView::stride, "Stride for each dimension") .def_rw("layout", &TensorView::layout, "Tensor layout type") - .def_rw("valid_shape", &TensorView::valid_shape, "Valid shape for each dimension"); + .def_rw("valid_shape", &TensorView::valid_shape, "Valid shape for each dimension") + .def_rw("pad", &TensorView::pad, "Pad mode for out-of-valid-shape accesses"); // TensorType - const shared_ptr auto tensor_type_class = nb::class_(ir, "TensorType", "Tensor type representation"); @@ -392,14 +404,6 @@ void BindIR(nb::module_& m) { .value("col_major", TileLayout::col_major, "Column-major layout") .export_values(); - // PadValue enum - must be before TileView - nb::enum_(ir, "PadValue", "Tile pad mode enumeration") - .value("null", PadValue::null, "No padding") - .value("zero", PadValue::zero, "Zero padding") - .value("max", PadValue::max, "Max value padding") - .value("min", PadValue::min, "Min value padding") - .export_values(); - // TileView - struct for tile view information nb::class_( ir, "TileView", diff --git a/python/pypto/ir/op/tensor_ops.py b/python/pypto/ir/op/tensor_ops.py index 56cfd541a..b18c7be77 100644 --- a/python/pypto/ir/op/tensor_ops.py +++ b/python/pypto/ir/op/tensor_ops.py @@ -162,6 +162,7 @@ def slice( shape: list[int | Expr] | _ir_core.MakeTuple, offset: list[int | Expr] | _ir_core.MakeTuple, valid_shape: list[int | Expr] | _ir_core.MakeTuple | None = None, + pad_value: PadValue | int | float | None = None, span: Span | None = None, ) -> Call: """Create a slice of a tensor with new shape and offset. @@ -171,6 +172,13 @@ def slice( shape: New shape dimensions, or a MakeTuple offset: Offset dimensions for the slice, or a MakeTuple valid_shape: Valid shape dimensions (optional, defaults to empty) + pad_value: Optional padding mode for out-of-valid-shape elements. + Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or + the literal sugars ``0``, ``math.inf``, ``-math.inf`` (normalized + via :func:`normalize_pad_value`). ``PadValue.null`` is passed + through unchanged and means "no padding". When omitted (``None``), + the kwarg is not forwarded — the deducer defaults to + ``PadValue.null``. span: Optional source span for debugging (auto-captured if not provided) Returns: @@ -184,7 +192,16 @@ def slice( args = [tensor, shape_tuple, offset_tuple] if valid_shape is not None: args.append(_to_make_tuple(valid_shape, actual_span)) - return _ir_core.create_op_call("tensor.slice", args, {}, actual_span) + + kwargs: dict[str, Any] = {} + if pad_value is not None: + # PadValue.null is a legal "no padding" signal for slice (unlike + # fillpad, which requires a real padding mode). Pass it through; + # normalize the rest via the shared helper so numeric sugar and + # validation match tensor.fillpad exactly. + kwargs["pad_value"] = pad_value if pad_value is PadValue.null else normalize_pad_value(pad_value) + + return _ir_core.create_op_call("tensor.slice", args, kwargs, actual_span) def fillpad( diff --git a/python/pypto/ir/type.py b/python/pypto/ir/type.py index 4104e09af..430e7794b 100644 --- a/python/pypto/ir/type.py +++ b/python/pypto/ir/type.py @@ -140,12 +140,13 @@ def __call__( stride: Sequence[Expr | int] | None = None, layout: TensorLayout | None = None, valid_shape: Sequence[Expr | int] | None = None, + pad: PadValue = PadValue.null, ) -> "_TensorViewBase": - if stride is None and layout is None and valid_shape is None: + if stride is None and layout is None and valid_shape is None and pad == PadValue.null: return _TensorViewBase() if layout is None: - raise ValueError("layout is required when stride or valid_shape is provided") - return _TensorViewBase(_normalize_seq(stride), layout, _normalize_seq(valid_shape)) + raise ValueError("layout is required when stride, valid_shape, or pad is provided") + return _TensorViewBase(_normalize_seq(stride), layout, _normalize_seq(valid_shape), pad) class TensorView(metaclass=_TensorViewMeta): diff --git a/python/pypto/language/op/tensor_ops.py b/python/pypto/language/op/tensor_ops.py index 0884365a3..a12a59e0a 100644 --- a/python/pypto/language/op/tensor_ops.py +++ b/python/pypto/language/op/tensor_ops.py @@ -13,6 +13,7 @@ that accept and return Tensor types instead of raw Expr/Call objects. """ +import warnings from collections.abc import Sequence from typing import overload @@ -168,6 +169,7 @@ def slice( shape: Sequence[IntLike], offset: Sequence[IntLike], valid_shape: Sequence[IntLike] | None = None, + pad_value: PadValue | int | float | None = None, ) -> Tensor: """Create a slice of a tensor with new shape and optional valid shape. @@ -176,10 +178,26 @@ def slice( shape: New shape dimensions offset: Offset dimensions for the slice valid_shape: Valid shape dimensions. When omitted, the full shape is valid. + pad_value: Optional padding mode for out-of-valid-shape elements. + ``None`` or ``PadValue.null`` means no padding (the default). + Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or + the literal sugars ``0``, ``math.inf``, ``-math.inf`` (same + spelling as :func:`tensor.fillpad`). Only meaningful when + ``valid_shape`` is smaller than ``shape``. Returns: Tensor wrapping the slice operation """ + if pad_value is not None and pad_value is not PadValue.null and valid_shape is None: + warnings.warn( + f"tensor.slice received pad_value={pad_value!r} but no valid_shape. " + f"pad_value has no effect unless valid_shape is smaller than shape. " + f"If you intend to narrow the valid region later via " + f"tensor.set_validshape, you can ignore this warning; otherwise " + f"pass valid_shape=... to tensor.slice.", + stacklevel=2, + ) + tensor_expr = tensor.unwrap() normalized_valid_shape = None if valid_shape is None else _normalize_intlike(valid_shape) call_expr = _ir_ops.slice( @@ -187,6 +205,7 @@ def slice( _normalize_intlike(shape), _normalize_intlike(offset), normalized_valid_shape, + pad_value=pad_value, ) return Tensor(expr=call_expr) diff --git a/python/pypto/pypto_core/ir.pyi b/python/pypto/pypto_core/ir.pyi index ccd60e471..cc0839075 100644 --- a/python/pypto/pypto_core/ir.pyi +++ b/python/pypto/pypto_core/ir.pyi @@ -366,7 +366,7 @@ class PadValue(enum.Enum): """Min value padding.""" class TensorView: - """Tensor view representation with stride, layout and valid shape.""" + """Tensor view representation with stride, layout, valid shape, and pad mode.""" stride: Sequence[Expr] """Stride for each dimension.""" @@ -377,6 +377,9 @@ class TensorView: valid_shape: Sequence[Expr] """Valid shape for each dimension (empty means use full shape).""" + pad: PadValue + """Pad mode for out-of-valid-shape accesses (default PadValue.null).""" + @overload def __init__(self) -> None: """Create an empty tensor view with default ND layout.""" @@ -387,13 +390,15 @@ class TensorView: stride: Sequence[Expr | int | Scalar], layout: TensorLayout, valid_shape: Sequence[Expr | int | Scalar] = ..., + pad: PadValue = ..., ) -> None: - """Create a tensor view with stride, layout and optional valid shape. + """Create a tensor view with stride, layout, optional valid shape, and optional pad. Args: stride: Stride for each dimension (Expr, int, or Scalar/DynVar) layout: Tensor layout type (ND, DN, or NZ) valid_shape: Valid shape for each dimension (optional, defaults to empty) + pad: Pad mode for out-of-valid-shape accesses (defaults to PadValue.null) """ class TensorType(ShapedType): diff --git a/src/ir/op/tensor_ops/memory.cpp b/src/ir/op/tensor_ops/memory.cpp index 39e246a6b..d86c3aa88 100644 --- a/src/ir/op/tensor_ops/memory.cpp +++ b/src/ir/op/tensor_ops/memory.cpp @@ -199,12 +199,30 @@ TypePtr DeduceTensorSliceType(const std::vector& args, } } - // View preserves dtype but has new shape (which can have different rank than input) - // If valid_shape is provided as 4th argument, store it in TensorView + // Read optional pad_value kwarg (default PadValue::null = no padding). + PadValue pad_value = PadValue::null; + for (const auto& [k, v] : kwargs) { + if (k != "pad_value") continue; + CHECK(v.type() == typeid(PadValue)) + << "tensor.slice pad_value must be a PadValue enum, got " << v.type().name(); + pad_value = std::any_cast(v); + CHECK(pad_value == PadValue::null || pad_value == PadValue::zero || pad_value == PadValue::max || + pad_value == PadValue::min) + << "tensor.slice pad_value has invalid enum value: " << static_cast(pad_value); + break; + } + + // View preserves dtype but has new shape (which can have different rank than input). + // If valid_shape is provided as 4th argument or pad_value is set, build a TensorView. if (args.size() == 4) { auto valid_shape_tuple = As(args[3]); CHECK(valid_shape_tuple) << "tensor.slice valid_shape (4th argument) must be a MakeTuple"; - TensorView tensor_view({}, TensorLayout::ND, valid_shape_tuple->elements_); + TensorView tensor_view({}, TensorLayout::ND, valid_shape_tuple->elements_, pad_value); + return std::make_shared(new_shape, tensor_type->dtype_, std::nullopt, + std::make_optional(std::move(tensor_view))); + } + if (pad_value != PadValue::null) { + TensorView tensor_view(std::vector{}, TensorLayout::ND, std::vector{}, pad_value); return std::make_shared(new_shape, tensor_type->dtype_, std::nullopt, std::make_optional(std::move(tensor_view))); } @@ -356,6 +374,7 @@ REGISTER_OP("tensor.slice") .add_argument("shape", "New shape dimensions (TupleType of ScalarType(INT64))") .add_argument("offset", "Offset dimensions (TupleType of ScalarType(INT64))") .set_output_memory_inherit_input() + .set_attr("pad_value") .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { return DeduceTensorSliceType(args, kwargs); diff --git a/src/ir/serialization/deserializer.cpp b/src/ir/serialization/deserializer.cpp index 289bd3a39..5c8617ac4 100644 --- a/src/ir/serialization/deserializer.cpp +++ b/src/ir/serialization/deserializer.cpp @@ -292,7 +292,22 @@ class IRDeserializer::Impl : public detail::DeserializerContext { } else { CHECK(false) << "Unknown TensorLayout: " << layout_str; } + } else if (key == "pad") { + std::string pad_str; + p->val.convert(pad_str); + if (pad_str == "null") { + tensor_view.pad = PadValue::null; + } else if (pad_str == "zero") { + tensor_view.pad = PadValue::zero; + } else if (pad_str == "max") { + tensor_view.pad = PadValue::max; + } else if (pad_str == "min") { + tensor_view.pad = PadValue::min; + } else { + CHECK(false) << "Unknown PadValue: " << pad_str; + } } + // Older serialized IR may omit "pad"; default stays PadValue::null. } return tensor_view; diff --git a/src/ir/serialization/serializer.cpp b/src/ir/serialization/serializer.cpp index 807faed14..993f97980 100644 --- a/src/ir/serialization/serializer.cpp +++ b/src/ir/serialization/serializer.cpp @@ -357,6 +357,24 @@ class IRSerializer::Impl { // Serialize layout enum tv_map["layout"] = msgpack::object(TensorLayoutToString(tensor_view->layout), zone); + // Serialize pad enum (same string encoding as TileView::pad) + std::string pad_str; + switch (tensor_view->pad) { + case PadValue::null: + pad_str = "null"; + break; + case PadValue::zero: + pad_str = "zero"; + break; + case PadValue::max: + pad_str = "max"; + break; + case PadValue::min: + pad_str = "min"; + break; + } + tv_map["pad"] = msgpack::object(pad_str, zone); + return msgpack::object(tv_map, zone); } diff --git a/src/ir/transforms/convert_to_ssa_pass.cpp b/src/ir/transforms/convert_to_ssa_pass.cpp index 2bd1f7ff1..c51836e4c 100644 --- a/src/ir/transforms/convert_to_ssa_pass.cpp +++ b/src/ir/transforms/convert_to_ssa_pass.cpp @@ -298,7 +298,7 @@ class SSAConverter { auto [st, st_changed] = SubstExprVec(tv.stride); if (vs_changed || st_changed) { changed = true; - new_tv = TensorView(std::move(st), tv.layout, std::move(vs)); + new_tv = TensorView(std::move(st), tv.layout, std::move(vs), tv.pad); } } if (changed) { diff --git a/src/ir/transforms/op_conversion_registry.cpp b/src/ir/transforms/op_conversion_registry.cpp index c27933b08..6a2e2cba5 100644 --- a/src/ir/transforms/op_conversion_registry.cpp +++ b/src/ir/transforms/op_conversion_registry.cpp @@ -214,10 +214,23 @@ void OpConversionRegistry::RegisterMemoryOps() { const auto& shape = args[1]; const auto& offset = args[2]; + // Extract pad_value kwarg (if any) to forward to the emitted tile.slice. + std::vector> forward_kwargs; + for (const auto& kv : kwargs) { + if (kv.first == "pad_value") { + forward_kwargs.push_back(kv); + break; + } + } + auto tensor_type = As(input->GetType()); auto tile_type = As(input->GetType()); if (tensor_type) { + // The tile.load path does not currently accept pad_value. If the user set + // pad_value on a tensor.slice over a TensorType input, the pad intent is + // lost here — a follow-up tile.fillpad is the workaround until tile.load + // grows its own pad_value kwarg. auto valid_shapes = (args.size() == 4) ? args[3] : shape; std::vector> load_kwargs = {{"target_memory", MemorySpace::Vec}, {"transpose", false}}; @@ -231,7 +244,7 @@ void OpConversionRegistry::RegisterMemoryOps() { if (args.size() == 4) { slice_args.push_back(args[3]); } - auto slice_call = op_reg.Create("tile.slice", slice_args, span); + auto slice_call = op_reg.Create("tile.slice", slice_args, forward_kwargs, span); return ConversionResult{slice_call}; } diff --git a/src/ir/transforms/python_printer.cpp b/src/ir/transforms/python_printer.cpp index ab9fa590c..e8460728f 100644 --- a/src/ir/transforms/python_printer.cpp +++ b/src/ir/transforms/python_printer.cpp @@ -1931,12 +1931,13 @@ std::string IRPythonPrinter::PrintTensorView(const TensorView& tensor_view, bool has_stride = !tensor_view.stride.empty(); bool has_non_default_layout = (tensor_view.layout != TensorLayout::ND); + bool has_non_default_pad = (tensor_view.pad != PadValue::null); - // If valid_shape matched and stride/layout are at defaults, skip TensorView entirely - if (first && !has_stride && !has_non_default_layout) return ""; + // If all fields are at defaults, skip TensorView entirely + if (first && !has_stride && !has_non_default_layout && !has_non_default_pad) return ""; // When TensorView is non-trivial, always emit both stride and layout to satisfy - // the C++ constructor signature TensorView(stride, layout, valid_shape=[]). + // the C++ constructor signature TensorView(stride, layout, valid_shape=[], pad=null). // Omitting either required arg causes TypeError when Python eagerly evaluates // function parameter annotations during exec() in the text parser. maybe_comma(); @@ -1950,6 +1951,26 @@ std::string IRPythonPrinter::PrintTensorView(const TensorView& tensor_view, maybe_comma(); oss << "layout=" << prefix_ << ".TensorLayout." << TensorLayoutToString(tensor_view.layout); + // pad — omit if null (default) + if (has_non_default_pad) { + maybe_comma(); + oss << "pad=" << prefix_ << ".PadValue."; + switch (tensor_view.pad) { + case PadValue::null: + oss << "null"; + break; + case PadValue::zero: + oss << "zero"; + break; + case PadValue::max: + oss << "max"; + break; + case PadValue::min: + oss << "min"; + break; + } + } + oss << ")"; return oss.str(); } diff --git a/src/ir/transforms/simplify_pass.cpp b/src/ir/transforms/simplify_pass.cpp index 539811159..1fc44afdd 100644 --- a/src/ir/transforms/simplify_pass.cpp +++ b/src/ir/transforms/simplify_pass.cpp @@ -319,7 +319,7 @@ class SimplifyMutator : public arith::IRMutatorWithAnalyzer { auto new_vs = SimplifyExprVec(tv.valid_shape, &view_changed); if (view_changed) { changed = true; - new_tv = TensorView(std::move(new_stride), tv.layout, std::move(new_vs)); + new_tv = TensorView(std::move(new_stride), tv.layout, std::move(new_vs), tv.pad); } } if (!changed) return type; diff --git a/src/ir/transforms/structural_equal.cpp b/src/ir/transforms/structural_equal.cpp index 5bce9b2fc..66a591e6e 100644 --- a/src/ir/transforms/structural_equal.cpp +++ b/src/ir/transforms/structural_equal.cpp @@ -1035,6 +1035,13 @@ bool StructuralEqualImpl::EqualType(const TypePtr& lhs, const TypePt } return false; } + // Compare pad + if (lhs_tv.pad != rhs_tv.pad) { + if constexpr (AssertMode) { + ThrowMismatch("TensorView pad mismatch", IRNodePtr(), IRNodePtr(), "", ""); + } + return false; + } } return true; } else if (auto lhs_tile = As(lhs)) { diff --git a/src/ir/type.cpp b/src/ir/type.cpp index 587993ee5..fe1d6176a 100644 --- a/src/ir/type.cpp +++ b/src/ir/type.cpp @@ -115,8 +115,8 @@ ShapedType::ShapedType(DataType dtype, const std::vector& shape, std::o } TensorView::TensorView(const std::vector& stride_ints, TensorLayout layout_, - const std::vector& valid_shape_ints) - : layout(layout_) { + const std::vector& valid_shape_ints, PadValue pad_) + : layout(layout_), pad(pad_) { for (int64_t s : stride_ints) { stride.push_back(std::make_shared(s, DataType::INDEX, Span::unknown())); } diff --git a/tests/ut/ir/operators/test_op_registry.py b/tests/ut/ir/operators/test_op_registry.py index b09a4d784..9330331a1 100644 --- a/tests/ut/ir/operators/test_op_registry.py +++ b/tests/ut/ir/operators/test_op_registry.py @@ -491,6 +491,12 @@ def test_tile_slice_pad_value_kwarg_schema(): assert tile_slice_op.has_attr("pad_value") +def test_tensor_slice_pad_value_kwarg_schema(): + """Test that tensor.slice declares pad_value in its kwarg schema.""" + tensor_slice_op = ir.get_op("tensor.slice") + assert tensor_slice_op.has_attr("pad_value") + + class TestOpMemorySpecRegistry: """Test that op memory specs are correctly registered and queryable.""" diff --git a/tests/ut/ir/operators/test_tensor_ops.py b/tests/ut/ir/operators/test_tensor_ops.py index dc5b9fa4f..9db52d020 100644 --- a/tests/ut/ir/operators/test_tensor_ops.py +++ b/tests/ut/ir/operators/test_tensor_ops.py @@ -17,6 +17,9 @@ - Python helper functions """ +import math + +import pypto.language as pl import pytest from pypto import DataType, ir from pypto.ir.op import tensor @@ -987,6 +990,105 @@ def test_tensor_slice_with_valid_shape(): assert len(result_type.tensor_view.valid_shape) == 2 +def _make_slice_tensor_var(): + """Build a [16, 32] FP16 tensor Var for slice pad_value tests.""" + span = ir.Span.unknown() + dim16 = ir.ConstInt(16, DataType.INT32, span) + dim32 = ir.ConstInt(32, DataType.INT32, span) + tensor_type = ir.TensorType([dim16, dim32], DataType.FP16) + return ir.Var("t", tensor_type, span) + + +def test_tensor_slice_with_pad_value(): + """tensor.slice writes pad_value=zero to the output tensor_view.pad.""" + tensor_var = _make_slice_tensor_var() + call = tensor.slice(tensor_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=ir.PadValue.zero) + + assert isinstance(call, ir.Call) + assert call.op.name == "tensor.slice" + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.tensor_view is not None + assert result_type.tensor_view.pad == ir.PadValue.zero + assert len(result_type.tensor_view.valid_shape) == 2 + + # Sanity-check min/max variants reach the same field. + for pad in (ir.PadValue.min, ir.PadValue.max): + call_p = tensor.slice(tensor_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=pad) + result_type_p = call_p.type + assert isinstance(result_type_p, ir.TensorType) + assert result_type_p.tensor_view is not None + assert result_type_p.tensor_view.pad == pad + + +def test_tensor_slice_default_pad_is_null(): + """tensor.slice without pad_value defaults to PadValue.null (backward compat).""" + tensor_var = _make_slice_tensor_var() + + # No tensor_view created when both valid_shape and pad_value are absent. + call = tensor.slice(tensor_var, [8, 16], [0, 0]) + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.tensor_view is None + + # With only valid_shape provided, tensor_view is present and pad defaults to null. + call_vs = tensor.slice(tensor_var, [8, 16], [0, 0], valid_shape=[8, 4]) + result_type_vs = call_vs.type + assert isinstance(result_type_vs, ir.TensorType) + assert result_type_vs.tensor_view is not None + assert result_type_vs.tensor_view.pad == ir.PadValue.null + + +def test_tensor_slice_rejects_bad_pad_value(): + """tensor.slice rejects a non-PadValue pad_value kwarg via registry validation.""" + tensor_var = _make_slice_tensor_var() + span = tensor_var.span + shape_tuple = ir.MakeTuple( + [ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(16, DataType.INT32, span)], span + ) + offset_tuple = ir.MakeTuple( + [ir.ConstInt(0, DataType.INT32, span), ir.ConstInt(0, DataType.INT32, span)], span + ) + valid_shape_tuple = ir.MakeTuple( + [ir.ConstInt(8, DataType.INT32, span), ir.ConstInt(4, DataType.INT32, span)], span + ) + with pytest.raises(TypeError, match="'pad_value'.*incompatible type"): + ir.create_op_call( + "tensor.slice", + [tensor_var, shape_tuple, offset_tuple, valid_shape_tuple], + {"pad_value": 5}, + span, + ) + + +def test_tensor_slice_accepts_numeric_sugar_pad_value(): + """tensor.slice maps 0 / math.inf / -math.inf onto PadValue zero/max/min.""" + tensor_var = _make_slice_tensor_var() + for literal, expected_pad in [ + (0, ir.PadValue.zero), + (math.inf, ir.PadValue.max), + (-math.inf, ir.PadValue.min), + ]: + call = tensor.slice(tensor_var, [8, 16], [0, 0], valid_shape=[8, 4], pad_value=literal) + result_type = call.type + assert isinstance(result_type, ir.TensorType) + assert result_type.tensor_view is not None + assert result_type.tensor_view.pad == expected_pad + + +def test_tensor_slice_pad_without_valid_shape_warns(): + """DSL emits a UserWarning when pad_value is set but valid_shape is None.""" + span = ir.Span.unknown() + dim16 = ir.ConstInt(16, DataType.INT32, span) + dim32 = ir.ConstInt(32, DataType.INT32, span) + tensor_type = ir.TensorType([dim16, dim32], DataType.FP16) + tensor_var = ir.Var("t", tensor_type, span) + + tensor_arg = pl.Tensor(expr=tensor_var) + with pytest.warns(UserWarning, match="pad_value has no effect"): + pl.tensor.slice(tensor_arg, [8, 16], [0, 0], pad_value=pl.PadValue.zero) + + def test_tensor_fillpad_clears_valid_shape(): """Test tensor.fillpad materializes a full-valid tensor view.""" span = ir.Span.unknown() diff --git a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py index 0008a996f..180924a84 100644 --- a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py +++ b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py @@ -1154,6 +1154,32 @@ def expected_body(ib, tiles, extras=()): ) _assert_convert_equal(before, expected) + def test_local_tensor_slice_with_pad_value_forwards_to_tile_slice(self): + """tensor.slice(..., pad_value=X) on a local tensor lowers to tile.slice(..., pad_value=X).""" + in_specs: list[InSpec] = [("x", [8, 32], DataType.FP32)] + + def before_body(ib, ins): + t = ib.let("t", tensor_ops.create([16, 64], DataType.FP32)) + s = ib.let( + "s", + tensor_ops.slice(t, [8, 32], [0, 0], valid_shape=[8, 8], pad_value=PadValue.min), + ) + return ib.let("y", tensor_ops.add(s, ins[0])) + + def expected_body(ib, tiles): + t_tile = ib.let("t_tile", tile_ops.create([16, 64], DataType.FP32)) + s_tile = ib.let( + "s_tile", + tile_ops.slice(t_tile, [8, 32], [0, 0], valid_shape=[8, 8], pad_value=PadValue.min), + ) + return ib.let("y_tile", tile_ops.add(s_tile, tiles[0])) + + before = _make_before(in_specs=in_specs, out_shape=[8, 32], out_dtype=DataType.FP32, body=before_body) + expected = _make_expected( + in_specs=in_specs, out_shape=[8, 32], out_dtype=DataType.FP32, body=expected_body + ) + _assert_convert_equal(before, expected) + def test_tensor_fillpad_converts_to_tile_fillpad(self): """tensor.fillpad should lower to tile.fillpad after loading the tensor.""" before, expected = _make_pair( diff --git a/tests/ut/ir/transforms/test_serialization.py b/tests/ut/ir/transforms/test_serialization.py index 971ec643a..122cb06ab 100644 --- a/tests/ut/ir/transforms/test_serialization.py +++ b/tests/ut/ir/transforms/test_serialization.py @@ -766,6 +766,32 @@ def test_tiletype_without_memory_space_survives_round_trip(self): # Verify structural equality ir.assert_structural_equal(var, restored_var, enable_auto_mapping=True) + def test_tensortype_tensorview_pad_survives_round_trip(self): + """TensorView::pad is preserved through serialize/deserialize.""" + span = ir.Span.unknown() + shape = [ + ir.ConstInt(16, DataType.INT64, span), + ir.ConstInt(16, DataType.INT64, span), + ] + tensor_view = ir.TensorView( + stride=[], + layout=ir.TensorLayout.ND, + pad=ir.PadValue.zero, + ) + tensor_type = ir.TensorType(shape, DataType.FP32, None, tensor_view) + var = ir.Var("tensor_var", tensor_type, span) + + serialized = ir.serialize(var) + restored = ir.deserialize(serialized) + restored_var = cast(ir.Var, restored) + + restored_tensor_type = restored_var.type + assert isinstance(restored_tensor_type, ir.TensorType) + assert restored_tensor_type.tensor_view is not None + assert restored_tensor_type.tensor_view.pad == ir.PadValue.zero + + ir.assert_structural_equal(var, restored_var, enable_auto_mapping=True) + def test_tiletype_with_memref_and_memory_space(self): """TileType with both memref and memory_space preserves both.""" # Create MemRef