diff --git a/include/pypto/ir/op_registry.h b/include/pypto/ir/op_registry.h index c37ba4170..005298ecc 100644 --- a/include/pypto/ir/op_registry.h +++ b/include/pypto/ir/op_registry.h @@ -55,7 +55,9 @@ struct OpMemorySpaceSpec { std::vector> input_constraints; /// Resolves output memory space from the Call's kwargs. - /// Returns nullopt to signal "inherit from first tile-typed input" (view ops). + /// Returns nullopt when the space cannot be resolved from kwargs alone — either + /// because the op inherits from its input (see `output_inherits_input`) or + /// because a retargetable kwarg is absent and InferTileMemorySpace must decide. using OutputResolver = std::function(const std::vector>& kwargs)>; OutputResolver deduce_output_memory; @@ -63,6 +65,12 @@ struct OpMemorySpaceSpec { /// When set, the output reuses the MemRef of the input argument at this index. /// Used by accumulate ops (matmul_acc, gemv_acc) where the output IS the input buffer. std::optional output_reuses_input_arg; + + /// True when the output memory space is defined to equal the first tile-typed + /// input's memory space (set via `set_output_memory_inherit_input`). + /// InferTileMemorySpace uses this for forward inheritance and backward-demand + /// propagation through view-like ops; memory reuse uses it to skip retargeting. + bool output_inherits_input = false; }; /** @@ -311,9 +319,14 @@ class OpRegistryEntry { return *this; } - /// Set output memory from kwarg (e.g., tile.load reads target_memory) - inline OpRegistryEntry& set_output_memory_from_kwarg(const std::string& kwarg_key = "target_memory", - MemorySpace default_space = MemorySpace::Vec) { + /// Set output memory from kwarg (e.g., tile.load reads target_memory). + /// When the kwarg is absent, the resolver falls back to `default_space`. Pass + /// `std::nullopt` (the default) to mark the op as retargetable: the resolver + /// returns nullopt and InferTileMemorySpace decides the final memory space + /// from producer/consumer context. + inline OpRegistryEntry& set_output_memory_from_kwarg( + const std::string& kwarg_key = "target_memory", + std::optional default_space = std::nullopt) { EnsureMemorySpec(); auto& spec = *memory_spec_; // NOLINT(bugprone-unchecked-optional-access) spec.deduce_output_memory = [kwarg_key, @@ -323,15 +336,18 @@ class OpRegistryEntry { return std::optional(AnyCast(v, kwarg_key)); } } - return std::optional(default_space); + return default_space; }; return *this; } - /// Set output memory inherited from first tile-typed input (view ops) + /// Set output memory inherited from first tile-typed input (view ops). + /// The resolver returns nullopt; InferTileMemorySpace resolves by copying the input's + /// (already-resolved) memory space onto the output. inline OpRegistryEntry& set_output_memory_inherit_input() { EnsureMemorySpec(); auto& spec = *memory_spec_; // NOLINT(bugprone-unchecked-optional-access) + spec.output_inherits_input = true; spec.deduce_output_memory = [](const std::vector>&) -> std::optional { return std::nullopt; @@ -365,6 +381,31 @@ class OpRegistryEntry { /// Get memory spec (nullopt if not annotated) [[nodiscard]] const std::optional& GetMemorySpec() const { return memory_spec_; } + /// True when this op's output memory space equals its first tile-typed input's + /// (registered via `set_output_memory_inherit_input`). The single source of truth + /// for passes that need to propagate memory-space information through view-like ops + /// (InferTileMemorySpace, memory reuse). + /// An op may combine this with `set_output_reuses_input(idx)` (e.g. in-place + /// variants like tile.fillpad_inplace that reuse the input's MemRef in place); + /// the memory-space-inheritance relation still holds. + [[nodiscard]] bool OutputMemoryInheritsInput() const { + return memory_spec_.has_value() && memory_spec_->output_inherits_input; + } + + /// True when this op's output memory space can be chosen by the compiler + /// (e.g. `tile.load`, `tile.create`): the op carries a writable `target_memory` + /// kwarg that InferTileMemorySpace can rewrite to match consumer demand. + /// Inherit-input and fixed-output ops don't participate in retargeting. + /// Distinguishes true deferral (resolver returns nullopt when the kwarg is + /// absent) from ops that carry a `target_memory` kwarg but still produce a + /// concrete default (e.g. `tile.move` → Vec) — those are not retargetable. + [[nodiscard]] bool HasRetargetableMemoryKwarg() const { + if (!memory_spec_.has_value() || !memory_spec_->deduce_output_memory) return false; + if (memory_spec_->output_inherits_input) return false; + if (!op_ || !op_->HasAttr("target_memory")) return false; + return !memory_spec_->deduce_output_memory({}).has_value(); + } + /// Declare that this op's output reuses the MemRef of the input at arg_index. /// Used for accumulate ops where the output writes into the input buffer. inline OpRegistryEntry& set_output_reuses_input(size_t arg_index) { diff --git a/python/bindings/modules/ir.cpp b/python/bindings/modules/ir.cpp index ae3ee96f8..e61f35559 100644 --- a/python/bindings/modules/ir.cpp +++ b/python/bindings/modules/ir.cpp @@ -472,7 +472,8 @@ void BindIR(nb::module_& m) { [](const std::string& op_name) -> nb::object { auto& registry = OpRegistry::GetInstance(); if (!registry.IsRegistered(op_name)) return nb::none(); - const auto& spec = registry.GetEntry(op_name).GetMemorySpec(); + const auto& entry = registry.GetEntry(op_name); + const auto& spec = entry.GetMemorySpec(); if (!spec.has_value()) return nb::none(); // Empty spec (from no_memory_spec()) — no constraints and no resolver if (spec->input_constraints.empty() && !spec->deduce_output_memory) return nb::none(); @@ -486,13 +487,20 @@ void BindIR(nb::module_& m) { inputs.append(nb::cast(allowed)); } result["input_constraints"] = inputs; - // Output (resolve with empty kwargs for display) - if (spec->deduce_output_memory) { + // Output (resolve with empty kwargs for display). Distinguishes: + // - Fixed/default-seeded: resolver returns a concrete MemorySpace. + // - Inherit-input (slice/reshape/...): marker string "inherit_from_input". + // - Retargetable with no kwarg default (tile.load/tile.create without + // target_memory): deferred — InferTileMemorySpace resolves from + // consumer demand. Reported as the string "deferred". + if (entry.OutputMemoryInheritsInput()) { + result["output_memory"] = "inherit_from_input"; + } else if (spec->deduce_output_memory) { auto out = spec->deduce_output_memory({}); if (out.has_value()) { result["output_memory"] = *out; } else { - result["output_memory"] = "inherit_from_input"; + result["output_memory"] = "deferred"; } } else { result["output_memory"] = nb::none(); diff --git a/python/pypto/pypto_core/ir.pyi b/python/pypto/pypto_core/ir.pyi index cc0839075..bcd3836f8 100644 --- a/python/pypto/pypto_core/ir.pyi +++ b/python/pypto/pypto_core/ir.pyi @@ -2493,8 +2493,17 @@ def get_op_memory_spec(op_name: str) -> dict[str, Any] | None: Returns: Dict with 'input_constraints' (list of lists of MemorySpace) and - 'output_memory' (MemorySpace, 'inherit_from_input', or None) keys, - or None if the operator has no memory spec or is not registered. + 'output_memory' keys, or None if the operator has no memory spec or + is not registered. 'output_memory' is one of: + + * A ``MemorySpace`` — fixed or kwarg-resolved (e.g. `tile.matmul` → + ``MemorySpace.Acc``). + * ``'inherit_from_input'`` — the output takes its memory space from + the first tile-typed input (e.g. `tile.slice`, `tile.reshape`). + * ``'deferred'`` — a retargetable producer whose ``target_memory`` + kwarg was absent; `InferTileMemorySpace` resolves it later from + consumer demand (e.g. `tile.load`, `tile.create`). + * ``None`` — no resolver registered for this op. """ # ========== Op Conversion Registry ========== diff --git a/src/ir/op/tile_ops/memory.cpp b/src/ir/op/tile_ops/memory.cpp index 7a01ed8e1..563d02dab 100644 --- a/src/ir/op/tile_ops/memory.cpp +++ b/src/ir/op/tile_ops/memory.cpp @@ -26,6 +26,7 @@ #include #include +#include "pypto/core/any_cast.h" #include "pypto/core/dtype.h" #include "pypto/core/error.h" #include "pypto/core/logging.h" @@ -126,30 +127,44 @@ TypePtr DeduceTileLoadType(const std::vector& args, CHECK(shapes_tuple->elements_.size() > 0) << "The operator " << op_name << " requires at least one dimension, but got empty shapes tuple"; - auto target_memory = GetKwarg(kwargs, "target_memory"); + // target_memory is optional: when absent, memory_space stays unresolved and + // InferTileMemorySpace will pick it from consumer demand. Layout is deferred in + // that case — the pass recomputes TileView via GetImplicitTileView once the + // space is known. + std::optional target_memory_opt; + for (const auto& [k, v] : kwargs) { + if (k == "target_memory") { + target_memory_opt = AnyCast(v, "target_memory"); + break; + } + } bool transpose = GetKwarg(kwargs, "transpose", false); - // Transpose is only supported when loading to L1 (Mat) - CHECK(!transpose || target_memory == MemorySpace::Mat) - << "The operator " << op_name - << " only supports transpose=true when target_memory is Mat (L1), but got " - << static_cast(target_memory); + // Transpose semantics are Mat-specific. Callers that use transpose=true must + // commit to target_memory=Mat at construction — InferTileMemorySpace does not + // revisit transpose decisions. + CHECK(!transpose || (target_memory_opt.has_value() && *target_memory_opt == MemorySpace::Mat)) + << "The operator " << op_name << " only supports transpose=true when target_memory is Mat (L1)"; CHECK(!transpose || shapes_tuple->elements_.size() >= 2) << "The operator " << op_name << " requires at least 2D shapes for transpose=true, but got " << shapes_tuple->elements_.size() << "D"; - // Nz/Zn for transpose false/true + // Nz/Zn layout: only chosen when target_memory is known. If it is absent, + // the default-constructed view is kept and InferTileMemorySpace rebuilds it + // once the memory space is resolved. TileView tile_view; - if (target_memory == MemorySpace::Mat) { - tile_view.blayout = TileLayout::col_major; - tile_view.slayout = TileLayout::row_major; - if (transpose) { - std::swap(tile_view.blayout, tile_view.slayout); + if (target_memory_opt.has_value()) { + if (*target_memory_opt == MemorySpace::Mat) { + tile_view.blayout = TileLayout::col_major; + tile_view.slayout = TileLayout::row_major; + if (transpose) { + std::swap(tile_view.blayout, tile_view.slayout); + } + } else if (auto last_dim = As(shapes_tuple->elements_.back()); + last_dim && last_dim->value_ == 1) { + tile_view.blayout = TileLayout::col_major; } - } else if (auto last_dim = As(shapes_tuple->elements_.back()); - last_dim && last_dim->value_ == 1) { - tile_view.blayout = TileLayout::col_major; } // Build tile shape from shapes tuple. @@ -515,7 +530,9 @@ REGISTER_OP("tile.create") .add_argument("shape", "Shape dimensions (TupleType of ScalarType(INT64))") .set_attr("dtype") .set_attr("target_memory") - .set_output_memory_from_kwarg("target_memory", MemorySpace::Vec) + // No fallback: when target_memory is absent, memory_space stays unresolved and + // InferTileMemorySpace picks the space from consumer demand. + .set_output_memory_from_kwarg("target_memory") .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { return DeduceTileCreateTileType(args, kwargs, "tile.create"); @@ -535,7 +552,9 @@ REGISTER_OP("tile.load") "Valid shape of tile in each dimension, in source tensor coordinates (TupleType of ScalarType). ") .set_attr("target_memory") .set_attr("transpose") - .set_output_memory_from_kwarg("target_memory", MemorySpace::Vec) + // No fallback: when target_memory is absent, memory_space stays unresolved and + // InferTileMemorySpace picks the space from consumer demand. + .set_output_memory_from_kwarg("target_memory") .f_deduce_type([](const std::vector& args, const std::vector>& kwargs) { return DeduceTileLoadType(args, kwargs, "tile.load"); diff --git a/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp b/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp index db8ea9031..c518a2b0d 100644 --- a/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp +++ b/src/ir/transforms/convert_tensor_to_tile_ops_pass.cpp @@ -227,9 +227,58 @@ class ConsumerSpaceCollector : public IRVisitor { return it != consumer_reqs_.end() ? std::optional{it->second} : std::nullopt; } + /// Second phase: propagate collected requirements backward through + /// (a) ops registered with `set_output_memory_inherit_input()` — output + /// memory equals the first tile/tensor-typed input's, so a demand on + /// the output is equivalently a demand on that input, and + /// (b) plain SSA aliases `y = x` where both sides are shaped Vars (the + /// parser elides no-op `tensor.fillpad(pad=zero)` into this form when + /// the input's valid_shape already zeroes the pad region). + /// + /// Edges are recorded in program order during the forward visit. Since the + /// inherit-input and alias relations are acyclic and flow strictly backward + /// (output/dst defined after input/src), a single reverse-order sweep + /// reaches the fixed point in O(N). Total pass cost stays O(N log N). + void PropagateThroughInheritInputOps() { + for (auto it = propagation_edges_.rbegin(); it != propagation_edges_.rend(); ++it) { + const auto& [dst, src] = *it; + auto out_it = consumer_reqs_.find(dst); + if (out_it == consumer_reqs_.end()) continue; + const auto& req = out_it->second; + auto [ins_it, inserted] = consumer_reqs_.try_emplace(src, req); + if (!inserted && ins_it->second.space == MemorySpace::Vec && req.space != MemorySpace::Vec) { + ins_it->second = req; + } + } + } + protected: void VisitStmt_(const AssignStmtPtr& op) override { if (!op) return; + auto is_shaped = [](const TypePtr& t) { return As(t) || As(t); }; + + // Record a propagation edge `dst -> src` in program order when the RHS is + // either a plain SSA alias (both sides shaped) or an inherit-input Call + // (first shaped input carries the memory-space relation). The reverse walk + // in phase 2 then resolves all back-propagation in a single pass. + if (op->var_ && is_shaped(op->var_->GetType())) { + if (auto src_var = As(op->value_); src_var && is_shaped(src_var->GetType())) { + propagation_edges_.emplace_back(op->var_.get(), src_var.get()); + } else if (auto call = As(op->value_); + call && !std::dynamic_pointer_cast(call->op_)) { + auto& op_reg = OpRegistry::GetInstance(); + if (op_reg.IsRegistered(call->op_->name_) && + op_reg.GetEntry(call->op_->name_).OutputMemoryInheritsInput()) { + for (const auto& arg : call->args_) { + if (auto arg_var = As(arg); arg_var && is_shaped(arg_var->GetType())) { + propagation_edges_.emplace_back(op->var_.get(), arg_var.get()); + break; + } + } + } + } + } + auto call = As(op->value_); if (!call || std::dynamic_pointer_cast(call->op_)) { IRVisitor::VisitStmt_(op); @@ -261,6 +310,10 @@ class ConsumerSpaceCollector : public IRVisitor { private: const OpConversionRegistry& registry_; std::unordered_map consumer_reqs_; + // `dst -> src` edges captured in program order — covers both Call-valued + // inherit-input ops and plain SSA aliases. A single reverse-order walk in + // PropagateThroughInheritInputOps reaches the fixed point. + std::vector> propagation_edges_; }; // ============================================================================ @@ -1186,8 +1239,11 @@ IncoreTransformResult TransformIncoreFunction(const FunctionPtr& func) { // Pre-scan: collect consumer memory space requirements (e.g. tensor.slice → tensor.matmul // needs Mat-space loads). Driven by InputSpaceReq metadata in OpConversionRegistry. + // Then propagate demands backward through pass-through ops (tensor.fillpad etc.) so a + // chain like `slice → fillpad → matmul` routes the slice's load directly into Mat. ConsumerSpaceCollector consumer_collector(conv_registry); consumer_collector.VisitStmt(func->body_); + consumer_collector.PropagateThroughInheritInputOps(); // Create the body mutator TensorToTileMutator mutator(conv_registry, op_registry, consumer_collector); diff --git a/src/ir/transforms/infer_tile_memory_space_pass.cpp b/src/ir/transforms/infer_tile_memory_space_pass.cpp index c144fd50c..aa7b66272 100644 --- a/src/ir/transforms/infer_tile_memory_space_pass.cpp +++ b/src/ir/transforms/infer_tile_memory_space_pass.cpp @@ -64,13 +64,107 @@ const std::vector>* GetInputConstraints(const std::stri return &spec_opt->input_constraints; } +// Prefer the non-Vec space when two demands collide on the same var. Vec acts as +// the permissive default, so a specialized demand (Mat, Left, Right, Acc) wins. +bool ShouldOverrideDemand(MemorySpace existing, MemorySpace incoming) { + return existing == MemorySpace::Vec && incoming != MemorySpace::Vec; +} + +// ============================================================================ +// Phase 0: Backward demand collection +// +// For each op with `input_constraints`, record "this input var is demanded to +// live in this space". Then propagate demands backward through ops registered +// with `set_output_memory_inherit_input()` to a fixed point so that chains like +// slice(tensor) -> fillpad -> matmul +// push the matmul's Mat demand back through fillpad onto the slice's output, +// enabling the downstream Phase 1 analyzer to resolve the slice-produced tile +// directly to Mat instead of routing through Vec. +// ============================================================================ + +class DemandCollector : public IRVisitor { + public: + [[nodiscard]] const std::map& GetDemands() const { return demands_; } + + void VisitStmt_(const AssignStmtPtr& op) override { + if (auto call = As(op->value_)) { + RecordDirectDemands(call); + RecordInheritInputEdge(op->var_, call); + } + IRVisitor::VisitStmt_(op); + } + + void VisitStmt_(const EvalStmtPtr& op) override { + if (auto call = As(op->expr_)) RecordDirectDemands(call); + IRVisitor::VisitStmt_(op); + } + + /// Propagate demand backward through OutputMemoryInheritsInput() ops. + /// Edges `dst -> src` are captured in program order during the forward visit; + /// since the inherit-input relation flows strictly backward (dst defined + /// after src), a single reverse-order sweep reaches the fixed point in O(N). + void PropagateThroughInheritInputOps() { + for (auto it = edges_.rbegin(); it != edges_.rend(); ++it) { + const auto& [dst, src] = *it; + auto out_it = demands_.find(dst); + if (out_it == demands_.end()) continue; + auto [ins_it, inserted] = demands_.try_emplace(src, out_it->second); + if (!inserted && ShouldOverrideDemand(ins_it->second, out_it->second)) { + ins_it->second = out_it->second; + } + } + } + + private: + std::map demands_; + // `dst -> src` edges for ops with OutputMemoryInheritsInput(), captured in + // program order. Walked in reverse in PropagateThroughInheritInputOps. + std::vector> edges_; + + void RecordDirectDemands(const CallPtr& call) { + auto& reg = OpRegistry::GetInstance(); + if (!reg.IsRegistered(call->op_->name_)) return; + const auto& spec = reg.GetEntry(call->op_->name_).GetMemorySpec(); + if (!spec.has_value()) return; + for (size_t i = 0; i < spec->input_constraints.size() && i < call->args_.size(); ++i) { + const auto& allowed = spec->input_constraints[i]; + if (allowed.empty()) continue; + auto var = As(call->args_[i]); + if (!var) continue; + // Preferred space: the first allowed entry. Backends are expected to list + // the canonical choice first (e.g. tile.store uses {Vec, Acc} — a Vec + // producer needs no move, and Acc-origin tiles keep their space). + MemorySpace demand = allowed[0]; + auto [it, inserted] = demands_.try_emplace(var, demand); + if (!inserted && ShouldOverrideDemand(it->second, demand)) { + it->second = demand; + } + } + } + + void RecordInheritInputEdge(const VarPtr& dst, const CallPtr& call) { + if (!dst) return; + auto& reg = OpRegistry::GetInstance(); + if (!reg.IsRegistered(call->op_->name_)) return; + if (!reg.GetEntry(call->op_->name_).OutputMemoryInheritsInput()) return; + for (const auto& arg : call->args_) { + auto var = As(arg); + if (!var) continue; + if (!As(var->GetType()) && !As(var->GetType())) continue; + edges_.emplace_back(dst, var); + break; // first tile-typed input only (matches inherit-input semantics) + } + } +}; + // ============================================================================ // Phase 1: Analyze - infer memory_space for each tile variable // ============================================================================ class TileMemorySpaceAnalyzer : public IRVisitor { public: - explicit TileMemorySpaceAnalyzer(const std::vector& params) { + TileMemorySpaceAnalyzer(const std::vector& params, const std::map& demands) + : demands_(demands) { for (const auto& var : params) { CHECK(!As(var->GetType())) << "InCore function parameter '" << var->name_hint_ << "' has TileType, but InCore parameters must be TensorType"; @@ -88,11 +182,21 @@ class TileMemorySpaceAnalyzer : public IRVisitor { if (auto call = As(op->value_)) { const std::string& op_name = call->op_->name_; if (op_name.rfind("tile.", 0) == 0) { - var_memory_[op->var_] = InferFromOp(op_name, call); + var_memory_[op->var_] = InferFromOp(op_name, call, op->var_); } else { // Non-tile ops producing TileType: default to Vec var_memory_[op->var_] = MemorySpace::Vec; } + } else if (auto src_var = As(op->value_)) { + // Plain SSA alias `y = x`. Inherit x's memory space onto y so later + // phases (MoveCollector, Phase 3) see a consistent memory_space on the + // alias. The Python frontend emits these when eliding no-op + // tensor.fillpad(pad=zero) calls whose input already has a matching + // valid_shape — the alias is value-identical to its source. + auto src_it = var_memory_.find(src_var); + if (src_it != var_memory_.end()) { + var_memory_[op->var_] = src_it->second; + } } IRVisitor::VisitStmt_(op); @@ -145,9 +249,10 @@ class TileMemorySpaceAnalyzer : public IRVisitor { } private: + const std::map& demands_; std::map var_memory_; - MemorySpace InferFromOp(const std::string& op_name, const CallPtr& call) { + MemorySpace InferFromOp(const std::string& op_name, const CallPtr& call, const VarPtr& out_var) { auto& registry = OpRegistry::GetInstance(); // Handle unregistered ops (backward compat) @@ -156,7 +261,8 @@ class TileMemorySpaceAnalyzer : public IRVisitor { return MemorySpace::Vec; } - const auto& spec_opt = registry.GetEntry(op_name).GetMemorySpec(); + const auto& entry = registry.GetEntry(op_name); + const auto& spec_opt = entry.GetMemorySpec(); if (!spec_opt.has_value() || !spec_opt->deduce_output_memory) { // no_memory_spec ops (e.g. tile.tpop_*): read memory_space from Call return type if (auto tile_type = As(call->GetType())) { @@ -171,11 +277,35 @@ class TileMemorySpaceAnalyzer : public IRVisitor { if (result.has_value()) { return *result; } - // nullopt -> inherit from first tile-typed input (view ops) - return InheritFromInput(call); + + // Resolver returned nullopt — kwarg absent. Two cases: + // (1) Inherit-input op (fillpad/slice/...): output = first tile input's + // space. Demand back-prop ensures input is or will be resolved to + // match consumer demand. + // (2) Retargetable producer whose kwarg is absent (e.g. a converter chose + // to let the pass decide): consult backward demand, then fall back. + // We never override a present kwarg — a Left/Right/Acc demand from a + // compute op (matmul) cannot be satisfied by a DDR load directly and must + // still route through Mat with a subsequent tile.move. + if (spec_opt->output_inherits_input) { + return InheritFromInput(call).value_or(MemorySpace::Vec); + } + if (entry.HasRetargetableMemoryKwarg()) { + auto demand_it = demands_.find(out_var); + if (demand_it != demands_.end()) { + MemorySpace demand = demand_it->second; + // Retargetable DDR-facing producers (tile.load) can only directly + // produce {Vec, Mat}; specialized demands (Left/Right/Acc/Bias) from + // downstream compute ops (matmul etc.) must be reached via a + // tile.move inserted by Phase 2 MoveCollector. Clamping here keeps + // the producer's output hardware-valid and preserves the move chain. + if (demand == MemorySpace::Vec || demand == MemorySpace::Mat) return demand; + } + } + return InheritFromInput(call).value_or(MemorySpace::Vec); } - MemorySpace InheritFromInput(const CallPtr& call) { + std::optional InheritFromInput(const CallPtr& call) { for (const auto& arg : call->args_) { if (auto var = As(arg)) { auto it = var_memory_.find(var); @@ -184,7 +314,7 @@ class TileMemorySpaceAnalyzer : public IRVisitor { } } } - return MemorySpace::Vec; + return std::nullopt; } }; @@ -333,13 +463,17 @@ class TileMemorySpaceMutator : public IRMutator { return std::make_shared(As(new_var_expr), new_value, op->span_); } - // Rewrite tile.create's target_memory kwarg when the LHS var was promoted - // (e.g. the for-loop accumulator back-propagation in Phase 1 moved the - // init from Vec to Acc). The new result type uses the implicit TileView - // for the promoted memory so later passes see a consistent layout. - // OpRegistry deduction would otherwise keep Vec-style layout defaults. + // Rewrite retargetable producers' target_memory kwarg so it matches the + // resolved memory space. Covers tile.create / tile.load / any op registered + // with HasRetargetableMemoryKwarg(): if Phase 1 resolved the output to a + // different space than the kwarg says (or the kwarg is absent because the + // converter let the pass decide), we rewrite the call so codegen reads a + // consistent value and the result type gets a fresh implicit TileView. if (auto call = As(new_value); call) { - if (auto op_name_node = As(call->op_); op_name_node && op_name_node->name_ == "tile.create") { + auto& registry = OpRegistry::GetInstance(); + const std::string& call_op_name = call->op_->name_; + if (registry.IsRegistered(call_op_name) && + registry.GetEntry(call_op_name).HasRetargetableMemoryKwarg()) { auto mem_it = var_memory_.find(op->var_); auto old_call_type = As(call->GetType()); if (mem_it != var_memory_.end() && old_call_type) { @@ -351,10 +485,6 @@ class TileMemorySpaceMutator : public IRMutator { break; } } - // tile.create defaults target_memory to Vec, so an explicit Vec call - // may omit the kwarg entirely. Rewrite when the kwarg is missing or - // differs from the promoted space: preserve other kwargs, overwrite - // target_memory if present, and inject it otherwise. if (!kwarg_target.has_value() || *kwarg_target != promoted) { std::vector> new_kwargs; new_kwargs.reserve(call->kwargs_.size() + 1); @@ -535,8 +665,16 @@ class TileMemorySpaceMutator : public IRMutator { // ============================================================================ FunctionPtr TransformInferTileMemorySpace(const FunctionPtr& func) { - // Phase 1: Analyze — infer memory space for each tile variable - TileMemorySpaceAnalyzer analyzer(func->params_); + // Phase 0: Collect backward demand from op input_constraints; propagate + // through OutputMemoryInheritsInput() ops so demand reaches retargetable + // producers (tile.load/tile.create) even through view chains (slice/fillpad). + DemandCollector demand_collector; + demand_collector.VisitStmt(func->body_); + demand_collector.PropagateThroughInheritInputOps(); + + // Phase 1: Analyze — infer memory space for each tile variable, using Phase-0 + // demand as fallback for retargetable producers whose target_memory is absent. + TileMemorySpaceAnalyzer analyzer(func->params_, demand_collector.GetDemands()); analyzer.VisitStmt(func->body_); const auto& var_memory = analyzer.GetVarMemory(); @@ -544,11 +682,13 @@ FunctionPtr TransformInferTileMemorySpace(const FunctionPtr& func) { return func; } - // Phase 2: Collect needed tile.move insertions + // Phase 2: Collect needed tile.move insertions for residual input-constraint + // mismatches (producer and demand both resolved to different fixed spaces). MoveCollector collector(var_memory); collector.VisitStmt(func->body_); - // Phase 3: Mutate — set memory_space_ on types, insert moves, substitute args + // Phase 3: Mutate — set memory_space_ on types, insert moves, substitute args, + // rewrite target_memory kwargs on retargetable producers to stay consistent. TileMemorySpaceMutator mutator(var_memory, collector.GetNeededMoves()); auto new_body = mutator.VisitStmt(func->body_); diff --git a/src/ir/transforms/init_memref.cpp b/src/ir/transforms/init_memref.cpp index 116c89c1e..326e81960 100644 --- a/src/ir/transforms/init_memref.cpp +++ b/src/ir/transforms/init_memref.cpp @@ -48,16 +48,14 @@ namespace ir { namespace { -// Check if operation is a view operation (zero-copy metadata transform) -// using the registry: deduce_output_memory returning nullopt = view op. +// Check if operation is a view operation (zero-copy metadata transform). +// A view op is one registered with set_output_memory_inherit_input() — its +// output reuses the input's MemRef view. Delegates to the shared registry +// predicate so InferTileMemorySpace and InitMemRef agree on the set. bool IsViewOperation(const std::string& op_name) { auto& registry = OpRegistry::GetInstance(); if (!registry.IsRegistered(op_name)) return false; - - const auto& spec_opt = registry.GetEntry(op_name).GetMemorySpec(); - if (!spec_opt.has_value() || !spec_opt->deduce_output_memory) return false; - - return !spec_opt->deduce_output_memory({}).has_value(); + return registry.GetEntry(op_name).OutputMemoryInheritsInput(); } // Check if an operation's output should reuse the MemRef of a specific input argument. diff --git a/src/ir/transforms/memory_reuse_pass.cpp b/src/ir/transforms/memory_reuse_pass.cpp index c5b14e224..bdab19188 100644 --- a/src/ir/transforms/memory_reuse_pass.cpp +++ b/src/ir/transforms/memory_reuse_pass.cpp @@ -394,16 +394,11 @@ class TopDownRetargeter { return c.bases.count(target_base) > 0; } - /// True when the op is registered with set_output_memory_inherit_input: - /// the memory spec exists, has no output_reuses_input_arg, and its - /// deduce_output_memory lambda returns nullopt for empty kwargs (the - /// signature the inherit-input registration leaves behind). + /// True when the op is registered with set_output_memory_inherit_input. + /// Delegates to the shared OpRegistryEntry predicate so passes that reason + /// about pass-through ops (here and InferTileMemorySpace) agree on the set. static bool IsOutputMemoryInheritInput(const OpRegistryEntry& entry) { - const auto& spec = entry.GetMemorySpec(); - if (!spec.has_value()) return false; - if (spec->output_reuses_input_arg.has_value()) return false; - if (!spec->deduce_output_memory) return false; - return !spec->deduce_output_memory({}).has_value(); + return entry.OutputMemoryInheritsInput(); } static bool HasKwarg(const Call& call, const std::string& key) { diff --git a/tests/ut/ir/operators/test_op_registry.py b/tests/ut/ir/operators/test_op_registry.py index 9330331a1..83ee1f44c 100644 --- a/tests/ut/ir/operators/test_op_registry.py +++ b/tests/ut/ir/operators/test_op_registry.py @@ -522,10 +522,12 @@ def test_matmul_acc_spec(self): assert constraints[2] == [ir.MemorySpace.Right] def test_load_spec(self): - """tile.load output is from kwarg, defaults to Vec.""" + """tile.load output is retargetable: resolves from target_memory kwarg if + present, otherwise deferred for InferTileMemorySpace to decide from + consumer demand.""" spec = ir.get_op_memory_spec("tile.load") assert spec is not None - assert spec["output_memory"] == ir.MemorySpace.Vec + assert spec["output_memory"] == "deferred" assert spec["input_constraints"] == [] def test_store_spec(self): @@ -682,10 +684,11 @@ def test_move_spec(self): assert spec["output_memory"] == ir.MemorySpace.Vec def test_create_spec(self): - """tile.create output is from kwarg, defaults to Vec.""" + """tile.create output is retargetable: resolves from target_memory kwarg + if present, otherwise deferred for InferTileMemorySpace to decide.""" spec = ir.get_op_memory_spec("tile.create") assert spec is not None - assert spec["output_memory"] == ir.MemorySpace.Vec + assert spec["output_memory"] == "deferred" class TestRegistryInfrastructure: @@ -703,11 +706,13 @@ def test_fixed_output_returns_enum(self): assert spec is not None assert isinstance(spec["output_memory"], ir.MemorySpace) - def test_kwarg_output_returns_default_enum(self): - """Kwarg-based output resolves to default MemorySpace enum.""" + def test_kwarg_output_returns_deferred(self): + """Retargetable ops (tile.load/tile.create) report 'deferred' when the + target_memory kwarg is absent — InferTileMemorySpace resolves from + consumer demand.""" spec = ir.get_op_memory_spec("tile.load") assert spec is not None - assert isinstance(spec["output_memory"], ir.MemorySpace) + assert spec["output_memory"] == "deferred" def test_inherit_output_returns_string(self): """Inherit-from-input output returns the string 'inherit_from_input'.""" diff --git a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py index 180924a84..7947bed21 100644 --- a/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py +++ b/tests/ut/ir/transforms/test_convert_tensor_to_tile_ops.py @@ -1511,6 +1511,142 @@ def expected_body(ib, params): ) _assert_convert_equal(before, expected) + def test_slice_alias_then_matmul_routes_load_to_mat(self): + """tensor.slice → SSA alias → tensor.matmul emits tile.load(Mat). + + Reproduction of the qwen3 decode MLP-down pattern: the parser elides + `y = x` aliases (e.g. from commented-out `pl.fillpad` wrappers), leaving + a chain `slice → alias → matmul`. ConsumerSpaceCollector must propagate + matmul's Mat demand backward through the alias so the slice lowers to + a Mat-targeted load instead of the default Vec (which otherwise routes + the whole scope through AIV and breaks the AIC/AIV split). + """ + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.BF16]], + ) -> pl.Tensor[[16, 128], pl.BF16]: + a_slice: pl.Tensor[[16, 128], pl.BF16] = pl.slice(a, [16, 128], [0, 0]) + b_slice: pl.Tensor[[128, 128], pl.BF16] = pl.slice(b, [128, 128], [0, 0]) + a_alias: pl.Tensor[[16, 128], pl.BF16] = a_slice + b_alias: pl.Tensor[[128, 128], pl.BF16] = b_slice + c: pl.Tensor[[16, 128], pl.BF16] = pl.matmul(a_alias, b_alias) + out_0: pl.Tensor[[16, 128], pl.BF16] = pl.assemble(out_0, c, [0, 0]) + return out_0 + + @pl.function + def main( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.BF16]: + out_0: pl.Tensor[[16, 128], pl.BF16] = pl.create_tensor([16, 128], dtype=pl.BF16) + return self.main_incore_0(a, b, out_0) + + @pl.program + class Expected: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.BF16]], + ) -> pl.Tensor[[16, 128], pl.BF16]: + a_slice__tile: pl.Tile[[16, 128], pl.BF16, pl.Mem.Mat] = pl.tile.load( + a, [0, 0], [16, 128], [16, 128], target_memory=pl.Mem.Mat, transpose=False + ) + b_slice__tile: pl.Tile[[128, 128], pl.BF16, pl.Mem.Mat] = pl.tile.load( + b, [0, 0], [128, 128], [128, 128], target_memory=pl.Mem.Mat, transpose=False + ) + a_alias: pl.Tile[[16, 128], pl.BF16, pl.Mem.Mat] = a_slice__tile + b_alias: pl.Tile[[128, 128], pl.BF16, pl.Mem.Mat] = b_slice__tile + c__tile: pl.Tile[[16, 128], pl.FP32, pl.Mem.Acc] = pl.tile.matmul(a_alias, b_alias) + out_0__tile: pl.Tensor[[16, 128], pl.BF16] = pl.tile.store(c__tile, [0, 0], out_0) + return out_0__tile + + @pl.function + def main( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.BF16]: + out_0: pl.Tensor[[16, 128], pl.BF16] = pl.create_tensor([16, 128], dtype=pl.BF16) + return self.main_incore_0(a, b, out_0) + + After = passes.convert_tensor_to_tile_ops()(Before) + ir.assert_structural_equal(After, Expected) + + def test_slice_chain_of_aliases_then_matmul(self): + """Demand propagates through a chain of SSA aliases, not just one hop. + + Ensures the single reverse-order sweep over ``propagation_edges_`` handles + transitive closure: slice → alias1 → alias2 → matmul must still reach + the slice-produced var and push Mat onto the emitted tile.load. + """ + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 64], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 64], pl.BF16]], + ) -> pl.Tensor[[16, 64], pl.BF16]: + a_slice: pl.Tensor[[16, 128], pl.BF16] = pl.slice(a, [16, 128], [0, 0]) + a_alias1: pl.Tensor[[16, 128], pl.BF16] = a_slice + a_alias2: pl.Tensor[[16, 128], pl.BF16] = a_alias1 + c: pl.Tensor[[16, 64], pl.BF16] = pl.matmul(a_alias2, b) + out_0: pl.Tensor[[16, 64], pl.BF16] = pl.assemble(out_0, c, [0, 0]) + return out_0 + + @pl.function + def main( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 64], pl.BF16], + ) -> pl.Tensor[[16, 64], pl.BF16]: + out_0: pl.Tensor[[16, 64], pl.BF16] = pl.create_tensor([16, 64], dtype=pl.BF16) + return self.main_incore_0(a, b, out_0) + + @pl.program + class Expected: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 64], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 64], pl.BF16]], + ) -> pl.Tensor[[16, 64], pl.BF16]: + a_slice__tile: pl.Tile[[16, 128], pl.BF16, pl.Mem.Mat] = pl.tile.load( + a, [0, 0], [16, 128], [16, 128], target_memory=pl.Mem.Mat, transpose=False + ) + a_alias1: pl.Tile[[16, 128], pl.BF16, pl.Mem.Mat] = a_slice__tile + a_alias2: pl.Tile[[16, 128], pl.BF16, pl.Mem.Mat] = a_alias1 + b__tile: pl.Tile[[128, 64], pl.BF16, pl.Mem.Mat] = pl.tile.load( + b, [0, 0], [128, 64], [128, 64], target_memory=pl.Mem.Mat, transpose=False + ) + c__tile: pl.Tile[[16, 64], pl.FP32, pl.Mem.Acc] = pl.tile.matmul(a_alias2, b__tile) + out_0__tile: pl.Tensor[[16, 64], pl.BF16] = pl.tile.store(c__tile, [0, 0], out_0) + return out_0__tile + + @pl.function + def main( + self, + a: pl.Tensor[[16, 128], pl.BF16], + b: pl.Tensor[[128, 64], pl.BF16], + ) -> pl.Tensor[[16, 64], pl.BF16]: + out_0: pl.Tensor[[16, 64], pl.BF16] = pl.create_tensor([16, 64], dtype=pl.BF16) + return self.main_incore_0(a, b, out_0) + + After = passes.convert_tensor_to_tile_ops()(Before) + ir.assert_structural_equal(After, Expected) + class TestScatterUpdateConversion: """Tests for tensor.scatter_update → tile.scatter_update conversion.""" diff --git a/tests/ut/ir/transforms/test_infer_tile_memory_space.py b/tests/ut/ir/transforms/test_infer_tile_memory_space.py index 43632e965..14da54c2a 100644 --- a/tests/ut/ir/transforms/test_infer_tile_memory_space.py +++ b/tests/ut/ir/transforms/test_infer_tile_memory_space.py @@ -1462,5 +1462,165 @@ def main( ir.assert_structural_equal(After, Expected) +class TestInferTileMemorySpaceSSAAlias: + """SSA-alias propagation added by the backward-demand-inference refactor. + + `y = x` where both sides are Tile-typed must forward x's resolved memory + space onto y. The pl.DSL parser emits these aliases when eliding no-op + wrappers (e.g. commented-out `pl.fillpad`), and earlier pipeline stages + also produce them. Before the refactor, aliases without an explicit + `pl.Mem.*` annotation left y with no memory_space set and later-phase + consumers (MoveCollector, Phase 3) diverged from x. + """ + + def test_ssa_alias_inherits_memory_space_from_source(self): + """`y = x` inherits x's resolved memory_space. tile.store demands Vec/Acc, + so a Mat alias requires a Mat→Vec move before the store — present in + both Before and Expected so the pass only has to propagate the alias's + memory_space, isolating what this test verifies.""" + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]], + ) -> pl.Tensor[[16, 128], pl.FP32]: + x_tile: pl.Tile[[16, 128], pl.BF16] = pl.load( + x, [0, 0], [16, 128], target_memory=pl.MemorySpace.Mat + ) + y_tile: pl.Tile[[128, 128], pl.BF16] = pl.load( + y, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat + ) + # The alias carries no memory_space annotation — Phase 1 must + # copy Mat over from x_tile. + x_alias: pl.Tile[[16, 128], pl.BF16] = x_tile + z_tile: pl.Tile[[16, 128], pl.FP32] = pl.matmul(x_alias, y_tile) + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.store(z_tile, [0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.FP32]: + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.create_tensor([16, 128], dtype=pl.FP32) + return self.main_incore_0(x, y, out_0) + + @pl.program + class Expected: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]], + ) -> pl.Tensor[[16, 128], pl.FP32]: + x_tile: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Mat] = pl.load( + x, [0, 0], [16, 128], target_memory=pl.MemorySpace.Mat + ) + y_tile: pl.Tile[[128, 128], pl.BF16, pl.MemorySpace.Mat] = pl.load( + y, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat + ) + x_alias: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Mat] = x_tile + x_alias_L: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Left] = pl.move( + x_alias, target_memory=pl.MemorySpace.Left + ) + y_tile_R: pl.Tile[[128, 128], pl.BF16, pl.MemorySpace.Right] = pl.move( + y_tile, target_memory=pl.MemorySpace.Right + ) + z_tile: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Acc] = pl.matmul(x_alias_L, y_tile_R) + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.store(z_tile, [0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.FP32]: + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.create_tensor([16, 128], dtype=pl.FP32) + return self.main_incore_0(x, y, out_0) + + After = passes.infer_tile_memory_space()(Before) + ir.assert_structural_equal(After, Expected) + + def test_ssa_alias_chain_feeds_matmul(self): + """`y = x`, `z = y`: both aliases inherit x's memory_space. Verifies + Phase 1 handles transitive SSA-alias chains in a single forward sweep.""" + + @pl.program + class Before: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]], + ) -> pl.Tensor[[16, 128], pl.FP32]: + x_tile: pl.Tile[[16, 128], pl.BF16] = pl.load( + x, [0, 0], [16, 128], target_memory=pl.MemorySpace.Mat + ) + alias_1: pl.Tile[[16, 128], pl.BF16] = x_tile + alias_2: pl.Tile[[16, 128], pl.BF16] = alias_1 + y_tile: pl.Tile[[128, 128], pl.BF16] = pl.load( + y, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat + ) + z_tile: pl.Tile[[16, 128], pl.FP32] = pl.matmul(alias_2, y_tile) + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.store(z_tile, [0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.FP32]: + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.create_tensor([16, 128], dtype=pl.FP32) + return self.main_incore_0(x, y, out_0) + + @pl.program + class Expected: + @pl.function(type=pl.FunctionType.InCore) + def main_incore_0( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + out_0: pl.Out[pl.Tensor[[16, 128], pl.FP32]], + ) -> pl.Tensor[[16, 128], pl.FP32]: + x_tile: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Mat] = pl.load( + x, [0, 0], [16, 128], target_memory=pl.MemorySpace.Mat + ) + alias_1: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Mat] = x_tile + alias_2: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Mat] = alias_1 + y_tile: pl.Tile[[128, 128], pl.BF16, pl.MemorySpace.Mat] = pl.load( + y, [0, 0], [128, 128], target_memory=pl.MemorySpace.Mat + ) + alias_2_L: pl.Tile[[16, 128], pl.BF16, pl.MemorySpace.Left] = pl.move( + alias_2, target_memory=pl.MemorySpace.Left + ) + y_tile_R: pl.Tile[[128, 128], pl.BF16, pl.MemorySpace.Right] = pl.move( + y_tile, target_memory=pl.MemorySpace.Right + ) + z_tile: pl.Tile[[16, 128], pl.FP32, pl.MemorySpace.Acc] = pl.matmul(alias_2_L, y_tile_R) + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.store(z_tile, [0, 0], out_0) + return out_0 + + @pl.function + def main( + self, + x: pl.Tensor[[16, 128], pl.BF16], + y: pl.Tensor[[128, 128], pl.BF16], + ) -> pl.Tensor[[16, 128], pl.FP32]: + out_0: pl.Tensor[[16, 128], pl.FP32] = pl.create_tensor([16, 128], dtype=pl.FP32) + return self.main_incore_0(x, y, out_0) + + After = passes.infer_tile_memory_space()(Before) + ir.assert_structural_equal(After, Expected) + + if __name__ == "__main__": pytest.main([__file__, "-v"])