diff --git a/python/pypto/debug/torch_codegen.py b/python/pypto/debug/torch_codegen.py index 859cbd556..4db706c83 100644 --- a/python/pypto/debug/torch_codegen.py +++ b/python/pypto/debug/torch_codegen.py @@ -54,6 +54,35 @@ def _torch_dtype(dt: DataType) -> str: 5: ">=", # GE } + +def _sanitize_name_hint(hint: str) -> str: + """Convert an IR name hint into a valid Python identifier base.""" + base = hint or "v" + base = re.sub(r"[^a-zA-Z0-9_]", "_", base) + base = re.sub(r"__+", "_", base).strip("_") or "v" + if base[0].isdigit(): + base = f"v_{base}" + if keyword.iskeyword(base): + base = f"{base}_v" + return base + + +def _make_unique_names(hints: list[str]) -> list[str]: + """Make sanitized Python identifiers unique while preserving order.""" + unique_names: list[str] = [] + counts: dict[str, int] = {} + for hint in hints: + base = _sanitize_name_hint(hint) + count = counts.get(base, 0) + if count == 0: + unique_names.append(base) + counts[base] = 1 + else: + unique_names.append(f"{base}_{count}") + counts[base] = count + 1 + return unique_names + + # --------------------------------------------------------------------------- # Preamble inserted at top of every generated script # --------------------------------------------------------------------------- @@ -163,6 +192,185 @@ def _assemble(target, source, offsets): valid_slices = tuple(slice(0, s) for s in valid_shape) target[slices] = source[valid_slices] return target + +def _split_for_aiv_consumer(tensor, split_mode): + if split_mode == 0: + return [tensor.clone()] + elif split_mode == 1: # UpDown: split along rows (dim -2) + mid = tensor.shape[-2] // 2 + return [tensor[..., :mid, :].clone(), tensor[..., mid:, :].clone()] + elif split_mode == 2: # LeftRight: split along cols (dim -1) + mid = tensor.shape[-1] // 2 + return [tensor[..., :mid].clone(), tensor[..., mid:].clone()] + return [tensor.clone()] + +def _tpush_to_aiv(pipe, tensor, split_mode): + for chunk in _split_for_aiv_consumer(tensor, split_mode): + pipe.append(chunk) + return tensor + +def _tpush_to_aic(pipe, tensor, _split_mode): + # Outside the cooperative scheduler we do not model two AIV subblocks + # running in parallel, so the C2V push always carries the full tile and + # the split kwarg is informational. Scheduler-mode emission uses the + # ``_g`` variants which do honor split. + pipe.append(tensor.clone()) + return tensor + +def _tpop_from_aic(pipe, split_mode): + if split_mode == 0: + return pipe.popleft() + # split>0: ``_tpush_to_aiv`` queued two halves on the same deque; pop + # both and reassemble the full tile so legacy single-AIV-subblock code + # paths see the same end-to-end value as the scheduled path. + first = pipe.popleft() + second = pipe.popleft() + if split_mode == 1: # UpDown -> rows + return torch.cat([first, second], dim=-2) + if split_mode == 2: # LeftRight -> cols + return torch.cat([first, second], dim=-1) + return first + +def _tpop_from_aiv(pipe, _split_mode): + return pipe.popleft() + +# Per-subblock pipes (only used when split>0). In real hardware a split tile +# is delivered to two AIV subblocks (sb0 / sb1) which run the same kernel on +# their own half; their outputs are reassembled at the AIC pop side. These +# extra deques carry the per-subblock chunks so each AIV invocation in the +# scheduler picks up its own half, and AIC's tpop_from_aiv can ``cat`` the two +# halves back together. +_pipes_sb = { + 'to_aiv_sb0': deque(), 'to_aiv_sb1': deque(), + 'to_aic_sb0': deque(), 'to_aic_sb1': deque(), +} + +# Current subblock id (0 or 1). ``_run_scheduler`` writes this slot before +# resuming each task so the generator-mode pipe helpers and +# tile.get_subblock_idx codegen can read it via runtime context. +_current_sb = [0] + +def _reset_pipes(): + _pipes['to_aiv'].clear() + _pipes['to_aic'].clear() + for q in _pipes_sb.values(): + q.clear() + _current_sb[0] = 0 + +# --- Cooperative scheduler for cross-core simulation ----------------------- +# Used when a Group function couples AIC and AIV functions whose tpush/tpop +# operations cannot be modeled by simply running one side and then the other +# (e.g. bidirectional V<->C, producer/consumer feedback loops, or any pipe +# op carrying split>0 which needs two AIV subblocks running in parallel). +# Each AIC or AIV function is emitted as a Python generator that ``yield``s a +# ``_WaitPop`` / ``_WaitPush`` request at every pipe synchronization point. +# ``_run_scheduler`` advances each generator until it blocks, then switches +# to the next, mirroring the cooperative interleaving of the real cores. + +class _WaitPop: + __slots__ = ("pipe",) + def __init__(self, pipe): + self.pipe = pipe + +class _WaitPush: + __slots__ = ("pipe", "item") + def __init__(self, pipe, item): + self.pipe = pipe + self.item = item + +def _tpush_to_aiv_g(pipe, tensor, split_mode): + # split_mode == 0: single full-tile chunk on the unified to_aiv pipe. + # split_mode > 0: split the tile and route halves to per-subblock pipes + # so each AIV subblock picks up its own portion. + if split_mode == 0: + yield _WaitPush(pipe, tensor.clone()) + return tensor + chunks = _split_for_aiv_consumer(tensor, split_mode) + yield _WaitPush(_pipes_sb['to_aiv_sb0'], chunks[0]) + yield _WaitPush(_pipes_sb['to_aiv_sb1'], chunks[1]) + return tensor + +def _tpush_to_aic_g(pipe, tensor, split_mode): + # split_mode == 0: single chunk on the unified to_aic pipe. + # split_mode > 0: push each AIV subblock's contribution onto its own pipe; + # AIC's tpop_from_aiv reassembles them. + if split_mode == 0: + yield _WaitPush(pipe, tensor.clone()) + return tensor + sb = _current_sb[0] + yield _WaitPush(_pipes_sb[f'to_aic_sb{sb}'], tensor.clone()) + return tensor + +def _tpop_from_aic_g(pipe, split_mode): + if split_mode == 0: + return (yield _WaitPop(pipe)) + sb = _current_sb[0] + return (yield _WaitPop(_pipes_sb[f'to_aiv_sb{sb}'])) + +def _tpop_from_aiv_g(pipe, split_mode): + if split_mode == 0: + return (yield _WaitPop(pipe)) + sb0 = yield _WaitPop(_pipes_sb['to_aic_sb0']) + sb1 = yield _WaitPop(_pipes_sb['to_aic_sb1']) + if split_mode == 1: # UpDown -> rows + return torch.cat([sb0, sb1], dim=-2) + if split_mode == 2: # LeftRight -> cols + return torch.cat([sb0, sb1], dim=-1) + return sb0 + +def _run_scheduler(tasks): + # Cooperative round-robin scheduler over generator-style AIC/AIV bodies. + # ``tasks`` is a list of ``(name, generator, subblock_id)`` tuples. Each + # generator yields _WaitPop / _WaitPush requests at pipe sync points. + # Pipes are unbounded deques so _WaitPush always succeeds; the + # interesting suspend point is _WaitPop on an empty pipe. The scheduler + # keeps cycling until every generator returns; if a full pass makes no + # progress and any task is still alive, it raises a deadlock error with + # the pending request kinds. Before resuming each task we set + # ``_current_sb[0]`` so the split-aware pipe helpers and + # tile.get_subblock_idx see the right id. + states = [] + for name, gen, sb in tasks: + _current_sb[0] = sb + try: + req = next(gen) + states.append([name, gen, req, False, sb]) + except StopIteration: + states.append([name, gen, None, True, sb]) + while True: + progressed = False + all_done = True + for st in states: + if st[3]: + continue + all_done = False + req = st[2] + advance_value = None + advance = False + if isinstance(req, _WaitPush): + req.pipe.append(req.item) + advance = True + elif isinstance(req, _WaitPop): + if len(req.pipe) > 0: + advance_value = req.pipe.popleft() + advance = True + else: + # Defensive: unknown yield value treated as cooperative yield. + advance = True + if advance: + progressed = True + _current_sb[0] = st[4] + try: + st[2] = st[1].send(advance_value) + except StopIteration: + st[3] = True + if all_done: + return + if not progressed: + blocked = [(s[0], type(s[2]).__name__) for s in states if not s[3]] + raise RuntimeError( + "Cross-core simulation deadlock; tasks blocked: " + repr(blocked) + ) """ # --------------------------------------------------------------------------- @@ -306,6 +514,25 @@ def _handle_fillpad(a: list[str], kw: dict[str, Any]) -> str: # Build the dispatch table _OP_MAP: dict[str, OpHandler] = {} +# Cross-core scheduler-mode overrides: same op names, but emissions wrap into +# ``(yield from _..._g(...))`` so the enclosing AIC/AIV function becomes a +# Python generator that yields at every pipe sync point. Only populated for +# the four cross-core ops; everything else falls back to ``_OP_MAP``. +_OP_MAP_SCHED: dict[str, OpHandler] = { + "tile.tpush_to_aiv": ( + lambda a, kw: f"(yield from _tpush_to_aiv_g(_pipes['to_aiv'], {a[0]}, {kw.get('split', 0)}))" + ), + "tile.tpush_to_aic": ( + lambda a, kw: f"(yield from _tpush_to_aic_g(_pipes['to_aic'], {a[0]}, {kw.get('split', 0)}))" + ), + "tile.tpop_from_aic": ( + lambda _a, kw: f"(yield from _tpop_from_aic_g(_pipes['to_aiv'], {kw.get('split', 0)}))" + ), + "tile.tpop_from_aiv": ( + lambda _a, kw: f"(yield from _tpop_from_aiv_g(_pipes['to_aic'], {kw.get('split', 0)}))" + ), +} + def _register_ops() -> None: m = _OP_MAP @@ -393,6 +620,11 @@ def _register_ops() -> None: m["tile.read"] = lambda a, _kw: f"{a[0]}[{a[1]}]" m["tile.write"] = lambda a, _kw: f"_write_and_return({a[0]}, {a[1]}, {a[2]})" m["tile.get_block_idx"] = lambda _a, _kw: "0" + # tile.get_subblock_idx returns the active AIV subblock id at runtime so + # split-aware kernels can compute per-subblock offsets / slices. Outside + # a scheduled Group ``_current_sb[0]`` stays 0, matching the legacy + # single-subblock behavior for unidirectional / split=0 callers. + m["tile.get_subblock_idx"] = lambda _a, _kw: "_current_sb[0]" # tile log / relu m["tile.log"] = _torch_fn("log") @@ -445,10 +677,10 @@ def _register_ops() -> None: m["tile.subsc"] = lambda a, _kw: f"({a[0]} - {a[1]} - {a[2]})" # --- Cross-core pipe ops --- - m["tile.tpush_to_aiv"] = lambda a, _kw: f"_pipes['to_aiv'].append({a[0]}.clone())" - m["tile.tpush_to_aic"] = lambda a, _kw: f"_pipes['to_aic'].append({a[0]}.clone())" - m["tile.tpop_from_aic"] = lambda _a, _kw: "_pipes['to_aic'].popleft()" - m["tile.tpop_from_aiv"] = lambda _a, _kw: "_pipes['to_aiv'].popleft()" + m["tile.tpush_to_aiv"] = lambda a, kw: f"_tpush_to_aiv(_pipes['to_aiv'], {a[0]}, {kw.get('split', 0)})" + m["tile.tpush_to_aic"] = lambda a, kw: f"_tpush_to_aic(_pipes['to_aic'], {a[0]}, {kw.get('split', 0)})" + m["tile.tpop_from_aic"] = lambda _a, kw: f"_tpop_from_aic(_pipes['to_aiv'], {kw.get('split', 0)})" + m["tile.tpop_from_aiv"] = lambda _a, kw: f"_tpop_from_aiv(_pipes['to_aic'], {kw.get('split', 0)})" # --- System ops (no-ops) --- for op_name in ( @@ -470,6 +702,61 @@ def _register_ops() -> None: _register_ops() +# --------------------------------------------------------------------------- +# Helpers for cross-core program simulation +# --------------------------------------------------------------------------- + + +def _extract_group_member_names(group_func: _ir.Function) -> list[str]: + """Extract AIC/AIV member function names from a Group function's body.""" + names: list[str] = [] + stmts = _ir.flatten_to_stmts(group_func.body) + for stmt in stmts: + call = None + if isinstance(stmt, _ir.EvalStmt): + call = stmt.expr + elif isinstance(stmt, _ir.AssignStmt): + call = stmt.value + if isinstance(call, _ir.Call) and isinstance(call.op, _ir.GlobalVar): + names.append(call.op.name) + return names + + +def _generate_entry_point(program: _ir.Program) -> str: + """Generate a ``run()`` entry-point wrapper for a Program. + + Prefers an Orchestration function, falls back to the first Opaque + function, then Group, and returns an empty string if none exists. + """ + entry_func = None + for func in program.functions.values(): + if func.func_type == _ir.FunctionType.Orchestration: + entry_func = func + break + if entry_func is None: + for func in program.functions.values(): + if func.func_type == _ir.FunctionType.Opaque: + entry_func = func + break + if entry_func is None: + for func in program.functions.values(): + if func.func_type == _ir.FunctionType.Group: + entry_func = func + break + if entry_func is None: + return "" + # If the entry function itself is named ``run``, skip emitting a wrapper + # to avoid producing ``def run(...): return run(...)`` (infinite recursion). + if entry_func.name == "run": + return "" + param_names = _make_unique_names([p.name_hint for p in entry_func.params]) + return ( + f"# Entry point\n" + f"def run({', '.join(param_names)}):\n" + f" return {entry_func.name}({', '.join(param_names)})\n" + ) + + # --------------------------------------------------------------------------- # Binary / unary IR expression -> Python operator string # --------------------------------------------------------------------------- @@ -514,9 +801,19 @@ def __init__(self, *, check_shapes: bool = False) -> None: self._indent: int = 0 self._expr_result: str = "" self._var_names: dict[int, str] = {} # id(Var) -> unique name + self._stable_hints: dict[str, str] = {} # hint -> name for params/aliases only + self._var_refs: list[_ir.Var] = [] # prevent GC of Var wrappers self._name_counter: dict[str, int] = {} self._yield_targets: list[str] = [] # names to assign on yield self._check_shapes: bool = check_shapes + # Cross-core scheduler mode: function names that should be emitted as + # generators using ``_OP_MAP_SCHED`` for tpush/tpop ops. Populated by + # ``visit_program`` after pattern detection. + self._sched_funcs: set[str] = set() + # AIV member functions that must be scheduled twice (one task per + # AIV subblock id) because they participate in split>0 transfers. + self._sched_aiv_dup: set[str] = set() + self._current_func_name: str = "" # -- helpers -- @@ -524,17 +821,7 @@ def _emit(self, line: str) -> None: self._lines.append(" " * self._indent + line) def _unique_name(self, hint: str) -> str: - base = hint or "v" - # Sanitize: replace non-identifier chars with underscore - base = re.sub(r"[^a-zA-Z0-9_]", "_", base) - # Collapse consecutive underscores - base = re.sub(r"__+", "_", base).strip("_") or "v" - # Ensure doesn't start with digit - if base[0].isdigit(): - base = f"v_{base}" - # Avoid Python keywords - if keyword.iskeyword(base): - base = f"{base}_v" + base = _sanitize_name_hint(hint) count = self._name_counter.get(base, 0) if count == 0: self._name_counter[base] = 1 @@ -545,7 +832,18 @@ def _unique_name(self, hint: str) -> str: def _name_of(self, var: _ir.Var) -> str: vid = id(var) if vid not in self._var_names: - self._var_names[vid] = self._unique_name(var.name_hint) + hint = var.name_hint + # Nanobind may create fresh Python wrappers for the same C++ + # Var, giving each a different id(). Fall back to stable hints + # (function params and loop aliases) where the mapping is + # unambiguous. Do NOT fall back for local SSA vars — different + # Vars can share a name_hint and must get unique names. + if hint in self._stable_hints: + name = self._stable_hints[hint] + else: + name = self._unique_name(hint) + self._var_names[vid] = name + self._var_refs.append(var) # prevent GC id reuse return self._var_names[vid] def _visit_expr_str(self, expr: _ir.Expr) -> str: @@ -578,17 +876,285 @@ def _alias_return_vars(self, return_vars: list[_ir.Var], names: list[str]) -> No """Map return_vars to the same names as iter_args after a loop.""" for rv, name in zip(return_vars, names): self._var_names[id(rv)] = name + self._stable_hints[rv.name_hint] = name + self._var_refs.append(rv) # -- top-level -- + def _reset_var_scope(self) -> None: + """Reset per-function variable naming state. + + Each function in a Program gets its own naming scope so that + identically-named IR variables in different functions do not + collide. + """ + self._var_names.clear() + self._stable_hints.clear() + self._var_refs.clear() + self._name_counter.clear() + def visit_program(self, program: _ir.Program) -> None: + # Classify functions by type for dependency-ordered emission + aic_aiv_funcs: list[_ir.Function] = [] + group_funcs: list[_ir.Function] = [] + orch_funcs: list[_ir.Function] = [] + other_funcs: list[_ir.Function] = [] + for _gv, func in program.functions.items(): + ft = func.func_type + if ft in (_ir.FunctionType.AIC, _ir.FunctionType.AIV): + aic_aiv_funcs.append(func) + elif ft == _ir.FunctionType.Group: + group_funcs.append(func) + elif ft == _ir.FunctionType.Orchestration: + orch_funcs.append(func) + else: + other_funcs.append(func) + + # Cross-core scheduler detection: any Group whose AIC/AIV members + # together contain bidirectional pipe access (tpush_to_aiv AND + # tpush_to_aic), or any single-direction member that does both + # tpush and tpop on the same side, OR uses split>0 on any pipe op, + # must be emitted with the cooperative scheduler so that pipe sync + # points are honored and split tiles fan out to two AIV subblocks. + funcs_by_name = {f.name: f for f in program.functions.values()} + scheduled_groups = self._detect_scheduled_groups(program, group_funcs) + for grp in scheduled_groups: + for member_name in _extract_group_member_names(grp): + self._sched_funcs.add(member_name) + member = funcs_by_name.get(member_name) + if member is None: + continue + # AIV members that touch the pipe with split>0 must run twice + # under the scheduler (once per AIV subblock 0/1) so both + # halves of a split tile are produced/consumed. + if member.func_type == _ir.FunctionType.AIV and self._function_uses_split(member): + self._sched_aiv_dup.add(member_name) + + # Emit in dependency order: AIC/AIV leaves → Opaque/InCore → Group → Orchestration + for func in aic_aiv_funcs: + self._reset_var_scope() + self.visit_function(func) + for func in other_funcs: + self._reset_var_scope() self.visit_function(func) + for func in group_funcs: + self._reset_var_scope() + if func in scheduled_groups: + self._visit_scheduled_group_function(func) + else: + self._visit_group_function(func, program) + for func in orch_funcs: + self._reset_var_scope() + self.visit_function(func) + + _PIPE_OP_NAMES = ( + "tile.tpush_to_aiv", + "tile.tpush_to_aic", + "tile.tpop_from_aic", + "tile.tpop_from_aiv", + ) + + @classmethod + def _walk_pipe_calls(cls, func: _ir.Function): + """Yield every cross-core pipe ``Call`` node found anywhere in ``func``.""" + targets = set(cls._PIPE_OP_NAMES) + + def walk(node): + if isinstance(node, _ir.Call): + name = node.op.name if hasattr(node.op, "name") else None + if name in targets: + yield node + for a in node.args: + yield from walk(a) + for v in (node.kwargs or {}).values(): + yield from walk(v) + return + for attr in ("body", "expr", "value", "args", "kwargs", "stmts"): + if not hasattr(node, attr): + continue + child = getattr(node, attr) + if isinstance(child, dict): + for v in child.values(): + yield from walk(v) + elif isinstance(child, (list, tuple)): + for v in child: + yield from walk(v) + elif child is not None: + yield from walk(child) + + yield from walk(func.body) + + @classmethod + def _scan_pipe_ops(cls, func: _ir.Function) -> set[str]: + """Return the set of cross-core pipe op names found anywhere in ``func``.""" + return {call.op.name for call in cls._walk_pipe_calls(func)} + + @classmethod + def _function_uses_split(cls, func: _ir.Function) -> bool: + """Return True if any pipe op in ``func`` carries split>0.""" + for call in cls._walk_pipe_calls(func): + kw = call.kwargs or {} + split = kw.get("split", 0) + if isinstance(split, int) and split > 0: + return True + return False + + def _detect_scheduled_groups( + self, program: _ir.Program, group_funcs: list[_ir.Function] + ) -> set[_ir.Function]: + """Identify Group functions that need the cooperative scheduler. + + Trigger conditions (any one is sufficient): + * Bidirectional: members collectively contain both ``tpush_to_aiv`` + and ``tpush_to_aic``. + * Same-side feedback: a single member contains both a ``tpush`` and + a ``tpop`` on the *same* pipe direction (would deadlock self). + * Split>0 on any pipe op: split tiles need both AIV subblocks to + run, which only the scheduler models correctly. + + Single-direction Groups with simple linear push/pop and split=0 + continue to use the legacy sequential emission path (zero behavior + change). + """ + funcs_by_name = {f.name: f for f in program.functions.values()} + scheduled: set[_ir.Function] = set() + for grp in group_funcs: + members = [funcs_by_name[n] for n in _extract_group_member_names(grp) if n in funcs_by_name] + if not members: + continue + all_ops: set[str] = set() + same_side_feedback = False + uses_split = False + for m in members: + m_ops = self._scan_pipe_ops(m) + all_ops |= m_ops + # Same-side feedback within one member function: + if "tile.tpush_to_aiv" in m_ops and "tile.tpop_from_aic" in m_ops: + same_side_feedback = True + if "tile.tpush_to_aic" in m_ops and "tile.tpop_from_aiv" in m_ops: + same_side_feedback = True + if self._function_uses_split(m): + uses_split = True + bidirectional = "tile.tpush_to_aiv" in all_ops and "tile.tpush_to_aic" in all_ops + if bidirectional or same_side_feedback or uses_split: + scheduled.add(grp) + return scheduled + + def _visit_scheduled_group_function(self, func: _ir.Function) -> None: + """Emit a Group function in cooperative-scheduler form. + + Each AIC/AIV member call inside the Group body is converted into a + ``(name, generator)`` entry handed to ``_run_scheduler``. All other + statements (asserts, tensor wrap-up, etc.) are re-emitted verbatim + through the normal visitor so non-call setup still runs in order. + """ + params = [self._name_of(p) for p in func.params] + self._register_param_hints(func.params) + self._emit(f"def {func.name}({', '.join(params)}):") + self._indent += 1 + member_names = _extract_group_member_names(func) + if member_names: + self._emit(f"# Group (scheduled): {', '.join(member_names)}") + if self._check_shapes: + for p in func.params: + self._emit_shape_dtype_check(self._name_of(p), p.type, shape=False) + self._emit("_reset_pipes()") + + # Walk the (flattened) Group body and split call statements out into + # the scheduler tasks list. Anything else is emitted in place. + stmts = _ir.flatten_to_stmts(func.body) + task_lines: list[str] = [] + for stmt in stmts: + call = None + if isinstance(stmt, _ir.EvalStmt): + call = stmt.expr + elif isinstance(stmt, _ir.AssignStmt): + call = stmt.value + if ( + isinstance(call, _ir.Call) + and isinstance(call.op, _ir.GlobalVar) + and call.op.name in self._sched_funcs + ): + arg_strs = [self._visit_expr_str(a) for a in call.args] + fname = call.op.name + joined = ", ".join(arg_strs) + if fname in self._sched_aiv_dup: + # Schedule the AIV body twice, once per subblock id. + for sb in (0, 1): + task_lines.append(f' ("{fname}#{sb}", {fname}({joined}), {sb}),') + else: + task_lines.append(f' ("{fname}", {fname}({joined}), 0),') + else: + # Fall back to default per-statement emission for anything + # that is not an AIC/AIV member call (rare in well-formed + # Group bodies but kept for safety). + self.visit_stmt(stmt) + if task_lines: + self._emit("_run_scheduler([") + for line in task_lines: + self._lines.append(line) + self._emit("])") + else: + self._emit("pass") + self._indent -= 1 + self._emit("") + + def _register_param_hints(self, params) -> None: + """Register param ``name_hint`` -> emitted name in ``_stable_hints``. + + ``name_hint`` collisions are dropped (not registered) so that fresh + nanobind wrappers for ambiguously-named params fall back to the + counter-based uniquing path in ``_unique_name`` instead of silently + resolving to the wrong param. + """ + seen: dict[str, str] = {} + ambiguous: set[str] = set() + for p in params: + hint = p.name_hint + name = self._var_names[id(p)] + if hint in ambiguous: + continue + if hint in seen and seen[hint] != name: + ambiguous.add(hint) + self._stable_hints.pop(hint, None) + continue + seen[hint] = name + self._stable_hints[hint] = name + + def _visit_group_function(self, func: _ir.Function, _program: _ir.Program) -> None: + """Generate a Group function that calls its AIC+AIV members sequentially.""" + params = [self._name_of(p) for p in func.params] + # Register param hints as stable so nanobind wrapper GC doesn't + # break references to these vars inside the function body. + self._register_param_hints(func.params) + self._emit(f"def {func.name}({', '.join(params)}):") + self._indent += 1 + + member_names = _extract_group_member_names(func) + if member_names: + self._emit(f"# Group: {', '.join(member_names)}") + + if self._check_shapes: + for p in func.params: + self._emit_shape_dtype_check(self._name_of(p), p.type, shape=False) + + n_before = len(self._lines) + self.visit_stmt(func.body) + if len(self._lines) == n_before: + self._emit("pass") + self._indent -= 1 + self._emit("") def visit_function(self, func: _ir.Function) -> None: params = [self._name_of(p) for p in func.params] + # Register param hints as stable so nanobind wrapper GC doesn't + # break references to these vars inside the function body. + self._register_param_hints(func.params) self._emit(f"def {func.name}({', '.join(params)}):") self._indent += 1 + prev_name = self._current_func_name + self._current_func_name = func.name if self._check_shapes: for p in func.params: # InCore kernel params may receive partial data (boundary tiles), @@ -599,6 +1165,7 @@ def visit_function(self, func: _ir.Function) -> None: if len(self._lines) == n_before: self._emit("pass") self._indent -= 1 + self._current_func_name = prev_name self._emit("") # -- expression visitors -- @@ -657,7 +1224,10 @@ def visit_unary_expr(self, op: _ir.UnaryExpr) -> None: def visit_call(self, op: _ir.Call) -> None: op_name = op.op.name - handler = _OP_MAP.get(op_name) + if self._current_func_name in self._sched_funcs and op_name in _OP_MAP_SCHED: + handler = _OP_MAP_SCHED[op_name] + else: + handler = _OP_MAP.get(op_name) # Evaluate arguments arg_strs = [self._visit_expr_str(a) for a in op.args] @@ -892,10 +1462,14 @@ def torch_codegen(node: _ir.Program | _ir.Function, *, check_shapes: bool = Fals if isinstance(node, _ir.Program): cg.visit_program(node) + lines.append(cg.get_output()) + entry = _generate_entry_point(node) + if entry: + lines.append(entry) elif isinstance(node, _ir.Function): cg.visit_function(node) + lines.append(cg.get_output()) else: raise TypeError(f"torch_codegen expects Program or Function, got {type(node).__name__}") - lines.append(cg.get_output()) return "\n".join(lines) diff --git a/tests/st/codegen/test_torch_codegen_cross_core.py b/tests/st/codegen/test_torch_codegen_cross_core.py new file mode 100644 index 000000000..f96c6e564 --- /dev/null +++ b/tests/st/codegen/test_torch_codegen_cross_core.py @@ -0,0 +1,475 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""System tests for torch_codegen on hand-built cross-core (tpush/tpop) IR. + +Constructs minimal AIC + AIV + Group programs that directly use +``tile.tpush_to_aiv`` / ``tile.tpop_from_aic`` (V2C direction) and +``tile.tpush_to_aic`` / ``tile.tpop_from_aiv`` (C2V direction) for all three +split modes (NoSplit / UpDown / LeftRight), then compares torch_codegen output +against a hand-written torch golden. + +Why hand-built IR rather than ``pl.SplitMode.*`` lowering: + +* The ``tests/st/runtime/test_cross_core.py`` ST goes through the full lowering + pipeline (``pl.chunked_loop_optimizer`` → mixed-kernel expansion → tpush/tpop) + and validates the *end-to-end sim* output. It exercises tpush/tpop indirectly + but never compares torch_codegen against a reference per-tile. +* This test pins the torch_codegen contract for the four cross-core ops + (split=0/1/2 on each of the four ops) at the level the codegen actually + emits. Failures here point at the codegen helpers / split semantics + directly instead of at the lowering pipeline. +""" + +from collections import deque + +import pytest +import torch +from pypto import DataType, ir +from pypto.debug import torch_codegen + +# --------------------------------------------------------------------------- +# IR builders +# --------------------------------------------------------------------------- + + +def _span(): + return ir.Span.unknown() + + +def _tile(name: str, shape: list[int]) -> ir.Var: + return ir.Var(name, ir.TileType(shape, DataType.FP32), _span()) + + +def _tensor(name: str, shape: list[int]) -> ir.Var: + return ir.Var(name, ir.TensorType(shape, DataType.FP32), _span()) + + +def _int(val: int) -> ir.ConstInt: + return ir.ConstInt(val, DataType.INT64, _span()) + + +def _aiv_subblock_shape(full_shape: list[int], split: int) -> list[int]: + """Subblock shape that each AIV subblock processes after a tpop_from_aic.""" + if split == 1: # UpDown → row half + return [full_shape[0] // 2, full_shape[1]] + if split == 2: # LeftRight → column half + return [full_shape[0], full_shape[1] // 2] + return list(full_shape) + + +def _build_v2c_program(split: int): + """Build V2C: AIC matmul → tpush_to_aiv → AIV tpop + add residual → store. + + For ``split == 0`` the AIV consumer receives the full tile, adds residual, + and stores it. + + For ``split > 0`` the program runs in scheduler mode: AIC pushes the + matmul result split, both AIV subblocks pop their half and add the + matching half of residual (sliced via ``tile.get_subblock_idx``), then + push the half-result back to AIC. AIC pops with the same split (which + reassembles via ``cat``) and stores the full tile, so the final output + equals ``matmul(a, b) + residual`` regardless of split. + """ + full_shape = [4, 4] + + if split == 0: + # Legacy unidirectional path (no scheduler). + a = _tile("a", full_shape) + b = _tile("b", full_shape) + result = _tile("result", full_shape) + matmul_call = ir.create_op_call("tile.matmul", [a, b], _span()) + push_call = ir.create_op_call("tile.tpush_to_aiv", [result], {"split": split}, _span()) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(result, matmul_call, _span()), + ir.EvalStmt(push_call, _span()), + ], + _span(), + ) + aic_func = ir.Function("aic_matmul", [a, b], [], aic_body, _span(), type=ir.FunctionType.AIC) + + residual = _tile("residual", full_shape) + out = _tensor("out", full_shape) + pop_call = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": 0}, + ir.TileType(full_shape, DataType.FP32), + _span(), + ) + add_result = _tile("add_result", full_shape) + add_call = ir.create_op_call("tile.add", [pop_call, residual], _span()) + offsets = ir.MakeTuple([_int(0), _int(0)], _span()) + store_call = ir.create_op_call("tile.store", [add_result, offsets, out], _span()) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(add_result, add_call, _span()), + ir.EvalStmt(store_call, _span()), + ], + _span(), + ) + aiv_func = ir.Function( + "aiv_add_residual", + [residual, out], + [], + aiv_body, + _span(), + type=ir.FunctionType.AIV, + ) + + aic_call = ir.Call(ir.GlobalVar("aic_matmul"), [a, b], _span()) + aiv_call = ir.Call(ir.GlobalVar("aiv_add_residual"), [residual, out], _span()) + group_body = ir.SeqStmts( + [ir.EvalStmt(aic_call, _span()), ir.EvalStmt(aiv_call, _span())], + _span(), + ) + group_func = ir.Function( + "v2c_matmul_add_group", + [a, b, residual, out], + [], + group_body, + _span(), + type=ir.FunctionType.Group, + ) + program = ir.Program([aic_func, aiv_func, group_func], "v2c_cross_core_test", _span()) + + specs = [("a", full_shape), ("b", full_shape), ("residual", full_shape), ("out", full_shape)] + + def golden(tensors: dict[str, torch.Tensor]) -> None: + tensors["out"][:] = torch.matmul(tensors["a"], tensors["b"]) + tensors["residual"] + + return program, specs, golden + + # split > 0: bidirectional roundtrip with reassembly at AIC. + half_shape = _aiv_subblock_shape(full_shape, split) + + # AIC: matmul -> push split -> pop reassembled -> store full. + a = _tile("a", full_shape) + b = _tile("b", full_shape) + out = _tensor("out", full_shape) + matmul_call = ir.create_op_call("tile.matmul", [a, b], _span()) + result = _tile("result", full_shape) + push_call = ir.create_op_call("tile.tpush_to_aiv", [result], {"split": split}, _span()) + pop_back_call = ir.Call( + ir.get_op("tile.tpop_from_aiv"), + [], + {"split": split}, + ir.TileType(full_shape, DataType.FP32), + _span(), + ) + reassembled = _tile("reassembled", full_shape) + offsets = ir.MakeTuple([_int(0), _int(0)], _span()) + store_call = ir.create_op_call("tile.store", [reassembled, offsets, out], _span()) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(result, matmul_call, _span()), + ir.EvalStmt(push_call, _span()), + ir.AssignStmt(reassembled, pop_back_call, _span()), + ir.EvalStmt(store_call, _span()), + ], + _span(), + ) + aic_func = ir.Function("aic_matmul", [a, b, out], [], aic_body, _span(), type=ir.FunctionType.AIC) + + # AIV (runs once per subblock): pop half, add matching half of residual, + # push back. Per-subblock offset uses tile.get_subblock_idx. + if split == 1: + offset_dim0 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], _span()), + _int(half_shape[0]), + DataType.INT64, + _span(), + ) + slice_offsets = ir.MakeTuple([offset_dim0, _int(0)], _span()) + else: + offset_dim1 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], _span()), + _int(half_shape[1]), + DataType.INT64, + _span(), + ) + slice_offsets = ir.MakeTuple([_int(0), offset_dim1], _span()) + + residual_full = _tile("residual", full_shape) + pop_call = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": split}, + ir.TileType(half_shape, DataType.FP32), + _span(), + ) + popped_half = _tile("popped_half", half_shape) + half_shape_tuple = ir.MakeTuple([_int(d) for d in half_shape], _span()) + residual_half_call = ir.create_op_call( + "tile.slice", [residual_full, half_shape_tuple, slice_offsets], _span() + ) + residual_half = _tile("residual_half", half_shape) + add_half_call = ir.create_op_call("tile.add", [popped_half, residual_half], _span()) + add_half = _tile("add_half", half_shape) + push_back_call = ir.create_op_call("tile.tpush_to_aic", [add_half], {"split": split}, _span()) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(popped_half, pop_call, _span()), + ir.AssignStmt(residual_half, residual_half_call, _span()), + ir.AssignStmt(add_half, add_half_call, _span()), + ir.EvalStmt(push_back_call, _span()), + ], + _span(), + ) + aiv_func = ir.Function( + "aiv_add_residual", + [residual_full], + [], + aiv_body, + _span(), + type=ir.FunctionType.AIV, + ) + + aic_call = ir.Call(ir.GlobalVar("aic_matmul"), [a, b, out], _span()) + aiv_call = ir.Call(ir.GlobalVar("aiv_add_residual"), [residual_full], _span()) + group_body = ir.SeqStmts( + [ir.EvalStmt(aic_call, _span()), ir.EvalStmt(aiv_call, _span())], + _span(), + ) + group_func = ir.Function( + "v2c_matmul_add_group", + [a, b, residual_full, out], + [], + group_body, + _span(), + type=ir.FunctionType.Group, + ) + program = ir.Program([aic_func, aiv_func, group_func], "v2c_cross_core_test", _span()) + + specs = [("a", full_shape), ("b", full_shape), ("residual", full_shape), ("out", full_shape)] + + def golden(tensors: dict[str, torch.Tensor]) -> None: + tensors["out"][:] = torch.matmul(tensors["a"], tensors["b"]) + tensors["residual"] + + return program, specs, golden + + +def _build_c2v_program(split: int): + """Build C2V: AIV tpush_to_aic(a) → AIC tpop_from_aiv → matmul(b) → store. + + For ``split == 0`` the AIV pushes the full tile and AIC pops the full tile. + + For ``split > 0`` each AIV subblock pushes only its own slice of ``a`` + (selected via ``tile.get_subblock_idx`` and ``tile.slice``); AIC's + ``tpop_from_aiv`` with the same split reassembles the two halves into a + full tile via ``cat``. The matmul output therefore equals + ``matmul(a, b)`` for every split mode. + """ + full_shape = [4, 4] + + if split == 0: + a = _tile("a", full_shape) + push_call = ir.create_op_call("tile.tpush_to_aic", [a], {"split": 0}, _span()) + aiv_body = ir.SeqStmts([ir.EvalStmt(push_call, _span())], _span()) + aiv_func = ir.Function("aiv_push_a", [a], [], aiv_body, _span(), type=ir.FunctionType.AIV) + else: + # AIV (per-subblock): slice ``a`` along the split axis at offset + # subblock_idx * half_size, push that half to AIC. + half_shape = _aiv_subblock_shape(full_shape, split) + a = _tile("a", full_shape) + if split == 1: + offset_dim0 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], _span()), + _int(half_shape[0]), + DataType.INT64, + _span(), + ) + slice_offsets = ir.MakeTuple([offset_dim0, _int(0)], _span()) + else: + offset_dim1 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], _span()), + _int(half_shape[1]), + DataType.INT64, + _span(), + ) + slice_offsets = ir.MakeTuple([_int(0), offset_dim1], _span()) + half_shape_tuple = ir.MakeTuple([_int(d) for d in half_shape], _span()) + slice_call = ir.create_op_call("tile.slice", [a, half_shape_tuple, slice_offsets], _span()) + a_half = _tile("a_half", half_shape) + push_call = ir.create_op_call("tile.tpush_to_aic", [a_half], {"split": split}, _span()) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(a_half, slice_call, _span()), + ir.EvalStmt(push_call, _span()), + ], + _span(), + ) + aiv_func = ir.Function("aiv_push_a", [a], [], aiv_body, _span(), type=ir.FunctionType.AIV) + + # AIC: pop full tile (split-aware reassembly inside the helper); matmul; store. + b = _tile("b", full_shape) + out = _tensor("out", full_shape) + pop_call = ir.Call( + ir.get_op("tile.tpop_from_aiv"), + [], + {"split": split}, + ir.TileType(full_shape, DataType.FP32), + _span(), + ) + mm = _tile("mm", full_shape) + mm_call = ir.create_op_call("tile.matmul", [pop_call, b], _span()) + offsets = ir.MakeTuple([_int(0), _int(0)], _span()) + store_call = ir.create_op_call("tile.store", [mm, offsets, out], _span()) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(mm, mm_call, _span()), + ir.EvalStmt(store_call, _span()), + ], + _span(), + ) + aic_func = ir.Function( + "aic_matmul_b", + [b, out], + [], + aic_body, + _span(), + type=ir.FunctionType.AIC, + ) + + aiv_call = ir.Call(ir.GlobalVar("aiv_push_a"), [a], _span()) + aic_call = ir.Call(ir.GlobalVar("aic_matmul_b"), [b, out], _span()) + group_body = ir.SeqStmts( + [ir.EvalStmt(aiv_call, _span()), ir.EvalStmt(aic_call, _span())], + _span(), + ) + group_func = ir.Function( + "c2v_push_matmul_group", + [a, b, out], + [], + group_body, + _span(), + type=ir.FunctionType.Group, + ) + + program = ir.Program([aiv_func, aic_func, group_func], "c2v_cross_core_test", _span()) + + specs = [("a", full_shape), ("b", full_shape), ("out", full_shape)] + + def golden(tensors: dict[str, torch.Tensor]) -> None: + tensors["out"][:] = torch.matmul(tensors["a"], tensors["b"]) + + return program, specs, golden + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _build_inputs(specs: list[tuple[str, list[int]]], seed: int = 42) -> dict[str, torch.Tensor]: + torch.manual_seed(seed) + out: dict[str, torch.Tensor] = {} + for name, shape in specs: + if name == "out": + out[name] = torch.zeros(shape, dtype=torch.float32) + else: + out[name] = torch.randn(shape, dtype=torch.float32) + return out + + +def _run_codegen(program, specs, seed=42) -> dict[str, torch.Tensor]: + code = torch_codegen(program, check_shapes=False) + tensors = _build_inputs(specs, seed=seed) + ns: dict = {} + exec(code, ns) # noqa: S102 + # Reset pipe state per test (the generated code creates module-level deques). + ns["_pipes"] = {"to_aiv": deque(), "to_aic": deque()} + args = [tensors[name] for name, _ in specs] + assert "run" in ns and callable(ns["run"]), "torch_codegen entry point `run(...)` was not generated" + ns["run"](*args) + return tensors + + +def _run_golden(specs, golden, seed=42) -> dict[str, torch.Tensor]: + tensors = _build_inputs(specs, seed=seed) + golden(tensors) + return tensors + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +@pytest.mark.parametrize( + ("split", "label"), + [(0, "nosplit"), (1, "updown"), (2, "leftright")], +) +def test_v2c_tpush_tpop_codegen_vs_golden(split, label): + """V2C (tpush_to_aiv / tpop_from_aic) torch_codegen ≈ torch golden.""" + program, specs, golden = _build_v2c_program(split) + + code = torch_codegen(program, check_shapes=False) + # Pin codegen contract for the V2C ops at this split. + if split == 0: + # Legacy unidirectional path: direct helper calls. + assert "_tpush_to_aiv(_pipes['to_aiv'], result, 0)" in code, ( + f"[{label}] tpush_to_aiv split kwarg not forwarded" + ) + assert "_tpop_from_aic(_pipes['to_aiv'], 0)" in code, ( + f"[{label}] tpop_from_aic split kwarg not forwarded" + ) + else: + # Bidirectional split>0 path: scheduler-mode generator wrappers. + assert "_run_scheduler([" in code, f"[{label}] expected scheduler emission for split>0" + assert f"_tpush_to_aiv_g(_pipes['to_aiv'], result, {split})" in code + assert f"_tpop_from_aic_g(_pipes['to_aiv'], {split})" in code + assert f"_tpop_from_aiv_g(_pipes['to_aic'], {split})" in code + + cg = _run_codegen(program, specs) + gd = _run_golden(specs, golden) + diff = (cg["out"] - gd["out"]).abs() + assert torch.allclose(cg["out"], gd["out"], rtol=1e-5, atol=1e-5), ( + f"[{label}] V2C codegen vs golden max abs diff = {diff.max().item():.3e}" + ) + + +@pytest.mark.parametrize( + ("split", "label"), + [(0, "nosplit"), (1, "updown"), (2, "leftright")], +) +def test_c2v_tpush_tpop_codegen_vs_golden(split, label): + """C2V (tpush_to_aic / tpop_from_aiv) torch_codegen ≈ torch golden. + + For ``split == 0`` AIV pushes the full tile and AIC pops the full tile. + For ``split > 0`` each AIV subblock pushes its own slice of ``a`` and + AIC's ``tpop_from_aiv`` reassembles via ``cat``; the matmul output + therefore equals ``matmul(a, b)`` for every split mode. + """ + program, specs, golden = _build_c2v_program(split) + + code = torch_codegen(program, check_shapes=False) + if split == 0: + assert "_tpush_to_aic(_pipes['to_aic'], a, 0)" in code, ( + f"[{label}] tpush_to_aic split kwarg not forwarded" + ) + assert "_tpop_from_aiv(_pipes['to_aic'], 0)" in code, ( + f"[{label}] tpop_from_aiv split kwarg not forwarded" + ) + else: + assert "_run_scheduler([" in code, f"[{label}] expected scheduler emission for split>0" + assert f"_tpush_to_aic_g(_pipes['to_aic'], a_half, {split})" in code + assert f"_tpop_from_aiv_g(_pipes['to_aic'], {split})" in code + + cg = _run_codegen(program, specs) + gd = _run_golden(specs, golden) + diff = (cg["out"] - gd["out"]).abs() + assert torch.allclose(cg["out"], gd["out"], rtol=1e-5, atol=1e-5), ( + f"[{label}] C2V codegen vs golden max abs diff = {diff.max().item():.3e}" + ) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/st/codegen/test_torch_codegen_scheduler.py b/tests/st/codegen/test_torch_codegen_scheduler.py new file mode 100644 index 000000000..acbc9f5cd --- /dev/null +++ b/tests/st/codegen/test_torch_codegen_scheduler.py @@ -0,0 +1,417 @@ +# Copyright (c) PyPTO Contributors. +# This program is free software, you can redistribute it and/or modify it under the terms and conditions of +# CANN Open Software License Agreement Version 2.0 (the "License"). +# Please refer to the License for details. You may not use this file except in compliance with the License. +# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, +# INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. +# See LICENSE in the root of the software repository for the full text of the License. +# ----------------------------------------------------------------------------------------------------------- + +"""System tests for torch_codegen cooperative scheduler on bidirectional / loop+sync IR. + +Background +---------- +The original ``torch_codegen`` cross-core support modeled a pipe as a shared +``deque`` and emitted Group functions as ``aic_fn(); aiv_fn();`` (one side +fully runs before the other). That works for trivial unidirectional patterns +where pop never precedes its matching push, but it deadlocks (``IndexError`` +on empty deque) the moment the IR contains: + + * **bidirectional** flow (V→C and C→V on the same Group), or + * **same-side feedback** (a single function does both ``tpush`` and + ``tpop`` on the same pipe direction). + +This file pins the cooperative-scheduler path that handles those cases: + + 1. ``visit_program`` detects when a Group needs the scheduler. + 2. AIC/AIV members of such Groups are emitted as Python *generators* that + ``yield`` ``_WaitPop`` / ``_WaitPush`` at each pipe sync point. + 3. The Group function emits ``_run_scheduler([(name, gen, sb), ...])`` which + round-robin-advances each generator, suspending it on empty pipes and + resuming it when its peer pushes data. + +These tests construct hand-built IR (no compiler lowering involved) so that +failures point directly at the codegen / scheduler layer. +""" + +from collections import deque + +import pytest +import torch +from pypto import DataType, ir +from pypto.debug import torch_codegen + +# --------------------------------------------------------------------------- +# IR construction helpers +# --------------------------------------------------------------------------- + + +def _span(): + return ir.Span.unknown() + + +def _tile(name: str, shape: list[int]) -> ir.Var: + return ir.Var(name, ir.TileType(shape, DataType.FP32), _span()) + + +def _tensor(name: str, shape: list[int]) -> ir.Var: + return ir.Var(name, ir.TensorType(shape, DataType.FP32), _span()) + + +def _i(val: int) -> ir.ConstInt: + return ir.ConstInt(val, DataType.INT64, _span()) + + +# --------------------------------------------------------------------------- +# Bidirectional: AIV pushes A → AIC pops A, matmul(A, B), pushes result +# back → AIV pops result, adds residual, stores. +# +# Failure mode under the legacy emitter: +# Group runs aiv_fn() to completion first. After aiv_fn pushes A it +# immediately tries to pop from the to_aiv pipe (the result coming back +# from AIC), but AIC has not run yet → ``IndexError: pop from empty deque``. +# --------------------------------------------------------------------------- + + +def _build_bidirectional_program(): + shape = [4, 4] + + # AIV: push A, pop result, store(result + residual) + a = _tile("a", shape) + residual = _tile("residual", shape) + out = _tensor("out", shape) + push_a = ir.create_op_call("tile.tpush_to_aic", [a], {"split": 0}, _span()) + pop_result = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + add_r = _tile("add_r", shape) + add_call = ir.create_op_call("tile.add", [pop_result, residual], _span()) + offsets = ir.MakeTuple([_i(0), _i(0)], _span()) + store_call = ir.create_op_call("tile.store", [add_r, offsets, out], _span()) + aiv_body = ir.SeqStmts( + [ + ir.EvalStmt(push_a, _span()), + ir.AssignStmt(add_r, add_call, _span()), + ir.EvalStmt(store_call, _span()), + ], + _span(), + ) + aiv_func = ir.Function("aiv_fn", [a, residual, out], [], aiv_body, _span(), type=ir.FunctionType.AIV) + + # AIC: pop A, matmul(A, B), push result + b = _tile("b", shape) + pop_a = ir.Call( + ir.get_op("tile.tpop_from_aiv"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + mm = _tile("mm", shape) + mm_call = ir.create_op_call("tile.matmul", [pop_a, b], _span()) + push_mm = ir.create_op_call("tile.tpush_to_aiv", [mm], {"split": 0}, _span()) + aic_body = ir.SeqStmts([ir.AssignStmt(mm, mm_call, _span()), ir.EvalStmt(push_mm, _span())], _span()) + aic_func = ir.Function("aic_fn", [b], [], aic_body, _span(), type=ir.FunctionType.AIC) + + aiv_call = ir.Call(ir.GlobalVar("aiv_fn"), [a, residual, out], _span()) + aic_call = ir.Call(ir.GlobalVar("aic_fn"), [b], _span()) + group_body = ir.SeqStmts([ir.EvalStmt(aiv_call, _span()), ir.EvalStmt(aic_call, _span())], _span()) + group_func = ir.Function( + "bidir_grp", + [a, b, residual, out], + [], + group_body, + _span(), + type=ir.FunctionType.Group, + ) + + program = ir.Program([aiv_func, aic_func, group_func], "bidir_test", _span()) + specs = [ + ("a", shape), + ("b", shape), + ("residual", shape), + ("out", shape), + ] + + def golden(tensors): + tensors["out"][:] = torch.matmul(tensors["a"], tensors["b"]) + tensors["residual"] + + return program, specs, golden + + +# --------------------------------------------------------------------------- +# Same-side feedback: a single AIV function pushes to AIC then pops back from +# AIC on the same pipe direction. This forces same-side feedback detection +# (one function uses both tpush_to_aic and tpop_from_aiv) and exercises the +# scheduler under a single-task-pair feedback pattern. +# --------------------------------------------------------------------------- + + +def _build_same_side_feedback_program(): + """AIV does: push(A) -> AIC pops, doubles, pushes -> AIV pops, stores.""" + shape = [4, 4] + + a = _tile("a", shape) + out = _tensor("out", shape) + push_a = ir.create_op_call("tile.tpush_to_aic", [a], {"split": 0}, _span()) + pop_back = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + doubled = _tile("doubled", shape) + bind_doubled = ir.AssignStmt(doubled, pop_back, _span()) + offsets = ir.MakeTuple([_i(0), _i(0)], _span()) + store_call = ir.create_op_call("tile.store", [doubled, offsets, out], _span()) + aiv_body = ir.SeqStmts( + [ + ir.EvalStmt(push_a, _span()), + bind_doubled, + ir.EvalStmt(store_call, _span()), + ], + _span(), + ) + aiv_func = ir.Function("aiv_fb", [a, out], [], aiv_body, _span(), type=ir.FunctionType.AIV) + + pop_a = ir.Call( + ir.get_op("tile.tpop_from_aiv"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + captured = _tile("captured", shape) + twice = _tile("twice", shape) + twice_call = ir.create_op_call("tile.add", [captured, captured], _span()) + push_twice = ir.create_op_call("tile.tpush_to_aiv", [twice], {"split": 0}, _span()) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(captured, pop_a, _span()), + ir.AssignStmt(twice, twice_call, _span()), + ir.EvalStmt(push_twice, _span()), + ], + _span(), + ) + aic_func = ir.Function("aic_fb", [], [], aic_body, _span(), type=ir.FunctionType.AIC) + + aiv_call = ir.Call(ir.GlobalVar("aiv_fb"), [a, out], _span()) + aic_call = ir.Call(ir.GlobalVar("aic_fb"), [], _span()) + group_body = ir.SeqStmts([ir.EvalStmt(aiv_call, _span()), ir.EvalStmt(aic_call, _span())], _span()) + group_func = ir.Function( + "fb_grp", + [a, out], + [], + group_body, + _span(), + type=ir.FunctionType.Group, + ) + + program = ir.Program([aiv_func, aic_func, group_func], "feedback_test", _span()) + specs = [("a", shape), ("out", shape)] + + def golden(tensors): + tensors["out"][:] = tensors["a"] + tensors["a"] + + return program, specs, golden + + +# --------------------------------------------------------------------------- +# Test runner helpers +# --------------------------------------------------------------------------- + + +def _build_inputs(specs, seed=42): + torch.manual_seed(seed) + out: dict = {} + for name, shape in specs: + if name == "out": + out[name] = torch.zeros(shape, dtype=torch.float32) + else: + out[name] = torch.randn(shape, dtype=torch.float32) + return out + + +def _run_codegen(program, specs, seed=42): + code = torch_codegen(program, check_shapes=False) + tensors = _build_inputs(specs, seed=seed) + ns: dict = {} + exec(code, ns) # noqa: S102 + # Reset pipes (the generated code uses module-level deques). + ns["_pipes"] = {"to_aiv": deque(), "to_aic": deque()} + args = [tensors[name] for name, _ in specs] + assert "run" in ns and callable(ns["run"]), "torch_codegen entry point missing" + ns["run"](*args) + return tensors, code + + +def _run_golden(specs, golden, seed=42): + tensors = _build_inputs(specs, seed=seed) + golden(tensors) + return tensors + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +def test_bidirectional_emits_scheduler_and_matches_golden(): + """V↔C bidirectional Group must use the cooperative scheduler. + + Pins both the codegen contract (scheduler markers present, generator-style + helpers used) and the runtime numerical correctness against a torch + golden. Without the scheduler this case raises IndexError. + """ + program, specs, golden = _build_bidirectional_program() + tensors, code = _run_codegen(program, specs) + + # Codegen contract: scheduler-form markers must appear. + assert "_run_scheduler([" in code, "scheduler call not emitted for bidirectional Group" + assert "_reset_pipes()" in code, "pipe reset not emitted in scheduled Group" + assert "yield from _tpush_to_aic_g" in code, "AIV tpush_to_aic not in generator form" + assert "yield from _tpop_from_aiv_g" in code, "AIC tpop_from_aiv not in generator form" + assert "yield from _tpush_to_aiv_g" in code, "AIC tpush_to_aiv not in generator form" + assert "yield from _tpop_from_aic_g" in code, "AIV tpop_from_aic not in generator form" + + # Numerical correctness vs torch golden. + gd = _run_golden(specs, golden) + diff = (tensors["out"] - gd["out"]).abs().max().item() + assert torch.allclose(tensors["out"], gd["out"], rtol=1e-5, atol=1e-5), ( + f"bidirectional: max abs diff = {diff:.3e}" + ) + + +def test_same_side_feedback_uses_scheduler_and_matches_golden(): + """AIV does push+pop on the same pipe direction; needs the scheduler.""" + program, specs, golden = _build_same_side_feedback_program() + tensors, code = _run_codegen(program, specs) + + assert "_run_scheduler([" in code, "scheduler call not emitted for same-side feedback" + assert "yield from _tpush_to_aic_g" in code + assert "yield from _tpop_from_aic_g" in code + + gd = _run_golden(specs, golden) + diff = (tensors["out"] - gd["out"]).abs().max().item() + assert torch.allclose(tensors["out"], gd["out"], rtol=1e-5, atol=1e-5), ( + f"same-side feedback: max abs diff = {diff:.3e}" + ) + + +def test_unidirectional_v2c_keeps_legacy_emission(): + """Single-direction V→C (no feedback) must still use the legacy path. + + Regression guard: the scheduler must not be triggered for trivial + unidirectional Groups so that the existing 60+ codegen tests are + unaffected. + """ + shape = [4, 4] + a = _tile("a", shape) + b = _tile("b", shape) + result = _tile("result", shape) + matmul_call = ir.create_op_call("tile.matmul", [a, b], _span()) + push_call = ir.create_op_call("tile.tpush_to_aiv", [result], {"split": 0}, _span()) + aic_body = ir.SeqStmts( + [ir.AssignStmt(result, matmul_call, _span()), ir.EvalStmt(push_call, _span())], + _span(), + ) + aic_func = ir.Function("aic_only", [a, b], [], aic_body, _span(), type=ir.FunctionType.AIC) + + out = _tensor("out", shape) + pop_call = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + popped = _tile("popped", shape) + offsets = ir.MakeTuple([_i(0), _i(0)], _span()) + store_call = ir.create_op_call("tile.store", [popped, offsets, out], _span()) + aiv_body = ir.SeqStmts( + [ir.AssignStmt(popped, pop_call, _span()), ir.EvalStmt(store_call, _span())], + _span(), + ) + aiv_func = ir.Function("aiv_only", [out], [], aiv_body, _span(), type=ir.FunctionType.AIV) + + aic_call = ir.Call(ir.GlobalVar("aic_only"), [a, b], _span()) + aiv_call = ir.Call(ir.GlobalVar("aiv_only"), [out], _span()) + group_body = ir.SeqStmts([ir.EvalStmt(aic_call, _span()), ir.EvalStmt(aiv_call, _span())], _span()) + group_func = ir.Function("uni_grp", [a, b, out], [], group_body, _span(), type=ir.FunctionType.Group) + + prog = ir.Program([aic_func, aiv_func, group_func], "uni_test", _span()) + code = torch_codegen(prog, check_shapes=False) + + # The preamble defines _run_scheduler; check for the call-site form ([) only. + assert "_run_scheduler([" not in code, "unidirectional V2C must NOT trigger the scheduler" + assert "yield from _tpush_to_aiv_g" not in code, ( + "unidirectional V2C must use legacy non-generator helpers" + ) + # Legacy helpers still expected. + assert "_tpush_to_aiv(_pipes['to_aiv']," in code + assert "_tpop_from_aic(_pipes['to_aiv']," in code + + +def test_scheduler_deadlock_raises(): + """Two AIC/AIV functions that both pop before either pushes must deadlock.""" + shape = [4, 4] + a = _tile("a", shape) + b = _tile("b", shape) + out = _tensor("out", shape) + + pop_back = ir.Call( + ir.get_op("tile.tpop_from_aic"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + pop_v = _tile("pop_v", shape) + push_pop_v = ir.create_op_call("tile.tpush_to_aic", [pop_v], {"split": 0}, _span()) + aiv_body = ir.SeqStmts( + [ir.AssignStmt(pop_v, pop_back, _span()), ir.EvalStmt(push_pop_v, _span())], + _span(), + ) + aiv_func = ir.Function("aiv_dl", [a, out], [], aiv_body, _span(), type=ir.FunctionType.AIV) + + pop_fwd = ir.Call( + ir.get_op("tile.tpop_from_aiv"), + [], + {"split": 0}, + ir.TileType(shape, DataType.FP32), + _span(), + ) + pop_c = _tile("pop_c", shape) + push_pop_c = ir.create_op_call("tile.tpush_to_aiv", [pop_c], {"split": 0}, _span()) + aic_body = ir.SeqStmts( + [ir.AssignStmt(pop_c, pop_fwd, _span()), ir.EvalStmt(push_pop_c, _span())], + _span(), + ) + aic_func = ir.Function("aic_dl", [b], [], aic_body, _span(), type=ir.FunctionType.AIC) + + aiv_call = ir.Call(ir.GlobalVar("aiv_dl"), [a, out], _span()) + aic_call = ir.Call(ir.GlobalVar("aic_dl"), [b], _span()) + group_body = ir.SeqStmts([ir.EvalStmt(aiv_call, _span()), ir.EvalStmt(aic_call, _span())], _span()) + group_func = ir.Function("dl_grp", [a, b, out], [], group_body, _span(), type=ir.FunctionType.Group) + + prog = ir.Program([aiv_func, aic_func, group_func], "deadlock_test", _span()) + code = torch_codegen(prog, check_shapes=False) + assert "_run_scheduler([" in code, "deadlock test should still go through scheduler" + + ns: dict = {} + exec(code, ns) # noqa: S102 + ns["_pipes"] = {"to_aiv": deque(), "to_aic": deque()} + ta = torch.randn(*shape) + tb = torch.randn(*shape) + tout = torch.zeros(*shape) + with pytest.raises(RuntimeError, match="deadlock"): + ns["run"](ta, tb, tout) + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/ut/debug/test_torch_codegen.py b/tests/ut/debug/test_torch_codegen.py index 41c8f1889..67441c4ed 100644 --- a/tests/ut/debug/test_torch_codegen.py +++ b/tests/ut/debug/test_torch_codegen.py @@ -409,13 +409,13 @@ def test_system_ops_are_noops(): def test_pipe_ops(): - """tile.tpush/tpop should emit pipe simulation.""" + """tile.tpush/tpop should emit pipe simulation with split support.""" tile = _tile_var("tile", [64, 64]) push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) body = ir.EvalStmt(push_call, _span()) func = _simple_function("f", [tile], body) code = torch_codegen(func) - assert "_pipes['to_aiv'].append" in code + assert "_tpush_to_aiv(_pipes['to_aiv'], tile, 0)" in code # --------------------------------------------------------------------------- @@ -817,5 +817,977 @@ def test_variable_name_uniquing(): assert cg._unique_name("a") == "a_2" +# --------------------------------------------------------------------------- +# Test: split-aware tpush/tpop +# --------------------------------------------------------------------------- + + +def test_tpush_no_split(): + """tpush with split=0 should push whole tensor (backward compatible).""" + tile = _tile_var("tile", [64, 64]) + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + body = ir.EvalStmt(push_call, _span()) + func = _simple_function("f", [tile], body) + code = torch_codegen(func) + assert "_tpush_to_aiv(_pipes['to_aiv'], tile, 0)" in code + + +def test_tpush_updown_split(): + """tpush with split=1 (UpDown) should use the AIC->AIV split helper.""" + tile = _tile_var("tile", [64, 64]) + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 1}) + body = ir.EvalStmt(push_call, _span()) + func = _simple_function("f", [tile], body) + code = torch_codegen(func) + assert "_tpush_to_aiv(_pipes['to_aiv'], tile, 1)" in code + + +def test_tpush_leftright_split(): + """tpush with split=2 (LeftRight) should use the V2C helper.""" + tile = _tile_var("tile", [64, 64]) + push_call = _op_call("tile.tpush_to_aic", [tile], {"split": 2}) + body = ir.EvalStmt(push_call, _span()) + func = _simple_function("f", [tile], body) + code = torch_codegen(func) + assert "_tpush_to_aic(_pipes['to_aic'], tile, 2)" in code + + +def test_tpop_no_split(): + """tpop_from_aic with split=0 should pop whole tensor (backward compatible).""" + pop_call = _op_call("tile.tpop_from_aic", [], {"split": 0}) + out = _tile_var("out", [64, 64]) + assign = ir.AssignStmt(out, pop_call, _span()) + func = _simple_function("f", [], assign) + code = torch_codegen(func) + # tpop_from_aic reads from the AIC→AIV pipe ('to_aiv') + assert "_tpop_from_aic(_pipes['to_aiv'], 0)" in code + + +def test_tpop_updown_split(): + """tpop_from_aiv with split=1 should still use the full-tile V2C helper.""" + pop_call = _op_call("tile.tpop_from_aiv", [], {"split": 1}) + out = _tile_var("out", [64, 64]) + assign = ir.AssignStmt(out, pop_call, _span()) + func = _simple_function("f", [], assign) + code = torch_codegen(func) + # tpop_from_aiv reads from the AIV→AIC pipe ('to_aic') + assert "_tpop_from_aiv(_pipes['to_aic'], 1)" in code + + +def test_tpop_leftright_split(): + """tpop_from_aic with split=2 should use the split-aware C2V helper.""" + pop_call = _op_call("tile.tpop_from_aic", [], {"split": 2}) + out = _tile_var("out", [64, 64]) + assign = ir.AssignStmt(out, pop_call, _span()) + func = _simple_function("f", [], assign) + code = torch_codegen(func) + # tpop_from_aic reads from the AIC→AIV pipe ('to_aiv') + assert "_tpop_from_aic(_pipes['to_aiv'], 2)" in code + + +# --------------------------------------------------------------------------- +# Test: numerical roundtrip for tpush/tpop with split +# --------------------------------------------------------------------------- + + +def test_numerical_roundtrip_tpush_tpop_no_split(): + """End-to-end: tpush then tpop with no split should preserve data.""" + tile = _tile_var("tile", [4, 4]) + out = _tile_var("out", [4, 4]) + + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + pop_call = _op_call("tile.tpop_from_aic", [], {"split": 0}) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push_call, _span()), + ir.AssignStmt(out, pop_call, _span()), + ir.ReturnStmt([out], _span()), + ], + _span(), + ) + func = _simple_function("f", [tile], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + result = ns["f"](t) + assert torch.allclose(result, t) + + +def test_numerical_roundtrip_tpush_tpop_updown_split(): + """Legacy non-Group roundtrip: tpush(split=UpDown) + tpop reassembles full tile.""" + tile = _tile_var("tile", [4, 4]) + out = _tile_var("out", [4, 4]) + + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 1}) + pop_call = _op_call("tile.tpop_from_aic", [], {"split": 1}) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push_call, _span()), + ir.AssignStmt(out, pop_call, _span()), + ir.ReturnStmt([out], _span()), + ], + _span(), + ) + func = _simple_function("f", [tile], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + result = ns["f"](t) + assert torch.allclose(result, t) + + +def test_numerical_roundtrip_tpush_tpop_leftright_split(): + """Legacy non-Group roundtrip: V2C tpush+tpop with split kwarg returns full tile. + + Outside the scheduler we do not model two AIV subblocks pushing + independent halves; the legacy single-subblock path pushes the full + tile and pops the full tile, so split is informational only. + """ + tile = _tile_var("tile", [4, 4]) + out = _tile_var("out", [4, 4]) + + push_call = _op_call("tile.tpush_to_aic", [tile], {"split": 2}) + pop_call = _op_call("tile.tpop_from_aiv", [], {"split": 2}) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push_call, _span()), + ir.AssignStmt(out, pop_call, _span()), + ir.ReturnStmt([out], _span()), + ], + _span(), + ) + func = _simple_function("f", [tile], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + result = ns["f"](t) + assert torch.allclose(result, t) + + +# --------------------------------------------------------------------------- +# Test: cross-core program simulation +# --------------------------------------------------------------------------- + + +def test_program_with_aic_aiv_functions(): + """Program with AIC+AIV functions should emit both function definitions.""" + span = _span() + + # AIC function: takes tile, pushes to AIV + tile = _tile_var("tile", [4, 4]) + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + aic_body = ir.EvalStmt(push_call, span) + aic_func = ir.Function("aic_compute", [tile], [], aic_body, span, type=ir.FunctionType.AIC) + + # AIV function: pops from AIV, adds 1, returns + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call(pop_op, [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + out = _tile_var("out", [4, 4]) + add_call = _op_call("tile.adds", [pop_call, _float(1.0)]) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(out, add_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + aiv_func = ir.Function( + "aiv_compute", [], [ir.TileType([4, 4], DataType.FP32)], aiv_body, span, type=ir.FunctionType.AIV + ) + + prog = _program([aic_func, aiv_func]) + code = torch_codegen(prog) + + assert "def aic_compute" in code + assert "def aiv_compute" in code + assert "_tpush_to_aiv" in code + assert "_tpop_from_aic" in code + + +def test_program_with_group_function(): + """Program with Group function should emit coordinated AIC+AIV calls.""" + span = _span() + + # AIC function + tile = _tile_var("tile", [4, 4]) + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + aic_body = ir.EvalStmt(push_call, span) + aic_func = ir.Function("aic_kernel", [tile], [], aic_body, span, type=ir.FunctionType.AIC) + + # AIV function + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call(pop_op, [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + out = _tile_var("out", [4, 4]) + add_call = _op_call("tile.adds", [pop_call, _float(1.0)]) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(out, add_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + aiv_func = ir.Function( + "aiv_kernel", [], [ir.TileType([4, 4], DataType.FP32)], aiv_body, span, type=ir.FunctionType.AIV + ) + + # Group function that calls both + aic_gv = ir.GlobalVar("aic_kernel") + aiv_gv = ir.GlobalVar("aiv_kernel") + group_body = ir.SeqStmts( + [ + ir.EvalStmt(ir.Call(aic_gv, [tile], span), span), + ir.EvalStmt(ir.Call(aiv_gv, [], span), span), + ], + span, + ) + group_func = ir.Function( + "my_group", [tile], [ir.TileType([4, 4], DataType.FP32)], group_body, span, type=ir.FunctionType.Group + ) + + prog = _program([aic_func, aiv_func, group_func]) + code = torch_codegen(prog) + + assert "def aic_kernel" in code + assert "def aiv_kernel" in code + assert "def my_group" in code + assert "# Group:" in code + + +def test_program_with_multiple_group_functions(): + """Program with multiple Group functions should emit all with isolated variable scopes.""" + span = _span() + + # === Group 1: matmul pipeline === + # AIC function for group 1 + tile1 = _tile_var("tile", [4, 4]) # same name as tile2, but different scope + push_call1 = _op_call("tile.tpush_to_aiv", [tile1], {"split": 0}) + aic_body1 = ir.EvalStmt(push_call1, span) + aic_func1 = ir.Function("aic_matmul", [tile1], [], aic_body1, span, type=ir.FunctionType.AIC) + + # AIV function for group 1 + pop_call1 = ir.Call( + ir.get_op("tile.tpop_from_aic"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + out1 = _tile_var("out", [4, 4]) + add_call1 = _op_call("tile.adds", [pop_call1, _float(1.0)]) + aiv_body1 = ir.SeqStmts( + [ir.AssignStmt(out1, add_call1, span), ir.ReturnStmt([out1], span)], + span, + ) + aiv_func1 = ir.Function( + "aiv_gelu", [], [ir.TileType([4, 4], DataType.FP32)], aiv_body1, span, type=ir.FunctionType.AIV + ) + + # Group 1 function + group_body1 = ir.SeqStmts( + [ + ir.EvalStmt(ir.Call(ir.GlobalVar("aic_matmul"), [tile1], span), span), + ir.AssignStmt(out1, ir.Call(ir.GlobalVar("aiv_gelu"), [], span), span), + ir.ReturnStmt([out1], span), + ], + span, + ) + group_func1 = ir.Function( + "matmul_pipeline", + [tile1], + [ir.TileType([4, 4], DataType.FP32)], + group_body1, + span, + type=ir.FunctionType.Group, + ) + + # === Group 2: activation pipeline (AIV pushes to AIC, AIC consumes) === + # AIV function for group 2 + tile2 = _tile_var("tile", [4, 4]) # same name as tile1 + push_call2 = _op_call("tile.tpush_to_aic", [tile2], {"split": 0}) + aiv_body2 = ir.EvalStmt(push_call2, span) + aiv_func2 = ir.Function("aiv_activation", [tile2], [], aiv_body2, span, type=ir.FunctionType.AIV) + + # AIC function for group 2 + pop_call2 = ir.Call( + ir.get_op("tile.tpop_from_aiv"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + out2 = _tile_var("out", [4, 4]) + mul_call2 = _op_call("tile.muls", [pop_call2, _float(2.0)]) + aic_body2 = ir.SeqStmts( + [ir.AssignStmt(out2, mul_call2, span), ir.ReturnStmt([out2], span)], + span, + ) + aic_func2 = ir.Function( + "aic_norm", [], [ir.TileType([4, 4], DataType.FP32)], aic_body2, span, type=ir.FunctionType.AIC + ) + + # Group 2 function + group_body2 = ir.SeqStmts( + [ + ir.EvalStmt(ir.Call(ir.GlobalVar("aiv_activation"), [tile2], span), span), + ir.AssignStmt(out2, ir.Call(ir.GlobalVar("aic_norm"), [], span), span), + ir.ReturnStmt([out2], span), + ], + span, + ) + group_func2 = ir.Function( + "activation_pipeline", + [tile2], + [ir.TileType([4, 4], DataType.FP32)], + group_body2, + span, + type=ir.FunctionType.Group, + ) + + # Build program with both groups + prog = _program([aic_func1, aiv_func1, aiv_func2, aic_func2, group_func1, group_func2]) + code = torch_codegen(prog) + + # Verify all functions are generated + assert "def aic_matmul" in code + assert "def aiv_gelu" in code + assert "def aiv_activation" in code + assert "def aic_norm" in code + assert "def matmul_pipeline" in code + assert "def activation_pipeline" in code + + # Verify both Group comments are present + assert code.count("# Group:") == 2 + + # Execute and verify numerical correctness for both groups + ns: dict = {} + exec(code, ns) # noqa: S102 + + t = torch.randn(4, 4) + + # Group 1: adds 1.0 + result1 = ns["matmul_pipeline"](t) + assert torch.allclose(result1, t + 1.0, atol=1e-6) + + # Group 2: multiplies by 2.0 + result2 = ns["activation_pipeline"](t) + assert torch.allclose(result2, t * 2.0, atol=1e-6) + + +def test_program_emits_entry_point(): + """Program codegen should emit a run() entry point for Opaque functions.""" + span = _span() + + a = _tensor_var("a", [4, 4]) + ret = ir.ReturnStmt([a], span) + func = ir.Function( + "main", [a], [ir.TensorType([4, 4], DataType.FP32)], ret, span, type=ir.FunctionType.Opaque + ) + prog = _program([func]) + code = torch_codegen(prog) + + assert "def run(a):" in code + assert "return main(a)" in code + + +def test_program_entry_point_sanitizes_parameter_names(): + """Program entry point should sanitize/unique parameter names like function codegen.""" + span = _span() + + class_kw = _tensor_var("class", [4, 4]) + duplicate = _tensor_var("class", [4, 4]) + ret = ir.ReturnStmt([class_kw], span) + func = ir.Function( + "main", + [class_kw, duplicate], + [ir.TensorType([4, 4], DataType.FP32)], + ret, + span, + type=ir.FunctionType.Opaque, + ) + prog = _program([func]) + code = torch_codegen(prog) + + assert "def main(class_v, class_v_1):" in code + assert "def run(class_v, class_v_1):" in code + compile(code, "", "exec") + + +def test_program_entry_point_prefers_orchestration(): + """Program entry point should prefer Orchestration function over Opaque.""" + span = _span() + + a = _tensor_var("a", [4, 4]) + + opaque_ret = ir.ReturnStmt([a], span) + opaque_func = ir.Function( + "helper", [a], [ir.TensorType([4, 4], DataType.FP32)], opaque_ret, span, type=ir.FunctionType.Opaque + ) + + orch_ret = ir.ReturnStmt([a], span) + orch_func = ir.Function( + "orch_main", + [a], + [ir.TensorType([4, 4], DataType.FP32)], + orch_ret, + span, + type=ir.FunctionType.Orchestration, + ) + + prog = _program([opaque_func, orch_func]) + code = torch_codegen(prog) + + assert "return orch_main(a)" in code + + +# --------------------------------------------------------------------------- +# Test: qwen3-style cross-core precision verification +# --------------------------------------------------------------------------- + + +def test_numerical_cross_core_matmul_residual(): + """End-to-end cross-core: AIC does matmul, pushes to AIV, AIV adds residual. + + Simulates the qwen3 pattern: output = matmul(a, b), then result = output + residual. + """ + span = _span() + + # AIC function: matmul then tpush + a = _tile_var("a", [4, 4]) + b = _tile_var("b", [4, 4]) + matmul_call = _op_call("tile.matmul", [a, b]) + result = _tile_var("result", [4, 4]) + push_call = _op_call("tile.tpush_to_aiv", [result], {"split": 0}) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(result, matmul_call, span), + ir.EvalStmt(push_call, span), + ], + span, + ) + aic_func = ir.Function("aic_matmul", [a, b], [], aic_body, span, type=ir.FunctionType.AIC) + + # AIV function: tpop then add residual + residual = _tile_var("residual", [4, 4]) + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call(pop_op, [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + out = _tile_var("out", [4, 4]) + add_call = _op_call("tile.add", [pop_call, residual]) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(out, add_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + aiv_func = ir.Function( + "aiv_add_residual", + [residual], + [ir.TileType([4, 4], DataType.FP32)], + aiv_body, + span, + type=ir.FunctionType.AIV, + ) + + # Group function: calls AIC then AIV + aic_call = ir.Call(ir.GlobalVar("aic_matmul"), [a, b], span) + aiv_call = ir.Call(ir.GlobalVar("aiv_add_residual"), [residual], span) + group_body = ir.SeqStmts( + [ + ir.EvalStmt(aic_call, span), + ir.AssignStmt(out, aiv_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + group_func = ir.Function( + "matmul_add_group", + [a, b, residual], + [ir.TileType([4, 4], DataType.FP32)], + group_body, + span, + type=ir.FunctionType.Group, + ) + + prog = _program([aic_func, aiv_func, group_func]) + code = torch_codegen(prog) + + # Execute and verify + ns: dict = {} + exec(code, ns) # noqa: S102 + + t_a = torch.randn(4, 4) + t_b = torch.randn(4, 4) + t_residual = torch.randn(4, 4) + + result_val = ns["matmul_add_group"](t_a, t_b, t_residual) + expected = torch.matmul(t_a, t_b).float() + t_residual + assert torch.allclose(result_val, expected, atol=1e-5) + + +def _build_cross_core_matmul_residual_program(split: int) -> ir.Program: + """Build a cross-core Program: AIC matmul → tpush → AIV tpop + add residual. + + For split == 0: AIV pops the full tile, adds residual, returns it; the + Group binds the AIV return to ``out`` and returns it. For split > 0: + the kernel is restructured as a bidirectional roundtrip — AIC pushes the + matmul output split, each AIV subblock pops its half, adds the matching + half of residual (sliced via ``tile.get_subblock_idx``), pushes the + half-result back to AIC; AIC pops with the same split and reassembles + the full tile, storing into the ``out`` output tensor. The Group has + no return value in this mode (``out`` is an output param). + """ + span = _span() + + if split == 0: + a = _tile_var("a", [4, 4]) + b = _tile_var("b", [4, 4]) + matmul_call = _op_call("tile.matmul", [a, b]) + result = _tile_var("result", [4, 4]) + push_call = _op_call("tile.tpush_to_aiv", [result], {"split": 0}) + aic_body = ir.SeqStmts( + [ir.AssignStmt(result, matmul_call, span), ir.EvalStmt(push_call, span)], + span, + ) + aic_func = ir.Function("aic_matmul", [a, b], [], aic_body, span, type=ir.FunctionType.AIC) + + residual = _tile_var("residual", [4, 4]) + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call(pop_op, [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + out = _tile_var("out", [4, 4]) + add_call = _op_call("tile.add", [pop_call, residual]) + aiv_body = ir.SeqStmts( + [ir.AssignStmt(out, add_call, span), ir.ReturnStmt([out], span)], + span, + ) + aiv_func = ir.Function( + "aiv_add_residual", + [residual], + [ir.TileType([4, 4], DataType.FP32)], + aiv_body, + span, + type=ir.FunctionType.AIV, + ) + + aic_call = ir.Call(ir.GlobalVar("aic_matmul"), [a, b], span) + aiv_call = ir.Call(ir.GlobalVar("aiv_add_residual"), [residual], span) + group_body = ir.SeqStmts( + [ + ir.EvalStmt(aic_call, span), + ir.AssignStmt(out, aiv_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + group_func = ir.Function( + "matmul_add_group", + [a, b, residual], + [ir.TileType([4, 4], DataType.FP32)], + group_body, + span, + type=ir.FunctionType.Group, + ) + return _program([aic_func, aiv_func, group_func]) + + # split > 0: bidirectional roundtrip with reassembly at AIC. Output is a + # full-shape tensor written via tile.store at offset (0, 0). + full_shape = [4, 4] + + # AIC: result = matmul(a, b); tpush_to_aiv(result, split); reassembled = + # tpop_from_aiv(split); store(reassembled, (0,0), out_tensor) + a = _tile_var("a", full_shape) + b = _tile_var("b", full_shape) + out_tensor = _tensor_var("out_tensor", full_shape) + matmul_call = _op_call("tile.matmul", [a, b]) + result = _tile_var("result", full_shape) + push_call = _op_call("tile.tpush_to_aiv", [result], {"split": split}) + pop_back_op = ir.get_op("tile.tpop_from_aiv") + pop_back_call = ir.Call( + pop_back_op, + [], + {"split": split}, + ir.TileType(full_shape, DataType.FP32), + span, + ) + reassembled = _tile_var("reassembled", full_shape) + offsets_zero = ir.MakeTuple([_int(0), _int(0)], span) + store_call = _op_call("tile.store", [reassembled, offsets_zero, out_tensor]) + aic_body = ir.SeqStmts( + [ + ir.AssignStmt(result, matmul_call, span), + ir.EvalStmt(push_call, span), + ir.AssignStmt(reassembled, pop_back_call, span), + ir.EvalStmt(store_call, span), + ], + span, + ) + aic_func = ir.Function( + "aic_matmul", + [a, b, out_tensor], + [], + aic_body, + span, + type=ir.FunctionType.AIC, + ) + + # AIV body (runs once per subblock): pop half from AIC, add the matching + # half of residual sliced by subblock_idx, push back to AIC. + if split == 1: # UpDown + half_shape = [full_shape[0] // 2, full_shape[1]] + # offset along dim 0 = subblock_idx * (H/2) + offset_dim0 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], span), + _int(half_shape[0]), + DataType.INT64, + span, + ) + slice_offsets = ir.MakeTuple([offset_dim0, _int(0)], span) + else: # LeftRight + half_shape = [full_shape[0], full_shape[1] // 2] + offset_dim1 = ir.Mul( + ir.create_op_call("tile.get_subblock_idx", [], span), + _int(half_shape[1]), + DataType.INT64, + span, + ) + slice_offsets = ir.MakeTuple([_int(0), offset_dim1], span) + + residual_full = _tile_var("residual", full_shape) + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call( + pop_op, + [], + {"split": split}, + ir.TileType(half_shape, DataType.FP32), + span, + ) + popped_half = _tile_var("popped_half", half_shape) + half_shape_tuple = ir.MakeTuple([_int(d) for d in half_shape], span) + residual_half_call = _op_call("tile.slice", [residual_full, half_shape_tuple, slice_offsets]) + residual_half = _tile_var("residual_half", half_shape) + add_call = _op_call("tile.add", [popped_half, residual_half]) + add_half = _tile_var("add_half", half_shape) + push_back_call = _op_call("tile.tpush_to_aic", [add_half], {"split": split}) + aiv_body = ir.SeqStmts( + [ + ir.AssignStmt(popped_half, pop_call, span), + ir.AssignStmt(residual_half, residual_half_call, span), + ir.AssignStmt(add_half, add_call, span), + ir.EvalStmt(push_back_call, span), + ], + span, + ) + aiv_func = ir.Function( + "aiv_add_residual", + [residual_full], + [], + aiv_body, + span, + type=ir.FunctionType.AIV, + ) + + aic_call = ir.Call(ir.GlobalVar("aic_matmul"), [a, b, out_tensor], span) + aiv_call = ir.Call(ir.GlobalVar("aiv_add_residual"), [residual_full], span) + group_body = ir.SeqStmts( + [ + ir.EvalStmt(aic_call, span), + ir.EvalStmt(aiv_call, span), + ], + span, + ) + group_func = ir.Function( + "matmul_add_group", + [a, b, residual_full, out_tensor], + [], + group_body, + span, + type=ir.FunctionType.Group, + ) + return _program([aic_func, aiv_func, group_func]) + + +def test_numerical_cross_core_matmul_residual_updown_split(): + """Cross-core matmul+residual with UpDown split=1, full-tile correctness. + + The Group runs as a bidirectional scheduler: AIC pushes split=1, both AIV + subblocks pop their halves and add the matching half of residual, push + back, AIC reassembles into the full output. + """ + prog = _build_cross_core_matmul_residual_program(split=1) + code = torch_codegen(prog) + + ns: dict = {} + exec(code, ns) # noqa: S102 + + t_a = torch.randn(4, 4) + t_b = torch.randn(4, 4) + t_residual = torch.randn(4, 4) + t_out = torch.zeros(4, 4) + + ns["matmul_add_group"](t_a, t_b, t_residual, t_out) + expected = torch.matmul(t_a, t_b).float() + t_residual + assert torch.allclose(t_out, expected, atol=1e-5) + + +def test_numerical_cross_core_matmul_residual_leftright_split(): + """Cross-core matmul+residual with LeftRight split=2, full-tile correctness.""" + prog = _build_cross_core_matmul_residual_program(split=2) + code = torch_codegen(prog) + + ns: dict = {} + exec(code, ns) # noqa: S102 + + t_a = torch.randn(4, 4) + t_b = torch.randn(4, 4) + t_residual = torch.randn(4, 4) + t_out = torch.zeros(4, 4) + + ns["matmul_add_group"](t_a, t_b, t_residual, t_out) + expected = torch.matmul(t_a, t_b).float() + t_residual + assert torch.allclose(t_out, expected, atol=1e-5) + + +def test_nested_tpop_in_expression(): + """tpop used directly inside tile.add (nested Call) should work via dispatch workaround.""" + span = _span() + + tile = _tile_var("tile", [4, 4]) + residual = _tile_var("residual", [4, 4]) + out = _tile_var("out", [4, 4]) + + # Build: out = tile.add(tpop_from_aic(), residual) — tpop nested inside add + pop_op = ir.get_op("tile.tpop_from_aic") + pop_call = ir.Call(pop_op, [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + add_call = _op_call("tile.add", [pop_call, residual]) + + # Push first, then pop inside add + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + body = ir.SeqStmts( + [ + ir.EvalStmt(push_call, span), + ir.AssignStmt(out, add_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + func = _simple_function("f", [tile, residual], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + # Verify nested tpop is generated correctly + assert "_tpop_from_aic(" in code + + # Execute and verify numerical correctness + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + r = torch.randn(4, 4) + result_val = ns["f"](t, r) + assert torch.allclose(result_val, t + r, atol=1e-6) + + +def test_bidirectional_pipe_communication(): + """AIC → tpush_to_aiv → AIV tpop_from_aic → tpush_to_aic → AIC tpop_from_aiv: bidirectional pipes.""" + span = _span() + + tile = _tile_var("tile", [4, 4]) + out = _tile_var("out", [4, 4]) + + # Step 1: push to to_aiv pipe (AIC→AIV direction) + push_to_aiv = _op_call("tile.tpush_to_aiv", [tile], {"split": 0}) + # Step 2: pop from to_aiv pipe (AIV reads data from AIC) + pop_from_aic = ir.Call( + ir.get_op("tile.tpop_from_aic"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + mid = _tile_var("mid", [4, 4]) + # Step 3: push to to_aic pipe (AIV→AIC direction) + push_to_aic = _op_call("tile.tpush_to_aic", [mid], {"split": 0}) + # Step 4: pop from to_aic pipe (AIC reads data from AIV) + pop_from_aiv = ir.Call( + ir.get_op("tile.tpop_from_aiv"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push_to_aiv, span), + ir.AssignStmt(mid, pop_from_aic, span), + ir.EvalStmt(push_to_aic, span), + ir.AssignStmt(out, pop_from_aiv, span), + ir.ReturnStmt([out], span), + ], + span, + ) + func = _simple_function("f", [tile], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + assert "_pipes['to_aiv']" in code + assert "_pipes['to_aic']" in code + + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + result_val = ns["f"](t) + assert torch.allclose(result_val, t, atol=1e-6) + + +def test_pipe_empty_after_balanced_pushpop(): + """After pushing N times and popping N times, the pipe should be empty.""" + span = _span() + + t1 = _tile_var("t1", [4, 4]) + t2 = _tile_var("t2", [4, 4]) + o1 = _tile_var("o1", [4, 4]) + o2 = _tile_var("o2", [4, 4]) + + push1 = _op_call("tile.tpush_to_aiv", [t1], {"split": 0}) + push2 = _op_call("tile.tpush_to_aiv", [t2], {"split": 0}) + pop1 = ir.Call( + ir.get_op("tile.tpop_from_aic"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + pop2 = ir.Call( + ir.get_op("tile.tpop_from_aic"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span + ) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push1, span), + ir.EvalStmt(push2, span), + ir.AssignStmt(o1, pop1, span), + ir.AssignStmt(o2, pop2, span), + ir.ReturnStmt([o1, o2], span), + ], + span, + ) + func = _simple_function( + "f", [t1, t2], body, [ir.TileType([4, 4], DataType.FP32), ir.TileType([4, 4], DataType.FP32)] + ) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + a = torch.randn(4, 4) + b = torch.randn(4, 4) + r1, r2 = ns["f"](a, b) + + # FIFO order: first pushed = first popped + assert torch.allclose(r1, a, atol=1e-6) + assert torch.allclose(r2, b, atol=1e-6) + + # Pipe should be empty after balanced push/pop + assert len(ns["_pipes"]["to_aiv"]) == 0 + + +def test_tpop_from_aiv_split_keeps_full_tile_with_shape_checks(): + """AIC-side tpop_from_aiv must keep full-tile shape under split mode.""" + span = _span() + + tile = _tile_var("tile", [4, 4]) + out = _tile_var("out", [4, 4]) + + push_call = _op_call("tile.tpush_to_aic", [tile], {"split": 1}) + pop_call = ir.Call( + ir.get_op("tile.tpop_from_aiv"), [], {"split": 1}, ir.TileType([4, 4], DataType.FP32), span + ) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push_call, span), + ir.AssignStmt(out, pop_call, span), + ir.ReturnStmt([out], span), + ], + span, + ) + func = _simple_function("f", [tile], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func, check_shapes=True) + + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + result = ns["f"](t) + assert torch.allclose(result, t, atol=1e-6) + + +# --------------------------------------------------------------------------- +# Test: edge cases and error handling +# --------------------------------------------------------------------------- + + +def test_invalid_split_mode_fallback(): + """Invalid split_mode (not 0, 1, 2) should fallback to no-split behavior.""" + span = _span() + + # Test split=3 (invalid) + tile = _tile_var("tile", [4, 4]) + push_call = _op_call("tile.tpush_to_aiv", [tile], {"split": 3}) + body = ir.EvalStmt(push_call, span) + func = _simple_function("f", [tile], body) + code = torch_codegen(func) + + # Should fallback to no-split: push 1 chunk + ns: dict = {} + exec(code, ns) # noqa: S102 + t = torch.randn(4, 4) + ns["f"](t) + # Pipe should have 1 element (not 2 like split=1 or split=2) + assert len(ns["_pipes"]["to_aiv"]) == 1 + + +def test_pop_from_empty_pipe_raises(): + """Popping from an empty pipe should raise IndexError.""" + span = _span() + + # Function that only pops without pushing + pop_call = _op_call("tile.tpop_from_aic", [], {"split": 0}) + out = _tile_var("out", [4, 4]) + body = ir.AssignStmt(out, pop_call, span) + func = _simple_function("f", [], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + + # Should raise IndexError when popping from empty pipe + with pytest.raises(IndexError, match="pop from an empty"): + ns["f"]() + + +def test_unbalanced_push_pop_pipe_state(): + """Push 2 times, pop 1 time: pipe should have 1 element remaining.""" + span = _span() + + t1 = _tile_var("t1", [4, 4]) + t2 = _tile_var("t2", [4, 4]) + out = _tile_var("out", [4, 4]) + + push1 = _op_call("tile.tpush_to_aiv", [t1], {"split": 0}) + push2 = _op_call("tile.tpush_to_aiv", [t2], {"split": 0}) + pop = ir.Call(ir.get_op("tile.tpop_from_aic"), [], {"split": 0}, ir.TileType([4, 4], DataType.FP32), span) + + body = ir.SeqStmts( + [ + ir.EvalStmt(push1, span), + ir.EvalStmt(push2, span), + ir.AssignStmt(out, pop, span), + ir.ReturnStmt([out], span), + ], + span, + ) + func = _simple_function("f", [t1, t2], body, [ir.TileType([4, 4], DataType.FP32)]) + code = torch_codegen(func) + + ns: dict = {} + exec(code, ns) # noqa: S102 + a = torch.randn(4, 4) + b = torch.randn(4, 4) + result = ns["f"](a, b) + + # First pushed should be popped (FIFO) + assert torch.allclose(result, a, atol=1e-6) + + # Pipe should have 1 element remaining (second push) + assert len(ns["_pipes"]["to_aiv"]) == 1 + # The remaining element should be the second pushed + assert torch.allclose(ns["_pipes"]["to_aiv"][0], b, atol=1e-6) + + if __name__ == "__main__": pytest.main([__file__, "-v"])