Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/en/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ tensor_with_view = ir.TensorType([128, 256], DataType.FP32, memref=None, tensor_
# With valid_shape
tensor_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128])

# With pad mode for out-of-valid-shape accesses (symmetric with TileView)
tensor_view = ir.TensorView(
stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128], pad=ir.PadValue.zero
)

# Different layouts
nd_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND) # ND layout
dn_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.DN) # DN layout
Expand All @@ -70,6 +75,15 @@ tensor_with_both = ir.TensorType([128, 256], DataType.FP16, memref=memref, tenso
- `DN`: DN layout
- `NZ`: NZ layout

**TensorView fields:**

- `stride`: stride for each dimension
- `layout`: `TensorLayout.ND` / `DN` / `NZ`
- `valid_shape`: optional valid-region dimensions (empty means use full shape)
- `pad`: `PadValue.null` (default) / `zero` / `max` / `min` — padding mode used
when loads/slices read outside the `valid_shape`. Peer of `TileView.pad`;
`tensor.slice(..., pad_value=PadValue.zero)` writes this field.

### TileType

Specialized tensor with optional memory and view information for hardware-optimized operations.
Expand Down
14 changes: 14 additions & 0 deletions docs/zh-cn/dev/ir/02-types.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ tensor_with_view = ir.TensorType([128, 256], DataType.FP32, memref=None, tensor_
# With valid_shape
tensor_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128])

# With pad mode for out-of-valid-shape accesses (symmetric with TileView)
tensor_view = ir.TensorView(
stride=[1, 128], layout=ir.TensorLayout.ND, valid_shape=[64, 128], pad=ir.PadValue.zero
)

# Different layouts
nd_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.ND) # ND layout
dn_view = ir.TensorView(stride=[1, 128], layout=ir.TensorLayout.DN) # DN layout
Expand All @@ -70,6 +75,15 @@ tensor_with_both = ir.TensorType([128, 256], DataType.FP16, memref=memref, tenso
- `DN`:DN 布局
- `NZ`:NZ 布局

**TensorView 字段:**

- `stride`:每个维度的步长
- `layout`:`TensorLayout.ND` / `DN` / `NZ`
- `valid_shape`:可选的有效区域维度(为空表示使用完整 shape)
- `pad`:`PadValue.null`(默认)/ `zero` / `max` / `min`,用于访问超出
`valid_shape` 部分时的填充模式。与 `TileView.pad` 对称;
`tensor.slice(..., pad_value=PadValue.zero)` 会写入该字段。

### TileType

专用张量类型,带可选内存和视图信息,用于硬件优化操作。
Expand Down
2 changes: 1 addition & 1 deletion include/pypto/ir/transforms/utils/memref_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ inline std::optional<TensorView> RemapTensorViewExprs(const std::optional<Tensor
return tensor_view;
}
changed = true;
return TensorView(std::move(new_stride), tensor_view->layout, std::move(new_valid_shape));
return TensorView(std::move(new_stride), tensor_view->layout, std::move(new_valid_shape), tensor_view->pad);
}

template <typename RemapExprFn>
Expand Down
53 changes: 29 additions & 24 deletions include/pypto/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,18 +153,35 @@ std::string TensorLayoutToString(TensorLayout layout);
*/
TensorLayout StringToTensorLayout(const std::string& str);

/**
* @brief Pad mode enumeration (shared by TileView and TensorView)
*
* Defines the padding mode applied when a tile/tensor view access falls
* outside `valid_shape` but still within the physical shape:
* - null: No padding
* - zero: Pad with zero
* - max: Pad with maximum value of the element type
* - min: Pad with minimum value of the element type
*/
enum class PadValue {
null, ///< No padding
zero, ///< Zero padding
max, ///< Max value padding
min ///< Min value padding
};

/**
* @brief Tensor view representation
*
* Represents the view information for a tensor, including stride and layout.
* The shape is stored in TensorType itself, so TensorView only needs
* stride and layout information.
* Represents the view information for a tensor, including stride, layout,
* valid_shape, and pad mode. The shape is stored in TensorType itself.
*/
struct TensorView {
std::vector<ExprPtr> stride; ///< Stride for each dimension
TensorLayout layout; ///< Tensor layout type
std::vector<ExprPtr>
valid_shape; ///< Valid shape for each dimension (optional, empty means use full shape)
valid_shape; ///< Valid shape for each dimension (optional, empty means use full shape)
PadValue pad = PadValue::null; ///< Pad mode for accesses outside valid_shape but within shape

/**
* @brief Default constructor with ND layout and empty stride/valid_shape
Expand All @@ -177,19 +194,22 @@ struct TensorView {
* @param stride Stride for each dimension
* @param layout Tensor layout type
* @param valid_shape Valid shape for each dimension (optional, defaults to empty)
* @param pad Pad mode (optional, defaults to PadValue::null)
*/
TensorView(std::vector<ExprPtr> stride, TensorLayout layout, std::vector<ExprPtr> valid_shape = {})
: stride(std::move(stride)), layout(layout), valid_shape(std::move(valid_shape)) {}
TensorView(std::vector<ExprPtr> stride, TensorLayout layout, std::vector<ExprPtr> valid_shape = {},
PadValue pad = PadValue::null)
: stride(std::move(stride)), layout(layout), valid_shape(std::move(valid_shape)), pad(pad) {}

/**
* @brief Constructor with integer stride and valid_shape (auto-converted to ConstInt)
*
* @param stride Stride for each dimension (int64, converted to ConstInt with INDEX dtype)
* @param layout Tensor layout type
* @param valid_shape Valid shape for each dimension (int64, defaults to empty)
* @param pad Pad mode (optional, defaults to PadValue::null)
*/
TensorView(const std::vector<int64_t>& stride, TensorLayout layout,
const std::vector<int64_t>& valid_shape = {});
const std::vector<int64_t>& valid_shape = {}, PadValue pad = PadValue::null);

/**
* @brief Get field descriptors for reflection-based visitation
Expand All @@ -199,7 +219,8 @@ struct TensorView {
static constexpr auto GetFieldDescriptors() {
return std::make_tuple(reflection::UsualField(&TensorView::stride, "stride"),
reflection::UsualField(&TensorView::layout, "layout"),
reflection::UsualField(&TensorView::valid_shape, "valid_shape"));
reflection::UsualField(&TensorView::valid_shape, "valid_shape"),
reflection::UsualField(&TensorView::pad, "pad"));
}
};

Expand Down Expand Up @@ -227,22 +248,6 @@ std::string TileLayoutToString(TileLayout layout);
*/
TileLayout StringToTileLayout(const std::string& str);

/**
* @brief Tile pad enumeration
*
* Defines the padding mode for out-of-bound tile accesses:
* - null: No padding
* - zero: Pad with zero
* - max: Pad with maximum value of the element type
* - min: Pad with minimum value of the element type
*/
enum class PadValue {
null, ///< No padding
zero, ///< Zero padding
max, ///< Max value padding
min ///< Min value padding
};

/**
* @brief Tile view representation
*
Expand Down
32 changes: 18 additions & 14 deletions python/bindings/modules/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -291,18 +291,30 @@ void BindIR(nb::module_& m) {
.value("NZ", TensorLayout::NZ, "NZ layout")
.export_values();

// PadValue enum - must be before both TensorView and TileView since both carry it
nb::enum_<PadValue>(ir, "PadValue", "Pad mode enumeration for tile/tensor views")
.value("null", PadValue::null, "No padding")
.value("zero", PadValue::zero, "Zero padding")
.value("max", PadValue::max, "Max value padding")
.value("min", PadValue::min, "Min value padding")
.export_values();

// TensorView - struct for tensor view information - must be before TensorType
nb::class_<TensorView>(ir, "TensorView", "Tensor view representation with stride, layout and valid shape")
nb::class_<TensorView>(ir, "TensorView",
"Tensor view representation with stride, layout, valid shape, and pad mode")
.def(nb::init<>(), "Create an empty tensor view")
.def(nb::init<const std::vector<ExprPtr>&, TensorLayout, const std::vector<ExprPtr>&>(),
.def(nb::init<const std::vector<ExprPtr>&, TensorLayout, const std::vector<ExprPtr>&, PadValue>(),
nb::arg("stride"), nb::arg("layout"), nb::arg("valid_shape") = std::vector<ExprPtr>{},
"Create a tensor view with stride, layout and optional valid shape")
.def(nb::init<const std::vector<int64_t>&, TensorLayout, const std::vector<int64_t>&>(),
nb::arg("pad") = PadValue::null,
"Create a tensor view with stride, layout, optional valid shape, and optional pad")
.def(nb::init<const std::vector<int64_t>&, TensorLayout, const std::vector<int64_t>&, PadValue>(),
nb::arg("stride"), nb::arg("layout"), nb::arg("valid_shape") = std::vector<int64_t>{},
"Create a tensor view with integer stride, layout and optional integer valid shape")
nb::arg("pad") = PadValue::null,
"Create a tensor view with integer stride, layout, optional integer valid shape, and optional pad")
.def_rw("stride", &TensorView::stride, "Stride for each dimension")
.def_rw("layout", &TensorView::layout, "Tensor layout type")
.def_rw("valid_shape", &TensorView::valid_shape, "Valid shape for each dimension");
.def_rw("valid_shape", &TensorView::valid_shape, "Valid shape for each dimension")
.def_rw("pad", &TensorView::pad, "Pad mode for out-of-valid-shape accesses");

// TensorType - const shared_ptr
auto tensor_type_class = nb::class_<TensorType, ShapedType>(ir, "TensorType", "Tensor type representation");
Expand Down Expand Up @@ -392,14 +404,6 @@ void BindIR(nb::module_& m) {
.value("col_major", TileLayout::col_major, "Column-major layout")
.export_values();

// PadValue enum - must be before TileView
nb::enum_<PadValue>(ir, "PadValue", "Tile pad mode enumeration")
.value("null", PadValue::null, "No padding")
.value("zero", PadValue::zero, "Zero padding")
.value("max", PadValue::max, "Max value padding")
.value("min", PadValue::min, "Min value padding")
.export_values();

// TileView - struct for tile view information
nb::class_<TileView>(
ir, "TileView",
Expand Down
19 changes: 18 additions & 1 deletion python/pypto/ir/op/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def slice(
shape: list[int | Expr] | _ir_core.MakeTuple,
offset: list[int | Expr] | _ir_core.MakeTuple,
valid_shape: list[int | Expr] | _ir_core.MakeTuple | None = None,
pad_value: PadValue | int | float | None = None,
span: Span | None = None,
) -> Call:
"""Create a slice of a tensor with new shape and offset.
Expand All @@ -171,6 +172,13 @@ def slice(
shape: New shape dimensions, or a MakeTuple
offset: Offset dimensions for the slice, or a MakeTuple
valid_shape: Valid shape dimensions (optional, defaults to empty)
pad_value: Optional padding mode for out-of-valid-shape elements.
Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or
the literal sugars ``0``, ``math.inf``, ``-math.inf`` (normalized
via :func:`normalize_pad_value`). ``PadValue.null`` is passed
through unchanged and means "no padding". When omitted (``None``),
the kwarg is not forwarded — the deducer defaults to
``PadValue.null``.
span: Optional source span for debugging (auto-captured if not provided)

Returns:
Expand All @@ -184,7 +192,16 @@ def slice(
args = [tensor, shape_tuple, offset_tuple]
if valid_shape is not None:
args.append(_to_make_tuple(valid_shape, actual_span))
return _ir_core.create_op_call("tensor.slice", args, {}, actual_span)

kwargs: dict[str, Any] = {}
if pad_value is not None:
# PadValue.null is a legal "no padding" signal for slice (unlike
# fillpad, which requires a real padding mode). Pass it through;
# normalize the rest via the shared helper so numeric sugar and
# validation match tensor.fillpad exactly.
kwargs["pad_value"] = pad_value if pad_value is PadValue.null else normalize_pad_value(pad_value)

return _ir_core.create_op_call("tensor.slice", args, kwargs, actual_span)


def fillpad(
Expand Down
7 changes: 4 additions & 3 deletions python/pypto/ir/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,12 +140,13 @@ def __call__(
stride: Sequence[Expr | int] | None = None,
layout: TensorLayout | None = None,
valid_shape: Sequence[Expr | int] | None = None,
pad: PadValue = PadValue.null,
) -> "_TensorViewBase":
if stride is None and layout is None and valid_shape is None:
if stride is None and layout is None and valid_shape is None and pad == PadValue.null:
return _TensorViewBase()
if layout is None:
raise ValueError("layout is required when stride or valid_shape is provided")
return _TensorViewBase(_normalize_seq(stride), layout, _normalize_seq(valid_shape))
raise ValueError("layout is required when stride, valid_shape, or pad is provided")
return _TensorViewBase(_normalize_seq(stride), layout, _normalize_seq(valid_shape), pad)


class TensorView(metaclass=_TensorViewMeta):
Expand Down
19 changes: 19 additions & 0 deletions python/pypto/language/op/tensor_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
that accept and return Tensor types instead of raw Expr/Call objects.
"""

import warnings
Comment thread
Hzfengsy marked this conversation as resolved.
from collections.abc import Sequence
from typing import overload

Expand Down Expand Up @@ -168,6 +169,7 @@ def slice(
shape: Sequence[IntLike],
offset: Sequence[IntLike],
valid_shape: Sequence[IntLike] | None = None,
pad_value: PadValue | int | float | None = None,
) -> Tensor:
"""Create a slice of a tensor with new shape and optional valid shape.

Expand All @@ -176,17 +178,34 @@ def slice(
shape: New shape dimensions
offset: Offset dimensions for the slice
valid_shape: Valid shape dimensions. When omitted, the full shape is valid.
pad_value: Optional padding mode for out-of-valid-shape elements.
``None`` or ``PadValue.null`` means no padding (the default).
Accepts ``PadValue.zero`` / ``PadValue.max`` / ``PadValue.min``, or
the literal sugars ``0``, ``math.inf``, ``-math.inf`` (same
spelling as :func:`tensor.fillpad`). Only meaningful when
``valid_shape`` is smaller than ``shape``.

Returns:
Tensor wrapping the slice operation
"""
if pad_value is not None and pad_value is not PadValue.null and valid_shape is None:
warnings.warn(
f"tensor.slice received pad_value={pad_value!r} but no valid_shape. "
f"pad_value has no effect unless valid_shape is smaller than shape. "
f"If you intend to narrow the valid region later via "
f"tensor.set_validshape, you can ignore this warning; otherwise "
f"pass valid_shape=... to tensor.slice.",
stacklevel=2,
)

tensor_expr = tensor.unwrap()
normalized_valid_shape = None if valid_shape is None else _normalize_intlike(valid_shape)
call_expr = _ir_ops.slice(
tensor_expr,
_normalize_intlike(shape),
_normalize_intlike(offset),
normalized_valid_shape,
pad_value=pad_value,
)
return Tensor(expr=call_expr)

Expand Down
9 changes: 7 additions & 2 deletions python/pypto/pypto_core/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -366,7 +366,7 @@ class PadValue(enum.Enum):
"""Min value padding."""

class TensorView:
"""Tensor view representation with stride, layout and valid shape."""
"""Tensor view representation with stride, layout, valid shape, and pad mode."""

stride: Sequence[Expr]
"""Stride for each dimension."""
Expand All @@ -377,6 +377,9 @@ class TensorView:
valid_shape: Sequence[Expr]
"""Valid shape for each dimension (empty means use full shape)."""

pad: PadValue
"""Pad mode for out-of-valid-shape accesses (default PadValue.null)."""

@overload
def __init__(self) -> None:
"""Create an empty tensor view with default ND layout."""
Expand All @@ -387,13 +390,15 @@ class TensorView:
stride: Sequence[Expr | int | Scalar],
layout: TensorLayout,
valid_shape: Sequence[Expr | int | Scalar] = ...,
pad: PadValue = ...,
) -> None:
"""Create a tensor view with stride, layout and optional valid shape.
"""Create a tensor view with stride, layout, optional valid shape, and optional pad.
Args:
stride: Stride for each dimension (Expr, int, or Scalar/DynVar)
layout: Tensor layout type (ND, DN, or NZ)
valid_shape: Valid shape for each dimension (optional, defaults to empty)
pad: Pad mode for out-of-valid-shape accesses (defaults to PadValue.null)
"""

class TensorType(ShapedType):
Expand Down
Loading
Loading