Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 47 additions & 6 deletions include/pypto/ir/op_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,22 @@ struct OpMemorySpaceSpec {
std::vector<std::vector<MemorySpace>> input_constraints;

/// Resolves output memory space from the Call's kwargs.
/// Returns nullopt to signal "inherit from first tile-typed input" (view ops).
/// Returns nullopt when the space cannot be resolved from kwargs alone — either
/// because the op inherits from its input (see `output_inherits_input`) or
/// because a retargetable kwarg is absent and InferTileMemorySpace must decide.
using OutputResolver =
std::function<std::optional<MemorySpace>(const std::vector<std::pair<std::string, std::any>>& kwargs)>;
OutputResolver deduce_output_memory;

/// When set, the output reuses the MemRef of the input argument at this index.
/// Used by accumulate ops (matmul_acc, gemv_acc) where the output IS the input buffer.
std::optional<size_t> output_reuses_input_arg;

/// True when the output memory space is defined to equal the first tile-typed
/// input's memory space (set via `set_output_memory_inherit_input`).
/// InferTileMemorySpace uses this for forward inheritance and backward-demand
/// propagation through view-like ops; memory reuse uses it to skip retargeting.
bool output_inherits_input = false;
};

/**
Expand Down Expand Up @@ -311,9 +319,14 @@ class OpRegistryEntry {
return *this;
}

/// Set output memory from kwarg (e.g., tile.load reads target_memory)
inline OpRegistryEntry& set_output_memory_from_kwarg(const std::string& kwarg_key = "target_memory",
MemorySpace default_space = MemorySpace::Vec) {
/// Set output memory from kwarg (e.g., tile.load reads target_memory).
/// When the kwarg is absent, the resolver falls back to `default_space`. Pass
/// `std::nullopt` (the default) to mark the op as retargetable: the resolver
/// returns nullopt and InferTileMemorySpace decides the final memory space
/// from producer/consumer context.
inline OpRegistryEntry& set_output_memory_from_kwarg(
const std::string& kwarg_key = "target_memory",
std::optional<MemorySpace> default_space = std::nullopt) {
EnsureMemorySpec();
auto& spec = *memory_spec_; // NOLINT(bugprone-unchecked-optional-access)
spec.deduce_output_memory = [kwarg_key,
Expand All @@ -323,15 +336,18 @@ class OpRegistryEntry {
return std::optional<MemorySpace>(AnyCast<MemorySpace>(v, kwarg_key));
}
}
return std::optional<MemorySpace>(default_space);
return default_space;
};
return *this;
}

/// Set output memory inherited from first tile-typed input (view ops)
/// Set output memory inherited from first tile-typed input (view ops).
/// The resolver returns nullopt; InferTileMemorySpace resolves by copying the input's
/// (already-resolved) memory space onto the output.
inline OpRegistryEntry& set_output_memory_inherit_input() {
EnsureMemorySpec();
auto& spec = *memory_spec_; // NOLINT(bugprone-unchecked-optional-access)
spec.output_inherits_input = true;
spec.deduce_output_memory =
[](const std::vector<std::pair<std::string, std::any>>&) -> std::optional<MemorySpace> {
return std::nullopt;
Expand Down Expand Up @@ -365,6 +381,31 @@ class OpRegistryEntry {
/// Get memory spec (nullopt if not annotated)
[[nodiscard]] const std::optional<OpMemorySpaceSpec>& GetMemorySpec() const { return memory_spec_; }

/// True when this op's output memory space equals its first tile-typed input's
/// (registered via `set_output_memory_inherit_input`). The single source of truth
/// for passes that need to propagate memory-space information through view-like ops
/// (InferTileMemorySpace, memory reuse).
/// An op may combine this with `set_output_reuses_input(idx)` (e.g. in-place
/// variants like tile.fillpad_inplace that reuse the input's MemRef in place);
/// the memory-space-inheritance relation still holds.
[[nodiscard]] bool OutputMemoryInheritsInput() const {
return memory_spec_.has_value() && memory_spec_->output_inherits_input;
}

/// True when this op's output memory space can be chosen by the compiler
/// (e.g. `tile.load`, `tile.create`): the op carries a writable `target_memory`
/// kwarg that InferTileMemorySpace can rewrite to match consumer demand.
/// Inherit-input and fixed-output ops don't participate in retargeting.
/// Distinguishes true deferral (resolver returns nullopt when the kwarg is
/// absent) from ops that carry a `target_memory` kwarg but still produce a
/// concrete default (e.g. `tile.move` → Vec) — those are not retargetable.
[[nodiscard]] bool HasRetargetableMemoryKwarg() const {
if (!memory_spec_.has_value() || !memory_spec_->deduce_output_memory) return false;
if (memory_spec_->output_inherits_input) return false;
if (!op_ || !op_->HasAttr("target_memory")) return false;
return !memory_spec_->deduce_output_memory({}).has_value();
}

/// Declare that this op's output reuses the MemRef of the input at arg_index.
/// Used for accumulate ops where the output writes into the input buffer.
inline OpRegistryEntry& set_output_reuses_input(size_t arg_index) {
Expand Down
16 changes: 12 additions & 4 deletions python/bindings/modules/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,8 @@ void BindIR(nb::module_& m) {
[](const std::string& op_name) -> nb::object {
auto& registry = OpRegistry::GetInstance();
if (!registry.IsRegistered(op_name)) return nb::none();
const auto& spec = registry.GetEntry(op_name).GetMemorySpec();
const auto& entry = registry.GetEntry(op_name);
const auto& spec = entry.GetMemorySpec();
if (!spec.has_value()) return nb::none();
// Empty spec (from no_memory_spec()) — no constraints and no resolver
if (spec->input_constraints.empty() && !spec->deduce_output_memory) return nb::none();
Expand All @@ -486,13 +487,20 @@ void BindIR(nb::module_& m) {
inputs.append(nb::cast(allowed));
}
result["input_constraints"] = inputs;
// Output (resolve with empty kwargs for display)
if (spec->deduce_output_memory) {
// Output (resolve with empty kwargs for display). Distinguishes:
// - Fixed/default-seeded: resolver returns a concrete MemorySpace.
// - Inherit-input (slice/reshape/...): marker string "inherit_from_input".
// - Retargetable with no kwarg default (tile.load/tile.create without
// target_memory): deferred — InferTileMemorySpace resolves from
// consumer demand. Reported as the string "deferred".
if (entry.OutputMemoryInheritsInput()) {
result["output_memory"] = "inherit_from_input";
} else if (spec->deduce_output_memory) {
auto out = spec->deduce_output_memory({});
if (out.has_value()) {
result["output_memory"] = *out;
} else {
result["output_memory"] = "inherit_from_input";
result["output_memory"] = "deferred";
}
} else {
result["output_memory"] = nb::none();
Expand Down
13 changes: 11 additions & 2 deletions python/pypto/pypto_core/ir.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2493,8 +2493,17 @@ def get_op_memory_spec(op_name: str) -> dict[str, Any] | None:

Returns:
Dict with 'input_constraints' (list of lists of MemorySpace) and
'output_memory' (MemorySpace, 'inherit_from_input', or None) keys,
or None if the operator has no memory spec or is not registered.
'output_memory' keys, or None if the operator has no memory spec or
is not registered. 'output_memory' is one of:

* A ``MemorySpace`` — fixed or kwarg-resolved (e.g. `tile.matmul` →
``MemorySpace.Acc``).
* ``'inherit_from_input'`` — the output takes its memory space from
the first tile-typed input (e.g. `tile.slice`, `tile.reshape`).
* ``'deferred'`` — a retargetable producer whose ``target_memory``
kwarg was absent; `InferTileMemorySpace` resolves it later from
consumer demand (e.g. `tile.load`, `tile.create`).
* ``None`` — no resolver registered for this op.
"""

# ========== Op Conversion Registry ==========
Expand Down
53 changes: 36 additions & 17 deletions src/ir/op/tile_ops/memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include <utility>
#include <vector>

#include "pypto/core/any_cast.h"
#include "pypto/core/dtype.h"
#include "pypto/core/error.h"
#include "pypto/core/logging.h"
Expand Down Expand Up @@ -126,30 +127,44 @@ TypePtr DeduceTileLoadType(const std::vector<ExprPtr>& args,
CHECK(shapes_tuple->elements_.size() > 0)
<< "The operator " << op_name << " requires at least one dimension, but got empty shapes tuple";

auto target_memory = GetKwarg<MemorySpace>(kwargs, "target_memory");
// target_memory is optional: when absent, memory_space stays unresolved and
// InferTileMemorySpace will pick it from consumer demand. Layout is deferred in
// that case — the pass recomputes TileView via GetImplicitTileView once the
// space is known.
std::optional<MemorySpace> target_memory_opt;
for (const auto& [k, v] : kwargs) {
if (k == "target_memory") {
target_memory_opt = AnyCast<MemorySpace>(v, "target_memory");
break;
}
}
bool transpose = GetKwarg<bool>(kwargs, "transpose", false);

// Transpose is only supported when loading to L1 (Mat)
CHECK(!transpose || target_memory == MemorySpace::Mat)
<< "The operator " << op_name
<< " only supports transpose=true when target_memory is Mat (L1), but got "
<< static_cast<int>(target_memory);
// Transpose semantics are Mat-specific. Callers that use transpose=true must
// commit to target_memory=Mat at construction — InferTileMemorySpace does not
// revisit transpose decisions.
CHECK(!transpose || (target_memory_opt.has_value() && *target_memory_opt == MemorySpace::Mat))
<< "The operator " << op_name << " only supports transpose=true when target_memory is Mat (L1)";

CHECK(!transpose || shapes_tuple->elements_.size() >= 2)
<< "The operator " << op_name << " requires at least 2D shapes for transpose=true, but got "
<< shapes_tuple->elements_.size() << "D";

// Nz/Zn for transpose false/true
// Nz/Zn layout: only chosen when target_memory is known. If it is absent,
// the default-constructed view is kept and InferTileMemorySpace rebuilds it
// once the memory space is resolved.
TileView tile_view;
if (target_memory == MemorySpace::Mat) {
tile_view.blayout = TileLayout::col_major;
tile_view.slayout = TileLayout::row_major;
if (transpose) {
std::swap(tile_view.blayout, tile_view.slayout);
if (target_memory_opt.has_value()) {
if (*target_memory_opt == MemorySpace::Mat) {
tile_view.blayout = TileLayout::col_major;
tile_view.slayout = TileLayout::row_major;
if (transpose) {
std::swap(tile_view.blayout, tile_view.slayout);
}
} else if (auto last_dim = As<ConstInt>(shapes_tuple->elements_.back());
last_dim && last_dim->value_ == 1) {
tile_view.blayout = TileLayout::col_major;
}
} else if (auto last_dim = As<ConstInt>(shapes_tuple->elements_.back());
last_dim && last_dim->value_ == 1) {
tile_view.blayout = TileLayout::col_major;
}

// Build tile shape from shapes tuple.
Expand Down Expand Up @@ -515,7 +530,9 @@ REGISTER_OP("tile.create")
.add_argument("shape", "Shape dimensions (TupleType of ScalarType(INT64))")
.set_attr<DataType>("dtype")
.set_attr<MemorySpace>("target_memory")
.set_output_memory_from_kwarg("target_memory", MemorySpace::Vec)
// No fallback: when target_memory is absent, memory_space stays unresolved and
// InferTileMemorySpace picks the space from consumer demand.
.set_output_memory_from_kwarg("target_memory")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTileCreateTileType(args, kwargs, "tile.create");
Expand All @@ -535,7 +552,9 @@ REGISTER_OP("tile.load")
"Valid shape of tile in each dimension, in source tensor coordinates (TupleType of ScalarType). ")
.set_attr<MemorySpace>("target_memory")
.set_attr<bool>("transpose")
.set_output_memory_from_kwarg("target_memory", MemorySpace::Vec)
// No fallback: when target_memory is absent, memory_space stays unresolved and
// InferTileMemorySpace picks the space from consumer demand.
.set_output_memory_from_kwarg("target_memory")
.f_deduce_type([](const std::vector<ExprPtr>& args,
const std::vector<std::pair<std::string, std::any>>& kwargs) {
return DeduceTileLoadType(args, kwargs, "tile.load");
Expand Down
56 changes: 56 additions & 0 deletions src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -227,9 +227,58 @@ class ConsumerSpaceCollector : public IRVisitor {
return it != consumer_reqs_.end() ? std::optional{it->second} : std::nullopt;
}

/// Second phase: propagate collected requirements backward through
/// (a) ops registered with `set_output_memory_inherit_input()` — output
/// memory equals the first tile/tensor-typed input's, so a demand on
/// the output is equivalently a demand on that input, and
/// (b) plain SSA aliases `y = x` where both sides are shaped Vars (the
/// parser elides no-op `tensor.fillpad(pad=zero)` into this form when
/// the input's valid_shape already zeroes the pad region).
///
/// Edges are recorded in program order during the forward visit. Since the
/// inherit-input and alias relations are acyclic and flow strictly backward
/// (output/dst defined after input/src), a single reverse-order sweep
/// reaches the fixed point in O(N). Total pass cost stays O(N log N).
void PropagateThroughInheritInputOps() {
for (auto it = propagation_edges_.rbegin(); it != propagation_edges_.rend(); ++it) {
const auto& [dst, src] = *it;
auto out_it = consumer_reqs_.find(dst);
if (out_it == consumer_reqs_.end()) continue;
const auto& req = out_it->second;
auto [ins_it, inserted] = consumer_reqs_.try_emplace(src, req);
if (!inserted && ins_it->second.space == MemorySpace::Vec && req.space != MemorySpace::Vec) {
ins_it->second = req;
}
}
}

protected:
void VisitStmt_(const AssignStmtPtr& op) override {
if (!op) return;
auto is_shaped = [](const TypePtr& t) { return As<TensorType>(t) || As<TileType>(t); };

// Record a propagation edge `dst -> src` in program order when the RHS is
// either a plain SSA alias (both sides shaped) or an inherit-input Call
// (first shaped input carries the memory-space relation). The reverse walk
// in phase 2 then resolves all back-propagation in a single pass.
if (op->var_ && is_shaped(op->var_->GetType())) {
if (auto src_var = As<Var>(op->value_); src_var && is_shaped(src_var->GetType())) {
propagation_edges_.emplace_back(op->var_.get(), src_var.get());
} else if (auto call = As<Call>(op->value_);
call && !std::dynamic_pointer_cast<const GlobalVar>(call->op_)) {
auto& op_reg = OpRegistry::GetInstance();
if (op_reg.IsRegistered(call->op_->name_) &&
op_reg.GetEntry(call->op_->name_).OutputMemoryInheritsInput()) {
for (const auto& arg : call->args_) {
if (auto arg_var = As<Var>(arg); arg_var && is_shaped(arg_var->GetType())) {
propagation_edges_.emplace_back(op->var_.get(), arg_var.get());
break;
}
}
}
}
}

auto call = As<Call>(op->value_);
if (!call || std::dynamic_pointer_cast<const GlobalVar>(call->op_)) {
IRVisitor::VisitStmt_(op);
Expand Down Expand Up @@ -261,6 +310,10 @@ class ConsumerSpaceCollector : public IRVisitor {
private:
const OpConversionRegistry& registry_;
std::unordered_map<const Var*, ConsumerSpaceReq> consumer_reqs_;
// `dst -> src` edges captured in program order — covers both Call-valued
// inherit-input ops and plain SSA aliases. A single reverse-order walk in
// PropagateThroughInheritInputOps reaches the fixed point.
std::vector<std::pair<const Var*, const Var*>> propagation_edges_;
};

// ============================================================================
Expand Down Expand Up @@ -1186,8 +1239,11 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func) {

// Pre-scan: collect consumer memory space requirements (e.g. tensor.slice → tensor.matmul
// needs Mat-space loads). Driven by InputSpaceReq metadata in OpConversionRegistry.
// Then propagate demands backward through pass-through ops (tensor.fillpad etc.) so a
// chain like `slice → fillpad → matmul` routes the slice's load directly into Mat.
ConsumerSpaceCollector consumer_collector(conv_registry);
consumer_collector.VisitStmt(func->body_);
consumer_collector.PropagateThroughInheritInputOps();

// Create the body mutator
TensorToTileMutator mutator(conv_registry, op_registry, consumer_collector);
Expand Down
Loading
Loading